@@ -201,11 +201,6 @@ def setup_optimization_and_priors(
201201 data = {}
202202 self .shared_intrinsics = shared_intrinsics
203203
204- if shared_intrinsics : # si => must use pinhole
205- assert (
206- self .camera_model == camera_models ["pinhole" ]
207- ), f"Shared intrinsics only supported with pinhole camera model: { self .camera_model } "
208-
209204 self .estimate_gravity = True
210205 if "prior_gravity" in data :
211206 self .estimate_gravity = False
@@ -239,6 +234,10 @@ def setup_optimization_and_priors(
239234 )
240235 )
241236
237+ self .n_intrinsic_params = self .estimate_focal + (
238+ self .camera_model .num_dist_params () if self .camera_has_distortion else 0
239+ )
240+
242241 logger .debug (f"Camera Model: { self .camera_model } " )
243242 logger .debug (f"Optimizing gravity: { self .estimate_gravity } ({ self .gravity_delta_dims } )" )
244243 logger .debug (f"Optimizing focal: { self .estimate_focal } ({ self .focal_delta_dims } )" )
@@ -352,23 +351,35 @@ def calculate_gradient_and_hessian(
352351 # reshape to (1, B * (N_params-1) + 1)
353352 Grad_g = Grad [..., :2 ].reshape (1 , - 1 )
354353 Grad_f = Grad [..., 2 ].reshape (1 , - 1 ).sum (- 1 , keepdim = True )
355- Grad = torch .cat ([Grad_g , Grad_f ], dim = - 1 )
354+ Grad_dist = Grad [..., 3 :].sum (- 2 ).reshape (1 , - 1 )
355+ Grad = torch .cat ([Grad_g , Grad_f , Grad_dist ], dim = - 1 )
356356
357357 Hess = torch .einsum ("...Njk,...Njl->...Nkl" , J , J )
358358 Hess = weights [..., None , None ] * Hess
359359 Hess = Hess .sum (- 3 )
360360
361361 if shared_intrinsics :
362+ """
363+ Hess =
364+ [
365+ diag(H_G ) J_g_intrinsic
366+
367+ J_g_intrinsic^T H_intrinsic
368+ ]
369+ """
370+ B = Hess .shape [0 ]
362371 H_g = torch .block_diag (* list (Hess [..., :2 , :2 ]))
363- J_fg = Hess [..., :2 , 2 ].flatten ()
364- J_gf = Hess [..., 2 , :2 ].flatten ()
365- J_f = Hess [..., 2 , 2 ].sum ()
366- dims = H_g .shape [- 1 ] + 1
372+ J_intrinsics_g = Hess [..., :2 , 2 :].reshape (B * 2 , - 1 )
373+ J_g_intrinsics = Hess [..., 2 :, :2 ].permute (0 , 2 , 1 ).reshape (B * 2 , - 1 ).T
374+
375+ H_intrinsics = Hess [..., 2 :, 2 :].sum (- 3 )
376+
377+ dims = H_g .shape [- 1 ] + self .n_intrinsic_params
367378 Hess = Hess .new_zeros ((dims , dims ), dtype = torch .float32 )
368- Hess [:- 1 , :- 1 ] = H_g
369- Hess [- 1 , :- 1 ] = J_gf
370- Hess [:- 1 , - 1 ] = J_fg
371- Hess [- 1 , - 1 ] = J_f
379+ Hess [: - self . n_intrinsic_params , : - self . n_intrinsic_params ] = H_g
380+ Hess [- self . n_intrinsic_params : , : - self . n_intrinsic_params ] = J_g_intrinsics
381+ Hess [: - self . n_intrinsic_params , - self . n_intrinsic_params : ] = J_intrinsics_g
382+ Hess [- self . n_intrinsic_params : , - self . n_intrinsic_params : ] = H_intrinsics
372383 Hess = Hess .unsqueeze (0 )
373384
374385 return Grad , Hess
@@ -416,7 +427,9 @@ def setup_system(
416427 Hess = J_up .new_zeros (J_up .shape [0 ], n_params , n_params )
417428
418429 if shared_intrinsics :
419- N_params = Grad .shape [0 ] * (n_params - 1 ) + 1
430+ N_params = (
431+ Grad .shape [0 ] * (n_params - self .n_intrinsic_params ) + self .n_intrinsic_params
432+ )
420433 Grad = Grad .new_zeros (1 , N_params )
421434 Hess = Hess .new_zeros (1 , N_params , N_params )
422435
@@ -584,9 +597,10 @@ def optimize(
584597 delta = optimizer_step (Grad , Hess , lamb ) # (B, N_params)
585598
586599 if self .shared_intrinsics :
587- delta_g = delta [..., :- 1 ].reshape (B , 2 )
588- delta_f = delta [..., - 1 ].expand (B , 1 )
589- delta = torch .cat ([delta_g , delta_f ], dim = - 1 )
600+ delta_g = delta [..., : - self .n_intrinsic_params ].reshape (B , 2 )
601+ delta_f = delta [..., - self .n_intrinsic_params ].expand (B , 1 )
602+ delta_dist = delta [..., - self .n_intrinsic_params + 1 :].expand (B , - 1 )
603+ delta = torch .cat ([delta_g , delta_f , delta_dist ], dim = - 1 )
590604
591605 # calculate new cost
592606 camera_opt , gravity_opt = self .update_estimate (camera_opt , gravity_opt , delta )
0 commit comments