Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor plotting #33

Closed
wants to merge 74 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
7281402
integrate POPGym, with integration tests.
Jun 27, 2023
a8531ad
error type
Jun 27, 2023
86537a3
addressing all PR comments except popgym test assertions
Jun 27, 2023
7937247
fix popgym tests
taodav Jul 7, 2023
8280ebc
yapf
taodav Jul 8, 2023
9f08e0d
temporarily remove [navigation] from popgym, since mazelib doesn't wo…
taodav Jul 8, 2023
097958d
bump yapf version
taodav Jul 8, 2023
1787a7d
env init bug
taodav Jul 8, 2023
83fed95
[WIP] add hyperparam file for POPGym, need to deal with other action …
taodav Jul 10, 2023
baf0333
[WIP] add wrappers for different obs spaces. Missing final wrapper fo…
taodav Jul 12, 2023
6455fef
leave out Pendulum for now, due to continuous action space
taodav Jul 13, 2023
002b65f
popgym tests
taodav Jul 13, 2023
ab0b34c
refactor observation and action wrappers
taodav Jul 17, 2023
98adfb3
reduce number of runs
taodav Jul 17, 2023
e9950d5
add back in flatten multi discrete wrapper
taodav Jul 17, 2023
c92d070
change popgym sweep td seeds down to 3
taodav Jul 17, 2023
ec31152
add back array casting observation wrapper
taodav Jul 17, 2023
b39afb9
add 3 envs to pesky memory leak
taodav Jul 19, 2023
9186d4f
add binning optimization for cache misses for online training
taodav Jul 20, 2023
11d87bd
allow cache misses for len(buffer) < 10
taodav Jul 20, 2023
43e1770
revert popgym_sweep_mc to all envs
taodav Jul 20, 2023
c7390a2
add reduce_eval_size script
taodav Jul 24, 2023
36c6d38
reduce filesize for results
taodav Jul 24, 2023
facdd07
GET RID OF ONLINE REWARD SAVINGgit add -u .
taodav Jul 28, 2023
b4e044b
add script for reduce online logging size
taodav Jul 28, 2023
caca4ec
add memoryless runs
taodav Aug 30, 2023
df14346
add new and improved write_jobs script
taodav Aug 30, 2023
5882f4a
remove --hparam from write_job scripts
taodav Aug 30, 2023
6ae76aa
set entry to grl.run
taodav Aug 30, 2023
52c8b4b
Merge pull request #15 from taodav/integrate_popgym
taodav Sep 1, 2023
e11accc
remove pynbs, back to hydrogen
taodav Sep 1, 2023
5f83d71
add discrete and random uniform memory
taodav Sep 21, 2023
622a5a8
revert back to mi_perf plotting
taodav Sep 25, 2023
66f6a32
fix policy_grad
taodav Sep 25, 2023
356306c
add optimal tiger memory, and change --account to --partition in onag…
taodav Sep 25, 2023
fc24a06
remove double network
taodav Sep 25, 2023
eb6ed2e
Fix pomdp solver results
taodav Sep 28, 2023
8f81a8f
add random kitchen sink initialization
taodav Sep 27, 2023
93fa86f
Merge pull request #16 from taodav/analytical_kitchen_sink
taodav Sep 28, 2023
bc00273
allow for mi_steps = 0, and add memoryless kitchen sink hyperparams
taodav Oct 9, 2023
fd1c0ef
mi_performance updated for memoryless random kitchen sinks.
taodav Oct 10, 2023
0dd9bef
add tiger counting memory
taodav Oct 16, 2023
35e1978
trajectory logging
taodav Oct 17, 2023
e180264
working (with many samples) mem_traj_logging, with memory cross produ…
taodav Oct 18, 2023
ff77a58
add more things for .POMDP file parsing (for parsing heaving + hell)
taodav Nov 13, 2023
bd16557
add hallway kitchen sinks
taodav Nov 14, 2023
a3116f7
move hallway kitchen sinks
taodav Nov 14, 2023
6ba225b
bump plotting
taodav Nov 14, 2023
6c45de1
halfway through implementing pg for mem augmented pi
taodav Nov 14, 2023
52630f6
add mem_pg loss
taodav Nov 15, 2023
b2ed32f
running, but not working policy optimization
taodav Nov 15, 2023
704d7d8
still a strange bug with mem_pg
taodav Nov 15, 2023
6ab774c
fix policy_mem_grad, tested on tmaze, cheese and hallway
taodav Nov 15, 2023
4b7dda3
running unrolled policy_mem_grad
taodav Nov 18, 2023
9f04bc9
working pg_mem_unrolled
taodav Nov 18, 2023
f3857f5
Add size annotations for unrolled_mem_pg
camall3n Nov 20, 2023
462bf03
add final_discrep_kitchen_sinks_pg
taodav Nov 20, 2023
eceef03
working (?) td(0)
taodav Dec 4, 2023
4eec97f
add magnitude pg runs
taodav Dec 4, 2023
dfa8e19
change name to bellman residual, add alpha = 0 run.
taodav Dec 6, 2023
e301b29
add 'residual' argument for bellman err and mstd err
taodav Dec 6, 2023
dfea2df
implemented TD error
taodav Dec 6, 2023
1988ed7
add kitchen sink policies for other objective types, also add tde_kit…
taodav Dec 19, 2023
97c8f83
add script for testing multiple-step bellman vs multiple-step bellman…
taodav Jan 11, 2024
4b0194a
Fix MSTDE (technically mean-squared sarsa error)
camall3n Jan 13, 2024
d3d3103
remove unused arguments for MSTDE
taodav Jan 16, 2024
7cb517d
add functionality for error_type in MSTDE
taodav Jan 16, 2024
73ce282
Merge pull request #17 from camall3n/fix-mstde
taodav Jan 16, 2024
f1f4ce8
remove alpha from kitchen sinks MSTDE runs
taodav Jan 16, 2024
c813741
add mem_lambda_tde_pg
taodav Jan 22, 2024
278532e
fix missing passing of optimizer params!
taodav Jan 23, 2024
9f2a350
fix terminals
taodav Jan 30, 2024
c0f32d1
add stuff
taodav Jan 31, 2024
0679ded
refactored and working multi-dir plotting
taodav Feb 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 67 additions & 28 deletions grl/agent/analytical.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import optax

from grl.mdp import POMDP
from grl.utils.loss import policy_discrep_loss, pg_objective_func
from grl.utils.loss import mem_discrep_loss, mem_magnitude_td_loss, obs_space_mem_discrep_loss
from grl.utils.math import glorot_init
from grl.utils.augment_policy import construct_aug_policy
from grl.utils.loss import policy_discrep_loss, pg_objective_func, \
mem_pg_objective_func, unrolled_mem_pg_objective_func
from grl.utils.loss import mem_discrep_loss, mem_bellman_loss, mem_tde_loss, obs_space_mem_discrep_loss
from grl.utils.math import glorot_init, reverse_softmax
from grl.utils.optimizer import get_optimizer
from grl.vi import policy_iteration_step

Expand All @@ -29,6 +31,7 @@ def __init__(self,
value_type: str = 'v',
error_type: str = 'l2',
objective: str = 'discrep',
residual: bool = False,
lambda_0: float = 0.,
lambda_1: float = 1.,
alpha: float = 1.,
Expand All @@ -43,7 +46,7 @@ def __init__(self,
:param mem_params: Memory parameters (optional)
:param value_type: If we optimize lambda discrepancy, what type of lambda discrepancy do we optimize? (v | q)
:param error_type: lambda discrepancy error type (l2 | abs)
:param objective: What objective are we trying to minimize? (discrep | magnitude)
:param objective: What objective are we trying to minimize? (discrep | bellman | tde)
:param pi_softmax_temp: When we take the softmax over pi_params, what is the softmax temperature?
:param policy_optim_alg: What type of policy optimization do we do? (pi | pg)
(discrep_max: discrepancy maximization | discrep_min: discrepancy minimization
Expand All @@ -58,13 +61,18 @@ def __init__(self,
self.og_n_obs = self.pi_params.shape[0]

self.pg_objective_func = jit(pg_objective_func)
if self.policy_optim_alg == 'policy_mem_grad':
self.pg_objective_func = jit(mem_pg_objective_func)
elif self.policy_optim_alg == 'policy_mem_grad_unrolled':
self.pg_objective_func = jit(unrolled_mem_pg_objective_func)

self.policy_iteration_update = jit(policy_iteration_step, static_argnames=['eps'])
self.epsilon = epsilon

self.val_type = value_type
self.error_type = error_type
self.objective = objective
self.residual = residual
self.lambda_0 = lambda_0
self.lambda_1 = lambda_1
self.alpha = alpha
Expand All @@ -77,19 +85,29 @@ def __init__(self,

self.new_mem_pi = new_mem_pi

self.optim_str = optim_str
# initialize optimizers
self.pi_lr = pi_lr
self.pi_optim = get_optimizer(optim_str, self.pi_lr)
self.pi_optim_state = self.pi_optim.init(self.pi_params)

self.mem_params = None
if mem_params is not None:
self.mem_params = mem_params

if self.policy_optim_alg in ['policy_mem_grad', 'policy_mem_grad_unrolled']:
mem_probs, pi_probs = softmax(self.mem_params, -1), softmax(self.pi_params, -1)
aug_policy = construct_aug_policy(mem_probs, pi_probs)
self.pi_aug_params = reverse_softmax(aug_policy)

self.mi_lr = mi_lr
self.mem_optim = get_optimizer(optim_str, self.mi_lr)
self.mem_optim_state = self.mem_optim.init(self.mem_params)

# initialize optimizers
self.optim_str = optim_str
self.pi_lr = pi_lr
self.pi_optim = get_optimizer(optim_str, self.pi_lr)

pi_params_to_optimize = self.pi_params
if self.policy_optim_alg in ['policy_mem_grad', 'policy_mem_grad_unrolled']:
pi_params_to_optimize = self.pi_aug_params
self.pi_optim_state = self.pi_optim.init(pi_params_to_optimize)

self.pi_softmax_temp = pi_softmax_temp

self.rand_key = rand_key
Expand All @@ -113,19 +131,25 @@ def init_and_jit_objectives(self):
self.policy_discrep_objective_func = jit(partial_policy_discrep_loss)

mem_loss_fn = mem_discrep_loss
partial_kwargs = {
'value_type': self.val_type,
'error_type': self.error_type,
'lambda_0': self.lambda_0,
'lambda_1': self.lambda_1,
'alpha': self.alpha,
'flip_count_prob': self.flip_count_prob
}
if hasattr(self, 'objective'):
if self.objective == 'magnitude':
mem_loss_fn = mem_magnitude_td_loss
if self.objective == 'bellman':
mem_loss_fn = mem_bellman_loss
partial_kwargs['residual'] = self.residual
elif self.objective == 'tde':
mem_loss_fn = mem_tde_loss
partial_kwargs['residual'] = self.residual
elif self.objective == 'obs_space':
mem_loss_fn = obs_space_mem_discrep_loss

partial_mem_discrep_loss = partial(mem_loss_fn,
value_type=self.val_type,
error_type=self.error_type,
lambda_0=self.lambda_0,
lambda_1=self.lambda_1,
alpha=self.alpha,
flip_count_prob=self.flip_count_prob)
partial_mem_discrep_loss = partial(mem_loss_fn, **partial_kwargs)
self.memory_objective_func = jit(partial_mem_discrep_loss)

@property
Expand All @@ -143,6 +167,7 @@ def reset_pi_params(self, pi_shape: Sequence[int] = None):
if pi_shape is None:
pi_shape = self.pi_params.shape
self.pi_params = glorot_init(pi_shape)
self.pi_optim_state = self.pi_optim.init(self.pi_params)

def new_pi_over_mem(self):
if self.pi_params.shape[0] != self.og_n_obs:
Expand All @@ -169,7 +194,7 @@ def policy_gradient_update(self, params: jnp.ndarray, optim_state: jnp.ndarray,
params_grad = -params_grad
updates, optimizer_state = self.pi_optim.update(params_grad, optim_state, params)
params = optax.apply_updates(params, updates)
return v_0, td_v_vals, td_q_vals, params
return v_0, td_v_vals, td_q_vals, params, optimizer_state

@partial(jit, static_argnames=['self', 'sign'])
def policy_discrep_update(self,
Expand All @@ -187,12 +212,15 @@ def policy_discrep_update(self,
updates, optimizer_state = self.pi_optim.update(params_grad, optim_state, params)
params = optax.apply_updates(params, updates)

return loss, mc_vals, td_vals, params
return loss, mc_vals, td_vals, params, optimizer_state

def policy_improvement(self, pomdp: POMDP):
if self.policy_optim_alg == 'policy_grad':
v_0, prev_td_v_vals, prev_td_q_vals, new_pi_params = \
self.policy_gradient_update(self.pi_params, self.pi_optim_state, pomdp)
if self.policy_optim_alg in ['policy_grad', 'policy_mem_grad', 'policy_mem_grad_unrolled']:
policy_params = self.pi_params
if self.policy_optim_alg in ['policy_mem_grad', 'policy_mem_grad_unrolled']:
policy_params = self.pi_aug_params
v_0, prev_td_v_vals, prev_td_q_vals, new_pi_params, new_optim_state= \
self.policy_gradient_update(policy_params, self.pi_optim_state, pomdp)
output = {
'v_0': v_0,
'prev_td_q_vals': prev_td_q_vals,
Expand All @@ -201,17 +229,23 @@ def policy_improvement(self, pomdp: POMDP):
elif self.policy_optim_alg == 'policy_iter':
new_pi_params, prev_td_v_vals, prev_td_q_vals = self.policy_iteration_update(
self.pi_params, pomdp, eps=self.epsilon)
new_optim_state = self.pi_optim_state
output = {'prev_td_q_vals': prev_td_q_vals, 'prev_td_v_vals': prev_td_v_vals}
elif self.policy_optim_alg == 'discrep_max' or self.policy_optim_alg == 'discrep_min':
loss, mc_vals, td_vals, new_pi_params = self.policy_discrep_update(
loss, mc_vals, td_vals, new_pi_params, new_optim_state = self.policy_discrep_update(
self.pi_params,
self.pi_optim_state,
pomdp,
sign=(self.policy_optim_alg == 'discrep_max'))
output = {'loss': loss, 'mc_vals': mc_vals, 'td_vals': td_vals}
else:
raise NotImplementedError
self.pi_params = new_pi_params

if self.policy_optim_alg in ['policy_mem_grad', 'policy_mem_grad_unrolled']:
self.pi_aug_params = new_pi_params
else:
self.pi_params = new_pi_params
self.pi_optim_state = new_optim_state
return output

@partial(jit, static_argnames=['self'])
Expand All @@ -224,13 +258,14 @@ def memory_update(self, params: jnp.ndarray, optim_state: jnp.ndarray, pi_params
updates, optimizer_state = self.mem_optim.update(params_grad, optim_state, params)
params = optax.apply_updates(params, updates)

return loss, params
return loss, params, optimizer_state

def memory_improvement(self, pomdp: POMDP):
assert self.mem_params is not None, 'I have no memory params'
loss, new_mem_params = self.memory_update(self.mem_params, self.mem_optim_state,
loss, new_mem_params, new_mem_optim_state = self.memory_update(self.mem_params, self.mem_optim_state,
self.pi_params, pomdp)
self.mem_params = new_mem_params
self.mem_optim_state = new_mem_optim_state
return loss

def __getstate__(self) -> dict:
Expand All @@ -254,6 +289,10 @@ def __setstate__(self, state: dict):

# restore jitted functions
self.pg_objective_func = jit(pg_objective_func)
if self.policy_optim_alg == 'policy_mem_grad':
self.pg_objective_func = jit(mem_pg_objective_func)
elif self.policy_optim_alg == 'policy_mem_grad_unrolled':
self.pg_objective_func = jit(unrolled_mem_pg_objective_func)
self.policy_iteration_update = jit(policy_iteration_step, static_argnames=['eps'])

if 'optim_str' not in state:
Expand Down
71 changes: 60 additions & 11 deletions grl/environment/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,77 @@
from argparse import Namespace

import jax
import gymnasium as gym
import numpy as np
from numpy import random
import popgym
from popgym.wrappers import Flatten

from .rocksample import RockSample
from .spec import load_spec, load_pomdp
from .wrappers import OneHotObservationWrapper, OneHotActionConcatWrapper
from .wrappers import OneHotObservationWrapper, OneHotActionConcatWrapper, \
FlattenMultiDiscreteActionWrapper, DiscreteObservationWrapper, \
ContinuousToDiscrete, ArrayObservationWrapper

def get_popgym_env(args: Namespace, rand_key: random.RandomState = None, **kwargs):
# check to see if name exists
env_names = set([e["id"] for e in popgym.envs.ALL.values()])
if args.spec not in env_names:
raise AttributeError(f"spec {args.spec} not found")
# wrappers fail unless disable_env_checker=True
env = gym.make(args.spec, disable_env_checker=True)
env.reset(seed=args.seed)
env.rand_key = rand_key
env.gamma = args.gamma

return env

def get_env(args: Namespace,
rand_state: np.random.RandomState = None,
rand_key: jax.random.PRNGKey = None,
action_bins: int = 6,
**kwargs):
"""
:param action_bins: If we have a continous action space, how many bins do we discretize to?
"""
# First we check our POMDP specs
try:
env, _ = load_pomdp(args.spec, rand_key=rand_state, **kwargs)

# TODO: some features are already encoded in a one-hot manner.
if args.feature_encoding == 'one_hot':
env = OneHotObservationWrapper(env)
except AttributeError:
if args.spec == 'rocksample':
env = RockSample(rand_key=rand_key, **kwargs)
else:
raise NotImplementedError
# try to load from popgym
# validate input: we need a custom gamma for popgym args as they don't come with a gamma
if args.gamma is None:
raise AttributeError("Can't load non-native environments without passing in gamma!")
try:
env, _ = load_pomdp(args.spec, rand_key=rand_state, **kwargs)

except AttributeError:
# try to load from popgym
# validate input: we need a custom gamma for popgym args as they don't come with a gamma
if args.gamma is None:
raise AttributeError(
"Can't load non-native environments without passing in gamma!")
try:
env = get_popgym_env(args, rand_key=rand_state, **kwargs)

env = Flatten(env)
# also might need to preprocess our observation spaces
if isinstance(env.observation_space, gym.spaces.Discrete)\
and args.feature_encoding != 'one_hot':
env = DiscreteObservationWrapper(env)
if isinstance(env.observation_space, gym.spaces.Tuple):
env = ArrayObservationWrapper(env)

# preprocess continous action spaces
if isinstance(env.action_space, gym.spaces.Box):
env = ContinuousToDiscrete(env, action_bins)
elif isinstance(env.action_space, gym.spaces.MultiDiscrete):
env = FlattenMultiDiscreteActionWrapper(env)

except AttributeError:
# don't have anything else implemented
raise NotImplementedError

if args.feature_encoding == 'one_hot':
env = OneHotObservationWrapper(env)

if args.action_cond == 'cat':
env = OneHotActionConcatWrapper(env)
Expand Down
Binary file modified grl/environment/pomdp_files/4x3.95-pomdp-solver-results.npy
Binary file not shown.
Binary file not shown.
Binary file modified grl/environment/pomdp_files/example_7-pomdp-solver-results.npy
Binary file not shown.
Binary file modified grl/environment/pomdp_files/network-pomdp-solver-results.npy
Binary file not shown.
Loading
Loading