-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvalue_iteration.py
executable file
·70 lines (46 loc) · 1.88 KB
/
value_iteration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import numpy as np
from GridWorld import standard_grid , windy_grid , ACTION_SPACE
from policy_evaluation_prob import get_trans_p_and_rewards
from Policy_evaluation_deterministic import print_values , print_policy
import math
threshold = math.exp(-3)
gamma = 0.9
if __name__ == '__main__':
grid = windy_grid()
trans_prob , rewards = get_trans_p_and_rewards(grid)
V={}
for s in grid.all_states():
V[s] =0
while True:
biggest_change = 0
for s in grid.actions.keys():
old_v = V[s]
new_v = float('-inf')
if not grid.is_terminal(s):
for a in ACTION_SPACE:
v = 0
for s2 in grid.all_states():
r = rewards.get((s,a,s2),0)
v+= trans_prob.get((s,a,s2),0)*(r+ gamma * V[s2])
if v>new_v:
new_v = v
V[s] = new_v
biggest_change = max(biggest_change,np.abs(old_v-V[s]))
if biggest_change<threshold:
break
policy = {}
for s in grid.actions.keys():
old_v = V[s]
new_v = float('-inf')
if not grid.is_terminal(s):
for a in ACTION_SPACE:
v = 0
for s2 in grid.all_states():
r = rewards.get((s, a, s2),0)
v += trans_prob.get((s, a, s2), 0) * (r + gamma * V[s2])
if v > new_v:
new_a = a
new_v = v
policy[s] = new_a
print_policy(policy ,grid)
print_values(V , grid)