Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added transform option 'resize-pad' to dataset_tool.py #64

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions dataset_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,16 +250,55 @@ def center_crop_wide(width, height, img):
canvas[(width - height) // 2 : (width + height) // 2, :] = img
return canvas

def _to_rgb(img):
# add channel dim
if img.ndim == 2:
img = img[:, :, None]
assert img.ndim == 3, f'input image has incorrect number of dimensions, required 2 (H, W) or 3 (H, W, C), got: {img.shape}'
# to RGB
if img.shape[-1] == 1:
img = img.repeat(3, axis=-1)
elif img.shape[-1] == 4:
img = img[:, :, :3]
assert img.shape[-1] == 3, f'input image must have 1 or 3 channels, got: {img.shape}'
return img

def resize_pad(width, height, img):
img = _to_rgb(img)
# exit early
img_h, img_w = img.shape[:2]
if width == img_w and height == img_h:
return img
# get scale size, avoiding precision errors
scale_ratio = max(img_h / height, img_w / width)
scale_h = int(round(img_h / scale_ratio, 5))
scale_w = int(round(img_w / scale_ratio, 5))
assert scale_h <= height and scale_w <= width, f'scaled shape {scale_w}x{scale_h} is not smaller than or equal to the required shape: {width}x{height} this is a bug:'
# scale image
img = scale(scale_w, scale_h, img)
# pad the image if needed
pad_h, pad_w = height - scale_h, width - scale_w
if pad_h != 0 or pad_w != 0:
pad_dims = [[np.floor(pad_h/2), np.ceil(pad_h/2)], [np.floor(pad_w/2), np.ceil(pad_w/2)], [0, 0]] # (H,W,C)
img = np.pad(img, np.array(pad_dims).astype('int'))
# check the shape
assert img.shape[0] == height and img.shape[1] == width, f'output shape {img.shape[1]}x{img.shape[0]} does not match required shape: {width}x{height} this is a bug!'
return img

if transform is None:
return functools.partial(scale, output_width, output_height)
if transform == 'center-crop':
if (output_width is None) or (output_height is None):
error ('must specify --resolution=WxH when using ' + transform + 'transform')
error ('must specify --resolution=WxH when using ' + transform + ' transform')
return functools.partial(center_crop, output_width, output_height)
if transform == 'center-crop-wide':
if (output_width is None) or (output_height is None):
error ('must specify --resolution=WxH when using ' + transform + ' transform')
return functools.partial(center_crop_wide, output_width, output_height)
if transform == 'resize-pad':
if (output_width is None) or (output_height is None):
error ('must specify --resolution=WxH when using ' + transform + ' transform')
return functools.partial(resize_pad, output_width, output_height)
assert False, 'unknown transform'

#----------------------------------------------------------------------------
Expand Down Expand Up @@ -321,7 +360,7 @@ def folder_write_bytes(fname: str, data: Union[bytes, str]):
@click.option('--source', help='Directory or archive name for input dataset', required=True, metavar='PATH')
@click.option('--dest', help='Output directory or archive name for output dataset', required=True, metavar='PATH')
@click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None)
@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide']))
@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide', 'resize-pad']))
@click.option('--resolution', help='Output resolution (e.g., \'512x512\')', metavar='WxH', type=parse_tuple)
def convert_dataset(
ctx: click.Context,
Expand Down