Skip to content

Bug in func keras.applications.vgg16.preprocess_input() when x in 3D and data_format=='channels_first' #21749

@ILCSFNO

Description

@ILCSFNO

Bug Issue

The doc of keras.applications.vgg16.preprocess_input() shows its description as below:

data_format: Optional data format of the image tensor/array. None, means
the global setting `keras.backend.image_data_format()` is used
(unless you changed it, it uses "channels_last").{mode}
Defaults to `None`.

PREPROCESS_INPUT_MODE_DOC = """
mode: One of "caffe", "tf" or "torch".
- caffe: will convert the images from RGB to BGR,
then will zero-center each color channel with
respect to the ImageNet dataset,
without scaling.
- tf: will scale pixels between -1 and 1,
sample-wise.
- torch: will scale pixels between 0 and 1 and then
will normalize each channel with respect to the
ImageNet dataset.
Defaults to `"caffe"`.
"""

For the repro below, which is expected to work well, I found a bug that keras.applications.vgg16.preprocess_input can't deal correctly with x in 3D and data_format=='channels_first', using tf 2.20.0 and keras latest:

Repro

import tensorflow as tf
import numpy as np
import keras
data_format = 'channels_first'
image = np.random.randint(0, 256, size=(224, 224, 3), dtype=np.uint8)
image = tf.convert_to_tensor(image, dtype=tf.float32)
image = tf.transpose(image, perm=[2, 0, 1])  # Transpose to channels_first
preprocessed_image = keras.applications.vgg16.preprocess_input(image, data_format=data_format)
print(preprocessed_image.shape)

Output

InvalidArgumentError: {{function_node __wrapped__AddV2_device_/job:localhost/replica:0/task:0/device:CPU:0}} Incompatible shapes: [3,224,224] vs. [1,3,1] [Op:AddV2] name: 

The related codes is here:

# Zero-center by mean pixel
if data_format == "channels_first":
mean_tensor = ops.reshape(mean_tensor, (1, 3) + (1,) * (ndim - 2))
else:
mean_tensor = ops.reshape(mean_tensor, (1,) * (ndim - 1) + (3,))
x += mean_tensor

It should deal with data_format=='channels_first' conditionally(3D v.s. 4D) just like here:

if data_format == "channels_first":
# 'RGB'->'BGR'
if len(x.shape) == 3:
x = ops.stack([x[i, ...] for i in (2, 1, 0)], axis=0)
else:
x = ops.stack([x[:, i, :] for i in (2, 1, 0)], axis=1)

Thanks for noting!

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions