diff --git a/PythonAPI/pycocotools/coco.py b/PythonAPI/pycocotools/coco.py index 58bc7eaa..b4abba81 100644 --- a/PythonAPI/pycocotools/coco.py +++ b/PythonAPI/pycocotools/coco.py @@ -198,6 +198,29 @@ def getImgIds(self, imgIds=[], catIds=[]): ids &= set(self.catToImgs[catId]) return list(ids) + #get IDs of images which contain one or more of each category + def getImgIdsUnion(self, imgIds=[], catIds=[]): + ''' + Get img ids that satisfy given filter conditions. + :param imgIds (int array) : get imgs for given ids + :param catIds (int array) : get imgs with all given cats + :return: ids (int array) : integer array of img ids + ''' + imgIds = imgIds if _isArrayLike(imgIds) else [imgIds] + catIds = catIds if _isArrayLike(catIds) else [catIds] + + if len(imgIds) == len(catIds) == 0: + ids = self.imgs.keys() + else: + ids = set(imgIds) + for i, catId in enumerate(catIds): + if i == 0 and len(ids) == 0: + ids = set(self.catToImgs[catId]) + else: + ids |= set(self.catToImgs[catId]) + return list(ids) + + def loadAnns(self, ids=[]): """ Load anns with the specified ids.