-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
79 lines (64 loc) · 2.77 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import fire
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
from chess_dataset import ChessDataset
from pgn_to_array_converter import download_pgn, PgnToArrayConverter
from train_chess_cnn_keras import define_model, define_callbacks
def get_data(download_url, rating=2700, overwrite=True):
# download the pgn
pgn_file = download_pgn(download_url, overwrite=overwrite)
print(f"\nRating threshold set to: {rating}")
data_converter = PgnToArrayConverter(pgn_file, rating=rating)
print("\nConverting PGN to numpy arrays")
data_converter.pgn_to_arrays()
def train_cnn(model_save_name=None, epochs=20, batch_size=256):
print(f"Beginning training of Chess CNN of selected games for {epochs} epochs")
input_shape = (8, 8, 8)
model = define_model(input_shape, (8,))
model.fit(
ChessDataset(batch_size=batch_size),
epochs=epochs,
validation_data=ChessDataset(batch_size=batch_size, validation=True),
workers=4,
shuffle=True,
callbacks=define_callbacks(),
)
model.save(model_save_name, overwrite=True)
# Tensorflow currently hordes the GPU memory after training
# print("Clearing GPU Memory")
# tf.config.set_visible_devices([], 'GPU')
# tf.keras.backend.clear_session()
# def play_cnn(model_path):
# print("\nRunning Basic Chess GUI")
# Path("model_name.txt").open("w").write(model_path)
# cmd = "python kivy_gui.py"
# time.sleep(5)
# sp.run(cmd.split())
def download_train_play_chess_cnn(
download_url: str = "https://database.nikonoel.fr/lichess_elite_2021-01.zip",
rating: int = 2700,
model_name: str = "chess_model.h5",
epochs: int = 20,
overwrite: bool = True,
skip_data: bool = False,
):
"""
The complete pipeline of downloading PGN Zip from Lichess (https://database.nikonoel.fr),
Training a CNN using Keras (https://keras.io) and then Playing the Engine using Kivy (https://kivy.org).
Note: Currently need to run "python kivy_gui.py" after this command, as Tensorflow doesn't release the GPU memory
after training. Likely a bug.
:param rating: Rating threshold for games to be trained on
:param model_name: Model save name
:param epochs: Number of epochs to train the Chess CNN on
:param download_url: Path to Lichess PGN Zip, E.g. https://database.nikonoel.fr/lichess_elite_2021-01.zip".
:param overwrite: Overwrite existing files in download
:param skip_data: Skip loading the data, makes sure it exists
Otherwise can modify the download_pgn function yourself
"""
if not skip_data:
get_data(download_url, rating, overwrite)
train_cnn(model_name, epochs)
# commented out due to Tensorflow bug
# play_cnn(model_name)
if __name__ == "__main__":
fire.Fire(download_train_play_chess_cnn)