Skip to content

Commit 2b84e03

Browse files
djl11seanpmorgan
authored andcommitted
Fix bug in "xy" mode for bilinear interpolation (#845)
* Fix bug in "xy" mode for bilinear interpolation * fixed error in interp implementation, and added test for non-square images. * removed erroneous non-square test, and modified small_grid tests to be non-square.
1 parent 06d686e commit 2b84e03

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

tensorflow_addons/image/dense_image_warp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,11 @@ def interpolate_bilinear(grid, query_points, indexing="ij", name=None):
116116
index_order = [0, 1] if indexing == "ij" else [1, 0]
117117
unstacked_query_points = tf.unstack(query_points, axis=2, num=2)
118118

119-
for dim in index_order:
119+
for i, dim in enumerate(index_order):
120120
with tf.name_scope("dim-" + str(dim)):
121121
queries = unstacked_query_points[dim]
122122

123-
size_in_indexing_dimension = grid_shape[dim + 1]
123+
size_in_indexing_dimension = grid_shape[i + 1]
124124

125125
# max_floor is size_in_indexing_dimension - 2 so that max_floor + 1
126126
# is still a valid index into the grid.

tensorflow_addons/image/dense_image_warp_test.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,28 @@
3030
@test_utils.run_all_in_graph_and_eager_modes
3131
class InterpolateBilinearTest(tf.test.TestCase):
3232
def test_interpolate_small_grid_ij(self):
33-
grid = tf.constant([[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]],
34-
shape=[1, 3, 3, 1])
35-
query_points = tf.constant([[0., 0.], [1., 0.], [2., 0.5], [1.5, 1.5]],
36-
shape=[1, 4, 2])
37-
expected_results = np.reshape(np.array([0., 3., 6.5, 6.]), [1, 4, 1])
33+
grid = tf.constant(
34+
[[0., 1., 2.], [3., 4., 5.], [6., 7., 8.], [9., 10., 11.]],
35+
shape=[1, 4, 3, 1])
36+
query_points = tf.constant(
37+
[[0., 0.], [1., 0.], [2., 0.5], [1.5, 1.5], [3., 2.]],
38+
shape=[1, 5, 2])
39+
expected_results = np.reshape(
40+
np.array([0., 3., 6.5, 6., 11.]), [1, 5, 1])
3841

3942
interp = interpolate_bilinear(grid, query_points)
4043

4144
self.assertAllClose(expected_results, interp)
4245

4346
def test_interpolate_small_grid_xy(self):
44-
grid = tf.constant([[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]],
45-
shape=[1, 3, 3, 1])
47+
grid = tf.constant(
48+
[[0., 1., 2.], [3., 4., 5.], [6., 7., 8.], [9., 10., 11.]],
49+
shape=[1, 4, 3, 1])
4650
query_points = tf.constant(
47-
[[0., 0.], [0., 1.], [0.5, 2.0], [1.5, 1.5]], shape=[1, 4, 2])
48-
expected_results = np.reshape(np.array([0., 3., 6.5, 6.]), [1, 4, 1])
51+
[[0., 0.], [0., 1.], [0.5, 2.0], [1.5, 1.5], [2., 3.]],
52+
shape=[1, 5, 2])
53+
expected_results = np.reshape(
54+
np.array([0., 3., 6.5, 6., 11.]), [1, 5, 1])
4955

5056
interp = interpolate_bilinear(grid, query_points, indexing="xy")
5157

0 commit comments

Comments
 (0)