Skip to content

Commit

Permalink
Multiagent training and test scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
JacopoPan committed Nov 5, 2020
1 parent 99db5b1 commit b08eafd
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 76 deletions.
107 changes: 71 additions & 36 deletions experiments/learning/multiagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
import argparse
from datetime import datetime
import subprocess
import pdb
import math
import numpy as np
Expand All @@ -17,6 +18,7 @@
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork
import ray
from ray import tune
from ray.tune.logger import DEFAULT_LOGGERS
from ray.tune import register_env
from ray.rllib.agents import ppo
from ray.rllib.agents.ppo import PPOTrainer, PPOTFPolicy
Expand All @@ -32,21 +34,16 @@
from gym_pybullet_drones.envs.multi_agent_rl.FlockAviary import FlockAviary
from gym_pybullet_drones.envs.multi_agent_rl.LeaderFollowerAviary import LeaderFollowerAviary
from gym_pybullet_drones.envs.multi_agent_rl.MeetupAviary import MeetupAviary
from gym_pybullet_drones.envs.single_agent_rl.BaseSingleAgentAviary import ActionType, ObservationType
from gym_pybullet_drones.utils.Logger import Logger
from gym_pybullet_drones.utils.utils import *

######################################################################################################################################################
class FillInActions(DefaultCallbacks):
def on_postprocess_trajectory(self, worker, episode, agent_id, policy_id, policies, postprocessed_batch, original_batches, **kwargs):
action_vec_size = 4
to_update = postprocessed_batch[SampleBatch.CUR_OBS]
other_id = 1 if agent_id == 0 else 0
action_encoder = ModelCatalog.get_preprocessor_for_space(
Box(-np.inf, np.inf, (action_vec_size,), np.float32) # Unbounded
)
_, opponent_batch = original_batches[other_id]
opponent_actions = np.array([ action_encoder.transform(a) for a in opponent_batch[SampleBatch.ACTIONS] ])
to_update[:, -action_vec_size:] = opponent_actions
####################
AGGR_PHY_STEPS = 5
####################

ACTION_VEC_SIZE = -1
OWN_OBS_VEC_SIZE = -1

