diff --git a/ding/policy/__init__.py b/ding/policy/__init__.py index 794a7f94fa..25e8b67c4d 100755 --- a/ding/policy/__init__.py +++ b/ding/policy/__init__.py @@ -1,5 +1,5 @@ from .base_policy import Policy, CommandModePolicy, create_policy, get_policy_cls -from .common_utils import single_env_forward_wrapper, single_env_forward_wrapper_ttorch +from .common_utils import single_env_forward_wrapper, single_env_forward_wrapper_ttorch, default_preprocess_learn from .dqn import DQNSTDIMPolicy, DQNPolicy from .mdqn import MDQNPolicy from .iqn import IQNPolicy @@ -17,8 +17,8 @@ from .pg import PGPolicy from .a2c import A2CPolicy from .ppo import PPOPolicy, PPOPGPolicy, PPOOffPolicy -from .sac import SACPolicy, SACDiscretePolicy, SQILSACPolicy -from .cql import CQLPolicy, CQLDiscretePolicy +from .sac import SACPolicy, DiscreteSACPolicy, SQILSACPolicy +from .cql import CQLPolicy, DiscreteCQLPolicy from .edac import EDACPolicy from .impala import IMPALAPolicy from .ngu import NGUPolicy diff --git a/ding/policy/base_policy.py b/ding/policy/base_policy.py index 9f08698888..3ff99c7b43 100644 --- a/ding/policy/base_policy.py +++ b/ding/policy/base_policy.py @@ -1,10 +1,10 @@ +from typing import Optional, List, Dict, Any, Tuple, Union from abc import ABC, abstractmethod from collections import namedtuple -from typing import Optional, List, Dict, Any, Tuple, Union +from easydict import EasyDict -import torch import copy -from easydict import EasyDict +import torch from ding.model import create_model from ding.utils import import_module, allreduce, broadcast, get_rank, allreduce_async, synchronize, deep_merge_dicts, \ @@ -12,11 +12,28 @@ class Policy(ABC): + """ + Overview: + The basic class of Reinforcement Learning (RL) and Imitation Learning (IL) policy in DI-engine. + Property: + ``cfg``, ``learn_mode``, ``collect_mode``, ``eval_mode`` + """ @classmethod def default_config(cls: type) -> EasyDict: + """ + Overview: + Get the default config of policy. This method is used to create the default config of policy. + Returns: + - cfg (:obj:`EasyDict`): The default config of corresponding policy. For the derived policy class, \ + it will recursively merge the default config of base class and its own default config. + + .. tip:: + This method will deepcopy the ``config`` attribute of the class and return the result. So users don't need \ + to worry about the modification of the returned config. + """ if cls == Policy: - raise RuntimeError + raise RuntimeError("Basic class Policy doesn't have completed default_config") base_cls = cls.__base__ if base_cls == Policy: @@ -64,20 +81,47 @@ def default_config(cls: type) -> EasyDict: ) total_field = set(['learn', 'collect', 'eval']) config = dict( + # (bool) Whether the learning policy is the same as the collecting data policy (on-policy). on_policy=False, + # (bool) Whether to use cuda in policy. cuda=False, + # (bool) Whether to use data parallel multi-gpu mode in policy. multi_gpu=False, + # (bool) Whether to synchronize update the model parameters after allreduce the gradients of model parameters. bp_update_sync=True, + # (bool) Whether to enable infinite trajectory length in data collecting. traj_len_inf=False, + # neural network model config model=dict(), ) def __init__( self, - cfg: dict, - model: Optional[Union[type, torch.nn.Module]] = None, + cfg: EasyDict, + model: Optional[torch.nn.Module] = None, enable_field: Optional[List[str]] = None ) -> None: + """ + Overview: + Initialize policy instance according to input configures and model. This method will initialize differnent \ + fields in policy, including ``learn``, ``collect``, ``eval``. The ``learn`` field is used to train the \ + policy, the ``collect`` field is used to collect data for training, and the ``eval`` field is used to \ + evaluate the policy. The ``enable_field`` is used to specify which field to initialize, if it is None, \ + then all fields will be initialized. + Arguments: + - cfg (:obj:`EasyDict`): The final merged config used to initialize policy. For the default config, \ + see the ``config`` attribute and its comments of policy class. + - model (:obj:`torch.nn.Module`): The neural network model used to initialize policy. If it \ + is None, then the model will be created according to ``default_model`` method and ``cfg.model`` field. \ + Otherwise, the model will be set to the ``model`` instance created by outside caller. + - enable_field (:obj:`Optional[List[str]]`): The field list to initialize. If it is None, then all fields \ + will be initialized. Otherwise, only the fields in ``enable_field`` will be initialized, which is \ + beneficial to save resources. + + .. note:: + For the derived policy class, it should implement the ``_init_learn``, ``_init_collect``, ``_init_eval`` \ + method to initialize the corresponding field. + """ self._cfg = cfg self._on_policy = self._cfg.on_policy if enable_field is None: @@ -94,6 +138,7 @@ def __init__( multi_gpu = self._cfg.multi_gpu self._rank = get_rank() if multi_gpu else 0 if self._cuda: + # model.cuda() is an in-place operation. model.cuda() if multi_gpu: bp_update_sync = self._cfg.bp_update_sync @@ -102,6 +147,7 @@ def __init__( else: self._rank = 0 if self._cuda: + # model.cuda() is an in-place operation. model.cuda() self._model = model self._device = 'cuda:{}'.format(self._rank % torch.cuda.device_count()) if self._cuda else 'cpu' @@ -110,13 +156,26 @@ def __init__( self._rank = 0 self._device = 'cpu' + # call the initialization method of different modes, such as ``_init_learn``, ``_init_collect``, ``_init_eval`` for field in self._enable_field: getattr(self, '_init_' + field)() def _init_multi_gpu_setting(self, model: torch.nn.Module, bp_update_sync: bool) -> None: + """ + Overview: + Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning \ + of the training, and prepare the hook function to allreduce the gradients of model parameters. + Arguments: + - model (:obj:`torch.nn.Module`): The neural network model to be trained. + - bp_update_sync (:obj:`bool`): Whether to synchronize update the model parameters after allreduce the \ + gradients of model parameters. Async update can be parallel in different network layers like pipeline \ + so that it can save time. + """ for name, param in model.state_dict().items(): assert isinstance(param.data, torch.Tensor), type(param.data) broadcast(param.data, 0) + # here we manually set the gradient to zero tensor at the beginning of the training, which is necessary for + # the case that different GPUs have different computation graph. for name, param in model.named_parameters(): setattr(param, 'grad', torch.zeros_like(param)) if not bp_update_sync: @@ -134,7 +193,23 @@ def hook(*ignore): grad_acc = p_tmp.grad_fn.next_functions[0][0] grad_acc.register_hook(make_hook(name, p)) - def _create_model(self, cfg: dict, model: Optional[torch.nn.Module] = None) -> torch.nn.Module: + def _create_model(self, cfg: EasyDict, model: Optional[torch.nn.Module] = None) -> torch.nn.Module: + """ + Overview: + Create or validate the neural network model according to input configures and model. If the input model is \ + None, then the model will be created according to ``default_model`` method and ``cfg.model`` field. \ + Otherwise, the model will be verified as an instance of ``torch.nn.Module`` and set to the ``model`` \ + instance created by outside caller. + Arguments: + - cfg (:obj:`EasyDict`): The final merged config used to initialize policy. + - model (:obj:`torch.nn.Module`): The neural network model used to initialize policy. User can refer to \ + the default model defined in corresponding policy to customize its own model. + Returns: + - model (:obj:`torch.nn.Module`): The created neural network model. The different modes of policy will \ + add distinct wrappers and plugins to the model, which is used to train, collect and evaluate. + Raises: + - RuntimeError: If the input model is not None and is not an instance of ``torch.nn.Module``. + """ if model is None: model_cfg = cfg.model if 'type' not in model_cfg: @@ -154,18 +229,77 @@ def cfg(self) -> EasyDict: @abstractmethod def _init_learn(self) -> None: + """ + Overview: + Initialize the learn mode of policy, including related attributes and modules. This method will be \ + called in ``__init__`` method if ``learn`` field is in ``enable_field``. Almost different policies have \ + its own learn mode, so this method must be overrided in subclass. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. + """ raise NotImplementedError @abstractmethod def _init_collect(self) -> None: + """ + Overview: + Initialize the collect mode of policy, including related attributes and modules. This method will be \ + called in ``__init__`` method if ``collect`` field is in ``enable_field``. Almost different policies have \ + its own collect mode, so this method must be overrided in subclass. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_collect`` \ + and ``_load_state_dict_collect`` methods. + + .. note:: + If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ + with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. + """ raise NotImplementedError @abstractmethod def _init_eval(self) -> None: + """ + Overview: + Initialize the eval mode of policy, including related attributes and modules. This method will be \ + called in ``__init__`` method if ``eval`` field is in ``enable_field``. Almost different policies have \ + its own eval mode, so this method must be overrided in subclass. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_eval`` \ + and ``_load_state_dict_eval`` methods. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. + """ raise NotImplementedError @property def learn_mode(self) -> 'Policy.learn_function': # noqa + """ + Overview: + Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple \ + to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \ + subclass can override the interfaces to customize its own learn mode. + Returns: + - interfaces (:obj:`Policy.learn_function`): The interfaces of learn mode of policy, it is a namedtuple \ + whose values of distinct fields are different internal methods. + Examples: + >>> policy = Policy(cfg, model) + >>> policy_learn = policy.learn_mode + >>> train_output = policy_learn.forward(data) + >>> state_dict = policy_learn.state_dict() + """ return Policy.learn_function( self._forward_learn, self._reset_learn, @@ -179,6 +313,21 @@ def learn_mode(self) -> 'Policy.learn_function': # noqa @property def collect_mode(self) -> 'Policy.collect_function': # noqa + """ + Overview: + Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple \ + to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \ + subclass can override the interfaces to customize its own collect mode. + Returns: + - interfaces (:obj:`Policy.collect_function`): The interfaces of collect mode of policy, it is a \ + namedtuple whose values of distinct fields are different internal methods. + Examples: + >>> policy = Policy(cfg, model) + >>> policy_collect = policy.collect_mode + >>> obs = env_manager.ready_obs + >>> inference_output = policy_collect.forward(obs) + >>> next_obs, rew, done, info = env_manager.step(inference_output.action) + """ return Policy.collect_function( self._forward_collect, self._process_transition, @@ -192,6 +341,21 @@ def collect_mode(self) -> 'Policy.collect_function': # noqa @property def eval_mode(self) -> 'Policy.eval_function': # noqa + """ + Overview: + Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple \ + to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \ + subclass can override the interfaces to customize its own eval mode. + Returns: + - interfaces (:obj:`Policy.eval_function`): The interfaces of eval mode of policy, it is a namedtuple \ + whose values of distinct fields are different internal methods. + Examples: + >>> policy = Policy(cfg, model) + >>> policy_eval = policy.eval_mode + >>> obs = env_manager.ready_obs + >>> inference_output = policy_eval.forward(obs) + >>> next_obs, rew, done, info = env_manager.step(inference_output.action) + """ return Policy.eval_function( self._forward_eval, self._reset_eval, @@ -202,9 +366,32 @@ def eval_mode(self) -> 'Policy.eval_function': # noqa ) def _set_attribute(self, name: str, value: Any) -> None: + """ + Overview: + In order to control the access of the policy attributes, we expose different modes to outside rather than \ + directly use the policy instance. And we also provide a method to set the attribute of the policy in \ + different modes. And the new attribute will named as ``_{name}``. + Arguments: + - name (:obj:`str`): The name of the attribute. + - value (:obj:`Any`): The value of the attribute. + """ setattr(self, '_' + name, value) def _get_attribute(self, name: str) -> Any: + """ + Overview: + In order to control the access of the policy attributes, we expose different modes to outside rather than \ + directly use the policy instance. And we also provide a method to get the attribute of the policy in \ + different modes. + Arguments: + - name (:obj:`str`): The name of the attribute. + Returns: + - value (:obj:`Any`): The value of the attribute. + + .. note:: + DI-engine's policy will first try to access `_get_{name}` method, and then try to access `_{name}` \ + attribute. If both of them are not found, it will raise a ``NotImplementedError``. + """ if hasattr(self, '_get_' + name): return getattr(self, '_get_' + name)() elif hasattr(self, '_' + name): @@ -213,9 +400,27 @@ def _get_attribute(self, name: str) -> Any: raise NotImplementedError def __repr__(self) -> str: + """ + Overview: + Get the string representation of the policy. + Returns: + - repr (:obj:`str`): The string representation of the policy. + """ return "DI-engine DRL Policy\n{}".format(repr(self._model)) def sync_gradients(self, model: torch.nn.Module) -> None: + """ + Overview: + Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training. + Arguments: + - model (:obj:`torch.nn.Module`): The model to synchronize gradients. + + .. note:: + This method is only used in multi-gpu training, and it shoule be called after ``backward`` method and \ + before ``step`` method. The user can also use ``bp_update_sync`` config to control whether to synchronize \ + gradients allreduce and optimizer updates. + """ + if self._bp_update_sync: for name, param in model.named_parameters(): if param.requires_grad: @@ -225,28 +430,99 @@ def sync_gradients(self, model: torch.nn.Module) -> None: # don't need to implement default_model method by force def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. + + .. note:: + The user can define and use customized network model but must obey the same inferface definition indicated \ + by import_names path. For example about DQN, its registered name is ``dqn`` and the import_names is \ + ``ding.model.template.q_learning.DQN`` + """ raise NotImplementedError # *************************************** learn function ************************************ @abstractmethod - def _forward_learn(self, data: dict) -> Dict[str, Any]: + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Overview: + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss value, policy entropy, q value, priority, \ + and so on. This method is left to be implemented by the subclass, and more arguments can be added in \ + ``data`` item if necessary. + Arguments: + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, in the ``_forward_learn`` method, data should be stacked in \ + the batch dimension by some utility functions such as ``default_preprocess_learn``. + Returns: + - output (:obj:`Dict[int, Any]`): The training information of policy forward, including some metrics for \ + monitoring training such as loss, priority, q value, policy entropy, and some data for next step \ + training such as priority. Note the output data item should be Python native scalar rather than \ + PyTorch tensor, which is convenient for the outside to use. + """ raise NotImplementedError # don't need to implement _reset_learn method by force def _reset_learn(self, data_id: Optional[List[int]] = None) -> None: + """ + Overview: + Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the \ + memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \ + varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ + different trajectories in ``data_id`` will have different hidden state in RNN. + Arguments: + - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ + specified by ``data_id``. + + .. note:: + This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary. + """ pass def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + + .. tip:: + The default implementation is ``['cur_lr', 'total_loss']``. Other derived classes can overwrite this \ + method to add their own keys if necessary. + """ return ['cur_lr', 'total_loss'] def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model and optimizer. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. + """ return { 'model': self._learn_model.state_dict(), 'optimizer': self._optimizer.state_dict(), } def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + + .. tip:: + If you want to only load some parts of model, you can simply set the ``strict`` argument in \ + load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ + complicated operation. + """ self._learn_model.load_state_dict(state_dict['model']) self._optimizer.load_state_dict(state_dict['optimizer']) @@ -260,34 +536,123 @@ def _get_batch_size(self) -> Union[int, Dict[str, int]]: # *************************************** collect function ************************************ @abstractmethod - def _forward_collect(self, data: dict, **kwargs) -> dict: + def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: + """ + Overview: + Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ + that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ + data, such as the action to interact with the envs, or the action logits to calculate the loss in learn \ + mode. This method is left to be implemented by the subclass, and more arguments can be added in ``kwargs`` \ + part if necessary. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ + other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ + dict is the same as the input data, i.e. environment id. + """ raise NotImplementedError @abstractmethod - def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: + def _process_transition( + self, obs: Union[torch.Tensor, Dict[str, torch.Tensor]], policy_output: Dict[str, torch.Tensor], + timestep: namedtuple + ) -> Dict[str, torch.Tensor]: + """ + Overview: + Process and pack one timestep transition data into a dict, such as . Some policies \ + need to do some special process and pack its own necessary attributes (e.g. hidden state and logit), \ + so this method is left to be implemented by the subclass. + Arguments: + - obs (:obj:`Union[torch.Tensor, Dict[str, torch.Tensor]]`): The observation of the current timestep. + - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ + as input. Usually, it contains the action and the logit of the action. + - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ + except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ + reward, done, info, etc. + Returns: + - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. + """ raise NotImplementedError @abstractmethod - def _get_train_sample(self, data: list) -> Union[None, List[Any]]: + def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Overview: + For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ + can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) \ + or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary \ + RL data preprocessing before training, which can help learner amortize revelant time consumption. \ + In addition, you can also implement this method as an identity function and do the data processing \ + in ``self._forward_learn`` method. + Arguments: + - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ + the same format as the return value of ``self._process_transition`` method. + Returns: + - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \ + as input transitions, but may contain more data for training, such as nstep reward, advantage, etc. + + .. note:: + We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \ + And the user can customize the this data processing procecure by overriding this two methods and collector \ + itself + """ raise NotImplementedError # don't need to implement _reset_collect method by force def _reset_collect(self, data_id: Optional[List[int]] = None) -> None: + """ + Overview: + Reset some stateful variables for collect mode when necessary, such as the hidden state of RNN or the \ + memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \ + varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ + different environments/episodes in collecting in ``data_id`` will have different hidden state in RNN. + Arguments: + - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ + specified by ``data_id``. + + .. note:: + This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary. + """ pass def _state_dict_collect(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of collect mode, only including model in usual, which is necessary for distributed \ + training scenarios to auto-recover collectors. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy collect state, for saving and restoring. + + .. tip:: + Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed \ + collector and renew a new one. + """ return {'model': self._collect_model.state_dict()} def _load_state_dict_collect(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover \ + checkpoint, or model replica from learner in distributed training scenarios. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy collect state saved before. + + .. tip:: + If you want to only load some parts of model, you can simply set the ``strict`` argument in \ + load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ + complicated operation. + """ self._collect_model.load_state_dict(state_dict['model'], strict=True) - def _get_n_sample(self): + def _get_n_sample(self) -> Union[int, None]: if 'n_sample' in self._cfg: return self._cfg.n_sample else: # for compatibility return self._cfg.collect.get('n_sample', None) # for some adpative collecting data case - def _get_n_episode(self): + def _get_n_episode(self) -> Union[int, None]: if 'n_episode' in self._cfg: return self._cfg.n_episode else: # for compatibility @@ -296,54 +661,201 @@ def _get_n_episode(self): # *************************************** eval function ************************************ @abstractmethod - def _forward_eval(self, data: dict) -> Dict[str, Any]: + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ + Overview: + Policy forward function of eval mode (evaluation policy performance, such as interacting with envs or \ + computing metrics on validation dataset). Forward means that the policy gets some necessary data (mainly \ + observation) from the envs and then returns the output data, such as the action to interact with the envs. \ + This method is left to be implemented by the subclass. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + """ raise NotImplementedError # don't need to implement _reset_eval method by force def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: + """ + Overview: + Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the \ + memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \ + varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ + different environments/episodes in evaluation in ``data_id`` will have different hidden state in RNN. + Arguments: + - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ + specified by ``data_id``. + + .. note:: + This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary. + """ pass def _state_dict_eval(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of eval mode, only including model in usual, which is necessary for distributed \ + training scenarios to auto-recover evaluators. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy eval state, for saving and restoring. + + .. tip:: + Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed \ + evaluator and renew a new one. + """ return {'model': self._eval_model.state_dict()} def _load_state_dict_eval(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy eval mode, such as load auto-recover \ + checkpoint, or model replica from learner in distributed training scenarios. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy eval state saved before. + + .. tip:: + If you want to only load some parts of model, you can simply set the ``strict`` argument in \ + load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ + complicated operation. + """ self._eval_model.load_state_dict(state_dict['model'], strict=True) class CommandModePolicy(Policy): + """ + Overview: + Policy with command mode, which can be used in old version of DI-engine pipeline: ``serial_pipeline``. \ + ``CommandModePolicy`` uses ``_get_setting_learn``, ``_get_setting_collect``, ``_get_setting_eval`` methods \ + to exchange information between different workers. + + Interface: + ``_init_command``, ``_get_setting_learn``, ``_get_setting_collect``, ``_get_setting_eval`` + Property: + ``command_mode`` + """ command_function = namedtuple('command_function', ['get_setting_learn', 'get_setting_collect', 'get_setting_eval']) total_field = set(['learn', 'collect', 'eval', 'command']) @property def command_mode(self) -> 'Policy.command_function': # noqa + """ + Overview: + Return the interfaces of command mode of policy, which is used to train the model. Here we use namedtuple \ + to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \ + subclass can override the interfaces to customize its own command mode. + Returns: + - interfaces (:obj:`Policy.command_function`): The interfaces of command mode, it is a namedtuple \ + whose values of distinct fields are different internal methods. + Examples: + >>> policy = CommandModePolicy(cfg, model) + >>> policy_command = policy.command_mode + >>> settings = policy_command.get_setting_learn(command_info) + """ return CommandModePolicy.command_function( self._get_setting_learn, self._get_setting_collect, self._get_setting_eval ) @abstractmethod def _init_command(self) -> None: + """ + Overview: + Initialize the command mode of policy, including related attributes and modules. This method will be \ + called in ``__init__`` method if ``command`` field is in ``enable_field``. Almost different policies have \ + its own command mode, so this method must be overrided in subclass. + + .. note:: + If you want to set some spacial member variables in ``_init_command`` method, you'd better name them \ + with prefix ``_command_`` to avoid conflict with other modes, such as ``self._command_attr1``. + """ raise NotImplementedError # *************************************** command function ************************************ @abstractmethod - def _get_setting_learn(self, command_info: dict) -> dict: + def _get_setting_learn(self, command_info: Dict[str, Any]) -> Dict[str, Any]: + """ + Overview: + Accoding to ``command_info``, i.e., global training information (e.g. training iteration, collected env \ + step, evaluation results, etc.), return the setting of learn mode, which contains dynamically changed \ + hyperparameters for learn mode, such as ``batch_size``, ``learning_rate``, etc. + Arguments: + - command_info (:obj:`Dict[str, Any]`): The global training information, which is defined in ``commander``. + Returns: + - setting (:obj:`Dict[str, Any]`): The latest setting of learn mode, which is usually used as extra \ + arguments of the ``policy._forward_learn`` method. + """ raise NotImplementedError @abstractmethod - def _get_setting_collect(self, command_info: dict) -> dict: + def _get_setting_collect(self, command_info: Dict[str, Any]) -> Dict[str, Any]: + """ + Overview: + Accoding to ``command_info``, i.e., global training information (e.g. training iteration, collected env \ + step, evaluation results, etc.), return the setting of collect mode, which contains dynamically changed \ + hyperparameters for collect mode, such as ``eps``, ``temperature``, etc. + Arguments: + - command_info (:obj:`Dict[str, Any]`): The global training information, which is defined in ``commander``. + Returns: + - setting (:obj:`Dict[str, Any]`): The latest setting of collect mode, which is usually used as extra \ + arguments of the ``policy._forward_collect`` method. + """ raise NotImplementedError @abstractmethod - def _get_setting_eval(self, command_info: dict) -> dict: + def _get_setting_eval(self, command_info: Dict[str, Any]) -> Dict[str, Any]: + """ + Overview: + Accoding to ``command_info``, i.e., global training information (e.g. training iteration, collected env \ + step, evaluation results, etc.), return the setting of eval mode, which contains dynamically changed \ + hyperparameters for eval mode, such as ``temperature``, etc. + Arguments: + - command_info (:obj:`Dict[str, Any]`): The global training information, which is defined in ``commander``. + Returns: + - setting (:obj:`Dict[str, Any]`): The latest setting of eval mode, which is usually used as extra \ + arguments of the ``policy._forward_eval`` method. + """ raise NotImplementedError -def create_policy(cfg: dict, **kwargs) -> Policy: - cfg = EasyDict(cfg) +def create_policy(cfg: EasyDict, **kwargs) -> Policy: + """ + Overview: + Create a policy instance according to ``cfg`` and other kwargs. + Arguments: + - cfg (:obj:`EasyDict`): Final merged policy config. + ArgumentsKeys: + - type (:obj:`str`): Policy type set in ``POLICY_REGISTRY.register`` method , such as ``dqn`` . + - import_names (:obj:`List[str]`): A list of module names (paths) to import before creating policy, such \ + as ``ding.policy.dqn`` . + Returns: + - policy (:obj:`Policy`): The created policy instance. + + .. tip:: + ``kwargs`` contains other arguments that need to be passed to the policy constructor. You can refer to \ + the ``__init__`` method of the corresponding policy class for details. + + .. note:: + For more details about how to merge config, please refer to the system document of DI-engine \ + (`en link <../03_system/config.html>`_). + """ import_module(cfg.get('import_names', [])) return POLICY_REGISTRY.build(cfg.type, cfg=cfg, **kwargs) def get_policy_cls(cfg: EasyDict) -> type: + """ + Overview: + Get policy class according to ``cfg``, which is used to access related class variables/methods. + Arguments: + - cfg (:obj:`EasyDict`): Final merged policy config. + ArgumentsKeys: + - type (:obj:`str`): Policy type set in ``POLICY_REGISTRY.register`` method , such as ``dqn`` . + - import_names (:obj:`List[str]`): A list of module names (paths) to import before creating policy, such \ + as ``ding.policy.dqn`` . + Returns: + - policy (:obj:`type`): The policy class. + """ import_module(cfg.get('import_names', [])) return POLICY_REGISTRY.get(cfg.type) diff --git a/ding/policy/bc.py b/ding/policy/bc.py index 2ef8ce672f..0c95b8abec 100644 --- a/ding/policy/bc.py +++ b/ding/policy/bc.py @@ -20,6 +20,11 @@ @POLICY_REGISTRY.register('bc') class BehaviourCloningPolicy(Policy): + """ + Overview: + Behaviour Cloning (BC) policy class, which supports both discrete and continuous action space. \ + The policy is trained by supervised learning, and the data is a offline dataset collected by expert. + """ config = dict( type='bc', @@ -52,18 +57,46 @@ class BehaviourCloningPolicy(Policy): max=0.5, ), ), - eval=dict(), - other=dict(replay_buffer=dict(replay_buffer_size=10000, )), + eval=dict(), # for compatibility ) def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. + + .. note:: + The user can define and use customized network model but must obey the same inferface definition indicated \ + by import_names path. For example about discrete BC, its registered name is ``discrete_bc`` and the \ + import_names is ``ding.model.template.bc``. + """ if self._cfg.continuous: return 'continuous_bc', ['ding.model.template.bc'] else: return 'discrete_bc', ['ding.model.template.bc'] - def _init_learn(self): - assert self._cfg.learn.optimizer in ['SGD', 'Adam'] + def _init_learn(self) -> None: + """ + Overview: + Initialize the learn mode of policy, including related attributes and modules. For BC, it mainly contains \ + optimizer, algorithm-specific arguments such as lr_scheduler, loss, etc. \ + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. + """ + assert self._cfg.learn.optimizer in ['SGD', 'Adam'], self._cfg.learn.optimizer if self._cfg.learn.optimizer == 'SGD': self._optimizer = SGD( self._model.parameters(), @@ -103,20 +136,38 @@ def lr_scheduler_fn(epoch): elif self._cfg.loss_type == 'mse_loss': self._loss = nn.MSELoss() else: - raise KeyError + raise KeyError("not support loss type: {}".format(self._cfg.loss_type)) else: if not self._cfg.learn.ce_label_smooth: self._loss = nn.CrossEntropyLoss() else: self._loss = LabelSmoothCELoss(0.1) - if self._cfg.learn.show_accuracy: - # accuracy statistics for debugging in discrete action space env, e.g. for gfootball - self.total_accuracy_in_dataset = [] - self.action_accuracy_in_dataset = {k: [] for k in range(self._cfg.action_shape)} + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Overview: + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss and time. + Arguments: + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For BC, each element in list is a dict containing at least the following keys: ``obs``, ``action``. + Returns: + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. - def _forward_learn(self, data): - if not isinstance(data, dict): + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + """ + if isinstance(data, list): data = default_collate(data) if self._cuda: data = to_device(data, self._device) @@ -125,10 +176,10 @@ def _forward_learn(self, data): obs, action = data['obs'], data['action'].squeeze() if self._cfg.continuous: if self._cfg.learn.tanh_mask: - ''' + """tanh_mask We mask the action out of range of [tanh(-1),tanh(1)], model will learn information and produce action in [-1,1]. So the action won't always converge to -1 or 1. - ''' + """ mu = self._eval_model.forward(data['obs'])['action'] bound = 1 - 2 / (math.exp(2) + 1) # tanh(1): (e-e**(-1))/(e+e**(-1)) mask = mu.ge(-bound) & mu.le(bound) @@ -183,28 +234,57 @@ def _forward_learn(self, data): 'sync_time': sync_time, } - def _monitor_vars_learn(self): + def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ return ['cur_lr', 'total_loss', 'forward_time', 'backward_time', 'sync_time'] def _init_eval(self): + """ + Overview: + Initialize the eval mode of policy, including related attributes and modules. For BC, it contains the \ + eval model to greedily select action with argmax q_value mechanism for discrete action space. + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. + """ if self._cfg.continuous: self._eval_model = model_wrap(self._model, wrapper_name='base') else: self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') self._eval_model.reset() - def _forward_eval(self, data): - gfootball_flag = False + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ + Overview: + Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ + means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ + action to interact with the envs. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + """ tensor_input = isinstance(data, torch.Tensor) if tensor_input: data = default_collate(list(data)) else: data_id = list(data.keys()) - if data_id == ['processed_obs', 'raw_obs']: - # for gfootball - gfootball_flag = True - data = {0: data} - data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: data = to_device(data, self._device) @@ -213,22 +293,20 @@ def _forward_eval(self, data): output = self._eval_model.forward(data) if self._cuda: output = to_device(output, 'cpu') - if tensor_input or gfootball_flag: + if tensor_input: return output else: output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} def _init_collect(self) -> None: - r""" + """ Overview: - Collect mode init method. Called by ``self.__init__``. - Init traj and unroll length, collect model. - Enable the eps_greedy_sample + BC policy uses offline dataset so it does not need to collect data. However, sometimes we need to use the \ + trained BC policy to collect data for other purposes. """ self._unroll_len = self._cfg.collect.unroll_len if self._cfg.continuous: - # self._collect_model = model_wrap(self._model, wrapper_name='base') self._collect_model = model_wrap( self._model, wrapper_name='action_noise', @@ -244,14 +322,6 @@ def _init_collect(self) -> None: self._collect_model.reset() def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: - r""" - Overview: - Forward function for collect mode with eps_greedy - Arguments: - - data (:obj:`dict`): Dict type data, including at least ['obs']. - Returns: - - data (:obj:`dict`): The collected data - """ data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: @@ -268,43 +338,16 @@ def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} - def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: - r""" - Overview: - Generate dict type transition data from inputs. - Arguments: - - obs (:obj:`Any`): Env observation - - model_output (:obj:`dict`): Output of collect model, including at least ['action'] - - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \ - (here 'obs' indicates obs after env step). - Returns: - - transition (:obj:`dict`): Dict type transition data. - """ + def _process_transition(self, obs: Any, policy_output: dict, timestep: namedtuple) -> dict: transition = { 'obs': obs, 'next_obs': timestep.obs, - 'action': model_output['action'], + 'action': policy_output['action'], 'reward': timestep.reward, 'done': timestep.done, } return EasyDict(transition) def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """ - Overview: - For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \ - can be used for training directly. A train sample can be a processed transition(DQN with nstep TD) \ - or some continuous transitions(DRQN). - Arguments: - - data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \ - format as the return value of ``self._process_transition`` method. - Returns: - - samples (:obj:`dict`): The list of training samples. - - .. note:: - We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \ - And the user can customize the this data processing procecure by overriding this two methods and collector \ - itself. - """ data = get_nstep_return_data(data, 1, 1) return get_train_sample(data, self._unroll_len) diff --git a/ding/policy/command_mode_policy_instance.py b/ding/policy/command_mode_policy_instance.py index 4fe7cab6d7..2d5e3271dd 100644 --- a/ding/policy/command_mode_policy_instance.py +++ b/ding/policy/command_mode_policy_instance.py @@ -24,7 +24,7 @@ from .td3 import TD3Policy from .td3_vae import TD3VAEPolicy from .td3_bc import TD3BCPolicy -from .sac import SACPolicy, SACDiscretePolicy +from .sac import SACPolicy, DiscreteSACPolicy, SQILSACPolicy from .mbpolicy.mbsac import MBSACPolicy, STEVESACPolicy from .mbpolicy.dreamer import DREAMERPolicy from .qmix import QMIXPolicy @@ -42,10 +42,9 @@ from .r2d3 import R2D3Policy from .d4pg import D4PGPolicy -from .cql import CQLPolicy, CQLDiscretePolicy +from .cql import CQLPolicy, DiscreteCQLPolicy from .dt import DTPolicy from .pdqn import PDQNPolicy -from .sac import SQILSACPolicy from .madqn import MADQNPolicy from .bdq import BDQPolicy from .bcq import BCQPolicy @@ -316,8 +315,8 @@ class CQLCommandModePolicy(CQLPolicy, DummyCommandModePolicy): pass -@POLICY_REGISTRY.register('cql_discrete_command') -class CQLDiscreteCommandModePolicy(CQLDiscretePolicy, EpsCommandModePolicy): +@POLICY_REGISTRY.register('discrete_cql_command') +class DiscreteCQLCommandModePolicy(DiscreteCQLPolicy, EpsCommandModePolicy): pass @@ -376,8 +375,8 @@ class PDQNCommandModePolicy(PDQNPolicy, EpsCommandModePolicy): pass -@POLICY_REGISTRY.register('sac_discrete_command') -class SACDiscreteCommandModePolicy(SACDiscretePolicy, EpsCommandModePolicy): +@POLICY_REGISTRY.register('discrete_sac_command') +class DiscreteSACCommandModePolicy(DiscreteSACPolicy, EpsCommandModePolicy): pass diff --git a/ding/policy/common_utils.py b/ding/policy/common_utils.py index ac88f38c0d..fd2c7d3d61 100644 --- a/ding/policy/common_utils.py +++ b/ding/policy/common_utils.py @@ -1,4 +1,4 @@ -from typing import List, Any +from typing import List, Any, Dict, Callable import torch import numpy as np import treetensor.torch as ttorch @@ -12,10 +12,24 @@ def default_preprocess_learn( use_priority: bool = False, use_nstep: bool = False, ignore_done: bool = False, -) -> dict: +) -> Dict[str, torch.Tensor]: + """ + Overview: + Default data pre-processing in policy's ``_forward_learn`` method, including stacking batch data, preprocess \ + ignore done, nstep and priority IS weight. + Arguments: + - data (:obj:`List[Any]`): The list of a training batch samples, each sample is a dict of PyTorch Tensor. + - use_priority_IS_weight (:obj:`bool`): Whether to use priority IS weight correction, if True, this function \ + will set the weight of each sample to the priority IS weight. + - use_priority (:obj:`bool`): Whether to use priority, if True, this function will set the priority IS weight. + - use_nstep (:obj:`bool`): Whether to use nstep TD error, if True, this function will reshape the reward. + - ignore_done (:obj:`bool`): Whether to ignore done, if True, this function will set the done to 0. + Returns: + - data (:obj:`Dict[str, torch.Tensor]`): The preprocessed dict data whose values can be directly used for \ + the following model forward and loss computation. + """ # data preprocess - if data[0]['action'].dtype in [np.int8, np.int16, np.int32, np.int64, torch.int8, torch.int16, torch.int32, - torch.int64]: + if data[0]['action'].dtype in [np.int64, torch.int64]: data = default_collate(data, cat_1dim=True) # for discrete action else: data = default_collate(data, cat_1dim=False) # for continuous action @@ -42,7 +56,7 @@ def default_preprocess_learn( else: data['weight'] = data.get('weight', None) if use_nstep: - # Reward reshaping for n-step + # reward reshaping for n-step reward = data['reward'] if len(reward.shape) == 1: reward = reward.unsqueeze(1) @@ -55,7 +69,23 @@ def default_preprocess_learn( return data -def single_env_forward_wrapper(forward_fn): +def single_env_forward_wrapper(forward_fn: Callable) -> Callable: + """ + Overview: + Wrap policy to support gym-style interaction between policy and single environment. + Arguments: + - forward_fn (:obj:`Callable`): The original forward function of policy. + Returns: + - wrapped_forward_fn (:obj:`Callable`): The wrapped forward function of policy. + Examples: + >>> env = gym.make('CartPole-v0') + >>> policy = DQNPolicy(...) + >>> forward_fn = single_env_forward_wrapper(policy.eval_mode.forward) + >>> obs = env.reset() + >>> action = forward_fn(obs) + >>> next_obs, rew, done, info = env.step(action) + + """ def _forward(obs): obs = {0: unsqueeze(to_tensor(obs))} @@ -66,7 +96,24 @@ def _forward(obs): return _forward -def single_env_forward_wrapper_ttorch(forward_fn, cuda=True): +def single_env_forward_wrapper_ttorch(forward_fn: Callable, cuda: bool = True) -> Callable: + """ + Overview: + Wrap policy to support gym-style interaction between policy and single environment for treetensor (ttorch) data. + Arguments: + - forward_fn (:obj:`Callable`): The original forward function of policy. + - cuda (:obj:`bool`): Whether to use cuda in policy, if True, this function will move the input data to cuda. + Returns: + - wrapped_forward_fn (:obj:`Callable`): The wrapped forward function of policy. + + Examples: + >>> env = gym.make('CartPole-v0') + >>> policy = PPOFPolicy(...) + >>> forward_fn = single_env_forward_wrapper_ttorch(policy.eval) + >>> obs = env.reset() + >>> action = forward_fn(obs) + >>> next_obs, rew, done, info = env.step(action) + """ def _forward(obs): # unsqueeze means add batch dim, i.e. (O, ) -> (1, O) diff --git a/ding/policy/cql.py b/ding/policy/cql.py index 8a850f24d7..0910ebe262 100644 --- a/ding/policy/cql.py +++ b/ding/policy/cql.py @@ -12,147 +12,118 @@ from ding.utils import POLICY_REGISTRY from ding.utils.data import default_collate, default_decollate from .sac import SACPolicy -from .dqn import DQNPolicy +from .qrdqn import QRDQNPolicy from .common_utils import default_preprocess_learn @POLICY_REGISTRY.register('cql') class CQLPolicy(SACPolicy): """ - Overview: - Policy class of CQL algorithm. - - Config: - == ==================== ======== ============= ================================= ======================= - ID Symbol Type Default Value Description Other(Shape) - == ==================== ======== ============= ================================= ======================= - 1 ``type`` str td3 | RL policy register name, refer | this arg is optional, - | to registry ``POLICY_REGISTRY`` | a placeholder - 2 ``cuda`` bool True | Whether to use cuda for network | - 3 | ``random_`` int 10000 | Number of randomly collected | Default to 10000 for - | ``collect_size`` | training samples in replay | SAC, 25000 for DDPG/ - | | buffer when training starts. | TD3. - 4 | ``model.policy_`` int 256 | Linear layer size for policy | - | ``embedding_size`` | network. | - 5 | ``model.soft_q_`` int 256 | Linear layer size for soft q | - | ``embedding_size`` | network. | - 6 | ``model.value_`` int 256 | Linear layer size for value | Defalut to None when - | ``embedding_size`` | network. | model.value_network - | | | is False. - 7 | ``learn.learning`` float 3e-4 | Learning rate for soft q | Defalut to 1e-3, when - | ``_rate_q`` | network. | model.value_network - | | | is True. - 8 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to 1e-3, when - | ``_rate_policy`` | network. | model.value_network - | | | is True. - 9 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to None when - | ``_rate_value`` | network. | model.value_network - | | | is False. - 10 | ``learn.alpha`` float 0.2 | Entropy regularization | alpha is initiali- - | | coefficient. | zation for auto - | | | `alpha`, when - | | | auto_alpha is True - 11 | ``learn.repara_`` bool True | Determine whether to use | - | ``meterization`` | reparameterization trick. | - 12 | ``learn.`` bool False | Determine whether to use | Temperature parameter - | ``auto_alpha`` | auto temperature parameter | determines the - | | `alpha`. | relative importance - | | | of the entropy term - | | | against the reward. - 13 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only - | ``ignore_done`` | done flag. | in halfcheetah env. - 14 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation - | ``target_theta`` | target network. | factor in polyak aver - | | | aging for target - | | | networks. - == ==================== ======== ============= ================================= ======================= - """ + Overview: + Policy class of CQL algorithm for continuous control. Paper link: https://arxiv.org/abs/2006.04779. + + Config: + == ==================== ======== ============= ================================= ======================= + ID Symbol Type Default Value Description Other(Shape) + == ==================== ======== ============= ================================= ======================= + 1 ``type`` str cql | RL policy register name, refer | this arg is optional, + | to registry ``POLICY_REGISTRY`` | a placeholder + 2 ``cuda`` bool True | Whether to use cuda for network | + 3 | ``random_`` int 10000 | Number of randomly collected | Default to 10000 for + | ``collect_size`` | training samples in replay | SAC, 25000 for DDPG/ + | | buffer when training starts. | TD3. + 4 | ``model.policy_`` int 256 | Linear layer size for policy | + | ``embedding_size`` | network. | + 5 | ``model.soft_q_`` int 256 | Linear layer size for soft q | + | ``embedding_size`` | network. | + 6 | ``model.value_`` int 256 | Linear layer size for value | Defalut to None when + | ``embedding_size`` | network. | model.value_network + | | | is False. + 7 | ``learn.learning`` float 3e-4 | Learning rate for soft q | Defalut to 1e-3, when + | ``_rate_q`` | network. | model.value_network + | | | is True. + 8 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to 1e-3, when + | ``_rate_policy`` | network. | model.value_network + | | | is True. + 9 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to None when + | ``_rate_value`` | network. | model.value_network + | | | is False. + 10 | ``learn.alpha`` float 0.2 | Entropy regularization | alpha is initiali- + | | coefficient. | zation for auto + | | | `alpha`, when + | | | auto_alpha is True + 11 | ``learn.repara_`` bool True | Determine whether to use | + | ``meterization`` | reparameterization trick. | + 12 | ``learn.`` bool False | Determine whether to use | Temperature parameter + | ``auto_alpha`` | auto temperature parameter | determines the + | | `alpha`. | relative importance + | | | of the entropy term + | | | against the reward. + 13 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only + | ``ignore_done`` | done flag. | in halfcheetah env. + 14 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation + | ``target_theta`` | target network. | factor in polyak aver + | | | aging for target + | | | networks. + == ==================== ======== ============= ================================= ======================= + """ config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). - type='sac', - # (bool) Whether to use cuda for network. + type='cql', + # (bool) Whether to use cuda for policy. cuda=False, - # (bool type) on_policy: Determine whether on-policy or off-policy. + # (bool) on_policy: Determine whether on-policy or off-policy. # on-policy setting influences the behaviour of buffer. - # Default False in SAC. on_policy=False, - multi_agent=False, - # (bool type) priority: Determine whether to use priority in buffer sample. - # Default False in SAC. + # (bool) priority: Determine whether to use priority in buffer sample. priority=False, # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. priority_IS_weight=False, # (int) Number of training samples(randomly collected) in replay buffer when training starts. - # Default 10000 in SAC. random_collect_size=10000, model=dict( # (bool type) twin_critic: Determine whether to use double-soft-q-net for target q computation. # Please refer to TD3 about Clipped Double-Q Learning trick, which learns two Q-functions instead of one . # Default to True. twin_critic=True, - - # (bool type) value_network: Determine whether to use value network as the - # original SAC paper (arXiv 1801.01290). - # using value_network needs to set learning_rate_value, learning_rate_q, - # and learning_rate_policy in `cfg.policy.learn`. - # Default to False. - # value_network=False, - # (str type) action_space: Use reparameterization trick for continous action action_space='reparameterization', - # (int) Hidden size for actor network head. actor_head_hidden_size=256, - # (int) Hidden size for critic network head. critic_head_hidden_size=256, ), + # learn_mode config learn=dict( - - # How many updates(iterations) to train after collector's one collection. + # (int) How many updates (iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. - # collect data -> update policy-> collect data -> ... update_per_collect=1, # (int) Minibatch size for gradient descent. batch_size=256, - - # (float type) learning_rate_q: Learning rate for soft q network. - # Default to 3e-4. - # Please set to 1e-3, when model.value_network is True. + # (float) learning_rate_q: Learning rate for soft q network. learning_rate_q=3e-4, - # (float type) learning_rate_policy: Learning rate for policy network. - # Default to 3e-4. - # Please set to 1e-3, when model.value_network is True. + # (float) learning_rate_policy: Learning rate for policy network. learning_rate_policy=3e-4, - # (float type) learning_rate_value: Learning rate for value network. - # `learning_rate_value` should be initialized, when model.value_network is True. - # Please set to 3e-4, when model.value_network is True. - learning_rate_value=3e-4, - - # (float type) learning_rate_alpha: Learning rate for auto temperature parameter `\alpha`. - # Default to 3e-4. + # (float) learning_rate_alpha: Learning rate for auto temperature parameter ``alpha``. learning_rate_alpha=3e-4, - # (float type) target_theta: Used for soft update of the target network, + # (float) target_theta: Used for soft update of the target network, # aka. Interpolation factor in polyak averaging for target networks. - # Default to 0.005. target_theta=0.005, # (float) discount factor for the discounted sum of rewards, aka. gamma. discount_factor=0.99, - - # (float type) alpha: Entropy regularization coefficient. + # (float) alpha: Entropy regularization coefficient. # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details. # If auto_alpha is set to `True`, alpha is initialization for auto `\alpha`. # Default to 0.2. alpha=0.2, - - # (bool type) auto_alpha: Determine whether to use auto temperature parameter `\alpha` . + # (bool) auto_alpha: Determine whether to use auto temperature parameter `\alpha` . # Temperature parameter determines the relative importance of the entropy term against the reward. # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details. # Default to False. # Note that: Using auto alpha needs to set learning_rate_alpha in `cfg.policy.learn`. auto_alpha=True, - # (bool type) log_space: Determine whether to use auto `\alpha` in log space. + # (bool) log_space: Determine whether to use auto `\alpha` in log space. log_space=True, # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum) # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. @@ -162,46 +133,44 @@ class CQLPolicy(SACPolicy): # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), # when the episode step is greater than max episode step. ignore_done=False, - # (float) Weight uniform initialization range in the last output layer + # (float) Weight uniform initialization range in the last output layer. init_w=3e-3, - # (int) The numbers of action sample each at every state s from a uniform-at-random + # (int) The numbers of action sample each at every state s from a uniform-at-random. num_actions=10, # (bool) Whether use lagrange multiplier in q value loss. with_lagrange=False, - # (float) The threshold for difference in Q-values + # (float) The threshold for difference in Q-values. lagrange_thresh=-1, # (float) Loss weight for conservative item. min_q_weight=1.0, # (bool) Whether to use entropy in target q. with_q_entropy=False, ), - collect=dict( - # (int) Cut trajectories into pieces with length "unroll_len". - unroll_len=1, - ), - eval=dict(), - other=dict( - replay_buffer=dict( - # (int type) replay_buffer_size: Max size of replay buffer. - replay_buffer_size=1000000, - # (int type) max_use: Max use times of one data in the buffer. - # Data will be removed once used for too many times. - # Default to infinite. - # max_use=256, - ), - ), + eval=dict(), # for compatibility ) def _init_learn(self) -> None: - r""" + """ Overview: - Learn mode init method. Called by ``self.__init__``. - Init q, value and policy's optimizers, algorithm config, main and target models. + Initialize the learn mode of policy, including related attributes and modules. For SAC, it mainly \ + contains three optimizers, algorithm-specific arguments such as gamma, min_q_weight, with_lagrange and \ + with_q_entropy, main and target model. Especially, the ``auto_alpha`` mechanism for balancing max entropy \ + target is also initialized here. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ - # Init self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight - self._value_network = False self._twin_critic = self._cfg.model.twin_critic self._num_actions = self._cfg.learn.num_actions @@ -235,11 +204,6 @@ def _init_learn(self) -> None: self._model.critic_head[-1].last.bias.data.uniform_(-init_w, init_w) # Optimizers - if self._value_network: - self._optimizer_value = Adam( - self._model.value_critic.parameters(), - lr=self._cfg.learn.learning_rate_value, - ) self._optimizer_q = Adam( self._model.critic.parameters(), lr=self._cfg.learn.learning_rate_q, @@ -291,14 +255,30 @@ def _init_learn(self) -> None: self._forward_learn_cnt = 0 - def _forward_learn(self, data: dict) -> Dict[str, Any]: - r""" + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """ Overview: - Forward and backward function of learn mode. + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss, action, priority. Arguments: - - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs'] + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For CQL, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ + ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight``. Returns: - - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. """ loss_dict = {} data = default_preprocess_learn( @@ -325,37 +305,29 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] # 2. predict target value - if self._value_network: - # predict v value - v_value = self._learn_model.forward(obs, mode='compute_value_critic')['v_value'] - with torch.no_grad(): - next_v_value = self._target_model.forward(next_obs, mode='compute_value_critic')['v_value'] - target_q_value = next_v_value - else: - # target q value. - with torch.no_grad(): - (mu, sigma) = self._learn_model.forward(next_obs, mode='compute_actor')['logit'] - - dist = Independent(Normal(mu, sigma), 1) - pred = dist.rsample() - next_action = torch.tanh(pred) - y = 1 - next_action.pow(2) + 1e-6 - next_log_prob = dist.log_prob(pred).unsqueeze(-1) - next_log_prob = next_log_prob - torch.log(y).sum(-1, keepdim=True) - - next_data = {'obs': next_obs, 'action': next_action} - target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value'] - # the value of a policy according to the maximum entropy objective - if self._twin_critic: - # find min one as target q value - if self._with_q_entropy: - target_q_value = torch.min(target_q_value[0], - target_q_value[1]) - self._alpha * next_log_prob.squeeze(-1) - else: - target_q_value = torch.min(target_q_value[0], target_q_value[1]) + with torch.no_grad(): + (mu, sigma) = self._learn_model.forward(next_obs, mode='compute_actor')['logit'] + + dist = Independent(Normal(mu, sigma), 1) + pred = dist.rsample() + next_action = torch.tanh(pred) + y = 1 - next_action.pow(2) + 1e-6 + next_log_prob = dist.log_prob(pred).unsqueeze(-1) + next_log_prob = next_log_prob - torch.log(y).sum(-1, keepdim=True) + + next_data = {'obs': next_obs, 'action': next_action} + target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value'] + # the value of a policy according to the maximum entropy objective + if self._twin_critic: + # find min one as target q value + if self._with_q_entropy: + target_q_value = torch.min(target_q_value[0], + target_q_value[1]) - self._alpha * next_log_prob.squeeze(-1) else: - if self._with_q_entropy: - target_q_value = target_q_value - self._alpha * next_log_prob.squeeze(-1) + target_q_value = torch.min(target_q_value[0], target_q_value[1]) + else: + if self._with_q_entropy: + target_q_value = target_q_value - self._alpha * next_log_prob.squeeze(-1) # 3. compute q loss if self._twin_critic: @@ -450,20 +422,6 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: if self._twin_critic: new_q_value = torch.min(new_q_value[0], new_q_value[1]) - # 7. (optional)compute value loss - if self._value_network: - # new_q_value: (bs, ), log_prob: (bs, act_shape) -> target_v_value: (bs, ) - if self._with_q_entropy: - target_v_value = (new_q_value.unsqueeze(-1) - self._alpha * log_prob).mean(dim=-1) - else: - target_v_value = new_q_value.unsqueeze(-1).mean(dim=-1) - loss_dict['value_loss'] = F.mse_loss(v_value, target_v_value.detach()) - - # update value network - self._optimizer_value.zero_grad() - loss_dict['value_loss'].backward() - self._optimizer_value.step() - # 8. compute policy loss policy_loss = (self._alpha * log_prob - new_q_value.unsqueeze(-1)).mean() @@ -511,8 +469,7 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: **loss_dict } - def _get_policy_actions(self, data: Dict, num_actions=10, epsilon: float = 1e-6) -> List: - + def _get_policy_actions(self, data: Dict, num_actions: int = 10, epsilon: float = 1e-6) -> List: # evaluate to get action distribution obs = data['obs'] obs = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1]) @@ -528,7 +485,7 @@ def _get_policy_actions(self, data: Dict, num_actions=10, epsilon: float = 1e-6) return action, log_prob.view(-1, num_actions, 1) - def _get_q_value(self, data: Dict, keep=True) -> torch.Tensor: + def _get_q_value(self, data: Dict, keep: bool = True) -> torch.Tensor: new_q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] if self._twin_critic: new_q_value = [value.view(-1, self._num_actions, 1) for value in new_q_value] @@ -539,17 +496,18 @@ def _get_q_value(self, data: Dict, keep=True) -> torch.Tensor: return new_q_value -@POLICY_REGISTRY.register('cql_discrete') -class CQLDiscretePolicy(DQNPolicy): +@POLICY_REGISTRY.register('discrete_cql') +class DiscreteCQLPolicy(QRDQNPolicy): """ - Overview: - Policy class of CQL algorithm in discrete environments. + Overview: + Policy class of discrete CQL algorithm in discrete action space environments. + Paper link: https://arxiv.org/abs/2006.04779. """ config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). - type='cql_discrete', - # (bool) Whether to use cuda for network. + type='discrete_cql', + # (bool) Whether to use cuda for policy. cuda=False, # (bool) Whether the RL algorithm is on-policy or off-policy. on_policy=False, @@ -559,53 +517,43 @@ class CQLDiscretePolicy(DQNPolicy): discount_factor=0.97, # (int) N-step reward for target q_value estimation nstep=1, + # learn_mode config learn=dict( - # How many updates(iterations) to train after collector's one collection. + # (int) How many updates (iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. - # collect data -> update policy-> collect data -> ... update_per_collect=1, + # (int) Minibatch size for one gradient descent. batch_size=64, + # (float) Learning rate for soft q network. learning_rate=0.001, - # ============================================================== - # The following configs are algorithm-specific - # ============================================================== # (int) Frequence of target network update. target_update_freq=100, - # (bool) Whether ignore done(usually for max step termination env) + # (bool) Whether ignore done(usually for max step termination env). ignore_done=False, # (float) Loss weight for conservative item. min_q_weight=1.0, ), - # collect_mode config - collect=dict( - # (int) Cut trajectories into pieces with length "unroll_len". - unroll_len=1, - ), - eval=dict(), - # other config - other=dict( - # Epsilon greedy with decay. - eps=dict( - # (str) Decay type. Support ['exp', 'linear']. - type='exp', - start=0.95, - end=0.1, - # (int) Decay length(env step) - decay=10000, - ), - replay_buffer=dict(replay_buffer_size=10000, ) - ), + eval=dict(), # for compatibility ) - def default_model(self) -> Tuple[str, List[str]]: - return 'qrdqn', ['ding.model.template.q_learning'] - def _init_learn(self) -> None: - r""" - Overview: - Learn mode init method. Called by ``self.__init__``. - Init the optimizer, algorithm config, main and target models. - """ + """ + Overview: + Initialize the learn mode of policy, including related attributes and modules. For DiscreteCQL, it mainly \ + contains the optimizer, algorithm-specific arguments such as gamma, nstep and min_q_weight, main and \ + target model. This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. + """ self._min_q_weight = self._cfg.learn.min_q_weight self._priority = self._cfg.priority # Optimizer @@ -626,15 +574,32 @@ def _init_learn(self) -> None: self._learn_model.reset() self._target_model.reset() - def _forward_learn(self, data: dict) -> Dict[str, Any]: - r""" - Overview: - Forward and backward function of learn mode. - Arguments: - - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs'] - Returns: - - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. - """ + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Overview: + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss, action, priority. + Arguments: + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For DiscreteCQL, each element in list is a dict containing at least the following keys: ``obs``, \ + ``action``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys like ``weight`` \ + and ``value_gamma`` for nstep return computation. + Returns: + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + """ data = default_preprocess_learn( data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True ) @@ -701,70 +666,12 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: # '[histogram]action_distribution': data['action'], } - def _state_dict_learn(self) -> Dict[str, Any]: - return { - 'model': self._learn_model.state_dict(), - 'target_model': self._target_model.state_dict(), - 'optimizer': self._optimizer.state_dict(), - } - - def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - self._learn_model.load_state_dict(state_dict['model']) - self._target_model.load_state_dict(state_dict['target_model']) - self._optimizer.load_state_dict(state_dict['optimizer']) - - def _init_collect(self) -> None: - r""" - Overview: - Collect mode init method. Called by ``self.__init__``. - Init traj and unroll length, collect model. - Enable the eps_greedy_sample - """ - self._unroll_len = self._cfg.collect.unroll_len - self._gamma = self._cfg.discount_factor # necessary for parallel - self._nstep = self._cfg.nstep # necessary for parallel - self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_sample') - self._collect_model.reset() - - def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: + def _monitor_vars_learn(self) -> List[str]: """ Overview: - Forward computation graph of collect mode(collect training data), with eps_greedy for exploration. - Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ - values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. - - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step. + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. Returns: - - output (:obj:`Dict[int, Any]`): The dict of predicting policy_output(action) for the interaction with \ - env and the constructing of transition. - ArgumentsKeys: - - necessary: ``obs`` - ReturnsKeys - - necessary: ``logit``, ``action`` + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. """ - data_id = list(data.keys()) - data = default_collate(list(data.values())) - if self._cuda: - data = to_device(data, self._device) - self._collect_model.eval() - with torch.no_grad(): - output = self._collect_model.forward(data, eps=eps) - if self._cuda: - output = to_device(output, 'cpu') - output = default_decollate(output) - return {i: d for i, d in zip(data_id, output)} - - def _get_train_sample(self, data: list) -> Union[None, List[Any]]: - r""" - Overview: - Get the trajectory and the n step return data, then sample from the n_step return data - Arguments: - - data (:obj:`list`): The trajectory's cache - Returns: - - samples (:obj:`dict`): The training samples generated - """ - data = get_nstep_return_data(data, self._nstep, gamma=self._gamma) - return get_train_sample(data, self._unroll_len) - - def _monitor_vars_learn(self) -> List[str]: return ['cur_lr', 'total_loss', 'q_target', 'q_value'] diff --git a/ding/policy/ddpg.py b/ding/policy/ddpg.py index 8629cca4af..2e253370b8 100644 --- a/ding/policy/ddpg.py +++ b/ding/policy/ddpg.py @@ -14,17 +14,11 @@ @POLICY_REGISTRY.register('ddpg') class DDPGPolicy(Policy): - r""" + """ Overview: - Policy class of DDPG algorithm. - - https://arxiv.org/pdf/1509.02971.pdf - - Property: - learn_mode, collect_mode, eval_mode + Policy class of DDPG algorithm. Paper link: https://arxiv.org/abs/1509.02971. Config: - == ==================== ======== ============= ================================= ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============= ================================= ======================= @@ -68,30 +62,28 @@ class DDPGPolicy(Policy): config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). type='ddpg', - # (bool) Whether to use cuda for network. + # (bool) Whether to use cuda in policy. cuda=False, - # (bool type) on_policy: Determine whether on-policy or off-policy. - # on-policy setting influences the behaviour of buffer. - # Default False in DDPG. + # (bool) Whether learning policy is the same as collecting data policy(on-policy). Default False in DDPG. on_policy=False, - # (bool) Whether use priority(priority sample, IS weight, update priority) - # Default False in DDPG. + # (bool) Whether to enable priority experience sample. priority=False, # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. priority_IS_weight=False, # (int) Number of training samples(randomly collected) in replay buffer when training starts. # Default 25000 in DDPG/TD3. random_collect_size=25000, - # (bool) Whether to need policy data in process transition + # (bool) Whether to need policy data in process transition. transition_with_policy_data=False, - # (str) Action space type - action_space='continuous', # ['continuous', 'hybrid'] - # (bool) Whether use batch normalization for reward + # (str) Action space type, including ['continuous', 'hybrid']. + action_space='continuous', + # (bool) Whether use batch normalization for reward. reward_batch_norm=False, - # (bool) Whether to enable multi-agent training setting + # (bool) Whether to enable multi-agent training setting. multi_agent=False, + # learn_mode config learn=dict( - # How many updates(iterations) to train after collector's one collection. + # (int) How many updates(iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. # collect data -> update policy-> collect data -> ... update_per_collect=1, @@ -109,7 +101,7 @@ class DDPGPolicy(Policy): # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), # when the episode step is greater than max episode step. ignore_done=False, - # (float type) target_theta: Used for soft update of the target network, + # (float) target_theta: Used for soft update of the target network, # aka. Interpolation factor in polyak averaging for target networks. # Default to 0.005. target_theta=0.005, @@ -124,39 +116,55 @@ class DDPGPolicy(Policy): # Default True for TD3, False for DDPG. noise=False, ), + # collect_mode config collect=dict( - # (int) Only one of [n_sample, n_episode] shoule be set + # (int) How many training samples collected in one collection procedure. + # Only one of [n_sample, n_episode] shoule be set. # n_sample=1, - # (int) Cut trajectories into pieces with length "unroll_len". + # (int) Split episodes or trajectories into pieces with length `unroll_len`. unroll_len=1, # (float) It is a must to add noise during collection. So here omits "noise" and only set "noise_sigma". noise_sigma=0.1, ), - eval=dict( - evaluator=dict( - # (int) Evaluate every "eval_freq" training iterations. - eval_freq=5000, - ), - ), + eval=dict(), # for compability other=dict( replay_buffer=dict( - # (int) Maximum size of replay buffer. + # (int) Maximum size of replay buffer. Usually, larger buffer size is better. replay_buffer_size=100000, ), ), ) def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. + """ if self._cfg.multi_agent: return 'continuous_maqac', ['ding.model.template.maqac'] else: return 'continuous_qac', ['ding.model.template.qac'] def _init_learn(self) -> None: - r""" + """ Overview: - Learn mode init method. Called by ``self.__init__``. - Init actor and critic optimizers, algorithm config, main and target models. + Initialize the learn mode of policy, including related attributes and modules. For DDPG, it mainly \ + contains two optimizers, algorithm-specific arguments such as gamma and twin_critic, main and target model. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight @@ -177,7 +185,9 @@ def _init_learn(self) -> None: # main and target models self._target_model = copy.deepcopy(self._model) + self._learn_model = model_wrap(self._model, wrapper_name='base') if self._cfg.action_space == 'hybrid': + self._learn_model = model_wrap(self._learn_model, wrapper_name='hybrid_argmax_sample') self._target_model = model_wrap(self._target_model, wrapper_name='hybrid_argmax_sample') self._target_model = model_wrap( self._target_model, @@ -196,22 +206,39 @@ def _init_learn(self) -> None: }, noise_range=self._cfg.learn.noise_range ) - self._learn_model = model_wrap(self._model, wrapper_name='base') - if self._cfg.action_space == 'hybrid': - self._learn_model = model_wrap(self._learn_model, wrapper_name='hybrid_argmax_sample') self._learn_model.reset() self._target_model.reset() self._forward_learn_cnt = 0 # count iterations - def _forward_learn(self, data: dict) -> Dict[str, Any]: - r""" + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """ Overview: - Forward and backward function of learn mode. + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss, action, priority. Arguments: - - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs'] + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For DDPG, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ + ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \ + and ``logit`` which is used for hybrid action space. Returns: - - info_dict (:obj:`Dict[str, Any]`): Including at least actor and critic lr, different losses. + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for DDPGPolicy: ``ding.policy.tests.test_ddpg``. """ loss_dict = {} data = default_preprocess_learn( @@ -314,6 +341,12 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: } def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model, target_model and optimizers. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. + """ return { 'model': self._learn_model.state_dict(), 'target_model': self._target_model.state_dict(), @@ -322,16 +355,33 @@ def _state_dict_learn(self) -> Dict[str, Any]: } def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + + .. tip:: + If you want to only load some parts of model, you can simply set the ``strict`` argument in \ + load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ + complicated operation. + """ self._learn_model.load_state_dict(state_dict['model']) self._target_model.load_state_dict(state_dict['target_model']) self._optimizer_actor.load_state_dict(state_dict['optimizer_actor']) self._optimizer_critic.load_state_dict(state_dict['optimizer_critic']) def _init_collect(self) -> None: - r""" + """ Overview: - Collect mode init method. Called by ``self.__init__``. - Init traj and unroll length, collect model. + Initialize the collect mode of policy, including related attributes and modules. For DDPG, it contains the \ + collect_model to balance the exploration and exploitation with the perturbed noise mechanism, and other \ + algorithm-specific arguments such as unroll_len. \ + This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ + with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. """ self._unroll_len = self._cfg.collect.unroll_len # collect model @@ -349,18 +399,28 @@ def _init_collect(self) -> None: self._collect_model = model_wrap(self._collect_model, wrapper_name='hybrid_eps_greedy_multinomial_sample') self._collect_model.reset() - def _forward_collect(self, data: dict, **kwargs) -> dict: - r""" + def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: + """ Overview: - Forward function of collect mode. + Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ + that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ + data, such as the action to interact with the envs. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ - values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs. - ReturnsKeys - - necessary: ``action`` - - optional: ``logit`` + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ + other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ + dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for DDPGPolicy: ``ding.policy.tests.test_ddpg``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) @@ -374,55 +434,84 @@ def _forward_collect(self, data: dict, **kwargs) -> dict: output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} - def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> Dict[str, Any]: - r""" + def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor], + timestep: namedtuple) -> Dict[str, torch.Tensor]: + """ Overview: - Generate dict type transition data from inputs. + Process and pack one timestep transition data into a dict, which can be directly used for training and \ + saved in replay buffer. For DDPG, it contains obs, next_obs, action, reward, done. Arguments: - - obs (:obj:`Any`): Env observation - - model_output (:obj:`dict`): Output of collect model, including at least ['action'] - - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \ - (here 'obs' indicates obs after env step, i.e. next_obs). - Return: - - transition (:obj:`Dict[str, Any]`): Dict type transition data. + - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. + - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ + as input. For DDPG, it contains the action and the logit of the action (in hybrid action space). + - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ + except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ + reward, done, info, etc. + Returns: + - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. """ transition = { 'obs': obs, 'next_obs': timestep.obs, - 'action': model_output['action'], + 'action': policy_output['action'], 'reward': timestep.reward, 'done': timestep.done, } if self._cfg.action_space == 'hybrid': - transition['logit'] = model_output['logit'] + transition['logit'] = policy_output['logit'] return transition - def _get_train_sample(self, data: list) -> Union[None, List[Any]]: - return get_train_sample(data, self._unroll_len) + def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Overview: + For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ + can be used for training directly. In DDPG, a train sample is a processed transition (unroll_len=1). + Arguments: + - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ + the same format as the return value of ``self._process_transition`` method. + Returns: + - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \ + as input transitions, but may contain more data for training. + """ + return get_train_sample(transitions, self._unroll_len) def _init_eval(self) -> None: - r""" + """ Overview: - Evaluate mode init method. Called by ``self.__init__``. - Init eval model. Unlike learn and collect model, eval model does not need noise. + Initialize the eval mode of policy, including related attributes and modules. For DDPG, it contains the \ + eval model to greedily select action type with argmax q_value mechanism for hybrid action space. \ + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. """ self._eval_model = model_wrap(self._model, wrapper_name='base') if self._cfg.action_space == 'hybrid': self._eval_model = model_wrap(self._eval_model, wrapper_name='hybrid_argmax_sample') self._eval_model.reset() - def _forward_eval(self, data: dict) -> dict: - r""" + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ Overview: - Forward function of eval mode, similar to ``self._forward_collect``. + Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ + means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ + action to interact with the envs. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ - values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. Returns: - - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. - ReturnsKeys - - necessary: ``action`` - - optional: ``logit`` + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for DDPGPolicy: ``ding.policy.tests.test_ddpg``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) @@ -437,11 +526,12 @@ def _forward_eval(self, data: dict) -> dict: return {i: d for i, d in zip(data_id, output)} def _monitor_vars_learn(self) -> List[str]: - r""" + """ Overview: - Return variables' names if variables are to used in monitor. + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. Returns: - - vars (:obj:`List[str]`): Variables' name list. + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. """ ret = [ 'cur_lr_actor', 'cur_lr_critic', 'critic_loss', 'actor_loss', 'total_loss', 'q_value', 'q_value_twin', diff --git a/ding/policy/dqn.py b/ding/policy/dqn.py index 7adb614818..d1f6fdbb49 100644 --- a/ding/policy/dqn.py +++ b/ding/policy/dqn.py @@ -85,22 +85,23 @@ class DQNPolicy(Policy): config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). type='dqn', - # (bool) Whether use cuda in policy. + # (bool) Whether to use cuda in policy. cuda=False, # (bool) Whether learning policy is the same as collecting data policy(on-policy). on_policy=False, - # (bool) Whether enable priority experience sample. + # (bool) Whether to enable priority experience sample. priority=False, - # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. + # (bool) Whether to use Importance Sampling Weight to correct biased update. If True, priority must be True. priority_IS_weight=False, # (float) Discount factor(gamma) for returns. discount_factor=0.97, # (int) The number of step for calculating target q_value. nstep=1, model=dict( - #(list(int)) Sequence of ``hidden_size`` of subsequent conv layers and the final dense layer. + # (list(int)) Sequence of ``hidden_size`` of subsequent conv layers and the final dense layer. encoder_hidden_size_list=[128, 128, 64], ), + # learn_mode config learn=dict( # (int) How many updates(iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. @@ -134,7 +135,7 @@ class DQNPolicy(Policy): # (int) Split episodes or trajectories into pieces with length `unroll_len`. unroll_len=1, ), - eval=dict(), + eval=dict(), # for compability # other config other=dict( # Epsilon greedy with decay. @@ -149,7 +150,7 @@ class DQNPolicy(Policy): decay=10000, ), replay_buffer=dict( - # (int) Maximum size of replay buffer. Usually, larger buffer size is good. + # (int) Maximum size of replay buffer. Usually, larger buffer size is better. replay_buffer_size=10000, ), ), @@ -158,21 +159,35 @@ class DQNPolicy(Policy): def default_model(self) -> Tuple[str, List[str]]: """ Overview: - Return this algorithm default model setting for demonstration. + Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. Returns: - - model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names + - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. .. note:: The user can define and use customized network model but must obey the same inferface definition indicated \ - by import_names path. For DQN, ``ding.model.template.q_learning.DQN`` + by import_names path. For example about DQN, its registered name is ``dqn`` and the import_names is \ + ``ding.model.template.q_learning``. """ return 'dqn', ['ding.model.template.q_learning'] def _init_learn(self) -> None: """ Overview: - Learn mode init method. Called by ``self.__init__``, initialize the optimizer, algorithm arguments, main \ - and target model. + Initialize the learn mode of policy, including related attributes and modules. For DQN, it mainly contains \ + optimizer, algorithm-specific arguments such as nstep and gamma, main and target model. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight @@ -204,23 +219,36 @@ def _init_learn(self) -> None: self._learn_model.reset() self._target_model.reset() - def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: """ Overview: - Forward computation graph of learn mode(updating policy). + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss, q value, priority. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \ - np.ndarray or dict/list combinations. + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For DQN, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ + ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \ + and ``value_gamma``. Returns: - - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \ - recorded in text log and tensorboard, values are python scalar or a list of scalars. - ArgumentsKeys: - - necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done`` - - optional: ``value_gamma`` - ReturnsKeys: - - necessary: ``cur_lr``, ``total_loss``, ``priority`` - - optional: ``action_distribution`` + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``. """ + # Data preprocessing operations, such as stack data, cpu to cuda device data = default_preprocess_learn( data, use_priority=self._priority, @@ -230,9 +258,7 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: ) if self._cuda: data = to_device(data, self._device) - # ==================== # Q-learning forward - # ==================== self._learn_model.train() self._target_model.train() # Current q value (main model) @@ -249,18 +275,14 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: value_gamma = data.get('value_gamma') loss, td_error_per_sample = q_nstep_td_error(data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma) - # ==================== - # Q-learning update - # ==================== + # Update network parameters self._optimizer.zero_grad() loss.backward() if self._cfg.multi_gpu: self.sync_gradients(self._learn_model) self._optimizer.step() - # ============= - # after update - # ============= + # Postprocessing operations, such as updating target model, return logged values and priority. self._target_model.update(self._learn_model.state_dict()) return { 'cur_lr': self._optimizer.defaults['lr'], @@ -273,14 +295,21 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: } def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ return ['cur_lr', 'total_loss', 'q_value', 'target_q_value'] def _state_dict_learn(self) -> Dict[str, Any]: """ Overview: - Return the state_dict of learn mode, usually including model and optimizer. + Return the state_dict of learn mode, usually including model, target_model and optimizer. Returns: - - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. """ return { 'model': self._learn_model.state_dict(), @@ -293,7 +322,7 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: Overview: Load the state_dict variable into policy learn mode. Arguments: - - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before. + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. .. tip:: If you want to only load some parts of model, you can simply set the ``strict`` argument in \ @@ -307,8 +336,18 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: def _init_collect(self) -> None: """ Overview: - Collect mode init method. Called by ``self.__init__``, initialize algorithm arguments and collect_model, \ - enable the eps_greedy_sample for exploration. + Initialize the collect mode of policy, including related attributes and modules. For DQN, it contains the \ + collect_model to balance the exploration and exploitation with epsilon-greedy sample mechanism, and other \ + algorithm-specific arguments such as unroll_len and nstep. + This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ + with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. + + .. tip:: + Some variables need to initialize independently in different modes, such as gamma and nstep in DQN. This \ + design is for the convenience of parallel execution of different policy modes. """ self._unroll_len = self._cfg.collect.unroll_len self._gamma = self._cfg.discount_factor # necessary for parallel @@ -319,18 +358,27 @@ def _init_collect(self) -> None: def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: """ Overview: - Forward computation graph of collect mode(collect training data), with eps_greedy for exploration. + Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ + that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ + data, such as the action to interact with the envs. Besides, this policy also needs ``eps`` argument for \ + exploration, i.e., classic epsilon-greedy exploration strategy. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ - values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. - - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step. + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + - eps (:obj:`float`): The epsilon value for exploration. Returns: - - output (:obj:`Dict[int, Any]`): The dict of predicting policy_output(action) for the interaction with \ - env and the constructing of transition. - ArgumentsKeys: - - necessary: ``obs`` - ReturnsKeys - - necessary: ``logit``, ``action`` + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ + other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ + dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) @@ -344,38 +392,40 @@ def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} - def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Overview: - For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \ - can be used for training directly. A train sample can be a processed transition(DQN with nstep TD) \ - or some continuous transitions(DRQN). + For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ + can be used for training directly. In DQN with nstep TD, a train sample is a processed transition. \ + This method is usually used in collectors to execute necessary \ + RL data preprocessing before training, which can help learner amortize revelant time consumption. \ + In addition, you can also implement this method as an identity function and do the data processing \ + in ``self._forward_learn`` method. Arguments: - - data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \ - format as the return value of ``self._process_transition`` method. + - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ + in the same format as the return value of ``self._process_transition`` method. Returns: - - samples (:obj:`dict`): The list of training samples. - - .. note:: - We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \ - And the user can customize the this data processing procecure by overriding this two methods and collector \ - itself. + - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is similar in format \ + to input transitions, but may contain more data for training, such as nstep reward and target obs. """ - data = get_nstep_return_data(data, self._nstep, gamma=self._gamma) - return get_train_sample(data, self._unroll_len) + transitions = get_nstep_return_data(transitions, self._nstep, gamma=self._gamma) + return get_train_sample(transitions, self._unroll_len) - def _process_transition(self, obs: Any, policy_output: Dict[str, Any], timestep: namedtuple) -> Dict[str, Any]: + def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor], + timestep: namedtuple) -> Dict[str, torch.Tensor]: """ Overview: - Generate a transition(e.g.: ) for this algorithm training. + Process and pack one timestep transition data into a dict, which can be directly used for training and \ + saved in replay buffer. For DQN, it contains obs, next_obs, action, reward, done. Arguments: - - obs (:obj:`Any`): Env observation. - - policy_output (:obj:`Dict[str, Any]`): The output of policy collect mode(``self._forward_collect``),\ - including at least ``action``. - - timestep (:obj:`namedtuple`): The output after env step(execute policy output action), including at \ - least ``obs``, ``reward``, ``done``, (here obs indicates obs after env step). + - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. + - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ + as input. For DQN, it contains the action and the logit (q_value) of the action. + - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ + except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ + reward, done, info, etc. Returns: - - transition (:obj:`dict`): Dict type transition data. + - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. """ transition = { 'obs': obs, @@ -387,9 +437,15 @@ def _process_transition(self, obs: Any, policy_output: Dict[str, Any], timestep: return transition def _init_eval(self) -> None: - r""" + """ Overview: - Evaluate mode init method. Called by ``self.__init__``, initialize eval_model. + Initialize the eval mode of policy, including related attributes and modules. For DQN, it contains the \ + eval model to greedily select action with argmax q_value mechanism. + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. """ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') self._eval_model.reset() @@ -397,17 +453,24 @@ def _init_eval(self) -> None: def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: """ Overview: - Forward computation graph of eval mode(evaluate policy performance), at most cases, it is similar to \ - ``self._forward_collect``. + Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ + means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ + action to interact with the envs. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ - values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. Returns: - - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. - ArgumentsKeys: - - necessary: ``obs`` - ReturnsKeys - - necessary: ``action`` + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) @@ -427,6 +490,7 @@ def calculate_priority(self, data: Dict[int, Any], update_target_model: bool = F Calculate priority for replay buffer. Arguments: - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training. + - update_target_model (:obj:`bool`): Whether to update target model. Returns: - priority (:obj:`Dict[str, Any]`): Dict type priority data, values are python scalar or a list of scalars. ArgumentsKeys: @@ -474,7 +538,8 @@ def calculate_priority(self, data: Dict[int, Any], update_target_model: bool = F class DQNSTDIMPolicy(DQNPolicy): """ Overview: - Policy class of DQN algorithm, extended by auxiliary objectives. + Policy class of DQN algorithm, extended by ST-DIM auxiliary objectives. + ST-DIM paper link: https://arxiv.org/abs/1906.08226. Config: == ==================== ======== ============== ======================================== ======================= ID Symbol Type Default Value Description Other(Shape) @@ -529,68 +594,83 @@ class DQNSTDIMPolicy(DQNPolicy): """ config = dict( + # (str) RL policy register name (refer to function "POLICY_REGISTRY"). type='dqn_stdim', - # (bool) Whether use cuda in policy + # (bool) Whether to use cuda in policy. cuda=False, - # (bool) Whether learning policy is the same as collecting data policy(on-policy) + # (bool) Whether to learning policy is the same as collecting data policy (on-policy). on_policy=False, - # (bool) Whether enable priority experience sample + # (bool) Whether to enable priority experience sample. priority=False, - # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. + # (bool) Whether to use Importance Sampling Weight to correct biased update. If True, priority must be True. priority_IS_weight=False, - # (float) Discount factor(gamma) for returns + # (float) Discount factor(gamma) for returns. discount_factor=0.97, - # (int) The number of step for calculating target q_value + # (int) The number of step for calculating target q_value. nstep=1, + # (float) The weight of auxiliary loss to main loss. + aux_loss_weight=0.001, + # learn_mode config learn=dict( - # How many updates(iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. # collect data -> update policy-> collect data -> ... update_per_collect=3, - # (int) How many samples in a training batch + # (int) How many samples in a training batch. batch_size=64, - # (float) The step size of gradient descent + # (float) The step size of gradient descent. learning_rate=0.001, - # ============================================================== - # The following configs are algorithm-specific - # ============================================================== # (int) Frequence of target network update. target_update_freq=100, - # (bool) Whether ignore done(usually for max step termination env) + # (bool) Whether ignore done(usually for max step termination env). ignore_done=False, ), # collect_mode config collect=dict( - # (int) Only one of [n_sample, n_episode] shoule be set + # (int) How many training samples collected in one collection procedure. + # Only one of [n_sample, n_episode] shoule be set. # n_sample=8, # (int) Cut trajectories into pieces with length "unroll_len". unroll_len=1, ), - eval=dict(), + eval=dict(), # for compability # other config other=dict( # Epsilon greedy with decay. eps=dict( # (str) Decay type. Support ['exp', 'linear']. type='exp', - # (float) Epsilon start value + # (float) Epsilon start value. start=0.95, - # (float) Epsilon end value + # (float) Epsilon end value. end=0.1, - # (int) Decay length(env step) + # (int) Decay length (env step). decay=10000, ), - replay_buffer=dict(replay_buffer_size=10000, ), + replay_buffer=dict( + # (int) Maximum size of replay buffer. Usually, larger buffer size is better. + replay_buffer_size=10000, + ), ), - aux_loss_weight=0.001, ) def _init_learn(self) -> None: """ Overview: - Learn mode init method. Called by ``self.__init__``. - Init the auxiliary model, its optimizer, and the axuliary loss weight to the main loss. + Initialize the learn mode of policy, including related attributes and modules. For DQNSTDIM, it first \ + call super class's ``_init_learn`` method, then initialize extra auxiliary model, its optimizer, and the \ + loss weight. This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ super()._init_learn() x_size, y_size = self._get_encoding_size() @@ -600,12 +680,12 @@ def _init_learn(self) -> None: self._aux_optimizer = Adam(self._aux_model.parameters(), lr=self._cfg.learn.learning_rate) self._aux_loss_weight = self._cfg.aux_loss_weight - def _get_encoding_size(self): + def _get_encoding_size(self) -> Tuple[Tuple[int], Tuple[int]]: """ Overview: Get the input encoding size of the ST-DIM axuiliary model. Returns: - - info_dict (:obj:`[Tuple, Tuple]`): The encoding size without the first (Batch) dimension. + - info_dict (:obj:`Tuple[Tuple[int], Tuple[int]]`): The encoding size without the first (Batch) dimension. """ obs = self._cfg.model.obs_shape if isinstance(obs, int): @@ -620,16 +700,15 @@ def _get_encoding_size(self): x, y = self._model_encode(test_data) return x.size()[1:], y.size()[1:] - def _model_encode(self, data): + def _model_encode(self, data: dict) -> Tuple[torch.Tensor]: """ Overview: Get the encoding of the main model as input for the auxiliary model. Arguments: - data (:obj:`dict`): Dict type data, same as the _forward_learn input. Returns: - - (:obj:`Tuple[Tensor]`): the tuple of two tensors to apply contrastive embedding learning. - In ST-DIM algorithm, these two variables are the dqn encoding of `obs` and `next_obs`\ - respectively. + - (:obj:`Tuple[torch.Tensor]`): the tuple of two tensors to apply contrastive embedding learning. \ + In ST-DIM algorithm, these two variables are the dqn encoding of `obs` and `next_obs` respectively. """ assert hasattr(self._model, "encoder") x = self._model.encoder(data["obs"]) @@ -639,19 +718,28 @@ def _model_encode(self, data): def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Overview: - Forward computation graph of learn mode(updating policy). + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss, q value, priority, aux_loss. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \ - np.ndarray or dict/list combinations. + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For DQNSTDIM, each element in list is a dict containing at least the following keys: ``obs``, \ + ``action``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as \ + ``weight`` and ``value_gamma``. Returns: - - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \ - recorded in text log and tensorboard, values are python scalar or a list of scalars. - ArgumentsKeys: - - necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done`` - - optional: ``value_gamma``, ``IS`` - ReturnsKeys: - - necessary: ``cur_lr``, ``total_loss``, ``priority`` - - optional: ``action_distribution`` + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. """ data = default_preprocess_learn( data, @@ -735,6 +823,13 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: } def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ return ['cur_lr', 'bellman_loss', 'aux_loss_learn', 'aux_loss_eval', 'total_loss', 'q_value'] def _state_dict_learn(self) -> Dict[str, Any]: diff --git a/ding/policy/dt.py b/ding/policy/dt.py index adef441820..145b11f97c 100644 --- a/ding/policy/dt.py +++ b/ding/policy/dt.py @@ -1,7 +1,4 @@ -"""The code is adapted from https://github.com/nikhilbarhate99/min-decision-transformer -""" - -from typing import List, Dict, Any, Tuple +from typing import List, Dict, Any, Tuple, Optional from collections import namedtuple import torch.nn.functional as F import torch @@ -9,16 +6,15 @@ from ding.torch_utils import to_device from ding.utils import POLICY_REGISTRY from ding.utils.data import default_decollate -from ding.torch_utils import one_hot from .base_policy import Policy @POLICY_REGISTRY.register('dt') class DTPolicy(Policy): - r""" + """ Overview: Policy class of Decision Transformer algorithm in discrete environments. - Paper link: https://arxiv.org/abs/2106.01345 + Paper link: https://arxiv.org/abs/2106.01345. """ config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). @@ -42,13 +38,37 @@ class DTPolicy(Policy): ) def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. + + .. note:: + The user can define and use customized network model but must obey the same inferface definition indicated \ + by import_names path. For example about DQN, its registered name is ``dqn`` and the import_names is \ + ``ding.model.template.q_learning``. + """ return 'dt', ['ding.model.template.dt'] def _init_learn(self) -> None: - r""" + """ Overview: - Learn mode init method. Called by ``self.__init__``. - Init the optimizer, algorithm config, main and target models. + Initialize the learn mode of policy, including related attributes and modules. For Decision Transformer, \ + it mainly contains the optimizer, algorithm-specific arguments such as rtg_scale and lr scheduler. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ # rtg_scale: scale of `return to go` # rtg_target: max target of `return to go` @@ -83,14 +103,26 @@ def _init_learn(self) -> None: self.max_env_score = -1.0 - def _forward_learn(self, data: list) -> Dict[str, Any]: - r""" - Overview: - Forward and backward function of learn mode. - Arguments: - - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs'] - Returns: - - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. + def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]: + """ + Overview: + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the offline dataset and then returns the output \ + result, including various training information such as loss, current learning rate. + Arguments: + - data (:obj:`List[torch.Tensor]`): The input data used for policy forward, including a series of \ + processed torch.Tensor data, i.e., timesteps, states, actions, returns_to_go, traj_mask. + Returns: + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + """ self._learn_model.train() @@ -145,9 +177,20 @@ def _forward_learn(self, data: list) -> Dict[str, Any]: } def _init_eval(self) -> None: - r""" + """ Overview: - Evaluate mode init method. Called by ``self.__init__``, initialize eval_model. + Initialize the eval mode of policy, including related attributes and modules. For DQN, it contains the \ + eval model, some algorithm-specific parameters such as context_len, max_eval_ep_len, etc. + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. tip:: + For the evaluation of complete episodes, we need to maintain some historical information for transformer \ + inference. These variables need to be initialized in ``_init_eval`` and reset in ``_reset_eval`` when \ + necessary. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. """ self._eval_model = self._model # init data @@ -196,6 +239,22 @@ def _init_eval(self) -> None: ) def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ + Overview: + Policy forward function of eval mode (evaluation policy performance, such as interacting with envs. \ + Forward means that the policy gets some input data (current obs/return-to-go and historical information) \ + from the envs and then returns the output data, such as the action to interact with the envs. \ + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs and \ + reward to calculate running return-to-go. The key of the dict is environment id and the value is the \ + corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + Decision Transformer will do different operations for different types of envs in evaluation. + """ # save and forward data_id = list(data.keys()) @@ -279,7 +338,17 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} - def _reset_eval(self, data_id: List[int] = None) -> None: + def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: + """ + Overview: + Reset some stateful variables for eval mode when necessary, such as the historical info of transformer \ + for decision transformer. If ``data_id`` is None, it means to reset all the stateful \ + varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ + different environments/episodes in evaluation in ``data_id`` will have different history. + Arguments: + - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ + specified by ``data_id``. + """ # clean data if data_id is None: self.t = [0 for _ in range(self.eval_batch_size)] @@ -339,21 +408,14 @@ def _reset_eval(self, data_id: List[int] = None) -> None: self.timesteps[i] = torch.arange(start=0, end=self.max_eval_ep_len, step=1).to(self._device) self.rewards_to_go[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device) - def _state_dict_learn(self) -> Dict[str, Any]: - return { - 'model': self._learn_model.state_dict(), - # 'target_model': self._target_model.state_dict(), - 'optimizer': self._optimizer.state_dict(), - } - - def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - self._learn_model.load_state_dict(state_dict['model']) - self._optimizer.load_state_dict(state_dict['optimizer']) - - def _load_state_dict_eval(self, state_dict: Dict[str, Any]) -> None: - self._eval_model.load_state_dict(state_dict) - def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ return ['cur_lr', 'action_loss'] def _init_collect(self) -> None: diff --git a/ding/policy/impala.py b/ding/policy/impala.py index 958275fc31..46adeb1204 100644 --- a/ding/policy/impala.py +++ b/ding/policy/impala.py @@ -14,9 +14,9 @@ @POLICY_REGISTRY.register('impala') class IMPALAPolicy(Policy): - r""" + """ Overview: - Policy class of IMPALA algorithm. + Policy class of IMPALA algorithm. Paper link: https://arxiv.org/abs/1802.01561. Config: == ==================== ======== ============== ======================================== ======================= @@ -41,80 +41,117 @@ class IMPALAPolicy(Policy): == ==================== ======== ============== ======================================== ======================= """ config = dict( + # (str) RL policy register name (refer to function "POLICY_REGISTRY"). type='impala', + # (bool) Whether to use cuda in policy. cuda=False, - # (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same) - # here we follow ppo serial pipeline, the original is False + # (bool) Whether learning policy is the same as collecting data policy(on-policy). on_policy=False, + # (bool) Whether to enable priority experience sample. priority=False, # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. priority_IS_weight=False, - # (str) Which kind of action space used in IMPALAPolicy, ['discrete', 'continuous'] + # (str) Which kind of action space used in IMPALAPolicy, ['discrete', 'continuous']. action_space='discrete', - # (int) the trajectory length to calculate v-trace target + # (int) the trajectory length to calculate v-trace target. unroll_len=32, - # (bool) Whether to need policy data in process transition + # (bool) Whether to need policy data in process transition. transition_with_policy_data=True, + # learn_mode config learn=dict( - # (int) collect n_sample data, train model update_per_collect times - # here we follow ppo serial pipeline + # (int) collect n_sample data, train model update_per_collect times. update_per_collect=4, - # (int) the number of data for a train iteration + # (int) the number of data for a train iteration. batch_size=16, + # (float) The step size of gradient descent. learning_rate=0.0005, - # (float) loss weight of the value network, the weight of policy network is set to 1 + # (float) loss weight of the value network, the weight of policy network is set to 1. value_weight=0.5, - # (float) loss weight of the entropy regularization, the weight of policy network is set to 1 + # (float) loss weight of the entropy regularization, the weight of policy network is set to 1. entropy_weight=0.0001, - # (float) discount factor for future reward, defaults int [0, 1] + # (float) discount factor for future reward, defaults int [0, 1]. discount_factor=0.99, - # (float) additional discounting parameter + # (float) additional discounting parameter. lambda_=0.95, - # (float) clip ratio of importance weights + # (float) clip ratio of importance weights. rho_clip_ratio=1.0, - # (float) clip ratio of importance weights + # (float) clip ratio of importance weights. c_clip_ratio=1.0, - # (float) clip ratio of importance sampling + # (float) clip ratio of importance sampling. rho_pg_clip_ratio=1.0, + # (str) The gradient clip operation type used in IMPALA, ['clip_norm', clip_value', 'clip_momentum_norm']. + grad_clip_type=None, + # (float) The gradient clip target value used in IMPALA. + # If ``grad_clip_type`` is 'clip_norm', then the maximum of gradient will be normalized to this value. + clip_value=0.5, + # (str) Optimizer used to train the network, ['adam', 'rmsprop']. + optim='adam', ), + # collect_mode config collect=dict( - # (int) collect n_sample data, train model n_iteration times + # (int) How many training samples collected in one collection procedure. + # Only one of [n_sample, n_episode] shoule be set. # n_sample=16, - collector=dict(collect_print_freq=1000, ), ), - eval=dict(evaluator=dict(eval_freq=1000, ), ), - other=dict(replay_buffer=dict( - replay_buffer_size=1000, - max_use=16, - ), ), + eval=dict(), # for compatibility + other=dict( + replay_buffer=dict( + # (int) Maximum size of replay buffer. Usually, larger buffer size is better. + replay_buffer_size=1000, + # (int) Maximum use times for a sample in buffer. If reaches this value, the sample will be removed. + max_use=16, + ), + ), ) def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. + + .. note:: + The user can define and use customized network model but must obey the same inferface definition indicated \ + by import_names path. For example about IMPALA , its registered name is ``vac`` and the import_names is \ + ``ding.model.template.vac``. + """ return 'vac', ['ding.model.template.vac'] def _init_learn(self) -> None: - r""" + """ Overview: - Learn mode init method. Called by ``self.__init__``. - Initialize the optimizer, algorithm config and main model. + Initialize the learn mode of policy, including related attributes and modules. For IMPALA, it mainly \ + contains optimizer, algorithm-specific arguments such as loss weight and gamma, main (learn) model. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ - assert self._cfg.action_space in ["continuous", "discrete"] + assert self._cfg.action_space in ["continuous", "discrete"], self._cfg.action_space self._action_space = self._cfg.action_space # Optimizer - grad_clip_type = self._cfg.learn.get("grad_clip_type", None) - clip_value = self._cfg.learn.get("clip_value", None) - optim_type = self._cfg.learn.get("optim", "adam") + optim_type = self._cfg.learn.optim if optim_type == 'rmsprop': self._optimizer = RMSprop(self._model.parameters(), lr=self._cfg.learn.learning_rate) elif optim_type == 'adam': self._optimizer = Adam( self._model.parameters(), - grad_clip_type=grad_clip_type, - clip_value=clip_value, + grad_clip_type=self._cfg.learn.grad_clip_type, + clip_value=self._cfg.learn.clip_value, lr=self._cfg.learn.learning_rate ) else: - raise NotImplementedError + raise NotImplementedError("Now only support rmsprop and adam, but input is {}".format(optim_type)) self._learn_model = model_wrap(self._model, wrapper_name='base') self._action_shape = self._cfg.model.action_shape @@ -141,8 +178,8 @@ def _data_preprocess_learn(self, data: List[Dict[str, Any]]): Convert list trajectory data to to trajectory data, which is a dict of tensors. Arguments: - data (:obj:`List[Dict[str, Any]]`): List type data, a list of data for training. Each list element is a \ - dict, whose values are torch.Tensor or np.ndarray or dict/list combinations, keys include at least 'obs',\ - 'next_obs', 'logit', 'action', 'reward', 'done' + dict, whose values are torch.Tensor or np.ndarray or dict/list combinations, keys include at least \ + 'obs', 'next_obs', 'logit', 'action', 'reward', 'done' Returns: - data (:obj:`dict`): Dict type data. Values are torch.Tensor or np.ndarray or dict/list combinations. \ ReturnsKeys: @@ -181,22 +218,33 @@ def _data_preprocess_learn(self, data: List[Dict[str, Any]]): return data def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: - r""" + """ Overview: - Forward computation graph of learn mode(updating policy). + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss and current learning rate. Arguments: - - data (:obj:`List[Dict[str, Any]]`): List type data, a list of data for training. Each list element is a \ - dict, whose values are torch.Tensor or np.ndarray or dict/list combinations, keys include at least 'obs',\ - 'next_obs', 'logit', 'action', 'reward', 'done' + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For IMPALA, each element in list is a dict containing at least the following keys: ``obs``, \ + ``action``, ``logit``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such \ + as ``weight``. Returns: - - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \ - recorded in text log and tensorboard, values are python scalar or a list of scalars. - ArgumentsKeys: - - necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done`` - - optional: 'collect_iter', 'replay_unique_id', 'replay_buffer_idx', 'priority', 'staleness', 'use', 'IS' - ReturnsKeys: - - necessary: ``cur_lr``, ``total_loss``, ``policy_loss`,``value_loss``,``entropy_loss`` - - optional: ``priority`` + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to unittest for IMPALAPolicy: ``ding.policy.tests.test_impala``. """ data = self._data_preprocess_learn(data) # ==================== @@ -231,23 +279,21 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: 'entropy_loss': vtrace_loss.entropy_loss.item(), } - def _reshape_data(self, output: Dict[str, Any], data: Dict[str, Any]) -> Tuple[Any, Any, Any, Any, Any, Any]: - r""" + def _reshape_data(self, output: Dict[str, Any], data: Dict[str, Any]) -> Tuple: + """ Overview: - Obtain weights for loss calculating, where should be 0 for done positions - Update values and rewards with the weight + Obtain weights for loss calculating, where should be 0 for done positions. Update values and rewards with \ + the weight. Arguments: - output (:obj:`Dict[int, Any]`): Dict type data, output of learn_model forward. \ - Values are torch.Tensor or np.ndarray or dict/list combinations,keys are value, logit. - - data (:obj:`Dict[int, Any]`): Dict type data, input of policy._forward_learn \ - Values are torch.Tensor or np.ndarray or dict/list combinations. Keys includes at \ - least ['logit', 'action', 'reward', 'done',] + Values are torch.Tensor or np.ndarray or dict/list combinations,keys are value, logit. + - data (:obj:`Dict[int, Any]`): Dict type data, input of policy._forward_learn Values are torch.Tensor or \ + np.ndarray or dict/list combinations. Keys includes at least ['logit', 'action', 'reward', 'done']. Returns: - - data (:obj:`Tuple[Any]`): Tuple of target_logit, behaviour_logit, actions, \ - values, rewards, weights + - data (:obj:`Tuple[Any]`): Tuple of target_logit, behaviour_logit, actions, values, rewards, weights. ReturnsShapes: - target_logit (:obj:`torch.FloatTensor`): :math:`((T+1), B, Obs_Shape)`, where T is timestep,\ - B is batch size and Obs_Shape is the shape of single env observation. + B is batch size and Obs_Shape is the shape of single env observation. - behaviour_logit (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where N is action dim. - actions (:obj:`torch.LongTensor`): :math:`(T, B)` - values (:obj:`torch.FloatTensor`): :math:`(T+1, B)` @@ -275,37 +321,17 @@ def _reshape_data(self, output: Dict[str, Any], data: Dict[str, Any]) -> Tuple[A rewards = rewards * weights # shape T,B return target_logit, behaviour_logit, actions, values, rewards, weights - def _state_dict_learn(self) -> Dict[str, Any]: - r""" - Overview: - Return the state_dict of learn mode, usually including model and optimizer. - Returns: - - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. + def _init_collect(self) -> None: """ - return { - 'model': self._learn_model.state_dict(), - 'optimizer': self._optimizer.state_dict(), - } - - def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - r""" Overview: - Load the state_dict variable into policy learn mode. - Arguments: - - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before. - .. tip:: - If you want to only load some parts of model, you can simply set the ``strict`` argument in \ - load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ - complicated operation. - """ - self._learn_model.load_state_dict(state_dict['model']) - self._optimizer.load_state_dict(state_dict['optimizer']) + Initialize the collect mode of policy, including related attributes and modules. For IMPALA, it contains \ + the collect_model to balance the exploration and exploitation (e.g. the multinomial sample mechanism in \ + discrete action space), and other algorithm-specific arguments such as unroll_len. + This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. - def _init_collect(self) -> None: - r""" - Overview: - Collect mode init method. Called by ``self.__init__``, initialize algorithm arguments and collect_model. - Use multinomial_sample to choose action. + .. note:: + If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ + with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. """ assert self._cfg.action_space in ["continuous", "discrete"] self._action_space = self._cfg.action_space @@ -316,18 +342,32 @@ def _init_collect(self) -> None: self._collect_model.reset() - def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Dict[str, Any]]: - r""" + def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ Overview: - Forward computation graph of collect mode(collect training data). + Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ + that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ + data, such as the action to interact with the envs. Arguments: - - data (:obj:`Dict[int, Any]`): Dict type data, stacked env data for predicting \ - action, values are torch.Tensor or np.ndarray or dict/list combinations,keys \ - are env_id indicated by integer. + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. Returns: - - output (:obj:`Dict[int, Dict[str,Any]]`): Dict of predicting policy_output(logit, action) for each env. - ReturnsKeys - - necessary: ``logit``, ``action`` + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ + other necessary data (action logit and value) for learn mode defined in ``self._process_transition`` \ + method. The key of the dict is the same as the input data, i.e. environment id. + + .. tip:: + If you want to add more tricks on this policy, like temperature factor in multinomial sample, you can pass \ + related data as extra keyword arguments of this method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to unittest for IMPALAPolicy: ``ding.policy.tests.test_impala``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) @@ -343,34 +383,34 @@ def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Dict[str, Any]]: return output def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - r""" + """ Overview: - For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \ - can be used for training directly. + For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ + can be used for training. In IMPALA, a train sample is processed transitions with unroll_len length. Arguments: - - data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \ - format as the return value of ``self._process_transition`` method. + - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ + the same format as the return value of ``self._process_transition`` method. Returns: - - samples (:obj:`dict`): List of training samples. - .. note:: - We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \ - And the user can customize the this data processing procedure by overriding this two methods and collector \ - itself. + - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \ + as input transitions, but may contain more data for training. """ return get_train_sample(data, self._unroll_len) - def _process_transition(self, obs: Any, policy_output: Dict[str, Any], timestep: namedtuple) -> Dict[str, Any]: - r""" + def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor], + timestep: namedtuple) -> Dict[str, torch.Tensor]: + """ Overview: - Generate dict type transition data from inputs. + Process and pack one timestep transition data into a dict, which can be directly used for training and \ + saved in replay buffer. For IMPALA, it contains obs, next_obs, action, reward, done, logit. Arguments: - - obs (:obj:`Any`): Env observation,can be torch.Tensor or np.ndarray or dict/list combinations. - - model_output (:obj:`dict`): Output of collect model, including ['logit','action'] - - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done']\ - (here 'obs' indicates obs after env step). + - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. + - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ + as input. For IMPALA, it contains the action and the logit of the action. + - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ + except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ + reward, done, info, etc. Returns: - - transition (:obj:`dict`): Dict type transition data, including at least ['obs','next_obs', 'logit',\ - 'action','reward', 'done'] + - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. """ transition = { 'obs': obs, @@ -383,12 +423,17 @@ def _process_transition(self, obs: Any, policy_output: Dict[str, Any], timestep: return transition def _init_eval(self) -> None: - r""" + """ Overview: - Evaluate mode init method. Called by ``self.__init__``, initialize eval_model, - and use argmax_sample to choose action. + Initialize the eval mode of policy, including related attributes and modules. For IMPALA, it contains the \ + eval model to select optimial action (e.g. greedily select action with argmax mechanism in discrete action). + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. """ - assert self._cfg.action_space in ["continuous", "discrete"] + assert self._cfg.action_space in ["continuous", "discrete"], self._cfg.action_space self._action_space = self._cfg.action_space if self._action_space == 'continuous': self._eval_model = model_wrap(self._model, wrapper_name='deterministic_sample') @@ -398,19 +443,28 @@ def _init_eval(self) -> None: self._eval_model.reset() def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: - r""" + """ Overview: - Forward computation graph of eval mode(evaluate policy performance), at most cases, it is similar to \ - ``self._forward_collect``. + Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ + means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ + action to interact with the envs. ``_forward_eval`` in IMPALA often uses deterministic sample to get \ + actions while ``_forward_collect`` usually uses stochastic sample method for balance exploration and \ + exploitation. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ - values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. Returns: - - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. - ReturnsKeys - - necessary: ``action`` - - optional: ``logit`` + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + .. note:: + For more detailed examples, please refer to unittest for IMPALAPolicy: ``ding.policy.tests.test_impala``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) @@ -426,13 +480,11 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: return output def _monitor_vars_learn(self) -> List[str]: - r""" + """ Overview: - Return this algorithm default model setting for demonstration. + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. Returns: - - model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names - .. note:: - The user can define and use customized network model but must obey the same interface definition indicated \ - by import_names path. For IMPALA, ``ding.model.interface.IMPALA`` + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. """ return super()._monitor_vars_learn() + ['policy_loss', 'value_loss', 'entropy_loss'] diff --git a/ding/policy/mdqn.py b/ding/policy/mdqn.py index 688c27de64..8842c11102 100644 --- a/ding/policy/mdqn.py +++ b/ding/policy/mdqn.py @@ -16,7 +16,7 @@ class MDQNPolicy(DQNPolicy): """ Overview: Policy class of Munchausen DQN algorithm, extended by auxiliary objectives. - Paper link: https://arxiv.org/abs/2007.14430 + Paper link: https://arxiv.org/abs/2007.14430. Config: == ==================== ======== ============== ======================================== ======================= ID Symbol Type Default Value Description Other(Shape) @@ -70,26 +70,27 @@ class MDQNPolicy(DQNPolicy): == ==================== ======== ============== ======================================== ======================= """ config = dict( + # (str) RL policy register name (refer to function "POLICY_REGISTRY"). type='mdqn', - # (bool) Whether use cuda in policy + # (bool) Whether to use cuda in policy. cuda=False, - # (bool) Whether learning policy is the same as collecting data policy(on-policy) + # (bool) Whether learning policy is the same as collecting data policy(on-policy). on_policy=False, - # (bool) Whether enable priority experience sample + # (bool) Whether to enable priority experience sample. priority=False, - # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. + # (bool) Whether to use Importance Sampling Weight to correct biased update. If True, priority must be True. priority_IS_weight=False, - # (float) Discount factor(gamma) for returns + # (float) Discount factor(gamma) for returns. discount_factor=0.97, - # (float) Entropy factor (tau) for Munchausen DQN + # (float) Entropy factor (tau) for Munchausen DQN. entropy_tau=0.03, - # (float) Discount factor (alpha) for Munchausen term + # (float) Discount factor (alpha) for Munchausen term. m_alpha=0.9, - # (int) The number of step for calculating target q_value + # (int) The number of step for calculating target q_value. nstep=1, + # learn_mode config learn=dict( - - # How many updates(iterations) to train after collector's one collection. + # (int) How many updates(iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. # collect data -> update policy-> collect data -> ... update_per_collect=3, @@ -97,44 +98,63 @@ class MDQNPolicy(DQNPolicy): batch_size=64, # (float) The step size of gradient descent learning_rate=0.001, - # ============================================================== - # The following configs are algorithm-specific - # ============================================================== # (int) Frequence of target network update. target_update_freq=100, - # (bool) Whether ignore done(usually for max step termination env) + # (bool) Whether ignore done(usually for max step termination env). + # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. + # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. + # However, interaction with HalfCheetah always gets done with done is False, + # Since we inplace done==True with done==False to keep + # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), + # when the episode step is greater than max episode step. ignore_done=False, ), # collect_mode config collect=dict( - # (int) Only one of [n_sample, n_episode] shoule be set + # (int) How many training samples collected in one collection procedure. + # Only one of [n_sample, n_episode] shoule be set. n_sample=4, - # (int) Cut trajectories into pieces with length "unroll_len". + # (int) Split episodes or trajectories into pieces with length `unroll_len`. unroll_len=1, ), - eval=dict(), + eval=dict(), # for compability # other config other=dict( # Epsilon greedy with decay. eps=dict( # (str) Decay type. Support ['exp', 'linear']. type='exp', - # (float) Epsilon start value + # (float) Epsilon start value. start=0.95, - # (float) Epsilon end value + # (float) Epsilon end value. end=0.1, - # (int) Decay length(env step) + # (int) Decay length(env step). decay=10000, ), - replay_buffer=dict(replay_buffer_size=10000, ), + replay_buffer=dict( + # (int) Maximum size of replay buffer. Usually, larger buffer size is better. + replay_buffer_size=10000, + ), ), ) def _init_learn(self) -> None: """ Overview: - Learn mode init method. Called by ``self.__init__``, initialize the optimizer, algorithm arguments, main \ - and target model. + Initialize the learn mode of policy, including related attributes and modules. For MDQN, it contains \ + optimizer, algorithm-specific arguments such as entropy_tau, m_alpha and nstep, main and target model. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight @@ -172,18 +192,31 @@ def _init_learn(self) -> None: def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Overview: - Forward computation graph of learn mode(updating policy). + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss, action_gap, clip_frac, priority. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \ - np.ndarray or dict/list combinations. + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For MDQN, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ + ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \ + and ``value_gamma``. Returns: - - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \ - recorded in text log and tensorboard, values are python scalar or a list of scalars. - ArgumentsKeys: - - necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done`` - - optional: ``value_gamma``, ``IS`` - ReturnsKeys: - - necessary: ``cur_lr``, ``total_loss``, ``priority``, ``action_gap``, ``clip_frac`` + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for MDQNPolicy: ``ding.policy.tests.test_mdqn``. """ data = default_preprocess_learn( data, @@ -238,4 +271,11 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: } def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ return ['cur_lr', 'total_loss', 'q_value', 'action_gap', 'clip_frac'] diff --git a/ding/policy/pdqn.py b/ding/policy/pdqn.py index 7047ebc528..6b66e263ab 100644 --- a/ding/policy/pdqn.py +++ b/ding/policy/pdqn.py @@ -14,9 +14,10 @@ @POLICY_REGISTRY.register('pdqn') class PDQNPolicy(Policy): - r""": + """ Overview: Policy class of PDQN algorithm, which extends the DQN algorithm on discrete-continuous hybrid action spaces. + Paper link: https://arxiv.org/abs/1810.06394. Config: == ==================== ======== ============== ======================================== ======================= @@ -58,12 +59,12 @@ class PDQNPolicy(Policy): | ``_sigma`` | during collection 17 | ``other.eps.type`` str exp | exploration rate decay type | Support ['exp', | 'linear']. - 18 | ``other.eps. float 0.95 | start value of exploration rate | [0,1] - | start`` - 19 | ``other.eps. float 0.05 | end value of exploration rate | [0,1] - | end`` - 20 | ``other.eps. int 10000 | decay length of exploration | greater than 0. set - | decay`` | decay=10000 means + 18 | ``other.eps.`` float 0.95 | start value of exploration rate | [0,1] + | ``start`` + 19 | ``other.eps.`` float 0.05 | end value of exploration rate | [0,1] + | ``end`` + 20 | ``other.eps.`` int 10000 | decay length of exploration | greater than 0. set + | ``decay`` | decay=10000 means | the exploration rate | decay from start | value to end value @@ -71,73 +72,104 @@ class PDQNPolicy(Policy): == ==================== ======== ============== ======================================== ======================= """ config = dict( + # (str) RL policy register name (refer to function "POLICY_REGISTRY"). type='pdqn', + # (bool) Whether to use cuda in policy. cuda=False, + # (bool) Whether learning policy is the same as collecting data policy(on-policy). on_policy=False, + # (bool) Whether to enable priority experience sample. priority=False, - # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. + # (bool) Whether to use Importance Sampling Weight to correct biased update. If True, priority must be True. priority_IS_weight=False, + # (float) Discount factor(gamma) for returns. discount_factor=0.97, + # (int) The number of step for calculating target q_value. nstep=1, + # learn_mode config learn=dict( - - # How many updates(iterations) to train after collector's one collection. + # (int) How many updates(iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. # collect data -> update policy-> collect data -> ... update_per_collect=3, + # (int) How many samples in a training batch. batch_size=64, + # (float) The step size of gradient descent. learning_rate=0.001, - # ============================================================== - # The following configs are algorithm-specific - # ============================================================== # (int) Frequence of target network update. target_theta=0.005, - # (bool) Whether ignore done(usually for max step termination env) + # (bool) Whether ignore done(usually for max step termination env). + # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. + # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. + # However, interaction with HalfCheetah always gets done with done is False, + # Since we inplace done==True with done==False to keep + # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), + # when the episode step is greater than max episode step. ignore_done=False, ), # collect_mode config collect=dict( - # (int) Only one of [n_sample, n_episode] shoule be set + # (int) How many training samples collected in one collection procedure. + # Only one of [n_sample, n_episode] shoule be set. # n_sample=8, - # (int) Cut trajectories into pieces with length "unroll_len". + # (int) Split episodes or trajectories into pieces with length `unroll_len`. unroll_len=1, - # (float) It is a must to add noise during collection. So here omits "noise" and only set "noise_sigma". + # (float) It is a must to add noise during collection. So here omits noise and only set ``noise_sigma``. noise_sigma=0.1, ), - eval=dict(), + eval=dict(), # for compatibility # other config other=dict( # Epsilon greedy with decay. eps=dict( # (str) Decay type. Support ['exp', 'linear']. type='exp', + # (float) Epsilon start value. start=0.95, + # (float) Epsilon end value. end=0.1, # (int) Decay length(env step) decay=10000, ), - replay_buffer=dict(replay_buffer_size=10000, ), + replay_buffer=dict( + # (int) Maximum size of replay buffer. Usually, larger buffer size is better. + replay_buffer_size=10000, + ), ), ) def default_model(self) -> Tuple[str, List[str]]: """ Overview: - Return this algorithm default model setting for demonstration. + Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. Returns: - - model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names + - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. .. note:: The user can define and use customized network model but must obey the same inferface definition indicated \ - by import_names path. For PDQN, ``ding.model.template.pdqn.PDQN`` + by import_names path. For example about PDQN, its registered name is ``pdqn`` and the import_names is \ + ``ding.model.template.pdqn``. """ return 'pdqn', ['ding.model.template.pdqn'] def _init_learn(self) -> None: """ Overview: - Learn mode init method. Called by ``self.__init__``, initialize the optimizer, algorithm arguments, main \ - and target model. + Initialize the learn mode of policy, including related attributes and modules. For PDQN, it mainly \ + contains two optimizers, algorithm-specific arguments such as nstep and gamma, main and target model. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight @@ -170,19 +202,31 @@ def _init_learn(self) -> None: def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Overview: - Forward computation graph of learn mode(updating policy). + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss, q value, target_q_value, priority. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \ - np.ndarray or dict/list combinations. + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For PDQN, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ + ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \ + and ``value_gamma``. Returns: - - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \ - recorded in text log and tensorboard, values are python scalar or a list of scalars. - ArgumentsKeys: - - necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done`` - - optional: ``value_gamma``, ``IS`` - ReturnsKeys: - - necessary: ``cur_lr``, ``q_loss``, ``continuous_loss``, - ``q_value``, ``priority``, ``reward``, ``target_q_value`` + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for PDQNPolicy: ``ding.policy.tests.test_pdqn``. """ data = default_preprocess_learn( data, @@ -276,7 +320,8 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: def _state_dict_learn(self) -> Dict[str, Any]: """ Overview: - Return the state_dict of learn mode, usually including model and optimizer. + Return the state_dict of learn mode, usually including model, target model, discrete part optimizer, and \ + continuous part optimizer. Returns: - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. """ @@ -307,8 +352,19 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: def _init_collect(self) -> None: """ Overview: - Collect mode init method. Called by ``self.__init__``, initialize algorithm arguments and collect_model, \ - enable the eps_greedy_sample for exploration. + Initialize the collect mode of policy, including related attributes and modules. For PDQN, it contains the \ + collect_model to balance the exploration and exploitation with epsilon-greedy sample mechanism and \ + continuous action mechanism, besides, other algorithm-specific arguments such as unroll_len and nstep are \ + also initialized here. + This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ + with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. + + .. tip:: + Some variables need to initialize independently in different modes, such as gamma and nstep in PDQN. This \ + design is for the convenience of parallel execution of different policy modes. """ self._unroll_len = self._cfg.collect.unroll_len self._gamma = self._cfg.discount_factor # necessary for parallel @@ -329,18 +385,27 @@ def _init_collect(self) -> None: def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: """ Overview: - Forward computation graph of collect mode(collect training data), with eps_greedy for exploration. + Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ + that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ + data, such as the action to interact with the envs. Besides, this policy also needs ``eps`` argument for \ + exploration, i.e., classic epsilon-greedy exploration strategy. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ - values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. - - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step. + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + - eps (:obj:`float`): The epsilon value for exploration. Returns: - - output (:obj:`Dict[int, Any]`): The dict of predicting policy_output(action) for the interaction with \ - env and the constructing of transition. - ArgumentsKeys: - - necessary: ``obs`` - ReturnsKeys - - necessary: ``logit``, ``action`` + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ + other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ + dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for PDQNPolicy: ``ding.policy.tests.test_pdqn``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) @@ -356,69 +421,86 @@ def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} - def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Overview: - For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \ - can be used for training directly. A train sample can be a processed transition(DQN with nstep TD) \ - or some continuous transitions(DRQN). + For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ + can be used for training directly. In PDQN, a train sample is a processed transition. \ + This method is usually used in collectors to execute necessary \ + RL data preprocessing before training, which can help learner amortize revelant time consumption. \ + In addition, you can also implement this method as an identity function and do the data processing \ + in ``self._forward_learn`` method. Arguments: - - data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \ - format as the return value of ``self._process_transition`` method. + - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ + the same format as the return value of ``self._process_transition`` method. Returns: - - samples (:obj:`dict`): The list of training samples. - - .. note:: - We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \ - And the user can customize the this data processing procecure by overriding this two methods and collector \ - itself. + - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \ + as input transitions, but may contain more data for training, such as nstep reward and target obs. """ - data = get_nstep_return_data(data, self._nstep, gamma=self._gamma) - return get_train_sample(data, self._unroll_len) + transitions = get_nstep_return_data(transitions, self._nstep, gamma=self._gamma) + return get_train_sample(transitions, self._unroll_len) - def _process_transition(self, obs: Any, model_output: Dict[str, Any], timestep: namedtuple) -> Dict[str, Any]: + def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor], + timestep: namedtuple) -> Dict[str, torch.Tensor]: """ Overview: - Generate a transition(e.g.: ) for this algorithm training. + Process and pack one timestep transition data into a dict, which can be directly used for training and \ + saved in replay buffer. For PDQN, it contains obs, next_obs, action, reward, done and logit. Arguments: - - obs (:obj:`Any`): Env observation. - - policy_output (:obj:`Dict[str, Any]`): The output of policy collect mode(``self._forward_collect``),\ - including at least ``action``. - - timestep (:obj:`namedtuple`): The output after env step(execute policy output action), including at \ - least ``obs``, ``reward``, ``done``, (here obs indicates obs after env step). + - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. + - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ + as input. For PDQN, it contains the hybrid action and the logit (discrete part q_value) of the action. + - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ + except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ + reward, done, info, etc. Returns: - - transition (:obj:`dict`): Dict type transition data. + - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. """ transition = { 'obs': obs, 'next_obs': timestep.obs, - 'action': model_output['action'], - 'logit': model_output['logit'], + 'action': policy_output['action'], + 'logit': policy_output['logit'], 'reward': timestep.reward, 'done': timestep.done, } return transition def _init_eval(self) -> None: - r""" + """ Overview: - Evaluate mode init method. Called by ``self.__init__``, initialize eval_model. + Initialize the eval mode of policy, including related attributes and modules. For PDQN, it contains the \ + eval model to greedily select action with argmax q_value mechanism. + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. """ self._eval_model = model_wrap(self._model, wrapper_name='hybrid_argmax_sample') self._eval_model.reset() - def _forward_eval(self, data: dict) -> dict: - r""" + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ Overview: - Forward function of eval mode, similar to ``self._forward_collect``. + Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ + means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ + action to interact with the envs. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ - values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. Returns: - - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. - ReturnsKeys - - necessary: ``action`` - - optional: ``logit`` + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for PDQNPolicy: ``ding.policy.tests.test_pdqn``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) @@ -435,10 +517,11 @@ def _forward_eval(self, data: dict) -> dict: return {i: d for i, d in zip(data_id, output)} def _monitor_vars_learn(self) -> List[str]: - r""" + """ Overview: - Return variables' names if variables are to used in monitor. + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. Returns: - - vars (:obj:`List[str]`): Variables' name list. + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. """ return ['cur_lr', 'total_loss', 'q_loss', 'continuous_loss', 'q_value', 'reward', 'target_q_value'] diff --git a/ding/policy/policy_factory.py b/ding/policy/policy_factory.py index d5e0a5da98..ba9b77df29 100644 --- a/ding/policy/policy_factory.py +++ b/ding/policy/policy_factory.py @@ -1,24 +1,39 @@ from typing import Dict, Any, Callable from collections import namedtuple from easydict import EasyDict +import gym import torch from ding.torch_utils import to_device -import gym class PolicyFactory: - r""" + """ Overview: - Pure random policy. Only used for initial sample collecting if `cfg.policy.random_collect_size` > 0. + Policy factory class, used to generate different policies for general purpose. Such as random action policy, \ + which is used for initial sample collecting for better exploration when ``random_collect_size`` > 0. + Interfaces: + ``get_random_policy`` """ @staticmethod def get_random_policy( - policy: 'BasePolicy', # noqa + policy: 'Policy.collect_mode', # noqa action_space: 'gym.spaces.Space' = None, # noqa forward_fn: Callable = None, - ) -> None: + ) -> 'Policy.collect_mode': # noqa + """ + Overview: + According to the given action space, define the forward function of the random policy, then pack it with \ + other interfaces of the given policy, and return the final collect mode interfaces of policy. + Arguments: + - policy (:obj:`Policy.collect_mode`): The collect mode interfaces of the policy. + - action_space (:obj:`gym.spaces.Space`): The action space of the environment, gym-style. + - forward_fn (:obj:`Callable`): It action space is too complex, you can define your own forward function \ + and pass it to this function, note you should set ``action_space`` to ``None`` in this case. + Returns: + - random_policy (:obj:`Policy.collect_mode`): The collect mode intefaces of the random policy. + """ assert not (action_space is None and forward_fn is None) random_collect_function = namedtuple( 'random_collect_function', [ @@ -69,7 +84,23 @@ def reset(*args, **kwargs) -> None: ) -def get_random_policy(cfg: EasyDict, policy: 'Policy.collect_mode', env: 'BaseEnvManager'): # noqa +def get_random_policy( + cfg: EasyDict, + policy: 'Policy.collect_mode', # noqa + env: 'BaseEnvManager' # noqa +) -> 'Policy.collect_mode': # noqa + """ + Overview: + The entry function to get the corresponding random policy. If a policy needs special data items in a \ + transition, then return itself, otherwise, we will use ``PolicyFactory`` to return a general random policy. + Arguments: + - cfg (:obj:`EasyDict`): The EasyDict-type dict configuration. + - policy (:obj:`Policy.collect_mode`): The collect mode interfaces of the policy. + - env (:obj:`BaseEnvManager`): The env manager instance, which is used to get the action space for random \ + action generation. + Returns: + - random_policy (:obj:`Policy.collect_mode`): The collect mode intefaces of the random policy. + """ if cfg.policy.get('transition_with_policy_data', False): return policy else: diff --git a/ding/policy/ppo.py b/ding/policy/ppo.py index 2245a47096..9ebd7b0f73 100644 --- a/ding/policy/ppo.py +++ b/ding/policy/ppo.py @@ -17,9 +17,9 @@ @POLICY_REGISTRY.register('ppo') class PPOPolicy(Policy): - r""" + """ Overview: - Policy class of on policy version PPO algorithm. + Policy class of on-policy version PPO algorithm. Paper link: https://arxiv.org/abs/1707.06347. """ config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). @@ -28,63 +28,105 @@ class PPOPolicy(Policy): cuda=False, # (bool) Whether the RL algorithm is on-policy or off-policy. (Note: in practice PPO can be off-policy used) on_policy=True, - # (bool) Whether to use priority(priority sample, IS weight, update priority) + # (bool) Whether to use priority (priority sample, IS weight, update priority). priority=False, # (bool) Whether to use Importance Sampling Weight to correct biased update due to priority. # If True, priority must be True. priority_IS_weight=False, - # (bool) Whether to recompurete advantages in each iteration of on-policy PPO + # (bool) Whether to recompurete advantages in each iteration of on-policy PPO. recompute_adv=True, # (str) Which kind of action space used in PPOPolicy, ['discrete', 'continuous', 'hybrid'] action_space='discrete', - # (bool) Whether to use nstep return to calculate value target, otherwise, use return = adv + value + # (bool) Whether to use nstep return to calculate value target, otherwise, use return = adv + value. nstep_return=False, - # (bool) Whether to enable multi-agent training, i.e.: MAPPO + # (bool) Whether to enable multi-agent training, i.e.: MAPPO. multi_agent=False, - # (bool) Whether to need policy data in process transition + # (bool) Whether to need policy ``_forward_collect`` output data in process transition. transition_with_policy_data=True, + # learn_mode config learn=dict( + # (int) After collecting n_sample/n_episode data, how many epoches to train models. + # Each epoch means the one entire passing of training data. epoch_per_collect=10, + # (int) How many samples in a training batch. batch_size=64, + # (float) The step size of gradient descent. learning_rate=3e-4, - # ============================================================== - # The following configs is algorithm-specific - # ============================================================== - # (float) The loss weight of value network, policy network weight is set to 1 + # (float) The loss weight of value network, policy network weight is set to 1. value_weight=0.5, - # (float) The loss weight of entropy regularization, policy network weight is set to 1 + # (float) The loss weight of entropy regularization, policy network weight is set to 1. entropy_weight=0.0, - # (float) PPO clip ratio, defaults to 0.2 + # (float) PPO clip ratio, defaults to 0.2. clip_ratio=0.2, - # (bool) Whether to use advantage norm in a whole training batch + # (bool) Whether to use advantage norm in a whole training batch. adv_norm=True, + # (bool) Whether to use value norm with running mean and std in the whole training process. value_norm=True, + # (bool) Whether to enable special network parameters initialization scheme in PPO, such as orthogonal init. ppo_param_init=True, + # (str) The gradient clip operation type used in PPO, ['clip_norm', clip_value', 'clip_momentum_norm']. grad_clip_type='clip_norm', + # (float) The gradient clip target value used in PPO. + # If ``grad_clip_type`` is 'clip_norm', then the maximum of gradient will be normalized to this value. grad_clip_value=0.5, + # (bool) Whether ignore done (usually for max step termination env). ignore_done=False, ), + # collect_mode config collect=dict( - # (int) Only one of [n_sample, n_episode] shoule be set + # (int) How many training samples collected in one collection procedure. + # Only one of [n_sample, n_episode] shoule be set. # n_sample=64, - # (int) Cut trajectories into pieces with length "unroll_len". + # (int) Split episodes or trajectories into pieces with length `unroll_len`. unroll_len=1, - # ============================================================== - # The following configs is algorithm-specific - # ============================================================== # (float) Reward's future discount factor, aka. gamma. discount_factor=0.99, # (float) GAE lambda factor for the balance of bias and variance(1-step td and mc) gae_lambda=0.95, ), - eval=dict(), + eval=dict(), # for compability ) + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. + + .. note:: + The user can define and use customized network model but must obey the same inferface definition indicated \ + by import_names path. For example about PPO, its registered name is ``ppo`` and the import_names is \ + ``ding.model.template.vac``. + + .. note:: + Because now PPO supports both single-agent and multi-agent usages, so we can implement these functions \ + with the same policy and two different default models, which is controled by ``self._cfg.multi_agent``. + """ + if self._cfg.multi_agent: + return 'mavac', ['ding.model.template.mavac'] + else: + return 'vac', ['ding.model.template.vac'] + def _init_learn(self) -> None: - r""" + """ Overview: - Learn mode init method. Called by ``self.__init__``. - Init the optimizer, algorithm config and the main model. + Initialize the learn mode of policy, including related attributes and modules. For PPO, it mainly contains \ + optimizer, algorithm-specific arguments such as loss weight, clip_ratio and recompute_adv. This method \ + also executes some special network initializations and prepares running mean/std monitor for value. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight @@ -143,16 +185,40 @@ def _init_learn(self) -> None: # Main model self._learn_model.reset() - def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: - r""" + def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ Overview: - Forward and backward function of learn mode. + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss, clipfrac, approx_kl. Arguments: - - data (:obj:`dict`): Dict type data + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including the latest \ + collected training samples for on-policy algorithms like PPO. For each element in list, the key of the \ + dict is the name of data items and the value is the corresponding data. Usually, the value is \ + torch.Tensor or np.ndarray or there dict/list combinations. In the ``_forward_learn`` method, data \ + often need to first be stacked in the batch dimension by some utility functions such as \ + ``default_preprocess_learn``. \ + For PPO, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ + ``reward``, ``logit``, ``value``, ``done``. Sometimes, it also contains other keys such as ``weight``. Returns: - - info_dict (:obj:`Dict[str, Any]`): - Including current lr, total_loss, policy_loss, value_loss, entropy_loss, \ - adv_abs_max, approx_kl, clipfrac + - return_infos (:obj:`List[Dict[str, Any]]`): The information list that indicated training result, each \ + training iteration contains append a information dict into the final list. The list will be precessed \ + and recorded in text log and tensorboard. The value of the dict must be python scalar or a list of \ + scalars. For the detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. tip:: + The training procedure of PPO is two for loops. The outer loop trains all the collected training samples \ + with ``epoch_per_collect`` epochs. The inner loop splits all the data into different mini-batch with \ + the length of ``batch_size``. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for PPOPolicy: ``ding.policy.tests.test_ppo``. """ data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=False) if self._cuda: @@ -272,24 +338,24 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: return_infos.append(return_info) return return_infos - def _state_dict_learn(self) -> Dict[str, Any]: - return { - 'model': self._learn_model.state_dict(), - 'optimizer': self._optimizer.state_dict(), - } - - def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - self._learn_model.load_state_dict(state_dict['model']) - self._optimizer.load_state_dict(state_dict['optimizer']) - def _init_collect(self) -> None: - r""" + """ Overview: - Collect mode init method. Called by ``self.__init__``. - Init traj and unroll length, collect model. + Initialize the collect mode of policy, including related attributes and modules. For PPO, it contains the \ + collect_model to balance the exploration and exploitation (e.g. the multinomial sample mechanism in \ + discrete action space), and other algorithm-specific arguments such as unroll_len and gae_lambda. + This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ + with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. + + .. tip:: + Some variables need to initialize independently in different modes, such as gamma and gae_lambda in PPO. \ + This design is for the convenience of parallel execution of different policy modes. """ self._unroll_len = self._cfg.collect.unroll_len - assert self._cfg.action_space in ["continuous", "discrete", "hybrid"] + assert self._cfg.action_space in ["continuous", "discrete", "hybrid"], self._cfg.action_space self._action_space = self._cfg.action_space if self._action_space == 'continuous': self._collect_model = model_wrap(self._model, wrapper_name='reparam_sample') @@ -302,17 +368,32 @@ def _init_collect(self) -> None: self._gae_lambda = self._cfg.collect.gae_lambda self._recompute_adv = self._cfg.recompute_adv - def _forward_collect(self, data: dict) -> dict: - r""" + def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ Overview: - Forward function of collect mode. + Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ + that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ + data, such as the action to interact with the envs. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ - values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs. - ReturnsKeys - - necessary: ``action`` + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ + other necessary data (action logit and value) for learn mode defined in ``self._process_transition`` \ + method. The key of the dict is the same as the input data, i.e. environment id. + + .. tip:: + If you want to add more tricks on this policy, like temperature factor in multinomial sample, you can pass \ + related data as extra keyword arguments of this method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for PPOPolicy: ``ding.policy.tests.test_ppo``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) @@ -326,38 +407,54 @@ def _forward_collect(self, data: dict) -> dict: output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} - def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: + def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor], + timestep: namedtuple) -> Dict[str, torch.Tensor]: """ Overview: - Generate dict type transition data from inputs. + Process and pack one timestep transition data into a dict, which can be directly used for training and \ + saved in replay buffer. For PPO, it contains obs, next_obs, action, reward, done, logit, value. Arguments: - - obs (:obj:`Any`): Env observation - - model_output (:obj:`dict`): Output of collect model, including at least ['action'] - - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done']\ - (here 'obs' indicates obs after env step). + - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. + - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ + as input. For PPO, it contains the state value, action and the logit of the action. + - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ + except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ + reward, done, info, etc. Returns: - - transition (:obj:`dict`): Dict type transition data. + - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. + + .. note:: + ``next_obs`` is used to calculate nstep return when necessary, so we place in into transition by default. \ + You can delete this field to save memory occupancy if you do not need nstep return. """ transition = { 'obs': obs, 'next_obs': timestep.obs, - 'action': model_output['action'], - 'logit': model_output['logit'], - 'value': model_output['value'], + 'action': policy_output['action'], + 'logit': policy_output['logit'], + 'value': policy_output['value'], 'reward': timestep.reward, 'done': timestep.done, } return transition - def _get_train_sample(self, data: list) -> Union[None, List[Any]]: - r""" + def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ Overview: - Get the trajectory and calculate GAE, return one data to cache for next time calculation + For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ + can be used for training directly. In PPO, a train sample is a processed transition with new computed \ + ``traj_flag`` and ``adv`` field. This method is usually used in collectors to execute necessary \ + RL data preprocessing before training, which can help learner amortize revelant time consumption. \ + In addition, you can also implement this method as an identity function and do the data processing \ + in ``self._forward_learn`` method. Arguments: - - data (:obj:`list`): The trajectory's cache + - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ + the same format as the return value of ``self._process_transition`` method. Returns: - - samples (:obj:`dict`): The training samples generated + - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \ + as input transitions, but may contain more data for training, such as GAE advantage. """ + data = transitions data = to_device(data, self._device) for transition in data: transition['traj_flag'] = copy.deepcopy(transition['done']) @@ -397,10 +494,15 @@ def _get_train_sample(self, data: list) -> Union[None, List[Any]]: return get_train_sample(data, self._unroll_len) def _init_eval(self) -> None: - r""" + """ Overview: - Evaluate mode init method. Called by ``self.__init__``. - Init eval model with argmax strategy. + Initialize the eval mode of policy, including related attributes and modules. For PPO, it contains the \ + eval model to select optimial action (e.g. greedily select action with argmax mechanism in discrete action). + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. """ assert self._cfg.action_space in ["continuous", "discrete", "hybrid"] self._action_space = self._cfg.action_space @@ -412,17 +514,29 @@ def _init_eval(self) -> None: self._eval_model = model_wrap(self._model, wrapper_name='hybrid_deterministic_argmax_sample') self._eval_model.reset() - def _forward_eval(self, data: dict) -> dict: - r""" + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ Overview: - Forward function of eval mode, similar to ``self._forward_collect``. + Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ + means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ + action to interact with the envs. ``_forward_eval`` in PPO often uses deterministic sample method to get \ + actions while ``_forward_collect`` usually uses stochastic sample method for balance exploration and \ + exploitation. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ - values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. Returns: - - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. - ReturnsKeys - - necessary: ``action`` + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for PPOPolicy: ``ding.policy.tests.test_ppo``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) @@ -436,13 +550,14 @@ def _forward_eval(self, data: dict) -> dict: output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} - def default_model(self) -> Tuple[str, List[str]]: - if self._cfg.multi_agent: - return 'mavac', ['ding.model.template.mavac'] - else: - return 'vac', ['ding.model.template.vac'] - def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ variables = super()._monitor_vars_learn() + [ 'policy_loss', 'value_loss', @@ -461,9 +576,10 @@ def _monitor_vars_learn(self) -> List[str]: @POLICY_REGISTRY.register('ppo_pg') class PPOPGPolicy(Policy): - r""" + """ Overview: - Policy class of on policy version PPO algorithm (pure policy gradient). + Policy class of on policy version PPO algorithm (pure policy gradient without value network). + Paper link: https://arxiv.org/abs/1707.06347. """ config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). @@ -474,46 +590,75 @@ class PPOPGPolicy(Policy): on_policy=True, # (str) Which kind of action space used in PPOPolicy, ['discrete', 'continuous', 'hybrid'] action_space='discrete', - # (bool) Whether to enable multi-agent training, i.e.: MAPPO + # (bool) Whether to enable multi-agent training, i.e.: MAPPO. multi_agent=False, - # (bool) Whether to need policy data in process transition + # (bool) Whether to need policy data in process transition. transition_with_policy_data=True, + # learn_mode config learn=dict( + # (int) After collecting n_sample/n_episode data, how many epoches to train models. + # Each epoch means the one entire passing of training data. epoch_per_collect=10, + # (int) How many samples in a training batch. batch_size=64, + # (float) The step size of gradient descent. learning_rate=3e-4, - # ============================================================== - # The following configs is algorithm-specific - # ============================================================== - # (float) The loss weight of entropy regularization, policy network weight is set to 1 + # (float) The loss weight of entropy regularization, policy network weight is set to 1. entropy_weight=0.0, - # (float) PPO clip ratio, defaults to 0.2 + # (float) PPO clip ratio, defaults to 0.2. clip_ratio=0.2, - # (bool) Whether to use advantage norm in a whole training batch + # (bool) Whether to enable special network parameters initialization scheme in PPO, such as orthogonal init. ppo_param_init=True, + # (str) The gradient clip operation type used in PPO, ['clip_norm', clip_value', 'clip_momentum_norm']. grad_clip_type='clip_norm', + # (float) The gradient clip target value used in PPO. + # If ``grad_clip_type`` is 'clip_norm', then the maximum of gradient will be normalized to this value. grad_clip_value=0.5, + # (bool) Whether ignore done (usually for max step termination env). ignore_done=False, ), + # collect_mode config collect=dict( - # (int) Only one of n_episode shoule be set + # (int) How many training episodes collected in one collection process. Only one of n_episode shoule be set. # n_episode=8, # (int) Cut trajectories into pieces with length "unroll_len". unroll_len=1, - # ============================================================== - # The following configs is algorithm-specific - # ============================================================== # (float) Reward's future discount factor, aka. gamma. discount_factor=0.99, ), - eval=dict(), + eval=dict(), # for compability ) def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. + """ return 'pg', ['ding.model.template.pg'] def _init_learn(self) -> None: - assert self._cfg.action_space in ["continuous", "discrete", "hybrid"] + """ + Overview: + Initialize the learn mode of policy, including related attributes and modules. For PPOPG, it mainly \ + contains optimizer, algorithm-specific arguments such as loss weight and clip_ratio. This method \ + also executes some special network initializations. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. + """ + assert self._cfg.action_space in ["continuous", "discrete"] self._action_space = self._cfg.action_space if self._cfg.learn.ppo_param_init: for n, m in self._model.named_modules(): @@ -545,7 +690,39 @@ def _init_learn(self) -> None: # Main model self._learn_model.reset() - def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: + def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Overview: + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss, clipfrac, approx_kl. + Arguments: + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including the latest \ + collected training samples for on-policy algorithms like PPO. For each element in list, the key of the \ + dict is the name of data items and the value is the corresponding data. Usually, the value is \ + torch.Tensor or np.ndarray or there dict/list combinations. In the ``_forward_learn`` method, data \ + often need to first be stacked in the batch dimension by some utility functions such as \ + ``default_preprocess_learn``. \ + For PPOPG, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ + ``return``, ``logit``, ``done``. Sometimes, it also contains other keys such as ``weight``. + Returns: + - return_infos (:obj:`List[Dict[str, Any]]`): The information list that indicated training result, each \ + training iteration contains append a information dict into the final list. The list will be precessed \ + and recorded in text log and tensorboard. The value of the dict must be python scalar or a list of \ + scalars. For the detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. tip:: + The training procedure of PPOPG is two for loops. The outer loop trains all the collected training samples \ + with ``epoch_per_collect`` epochs. The inner loop splits all the data into different mini-batch with \ + the length of ``batch_size``. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + """ + data = default_preprocess_learn(data) if self._cuda: data = to_device(data, self._device) @@ -589,7 +766,22 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: return return_infos def _init_collect(self) -> None: - assert self._cfg.action_space in ["continuous", "discrete", "hybrid"] + """ + Overview: + Initialize the collect mode of policy, including related attributes and modules. For PPOPG, it contains \ + the collect_model to balance the exploration and exploitation (e.g. the multinomial sample mechanism in \ + discrete action space), and other algorithm-specific arguments such as unroll_len and gae_lambda. + This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ + with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. + + .. tip:: + Some variables need to initialize independently in different modes, such as gamma and gae_lambda in PPO. \ + This design is for the convenience of parallel execution of different policy modes. + """ + assert self._cfg.action_space in ["continuous", "discrete"], self._cfg.action_space self._action_space = self._cfg.action_space self._unroll_len = self._cfg.collect.unroll_len if self._action_space == 'continuous': @@ -599,7 +791,30 @@ def _init_collect(self) -> None: self._collect_model.reset() self._gamma = self._cfg.collect.discount_factor - def _forward_collect(self, data: dict) -> dict: + def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ + Overview: + Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ + that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ + data, such as the action to interact with the envs. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ + other necessary data (action logit) for learn mode defined in ``self._process_transition`` \ + method. The key of the dict is the same as the input data, i.e. environment id. + + .. tip:: + If you want to add more tricks on this policy, like temperature factor in multinomial sample, you can pass \ + related data as extra keyword arguments of this method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + """ data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: @@ -612,17 +827,47 @@ def _forward_collect(self, data: dict) -> dict: output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} - def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: + def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor], + timestep: namedtuple) -> Dict[str, torch.Tensor]: + """ + Overview: + Process and pack one timestep transition data into a dict, which can be directly used for training and \ + saved in replay buffer. For PPOPG, it contains obs, action, reward, done, logit. + Arguments: + - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. + - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ + as input. For PPOPG, it contains the action and the logit of the action. + - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ + except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ + reward, done, info, etc. + Returns: + - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. + """ transition = { 'obs': obs, - 'action': model_output['action'], - 'logit': model_output['logit'], + 'action': policy_output['action'], + 'logit': policy_output['logit'], 'reward': timestep.reward, 'done': timestep.done, } return transition - def _get_train_sample(self, data: list) -> Union[None, List[Any]]: + def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Overview: + For a given entire episode data (a list of transition), process it into a list of sample that \ + can be used for training directly. In PPOPG, a train sample is a processed transition with new computed \ + ``return`` field. This method is usually used in collectors to execute necessary \ + RL data preprocessing before training, which can help learner amortize revelant time consumption. \ + In addition, you can also implement this method as an identity function and do the data processing \ + in ``self._forward_learn`` method. + Arguments: + - data (:obj:`List[Dict[str, Any]`): The episode data (a list of transition), each element is \ + the same format as the return value of ``self._process_transition`` method. + Returns: + - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \ + as input transitions, but may contain more data for training, such as discounted episode return. + """ assert data[-1]['done'] is True, "PPO-PG needs a complete epsiode" if self._cfg.learn.ignore_done: @@ -636,7 +881,17 @@ def _get_train_sample(self, data: list) -> Union[None, List[Any]]: return get_train_sample(data, self._unroll_len) def _init_eval(self) -> None: - assert self._cfg.action_space in ["continuous", "discrete", "hybrid"] + """ + Overview: + Initialize the eval mode of policy, including related attributes and modules. For PPOPG, it contains the \ + eval model to select optimial action (e.g. greedily select action with argmax mechanism in discrete action). + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. + """ + assert self._cfg.action_space in ["continuous", "discrete"] self._action_space = self._cfg.action_space if self._action_space == 'continuous': self._eval_model = model_wrap(self._model, wrapper_name='deterministic_sample') @@ -644,7 +899,30 @@ def _init_eval(self) -> None: self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') self._eval_model.reset() - def _forward_eval(self, data: dict) -> dict: + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ + Overview: + Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ + means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ + action to interact with the envs. ``_forward_eval`` in PPO often uses deterministic sample method to get \ + actions while ``_forward_collect`` usually uses stochastic sample method for balance exploration and \ + exploitation. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for PPOPGPolicy: ``ding.policy.tests.test_ppo``. + """ data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: @@ -658,6 +936,13 @@ def _forward_eval(self, data: dict) -> dict: return {i: d for i, d in zip(data_id, output)} def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ return super()._monitor_vars_learn() + [ 'policy_loss', 'entropy_loss', @@ -668,9 +953,10 @@ def _monitor_vars_learn(self) -> List[str]: @POLICY_REGISTRY.register('ppo_offpolicy') class PPOOffPolicy(Policy): - r""" + """ Overview: - Policy class of PPO algorithm. + Policy class of off-policy version PPO algorithm. Paper link: https://arxiv.org/abs/1707.06347. + This version is more suitable for large-scale distributed training. """ config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). @@ -678,69 +964,103 @@ class PPOOffPolicy(Policy): # (bool) Whether to use cuda for network. cuda=False, on_policy=False, - # (bool) Whether to use priority(priority sample, IS weight, update priority) + # (bool) Whether to use priority (priority sample, IS weight, update priority). priority=False, # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. priority_IS_weight=False, - # (str) Which kind of action space used in PPOPolicy, ["continuous", "discrete", "hybrid"] + # (str) Which kind of action space used in PPOPolicy, ["continuous", "discrete", "hybrid"]. action_space='discrete', - # (bool) Whether to use nstep_return for value loss + # (bool) Whether to use nstep_return for value loss. nstep_return=False, + # (int) The timestep of TD (temporal-difference) loss. nstep=3, - # (bool) Whether to need policy data in process transition + # (bool) Whether to need policy data in process transition. transition_with_policy_data=True, + # learn_mode config learn=dict( - # How many updates(iterations) to train after collector's one collection. + # (int) How many updates(iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. # collect data -> update policy-> collect data -> ... update_per_collect=5, + # (int) How many samples in a training batch. batch_size=64, + # (float) The step size of gradient descent. learning_rate=0.001, - separate_optimizer=False, - # ============================================================== - # The following configs is algorithm-specific - # ============================================================== - # (float) The loss weight of value network, policy network weight is set to 1 + # (float) The loss weight of value network, policy network weight is set to 1. value_weight=0.5, - # (float) The loss weight of entropy regularization, policy network weight is set to 1 + # (float) The loss weight of entropy regularization, policy network weight is set to 1. entropy_weight=0.01, - # (float) PPO clip ratio, defaults to 0.2 + # (float) PPO clip ratio, defaults to 0.2. clip_ratio=0.2, - # (bool) Whether to use advantage norm in a whole training batch + # (bool) Whether to use advantage norm in a whole training batch. adv_norm=False, + # (bool) Whether to use value norm with running mean and std in the whole training process. value_norm=True, + # (bool) Whether to enable special network parameters initialization scheme in PPO, such as orthogonal init. ppo_param_init=True, + # (str) The gradient clip operation type used in PPO, ['clip_norm', clip_value', 'clip_momentum_norm']. grad_clip_type='clip_norm', + # (float) The gradient clip target value used in PPO. + # If ``grad_clip_type`` is 'clip_norm', then the maximum of gradient will be normalized to this value. grad_clip_value=0.5, + # (bool) Whether ignore done (usually for max step termination env). ignore_done=False, + # (float) The weight decay (L2 regularization) loss weight, defaults to 0.0. weight_decay=0.0, ), + # collect_mode config collect=dict( - # (int) Only one of [n_sample, n_episode] shoule be set + # (int) How many training samples collected in one collection procedure. + # Only one of [n_sample, n_episode] shoule be set. # n_sample=64, # (int) Cut trajectories into pieces with length "unroll_len". unroll_len=1, - # ============================================================== - # The following configs is algorithm-specific - # ============================================================== # (float) Reward's future discount factor, aka. gamma. discount_factor=0.99, - # (float) GAE lambda factor for the balance of bias and variance(1-step td and mc) + # (float) GAE lambda factor for the balance of bias and variance (1-step td and mc). gae_lambda=0.95, ), - eval=dict(), - other=dict(replay_buffer=dict(replay_buffer_size=10000, ), ), + eval=dict(), # for compability + other=dict( + replay_buffer=dict( + # (int) Maximum size of replay buffer. Usually, larger buffer size is better. + replay_buffer_size=10000, + ), + ), ) + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. + """ + return 'vac', ['ding.model.template.vac'] + def _init_learn(self) -> None: - r""" + """ Overview: - Learn mode init method. Called by ``self.__init__``. - Init the optimizer, algorithm config and the main model. + Initialize the learn mode of policy, including related attributes and modules. For PPOOff, it mainly \ + contains optimizer, algorithm-specific arguments such as loss weight and clip_ratio. This method \ + also executes some special network initializations and prepares running mean/std monitor for value. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight - assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in PPO" + assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in PPOOff" assert self._cfg.action_space in ["continuous", "discrete", "hybrid"] self._action_space = self._cfg.action_space @@ -797,16 +1117,31 @@ def _init_learn(self) -> None: # Main model self._learn_model.reset() - def _forward_learn(self, data: dict) -> Dict[str, Any]: - r""" + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """ Overview: - Forward and backward function of learn mode. + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss, clipfrac and approx_kl. Arguments: - - data (:obj:`dict`): Dict type data + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For PPOOff, each element in list is a dict containing at least the following keys: ``obs``, ``adv``, \ + ``action``, ``logit``, ``value``, ``done``. Sometimes, it also contains other keys such as ``weight`` \ + and ``value_gamma``. Returns: - - info_dict (:obj:`Dict[str, Any]`): - Including current lr, total_loss, policy_loss, value_loss, entropy_loss, \ - adv_abs_max, approx_kl, clipfrac + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. """ data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=self._nstep_return) if self._cuda: @@ -821,7 +1156,7 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: self._learn_model.train() with torch.no_grad(): - if hasattr(self, "_value_norm") and self._value_norm: + if self._value_norm: unnormalized_return = data['adv'] + data['value'] * self._running_mean_std.std data['return'] = unnormalized_return / self._running_mean_std.std self._running_mean_std.update(unnormalized_return.cpu().numpy()) @@ -961,21 +1296,21 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: ) return return_info - def _state_dict_learn(self) -> Dict[str, Any]: - return { - 'model': self._learn_model.state_dict(), - 'optimizer': self._optimizer.state_dict(), - } - - def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - self._learn_model.load_state_dict(state_dict['model']) - self._optimizer.load_state_dict(state_dict['optimizer']) - def _init_collect(self) -> None: - r""" + """ Overview: - Collect mode init method. Called by ``self.__init__``. - Init traj and unroll length, collect model. + Initialize the collect mode of policy, including related attributes and modules. For PPOOff, it contains \ + collect_model to balance the exploration and exploitation (e.g. the multinomial sample mechanism in \ + discrete action space), and other algorithm-specific arguments such as unroll_len and gae_lambda. + This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ + with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. + + .. tip:: + Some variables need to initialize independently in different modes, such as gamma and gae_lambda in PPOOff. + This design is for the convenience of parallel execution of different policy modes. """ self._unroll_len = self._cfg.collect.unroll_len assert self._cfg.action_space in ["continuous", "discrete", "hybrid"] @@ -991,18 +1326,34 @@ def _init_collect(self) -> None: self._gae_lambda = self._cfg.collect.gae_lambda self._nstep = self._cfg.nstep self._nstep_return = self._cfg.nstep_return + self._value_norm = self._cfg.learn.value_norm - def _forward_collect(self, data: dict) -> dict: - r""" + def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ Overview: - Forward function of collect mode. + Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ + that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ + data, such as the action to interact with the envs. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ - values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs. - ReturnsKeys - - necessary: ``action`` + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ + other necessary data (action logit and value) for learn mode defined in ``self._process_transition`` \ + method. The key of the dict is the same as the input data, i.e. environment id. + + .. tip:: + If you want to add more tricks on this policy, like temperature factor in multinomial sample, you can pass \ + related data as extra keyword arguments of this method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for PPOOffPolicy: ``ding.policy.tests.test_ppo``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) @@ -1016,39 +1367,55 @@ def _forward_collect(self, data: dict) -> dict: output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} - def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: + def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor], + timestep: namedtuple) -> Dict[str, torch.Tensor]: """ Overview: - Generate dict type transition data from inputs. + Process and pack one timestep transition data into a dict, which can be directly used for training and \ + saved in replay buffer. For PPO, it contains obs, next_obs, action, reward, done, logit, value. Arguments: - - obs (:obj:`Any`): Env observation - - model_output (:obj:`dict`): Output of collect model, including at least ['action'] - - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done']\ - (here 'obs' indicates obs after env step). + - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. + - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ + as input. For PPO, it contains the state value, action and the logit of the action. + - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ + except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ + reward, done, info, etc. Returns: - - transition (:obj:`dict`): Dict type transition data. + - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. + + .. note:: + ``next_obs`` is used to calculate nstep return when necessary, so we place in into transition by default. \ + You can delete this field to save memory occupancy if you do not need nstep return. """ transition = { 'obs': obs, 'next_obs': timestep.obs, - 'logit': model_output['logit'], - 'action': model_output['action'], - 'value': model_output['value'], + 'logit': policy_output['logit'], + 'action': policy_output['action'], + 'value': policy_output['value'], 'reward': timestep.reward, 'done': timestep.done, } return transition - def _get_train_sample(self, data: list) -> Union[None, List[Any]]: - r""" + def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ Overview: - Get the trajectory and calculate GAE, return one data to cache for next time calculation + For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ + can be used for training directly. In PPO, a train sample is a processed transition with new computed \ + ``traj_flag`` and ``adv`` field. This method is usually used in collectors to execute necessary \ + RL data preprocessing before training, which can help learner amortize revelant time consumption. \ + In addition, you can also implement this method as an identity function and do the data processing \ + in ``self._forward_learn`` method. Arguments: - - data (:obj:`list`): The trajectory's cache + - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ + the same format as the return value of ``self._process_transition`` method. Returns: - - samples (:obj:`dict`): The training samples generated + - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \ + as input transitions, but may contain more data for training, such as GAE advantage. """ + data = transitions data = to_device(data, self._device) for transition in data: transition['traj_flag'] = copy.deepcopy(transition['done']) @@ -1066,7 +1433,7 @@ def _get_train_sample(self, data: list) -> Union[None, List[Any]]: )['value'] if len(last_value.shape) == 2: # multi_agent case: last_value = last_value.squeeze(0) - if hasattr(self, "_value_norm") and self._value_norm: + if self._value_norm: last_value *= self._running_mean_std.std for i in range(len(data)): data[i]['value'] *= self._running_mean_std.std @@ -1077,7 +1444,7 @@ def _get_train_sample(self, data: list) -> Union[None, List[Any]]: gae_lambda=self._gae_lambda, cuda=False, ) - if hasattr(self, "_value_norm") and self._value_norm: + if self._value_norm: for i in range(len(data)): data[i]['value'] /= self._running_mean_std.std @@ -1087,10 +1454,15 @@ def _get_train_sample(self, data: list) -> Union[None, List[Any]]: return get_nstep_return_data(data, self._nstep) def _init_eval(self) -> None: - r""" + """ Overview: - Evaluate mode init method. Called by ``self.__init__``. - Init eval model with argmax strategy. + Initialize the eval mode of policy, including related attributes and modules. For PPOOff, it contains the \ + eval model to select optimial action (e.g. greedily select action with argmax mechanism in discrete action). + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. """ assert self._cfg.action_space in ["continuous", "discrete", "hybrid"] self._action_space = self._cfg.action_space @@ -1102,17 +1474,29 @@ def _init_eval(self) -> None: self._eval_model = model_wrap(self._model, wrapper_name='hybrid_deterministic_argmax_sample') self._eval_model.reset() - def _forward_eval(self, data: dict) -> dict: - r""" + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ Overview: - Forward function of eval mode, similar to ``self._forward_collect``. + Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ + means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ + action to interact with the envs. ``_forward_eval`` in PPO often uses deterministic sample method to get \ + actions while ``_forward_collect`` usually uses stochastic sample method for balance exploration and \ + exploitation. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ - values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. Returns: - - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. - ReturnsKeys - - necessary: ``action`` + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for PPOOffPolicy: ``ding.policy.tests.test_ppo``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) @@ -1126,10 +1510,14 @@ def _forward_eval(self, data: dict) -> dict: output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} - def default_model(self) -> Tuple[str, List[str]]: - return 'vac', ['ding.model.template.vac'] - def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ variables = super()._monitor_vars_learn() + [ 'policy_loss', 'value', 'value_loss', 'entropy_loss', 'adv_abs_max', 'approx_kl', 'clipfrac' ] @@ -1143,6 +1531,8 @@ class PPOSTDIMPolicy(PPOPolicy): """ Overview: Policy class of on policy version PPO algorithm with ST-DIM auxiliary model. + PPO paper link: https://arxiv.org/abs/1707.06347. + ST-DIM paper link: https://arxiv.org/abs/1906.08226. """ config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). @@ -1166,6 +1556,8 @@ class PPOSTDIMPolicy(PPOPolicy): multi_agent=False, # (bool) Whether to need policy data in process transition transition_with_policy_data=True, + # (float) The loss weight of the auxiliary model to the main loss. + aux_loss_weight=0.001, aux_model=dict( # (int) the encoding size (of each head) to apply contrastive loss. encode_shape=64, @@ -1176,42 +1568,48 @@ class PPOSTDIMPolicy(PPOPolicy): # (float) a parameter to adjust the polarity between positive and negative samples. temperature=1.0, ), + # learn_mode config learn=dict( + # (int) After collecting n_sample/n_episode data, how many epoches to train models. + # Each epoch means the one entire passing of training data. epoch_per_collect=10, + # (int) How many samples in a training batch. batch_size=64, + # (float) The step size of gradient descent. learning_rate=3e-4, - # ============================================================== - # The following configs is algorithm-specific - # ============================================================== - # (float) The loss weight of value network, policy network weight is set to 1 + # (float) The loss weight of value network, policy network weight is set to 1. value_weight=0.5, - # (float) The loss weight of entropy regularization, policy network weight is set to 1 + # (float) The loss weight of entropy regularization, policy network weight is set to 1. entropy_weight=0.0, - # (float) PPO clip ratio, defaults to 0.2 + # (float) PPO clip ratio, defaults to 0.2. clip_ratio=0.2, - # (bool) Whether to use advantage norm in a whole training batch + # (bool) Whether to use advantage norm in a whole training batch. adv_norm=True, + # (bool) Whether to use value norm with running mean and std in the whole training process. value_norm=True, + # (bool) Whether to enable special network parameters initialization scheme in PPO, such as orthogonal init. ppo_param_init=True, + # (str) The gradient clip operation type used in PPO, ['clip_norm', clip_value', 'clip_momentum_norm']. grad_clip_type='clip_norm', + # (float) The gradient clip target value used in PPO. + # If ``grad_clip_type`` is 'clip_norm', then the maximum of gradient will be normalized to this value. grad_clip_value=0.5, + # (bool) Whether ignore done (usually for max step termination env). ignore_done=False, ), + # collect_mode config collect=dict( - # (int) Only one of [n_sample, n_episode] shoule be set + # (int) How many training samples collected in one collection procedure. + # Only one of [n_sample, n_episode] shoule be set. # n_sample=64, # (int) Cut trajectories into pieces with length "unroll_len". unroll_len=1, - # ============================================================== - # The following configs is algorithm-specific - # ============================================================== # (float) Reward's future discount factor, aka. gamma. discount_factor=0.99, - # (float) GAE lambda factor for the balance of bias and variance(1-step td and mc) + # (float) GAE lambda factor for the balance of bias and variance (1-step td and mc). gae_lambda=0.95, ), - eval=dict(), - aux_loss_weight=0.001, + eval=dict(), # for compability ) def _init_learn(self) -> None: @@ -1400,6 +1798,13 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: return return_infos def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model, optimizer and aux_optimizer for \ + representation learning. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. + """ return { 'model': self._learn_model.state_dict(), 'optimizer': self._optimizer.state_dict(), @@ -1407,9 +1812,27 @@ def _state_dict_learn(self) -> Dict[str, Any]: } def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + + .. tip:: + If you want to only load some parts of model, you can simply set the ``strict`` argument in \ + load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ + complicated operation. + """ self._learn_model.load_state_dict(state_dict['model']) self._optimizer.load_state_dict(state_dict['optimizer']) self._aux_optimizer.load_state_dict(state_dict['aux_optimizer']) def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ return super()._monitor_vars_learn() + ["aux_loss_learn", "aux_loss_eval"] diff --git a/ding/policy/qmix.py b/ding/policy/qmix.py index 58cd5e074e..ff1d66f7c8 100644 --- a/ding/policy/qmix.py +++ b/ding/policy/qmix.py @@ -1,7 +1,7 @@ -from typing import List, Dict, Any, Tuple, Union, Optional +from typing import List, Dict, Any, Tuple, Optional from collections import namedtuple -import torch import copy +import torch from ding.torch_utils import RMSprop, to_device from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample @@ -15,12 +15,8 @@ class QMIXPolicy(Policy): """ Overview: - Policy class of QMIX algorithm. QMIX is a multi model reinforcement learning algorithm, \ - you can view the paper in the following link https://arxiv.org/abs/1803.11485 - Interface: - _init_learn, _data_preprocess_learn, _forward_learn, _reset_learn, _state_dict_learn, _load_state_dict_learn \ - _init_collect, _forward_collect, _reset_collect, _process_transition, _init_eval, _forward_eval \ - _reset_eval, _get_train_sample, default_model + Policy class of QMIX algorithm. QMIX is a multi-agent reinforcement learning algorithm, \ + you can view the paper in the following link https://arxiv.org/abs/1803.11485. Config: == ==================== ======== ============== ======================================== ======================= ID Symbol Type Default Value Description Other(Shape) @@ -55,48 +51,48 @@ class QMIXPolicy(Policy): priority=False, # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. priority_IS_weight=False, + # learn_mode config learn=dict( + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... update_per_collect=20, + # (int) How many samples in a training batch. batch_size=32, + # (float) The step size of gradient descent. learning_rate=0.0005, clip_value=100, - # ============================================================== - # The following configs is algorithm-specific - # ============================================================== - # (float) Target network update momentum parameter. - # in [0, 1]. + # (float) Target network update momentum parameter, in [0, 1]. target_update_theta=0.008, - # (float) The discount factor for future rewards, - # in [0, 1]. + # (float) The discount factor for future rewards, in [0, 1]. discount_factor=0.99, - # (bool) Whether to use double DQN mechanism(target q for surpassing over estimation) + # (bool) Whether to use double DQN mechanism(target q for surpassing over estimation). double_q=False, ), + # collect_mode config collect=dict( - # (int) Only one of [n_sample, n_episode] shoule be set - # n_episode=32, - # (int) Cut trajectories into pieces with length "unroll_len", the length of timesteps + # (int) How many training samples collected in one collection procedure. + # In each collect phase, we collect a total of sequence samples, a sample with length unroll_len. + # n_sample=32, + # (int) Split trajectories into pieces with length ``unroll_len``, the length of timesteps # in each forward when training. In qmix, it is greater than 1 because there is RNN. unroll_len=10, ), - eval=dict(), + eval=dict(), # for compatibility other=dict( eps=dict( - # (str) Type of epsilon decay + # (str) Type of epsilon decay. type='exp', # (float) Start value for epsilon decay, in [0, 1]. - # 0 means not use epsilon decay. start=1, # (float) Start value for epsilon decay, in [0, 1]. end=0.05, - # (int) Decay length(env step) + # (int) Decay length(env step). decay=50000, ), replay_buffer=dict( + # (int) Maximum size of replay buffer. Usually, larger buffer size is better. replay_buffer_size=5000, - # (int) The maximum reuse times of each data - max_reuse=1e+9, - max_staleness=1e+9, ), ), ) @@ -117,17 +113,25 @@ def default_model(self) -> Tuple[str, List[str]]: def _init_learn(self) -> None: """ Overview: - Learn mode init method. Called by ``self.__init__``. - Init the learner model of QMIXPolicy - Arguments: - .. note:: + Initialize the learn mode of policy, including some attributes and modules. For QMIX, it mainly contains \ + optimizer, algorithm-specific arguments such as gamma, main and target model. Because of the use of RNN, \ + all the models should be wrappered with ``hidden_state`` which needs to be initialized with proper size. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. - The _init_learn method takes the argument from the self._cfg.learn in the config file + .. tip:: + For multi-agent algorithm, we often need to use ``agent_num`` to initialize some necessary variables. - - learning_rate (:obj:`float`): The learning rate fo the optimizer - - gamma (:obj:`float`): The discount factor + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. - agent_num (:obj:`int`): Since this is a multi-agent algorithm, we need to input the agent num. - - batch_size (:obj:`int`): Need batch size info to init hidden_state plugins """ self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight @@ -163,8 +167,8 @@ def _init_learn(self) -> None: self._learn_model.reset() self._target_model.reset() - def _data_preprocess_learn(self, data: List[Any]) -> dict: - r""" + def _data_preprocess_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """ Overview: Preprocess the data to fit the required data format for learning Arguments: @@ -181,22 +185,35 @@ def _data_preprocess_learn(self, data: List[Any]) -> dict: data['done'] = data['done'].float() return data - def _forward_learn(self, data: dict) -> Dict[str, Any]: - r""" + def _forward_learn(self, data: List[List[Dict[str, Any]]]) -> Dict[str, Any]: + """ Overview: - Forward and backward function of learn mode. + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data (trajectory for QMIX) from the replay buffer and then \ + returns the output result, including various training information such as loss, q value, grad_norm. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \ - np.ndarray or dict/list combinations. + - data (:obj:`List[List[Dict[int, Any]]]`): The input data used for policy forward, including a batch of \ + training samples. For each dict element, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the time and \ + batch dimension by the utility functions ``self._data_preprocess_learn``. \ + For QMIX, each element in list is a trajectory with the length of ``unroll_len``, and the element in \ + trajectory list is a dict containing at least the following keys: ``obs``, ``action``, ``prev_state``, \ + ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \ + and ``value_gamma``. Returns: - - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \ - recorded in text log and tensorboard, values are python scalar or a list of scalars. - ArgumentsKeys: - - necessary: ``obs``, ``next_obs``, ``action``, ``reward``, ``weight``, ``prev_state``, ``done`` - ReturnsKeys: - - necessary: ``cur_lr``, ``total_loss`` - - cur_lr (:obj:`float`): Current learning rate - - total_loss (:obj:`float`): The calculated loss + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for QMIXPolicy: ``ding.policy.tests.test_qmix``. """ data = self._data_preprocess_learn(data) # ==================== @@ -249,21 +266,24 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: } def _reset_learn(self, data_id: Optional[List[int]] = None) -> None: - r""" + """ Overview: - Reset learn model to the state indicated by data_id + Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the \ + memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \ + varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ + different trajectories in ``data_id`` will have different hidden state in RNN. Arguments: - - data_id (:obj:`Optional[List[int]]`): The id that store the state and we will reset\ - the model state to the state indicated by data_id + - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ + (i.e. RNN hidden_state in QMIX) specified by ``data_id``. """ self._learn_model.reset(data_id=data_id) def _state_dict_learn(self) -> Dict[str, Any]: - r""" + """ Overview: - Return the state_dict of learn mode, usually including model and optimizer. + Return the state_dict of learn mode, usually including model, target_model and optimizer. Returns: - - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. """ return { 'model': self._learn_model.state_dict(), @@ -276,7 +296,7 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: Overview: Load the state_dict variable into policy learn mode. Arguments: - - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before. + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. .. tip:: If you want to only load some parts of model, you can simply set the ``strict`` argument in \ @@ -288,11 +308,17 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: self._optimizer.load_state_dict(state_dict['optimizer']) def _init_collect(self) -> None: - r""" + """ Overview: - Collect mode init method. Called by ``self.__init__``. - Init traj and unroll length, collect model. - Enable the eps_greedy_sample and the hidden_state plugin. + Initialize the collect mode of policy, including related attributes and modules. For QMIX, it contains the \ + collect_model to balance the exploration and exploitation with epsilon-greedy sample mechanism and \ + maintain the hidden state of rnn. Besides, there are some initialization operations about other \ + algorithm-specific arguments such as burnin_step, unroll_len and nstep. + This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ + with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. """ self._unroll_len = self._cfg.collect.unroll_len self._collect_model = model_wrap( @@ -305,18 +331,34 @@ def _init_collect(self) -> None: self._collect_model = model_wrap(self._collect_model, wrapper_name='eps_greedy_sample') self._collect_model.reset() - def _forward_collect(self, data: dict, eps: float) -> dict: - r""" + def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: + """ Overview: - Forward function for collect mode with eps_greedy + Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ + that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ + data, such as the action to interact with the envs. Besides, this policy also needs ``eps`` argument for \ + exploration, i.e., classic epsilon-greedy exploration strategy. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ - values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. - - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step. + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + - eps (:obj:`float`): The epsilon value for exploration. Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs. - ReturnsKeys - - necessary: ``action`` + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ + other necessary data (prev_state) for learn mode defined in ``self._process_transition`` method. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + RNN's hidden states are maintained in the policy, so we don't need pass them into data but to reset the \ + hidden states with ``_reset_collect`` method when episode ends. Besides, the previous hidden states are \ + necessary for training, so we need to return them in ``_process_transition`` method. + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for QMIXPolicy: ``ding.policy.tests.test_qmix``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) @@ -332,43 +374,73 @@ def _forward_collect(self, data: dict, eps: float) -> dict: return {i: d for i, d in zip(data_id, output)} def _reset_collect(self, data_id: Optional[List[int]] = None) -> None: - r""" + """ Overview: - Reset collect model to the state indicated by data_id + Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the \ + memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \ + varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ + different environments/episodes in evaluation in ``data_id`` will have different hidden state in RNN. Arguments: - - data_id (:obj:`Optional[List[int]]`): The id that store the state and we will reset\ - the model state to the state indicated by data_id + - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ + (i.e., RNN hidden_state in QMIX) specified by ``data_id``. """ self._collect_model.reset(data_id=data_id) - def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: - r""" + def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor], + timestep: namedtuple) -> Dict[str, torch.Tensor]: + """ Overview: - Generate dict type transition data from inputs. + Process and pack one timestep transition data into a dict, which can be directly used for training and \ + saved in replay buffer. For QMIX, it contains obs, next_obs, action, prev_state, reward, done. Arguments: - - obs (:obj:`Any`): Env observation - - model_output (:obj:`dict`): Output of collect model, including at least ['action', 'prev_state'] - - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done']\ - (here 'obs' indicates obs after env step). + - obs (:obj:`torch.Tensor`): The env observation of current timestep, usually including ``agent_obs`` \ + and ``global_obs`` in multi-agent environment like MPE and SMAC. + - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ + as input. For QMIX, it contains the action and the prev_state of RNN. + - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ + except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ + reward, done, info, etc. Returns: - - transition (:obj:`dict`): Dict type transition data, including 'obs', 'next_obs', 'prev_state',\ - 'action', 'reward', 'done' + - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. """ transition = { 'obs': obs, 'next_obs': timestep.obs, - 'prev_state': model_output['prev_state'], - 'action': model_output['action'], + 'prev_state': policy_output['prev_state'], + 'action': policy_output['action'], 'reward': timestep.reward, 'done': timestep.done, } return transition + def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Overview: + For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ + can be used for training directly. In QMIX, a train sample is processed transitions with unroll_len \ + length. This method is usually used in collectors to execute necessary \ + RL data preprocessing before training, which can help learner amortize revelant time consumption. \ + In addition, you can also implement this method as an identity function and do the data processing \ + in ``self._forward_learn`` method. + Arguments: + - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ + the same format as the return value of ``self._process_transition`` method. + Returns: + - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each sample is a fixed-length \ + trajectory, and each element in a sample is the similar format as input transitions. + """ + return get_train_sample(transitions, self._unroll_len) + def _init_eval(self) -> None: - r""" + """ Overview: - Evaluate mode init method. Called by ``self.__init__``. - Init eval model with argmax strategy and the hidden_state plugin. + Initialize the eval mode of policy, including related attributes and modules. For QMIX, it contains the \ + eval model to greedily select action with argmax q_value mechanism and main the hidden state. + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. """ self._eval_model = model_wrap( self._model, @@ -381,16 +453,31 @@ def _init_eval(self) -> None: self._eval_model.reset() def _forward_eval(self, data: dict) -> dict: - r""" + """ Overview: - Forward function of eval mode, similar to ``self._forward_collect``. + Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ + means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ + action to interact with the envs. ``_forward_eval`` often use argmax sample method to get actions that \ + q_value is the highest. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ - values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. Returns: - - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. - ReturnsKeys - - necessary: ``action`` + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + RNN's hidden states are maintained in the policy, so we don't need pass them into data but to reset the \ + hidden states with ``_reset_eval`` method when the episode ends. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for QMIXPolicy: ``ding.policy.tests.test_qmix``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) @@ -406,31 +493,24 @@ def _forward_eval(self, data: dict) -> dict: return {i: d for i, d in zip(data_id, output)} def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: - r""" - Overview: - Reset eval model to the state indicated by data_id - Arguments: - - data_id (:obj:`Optional[List[int]]`): The id that store the state and we will reset\ - the model state to the state indicated by data_id """ - self._eval_model.reset(data_id=data_id) - - def _get_train_sample(self, data: list) -> Union[None, List[Any]]: - r""" Overview: - Get the train sample from trajectory. + Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the \ + memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \ + varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ + different environments/episodes in evaluation in ``data_id`` will have different hidden state in RNN. Arguments: - - data (:obj:`list`): The trajectory's cache - Returns: - - samples (:obj:`dict`): The training samples generated + - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ + (i.e., RNN hidden_state in QMIX) specified by ``data_id``. """ - return get_train_sample(data, self._unroll_len) + self._eval_model.reset(data_id=data_id) def _monitor_vars_learn(self) -> List[str]: - r""" + """ Overview: - Return variables' name if variables are to used in monitor. + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. Returns: - - vars (:obj:`List[str]`): Variables' name list. + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. """ return ['cur_lr', 'total_loss', 'total_q', 'target_total_q', 'grad_norm', 'target_reward_total_q'] diff --git a/ding/policy/r2d2.py b/ding/policy/r2d2.py index a7d49d2bb6..0726c2c820 100644 --- a/ding/policy/r2d2.py +++ b/ding/policy/r2d2.py @@ -15,16 +15,16 @@ @POLICY_REGISTRY.register('r2d2') class R2D2Policy(Policy): - r""" + """ Overview: Policy class of R2D2, from paper `Recurrent Experience Replay in Distributed Reinforcement Learning` . - R2D2 proposes that several tricks should be used to improve upon DRQN, - namely some recurrent experience replay tricks such as burn-in. + R2D2 proposes that several tricks should be used to improve upon DRQN, namely some recurrent experience replay \ + tricks and the burn-in mechanism for off-policy training. Config: == ==================== ======== ============== ======================================== ======================= ID Symbol Type Default Value Description Other(Shape) == ==================== ======== ============== ======================================== ======================= - 1 ``type`` str dqn | RL policy register name, refer to | This arg is optional, + 1 ``type`` str r2d2 | RL policy register name, refer to | This arg is optional, | registry ``POLICY_REGISTRY`` | a placeholder 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff- | erent from modes @@ -68,13 +68,10 @@ class R2D2Policy(Policy): cuda=False, # (bool) Whether the RL algorithm is on-policy or off-policy. on_policy=False, - # (bool) Whether use priority(priority sample, IS weight, update priority) + # (bool) Whether to use priority(priority sample, IS weight, update priority) priority=True, - # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. + # (bool) Whether to use Importance Sampling Weight to correct biased update. If True, priority must be True. priority_IS_weight=True, - # ============================================================== - # The following configs are algorithm-specific - # ============================================================== # (float) Reward's future discount factor, aka. gamma. discount_factor=0.997, # (int) N-step reward for target q_value estimation @@ -85,64 +82,103 @@ class R2D2Policy(Policy): # (int) the trajectory length to unroll the RNN network minus # the timestep of burnin operation learn_unroll_len=80, + # learn_mode config learn=dict( + # (int) The number of training updates (iterations) to perform after each data collection by the collector. + # A larger "update_per_collect" value implies a more off-policy approach. + # The whole pipeline process follows this cycle: collect data -> update policy -> collect data -> ... update_per_collect=1, + # (int) The number of samples in a training batch. batch_size=64, + # (float) The step size of gradient descent, determining the rate of learning. learning_rate=0.0001, - # ============================================================== - # The following configs are algorithm-specific - # ============================================================== # (int) Frequence of target network update. # target_update_freq=100, target_update_theta=0.001, # (bool) whether use value_rescale function for predicted value value_rescale=True, + # (bool) Whether ignore done(usually for max step termination env). + # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. + # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. + # However, interaction with HalfCheetah always gets done with done is False, + # Since we inplace done==True with done==False to keep + # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), + # when the episode step is greater than max episode step. ignore_done=False, ), + # collect_mode config collect=dict( - # NOTE: It is important that set key traj_len_inf=True here, + # (int) How many training samples collected in one collection procedure. + # In each collect phase, we collect a total of sequence samples. + n_sample=32, + # (bool) It is important that set key traj_len_inf=True here, # to make sure self._traj_len=INF in serial_sample_collector.py. # In R2D2 policy, for each collect_env, we want to collect data of length self._traj_len=INF # unless the episode enters the 'done' state. - # In each collect phase, we collect a total of sequence samples. - n_sample=32, traj_len_inf=True, - # `env_num` is used in hidden state, should equal to that one in env config. - # User should specify this value in user config. + # (int) `env_num` is used in hidden state, should equal to that one in env config (e.g. collector_env_num). + # User should specify this value in user config. `None` is a placeholder. env_num=None, ), + # eval_mode config eval=dict( - # `env_num` is used in hidden state, should equal to that one in env config. + # (int) `env_num` is used in hidden state, should equal to that one in env config (e.g. evaluator_env_num). # User should specify this value in user config. env_num=None, ), other=dict( + # Epsilon greedy with decay. eps=dict( + # (str) Type of decay. Supports either 'exp' (exponential) or 'linear'. type='exp', + # (float) Initial value of epsilon at the start. start=0.95, + # (float) Final value of epsilon after decay. end=0.05, + # (int) The number of environment steps over which epsilon should decay. decay=10000, ), - replay_buffer=dict(replay_buffer_size=10000, ), + replay_buffer=dict( + # (int) Maximum size of replay buffer. Usually, larger buffer size is better. + replay_buffer_size=10000, + ), ), ) def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. + + .. note:: + The user can define and use customized network model but must obey the same inferface definition indicated \ + by import_names path. For example about R2D2, its registered name is ``drqn`` and the import_names is \ + ``ding.model.template.q_learning``. + """ return 'drqn', ['ding.model.template.q_learning'] def _init_learn(self) -> None: - r""" + """ Overview: - Init the learner model of R2D2Policy - Arguments: - - learning_rate (:obj:`float`): The learning rate fo the optimizer - - gamma (:obj:`float`): The discount factor - - nstep (:obj:`int`): The num of n step return - - value_rescale (:obj:`bool`): Whether to use value rescaled loss in algorithm - - burnin_step (:obj:`int`): The num of step of burnin + Initialize the learn mode of policy, including some attributes and modules. For R2D2, it mainly contains \ + optimizer, algorithm-specific arguments such as burnin_step, value_rescale and gamma, main and target \ + model. Because of the use of RNN, all the models should be wrappered with ``hidden_state`` which needs to \ + be initialized with proper size. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. .. note:: - The _init_learn method takes the argument from the self._cfg.learn in the config file + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight @@ -174,16 +210,15 @@ def _init_learn(self) -> None: self._learn_model.reset() self._target_model.reset() - def _data_preprocess_learn(self, data: List[Dict[str, Any]]) -> dict: - r""" + def _data_preprocess_learn(self, data: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + """ Overview: Preprocess the data to fit the required data format for learning Arguments: - - data (:obj:`List[Dict[str, Any]]`): the data collected from collect function + - data (:obj:`List[Dict[str, Any]]`): The data collected from collect function Returns: - - data (:obj:`Dict[str, Any]`): the processed data, including at least \ + - data (:obj:`Dict[str, torch.Tensor]`): The processed data, including at least \ ['main_obs', 'target_obs', 'burnin_obs', 'action', 'reward', 'done', 'weight'] - - data_info (:obj:`dict`): the data info, such as replay_buffer_idx, replay_unique_id """ # data preprocess data = timestep_collate(data) @@ -241,18 +276,35 @@ def _data_preprocess_learn(self, data: List[Dict[str, Any]]) -> dict: return data - def _forward_learn(self, data: dict) -> Dict[str, Any]: - r""" + def _forward_learn(self, data: List[List[Dict[str, Any]]]) -> Dict[str, Any]: + """ Overview: - Forward and backward function of learn mode. - Acquire the data, calculate the loss and optimize learner model. + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data (trajectory for R2D2) from the replay buffer and then \ + returns the output result, including various training information such as loss, q value, priority. Arguments: - - data (:obj:`dict`): Dict type data, including at least \ - ['main_obs', 'target_obs', 'burnin_obs', 'action', 'reward', 'done', 'weight'] + - data (:obj:`List[List[Dict[int, Any]]]`): The input data used for policy forward, including a batch of \ + training samples. For each dict element, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the time and \ + batch dimension by the utility functions ``self._data_preprocess_learn``. \ + For R2D2, each element in list is a trajectory with the length of ``unroll_len``, and the element in \ + trajectory list is a dict containing at least the following keys: ``obs``, ``action``, ``prev_state``, \ + ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \ + and ``value_gamma``. Returns: - - info_dict (:obj:`Dict[str, Any]`): Including cur_lr and total_loss - - cur_lr (:obj:`float`): Current learning rate - - total_loss (:obj:`float`): The calculated loss + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for R2D2Policy: ``ding.policy.tests.test_r2d2``. """ # forward data = self._data_preprocess_learn(data) # output datatype: Dict @@ -343,23 +395,64 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: } def _reset_learn(self, data_id: Optional[List[int]] = None) -> None: + """ + Overview: + Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the \ + memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \ + varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ + different trajectories in ``data_id`` will have different hidden state in RNN. + Arguments: + - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ + (i.e. RNN hidden_state in R2D2) specified by ``data_id``. + """ + self._learn_model.reset(data_id=data_id) def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model, target_model and optimizer. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. + """ return { 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), 'optimizer': self._optimizer.state_dict(), } def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + + .. tip:: + If you want to only load some parts of model, you can simply set the ``strict`` argument in \ + load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ + complicated operation. + """ self._learn_model.load_state_dict(state_dict['model']) + self._target_model.load_state_dict(state_dict['target_model']) self._optimizer.load_state_dict(state_dict['optimizer']) def _init_collect(self) -> None: - r""" + """ Overview: - Collect mode init method. Called by ``self.__init__``. - Init traj and unroll length, collect model. + Initialize the collect mode of policy, including related attributes and modules. For R2D2, it contains the \ + collect_model to balance the exploration and exploitation with epsilon-greedy sample mechanism and \ + maintain the hidden state of rnn. Besides, there are some initialization operations about other \ + algorithm-specific arguments such as burnin_step, unroll_len and nstep. + This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ + with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. + + .. tip:: + Some variables need to initialize independently in different modes, such as gamma and nstep in R2D2. This \ + design is for the convenience of parallel execution of different policy modes. """ self._nstep = self._cfg.nstep self._burnin_step = self._cfg.burnin_step @@ -375,18 +468,34 @@ def _init_collect(self) -> None: self._collect_model = model_wrap(self._collect_model, wrapper_name='eps_greedy_sample') self._collect_model.reset() - def _forward_collect(self, data: dict, eps: float) -> dict: - r""" + def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: + """ Overview: - Forward function for collect mode with eps_greedy + Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ + that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ + data, such as the action to interact with the envs. Besides, this policy also needs ``eps`` argument for \ + exploration, i.e., classic epsilon-greedy exploration strategy. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ - values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. - - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step. + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + - eps (:obj:`float`): The epsilon value for exploration. Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs. - ReturnsKeys - - necessary: ``action`` + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ + other necessary data (prev_state) for learn mode defined in ``self._process_transition`` method. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + RNN's hidden states are maintained in the policy, so we don't need pass them into data but to reset the \ + hidden states with ``_reset_collect`` method when episode ends. Besides, the previous hidden states are \ + necessary for training, so we need to return them in ``_process_transition`` method. + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for R2D2Policy: ``ding.policy.tests.test_r2d2``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) @@ -404,62 +513,104 @@ def _forward_collect(self, data: dict, eps: float) -> dict: return {i: d for i, d in zip(data_id, output)} def _reset_collect(self, data_id: Optional[List[int]] = None) -> None: + """ + Overview: + Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the \ + memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \ + varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ + different environments/episodes in evaluation in ``data_id`` will have different hidden state in RNN. + Arguments: + - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ + (i.e., RNN hidden_state in R2D2) specified by ``data_id``. + """ self._collect_model.reset(data_id=data_id) - def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: - r""" + def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor], + timestep: namedtuple) -> Dict[str, torch.Tensor]: + """ Overview: - Generate dict type transition data from inputs. + Process and pack one timestep transition data into a dict, which can be directly used for training and \ + saved in replay buffer. For R2D2, it contains obs, action, prev_state, reward, and done. Arguments: - - obs (:obj:`Any`): Env observation - - model_output (:obj:`dict`): Output of collect model, including at least ['action', 'prev_state'] - - timestep (:obj:`namedtuple`): Output after env step, including at least ['reward', 'done'] \ - (here 'obs' indicates obs after env step). + - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. + - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network given the observation \ + as input. For R2D2, it contains the action and the prev_state of RNN. + - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ + except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ + reward, done, info, etc. Returns: - - transition (:obj:`dict`): Dict type transition data. + - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. """ transition = { 'obs': obs, - 'action': model_output['action'], - 'prev_state': model_output['prev_state'], + 'action': policy_output['action'], + 'prev_state': policy_output['prev_state'], 'reward': timestep.reward, 'done': timestep.done, } return transition - def _get_train_sample(self, data: list) -> Union[None, List[Any]]: - r""" + def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ Overview: - Get the trajectory and the n step return data, then sample from the n_step return data + For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ + can be used for training directly. In R2D2, a train sample is processed transitions with unroll_len \ + length. This method is usually used in collectors to execute necessary \ + RL data preprocessing before training, which can help learner amortize revelant time consumption. \ + In addition, you can also implement this method as an identity function and do the data processing \ + in ``self._forward_learn`` method. Arguments: - - data (:obj:`list`): The trajectory's cache + - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ + the same format as the return value of ``self._process_transition`` method. Returns: - - samples (:obj:`dict`): The training samples generated + - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each sample is a fixed-length \ + trajectory, and each element in a sample is the similar format as input transitions, but may contain \ + more data for training, such as nstep reward and value_gamma factor. """ - data = get_nstep_return_data(data, self._nstep, gamma=self._gamma) - return get_train_sample(data, self._unroll_len) + transitions = get_nstep_return_data(transitions, self._nstep, gamma=self._gamma) + return get_train_sample(transitions, self._unroll_len) def _init_eval(self) -> None: - r""" + """ Overview: - Evaluate mode init method. Called by ``self.__init__``. - Init eval model with argmax strategy. + Initialize the eval mode of policy, including related attributes and modules. For R2D2, it contains the \ + eval model to greedily select action with argmax q_value mechanism and main the hidden state. + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. """ self._eval_model = model_wrap(self._model, wrapper_name='hidden_state', state_num=self._cfg.eval.env_num) self._eval_model = model_wrap(self._eval_model, wrapper_name='argmax_sample') self._eval_model.reset() - def _forward_eval(self, data: dict) -> dict: - r""" + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ Overview: - Forward function of eval mode, similar to ``self._forward_collect``. + Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ + means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ + action to interact with the envs. ``_forward_eval`` often use argmax sample method to get actions that \ + q_value is the highest. Arguments: - - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ - values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. Returns: - - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. - ReturnsKeys - - necessary: ``action`` + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + RNN's hidden states are maintained in the policy, so we don't need pass them into data but to reset the \ + hidden states with ``_reset_eval`` method when the episode ends. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for R2D2Policy: ``ding.policy.tests.test_r2d2``. """ data_id = list(data.keys()) data = default_collate(list(data.values())) @@ -475,9 +626,26 @@ def _forward_eval(self, data: dict) -> dict: return {i: d for i, d in zip(data_id, output)} def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: + """ + Overview: + Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the \ + memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \ + varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \ + different environments/episodes in evaluation in ``data_id`` will have different hidden state in RNN. + Arguments: + - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \ + (i.e., RNN hidden_state in R2D2) specified by ``data_id``. + """ self._eval_model.reset(data_id=data_id) def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ return super()._monitor_vars_learn() + [ 'total_loss', 'priority', 'q_s_taken-a_t0', 'target_q_s_max-a_t0', 'q_s_a-mean_t0' ] diff --git a/ding/policy/sac.py b/ding/policy/sac.py index ebf2845e51..5b5dfe55c8 100644 --- a/ding/policy/sac.py +++ b/ding/policy/sac.py @@ -16,58 +16,21 @@ from .common_utils import default_preprocess_learn -@POLICY_REGISTRY.register('sac_discrete') -class SACDiscretePolicy(Policy): - r""" - Overview: - Policy class of discrete SAC algorithm. Paper link: https://arxiv.org/pdf/1910.07207.pdf. - - Config: - == ==================== ======== ============= ================================= ======================= - ID Symbol Type Default Value Description Other - == ==================== ======== ============= ================================= ======================= - 1 ``type`` str sac_discrete | RL policy register name, refer | this arg is optional, - | to registry ``POLICY_REGISTRY`` | a placeholder - 2 ``cuda`` bool True | Whether to use cuda for network | - 3 ``on_policy`` bool False | SACDiscrete is an off-policy | - | algorithm. | - 4 ``priority`` bool False | Whether to use priority | - | sampling in buffer. | - 5 | ``priority_IS_`` bool False | Whether use Importance Sampling | - | ``weight`` | weight to correct biased update | - 6 | ``random_`` int 10000 | Number of randomly collected | Default to 10000 for - | ``collect_size`` | training samples in replay | SAC, 25000 for DDPG/ - | | buffer when training starts. | TD3. - 7 | ``learn.learning`` float 3e-4 | Learning rate for soft q | Defalut to 1e-3 - | ``_rate_q`` | network. | - 8 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to 1e-3 - | ``_rate_policy`` | network. | - 9 | ``learn.alpha`` float 0.2 | Entropy regularization | alpha is initiali- - | | coefficient. | zation for auto - | | | `\alpha`, when - | | | auto_alpha is True - 10 | ``learn.`` bool False | Determine whether to use | Temperature parameter - | ``auto_alpha`` | auto temperature parameter | determines the - | | `\alpha`. | relative importance - | | | of the entropy term - | | | against the reward. - 11 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only - | ``ignore_done`` | done flag. | in env like Pendulum - 12 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation - | ``target_theta`` | target network. | factor in polyak aver - | | | aging for target - | | | networks. - == ==================== ======== ============= ================================= ======================= - """ +@POLICY_REGISTRY.register('discrete_sac') +class DiscreteSACPolicy(Policy): + """ + Overview: + Policy class of discrete SAC algorithm. Paper link: https://arxiv.org/abs/1910.07207. + """ config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). - type='sac_discrete', + type='discrete_sac', # (bool) Whether to use cuda for network and loss computation. cuda=False, - # (bool) Whether to belong to on-policy or off-policy algorithm, SACDiscrete is an off-policy algorithm. + # (bool) Whether to belong to on-policy or off-policy algorithm, DiscreteSAC is an off-policy algorithm. on_policy=False, - # (bool) Whether to use priority sampling in buffer. Default to False in SACDiscrete. + # (bool) Whether to use priority sampling in buffer. Default to False in DiscreteSAC. priority=False, # (bool) Whether use Importance Sampling weight to correct biased update. If True, priority must be True. priority_IS_weight=False, @@ -82,6 +45,7 @@ class SACDiscretePolicy(Policy): # For more details, please refer to TD3 about Clipped Double-Q Learning trick. twin_critic=True, ), + # learn_mode config learn=dict( # (int) How many updates (iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. @@ -123,8 +87,10 @@ class SACDiscretePolicy(Policy): # (float) Weight uniform initialization max range in the last output layer init_w=3e-3, ), + # collect_mode config collect=dict( # (int) How many training samples collected in one collection procedure. + # Only one of [n_sample, n_episode] shoule be set. n_sample=1, # (int) Split episodes or trajectories into pieces with length `unroll_len`. unroll_len=1, @@ -132,6 +98,7 @@ class SACDiscretePolicy(Policy): # In some algorithm like guided cost learning, we need to use logit to train the reward model. collector_logit=False, ), + eval=dict(), # for compability other=dict( replay_buffer=dict( # (int) Maximum size of replay buffer. Usually, larger buffer size is good @@ -142,6 +109,13 @@ class SACDiscretePolicy(Policy): ) def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. + """ if self._cfg.multi_agent: return 'discrete_maqac', ['ding.model.template.maqac'] else: @@ -150,8 +124,21 @@ def default_model(self) -> Tuple[str, List[str]]: def _init_learn(self) -> None: """ Overview: - Learn mode init method. Called by ``self.__init__``. - Init q function and policy's optimizers, algorithm config, main and target models. + Initialize the learn mode of policy, including related attributes and modules. For DiscreteSAC, it mainly \ + contains three optimizers, algorithm-specific arguments such as gamma and twin_critic, main and target \ + model. Especially, the ``auto_alpha`` mechanism for balancing max entropy target is also initialized here. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. """ self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight @@ -170,7 +157,7 @@ def _init_learn(self) -> None: self._gamma = self._cfg.learn.discount_factor if self._cfg.learn.auto_alpha: if self._cfg.learn.target_entropy is None: - assert 'action_shape' in self._cfg.model, "SACDiscrete need network model with action_shape variable" + assert 'action_shape' in self._cfg.model, "DiscreteSAC need network model with action_shape variable" self._target_entropy = -np.prod(self._cfg.model.action_shape) else: self._target_entropy = self._cfg.learn.target_entropy @@ -205,7 +192,35 @@ def _init_learn(self) -> None: self._learn_model.reset() self._target_model.reset() - def _forward_learn(self, data: dict) -> Dict[str, Any]: + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Overview: + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss, action, priority. + Arguments: + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For SAC, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ + ``logit``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys like ``weight``. + Returns: + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for DiscreteSACPolicy: \ + ``ding.policy.tests.test_discrete_sac``. + """ loss_dict = {} data = default_preprocess_learn( data, @@ -332,8 +347,15 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: } def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model, target_model and optimizers. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. + """ ret = { 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), 'optimizer_q': self._optimizer_q.state_dict(), 'optimizer_policy': self._optimizer_policy.state_dict(), } @@ -342,13 +364,36 @@ def _state_dict_learn(self) -> Dict[str, Any]: return ret def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + + .. tip:: + If you want to only load some parts of model, you can simply set the ``strict`` argument in \ + load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ + complicated operation. + """ self._learn_model.load_state_dict(state_dict['model']) + self._target_model.load_state_dict(state_dict['target_model']) self._optimizer_q.load_state_dict(state_dict['optimizer_q']) self._optimizer_policy.load_state_dict(state_dict['optimizer_policy']) if self._auto_alpha: self._alpha_optim.load_state_dict(state_dict['optimizer_alpha']) def _init_collect(self) -> None: + """ + Overview: + Initialize the collect mode of policy, including related attributes and modules. For SAC, it contains the \ + collect_model to balance the exploration and exploitation with the epsilon and multinomial sample \ + mechanism, and other algorithm-specific arguments such as unroll_len. \ + This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ + with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. + """ self._unroll_len = self._cfg.collect.unroll_len # Empirically, we found that eps_greedy_multinomial_sample works better than multinomial_sample # and eps_greedy_sample, and we don't divide logit by alpha, @@ -356,7 +401,32 @@ def _init_collect(self) -> None: self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_multinomial_sample') self._collect_model.reset() - def _forward_collect(self, data: dict, eps: float) -> dict: + def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: + """ + Overview: + Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ + that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ + data, such as the action to interact with the envs. Besides, this policy also needs ``eps`` argument for \ + exploration, i.e., classic epsilon-greedy exploration strategy. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + - eps (:obj:`float`): The epsilon value for exploration. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ + other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ + dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for DiscreteSACPolicy: \ + ``ding.policy.tests.test_discrete_sac``. + """ data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: @@ -369,25 +439,83 @@ def _forward_collect(self, data: dict, eps: float) -> dict: output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} - def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: + def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor], + timestep: namedtuple) -> Dict[str, torch.Tensor]: + """ + Overview: + Process and pack one timestep transition data into a dict, which can be directly used for training and \ + saved in replay buffer. For discrete SAC, it contains obs, next_obs, logit, action, reward, done. + Arguments: + - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. + - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ + as input. For discrete SAC, it contains the action and the logit of the action. + - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ + except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ + reward, done, info, etc. + Returns: + - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. + """ transition = { 'obs': obs, 'next_obs': timestep.obs, - 'action': model_output['action'], - 'logit': model_output['logit'], + 'action': policy_output['action'], + 'logit': policy_output['logit'], 'reward': timestep.reward, 'done': timestep.done, } return transition - def _get_train_sample(self, data: list) -> Union[None, List[Any]]: - return get_train_sample(data, self._unroll_len) + def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Overview: + For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ + can be used for training directly. In discrete SAC, a train sample is a processed transition (unroll_len=1). + Arguments: + - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ + the same format as the return value of ``self._process_transition`` method. + Returns: + - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \ + as input transitions, but may contain more data for training. + """ + return get_train_sample(transitions, self._unroll_len) def _init_eval(self) -> None: + """ + Overview: + Initialize the eval mode of policy, including related attributes and modules. For DiscreteSAC, it contains \ + the eval model to greedily select action type with argmax q_value mechanism. + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. + """ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') self._eval_model.reset() - def _forward_eval(self, data: dict) -> dict: + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ + Overview: + Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ + means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ + action to interact with the envs. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for DiscreteSACPolicy: \ + ``ding.policy.tests.test_discrete_sac``. + """ data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: @@ -401,6 +529,13 @@ def _forward_eval(self, data: dict) -> dict: return {i: d for i, d in zip(data_id, output)} def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ twin_critic = ['twin_critic_loss'] if self._twin_critic else [] if self._auto_alpha: return super()._monitor_vars_learn() + [ @@ -416,11 +551,11 @@ def _monitor_vars_learn(self) -> List[str]: @POLICY_REGISTRY.register('sac') class SACPolicy(Policy): - r""" - Overview: - Policy class of continuous SAC algorithm. Paper link: https://arxiv.org/pdf/1801.01290.pdf + """ + Overview: + Policy class of continuous SAC algorithm. Paper link: https://arxiv.org/pdf/1801.01290.pdf - Config: + Config: == ==================== ======== ============= ================================= ======================= ID Symbol Type Default Value Description Other == ==================== ======== ============= ================================= ======================= @@ -442,11 +577,11 @@ class SACPolicy(Policy): | ``_rate_policy`` | network. | 9 | ``learn.alpha`` float 0.2 | Entropy regularization | alpha is initiali- | | coefficient. | zation for auto - | | | `\alpha`, when + | | | alpha, when | | | auto_alpha is True 10 | ``learn.`` bool False | Determine whether to use | Temperature parameter | ``auto_alpha`` | auto temperature parameter | determines the - | | `\alpha`. | relative importance + | | alpha. | relative importance | | | of the entropy term | | | against the reward. 11 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only @@ -456,7 +591,7 @@ class SACPolicy(Policy): | | | aging for target | | | networks. == ==================== ======== ============= ================================= ======================= - """ + """ config = dict( # (str) RL policy register name (refer to function "POLICY_REGISTRY"). @@ -482,6 +617,7 @@ class SACPolicy(Policy): # (str) Use reparameterization trick for continous action. action_space='reparameterization', ), + # learn_mode config learn=dict( # (int) How many updates (iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. @@ -523,6 +659,7 @@ class SACPolicy(Policy): # (float) Weight uniform initialization max range in the last output layer. init_w=3e-3, ), + # collect_mode config collect=dict( # (int) How many training samples collected in one collection procedure. n_sample=1, @@ -532,6 +669,7 @@ class SACPolicy(Policy): # In some algorithm like guided cost learning, we need to use logit to train the reward model. collector_logit=False, ), + eval=dict(), # for compability other=dict( replay_buffer=dict( # (int) Maximum size of replay buffer. Usually, larger buffer size is good @@ -542,12 +680,37 @@ class SACPolicy(Policy): ) def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ + automatically call this method to get the default model setting and create model. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. + """ if self._cfg.multi_agent: return 'continuous_maqac', ['ding.model.template.maqac'] else: return 'continuous_qac', ['ding.model.template.qac'] def _init_learn(self) -> None: + """ + Overview: + Initialize the learn mode of policy, including related attributes and modules. For SAC, it mainly \ + contains three optimizers, algorithm-specific arguments such as gamma and twin_critic, main and target \ + model. Especially, the ``auto_alpha`` mechanism for balancing max entropy target is also initialized here. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. + """ self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight self._twin_critic = self._cfg.model.twin_critic @@ -607,7 +770,34 @@ def _init_learn(self) -> None: self._learn_model.reset() self._target_model.reset() - def _forward_learn(self, data: dict) -> Dict[str, Any]: + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Overview: + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss, action, priority. + Arguments: + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For SAC, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ + ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight``. + Returns: + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``. + """ loss_dict = {} data = default_preprocess_learn( data, @@ -730,6 +920,12 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: } def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model, target_model and optimizers. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. + """ ret = { 'model': self._learn_model.state_dict(), 'target_model': self._target_model.state_dict(), @@ -741,6 +937,17 @@ def _state_dict_learn(self) -> Dict[str, Any]: return ret def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + + .. tip:: + If you want to only load some parts of model, you can simply set the ``strict`` argument in \ + load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ + complicated operation. + """ self._learn_model.load_state_dict(state_dict['model']) self._target_model.load_state_dict(state_dict['target_model']) self._optimizer_q.load_state_dict(state_dict['optimizer_q']) @@ -749,11 +956,46 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: self._alpha_optim.load_state_dict(state_dict['optimizer_alpha']) def _init_collect(self) -> None: + """ + Overview: + Initialize the collect mode of policy, including related attributes and modules. For SAC, it contains the \ + collect_model other algorithm-specific arguments such as unroll_len. \ + This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ + with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. + """ self._unroll_len = self._cfg.collect.unroll_len self._collect_model = model_wrap(self._model, wrapper_name='base') self._collect_model.reset() - def _forward_collect(self, data: dict) -> dict: + def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: + """ + Overview: + Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ + that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ + data, such as the action to interact with the envs. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ + other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ + dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + ``logit`` in SAC means the mu and sigma of Gaussioan distribution. Here we use this name for consistency. + + .. note:: + For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``. + """ data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: @@ -769,7 +1011,23 @@ def _forward_collect(self, data: dict) -> dict: output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} - def _process_transition(self, obs: Any, policy_output: dict, timestep: namedtuple) -> dict: + def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor], + timestep: namedtuple) -> Dict[str, torch.Tensor]: + """ + Overview: + Process and pack one timestep transition data into a dict, which can be directly used for training and \ + saved in replay buffer. For continuous SAC, it contains obs, next_obs, action, reward, done. The logit \ + will be also added when ``collector_logit`` is True. + Arguments: + - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari. + - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \ + as input. For continuous SAC, it contains the action and the logit (mu and sigma) of the action. + - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \ + except all the elements have been transformed into tensor data. Usually, it contains the next obs, \ + reward, done, info, etc. + Returns: + - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep. + """ if self._cfg.collect.collector_logit: transition = { 'obs': obs, @@ -789,14 +1047,60 @@ def _process_transition(self, obs: Any, policy_output: dict, timestep: namedtupl } return transition - def _get_train_sample(self, data: list) -> Union[None, List[Any]]: - return get_train_sample(data, self._unroll_len) + def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Overview: + For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \ + can be used for training directly. In continuous SAC, a train sample is a processed transition \ + (unroll_len=1). + Arguments: + - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \ + the same format as the return value of ``self._process_transition`` method. + Returns: + - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \ + as input transitions, but may contain more data for training. + """ + return get_train_sample(transitions, self._unroll_len) def _init_eval(self) -> None: + """ + Overview: + Initialize the eval mode of policy, including related attributes and modules. For SAC, it contains the \ + eval model, which is equipped with ``base`` model wrapper to ensure compability. + This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \ + with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``. + """ self._eval_model = model_wrap(self._model, wrapper_name='base') self._eval_model.reset() - def _forward_eval(self, data: dict) -> dict: + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ + Overview: + Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \ + means that the policy gets some necessary data (mainly observation) from the envs and then returns the \ + action to interact with the envs. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \ + key of the dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + ``logit`` in SAC means the mu and sigma of Gaussioan distribution. Here we use this name for consistency. + + .. note:: + For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``. + """ data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: @@ -812,6 +1116,13 @@ def _forward_eval(self, data: dict) -> dict: return {i: d for i, d in zip(data_id, output)} def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ twin_critic = ['twin_critic_loss'] if self._twin_critic else [] alpha_loss = ['alpha_loss'] if self._auto_alpha else [] return [ @@ -830,8 +1141,32 @@ def _monitor_vars_learn(self) -> List[str]: @POLICY_REGISTRY.register('sqil_sac') class SQILSACPolicy(SACPolicy): + """ + Overview: + Policy class of continuous SAC algorithm with SQIL extension. + SAC paper link: https://arxiv.org/pdf/1801.01290.pdf + SQIL paper link: https://arxiv.org/abs/1905.11108 + """ def _init_learn(self) -> None: + """ + Overview: + Initialize the learn mode of policy, including related attributes and modules. For SAC, it mainly \ + contains three optimizers, algorithm-specific arguments such as gamma and twin_critic, main and target \ + model. Especially, the ``auto_alpha`` mechanism for balancing max entropy target is also initialized here. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. + """ self._priority = self._cfg.priority self._priority_IS_weight = self._cfg.priority_IS_weight self._twin_critic = self._cfg.model.twin_critic @@ -856,7 +1191,7 @@ def _init_learn(self) -> None: self._gamma = self._cfg.learn.discount_factor if self._cfg.learn.auto_alpha: if self._cfg.learn.target_entropy is None: - assert 'action_shape' in self._cfg.model, "SACDiscrete need network model with action_shape variable" + assert 'action_shape' in self._cfg.model, "SQILSACPolicy need network model with action_shape variable" self._target_entropy = -np.prod(self._cfg.model.action_shape) else: self._target_entropy = self._cfg.learn.target_entropy @@ -895,7 +1230,38 @@ def _init_learn(self) -> None: self._monitor_cos = True self._monitor_entropy = True - def _forward_learn(self, data: dict) -> Dict[str, Any]: + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Overview: + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the replay buffer and then returns the output \ + result, including various training information such as loss, action, priority. + Arguments: + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For SAC, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ + ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight``. + Returns: + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + For SQIL + SAC, input data is composed of two parts with the same size: agent data and expert data. \ + Both of them are relabelled with new reward according to SQIL algorithm. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``. + """ loss_dict = {} if self._monitor_cos: agent_data = default_preprocess_learn( @@ -1094,6 +1460,13 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: return var_monitor def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ twin_critic = ['twin_critic_loss'] if self._twin_critic else [] alpha_loss = ['alpha_loss'] if self._auto_alpha else [] cos_similarity = ['cos_similarity'] if self._monitor_cos else [] diff --git a/ding/policy/td3.py b/ding/policy/td3.py index 632e4b3c22..7359190282 100644 --- a/ding/policy/td3.py +++ b/ding/policy/td3.py @@ -5,17 +5,11 @@ @POLICY_REGISTRY.register('td3') class TD3Policy(DDPGPolicy): - r""" + """ Overview: - Policy class of TD3 algorithm. - - Since DDPG and TD3 share many common things, we can easily derive this TD3 + Policy class of TD3 algorithm. Since DDPG and TD3 share many common things, we can easily derive this TD3 \ class from DDPG class by changing ``_actor_update_freq``, ``_twin_critic`` and noise in model wrapper. - - https://arxiv.org/pdf/1802.09477.pdf - - Property: - learn_mode, collect_mode, eval_mode + Paper link: https://arxiv.org/pdf/1802.09477.pdf Config: @@ -68,9 +62,7 @@ class from DDPG class by changing ``_actor_update_freq``, ``_twin_critic`` and n type='td3', # (bool) Whether to use cuda for network. cuda=False, - # (bool type) on_policy: Determine whether on-policy or off-policy. - # on-policy setting influences the behaviour of buffer. - # Default False in TD3. + # (bool) on_policy: Determine whether on-policy or off-policy. Default False in TD3. on_policy=False, # (bool) Whether use priority(priority sample, IS weight, update priority) # Default False in TD3. @@ -80,6 +72,8 @@ class from DDPG class by changing ``_actor_update_freq``, ``_twin_critic`` and n # (int) Number of training samples(randomly collected) in replay buffer when training starts. # Default 25000 in DDPG/TD3. random_collect_size=25000, + # (bool) Whether to need policy data in process transition. + transition_with_policy_data=False, # (str) Action space type action_space='continuous', # ['continuous', 'hybrid'] # (bool) Whether use batch normalization for reward @@ -92,9 +86,9 @@ class from DDPG class by changing ``_actor_update_freq``, ``_twin_critic`` and n # Default True for TD3, False for DDPG. twin_critic=True, ), + # learn_mode config learn=dict( - - # How many updates(iterations) to train after collector's one collection. + # (int) How many updates(iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. # collect data -> update policy-> collect data -> ... update_per_collect=1, @@ -112,7 +106,7 @@ class from DDPG class by changing ``_actor_update_freq``, ``_twin_critic`` and n # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), # when the episode step is greater than max episode step. ignore_done=False, - # (float type) target_theta: Used for soft update of the target network, + # (float) target_theta: Used for soft update of the target network, # aka. Interpolation factor in polyak averaging for target networks. # Default to 0.005. target_theta=0.005, @@ -130,30 +124,37 @@ class from DDPG class by changing ``_actor_update_freq``, ``_twin_critic`` and n noise_sigma=0.2, # (dict) Limit for range of target policy smoothing noise, aka. noise_clip. noise_range=dict( + # (int) min value of noise min=-0.5, + # (int) max value of noise max=0.5, ), ), + # collect_mode config collect=dict( + # (int) How many training samples collected in one collection procedure. + # Only one of [n_sample, n_episode] shoule be set. # n_sample=1, # (int) Cut trajectories into pieces with length "unroll_len". unroll_len=1, # (float) It is a must to add noise during collection. So here omits "noise" and only set "noise_sigma". noise_sigma=0.1, ), - eval=dict( - evaluator=dict( - # (int) Evaluate every "eval_freq" training iterations. - eval_freq=5000, - ), - ), + eval=dict(), # for compability other=dict( replay_buffer=dict( - # (int) Maximum size of replay buffer. + # (int) Maximum size of replay buffer. Usually, larger buffer size is better. replay_buffer_size=100000, ), ), ) def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ return ["q_value", "loss", "lr", "entropy", "target_q_value", "td_error"] diff --git a/ding/policy/tests/test_cql.py b/ding/policy/tests/test_cql.py index 949e4dce6d..248653da6a 100644 --- a/ding/policy/tests/test_cql.py +++ b/ding/policy/tests/test_cql.py @@ -3,7 +3,7 @@ import pytest import torch from easydict import EasyDict -from ding.policy.cql import CQLPolicy, CQLDiscretePolicy +from ding.policy.cql import CQLPolicy, DiscreteCQLPolicy from ding.utils.data import offline_data_save_type from tensorboardX import SummaryWriter from ding.model.wrapper.model_wrappers import ArgmaxSampleWrapper, EpsGreedySampleWrapper, TargetNetworkWrapper @@ -23,7 +23,7 @@ cfg2.learn.auto_alpha = False cfg2.learn.log_space = False -cfg3 = EasyDict(CQLDiscretePolicy.default_config()) +cfg3 = EasyDict(DiscreteCQLPolicy.default_config()) cfg3.model = {} cfg3.model.obs_shape = obs_space cfg3.model.action_shape = action_space @@ -89,7 +89,7 @@ def get_transition_discrete(size=20): @pytest.mark.parametrize('cfg', [cfg3, cfg4]) @pytest.mark.unittest def test_cql_discrete(cfg): - policy = CQLDiscretePolicy(cfg, enable_field=['collect', 'eval', 'learn']) + policy = DiscreteCQLPolicy(cfg, enable_field=['collect', 'eval', 'learn']) assert type(policy._learn_model) == ArgmaxSampleWrapper assert type(policy._target_model) == TargetNetworkWrapper assert type(policy._collect_model) == EpsGreedySampleWrapper diff --git a/ding/rl_utils/td.py b/ding/rl_utils/td.py index d7274e6a92..7c5b995eaa 100644 --- a/ding/rl_utils/td.py +++ b/ding/rl_utils/td.py @@ -718,10 +718,12 @@ def bdq_nstep_td_error( ) -> torch.Tensor: """ Overview: - Multistep (1 step or n step) td_error for BDQ algorithm, referenced paper "Action Branching Architectures \ - for Deep Reinforcement Learning", link: https://arxiv.org/pdf/1711.08946. + Multistep (1 step or n step) td_error for BDQ algorithm, referenced paper "Action Branching Architectures for \ + Deep Reinforcement Learning", link: https://arxiv.org/pdf/1711.08946. In fact, the original paper only provides the 1-step TD-error calculation method, and here we extend the \ - calculation method of n-step TD-error. + calculation method of n-step, i.e., TD-error: + :math:`y_d = \sigma_{t=0}^{nstep} \gamma^t * r_t + \gamma^{nstep} * Q_d'(s', argmax Q_d(s', a_d))` + :math:`TD-error = \frac{1}{D} * (y_d - Q_d(s, a_d))^2` Arguments: - data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss - gamma (:obj:`float`): Discount factor diff --git a/dizoo/classic_control/cartpole/config/cartpole_cql_config.py b/dizoo/classic_control/cartpole/config/cartpole_cql_config.py index ec804a6339..0b1932e5ad 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_cql_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_cql_config.py @@ -42,7 +42,7 @@ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'], ), env_manager=dict(type='base'), - policy=dict(type='cql_discrete'), + policy=dict(type='discrete_cql'), ) cartpole_discrete_cql_create_config = EasyDict(cartpole_discrete_cql_create_config) create_config = cartpole_discrete_cql_create_config diff --git a/dizoo/classic_control/cartpole/config/cartpole_sac_config.py b/dizoo/classic_control/cartpole/config/cartpole_sac_config.py index 36dcb53be6..74c5281577 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_sac_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_sac_config.py @@ -61,7 +61,7 @@ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'], ), env_manager=dict(type='base'), - policy=dict(type='sac_discrete'), + policy=dict(type='discrete_sac'), ) cartpole_sac_create_config = EasyDict(cartpole_sac_create_config) create_config = cartpole_sac_create_config