diff --git a/torchvision/tv_tensors/_dataset_wrapper.py b/torchvision/tv_tensors/_dataset_wrapper.py index 23683221f60..ed178c8d1f9 100644 --- a/torchvision/tv_tensors/_dataset_wrapper.py +++ b/torchvision/tv_tensors/_dataset_wrapper.py @@ -356,19 +356,6 @@ def coco_dectection_wrapper_factory(dataset, target_keys): default={"image_id", "boxes", "labels"}, ) - def segmentation_to_mask(segmentation, *, canvas_size): - from pycocotools import mask - - if isinstance(segmentation, dict): - # if counts is a string, it is already an encoded RLE mask - if not isinstance(segmentation["counts"], str): - segmentation = mask.frPyObjects(segmentation, *canvas_size) - elif isinstance(segmentation, list): - segmentation = mask.merge(mask.frPyObjects(segmentation, *canvas_size)) - else: - raise ValueError(f"COCO segmentation expected to be a dict or a list, got {type(segmentation)}") - return torch.from_numpy(mask.decode(segmentation)) - def wrapper(idx, sample): image_id = dataset.ids[idx] @@ -394,15 +381,10 @@ def wrapper(idx, sample): ), new_format=tv_tensors.BoundingBoxFormat.XYXY, ) - + coco_ann = dataset.coco.imgToAnns[image_id] if "masks" in target_keys: target["masks"] = tv_tensors.Mask( - torch.stack( - [ - segmentation_to_mask(segmentation, canvas_size=canvas_size) - for segmentation in batched_target["segmentation"] - ] - ), + torch.stack([torch.from_numpy(dataset.coco.annToMask(ann)) for ann in coco_ann]), ) if "labels" in target_keys: