Skip to content

Commit 89b212f

Browse files
authored
Merge pull request #399 from facebookresearch/yvsriram/ovmm_im_fixes
Get 2d association for IM working in OVMM envs
2 parents 601debd + 3ce6e45 commit 89b212f

File tree

7 files changed

+78
-42
lines changed

7 files changed

+78
-42
lines changed

projects/habitat_ovmm/configs/agent/heuristic_instance_tracking_agent.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ SEMANTIC_MAP:
2424
dilate_size: 3
2525
dilate_iter: 1
2626
record_instance_ids: True
27+
instance_association: map_overlap
2728
max_instances: 0
2829

2930
SKILLS:

src/home_robot/home_robot/agent/objectnav_agent/objectnav_agent.py

+4
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,8 @@ def reset_vectorized(self):
293293
self.last_poses = [np.zeros(3)] * self.num_environments
294294
self.semantic_map.init_map_and_pose()
295295
self.episode_panorama_start_steps = self.panorama_start_steps
296+
if self.record_instance_ids:
297+
self.instance_memory.reset()
296298
self.planner.reset()
297299

298300
def reset_vectorized_for_env(self, e: int):
@@ -302,6 +304,8 @@ def reset_vectorized_for_env(self, e: int):
302304
self.last_poses[e] = np.zeros(3)
303305
self.semantic_map.init_map_and_pose_for_env(e)
304306
self.episode_panorama_start_steps = self.panorama_start_steps
307+
if self.record_instance_ids:
308+
self.instance_memory.reset_for_env(e)
305309
self.planner.reset()
306310

307311
# ---------------------------------------------------------------------

src/home_robot/home_robot/mapping/semantic/categorical_2d_semantic_map_module.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(
7070
evaluate_instance_tracking: bool = False,
7171
instance_memory: Optional[InstanceMemory] = None,
7272
max_instances: int = 0,
73-
instance_association: str = "bbox_iou",
73+
instance_association: str = "map_overlap",
7474
dilation_for_instances: int = 5,
7575
padding_for_instance_overlap: int = 5,
7676
):
@@ -816,6 +816,7 @@ def _update_global_map_instances_for_one_channel(
816816
extended_dilated_local_map,
817817
global_instances_within_local,
818818
max_instance_id,
819+
torch.unique(extended_local_map),
819820
)
820821

821822
# Update the global map with the associated instances from the local map
@@ -838,21 +839,23 @@ def _get_local_to_global_instance_mapping(
838839
extended_local_labels: Tensor,
839840
global_instances_within_local: Tensor,
840841
max_instance_id: int,
842+
local_instance_ids: Tensor,
841843
) -> dict:
842844
"""
843845
Creates a mapping of local instance IDs to global instance IDs.
844846
845847
Args:
846848
extended_local_labels: Labels of instances in the extended local map.
847849
global_instances_within_local: Instances from the global map within the local map's region.
848-
850+
max_instance_id: The number of instance ids that are used up
851+
local_instance_ids: The local instance ids for which local to global mapping is to be determined
849852
Returns:
850853
A mapping of local instance IDs to global instance IDs.
851854
"""
852855
instance_mapping = {}
853856

