Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Owlv2 Performance #1065

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open

Fix Owlv2 Performance #1065

wants to merge 20 commits into from

Conversation

balthazur
Copy link
Contributor

@balthazur balthazur commented Mar 7, 2025

Description

  • Remove compilation in background as it failed (Fix for Inference Internal)
  • Improve SerializeOwlV2 class to keep reference of base class to enable (fix for slow serialization/training)

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

How has this change been tested, please provide a testcase or example of how you tested the change?

locally and on staging

Any specific deployment considerations

For example, documentation changes, usability, usage/costs, secrets, etc.

Docs

  • Docs updated? What were the changes:

@balthazur balthazur changed the title Draft: Investigage Owlv2 performance Fix Owlv2 Performance Mar 10, 2025
@balthazur balthazur marked this pull request as ready for review March 10, 2025 13:01
@balthazur balthazur self-assigned this Mar 10, 2025
Copy link
Contributor

@isaacrob-roboflow isaacrob-roboflow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would love to better understand why we have both an owlv2 singleton and a reference dict, as it seems like those are intended to serve similar functions and the fact that both are necessary seems like I'm either missing something or there's a hidden bug

Comment on lines +714 to +717
# Cache of OwlV2 instances to avoid creating new ones for each serialize_training_data call
# This improves performance by reusing model instances across serialization operations
_base_owlv2_instances = {}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems like this dict serves a redundant purpose with the singleton .. might be simpler to maintain this long term if we either take out the singleton or fix the singleton such that this dict doesn't have to exist? or am I misunderstanding what's happening here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This dict is only for class SerializedOwlV2, which is another wrapper around the class OwlV2, and has the classmethod serialize_training_data. This function creates an OWLv2 Instance like owlv2 = OwlV2(model_id=roboflow_id) every time the function is called (everytime a training job comes in), which makes it create a new Instance every time without using the Singleton.

I'll try to think of a better way, but thought it would be fine because its a wrapper class for the serialization progress.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def save_embeddings(training_data, savedirprefix, previous_embeddings_file=None):
    from inference.models.owlv2.owlv2 import SerializedOwlV2
    from inference.core.env import OWLV2_VERSION_ID
    from adapters import logging_adapter

    save_dir=f"/tmp/{savedirprefix}/embeddings"
    os.makedirs(save_dir, exist_ok=True)

    total_images = len(training_data)
    logging_adapter.info(f"Starting embedding generation for {total_images} images")
    
    start_time = time.time()
    embeddings_pt = SerializedOwlV2.serialize_training_data(
        training_data=training_data,
        hf_id=f"google/{OWLV2_VERSION_ID}",
        save_dir=save_dir,
        previous_embeddings_file=previous_embeddings_file
    )
    

This function above calls the serialize_training_data every time, and was creating a new OWLv2 Instance with different context and callstack everytime.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@isaacrob-roboflow Any suggestions how to unify the Singleton with the dict in Inference?

@balthazur
Copy link
Contributor Author

@isaacrob-roboflow I investigated a bit more and tried a few different options. I created a short summary using ChatGPT to bring my messy notes in order:

The Problem

The issue is that the Owlv2Singleton class uses a weakref.WeakValueDictionary() to store model instances. This means that when there's no strong reference to the singleton object, it can be garbage collected.

When serialize_training_data() is called multiple times (for different training jobs), each call creates an OWLv2 instance like:

owlv2 = OwlV2(model_id=roboflow_id)

After the method completes, this instance is no longer referenced and gets garbage collected, along with its caches and potentially the singleton model if nothing else holds a reference to it.

Why Our Solution Works

The class dictionary _base_owlv2_instances solves this by maintaining strong references to the OWLv2 instances:

if roboflow_id in cls._base_owlv2_instances:
    owlv2 = cls._base_owlv2_instances[roboflow_id]
else:
    owlv2 = OwlV2(model_id=roboflow_id)
    cls._base_owlv2_instances[roboflow_id] = owlv2

This ensures that between calls to serialize_training_data(), we reuse the same OWLv2 instance with its caches intact, improving performance significantly.

Why Not Just Fix the Singleton?

We could change the singleton to use a regular dictionary instead of a weak dictionary, and that would help maintain the model. However, there's still an issue:

  • The singleton only holds the heavy model (from Hugging Face).
  • The important caches (image embeddings, etc.) are stored on the OWLv2 instance itself, not in the singleton.
  • We need to maintain a reference to the complete OWLv2 instance to keep those caches.

Current Implementation

The current implementation is straightforward and effective:

  • SerializedOwlV2._base_owlv2_instances maintains OWLv2 instances by model ID.
  • This is specifically used for the serialization service, not affecting the base OWLv2 class's general inference usage.
  • It maintains a persistent OWLv2 instance with all its caches across multiple calls.
  • While the dictionary structure is simple (typically just containing one entry like {'owlv2/owlv2-large-patch14-ensemble': <OwlV2 object>}), using a dictionary gives us flexibility to handle different OWLv2 versions if needed in the future.

For now, this is a reasonable solution that solves the issue without requiring major refactoring. The dictionary is simple but effective at maintaining state between calls.

If you'd like to refactor this later to unify the caching approach, that would be great, but the current implementation works well for our immediate needs. Also, if you have ideas for a better fix please feel free! Happy to apply them over mine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants