From 87f04fa445134a90a3cc7ffbe29b3d914d8f60a8 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Tue, 9 Nov 2021 11:39:40 +0200 Subject: [PATCH] add transform option 'resize-pad' to dataset_tool.py First scales an image within the specified bounds then pads the scaled image if one axis is smaller than what is required This operation preserves the image aspect ratio and content, without cropping or warping. --- dataset_tool.py | 43 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/dataset_tool.py b/dataset_tool.py index 747189fd..94dbe89f 100644 --- a/dataset_tool.py +++ b/dataset_tool.py @@ -250,16 +250,55 @@ def center_crop_wide(width, height, img): canvas[(width - height) // 2 : (width + height) // 2, :] = img return canvas + def _to_rgb(img): + # add channel dim + if img.ndim == 2: + img = img[:, :, None] + assert img.ndim == 3, f'input image has incorrect number of dimensions, required 2 (H, W) or 3 (H, W, C), got: {img.shape}' + # to RGB + if img.shape[-1] == 1: + img = img.repeat(3, axis=-1) + elif img.shape[-1] == 4: + img = img[:, :, :3] + assert img.shape[-1] == 3, f'input image must have 1 or 3 channels, got: {img.shape}' + return img + + def resize_pad(width, height, img): + img = _to_rgb(img) + # exit early + img_h, img_w = img.shape[:2] + if width == img_w and height == img_h: + return img + # get scale size, avoiding precision errors + scale_ratio = max(img_h / height, img_w / width) + scale_h = int(round(img_h / scale_ratio, 5)) + scale_w = int(round(img_w / scale_ratio, 5)) + assert scale_h <= height and scale_w <= width, f'scaled shape {scale_w}x{scale_h} is not smaller than or equal to the required shape: {width}x{height} this is a bug:' + # scale image + img = scale(scale_w, scale_h, img) + # pad the image if needed + pad_h, pad_w = height - scale_h, width - scale_w + if pad_h != 0 or pad_w != 0: + pad_dims = [[np.floor(pad_h/2), np.ceil(pad_h/2)], [np.floor(pad_w/2), np.ceil(pad_w/2)], [0, 0]] # (H,W,C) + img = np.pad(img, np.array(pad_dims).astype('int')) + # check the shape + assert img.shape[0] == height and img.shape[1] == width, f'output shape {img.shape[1]}x{img.shape[0]} does not match required shape: {width}x{height} this is a bug!' + return img + if transform is None: return functools.partial(scale, output_width, output_height) if transform == 'center-crop': if (output_width is None) or (output_height is None): - error ('must specify --resolution=WxH when using ' + transform + 'transform') + error ('must specify --resolution=WxH when using ' + transform + ' transform') return functools.partial(center_crop, output_width, output_height) if transform == 'center-crop-wide': if (output_width is None) or (output_height is None): error ('must specify --resolution=WxH when using ' + transform + ' transform') return functools.partial(center_crop_wide, output_width, output_height) + if transform == 'resize-pad': + if (output_width is None) or (output_height is None): + error ('must specify --resolution=WxH when using ' + transform + ' transform') + return functools.partial(resize_pad, output_width, output_height) assert False, 'unknown transform' #---------------------------------------------------------------------------- @@ -321,7 +360,7 @@ def folder_write_bytes(fname: str, data: Union[bytes, str]): @click.option('--source', help='Directory or archive name for input dataset', required=True, metavar='PATH') @click.option('--dest', help='Output directory or archive name for output dataset', required=True, metavar='PATH') @click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None) -@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide'])) +@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide', 'resize-pad'])) @click.option('--resolution', help='Output resolution (e.g., \'512x512\')', metavar='WxH', type=parse_tuple) def convert_dataset( ctx: click.Context,