@@ -671,32 +671,39 @@ def hflip(img: Tensor) -> Tensor:
671
671
return F_t .hflip (img )
672
672
673
673
674
- def _get_perspective_coeffs (startpoints : List [List [int ]], endpoints : List [List [int ]]) -> List [float ]:
674
+ def _get_perspective_coeffs (startpoints : List [List [int ]] | Tensor , endpoints : List [List [int ]] | Tensor ) -> List [float ]:
675
675
"""Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
676
676
677
677
In Perspective Transform each pixel (x, y) in the original image gets transformed as,
678
678
(x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )
679
679
680
680
Args:
681
- startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
681
+ startpoints (list of list of ints or Tensor ): List or Tensor containing four lists of two integers corresponding to four corners
682
682
``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
683
- endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
683
+ endpoints (list of list of ints or Tensor ): List or Tensor containing four lists of two integers corresponding to four corners
684
684
``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
685
685
686
686
Returns:
687
687
octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
688
688
"""
689
+
690
+ startpoints = startpoints if isinstance (startpoints , Tensor ) else torch .tensor (startpoints , dtype = torch .float64 )
691
+ endpoints = endpoints if isinstance (endpoints , Tensor ) else torch .tensor (endpoints , dtype = torch .float64 )
692
+
689
693
if len (startpoints ) != 4 or len (endpoints ) != 4 :
690
694
raise ValueError (
691
695
f"Please provide exactly four corners, got { len (startpoints )} startpoints and { len (endpoints )} endpoints."
692
696
)
693
- a_matrix = torch .zeros (2 * len (startpoints ), 8 , dtype = torch .float64 )
694
697
695
- for i , (p1 , p2 ) in enumerate (zip (endpoints , startpoints )):
696
- a_matrix [2 * i , :] = torch .tensor ([p1 [0 ], p1 [1 ], 1 , 0 , 0 , 0 , - p2 [0 ] * p1 [0 ], - p2 [0 ] * p1 [1 ]])
697
- a_matrix [2 * i + 1 , :] = torch .tensor ([0 , 0 , 0 , p1 [0 ], p1 [1 ], 1 , - p2 [1 ] * p1 [0 ], - p2 [1 ] * p1 [1 ]])
698
+ a_matrix = torch .zeros (2 * len (startpoints ), 8 , dtype = torch .float64 , device = startpoints .device )
699
+ a_matrix [::2 , :2 ] = endpoints
700
+ a_matrix [1 ::2 , 3 :5 ] = endpoints
701
+ a_matrix [::2 , 2 ] = 1
702
+ a_matrix [1 ::2 , 5 ] = 1
703
+ a_matrix [::2 , 6 :] = - startpoints [:, 0 :1 ] * endpoints
704
+ a_matrix [1 ::2 , 6 :] = - startpoints [:, 1 :2 ] * endpoints
698
705
699
- b_matrix = torch . tensor ( startpoints , dtype = torch .float64 ).view (8 )
706
+ b_matrix = startpoints . to ( dtype = torch .float64 ).view (8 )
700
707
# do least squares in double precision to prevent numerical issues
701
708
res = torch .linalg .lstsq (a_matrix , b_matrix , driver = "gels" ).solution .to (torch .float32 )
702
709
@@ -706,8 +713,8 @@ def _get_perspective_coeffs(startpoints: List[List[int]], endpoints: List[List[i
706
713
707
714
def perspective (
708
715
img : Tensor ,
709
- startpoints : List [List [int ]],
710
- endpoints : List [List [int ]],
716
+ startpoints : List [List [int ]] | Tensor ,
717
+ endpoints : List [List [int ]] | Tensor ,
711
718
interpolation : InterpolationMode = InterpolationMode .BILINEAR ,
712
719
fill : Optional [List [float ]] = None ,
713
720
) -> Tensor :
@@ -717,9 +724,9 @@ def perspective(
717
724
718
725
Args:
719
726
img (PIL Image or Tensor): Image to be transformed.
720
- startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
727
+ startpoints (list of list of ints or Tensor ): List or Tensor containing four lists of two integers corresponding to four corners
721
728
``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
722
- endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
729
+ endpoints (list of list of ints or Tensor ): List or Tensor containing four lists of two integers corresponding to four corners
723
730
``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
724
731
interpolation (InterpolationMode): Desired interpolation enum defined by
725
732
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
0 commit comments