Skip to content

Commit 3c82ace

Browse files
committed
init
0 parents  commit 3c82ace

38 files changed

+8288
-0
lines changed

.gitattributes

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Auto detect text files and perform LF normalization
2+
* text=auto

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
*.pyc

README.md

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Code for the Bayesian Behaviors framework
2+
3+
The current repository contains the source code for generating the simulation results for the paper "Synergizing habits and goals with variational Bayes", published on *Nature Communications* (Link to be updated)
4+
5+
## Installation
6+
7+
Tested using Python 3.7.7 on Ubuntu 20.04 and Windows 11
8+
9+
### Install Requirements (typically takes a few minutes)
10+
11+
```bash
12+
pip install -r requirements.txt
13+
```
14+
15+
And you also need to install PyTorch. Please install PyTorch >= 1.11 that matches your CUDA version according to <https://pytorch.org/>.
16+
17+
## How to train and inference (Python, PyTorch)
18+
19+
### Habitization Experiment (Results for Figures 2, 3, 4)
20+
21+
```bash
22+
python run_habitization_experiment.py --seed 42 --verbose 1 --gui 0
23+
```
24+
25+
Set `--gui 1` if you want to see the visualized environment.
26+
27+
The default arguments (hyperparameters) are the same as used in the paper. For the information of the arguments in training the habitual behavior, see `run_habitization_experiment.py`
28+
29+
To run the models with different training steps in stage 2 (Figure 3), use the `--stage_3_start_step` argument.
30+
31+
### Flexible Goal-Directed Planning Experiment (Results for Figure 5)
32+
33+
```bash
34+
python run_planning_experiment.py --seed 42 --verbose 1 --gui 0
35+
```
36+
37+
### Data format
38+
39+
Either program takes less than 1 day with a descent GPU, the result data will be saved at `./data/` and `./details/` (and at `./planning/` for the planning experiment) in .mat files, for which you can load using MATLAB or scipy:
40+
41+
```python
42+
import scipy.io as sio
43+
data = sio.loadmat("xxx.mat")
44+
```
45+
46+
The PyTorch model of the trained agent will also be saved at `./data/`, which can be loaded by `torch.load()`.
47+
48+
49+
50+
## Tutorial on plotting the quantitative results in the article (MATLAB)
51+
52+
To replicate the plots, please ensure you have MATLAB version R2022b or later, and download the simulated result data from TODO.
53+
(You may also train your own models using the guideline above).
54+
55+
The start, change the MATLAB working directory to ./data_analysis
56+
57+
### Figure 2b
58+
59+
```matlab
60+
plot_adaptation_readaptation_progress("DATAPATH/BB_habit_automaticity/search_mpz_0.1_s3s_420000/details/")
61+
```
62+
63+
Please modify DATAPATH to the data folder you downloaded.
64+
65+
### Figure 2c-h
66+
67+
```matlab
68+
fig2_habitization_analysis("DATAPATH/BB_habit_automaticity/search_mpz_0.1_s3s_420000/data/")
69+
```
70+
71+
### Figure 3
72+
73+
```matlab
74+
fig3_extinction_analysis("DATAPATH/BB_habit_automaticity/")
75+
```
76+
77+
### Figure 4
78+
79+
```matlab
80+
fig4_devaluation_analysis("DATAPATH/BB_habitization/")
81+
```
82+
83+
### Figure 5b
84+
85+
```matlab
86+
plot_adaptation_progress("DATAPATH/BB_planning/search_mpz_0.1/details/")
87+
```
88+
89+
### Figure 5c
90+
91+
```matlab
92+
plot_diversity_statistics("DATAPATH/BB_planning/search_mpz_0.1/details/")
93+
```
94+
95+
### Figure 5d,e
96+
97+
```matlab
98+
plot_planning_details("DATAPATH/BB_planning/search_mpz_0.1/planning/")
99+
```
100+
101+
## Citation
102+
103+
To be updated

base_modules.py

