-
Notifications
You must be signed in to change notification settings - Fork 1
feat: add implementation of Lambda regularization #216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Neonkraft
commented
Feb 7, 2025
•
edited
Loading
edited
- Implemented Lambda-DARTS for DARTS, NB201 and TNB101.
- Added tests.
@@ -457,6 +458,9 @@ def _train_epoch( # noqa: C901 | |||
if isinstance(unwrapped_network, LayerAlignmentScoreSupport): | |||
unwrapped_network.update_layer_alignment_scores() | |||
|
|||
if isinstance(unwrapped_network, LambdaDARTSSupport): | |||
unwrapped_network.add_lambda_regularization(base_inputs, base_targets) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You missed adding criterion here with parameters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch, thanks!
@@ -34,6 +38,7 @@ def __init__( | |||
lora_toggler: LoRAToggler | None = None, | |||
is_arch_attention_enabled: bool = False, | |||
regularizer: Regularizer | None = None, | |||
lambda_regularizer: LambdaReg | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed that the lambda darts regularization is always turned on. I suppose you have written the function disable_lambda_darts()
. You can by deafult disable it. Perhaps when you are accepting a lambda_regularizer, if it is not None, enable it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LambdaReg
has enabled=True
by default. But LambdaReg
is only instantiated when the user sets use_lambda_regularizer
to True
in the profile. Further, the user can configure the lambda_regularizer and disable it.
I found another bug when looking into this, actually. The code would crash when use_lambda_regularizer=False
. I've fixed that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, that's what I was talking about earlier. I was testing the case when use_lambda_regularizer=False
, and found that the regularisation was still occurring.
I'm happy with the workflow of the lambda darts except one thing. There is one thing missing in our code. While calculating forward and backward grads, they dont use softmax at all (save softmaxed params, and restore them later). But we never do that. They do that to not have a softmax backward in their grads, as far as i understand, right? But we still account that in our code. Do you see the problem? Apart from it, there are minor changes that are needed to be done. |
When they "save" the arch parameters, they also update their values to have the softmax normalized values:
Later, in the forward pass, they discriminate between the model step and the architect step as follows:
It's not clear to me why they do it this way. As far as I can see, removing "updateType" and having the following code should be identical:
This ultimately boils down to the implementation of Equations 12 and 13 in the paper. We need the gradients of the arch weights for each cell to calculate the perturbations. These perturbations will be applied in the additional forward and backward passes, and the result from that will be used to calculate the lambda regularization terms which will be directly applied to the parameters of the model. Can you think of any reasoning for doing it the way the authors have, as opposed to the way given above? |
As I get it, when they save the softmaxed parameters, (at that point) they make sure that the softmax backward operation does not come into the computational graph. Thats the only difference i can point out. That means, in their implementation, they wanted to only consider till the softmaxed arch weights ( I agree that the formulation looks same for both us and them. But the question remains, how does this one difference really affect the formulation of equation 12 and 13! |
I'm not sure what you mean by this. Can you explain? |
I meant, the gradient term for sigmoid for an input x, d(sigmoid (x)) = sigmoid(x) * (1 - sigmoid(x)) dx. So with their code, they only are only considering the first term.. |
The sigma in the equation denotes the softmax operation, not the sigmoid activation function. |
Yeah, but the expression remains same for softmax as well. |