Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7016,6 +7016,29 @@ def test_parallelogram_to_bounding_boxes(input_size, device):
actual = _parallelogram_to_bounding_boxes(parallelogram)
torch.testing.assert_close(actual, expected)

# Test the transformation of a simple parallelogram.
# 1
# 1-2 / 2
# / / -> / /
# 4-3 4 /
# 3
#
# 1
# 1-2 \ 2
# \ \ -> \ \
# 4-3 4 \
# 3
parallelogram = torch.tensor(
[[0, 4, 3, 1, 5, 1, 2, 4], [0, 1, 2, 1, 5, 4, 3, 4]],
dtype=torch.float32,
)
expected = torch.tensor(
[[0, 4, 4, 0, 5, 1, 1, 5], [0, 1, 1, 0, 5, 4, 4, 5]],
dtype=torch.float32,
)
actual = _parallelogram_to_bounding_boxes(parallelogram)
torch.testing.assert_close(actual, expected)


@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, tv_tensors.Image))
@pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite"))
Expand Down
67 changes: 29 additions & 38 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,54 +451,45 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso
torch.Tensor: Tensor of same shape as input containing the rectangle coordinates.
The output maintains the same dtype as the input.
"""
original_shape = parallelogram.shape
dtype = parallelogram.dtype
acceptable_dtypes = [torch.float32, torch.float64]
need_cast = dtype not in acceptable_dtypes
if need_cast:
# Up-case to avoid overflow for square operations
parallelogram = parallelogram.to(torch.float32)
out_boxes = parallelogram.clone()

# Calculate parallelogram diagonal vectors
dx13 = parallelogram[..., 4] - parallelogram[..., 0]
dy13 = parallelogram[..., 5] - parallelogram[..., 1]
dx42 = parallelogram[..., 2] - parallelogram[..., 6]
dy42 = parallelogram[..., 3] - parallelogram[..., 7]
dx12 = parallelogram[..., 2] - parallelogram[..., 0]
dy12 = parallelogram[..., 1] - parallelogram[..., 3]
diag13 = torch.sqrt(dx13**2 + dy13**2)
diag24 = torch.sqrt(dx42**2 + dy42**2)
mask = diag13 > diag24

# Calculate rotation angle in radians
r_rad = torch.atan2(dy12, dx12)
cos, sin = torch.cos(r_rad), torch.sin(r_rad)

# Calculate width using the angle between diagonal and rotation
w = torch.where(
mask,
diag13 * torch.abs(torch.sin(torch.atan2(dx13, dy13) - r_rad)),
diag24 * torch.abs(torch.sin(torch.atan2(dx42, dy42) - r_rad)),
)

delta_x = w * cos
delta_y = w * sin
# Update coordinates to form a rectangle
# Keeping the points (x1, y1) and (x3, y3) unchanged.
out_boxes[..., 2] = torch.where(mask, parallelogram[..., 0] + delta_x, parallelogram[..., 2])
out_boxes[..., 3] = torch.where(mask, parallelogram[..., 1] - delta_y, parallelogram[..., 3])
out_boxes[..., 6] = torch.where(mask, parallelogram[..., 4] - delta_x, parallelogram[..., 6])
out_boxes[..., 7] = torch.where(mask, parallelogram[..., 5] + delta_y, parallelogram[..., 7])

# Keeping the points (x2, y2) and (x4, y4) unchanged.
out_boxes[..., 0] = torch.where(~mask, parallelogram[..., 2] - delta_x, parallelogram[..., 0])
out_boxes[..., 1] = torch.where(~mask, parallelogram[..., 3] + delta_y, parallelogram[..., 1])
out_boxes[..., 4] = torch.where(~mask, parallelogram[..., 6] + delta_x, parallelogram[..., 4])
out_boxes[..., 5] = torch.where(~mask, parallelogram[..., 7] - delta_y, parallelogram[..., 5])
x1, y1, x2, y2, x3, y3, x4, y4 = parallelogram.unbind(-1)
cx = (x1 + x3) / 2
cy = (y1 + y3) / 2

# Calculate width, height, and rotation angle of the parallelogram
wp = torch.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
hp = torch.sqrt((x4 - x1) ** 2 + (y4 - y1) ** 2)
r12 = torch.atan2(y1 - y2, x2 - x1)
r14 = torch.atan2(y1 - y4, x4 - x1)
r_rad = r12 - r14
sign = torch.where(r_rad > torch.pi / 2, -1, 1)
cos, sin = r_rad.cos(), r_rad.sin()

# Calculate width, height, and rotation angle of the rectangle
w = torch.where(wp < hp, wp * sin, wp + hp * cos * sign)
h = torch.where(wp > hp, hp * sin, hp + wp * cos * sign)
r_rad = torch.where(hp > wp, r14 + torch.pi / 2, r12)
cos, sin = r_rad.cos(), r_rad.sin()

x1 = cx - w / 2 * cos - h / 2 * sin
y1 = cy - h / 2 * cos + w / 2 * sin
x2 = cx + w / 2 * cos - h / 2 * sin
y2 = cy - h / 2 * cos - w / 2 * sin
x3 = cx + w / 2 * cos + h / 2 * sin
y3 = cy + h / 2 * cos - w / 2 * sin
x4 = cx - w / 2 * cos + h / 2 * sin
y4 = cy + h / 2 * cos + w / 2 * sin
out_boxes = torch.stack((x1, y1, x2, y2, x3, y3, x4, y4), dim=-1).reshape(original_shape)

if need_cast:
out_boxes = out_boxes.to(dtype)

return out_boxes


Expand Down
Loading