Skip to content

Commit 70c1af4

Browse files
committed
before UAI
1 parent 6d9eecf commit 70c1af4

File tree

17 files changed

+180
-762
lines changed

17 files changed

+180
-762
lines changed

.config

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"exp_path": "pj/qmc"}
1+
{"exp_path": "pj/qmc/qmc"}

Makefile

Lines changed: 0 additions & 3 deletions
This file was deleted.

arqmc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pathlib import Path
88
from ipdb import launch_ipdb_on_exception
99

10-
import exp_utils.run
10+
import exps.utils.run
1111
from envs import Brownian, LQR
1212
#from rqmc_distributions.dist_rqmc import Uniform_RQMC, Normal_RQMC
1313
from rqmc_distributions import Normal_RQMC, Uniform_RQMC
@@ -31,7 +31,7 @@ def parse_args(args=None):
3131
parser.add_argument('--algos', type=str, nargs='+', default=['mc', 'rqmc', 'arqmc'])
3232
parser.add_argument('--exp_name', type=str, default=None)
3333
parser.add_argument('--seed', type=int, default=None)
34-
return exp_utils.run.parse_args(parser, args, exp_name_attr='exp_name')
34+
return exps.utils.run.parse_args(parser, args, exp_name_attr='exp_name')
3535

3636
### tasks ### (estimate cost, learn)
3737

@@ -88,7 +88,8 @@ def brownian(args):
8888
states = [env.reset() for env in envs]
8989
dones = [False for _ in range(args.n_trajs)]
9090
uniform_noises = ssj_uniform(args.n_trajs, 1) # n_trajs , action_dim
91-
noises = uniform2normal(random_shift(np.expand_dims(uniform_noises, 1).repeat(args.horizon, 1), 0))
91+
noises = uniform2normal(random_shift(np.expand_dims(uniform_noises, 1).repeat(args.horizon, 1), 0)) # n_trajs, horizon, action_dim
92+
import ipdb; ipdb.set_trace()
9293
for j in range(args.horizon):
9394
if np.all(dones): break
9495
envs, states, dones, returns = zip(*sorted(zip(envs, states, dones, returns), key=lambda x: np.inf if x[2] else x[1]))

environment.yml

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ channels:
44
- https://repo.continuum.io/pkgs/free
55
- defaults
66
dependencies:
7-
- blas=1.0=mkl
7+
#- blas=1.0=mkl
88
- ca-certificates=2019.1.23
99
- certifi=2018.11.29
1010
- cffi=1.12.1
@@ -19,9 +19,9 @@ dependencies:
1919
- libpng=1.6.36
2020
- libstdcxx-ng=8.2.0
2121
- libtiff=4.0.10
22-
- mkl=2019.1
23-
- mkl_fft=1.0.10
24-
- mkl_random=1.0.2
22+
#- mkl=2019.1
23+
#- mkl_fft=1.0.10
24+
#- mkl_random=1.0.2
2525
- ncurses=6.1
2626
- ninja=1.8.2
2727
- numpy=1.16.2
@@ -59,7 +59,7 @@ dependencies:
5959
- filelock==3.0.10
6060
- future==0.17.1
6161
- glfw==1.8.1
62-
- gym==0.12.0
62+
- gym==0.15.3
6363
- idna==2.8
6464
- imageio==2.5.0
6565
- ipdb==0.11
@@ -80,7 +80,7 @@ dependencies:
8080
- markupsafe==1.1.1
8181
- matplotlib==3.0.3
8282
- mistune==0.8.4
83-
- mujoco-py==2.0.2.2
83+
#- mujoco-py==2.0.2.2
8484
- nbconvert==5.4.1
8585
- nbformat==4.4.0
8686
- notebook==5.7.6
@@ -89,7 +89,7 @@ dependencies:
8989
- pandas==0.24.1
9090
- pandocfilters==1.4.2
9191
- parso==0.3.4
92-
- particles==0.1
92+
#- particles==0.1
9393
- pexpect==4.6.0
9494
- pickleshare==0.7.5
9595
- prometheus-client==0.6.0
@@ -118,4 +118,5 @@ dependencies:
118118
- wcwidth==0.1.7
119119
- webencodings==0.5.1
120120
- widgetsnbextension==3.4.2
121-
121+
- py4j
122+
- randopt

envs/gridworld/maps/four_ends.txt

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
###############################
2+
###############################
3+
############### ###############
4+
############### ###############
5+
############### ###############
6+
############### ###############
7+
############### ###############
8+
############### ###############
9+
############### ###############
10+
############### ###############
11+
############### ###############
12+
############### ###############
13+
############### ###############
14+
############### ###############
15+
############### ###############
16+
# #
17+
############### ###############
18+
############### ###############
19+
############### ###############
20+
############### ###############
21+
############### ###############
22+
############### ###############
23+
############### ###############
24+
############### ###############
25+
############### ###############
26+
############### ###############
27+
############### ###############
28+
############### ###############
29+
############### ###############
30+
############### ###############
31+
###############################

