26
26
get_box_bounds_from_verts ,
27
27
get_box_verts_from_bounds ,
28
28
)
29
- from home_robot .utils .image import dilate_or_erode_mask
29
+ from home_robot .utils .image import dilate_or_erode_mask , interpolate_image
30
30
from home_robot .utils .point_cloud_torch import get_bounds
31
31
from home_robot .utils .voxel import drop_smallest_weight_points
32
32
@@ -76,6 +76,10 @@ class InstanceView:
76
76
cam_to_world : Tensor = None
77
77
"""[4,4] Tensor pose matrix mapping camera space to world space"""
78
78
79
+ # Where did we observe this from
80
+ pose : Tensor = None
81
+ """ Base pose of the robot when this view was collected"""
82
+
79
83
@cached_property
80
84
def object_coverage (self ):
81
85
return float (self .mask .sum ()) / self .mask .size
@@ -271,7 +275,7 @@ def __init__(
271
275
self ,
272
276
num_envs : int ,
273
277
du_scale : int ,
274
- instance_association : str = "bbox_iou " ,
278
+ instance_association : str = "map_overlap " ,
275
279
instance_association_within_class : bool = True ,
276
280
iou_threshold : float = 0.8 ,
277
281
global_box_nms_thresh : float = 0.0 ,
@@ -647,7 +651,8 @@ def associate_instances_to_memory(self):
647
651
match_within_category = self .instance_association_within_class ,
648
652
)
649
653
if global_instance_id is None :
650
- global_instance_id = len (self .instances [env_id ])
654
+ # start ids from 1
655
+ global_instance_id = len (self .instances [env_id ]) + 1
651
656
self .add_view_to_instance (
652
657
env_id , local_instance_id , global_instance_id
653
658
)
@@ -658,7 +663,8 @@ def associate_instances_to_memory(self):
658
663
match_within_category = self .instance_association_within_class ,
659
664
)
660
665
if global_instance_id is None :
661
- global_instance_id = len (self .instances [env_id ])
666
+ # start ids from 1
667
+ global_instance_id = len (self .instances [env_id ]) + 1
662
668
self .add_view_to_instance (
663
669
env_id , local_instance_id , global_instance_id
664
670
)
@@ -752,6 +758,7 @@ def process_instances_for_env(
752
758
mask_out_object : bool = True ,
753
759
background_instance_label : int = 0 ,
754
760
valid_points : Optional [Tensor ] = None ,
761
+ pose : Optional [Tensor ] = None ,
755
762
):
756
763
"""
757
764
Process instance information in the current frame and add instance views to the list of unprocessed views for future association.
@@ -764,7 +771,7 @@ def process_instances_for_env(
764
771
instance_seg (Tensor): [H, W] tensor of instance ids at each pixel
765
772
point_cloud (Tensor): Point cloud data in world coordinates.
766
773
image (Tensor): [3, H, W] RGB image
767
- pose : 4x4 camera_space_to_world transform
774
+ cam_to_world : 4x4 camera_space_to_world transform
768
775
instance_classes (Optional[Tensor]): [K,] class ids for each instance in instance seg
769
776
class_int = instance_classes[instance_id]
770
777
instance_scores (Optional[Tensor]): [K,] detection confidences for each instance in instance_seg
@@ -773,6 +780,7 @@ def process_instances_for_env(
773
780
# If false does it not save crops? Not black background?
774
781
background_class_label(int): id indicating background points in instance_seg. That view is not saved. (default = 0)
775
782
valid_points (Tensor): [H, W] boolean tensor indicating valid points in the pointcloud
783
+ pose: (Optional[Tensor]): base pose of the agent at this timestep
776
784
Note:
777
785
- The method creates instance views for detected instances within the provided data.
778
786
- If a semantic segmentation tensor is provided, each instance is associated with a semantic category.
@@ -788,35 +796,29 @@ def process_instances_for_env(
788
796
), "Ensure that RGB images are channels-first and in the right format."
789
797
790
798
self .unprocessed_views [env_id ] = {}
791
- # append image to list of images
799
+ # append image to list of images; move tensors to cpu to prevent memory from blowing up
792
800
if self .images [env_id ] is None :
793
- self .images [env_id ] = image .unsqueeze (0 )
801
+ self .images [env_id ] = image .unsqueeze (0 ). detach (). cpu ()
794
802
else :
795
803
self .images [env_id ] = torch .cat (
796
- [self .images [env_id ], image .unsqueeze (0 )], dim = 0
804
+ [self .images [env_id ], image .unsqueeze (0 ). detach (). cpu () ], dim = 0
797
805
)
798
806
if self .point_cloud [env_id ] is None :
799
- self .point_cloud [env_id ] = point_cloud .unsqueeze (0 )
807
+ self .point_cloud [env_id ] = point_cloud .unsqueeze (0 ). detach (). cpu ()
800
808
else :
801
809
self .point_cloud [env_id ] = torch .cat (
802
- [self .point_cloud [env_id ], point_cloud .unsqueeze (0 )], dim = 0
810
+ [self .point_cloud [env_id ], point_cloud .unsqueeze (0 ).detach ().cpu ()],
811
+ dim = 0 ,
803
812
)
804
813
805
- # Valid opints
814
+ # Valid points
806
815
if valid_points is None :
807
- valid_points = torch .full (
808
- image . shape [:, 0 ], True , dtype = torch .bool , device = image .device
816
+ valid_points = torch .full_like (
817
+ image [ 0 ], True , dtype = torch .bool , device = image .device
809
818
)
810
819
if self .du_scale != 1 :
811
- valid_points_downsampled = (
812
- torch .nn .functional .interpolate (
813
- valid_points .unsqueeze (0 ).unsqueeze (0 ).float (),
814
- scale_factor = 1 / self .du_scale ,
815
- mode = "nearest" ,
816
- )
817
- .squeeze (0 )
818
- .squeeze (0 )
819
- .bool ()
820
+ valid_points_downsampled = interpolate_image (
821
+ valid_points , scale_factor = 1 / self .du_scale
820
822
)
821
823
else :
822
824
valid_points_downsampled = valid_points
@@ -863,20 +865,15 @@ def process_instances_for_env(
863
865
864
866
# TODO: If we use du_scale, we should apply this at the beginning to speed things up
865
867
if self .du_scale != 1 :
866
- # downsample mask by du_scale using "NEAREST"
867
- instance_mask_downsampled = (
868
- torch .nn .functional .interpolate (
869
- instance_mask .unsqueeze (0 ).unsqueeze (0 ).float (),
870
- scale_factor = 1 / self .du_scale ,
871
- mode = "nearest" ,
872
- )
873
- .squeeze (0 )
874
- .squeeze (0 )
875
- .bool ()
868
+ instance_mask_downsampled = interpolate_image (
869
+ instance_mask , scale_factor = 1 / self .du_scale
870
+ )
871
+ image_downsampled = interpolate_image (
872
+ image , scale_factor = 1 / self .du_scale
876
873
)
877
-
878
874
else :
879
875
instance_mask_downsampled = instance_mask
876
+ image_downsampled = image
880
877
881
878
# Erode instance masks for point cloud
882
879
# TODO: We can do erosion and masking on the downsampled/cropped image to avoid unnecessary computation
@@ -912,7 +909,9 @@ def process_instances_for_env(
912
909
instance_mask_downsampled & valid_points_downsampled
913
910
)
914
911
point_cloud_instance = point_cloud [point_mask_downsampled ]
915
- point_cloud_rgb_instance = image .permute (1 , 2 , 0 )[point_mask_downsampled ]
912
+ point_cloud_rgb_instance = image_downsampled .permute (1 , 2 , 0 )[
913
+ point_mask_downsampled
914
+ ]
916
915
917
916
n_points = point_mask_downsampled .sum ()
918
917
n_mask = instance_mask_downsampled .sum ()
@@ -941,6 +940,7 @@ def process_instances_for_env(
941
940
category_id = category_id ,
942
941
score = score ,
943
942
bounds = bounds , # .cpu().numpy(),
943
+ pose = pose ,
944
944
)
945
945
# append instance view to list of instance views
946
946
self .unprocessed_views [env_id ][instance_id .item ()] = instance_view
@@ -970,9 +970,10 @@ def process_instances(
970
970
self ,
971
971
instance_channels : Tensor ,
972
972
point_cloud : Tensor ,
973
- pose : torch .Tensor ,
974
973
image : Tensor ,
974
+ cam_to_world : Optional [Tensor ] = None ,
975
975
semantic_channels : Optional [Tensor ] = None ,
976
+ pose : Optional [Tensor ] = None ,
976
977
):
977
978
"""
978
979
Process instance information across environments and associate instance views with global instances.
@@ -1005,9 +1006,10 @@ def process_instances(
1005
1006
env_id ,
1006
1007
instance_segs [env_id ],
1007
1008
point_cloud [env_id ],
1008
- pose [env_id ],
1009
1009
image [env_id ],
1010
+ cam_to_world = cam_to_world [env_id ] if cam_to_world is not None else None ,
1010
1011
semantic_seg = semantic_seg ,
1012
+ pose = pose [env_id ] if pose is not None else None ,
1011
1013
)
1012
1014
self .associate_instances_to_memory ()
1013
1015
0 commit comments