Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add learning agent #31

Merged
merged 280 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
280 commits
Select commit Hold shift + click to select a range
59b626f
Fix formatting
camall3n May 13, 2023
7a4b4f4
Change var name discs -> discreps
camall3n May 13, 2023
647ed3f
Add discrep progress plot for queue search
camall3n May 13, 2023
7d96616
Refactor optimize_memory_optuna & prepare for other optimizers
camall3n May 13, 2023
1b80306
Refactor learning agent out of search node
camall3n May 13, 2023
cffb1a8
Refactor search node into utils/discrete_search
camall3n May 13, 2023
06f60d9
Refactor fifo-queue search into learning agent
camall3n May 13, 2023
ac8dd88
Refactor simulated annealing into learning agent
camall3n May 13, 2023
a78d295
Refactor all optimizers into optimize_memory()
camall3n May 13, 2023
f7f9671
Hook discrete mem optimization script up to args.mem_optimizer
camall3n May 13, 2023
8d8d386
Clean up discrete results folder & hook into study_name
camall3n May 13, 2023
46cefb5
Fix simulated annealing bug where improvements led to overflow
camall3n May 13, 2023
61676b5
Fix bug in simulated annealing where obs index could go out of bounds
camall3n May 13, 2023
e315044
Add plot script for discrete experiments
camall3n May 13, 2023
b316256
Fix bug where study_dir wasn't getting set
camall3n May 13, 2023
b62c9a8
Fix bug where study_dir wasn't getting created
camall3n May 13, 2023
0fe6ccb
Make policy iteration sane again
camall3n May 13, 2023
7f27917
Switch to hardmax
camall3n May 13, 2023
9a03988
Add tqdm for search-queue and annealing
camall3n May 13, 2023
92d481f
Consolidate tqdm bars for filling replay buffer
camall3n May 13, 2023
0b6959a
Add option for pruning queue search when parent node is suboptimal
camall3n May 13, 2023
01b27fa
Add hash function to SearchNode and test with PriorityQueue
camall3n May 13, 2023
cdc12be
Integrate priority queue into queue-based mem search
camall3n May 13, 2023
a78afce
Disable priority queue by default
camall3n May 13, 2023
75bb211
[WIP] tune annealing to work for tmaze & ~cheese
camall3n May 14, 2023
d08dc11
[WIP] slightly better settings(?) and add n_repeats
camall3n May 14, 2023
6184457
Stop analytical policy iteration if policy stops changing
camall3n May 14, 2023
9c26bc2
[snapshot] slightly too short to get good cheese performance
camall3n May 14, 2023
121775e
[snapshot] works for tmaze, cheese, 4x3, shuttle; scale by initial la…
camall3n May 14, 2023
3bdd55d
Remove unnecessary deepcopy
camall3n May 14, 2023
cbc4847
Save policy history, online PI, hook up annealing params, fix mem res…
camall3n May 15, 2023
a635b00
Optimize deterministic binary memory to avoid np.random.choice
camall3n May 15, 2023
080781e
Restore reset_before_converging=True
camall3n May 15, 2023
3d57fe5
Move annealing args to argparse
camall3n May 15, 2023
5b12879
Manually normalize reward scale
camall3n May 15, 2023
e2ad82d
Fix bug in best_discrep reporting
camall3n May 15, 2023
e187494
Add seed argument
camall3n May 15, 2023
403ab26
Add annealing args
camall3n May 15, 2023
9edd17d
Add new plotting scripts
camall3n May 15, 2023
7da6720
Update plotting script
camall3n May 15, 2023
5464b17
Remove test code
camall3n May 15, 2023
07d7415
Restore option to run with priority queue
camall3n May 16, 2023
4dbc02e
Fix broken arg name
camall3n Jun 10, 2023
42d373c
Add reward ranges for slippery tmaze
camall3n Jun 10, 2023
7281402
integrate POPGym, with integration tests.
Jun 27, 2023
a8531ad
error type
Jun 27, 2023
86537a3
addressing all PR comments except popgym test assertions
Jun 27, 2023
7937247
fix popgym tests
taodav Jul 7, 2023
8280ebc
yapf
taodav Jul 8, 2023
9f08e0d
temporarily remove [navigation] from popgym, since mazelib doesn't wo…
taodav Jul 8, 2023
097958d
bump yapf version
taodav Jul 8, 2023
1787a7d
env init bug
taodav Jul 8, 2023
83fed95
[WIP] add hyperparam file for POPGym, need to deal with other action …
taodav Jul 10, 2023
baf0333
[WIP] add wrappers for different obs spaces. Missing final wrapper fo…
taodav Jul 12, 2023
6455fef
leave out Pendulum for now, due to continuous action space
taodav Jul 13, 2023
002b65f
popgym tests
taodav Jul 13, 2023
ab0b34c
refactor observation and action wrappers
taodav Jul 17, 2023
98adfb3
reduce number of runs
taodav Jul 17, 2023
e9950d5
add back in flatten multi discrete wrapper
taodav Jul 17, 2023
c92d070
change popgym sweep td seeds down to 3
taodav Jul 17, 2023
ec31152
add back array casting observation wrapper
taodav Jul 17, 2023
b39afb9
add 3 envs to pesky memory leak
taodav Jul 19, 2023
9186d4f
add binning optimization for cache misses for online training
taodav Jul 20, 2023
11d87bd
allow cache misses for len(buffer) < 10
taodav Jul 20, 2023
43e1770
revert popgym_sweep_mc to all envs
taodav Jul 20, 2023
c7390a2
add reduce_eval_size script
taodav Jul 24, 2023
36c6d38
reduce filesize for results
taodav Jul 24, 2023
facdd07
GET RID OF ONLINE REWARD SAVINGgit add -u .
taodav Jul 28, 2023
b4e044b
add script for reduce online logging size
taodav Jul 28, 2023
caca4ec
add memoryless runs
taodav Aug 30, 2023
df14346
add new and improved write_jobs script
taodav Aug 30, 2023
5882f4a
remove --hparam from write_job scripts
taodav Aug 30, 2023
6ae76aa
set entry to grl.run
taodav Aug 30, 2023
52c8b4b
Merge pull request #15 from taodav/integrate_popgym
taodav Sep 1, 2023
e11accc
remove pynbs, back to hydrogen
taodav Sep 1, 2023
ac48d6f
Fix undefined var error
camall3n Sep 13, 2023
26843ff
Add new forms of memory function locality
camall3n Sep 13, 2023
aeffe79
Reconnect annealing params for discrete optim (see 403ab26 and 07d7415)
camall3n Sep 15, 2023
f126741
Ensure mem fn doesn't get worse after optimization
camall3n Sep 15, 2023
3f88cdf
Disable override forcing use of annealing optimizer
camall3n Sep 15, 2023
bd38da1
Disable override forcing n_memory_trials=400
camall3n Sep 15, 2023
0ae7dc2
Add exit when annealing tmin > tmax
camall3n Sep 15, 2023
e04ce26
Fix pandas pivot named argument bug
camall3n Sep 16, 2023
3871714
Update paths, print n_runs and param_counts
camall3n Sep 16, 2023
21d7368
Add barplot averaging over all runs in each env
camall3n Sep 19, 2023
5dc13e3
[WIP] Add script to debug discrete optimization progress
camall3n Sep 19, 2023
65a8211
[Sync] with origin/main
camall3n Sep 19, 2023
c6513fb
Fix class/module names: AMDP->POMDP; agents->agent
camall3n Sep 19, 2023
8a6be00
Fix formatting (and force push to retry stochastic RTOL test)
camall3n Sep 19, 2023
5f83d71
add discrete and random uniform memory
taodav Sep 21, 2023
3c605eb
Add value-error implementation
camall3n Sep 19, 2023
613fbad
Fix formatting
camall3n Sep 19, 2023
eae2717
Standardize and fix comments for zeroing terminal state counts
camall3n Sep 19, 2023
c5c6ff6
[Merge] branch 'value-error' into learning-agent
camall3n Sep 21, 2023
957cfc5
Fix AMDP->POMDP comment
camall3n Sep 21, 2023
ebc3dff
Refactor n_annealing_repeats outer loop
camall3n Sep 21, 2023
71f0bfb
Add option to sample annealing hyperparams for each annealing repetition
camall3n Sep 21, 2023
6fa043c
Squash bar plot and change colors to cool / warm
camall3n Jun 10, 2023
ca99e16
Update colors to be more muted
camall3n Jun 12, 2023
afe9a3c
Update .py file for custom barplot ordering
camall3n Sep 23, 2023
a4afd2c
[Merge] 'improved-bar-plot' into learning-agent
camall3n Sep 23, 2023
dc95ee8
Add other bar plot result dirs + comments
camall3n Sep 23, 2023
df680b6
Add discrete optimization results as hatched bars in mi_performance
camall3n Sep 24, 2023
5cafd3a
Fix formatting
camall3n Sep 24, 2023
09e4ed7
[WIP] Add policy_optimization arg (td/mc/none)
camall3n Sep 23, 2023
93ea329
Add arg --n_random_policies; add script to determine default value
camall3n Sep 24, 2023
888bc05
[Merge] 'random-policies' into learning-agent
camall3n Sep 24, 2023
02c941a
Add learning agent support for non-binary memory
camall3n Sep 25, 2023
64fbaef
Move reward scaling into env loop
camall3n Sep 25, 2023
00ce59a
Log n_memory_states
camall3n Sep 25, 2023
622a5a8
revert back to mi_perf plotting
taodav Sep 25, 2023
66f6a32
fix policy_grad
taodav Sep 25, 2023
356306c
add optimal tiger memory, and change --account to --partition in onag…
taodav Sep 25, 2023
59e0a96
Update plot script to add more hashed lines
camall3n Sep 25, 2023
fc24a06
remove double network
taodav Sep 25, 2023
11a5d2b
Add arg support for policy gradient
camall3n Sep 26, 2023
e66f21d
fix policy_grad
taodav Sep 25, 2023
b5e5df6
Add reset_pi_params?
camall3n Sep 26, 2023
4f1976b
Disable jit
camall3n Sep 26, 2023
e0772a4
Add commented-out jax debug breakpoints
camall3n Sep 26, 2023
6271e45
set default hyperparams for analytical memory iteration to final_anal…
taodav Sep 25, 2023
5ae1abc
Switch to 64-bit jax to fix NaN issue
camall3n Sep 26, 2023
164520f
Scale did_change calculation with policy LR; increase n_pi_iterations
camall3n Sep 26, 2023
c4b74cb
Remove pi_improvement did_change stopping condition entirely
camall3n Sep 26, 2023
c86d069
Change pi_lr to 0.01
camall3n Sep 26, 2023
0dece79
[Merge] 'debug-pg' into learning-agent
camall3n Sep 26, 2023
320022a
Disable reward scaling by default
camall3n Sep 26, 2023
975ae88
Change number of policy iterations from 100k -> 5k
camall3n Sep 27, 2023
58ad570
[Bugfix] Update mem_aug_mdp after memory optimization
camall3n Sep 27, 2023
14749ca
Take larger LD policy of PI/PG-optimal and random high-LD
camall3n Sep 27, 2023
3afdf7d
Update mi_performance plotting script for locality06
camall3n Sep 27, 2023
3b862d5
Update mi_performance script for locality07
camall3n Sep 28, 2023
4c3a11e
Add support for running tiger with known LD policy
camall3n Sep 25, 2023
bdecb12
add optimal tiger memory, and change --account to --partition in onag…
taodav Sep 25, 2023
435d700
Add script to see if tiger optimal mem improves LD
camall3n Sep 28, 2023
4322373
[Merge] 'tiger-known-ld' into learning-agent
camall3n Sep 28, 2023
2fbc6b3
Add tmaze5-fixed.POMDP
camall3n Sep 28, 2023
cb08184
Fix calibration of tmaze
camall3n Sep 28, 2023
1059fc5
Update barplot hack indices & fix formatting
camall3n Sep 28, 2023
7cc968f
Clean up bar plots by policy_optim_alg
camall3n Sep 28, 2023
eb6ed2e
Fix pomdp solver results
taodav Sep 28, 2023
8f81a8f
add random kitchen sink initialization
taodav Sep 27, 2023
93fa86f
Merge pull request #16 from taodav/analytical_kitchen_sink
taodav Sep 28, 2023
96bb4bb
Update plots for kitchen sink experiment
camall3n Sep 28, 2023
40e7bb0
Improve barplot
camall3n Sep 28, 2023
aa3df9d
Fix title / ylabel
camall3n Sep 28, 2023
4ccad3a
Reset hatch.color to black
camall3n Sep 28, 2023
bc00273
allow for mi_steps = 0, and add memoryless kitchen sink hyperparams
taodav Oct 9, 2023
fd1c0ef
mi_performance updated for memoryless random kitchen sinks.
taodav Oct 10, 2023
0dd9bef
add tiger counting memory
taodav Oct 16, 2023
35e1978
trajectory logging
taodav Oct 17, 2023
e180264
working (with many samples) mem_traj_logging, with memory cross produ…
taodav Oct 18, 2023
550e1e6
Add value error logging
camall3n Nov 12, 2023
d0ad107
Improve LD vs. value_err correlation plots
camall3n Nov 13, 2023
f3143d0
Add conversions to/from augmented policies (OM->AM = OMA->M * OM->A)
camall3n Nov 13, 2023
54c80f2
Auto-calculate shapes for augmented policy conversions
camall3n Nov 13, 2023
aa9e6e7
[WIP] Add draft for augmented policy gradient objective function
camall3n Nov 13, 2023
a1e8164
Tighten ylims for correlation plots
camall3n Nov 13, 2023
ff77a58
add more things for .POMDP file parsing (for parsing heaving + hell)
taodav Nov 13, 2023
bd16557
add hallway kitchen sinks
taodav Nov 14, 2023
a3116f7
move hallway kitchen sinks
taodav Nov 14, 2023
6ba225b
bump plotting
taodav Nov 14, 2023
dd449c9
Switch to log-probs to avoid div by zero; add David's test case
camall3n Nov 14, 2023
a0a3b57
Add debug code to visualize mem fn probs
camall3n Nov 14, 2023
6c45de1
halfway through implementing pg for mem augmented pi
taodav Nov 14, 2023
52630f6
add mem_pg loss
taodav Nov 15, 2023
b2ed32f
running, but not working policy optimization
taodav Nov 15, 2023
704d7d8
still a strange bug with mem_pg
taodav Nov 15, 2023
6ab774c
fix policy_mem_grad, tested on tmaze, cheese and hallway
taodav Nov 15, 2023
da169a1
Add comment explaining Hallway domain
camall3n Nov 16, 2023
dfbde37
Remove trailing whitespaces
camall3n Nov 16, 2023
bcfa068
Fix citation
camall3n Nov 16, 2023
ab6ac8e
Rename observations to match indexing
camall3n Nov 16, 2023
8ae78d3
Add comment about unused observations
camall3n Nov 16, 2023
4b7dda3
running unrolled policy_mem_grad
taodav Nov 18, 2023
9f04bc9
working pg_mem_unrolled
taodav Nov 18, 2023
f3857f5
Add size annotations for unrolled_mem_pg
camall3n Nov 20, 2023
462bf03
add final_discrep_kitchen_sinks_pg
taodav Nov 20, 2023
eceef03
working (?) td(0)
taodav Dec 4, 2023
4eec97f
add magnitude pg runs
taodav Dec 4, 2023
dfa8e19
change name to bellman residual, add alpha = 0 run.
taodav Dec 6, 2023
e301b29
add 'residual' argument for bellman err and mstd err
taodav Dec 6, 2023
dfea2df
implemented TD error
taodav Dec 6, 2023
1988ed7
add kitchen sink policies for other objective types, also add tde_kit…
taodav Dec 19, 2023
97c8f83
add script for testing multiple-step bellman vs multiple-step bellman…
taodav Jan 11, 2024
4b0194a
Fix MSTDE (technically mean-squared sarsa error)
camall3n Jan 13, 2024
d3d3103
remove unused arguments for MSTDE
taodav Jan 16, 2024
7cb517d
add functionality for error_type in MSTDE
taodav Jan 16, 2024
73ce282
Merge pull request #17 from camall3n/fix-mstde
taodav Jan 16, 2024
f1f4ce8
remove alpha from kitchen sinks MSTDE runs
taodav Jan 16, 2024
c813741
add mem_lambda_tde_pg
taodav Jan 22, 2024
278532e
fix missing passing of optimizer params!
taodav Jan 23, 2024
9f2a350
fix terminals
taodav Jan 30, 2024
c0f32d1
add stuff
taodav Jan 31, 2024
e90f77a
pass in terminal_mask for tree_unflatten in POMDP
taodav Jan 31, 2024
0679ded
refactored and working multi-dir plotting
taodav Feb 1, 2024
06230bc
Merge pull request #18 from taodav/refactor_plotting
taodav Feb 1, 2024
659f47d
first step of batch_run
taodav Jan 26, 2024
7efabf8
running batch_run! need to still write tests
taodav Jan 28, 2024
761f529
about to add logging info
taodav Jan 29, 2024
1fa19c4
just missing value error
taodav Jan 29, 2024
9986b7b
finished implementing measures
taodav Jan 30, 2024
b8e7866
fix value error count probability dist
taodav Jan 30, 2024
ddfde4f
add logging measures
taodav Jan 31, 2024
be1cbf8
running batch_run with logging
taodav Jan 31, 2024
07c3150
renamed to batch_run_pg
taodav Jan 31, 2024
e66371a
split 8 mem state runs into two seperate runs
taodav Jan 31, 2024
97dee03
missing comma
taodav Jan 31, 2024
447a322
split seeds up even more
taodav Feb 1, 2024
7cd0fcd
okay first... only do 2 and 4 mem bits
taodav Feb 1, 2024
5d20e83
change run_gpu_locally for batch_run_pg
taodav Feb 1, 2024
20256ce
about to reparse results from batch_run
taodav Feb 2, 2024
3da11e0
add reorg script, with WIP parse_experiments for batch_runs
taodav Feb 5, 2024
b9595e7
parsing experiments
taodav Feb 5, 2024
3271e3e
add parser and plotter for batch_run
taodav Feb 5, 2024
e644cda
plotting progress
taodav Feb 6, 2024
5d0358b
need to normalize
taodav Feb 6, 2024
2e8ac49
debugging initial policy in batch_run
taodav Feb 6, 2024
67f084d
fix plotting
taodav Feb 6, 2024
b1ee203
change the initial policy improvement step to pg
taodav Feb 6, 2024
c25a79f
add 3 bit runs for batch_run_pg
taodav Feb 6, 2024
eedcdf3
[WIP] Add script for comparing TDE vs LD vs Value Error
camall3n Jan 13, 2024
039c2d9
Update compare script
camall3n Feb 12, 2024
92213e6
revert back to non-jitted memory improvement, and add mem_tde_01 runs
taodav Feb 12, 2024
a1af23b
Clean up TD-error script, add bar plots
camall3n Feb 13, 2024
c493163
add interleave hyperparams
taodav Feb 21, 2024
582622d
remove unneeded arguments
taodav Feb 21, 2024
4b990d3
add log_every
taodav Feb 21, 2024
b36b272
add --objective
taodav Feb 21, 2024
bbb71f2
new interleave runs
taodav Feb 21, 2024
47aa192
add batch_run fixes
taodav Feb 21, 2024
df7b0ad
Merge pull request #19 from taodav/interleave
taodav Feb 21, 2024
542d613
fix interleave
taodav Feb 21, 2024
d2dbdc7
fix batch_run
taodav Feb 22, 2024
94cb140
add residual arg
taodav Feb 22, 2024
91be4d3
add plotting grads script
taodav Feb 23, 2024
dcdec22
fix policy iteration
taodav Feb 23, 2024
b52422b
Add example 26 (and alternate)
camall3n Mar 8, 2024
d00d808
Add script to check example 26
camall3n Mar 8, 2024
596084c
add initial implementation for four_tmaze
taodav Mar 8, 2024
715baee
add compass world
taodav Mar 11, 2024
181cf10
move goal position to middle of left wall for compass world
taodav Mar 11, 2024
19ceab2
add batch_run_kitchen_sinks
taodav Mar 15, 2024
0ab6922
running batch kitchen sinks
taodav Mar 15, 2024
ddff824
add batch_run_kitchen
taodav Mar 15, 2024
e714078
[Merge] taodav/main into learning-agent
camall3n Apr 11, 2024
38a9533
Remove duplicate value_error implementation
camall3n Apr 11, 2024
e579f3e
[Merge] PR #20 - Add script to compare error signals
camall3n Apr 19, 2024
737a50a
[Sync] with taodav/main
camall3n Apr 19, 2024
ac05b27
Add hill climbing for MSTDE
camall3n Apr 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 257 additions & 20 deletions grl/agent/actorcritic.py

