-
Notifications
You must be signed in to change notification settings - Fork 135
[Part 2] Image Transformer #721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jonpsy
wants to merge
8
commits into
tensorflow:master
Choose a base branch
from
jonpsy:transformer
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+529
−0
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
34223da
init push
jonpsy a1624c4
Image transformer logic
jonpsy 775e853
Build file.
jonpsy 6bd88dc
testing build
jonpsy 0c80524
Input images
jonpsy a97c5fe
Test logic
jonpsy 06d6cc3
esrgan models
jonpsy a3cf1f7
protos
jonpsy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
123 changes: 123 additions & 0 deletions
123
tensorflow_lite_support/cc/task/vision/image_transformer.cc
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include "tensorflow_lite_support/cc/task/vision/image_transformer.h" | ||
|
||
#include "external/com_google_absl/absl/strings/str_format.h" | ||
#include "external/com_google_absl/absl/strings/string_view.h" | ||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers | ||
#include "tensorflow_lite_support/cc/port/status_macros.h" | ||
#include "tensorflow_lite_support/cc/task/core/task_api_factory.h" | ||
|
||
namespace tflite { | ||
namespace task { | ||
namespace vision { | ||
|
||
namespace { | ||
|
||
using ::absl::StatusCode; | ||
using ::tflite::support::CreateStatusWithPayload; | ||
using ::tflite::support::StatusOr; | ||
using ::tflite::support::TfLiteSupportStatus; | ||
using ::tflite::task::core::AssertAndReturnTypedTensor; | ||
using ::tflite::task::core::TaskAPIFactory; | ||
using ::tflite::task::core::TfLiteEngine; | ||
using ::tflite::task::vision::FrameBuffer; | ||
} // namespace | ||
|
||
/* static */ | ||
StatusOr<std::unique_ptr<ImageTransformer>> ImageTransformer::CreateFromOptions( | ||
const ImageTransformerOptions& options, | ||
std::unique_ptr<tflite::OpResolver> resolver) { | ||
RETURN_IF_ERROR(SanityCheckOptions(options)); | ||
|
||
// Copy options to ensure the ExternalFile outlives the constructed object. | ||
auto options_copy = absl::make_unique<ImageTransformerOptions>(options); | ||
|
||
std::unique_ptr<ImageTransformer> image_transformer; | ||
|
||
ASSIGN_OR_RETURN(image_transformer, | ||
TaskAPIFactory::CreateFromBaseOptions<ImageTransformer>( | ||
&options_copy->base_options(), std::move(resolver))); | ||
|
||
RETURN_IF_ERROR(image_transformer->Init(std::move(options_copy))); | ||
return image_transformer; | ||
} | ||
|
||
/* static */ | ||
absl::Status ImageTransformer::SanityCheckOptions( | ||
const ImageTransformerOptions& options) { | ||
// Nothing to do. | ||
return absl::OkStatus(); | ||
} | ||
|
||
absl::Status ImageTransformer::Init( | ||
std::unique_ptr<ImageTransformerOptions> options) { | ||
// Set options. | ||
options_ = std::move(options); | ||
|
||
// Perform pre-initialization actions (by default, sets the process engine for | ||
// image pre-processing to kLibyuv as a sane default). | ||
RETURN_IF_ERROR(PreInit()); | ||
|
||
// Sanity check and set inputs and outputs. | ||
RETURN_IF_ERROR(CheckAndSetInputs()); | ||
RETURN_IF_ERROR(CheckAndSetOutputs()); | ||
|
||
RETURN_IF_ERROR(PostInit()); | ||
|
||
ASSIGN_OR_RETURN(postprocessor_, processor::ImagePostprocessor::Create( | ||
GetTfLiteEngine(), {0}, {0})); | ||
|
||
return absl::OkStatus(); | ||
} | ||
|
||
absl::Status ImageTransformer::PreInit() { | ||
SetProcessEngine(FrameBufferUtils::ProcessEngine::kLibyuv); | ||
return absl::OkStatus(); | ||
} | ||
|
||
absl::Status ImageTransformer::PostInit() { | ||
// Nothing to do. | ||
return absl::OkStatus(); | ||
} | ||
|
||
absl::Status ImageTransformer::CheckAndSetOutputs() { | ||
// Nothing to do. | ||
return absl::OkStatus(); | ||
} | ||
|
||
StatusOr<FrameBuffer> ImageTransformer::Transform( | ||
const FrameBuffer& frame_buffer) { | ||
BoundingBox roi; | ||
roi.set_width(frame_buffer.dimension().width); | ||
roi.set_height(frame_buffer.dimension().height); | ||
return Transform(frame_buffer, roi); | ||
} | ||
|
||
StatusOr<FrameBuffer> ImageTransformer::Transform( | ||
const FrameBuffer& frame_buffer, const BoundingBox& roi) { | ||
return InferWithFallback(frame_buffer, roi); | ||
} | ||
|
||
StatusOr<FrameBuffer> ImageTransformer::Postprocess( | ||
const std::vector<const TfLiteTensor*>& /*output_tensors*/, | ||
const FrameBuffer& /*frame_buffer*/, const BoundingBox& /*roi*/) { | ||
ASSIGN_OR_RETURN(auto postprocessed_output, postprocessor_->Postprocess()); | ||
return postprocessed_output; | ||
} | ||
} // namespace vision | ||
} // namespace task | ||
} // namespace tflite |
138 changes: 138 additions & 0 deletions
138
tensorflow_lite_support/cc/task/vision/image_transformer.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_TRANSFORMER_H_ | ||
#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_TRANSFORMER_H_ | ||
|
||
#include <memory> | ||
#include <vector> | ||
|
||
#include "tensorflow/lite/core/api/op_resolver.h" | ||
#include "tensorflow/lite/core/shims/cc/kernels/register.h" | ||
#include "tensorflow_lite_support/cc/port/statusor.h" | ||
#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" | ||
#include "tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h" | ||
#include "tensorflow_lite_support/cc/task/vision/proto/image_transformer_options_proto_inc.h" | ||
#include "tensorflow_lite_support/cc/task/processor/image_postprocessor.h" | ||
|
||
namespace tflite { | ||
namespace task { | ||
namespace vision { | ||
|
||
// Performs transformation on images. | ||
// | ||
// The API expects a TFLite model with optional, but strongly recommended, | ||
// TFLite Model Metadata. | ||
// | ||
// Input tensor: | ||
// (kTfLiteUInt8/kTfLiteFloat32) | ||
// - image input of size `[batch x height x width x channels]`. | ||
// - batch inference is not supported (`batch` is required to be 1). | ||
// - only RGB inputs are supported (`channels` is required to be 3). | ||
// - if type is kTfLiteFloat32, NormalizationOptions are required to be | ||
// attached to the metadata for input normalization. | ||
// At least one output tensor with: | ||
// (kTfLiteUInt8/kTfLiteFloat32) | ||
// - `N `classes and either 2 or 4 dimensions, i.e. `[1 x N]` or | ||
// `[1 x 1 x 1 x N]` | ||
// - optional (but recommended) label map(s) as AssociatedFile-s with type | ||
// TENSOR_AXIS_LABELS, containing one label per line. The first such | ||
// AssociatedFile (if any) is used to fill the `class_name` field of the | ||
// results. The `display_name` field is filled from the AssociatedFile (if | ||
// any) whose locale matches the `display_names_locale` field of the | ||
// `ImageTransformerOptions` used at creation time ("en" by default, i.e. | ||
// English). If none of these are available, only the `index` field of the | ||
// results will be filled. | ||
// | ||
// An example of such model can be found at: | ||
// https://tfhub.dev/bohemian-visual-recognition-alliance/lite-model/models/mushroom-identification_v1/1 | ||
// | ||
// A CLI demo tool is available for easily trying out this API, and provides | ||
// example usage. See: | ||
// examples/task/vision/desktop/image_classifier_demo.cc | ||
class ImageTransformer : public BaseVisionTaskApi<FrameBuffer> { | ||
public: | ||
using BaseVisionTaskApi::BaseVisionTaskApi; | ||
|
||
// Creates an ImageTransformer from the provided options. A non-default | ||
// OpResolver can be specified in order to support custom Ops or specify a | ||
// subset of built-in Ops.f | ||
static tflite::support::StatusOr<std::unique_ptr<ImageTransformer>> | ||
CreateFromOptions( | ||
const ImageTransformerOptions& options, | ||
std::unique_ptr<tflite::OpResolver> resolver = | ||
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()); | ||
|
||
// Performs actual transformation on the provided FrameBuffer. | ||
// | ||
// The FrameBuffer can be of any size and any of the supported formats, i.e. | ||
// RGBA, RGB, NV12, NV21, YV12, YV21. It is automatically pre-processed before | ||
// inference in order to (and in this order): | ||
// - resize it (with bilinear interpolation, aspect-ratio *not* preserved) to | ||
// the dimensions of the model input tensor, | ||
// - convert it to the colorspace of the input tensor (i.e. RGB, which is the | ||
// only supported colorspace for now), | ||
// - rotate it according to its `Orientation` so that inference is performed | ||
// on an "upright" image. | ||
tflite::support::StatusOr<FrameBuffer> Transform( | ||
const FrameBuffer& frame_buffer); | ||
|
||
// Same as above, except that the transformation is performed based on the | ||
// input region of interest. Cropping according to this region of interest is | ||
// prepended to the pre-processing operations. | ||
// | ||
// IMPORTANT: as a consequence of cropping occurring first, the provided | ||
// region of interest is expressed in the unrotated frame of reference | ||
// coordinates system, i.e. in `[0, frame_buffer.width) x [0, | ||
// frame_buffer.height)`, which are the dimensions of the underlying | ||
// `frame_buffer` data before any `Orientation` flag gets applied. Also, the | ||
// region of interest is not clamped, so this method will return a non-ok | ||
// status if the region is out of these bounds. | ||
tflite::support::StatusOr<FrameBuffer> Transform( | ||
const FrameBuffer& frame_buffer, const BoundingBox& roi); | ||
|
||
protected: | ||
// The options used to build this ImageTransformer. | ||
std::unique_ptr<ImageTransformerOptions> options_; | ||
|
||
// Post-processing to transform the raw model outputs into image results. | ||
tflite::support::StatusOr<FrameBuffer> Postprocess( | ||
const std::vector<const TfLiteTensor*>& output_tensors, | ||
const FrameBuffer& frame_buffer, const BoundingBox& roi) override; | ||
|
||
// Performs sanity checks on the provided ImageTransformerOptions. | ||
static absl::Status SanityCheckOptions(const ImageTransformerOptions& options); | ||
|
||
// Initializes the ImageTransformer from the provided ImageTransformerOptions, | ||
// whose ownership is transferred to this object. | ||
absl::Status Init(std::unique_ptr<ImageTransformerOptions> options); | ||
|
||
// Performs pre-initialization actions. | ||
virtual absl::Status PreInit(); | ||
// Performs post-initialization actions. | ||
virtual absl::Status PostInit(); | ||
|
||
private: | ||
// Performs sanity checks on the model outputs and extracts their metadata. | ||
absl::Status CheckAndSetOutputs(); | ||
|
||
std::unique_ptr<processor::ImagePostprocessor> postprocessor_; | ||
}; | ||
|
||
} // namespace vision | ||
} // namespace task | ||
} // namespace tflite | ||
|
||
#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_TRANSFORMER_H_ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
28 changes: 28 additions & 0 deletions
28
tensorflow_lite_support/cc/task/vision/proto/image_transformer_options.proto
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
syntax = "proto2"; | ||
|
||
package tflite.task.vision; | ||
|
||
import "tensorflow_lite_support/cc/task/core/proto/base_options.proto"; | ||
|
||
// Options for setting up an ImageTransformer. | ||
// Next Id: 10. | ||
message ImageTransformerOptions { | ||
// Base options for configuring Task library, such as specifying the TfLite | ||
// model file with metadata, accelerator options, etc. | ||
optional tflite.task.core.BaseOptions base_options = 1; | ||
} |
23 changes: 23 additions & 0 deletions
23
tensorflow_lite_support/cc/task/vision/proto/image_transformer_options_proto_inc.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_TRANSFORMER_OPTIONS_PROTO_INC_H_ | ||
#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_TRANSFORMER_OPTIONS_PROTO_INC_H_ | ||
|
||
#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h" | ||
#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" | ||
|
||
#include "tensorflow_lite_support/cc/task/vision/proto/image_transformer_options.pb.h" | ||
#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_TRANSFORMER_OPTIONS_PROTO_INC_H_ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
145 changes: 145 additions & 0 deletions
145
tensorflow_lite_support/cc/test/task/vision/image_transformer_test.cc
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include "tensorflow_lite_support/cc/task/vision/image_transformer.h" | ||
|
||
#include <memory> | ||
|
||
#include "absl/status/status.h" // from @com_google_absl | ||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||
#include "tensorflow_lite_support/cc/port/gmock.h" | ||
#include "tensorflow_lite_support/cc/port/gtest.h" | ||
#include "tensorflow_lite_support/cc/port/status_matchers.h" | ||
#include "tensorflow_lite_support/cc/task/core/task_utils.h" | ||
#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" | ||
#include "tensorflow_lite_support/cc/test/test_utils.h" | ||
#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" | ||
|
||
namespace tflite { | ||
namespace task { | ||
namespace vision { | ||
namespace { | ||
|
||
using ::tflite::support::StatusOr; | ||
using ::tflite::task::JoinPath; | ||
using ::tflite::task::core::TfLiteEngine; | ||
|
||
|
||
constexpr char kTestDataDirectory[] = | ||
"/tensorflow_lite_support/cc/test/testdata/task/" | ||
"vision/"; | ||
|
||
constexpr char kESRGANModelWithInputAndOutputMetaData[] = "esrgan_with_input_and_output_metadata.tflite"; | ||
constexpr char kESRGANModelWithInputMetaData[] = "esrgan_with_input_metadata.tflite"; | ||
|
||
StatusOr<ImageData> LoadImage(std::string image_name) { | ||
return DecodeImageFromFile(JoinPath("./" /*test src dir*/, | ||
kTestDataDirectory, image_name)); | ||
} | ||
|
||
class PostprocessorTest : public tflite_shims::testing::Test {}; | ||
|
||
TEST_F(PostprocessorTest, FloatSucceedsWithFullMetadata) { | ||
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, LoadImage("husky_downsampled.jpg")); | ||
|
||
std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( | ||
rgb_image.pixel_data, | ||
FrameBuffer::Dimension{rgb_image.width, rgb_image.height}); | ||
ImageTransformerOptions options; | ||
options.mutable_base_options()->mutable_model_file()->set_file_name( | ||
JoinPath("./" /*test src dir*/, kTestDataDirectory, | ||
kESRGANModelWithInputAndOutputMetaData)); | ||
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageTransformer> image_transformer, | ||
ImageTransformer::CreateFromOptions(options)); | ||
|
||
StatusOr<FrameBuffer> result_or = | ||
image_transformer->Transform(*frame_buffer); | ||
ImageDataFree(&rgb_image); | ||
SUPPORT_ASSERT_OK(result_or); | ||
} | ||
|
||
TEST_F(PostprocessorTest, FloatSucceedsWithPartialMetadata) { | ||
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData rgb_image, LoadImage("husky_downsampled.jpg")); | ||
|
||
std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( | ||
rgb_image.pixel_data, | ||
FrameBuffer::Dimension{rgb_image.width, rgb_image.height}); | ||
ImageTransformerOptions options; | ||
options.mutable_base_options()->mutable_model_file()->set_file_name( | ||
JoinPath("./" /*test src dir*/, kTestDataDirectory, | ||
kESRGANModelWithInputMetaData)); | ||
|
||
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageTransformer> image_transformer, | ||
ImageTransformer::CreateFromOptions(options)); | ||
|
||
StatusOr<FrameBuffer> result_or = | ||
image_transformer->Transform(*frame_buffer); | ||
ImageDataFree(&rgb_image); | ||
SUPPORT_ASSERT_OK(result_or); | ||
} | ||
|
||
class SuperResolutionTest : public tflite_shims::testing::Test {}; | ||
|
||
// Calculate the peak signal-to-noise ratio. | ||
// Original code: https://www.geeksforgeeks.org/python-peak-signal-to-noise-ratio-psnr/. | ||
double PSNR(const FrameBuffer& enhancedImage, const FrameBuffer& testImage) { | ||
int imageSize = testImage.dimension().width * testImage.dimension().height; | ||
const uint8* enhancedImagePtr = enhancedImage.plane(0).buffer; | ||
const uint8* testImagePtr = testImage.plane(0).buffer; | ||
double mse = 0.0; | ||
for (int i = 0; i < imageSize; ++i, ++enhancedImagePtr, ++testImagePtr) { | ||
mse += std::pow(static_cast<double>(*enhancedImagePtr) - static_cast<double>(*testImagePtr), 2); | ||
} | ||
mse /= imageSize; | ||
|
||
// Zero MSE means no noise is present in the signal. | ||
double psnr = mse == 0 ? 100.0 : 20 * std::log10(255.0 / std::sqrt(mse)); | ||
|
||
return psnr; | ||
} | ||
|
||
// Use a bi-cubically downsampled image as input to the model and compare | ||
// the model output with the original image. | ||
TEST_F(SuperResolutionTest, GoldenImageComparisonTest) { | ||
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData husky_downsampled, LoadImage("husky_downsampled.jpg")); | ||
SUPPORT_ASSERT_OK_AND_ASSIGN(ImageData husky_original, LoadImage("husky_original.jpg")); | ||
|
||
std::unique_ptr<FrameBuffer> husky_downsampled_buffer = CreateFromRgbRawBuffer( | ||
husky_downsampled.pixel_data, | ||
FrameBuffer::Dimension{husky_downsampled.width, husky_downsampled.height}); | ||
|
||
std::unique_ptr<FrameBuffer> husky_original_buffer = CreateFromRgbRawBuffer( | ||
husky_original.pixel_data, | ||
FrameBuffer::Dimension{husky_original.width, husky_original.height}); | ||
|
||
ImageTransformerOptions options; | ||
options.mutable_base_options()->mutable_model_file()->set_file_name( | ||
JoinPath("./" /*test src dir*/, kTestDataDirectory, | ||
kESRGANModelWithInputAndOutputMetaData)); | ||
SUPPORT_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageTransformer> image_transformer, | ||
ImageTransformer::CreateFromOptions(options)); | ||
|
||
StatusOr<FrameBuffer> result_or = | ||
image_transformer->Transform(*husky_downsampled_buffer); | ||
SUPPORT_ASSERT_OK(result_or); | ||
EXPECT_DOUBLE_EQ(PSNR(result_or.value(), *husky_original_buffer), 25.073790631326489); | ||
ImageDataFree(&husky_downsampled); | ||
ImageDataFree(&husky_original); | ||
} | ||
|
||
} // namespace | ||
} // namespace processor | ||
} // namespace task | ||
} // namespace tflite |
Binary file added
BIN
+4.76 MB
...ow_lite_support/cc/test/testdata/task/vision/esrgan_with_input_and_output_metadata.tflite
Binary file not shown.
Binary file added
BIN
+4.76 MB
tensorflow_lite_support/cc/test/testdata/task/vision/esrgan_with_input_metadata.tflite
Binary file not shown.
Binary file added
BIN
+1.44 KB
tensorflow_lite_support/cc/test/testdata/task/vision/husky_downsampled.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+7.94 KB
tensorflow_lite_support/cc/test/testdata/task/vision/husky_original.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Building on your idea of storing the "husky_test.png" (the output I got from dumping pipeline out into file => reading in ipynb => saving as png). I thought of using that as "golden image" to compare with
result_or.value()
. The problem isLoadImage
due to which the pixels are a bit off. So I think PSNR solution is the best stand-in solution atm.