Skip to content

feat: add implementation of Lambda regularization #216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

Neonkraft
Copy link
Contributor

@Neonkraft Neonkraft commented Feb 7, 2025

  • Implemented Lambda-DARTS for DARTS, NB201 and TNB101.
  • Added tests.

@Neonkraft Neonkraft requested a review from abhash-er February 7, 2025 16:23
@Neonkraft Neonkraft changed the title feat: add initial implementation of Lambda regularization feat: add implementation of Lambda regularization Feb 12, 2025
@@ -457,6 +458,9 @@ def _train_epoch( # noqa: C901
if isinstance(unwrapped_network, LayerAlignmentScoreSupport):
unwrapped_network.update_layer_alignment_scores()

if isinstance(unwrapped_network, LambdaDARTSSupport):
unwrapped_network.add_lambda_regularization(base_inputs, base_targets)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You missed adding criterion here with parameters

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, thanks!

@@ -34,6 +38,7 @@ def __init__(
lora_toggler: LoRAToggler | None = None,
is_arch_attention_enabled: bool = False,
regularizer: Regularizer | None = None,
lambda_regularizer: LambdaReg | None = None,
Copy link
Collaborator

@abhash-er abhash-er Feb 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that the lambda darts regularization is always turned on. I suppose you have written the function disable_lambda_darts(). You can by deafult disable it. Perhaps when you are accepting a lambda_regularizer, if it is not None, enable it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LambdaReg has enabled=True by default. But LambdaReg is only instantiated when the user sets use_lambda_regularizer to True in the profile. Further, the user can configure the lambda_regularizer and disable it.

I found another bug when looking into this, actually. The code would crash when use_lambda_regularizer=False. I've fixed that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that's what I was talking about earlier. I was testing the case when use_lambda_regularizer=False, and found that the regularisation was still occurring.

@abhash-er
Copy link
Collaborator

abhash-er commented Feb 13, 2025

I'm happy with the workflow of the lambda darts except one thing. There is one thing missing in our code. While calculating forward and backward grads, they dont use softmax at all (save softmaxed params, and restore them later). But we never do that. They do that to not have a softmax backward in their grads, as far as i understand, right? But we still account that in our code. Do you see the problem?

Apart from it, there are minor changes that are needed to be done.

@Neonkraft
Copy link
Contributor Author

Neonkraft commented Feb 17, 2025

While calculating forward and backward grads, they dont use softmax at all (save softmaxed params, and restore them later). But we never do that. They do that to not have a softmax backward in their grads, as far as i understand, right? But we still account that in our code. Do you see the problem?

When they "save" the arch parameters, they also update their values to have the softmax normalized values:

    def softmax_arch_parameters(self):
        self._save_arch_parameters()
        for p in self._arch_parameters:
            p.data.copy_(F.softmax(p, dim=-1))

Later, in the forward pass, they discriminate between the model step and the architect step as follows:

    def forward(self, input, updateType='alpha', pert=None):
        s0 = s1 = self.stem(input)
        self.weights['normal'] = []
        self.weights['reduce'] = []
        for i, cell in enumerate(self.cells):
            if cell.reduction:
                if updateType == 'weight':
                    weights = self.alphas_reduce.clone() # Don't have to softmax this because it's already been softmaxed
                else:
                    weights = F.softmax(self.alphas_reduce, dim=-1)
            else:
                if updateType == 'weight':
                    weights = self.alphas_normal.clone() # Here too!
                else:
                    weights = F.softmax(self.alphas_normal, dim=-1)
            if self.training:
                weights.retain_grad()
                self.weights['reduce' if cell.reduction else 'normal'].append(weights)
            if pert:
                weights = weights - pert[i]
            s0, s1 = s1, cell(s0, s1, weights, self.drop_path_prob)
        out = self.global_pooling(s1)
        logits = self.classifier(out.view(out.size(0),-1))
        return logits

It's not clear to me why they do it this way. As far as I can see, removing "updateType" and having the following code should be identical:

  for i, cell in enumerate(self.cells):
            if cell.reduction:
                    weights = F.softmax(self.alphas_reduce, dim=-1)
            else:
                    weights = F.softmax(self.alphas_normal, dim=-1)
            if self.training:
                weights.retain_grad()
                self.weights['reduce' if cell.reduction else 'normal'].append(weights)
            if pert:
                weights = weights - pert[i]

This ultimately boils down to the implementation of Equations 12 and 13 in the paper. We need the gradients of the arch weights for each cell to calculate the perturbations. These perturbations will be applied in the additional forward and backward passes, and the result from that will be used to calculate the lambda regularization terms which will be directly applied to the parameters of the model.

Can you think of any reasoning for doing it the way the authors have, as opposed to the way given above?

@abhash-er
Copy link
Collaborator

abhash-er commented Feb 17, 2025

Can you think of any reasoning for doing it the way the authors have, as opposed to the way given above?

As I get it, when they save the softmaxed parameters, (at that point) they make sure that the softmax backward operation does not come into the computational graph. Thats the only difference i can point out.

That means, in their implementation, they wanted to only consider till the softmaxed arch weights (sigmoid(arch)), and not want the term (1 - sigmoid(arch)) (when the update type is weight)?

I agree that the formulation looks same for both us and them. But the question remains, how does this one difference really affect the formulation of equation 12 and 13!

@Neonkraft
Copy link
Contributor Author

and not want the term (1 - sigmoid(arch)) (when the update type is weight)?

I'm not sure what you mean by this. Can you explain?

@abhash-er
Copy link
Collaborator

abhash-er commented Feb 17, 2025

and not want the term (1 - sigmoid(arch)) (when the update type is weight)?

I'm not sure what you mean by this. Can you explain?

I meant, the gradient term for sigmoid for an input x, d(sigmoid (x)) = sigmoid(x) * (1 - sigmoid(x)) dx. So with their code, they only are only considering the first term..

@Neonkraft
Copy link
Contributor Author

The sigma in the equation denotes the softmax operation, not the sigmoid activation function.

@abhash-er
Copy link
Collaborator

The sigma in the equation denotes the softmax operation, not the sigmoid activation function.

Yeah, but the expression remains same for softmax as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants