Skip to content

Commit da07901

Browse files
authored
Update random_perspective to use ops.perspective_transform (keras-team#20915)
* Update get_perspective_matrix method * Update bbox logic * refactoring random_perspective * apply tensor cast * add dtype conversion * Update base scale factor * correct failed test case * correct failed test case * correct failed test case * Remove scale zero test case * update the logic to use perspective_transform on image layer * Update test cases
1 parent 86873b5 commit da07901

File tree

3 files changed

+203
-123
lines changed

3 files changed

+203
-123
lines changed

keras/src/backend/jax/image.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,8 @@ def perspective_transform(
546546

547547
batch_size, height, width, channels = images.shape
548548
transforms = compute_homography_matrix(
549-
jnp.asarray(start_points), jnp.asarray(end_points)
549+
jnp.asarray(start_points, dtype="float32"),
550+
jnp.asarray(end_points, dtype="float32"),
550551
)
551552

552553
x, y = jnp.meshgrid(jnp.arange(width), jnp.arange(height), indexing="xy")

keras/src/layers/preprocessing/image_preprocessing/random_perspective.py

Lines changed: 98 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class RandomPerspective(BaseImagePreprocessingLayer):
4949
def __init__(
5050
self,
5151
factor=1.0,
52-
scale=0.3,
52+
scale=1.0,
5353
interpolation="bilinear",
5454
fill_value=0.0,
5555
seed=None,
@@ -103,6 +103,10 @@ def get_random_transformation(self, data, training=True, seed=None):
103103
batch_size = 1
104104
else:
105105
batch_size = images_shape[0]
106+
height, width = (
107+
images.shape[self.height_axis],
108+
images.shape[self.width_axis],
109+
)
106110

107111
seed = seed or self._get_seed_generator(self.backend._backend)
108112

@@ -122,32 +126,42 @@ def get_random_transformation(self, data, training=True, seed=None):
122126
apply_perspective = random_threshold < transformation_probability
123127

124128
perspective_factor = self.backend.random.uniform(
125-
minval=-self.scale,
126-
maxval=self.scale,
127-
shape=[batch_size, 4],
129+
shape=(batch_size, 4, 2),
130+
minval=-0.5 * self.scale,
131+
maxval=0.5 * self.scale,
128132
seed=seed,
129133
dtype=self.compute_dtype,
130134
)
131135

136+
start_points = self.backend.convert_to_tensor(
137+
[
138+
[
139+
[0.0, 0.0],
140+
[width - 1, 0.0],
141+
[0.0, height - 1],
142+
[width - 1, height - 1],
143+
]
144+
],
145+
dtype=self.compute_dtype,
146+
)
147+
148+
start_points = self.backend.numpy.repeat(
149+
start_points, batch_size, axis=0
150+
)
151+
end_points = start_points + start_points * perspective_factor
152+
132153
return {
133154
"apply_perspective": apply_perspective,
134-
"perspective_factor": perspective_factor,
155+
"start_points": start_points,
156+
"end_points": end_points,
135157
"input_shape": images_shape,
136158
}
137159

138160
def transform_images(self, images, transformation, training=True):
139161
images = self.backend.cast(images, self.compute_dtype)
140162
if training and transformation is not None:
141-
apply_perspective = transformation["apply_perspective"]
142-
perspective_images = self._perspective_inputs(
143-
images, transformation
144-
)
145-
146-
images = self.backend.numpy.where(
147-
apply_perspective[:, None, None, None],
148-
perspective_images,
149-
images,
150-
)
163+
images = self._perspective_inputs(images, transformation)
164+
images = self.backend.cast(images, self.compute_dtype)
151165
return images
152166

153167
def _perspective_inputs(self, inputs, transformation):
@@ -159,63 +173,36 @@ def _perspective_inputs(self, inputs, transformation):
159173
if unbatched:
160174
inputs = self.backend.numpy.expand_dims(inputs, axis=0)
161175

162-
perspective_factor = self.backend.core.convert_to_tensor(
163-
transformation["perspective_factor"], dtype=self.compute_dtype
164-
)
165-
outputs = self.backend.image.affine_transform(
176+
start_points = transformation["start_points"]
177+
end_points = transformation["end_points"]
178+
179+
outputs = self.backend.image.perspective_transform(
166180
inputs,
167-
transform=self._get_perspective_matrix(perspective_factor),
181+
start_points,
182+
end_points,
168183
interpolation=self.interpolation,
169-
fill_mode="constant",
170184
fill_value=self.fill_value,
171185
data_format=self.data_format,
172186
)
173187

188+
apply_perspective = transformation["apply_perspective"]
189+
outputs = self.backend.numpy.where(
190+
apply_perspective[:, None, None, None],
191+
outputs,
192+
inputs,
193+
)
194+
174195
if unbatched:
175196
outputs = self.backend.numpy.squeeze(outputs, axis=0)
176197
return outputs
177198

178-
def _get_perspective_matrix(self, perspectives):
179-
perspectives = self.backend.core.convert_to_tensor(
180-
perspectives, dtype=self.compute_dtype
181-
)
182-
num_perspectives = self.backend.shape(perspectives)[0]
183-
return self.backend.numpy.concatenate(
184-
[
185-
self.backend.numpy.ones(
186-
(num_perspectives, 1), dtype=self.compute_dtype
187-
)
188-
+ perspectives[:, :1],
189-
perspectives[:, :1],
190-
perspectives[:, 2:3],
191-
perspectives[:, 1:2],
192-
self.backend.numpy.ones(
193-
(num_perspectives, 1), dtype=self.compute_dtype
194-
)
195-
+ perspectives[:, 1:2],
196-
perspectives[:, 3:4],
197-
self.backend.numpy.zeros((num_perspectives, 2)),
198-
],
199-
axis=1,
200-
)
201-
202-
def _get_transformed_coordinates(self, x, y, transform):
203-
a0, a1, a2, b0, b1, b2, c0, c1 = self.backend.numpy.split(
204-
transform, 8, axis=-1
205-
)
206-
207-
x_transformed = (a1 * (y - b2) - b1 * (x - a2)) / (a1 * b0 - a0 * b1)
208-
y_transformed = (b0 * (x - a2) - a0 * (y - b2)) / (a1 * b0 - a0 * b1)
209-
210-
return x_transformed, y_transformed
211-
212199
def transform_bounding_boxes(
213200
self,
214201
bounding_boxes,
215202
transformation,
216203
training=True,
217204
):
218-
if training:
205+
if training and transformation is not None:
219206
if backend_utils.in_tf_graph():
220207
self.backend.set_backend("tensorflow")
221208

@@ -233,26 +220,33 @@ def transform_bounding_boxes(
233220
)
234221

235222
boxes = bounding_boxes["boxes"]
236-
237223
x0, y0, x1, y1 = self.backend.numpy.split(boxes, 4, axis=-1)
238224

239-
perspective_factor = transformation["perspective_factor"]
240-
transform = self._get_perspective_matrix(perspective_factor)
225+
start_points = transformation["start_points"]
226+
end_points = transformation["end_points"]
227+
transform = self.backend.image.compute_homography_matrix(
228+
start_points, end_points
229+
)
241230
transform = self.backend.numpy.expand_dims(transform, axis=1)
242231
transform = self.backend.cast(transform, dtype=self.compute_dtype)
243232

244-
x_1, y_1 = self._get_transformed_coordinates(x0, y0, transform)
245-
x_2, y_2 = self._get_transformed_coordinates(x1, y1, transform)
246-
x_3, y_3 = self._get_transformed_coordinates(x0, y1, transform)
247-
x_4, y_4 = self._get_transformed_coordinates(x1, y0, transform)
233+
corners = [
234+
self._get_transformed_coordinates(x, y, transform)
235+
for x, y in [(x0, y0), (x1, y1), (x0, y1), (x1, y0)]
236+
]
237+
x_corners, y_corners = zip(*corners)
248238

249-
xs = self.backend.numpy.concatenate([x_1, x_2, x_3, x_4], axis=-1)
250-
ys = self.backend.numpy.concatenate([y_1, y_2, y_3, y_4], axis=-1)
239+
xs = self.backend.numpy.stack(x_corners, axis=-1)
240+
ys = self.backend.numpy.stack(y_corners, axis=-1)
251241

252-
min_x = self.backend.numpy.min(xs, axis=-1)
253-
max_x = self.backend.numpy.max(xs, axis=-1)
254-
min_y = self.backend.numpy.min(ys, axis=-1)
255-
max_y = self.backend.numpy.max(ys, axis=-1)
242+
min_x, max_x = (
243+
self.backend.numpy.min(xs, axis=-1),
244+
self.backend.numpy.max(xs, axis=-1),
245+
)
246+
min_y, max_y = (
247+
self.backend.numpy.min(ys, axis=-1),
248+
self.backend.numpy.max(ys, axis=-1),
249+
)
256250

257251
min_x = self.backend.numpy.expand_dims(min_x, axis=-1)
258252
max_x = self.backend.numpy.expand_dims(max_x, axis=-1)
@@ -280,8 +274,43 @@ def transform_bounding_boxes(
280274
bounding_box_format="xyxy",
281275
)
282276

277+
self.backend.reset()
278+
283279
return bounding_boxes
284280

281+
def _get_transformed_coordinates(
282+
self, x_coords, y_coords, transformation_matrix
283+
):
284+
backend = self.backend
285+
286+
batch_size = backend.shape(transformation_matrix)[0]
287+
288+
homogeneous_transform = backend.numpy.concatenate(
289+
[transformation_matrix, backend.numpy.ones((batch_size, 1, 1))],
290+
axis=-1,
291+
)
292+
homogeneous_transform = backend.numpy.reshape(
293+
homogeneous_transform, (batch_size, 3, 3)
294+
)
295+
296+
inverse_transform = backend.linalg.inv(homogeneous_transform)
297+
298+
ones_column = backend.numpy.ones_like(x_coords)
299+
homogeneous_coords = backend.numpy.concatenate(
300+
[x_coords, y_coords, ones_column], axis=-1
301+
)
302+
303+
homogeneous_coords = backend.numpy.moveaxis(homogeneous_coords, -1, -2)
304+
transformed_coords = backend.numpy.matmul(
305+
inverse_transform, homogeneous_coords
306+
)
307+
transformed_coords = backend.numpy.moveaxis(transformed_coords, -1, -2)
308+
309+
x_transformed = transformed_coords[..., 0] / transformed_coords[..., 2]
310+
y_transformed = transformed_coords[..., 1] / transformed_coords[..., 2]
311+
312+
return x_transformed, y_transformed
313+
285314
def transform_labels(self, labels, transformation, training=True):
286315
return labels
287316

0 commit comments

Comments
 (0)