Skip to content

Commit

Permalink
polish(rjy): polish pg/iqn/edac policy doc (#764)
Browse files Browse the repository at this point in the history
* polish(rjy): polish pg/iqn/edac policy doc
  • Loading branch information
nighood authored Jan 29, 2024
1 parent 74c6a1e commit acd23e5
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 40 deletions.
51 changes: 46 additions & 5 deletions ding/policy/edac.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
class EDACPolicy(SACPolicy):
"""
Overview:
Policy class of EDAC algorithm. https://arxiv.org/pdf/2110.01548.pdf
Policy class of EDAC algorithm. Paper link: https://arxiv.org/pdf/2110.01548.pdf
Config:
== ==================== ======== ============= ================================= =======================
Expand Down Expand Up @@ -149,18 +149,59 @@ def default_model(self) -> Tuple[str, List[str]]:
return 'edac', ['ding.model.template.edac']

def _init_learn(self) -> None:
r"""
"""
Overview:
Learn mode init method. Called by ``self.__init__``.
Init q, value and policy's optimizers, algorithm config, main and target models.
Initialize the learn mode of policy, including related attributes and modules. For EDAC, in addition \
to the things that need to be initialized in SAC, it is also necessary to additionally define \
eta/with_q_entropy/forward_learn_cnt. \
This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
.. note::
For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
and ``_load_state_dict_learn`` methods.
.. note::
For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
.. note::
If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
"""
super()._init_learn()
# EDAC special implementation
self._eta = self._cfg.learn.eta
self._with_q_entropy = self._cfg.learn.with_q_entropy
self._forward_learn_cnt = 0

def _forward_learn(self, data: dict) -> Dict[str, Any]:
def _forward_learn(self, data: List[Dict[int, Any]]) -> Dict[str, Any]:
"""
Overview:
Policy forward function of learn mode (training policy and updating parameters). Forward means \
that the policy inputs some training batch data from the replay buffer and then returns the output \
result, including various training information such as loss, action, priority.
Arguments:
- data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
training samples. For each element in list, the key of the dict is the name of data items and the \
value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \
dimension by some utility functions such as ``default_preprocess_learn``. \
For EDAC, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \
``logit``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys like ``weight``.
Returns:
- info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
.. note::
The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
For the data type that not supported, the main reason is that the corresponding model does not support it. \
You can implement you own model rather than use the default model. For more information, please raise an \
issue in GitHub repo and we will continue to follow up.
.. note::
For more detailed examples, please refer to our unittest for EDACPolicy: \
``ding.policy.tests.test_edac``.
"""
loss_dict = {}
data = default_preprocess_learn(
data,
Expand Down
85 changes: 75 additions & 10 deletions ding/policy/iqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@

@POLICY_REGISTRY.register('iqn')
class IQNPolicy(DQNPolicy):
r"""
"""
Overview:
Policy class of IQN algorithm.
Policy class of IQN algorithm. Paper link: https://arxiv.org/pdf/1806.06923.pdf. \
Distrbutional RL is a new direction of RL, which is more stable than the traditional RL algorithm. \
The core idea of distributional RL is to estimate the distribution of action value instead of the \
expectation. The difference between IQN and DQN is that IQN uses quantile regression to estimate the \
quantile value of the action distribution, while DQN uses the expectation of the action distribution. \
Config:
== ==================== ======== ============== ======================================== =======================
Expand Down Expand Up @@ -98,13 +102,37 @@ class IQNPolicy(DQNPolicy):
)

def default_model(self) -> Tuple[str, List[str]]:
"""
Overview:
Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
automatically call this method to get the default model setting and create model.
Returns:
- model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
.. note::
The user can define and use customized network model but must obey the same inferface definition indicated \
by import_names path. For example about IQN, its registered name is ``iqn`` and the import_names is \
``ding.model.template.q_learning``.
"""
return 'iqn', ['ding.model.template.q_learning']

def _init_learn(self) -> None:
r"""
"""
Overview:
Learn mode init method. Called by ``self.__init__``.
Init the optimizer, algorithm config, main and target models.
Initialize the learn mode of policy, including related attributes and modules. For IQN, it mainly contains \
optimizer, algorithm-specific arguments such as nstep, kappa and gamma, main and target model.
This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
.. note::
For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
and ``_load_state_dict_learn`` methods.
.. note::
For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
.. note::
If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
"""
self._priority = self._cfg.priority
# Optimizer
Expand All @@ -126,14 +154,34 @@ def _init_learn(self) -> None:
self._learn_model.reset()
self._target_model.reset()

def _forward_learn(self, data: dict) -> Dict[str, Any]:
r"""
def _forward_learn(self, data: List[Dict[int, Any]]) -> Dict[str, Any]:
"""
Overview:
Forward and backward function of learn mode.
Policy forward function of learn mode (training policy and updating parameters). Forward means \
that the policy inputs some training batch data from the replay buffer and then returns the output \
result, including various training information such as loss, priority.
Arguments:
- data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs']
- data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
training samples. For each element in list, the key of the dict is the name of data items and the \
value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \
dimension by some utility functions such as ``default_preprocess_learn``. \
For IQN, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \
``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \
and ``value_gamma``.
Returns:
- info_dict (:obj:`Dict[str, Any]`): Including current lr and loss.
- info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
.. note::
The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
For the data type that not supported, the main reason is that the corresponding model does not support it. \
You can implement you own model rather than use the default model. For more information, please raise an \
issue in GitHub repo and we will continue to follow up.
.. note::
For more detailed examples, please refer to our unittest for IQNPolicy: ``ding.policy.tests.test_iqn``.
"""
data = default_preprocess_learn(
data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True
Expand Down Expand Up @@ -186,13 +234,30 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:
}

def _state_dict_learn(self) -> Dict[str, Any]:
"""
Overview:
Return the state_dict of learn mode, usually including model, target_model and optimizer.
Returns:
- state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring.
"""
return {
'model': self._learn_model.state_dict(),
'target_model': self._target_model.state_dict(),
'optimizer': self._optimizer.state_dict(),
}

def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
"""
Overview:
Load the state_dict variable into policy learn mode.
Arguments:
- state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before.
.. tip::
If you want to only load some parts of model, you can simply set the ``strict`` argument in \
load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
complicated operation.
"""
self._learn_model.load_state_dict(state_dict['model'])
self._target_model.load_state_dict(state_dict['target_model'])
self._optimizer.load_state_dict(state_dict['optimizer'])
Loading

0 comments on commit acd23e5

Please sign in to comment.