Skip to content

Commit 007eb0d

Browse files
committed
fix bug: SAC/TD3/DDPG
1 parent ce8f79f commit 007eb0d

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

Char05 DDPG/DDPG.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(self, state_dim, action_dim, init_w=3e-3):
6363
self.l2 = nn.Linear(256, 256)
6464
self.l3 = nn.Linear(256, action_dim)
6565

66-
self.l3.weight.data.uniform_(init_w, init_w)
66+
self.l3.weight.data.uniform_(-init_w, init_w)
6767
self.l3.bias.data.uniform_(-init_w, init_w)
6868

6969
def forward(self, state):

Char09 SAC/SAC_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(self, state_dim, action_dim, init_w=3e-3):
6666
self.l2 = nn.Linear(256, 256)
6767

6868
self.mu_head = nn.Linear(256, action_dim)
69-
self.mu_head.weight.data.uniform_(init_w, init_w)
69+
self.mu_head.weight.data.uniform_(-init_w, init_w)
7070
self.mu_head.bias.data.uniform_(-init_w, init_w)
7171

7272
self.log_std_head = nn.Linear(256, action_dim)
@@ -210,7 +210,7 @@ def update(self):
210210

211211
state = torch.tensor(state, dtype=torch.float).to(device)
212212
action = torch.tensor(action, dtype=torch.float).to(device)
213-
reward = torch.tensor(reward, dtype=torch.float).view(batch_size, -1).to(device)
213+
reward = torch.tensor(reward, dtype=torch.float).view(batch_size, 1).to(device)
214214
next_state = torch.tensor(next_state, dtype=torch.float).to(device)
215215
done = torch.tensor(done, dtype=torch.float).to(device).view(batch_size, -1).to(device)
216216

Char09 SAC/SAC_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(self, state_dim, action_dim, init_w=3e-3):
6666
self.l2 = nn.Linear(256, 256)
6767

6868
self.mu_head = nn.Linear(256, action_dim)
69-
self.mu_head.weight.data.uniform_(init_w, init_w)
69+
self.mu_head.weight.data.uniform_(-init_w, init_w)
7070
self.mu_head.bias.data.uniform_(-init_w, init_w)
7171

7272
self.log_std_head = nn.Linear(256, action_dim)

Char10 TD3/TD3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(self, state_dim, action_dim, init_w=3e-3):
6666
self.l2 = nn.Linear(256, 256)
6767
self.l3 = nn.Linear(256, action_dim)
6868

69-
self.l3.weight.data.uniform_(init_w, init_w)
69+
self.l3.weight.data.uniform_(-init_w, init_w)
7070
self.l3.bias.data.uniform_(-init_w, init_w)
7171

7272
def forward(self, state):

0 commit comments

Comments
 (0)