-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPolicy_evaluation_deterministic.py
executable file
·91 lines (73 loc) · 2.34 KB
/
Policy_evaluation_deterministic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import numpy as np
from GridWorld import standard_grid , ACTION_SPACE
import math
threshold = math.exp(-3)
def print_values(v,g):
for i in range(g.rows):
print("----------------------------")
for j in range(g.cols):
V = v.get((i,j),0)
if V>=0:
print(" %.2f|"%V , end="")
else:
print("%.2f|"%V , end="")
print("")
def print_policy(P,g):
for i in range(g.rows):
print("----------------------------")
for j in range(g.cols):
a = P.get((i,j),'')
print(" %s |"%a , end="")
print("")
if __name__ == '__main__':
transition_prob = {}
rewards = {}
grid = standard_grid()
for i in range(grid.rows):
for j in range(grid.cols):
s = (i,j)
if not grid.is_terminal(s):
for a in ACTION_SPACE:
s2 = grid.get_next_state(s , a)
transition_prob[(s,a,s2)]=1
if s2 in grid.rewards:
rewards[(s,a,s2)]=grid.rewards[s2]
print(rewards)
policy = {
(2,0):'U',
(1,0):'U',
(0,0):'R',
(0,1):'R',
(0,2):'R',
(1,2):'U',
(2,1):'R',
(2,2):'U',
(2,3):'L',
}
print_policy(policy , grid)
print("")
v = {}
for s in grid.all_states():
v[s]=0
gamma = 0.9
it = 0
while True:
biggest_change = 0
for s in grid.all_states():
old_v =v[s]
new_v=0
for a in ACTION_SPACE:
for s2 in grid.all_states():
action_prob = 1 if policy.get(s) == a else 0
# action_prob = 1
#else: action_prob =0
r = rewards.get((s,a,s2),0)
new_v += action_prob*transition_prob.get((s,a,s2),0)*(r+gamma*v[s2])
v[s]=new_v
biggest_change = max(biggest_change , np.abs(new_v - old_v))
print(f" itr : {it} , Biggest Change : {biggest_change} \n" )
print_values(v,grid)
it+=1
if biggest_change<threshold:
break
print("\nTask Completed")