Skip to content

Commit 86873b5

Browse files
authored
Add perspective_transform for ops (keras-team#20899)
* Add perspective_transform for ops * Add perspective_transform for torch * Add perspective_transform for jax * Add perspective_transform for ops * Add perspective_transform test * Fix failed test cases * Fix failed test on torch ci
1 parent e045b6a commit 86873b5

File tree

8 files changed

+1379
-0
lines changed

8 files changed

+1379
-0
lines changed

keras/api/_tf_keras/keras/ops/image/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from keras.src.ops.image import hsv_to_rgb
1111
from keras.src.ops.image import map_coordinates
1212
from keras.src.ops.image import pad_images
13+
from keras.src.ops.image import perspective_transform
1314
from keras.src.ops.image import resize
1415
from keras.src.ops.image import rgb_to_grayscale
1516
from keras.src.ops.image import rgb_to_hsv

keras/api/ops/image/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from keras.src.ops.image import hsv_to_rgb
1111
from keras.src.ops.image import map_coordinates
1212
from keras.src.ops.image import pad_images
13+
from keras.src.ops.image import perspective_transform
1314
from keras.src.ops.image import resize
1415
from keras.src.ops.image import rgb_to_grayscale
1516
from keras.src.ops.image import rgb_to_hsv

keras/src/backend/jax/image.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,151 @@ def affine_transform(
491491
}
492492

493493

494+
def perspective_transform(
495+
images,
496+
start_points,
497+
end_points,
498+
interpolation="bilinear",
499+
fill_value=0,
500+
data_format=None,
501+
):
502+
data_format = backend.standardize_data_format(data_format)
503+
if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():
504+
raise ValueError(
505+
"Invalid value for argument `interpolation`. Expected of one "
506+
f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: "
507+
f"interpolation={interpolation}"
508+
)
509+
510+
if len(images.shape) not in (3, 4):
511+
raise ValueError(
512+
"Invalid images rank: expected rank 3 (single image) "
513+
"or rank 4 (batch of images). Received input with shape: "
514+
f"images.shape={images.shape}"
515+
)
516+
517+
if start_points.shape[-2:] != (4, 2) or start_points.ndim not in (2, 3):
518+
raise ValueError(
519+
"Invalid start_points shape: expected (4,2) for a single image"
520+
f" or (N,4,2) for a batch. Received shape: {start_points.shape}"
521+
)
522+
if end_points.shape[-2:] != (4, 2) or end_points.ndim not in (2, 3):
523+
raise ValueError(
524+
"Invalid end_points shape: expected (4,2) for a single image"
525+
f" or (N,4,2) for a batch. Received shape: {end_points.shape}"
526+
)
527+
if start_points.shape != end_points.shape:
528+
raise ValueError(
529+
"start_points and end_points must have the same shape."
530+
f" Received start_points.shape={start_points.shape}, "
531+
f"end_points.shape={end_points.shape}"
532+
)
533+
534+
need_squeeze = False
535+
if len(images.shape) == 3:
536+
images = jnp.expand_dims(images, axis=0)
537+
need_squeeze = True
538+
539+
if len(start_points.shape) == 2:
540+
start_points = jnp.expand_dims(start_points, axis=0)
541+
if len(end_points.shape) == 2:
542+
end_points = jnp.expand_dims(end_points, axis=0)
543+
544+
if data_format == "channels_first":
545+
images = jnp.transpose(images, (0, 2, 3, 1))
546+
547+
batch_size, height, width, channels = images.shape
548+
transforms = compute_homography_matrix(
549+
jnp.asarray(start_points), jnp.asarray(end_points)
550+
)
551+
552+
x, y = jnp.meshgrid(jnp.arange(width), jnp.arange(height), indexing="xy")
553+
grid = jnp.stack([x.ravel(), y.ravel(), jnp.ones_like(x).ravel()], axis=0)
554+
555+
def transform_coordinates(transform):
556+
denom = transform[6] * grid[0] + transform[7] * grid[1] + 1.0
557+
x_in = (
558+
transform[0] * grid[0] + transform[1] * grid[1] + transform[2]
559+
) / denom
560+
y_in = (
561+
transform[3] * grid[0] + transform[4] * grid[1] + transform[5]
562+
) / denom
563+
return jnp.stack([y_in, x_in], axis=0)
564+
565+
transformed_coords = jax.vmap(transform_coordinates)(transforms)
566+
567+
def interpolate_image(image, coords):
568+
def interpolate_channel(channel_img):
569+
return jax.scipy.ndimage.map_coordinates(
570+
channel_img,
571+
coords,
572+
order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],
573+
mode="constant",
574+
cval=fill_value,
575+
).reshape(height, width)
576+
577+
return jax.vmap(interpolate_channel, in_axes=0)(
578+
jnp.moveaxis(image, -1, 0)
579+
)
580+
581+
output = jax.vmap(interpolate_image, in_axes=(0, 0))(
582+
images, transformed_coords
583+
)
584+
output = jnp.moveaxis(output, 1, -1)
585+
586+
if data_format == "channels_first":
587+
output = jnp.transpose(output, (0, 3, 1, 2))
588+
if need_squeeze:
589+
output = jnp.squeeze(output, axis=0)
590+
591+
return output
592+
593+
594+
def compute_homography_matrix(start_points, end_points):
595+
start_x, start_y = start_points[..., 0], start_points[..., 1]
596+
end_x, end_y = end_points[..., 0], end_points[..., 1]
597+
598+
zeros = jnp.zeros_like(end_x)
599+
ones = jnp.ones_like(end_x)
600+
601+
x_rows = jnp.stack(
602+
[
603+
end_x,
604+
end_y,
605+
ones,
606+
zeros,
607+
zeros,
608+
zeros,
609+
-start_x * end_x,
610+
-start_x * end_y,
611+
],
612+
axis=-1,
613+
)
614+
y_rows = jnp.stack(
615+
[
616+
zeros,
617+
zeros,
618+
zeros,
619+
end_x,
620+
end_y,
621+
ones,
622+
-start_y * end_x,
623+
-start_y * end_y,
624+
],
625+
axis=-1,
626+
)
627+
628+
coefficient_matrix = jnp.concatenate([x_rows, y_rows], axis=1)
629+
630+
target_vector = jnp.expand_dims(
631+
jnp.concatenate([start_x, start_y], axis=-1), axis=-1
632+
)
633+
634+
homography_matrix = jnp.linalg.solve(coefficient_matrix, target_vector)
635+
636+
return homography_matrix.squeeze(-1)
637+
638+
494639
def map_coordinates(
495640
inputs, coordinates, order, fill_mode="constant", fill_value=0.0
496641
):

0 commit comments

Comments
 (0)