Skip to content

Commit

Permalink
add reansform option 'resize-pad' to dataset_tool.py that first scale…
Browse files Browse the repository at this point in the history
…s an image into bounds then pads the image to the specified size
  • Loading branch information
nmichlo committed Nov 9, 2021
1 parent a5a69f5 commit e53d47c
Showing 1 changed file with 41 additions and 2 deletions.
43 changes: 41 additions & 2 deletions dataset_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,16 +250,55 @@ def center_crop_wide(width, height, img):
canvas[(width - height) // 2 : (width + height) // 2, :] = img
return canvas

def _to_rgb(img):
# add channel dim
if img.ndim == 2:
img = img[:, :, None]
assert img.ndim == 3, f'input image has incorrect number of dimensions, required 2 (H, W) or 3 (H, W, C), got: {img.shape}'
# to RGB
if img.shape[-1] == 1:
img = img.repeat(3, axis=-1)
elif img.shape[-1] == 4:
img = img[:, :, :3]
assert img.shape[-1] == 3, f'input image must have 1 or 3 channels, got: {img.shape}'
return img

def resize_pad(width, height, img):
img = _to_rgb(img)
# exit early
img_h, img_w = img.shape[:2]
if width == img_w and height == img_h:
return img
# get scale size, avoiding precision errors
scale_ratio = max(img_h / height, img_w / width)
scale_h = int(round(img_h / scale_ratio, 5))
scale_w = int(round(img_w / scale_ratio, 5))
assert scale_h <= height and scale_w <= width, f'scaled shape {scale_w}x{scale_h} is not smaller than or equal to the required shape: {width}x{height} this is a bug:'
# scale image
img = scale(scale_w, scale_h, img)
# pad the image if needed
pad_h, pad_w = height - scale_h, width - scale_w
if pad_h != 0 or pad_w != 0:
pad_dims = [[np.floor(pad_h/2), np.ceil(pad_h/2)], [np.floor(pad_w/2), np.ceil(pad_w/2)], [0, 0]] # (H,W,C)
img = np.pad(img, np.array(pad_dims).astype('int'))
# check the shape
assert img.shape[0] == height and img.shape[1] == width, f'output shape {img.shape[1]}x{img.shape[0]} does not match required shape: {width}x{height} this is a bug!'
return img

if transform is None:
return functools.partial(scale, output_width, output_height)
if transform == 'center-crop':
if (output_width is None) or (output_height is None):
error ('must specify --resolution=WxH when using ' + transform + 'transform')
error ('must specify --resolution=WxH when using ' + transform + ' transform')
return functools.partial(center_crop, output_width, output_height)
if transform == 'center-crop-wide':
if (output_width is None) or (output_height is None):
error ('must specify --resolution=WxH when using ' + transform + ' transform')
return functools.partial(center_crop_wide, output_width, output_height)
if transform == 'resize-pad':
if (output_width is None) or (output_height is None):
error ('must specify --resolution=WxH when using ' + transform + ' transform')
return functools.partial(resize_pad, output_width, output_height)
assert False, 'unknown transform'

#----------------------------------------------------------------------------
Expand Down Expand Up @@ -321,7 +360,7 @@ def folder_write_bytes(fname: str, data: Union[bytes, str]):
@click.option('--source', help='Directory or archive name for input dataset', required=True, metavar='PATH')
@click.option('--dest', help='Output directory or archive name for output dataset', required=True, metavar='PATH')
@click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None)
@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide']))
@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide', 'resize-pad']))
@click.option('--resolution', help='Output resolution (e.g., \'512x512\')', metavar='WxH', type=parse_tuple)
def convert_dataset(
ctx: click.Context,
Expand Down

0 comments on commit e53d47c

Please sign in to comment.