Skip to content

Commit

Permalink
Cleanup and reorganization, AI vs AI added
Browse files Browse the repository at this point in the history
  • Loading branch information
culk committed Mar 13, 2018
1 parent b4700ff commit 887367f
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 28 deletions.
59 changes: 38 additions & 21 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,49 @@ class Config(object):
information parameters. Model objects are passed a Config() object at
instantiation.
"""
models = ['linear', 'dense', 'conv']

hidden_size = 10
# TODO: (feature add) scale the learning rate over time?
lr = 0.01

# Training settings
num_iters = 30
batch_size = 100
epochs = 50
lr = 0.01

num_layers = 5
num_filters = 16
kernel_size = 3
dropout = 0.1

# Neural network settings
models = ['linear', 'dense', 'conv']
model = models[2]
checkpoint_folder = 'checkpoints'
arena_games = 40
arena_threshold = 0.55
# Should be set in a way that encourages exploration in early moves and then
# selects optimal moves later in the game
temp_threshold = 6
num_iters = 30
num_layers = 10
hidden_size = 64 # linear and dense only
num_filters = 64 # conv only
kernel_size = 3 # conv only
dropout = 0.1 # conv only
regularizer = 0.0001 # resnet only
num_residual_blocks = 3 # resnet only

num_episodes = 100
# MCTS settings
num_episodes = 25
num_sims = 10
c_puct = 1
num_sims = 25
# Should be set based on game length to encourage exploration in early moves
temp_threshold = 6

# Unused:
#arena_games = 40
#arena_threshold = 0.55

# For AI vs AI play
class Config1(Config):
model = Config.models[2]
checkpoint_folder = 'checkpoints_old'

num_layers = 10 # both
num_filters = 64 # conv only
hidden_size = 64 # linear and dense only

class Config2(Config):
model = Config.models[2]
checkpoint_folder = 'checkpoints_old'

num_layers = 15 # both
num_filters = 16 # conv only
hidden_size = 64 # linear and dense only

regularizer = 0.0001
num_residual_blocks = 3
31 changes: 24 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,45 @@
from retrain import Coach, Individual
from selfplay import HumanPlay
from selfplay import Arena, HumanPlay
from mcts import MCTS
from zertz.ZertzGame import ZertzGame as Game
from model import NNetWrapper as NN
from config import Config
from config import Config, Config1, Config2

if __name__ == '__main__':
# Game settings
rings = 19
marbles = {'w': 10, 'g': 10, 'b': 10}
win_con = [{'w': 2}, {'g': 2}, {'b': 2}, {'w': 1, 'g': 1, 'b': 1}]
t = 3

# Setup
game = Game(rings, marbles, win_con, t)
config = Config()
nnet = NN(game, config)
trainer = Individual(game, nnet, config)

# load model weights?

# learn
# Option #1: Learn
trainer = Individual(game, nnet, config)
trainer.learn()

# play against AI
# Option #2: Human vs AI
#nnet.load_checkpoint(filename='checkpoint_64_10_29.pth.tar')
#ai_agent = MCTS(game, nnet, config.c_puct, config.num_sims)
#hp = HumanPlay(game, ai_agent)
#hp.play()

# Option #3: AI vs AI
#config1 = Config1()
#config2 = Config2()
#nnet1 = NN(game, config1)
#nnet2 = NN(game, config2)

#nnet1.load_checkpoint(filename='checkpoint_64_10_29.pth.tar')
#nnet2.load_checkpoint(filename='checkpoint_16_15_29.pth.tar')

#ai_agent1 = MCTS(game, nnet1, config1.c_puct, config1.num_sims)
#ai_agent2 = MCTS(game, nnet2, config2.c_puct, config2.num_sims)

#arena = Arena(game, ai_agent1, ai_agent2)
#ai1_win, ai2_win, draw = arena.play_matches(10)
#print(ai1_win, ai2_win, draw)

0 comments on commit 887367f

Please sign in to comment.