Skip to content

Commit 778d5a2

Browse files
author
liychen
committed
Merge branch 'master' of github.com:lchenat/qmc
2 parents d08d0e1 + 75db514 commit 778d5a2

File tree

3 files changed

+61
-60
lines changed

3 files changed

+61
-60
lines changed

gen_exps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def post_variant(variant):
5353
generate_args('exps/search_learn_{}'.format(kwargs['--env']), args, kwargs, variants, post_variant=post_variant, shuffle=shuffle)
5454

5555
@cmd()
56-
def search_network(touch: int=1, shuffle: int=0):
56+
def search_network_std(touch: int=1, shuffle: int=0):
5757
variants = {
5858
'--n_trajs': [60, 100, 150, 200, 300],
5959
'-lr': [0.0001, 0.0005, 0.001],
@@ -73,7 +73,7 @@ def search_network(touch: int=1, shuffle: int=0):
7373
def post_variant(variant):
7474
variant['--save_fn'] = 'data/search_network/{}-{}-{}-{}'.format(*[variant[k] for k in ['--n_trajs', '-lr', '-H', '--init_scale']])
7575
return variant
76-
generate_args('exps/search_network', args, kwargs, variants, post_variant=post_variant, shuffle=shuffle)
76+
generate_args('exps/search_network_std', args, kwargs, variants, post_variant=post_variant, shuffle=shuffle)
7777

7878

7979
if __name__ == "__main__":

models.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,28 +33,28 @@ def __init__(
3333
state_dim,
3434
action_dim,
3535
mean_network,
36-
learn_std=False,
36+
learn_std=True,
3737
):
3838
super().__init__()
3939
self.mean = mean_network
4040
self.std = torch.zeros(action_dim)
4141
if learn_std: self.std = nn.Parameter(self.std)
42+
self.learn_std = learn_std
4243
self.to(Config.DEVICE)
4344

4445
def distribution(self, obs):
4546
obs = tensor(obs)
46-
mean = self.mean(obs)
47-
dist = torch.distributions.Normal(mean, tensor(torch.ones_like(self.std)))
48-
#mean = torch.tanh(self.mean(obs))
49-
#dist = torch.distributions.Normal(mean, F.softplus(self.std))
50-
#log_prob = dist.log_prob(action).sum(-1).unsqueeze(-1)
47+
if self.learn_std:
48+
dist = torch.distributions.Normal(torch.tanh(self.mean(obs)), F.softplus(self.std))
49+
else:
50+
dist = torch.distributions.Normal(self.mean(obs), self.std)
5151
return dist
5252

5353
def forward(self, obs, noise):
5454
#:: there is an issue with gpu of multiprocessing, unless you want to have one GPU each process, it is not worth it.
5555
obs = tensor(obs)
56-
#mean = torch.tanh(self.mean(obs)) # bounded action!!!
57-
mean = self.mean(obs)
58-
#action = mean + tensor(noise) * F.softplus(self.std)
59-
action = mean + tensor(noise)
56+
if self.learn_std:
57+
action = torch.tanh(self.mean(obs)) + tensor(noise) * F.softplus(self.std)
58+
else:
59+
action = self.mean(obs) + tensor(noise)
6060
return action.cpu().detach().numpy()

0 commit comments

Comments
 (0)