Skip to content

Commit

Permalink
Update test_ld_zero_conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
camall3n committed Aug 7, 2024
1 parent 9eedb9e commit 4b22bf8
Showing 1 changed file with 70 additions and 13 deletions.
83 changes: 70 additions & 13 deletions scripts/knife_edge/test_ld_zero_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,24 +76,82 @@ def get_max_diffs(pi, pomdp):
td_policy = np.einsum("sw,wa,wt->sta", pomdp.phi, pi, W)
pi_s = pomdp.get_ground_policy(pi)
Pi_s = np.eye(len(pi_s))[..., None] * pi_s[None,...]
td_policy.shape
Pi_s.shape
Pi = np.eye(len(pi))[..., None] * pi[None,...]
Phi = pomdp.phi
T = np.moveaxis(pomdp.T, 0, 1)

_, a = pi_s.shape
s, o = Phi.shape
sa = s*a
oa = o*a
# np.allclose(W @ pomdp.T @ Phi)
plt.imshow(Phi @ W)

IA = np.eye(len(pomdp.T))
W_A = np.kron(W, IA).reshape(o, a, s, a)
I_sa = np.eye(sa).reshape(s, a, s, a)
I_oa = np.eye(oa).reshape(o, a, o, a)

def dot(x, *args):
while args:
y, args = args[0], args[1:]
x = np.tensordot(x, y, axes=1)
return x

def ddot(x, *args):
while args:
y, args = args[0], args[1:]
# x = np.einsum('ijkl,klmn->ijmn', x, y)
x = np.tensordot(x, y, axes=2)
return x

def dpow(a, exp):
x, *rest = [a]*exp
return ddot(x, *rest)

W.shape
T.shape
T_sasa = dot(T, Pi_s)
T_oaoa = dot(W,T,Phi,Pi)

T1_sasa = dot(T, Pi_s)
T2_sasa = dot(T, td_policy)

sr1_sasa = np.linalg.tensorinv((I_sa - pomdp.gamma * T1_sasa))
sr2_sasa = np.linalg.tensorinv((I_sa - pomdp.gamma * T2_sasa))

R_sa = np.einsum("ast,ast->sa", pomdp.T, pomdp.R)
np.allclose(dot(W, ddot(sr1_sasa)), dot(W, ddot(sr2_sasa)))

np.allclose(dpow(T_oaoa, 2), ddot(T_oaoa, T_oaoa))

# pomdp_mdp_predictions = [
# (dot(W, I_sa), ddot(I_oa, W_A)),
# (dot(W, T_sasa), ddot(T_oaoa, W_A)),
# ]
# for i in range(2, 30):
# pomdp_mdp_predictions.append(
# (dot(W, dpow(T_sasa, i)), ddot(dpow(T_oaoa, i), W_A))
# )
#
# for pred_pomdp, pred_mdp in pomdp_mdp_predictions:
# assert np.allclose(pred_pomdp, pred_mdp)
# print('.', end='')
# print()

lambda_0 = 0.0
lambda_1 = 1.0
def get_K(Pi_s, td_policy, lambda_):
return lambda_ * Pi_s + (1-lambda_) * td_policy
K0 = get_K(Pi_s, td_policy, lambda_0)
K1 = get_K(Pi_s, td_policy, lambda_1)
n_a, n_s, _ = pomdp.T.shape
n_sa = n_a * n_s
T_sasa_0 = np.reshape(np.einsum('aij,jkb->iakb', pomdp.T, K0), (n_sa, n_sa))
T_sasa_1 = np.reshape(np.einsum('aij,jkb->iakb', pomdp.T, K1), (n_sa, n_sa))
R_sa = np.reshape(np.einsum("ast,ast->sa", pomdp.T, pomdp.R), n_sa)
I = np.eye(len(R_sa))
T_sasa_0 = np.reshape(np.einsum('aij,jkb->iakb', pomdp.T, K0), (sa, sa))
T_sasa_1 = np.reshape(np.einsum('aij,jkb->iakb', pomdp.T, K1), (sa, sa))
R_sa = np.reshape(np.einsum("ast,ast->sa", pomdp.T, pomdp.R), sa)
I = np.eye(sa)

Q0_sa = np.reshape(np.linalg.solve((I - pomdp.gamma * T_sasa_0), R_sa), (n_s, n_a))
Q1_sa = np.reshape(np.linalg.solve((I - pomdp.gamma * T_sasa_1), R_sa), (n_s, n_a))
Q0_sa = np.reshape(np.linalg.solve((I - pomdp.gamma * T_sasa_0), R_sa), (s, a))
Q1_sa = np.reshape(np.linalg.solve((I - pomdp.gamma * T_sasa_1), R_sa), (s, a))

Q0_wa = W @ Q0_sa
Q1_wa = W @ Q1_sa
Expand All @@ -106,23 +164,22 @@ def get_K(Pi_s, td_policy, lambda_):
}
return diffs

spec = 'ld_zero_by_t_projection'
spec = 'ld_zero_by_w_projection'

specs_and_n_params = {
'ld_zero_by_k_equality': 3,
'ld_zero_by_t_projection': 3,
'ld_zero_by_r_projection': 1,
'ld_zero_by_w_projection': 1,
}

#%%
data = []
ps = np.linspace(0, 1, n_samples)
for spec, n_params in specs_and_n_params.items():
pomdp, info = load_pomdp(spec)
all_ps = np.reshape(np.meshgrid(*[ps]*n_params), (n_params, -1)).T
for probs in tqdm(all_ps):
pi = get_policy(spec, *probs)
pomdp.phi
diffs = get_max_diffs(pi, pomdp)
diffs['spec'] = spec
data.append(diffs)
Expand Down

0 comments on commit 4b22bf8

Please sign in to comment.