diff --git a/torchvision/utils.py b/torchvision/utils.py index ea1c17230b7..1ad78fe14c3 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -336,13 +336,13 @@ def draw_keypoints( """ Draws Keypoints on given RGB image. - The values of the input image should be uint8 between 0 and 255. + The image values should be uint8 in [0, 255] or float in [0, 1]. Keypoints can be drawn for multiple instances at a time. This method allows that keypoints and their connectivity are drawn based on the visibility of this keypoint. Args: - image (Tensor): Tensor of shape (3, H, W) and dtype uint8. + image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float. keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoint locations for each of the N instances, in the format [x, y]. connectivity (List[Tuple[int, int]]]): A List of tuple where each tuple contains a pair of keypoints @@ -363,7 +363,7 @@ def draw_keypoints( For more details, see :ref:`draw_keypoints_with_visibility`. Returns: - img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. + img (Tensor[C, H, W]): Image Tensor with keypoints drawn. """ if not torch.jit.is_scripting() and not torch.jit.is_tracing():