Skip to content

Commit

Permalink
feature(luyd): fix dt new pipeline of mujoco (#754)
Browse files Browse the repository at this point in the history
* fix dt in mujoco

* Fix according to comment
  • Loading branch information
AltmanD authored Dec 11, 2023
1 parent 5788265 commit b959eb1
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 73 deletions.
9 changes: 6 additions & 3 deletions ding/framework/middleware/functional/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,10 @@ def offline_data_fetcher_from_mem(cfg: EasyDict, dataset: Dataset) -> Callable:
def producer(queue, dataset, batch_size, device):
torch.set_num_threads(4)
nonlocal stream
idx_iter = iter(range(len(dataset)))
idx_iter = iter(range(len(dataset) - batch_size))

if len(dataset) < batch_size:
logging.warning('batch_size is too large!!!!')
with torch.cuda.stream(stream):
while True:
if queue.full():
Expand All @@ -201,7 +204,7 @@ def producer(queue, dataset, batch_size, device):
start_idx = next(idx_iter)
except StopIteration:
del idx_iter
idx_iter = iter(range(len(dataset)))
idx_iter = iter(range(len(dataset) - batch_size))
start_idx = next(idx_iter)
data = [dataset.__getitem__(idx) for idx in range(start_idx, start_idx + batch_size)]
data = [[i[j] for i in data] for j in range(len(data[0]))]
Expand All @@ -211,7 +214,7 @@ def producer(queue, dataset, batch_size, device):
queue = Queue(maxsize=50)
device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu'
producer_thread = Thread(
target=producer, args=(queue, dataset, cfg.policy.batch_size, device), name='cuda_fetcher_producer'
target=producer, args=(queue, dataset, cfg.policy.learn.batch_size, device), name='cuda_fetcher_producer'
)

def _fetch(ctx: "OfflineRLContext"):
Expand Down
6 changes: 4 additions & 2 deletions ding/policy/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]:
if self._basic_discrete_env:
actions = actions.to(torch.long)
actions = actions.squeeze(-1)
action_target = torch.clone(actions).detach().to(self._device)
action_target = torch.clone(actions).detach().to(self._device)

if self._atari_env:
state_preds, action_preds, return_preds = self._learn_model.forward(
Expand Down Expand Up @@ -291,7 +291,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
self.states[i, self.t[i]] = data[i]['obs'].to(self._device)
else:
self.states[i, self.t[i]] = (data[i]['obs'].to(self._device) - self.state_mean) / self.state_std
self.running_rtg[i] = self.running_rtg[i] - data[i]['reward'].to(self._device)
self.running_rtg[i] = self.running_rtg[i] - (data[i]['reward'] / self.rtg_scale).to(self._device)
self.rewards_to_go[i, self.t[i]] = self.running_rtg[i]

if self.t[i] <= self.context_len:
Expand Down Expand Up @@ -328,6 +328,8 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
act[i] = torch.multinomial(probs[i], num_samples=1)
else:
act = torch.argmax(logits, axis=1).unsqueeze(1)
else:
act = logits
for i in data_id:
self.actions[i, self.t[i]] = act[i] # TODO: self.actions[i] should be a queue when exceed max_t
self.t[i] += 1
Expand Down
48 changes: 1 addition & 47 deletions ding/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,13 +389,10 @@ def __init__(self, cfg: dict) -> None:

self.trajectories = paths

# calculate min len of traj, state mean and variance
# and returns_to_go for all traj
min_len = 10 ** 6
# calculate state mean and variance and returns_to_go for all traj
states = []
for traj in self.trajectories:
traj_len = traj['observations'].shape[0]
min_len = min(min_len, traj_len)
states.append(traj['observations'])
# calculate returns to go and rescale them
traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale
Expand All @@ -408,46 +405,6 @@ def __init__(self, cfg: dict) -> None:
for traj in self.trajectories:
traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std

# self.trajectories = {}
# exp_key = ['rewards', 'terminals', 'timeouts']
# for k in dataset.keys():
# logging.info(f'Load {k} data.')
# if k in exp_key:
# self.trajectories[k] = np.expand_dims(dataset[k][:], axis=1)
# else:
# self.trajectories[k] = dataset[k][:]

# # used for input normalization
# states = np.concatenate(self.trajectories['observations'], axis=0)
# self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6

# # normalize states
# self.trajectories['observations'] = (self.trajectories['observations'] - self.state_mean) / self.state_std
# self.trajectories['returns_to_go'] = discount_cumsum(self.trajectories['rewards'], 1.0) / rtg_scale

# datalen = self.trajectories['rewards'].shape[0]

# use_timeouts = False
# if 'timeouts' in dataset:
# use_timeouts = True

# data_ = collections.defaultdict(list)
# episode_step = 0
# trajectories_tmp = []
# for i in range(datalen):
# done_bool = bool(self.trajectories['terminals'][i])
# final_timestep = (episode_step == 1000-1)
# for k in ['observations', 'actions', 'returns_to_go']:
# data_[k].append(self.trajectories[k][i])
# if done_bool or final_timestep:
# episode_step = 0
# episode_data = {}
# for k in data_:
# episode_data[k] = np.array(data_[k])
# trajectories_tmp.append(episode_data)
# data_ = collections.defaultdict(list)
# episode_step += 1
# self.trajectories = trajectories_tmp
elif 'pkl' in dataset_path:
if 'dqn' in dataset_path:
# load dataset
Expand Down Expand Up @@ -493,11 +450,8 @@ def __init__(self, cfg: dict) -> None:
with open(dataset_path, 'rb') as f:
self.trajectories = pickle.load(f)

min_len = 10 ** 6
states = []
for traj in self.trajectories:
traj_len = traj['observations'].shape[0]
min_len = min(min_len, traj_len)
states.append(traj['observations'])
# calculate returns to go and rescale them
traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale
Expand Down
2 changes: 1 addition & 1 deletion dizoo/d4rl/config/hopper_expert_dt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
dataset=dict(
env_type='mujoco',
rtg_scale=1000,
context_len=30,
context_len=20,
data_dir_prefix='d4rl/hopper_expert-v2.pkl',
),
policy=dict(
Expand Down
6 changes: 3 additions & 3 deletions dizoo/d4rl/config/hopper_medium_dt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
dataset=dict(
env_type='mujoco',
rtg_scale=1000,
context_len=30,
data_dir_prefix='d4rl/hopper_medium-v2.pkl',
context_len=20,
data_dir_prefix='d4rl/hopper_medium_expert-v2.pkl',
),
policy=dict(
cuda=True,
Expand Down Expand Up @@ -47,7 +47,7 @@
data_type='d4rl_trajectory',
unroll_len=1,
),
eval=dict(evaluator=dict(eval_freq=100, ), ),
eval=dict(evaluator=dict(eval_freq=1000, ), ),
),
)

Expand Down
6 changes: 3 additions & 3 deletions dizoo/d4rl/config/hopper_medium_expert_dt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from copy import deepcopy

hopper_dt_config = dict(
exp_name='dt_log/d4rl/hopper/hopper_medium_expert_dt_seed0',
exp_name='dt_log/d4rl/hopper/hopper_medium_expert_dt',
env=dict(
env_id='Hopper-v3',
collector_env_num=1,
Expand All @@ -14,8 +14,8 @@
dataset=dict(
env_type='mujoco',
rtg_scale=1000,
context_len=30,
data_dir_prefix='d4rl/hopper_medium_expert-v2.pkl',
context_len=20,
data_dir_prefix='d4rl/hopper_medium_expert.pkl',
),
policy=dict(
cuda=True,
Expand Down
14 changes: 7 additions & 7 deletions dizoo/d4rl/config/walker2d_medium_dt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from copy import deepcopy

walk2d_dt_config = dict(
exp_name='dt_log/d4rl/walk2d/walk2d_medium_dt_seed0',
exp_name='dt_log/d4rl/walk2d/walk2d_medium_dt',
env=dict(
env_id='Walk2d-v3',
env_id='Walker2d-v3',
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
Expand All @@ -14,16 +14,16 @@
dataset=dict(
env_type='mujoco',
rtg_scale=1000,
context_len=30,
data_dir_prefix='d4rl/walk2d_medium-v2.pkl',
context_len=20,
data_dir_prefix='d4rl/walker2d_medium-v2.pkl',
),
policy=dict(
cuda=True,
stop_value=5000,
state_mean=None,
state_std=None,
evaluator_env_num=8,
env_name='Walk2d-v3',
env_name='Walker2d-v3',
rtg_target=5000, # max target return to go
max_eval_ep_len=1000, # max lenght of one episode
wt_decay=1e-4,
Expand All @@ -32,8 +32,8 @@
weight_decay=0.1,
clip_grad_norm_p=0.25,
model=dict(
state_dim=11,
act_dim=3,
state_dim=17,
act_dim=6,
n_blocks=3,
h_dim=128,
context_len=20,
Expand Down
12 changes: 5 additions & 7 deletions dizoo/d4rl/entry/d4rl_dt_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ding.config import compile_config
from ding.framework import task, ding_init
from ding.framework.context import OfflineRLContext
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger, termination_checker
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher_from_mem, offline_logger, termination_checker
from ding.utils import set_pkg_seed
from dizoo.d4rl.envs import D4RLEnv
from dizoo.d4rl.config.hopper_medium_dt_config import main_config, create_config
Expand All @@ -32,16 +32,14 @@ def main():

dataset = create_dataset(cfg)
# env_data_stats = dataset.get_d4rl_dataset_stats(cfg.policy.dataset_name)
env_data_stats = dataset.get_state_stats()
cfg.policy.state_mean, cfg.policy.state_std = np.array(env_data_stats['state_mean']
), np.array(env_data_stats['state_std'])
cfg.policy.state_mean, cfg.policy.state_std = dataset.get_state_stats()
model = DecisionTransformer(**cfg.policy.model)
policy = DTPolicy(cfg.policy, model=model)
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(offline_data_fetcher(cfg, dataset))
task.use(offline_data_fetcher_from_mem(cfg, dataset))
task.use(trainer(cfg, policy.learn_mode))
task.use(termination_checker(max_train_iter=1e5))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.use(termination_checker(max_train_iter=5e4))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
task.use(offline_logger())
task.run()

Expand Down

0 comments on commit b959eb1

Please sign in to comment.