-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP WIP WIP Can test version Can test version modify for dump onnx ready version ready version ready version ready version ready version ready version
- Loading branch information
1 parent
36eec70
commit cf37a4b
Showing
37 changed files
with
3,379 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
import torch.utils.data | ||
import torchvision | ||
|
||
from .coco import build as build_coco | ||
|
||
|
||
def get_coco_api_from_dataset(dataset): | ||
for _ in range(10): | ||
# if isinstance(dataset, torchvision.datasets.CocoDetection): | ||
# break | ||
if isinstance(dataset, torch.utils.data.Subset): | ||
dataset = dataset.dataset | ||
if isinstance(dataset, torchvision.datasets.CocoDetection): | ||
return dataset.coco | ||
|
||
|
||
def build_dataset(image_set, args): | ||
if args.dataset_file == 'coco': | ||
return build_coco(image_set, args) | ||
if args.dataset_file == 'coco_panoptic': | ||
# to avoid making panopticapi required for coco | ||
from .coco_panoptic import build as build_coco_panoptic | ||
return build_coco_panoptic(image_set, args) | ||
raise ValueError(f'dataset {args.dataset_file} not supported') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
""" | ||
COCO dataset which returns image_id for evaluation. | ||
Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py | ||
""" | ||
from pathlib import Path | ||
|
||
import torch | ||
import torch.utils.data | ||
import torchvision | ||
from pycocotools import mask as coco_mask | ||
|
||
import datasets.transforms as T | ||
|
||
|
||
class CocoDetection(torchvision.datasets.CocoDetection): | ||
def __init__(self, img_folder, ann_file, transforms, return_masks): | ||
super(CocoDetection, self).__init__(img_folder, ann_file) | ||
self._transforms = transforms | ||
self.prepare = ConvertCocoPolysToMask(return_masks) | ||
|
||
def __getitem__(self, idx): | ||
img, target = super(CocoDetection, self).__getitem__(idx) | ||
image_id = self.ids[idx] | ||
target = {'image_id': image_id, 'annotations': target} | ||
img, target = self.prepare(img, target) | ||
if self._transforms is not None: | ||
img, target = self._transforms(img, target) | ||
return img, target | ||
|
||
|
||
def convert_coco_poly_to_mask(segmentations, height, width): | ||
masks = [] | ||
for polygons in segmentations: | ||
rles = coco_mask.frPyObjects(polygons, height, width) | ||
mask = coco_mask.decode(rles) | ||
if len(mask.shape) < 3: | ||
mask = mask[..., None] | ||
mask = torch.as_tensor(mask, dtype=torch.uint8) | ||
mask = mask.any(dim=2) | ||
masks.append(mask) | ||
if masks: | ||
masks = torch.stack(masks, dim=0) | ||
else: | ||
masks = torch.zeros((0, height, width), dtype=torch.uint8) | ||
return masks | ||
|
||
|
||
class ConvertCocoPolysToMask(object): | ||
def __init__(self, return_masks=False): | ||
self.return_masks = return_masks | ||
|
||
def __call__(self, image, target): | ||
w, h = image.size | ||
|
||
image_id = target["image_id"] | ||
image_id = torch.tensor([image_id]) | ||
|
||
anno = target["annotations"] | ||
|
||
anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] | ||
|
||
boxes = [obj["bbox"] for obj in anno] | ||
# guard against no boxes via resizing | ||
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) | ||
boxes[:, 2:] += boxes[:, :2] | ||
boxes[:, 0::2].clamp_(min=0, max=w) | ||
boxes[:, 1::2].clamp_(min=0, max=h) | ||
|
||
classes = [obj["category_id"] for obj in anno] | ||
classes = torch.tensor(classes, dtype=torch.int64) | ||
|
||
if self.return_masks: | ||
segmentations = [obj["segmentation"] for obj in anno] | ||
masks = convert_coco_poly_to_mask(segmentations, h, w) | ||
|
||
keypoints = None | ||
if anno and "keypoints" in anno[0]: | ||
keypoints = [obj["keypoints"] for obj in anno] | ||
keypoints = torch.as_tensor(keypoints, dtype=torch.float32) | ||
num_keypoints = keypoints.shape[0] | ||
if num_keypoints: | ||
keypoints = keypoints.view(num_keypoints, -1, 3) | ||
|
||
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) | ||
boxes = boxes[keep] | ||
classes = classes[keep] | ||
if self.return_masks: | ||
masks = masks[keep] | ||
if keypoints is not None: | ||
keypoints = keypoints[keep] | ||
|
||
target = {} | ||
target["boxes"] = boxes | ||
target["labels"] = classes | ||
if self.return_masks: | ||
target["masks"] = masks | ||
target["image_id"] = image_id | ||
if keypoints is not None: | ||
target["keypoints"] = keypoints | ||
|
||
# for conversion to coco api | ||
area = torch.tensor([obj["area"] for obj in anno]) | ||
iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) | ||
target["area"] = area[keep] | ||
target["iscrowd"] = iscrowd[keep] | ||
|
||
target["orig_size"] = torch.as_tensor([int(h), int(w)]) | ||
target["size"] = torch.as_tensor([int(h), int(w)]) | ||
|
||
return image, target | ||
|
||
|
||
def make_coco_transforms(image_set): | ||
|
||
normalize = T.Compose([ | ||
T.ToTensor(), | ||
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | ||
]) | ||
|
||
scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] | ||
|
||
if image_set == 'train': | ||
return T.Compose([ | ||
T.RandomHorizontalFlip(), | ||
T.RandomSelect( | ||
T.RandomResize(scales, max_size=1333), | ||
T.Compose([ | ||
T.RandomResize([400, 500, 600]), | ||
T.RandomSizeCrop(384, 600), | ||
T.RandomResize(scales, max_size=1333), | ||
]) | ||
), | ||
normalize, | ||
]) | ||
|
||
if image_set == 'val': | ||
return T.Compose([ | ||
T.RandomResize([800], max_size=1333), | ||
normalize, | ||
]) | ||
|
||
raise ValueError(f'unknown {image_set}') | ||
|
||
|
||
def build(image_set, args): | ||
root = Path(args.coco_path) | ||
assert root.exists(), f'provided COCO path {root} does not exist' | ||
mode = 'instances' | ||
PATHS = { | ||
"train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'), | ||
"val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'), | ||
} | ||
|
||
img_folder, ann_file = PATHS[image_set] | ||
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms("val"), return_masks=args.masks) | ||
return dataset |
Oops, something went wrong.