@@ -491,6 +491,151 @@ def affine_transform(
491
491
}
492
492
493
493
494
+ def perspective_transform (
495
+ images ,
496
+ start_points ,
497
+ end_points ,
498
+ interpolation = "bilinear" ,
499
+ fill_value = 0 ,
500
+ data_format = None ,
501
+ ):
502
+ data_format = backend .standardize_data_format (data_format )
503
+ if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS .keys ():
504
+ raise ValueError (
505
+ "Invalid value for argument `interpolation`. Expected of one "
506
+ f"{ set (AFFINE_TRANSFORM_INTERPOLATIONS .keys ())} . Received: "
507
+ f"interpolation={ interpolation } "
508
+ )
509
+
510
+ if len (images .shape ) not in (3 , 4 ):
511
+ raise ValueError (
512
+ "Invalid images rank: expected rank 3 (single image) "
513
+ "or rank 4 (batch of images). Received input with shape: "
514
+ f"images.shape={ images .shape } "
515
+ )
516
+
517
+ if start_points .shape [- 2 :] != (4 , 2 ) or start_points .ndim not in (2 , 3 ):
518
+ raise ValueError (
519
+ "Invalid start_points shape: expected (4,2) for a single image"
520
+ f" or (N,4,2) for a batch. Received shape: { start_points .shape } "
521
+ )
522
+ if end_points .shape [- 2 :] != (4 , 2 ) or end_points .ndim not in (2 , 3 ):
523
+ raise ValueError (
524
+ "Invalid end_points shape: expected (4,2) for a single image"
525
+ f" or (N,4,2) for a batch. Received shape: { end_points .shape } "
526
+ )
527
+ if start_points .shape != end_points .shape :
528
+ raise ValueError (
529
+ "start_points and end_points must have the same shape."
530
+ f" Received start_points.shape={ start_points .shape } , "
531
+ f"end_points.shape={ end_points .shape } "
532
+ )
533
+
534
+ need_squeeze = False
535
+ if len (images .shape ) == 3 :
536
+ images = jnp .expand_dims (images , axis = 0 )
537
+ need_squeeze = True
538
+
539
+ if len (start_points .shape ) == 2 :
540
+ start_points = jnp .expand_dims (start_points , axis = 0 )
541
+ if len (end_points .shape ) == 2 :
542
+ end_points = jnp .expand_dims (end_points , axis = 0 )
543
+
544
+ if data_format == "channels_first" :
545
+ images = jnp .transpose (images , (0 , 2 , 3 , 1 ))
546
+
547
+ batch_size , height , width , channels = images .shape
548
+ transforms = compute_homography_matrix (
549
+ jnp .asarray (start_points ), jnp .asarray (end_points )
550
+ )
551
+
552
+ x , y = jnp .meshgrid (jnp .arange (width ), jnp .arange (height ), indexing = "xy" )
553
+ grid = jnp .stack ([x .ravel (), y .ravel (), jnp .ones_like (x ).ravel ()], axis = 0 )
554
+
555
+ def transform_coordinates (transform ):
556
+ denom = transform [6 ] * grid [0 ] + transform [7 ] * grid [1 ] + 1.0
557
+ x_in = (
558
+ transform [0 ] * grid [0 ] + transform [1 ] * grid [1 ] + transform [2 ]
559
+ ) / denom
560
+ y_in = (
561
+ transform [3 ] * grid [0 ] + transform [4 ] * grid [1 ] + transform [5 ]
562
+ ) / denom
563
+ return jnp .stack ([y_in , x_in ], axis = 0 )
564
+
565
+ transformed_coords = jax .vmap (transform_coordinates )(transforms )
566
+
567
+ def interpolate_image (image , coords ):
568
+ def interpolate_channel (channel_img ):
569
+ return jax .scipy .ndimage .map_coordinates (
570
+ channel_img ,
571
+ coords ,
572
+ order = AFFINE_TRANSFORM_INTERPOLATIONS [interpolation ],
573
+ mode = "constant" ,
574
+ cval = fill_value ,
575
+ ).reshape (height , width )
576
+
577
+ return jax .vmap (interpolate_channel , in_axes = 0 )(
578
+ jnp .moveaxis (image , - 1 , 0 )
579
+ )
580
+
581
+ output = jax .vmap (interpolate_image , in_axes = (0 , 0 ))(
582
+ images , transformed_coords
583
+ )
584
+ output = jnp .moveaxis (output , 1 , - 1 )
585
+
586
+ if data_format == "channels_first" :
587
+ output = jnp .transpose (output , (0 , 3 , 1 , 2 ))
588
+ if need_squeeze :
589
+ output = jnp .squeeze (output , axis = 0 )
590
+
591
+ return output
592
+
593
+
594
+ def compute_homography_matrix (start_points , end_points ):
595
+ start_x , start_y = start_points [..., 0 ], start_points [..., 1 ]
596
+ end_x , end_y = end_points [..., 0 ], end_points [..., 1 ]
597
+
598
+ zeros = jnp .zeros_like (end_x )
599
+ ones = jnp .ones_like (end_x )
600
+
601
+ x_rows = jnp .stack (
602
+ [
603
+ end_x ,
604
+ end_y ,
605
+ ones ,
606
+ zeros ,
607
+ zeros ,
608
+ zeros ,
609
+ - start_x * end_x ,
610
+ - start_x * end_y ,
611
+ ],
612
+ axis = - 1 ,
613
+ )
614
+ y_rows = jnp .stack (
615
+ [
616
+ zeros ,
617
+ zeros ,
618
+ zeros ,
619
+ end_x ,
620
+ end_y ,
621
+ ones ,
622
+ - start_y * end_x ,
623
+ - start_y * end_y ,
624
+ ],
625
+ axis = - 1 ,
626
+ )
627
+
628
+ coefficient_matrix = jnp .concatenate ([x_rows , y_rows ], axis = 1 )
629
+
630
+ target_vector = jnp .expand_dims (
631
+ jnp .concatenate ([start_x , start_y ], axis = - 1 ), axis = - 1
632
+ )
633
+
634
+ homography_matrix = jnp .linalg .solve (coefficient_matrix , target_vector )
635
+
636
+ return homography_matrix .squeeze (- 1 )
637
+
638
+
494
639
def map_coordinates (
495
640
inputs , coordinates , order , fill_mode = "constant" , fill_value = 0.0
496
641
):
0 commit comments