Skip to content

Commit

Permalink
polish(nyz): polish api doc comments problems
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Nov 1, 2023
1 parent c5a4be3 commit e9a978e
Show file tree
Hide file tree
Showing 13 changed files with 68 additions and 223 deletions.
14 changes: 7 additions & 7 deletions ding/envs/env/ding_env_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@ class DingEnvWrapper(BaseEnv):
def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True, caller: str = 'collector') -> None:
"""
Overview:
Initialize the DingEnvWrapper. Either an environment instance or a config to create the environment
instance should be passed in:
- An environment instance: The `env` parameter must not be `None`, but should be the instance.
It does not support subprocess environment manager. Thus, it is usually used in simple environments.
- A config to create an environment instance: The `cfg` parameter must contain `env_id`.
Initialize the DingEnvWrapper. Either an environment instance or a config to create the environment \
instance should be passed in. For the former, i.e., an environment instance: The `env` parameter must not \
be `None`, but should be the instance. It does not support subprocess environment manager. Thus, it is \
usually used in simple environments. For the latter, i.e., a config to create an environment instance: \
The `cfg` parameter must contain `env_id`.
Arguments:
- env (:obj:`gym.Env`): An environment instance to be wrapped.
- cfg (:obj:`dict`): The configuration dictionary to create an environment instance.
- seed_api (:obj:`bool`): Whether to use seed API. Defaults to True.
- caller (:obj:`str`): A string representing the caller of this method, including ``collector`` or
- caller (:obj:`str`): A string representing the caller of this method, including ``collector`` or \
``evaluator``. Different caller may need different wrappers. Default is 'collector'.
"""
self._env = None
Expand All @@ -44,7 +44,7 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True,
self._seed_api = seed_api # some env may disable `env.seed` api
self._caller = caller
if self._cfg is None:
self._cfg = dict()
self._cfg = {}
self._cfg = EasyDict(self._cfg)
if 'act_scale' not in self._cfg:
self._cfg.act_scale = False
Expand Down
6 changes: 5 additions & 1 deletion ding/envs/env_manager/base_env_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,9 @@ def closed(self) -> bool:
"""
return self._closed

def random_action(self) -> Dict:
return {env_id: self._env_ref.action_space.sample() for env_id in self.ready_obs_id}


@ENV_MANAGER_REGISTRY.register('base_v2')
class BaseEnvManagerV2(BaseEnvManager):
Expand All @@ -577,7 +580,8 @@ class BaseEnvManagerV2(BaseEnvManager):
.. note::
For more details about new task pipeline, please refer to the system document of DI-engine \
(`en link <../03_system/index.html>`_).
(`system en link <../03_system/index.html>`_).
Interfaces:
reset, step, seed, close, enable_save_replay, launch, default_config, reward_shaping, enable_save_figure
Properties:
Expand Down
2 changes: 2 additions & 0 deletions ding/envs/env_wrappers/env_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
- GymToGymnasiumWrapper: Adapts environments from the Gym library to be compatible with the Gymnasium library.
- AllinObsWrapper: Consolidates all information into the observation, useful for environments where the agent's
observation should include additional information such as the current score or time remaining.
- ObsPlusPrevActRewWrapper: This wrapper is used in policy NGU. It sets a dict as the new wrapped observation,
which includes the current observation, previous action and previous reward.
"""

import copy
Expand Down
1 change: 1 addition & 0 deletions ding/model/common/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,6 +1293,7 @@ def forward(self, key: torch.Tensor, query: torch.Tensor) -> torch.Tensor:
>>> query = torch.randn(4, 64)
>>> logit = head(key, query)
>>> assert logit.shape == torch.Size([4, 5])
.. note::
In this head, we assume that the ``key`` and ``query`` tensor are both normalized.
"""
Expand Down
1 change: 1 addition & 0 deletions ding/model/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def create_model(cfg: EasyDict) -> torch.nn.Module:
>>> 'action_shape': 2,
>>> })
>>> model = create_model(cfg)
.. tip::
This method will not modify the ``cfg`` , it will deepcopy the ``cfg`` and then modify it.
"""
Expand Down
1 change: 1 addition & 0 deletions ding/model/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@
from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS
from .bcq import BCQ
from .edac import EDAC
from .ebm import EBM, AutoregressiveEBM
33 changes: 4 additions & 29 deletions ding/model/template/acer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,40 +85,15 @@ def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict:
Use observation to predict output.
Parameter updates with ACER's MLPs forward setup.
Arguments:
Forward with ``'compute_actor'``:
- inputs (:obj:`torch.Tensor`):
The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``.
Whether ``actor_head_hidden_size`` or ``critic_head_hidden_size`` depend on ``mode``.
Forward with ``'compute_critic'``, inputs:`torch.Tensor` Necessary Keys:
- ``obs`` encoded tensors.
- mode (:obj:`str`): Name of the forward mode.
Returns:
- outputs (:obj:`Dict`): Outputs of network forward.
Forward with ``'compute_actor'``, Necessary Keys (either):
- logit (:obj:`torch.Tensor`):
- logit (:obj:`torch.Tensor`): Logit encoding tensor.
Forward with ``'compute_critic'``, Necessary Keys:
- q_value (:obj:`torch.Tensor`): Q value tensor.
Actor Shapes:
Shapes (Actor):
- obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``
- logit (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``
Critic Shapes:
Shapes (Critic):
- inputs (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``obs_shape``
- q_value (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``
Actor Examples:
>>> # Regression mode
>>> model = ACER(64, 64)
>>> inputs = torch.randn(4, 64)
>>> actor_outputs = model(inputs,'compute_actor')
>>> assert actor_outputs['logit'].shape == torch.Size([4, 64])
Critic Examples:
>>> inputs = torch.randn(4,N)
>>> model = ACER(obs_shape=(N, ),action_shape=5)
>>> model(inputs, mode='compute_critic')['q_value']
"""
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
return getattr(self, mode)(inputs)
Expand All @@ -127,7 +102,7 @@ def compute_actor(self, inputs: torch.Tensor) -> Dict:
"""
Overview:
Use encoded embedding tensor to predict output.
Execute parameter updates with ``'compute_actor'`` mode
Execute parameter updates with ``compute_actor`` mode
Use encoded embedding tensor to predict output.
Arguments:
- inputs (:obj:`torch.Tensor`):
Expand Down Expand Up @@ -156,7 +131,7 @@ def compute_actor(self, inputs: torch.Tensor) -> Dict:
def compute_critic(self, inputs: torch.Tensor) -> Dict:
"""
Overview:
Execute parameter updates with ``'compute_critic'`` mode
Execute parameter updates with ``compute_critic`` mode
Use encoded embedding tensor to predict output.
Arguments:
- ``obs``, ``action`` encoded tensors.
Expand Down
Loading

0 comments on commit e9a978e

Please sign in to comment.