diff --git a/dataset_tool.py b/dataset_tool.py index 747189fd..94dbe89f 100644 --- a/dataset_tool.py +++ b/dataset_tool.py @@ -250,16 +250,55 @@ def center_crop_wide(width, height, img): canvas[(width - height) // 2 : (width + height) // 2, :] = img return canvas + def _to_rgb(img): + # add channel dim + if img.ndim == 2: + img = img[:, :, None] + assert img.ndim == 3, f'input image has incorrect number of dimensions, required 2 (H, W) or 3 (H, W, C), got: {img.shape}' + # to RGB + if img.shape[-1] == 1: + img = img.repeat(3, axis=-1) + elif img.shape[-1] == 4: + img = img[:, :, :3] + assert img.shape[-1] == 3, f'input image must have 1 or 3 channels, got: {img.shape}' + return img + + def resize_pad(width, height, img): + img = _to_rgb(img) + # exit early + img_h, img_w = img.shape[:2] + if width == img_w and height == img_h: + return img + # get scale size, avoiding precision errors + scale_ratio = max(img_h / height, img_w / width) + scale_h = int(round(img_h / scale_ratio, 5)) + scale_w = int(round(img_w / scale_ratio, 5)) + assert scale_h <= height and scale_w <= width, f'scaled shape {scale_w}x{scale_h} is not smaller than or equal to the required shape: {width}x{height} this is a bug:' + # scale image + img = scale(scale_w, scale_h, img) + # pad the image if needed + pad_h, pad_w = height - scale_h, width - scale_w + if pad_h != 0 or pad_w != 0: + pad_dims = [[np.floor(pad_h/2), np.ceil(pad_h/2)], [np.floor(pad_w/2), np.ceil(pad_w/2)], [0, 0]] # (H,W,C) + img = np.pad(img, np.array(pad_dims).astype('int')) + # check the shape + assert img.shape[0] == height and img.shape[1] == width, f'output shape {img.shape[1]}x{img.shape[0]} does not match required shape: {width}x{height} this is a bug!' + return img + if transform is None: return functools.partial(scale, output_width, output_height) if transform == 'center-crop': if (output_width is None) or (output_height is None): - error ('must specify --resolution=WxH when using ' + transform + 'transform') + error ('must specify --resolution=WxH when using ' + transform + ' transform') return functools.partial(center_crop, output_width, output_height) if transform == 'center-crop-wide': if (output_width is None) or (output_height is None): error ('must specify --resolution=WxH when using ' + transform + ' transform') return functools.partial(center_crop_wide, output_width, output_height) + if transform == 'resize-pad': + if (output_width is None) or (output_height is None): + error ('must specify --resolution=WxH when using ' + transform + ' transform') + return functools.partial(resize_pad, output_width, output_height) assert False, 'unknown transform' #---------------------------------------------------------------------------- @@ -321,7 +360,7 @@ def folder_write_bytes(fname: str, data: Union[bytes, str]): @click.option('--source', help='Directory or archive name for input dataset', required=True, metavar='PATH') @click.option('--dest', help='Output directory or archive name for output dataset', required=True, metavar='PATH') @click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None) -@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide'])) +@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide', 'resize-pad'])) @click.option('--resolution', help='Output resolution (e.g., \'512x512\')', metavar='WxH', type=parse_tuple) def convert_dataset( ctx: click.Context,