Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

draw_keypoints() float support #8276

Merged
merged 6 commits into from
Mar 4, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
method description update
GsnMithra committed Feb 17, 2024
commit 1f2b77ca13a205a166ce27fc111b699eb7c5d5b2
6 changes: 3 additions & 3 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
@@ -336,13 +336,13 @@ def draw_keypoints(

"""
Draws Keypoints on given RGB image.
The values of the input image should be uint8 between 0 and 255.
The image values should be uint8 in [0, 255] or float in [0, 1].
Keypoints can be drawn for multiple instances at a time.
This method allows that keypoints and their connectivity are drawn based on the visibility of this keypoint.
Args:
image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float.
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoint locations for each of the N instances,
in the format [x, y].
connectivity (List[Tuple[int, int]]]): A List of tuple where each tuple contains a pair of keypoints
@@ -363,7 +363,7 @@ def draw_keypoints(
For more details, see :ref:`draw_keypoints_with_visibility`.
Returns:
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
img (Tensor[C, H, W]): Image Tensor with keypoints drawn.
"""

if not torch.jit.is_scripting() and not torch.jit.is_tracing():