Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ __pycache__
**/.vscode test/**
**/.vscode-smoke/**
**/.venv*/
venv
bin/**
build/**
obj/**
Expand Down
27 changes: 27 additions & 0 deletions integration_tests/test_save_img.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os

import numpy as np
import pytest

from keras.utils import img_to_array
from keras.utils import load_img
from keras.utils import save_img


@pytest.mark.parametrize(
"shape, name",
[
((50, 50, 3), "rgb.jpg"),
((50, 50, 4), "rgba.jpg"),
],
)
def test_save_jpg(tmp_path, shape, name):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use a TestCase subclass and use self.assert... methods instead of naked assert staments

img = np.random.randint(0, 256, size=shape, dtype=np.uint8)
path = tmp_path / name
save_img(path, img, file_format="jpg")
assert os.path.exists(path)

# Check that the image was saved correctly and converted to RGB if needed.
loaded_img = load_img(path)
loaded_array = img_to_array(loaded_img)
assert loaded_array.shape == (50, 50, 3)
18 changes: 16 additions & 2 deletions keras/src/applications/imagenet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,14 +278,28 @@ def _preprocess_tensor_input(x, data_format, mode):

# Zero-center by mean pixel
if data_format == "channels_first":
mean_tensor = ops.reshape(mean_tensor, (1, 3) + (1,) * (ndim - 2))
if ndim == 3:
mean_tensor = ops.reshape(mean_tensor, (3, 1, 1))
elif ndim == 4:
mean_tensor = ops.reshape(mean_tensor, (1, 3, 1, 1))
else:
raise ValueError(f"Unsupported shape for channels_first: {x.shape}")
Comment on lines +281 to +286
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While this logic correctly handles 3D and 4D inputs, the error message could be more informative. According to the Keras API design guidelines, a good error message should explain what was expected and how the user can fix it.1 This message only states what was received. Consider clarifying that only 3D and 4D tensors are supported for channels_first.

Additionally, you can make the code slightly cleaner by checking for the invalid ndim case first.

Suggested change
if ndim == 3:
mean_tensor = ops.reshape(mean_tensor, (3, 1, 1))
elif ndim == 4:
mean_tensor = ops.reshape(mean_tensor, (1, 3, 1, 1))
else:
raise ValueError(f"Unsupported shape for channels_first: {x.shape}")
if ndim not in (3, 4):
raise ValueError(
f"Unsupported tensor rank: {ndim}. With `data_format='channels_first'`, "
"`preprocess_input` only supports 3D (single image) and 4D (batch of "
f"images) tensors. Received tensor with shape: {x.shape}"
)
if ndim == 3:
mean_tensor = ops.reshape(mean_tensor, (3, 1, 1))
elif ndim == 4:
mean_tensor = ops.reshape(mean_tensor, (1, 3, 1, 1))

Style Guide References

Footnotes

  1. The style guide states that error messages should be contextual, informative, and actionable, explaining what happened, what was expected, and how to fix it.

else:
mean_tensor = ops.reshape(mean_tensor, (1,) * (ndim - 1) + (3,))
x += mean_tensor
if std is not None:
std_tensor = ops.convert_to_tensor(np.array(std), dtype=x.dtype)
if data_format == "channels_first":
std_tensor = ops.reshape(std_tensor, (-1, 1, 1))
if ndim == 3:
std_tensor = ops.reshape(std_tensor, (3, 1, 1))
elif ndim == 4:
std_tensor = ops.reshape(std_tensor, (1, 3, 1, 1))
else:
raise ValueError(
f"Unsupported shape for channels_first: {x.shape}"
)
else:
std_tensor = ops.reshape(std_tensor, (1,) * (ndim - 1) + (3,))
Comment on lines 292 to +302
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for reshaping std_tensor is a duplicate of the logic used for mean_tensor above. To improve maintainability and reduce code duplication, consider refactoring this. You could determine the reshape_shape once at the start of the data_format == 'channels_first' block and reuse it for both tensors. This would also apply to the error handling logic.

x /= std_tensor
return x

Expand Down
7 changes: 5 additions & 2 deletions keras/src/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,13 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs):
**kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
"""
data_format = backend.standardize_data_format(data_format)
# Normalize jpg → jpeg
if file_format is not None and file_format.lower() == "jpg":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert changes in this file

file_format = "jpeg"
img = array_to_img(x, data_format=data_format, scale=scale)
if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"):
if img.mode == "RGBA" and file_format == "jpeg":
warnings.warn(
"The JPG format does not support RGBA images, converting to RGB."
"The JPEG format does not support RGBA images, converting to RGB."
)
img = img.convert("RGB")
img.save(path, format=file_format, **kwargs)
Expand Down