Skip to content

Commit

Permalink
style(pu): bash format.sh
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Nov 4, 2024
1 parent 0968723 commit 22d5110
Show file tree
Hide file tree
Showing 29 changed files with 137 additions and 168 deletions.
10 changes: 9 additions & 1 deletion ding/entry/serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,15 @@ def serial_pipeline(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not cfg.policy.learn.resume_training)
cfg = compile_config(
cfg,
seed=seed,
env=env_fn,
auto=True,
create_cfg=create_cfg,
save_cfg=True,
renew_dir=not cfg.policy.learn.resume_training
)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down
21 changes: 10 additions & 11 deletions ding/entry/serial_entry_mbrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,15 @@ def mbrl_entry_setup(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not cfg.policy.learn.resume_training)
cfg = compile_config(
cfg,
seed=seed,
env=env_fn,
auto=True,
create_cfg=create_cfg,
save_cfg=True,
renew_dir=not cfg.policy.learn.resume_training
)

if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -71,16 +79,7 @@ def mbrl_entry_setup(
)

return (
cfg,
policy,
world_model,
env_buffer,
learner,
collector,
collector_env,
evaluator,
commander,
tb_logger,
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger,
resume_training
)

Expand Down
10 changes: 9 additions & 1 deletion ding/entry/serial_entry_ngu.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,15 @@ def serial_pipeline_ngu(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not cfg.policy.learn.resume_training)
cfg = compile_config(
cfg,
seed=seed,
env=env_fn,
auto=True,
create_cfg=create_cfg,
save_cfg=True,
renew_dir=not cfg.policy.learn.resume_training
)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down
10 changes: 9 additions & 1 deletion ding/entry/serial_entry_onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,15 @@ def serial_pipeline_onpolicy(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not cfg.policy.learn.resume_training)
cfg = compile_config(
cfg,
seed=seed,
env=env_fn,
auto=True,
create_cfg=create_cfg,
save_cfg=True,
renew_dir=not cfg.policy.learn.resume_training
)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down
10 changes: 9 additions & 1 deletion ding/entry/serial_entry_onpolicy_ppg.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,15 @@ def serial_pipeline_onpolicy_ppg(
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not cfg.policy.learn.resume_training)
cfg = compile_config(
cfg,
seed=seed,
env=env_fn,
auto=True,
create_cfg=create_cfg,
save_cfg=True,
renew_dir=not cfg.policy.learn.resume_training
)
# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down
4 changes: 2 additions & 2 deletions dizoo/cliffwalking/envs/cliffwalking_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def __init__(self, cfg: dict) -> None:
self._replay_path = None
self._observation_space = gym.spaces.Box(low=0, high=1, shape=(48, ), dtype=np.float32)
self._env = gym.make(
"CliffWalking", render_mode=self._cfg.render_mode, max_episode_steps=self._cfg.max_episode_steps
)
"CliffWalking", render_mode=self._cfg.render_mode, max_episode_steps=self._cfg.max_episode_steps
)
self._action_space = self._env.action_space
self._reward_space = gym.spaces.Box(
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
Expand Down
8 changes: 3 additions & 5 deletions dizoo/d4rl/config/antmaze_umaze_pd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
value_model_cfg=dict(
model='TemporalValue',
model_cfg=dict(
horizon = 256,
horizon=256,
transition_dim=37,
dim=32,
dim_mults=[1, 2, 4, 8],
Expand Down Expand Up @@ -92,10 +92,8 @@
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='pd',
),
policy=dict(type='pd', ),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
create_config = create_config
6 changes: 2 additions & 4 deletions dizoo/d4rl/config/halfcheetah_medium_expert_pd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
value_model_cfg=dict(
model='TemporalValue',
model_cfg=dict(
horizon = 4,
horizon=4,
transition_dim=23,
dim=32,
dim_mults=[1, 4, 8],
Expand Down Expand Up @@ -92,9 +92,7 @@
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='pd',
),
policy=dict(type='pd', ),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
Expand Down
8 changes: 2 additions & 6 deletions dizoo/d4rl/config/halfcheetah_medium_expert_qgpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
evaluator_env_num=8,
n_evaluator_episode=8,
),
dataset=dict(
env_id="halfcheetah-medium-expert-v2",
),
dataset=dict(env_id="halfcheetah-medium-expert-v2", ),
policy=dict(
cuda=True,
on_policy=False,
Expand Down Expand Up @@ -44,8 +42,6 @@

create_config = dict(
env_manager=dict(type='base'),
policy=dict(
type='qgpo',
),
policy=dict(type='qgpo', ),
)
create_config = EasyDict(create_config)
6 changes: 2 additions & 4 deletions dizoo/d4rl/config/halfcheetah_medium_pd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
value_model_cfg=dict(
model='TemporalValue',
model_cfg=dict(
horizon = 4,
horizon=4,
transition_dim=23,
dim=32,
dim_mults=[1, 4, 8],
Expand Down Expand Up @@ -92,9 +92,7 @@
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='pd',
),
policy=dict(type='pd', ),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
Expand Down
8 changes: 3 additions & 5 deletions dizoo/d4rl/config/hopper_medium_expert_pd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
value_model_cfg=dict(
model='TemporalValue',
model_cfg=dict(
horizon = 32,
horizon=32,
transition_dim=14,
dim=32,
dim_mults=[1, 2, 4, 8],
Expand Down Expand Up @@ -92,10 +92,8 @@
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='pd',
),
policy=dict(type='pd', ),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
create_config = create_config
8 changes: 2 additions & 6 deletions dizoo/d4rl/config/hopper_medium_expert_qgpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
evaluator_env_num=8,
n_evaluator_episode=8,
),
dataset=dict(
env_id="hopper-medium-expert-v2",
),
dataset=dict(env_id="hopper-medium-expert-v2", ),
policy=dict(
cuda=True,
on_policy=False,
Expand Down Expand Up @@ -44,8 +42,6 @@

create_config = dict(
env_manager=dict(type='base'),
policy=dict(
type='qgpo',
),
policy=dict(type='qgpo', ),
)
create_config = EasyDict(create_config)
8 changes: 3 additions & 5 deletions dizoo/d4rl/config/hopper_medium_pd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
value_model_cfg=dict(
model='TemporalValue',
model_cfg=dict(
horizon = 32,
horizon=32,
transition_dim=14,
dim=32,
dim_mults=[1, 2, 4, 8],
Expand Down Expand Up @@ -92,10 +92,8 @@
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='pd',
),
policy=dict(type='pd', ),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
create_config = create_config
6 changes: 2 additions & 4 deletions dizoo/d4rl/config/maze2d_large_pd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,8 @@
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='pd',
),
policy=dict(type='pd', ),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
create_config = create_config
6 changes: 2 additions & 4 deletions dizoo/d4rl/config/maze2d_medium_pd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,8 @@
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='pd',
),
policy=dict(type='pd', ),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
create_config = create_config
6 changes: 2 additions & 4 deletions dizoo/d4rl/config/maze2d_umaze_pd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,8 @@
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='pd',
),
policy=dict(type='pd', ),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
create_config = create_config
8 changes: 3 additions & 5 deletions dizoo/d4rl/config/walker2d_medium_expert_pd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
value_model_cfg=dict(
model='TemporalValue',
model_cfg=dict(
horizon = 32,
horizon=32,
transition_dim=23,
dim=32,
dim_mults=[1, 2, 4, 8],
Expand Down Expand Up @@ -92,10 +92,8 @@
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='pd',
),
policy=dict(type='pd', ),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
create_config = create_config
8 changes: 2 additions & 6 deletions dizoo/d4rl/config/walker2d_medium_expert_qgpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
evaluator_env_num=8,
n_evaluator_episode=8,
),
dataset=dict(
env_id="walker2d-medium-expert-v2",
),
dataset=dict(env_id="walker2d-medium-expert-v2", ),
policy=dict(
cuda=True,
on_policy=False,
Expand Down Expand Up @@ -44,8 +42,6 @@

create_config = dict(
env_manager=dict(type='base'),
policy=dict(
type='qgpo',
),
policy=dict(type='qgpo', ),
)
create_config = EasyDict(create_config)
8 changes: 3 additions & 5 deletions dizoo/d4rl/config/walker2d_medium_pd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
value_model_cfg=dict(
model='TemporalValue',
model_cfg=dict(
horizon = 32,
horizon=32,
transition_dim=23,
dim=32,
dim_mults=[1, 2, 4, 8],
Expand Down Expand Up @@ -92,10 +92,8 @@
import_names=['dizoo.d4rl.envs.d4rl_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='pd',
),
policy=dict(type='pd', ),
replay_buffer=dict(type='naive', ),
)
create_config = EasyDict(create_config)
create_config = create_config
create_config = create_config
4 changes: 2 additions & 2 deletions dizoo/d4rl/entry/d4rl_pd_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def train(args):
# launch from anywhere
config = Path(__file__).absolute().parent.parent / 'config' / args.config
config = Path(__file__).absolute().parent.parent / 'config' / args.config
config = read_config(str(config))
config[0].exp_name = config[0].exp_name.replace('0', str(args.seed))
serial_pipeline_offline(config, seed=args.seed)
Expand All @@ -18,4 +18,4 @@ def train(args):
parser.add_argument('--seed', '-s', type=int, default=10)
parser.add_argument('--config', '-c', type=str, default='halfcheetah_medium_pd_config.py')
args = parser.parse_args()
train(args)
train(args)
Loading

0 comments on commit 22d5110

Please sign in to comment.