Skip to content

Commit

Permalink
[Merge] PR #21 - learning-agent into taodav/main
Browse files Browse the repository at this point in the history
  • Loading branch information
camall3n authored Apr 22, 2024
2 parents e579f3e + ac05b27 commit 16a72e6
Show file tree
Hide file tree
Showing 31 changed files with 2,317 additions and 171 deletions.
277 changes: 257 additions & 20 deletions grl/agent/actorcritic.py

Large diffs are not rendered by default.

19 changes: 17 additions & 2 deletions grl/agent/td_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,17 @@ def _reset_q_values(self):
def _reset_eligibility(self):
self.eligibility = np.zeros((self.n_actions, self.n_obs))

def update(self, obs, action, reward, terminal, next_obs, next_action):
def update(
self,
obs,
action,
reward,
terminal,
next_obs,
next_action,
aug_obs=None, # memory-augmented observation
next_aug_obs=None, # and next observation (O x M)
):
# Because mdp.step() terminates with probability (1-γ),
# we have already factored in the γ that we would normally
# use to decay the eligibility.
Expand All @@ -50,6 +60,11 @@ def update(self, obs, action, reward, terminal, next_obs, next_action):
# probability γ.
#
# Thus we simply decay eligibility by λ.
if aug_obs is not None:
obs = aug_obs
if next_aug_obs is not None:
next_obs = next_aug_obs

self.eligibility *= self.lambda_
if self.trace_type == 'accumulating':
self.eligibility[action, obs] += 1
Expand Down Expand Up @@ -84,7 +99,7 @@ def run_td_lambda_on_mdp(
alpha=1,
n_episodes=1000,
):
# If AMDP, convert to pi_ground
# If POMDP, convert to pi_ground
if hasattr(mdp, 'phi'):
pi_ground = mdp.get_ground_policy(pi)
else:
Expand Down
12 changes: 12 additions & 0 deletions grl/environment/policy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ def tiger_alt_start_cam(**kwargs) -> jnp.ndarray:
pi_params = reverse_softmax(pi_phi)
return pi_params

def tiger_alt_start_known_ld(**kwargs) -> jnp.ndarray:
pi_phi = jnp.array([
[.9, 0.05, 0.05],
[.5, .125, .375],
[.25, .125, .625],
[1, 0, 0],
])
pi_params = reverse_softmax(pi_phi)
return pi_params

def get_start_pi(pi_name: str, pi_phi: jnp.ndarray = None, **kwargs):
if pi_phi is not None:
return reverse_softmax(pi_phi)
Expand All @@ -46,5 +56,7 @@ def get_start_pi(pi_name: str, pi_phi: jnp.ndarray = None, **kwargs):

except KeyError as _:
raise KeyError(f"No policy of the name {pi_name} found in policy_lib")
else:
print(f'Loaded policy "{pi_name}"')

return pi_params
190 changes: 120 additions & 70 deletions grl/environment/pomdp_files/hallway.POMDP

Large diffs are not rendered by default.

154 changes: 154 additions & 0 deletions grl/environment/pomdp_files/tmaze5-fixed.POMDP
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Converted POMDP file for tmaze_hyperparams
# with tmaze hallway length = 5

discount: 0.9
values: reward
states: 15
actions: 4
observations: 5

start:
0.5 0.5 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0

T: 0
1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0

T: 1
1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0

T: 2
0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0

T: 3
1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0


O: 0
1.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 1.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 1.0 0.0
0.0 0.0 0.0 0.0 0.0 1.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 1.0

O: 1
1.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 1.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 1.0 0.0
0.0 0.0 0.0 0.0 0.0 1.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 1.0

O: 2
1.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 1.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 1.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 1.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 1.0 0.0
0.0 0.0 0.0 0.0 0.0 1.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 1.0

O: 3
1.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 1.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 1.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 1.0 0.0
0.0 0.0 0.0 0.0 0.0 1.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 1.0

R: 0 : 12 : 14 : * 4.0
R: 0 : 13 : 14 : * -0.1

R: 1 : 12 : 14 : * -0.1
R: 1 : 13 : 14 : * 4.0
2 changes: 1 addition & 1 deletion grl/mc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def mc(mdp,
if mc_states not in ['all', 'first']:
raise ValueError("mc_states must be either 'all' or 'first'")

# If AMDP, convert to pi_ground
# If POMDP, convert to pi_ground
if hasattr(mdp, 'phi'):
pi_ground = mdp.get_ground_policy(pi)
else:
Expand Down
2 changes: 0 additions & 2 deletions grl/memory/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def all_n_state_deterministic_memory(n_mem_states: int):
all_mem_funcs = id[all_idxes]
return all_mem_funcs


def generate_random_uniform_memory_fn(n_mem_states: int, n_obs: int, n_actions: int):
T_mem = np.zeros((n_actions, n_obs, n_mem_states, n_mem_states))

Expand Down Expand Up @@ -163,7 +162,6 @@ def tiger_alt_start_1bit_optimal():
T_mem = np.stack([T_mem_listen, T_mem_listen, T_mem_listen])
return T_mem


"""
1 bit memory functions with three obs: r, b, t
and 2 actions: up, down
Expand Down
2 changes: 1 addition & 1 deletion grl/memory_iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def get_measure():
print(f"Learnt memory for iteration {mem_it}: \n"
f"{agent.memory}")

# Make a NEW memory AMDP
# Make a NEW memory POMDP
pomdp = memory_cross_product(agent.mem_params, init_pomdp)

if pi_per_step > 0:
Expand Down
15 changes: 8 additions & 7 deletions grl/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ def add_tmaze_hyperparams(parser: argparse.ArgumentParser):
# Args
parser = argparse.ArgumentParser()
# yapf:disable
parser.add_argument('--spec', default='example_11', type=str,
parser.add_argument('--spec', default='tmaze_5_two_thirds_up', type=str,
help='name of POMDP spec; evals Pi_phi policies by default')
parser.add_argument('--mi_iterations', type=int, default=1,
help='For memory iteration, how many iterations of memory iterations do we do?')
parser.add_argument('--mi_steps', type=int, default=50000,
parser.add_argument('--mi_steps', type=int, default=20000,
help='For memory iteration, how many steps of memory improvement do we do per iteration?')
parser.add_argument('--pi_steps', type=int, default=50000,
parser.add_argument('--pi_steps', type=int, default=10000,
help='For memory iteration, how many steps of policy improvement do we do per iteration?')
parser.add_argument('--policy_optim_alg', type=str, default='policy_iter',
help='policy improvement algorithm to use. "policy_iter" - policy iteration, "policy_grad" - policy gradient, '
Expand Down Expand Up @@ -147,10 +147,11 @@ def add_tmaze_hyperparams(parser: argparse.ArgumentParser):
logging.info(f'Pi_phi_x:\n {pi_dict["Pi_phi_x"]}')
if 'Pi_phi' in pi_dict and pi_dict['Pi_phi'] is not None:
logging.info(f'Pi_phi:\n {pi_dict["Pi_phi"]}')
if args.init_pi is not None:
pi_params = get_start_pi(args.init_pi,
pi_phi=pi_dict['Pi_phi'][0],
pomdp=pomdp)
if args.init_pi is not None:
try:
pi_params = get_start_pi(args.init_pi, pi_phi=pi_dict['Pi_phi'][0])
except TypeError:
pi_params = get_start_pi(args.init_pi, pi_phi=None)

results_path = results_path(args)

Expand Down
80 changes: 80 additions & 0 deletions grl/utils/augmented_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import numpy as np

from grl.utils import softmax, reverse_softmax

# mem_probs are shape AOM->M
# policy_probs are shape OM->(A*M)

# P(m' | o, a, m) = P(a, m' | o, m) / P(a | o, m)
# = P(a, m' | o, m) / sum_m' P(a, m' | o, m)
#
# P(m', a | o, m) = P(m' | o, a, m) * P(a | o, m)

A = 3
M = 2
O = 4
aug_policy_probs = softmax(np.random.normal(size=np.prod([A, O, M, M])).reshape([O, M, A*M]), axis=-1)

def deconstruct_aug_policy(aug_policy_probs):
O, M, AM = aug_policy_probs.shape
A = AM // M
aug_policy_probs_omam = aug_policy_probs.reshape([O, M, A, M])
action_policy_probs_oma1 = aug_policy_probs_omam.sum(-1, keepdims=1) # (O, M, A, 1)
# pr(^|*)
action_policy_probs = action_policy_probs_oma1.squeeze(-1)
assert np.allclose(action_policy_probs.sum(-1), 1)

aug_policy_logits_omam = reverse_softmax(aug_policy_probs_omam)
action_policy_logits_oma1 = reverse_softmax(action_policy_probs_oma1)
mem_logits_omam = (aug_policy_logits_omam - action_policy_logits_oma1) # (O, M, A, M)
mem_probs_omam = softmax(mem_logits_omam, -1) # (O, M, A, M)
# pr(^|*)

mem_probs = np.moveaxis(mem_probs_omam, -2, 0) # (A, O, M, M)
assert np.allclose(mem_probs.sum(-1), 1)

mem_logits = reverse_softmax(mem_probs)
return mem_logits, action_policy_probs

def construct_aug_policy(mem_logits, action_policy_probs):
A, O, M, _ = mem_logits.shape
mem_probs = softmax(mem_logits, axis=-1) # (A, O, M, M)
mem_probs_omam = np.moveaxis(mem_probs, 0, -2) # (O, M, A, M)

action_policy_probs_oma1 = action_policy_probs[..., None] # (O, M, A, 1)

aug_policy_probs_omam = (mem_probs_omam * action_policy_probs_oma1)
aug_policy_probs = aug_policy_probs_omam.reshape([O, M, A*M])
assert np.allclose(aug_policy_probs.sum(-1), 1)

return aug_policy_probs

mem_logits, action_policy_probs = deconstruct_aug_policy(aug_policy_probs)
aug_policy_probs_reconstructed = construct_aug_policy(mem_logits, action_policy_probs)
assert np.allclose(aug_policy_probs_reconstructed, aug_policy_probs)

#%%

import numpy as np

from grl.environment.spec import load_pomdp
from grl.memory import get_memory
from grl.utils.augmented_policy import construct_aug_policy, deconstruct_aug_policy


env, info = load_pomdp('tmaze_5_two_thirds_up', memory_id='18')
pi = info['Pi_phi'][0]
mem_params = get_memory('18',
n_obs=env.observation_space.n,
n_actions=env.action_space.n,
n_mem_states=2)

inp_aug_pi = np.expand_dims(pi, axis=1).repeat(mem_params.shape[-1], axis=1)

aug_policy = construct_aug_policy(mem_params, inp_aug_pi)
mem_logits_reconstr, deconstr_aug_pi = deconstruct_aug_policy(aug_policy)

softmax(mem_logits_reconstr, -1).round(3)
softmax(mem_params, -1).round(3)

print()
Loading

0 comments on commit 16a72e6

Please sign in to comment.