######################################################################################################################################################
class CustomTorchCentralizedCriticModel(TorchModelV2, nn.Module):
Expand All @@ -66,7 +63,7 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, name):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
nn.Module.__init__(self)
self.action_model = FullyConnectedNetwork(
Box(low=-1, high=1, shape=(20, )), # DOUBLE CHECK / MODIFY THIS # Box(low=0, high=1, shape=(6, )), # one-hot encoded Discrete(6)
Box(low=-1, high=1, shape=(OWN_OBS_VEC_SIZE, )),
action_space,
num_outputs,
model_config,
Expand All @@ -89,19 +86,30 @@ def value_function(self):
value_out, _ = self.value_model({ "obs": self._model_in[0] }, self._model_in[1], self._model_in[2])
return torch.reshape(value_out, [-1])

######################################################################################################################################################
class FillInActions(DefaultCallbacks):
def on_postprocess_trajectory(self, worker, episode, agent_id, policy_id, policies, postprocessed_batch, original_batches, **kwargs):
to_update = postprocessed_batch[SampleBatch.CUR_OBS]
other_id = 1 if agent_id == 0 else 0
action_encoder = ModelCatalog.get_preprocessor_for_space(
Box(-np.inf, np.inf, (ACTION_VEC_SIZE,), np.float32) # Unbounded
)
_, opponent_batch = original_batches[other_id]
opponent_actions = np.array([ action_encoder.transform(a) for a in opponent_batch[SampleBatch.ACTIONS] ])
to_update[:, -ACTION_VEC_SIZE:] = opponent_actions

######################################################################################################################################################
def central_critic_observer(agent_obs, **kw):
action_vec_size = 4
new_obs = {
0: {
"own_obs": agent_obs[0],
"opponent_obs": agent_obs[1],
"opponent_action": np.zeros(action_vec_size), # Filled in by FillInActions
"opponent_action": np.zeros(ACTION_VEC_SIZE), # Filled in by FillInActions
},
1: {
"own_obs": agent_obs[1],
"opponent_obs": agent_obs[0],
"opponent_action": np.zeros(action_vec_size), # Filled in by FillInActions
"opponent_action": np.zeros(ACTION_VEC_SIZE), # Filled in by FillInActions
},
}
return new_obs
Expand All @@ -111,44 +119,71 @@ def central_critic_observer(agent_obs, **kw):

#### Define and parse (optional) arguments for the script ##########################################
parser = argparse.ArgumentParser(description='Multi-agent reinforcement learning experiments script')
parser.add_argument('--num_drones', default=2, type=int, help='Number of drones (default: 2)', metavar='')
parser.add_argument('--env', default='flock', type=str, choices=['leaderfollower', 'flock', 'meetup'], help='Help (default: ..)', metavar='')
parser.add_argument('--num_drones', default=2, type=int, help='Number of drones (default: 2)', metavar='')
parser.add_argument('--env', default='flock', type=str, choices=['leaderfollower', 'flock', 'meetup'], help='Help (default: ..)', metavar='')
parser.add_argument('--obs', default='kin', type=ObservationType, help='Help (default: ..)', metavar='')
parser.add_argument('--act', default='one_d_rpm', type=ActionType, help='Help (default: ..)', metavar='')
parser.add_argument('--algo', default='cc', type=str, choices=['cc'], help='Help (default: ..)', metavar='')
parser.add_argument('--workers', default=0, type=int, help='Help (default: ..)', metavar='')
ARGS = parser.parse_args()
filename = os.path.dirname(os.path.abspath(__file__))+'/save-'+ARGS.env+'-'+datetime.now().strftime("%m.%d.%Y_%H.%M.%S")

#### Save directory ################################################################################
filename = os.path.dirname(os.path.abspath(__file__))+'/results/save-'+ARGS.env+'-'+str(ARGS.num_drones)+'-'+ARGS.algo+'-'+ARGS.obs.value+'-'+ARGS.act.value+'-'+datetime.now().strftime("%m.%d.%Y_%H.%M.%S")
if not os.path.exists(filename): os.makedirs(filename+'/')

#### Print out current git commit hash #############################################################
git_commit = subprocess.check_output(["git", "describe", "--tags"]).strip(); print(git_commit)
with open(filename+'/git_commit.txt', 'w+') as f: f.write(str(git_commit))

#### Constants, and errors #########################################################################
if ARGS.obs==ObservationType.KIN: OWN_OBS_VEC_SIZE = 12
elif ARGS.obs==ObservationType.RGB: print("[ERROR] ObservationType.RGB for multi-agent systems not yet implemented"); exit()
else: print("[ERROR] unknown ObservationType"); exit()
if ARGS.act in [ActionType.ONE_D_RPM, ActionType.ONE_D_DYN, ActionType.ONE_D_PID]: ACTION_VEC_SIZE = 1
elif ARGS.act in [ActionType.RPM, ActionType.DYN]: ACTION_VEC_SIZE = 4
elif ARGS.act==ActionType.PID: ACTION_VEC_SIZE = 3
else: print("[ERROR] unknown ActionType"); exit()

#### Uncomment to debug slurm scripts ##############################################################
# exit()

#### Initialize Ray Tune ###########################################################################
ray.shutdown()
ray.init(ignore_reinit_error=True)


#### Register the custom centralized critic model ##################################################
ModelCatalog.register_custom_model( "cc_model", CustomTorchCentralizedCriticModel )
ModelCatalog.register_custom_model("cc_model", CustomTorchCentralizedCriticModel)

#### Register the environment ######################################################################
temp_env_name = "this-aviary-v0"
if ARGS.env=='flock': register_env(temp_env_name, lambda _: FlockAviary(num_drones=ARGS.num_drones, aggregate_phy_steps=AGGR_PHY_STEPS, obs=ARGS.obs, act=ARGS.act))
else: print("[ERROR] not yet implemented"); exit()

#### Unused env to extract correctly sized action and observation spaces ###########################
temp_env = FlockAviary(num_drones=ARGS.num_drones)
if ARGS.env=='flock': temp_env = FlockAviary(num_drones=ARGS.num_drones, aggregate_phy_steps=AGGR_PHY_STEPS, obs=ARGS.obs, act=ARGS.act)
else: print("[ERROR] not yet implemented"); exit()
observer_space = Dict({
"own_obs": temp_env.observation_space[0], # Box(-1.0, 1.0, (20,), np.float32) or Dict(neighbors:MultiBinary(2), state:Box(-1.0, 1.0, (20,), np.float32))
"own_obs": temp_env.observation_space[0],
"opponent_obs": temp_env.observation_space[0],
"opponent_action": temp_env.action_space[0],
})
action_space = temp_env.action_space[0] # Box(-1.0, 1.0, (4,), np.float32)

#### Register the environment ######################################################################
register_env("this-flock-aviary-v0", lambda _: FlockAviary(num_drones=ARGS.num_drones))
action_space = temp_env.action_space[0]

#### Set up the trainer's config ###################################################################
config = ppo.DEFAULT_CONFIG.copy() # For the default config, see github.com/ray-project/ray/blob/master/rllib/agents/trainer.py
config = {
"env": "this-flock-aviary-v0",
"env": temp_env_name,
"num_workers": 0+ARGS.workers,
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0
"batch_mode": "complete_episodes",
"callbacks": FillInActions,
# "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0
"num_workers": 0,
"model": {
"custom_model": "cc_model",
},
"framework": "torch",
}

#### Set up the model parameters of the trainer's config ###########################################
config["model"] = {
"custom_model": "cc_model",
}

#### Set up the multiagent parameters of the trainer's config ######################################
config["multiagent"] = {
Expand All @@ -173,14 +208,14 @@ def central_critic_observer(agent_obs, **kw):
stop=stop,
config=config,
verbose=True,
checkpoint_at_end=True
checkpoint_at_end=True,
local_dir=filename,
)
# check_learning_achieved(results, 1.0)

#### Save agent #################################################################################
checkpoints = results.get_trial_checkpoints_paths(trial=results.get_best_trial('episode_reward_mean',mode='max'), metric='episode_reward_mean')
checkpoint_path = checkpoints[0][0]
print(checkpoint_path)
with open(filename+'/checkpoint.txt', 'w+') as f: f.write(checkpoints[0][0])

#### Shut down Ray #################################################################################
ray.shutdown()
Expand Down
Loading

0 comments on commit b08eafd

Please sign in to comment.