-
Notifications
You must be signed in to change notification settings - Fork 0
/
SAC_BC_N_AntMaze_MD.py
101 lines (88 loc) · 3.15 KB
/
SAC_BC_N_AntMaze_MD.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Imports
import gym
import random
import numpy as np
import copy
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
from Algorithms import SAC_BC_N_MSE
import d4rl
# Load environment
env = gym.make('antmaze-medium-diverse-v0')
dataset = d4rl.qlearning_dataset(env)
# Set seed
seed = 19636
offset = 100
env.seed(seed)
env.action_space.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
print("Converting data...")
'''
In v0 of antmaze data sets, the timeout flags are not synced with actual trajectory ends
see - https://github.com/Farama-Foundation/D4RL/issues/77
These end transitions can be identified as having a state/next-state L2-norm of > 0.5
These transitions are removed prior to training
'''
states = dataset["observations"]
next_states = dataset["next_observations"]
distance = np.linalg.norm(states[:, :2] - next_states[:, :2], axis=-1)
mean = np.mean(dataset["observations"][distance <= 0.5], 0)
std = np.std(dataset["observations"][distance <= 0.5], 0) + 1e-3
states = torch.Tensor((dataset["observations"][distance <= 0.5] - mean) / std)
actions = torch.Tensor(dataset["actions"][distance <= 0.5])
rewards = dataset["rewards"][distance <= 0.5]
rewards = 4 * (rewards - 0.5)
rewards = torch.Tensor(rewards)
next_states = torch.Tensor((dataset["next_observations"][distance <= 0.5] - mean) / std)
dones = torch.Tensor(dataset["terminals"][distance <= 0.5])
replay_buffer = [states, actions, rewards, next_states, dones]
print("...data conversion complete")
# Hyperparameters and initialisation
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
min_action = env.action_space.low[0]
max_action = env.action_space.high[0]
min_ent = -action_dim
dep_targets = False
higher_bc_period = 50000
beta = 0.02
num_critics = 10
device = "cuda:0"
agent = SAC_BC_N_MSE.Agent(state_dim, action_dim, min_action, max_action, min_ent, num_critics, device=device)
# Training SAC-BC-N #
'''
Reset goal each episode to evaluate as per https://github.com/Farama-Foundation/D4RL/pull/128
and lines 189-198 of https://github.com/Farama-Foundation/D4RL/pull/128/commits/724c37483a3ff9d8106107344742566eda4a11d6
'''
epochs = 100
iterations = 10000
grad_steps = 0
evals = 100
for epoch in range(epochs):
if grad_steps < higher_bc_period:
agent.train_offline(replay_buffer, iterations, 10 * beta, dep_targets)
else:
agent.train_offline(replay_buffer, iterations, beta, dep_targets)
grad_steps += iterations
# Evaluation (mean) #
env.seed(seed + offset)
scores_norm_mean = []
for eval in range(evals):
done = False
state = env.reset()
score = 0
goal = env.goal_sampler(np.random)
env.set_target_goal(goal)
while not done:
state = (state - mean) / std
action = agent.choose_action(state)
state, reward, done, info = env.step(action)
score += reward
score_norm = 100 * env.get_normalized_score(score)
scores_norm_mean.append(score_norm)
print("Epoch", epoch, "Grad steps", grad_steps, "Score Norm (Mean) %.2f" % np.mean(scores_norm_mean))