Skip to content

Commit 15f9128

Browse files
Skeleton for keypoints tutorial (#9209)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent a8dc530 commit 15f9128

File tree

4 files changed

+123
-1
lines changed

4 files changed

+123
-1
lines changed

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(self, src_dir):
8888
"plot_transforms_e2e.py",
8989
"plot_cutmix_mixup.py",
9090
"plot_rotated_box_transforms.py",
91+
"plot_keypoints_transforms.py",
9192
"plot_custom_transforms.py",
9293
"plot_tv_tensors.py",
9394
"plot_custom_tv_tensors.py",

gallery/assets/pottery.jpg

89.8 KB
Loading

gallery/transforms/helpers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import matplotlib.pyplot as plt
22
import torch
3-
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
3+
from torchvision.utils import draw_bounding_boxes, draw_keypoints, draw_segmentation_masks
44
from torchvision import tv_tensors
55
from torchvision.transforms import v2
66
from torchvision.transforms.v2 import functional as F
@@ -18,6 +18,7 @@ def plot(imgs, row_title=None, bbox_width=3, **imshow_kwargs):
1818
for col_idx, img in enumerate(row):
1919
boxes = None
2020
masks = None
21+
points = None
2122
if isinstance(img, tuple):
2223
img, target = img
2324
if isinstance(target, dict):
@@ -30,6 +31,8 @@ def plot(imgs, row_title=None, bbox_width=3, **imshow_kwargs):
3031
# work with this specific format.
3132
if tv_tensors.is_rotated_bounding_format(boxes.format):
3233
boxes = v2.ConvertBoundingBoxFormat("xyxyxyxy")(boxes)
34+
elif isinstance(target, tv_tensors.KeyPoints):
35+
points = target
3336
else:
3437
raise ValueError(f"Unexpected target type: {type(target)}")
3538
img = F.to_image(img)
@@ -44,6 +47,8 @@ def plot(imgs, row_title=None, bbox_width=3, **imshow_kwargs):
4447
img = draw_bounding_boxes(img, boxes, colors="yellow", width=bbox_width)
4548
if masks is not None:
4649
img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65)
50+
if points is not None:
51+
img = draw_keypoints(img, points, colors="red", radius=10)
4752

4853
ax = axs[row_idx, col_idx]
4954
ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs)
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""
2+
===============================================================
3+
Transforms on KeyPoints
4+
===============================================================
5+
6+
This example illustrates how to define and use keypoints.
7+
For this tutorial, we use this picture of a ceramic figure from the pre-columbian period.
8+
The image is specified "public domain" (https://www.metmuseum.org/art/collection/search/502727).
9+
10+
.. note::
11+
Support for keypoints was released in TorchVision 0.23 and is
12+
currently a BETA feature. We don't expect the API to change, but there may
13+
be some rare edge-cases. If you find any issues, please report them on
14+
our bug tracker: https://github.com/pytorch/vision/issues?q=is:open+is:issue
15+
16+
First, a bit of setup code:
17+
"""
18+
19+
# %%
20+
from PIL import Image
21+
from pathlib import Path
22+
import matplotlib.pyplot as plt
23+
24+
25+
import torch
26+
from torchvision.tv_tensors import KeyPoints
27+
from torchvision.transforms import v2
28+
from helpers import plot
29+
30+
plt.rcParams["figure.figsize"] = [10, 5]
31+
plt.rcParams["savefig.bbox"] = "tight"
32+
33+
# if you change the seed, make sure that the transformed output
34+
# still make sense
35+
torch.manual_seed(0)
36+
37+
# If you're trying to run that on Colab, you can download the assets and the
38+
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
39+
orig_img = Image.open(Path('../assets') / 'pottery.jpg')
40+
41+
# %%
42+
# Creating KeyPoints
43+
# -------------------------------
44+
# Key points are created by instantiating the
45+
# :class:`~torchvision.tv_tensors.KeyPoints` class.
46+
47+
48+
orig_pts = KeyPoints(
49+
[
50+
[
51+
[445, 700], # nose
52+
[320, 660],
53+
[370, 660],
54+
[420, 660], # left eye
55+
[300, 620],
56+
[420, 620], # left eyebrow
57+
[475, 665],
58+
[515, 665],
59+
[555, 655], # right eye
60+
[460, 625],
61+
[560, 600], # right eyebrow
62+
[370, 780],
63+
[450, 760],
64+
[540, 780],
65+
[450, 820], # mouth
66+
],
67+
],
68+
canvas_size=(orig_img.size[1], orig_img.size[0]),
69+
)
70+
71+
plot([(orig_img, orig_pts)])
72+
73+
# %%
74+
# Transforms illustrations
75+
# ------------------------
76+
#
77+
# Using :class:`~torchvision.transforms.RandomRotation`:
78+
rotater = v2.RandomRotation(degrees=(0, 180), expand=True)
79+
rotated_imgs = [rotater((orig_img, orig_pts)) for _ in range(4)]
80+
plot([(orig_img, orig_pts)] + rotated_imgs)
81+
82+
# %%
83+
# Using :class:`~torchvision.transforms.Pad`:
84+
padded_imgs_and_points = [
85+
v2.Pad(padding=padding)(orig_img, orig_pts)
86+
for padding in (30, 50, 100, 200)
87+
]
88+
plot([(orig_img, orig_pts)] + padded_imgs_and_points)
89+
90+
# %%
91+
# Using :class:`~torchvision.transforms.Resize`:
92+
resized_imgs = [
93+
v2.Resize(size=size)(orig_img, orig_pts)
94+
for size in (300, 500, 1000, orig_img.size)
95+
]
96+
plot([(orig_img, orig_pts)] + resized_imgs)
97+
98+
# %%
99+
# Using :class:`~torchvision.transforms.RandomPerspective`:
100+
perspective_transformer = v2.RandomPerspective(distortion_scale=0.6, p=1.0)
101+
perspective_imgs = [perspective_transformer(orig_img, orig_pts) for _ in range(4)]
102+
plot([(orig_img, orig_pts)] + perspective_imgs)
103+
104+
# %%
105+
# Using :class:`~torchvision.transforms.CenterCrop`:
106+
center_crops_and_points = [
107+
v2.CenterCrop(size=size)(orig_img, orig_pts)
108+
for size in (300, 500, 1000, orig_img.size)
109+
]
110+
plot([(orig_img, orig_pts)] + center_crops_and_points)
111+
112+
# %%
113+
# Using :class:`~torchvision.transforms.RandomRotation`:
114+
rotater = v2.RandomRotation(degrees=(0, 180))
115+
rotated_imgs = [rotater((orig_img, orig_pts)) for _ in range(4)]
116+
plot([(orig_img, orig_pts)] + rotated_imgs)

0 commit comments

Comments
 (0)