-
Notifications
You must be signed in to change notification settings - Fork 280
fix get_lora_target_names function #2167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
fix get_lora_target_names function #2167
Conversation
@@ -202,7 +204,7 @@ def enable_lora(self, rank, target_names=None): | |||
of the attention layers. | |||
""" | |||
if target_names is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't need this if
anymore!
keras_hub/src/models/backbone.py
Outdated
@@ -202,7 +204,7 @@ def enable_lora(self, rank, target_names=None): | |||
of the attention layers. | |||
""" | |||
if target_names is None: | |||
target_names = self.get_lora_target_names() | |||
target_names = self.get_lora_target_names(target_names=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_lora_target_names()
won't still work correctly in the following code and will return the default values:
model.enable_lora(rank=2, target_names=["query"])
model.get_lora_target_names()
I don't think we should expect the user to pass the target names to get the target names. If they already know what the target names are, they won't have to call get_lora_target_names()
.
Maybe it'd be good to define a new member variable like lora_target_names
and define setter/getter for it and initialize it with default values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Divya! Left some nit comments!
@@ -203,6 +211,8 @@ def enable_lora(self, rank, target_names=None): | |||
""" | |||
if target_names is None: | |||
target_names = self.get_lora_target_names() | |||
else: | |||
self._lora_target_names = target_names |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can use the setter here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's avoid the setter to statefulness. backbone.enable_lora(4)
should always mean the model's default targets no matter what the user has done before, and backbone.enable_lora(4, target_layer_names=["query", "key", "value"])
should mean a custom set of selections.
|
||
def set_lora_target_names(self, target_names): | ||
"""Set the list of layer names which are to be LoRA-fied. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it'd be good to check to make sure target_names is a list so it'll errors out if someone does set_lora_target_names("query")
def set_lora_target_names(self, target_names): | ||
"""Set the list of layer names which are to be LoRA-fied. | ||
""" | ||
self._lora_target_names = target_names |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assigns target_names reference to self._lora_target_names so if target_names is changed outside of this setter, it'll impact self._lora_target_names. It'd be better to do something like: self._lora_target_names = list(target_names)
to copy the elements.
return ["query_dense", "value_dense", "query", "value"] | ||
return self._lora_target_names | ||
|
||
def set_lora_target_names(self, target_names): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not have a setter like this. In fact get_lora_target_names
is really only for model implementations and should probably be renamed default_lora_layer_names
in this cl (do a find and replace to clean this up everywhere).
@@ -50,6 +50,12 @@ def __init__(self, *args, dtype=None, **kwargs): | |||
id(layer) for layer in self._flatten_layers() | |||
) | |||
self._initialized = True | |||
self._lora_target_names = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove this attr and move this back to get_default_lora_targets
.
def set_lora_target_names(self, target_names): | ||
"""Set the list of layer names which are to be LoRA-fied. | ||
""" | ||
self._lora_target_names = target_names | ||
|
||
def enable_lora(self, rank, target_names=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rename this to target_layer_names
, it's a little more descriptive.
We should also document the args to enable_lora
here, this is user facing. In the description for target_layer_names
describe that None
will populate this with target names with the models defaults, as returned by backbone.default_lora_layer_names()
.
@@ -188,11 +194,13 @@ def save_to_preset(self, preset_dir): | |||
|
|||
def get_lora_target_names(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's rename this everywhere to default_lora_layer_names
. Potentially we should name this a property too, without a setter.
@@ -203,6 +211,8 @@ def enable_lora(self, rank, target_names=None): | |||
""" | |||
if target_names is None: | |||
target_names = self.get_lora_target_names() | |||
else: | |||
self._lora_target_names = target_names |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's avoid the setter to statefulness. backbone.enable_lora(4)
should always mean the model's default targets no matter what the user has done before, and backbone.enable_lora(4, target_layer_names=["query", "key", "value"])
should mean a custom set of selections.
@divyashreepathihalli are you still working on this? |
This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you. |
No description provided.