-
Notifications
You must be signed in to change notification settings - Fork 378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Imitation Learning #973
base: i210_dev
Are you sure you want to change the base?
Imitation Learning #973
Changes from 62 commits
50b017a
1fc027d
d01aeb5
920dd73
630a100
722f439
e9d7634
83a7887
c668336
f54eebc
eb7b3a2
4f1b2ad
c36f010
cb4cae8
23e2ba3
f924d9c
0b08c33
21ee5ce
75716d3
cffc33d
6c0c590
61f9a3a
b4f844f
37f2c2e
db0442b
39ad373
ed065b3
cc0aa32
3a2e135
288a1cf
c1db60a
13a797c
4b853b5
5c0923d
c785944
fef3a83
fd29e0f
2b6cc08
8ef0179
3dfafe1
ba67961
bff9e47
db29793
d10f8e5
7fa3e3a
91fab74
d38839f
81e8d6a
240cc05
10bf24f
aa72b2e
e209de2
5ce3c4d
1aae8f8
f7451d0
9519c1f
2d96460
18c0d9e
024cb93
16a9ced
885ab6f
85fdd63
9dd65c8
739c2ca
ddce32e
4e6302e
29eb5a0
6c68800
d73612f
6aca7c5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,6 +50,10 @@ def parse_args(args): | |
parser.add_argument( | ||
'--rl_trainer', type=str, default="rllib", | ||
help='the RL trainer to use. either rllib or Stable-Baselines') | ||
parser.add_argument( | ||
'--load_weights_path', type=str, default=None, | ||
help='Path to h5 file containing a pretrained model. Relevent for PPO with RLLib' | ||
) | ||
parser.add_argument( | ||
'--algorithm', type=str, default="PPO", | ||
help='RL algorithm to use. Options are PPO, TD3, and CENTRALIZEDPPO (which uses a centralized value function)' | ||
|
@@ -111,9 +115,6 @@ def run_model_stablebaseline(flow_params, | |
stable_baselines.* | ||
the trained model | ||
""" | ||
from stable_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv | ||
from stable_baselines import PPO2 | ||
|
||
if num_cpus == 1: | ||
constructor = env_constructor(params=flow_params, version=0)() | ||
# The algorithms require a vectorized environment to run | ||
|
@@ -175,14 +176,18 @@ def setup_exps_rllib(flow_params, | |
alg_run = flags.algorithm.upper() | ||
|
||
if alg_run == "PPO": | ||
from flow.algorithms.custom_ppo import CustomPPOTrainer | ||
from flow.controllers.imitation_learning.custom_ppo import CustomPPOTrainer | ||
from ray.rllib.agents.ppo import DEFAULT_CONFIG | ||
alg_run = CustomPPOTrainer | ||
config = deepcopy(DEFAULT_CONFIG) | ||
|
||
|
||
alg_run = CustomPPOTrainer | ||
|
||
horizon = flow_params['env'].horizon | ||
|
||
config["num_workers"] = n_cpus | ||
config["horizon"] = horizon | ||
config["model"].update({"fcnet_hiddens": [32, 32]}) | ||
config["model"].update({"fcnet_hiddens": [32, 32, 32]}) | ||
config["train_batch_size"] = horizon * n_rollouts | ||
config["gamma"] = 0.995 # discount rate | ||
config["use_gae"] = True | ||
|
@@ -192,6 +197,21 @@ def setup_exps_rllib(flow_params, | |
if flags.grid_search: | ||
config["lambda"] = tune.grid_search([0.5, 0.9]) | ||
config["lr"] = tune.grid_search([5e-4, 5e-5]) | ||
|
||
if flags.load_weights_path: | ||
from flow.controllers.imitation_learning.ppo_model import PPONetwork | ||
from flow.controllers.imitation_learning.imitation_trainer import Imitation_PPO_Trainable | ||
from ray.rllib.models import ModelCatalog | ||
|
||
# Register custom model | ||
ModelCatalog.register_custom_model("PPO_loaded_weights", PPONetwork) | ||
# set model to the custom model for run | ||
config['model']['custom_model'] = "PPO_loaded_weights" | ||
config['model']['custom_options'] = {"h5_load_path": flags.load_weights_path} | ||
config['observation_filter'] = 'NoFilter' | ||
# alg run is the Trainable class | ||
alg_run = Imitation_PPO_Trainable | ||
|
||
elif alg_run == "CENTRALIZEDPPO": | ||
from flow.algorithms.centralized_PPO import CCTrainer, CentralizedCriticModel | ||
from ray.rllib.agents.ppo import DEFAULT_CONFIG | ||
|
@@ -313,7 +333,6 @@ def on_train_result(info): | |
register_env(gym_name, create_env) | ||
return alg_run, gym_name, config | ||
|
||
|
||
def train_rllib(submodule, flags): | ||
"""Train policies using the PPO algorithm in RLlib.""" | ||
import ray | ||
|
@@ -337,9 +356,11 @@ def trial_str_creator(trial): | |
return "{}_{}".format(trial.trainable_name, trial.experiment_tag) | ||
|
||
if flags.local_mode: | ||
print("LOCAL MODE") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: more informative comment |
||
ray.init(local_mode=True) | ||
else: | ||
ray.init() | ||
|
||
exp_dict = { | ||
"run_or_experiment": alg_run, | ||
"name": flags.exp_title, | ||
|
@@ -466,9 +487,6 @@ def train_h_baselines(flow_params, args, multiagent): | |
|
||
def train_stable_baselines(submodule, flags): | ||
"""Train policies using the PPO algorithm in stable-baselines.""" | ||
from stable_baselines.common.vec_env import DummyVecEnv | ||
from stable_baselines import PPO2 | ||
|
||
flow_params = submodule.flow_params | ||
# Path to the saved files | ||
exp_tag = flow_params['exp_tag'] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
import logging | ||
|
||
from ray.rllib.agents import with_common_config | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: no comment so I'm not sure what this file is for or hwo it's different than the other custom ppo |
||
from flow.controllers.imitation_learning.custom_ppo_tf_policy import CustomPPOTFPolicy | ||
from ray.rllib.agents.trainer_template import build_trainer | ||
from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer | ||
from ray.rllib.utils import try_import_tf | ||
|
||
tf = try_import_tf() | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# yapf: disable | ||
# __sphinx_doc_begin__ | ||
DEFAULT_CONFIG = with_common_config({ | ||
# Should use a critic as a baseline (otherwise don't use value baseline; | ||
# required for using GAE). | ||
"use_critic": True, | ||
# If true, use the Generalized Advantage Estimator (GAE) | ||
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf. | ||
"use_gae": True, | ||
# The GAE(lambda) parameter. | ||
"lambda": 1.0, | ||
# Initial coefficient for KL divergence. | ||
"kl_coeff": 0.2, | ||
# Size of batches collected from each worker. | ||
"rollout_fragment_length": 200, | ||
# Number of timesteps collected for each SGD round. This defines the size | ||
# of each SGD epoch. | ||
"train_batch_size": 4000, | ||
# Total SGD batch size across all devices for SGD. This defines the | ||
# minibatch size within each epoch. | ||
"sgd_minibatch_size": 128, | ||
# Whether to shuffle sequences in the batch when training (recommended). | ||
"shuffle_sequences": True, | ||
# Number of SGD iterations in each outer loop (i.e., number of epochs to | ||
# execute per train batch). | ||
"num_sgd_iter": 30, | ||
# Stepsize of SGD. | ||
"lr": 5e-5, | ||
# Learning rate schedule. | ||
"lr_schedule": None, | ||
# Share layers for value function. If you set this to True, it's important | ||
# to tune vf_loss_coeff. | ||
"vf_share_layers": False, | ||
# Coefficient of the value function loss. IMPORTANT: you must tune this if | ||
# you set vf_share_layers: True. | ||
"vf_loss_coeff": 1.0, | ||
# Coefficient of the entropy regularizer. | ||
"entropy_coeff": 0.0, | ||
# Decay schedule for the entropy regularizer. | ||
"entropy_coeff_schedule": None, | ||
# PPO clip parameter. | ||
"clip_param": 0.3, | ||
# Clip param for the value function. Note that this is sensitive to the | ||
# scale of the rewards. If your expected V is large, increase this. | ||
"vf_clip_param": 10.0, | ||
# If specified, clip the global norm of gradients by this amount. | ||
"grad_clip": None, | ||
# Target value for KL divergence. | ||
"kl_target": 0.01, | ||
# Whether to rollout "complete_episodes" or "truncate_episodes". | ||
"batch_mode": "truncate_episodes", | ||
# Which observation filter to apply to the observation. | ||
"observation_filter": "NoFilter", | ||
# Uses the sync samples optimizer instead of the multi-gpu one. This is | ||
# usually slower, but you might want to try it if you run into issues with | ||
# the default optimizer. | ||
"simple_optimizer": False, | ||
# Use PyTorch as framework? | ||
"use_pytorch": False | ||
}) | ||
# __sphinx_doc_end__ | ||
# yapf: enable | ||
|
||
|
||
def choose_policy_optimizer(workers, config): | ||
if config["simple_optimizer"]: | ||
return SyncSamplesOptimizer( | ||
workers, | ||
num_sgd_iter=config["num_sgd_iter"], | ||
train_batch_size=config["train_batch_size"], | ||
sgd_minibatch_size=config["sgd_minibatch_size"], | ||
standardize_fields=["advantages"]) | ||
|
||
return LocalMultiGPUOptimizer( | ||
workers, | ||
sgd_batch_size=config["sgd_minibatch_size"], | ||
num_sgd_iter=config["num_sgd_iter"], | ||
num_gpus=config["num_gpus"], | ||
rollout_fragment_length=config["rollout_fragment_length"], | ||
num_envs_per_worker=config["num_envs_per_worker"], | ||
train_batch_size=config["train_batch_size"], | ||
standardize_fields=["advantages"], | ||
shuffle_sequences=config["shuffle_sequences"]) | ||
|
||
|
||
def update_kl(trainer, fetches): | ||
# Single-agent. | ||
if "kl" in fetches: | ||
trainer.workers.local_worker().for_policy( | ||
lambda pi: pi.update_kl(fetches["kl"])) | ||
|
||
# Multi-agent. | ||
else: | ||
|
||
def update(pi, pi_id): | ||
if pi_id in fetches: | ||
pi.update_kl(fetches[pi_id]["kl"]) | ||
else: | ||
logger.debug("No data for {}, not updating kl".format(pi_id)) | ||
|
||
trainer.workers.local_worker().foreach_trainable_policy(update) | ||
|
||
|
||
def warn_about_bad_reward_scales(trainer, result): | ||
if result["policy_reward_mean"]: | ||
return # Punt on handling multiagent case. | ||
|
||
# Warn about excessively high VF loss. | ||
learner_stats = result["info"]["learner"] | ||
if "default_policy" in learner_stats: | ||
scaled_vf_loss = (trainer.config["vf_loss_coeff"] * | ||
learner_stats["default_policy"]["vf_loss"]) | ||
policy_loss = learner_stats["default_policy"]["policy_loss"] | ||
if trainer.config["vf_share_layers"] and scaled_vf_loss > 100: | ||
logger.warning( | ||
"The magnitude of your value function loss is extremely large " | ||
"({}) compared to the policy loss ({}). This can prevent the " | ||
"policy from learning. Consider scaling down the VF loss by " | ||
"reducing vf_loss_coeff, or disabling vf_share_layers.".format( | ||
scaled_vf_loss, policy_loss)) | ||
|
||
# Warn about bad clipping configs | ||
if trainer.config["vf_clip_param"] <= 0: | ||
rew_scale = float("inf") | ||
else: | ||
rew_scale = round( | ||
abs(result["episode_reward_mean"]) / | ||
trainer.config["vf_clip_param"], 0) | ||
if rew_scale > 200: | ||
logger.warning( | ||
"The magnitude of your environment rewards are more than " | ||
"{}x the scale of `vf_clip_param`. ".format(rew_scale) + | ||
"This means that it will take more than " | ||
"{} iterations for your value ".format(rew_scale) + | ||
"function to converge. If this is not intended, consider " | ||
"increasing `vf_clip_param`.") | ||
|
||
|
||
def validate_config(config): | ||
if config["entropy_coeff"] < 0: | ||
raise DeprecationWarning("entropy_coeff must be >= 0") | ||
if isinstance(config["entropy_coeff"], int): | ||
config["entropy_coeff"] = float(config["entropy_coeff"]) | ||
if config["sgd_minibatch_size"] > config["train_batch_size"]: | ||
raise ValueError( | ||
"Minibatch size {} must be <= train batch size {}.".format( | ||
config["sgd_minibatch_size"], config["train_batch_size"])) | ||
if config["batch_mode"] == "truncate_episodes" and not config["use_gae"]: | ||
raise ValueError( | ||
"Episode truncation is not supported without a value " | ||
"function. Consider setting batch_mode=complete_episodes.") | ||
if config["multiagent"]["policies"] and not config["simple_optimizer"]: | ||
logger.info( | ||
"In multi-agent mode, policies will be optimized sequentially " | ||
"by the multi-GPU optimizer. Consider setting " | ||
"simple_optimizer=True if this doesn't work for you.") | ||
if config["simple_optimizer"]: | ||
logger.warning( | ||
"Using the simple minibatch optimizer. This will significantly " | ||
"reduce performance, consider simple_optimizer=False.") | ||
elif config["use_pytorch"] or (tf and tf.executing_eagerly()): | ||
config["simple_optimizer"] = True # multi-gpu not supported | ||
|
||
|
||
def get_policy_class(config): | ||
if config.get("use_pytorch") is True: | ||
from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy | ||
return PPOTorchPolicy | ||
else: | ||
return CustomPPOTFPolicy | ||
|
||
|
||
CustomPPOTrainer = build_trainer( | ||
name="PPO", | ||
default_config=DEFAULT_CONFIG, | ||
default_policy=CustomPPOTFPolicy, | ||
get_policy_class=get_policy_class, | ||
make_policy_optimizer=choose_policy_optimizer, | ||
validate_config=validate_config, | ||
after_optimizer_step=update_kl, | ||
after_train_result=warn_about_bad_reward_scales) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing blank line at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
undo please. These are here because we don't want them in the imports at the top since not everyone install stable_baselines