-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathselfplay.py
253 lines (220 loc) · 10.2 KB
/
selfplay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
from copy import deepcopy
import random
import numpy as np
from mcts import MCTS
from config import Config
class SelfPlay(object):
def __init__(self, game, nnet):
self.game = deepcopy(game)
self.nnet = nnet
self.mcts = MCTS(self.game, self.nnet, Config.c_puct, Config.num_sims)
self.temp_threshold = Config.temp_threshold
self.use_dirichlet = Config.use_dirichlet
def generate_play_data(self):
examples = []
self.game.reset_board()
episode_step = 0
null_cap_pi = np.zeros(self.game.get_capture_action_size())
null_put_pi = np.zeros(self.game.get_placement_action_size())
board_state, player_value = self.game.get_current_state()
while True:
# Generate example and add it to the queue
episode_step += 1
temp = max(self.temp_threshold - episode_step, 0)
action_type, actions, probs = self.mcts.get_action_prob(board_state, temp=temp)
examples.append([board_state, action_type, probs, player_value])
# Select an action at random and update the game and MC search tree
probs /= sum(probs)
if self.use_dirichlet:
valid_placement, valid_capture = self.game.get_valid_actions(board_state)
if action_type == 'PUT':
dir_alpha = 1.0/np.sum(valid_placement)
dirichlet_probs = np.random.dirichlet(dir_alpha*np.ones(len(probs)))
valid_placement = valid_placement.flatten()
dirichlet_probs *= valid_placement
else:
dir_alpha = 1.0/np.sum(valid_capture)
dirichlet_probs = np.random.dirichlet(dir_alpha*np.ones(len(probs)))
valid_capture = valid_capture.flatten()
dirichlet_probs *= valid_capture
dirichlet_probs /= np.sum(dirichlet_probs)
action = actions[np.random.choice(np.arange(len(actions)), p=0.75*probs + 0.25*dirichlet_probs)]
else:
action = actions[np.random.choice(np.arange(len(actions)), p=probs)]
board_state, player_value = self.game.get_next_state(action, action_type)
self.mcts.move_root(action, player_value)
winner = self.game.get_game_ended(board_state)
if winner != 0 or episode_step > 200:
# Once winner is known, update each example with value based on the current player
# If the game reaches turn 200 with no winner then it is a draw and value is 0
# TODO: clean this up...
new_examples = []
for e in examples:
state = e[0]
v = winner * e[3]
if e[1] == 'PUT':
p_placement = e[2]
p_capture = null_cap_pi
action_type = 1
else:
p_placement = null_put_pi
p_capture = e[2]
action_type = 0
new_examples.append((state, p_placement, p_capture, v, action_type))
# Add opponent symmetry
opponent_state = np.copy(state)
opponent_state[-1] = (opponent_state[-1] - 1) * -1
new_examples.append((opponent_state, p_placement, p_capture, -v, action_type))
# Add other symmetries
for symmetry_type, state in self.game.get_symmetries(e[0]):
if e[1] == 'PUT':
p_placement = self.game.translate_action_symmetry(
e[1], symmetry_type, e[2]).flatten()
p_capture = null_cap_pi
action_type = 1
else:
p_placement = null_put_pi
p_capture = self.game.translate_action_symmetry(
e[1], symmetry_type, e[2]).flatten()
action_type = 0
new_examples.append((state, p_placement, p_capture, v, action_type))
opponent_state = np.copy(state)
opponent_state[-1] = (opponent_state[-1] - 1) * -1
new_examples.append((opponent_state, p_placement, p_capture, -v, action_type))
return new_examples
class Arena(object):
def __init__(self, game, player_agent1, player_agent2):
"""
player_agent1 and playeragent2 are two MCTS instances which have the newest and previous nnets policy_fn.
"""
self.player1 = player_agent1
self.player2 = player_agent2
self.game = game
def match(self, logging=False):
"""
Returns 1 if player1 won, -1 if player2 won.
"""
self.game.reset_board()
self.player1.reset(1)
self.player2.reset(-1)
#self.game.print_state()
while self.game.get_game_ended() == 0:
state, player_value = self.game.get_current_state()
# Obtain the policy from the player's agent
if player_value == 1: # if cur_player is player1
action_type, actions, probs = self.player1.get_action_prob(state, temp=0)
else: # plaver_value == -1 and cur_player is player2
action_type, actions, probs = self.player2.get_action_prob(state, temp=0)
# Choose the action greedily
action = actions[np.argmax(probs)]
if logging:
#print(state[0] + state[1] + state[2]*2 + state[3]*3)
action_log = self.game.action_to_str(action_type, action)
# Print the action taken
print "{}:\t {}".format(player_value, action_log)
#print(player_value, action_type, action)
board_state, player_value = self.game.get_next_state(action, action_type)
self.player1.move_root(action, player_value)
self.player2.move_root(action, -player_value)
# Update the board
#self.game.print_state()
return self.game.get_game_ended()
def play_matches(self, num_games):
player1_win, player2_win, draw = 0, 0, 0
# Player1 is new model
for t in xrange(num_games/2):
#if t == 0:
#winner = self.match(logging=True)
#else:
print "\nFirst player: AI 1"
winner = self.match(logging=True)
if winner == 1:
player1_win += 1
elif winner == -1:
player2_win += 1
else:
draw += 1
# Switch who goes first, player2 is new model
self.player1, self.player2 = self.player2, self.player1
for t in xrange(num_games/2):
#if t == 0:
#winner = self.match(logging=True)
#else:
print "\nFirst player: AI 2"
winner = self.match(logging=True)
if winner == 1:
player2_win += 1
elif winner == -1:
player1_win += 1
else:
draw += 1
return player1_win, player2_win, draw
class HumanPlay(object):
def __init__(self, game, ai_agent):
"""
ai_agent is a MCTS with a trained neural network
"""
self.game = game
self.ai = ai_agent
self.player = ['Human', 'AI']
self.cur_player = 0
self.first_ai_turn = True
def match(self):
"""
Game loop between ai player and human
"""
self.game.print_state()
while self.game.get_game_ended() == 0:
# Get current player's action
if self.player[self.cur_player] == 'AI':
state, player_value = self.game.get_current_state()
if self.first_ai_turn:
self.ai.reset(player_value)
self.first_ai_turn = False
action_type, actions, probs = self.ai.get_action_prob(state, temp=0)
action = actions[np.argmax(probs)]
action_log = self.game.action_to_str(action_type, action)
# Print the action taken
print "{}:\t {}".format(self.player[self.cur_player], action_log)
else:
while True:
action_str = raw_input("Enter action [i.e. 'PUT w A1 B2' or 'CAP b C4 g C2']\n"
+ "{}:\t ".format(self.player[self.cur_player]))
try:
action_type, action = self.game.str_to_action(action_str)
except:
action_type = ''
action = None
if action_type == 'PUT' and action is not None:
if self.game.get_valid_actions()[0][action]:
break
elif action_type == 'CAP' and action is not None:
if self.game.get_valid_actions()[1][action]:
break
print "Invalid action: {}".format(action_str)
# Apply the action
board_state, player_value = self.game.get_next_state(action, action_type)
self.ai.move_root(action, player_value)
# Update the board
self.game.print_state()
# Next player
if np.sum(board_state[self.game.board._CAPTURE_LAYER]) == 0:
self.cur_player = (self.cur_player + 1) % 2
return self.game.get_game_ended()
def play(self):
# Reset the board and ai state
self.game.reset_board()
self.first_ai_turn = True
# Determine who plays first
first = raw_input("Who plays first? ['h' = human, 'a' = ai, 'r' = random]\n> ")
if first == 'a':
self.player = ['AI', 'Human']
elif first == 'r':
random.shuffle(self.player)
# Play the game
game_value = self.match()
# Print out the winner
if game_value == 1:
print "{} Player Wins!".format(self.player[0])
else:
print "{} Player Wins!".format(self.player[1])