Skip to content

Commit ce8f79f

Browse files
committed
DDPG
1 parent 80b2a51 commit ce8f79f

File tree

4 files changed

+316
-0
lines changed

4 files changed

+316
-0
lines changed

Char05 DDPG/DDPG.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
import os
2+
import numpy as np
3+
import copy
4+
import gym
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
import torch.optim as optim
9+
from tensorboardX import SummaryWriter
10+
from buffer import ReplayBuffer
11+
12+
'''
13+
Deep Deterministic Policy Gradients (DDPG)
14+
Original paper:
15+
CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING https://arxiv.org/abs/1509.02971
16+
'''
17+
18+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19+
20+
# parameters
21+
env_name = "Pendulum-v1"
22+
tau = 0.01
23+
epsilon = 0.8
24+
epsilon_decay = 0.9999
25+
actor_lr = 3e-4
26+
critic_lr = 3e-4
27+
discount = 0.99
28+
buffer_size = 10000
29+
batch_size = 128
30+
max_episode = 40000
31+
max_step_size = 500
32+
seed = 1
33+
34+
render = True
35+
load = False
36+
37+
env = gym.make(env_name)
38+
39+
def envAction(action):
40+
41+
low = env.action_space.low
42+
high = env.action_space.high
43+
action = low + (action + 1.0) * 0.5 * (high - low)
44+
action = np.clip(action, low, high)
45+
46+
return action
47+
48+
# Set seeds
49+
env.seed(seed)
50+
torch.manual_seed(seed)
51+
np.random.seed(seed)
52+
53+
state_dim = env.observation_space.shape[0]
54+
action_dim = env.action_space.shape[0]
55+
56+
57+
class Actor(nn.Module):
58+
59+
def __init__(self, state_dim, action_dim, init_w=3e-3):
60+
super(Actor, self).__init__()
61+
62+
self.l1 = nn.Linear(state_dim, 256)
63+
self.l2 = nn.Linear(256, 256)
64+
self.l3 = nn.Linear(256, action_dim)
65+
66+
self.l3.weight.data.uniform_(init_w, init_w)
67+
self.l3.bias.data.uniform_(-init_w, init_w)
68+
69+
def forward(self, state):
70+
a = F.relu(self.l1(state))
71+
a = F.relu(self.l2(a))
72+
a = torch.tanh(self.l3(a))
73+
74+
return a
75+
76+
77+
class Critic(nn.Module):
78+
79+
def __init__(self, state_dim, action_dim, init_w=3e-3):
80+
super(Critic, self).__init__()
81+
82+
self.l1 = nn.Linear(state_dim + action_dim, 256)
83+
self.l2 = nn.Linear(256, 256)
84+
self.l3 = nn.Linear(256, 1)
85+
self.l3.weight.data.uniform_(-init_w, init_w)
86+
self.l3.bias.data.uniform_(-init_w, init_w)
87+
88+
def forward(self, state, action):
89+
sa = torch.cat((state, action), 1)
90+
q = F.relu(self.l1(sa))
91+
q = F.relu(self.l2(q))
92+
q = self.l3(q)
93+
94+
return q
95+
96+
97+
class DDPG:
98+
99+
def __init__(self):
100+
super(DDPG, self).__init__()
101+
102+
self.actor = Actor(state_dim, action_dim).to(device)
103+
self.actor_target = copy.deepcopy(self.actor)
104+
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
105+
106+
self.critic = Critic(state_dim, action_dim).to(device)
107+
self.critic_target = copy.deepcopy(self.critic)
108+
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
109+
110+
self.buffer = ReplayBuffer(buffer_size, batch_size)
111+
112+
self.num_training = 1
113+
114+
self.writer = SummaryWriter('./log')
115+
116+
os.makedirs('./model/', exist_ok=True)
117+
118+
def act(self, state):
119+
state = torch.FloatTensor(state.reshape(1, -1)).to(device)
120+
return self.actor(state).cpu().data.numpy().flatten()
121+
122+
def put(self, *transition):
123+
state, action, reward, next_state, done = transition
124+
state = torch.FloatTensor(state).to(device).unsqueeze(0)
125+
action = torch.FloatTensor(action).to(device).unsqueeze(0)
126+
Q = self.critic(state, action).detach()
127+
self.buffer.add(transition)
128+
129+
return Q.cpu().item()
130+
131+
def update(self):
132+
133+
if not self.buffer.sample_available():
134+
return
135+
136+
state, action, reward, next_state, done = self.buffer.sample()
137+
138+
# state = (state - self.buffer.state_mean())/(self.buffer.state_std() + 1e-7)
139+
# next_state = (next_state - self.buffer.state_mean())/(self.buffer.state_std() + 1e-6)
140+
# reward = reward / (self.buffer.reward_std() + 1e-6)
141+
142+
state = torch.tensor(state, dtype=torch.float).to(device)
143+
action = torch.tensor(action, dtype=torch.float).to(device)
144+
reward = torch.tensor(reward, dtype=torch.float).view(batch_size, -1).to(device)
145+
next_state = torch.tensor(next_state, dtype=torch.float).to(device)
146+
done = torch.tensor(done, dtype=torch.float).to(device).view(batch_size, -1).to(device)
147+
148+
with torch.no_grad():
149+
next_action = self.actor_target(next_state)
150+
target_Q = self.critic_target(next_state, next_action)
151+
target_Q = reward + (1 - done) * discount * target_Q
152+
153+
# Get current Q estimates
154+
current_Q = self.critic(state, action)
155+
156+
# Compute critic loss
157+
critic_loss = F.mse_loss(current_Q, target_Q)
158+
self.writer.add_scalar('Loss/critic_loss', critic_loss, global_step=self.num_training)
159+
160+
# Optimize the critic
161+
self.critic_optimizer.zero_grad()
162+
critic_loss.backward()
163+
self.critic_optimizer.step()
164+
165+
# Compute actor losse
166+
actor_loss = -self.critic(state, self.actor(state)).mean()
167+
self.writer.add_scalar('Loss/actor_loss', actor_loss, global_step=self.num_training)
168+
169+
# Optimize the actor
170+
self.actor_optimizer.zero_grad()
171+
actor_loss.backward()
172+
self.actor_optimizer.step()
173+
174+
# Update the frozen target models
175+
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
176+
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
177+
178+
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
179+
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
180+
181+
self.num_training += 1
182+
183+
def save(self):
184+
torch.save(self.actor.state_dict(), './model/actor.pth')
185+
torch.save(self.critic.state_dict(), './model/critic.pth')
186+
print("====================================")
187+
print("Model has been saved...")
188+
print("====================================")
189+
190+
def load(self):
191+
torch.load(self.actor.state_dict(), './model/actor.pth')
192+
torch.load(self.critic.state_dict(), './model/critic.pth')
193+
print("====================================")
194+
print("Model has been loaded...")
195+
print("====================================")
196+
197+
198+
if __name__ == '__main__':
199+
agent = DDPG()
200+
state = env.reset()
201+
202+
if load:
203+
agent.load()
204+
if render:
205+
env.render()
206+
207+
print("====================================")
208+
print("Collection Experience...")
209+
print("====================================")
210+
211+
total_step = 0
212+
213+
for episode in range(max_episode):
214+
215+
total_reward = 0
216+
state = env.reset()
217+
218+
for step in range(max_step_size):
219+
220+
total_step += 1
221+
222+
action = agent.act(state)
223+
224+
if epsilon > np.random.random():
225+
action = (action + np.random.normal(0, 0.2, size=action_dim)).clip(-1, 1)
226+
227+
next_state, reward, done, _ = env.step(envAction(action))
228+
229+
# reward trick of BipedalWalker-v3
230+
# if reward == -100:
231+
# reward = -1
232+
233+
if render:
234+
env.render()
235+
236+
agent.put(state, action, reward, next_state, done)
237+
238+
agent.update()
239+
240+
total_reward += reward
241+
242+
state = next_state
243+
244+
epsilon = max(epsilon_decay*epsilon, 0.10)
245+
agent.writer.add_scalar('Other/epsilon', epsilon, global_step=total_step)
246+
247+
if done:
248+
break
249+
250+
if episode % 10 == 0:
251+
agent.save()
252+
253+
agent.writer.add_scalar('Other/total_reward', total_reward, global_step=episode)

0 commit comments

Comments
 (0)