854857
# Associate instances in the local map with corresponding instances in the global map
855-
for local_instance_id in torch.unique(extended_local_labels):
858+
for local_instance_id in local_instance_ids:
856859
if local_instance_id == 0:
857860
# ignore 0 as it does not correspond to an instance
858861
continue
@@ -879,7 +882,7 @@ def _get_local_to_global_instance_mapping(
879882
self.instance_memory.add_view_to_instance(
880883
env_id, int(local_instance_id.item()), global_instance_id
881884
)
882-
instance_mapping[0] = 0
885+
instance_mapping[0.0] = 0
883886
return instance_mapping
884887

885888
def _update_global_map_instances(

src/home_robot/home_robot/mapping/semantic/instance_tracking_modules.py

+38-36
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
get_box_bounds_from_verts,
2727
get_box_verts_from_bounds,
2828
)
29-
from home_robot.utils.image import dilate_or_erode_mask
29+
from home_robot.utils.image import dilate_or_erode_mask, interpolate_image
3030
from home_robot.utils.point_cloud_torch import get_bounds
3131
from home_robot.utils.voxel import drop_smallest_weight_points
3232

@@ -76,6 +76,10 @@ class InstanceView:
7676
cam_to_world: Tensor = None
7777
"""[4,4] Tensor pose matrix mapping camera space to world space"""
7878

79+
# Where did we observe this from
80+
pose: Tensor = None
81+
""" Base pose of the robot when this view was collected"""
82+
7983
@cached_property
8084
def object_coverage(self):
8185
return float(self.mask.sum()) / self.mask.size
@@ -271,7 +275,7 @@ def __init__(
271275
self,
272276
num_envs: int,
273277
du_scale: int,
274-
instance_association: str = "bbox_iou",
278+
instance_association: str = "map_overlap",
275279
instance_association_within_class: bool = True,
276280
iou_threshold: float = 0.8,
277281
global_box_nms_thresh: float = 0.0,
@@ -647,7 +651,8 @@ def associate_instances_to_memory(self):
647651
match_within_category=self.instance_association_within_class,
648652
)
649653
if global_instance_id is None:
650-
global_instance_id = len(self.instances[env_id])
654+
# start ids from 1
655+
global_instance_id = len(self.instances[env_id]) + 1
651656
self.add_view_to_instance(
652657
env_id, local_instance_id, global_instance_id
653658
)
@@ -658,7 +663,8 @@ def associate_instances_to_memory(self):
658663
match_within_category=self.instance_association_within_class,
659664
)
660665
if global_instance_id is None:
661-
global_instance_id = len(self.instances[env_id])
666+
# start ids from 1
667+
global_instance_id = len(self.instances[env_id]) + 1
662668
self.add_view_to_instance(
663669
env_id, local_instance_id, global_instance_id
664670
)
@@ -752,6 +758,7 @@ def process_instances_for_env(
752758
mask_out_object: bool = True,
753759
background_instance_label: int = 0,
754760
valid_points: Optional[Tensor] = None,
761+
pose: Optional[Tensor] = None,
755762
):
756763
"""
757764
Process instance information in the current frame and add instance views to the list of unprocessed views for future association.
@@ -764,7 +771,7 @@ def process_instances_for_env(
764771
instance_seg (Tensor): [H, W] tensor of instance ids at each pixel
765772
point_cloud (Tensor): Point cloud data in world coordinates.
766773
image (Tensor): [3, H, W] RGB image
767-
pose: 4x4 camera_space_to_world transform
774+
cam_to_world: 4x4 camera_space_to_world transform
768775
instance_classes (Optional[Tensor]): [K,] class ids for each instance in instance seg
769776
class_int = instance_classes[instance_id]
770777
instance_scores (Optional[Tensor]): [K,] detection confidences for each instance in instance_seg
@@ -773,6 +780,7 @@ def process_instances_for_env(
773780
# If false does it not save crops? Not black background?
774781
background_class_label(int): id indicating background points in instance_seg. That view is not saved. (default = 0)
775782
valid_points (Tensor): [H, W] boolean tensor indicating valid points in the pointcloud
783+
pose: (Optional[Tensor]): base pose of the agent at this timestep
776784
Note:
777785
- The method creates instance views for detected instances within the provided data.
778786
- If a semantic segmentation tensor is provided, each instance is associated with a semantic category.
@@ -788,35 +796,29 @@ def process_instances_for_env(
788796
), "Ensure that RGB images are channels-first and in the right format."
789797

790798
self.unprocessed_views[env_id] = {}
791-
# append image to list of images
799+
# append image to list of images; move tensors to cpu to prevent memory from blowing up
792800
if self.images[env_id] is None:
793-
self.images[env_id] = image.unsqueeze(0)
801+
self.images[env_id] = image.unsqueeze(0).detach().cpu()
794802
else:
795803
self.images[env_id] = torch.cat(
796-
[self.images[env_id], image.unsqueeze(0)], dim=0
804+
[self.images[env_id], image.unsqueeze(0).detach().cpu()], dim=0
797805
)
798806
if self.point_cloud[env_id] is None:
799-
self.point_cloud[env_id] = point_cloud.unsqueeze(0)
807+
self.point_cloud[env_id] = point_cloud.unsqueeze(0).detach().cpu()
800808
else:
801809
self.point_cloud[env_id] = torch.cat(
802-
[self.point_cloud[env_id], point_cloud.unsqueeze(0)], dim=0
810+
[self.point_cloud[env_id], point_cloud.unsqueeze(0).detach().cpu()],
811+
dim=0,
803812
)
804813

805-
# Valid opints
814+
# Valid points
806815
if valid_points is None:
807-
valid_points = torch.full(
808-
image.shape[:, 0], True, dtype=torch.bool, device=image.device
816+
valid_points = torch.full_like(
817+
image[0], True, dtype=torch.bool, device=image.device
809818
)
810819
if self.du_scale != 1:
811-
valid_points_downsampled = (
812-
torch.nn.functional.interpolate(
813-
valid_points.unsqueeze(0).unsqueeze(0).float(),
814-
scale_factor=1 / self.du_scale,
815-
mode="nearest",
816-
)
817-
.squeeze(0)
818-
.squeeze(0)
819-
.bool()
820+
valid_points_downsampled = interpolate_image(
821+
valid_points, scale_factor=1 / self.du_scale
820822
)
821823
else:
822824
valid_points_downsampled = valid_points
@@ -863,20 +865,15 @@ def process_instances_for_env(
863865

864866
# TODO: If we use du_scale, we should apply this at the beginning to speed things up
865867
if self.du_scale != 1:
866-
# downsample mask by du_scale using "NEAREST"
867-
instance_mask_downsampled = (
868-
torch.nn.functional.interpolate(
869-
instance_mask.unsqueeze(0).unsqueeze(0).float(),
870-
scale_factor=1 / self.du_scale,
871-
mode="nearest",
872-
)
873-
.squeeze(0)
874-
.squeeze(0)
875-
.bool()
868+
instance_mask_downsampled = interpolate_image(
869+
instance_mask, scale_factor=1 / self.du_scale
870+
)
871+
image_downsampled = interpolate_image(
872+
image, scale_factor=1 / self.du_scale
876873
)
877-
878874
else:
879875
instance_mask_downsampled = instance_mask
876+
image_downsampled = image
880877

881878
# Erode instance masks for point cloud
882879
# TODO: We can do erosion and masking on the downsampled/cropped image to avoid unnecessary computation
@@ -912,7 +909,9 @@ def process_instances_for_env(
912909
instance_mask_downsampled & valid_points_downsampled
913910
)
914911
point_cloud_instance = point_cloud[point_mask_downsampled]
915-
point_cloud_rgb_instance = image.permute(1, 2, 0)[point_mask_downsampled]
912+
point_cloud_rgb_instance = image_downsampled.permute(1, 2, 0)[
913+
point_mask_downsampled
914+
]
916915

917916
n_points = point_mask_downsampled.sum()
918917
n_mask = instance_mask_downsampled.sum()
@@ -941,6 +940,7 @@ def process_instances_for_env(
941940
category_id=category_id,
942941
score=score,
943942
bounds=bounds, # .cpu().numpy(),
943+
pose=pose,
944944
)
945945
# append instance view to list of instance views
946946
self.unprocessed_views[env_id][instance_id.item()] = instance_view
@@ -970,9 +970,10 @@ def process_instances(
970970
self,
971971
instance_channels: Tensor,
972972
point_cloud: Tensor,
973-
pose: torch.Tensor,
974973
image: Tensor,
974+
cam_to_world: Optional[Tensor] = None,
975975
semantic_channels: Optional[Tensor] = None,
976+
pose: Optional[Tensor] = None,
976977
):
977978
"""
978979
Process instance information across environments and associate instance views with global instances.
@@ -1005,9 +1006,10 @@ def process_instances(
10051006
env_id,
10061007
instance_segs[env_id],
10071008
point_cloud[env_id],
1008-
pose[env_id],
10091009
image[env_id],
1010+
cam_to_world=cam_to_world[env_id] if cam_to_world is not None else None,
10101011
semantic_seg=semantic_seg,
1012+
pose=pose[env_id] if pose is not None else None,
10111013
)
10121014
self.associate_instances_to_memory()
10131015

