-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
41 lines (37 loc) · 1.13 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from torch import save
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from tictactoe.real_ai.boards import *
from tictactoe.real_ai.module import *
NUM_BATCHES = 8000
# these control over-fitting
# (16, 8) (32, 7)
ZERO_LIMIT = .0 * NUM_BATCHES
BATCH_SIZE = 16
ROUND_LIMIT = 6
if __name__ == '__main__':
counter = 0
model = RealAI()
loss_function = CrossEntropyLoss()
optimizer = Adam(params=model.parameters(), lr=1e-3)
boards = None
for epoch in range(NUM_BATCHES):
if epoch % ROUND_LIMIT == 0:
boards = Boards(BATCH_SIZE)
boards.opponent_go()
output = model(boards.get_merged())
suggestions = boards.get_suggestions()
loss = loss_function(output, suggestions)
if loss == 0:
if counter >= ZERO_LIMIT:
print("Reached limit.")
epoch -= 1
continue
counter += 1
print(loss.item())
loss += .001
boards.go(suggestions)
loss.backward()
optimizer.step()
print(f"{counter}/{NUM_BATCHES}")
save(model.state_dict(), "./model/23mxx.pth")