Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add visibility parameter to draw_keypoints() #8225

Merged
merged 17 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion gallery/others/plot_visualization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def show(imgs):
show(res)

# %%
# As we see the keypoints appear as colored circles over the image.
# As we see, the keypoints appear as colored circles over the image.
# The coco keypoints for a person are ordered and represent the following list.\

coco_keypoints = [
Expand Down Expand Up @@ -460,3 +460,63 @@ def show(imgs):

res = draw_keypoints(person_int, keypoints, connectivity=connect_skeleton, colors="blue", radius=4, width=3)
show(res)

# %%
# That looks pretty good.
#
# .. _draw_keypoints_with_visibility:
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved
#
# Keypoint Visibility
bmmtstb marked this conversation as resolved.
Show resolved Hide resolved
# ^^^^^^^^^^^^^^^^^^^
# Let's have a look at the results, another keypoint prediction module produced, and show the connectivity:

prediction = torch.tensor(
[[[208.0176, 214.2409, 1.0000],
[000.0000, 000.0000, 0.0000],
[197.8246, 210.6392, 1.0000],
[000.0000, 000.0000, 0.0000],
[178.6378, 217.8425, 1.0000],
[221.2086, 253.8591, 1.0000],
[160.6502, 269.4662, 1.0000],
[243.9929, 304.2822, 1.0000],
[138.4654, 328.8935, 1.0000],
[277.5698, 340.8990, 1.0000],
[153.4551, 374.5145, 1.0000],
[000.0000, 000.0000, 0.0000],
[226.0053, 370.3125, 1.0000],
[221.8081, 455.5516, 1.0000],
[273.9723, 448.9486, 1.0000],
[193.6275, 546.1933, 1.0000],
[273.3727, 545.5930, 1.0000]]]
)

res = draw_keypoints(person_int, prediction, connectivity=connect_skeleton, colors="blue", radius=4, width=3)
show(res)

# %%
# What happened there?
# The model, which predicted the new keypoints,
# can't detect the three points that are hidden on the upper left body of the skateboarder.
# More precisely, the model predicted that `(x, y, vis) = (0, 0, 0)` for the left_eye, left_ear, and left_hip.
# So we definitely don't want to display those keypoints and connections, and you don't have to.
# Looking at the parameters of :func:`~torchvision.utils.draw_keypoints`,
# we can see that we can pass a visibility tensor as an additional argument.
# Given the models' prediction, we have the visibility as the third keypoint dimension, we just need to extract it.
# Let's split the ``prediction`` into the keypoint coordinates and their respective visibility,
# and pass both of them as arguments to :func:`~torchvision.utils.draw_keypoints`.

coordinates, visibility = prediction.split([2, 1], dim=-1)
visibility = visibility.bool()

res = draw_keypoints(
person_int, coordinates, visibility=visibility, connectivity=connect_skeleton, colors="blue", radius=4, width=3
)
show(res)

# %%
# We can see that the undetected keypoints are not draw and the invisible keypoint connections were skipped.
# This can reduce the noise on images with multiple detections, or in cases like ours,
# when the keypoint-prediction model missed some detections.
# Most torch keypoint-prediction models return the visibility for every prediction, ready for you to use it.
# The :func:`~torchvision.models.detection.keypointrcnn_resnet50_fpn` model,
# which we used in the first case, does so too.
48 changes: 48 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,42 @@ def test_draw_keypoints_colored(colors):
assert_equal(img, img_cp)


@pytest.mark.parametrize("connectivity", [None, [(0, 1)], [(0, 1), (1, 2)]])
@pytest.mark.parametrize(
"vis",
[
None,
torch.ones((2, 3), dtype=torch.bool),
torch.ones((2, 3), dtype=torch.int),
torch.ones((2, 3, 1), dtype=torch.float),
],
)
def test_draw_keypoints_visibility(connectivity, vis):
bmmtstb marked this conversation as resolved.
Show resolved Hide resolved
# Keypoints is declared on top as global variable
keypoints_cp = keypoints.clone()
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
img_cp = img.clone()
if vis is None:
vis_cp = vis
else:
vis_cp = vis.clone()

result = utils.draw_keypoints(
image=img,
keypoints=keypoints,
visibility=vis,
connectivity=connectivity,
)
assert result.size(0) == 3
assert_equal(keypoints, keypoints_cp)
assert_equal(img, img_cp)
if vis_cp is None:
assert vis is None
else:
assert_equal(vis, vis_cp.squeeze_(-1))
assert vis.dtype == vis_cp.dtype


def test_draw_keypoints_errors():
h, w = 10, 10
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
Expand All @@ -379,6 +415,18 @@ def test_draw_keypoints_errors():
with pytest.raises(ValueError, match="keypoints must be of shape"):
invalid_keypoints = torch.tensor([[10, 10, 10, 10], [5, 6, 7, 8]], dtype=torch.float)
utils.draw_keypoints(image=img, keypoints=invalid_keypoints)
with pytest.raises(ValueError, match=re.escape("visibility must be of shape (num_instances, K)")):
one_dim_visibility = torch.tensor([True, True, True], dtype=torch.bool)
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=one_dim_visibility)
with pytest.raises(ValueError, match=re.escape("visibility must be of shape (num_instances, K)")):
three_dim_visibility = torch.ones((2, 3, 4), dtype=torch.bool)
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=three_dim_visibility)
with pytest.raises(ValueError, match="keypoints and visibility must have the same dimensionality"):
vis_wrong_n = torch.ones((3, 3), dtype=torch.bool)
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=vis_wrong_n)
with pytest.raises(ValueError, match="keypoints and visibility must have the same dimensionality"):
vis_wrong_k = torch.ones((2, 4), dtype=torch.bool)
utils.draw_keypoints(image=img, keypoints=keypoints, visibility=vis_wrong_k)


@pytest.mark.parametrize("batch", (True, False))
Expand Down
53 changes: 43 additions & 10 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def draw_segmentation_masks(
def draw_keypoints(
image: torch.Tensor,
keypoints: torch.Tensor,
visibility: Optional[torch.Tensor] = None,
bmmtstb marked this conversation as resolved.
Show resolved Hide resolved
connectivity: Optional[List[Tuple[int, int]]] = None,
colors: Optional[Union[str, Tuple[int, int, int]]] = None,
radius: int = 2,
Expand All @@ -336,13 +337,25 @@ def draw_keypoints(
"""
Draws Keypoints on given RGB image.
The values of the input image should be uint8 between 0 and 255.
Keypoints can be drawn for multiple instances at a time.

This method allows that keypoints and their connectivity are drawn based on the visibility of this keypoint.

Args:
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances,
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoint locations for each of the N instances,
in the format [x, y].
connectivity (List[Tuple[int, int]]]): A List of tuple where,
each tuple contains pair of keypoints to be connected.
visibility (Tensor): Tensor of shape (num_instances, K) specifying the visibility of the K
bmmtstb marked this conversation as resolved.
Show resolved Hide resolved
keypoints for each of the N instances.
True means that the respective keypoint is visible and should be drawn.
False means invisible, so neither the point nor possible connections containing it are drawn.
The input tensor will be cast to bool.
Default ``None`` means that all the keypoints are visible.
connectivity (List[Tuple[int, int]]]): A List of tuple where each tuple contains a pair of keypoints
to be connected.
If at least one of the two connected keypoints has a ``visibility`` of False,
this specific connection is not drawn.
Exclusions due to invisibility are computed per-instance.
colors (str, Tuple): The color can be represented as
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``.
radius (int): Integer denoting radius of keypoint.
Expand All @@ -354,6 +367,7 @@ def draw_keypoints(

if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(draw_keypoints)
# validate image
if not isinstance(image, torch.Tensor):
raise TypeError(f"The image must be a tensor, got {type(image)}")
elif image.dtype != torch.uint8:
Expand All @@ -363,24 +377,43 @@ def draw_keypoints(
elif image.size()[0] != 3:
raise ValueError("Pass an RGB image. Other Image formats are not supported")

# validate keypoints
if keypoints.ndim != 3:
raise ValueError("keypoints must be of shape (num_instances, K, 2)")

# validate visibility
if visibility is None: # set default
visibility = torch.ones(keypoints.shape[:-1], dtype=torch.bool)
# If the last dimension is 1, e.g., after calling split([2, 1], dim=-1) on the output of a keypoint-prediction
# model, make sure visibility has shape (num_instances, K).
# Iff K = 1, this has unwanted behavior, but K=1 does not really make sense in the first place.
visibility.squeeze_(-1)
bmmtstb marked this conversation as resolved.
Show resolved Hide resolved
if visibility.ndim != 2:
raise ValueError("visibility must be of shape (num_instances, K)")
bmmtstb marked this conversation as resolved.
Show resolved Hide resolved
if visibility.shape != keypoints.shape[:-1]:
raise ValueError("keypoints and visibility must have the same dimensionality for num_instances and K.")
bmmtstb marked this conversation as resolved.
Show resolved Hide resolved

ndarr = image.permute(1, 2, 0).cpu().numpy()
img_to_draw = Image.fromarray(ndarr)
draw = ImageDraw.Draw(img_to_draw)
img_kpts = keypoints.to(torch.int64).tolist()

for kpt_id, kpt_inst in enumerate(img_kpts):
for inst_id, kpt in enumerate(kpt_inst):
x1 = kpt[0] - radius
x2 = kpt[0] + radius
y1 = kpt[1] - radius
y2 = kpt[1] + radius
img_vis = visibility.cpu().bool().tolist()

for kpt_id, (kpt_inst, vis_inst) in enumerate(zip(img_kpts, img_vis)):
for inst_id, (kpt_coord, kp_vis) in enumerate(zip(kpt_inst, vis_inst)):
bmmtstb marked this conversation as resolved.
Show resolved Hide resolved
if not kp_vis: # skip drawing ellipse if the current keypoint is invisible
bmmtstb marked this conversation as resolved.
Show resolved Hide resolved
continue
x1 = kpt_coord[0] - radius
x2 = kpt_coord[0] + radius
y1 = kpt_coord[1] - radius
y2 = kpt_coord[1] + radius
draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0)

if connectivity:
for connection in connectivity:
# connection is skipped if one of the keypoints is not visible
if not vis_inst[connection[0]] or not vis_inst[connection[1]]:
continue
start_pt_x = kpt_inst[connection[0]][0]
start_pt_y = kpt_inst[connection[0]][1]

Expand Down
Loading