src/home_robot/home_robot/mapping/voxel/voxel.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -348,12 +348,13 @@ def add(
348348
instance_seg=instance,
349349
point_cloud=full_world_xyz.reshape(H, W, 3),
350350
image=rgb.permute(2, 0, 1),
351-
cam_to_world=base_pose,
351+
cam_to_world=camera_pose,
352352
instance_classes=instance_classes,
353353
instance_scores=instance_scores,
354354
mask_out_object=False, # Save the whole image here? Or is this with background?
355355
background_instance_label=self.background_instance_label,
356356
valid_points=valid_depth,
357+
pose=base_pose,
357358
)
358359
self.instances.associate_instances_to_memory()
359360

src/home_robot/home_robot/utils/image.py

+25
Original file line numberDiff line numberDiff line change
@@ -348,3 +348,28 @@ def get_cropped_image_with_padding(self, image, bbox, padding: float = 1.0):
348348
x:x2,
349349
]
350350
return cropped_image
351+
352+
353+
def interpolate_image(image: Tensor, scale_factor: float = 1.0, mode: str = "nearest"):
354+
"""
355+
Interpolates images by the specified scale_factor using the specific interpolation mode.
356+
This method uses `torch.nn.functional.interpolate` by temporarily adding batch dimension and channel dimension for 2D inputs.
357+
image (Tensor): image of shape [3, H, W] or [H, W]
358+
scale_factor (float): multiplier for spatial size
359+
mode: (str): algorithm for interpolation: 'nearest' (default), 'bicubic' or other interpolation modes at https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
360+
"""
361+
362+
if len(image.shape) == 2:
363+
image = image.unsqueeze(0)
364+
365+
image_downsampled = (
366+
torch.nn.functional.interpolate(
367+
image.unsqueeze(0).float(),
368+
scale_factor=scale_factor,
369+
mode=mode,
370+
)
371+
.squeeze()
372+
.squeeze()
373+
.bool()
374+
)
375+
return image_downsampled

src/home_robot_sim/home_robot_sim/env/habitat_objectnav_env/visualizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def _visualize_instance_counts(
501501
'"""
502502
num_instances_per_category = defaultdict(int)
503503
num_views_per_instance = defaultdict(list)
504-
for instance_id, instance in instance_memory.instance_views[0].items():
504+
for instance_id, instance in instance_memory.instances[0].items():
505505
num_instances_per_category[instance.category_id.item()] += 1
506506
num_views_per_instance[instance.category_id.item()].append(
507507
len(instance.instance_views)

0 commit comments

Comments
 (0)