Skip to content

Commit

Permalink
add transform option 'resize-pad' to dataset_tool.py
Browse files Browse the repository at this point in the history
First scales an image within the specified bounds then pads the scaled image if one axis is smaller than what is required

This operation preserves the image aspect ratio and content, without cropping or warping.
  • Loading branch information
nmichlo committed Nov 9, 2021
1 parent a5a69f5 commit 87f04fa
Showing 1 changed file with 41 additions and 2 deletions.
43 changes: 41 additions & 2 deletions dataset_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,16 +250,55 @@ def center_crop_wide(width, height, img):
canvas[(width - height) // 2 : (width + height) // 2, :] = img
return canvas

def _to_rgb(img):
# add channel dim
if img.ndim == 2:
img = img[:, :, None]
assert img.ndim == 3, f'input image has incorrect number of dimensions, required 2 (H, W) or 3 (H, W, C), got: {img.shape}'
# to RGB
if img.shape[-1] == 1:
img = img.repeat(3, axis=-1)
elif img.shape[-1] == 4:
img = img[:, :, :3]
assert img.shape[-1] == 3, f'input image must have 1 or 3 channels, got: {img.shape}'
return img

def resize_pad(width, height, img):
img = _to_rgb(img)
# exit early
img_h, img_w = img.shape[:2]
if width == img_w and height == img_h:
return img
# get scale size, avoiding precision errors
scale_ratio = max(img_h / height, img_w / width)
scale_h = int(round(img_h / scale_ratio, 5))
scale_w = int(round(img_w / scale_ratio, 5))
assert scale_h <= height and scale_w <= width, f'scaled shape {scale_w}x{scale_h} is not smaller than or equal to the required shape: {width}x{height} this is a bug:'
# scale image
img = scale(scale_w, scale_h, img)
# pad the image if needed
pad_h, pad_w = height - scale_h, width - scale_w
if pad_h != 0 or pad_w != 0:
pad_dims = [[np.floor(pad_h/2), np.ceil(pad_h/2)], [np.floor(pad_w/2), np.ceil(pad_w/2)], [0, 0]] # (H,W,C)
img = np.pad(img, np.array(pad_dims).astype('int'))
# check the shape
assert img.shape[0] == height and img.shape[1] == width, f'output shape {img.shape[1]}x{img.shape[0]} does not match required shape: {width}x{height} this is a bug!'
return img

if transform is None:
return functools.partial(scale, output_width, output_height)
if transform == 'center-crop':
if (output_width is None) or (output_height is None):
error ('must specify --resolution=WxH when using ' + transform + 'transform')
error ('must specify --resolution=WxH when using ' + transform + ' transform')
return functools.partial(center_crop, output_width, output_height)
if transform == 'center-crop-wide':
if (output_width is None) or (output_height is None):
error ('must specify --resolution=WxH when using ' + transform + ' transform')
return functools.partial(center_crop_wide, output_width, output_height)
if transform == 'resize-pad':
if (output_width is None) or (output_height is None):
error ('must specify --resolution=WxH when using ' + transform + ' transform')
return functools.partial(resize_pad, output_width, output_height)
assert False, 'unknown transform'

#----------------------------------------------------------------------------
Expand Down Expand Up @@ -321,7 +360,7 @@ def folder_write_bytes(fname: str, data: Union[bytes, str]):
@click.option('--source', help='Directory or archive name for input dataset', required=True, metavar='PATH')
@click.option('--dest', help='Output directory or archive name for output dataset', required=True, metavar='PATH')
@click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None)
@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide']))
@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide', 'resize-pad']))
@click.option('--resolution', help='Output resolution (e.g., \'512x512\')', metavar='WxH', type=parse_tuple)
def convert_dataset(
ctx: click.Context,
Expand Down

0 comments on commit 87f04fa

Please sign in to comment.