Large diffs are not rendered by default.

120 changes: 84 additions & 36 deletions grl/agent/analytical.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,28 @@
import optax

from grl.mdp import POMDP
from grl.utils.loss import policy_discrep_loss, pg_objective_func
from grl.utils.loss import mem_discrep_loss, mem_magnitude_td_loss, obs_space_mem_discrep_loss
from grl.utils.math import glorot_init
from grl.utils.policy import construct_aug_policy
from grl.utils.loss import policy_discrep_loss, pg_objective_func, \
mem_pg_objective_func, unrolled_mem_pg_objective_func
from grl.utils.loss import mem_discrep_loss, mem_bellman_loss, mem_tde_loss, obs_space_mem_discrep_loss
from grl.utils.math import glorot_init, reverse_softmax
from grl.utils.optimizer import get_optimizer
from grl.vi import policy_iteration_step

def new_pi_over_mem(pi_params: jnp.ndarray, add_n_mem_states: int,
new_mem_pi: str = 'repeat'):
old_pi_params_shape = pi_params.shape

pi_params = pi_params.repeat(add_n_mem_states, axis=0)

if new_mem_pi == 'random':
# randomly init policy for new memory state
new_mem_params = glorot_init(old_pi_params_shape)
pi_params = pi_params.at[1::2].set(new_mem_params)

return pi_params


class AnalyticalAgent:
"""
Analytical agent that learns optimal policy params based on an
Expand All @@ -29,6 +45,7 @@ def __init__(self,
value_type: str = 'v',
error_type: str = 'l2',
objective: str = 'discrep',
residual: bool = False,
lambda_0: float = 0.,
lambda_1: float = 1.,
alpha: float = 1.,
Expand All @@ -43,7 +60,7 @@ def __init__(self,
:param mem_params: Memory parameters (optional)
:param value_type: If we optimize lambda discrepancy, what type of lambda discrepancy do we optimize? (v | q)
:param error_type: lambda discrepancy error type (l2 | abs)
:param objective: What objective are we trying to minimize? (discrep | magnitude)
:param objective: What objective are we trying to minimize? (discrep | bellman | tde)
:param pi_softmax_temp: When we take the softmax over pi_params, what is the softmax temperature?
:param policy_optim_alg: What type of policy optimization do we do? (pi | pg)
(discrep_max: discrepancy maximization | discrep_min: discrepancy minimization
Expand All @@ -58,13 +75,18 @@ def __init__(self,
self.og_n_obs = self.pi_params.shape[0]

self.pg_objective_func = jit(pg_objective_func)
if self.policy_optim_alg == 'policy_mem_grad':
self.pg_objective_func = jit(mem_pg_objective_func)
elif self.policy_optim_alg == 'policy_mem_grad_unrolled':
self.pg_objective_func = jit(unrolled_mem_pg_objective_func)

self.policy_iteration_update = jit(policy_iteration_step, static_argnames=['eps'])
self.epsilon = epsilon

self.val_type = value_type
self.error_type = error_type
self.objective = objective
self.residual = residual
self.lambda_0 = lambda_0
self.lambda_1 = lambda_1
self.alpha = alpha
Expand All @@ -77,19 +99,29 @@ def __init__(self,

self.new_mem_pi = new_mem_pi

self.optim_str = optim_str
# initialize optimizers
self.pi_lr = pi_lr
self.pi_optim = get_optimizer(optim_str, self.pi_lr)
self.pi_optim_state = self.pi_optim.init(self.pi_params)

self.mem_params = None
if mem_params is not None:
self.mem_params = mem_params

if self.policy_optim_alg in ['policy_mem_grad', 'policy_mem_grad_unrolled']:
mem_probs, pi_probs = softmax(self.mem_params, -1), softmax(self.pi_params, -1)
aug_policy = construct_aug_policy(mem_probs, pi_probs)
self.pi_aug_params = reverse_softmax(aug_policy)

self.mi_lr = mi_lr
self.mem_optim = get_optimizer(optim_str, self.mi_lr)
self.mem_optim_state = self.mem_optim.init(self.mem_params)

# initialize optimizers
self.optim_str = optim_str
self.pi_lr = pi_lr
self.pi_optim = get_optimizer(optim_str, self.pi_lr)

pi_params_to_optimize = self.pi_params
if self.policy_optim_alg in ['policy_mem_grad', 'policy_mem_grad_unrolled']:
pi_params_to_optimize = self.pi_aug_params
self.pi_optim_state = self.pi_optim.init(pi_params_to_optimize)

self.pi_softmax_temp = pi_softmax_temp

self.rand_key = rand_key
Expand All @@ -113,19 +145,25 @@ def init_and_jit_objectives(self):
self.policy_discrep_objective_func = jit(partial_policy_discrep_loss)

mem_loss_fn = mem_discrep_loss
partial_kwargs = {
'value_type': self.val_type,
'error_type': self.error_type,
'lambda_0': self.lambda_0,
'lambda_1': self.lambda_1,
'alpha': self.alpha,
'flip_count_prob': self.flip_count_prob
}
if hasattr(self, 'objective'):
if self.objective == 'magnitude':
mem_loss_fn = mem_magnitude_td_loss
if self.objective == 'bellman':
mem_loss_fn = mem_bellman_loss
partial_kwargs['residual'] = self.residual
elif self.objective == 'tde':
mem_loss_fn = mem_tde_loss
partial_kwargs['residual'] = self.residual
elif self.objective == 'obs_space':
mem_loss_fn = obs_space_mem_discrep_loss

partial_mem_discrep_loss = partial(mem_loss_fn,
value_type=self.val_type,
error_type=self.error_type,
lambda_0=self.lambda_0,
lambda_1=self.lambda_1,
alpha=self.alpha,
flip_count_prob=self.flip_count_prob)
partial_mem_discrep_loss = partial(mem_loss_fn, **partial_kwargs)
self.memory_objective_func = jit(partial_mem_discrep_loss)

@property
Expand All @@ -143,21 +181,17 @@ def reset_pi_params(self, pi_shape: Sequence[int] = None):
if pi_shape is None:
pi_shape = self.pi_params.shape
self.pi_params = glorot_init(pi_shape)
self.pi_optim_state = self.pi_optim.init(self.pi_params)

def new_pi_over_mem(self):
if self.pi_params.shape[0] != self.og_n_obs:
raise NotImplementedError(
"Have not implemented adding bits to already existing memory.")

add_n_mem_states = self.mem_params.shape[-1]
old_pi_params_shape = self.pi_params.shape

self.pi_params = self.pi_params.repeat(add_n_mem_states, axis=0)

if self.new_mem_pi == 'random':
# randomly init policy for new memory state
new_mem_params = glorot_init(old_pi_params_shape)
self.pi_params = self.pi_params.at[1::2].set(new_mem_params)
self.pi_params = new_pi_over_mem(self.pi_params,
add_n_mem_states=add_n_mem_states,
new_mem_pi=self.new_mem_pi)

@partial(jit, static_argnames=['self'])
def policy_gradient_update(self, params: jnp.ndarray, optim_state: jnp.ndarray, pomdp: POMDP):
Expand All @@ -169,7 +203,7 @@ def policy_gradient_update(self, params: jnp.ndarray, optim_state: jnp.ndarray,
params_grad = -params_grad
updates, optimizer_state = self.pi_optim.update(params_grad, optim_state, params)
params = optax.apply_updates(params, updates)
return v_0, td_v_vals, td_q_vals, params
return v_0, td_v_vals, td_q_vals, params, optimizer_state

@partial(jit, static_argnames=['self', 'sign'])
def policy_discrep_update(self,
Expand All @@ -187,12 +221,15 @@ def policy_discrep_update(self,
updates, optimizer_state = self.pi_optim.update(params_grad, optim_state, params)
params = optax.apply_updates(params, updates)

return loss, mc_vals, td_vals, params
return loss, mc_vals, td_vals, params, optimizer_state

def policy_improvement(self, pomdp: POMDP):
if self.policy_optim_alg == 'policy_grad':
v_0, prev_td_v_vals, prev_td_q_vals, new_pi_params = \
self.policy_gradient_update(self.pi_params, self.pi_optim_state, pomdp)
if self.policy_optim_alg in ['policy_grad', 'policy_mem_grad', 'policy_mem_grad_unrolled']:
policy_params = self.pi_params
if self.policy_optim_alg in ['policy_mem_grad', 'policy_mem_grad_unrolled']:
policy_params = self.pi_aug_params
v_0, prev_td_v_vals, prev_td_q_vals, new_pi_params, new_optim_state= \
self.policy_gradient_update(policy_params, self.pi_optim_state, pomdp)
output = {
'v_0': v_0,
'prev_td_q_vals': prev_td_q_vals,
Expand All @@ -201,17 +238,23 @@ def policy_improvement(self, pomdp: POMDP):
elif self.policy_optim_alg == 'policy_iter':
new_pi_params, prev_td_v_vals, prev_td_q_vals = self.policy_iteration_update(
self.pi_params, pomdp, eps=self.epsilon)
new_optim_state = self.pi_optim_state
output = {'prev_td_q_vals': prev_td_q_vals, 'prev_td_v_vals': prev_td_v_vals}
elif self.policy_optim_alg == 'discrep_max' or self.policy_optim_alg == 'discrep_min':
loss, mc_vals, td_vals, new_pi_params = self.policy_discrep_update(
loss, mc_vals, td_vals, new_pi_params, new_optim_state = self.policy_discrep_update(
self.pi_params,
self.pi_optim_state,
pomdp,
sign=(self.policy_optim_alg == 'discrep_max'))
output = {'loss': loss, 'mc_vals': mc_vals, 'td_vals': td_vals}
else:
raise NotImplementedError
self.pi_params = new_pi_params

if self.policy_optim_alg in ['policy_mem_grad', 'policy_mem_grad_unrolled']:
self.pi_aug_params = new_pi_params
else:
self.pi_params = new_pi_params
self.pi_optim_state = new_optim_state
return output

@partial(jit, static_argnames=['self'])
Expand All @@ -224,13 +267,14 @@ def memory_update(self, params: jnp.ndarray, optim_state: jnp.ndarray, pi_params
updates, optimizer_state = self.mem_optim.update(params_grad, optim_state, params)
params = optax.apply_updates(params, updates)

return loss, params
return loss, params, optimizer_state

def memory_improvement(self, pomdp: POMDP):
assert self.mem_params is not None, 'I have no memory params'
loss, new_mem_params = self.memory_update(self.mem_params, self.mem_optim_state,
loss, new_mem_params, new_mem_optim_state = self.memory_update(self.mem_params, self.mem_optim_state,
self.pi_params, pomdp)
self.mem_params = new_mem_params
self.mem_optim_state = new_mem_optim_state
return loss

def __getstate__(self) -> dict:
Expand All @@ -254,6 +298,10 @@ def __setstate__(self, state: dict):

# restore jitted functions
self.pg_objective_func = jit(pg_objective_func)
if self.policy_optim_alg == 'policy_mem_grad':
self.pg_objective_func = jit(mem_pg_objective_func)
elif self.policy_optim_alg == 'policy_mem_grad_unrolled':
self.pg_objective_func = jit(unrolled_mem_pg_objective_func)
self.policy_iteration_update = jit(policy_iteration_step, static_argnames=['eps'])

if 'optim_str' not in state:
Expand Down
19 changes: 17 additions & 2 deletions grl/agent/td_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,17 @@ def _reset_q_values(self):
def _reset_eligibility(self):
self.eligibility = np.zeros((self.n_actions, self.n_obs))

def update(self, obs, action, reward, terminal, next_obs, next_action):
def update(
self,
obs,
action,
reward,
terminal,
next_obs,
next_action,
aug_obs=None, # memory-augmented observation
next_aug_obs=None, # and next observation (O x M)
):
# Because mdp.step() terminates with probability (1-γ),
# we have already factored in the γ that we would normally
# use to decay the eligibility.
Expand All @@ -50,6 +60,11 @@ def update(self, obs, action, reward, terminal, next_obs, next_action):
# probability γ.
#
# Thus we simply decay eligibility by λ.
if aug_obs is not None:
obs = aug_obs
if next_aug_obs is not None:
next_obs = next_aug_obs

self.eligibility *= self.lambda_
if self.trace_type == 'accumulating':
self.eligibility[action, obs] += 1
Expand Down Expand Up @@ -84,7 +99,7 @@ def run_td_lambda_on_mdp(
alpha=1,
n_episodes=1000,
):
# If AMDP, convert to pi_ground
# If POMDP, convert to pi_ground
if hasattr(mdp, 'phi'):
pi_ground = mdp.get_ground_policy(pi)
else:
Expand Down
71 changes: 60 additions & 11 deletions grl/environment/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,77 @@
from argparse import Namespace

import jax
import gymnasium as gym
import numpy as np
from numpy import random
import popgym
from popgym.wrappers import Flatten

from .rocksample import RockSample
from .spec import load_spec, load_pomdp
from .wrappers import OneHotObservationWrapper, OneHotActionConcatWrapper
from .wrappers import OneHotObservationWrapper, OneHotActionConcatWrapper, \
FlattenMultiDiscreteActionWrapper, DiscreteObservationWrapper, \
ContinuousToDiscrete, ArrayObservationWrapper

def get_popgym_env(args: Namespace, rand_key: random.RandomState = None, **kwargs):
# check to see if name exists
env_names = set([e["id"] for e in popgym.envs.ALL.values()])
if args.spec not in env_names:
raise AttributeError(f"spec {args.spec} not found")
# wrappers fail unless disable_env_checker=True
env = gym.make(args.spec, disable_env_checker=True)
env.reset(seed=args.seed)
env.rand_key = rand_key
env.gamma = args.gamma

return env

def get_env(args: Namespace,
rand_state: np.random.RandomState = None,
rand_key: jax.random.PRNGKey = None,
action_bins: int = 6,
**kwargs):
"""
:param action_bins: If we have a continous action space, how many bins do we discretize to?
"""
# First we check our POMDP specs
try:
env, _ = load_pomdp(args.spec, rand_key=rand_state, **kwargs)

# TODO: some features are already encoded in a one-hot manner.
if args.feature_encoding == 'one_hot':
env = OneHotObservationWrapper(env)
except AttributeError:
if args.spec == 'rocksample':
env = RockSample(rand_key=rand_key, **kwargs)
else:
raise NotImplementedError
# try to load from popgym
# validate input: we need a custom gamma for popgym args as they don't come with a gamma
if args.gamma is None:
raise AttributeError("Can't load non-native environments without passing in gamma!")
try:
env, _ = load_pomdp(args.spec, rand_key=rand_state, **kwargs)

except AttributeError:
# try to load from popgym
# validate input: we need a custom gamma for popgym args as they don't come with a gamma
if args.gamma is None:
raise AttributeError(
"Can't load non-native environments without passing in gamma!")
try:
env = get_popgym_env(args, rand_key=rand_state, **kwargs)

env = Flatten(env)
# also might need to preprocess our observation spaces
if isinstance(env.observation_space, gym.spaces.Discrete)\
and args.feature_encoding != 'one_hot':
env = DiscreteObservationWrapper(env)
if isinstance(env.observation_space, gym.spaces.Tuple):
env = ArrayObservationWrapper(env)

# preprocess continous action spaces
if isinstance(env.action_space, gym.spaces.Box):
env = ContinuousToDiscrete(env, action_bins)
elif isinstance(env.action_space, gym.spaces.MultiDiscrete):
env = FlattenMultiDiscreteActionWrapper(env)

except AttributeError:
# don't have anything else implemented
raise NotImplementedError

if args.feature_encoding == 'one_hot':
env = OneHotObservationWrapper(env)

if args.action_cond == 'cat':
env = OneHotActionConcatWrapper(env)
Expand Down
Loading
Loading