+196
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
class UnsqueezeModule(nn.Module):
6+
def __init__(self, dim: int):
7+
super(UnsqueezeModule, self).__init__()
8+
self.dim = dim
9+
10+
def forward(self, x):
11+
return torch.unsqueeze(x, dim=self.dim)
12+
13+
14+
def make_dcnn(feature_size, out_channels):
15+
dcnn = nn.Sequential(
16+
nn.ConvTranspose2d(feature_size, 64, [1, 4], 1, 0),
17+
nn.ReLU(),
18+
nn.ConvTranspose2d(64, 16, [2, 4], [1, 2], [0, 1]),
19+
nn.ReLU(),
20+
nn.ConvTranspose2d(16, 16, 4, 2, 1),
21+
nn.ReLU(),
22+
nn.ConvTranspose2d(16, 8, 4, 2, 1),
23+
nn.ReLU(),
24+
nn.ConvTranspose2d(8,
25+
8,
26+
kernel_size=3,
27+
stride=2,
28+
padding=1,
29+
output_padding=1),
30+
nn.ReLU(),
31+
nn.Conv2d(8, out_channels=out_channels,
32+
kernel_size=3, padding=1)
33+
) # output size 16 x 64
34+
35+
return dcnn
36+
37+
38+
def make_cnn(n_channels):
39+
40+
cnn_module_list = nn.ModuleList()
41+
cnn_module_list.append(nn.Conv2d(n_channels, 8, 4, 2, 1))
42+
cnn_module_list.append(nn.ReLU())
43+
cnn_module_list.append(nn.Conv2d(8, 16, 4, 2, 1))
44+
cnn_module_list.append(nn.ReLU())
45+
cnn_module_list.append(nn.Conv2d(16, 16, 4, 2, 1))
46+
cnn_module_list.append(nn.ReLU())
47+
cnn_module_list.append(nn.Conv2d(16, 64, [2, 4], 2, [0, 1]))
48+
cnn_module_list.append(nn.ReLU())
49+
cnn_module_list.append(nn.Conv2d(64, 256, [1, 4], [1, 4], 0))
50+
cnn_module_list.append(nn.ReLU())
51+
52+
cnn_module_list.append(nn.Flatten())
53+
phi_size = 256
54+
55+
return nn.Sequential(*cnn_module_list), phi_size
56+
57+
58+
def make_mlp(input_size, hidden_layers, output_size, act_fn, last_layer_linear=False):
59+
mlp = nn.ModuleList()
60+
last_layer_size = input_size
61+
for layer_size in hidden_layers:
62+
mlp.append(nn.Linear(last_layer_size, layer_size, bias=True))
63+
mlp.append(act_fn())
64+
last_layer_size = layer_size
65+
mlp.append(nn.Linear(last_layer_size, output_size, bias=True))
66+
if not last_layer_linear:
67+
mlp.append(act_fn())
68+
69+
return nn.Sequential(*mlp)
70+
71+
72+
class ContinuousActionQNetwork(nn.Module):
73+
def __init__(self, input_size, action_size, hidden_layers=None, act_fn=nn.ReLU):
74+
super(ContinuousActionQNetwork, self).__init__()
75+
76+
if hidden_layers is None:
77+
hidden_layers = [256, 256]
78+
self.input_size = input_size
79+
self.action_size = action_size
80+
self.output_size = 1
81+
self.hidden_layers = hidden_layers
82+
83+
self.network_modules = nn.ModuleList()
84+
85+
last_layer_size = input_size + action_size
86+
for layer_size in hidden_layers:
87+
self.network_modules.append(nn.Linear(last_layer_size, layer_size))
88+
self.network_modules.append(act_fn())
89+
last_layer_size = layer_size
90+
91+
self.network_modules.append(nn.Linear(last_layer_size, self.output_size))
92+
93+
self.main_network = nn.Sequential(*self.network_modules)
94+
95+
def forward(self, x, a):
96+
97+
q = self.main_network(torch.cat((x, a), dim=-1))
98+
99+
return q
100+
101+
102+
class ContinuousActionVNetwork(nn.Module):
103+
def __init__(self, input_size, hidden_layers=None, act_fn=nn.ReLU):
104+
super(ContinuousActionVNetwork, self).__init__()
105+
106+
if hidden_layers is None:
107+
hidden_layers = [256, 256]
108+
self.input_size = input_size
109+
self.output_size = 1
110+
self.hidden_layers = hidden_layers
111+
112+
self.network_modules = nn.ModuleList()
113+
114+
last_layer_size = input_size
115+
for layer_size in hidden_layers:
116+
self.network_modules.append(nn.Linear(last_layer_size, layer_size))
117+
self.network_modules.append(act_fn())
118+
last_layer_size = layer_size
119+
120+
self.network_modules.append(nn.Linear(last_layer_size, self.output_size))
121+
122+
self.main_network = nn.Sequential(*self.network_modules)
123+
124+
def forward(self, x):
125+
126+
q = self.main_network(x)
127+
128+
return q
129+
130+
131+
class ContinuousActionPolicyNetwork(nn.Module):
132+
def __init__(self, input_size, output_size, output_distribution="Gaussian", hidden_layers=None, act_fn=nn.ReLU,
133+
logsig_clip=None):
134+
super(ContinuousActionPolicyNetwork, self).__init__()
135+
136+
if logsig_clip is None:
137+
logsig_clip = [-20, 2]
138+
if hidden_layers is None:
139+
hidden_layers = [256, 256]
140+
self.input_size = input_size
141+
self.output_size = output_size
142+
self.hidden_layers = hidden_layers
143+
self.logsig_clip = logsig_clip
144+
145+
self.output_distribution = output_distribution # Currently only support "Gaussian" or "DiracDelta"
146+
147+
self.mu_layers = nn.ModuleList()
148+
self.logsig_layers = nn.ModuleList()
149+
150+
last_layer_size = input_size
151+
for layer_size in hidden_layers:
152+
self.mu_layers.append(nn.Linear(last_layer_size, layer_size))
153+
self.mu_layers.append(act_fn())
154+
self.logsig_layers.append(nn.Linear(last_layer_size, layer_size))
155+
self.logsig_layers.append(act_fn())
156+
last_layer_size = layer_size
157+
self.mu_layers.append(nn.Linear(last_layer_size, self.output_size))
158+
self.logsig_layers.append(nn.Linear(last_layer_size, self.output_size))
159+
160+
self.mu_net = nn.Sequential(*self.mu_layers)
161+
self.logsig_net = nn.Sequential(*self.logsig_layers)
162+
163+
def forward(self, x):
164+
165+
if self.output_distribution == "Gaussian":
166+
mu = self.mu_net(x)
167+
logsig = self.logsig_net(x).clamp(self.logsig_clip[0], self.logsig_clip[1])
168+
169+
return mu, logsig
170+
171+
else:
172+
raise NotImplementedError
173+
174+
def get_log_action_probability(self, x, a):
175+
176+
mu = self.mu_net(x)
177+
logsig = self.logsig_net(x).clamp(self.logsig_clip[0], self.logsig_clip[1])
178+
179+
dist = torch.distributions.normal.Normal(loc=mu, scale=torch.exp(logsig))
180+
log_action_probability = dist.log_prob(a)
181+
182+
return log_action_probability
183+
184+
def sample_action(self, x, greedy=False):
185+
186+
mu = self.mu_net(x)
187+
logsig = self.logsig_net(x).clamp(self.logsig_clip[0], self.logsig_clip[1])
188+
189+
if greedy:
190+
return torch.tanh(mu).detach().cpu().numpy()
191+
192+
else:
193+
dist = torch.distributions.normal.Normal(loc=mu, scale=torch.exp(logsig))
194+
sampled_u = dist.sample()
195+
196+
return torch.tanh(sampled_u.detach().cpu()).numpy()

0 commit comments

Comments
 (0)