-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
45 lines (35 loc) · 1.79 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import random
import numpy as np
from debug_settings import *
from environment import TheoremProvingEnvironment
from theorem import CancellationLaw
from agent import Agent
if SET_RANDOM_SEED:
random.seed(5)
np.random.seed(5)
# ----------------------------------------------------------------------------------
# Set up environment with a start state including variables, assumptions, and goal
# ----------------------------------------------------------------------------------
env = TheoremProvingEnvironment(thm=CancellationLaw)
# ----------------------------------------------------------------------------------
# Evaluate agent when q-table is empty (should be pretty bad)
# ----------------------------------------------------------------------------------
a = Agent(env)
episodes, total_epochs, total_penalties = a.evaluate(episodes=3)
print(f"BEFORE TRAINING:")
print(f"\tAverage timesteps per episode: {total_epochs / episodes}")
print(f"\tAverage penalties per episode: {total_penalties / episodes}")
# # ----------------------------------------------------------------------------------
# # Train agent by filling out the q-table
# # ----------------------------------------------------------------------------------
print("\nTRAINING...")
a.train(episodes=5)
print("Training finished.\n")
# print(a.qtable)
# # ----------------------------------------------------------------------------------
# # Evaluate how well agent was trained by evaluating how well it performs with new qtable
# # ----------------------------------------------------------------------------------
episodes, total_epochs, total_penalties = a.evaluate(episodes=3)
print(f"AFTER TRAINING:")
print(f"\tAverage timesteps per episode: {total_epochs / episodes}")
print(f"\tAverage penalties per episode: {total_penalties / episodes}")