1+ import torch
2+ import torch .nn as nn
3+
4+ from arm .network_utils import Conv3DInceptionBlock , DenseBlock , SpatialSoftmax3D , \
5+ Conv3DInceptionBlockUpsampleBlock , Conv3DBlock
6+
7+
8+ class Qattention3DNet (nn .Module ):
9+
10+ def __init__ (self ,
11+ in_channels : int ,
12+ out_channels : int ,
13+ out_dense : int ,
14+ voxel_size : int ,
15+ low_dim_size : int ,
16+ kernels : int ,
17+ timesteps : int ,
18+ norm : str = None ,
19+ activation : str = 'relu' ,
20+ dense_feats : int = 32 ,
21+ include_prev_layer = False ,):
22+ super (Qattention3DNet , self ).__init__ ()
23+ self ._in_channels = in_channels
24+ self ._out_channels = out_channels
25+ self ._norm = norm
26+ self ._activation = activation
27+ self ._kernels = kernels
28+ self ._timesteps = timesteps
29+ self ._low_dim_size = low_dim_size * timesteps
30+ self ._build_calls = 0
31+ self ._voxel_size = voxel_size
32+ self ._dense_feats = dense_feats
33+ self ._out_dense = out_dense
34+ self ._include_prev_layer = include_prev_layer
35+
36+ def build (self ):
37+ use_residual = False
38+ self ._build_calls += 1
39+ if self ._build_calls != 1 :
40+ raise RuntimeError ('Build needs to be called once.' )
41+
42+ spatial_size = self ._voxel_size
43+ self ._input_preprocess = Conv3DInceptionBlock (
44+ self ._in_channels , self ._kernels , norm = self ._norm ,
45+ activation = self ._activation )
46+
47+ d0_ins = self ._input_preprocess .out_channels * self ._timesteps
48+ if self ._include_prev_layer :
49+ PREV_VOXEL_CHANNELS = 0
50+ self ._input_preprocess_prev_layer = Conv3DInceptionBlock (
51+ self ._in_channels + PREV_VOXEL_CHANNELS , self ._kernels , norm = self ._norm ,
52+ activation = self ._activation )
53+ d0_ins += self ._input_preprocess_prev_layer .out_channels
54+
55+ if self ._low_dim_size > 0 :
56+ self ._proprio_preprocess = DenseBlock (
57+ self ._low_dim_size , self ._kernels , None , self ._activation )
58+ d0_ins += self ._kernels
59+
60+ self ._down0 = Conv3DInceptionBlock (
61+ d0_ins , self ._kernels , norm = self ._norm ,
62+ activation = self ._activation , residual = use_residual )
63+ self ._ss0 = SpatialSoftmax3D (
64+ spatial_size , spatial_size , spatial_size ,
65+ self ._down0 .out_channels )
66+ spatial_size //= 2
67+ self ._down1 = Conv3DInceptionBlock (
68+ self ._down0 .out_channels , self ._kernels * 2 , norm = self ._norm ,
69+ activation = self ._activation , residual = use_residual )
70+ self ._ss1 = SpatialSoftmax3D (
71+ spatial_size , spatial_size , spatial_size ,
72+ self ._down1 .out_channels )
73+ spatial_size //= 2
74+
75+ flat_size = self ._down0 .out_channels * 4 + self ._down1 .out_channels * 4
76+
77+ k1 = self ._down1 .out_channels
78+ if self ._voxel_size > 8 :
79+ k1 += self ._kernels
80+ self ._down2 = Conv3DInceptionBlock (
81+ self ._down1 .out_channels , self ._kernels * 4 , norm = self ._norm ,
82+ activation = self ._activation , residual = use_residual )
83+ flat_size += self ._down2 .out_channels * 4
84+ self ._ss2 = SpatialSoftmax3D (
85+ spatial_size , spatial_size , spatial_size ,
86+ self ._down2 .out_channels )
87+ spatial_size //= 2
88+ k2 = self ._down2 .out_channels
89+ if self ._voxel_size > 16 :
90+ k2 *= 2
91+ self ._down3 = Conv3DInceptionBlock (
92+ self ._down2 .out_channels , self ._kernels , norm = self ._norm ,
93+ activation = self ._activation , residual = use_residual )
94+ flat_size += self ._down3 .out_channels * 4
95+ self ._ss3 = SpatialSoftmax3D (
96+ spatial_size , spatial_size , spatial_size ,
97+ self ._down3 .out_channels )
98+ self ._up3 = Conv3DInceptionBlockUpsampleBlock (
99+ self ._kernels , self ._kernels , 2 , norm = self ._norm ,
100+ activation = self ._activation , residual = use_residual )
101+ self ._up2 = Conv3DInceptionBlockUpsampleBlock (
102+ k2 , self ._kernels , 2 , norm = self ._norm ,
103+ activation = self ._activation , residual = use_residual )
104+
105+ self ._up1 = Conv3DInceptionBlockUpsampleBlock (
106+ k1 , self ._kernels , 2 , norm = self ._norm ,
107+ activation = self ._activation , residual = use_residual )
108+
109+ self ._global_maxp = nn .AdaptiveMaxPool3d (1 )
110+ self ._local_maxp = nn .MaxPool3d (3 , 2 , padding = 1 )
111+ self ._final = Conv3DBlock (
112+ self ._kernels * 2 , self ._kernels , kernel_sizes = 3 ,
113+ strides = 1 , norm = self ._norm , activation = self ._activation )
114+ self ._final2 = Conv3DBlock (
115+ self ._kernels , self ._out_channels , kernel_sizes = 3 ,
116+ strides = 1 , norm = None , activation = None )
117+
118+ self ._ss_final = SpatialSoftmax3D (
119+ self ._voxel_size , self ._voxel_size , self ._voxel_size ,
120+ self ._kernels )
121+ flat_size += self ._kernels * 4
122+
123+ if self ._out_dense > 0 :
124+ self ._dense0 = DenseBlock (
125+ flat_size , self ._dense_feats , None , self ._activation )
126+ self ._dense1 = DenseBlock (
127+ self ._dense_feats , self ._dense_feats , None , self ._activation )
128+ self ._dense2 = DenseBlock (
129+ self ._dense_feats , self ._out_dense , None , None )
130+
131+ def forward (self , ins , proprio , prev_layer_voxel_grid ):
132+ b , t , _ , d , h , w = ins .shape
133+ x = torch .cat ([self ._input_preprocess (x_ ) for x_ in ins .unbind (1 )], 1 )
134+
135+ if self ._include_prev_layer :
136+ y = self ._input_preprocess_prev_layer (prev_layer_voxel_grid )
137+ x = torch .cat ([x , y ], dim = 1 )
138+
139+ if self ._low_dim_size > 0 :
140+ p = self ._proprio_preprocess (proprio )
141+ p = p .unsqueeze (- 1 ).unsqueeze (- 1 ).unsqueeze (- 1 ).repeat (
142+ 1 , 1 , d , h , w )
143+ x = torch .cat ([x , p ], dim = 1 )
144+
145+ d0 = self ._down0 (x )
146+ ss0 = self ._ss0 (d0 )
147+ maxp0 = self ._global_maxp (d0 ).view (b , - 1 )
148+ d1 = u = self ._down1 (self ._local_maxp (d0 ))
149+ ss1 = self ._ss1 (d1 )
150+ maxp1 = self ._global_maxp (d1 ).view (b , - 1 )
151+
152+ feats = [ss0 , maxp0 , ss1 , maxp1 ]
153+
154+ if self ._voxel_size > 8 :
155+ d2 = u = self ._down2 (self ._local_maxp (d1 ))
156+ feats .extend ([self ._ss2 (d2 ), self ._global_maxp (d2 ).view (b , - 1 )])
157+ if self ._voxel_size > 16 :
158+ d3 = self ._down3 (self ._local_maxp (d2 ))
159+ feats .extend ([self ._ss3 (d3 ), self ._global_maxp (d3 ).view (b , - 1 )])
160+ u3 = self ._up3 (d3 )
161+ u = torch .cat ([d2 , u3 ], dim = 1 )
162+ u2 = self ._up2 (u )
163+ u = torch .cat ([d1 , u2 ], dim = 1 )
164+
165+ u1 = self ._up1 (u )
166+ f1 = self ._final (torch .cat ([d0 , u1 ], dim = 1 ))
167+ trans = self ._final2 (f1 )
168+
169+ feats .extend ([self ._ss_final (f1 ), self ._global_maxp (f1 ).view (b , - 1 )])
170+
171+ self .latent_dict = {
172+ 'd0' : d0 .mean (- 1 ).mean (- 1 ).mean (- 1 ),
173+ 'd1' : d1 .mean (- 1 ).mean (- 1 ).mean (- 1 ),
174+ 'u1' : u1 .mean (- 1 ).mean (- 1 ).mean (- 1 ),
175+ 'trans_out' : trans ,
176+ }
177+
178+ rot_and_grip_out = None
179+ if self ._out_dense > 0 :
180+ dense0 = self ._dense0 (torch .cat (feats , 1 ))
181+ dense1 = self ._dense1 (dense0 )
182+ rot_and_grip_out = self ._dense2 (dense1 )
183+ self .latent_dict .update ({
184+ 'dense0' : dense0 ,
185+ 'dense1' : dense1 ,
186+ 'dense2' : rot_and_grip_out ,
187+ })
188+
189+ if self ._voxel_size > 8 :
190+ self .latent_dict .update ({
191+ 'd2' : d2 .mean (- 1 ).mean (- 1 ).mean (- 1 ),
192+ 'u2' : u2 .mean (- 1 ).mean (- 1 ).mean (- 1 ),
193+ })
194+ if self ._voxel_size > 16 :
195+ self .latent_dict .update ({
196+ 'd3' : d3 .mean (- 1 ).mean (- 1 ).mean (- 1 ),
197+ 'u3' : u3 .mean (- 1 ).mean (- 1 ).mean (- 1 ),
198+ })
199+
200+ return trans , rot_and_grip_out
0 commit comments