Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/opendilab/DI-engine into de…
Browse files Browse the repository at this point in the history
…v-ckpt
  • Loading branch information
puyuan1996 committed Oct 28, 2024
2 parents b09ffda + 3898386 commit 6b9f509
Show file tree
Hide file tree
Showing 17 changed files with 512 additions and 68 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ It provides **python-first** and **asynchronous-native** task and middleware abs
- Offline RL algorithms: BCQ, CQL, TD3BC, Decision Transformer, EDAC, Diffuser, Decision Diffuser, SO2
- Model-based RL algorithms: SVG, STEVE, MBPO, DDPPO, DreamerV3
- Exploration algorithms: HER, RND, ICM, NGU
- LLM + RL Algorithms: PPO-max, DPO, PromptPG
- LLM + RL Algorithms: PPO-max, DPO, PromptPG, PromptAWR
- Other algorithms: such as PER, PLR, PCGrad
- MCTS + RL algorithms: AlphaZero, MuZero, please refer to [LightZero](https://github.com/opendilab/LightZero)
- Generative Model + RL algorithms: Diffusion-QL, QGPO, SRPO, please refer to [GenerativeRL](https://github.com/opendilab/GenerativeRL)
Expand Down Expand Up @@ -283,6 +283,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 54 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 |
| 55 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)<br>[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
| 56 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
| 57 | [AWR](https://arxiv.org/pdf/1910.00177) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [policy/ibc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/prompt_awr.py) | python3 -u tabmwp_awr_config.py |

</details>

Expand Down Expand Up @@ -491,7 +492,7 @@ We appreciate all the feedbacks and contributions to improve DI-engine, both alg
```latex
@misc{ding,
title={DI-engine: A Universal AI System/Engine for Decision Intelligence},
author={Yazhe Niu, Jingxin Xu, Yuan Pu, Yunpeng Nie, Jinouwen Zhang, Shuai Hu, Liangxuan Zhao, Ming Zhang, Yu Liu},
author={Niu, Yazhe and Xu, Jingxin and Pu, Yuan and Nie, Yunpeng and Zhang, Jinouwen and Hu, Shuai and Zhao, Liangxuan and Zhang, Ming and Liu, Yu},
publisher={GitHub},
howpublished={\url{https://github.com/opendilab/DI-engine}},
year={2021},
Expand Down
51 changes: 39 additions & 12 deletions ding/model/template/language_transformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Dict
from typing import List, Dict, Optional
import torch
from torch import nn

Expand All @@ -15,31 +15,44 @@ class LanguageTransformer(nn.Module):
"""
Overview:
The LanguageTransformer network. Download a pre-trained language model and add head on it.
In the default case, we use BERT model as the text encoder, whose bi-directional character is good
for obtaining the embedding of the whole sentence.
Interfaces:
``__init__``, ``forward``
"""
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']

def __init__(
self,
model_name: str = "bert-base-uncased",
add_linear: bool = False,
embedding_size: int = 128,
freeze_encoder: bool = True
freeze_encoder: bool = True,
hidden_dim: int = 768,
norm_embedding: bool = False
) -> None:
"""
Overview:
Init the LanguageTransformer Model according to input arguments.
Arguments:
- model_name (:obj:`str`): The base language model name in huggingface, such as "bert-base-uncased".
- add_linear (:obj:`bool`): Whether to add a linear layer on the top of language model, defaults to be \
``False``.
``False``.
- embedding_size (:obj:`int`): The embedding size of the added linear layer, such as 128.
- freeze_encoder (:obj:`bool`): Whether to freeze the encoder language model while training, \
defaults to be ``True``.
defaults to be ``True``.
- hidden_dim (:obj:`int`): The embedding dimension of the encoding model (e.g. BERT). This value should \
correspond to the model you use. For bert-base-uncased, this value is 768.
- norm_embedding (:obj:`bool`): Whether to normalize the embedding vectors. Default to be ``False``.
"""
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForTokenClassification.from_pretrained(model_name)
in_channel = hidden_dim if not add_linear else embedding_size
self.value_head = nn.Linear(in_channel, 1)
self.norm = nn.Identity() if not norm_embedding else nn.LayerNorm(
normalized_shape=in_channel, elementwise_affine=False
)

# Freeze transformer encoder and only train the linear layer
if freeze_encoder:
Expand All @@ -49,9 +62,7 @@ def __init__(
if add_linear:
# Add a small, adjustable linear layer on top of language model tuned through RL
self.embedding_size = embedding_size
self.linear = nn.Linear(
self.model.config.hidden_size, embedding_size
) # 768 for bert-base-uncased, distilbert-base-uncased
self.linear = nn.Linear(self.model.config.hidden_size, embedding_size)
else:
self.linear = None

Expand All @@ -66,19 +77,27 @@ def _calc_embedding(self, x: list) -> torch.Tensor:
last_hidden_states = output.hidden_states[-1]
# Get [CLS] hidden states
sentence_embedding = last_hidden_states[:, 0, :] # len(input_list) x hidden_size
sentence_embedding = self.norm(sentence_embedding)

if self.linear:
sentence_embedding = self.linear(sentence_embedding) # len(input_list) x embedding_size

return sentence_embedding

def forward(self, train_samples: List[str], candidate_samples: List[str]) -> Dict:
def forward(
self,
train_samples: List[str],
candidate_samples: Optional[List[str]] = None,
mode: str = 'compute_actor'
) -> Dict:
"""
Overview:
LanguageTransformer forward computation graph, input two lists of strings and predict their matching scores.
Different ``mode`` will forward with different network modules to get different outputs.
Arguments:
- train_samples (:obj:`List[str]`): One list of strings.
- candidate_samples (:obj:`List[str]`): The other list of strings to calculate the matching scores.
- candidate_samples (:obj:`Optional[List[str]]`): The other list of strings to calculate matching scores.
- - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class.
Returns:
- output (:obj:`Dict`): Output dict data, including the logit of matching scores and the \
corresponding ``torch.distributions.Categorical`` object.
Expand All @@ -96,7 +115,15 @@ def forward(self, train_samples: List[str], candidate_samples: List[str]) -> Dic
>>> scores = model(ctxt_list, cands_list)
>>> assert scores.shape == (1, 3)
"""
assert mode in self.mode
prompt_embedding = self._calc_embedding(train_samples)
cands_embedding = self._calc_embedding(candidate_samples)
scores = torch.mm(prompt_embedding, cands_embedding.t())
return {'dist': torch.distributions.Categorical(logits=scores), 'logit': scores}

res_dict = {}
if mode in ['compute_actor', 'compute_actor_critic']:
cands_embedding = self._calc_embedding(candidate_samples)
scores = torch.mm(prompt_embedding, cands_embedding.t())
res_dict.update({'dist': torch.distributions.Categorical(logits=scores), 'logit': scores})
if mode in ['compute_critic', 'compute_actor_critic']:
value = self.value_head(prompt_embedding)
res_dict.update({'value': value})
return res_dict
78 changes: 44 additions & 34 deletions ding/model/template/mavac.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Dict, Optional
from typing import Union, Dict, Tuple, Optional
import torch
import torch.nn as nn

Expand Down Expand Up @@ -35,6 +35,7 @@ def __init__(
norm_type: Optional[str] = None,
sigma_type: Optional[str] = 'independent',
bound_type: Optional[str] = None,
encoder: Optional[Tuple[torch.nn.Module, torch.nn.Module]] = None,
) -> None:
"""
Overview:
Expand Down Expand Up @@ -66,6 +67,9 @@ def __init__(
to ``independent``, which means state-independent sigma parameters.
- bound_type (:obj:`Optional[str]`): The type of action bound methods in continuous action space, defaults \
to ``None``, which means no bound.
- encoder (:obj:`Optional[Tuple[torch.nn.Module, torch.nn.Module]]`): The encoder module list, defaults \
to ``None``, you can define your own actor and critic encoder module and pass it into MAVAC to \
deal with different observation space.
"""
super(MAVAC, self).__init__()
agent_obs_shape: int = squeeze(agent_obs_shape)
Expand All @@ -74,42 +78,38 @@ def __init__(
self.global_obs_shape, self.agent_obs_shape, self.action_shape = global_obs_shape, agent_obs_shape, action_shape
self.action_space = action_space
# Encoder Type
# We directly connect the Head after a Liner layer instead of using the 3-layer FCEncoder.
# In SMAC task it can obviously improve the performance.
# Users can change the model according to their own needs.
self.actor_encoder = nn.Identity()
self.critic_encoder = nn.Identity()
# Head Type
self.critic_head = nn.Sequential(
nn.Linear(global_obs_shape, critic_head_hidden_size), activation,
RegressionHead(
critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
if encoder:
self.actor_encoder, self.critic_encoder = encoder
else:
# We directly connect the Head after a Liner layer instead of using the 3-layer FCEncoder.
# In SMAC task it can obviously improve the performance.
# Users can change the model according to their own needs.
self.actor_encoder = nn.Sequential(
nn.Linear(agent_obs_shape, actor_head_hidden_size),
activation,
)
self.critic_encoder = nn.Sequential(
nn.Linear(global_obs_shape, critic_head_hidden_size),
activation,
)
# Head Type
self.critic_head = RegressionHead(
critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
)
assert self.action_space in ['discrete', 'continuous'], self.action_space
if self.action_space == 'discrete':
self.actor_head = nn.Sequential(
nn.Linear(agent_obs_shape, actor_head_hidden_size), activation,
DiscreteHead(
actor_head_hidden_size,
action_shape,
actor_head_layer_num,
activation=activation,
norm_type=norm_type
)
self.actor_head = DiscreteHead(
actor_head_hidden_size, action_shape, actor_head_layer_num, activation=activation, norm_type=norm_type
)
elif self.action_space == 'continuous':
self.actor_head = nn.Sequential(
nn.Linear(agent_obs_shape, actor_head_hidden_size), activation,
ReparameterizationHead(
actor_head_hidden_size,
action_shape,
actor_head_layer_num,
sigma_type=sigma_type,
activation=activation,
norm_type=norm_type,
bound_type=bound_type
)
self.actor_head = ReparameterizationHead(
actor_head_hidden_size,
action_shape,
actor_head_layer_num,
sigma_type=sigma_type,
activation=activation,
norm_type=norm_type,
bound_type=bound_type
)
# must use list, not nn.ModuleList
self.actor = [self.actor_encoder, self.actor_head]
Expand Down Expand Up @@ -261,7 +261,7 @@ def compute_actor_critic(self, x: Dict) -> Dict:
- value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
Shapes:
- logit (:obj:`torch.FloatTensor`): :math:`(B, M, N)`, where B is batch size and N is ``action_shape`` \
and M is ``agent_num``.
and M is ``agent_num``.
- value (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch sizeand M is ``agent_num``.
Examples:
Expand All @@ -275,6 +275,16 @@ def compute_actor_critic(self, x: Dict) -> Dict:
>>> assert outputs['value'].shape == torch.Size([10, 8])
>>> assert outputs['logit'].shape == torch.Size([10, 8, 14])
"""
logit = self.compute_actor(x)['logit']
value = self.compute_critic(x)['value']
x_actor = self.actor_encoder(x['agent_state'])
x_critic = self.critic_encoder(x['global_state'])

if self.action_space == 'discrete':
action_mask = x['action_mask']
x = self.actor_head(x_actor)
logit = x['logit']
logit[action_mask == 0.0] = -99999999
elif self.action_space == 'continuous':
x = self.actor_head(x_actor)
logit = x
value = self.critic_head(x_critic)['pred']
return {'logit': logit, 'value': value}
34 changes: 29 additions & 5 deletions ding/model/template/tests/test_language_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,33 @@ def check_model(self):
cands_list = [problems[pid] for pid in cand_pids]

model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256)
scores = model(ctxt_list, cands_list)
assert scores.shape == (1, 3)
output = model(ctxt_list, cands_list, mode='compute_actor')
assert 'dist' in output.keys() and 'logit' in output.keys() and len(output.keys()) == 2
assert output['logit'].shape == (1, 3)

model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, embedding_size=256)
scores = model(ctxt_list, cands_list)
assert scores.shape == (1, 3)
output = model(ctxt_list, cands_list, mode='compute_critic')
assert 'value' in output.keys() and len(output.keys()) == 1
assert output['value'].shape == (1, )

output = model(ctxt_list, cands_list, mode='compute_critic')
assert 'value' in output.keys() and 'dist' in output.keys() and 'logit' in output.keys() and len(
output.keys()
) == 3
assert output['value'].shape == (1, )
assert output['logit'].shape == (1, 3)

model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, norm_embedding=True)
output = model(ctxt_list, cands_list, mode='compute_actor')
assert 'dist' in output.keys() and 'logit' in output.keys() and len(output.keys()) == 2
assert output['logit'].shape == (1, 3)

output = model(ctxt_list, cands_list, mode='compute_critic')
assert 'value' in output.keys() and len(output.keys()) == 1
assert output['value'].shape == (1, )

output = model(ctxt_list, cands_list, mode='compute_critic')
assert 'value' in output.keys() and 'dist' in output.keys() and 'logit' in output.keys() and len(
output.keys()
) == 3
assert output['value'].shape == (1, )
assert output['logit'].shape == (1, 3)
37 changes: 37 additions & 0 deletions ding/model/template/tests/test_mavac.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import numpy as np
import torch
import torch.nn as nn
from itertools import product

from ding.model import mavac
Expand Down Expand Up @@ -50,3 +51,39 @@ def test_vac(self, agent_obs_shape, global_obs_shape):
value = model(data, mode='compute_critic')['value']
assert value.shape == (B, agent_num)
self.output_check(model.critic, value, action_shape)

def test_vac_with_encoder(self, agent_obs_shape, global_obs_shape):
data = {
'agent_state': torch.randn(B, agent_num, agent_obs_shape),
'global_state': torch.randn(B, agent_num, global_obs_shape),
'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape))
}

actor_size, critic_size = 128, 128
encoder = [nn.Linear(agent_obs_shape, actor_size), nn.Linear(global_obs_shape, critic_size)]
model = MAVAC(
agent_obs_shape,
global_obs_shape,
action_shape,
agent_num,
encoder=encoder,
actor_head_hidden_size=actor_size,
critic_head_hidden_size=critic_size
)

logit = model(data, mode='compute_actor_critic')['logit']
value = model(data, mode='compute_actor_critic')['value']

outputs = value.sum() + logit.sum()
self.output_check(model, outputs, action_shape)

for p in model.parameters():
p.grad = None
logit = model(data, mode='compute_actor')['logit']
self.output_check(model.actor, logit, model.action_shape)

for p in model.parameters():
p.grad = None
value = model(data, mode='compute_critic')['value']
assert value.shape == (B, agent_num)
self.output_check(model.critic, value, action_shape)
Loading

0 comments on commit 6b9f509

Please sign in to comment.