Skip to content

feat: add implementation of Lambda regularization #216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ repos:
rev: v1.1.1
hooks:
- id: mypy
exclude: ^tests/
additional_dependencies:
- "attrs"
- "types-pyyaml"
Expand Down
16 changes: 5 additions & 11 deletions examples/demo_light.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
from __future__ import annotations

from confopt.profile import GDASProfile
from confopt.profile import DARTSProfile, LambdaDARTSProfile
from confopt.train import Experiment
from confopt.enums import SearchSpaceType, DatasetType

if __name__ == "__main__":
profile = GDASProfile(
searchspace=SearchSpaceType.DARTS,
epochs=3,
early_stopper="skip_connection"
)
profile.configure_early_stopper(
max_skip_normal = 0,
max_skip_reduce = 0,
min_epochs = 1,
count_discrete = False,
profile = LambdaDARTSProfile(
searchspace=SearchSpaceType.DARTS, epochs=3,
)
profile.configure_searchspace(C=1)

experiment = Experiment(
search_space=SearchSpaceType.DARTS,
dataset=DatasetType.CIFAR10,
Expand Down
1 change: 1 addition & 0 deletions src/confopt/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class SamplerType(Enum):
DARTS = "darts"
DRNAS = "drnas"
GDAS = "gdas"
LAMBDADARTS = "lambda_darts"
SNAS = "snas"
REINMAX = "reinmax"

Expand Down
2 changes: 2 additions & 0 deletions src/confopt/profile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
DiscreteProfile,
DRNASProfile,
GDASProfile,
LambdaDARTSProfile,
ReinMaxProfile,
SNASProfile,
)

__all__ = [
"BaseProfile",
"DARTSProfile",
"LambdaDARTSProfile",
"GDASProfile",
"DRNASProfile",
"SNASProfile",
Expand Down
33 changes: 33 additions & 0 deletions src/confopt/profile/profiles.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

from abc import ABC
from dataclasses import asdict
from typing import Any

from confopt.enums import SamplerType, SearchSpaceType
from confopt.searchspace.common import LambdaReg
from confopt.searchspace.darts.core.genotypes import DARTSGenotype
from confopt.utils import get_num_classes

Expand Down Expand Up @@ -34,6 +36,37 @@ def _initialize_sampler_config(self) -> None:
self.sampler_config = darts_config # type: ignore


class LambdaDARTSProfile(DARTSProfile):
SAMPLER_TYPE = SamplerType.LAMBDADARTS

def __init__(
self,
searchspace: str | SearchSpaceType,
epochs: int,
**kwargs: Any,
) -> None:
super().__init__(
searchspace,
epochs,
**kwargs,
)
self.lambda_regularizer_config = asdict(LambdaReg())

def get_config(self) -> dict:
config = super().get_config()
config["lambda_regularizer"] = self.lambda_regularizer_config

return config

def configure_lambda_regularizer(self, **kwargs: Any) -> None:
for config_key in kwargs:
assert config_key in self.lambda_regularizer_config, (
f"{config_key} not a valid configuration for the"
+ "lambda regularization config"
)
self.lambda_regularizer_config[config_key] = kwargs[config_key]


class GDASProfile(BaseProfile, ABC):
SAMPLER_TYPE = SamplerType.GDAS

Expand Down
4 changes: 4 additions & 0 deletions src/confopt/searchspace/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
GradientMatchingScoreSupport,
GradientStatsSupport,
InsertCellSupport,
LambdaDARTSSupport,
LambdaReg,
LayerAlignmentScoreSupport,
OperationStatisticsSupport,
PerturbationArchSelectionSupport,
Expand All @@ -26,6 +28,8 @@
"FLOPSRegTermSupport",
"GradientMatchingScoreSupport",
"GradientStatsSupport",
"LambdaReg",
"LambdaDARTSSupport",
"LayerAlignmentScoreSupport",
"OperationStatisticsSupport",
"ArchAttentionSupport",
Expand Down
151 changes: 151 additions & 0 deletions src/confopt/searchspace/common/base_search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Literal

import numpy as np
Expand Down Expand Up @@ -168,6 +169,156 @@ def reset_gm_scores(self) -> None:
module.running_sim.reset()


@dataclass
class LambdaReg:
epsilon_base: float = 0.001
epsilon: float = 0.0
corr_type: str = "corr"
strength: float = 0.125
enabled: bool = True


class LambdaDARTSSupport(ModelWrapper):
def __init__(self, model: nn.Module):
super().__init__(model)
self._assert_model_has_implementation()
self.lambda_reg = LambdaReg()

def get_cells(self, cell_type: str | None = None) -> list[torch.nn.Module] | None:
return self.model.get_cells(cell_type)

def _assert_model_has_implementation(self) -> None:
base_error = "LambdaDARTSSupport implementation missing"

def assert_is_function(fn_name: str) -> None:
assert hasattr(
self.model, fn_name
), f"{base_error}: {fn_name} method not found in {type(self.model)}"
assert callable(
self.model.get_arch_grads
), f"'{fn_name}' should be a method"

assert_is_function("get_arch_grads")
assert_is_function("get_cells")
assert_is_function("set_lambda_perturbations")

def set_lambda_darts_params(self, lambda_reg: LambdaReg) -> None:
self.lambda_reg = lambda_reg

def enable_lambda_darts(self) -> None:
self.lambda_reg.enabled = True

def disable_lambda_darts(self) -> None:
self.lambda_reg.enabled = False

def get_perturbations(self) -> list[torch.Tensor]:
grads_normal, grads_reduce = self.model.get_arch_grads()
alpha_normal = self.arch_parameters[0]

def get_perturbation_for_cell(
layer_gradients: list[torch.Tensor],
) -> list[torch.Tensor]:
with torch.no_grad():
weight = 1 / ((len(layer_gradients) * (len(layer_gradients) - 1)) / 2)
if self.lambda_reg.corr_type == "corr":
u = [g / g.norm(p=2.0) for g in layer_gradients]
sum_u = sum(u)
identity_matrix = torch.eye(sum_u.shape[0], device=sum_u.device)
P = [
(1 / g.norm(p=2.0)) * (identity_matrix - torch.ger(u_l, u_l))
for g, u_l in zip(layer_gradients, u)
]
perturbations = [
weight * (P_l @ sum_u).reshape(alpha_normal.shape) for P_l in P
]
elif self.lambda_reg.corr_type == "signcorr":
perturbations = []
for i in range(len(layer_gradients)):
_dir: torch.Tensor = 0
for j in range(len(layer_gradients)):
if i == j:
continue
g, g_ = layer_gradients[i], layer_gradients[j]
dot, abs_dot = torch.dot(g, g_), torch.dot(
torch.abs(g), torch.abs(g_)
)
_dir += (
(
torch.ones_like(g_)
- (dot / abs_dot) * torch.sign(g) * torch.sign(g_)
)
* g_
/ abs_dot
)
perturbations.append(weight * _dir.reshape(alpha_normal.shape))
return perturbations

pert_normal = get_perturbation_for_cell(grads_normal)
pert_reduce = (
get_perturbation_for_cell(grads_reduce)
if grads_reduce is not None
else None
)
pert_denom = (
pert_normal + pert_reduce if pert_reduce is not None else pert_normal
)

self.lambda_reg.epsilon = (
self.lambda_reg.epsilon_base
/ torch.cat(pert_denom, dim=0).norm(p=2.0).item()
)

idx_normal = 0
idx_reduce = 0
pert = []

cells = self.get_cells()

if cells is not None:
for cell in cells:
if pert_reduce is not None and cell.reduction:
pert.append(pert_reduce[idx_reduce] * self.lambda_reg.epsilon)
idx_reduce += 1
else:
pert.append(pert_normal[idx_normal] * self.lambda_reg.epsilon)
idx_normal += 1

return pert

def add_lambda_regularization(
self, data: torch.Tensor, target: torch.Tensor, criterion: nn.modules.loss._Loss
) -> None:
if not self.lambda_reg.enabled:
return

pert = self.get_perturbations()

loss_fn = criterion
# Calculate forward and backward gradients to compute finite difference
self.model.set_lambda_perturbations(pert)
forward_grads = torch.autograd.grad(
loss_fn(self.model(data)[1], target),
self.model_weight_parameters(),
allow_unused=True,
)
self.model.set_lambda_perturbations([-p for p in pert])
backward_grads = torch.autograd.grad(
loss_fn(self.model(data)[1], target),
self.model_weight_parameters(),
allow_unused=True,
)

reg_grad = [
(f - b).div_(2 * self.lambda_reg.epsilon)
if (f is not None and b is not None)
else 0.0
for f, b in zip(forward_grads, backward_grads)
]
for param, grad in zip(self.model_weight_parameters(), reg_grad):
if param.grad is not None:
param.grad.data.add_(self.lambda_reg.strength * grad)


class LayerAlignmentScoreSupport(ModelWrapper):
def __init__(self, model: nn.Module):
super().__init__(model)
Expand Down
25 changes: 24 additions & 1 deletion src/confopt/searchspace/darts/core/model_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,25 @@ def __init__(
self.multihead_attention = nn.MultiheadAttention(
embed_dim=len(self.primitives), num_heads=1
)

self.lambda_perturbations = None
# mask for pruning
self._initialize_parameters()

def get_cells(self, cell_type: str | None) -> torch.nn.Module | None:
assert cell_type in [
"normal",
"reduce",
None,
], f"Illegal cell type: {cell_type}"
if cell_type == "normal":
cells = [cell for cell in self.cells if not cell.is_reduction_cell]
elif cell_type == "reduce":
cells = [cell for cell in self.cells if cell.is_reduction_cell]
else:
cells = self.cells

return cells

def new(self) -> Network:
"""Get a new object with same arch and beta parameters.

Expand Down Expand Up @@ -460,8 +475,13 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
weights = weights_normal.clone()
self.save_weight_grads(weights, cell_type="normal")

if self.lambda_perturbations is not None:
weights = weights - self.lambda_perturbations[_i]

s0, s1 = s1, cell(s0, s1, weights)

self.lambda_perturbations = None

out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0), -1))
return torch.squeeze(out, dim=(-1, -2)), logits
Expand Down Expand Up @@ -981,6 +1001,9 @@ def get_weighted_flops(self) -> torch.Tensor:
flops += torch.log(total_cell_flops)
return flops / len(self.cells)

def set_lambda_perturbations(self, lambda_perturbations: torch.Tensor) -> None:
self.lambda_perturbations = lambda_perturbations


def preserve_grads(m: nn.Module) -> None:
ignored_modules = (
Expand Down
2 changes: 2 additions & 0 deletions src/confopt/searchspace/darts/supernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
GradientMatchingScoreSupport,
GradientStatsSupport,
InsertCellSupport,
LambdaDARTSSupport,
LayerAlignmentScoreSupport,
OperationStatisticsSupport,
PerturbationArchSelectionSupport,
Expand All @@ -34,6 +35,7 @@ class DARTSSearchSpace(
ArchAttentionSupport,
GradientMatchingScoreSupport,
OperationStatisticsSupport,
LambdaDARTSSupport,
LayerAlignmentScoreSupport,
DrNASRegTermSupport,
FLOPSRegTermSupport,
Expand Down
8 changes: 5 additions & 3 deletions src/confopt/searchspace/nb1shot1/core/model_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,9 @@ def save_weight_grads(
grad_hook = weights.register_hook(self.save_gradient())
self.grad_hook_handlers.append(grad_hook)

def get_arch_grads(self, only_first_and_last: bool = False) -> list[torch.Tensor]:
def get_arch_grads(
self, only_first_and_last: bool = False
) -> tuple[list[torch.Tensor], list[torch.Tensor] | None]:
grads = []
if only_first_and_last:
grads.append(self.weights_grad[0].reshape(-1))
Expand All @@ -666,12 +668,12 @@ def get_arch_grads(self, only_first_and_last: bool = False) -> list[torch.Tensor
for alphas in self.weights_grad:
grads.append(alphas.reshape(-1))

return grads
return grads, None

def get_mean_layer_alignment_score(
self, only_first_and_last: bool = False
) -> float:
grads = self.get_arch_grads(only_first_and_last)
grads, _ = self.get_arch_grads(only_first_and_last)
mean_score = calc_layer_alignment_score(grads)

if math.isnan(mean_score):
Expand Down
Loading