Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Imitation learning with dagger #906

Open
wants to merge 390 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
390 commits
Select commit Hold shift + click to select a range
c373e94
datapip pipeline implemented
brentgryffindor Apr 6, 2020
a88c209
get up to date with i210_dev
brentgryffindor May 19, 2020
89f8d1d
remove dupe imports
liljonnystyle May 20, 2020
306a01f
remove blank lines after docstrings
liljonnystyle May 20, 2020
0d5fa6b
add back ray import
liljonnystyle May 20, 2020
0ade197
remove whitespace
liljonnystyle May 20, 2020
1111e9a
moved imports under functions in train.py (#903)
chendiw Apr 21, 2020
a4c7d67
get not departed vehicles (#922)
Yasharzf May 8, 2020
36e8851
changed _departed_ids, and _arrived_ids in the update function (#926)
Yasharzf May 9, 2020
ebb2921
Add an on ramp option
eugenevinitsky Mar 18, 2020
e4c02bb
Increased inflows to 10800 to match density in Bennis ring
eugenevinitsky Mar 19, 2020
505d646
Upgrade the network to not have keepclear value on the junctions
eugenevinitsky Mar 19, 2020
7d52445
Add 1 lane highway network for Benni
eugenevinitsky Mar 25, 2020
c3b2a51
multiple runs issue solved, testing added
brentgryffindor Apr 11, 2020
dc881e0
added more support for lambda function
brentgryffindor Apr 22, 2020
ee1188e
fix windoes line ending issue with experiment.py
brentgryffindor Apr 23, 2020
65c9ee0
fix style issue
brentgryffindor Apr 23, 2020
5a3ff57
reorganized file locations
brentgryffindor Apr 23, 2020
ddc53fb
fix some more style issues
brentgryffindor Apr 23, 2020
e7ac1a9
fix one more style issue
brentgryffindor Apr 23, 2020
c970219
added new two new quries
brentgryffindor May 10, 2020
3b10524
including next_V for testing only
brentgryffindor May 11, 2020
638f9b4
change the bucket to a common bucket
brentgryffindor May 18, 2020
bc8584a
removed the old tests
brentgryffindor May 18, 2020
0ee6646
Add an on ramp option
eugenevinitsky Mar 18, 2020
3af5595
datapip pipeline implemented
brentgryffindor Apr 6, 2020
8d4ad29
multiple runs issue solved, testing added
brentgryffindor Apr 11, 2020
aa14dbf
added more support for lambda function
brentgryffindor Apr 22, 2020
00a526b
fix windoes line ending issue with experiment.py
brentgryffindor Apr 23, 2020
de35f90
fix style issue
brentgryffindor Apr 23, 2020
979d047
reorganized file locations
brentgryffindor Apr 23, 2020
fdd983e
fix some more style issues
brentgryffindor Apr 23, 2020
6af7e02
added auto upload to s3 feature for the reply scipt and fix some othe…
brentgryffindor May 19, 2020
72d4733
fix trailing white space style issue
brentgryffindor May 19, 2020
420ea3f
some minor issue fixed
brentgryffindor May 19, 2020
e45eb92
reformatting energy queries
liljonnystyle May 19, 2020
d578e63
rename vehicle power demand query
liljonnystyle May 19, 2020
32c0528
move partition condition to cte's
liljonnystyle May 19, 2020
c7cd963
fix some query string formatting issue
brentgryffindor May 19, 2020
b5be92a
fix some style issue
brentgryffindor May 19, 2020
6884960
get up to date with i210_dev
brentgryffindor May 19, 2020
7e549be
update lambda function, change partition into multi-column
brentgryffindor May 20, 2020
a799abd
remove dupe imports
liljonnystyle May 20, 2020
f4fa426
remove blank lines after docstrings
liljonnystyle May 20, 2020
2563818
add back ray import
liljonnystyle May 20, 2020
498e08a
remove whitespace
liljonnystyle May 20, 2020
d7da535
style fixed
brentgryffindor May 20, 2020
3df2312
specify power demand model names
liljonnystyle May 20, 2020
28d4f73
fix bug in vehicle power demand
liljonnystyle May 25, 2020
0779832
Add several accelerations (with/without noise, with/without failsafes…
liljonnystyle May 21, 2020
b3f15a3
update queries with new column names
liljonnystyle May 21, 2020
d66a0ab
fix flake8 issues
liljonnystyle May 21, 2020
38af177
remove trailing whitespaces
liljonnystyle May 21, 2020
fceedf8
Add several accelerations (with/without noise, with/without failsafes…
liljonnystyle May 21, 2020
df182ad
fix accel with noise with failsafe output
liljonnystyle May 25, 2020
d888405
fix rebase errors
liljonnystyle May 25, 2020
27e2960
fix merge conflicts
liljonnystyle May 26, 2020
69f6f55
rm deleted file
liljonnystyle May 26, 2020
4f2f23e
add return carriage to eof
liljonnystyle May 26, 2020
d2ba069
revert accidental change
liljonnystyle May 26, 2020
8eee772
rename trajectory table
liljonnystyle May 26, 2020
b5f5424
Merge branch 'i210_dev' into jl-more-accel-outputs
liljonnystyle May 26, 2020
db0442b
Ported to Keras, initial implementation of loading to RLLib
akashvelu May 26, 2020
39ad373
Merge branch 'i210_dev' of https://github.com/flow-project/flow into …
akashvelu May 26, 2020
ed065b3
Bug fixes for starting training from imitation model
akashvelu May 26, 2020
cc0aa32
Minor cleanup
akashvelu May 26, 2020
3c6dcf7
added apply acceleratino function which uses setSpeed() method instea…
Yasharzf May 26, 2020
3a2e135
Minor cleanup
akashvelu May 26, 2020
ddf6a24
added failsafe methods for max accel/decel and speed limit, and all
Yasharzf May 26, 2020
53cf035
removed json file which was added by mistake
Yasharzf May 26, 2020
b16d949
leader utils added
brentgryffindor May 26, 2020
86acecc
merge conflict resovled
brentgryffindor May 26, 2020
b49dbce
fixed merge conflicts
Yasharzf May 26, 2020
528f0aa
fixed docstrings
Yasharzf May 26, 2020
cbf6a42
removed duplicated print
Yasharzf May 26, 2020
288a1cf
Removed usage of rllib.utils.freamwork
akashvelu May 26, 2020
c1db60a
Changed location of h5 file
akashvelu May 27, 2020
8645811
minor docstring formatting
liljonnystyle May 27, 2020
6f8d878
fixed a monor error in energy query, added network in metadata
brentgryffindor May 27, 2020
db33f7c
fix a minor mistake in docstring
brentgryffindor May 27, 2020
089822a
flake8 fix
brentgryffindor May 27, 2020
c3756f8
Fixed trajectory_table_path
akashvelu May 27, 2020
7f68c50
Fixed trajectory_table_path
akashvelu May 27, 2020
7ab2b3e
Merge pull request #952 from flow-project/av-emission-path
akashvelu May 27, 2020
13a797c
Merge branch 'i210_dev' of https://github.com/flow-project/flow into …
akashvelu May 27, 2020
1669787
addressing comments
liljonnystyle May 27, 2020
243c895
Merge pull request #939 from flow-project/jl-more-accel-outputs
liljonnystyle May 27, 2020
4b853b5
Function for ppo architecture
akashvelu May 27, 2020
5c0923d
Load weights for training in train.py
akashvelu May 27, 2020
6ad0b0d
Merge branch 'i210_dev' into datapipeline_dev_v2
liljonnystyle May 27, 2020
c785944
Code structure changes
akashvelu May 28, 2020
1857f83
fixed naming convention
brentgryffindor May 28, 2020
87ebf59
do repair partition for all new data upon arrival
brentgryffindor May 28, 2020
64b8a47
Merge branch 'datapipeline_dev_v2' of https://github.com/flow-project…
brentgryffindor May 28, 2020
05e793a
added leaderboard chart aggregation
brentgryffindor May 28, 2020
7a75c6e
update lambda function, added some comments
brentgryffindor May 28, 2020
fef3a83
Combine imitation and PPO training into one step
akashvelu May 28, 2020
8fac720
minor change to get_table_disk
brentgryffindor May 28, 2020
9537bb4
fix minor path issue
brentgryffindor May 28, 2020
ff90e8d
move deleting leaderboard_chart_agg to after downloading
brentgryffindor May 28, 2020
f7a278c
Network update (#953)
eugenevinitsky May 28, 2020
4ebcc06
seperated speed limit check, modified orders
Yasharzf May 28, 2020
bd7622a
fix get metadata
brentgryffindor May 28, 2020
9444107
Merge branch 'i210_dev' into jl-more-accel-outputs
liljonnystyle May 28, 2020
11639bf
Merge pull request #954 from flow-project/jl-more-accel-outputs
liljonnystyle May 28, 2020
e9c0438
fix get metadata
brentgryffindor May 28, 2020
fd29e0f
Code cleanup
akashvelu May 28, 2020
2b6cc08
test with cluster
akashvelu May 28, 2020
8ef0179
Merge i210_dev
akashvelu May 28, 2020
39e4bc4
fix pathname
liljonnystyle May 29, 2020
d6b6c18
fix pathname
liljonnystyle May 29, 2020
3dfafe1
Bug fix
akashvelu May 29, 2020
ba67961
Minor cleanup
akashvelu May 29, 2020
82b252e
added edge_id, lane_id, and distance
brentgryffindor May 30, 2020
65791df
added netwokr name translation
brentgryffindor May 30, 2020
f1ef8e2
fix some query bugs
liljonnystyle May 30, 2020
4b5cb41
update values for warm-up time and horizon
liljonnystyle May 30, 2020
4c0358a
leaderboard chart agg query fixes
liljonnystyle May 30, 2020
ce18a36
remove unnecessary references to "x"
liljonnystyle May 30, 2020
6c7c3c0
fix some error in query
brentgryffindor May 30, 2020
8c52264
merge conflict resolved
brentgryffindor May 30, 2020
7743445
Merge branch 'i210_dev' into datapipeline_dev_v2
brentgryffindor May 30, 2020
f0aa7b4
added metadata as a table, update realized_accel at timestep 0, fixed…
brentgryffindor May 31, 2020
b46d5f9
minor re-formats
liljonnystyle Jun 1, 2020
15b646b
update docstring for write_dict_to_csv
liljonnystyle Jun 1, 2020
17802fd
simplify network_name_translate
liljonnystyle Jun 1, 2020
a0b60c5
modify constraint specifications for run_query
liljonnystyle Jun 1, 2020
79afdae
rename loc_filter, add time filters
liljonnystyle Jun 1, 2020
09a26a3
rename constraints to filters
liljonnystyle Jun 1, 2020
f7a4d6e
rename constraints to filters
liljonnystyle Jun 1, 2020
8305165
tweak queries for styling
liljonnystyle Jun 1, 2020
ac8c545
remove outer joins to avoid edge cases
liljonnystyle Jun 1, 2020
f9f75af
rename query_date to submission-date
liljonnystyle Jun 1, 2020
660891a
write simulation result to disk every 100 time step
Jun 1, 2020
95933fa
merge conflict resolved
brentgryffindor Jun 1, 2020
d617c2f
reformat upload_to_s3, add missing comma
liljonnystyle Jun 1, 2020
a7eda70
fix i210 replay data collection
brentgryffindor Jun 1, 2020
248fe57
Merge branch 'datapipeline_dev_v2' of https://github.com/flow-project…
brentgryffindor Jun 1, 2020
f5f000e
remove extra comma
brentgryffindor Jun 1, 2020
9fa0027
Merge pull request #949 from flow-project/datapipeline_dev_v2
brentgryffindor Jun 1, 2020
1e42556
fix network name mapping for highway-single
brentgryffindor Jun 1, 2020
bff9e47
Load weights into rllib
akashvelu Jun 2, 2020
1ae0081
fix data collection issue in i210_replay
brentgryffindor Jun 2, 2020
f7d9ec1
fix the network name mapping for i210 single lane
brentgryffindor Jun 2, 2020
db29793
new merge
akashvelu Jun 2, 2020
a0c3904
Minor changes to support cusotm PPO
eugenevinitsky Jun 2, 2020
29ed0e7
Merge branch 'i210_dev' of github.com:flow-project/flow into i210_dev
eugenevinitsky Jun 2, 2020
d10f8e5
Value function learning
akashvelu Jun 3, 2020
7fa3e3a
Tensorboard plotting for loss
akashvelu Jun 3, 2020
91fab74
Bug fixes
akashvelu Jun 3, 2020
d38839f
Code cleanup
akashvelu Jun 4, 2020
b50c9a5
Merge branch 'master' of https://github.com/flow-project/flow into i2…
AboudyKreidieh Jun 10, 2020
c27bfe5
pep8 (mostly
AboudyKreidieh Jun 10, 2020
4e8769b
pydocstyle (mostly
AboudyKreidieh Jun 10, 2020
81e8d6a
Test file changes
akashvelu Jun 11, 2020
240cc05
Code cleanup
akashvelu Jun 14, 2020
29b06d8
Merge branch 'master' of https://github.com/flow-project/flow into i2…
AboudyKreidieh Jun 15, 2020
d9470cb
some cleanup
AboudyKreidieh Jun 15, 2020
1f6ceee
removed unused simulation
AboudyKreidieh Jun 15, 2020
0253fb1
Merge branch 'master' of https://github.com/flow-project/flow into i2…
AboudyKreidieh Jun 16, 2020
7fd11b1
Add queries for safety metrics reporting
liljonnystyle Jun 16, 2020
c7937ff
fix typo
liljonnystyle Jun 16, 2020
d523b74
filter warmup steps and ghost edges from safety calculation
liljonnystyle Jun 16, 2020
946938a
invert safety_rate definition
liljonnystyle Jun 16, 2020
10bf24f
Flag reconfig
akashvelu Jun 17, 2020
aa72b2e
flag cleanup
akashvelu Jun 17, 2020
e209de2
Reorganize method arguments
akashvelu Jun 17, 2020
5ce3c4d
Argument reorganizing
akashvelu Jun 17, 2020
3857a44
Merge branch 'i210_dev' of github.com:flow-project/flow into i210_dev
eugenevinitsky Jun 17, 2020
1aae8f8
Cleanup and rearrange args
akashvelu Jun 18, 2020
f7451d0
Custom PPO to log value function predictions:
akashvelu Jun 18, 2020
9519c1f
Merged i210 dev
akashvelu Jun 18, 2020
5ac9828
cleanup to the multi-agent trainer (#971)
AboudyKreidieh Jun 18, 2020
5746937
Metadata Configuration (#957)
brentgryffindor Jun 18, 2020
2d96460
Cleanup to train.py
akashvelu Jun 18, 2020
25e623a
timespace diagram merge bug fix
akashvelu Jun 19, 2020
2e76a4c
reduce time-bins to 10s
liljonnystyle Jun 19, 2020
0c7de60
reduce time-bins in more places
liljonnystyle Jun 19, 2020
71dee84
docstring fix
akashvelu Jun 19, 2020
973528e
Merge pull request #974 from flow-project/av-time-space-fix2
akashvelu Jun 19, 2020
dfb1c07
add query to count vehicles in domain at every timestep
liljonnystyle Jun 19, 2020
23c55fe
fix typo in window function
liljonnystyle Jun 19, 2020
18c0d9e
add imitation custom models
akashvelu Jun 19, 2020
024cb93
code cleanup
akashvelu Jun 19, 2020
16a9ced
Merge branch 'i210_dev' of https://github.com/flow-project/flow into …
akashvelu Jun 19, 2020
7ac4c32
implement _get_abs_pos() for HighwayNetwork
liljonnystyle Jun 19, 2020
588bffd
Merge pull request #978 from flow-project/jl-highway-tsd
liljonnystyle Jun 19, 2020
5de54b7
remove trailing whitespaces
liljonnystyle Jun 19, 2020
38a6d70
remove unused import
liljonnystyle Jun 19, 2020
3791048
fix flake8 issues
liljonnystyle Jun 19, 2020
ed01357
remove unused error variable
liljonnystyle Jun 19, 2020
b9fd3be
add expected blank line before function
liljonnystyle Jun 19, 2020
62ee8a0
add specified exception to try
liljonnystyle Jun 19, 2020
885ab6f
custom ppo for vf plotting edits
akashvelu Jun 19, 2020
01676d9
correct some docstring inconsistencies
liljonnystyle Jun 19, 2020
85fdd63
i210 imitation model file
akashvelu Jun 20, 2020
dd24fb0
Add query to produce max score line in leaderboard
liljonnystyle Jun 21, 2020
4e6a9b2
Add I210 edgestarts
liljonnystyle Jun 22, 2020
61c0885
Merge branch 'i210_dev' of github.com:flow-project/flow into i210_dev
eugenevinitsky Jun 22, 2020
3a5508c
Replace strategic mode with the new name, sumo_default
eugenevinitsky Jun 22, 2020
55e10fd
Merge branch 'i210_dev' into jl-safety-metrics
liljonnystyle Jun 23, 2020
3725e83
remove trailing whitespace
liljonnystyle Jun 23, 2020
f87f67a
fix CASE syntax error
liljonnystyle Jun 23, 2020
2a0d9cc
Merge pull request #972 from flow-project/jl-safety-metrics
liljonnystyle Jun 23, 2020
54ce4ec
reduce time-bins to 10s
liljonnystyle Jun 19, 2020
be5b853
reduce time-bins in more places
liljonnystyle Jun 19, 2020
55848a4
Merge branch 'jl-time-bins' of https://github.com/flow-project/flow i…
liljonnystyle Jun 23, 2020
85f65c5
Merge pull request #975 from flow-project/jl-time-bins
liljonnystyle Jun 23, 2020
2708287
Merge branch 'i210_dev' into jl-veh-counts
liljonnystyle Jun 23, 2020
e3de3db
fix groupby/window fn error
liljonnystyle Jun 23, 2020
afbfb6d
Merge pull request #976 from flow-project/jl-veh-counts
liljonnystyle Jun 23, 2020
7f406cb
fix is_baseline data type
liljonnystyle Jun 23, 2020
ce4cdb6
Merge pull request #983 from flow-project/jl-leaderboard-best
liljonnystyle Jun 23, 2020
5d897af
change schema, vehicle_counts -> vehicle_count
liljonnystyle Jun 23, 2020
7ddf890
fix some query bugs
brentgryffindor Jun 24, 2020
9dd65c8
Code cleanup
akashvelu Jun 25, 2020
739c2ca
test files synced to i210_dev
akashvelu Jun 25, 2020
ddce32e
Cleanup code
akashvelu Jun 26, 2020
4e6302e
Handle case with vehicle in no-control edge
akashvelu Jun 26, 2020
97cfdee
grey out warmup period and ghost cells
liljonnystyle Jul 2, 2020
e22189e
fix rectangle positioning for both networks
liljonnystyle Jul 2, 2020
06ff2d9
Merge pull request #989 from flow-project/jl-tsd-mask
liljonnystyle Jul 2, 2020
c830f78
Reward options in I210-dev
eugenevinitsky Jul 2, 2020
cb74b8c
fix pydocstyle
liljonnystyle Jul 6, 2020
eb2416b
add docstring
liljonnystyle Jul 6, 2020
6ed00e3
remove excess whitespace
liljonnystyle Jul 6, 2020
b80e563
only call get_configuration() if to_aws
liljonnystyle Jul 7, 2020
7c9a48a
Energy class for inventorying multiple energy models (#944)
liljonnystyle Jul 8, 2020
5b7e8b2
Time-Space Diagrams automatically to S3 (#993)
liljonnystyle Jul 8, 2020
c4ba7ad
Query Prereq Check (#987)
brentgryffindor Jul 8, 2020
bb1f4f5
remove extra whitespace
liljonnystyle Jul 8, 2020
9f1a834
whitespace linting
liljonnystyle Jul 8, 2020
220994e
Update energy query with new power demand model (#996)
liljonnystyle Jul 9, 2020
f1ded54
Power-Demand Model fix (#995)
liljonnystyle Jul 9, 2020
f63cc37
convert tacoma fc to gallons per hour
liljonnystyle Jul 9, 2020
c2836e8
comment on road grade; exception handling on unpickling
liljonnystyle Jul 9, 2020
29eb5a0
Add learning rate as a parameter, override import_from_h5 method usin…
akashvelu Jul 10, 2020
97333cf
add --multi_node flag
nathanlct Jul 10, 2020
f7e1d78
Merge pull request #998 from flow-project/i210_add_multinode
nathanlct Jul 10, 2020
3ac508a
Ak/i210 master merge (#994)
AboudyKreidieh Jul 11, 2020
bb94c27
remove line from testing
liljonnystyle Jul 11, 2020
d373965
fix toyota temp file removal
liljonnystyle Jul 11, 2020
ab6732e
fix fc <> power unit conversion
liljonnystyle Jul 11, 2020
c0de59b
make default highway single penetration rate 0
liljonnystyle Jul 11, 2020
5f6acc2
use 1609.34 meters per mile
liljonnystyle Jul 11, 2020
7a773e3
fix av routing controller if no on-ramp
liljonnystyle Jul 11, 2020
0e8be95
Time-Space Diagram offset axes (#999)
liljonnystyle Jul 12, 2020
6c68800
Move imitation to algorithms folder
akashvelu Jul 13, 2020
d73612f
Merge i210 into branch
akashvelu Jul 13, 2020
6aca7c5
Revert model architecture and # rollouts to previous defaults
akashvelu Jul 13, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions flow/controllers/dagger/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os
import time
import numpy as np
import tensorflow as tf
from trainer import Trainer
from flow.controllers.car_following_models import IDMController


class Runner(object):
""" Class to run imitation learning (training and evaluation) """

def __init__(self, params):

# initialize trainer
self.params = params
self.trainer = Trainer(params)

def run_training_loop(self):

self.trainer.run_training_loop(n_iter=self.params['n_iter'])

def evaluate(self):
self.trainer.evaluate_controller(num_trajs=self.params['num_eval_episodes'])

def save_controller_network(self):
self.trainer.save_controller_network()


def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--ep_len', type=int, default=3000)

parser.add_argument('--num_agent_train_steps_per_iter', type=int, default=1000) # number of gradient steps for training policy (per iter in n_iter)
parser.add_argument('--n_iter', '-n', type=int, default=5)

parser.add_argument('--batch_size', type=int, default=1000) # training data collected (in the env) during each iteration
parser.add_argument('--init_batch_size', type=int, default=3000)

parser.add_argument('--train_batch_size', type=int,
default=100) # number of sampled data points to be used per gradient/train step

parser.add_argument('--num_layers', type=int, default=3) # depth, of policy to be learned
parser.add_argument('--size', type=int, default=64) # width of each layer, of policy to be learned
parser.add_argument('--learning_rate', '-lr', type=float, default=5e-3) # learning rate for supervised learning
parser.add_argument('--replay_buffer_size', type=int, default=1000000)
parser.add_argument('--save_path', type=str, default='')
parser.add_argument('--save_model', type=int, default=0)
parser.add_argument('--num_eval_episodes', type=int, default=10)
parser.add_argument('--inject_noise', type=int, default=0)
parser.add_argument('--noise_variance',type=float, default=0.5)
parser.add_argument('--vehicle_id', type=str, default='rl_0')

args = parser.parse_args()

# convert args to dictionary
params = vars(args)
assert args.n_iter>1, ('DAgger needs >1 iteration')


# run training
train = Runner(params)
train.run_training_loop()

# evaluate
train.evaluate()
print("DONE")

if params['save_model'] == 1:
train.save_controller_network()

# tensorboard
if params['save_model'] == 1:
writer = tf.summary.FileWriter('./graphs2', tf.get_default_graph())


if __name__ == "__main__":
main()
179 changes: 179 additions & 0 deletions flow/controllers/dagger/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import time
from collections import OrderedDict
import pickle
import numpy as np
import tensorflow as tf
import gym
import os
from flow.utils.registry import make_create_env
from bottleneck_env import flow_params
from imitating_controller import ImitatingController
from flow.controllers.car_following_models import IDMController
from flow.controllers.velocity_controllers import FollowerStopper
from flow.core.params import SumoCarFollowingParams
from utils import *

class Trainer(object):
"""
Class to initialize and run training for imitation learning (with DAgger)
"""

def __init__(self, params):
self.params = params
self.sess = create_tf_session()

create_env, _ = make_create_env(flow_params)
self.env = create_env()
self.env.reset()

print(self.env.k.vehicle.get_ids())
assert self.params['vehicle_id'] in self.env.k.vehicle.get_ids()
self.vehicle_id = self.params['vehicle_id']

obs_dim = self.env.observation_space.shape[0]

action_dim = (1,)[0]
self.params['action_dim'] = action_dim
self.params['obs_dim'] = obs_dim

car_following_params = SumoCarFollowingParams()
self.controller = ImitatingController(self.vehicle_id, self.sess, self.params['action_dim'], self.params['obs_dim'], self.params['num_layers'], self.params['size'], self.params['learning_rate'], self.params['replay_buffer_size'], car_following_params = car_following_params, inject_noise=self.params['inject_noise'], noise_variance=self.params['noise_variance'])
# self.expert_controller = IDMController(self.vehicle_id, car_following_params = car_following_params)
self.expert_controller = FollowerStopper(self.vehicle_id, car_following_params = car_following_params)

tf.global_variables_initializer().run(session=self.sess)


def run_training_loop(self, n_iter):
"""
Trains controller for n_iter iterations

Args:
param n_iter: number of iterations to execute training
"""

# init vars at beginning of training
self.total_envsteps = 0
self.start_time = time.time()

for itr in range(n_iter):
print("\n\n********** Iteration %i ************"%itr)

# collect trajectories, to be used for training
if itr == 0:
# first iteration is standard behavioral cloning
training_returns = self.collect_training_trajectories(itr, self.params['init_batch_size'])
else:
training_returns = self.collect_training_trajectories(itr, self.params['batch_size'])

paths, envsteps_this_batch = training_returns
self.total_envsteps += envsteps_this_batch

# add collected data to replay buffer
self.controller.add_to_replay_buffer(paths)

# train controller (using sampled data from replay buffer)
loss = self.train_controller()

def collect_training_trajectories(self, itr, batch_size):
"""
Collect (state, action, reward, next_state, terminal) tuples for training

Args:
itr: iteration of training during which functino is called
batch_size: number of tuples to collect
Returns:
paths: list of trajectories
envsteps_this_batch: the sum over the numbers of environment steps in paths
"""

if itr == 0:
collect_controller = self.expert_controller
else:
collect_controller = self.controller

print("\nCollecting data to be used for training...")
paths, envsteps_this_batch = sample_trajectories(self.env, self.vehicle_id, collect_controller, self.expert_controller, batch_size, self.params['ep_len'])

return paths, envsteps_this_batch

def train_controller(self):
"""
Trains controller using data sampled from replay buffer
"""

print('Training controller using sampled data from replay buffer')
for train_step in range(self.params['num_agent_train_steps_per_iter']):
ob_batch, ac_batch, expert_ac_batch, re_batch, next_ob_batch, terminal_batch = self.controller.sample_data(self.params['train_batch_size'])
self.controller.train(ob_batch, expert_ac_batch)

def evaluate_controller(self, num_trajs = 10):
"""
Evaluates a trained controller on similarity with expert with respect to action taken and total reward per rollout

Args:
num_trajs: number of trajectories to evaluate performance on
"""

print("\n\n********** Evaluation ************ \n")

trajectories = sample_n_trajectories(self.env, self.vehicle_id, self.controller, self.expert_controller, num_trajs, self.params['ep_len'])

average_imitator_reward = 0
total_imitator_steps = 0
average_imitator_reward_per_rollout = 0

action_errors = np.array([])
average_action_expert = 0
average_action_imitator = 0

# compare actions taken in each step of trajectories
for traj in trajectories:
imitator_actions = traj['actions']
expert_actions = traj['expert_actions']

average_action_expert += np.sum(expert_actions)
average_action_imitator += np.sum(imitator_actions)

action_error = np.linalg.norm(imitator_actions - expert_actions) / len(imitator_actions)
action_errors = np.append(action_errors, action_error)

average_imitator_reward += np.sum(traj['rewards'])
total_imitator_steps += len(traj['rewards'])
average_imitator_reward_per_rollout += np.sum(traj['rewards'])

average_imitator_reward = average_imitator_reward / total_imitator_steps
average_imitator_reward_per_rollout = average_imitator_reward_per_rollout / len(trajectories)

average_action_expert = average_action_expert / total_imitator_steps
average_action_imitator = average_action_imitator / total_imitator_steps


expert_trajectories = sample_n_trajectories(self.env, self.vehicle_id, self.expert_controller, self.expert_controller, num_trajs, self.params['ep_len'])

average_expert_reward = 0
total_expert_steps = 0
average_expert_reward_per_rollout = 0

# compare reward accumulated in trajectories collected via expert vs. via imitator
for traj in expert_trajectories:
average_expert_reward += np.sum(traj['rewards'])
total_expert_steps += len(traj['rewards'])
average_expert_reward_per_rollout += np.sum(traj['rewards'])

average_expert_reward_per_rollout = average_expert_reward_per_rollout / len(expert_trajectories)
average_expert_reward = average_expert_reward / total_expert_steps

print("\nAVERAGE REWARD PER STEP EXPERT: ", average_expert_reward)
print("AVERAGE REWARD PER STEP IMITATOR: ", average_imitator_reward)
print("AVERAGE REWARD PER STEP DIFFERENCE: ", np.abs(average_expert_reward - average_imitator_reward), "\n")

print("AVERAGE REWARD PER ROLLOUT EXPERT: ", average_expert_reward_per_rollout)
print("AVERAGE REWARD PER ROLLOUT IMITATOR: ", average_imitator_reward_per_rollout)
print("AVERAGE REWARD PER ROLLOUT DIFFERENCE: ", np.abs(average_expert_reward_per_rollout - average_imitator_reward_per_rollout), "\n")

print("MEAN ACTION ERROR: ", np.mean(action_errors), "\n")

def save_controller_network(self):
print("Saving tensorflow model to: ", self.params['save_path'])
self.controller.save_network(self.params['save_path'])
Loading