-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Add save/load to Embeddings #8818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Adds save/load functionality to the Embeddings class, enabling persistence of embeddings indices to disk for fast loading without recomputing embeddings.
- Implements
save()
,load()
, andfrom_saved()
methods for the Embeddings class - Adds comprehensive test coverage for save/load functionality including error handling
- Removes the TODO comment about adding save/load methods
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
dspy/retrievers/embeddings.py | Implements save/load methods with pickle for config, numpy for embeddings, and FAISS index persistence |
tests/retrievers/test_embeddings.py | Adds comprehensive tests for save/load functionality and error cases |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
self.index = None | ||
|
||
# Reinitialize the search function | ||
self.search_fn = Unbatchify(self._batch_forward) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to reinitialize the search_fn?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, this is because from_saved
bypasses __init__
# but we can still save the embeddings for brute force search | ||
pass | ||
|
||
def load(self, path: str, embedder): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do we want this load
to be called? Do we want users to first create an Embedding
instance then call embedding.load()
, or make it a class method that return a loaded embedding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually only for consistency with other APIs like module.load. I guess mostly people will use from_saved. Do you think we should make load a classmethod?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha, I am asking because of the line self.search_fn = Unbatchify(self._batch_forward)
, if we do:
embedder = dspy.Embeddings(...)
embedder.load(...)
do we still need this self.search_fn = Unbatchify(self._batch_forward)
?
LGTM after #8818 (comment) is resolved, thank you! |
Closes #8807