Skip to content

Commit

Permalink
Add hill climbing for MSTDE
Browse files Browse the repository at this point in the history
  • Loading branch information
camall3n committed Apr 22, 2024
1 parent 737a50a commit ac05b27
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 34 deletions.
22 changes: 14 additions & 8 deletions grl/agent/actorcritic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from tqdm import tqdm
from grl.memory.analytical import memory_cross_product
from grl.utils.discrete_search import SearchNode, generate_hold_mem_fn
from grl.utils.loss import mem_discrep_loss, value_error
from grl.utils.loss import mem_discrep_loss, mem_tde_loss, value_error
from grl.agent.td_lambda import TDLambdaQFunction
from grl.utils.replaymemory import ReplayMemory
from grl.utils.math import arg_hardmax, arg_mellowmax, arg_boltzman, one_hot
Expand All @@ -37,7 +37,8 @@ def __init__(
policy_epsilon: float = 0.10,
mellowmax_beta: float = 10.0,
replay_buffer_size: int = 1000000,
mem_optimizer='queue', # [queue, annealing, optuna]
mem_optimizer='queue', # ['queue', 'annealing', 'optuna']
mem_optim_objective='ld', #['ld', 'td']
prune_if_parent_suboptimal=False, # search queue pruning
ignore_queue_priority=False, # search queue priority
annealing_should_sample_hyperparams=False,
Expand All @@ -48,7 +49,7 @@ def __init__(
study_name='default_study',
use_existing_study=False,
n_optuna_workers=1,
discrep_loss='abs',
discrep_loss='mse',
disable_importance_sampling=False,
override_mem_eval_with_analytical_env=None,
analytical_lambda_discrep_noise=0.0,
Expand All @@ -63,6 +64,7 @@ def __init__(
self.policy_epsilon = policy_epsilon
self.mellowmax_beta = mellowmax_beta
self.mem_optimizer = mem_optimizer
self.mem_optim_objective = mem_optim_objective
self.prune_if_parent_suboptimal = prune_if_parent_suboptimal
self.ignore_queue_priority = ignore_queue_priority
self.annealing_should_sample_hyperparams = annealing_should_sample_hyperparams
Expand Down Expand Up @@ -531,16 +533,20 @@ def evaluate_memory_analytical(self):
noise = 0
if self.analytical_lambda_discrep_noise > 0:
noise = np.random.normal(loc=0, scale=self.analytical_lambda_discrep_noise)
discrep = mem_discrep_loss(self.memory_logits, self.policy_probs,
self.override_mem_eval_with_analytical_env)
return discrep + noise
if self.mem_optim_objective == 'ld':
memory_cost = mem_discrep_loss(self.memory_logits, self.policy_probs,
self.override_mem_eval_with_analytical_env)
else:
memory_cost = mem_tde_loss(self.memory_logits, self.policy_probs,
self.override_mem_eval_with_analytical_env)
return memory_cost + noise

def evaluate_memory(self):
if self.override_mem_eval_with_analytical_env is not None:
discrep = self.evaluate_memory_analytical().item()
memory_cost = self.evaluate_memory_analytical().item()
mem_aug_pomdp = memory_cross_product(self.memory_logits, self.override_mem_eval_with_analytical_env)
value_err = value_error(self.policy_probs, mem_aug_pomdp)[0].item()
return discrep, value_err
return memory_cost, value_err

assert len(self.replay.memory) > 0
self.reset_memory_state()
Expand Down
3 changes: 0 additions & 3 deletions grl/agent/analytical.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,8 @@ def new_pi_over_mem(self):

@partial(jit, static_argnames=['self'])
def policy_gradient_update(self, params: jnp.ndarray, optim_state: jnp.ndarray, pomdp: POMDP):
# import jax
# jax.debug.breakpoint()
outs, params_grad = value_and_grad(self.pg_objective_func, has_aux=True)(params, pomdp)
v_0, (td_v_vals, td_q_vals) = outs
# jax.debug.breakpoint()

# We add a negative here to params_grad b/c we're trying to
# maximize the PG objective (value of start state).
Expand Down
20 changes: 11 additions & 9 deletions grl/agent/td_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,17 @@ def _reset_q_values(self):
def _reset_eligibility(self):
self.eligibility = np.zeros((self.n_actions, self.n_obs))

def update(self,
obs,
action,
reward,
terminal,
next_obs,
next_action,
aug_obs=None,
next_aug_obs=None):
def update(
self,
obs,
action,
reward,
terminal,
next_obs,
next_action,
aug_obs=None, # memory-augmented observation
next_aug_obs=None, # and next observation (O x M)
):
# Because mdp.step() terminates with probability (1-γ),
# we have already factored in the γ that we would normally
# use to decay the eligibility.
Expand Down
4 changes: 0 additions & 4 deletions grl/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,6 @@ def pg_objective_func(pi_params: jnp.ndarray, pomdp: POMDP):
Policy gradient objective function:
sum_{s_0} p(s_0) v_pi(s_0)
"""
# import jax
# jax.debug.breakpoint()
pi_abs = nn.softmax(pi_params, axis=-1)
pi_ground = pomdp.phi @ pi_abs

Expand All @@ -248,8 +246,6 @@ def augmented_pg_objective_func(augmented_pi_params: jnp.ndarray, pomdp: POMDP):
Policy gradient objective function:
sum_{s_0} p(s_0) v_pi(s_0)
"""
# import jax
# jax.debug.breakpoint()
augmented_pi_probs = nn.softmax(augmented_pi_params)
mem_probs, action_policy_probs = deconstruct_aug_policy(augmented_pi_probs)
mem_logits = reverse_softmax(mem_probs)
Expand Down
2 changes: 1 addition & 1 deletion grl/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def reverse_softmax(dists: jnp.ndarray, eps: float = 1e-20) -> jnp.ndarray:
"""
# c = jnp.log(jnp.exp(dists).sum(axis=-1))
# params = jnp.log(dists) + c
params = jnp.log(dists + eps)
params = jnp.log(dists + jnp.array(eps, dtype=dists.dtype))
return params

def one_hot(x: np.ndarray, n: int, axis: int = -1) -> np.ndarray:
Expand Down
2 changes: 1 addition & 1 deletion scripts/batch_run_interleave.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def make_experiment(args):
corridor_length=args.tmaze_corridor_length,
discount=args.tmaze_discount,
junction_up_pi=args.tmaze_junction_up_pi)

partial_kwargs = {
'value_type': args.value_type,
'error_type': args.error_type,
Expand Down
2 changes: 2 additions & 0 deletions scripts/learning_agent/memory_iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def parse_args():
parser.add_argument('--normalize_reward_range', action='store_true')
parser.add_argument('--mem_optimizer', type=str, default='queue',
choices=['queue', 'annealing', 'optuna'])
parser.add_argument('--mem_optim_objective', type=str, default='ld', choices=['ld', 'td'])
parser.add_argument('--enable_priority_queue', action='store_true')
parser.add_argument('--annealing_should_sample_hyperparams', action='store_true')
parser.add_argument('--annealing_tmax', type=float, default=3.16e-3)
Expand Down Expand Up @@ -244,6 +245,7 @@ def main():
policy_epsilon=args.policy_epsilon,
replay_buffer_size=args.replay_buffer_size,
mem_optimizer=args.mem_optimizer,
mem_optim_objective=args.mem_optim_objective,
ignore_queue_priority=(not args.enable_priority_queue),
annealing_should_sample_hyperparams=args.annealing_should_sample_hyperparams,
annealing_tmax=args.annealing_tmax,
Expand Down
2 changes: 1 addition & 1 deletion scripts/plotting/mi_performance_ryt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from argparse import Namespace
from jax.nn import softmax
from jax import config
from jax.config import config
from pathlib import Path
from collections import namedtuple
from tqdm import tqdm
Expand Down
16 changes: 9 additions & 7 deletions scripts/test_discrete_mem_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
n_mem_values=args.n_memory_states,
replay_buffer_size=args.replay_buffer_size,
mem_optimizer=args.mem_optimizer,
mem_optim_objective=args.mem_optim_objective,
ignore_queue_priority=(not args.enable_priority_queue),
annealing_should_sample_hyperparams=args.annealing_should_sample_hyperparams,
annealing_tmax=args.annealing_tmax,
Expand Down Expand Up @@ -127,7 +128,7 @@
mem_probs = generate_hold_mem_fn(learning_agent.n_actions, learning_agent.n_obs,
learning_agent.n_mem_states)
learning_agent.set_memory(mem_probs, logits=False)
mem_aug_mdp = memory_cross_product(learning_agent.memory_logits, env)
mem_aug_pomdp = memory_cross_product(learning_agent.memory_logits, env)

# Value stuff
def get_start_obs_value(pi, mdp):
Expand All @@ -136,8 +137,8 @@ def get_start_obs_value(pi, mdp):
value_fn, _, _ = lstdq_lambda(pi, mdp, lambda_=args.lambda1)
return (value_fn @ (mdp.p0 @ mdp.phi)).item()

start_value = get_start_obs_value(pi_aug, mem_aug_mdp)
initial_discrep = discrep_loss(pi_aug, mem_aug_mdp)[0].item()
start_value = get_start_obs_value(pi_aug, mem_aug_pomdp)
initial_discrep = discrep_loss(pi_aug, mem_aug_pomdp)[0].item()

print(f'Start value: {start_value}')
print(f'Initial discrep: {initial_discrep}')
Expand All @@ -162,7 +163,7 @@ def get_start_obs_value(pi, mdp):

# Final performance stuff
learning_agent.reset_policy()
mem_aug_mdp = memory_cross_product(learning_agent.memory_logits, env)
mem_aug_pomdp = memory_cross_product(learning_agent.memory_logits, env)
planning_agent = AnalyticalAgent(
pi_params=learning_agent.policy_logits,
rand_key=jax.random.PRNGKey(args.seed + 10000),
Expand All @@ -171,11 +172,11 @@ def get_start_obs_value(pi, mdp):
value_type='q',
policy_optim_alg=args.policy_optim_alg,
)
planning_agent.reset_pi_params((mem_aug_mdp.observation_space.n, mem_aug_mdp.action_space.n))
pi_improvement(planning_agent, mem_aug_mdp, iterations=n_pi_iterations)
planning_agent.reset_pi_params((mem_aug_pomdp.observation_space.n, mem_aug_pomdp.action_space.n))
pi_improvement(planning_agent, mem_aug_pomdp, iterations=n_pi_iterations)
learning_agent.set_policy(planning_agent.pi_params, logits=True)

end_value = get_start_obs_value(learning_agent.policy_probs, mem_aug_mdp)
end_value = get_start_obs_value(learning_agent.policy_probs, mem_aug_pomdp)
print(f'Start value: {start_value}')
print(f'End value: {end_value}')

Expand All @@ -190,6 +191,7 @@ def get_start_obs_value(pi, mdp):
'init_policy_randomly': args.init_policy_randomly,
'n_random_policies': args.n_random_policies,
'mem_optimizer': args.mem_optimizer,
'mem_optim_objective': args.mem_optim_objective,
'enable_priority_queue': args.enable_priority_queue,
'tmax': args.annealing_tmax,
'tmin': args.annealing_tmin,
Expand Down

0 comments on commit ac05b27

Please sign in to comment.