-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Fix: handle 3D input for channels_first in preprocess_input() #21749 #21754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
2b65f89
2a5bb21
6e0340b
7b84f95
49fc1b7
544e7f1
a935e1c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ __pycache__ | |
| **/.vscode test/** | ||
| **/.vscode-smoke/** | ||
| **/.venv*/ | ||
| venv | ||
| bin/** | ||
| build/** | ||
| obj/** | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| import os | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
|
|
||
| from keras.utils import img_to_array | ||
| from keras.utils import load_img | ||
| from keras.utils import save_img | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "shape, name", | ||
| [ | ||
| ((50, 50, 3), "rgb.jpg"), | ||
| ((50, 50, 4), "rgba.jpg"), | ||
| ], | ||
| ) | ||
| def test_save_jpg(tmp_path, shape, name): | ||
| img = np.random.randint(0, 256, size=shape, dtype=np.uint8) | ||
| path = tmp_path / name | ||
| save_img(path, img, file_format="jpg") | ||
| assert os.path.exists(path) | ||
|
|
||
| # Check that the image was saved correctly and converted to RGB if needed. | ||
| loaded_img = load_img(path) | ||
| loaded_array = img_to_array(loaded_img) | ||
| assert loaded_array.shape == (50, 50, 3) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -278,14 +278,28 @@ def _preprocess_tensor_input(x, data_format, mode): | |||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Zero-center by mean pixel | ||||||||||||||||||||||||||||||||||
| if data_format == "channels_first": | ||||||||||||||||||||||||||||||||||
| mean_tensor = ops.reshape(mean_tensor, (1, 3) + (1,) * (ndim - 2)) | ||||||||||||||||||||||||||||||||||
| if ndim == 3: | ||||||||||||||||||||||||||||||||||
| mean_tensor = ops.reshape(mean_tensor, (3, 1, 1)) | ||||||||||||||||||||||||||||||||||
| elif ndim == 4: | ||||||||||||||||||||||||||||||||||
| mean_tensor = ops.reshape(mean_tensor, (1, 3, 1, 1)) | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| raise ValueError(f"Unsupported shape for channels_first: {x.shape}") | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+281
to
+286
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While this logic correctly handles 3D and 4D inputs, the error message could be more informative. According to the Keras API design guidelines, a good error message should explain what was expected and how the user can fix it.1 This message only states what was received. Consider clarifying that only 3D and 4D tensors are supported for Additionally, you can make the code slightly cleaner by checking for the invalid
Suggested change
Style Guide ReferencesFootnotes
|
||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| mean_tensor = ops.reshape(mean_tensor, (1,) * (ndim - 1) + (3,)) | ||||||||||||||||||||||||||||||||||
| x += mean_tensor | ||||||||||||||||||||||||||||||||||
| if std is not None: | ||||||||||||||||||||||||||||||||||
| std_tensor = ops.convert_to_tensor(np.array(std), dtype=x.dtype) | ||||||||||||||||||||||||||||||||||
| if data_format == "channels_first": | ||||||||||||||||||||||||||||||||||
| std_tensor = ops.reshape(std_tensor, (-1, 1, 1)) | ||||||||||||||||||||||||||||||||||
| if ndim == 3: | ||||||||||||||||||||||||||||||||||
| std_tensor = ops.reshape(std_tensor, (3, 1, 1)) | ||||||||||||||||||||||||||||||||||
| elif ndim == 4: | ||||||||||||||||||||||||||||||||||
| std_tensor = ops.reshape(std_tensor, (1, 3, 1, 1)) | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||
| f"Unsupported shape for channels_first: {x.shape}" | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| std_tensor = ops.reshape(std_tensor, (1,) * (ndim - 1) + (3,)) | ||||||||||||||||||||||||||||||||||
|
Comment on lines
292
to
+302
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic for reshaping |
||||||||||||||||||||||||||||||||||
| x /= std_tensor | ||||||||||||||||||||||||||||||||||
| return x | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -175,10 +175,13 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs): | |
| **kwargs: Additional keyword arguments passed to `PIL.Image.save()`. | ||
| """ | ||
| data_format = backend.standardize_data_format(data_format) | ||
| # Normalize jpg → jpeg | ||
| if file_format is not None and file_format.lower() == "jpg": | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please revert changes in this file |
||
| file_format = "jpeg" | ||
| img = array_to_img(x, data_format=data_format, scale=scale) | ||
| if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"): | ||
| if img.mode == "RGBA" and file_format == "jpeg": | ||
| warnings.warn( | ||
| "The JPG format does not support RGBA images, converting to RGB." | ||
| "The JPEG format does not support RGBA images, converting to RGB." | ||
| ) | ||
| img = img.convert("RGB") | ||
| img.save(path, format=file_format, **kwargs) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use a TestCase subclass and use
self.assert...methods instead of naked assert staments