Skip to content

Commit

Permalink
[Merge] PR #20 - Add script to compare error signals
Browse files Browse the repository at this point in the history
  • Loading branch information
camall3n authored Apr 19, 2024
2 parents ddff824 + d00d808 commit e579f3e
Show file tree
Hide file tree
Showing 5 changed files with 439 additions and 0 deletions.
136 changes: 136 additions & 0 deletions grl/environment/examples_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,142 @@ def example_22():

return to_dict(T, R, 0.999, p0, phi, Pi_phi, Pi_phi_x)

def example_26():
# [r1, b1, w1, r2, b2 w2]
T_stay = np.array([
[0, 1, 0, 0, 0, 0.],
[0, 0, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 1, 0, 0],
])
T_flip = np.array([
[0, 0, 0, 0, 1, 0.],
[0, 0, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 1, 0, 0],
])
T = np.array([T_stay, T_flip])

R = np.array([
[0, 0, 0, 0, 0, 0.],
[0, 0,-1, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0],
])
R = np.array([R, R])

p0 = np.zeros(len(T[0]))
p0[0] = 0.5
p0[3] = 0.5

phi = np.array([
# r, b, w
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
[0, 0, 1]
])

p = .25
Pi_phi = [
np.array([
[p, 1 - p], # up, down
[p, 1 - p],
[p, 1 - p],
]),
]

Pi_phi_x = [
np.array([
[p, 1 - p], #r0
[p, 1 - p], #r1
[p, 1 - p], #b0
[p, 1 - p], #b1
[p, 1 - p], #w0
[p, 1 - p], #w1
]),
]

return to_dict(T, R, 0.9, p0, phi, Pi_phi, Pi_phi_x)

def example_26a():
# [r1, b1, w1, r2, b2 w2]
# but rewards are observable: distinguishes w1 from w2
T_stay = np.array([
[0, 1, 0, 0, 0, 0.],
[0, 0, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 1, 0, 0],
])
T_flip = np.array([
[0, 0, 0, 0, 1, 0.],
[0, 0, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 1, 0, 0],
])
T = np.array([T_stay, T_flip])

R = np.array([
[0, 0, 0, 0, 0, 0.],
[0, 0,-1, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0],
])
R = np.array([R, R])

p0 = np.zeros(len(T[0]))
p0[0] = 0.5
p0[3] = 0.5

phi = np.array([
# r, b, w1, w2
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1],
])

p = .25
Pi_phi = [
np.array([
[p, 1 - p], # up, down
[p, 1 - p],
[p, 1 - p],
[p, 1 - p],
]),
]

Pi_phi_x = [
np.array([
[p, 1 - p], #r0
[p, 1 - p], #r1
[p, 1 - p], #b0
[p, 1 - p], #b1
[p, 1 - p], #w1_0
[p, 1 - p], #w1_1
[p, 1 - p], #w2_0
[p, 1 - p], #w2_1
]),
]

return to_dict(T, R, 0.9, p0, phi, Pi_phi, Pi_phi_x)

def simple_chain(n: int = 10):
T = np.zeros((n, n))
states = np.arange(n)
Expand Down
14 changes: 14 additions & 0 deletions grl/memory/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,20 @@ def tiger_alt_start_1bit_optimal():
])
memory_12 = np.stack([mem_12, mem_12])

mem_101 = np.array([
# always hold the bit!
[ # red
#s0, s1
[1, 0.],
[0, 1],
],
[ # terminal
[1, 0],
[0, 1],
],
])
memory_101 = np.stack([mem_101, mem_101])

mem_102 = np.array([
# always flip the bit!
[ # red
Expand Down
63 changes: 63 additions & 0 deletions scripts/check_example_16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import argparse
import logging
import pathlib
from time import time

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import jax
from jax.config import config
from jax.nn import softmax

from grl.environment import load_pomdp
from grl.environment.policy_lib import get_start_pi
from grl.utils.loss import mstd_err, discrep_loss, value_error
from grl.utils.file_system import results_path, numpyify_and_save
from grl.memory import get_memory, memory_cross_product
from grl.memory_iteration import run_memory_iteration
from grl.utils.math import reverse_softmax
from grl.utils.mdp import functional_get_occupancy
from grl.utils.policy import construct_aug_policy
from grl.vi import td_pe
from grl.utils.policy_eval import analytical_pe, lstdq_lambda

#%%

np.set_printoptions(precision=8)

spec = 'example_16_terminal'
seed = 42

np.set_printoptions(precision=8, suppress=True)
config.update('jax_platform_name', 'cpu')
config.update("jax_enable_x64", True)

rand_key = None
np.random.seed(seed)
rand_key = jax.random.PRNGKey(seed)

pomdp, pi_dict = load_pomdp(spec, rand_key)
pomdp.gamma = 0.99999

if 'Pi_phi' in pi_dict and pi_dict['Pi_phi'] is not None:
pi_phi = pi_dict['Pi_phi'][0]
print(f'Pi_phi:\n {pi_phi}')

p = 0.5
pi_phi = np.array(
[[p, 1-p],
[p, 1-p]]
)

state_vals, mc_vals, td_vals, info = analytical_pe(pi_phi, pomdp)
state_vals['v']
state_vals['q']
mc_vals['v']
mc_vals['q']
td_vals['v']
td_vals['q']

discrep_loss(pi_phi, pomdp, alpha=0)

value_error(pi_phi, pomdp)
73 changes: 73 additions & 0 deletions scripts/check_example_26.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import argparse
import logging
import pathlib
from time import time

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import jax
from jax.config import config
from jax.nn import softmax

from grl.environment import load_pomdp
from grl.environment.policy_lib import get_start_pi
from grl.utils.loss import mstd_err, discrep_loss, value_error
from grl.utils.file_system import results_path, numpyify_and_save
from grl.memory import get_memory, memory_cross_product
from grl.memory_iteration import run_memory_iteration
from grl.utils.math import reverse_softmax
from grl.utils.mdp import functional_get_occupancy
from grl.utils.policy import construct_aug_policy
from grl.vi import td_pe
from grl.utils.policy_eval import analytical_pe, lstdq_lambda

#%%

np.set_printoptions(precision=8)

spec = 'example_26a'
seed = 42

np.set_printoptions(precision=8, suppress=True)
config.update('jax_platform_name', 'cpu')
config.update("jax_enable_x64", True)

rand_key = None
np.random.seed(seed)
rand_key = jax.random.PRNGKey(seed)

pomdp, pi_dict = load_pomdp(spec, rand_key)
pomdp.gamma = 0.99999

if 'Pi_phi' in pi_dict and pi_dict['Pi_phi'] is not None:
pi_phi = pi_dict['Pi_phi'][0]
print(f'Pi_phi:\n {pi_phi}')

p = 0.1
lds = []
ps = np.linspace(0,1,500)
for p in ps:

pi_phi = np.array(
[[p, 1-p],
[p, 1-p],
[p, 1-p],
[p, 1-p]]
)

state_vals, mc_vals, td_vals, info = analytical_pe(pi_phi, pomdp)
# state_vals['v']
# state_vals['q']
# mc_vals['v']
# mc_vals['q']
# td_vals['v']
# td_vals['q']

lds.append(discrep_loss(pi_phi, pomdp, alpha=0)[0])

# value_error(pi_phi, pomdp)

plt.semilogy(ps, lds)
plt.xlabel("Pr(stay)")
plt.ylabel("log LD")
Loading

0 comments on commit e579f3e

Please sign in to comment.