Skip to content

Commit 1622eab

Browse files
Make _get_perspective_coeffs device agnostic
1 parent f799a53 commit 1622eab

File tree

2 files changed

+48
-12
lines changed

2 files changed

+48
-12
lines changed

test/test_functional_tensor.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,35 @@ def test_perspective_batch(device, dims_and_points, dt):
434434
)
435435

436436

437+
@pytest.mark.parametrize("device", cpu_and_cuda())
438+
@pytest.mark.parametrize("dims_and_points", _get_data_dims_and_points_for_perspective())
439+
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
440+
def test_perspective_tensor_input(device, dims_and_points, dt):
441+
442+
if dt == torch.float16 and device == "cpu":
443+
# skip float16 on CPU case
444+
return
445+
446+
data_dims, (spoints, epoints) = dims_and_points
447+
print(spoints, epoints)
448+
449+
batch_tensors = _create_data_batch(*data_dims, num_samples=4, device=device)
450+
if dt is not None:
451+
batch_tensors = batch_tensors.to(dtype=dt)
452+
453+
# Ignore the equivalence between scripted and regular function on float16 cuda. The pixels at
454+
# the border may be entirely different due to small rounding errors.
455+
scripted_fn_atol = -1 if (dt == torch.float16 and device == "cuda") else 1e-8
456+
_test_fn_on_batch(
457+
batch_tensors,
458+
F.perspective,
459+
scripted_fn_atol=scripted_fn_atol,
460+
startpoints=torch.tensor(spoints, device=device, dtype=dt),
461+
endpoints=torch.tensor(epoints, device=device, dtype=dt),
462+
interpolation=NEAREST,
463+
)
464+
465+
437466
def test_perspective_interpolation_type():
438467
spoints = [[0, 0], [33, 0], [33, 25], [0, 25]]
439468
epoints = [[3, 2], [32, 3], [30, 24], [2, 25]]

torchvision/transforms/functional.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -671,32 +671,39 @@ def hflip(img: Tensor) -> Tensor:
671671
return F_t.hflip(img)
672672

673673

674-
def _get_perspective_coeffs(startpoints: List[List[int]], endpoints: List[List[int]]) -> List[float]:
674+
def _get_perspective_coeffs(startpoints: List[List[int]] | Tensor, endpoints: List[List[int]] | Tensor) -> List[float]:
675675
"""Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
676676
677677
In Perspective Transform each pixel (x, y) in the original image gets transformed as,
678678
(x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )
679679
680680
Args:
681-
startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
681+
startpoints (list of list of ints or Tensor): List or Tensor containing four lists of two integers corresponding to four corners
682682
``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
683-
endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
683+
endpoints (list of list of ints or Tensor): List or Tensor containing four lists of two integers corresponding to four corners
684684
``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
685685
686686
Returns:
687687
octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
688688
"""
689+
690+
startpoints = startpoints if isinstance(startpoints, Tensor) else torch.tensor(startpoints, dtype=torch.float64)
691+
endpoints = endpoints if isinstance(endpoints, Tensor) else torch.tensor(endpoints, dtype=torch.float64)
692+
689693
if len(startpoints) != 4 or len(endpoints) != 4:
690694
raise ValueError(
691695
f"Please provide exactly four corners, got {len(startpoints)} startpoints and {len(endpoints)} endpoints."
692696
)
693-
a_matrix = torch.zeros(2 * len(startpoints), 8, dtype=torch.float64)
694697

695-
for i, (p1, p2) in enumerate(zip(endpoints, startpoints)):
696-
a_matrix[2 * i, :] = torch.tensor([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
697-
a_matrix[2 * i + 1, :] = torch.tensor([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
698+
a_matrix = torch.zeros(2 * len(startpoints), 8, dtype=torch.float64, device=startpoints.device)
699+
a_matrix[::2, :2] = endpoints
700+
a_matrix[1::2, 3:5] = endpoints
701+
a_matrix[::2, 2] = 1
702+
a_matrix[1::2, 5] = 1
703+
a_matrix[::2, 6:] = -startpoints[:, 0:1] * endpoints
704+
a_matrix[1::2, 6:] = -startpoints[:, 1:2] * endpoints
698705

699-
b_matrix = torch.tensor(startpoints, dtype=torch.float64).view(8)
706+
b_matrix = startpoints.to(dtype=torch.float64).view(8)
700707
# do least squares in double precision to prevent numerical issues
701708
res = torch.linalg.lstsq(a_matrix, b_matrix, driver="gels").solution.to(torch.float32)
702709

@@ -706,8 +713,8 @@ def _get_perspective_coeffs(startpoints: List[List[int]], endpoints: List[List[i
706713

707714
def perspective(
708715
img: Tensor,
709-
startpoints: List[List[int]],
710-
endpoints: List[List[int]],
716+
startpoints: List[List[int]] | Tensor,
717+
endpoints: List[List[int]] | Tensor,
711718
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
712719
fill: Optional[List[float]] = None,
713720
) -> Tensor:
@@ -717,9 +724,9 @@ def perspective(
717724
718725
Args:
719726
img (PIL Image or Tensor): Image to be transformed.
720-
startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
727+
startpoints (list of list of ints or Tensor): List or Tensor containing four lists of two integers corresponding to four corners
721728
``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
722-
endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
729+
endpoints (list of list of ints or Tensor): List or Tensor containing four lists of two integers corresponding to four corners
723730
``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
724731
interpolation (InterpolationMode): Desired interpolation enum defined by
725732
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.

0 commit comments

Comments
 (0)