Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perceptual Similarity loss #20844

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open

Conversation

tristan-deep
Copy link

@tristan-deep tristan-deep commented Feb 2, 2025

As described in #20839.

Conversion of weights partly based on an implementation found here.

Added

TODO

  • add to Metrics?
  • include tests
  • upload weights

I uploaded the LPIPS weights to Hugging Face, as I'm not sure how to upload to storage.googleapis.com.

Testing code snippet

Tests both the model and loss object and compares to the torch metrics implementation.

# !pip install torch torchmetrics
# !pip install scikit-image

import torch
from skimage import data
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

from keras import ops
from keras.src.applications import lpips
from keras.src.losses import PerceptualSimilarity

## dummy images
image1 = data.camera()
image2 = data.gravel()

expected_lpips_value = 0.7237163782119751
print(f"Expected LPIPS value: {expected_lpips_value}")

image1 = image1[None, ..., None].astype("float32")
image2 = image2[None, ..., None].astype("float32")

image1 = ops.repeat(image1, 3, axis=-1)
image2 = ops.repeat(image2, 3, axis=-1)

# placeholder for weights
# https://huggingface.co/tristan-deep/lpips/blob/main/lpips_vgg16.weights.h5
custom_weight_path = <add path to weights here> 

## Keras model
model = lpips.LPIPS(weights=custom_weight_path)

model.summary()
print("Model loaded")

image1_tensor = lpips.preprocess_input(image1)
image2_tensor = lpips.preprocess_input(image2)

score = model([image1_tensor, image2_tensor])
print(f"Keras LPIPS score: {score}")

## Loss
loss = PerceptualSimilarity(weights=custom_weight_path)
value = loss(image1, image2)
print(f"Keras loss value: {value}")


## torchmetrics check
def preprocess_input_for_torch(x):
    # Convert from [0, 255] to [-1, 1]
    x = x / 127.5 - 1
    # Convert to numpy array
    x = ops.convert_to_numpy(x)
    # Rearrange from NHWC to NCHW format
    x = x.transpose(0, 3, 1, 2)
    # Convert to torch tensor
    x = torch.tensor(x)
    return x


image1_torch = preprocess_input_for_torch(image1)
image2_torch = preprocess_input_for_torch(image2)

# needs to be in range [-1, 1]
lpips = LearnedPerceptualImagePatchSimilarity(
    net_type="vgg",
    normalize=False,
)
score = lpips(image1_torch, image2_torch)

print(f"Torch LPIPS score: {score}")

Copy link

google-cla bot commented Feb 2, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@codecov-commenter
Copy link

codecov-commenter commented Feb 2, 2025

Codecov Report

Attention: Patch coverage is 27.39726% with 53 lines in your changes missing coverage. Please review.

Project coverage is 61.15%. Comparing base (fc1b26d) to head (4709571).
Report is 22 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/applications/lpips.py 24.61% 49 Missing ⚠️
keras/api/_tf_keras/keras/applications/__init__.py 0.00% 2 Missing ⚠️
...api/_tf_keras/keras/applications/lpips/__init__.py 0.00% 2 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (fc1b26d) and HEAD (4709571). Click for more details.

HEAD has 6 uploads less than BASE
Flag BASE (fc1b26d) HEAD (4709571)
keras 5 2
keras-torch 1 0
keras-tensorflow 1 0
keras-jax 1 0
Additional details and impacted files
@@             Coverage Diff             @@
##           master   #20844       +/-   ##
===========================================
- Coverage   82.04%   61.15%   -20.89%     
===========================================
  Files         559      564        +5     
  Lines       52367    52761      +394     
  Branches     8096     8154       +58     
===========================================
- Hits        42964    32268    -10696     
- Misses       7427    18522    +11095     
+ Partials     1976     1971        -5     
Flag Coverage Δ
keras 61.14% <27.39%> (-20.71%) ⬇️
keras-jax ?
keras-numpy 59.07% <27.39%> (+0.09%) ⬆️
keras-openvino 32.67% <27.39%> (+2.85%) ⬆️
keras-tensorflow ?
keras-torch ?
keras.applications 22.86% <24.61%> (?)
keras.applications-numpy 22.86% <24.61%> (?)
keras.applications-openvino 22.86% <24.61%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

from keras.src.models import Functional
from keras.src.utils import file_utils

WEIGHTS_PATH = (
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct weights path should still be added here. Currently I've put a placeholder. I uploaded the LPIPS weights to Hugging Face, as I'm not sure how to upload to storage.googleapis.com.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently added test will fail because it cannot find the weights at that location. Easy way to try out the tests locally is to change this line:

with

model = lpips.LPIPS(weights=<path_to_local_copy>)

@tristan-deep
Copy link
Author

@matt-gardner does it make more sense to leave LPIPS as an application and not include in the losses? That way people can always use it through a custom loss definition, without cluttering / redesigning the loss api for feature extraction losses.

@fchollet
Copy link
Collaborator

Thanks for the PR!

Due to the way this method works, it does not seem to be a good fit for the keras.loss API. In particular losses should not be importing Applications.

I think the best way to demonstrate this loss might be one of:

  • A standalone github repo factored as a reusable template showing how to use LPIPS with any model
  • A keras.io code example showing the same

The easiest way to leverage LPIPS might be to subclass a model and override the compute_loss() method.

@tristan-deep
Copy link
Author

@fchollet Thanks for the quick look! I agree that currently it doesn't really fit in the loss api and have removed it. Do you think LPIPS is still useful as an application within Keras? In that case, how do we handle the upload of the weights here?

As a follow up PR for keras.io, a code example could demonstrate it in a training loop for instance as you suggest.

@fchollet
Copy link
Collaborator

Do you think LPIPS is still useful as an application within Keras? In that case, how do we handle the upload of the weights here?

We aren't adding any new models to the keras.applications API at the moment, since all new models are going to keras_hub instead. So I recommend not adding it there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants