From 4afe54ab1d9508203111cedecb7863fa1e637e37 Mon Sep 17 00:00:00 2001 From: venkatram-dev Date: Tue, 10 Sep 2024 15:52:27 -0700 Subject: [PATCH 1/2] use_pycocotools_for_segment_mask --- torchvision/tv_tensors/_dataset_wrapper.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/torchvision/tv_tensors/_dataset_wrapper.py b/torchvision/tv_tensors/_dataset_wrapper.py index 23683221f60..b73cff08b5b 100644 --- a/torchvision/tv_tensors/_dataset_wrapper.py +++ b/torchvision/tv_tensors/_dataset_wrapper.py @@ -356,19 +356,6 @@ def coco_dectection_wrapper_factory(dataset, target_keys): default={"image_id", "boxes", "labels"}, ) - def segmentation_to_mask(segmentation, *, canvas_size): - from pycocotools import mask - - if isinstance(segmentation, dict): - # if counts is a string, it is already an encoded RLE mask - if not isinstance(segmentation["counts"], str): - segmentation = mask.frPyObjects(segmentation, *canvas_size) - elif isinstance(segmentation, list): - segmentation = mask.merge(mask.frPyObjects(segmentation, *canvas_size)) - else: - raise ValueError(f"COCO segmentation expected to be a dict or a list, got {type(segmentation)}") - return torch.from_numpy(mask.decode(segmentation)) - def wrapper(idx, sample): image_id = dataset.ids[idx] @@ -394,13 +381,13 @@ def wrapper(idx, sample): ), new_format=tv_tensors.BoundingBoxFormat.XYXY, ) - + coco_ann = dataset.coco.imgToAnns[image_id] if "masks" in target_keys: target["masks"] = tv_tensors.Mask( torch.stack( [ - segmentation_to_mask(segmentation, canvas_size=canvas_size) - for segmentation in batched_target["segmentation"] + torch.from_numpy(dataset.coco.annToMask(ann)) + for ann in coco_ann ] ), ) From 3233cf7b303623d331aba21ab0447c200e5ac3fc Mon Sep 17 00:00:00 2001 From: venkatram-dev Date: Tue, 10 Sep 2024 21:41:52 -0700 Subject: [PATCH 2/2] fix ufmt format --- torchvision/tv_tensors/_dataset_wrapper.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/torchvision/tv_tensors/_dataset_wrapper.py b/torchvision/tv_tensors/_dataset_wrapper.py index b73cff08b5b..ed178c8d1f9 100644 --- a/torchvision/tv_tensors/_dataset_wrapper.py +++ b/torchvision/tv_tensors/_dataset_wrapper.py @@ -384,12 +384,7 @@ def wrapper(idx, sample): coco_ann = dataset.coco.imgToAnns[image_id] if "masks" in target_keys: target["masks"] = tv_tensors.Mask( - torch.stack( - [ - torch.from_numpy(dataset.coco.annToMask(ann)) - for ann in coco_ann - ] - ), + torch.stack([torch.from_numpy(dataset.coco.annToMask(ann)) for ann in coco_ann]), ) if "labels" in target_keys: