Skip to content

Commit

Permalink
Merge pull request #769 from roboflow/bugfix-keypoint-visualization
Browse files Browse the repository at this point in the history
Bugfix: keypoint visualization block
  • Loading branch information
PawelPeczek-Roboflow authored Nov 4, 2024
2 parents f95802b + aedfa78 commit 7247ce1
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ def getAnnotator(

# Function to convert detections to keypoints
def convert_detections_to_keypoints(self, detections):
if len(detections) == 0:
return sv.KeyPoints.empty()
keypoints_xy = detections.data["keypoints_xy"]
keypoints_confidence = detections.data["keypoints_confidence"]
keypoints_class_name = detections.data["keypoints_class_name"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,45 @@ def test_keypoint_visualization_block_nocopy() -> None:
output.get("image").numpy_image.__array_interface__["data"][0]
== start_image.__array_interface__["data"][0]
)

def test_keypoint_visualization_block_no_predictions() -> None:
# given
block = KeypointVisualizationBlockV1()
start_image = np.zeros((1000, 1000, 3), dtype=np.uint8)

empty_predictions = sv.Detections.empty()

output = block.run(
image=WorkflowImageData(
parent_metadata=ImageParentMetadata(parent_id="some"),
numpy_image=start_image,
),
predictions=empty_predictions,
copy_image=True,
annotator_type="edge",
color="#A351FB",
text_color="black",
text_scale=0.5,
text_thickness=1,
text_padding=10,
thickness=2,
radius=10,
)

assert output is not None
assert "image" in output
assert hasattr(output.get("image"), "numpy_image")

# dimensions of output match input
assert output.get("image").numpy_image.shape == (1000, 1000, 3)

# check if the image is unchanged since there were no predictions
assert np.array_equal(
output.get("image").numpy_image, np.zeros((1000, 1000, 3), dtype=np.uint8)
)

# check that the image is copied
assert (
output.get("image").numpy_image.__array_interface__["data"][0]
!= start_image.__array_interface__["data"][0]
)

0 comments on commit 7247ce1

Please sign in to comment.