Skip to content

Commit 81c7ddf

Browse files
committed
Initial QTE commit.
1 parent de2db8f commit 81c7ddf

File tree

8 files changed

+865
-3
lines changed

8 files changed

+865
-3
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Codebase of Q-attention, coarse-to-fine Q-attention, and other variants. Code fr
55
- [Q-attention: Enabling Efficient Learning for Vision-based Robotic Manipulation](https://arxiv.org/abs/2105.14829) (ARM system)
66
- [Coarse-to-Fine Q-attention: Efficient Learning for Visual Robotic Manipulation via Discretisation](https://arxiv.org/abs/2106.12534) (C2F-ARM system)
77
- [Coarse-to-Fine Q-attention with Learned Path Ranking](https://arxiv.org/abs/2204.01571) (C2F-ARM+LPR system)
8+
- [Coarse-to-Fine Q-attention with Tree Expansion](https://arxiv.org/abs/2204.12471)
89

910
![task grid image missing](readme_files/variants.png)
1011

@@ -42,3 +43,8 @@ To launch **C2F-ARM+LPR**:
4243
```bash
4344
python launch.py method=LPR rlbench.task=take_lid_off_saucepan rlbench.demo_path=/mnt/my/save/dir framework.gpu=0
4445
```
46+
47+
To launch **C2F-ARM+QTE**:
48+
```bash
49+
python launch.py method=QTE rlbench.task=take_lid_off_saucepan rlbench.demo_path=/mnt/my/save/dir framework.gpu=0
50+
```

arm/qte/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import arm.qte.launch_utils

arm/qte/launch_utils.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from omegaconf import DictConfig
2+
3+
from arm.qte.networks import Qattention3DNet
4+
from arm.qte.qattention_agent import QAttentionAgent
5+
from arm.c2farm.qattention_stack_agent import QAttentionStackAgent
6+
from arm.preprocess_agent import PreprocessAgent
7+
8+
9+
def create_agent(cfg: DictConfig, env, depth_0bounds=None, cam_resolution=None):
10+
VOXEL_FEATS = 3
11+
LATENT_SIZE = 64
12+
cam_resolution = cam_resolution or [128, 128]
13+
14+
include_prev_layer = False
15+
16+
num_rotation_classes = int(360. // cfg.method.rotation_resolution)
17+
qattention_agents = []
18+
for depth, vox_size in enumerate(cfg.method.voxel_sizes):
19+
last = depth == len(cfg.method.voxel_sizes) - 1
20+
unet3d = Qattention3DNet(
21+
in_channels=VOXEL_FEATS + 3 + 1 + 3,
22+
out_channels=1,
23+
voxel_size=vox_size,
24+
timesteps=cfg.replay.timesteps,
25+
out_dense=((num_rotation_classes * 3) + 2) if last else 0,
26+
kernels=LATENT_SIZE,
27+
norm=None if 'None' in cfg.method.norm else cfg.method.norm,
28+
dense_feats=128,
29+
activation=cfg.method.activation,
30+
low_dim_size=env.low_dim_state_len,
31+
include_prev_layer=include_prev_layer and depth > 0)
32+
33+
34+
qattention_agent = QAttentionAgent(
35+
layer=depth,
36+
coordinate_bounds=depth_0bounds,
37+
unet3d=unet3d,
38+
camera_names=cfg.rlbench.cameras,
39+
voxel_size=vox_size,
40+
bounds_offset=cfg.method.bounds_offset[depth - 1] if depth > 0 else None,
41+
image_crop_size=cfg.method.image_crop_size,
42+
tau=cfg.method.tau,
43+
lr=cfg.method.lr,
44+
lambda_trans_qreg=cfg.method.lambda_trans_qreg,
45+
lambda_rot_qreg=cfg.method.lambda_rot_qreg,
46+
include_low_dim_state=True,
47+
image_resolution=cam_resolution,
48+
batch_size=cfg.replay.batch_size,
49+
timesteps=cfg.replay.timesteps,
50+
voxel_feature_size=VOXEL_FEATS,
51+
exploration_strategy=cfg.method.exploration_strategy,
52+
lambda_weight_l2=cfg.method.lambda_weight_l2,
53+
num_rotation_classes=num_rotation_classes,
54+
rotation_resolution=cfg.method.rotation_resolution,
55+
grad_clip=0.01,
56+
gamma=0.99,
57+
tree_search_breadth=cfg.method.tree_search_breadth,
58+
tree_during_update=cfg.method.tree_during_update,
59+
tree_during_act=cfg.method.tree_during_act
60+
)
61+
qattention_agents.append(qattention_agent)
62+
63+
for i in range(len(qattention_agents) - 1):
64+
qattention_agents[i].give_next_layer_qattention(qattention_agents[i + 1])
65+
66+
rotation_agent = QAttentionStackAgent(
67+
qattention_agents=qattention_agents,
68+
rotation_resolution=cfg.method.rotation_resolution,
69+
camera_names=cfg.rlbench.cameras,
70+
)
71+
preprocess_agent = PreprocessAgent(pose_agent=rotation_agent)
72+
return preprocess_agent

arm/qte/networks.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from arm.network_utils import Conv3DInceptionBlock, DenseBlock, SpatialSoftmax3D, \
5+
Conv3DInceptionBlockUpsampleBlock, Conv3DBlock
6+
7+
8+
class Qattention3DNet(nn.Module):
9+
10+
def __init__(self,
11+
in_channels: int,
12+
out_channels: int,
13+
out_dense: int,
14+
voxel_size: int,
15+
low_dim_size: int,
16+
kernels: int,
17+
timesteps: int,
18+
norm: str = None,
19+
activation: str = 'relu',
20+
dense_feats: int = 32,
21+
include_prev_layer = False,):
22+
super(Qattention3DNet, self).__init__()
23+
self._in_channels = in_channels
24+
self._out_channels = out_channels
25+
self._norm = norm
26+
self._activation = activation
27+
self._kernels = kernels
28+
self._timesteps = timesteps
29+
self._low_dim_size = low_dim_size * timesteps
30+
self._build_calls = 0
31+
self._voxel_size = voxel_size
32+
self._dense_feats = dense_feats
33+
self._out_dense = out_dense
34+
self._include_prev_layer = include_prev_layer
35+
36+
def build(self):
37+
use_residual = False
38+
self._build_calls += 1
39+
if self._build_calls != 1:
40+
raise RuntimeError('Build needs to be called once.')
41+
42+
spatial_size = self._voxel_size
43+
self._input_preprocess = Conv3DInceptionBlock(
44+
self._in_channels, self._kernels, norm=self._norm,
45+
activation=self._activation)
46+
47+
d0_ins = self._input_preprocess.out_channels * self._timesteps
48+
if self._include_prev_layer:
49+
PREV_VOXEL_CHANNELS = 0
50+
self._input_preprocess_prev_layer = Conv3DInceptionBlock(
51+
self._in_channels + PREV_VOXEL_CHANNELS, self._kernels, norm=self._norm,
52+
activation=self._activation)
53+
d0_ins += self._input_preprocess_prev_layer.out_channels
54+
55+
if self._low_dim_size > 0:
56+
self._proprio_preprocess = DenseBlock(
57+
self._low_dim_size, self._kernels, None, self._activation)
58+
d0_ins += self._kernels
59+
60+
self._down0 = Conv3DInceptionBlock(
61+
d0_ins, self._kernels, norm=self._norm,
62+
activation=self._activation, residual=use_residual)
63+
self._ss0 = SpatialSoftmax3D(
64+
spatial_size, spatial_size, spatial_size,
65+
self._down0.out_channels)
66+
spatial_size //= 2
67+
self._down1 = Conv3DInceptionBlock(
68+
self._down0.out_channels, self._kernels * 2, norm=self._norm,
69+
activation=self._activation, residual=use_residual)
70+
self._ss1 = SpatialSoftmax3D(
71+
spatial_size, spatial_size, spatial_size,
72+
self._down1.out_channels)
73+
spatial_size //= 2
74+
75+
flat_size = self._down0.out_channels * 4 + self._down1.out_channels * 4
76+
77+
k1 = self._down1.out_channels
78+
if self._voxel_size > 8:
79+
k1 += self._kernels
80+
self._down2 = Conv3DInceptionBlock(
81+
self._down1.out_channels, self._kernels * 4, norm=self._norm,
82+
activation=self._activation, residual=use_residual)
83+
flat_size += self._down2.out_channels * 4
84+
self._ss2 = SpatialSoftmax3D(
85+
spatial_size, spatial_size, spatial_size,
86+
self._down2.out_channels)
87+
spatial_size //= 2
88+
k2 = self._down2.out_channels
89+
if self._voxel_size > 16:
90+
k2 *= 2
91+
self._down3 = Conv3DInceptionBlock(
92+
self._down2.out_channels, self._kernels, norm=self._norm,
93+
activation=self._activation, residual=use_residual)
94+
flat_size += self._down3.out_channels * 4
95+
self._ss3 = SpatialSoftmax3D(
96+
spatial_size, spatial_size, spatial_size,
97+
self._down3.out_channels)
98+
self._up3 = Conv3DInceptionBlockUpsampleBlock(
99+
self._kernels, self._kernels, 2, norm=self._norm,
100+
activation=self._activation, residual=use_residual)
101+
self._up2 = Conv3DInceptionBlockUpsampleBlock(
102+
k2, self._kernels, 2, norm=self._norm,
103+
activation=self._activation, residual=use_residual)
104+
105+
self._up1 = Conv3DInceptionBlockUpsampleBlock(
106+
k1, self._kernels, 2, norm=self._norm,
107+
activation=self._activation, residual=use_residual)
108+
109+
self._global_maxp = nn.AdaptiveMaxPool3d(1)
110+
self._local_maxp = nn.MaxPool3d(3, 2, padding=1)
111+
self._final = Conv3DBlock(
112+
self._kernels * 2, self._kernels, kernel_sizes=3,
113+
strides=1, norm=self._norm, activation=self._activation)
114+
self._final2 = Conv3DBlock(
115+
self._kernels, self._out_channels, kernel_sizes=3,
116+
strides=1, norm=None, activation=None)
117+
118+
self._ss_final = SpatialSoftmax3D(
119+
self._voxel_size, self._voxel_size, self._voxel_size,
120+
self._kernels)
121+
flat_size += self._kernels * 4
122+
123+
if self._out_dense > 0:
124+
self._dense0 = DenseBlock(
125+
flat_size, self._dense_feats, None, self._activation)
126+
self._dense1 = DenseBlock(
127+
self._dense_feats, self._dense_feats, None, self._activation)
128+
self._dense2 = DenseBlock(
129+
self._dense_feats, self._out_dense, None, None)
130+
131+
def forward(self, ins, proprio, prev_layer_voxel_grid):
132+
b, t, _, d, h, w = ins.shape
133+
x = torch.cat([self._input_preprocess(x_) for x_ in ins.unbind(1)], 1)
134+
135+
if self._include_prev_layer:
136+
y = self._input_preprocess_prev_layer(prev_layer_voxel_grid)
137+
x = torch.cat([x, y], dim=1)
138+
139+
if self._low_dim_size > 0:
140+
p = self._proprio_preprocess(proprio)
141+
p = p.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(
142+
1, 1, d, h, w)
143+
x = torch.cat([x, p], dim=1)
144+
145+
d0 = self._down0(x)
146+
ss0 = self._ss0(d0)
147+
maxp0 = self._global_maxp(d0).view(b, -1)
148+
d1 = u = self._down1(self._local_maxp(d0))
149+
ss1 = self._ss1(d1)
150+
maxp1 = self._global_maxp(d1).view(b, -1)
151+
152+
feats = [ss0, maxp0, ss1, maxp1]
153+
154+
if self._voxel_size > 8:
155+
d2 = u = self._down2(self._local_maxp(d1))
156+
feats.extend([self._ss2(d2), self._global_maxp(d2).view(b, -1)])
157+
if self._voxel_size > 16:
158+
d3 = self._down3(self._local_maxp(d2))
159+
feats.extend([self._ss3(d3), self._global_maxp(d3).view(b, -1)])
160+
u3 = self._up3(d3)
161+
u = torch.cat([d2, u3], dim=1)
162+
u2 = self._up2(u)
163+
u = torch.cat([d1, u2], dim=1)
164+
165+
u1 = self._up1(u)
166+
f1 = self._final(torch.cat([d0, u1], dim=1))
167+
trans = self._final2(f1)
168+
169+
feats.extend([self._ss_final(f1), self._global_maxp(f1).view(b, -1)])
170+
171+
self.latent_dict = {
172+
'd0': d0.mean(-1).mean(-1).mean(-1),
173+
'd1': d1.mean(-1).mean(-1).mean(-1),
174+
'u1': u1.mean(-1).mean(-1).mean(-1),
175+
'trans_out': trans,
176+
}
177+
178+
rot_and_grip_out = None
179+
if self._out_dense > 0:
180+
dense0 = self._dense0(torch.cat(feats, 1))
181+
dense1 = self._dense1(dense0)
182+
rot_and_grip_out = self._dense2(dense1)
183+
self.latent_dict.update({
184+
'dense0': dense0,
185+
'dense1': dense1,
186+
'dense2': rot_and_grip_out,
187+
})
188+
189+
if self._voxel_size > 8:
190+
self.latent_dict.update({
191+
'd2': d2.mean(-1).mean(-1).mean(-1),
192+
'u2': u2.mean(-1).mean(-1).mean(-1),
193+
})
194+
if self._voxel_size > 16:
195+
self.latent_dict.update({
196+
'd3': d3.mean(-1).mean(-1).mean(-1),
197+
'u3': u3.mean(-1).mean(-1).mean(-1),
198+
})
199+
200+
return trans, rot_and_grip_out

0 commit comments

Comments
 (0)