-
Notifications
You must be signed in to change notification settings - Fork 192
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added multi-agent adversial training demo, game of tag
- Loading branch information
1 parent
623cdcf
commit 1d2caa2
Showing
16 changed files
with
1,391 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Game of Tag | ||
This directory contains a a multi-agent adversarial training demo. In the demo, there is a predator vehicle and a prey vehicle. | ||
The predator vehicle's goal is to catch the prey, and the prey vehicle's goal is to avoid getting caught. | ||
|
||
## Run training | ||
python examples/game_of_tag/game_of_tag.py examples/game_of_tag/scenarios/game_of_tag_demo_map/ | ||
|
||
## Run checkpoint | ||
python examples/game_of_tag/run_checkpoint.py examples/game_of_tag/scenarios/game_of_tag_demo_map/ | ||
|
||
## Setup: | ||
### Rewards | ||
The formula for reward is 0.5/(distance-COLLIDE_DISTANCE)^2 and capped at 10 | ||
|
||
- COLLIDE_DISTANCE is the observed distance when two vehicle collides. Since the position of two vehicle is at the center, the distance when collesion happens is not exactly 0. | ||
|
||
### Common Reward: | ||
Off road: -10 | ||
|
||
#### Prey: | ||
Collision with predator: -10 | ||
Distance to predator(d): 0.5/(d-COLLIDE_DISTANCE)^2 | ||
#### Predator: | ||
Collision with predator: -10 | ||
Distance to predator(d): 0.5/(d-COLLIDE_DISTANCE)^2 | ||
|
||
### Action: | ||
Speed selection in m/s: [0, 3, 6, 9] | ||
|
||
Lane change selection relative to current lane: [-1, 0, 1] | ||
|
||
## Output a model: | ||
Currently Rllib does not have implementation for exporting a pytorch model. | ||
|
||
Replace `export_model`'s implementation in `ray/rllib/policy/torch_policy.py` to the following: | ||
``` | ||
torch.save(self.model.state_dict(),f"{export_dir}/model.pt") | ||
``` | ||
Then follow the steps in game_of_tag.py to export the model. | ||
|
||
## Possible next steps | ||
- Increase the number of agents to 2 predators and 2 prey. | ||
This requires modelling the reward to still be a zero sum game. The complication can be understood from | ||
how to model the distance reward between 2 predators and 1 prey. If the reward is only from nearest predator | ||
to nearest prey, the sum of predator and prey rewards will no longer be 0 because 2 predators will be getting full | ||
reward from 1 prey but the prey will only get full reward from 1 predator. This will require the predators to know about each | ||
other or the prey to know about other prey, and the prey to know about multiple predators. | ||
- Add an attribute in observations to display whether the ego car is in front of the target vehicle or behind it, this may | ||
help to let ego vehicle know whether it should slow down or speed up |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,270 @@ | ||
"""Let's play tag! | ||
A predator-prey multi-agent example built on top of RLlib to facilitate further | ||
developments on multi-agent support for HiWay (including design, performance, | ||
research, and scaling). | ||
The predator and prey use separate policies. A predator "catches" its prey when | ||
it collides into the other vehicle. There can be multiple predators and | ||
multiple prey in a map. Social vehicles act as obstacles where both the | ||
predator and prey must avoid them. | ||
""" | ||
import argparse | ||
import os | ||
import random | ||
import multiprocessing | ||
import ray | ||
|
||
|
||
import numpy as np | ||
from typing import List | ||
from ray import tune | ||
from ray.rllib.utils import try_import_tf | ||
from ray.rllib.models import ModelCatalog | ||
from ray.tune import Stopper | ||
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork | ||
from ray.tune.schedulers import PopulationBasedTraining | ||
from ray.rllib.agents.ppo import PPOTrainer | ||
from pathlib import Path | ||
|
||
from smarts.env.rllib_hiway_env import RLlibHiWayEnv | ||
from smarts.core.agent import AgentSpec, Agent | ||
from smarts.core.controllers import ActionSpaceType | ||
from smarts.core.agent_interface import AgentInterface, AgentType, DoneCriteria | ||
from smarts.core.utils.file import copy_tree | ||
|
||
|
||
from examples.game_of_tag.tag_adapters import * | ||
from examples.game_of_tag.model import CustomFCModel | ||
|
||
|
||
# Add custom metrics to your tensorboard using these callbacks | ||
# see: https://ray.readthedocs.io/en/latest/rllib-training.html#callbacks-and-custom-metrics | ||
def on_episode_start(info): | ||
episode = info["episode"] | ||
print("episode {} started".format(episode.episode_id)) | ||
|
||
|
||
def on_episode_step(info): | ||
episode = info["episode"] | ||
single_agent_id = list(episode._agent_to_last_obs)[0] | ||
obs = episode.last_raw_obs_for(single_agent_id) | ||
|
||
|
||
def on_episode_end(info): | ||
episode = info["episode"] | ||
|
||
|
||
def explore(config): | ||
# ensure we collect enough timesteps to do sgd | ||
if config["train_batch_size"] < config["sgd_minibatch_size"] * 2: | ||
config["train_batch_size"] = config["sgd_minibatch_size"] * 2 | ||
# ensure we run at least one sgd iter | ||
if config["num_sgd_iter"] < 1: | ||
config["num_sgd_iter"] = 1 | ||
return config | ||
|
||
|
||
PREDATOR_POLICY = "predator_policy" | ||
PREY_POLICY = "prey_policy" | ||
|
||
|
||
def policy_mapper(agent_id): | ||
if agent_id in PREDATOR_IDS: | ||
return PREDATOR_POLICY | ||
elif agent_id in PREY_IDS: | ||
return PREY_POLICY | ||
|
||
|
||
class TimeStopper(Stopper): | ||
def __init__(self): | ||
self._start = time.time() | ||
# Currently will see obvious tag behaviour in 6 hours | ||
self._deadline = 48 * 60 * 60 # train for 48 hours | ||
|
||
def __call__(self, trial_id, result): | ||
return False | ||
|
||
def stop_all(self): | ||
return time.time() - self._start > self._deadline | ||
|
||
|
||
tf = try_import_tf() | ||
|
||
ModelCatalog.register_custom_model("CustomFCModel", CustomFCModel) | ||
|
||
rllib_agents = {} | ||
|
||
shared_interface = AgentInterface( | ||
max_episode_steps=1500, | ||
neighborhood_vehicles=True, | ||
waypoints=True, | ||
action=ActionSpaceType.LaneWithContinuousSpeed, | ||
) | ||
shared_interface.done_criteria = DoneCriteria( | ||
off_route=False, | ||
wrong_way=False, | ||
collision=True, | ||
) | ||
|
||
for agent_id in PREDATOR_IDS: | ||
rllib_agents[agent_id] = { | ||
"agent_spec": AgentSpec( | ||
interface=shared_interface, | ||
agent_builder=lambda: TagModelAgent( | ||
os.path.join(os.path.dirname(os.path.realpath(__file__)), "model"), | ||
OBSERVATION_SPACE, | ||
), | ||
observation_adapter=observation_adapter, | ||
reward_adapter=predator_reward_adapter, | ||
action_adapter=action_adapter, | ||
), | ||
"observation_space": OBSERVATION_SPACE, | ||
"action_space": ACTION_SPACE, | ||
} | ||
|
||
for agent_id in PREY_IDS: | ||
rllib_agents[agent_id] = { | ||
"agent_spec": AgentSpec( | ||
interface=shared_interface, | ||
agent_builder=lambda: TagModelAgent( | ||
os.path.join(os.path.dirname(os.path.realpath(__file__)), "model"), | ||
OBSERVATION_SPACE, | ||
), | ||
observation_adapter=observation_adapter, | ||
reward_adapter=prey_reward_adapter, | ||
action_adapter=action_adapter, | ||
), | ||
"observation_space": OBSERVATION_SPACE, | ||
"action_space": ACTION_SPACE, | ||
} | ||
|
||
|
||
def build_tune_config(scenario, headless=True, sumo_headless=False): | ||
rllib_policies = { | ||
policy_mapper(agent_id): ( | ||
None, | ||
rllib_agent["observation_space"], | ||
rllib_agent["action_space"], | ||
{"model": {"custom_model": "CustomFCModel"}}, | ||
) | ||
for agent_id, rllib_agent in rllib_agents.items() | ||
} | ||
|
||
tune_config = { | ||
"env": RLlibHiWayEnv, | ||
"framework": "torch", | ||
"log_level": "WARN", | ||
"num_workers": 3, | ||
"explore": True, | ||
"horizon": 10000, | ||
"env_config": { | ||
"seed": 42, | ||
"sim_name": "game_of_tag_works?", | ||
"scenarios": [os.path.abspath(scenario)], | ||
"headless": headless, | ||
"sumo_headless": sumo_headless, | ||
"agent_specs": { | ||
agent_id: rllib_agent["agent_spec"] | ||
for agent_id, rllib_agent in rllib_agents.items() | ||
}, | ||
}, | ||
"multiagent": { | ||
"policies": rllib_policies, | ||
"policies_to_train": [PREDATOR_POLICY, PREY_POLICY], | ||
"policy_mapping_fn": policy_mapper, | ||
}, | ||
"callbacks": { | ||
"on_episode_start": on_episode_start, | ||
"on_episode_step": on_episode_step, | ||
"on_episode_end": on_episode_end, | ||
}, | ||
} | ||
return tune_config | ||
|
||
|
||
def main(args): | ||
pbt = PopulationBasedTraining( | ||
time_attr="time_total_s", | ||
metric="episode_reward_mean", | ||
mode="max", | ||
perturbation_interval=300, | ||
resample_probability=0.25, | ||
# Specifies the mutations of these hyperparams | ||
hyperparam_mutations={ | ||
"lambda": lambda: random.uniform(0.9, 1.0), | ||
"clip_param": lambda: random.uniform(0.01, 0.5), | ||
"kl_coeff": lambda: 0.3, | ||
"lr": [1e-3], | ||
"sgd_minibatch_size": lambda: 128, | ||
"train_batch_size": lambda: 4000, | ||
"num_sgd_iter": lambda: 30, | ||
}, | ||
custom_explore_fn=explore, | ||
) | ||
local_dir = os.path.expanduser(args.result_dir) | ||
|
||
tune_config = build_tune_config(args.scenario) | ||
|
||
tune.run( | ||
PPOTrainer, # Rllib supports using PPO in multi-agent setting | ||
name="lets_play_tag", | ||
stop=TimeStopper(), | ||
# XXX: Every X iterations perform a _ray actor_ checkpoint (this is | ||
# different than _exporting_ a TF/PT checkpoint). | ||
checkpoint_freq=5, | ||
checkpoint_at_end=True, | ||
# XXX: Beware, resuming after changing tune params will not pick up | ||
# the new arguments as they are stored alongside the checkpoint. | ||
resume=args.resume_training, | ||
# restore="path_to_training_checkpoint/checkpoint_x/checkpoint-x", | ||
local_dir=local_dir, | ||
reuse_actors=True, | ||
max_failures=0, | ||
export_formats=["model", "checkpoint"], | ||
config=tune_config, | ||
scheduler=pbt, | ||
) | ||
|
||
# # To output a model | ||
# # 1: comment out tune.run and uncomment the following code | ||
# # 2: replace checkpoint path to training checkpoint path | ||
# # 3: inject code in rllib according to README.md and run | ||
# checkpoint_path = os.path.join( | ||
# os.path.dirname(os.path.realpath(__file__)), "models/checkpoint_360/checkpoint-360" | ||
# ) | ||
# ray.init(num_cpus=2) | ||
# training_agent = PPOTrainer(env=RLlibHiWayEnv,config=tune_config) | ||
# training_agent.restore(checkpoint_path) | ||
# prefix = "model.ckpt" | ||
# model_dir = os.path.join( | ||
# os.path.dirname(os.path.realpath(__file__)), "models/predator_model" | ||
# ) | ||
# training_agent.export_policy_model(model_dir, PREDATOR_POLICY) | ||
# model_dir = os.path.join( | ||
# os.path.dirname(os.path.realpath(__file__)), "models/prey_model" | ||
# ) | ||
# training_agent.export_policy_model(model_dir, PREY_POLICY) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser("rllib-example") | ||
parser.add_argument( | ||
"scenario", | ||
type=str, | ||
help="Scenario to run (see scenarios/ for some samples you can use)", | ||
) | ||
parser.add_argument( | ||
"--resume_training", | ||
default=False, | ||
action="store_true", | ||
help="Resume the last trained example", | ||
) | ||
parser.add_argument( | ||
"--result_dir", | ||
type=str, | ||
default="~/ray_results", | ||
help="Directory containing results (and checkpointing)", | ||
) | ||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import torch, gym | ||
from torch import nn | ||
from torch.distributions.normal import Normal | ||
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 | ||
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFCNet | ||
|
||
|
||
class CustomFCModel(TorchModelV2, nn.Module): | ||
"""Example of interpreting repeated observations.""" | ||
|
||
def __init__( | ||
self, | ||
obs_space: gym.spaces.Space, | ||
action_space: gym.spaces.Space, | ||
num_outputs: int, | ||
model_config, | ||
name: str, | ||
): | ||
super(CustomFCModel, self).__init__( | ||
obs_space=obs_space, | ||
action_space=action_space, | ||
num_outputs=num_outputs, | ||
model_config=model_config, | ||
name=name, | ||
) | ||
nn.Module.__init__(self) | ||
|
||
self.model = TorchFCNet( | ||
obs_space, action_space, num_outputs, model_config, name | ||
) | ||
|
||
def forward(self, input_dict, state, seq_lens): | ||
|
||
return self.model.forward(input_dict, state, seq_lens) | ||
|
||
def value_function(self): | ||
return self.model.value_function() |
Empty file.
Binary file not shown.
Binary file added
BIN
+216 Bytes
examples/game_of_tag/models/checkpoint_360/checkpoint-360.tune_metadata
Binary file not shown.
Binary file not shown.
Binary file not shown.
Oops, something went wrong.