envs/gridworld/pointmass.py

Lines changed: 80 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,83 +5,87 @@
55
import matplotlib.pyplot as plt
66
from .utils import Render, color_interpolate
77

8+
89
CUR_DIR = os.path.dirname(__file__)
910

11+
1012
def read_map(filename):
1113
m = []
1214
with open(filename) as f:
1315
for row in f:
1416
m.append(list(row.rstrip()))
1517
return m
1618

19+
1720
def get_grid_position(x, y):
1821
return int(x), int(y)
1922

23+
2024
def sample_pos(m, exclude=set(), rng=np.random):
2125
while True:
2226
x = rng.uniform(len(m))
2327
y = rng.uniform(len(m[0]))
2428
x, y = get_grid_position(x, y)
2529
if (m[x][y] != '#') and ((x, y) not in exclude): return np.array([x, y])
2630

31+
2732
colormap = {
2833
' ': color_interpolate(0.0, plt.cm.Greys(0.02), plt.cm.Greys(0.2)),
2934
'@': color_interpolate(0.0, plt.cm.Greys(0.02), plt.cm.Greys(0.2)),
3035
'#': color_interpolate(0.0, plt.cm.Greys(0.12), plt.cm.Greys(0.3)),
3136
}
3237

38+
3339
# push everything into wrapper (sample init position, sample goal, change map etc)
3440
# for MDP, state should contains all information, which means that to simulate in parallel n rollouts, you only need one environment and n states
3541
class PointMass(gym.Env):
3642
def __init__(
3743
self,
3844
map_name,
39-
goal=None,
4045
init_pos=None,
4146
n_sub_steps=10,
42-
done_threshold=0.8,
4347
seed=0,
4448
):
4549
self.map = read_map(os.path.join(CUR_DIR, 'maps', '{}.txt'.format(map_name)))
4650
self.row, self.col = len(self.map), len(self.map[0])
4751
self.seed(seed)
48-
if goal is None:
49-
goal = sample_pos(self.map, rng=self.rng)
50-
self.goal = goal
51-
if init_pos is None:
52-
init_pos = sample_pos(self.map, {tuple(goal)})
53-
assert init_pos[0] > 0 and init_pos[0] < self.row and init_pos[1] > 0 and init_pos[1] < self.col
54-
self.init_pos = np.asarray(init_pos)
52+
if init_pos is not None:
53+
self.load_params({'init_pos': init_pos})
54+
else:
55+
self._init_pos = None
5556
self.n_sub_steps = n_sub_steps
5657
self.done_threshold = done_threshold
5758

5859
self.observation_space = gym.spaces.Box(np.array([0.0, 0.0]), np.array([self.row, self.col]))
5960
self.action_space = gym.spaces.Box(np.array([-1.0, -1.0]), np.array([1.0, 1.0]))
6061
self._render = None
6162

63+
def load_params(self, params):
64+
if 'init_pos' in params:
65+
assert 0 < init_pos[0] < self.row and 0 < init_pos[1] < self.col
66+
self._init_pos = np.array(params['init_pos'])
67+
6268
def seed(self, seed=None):
6369
self.rng, _ = seeding.np_random(seed)
6470

6571
def reset(self):
66-
self.pos = self.init_pos
72+
assert self._init_pos is not None
73+
self.pos = self._init_pos
6774
return self.pos
6875

6976
def _is_blocked(self, pos):
7077
x, y = get_grid_position(*pos)
7178
return self.map[x][y] == '#'
7279

73-
def step(self, action):
74-
assert not self._is_blocked(self.pos), 'start position in the wall'
75-
action = np.clip(action, self.action_space.low, self.action_space.high)
80+
def transition(self, pos, action):
81+
assert not self._is_blocked(pos), 'start position in the wall'
82+
action = np.clip(action, self.action_space.low, self.action_space.high) # might cause problem
7683
dpos = 1.0 / self.n_sub_steps
7784
for _ in range(self.n_sub_steps):
78-
next_pos = self.pos + action * dpos
85+
next_pos = pos + action * dpos
7986
if self._is_blocked(next_pos): break
80-
self.pos = next_pos
81-
dist = np.linalg.norm(self.goal - self.pos)
82-
r = -dist
83-
done = dist < self.done_threshold
84-
return self.pos, r, done, {}
87+
pos = next_pos
88+
return pos
8589

8690
def render(self, repeat=32):
8791
self.init_render(repeat)
@@ -101,6 +105,63 @@ def init_render(self, repeat):
101105
self._render = Render(size=(self.col * repeat, self.row * repeat))
102106
return self
103107

108+
109+
class ReachPointMass(PointMass):
110+
def __init__(
111+
self,
112+
map_name,
113+
goal=None,
114+
init_pos=None,
115+
n_sub_steps=10,
116+
done_threshold=0.8,
117+
seed=0,
118+
):
119+
super().__init__(map_name, n_sub_steps=n_sub_steps, seed=seed)
120+
self.done_threshold = done_threshold
121+
if init_pos is not None:
122+
self.load_params({'init_pos': init_pos})
123+
else:
124+
self._init_pos = None
125+
if goal is not None:
126+
self.load_params({'goal': goal})
127+
else:
128+
self._goal = None
129+
130+
def load_params(self, params):
131+
super().load_params(params)
132+
if 'goal' in params:
133+
self._goal = goal
134+
if self._init_pos is not None and self._goal is not None:
135+
assert get_grid_position(*self._init_pos) != get_grid_position(*self._goal)
136+
137+
def step(self, action):
138+
self.pos = self.transition(self.pos, action)
139+
dist = np.linalg.norm(self.goal - self.pos)
140+
r = -dist
141+
done = dist < self.done_threshold
142+
return self.pos, r, done, {}
143+
144+
145+
class GaussianMixtureRewardPointMass(PointMass):
146+
def __init__(
147+
self,
148+
map_name,
149+
gaussians=[], # c, mean, sigma
150+
init_pos=None,
151+
n_sub_steps=10,
152+
seed=0,
153+
):
154+
super().__init__(map_name, init_pos=init_pos, n_sub_steps=n_sub_steps, seed=seed)
155+
self.gaussians = gaussians
156+
157+
def step(self, action):
158+
self.pos = self.transition(self.pos, action)
159+
r = 0.0
160+
for c, mean, sigma in self.gaussians:
161+
r += c * np.exp(-(self.pos - mean).square().sum() / sigma)
162+
return self.pos, r, False, {}
163+
164+
104165
class GaussianActionNoiseWrapper(gym.Wrapper):
105166
def __init__(self, env, scale, seed=None):
106167
super().__init__(self, env)

experiments/cost_lqr.run

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +0,0 @@
1-
--env lqr --exp_name cost_lqr/H_[horizon]-T[n_trajs] --n_runs 30 --seed 0 --sorter value norm group permute --horizon 20 --n_trajs 32
2-
--env lqr --exp_name cost_lqr/H_[horizon]-T[n_trajs] --n_runs 30 --seed 0 --sorter value norm group permute --horizon 20 --n_trajs 64
3-
--env lqr --exp_name cost_lqr/H_[horizon]-T[n_trajs] --n_runs 30 --seed 0 --sorter value norm group permute --horizon 20 --n_trajs 128
4-
--env lqr --exp_name cost_lqr/H_[horizon]-T[n_trajs] --n_runs 30 --seed 0 --sorter value norm group permute --horizon 20 --n_trajs 256
5-
--env lqr --exp_name cost_lqr/H_[horizon]-T[n_trajs] --n_runs 30 --seed 0 --sorter value norm group permute --horizon 20 --n_trajs 512
6-
--env lqr --exp_name cost_lqr/H_[horizon]-T[n_trajs] --n_runs 30 --seed 0 --sorter value norm group permute --horizon 20 --n_trajs 1024
7-
--env lqr --exp_name cost_lqr/H_[horizon]-T[n_trajs] --n_runs 30 --seed 0 --sorter value norm group permute --horizon 40 --n_trajs 32
8-
--env lqr --exp_name cost_lqr/H_[horizon]-T[n_trajs] --n_runs 30 --seed 0 --sorter value norm group permute --horizon 40 --n_trajs 64
9-
--env lqr --exp_name cost_lqr/H_[horizon]-T[n_trajs] --n_runs 30 --seed 0 --sorter value norm group permute --horizon 40 --n_trajs 128
10-
--env lqr --exp_name cost_lqr/H_[horizon]-T[n_trajs] --n_runs 30 --seed 0 --sorter value norm group permute --horizon 40 --n_trajs 256
11-
--env lqr --exp_name cost_lqr/H_[horizon]-T[n_trajs] --n_runs 30 --seed 0 --sorter value norm group permute --horizon 40 --n_trajs 512
12-
--env lqr --exp_name cost_lqr/H_[horizon]-T[n_trajs] --n_runs 30 --seed 0 --sorter value norm group permute --horizon 40 --n_trajs 1024

exps_utils/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)