diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 80248a69..0fc9bd0e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,6 +24,7 @@ repos: rev: v1.1.1 hooks: - id: mypy + exclude: ^tests/ additional_dependencies: - "attrs" - "types-pyyaml" diff --git a/examples/demo_light.py b/examples/demo_light.py index e1c85142..30287bd9 100644 --- a/examples/demo_light.py +++ b/examples/demo_light.py @@ -1,21 +1,15 @@ from __future__ import annotations -from confopt.profile import GDASProfile +from confopt.profile import DARTSProfile, LambdaDARTSProfile from confopt.train import Experiment from confopt.enums import SearchSpaceType, DatasetType if __name__ == "__main__": - profile = GDASProfile( - searchspace=SearchSpaceType.DARTS, - epochs=3, - early_stopper="skip_connection" - ) - profile.configure_early_stopper( - max_skip_normal = 0, - max_skip_reduce = 0, - min_epochs = 1, - count_discrete = False, + profile = LambdaDARTSProfile( + searchspace=SearchSpaceType.DARTS, epochs=3, ) + profile.configure_searchspace(C=1) + experiment = Experiment( search_space=SearchSpaceType.DARTS, dataset=DatasetType.CIFAR10, diff --git a/src/confopt/enums.py b/src/confopt/enums.py index dfce624c..446de013 100644 --- a/src/confopt/enums.py +++ b/src/confopt/enums.py @@ -19,6 +19,7 @@ class SamplerType(Enum): DARTS = "darts" DRNAS = "drnas" GDAS = "gdas" + LAMBDADARTS = "lambda_darts" SNAS = "snas" REINMAX = "reinmax" diff --git a/src/confopt/profile/__init__.py b/src/confopt/profile/__init__.py index ca3892ec..52ef6010 100644 --- a/src/confopt/profile/__init__.py +++ b/src/confopt/profile/__init__.py @@ -4,6 +4,7 @@ DiscreteProfile, DRNASProfile, GDASProfile, + LambdaDARTSProfile, ReinMaxProfile, SNASProfile, ) @@ -11,6 +12,7 @@ __all__ = [ "BaseProfile", "DARTSProfile", + "LambdaDARTSProfile", "GDASProfile", "DRNASProfile", "SNASProfile", diff --git a/src/confopt/profile/profiles.py b/src/confopt/profile/profiles.py index d78404e9..05a272a0 100644 --- a/src/confopt/profile/profiles.py +++ b/src/confopt/profile/profiles.py @@ -1,9 +1,11 @@ from __future__ import annotations from abc import ABC +from dataclasses import asdict from typing import Any from confopt.enums import SamplerType, SearchSpaceType +from confopt.searchspace.common import LambdaReg from confopt.searchspace.darts.core.genotypes import DARTSGenotype from confopt.utils import get_num_classes @@ -34,6 +36,37 @@ def _initialize_sampler_config(self) -> None: self.sampler_config = darts_config # type: ignore +class LambdaDARTSProfile(DARTSProfile): + SAMPLER_TYPE = SamplerType.LAMBDADARTS + + def __init__( + self, + searchspace: str | SearchSpaceType, + epochs: int, + **kwargs: Any, + ) -> None: + super().__init__( + searchspace, + epochs, + **kwargs, + ) + self.lambda_regularizer_config = asdict(LambdaReg()) + + def get_config(self) -> dict: + config = super().get_config() + config["lambda_regularizer"] = self.lambda_regularizer_config + + return config + + def configure_lambda_regularizer(self, **kwargs: Any) -> None: + for config_key in kwargs: + assert config_key in self.lambda_regularizer_config, ( + f"{config_key} not a valid configuration for the" + + "lambda regularization config" + ) + self.lambda_regularizer_config[config_key] = kwargs[config_key] + + class GDASProfile(BaseProfile, ABC): SAMPLER_TYPE = SamplerType.GDAS diff --git a/src/confopt/searchspace/common/__init__.py b/src/confopt/searchspace/common/__init__.py index b5640101..7a31b817 100644 --- a/src/confopt/searchspace/common/__init__.py +++ b/src/confopt/searchspace/common/__init__.py @@ -6,6 +6,8 @@ GradientMatchingScoreSupport, GradientStatsSupport, InsertCellSupport, + LambdaDARTSSupport, + LambdaReg, LayerAlignmentScoreSupport, OperationStatisticsSupport, PerturbationArchSelectionSupport, @@ -26,6 +28,8 @@ "FLOPSRegTermSupport", "GradientMatchingScoreSupport", "GradientStatsSupport", + "LambdaReg", + "LambdaDARTSSupport", "LayerAlignmentScoreSupport", "OperationStatisticsSupport", "ArchAttentionSupport", diff --git a/src/confopt/searchspace/common/base_search.py b/src/confopt/searchspace/common/base_search.py index ad801be0..9bcc3cf0 100644 --- a/src/confopt/searchspace/common/base_search.py +++ b/src/confopt/searchspace/common/base_search.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Literal import numpy as np @@ -168,6 +169,156 @@ def reset_gm_scores(self) -> None: module.running_sim.reset() +@dataclass +class LambdaReg: + epsilon_base: float = 0.001 + epsilon: float = 0.0 + corr_type: str = "corr" + strength: float = 0.125 + enabled: bool = True + + +class LambdaDARTSSupport(ModelWrapper): + def __init__(self, model: nn.Module): + super().__init__(model) + self._assert_model_has_implementation() + self.lambda_reg = LambdaReg() + + def get_cells(self, cell_type: str | None = None) -> list[torch.nn.Module] | None: + return self.model.get_cells(cell_type) + + def _assert_model_has_implementation(self) -> None: + base_error = "LambdaDARTSSupport implementation missing" + + def assert_is_function(fn_name: str) -> None: + assert hasattr( + self.model, fn_name + ), f"{base_error}: {fn_name} method not found in {type(self.model)}" + assert callable( + self.model.get_arch_grads + ), f"'{fn_name}' should be a method" + + assert_is_function("get_arch_grads") + assert_is_function("get_cells") + assert_is_function("set_lambda_perturbations") + + def set_lambda_darts_params(self, lambda_reg: LambdaReg) -> None: + self.lambda_reg = lambda_reg + + def enable_lambda_darts(self) -> None: + self.lambda_reg.enabled = True + + def disable_lambda_darts(self) -> None: + self.lambda_reg.enabled = False + + def get_perturbations(self) -> list[torch.Tensor]: + grads_normal, grads_reduce = self.model.get_arch_grads() + alpha_normal = self.arch_parameters[0] + + def get_perturbation_for_cell( + layer_gradients: list[torch.Tensor], + ) -> list[torch.Tensor]: + with torch.no_grad(): + weight = 1 / ((len(layer_gradients) * (len(layer_gradients) - 1)) / 2) + if self.lambda_reg.corr_type == "corr": + u = [g / g.norm(p=2.0) for g in layer_gradients] + sum_u = sum(u) + identity_matrix = torch.eye(sum_u.shape[0], device=sum_u.device) + P = [ + (1 / g.norm(p=2.0)) * (identity_matrix - torch.ger(u_l, u_l)) + for g, u_l in zip(layer_gradients, u) + ] + perturbations = [ + weight * (P_l @ sum_u).reshape(alpha_normal.shape) for P_l in P + ] + elif self.lambda_reg.corr_type == "signcorr": + perturbations = [] + for i in range(len(layer_gradients)): + _dir: torch.Tensor = 0 + for j in range(len(layer_gradients)): + if i == j: + continue + g, g_ = layer_gradients[i], layer_gradients[j] + dot, abs_dot = torch.dot(g, g_), torch.dot( + torch.abs(g), torch.abs(g_) + ) + _dir += ( + ( + torch.ones_like(g_) + - (dot / abs_dot) * torch.sign(g) * torch.sign(g_) + ) + * g_ + / abs_dot + ) + perturbations.append(weight * _dir.reshape(alpha_normal.shape)) + return perturbations + + pert_normal = get_perturbation_for_cell(grads_normal) + pert_reduce = ( + get_perturbation_for_cell(grads_reduce) + if grads_reduce is not None + else None + ) + pert_denom = ( + pert_normal + pert_reduce if pert_reduce is not None else pert_normal + ) + + self.lambda_reg.epsilon = ( + self.lambda_reg.epsilon_base + / torch.cat(pert_denom, dim=0).norm(p=2.0).item() + ) + + idx_normal = 0 + idx_reduce = 0 + pert = [] + + cells = self.get_cells() + + if cells is not None: + for cell in cells: + if pert_reduce is not None and cell.reduction: + pert.append(pert_reduce[idx_reduce] * self.lambda_reg.epsilon) + idx_reduce += 1 + else: + pert.append(pert_normal[idx_normal] * self.lambda_reg.epsilon) + idx_normal += 1 + + return pert + + def add_lambda_regularization( + self, data: torch.Tensor, target: torch.Tensor, criterion: nn.modules.loss._Loss + ) -> None: + if not self.lambda_reg.enabled: + return + + pert = self.get_perturbations() + + loss_fn = criterion + # Calculate forward and backward gradients to compute finite difference + self.model.set_lambda_perturbations(pert) + forward_grads = torch.autograd.grad( + loss_fn(self.model(data)[1], target), + self.model_weight_parameters(), + allow_unused=True, + ) + self.model.set_lambda_perturbations([-p for p in pert]) + backward_grads = torch.autograd.grad( + loss_fn(self.model(data)[1], target), + self.model_weight_parameters(), + allow_unused=True, + ) + + reg_grad = [ + (f - b).div_(2 * self.lambda_reg.epsilon) + if (f is not None and b is not None) + else 0.0 + for f, b in zip(forward_grads, backward_grads) + ] + for param, grad in zip(self.model_weight_parameters(), reg_grad): + if param.grad is not None: + param.grad.data.add_(self.lambda_reg.strength * grad) + + class LayerAlignmentScoreSupport(ModelWrapper): def __init__(self, model: nn.Module): super().__init__(model) diff --git a/src/confopt/searchspace/darts/core/model_search.py b/src/confopt/searchspace/darts/core/model_search.py index c385ac10..5b751bfd 100644 --- a/src/confopt/searchspace/darts/core/model_search.py +++ b/src/confopt/searchspace/darts/core/model_search.py @@ -357,10 +357,25 @@ def __init__( self.multihead_attention = nn.MultiheadAttention( embed_dim=len(self.primitives), num_heads=1 ) - + self.lambda_perturbations = None # mask for pruning self._initialize_parameters() + def get_cells(self, cell_type: str | None) -> torch.nn.Module | None: + assert cell_type in [ + "normal", + "reduce", + None, + ], f"Illegal cell type: {cell_type}" + if cell_type == "normal": + cells = [cell for cell in self.cells if not cell.is_reduction_cell] + elif cell_type == "reduce": + cells = [cell for cell in self.cells if cell.is_reduction_cell] + else: + cells = self.cells + + return cells + def new(self) -> Network: """Get a new object with same arch and beta parameters. @@ -460,8 +475,13 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: weights = weights_normal.clone() self.save_weight_grads(weights, cell_type="normal") + if self.lambda_perturbations is not None: + weights = weights - self.lambda_perturbations[_i] + s0, s1 = s1, cell(s0, s1, weights) + self.lambda_perturbations = None + out = self.global_pooling(s1) logits = self.classifier(out.view(out.size(0), -1)) return torch.squeeze(out, dim=(-1, -2)), logits @@ -981,6 +1001,9 @@ def get_weighted_flops(self) -> torch.Tensor: flops += torch.log(total_cell_flops) return flops / len(self.cells) + def set_lambda_perturbations(self, lambda_perturbations: torch.Tensor) -> None: + self.lambda_perturbations = lambda_perturbations + def preserve_grads(m: nn.Module) -> None: ignored_modules = ( diff --git a/src/confopt/searchspace/darts/supernet.py b/src/confopt/searchspace/darts/supernet.py index ca573bf5..cf62ce8e 100644 --- a/src/confopt/searchspace/darts/supernet.py +++ b/src/confopt/searchspace/darts/supernet.py @@ -14,6 +14,7 @@ GradientMatchingScoreSupport, GradientStatsSupport, InsertCellSupport, + LambdaDARTSSupport, LayerAlignmentScoreSupport, OperationStatisticsSupport, PerturbationArchSelectionSupport, @@ -34,6 +35,7 @@ class DARTSSearchSpace( ArchAttentionSupport, GradientMatchingScoreSupport, OperationStatisticsSupport, + LambdaDARTSSupport, LayerAlignmentScoreSupport, DrNASRegTermSupport, FLOPSRegTermSupport, diff --git a/src/confopt/searchspace/nb1shot1/core/model_search.py b/src/confopt/searchspace/nb1shot1/core/model_search.py index bbfee10d..bd821442 100644 --- a/src/confopt/searchspace/nb1shot1/core/model_search.py +++ b/src/confopt/searchspace/nb1shot1/core/model_search.py @@ -657,7 +657,9 @@ def save_weight_grads( grad_hook = weights.register_hook(self.save_gradient()) self.grad_hook_handlers.append(grad_hook) - def get_arch_grads(self, only_first_and_last: bool = False) -> list[torch.Tensor]: + def get_arch_grads( + self, only_first_and_last: bool = False + ) -> tuple[list[torch.Tensor], list[torch.Tensor] | None]: grads = [] if only_first_and_last: grads.append(self.weights_grad[0].reshape(-1)) @@ -666,12 +668,12 @@ def get_arch_grads(self, only_first_and_last: bool = False) -> list[torch.Tensor for alphas in self.weights_grad: grads.append(alphas.reshape(-1)) - return grads + return grads, None def get_mean_layer_alignment_score( self, only_first_and_last: bool = False ) -> float: - grads = self.get_arch_grads(only_first_and_last) + grads, _ = self.get_arch_grads(only_first_and_last) mean_score = calc_layer_alignment_score(grads) if math.isnan(mean_score): diff --git a/src/confopt/searchspace/nb201/core/model_search.py b/src/confopt/searchspace/nb201/core/model_search.py index edb86989..dce28efe 100644 --- a/src/confopt/searchspace/nb201/core/model_search.py +++ b/src/confopt/searchspace/nb201/core/model_search.py @@ -151,6 +151,7 @@ def __init__( self.multihead_attention = nn.MultiheadAttention( embed_dim=len(self.op_names), num_heads=1 ) + self.lambda_perturbations = None def get_weights(self) -> list[nn.Parameter]: """Get a list of learnable parameters in the model. (does not include alpha or @@ -211,6 +212,14 @@ def show_betas(self) -> str: nn.functional.softmax(self.beta_parameters, dim=-1).cpu() ) + def get_cells(self, cell_type: str | None) -> torch.nn.Module | None: + assert ( + cell_type == "normal" or cell_type is None + ), f"Illegal cell type: {cell_type}" + cells = [cell for cell in self.cells if isinstance(cell, SearchCell)] + + return cells + def get_message(self) -> str: """Gets a message describing the model and its cells. @@ -331,14 +340,22 @@ def forward(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: self.sampled_weights = [weights] feature = self.stem(inputs) - for _i, cell in enumerate(self.cells): + _i = 0 + for cell in self.cells: if isinstance(cell, SearchCell): alphas = weights.clone() self.save_weight_grads(alphas) + + if self.lambda_perturbations is not None: + weights = weights - self.lambda_perturbations[_i] + feature = cell(feature, alphas) + _i += 1 else: feature = cell(feature) + self.lambda_perturbations = None + out = self.lastact(feature) out = self.global_pooling(out) out = out.view(out.size(0), -1) @@ -382,7 +399,9 @@ def edge_normalization_forward( return out, logits - def get_arch_grads(self, only_first_and_last: bool = False) -> list[torch.Tensor]: + def get_arch_grads( + self, only_first_and_last: bool = False + ) -> tuple[list[torch.Tensor], list[torch.Tensor] | None]: grads = [] if only_first_and_last: grads.append(self.weights_grad[0].reshape(-1)) @@ -391,12 +410,12 @@ def get_arch_grads(self, only_first_and_last: bool = False) -> list[torch.Tensor for alphas in self.weights_grad: grads.append(alphas.reshape(-1)) - return grads + return grads, None def get_mean_layer_alignment_score( self, only_first_and_last: bool = False ) -> float: - grads = self.get_arch_grads(only_first_and_last) + grads, _ = self.get_arch_grads(only_first_and_last) mean_score = calc_layer_alignment_score(grads) if math.isnan(mean_score): @@ -503,6 +522,9 @@ def get_weighted_flops(self) -> torch.Tensor: flops += torch.log(total_cell_flops) return flops / len(self.cells) + def set_lambda_perturbations(self, lambda_perturbations: torch.Tensor) -> None: + self.lambda_perturbations = lambda_perturbations + def preserve_grads(m: nn.Module) -> None: ignored_modules = ( diff --git a/src/confopt/searchspace/nb201/supernet.py b/src/confopt/searchspace/nb201/supernet.py index 55602c48..17702402 100644 --- a/src/confopt/searchspace/nb201/supernet.py +++ b/src/confopt/searchspace/nb201/supernet.py @@ -13,6 +13,7 @@ FLOPSRegTermSupport, GradientMatchingScoreSupport, GradientStatsSupport, + LambdaDARTSSupport, LayerAlignmentScoreSupport, OperationStatisticsSupport, PerturbationArchSelectionSupport, @@ -35,6 +36,7 @@ class NASBench201SearchSpace( ArchAttentionSupport, GradientMatchingScoreSupport, OperationStatisticsSupport, + LambdaDARTSSupport, LayerAlignmentScoreSupport, DrNASRegTermSupport, PerturbationArchSelectionSupport, diff --git a/src/confopt/searchspace/tnb101/core/model_search.py b/src/confopt/searchspace/tnb101/core/model_search.py index 1571f826..c73b383d 100644 --- a/src/confopt/searchspace/tnb101/core/model_search.py +++ b/src/confopt/searchspace/tnb101/core/model_search.py @@ -60,7 +60,6 @@ def __init__( self.n_modules = 5 self.blocks_per_module = [2] * self.n_modules - # initialize other arguments for intializing a new model self.affine = affine self.track_running_stats = track_running_stats self.dataset = dataset @@ -93,15 +92,7 @@ def __init__( self.cells.append(cell) C_in = C_out self.num_edge = len(self.cells[0].edges) - - if dataset == "jigsaw": - self.num_classes = 1000 - elif dataset == "class_object": - self.num_classes = 100 - elif dataset == "class_scene": - self.num_classes = 63 - else: - self.num_classes = num_classes + self.num_classes = self.get_num_classes_from_dataset(num_classes, dataset) self.stem = self._get_stem_for_task(dataset) self.decoder = self._get_decoder_for_task(dataset, C_out) @@ -133,8 +124,19 @@ def __init__( self.num_ops = len(self.op_names) self.num_nodes = self.max_nodes - 1 + self.lambda_perturbations = None self._initialize_projection_params() + def get_num_classes_from_dataset(self, num_classes: int, dataset: str) -> int: + if dataset == "jigsaw": + num_classes = 1000 + elif dataset == "class_object": + num_classes = 100 + elif dataset == "class_scene": + num_classes = 63 + + return num_classes + def arch_parameters(self) -> nn.Parameter: return self._arch_parameters @@ -179,9 +181,13 @@ def forward(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: feature = self.stem(inputs) - for cell in self.cells: + for _i, cell in enumerate(self.cells): weights = alphas.clone() self.save_weight_grads(weights) + + if self.lambda_perturbations is not None: + weights = weights - self.lambda_perturbations[_i] + feature = cell(feature, weights) out = self.decoder(feature) @@ -399,7 +405,9 @@ def save_weight_grads( grad_hook = weights.register_hook(self.save_gradient()) self.grad_hook_handlers.append(grad_hook) - def get_arch_grads(self, only_first_and_last: bool = False) -> list[torch.Tensor]: + def get_arch_grads( + self, only_first_and_last: bool = False + ) -> tuple[list[torch.Tensor], list[torch.Tensor] | None]: grads = [] if only_first_and_last: grads.append(self.weights_grad[0].reshape(-1)) @@ -408,12 +416,20 @@ def get_arch_grads(self, only_first_and_last: bool = False) -> list[torch.Tensor for alphas in self.weights_grad: grads.append(alphas.reshape(-1)) - return grads + return grads, None + + def get_cells(self, cell_type: str | None) -> torch.nn.Module | None: + assert ( + cell_type == "normal" or cell_type is None + ), f"Illegal cell type: {cell_type}" + cells = [cell for cell in self.cells if isinstance(cell, TNB101SearchCell)] + + return cells def get_mean_layer_alignment_score( self, only_first_and_last: bool = False ) -> float: - grads = self.get_arch_grads(only_first_and_last) + grads, _ = self.get_arch_grads(only_first_and_last) mean_score = calc_layer_alignment_score(grads) if math.isnan(mean_score): @@ -468,6 +484,9 @@ def remove_from_projected_weights( ### End of PerturbationArchSelection methods ### + def set_lambda_perturbations(self, lambda_perturbations: torch.Tensor) -> None: + self.lambda_perturbations = lambda_perturbations + class TNB101SearchCell(nn.Module): expansion = 1 diff --git a/src/confopt/searchspace/tnb101/supernet.py b/src/confopt/searchspace/tnb101/supernet.py index 8978a3d6..24ff213f 100644 --- a/src/confopt/searchspace/tnb101/supernet.py +++ b/src/confopt/searchspace/tnb101/supernet.py @@ -13,6 +13,7 @@ FLOPSRegTermSupport, GradientMatchingScoreSupport, GradientStatsSupport, + LambdaDARTSSupport, LayerAlignmentScoreSupport, OperationStatisticsSupport, PerturbationArchSelectionSupport, @@ -35,6 +36,7 @@ class TransNASBench101SearchSpace( GradientStatsSupport, FLOPSRegTermSupport, OperationStatisticsSupport, + LambdaDARTSSupport, LayerAlignmentScoreSupport, FairDARTSRegTermSupport, DrNASRegTermSupport, diff --git a/src/confopt/train/configurable_trainer.py b/src/confopt/train/configurable_trainer.py index c569a474..6778f5e3 100644 --- a/src/confopt/train/configurable_trainer.py +++ b/src/confopt/train/configurable_trainer.py @@ -19,6 +19,7 @@ from confopt.searchspace.common.base_search import ( GradientMatchingScoreSupport, GradientStatsSupport, + LambdaDARTSSupport, LayerAlignmentScoreSupport, OperationStatisticsSupport, ) @@ -400,17 +401,17 @@ def _train_epoch( # noqa: C901 oles_frequency: int = 20, oles_threshold: float = 0.4, ) -> tuple[TrainingMetrics, TrainingMetrics]: - data_time, batch_time = AverageMeter(), AverageMeter() - base_losses, base_top1, base_top5 = ( - AverageMeter(), - AverageMeter(), - AverageMeter(), - ) - arch_losses, arch_top1, arch_top5 = ( - AverageMeter(), - AverageMeter(), - AverageMeter(), - ) + ( + data_time, + batch_time, + base_losses, + base_top1, + base_top5, + arch_losses, + arch_top1, + arch_top5, + ) = (AverageMeter() for _ in range(8)) + network.train() unwrapped_network = unwrap_model(network) end = time.time() @@ -482,6 +483,11 @@ def _train_epoch( # noqa: C901 if isinstance(unwrapped_network, LayerAlignmentScoreSupport): unwrapped_network.update_layer_alignment_scores() + if isinstance(unwrapped_network, LambdaDARTSSupport): + unwrapped_network.add_lambda_regularization( + base_inputs, base_targets, criterion + ) + torch.nn.utils.clip_grad_norm_( unwrapped_network.model_weight_parameters(), 5 ) @@ -512,18 +518,6 @@ def _train_epoch( # noqa: C901 end = time.time() if step % print_freq == 0 or step + 1 == len(train_loader): - # Tstr = f"Time {batch_time.val:.2f} ({batch_time.avg:.2f})" \ - # + f"Data {data_time.val:.2f} ({data_time.avg:.2f})" - - # Wstr = f"Base [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1" \ - # + f"{top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f}" \ - # + f"({top5.avg:.2f})]" - - # Astr = f"Arch [Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1" \ - # + f"{top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f}" \ - # + f"({top5.avg:.2f})]" - - # logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr) ... if self.debug_mode and step > DEBUG_STEPS: diff --git a/src/confopt/train/experiment.py b/src/confopt/train/experiment.py index 740a7716..81893d85 100644 --- a/src/confopt/train/experiment.py +++ b/src/confopt/train/experiment.py @@ -65,6 +65,7 @@ SearchSpace, TransNASBench101SearchSpace, ) +from confopt.searchspace.common import LambdaReg from confopt.train import ConfigurableTrainer, DiscreteTrainer from confopt.train.projection import PerturbationArchSelection from confopt.train.search_space_handler import SearchSpaceHandler @@ -259,6 +260,7 @@ def _init_components( self._set_lora_toggler(config.get("lora", {}), config.get("lora_extra", {})) self._set_weight_entangler() self._set_regularizer(config.get("regularization", {})) + self._set_lambda_regularizer(config.get("lambda_regularizer", {})) self._set_profile(config) self._set_early_stopper( config["early_stopper"], config.get("early_stopper_config", {}) @@ -314,7 +316,7 @@ def _set_sampler( ) -> None: arch_params = self.search_space.arch_parameters self.sampler: BaseSampler | None = None - if sampler == SamplerType.DARTS: + if sampler in (SamplerType.DARTS, SamplerType.LAMBDADARTS): self.sampler = DARTSSampler(**config, arch_parameters=arch_params) elif sampler == SamplerType.DRNAS: self.sampler = DRNASSampler(**config, arch_parameters=arch_params) @@ -402,6 +404,11 @@ def _set_regularizer(self, config: dict) -> None: loss_weight=config["loss_weight"], ) + def _set_lambda_regularizer(self, config: dict) -> None: + self.lambda_regularizer = ( + None if (config is None or len(config) == 0) else LambdaReg(**config) + ) + def _set_profile(self, config: dict) -> None: assert self.sampler is not None @@ -417,6 +424,7 @@ def _set_profile(self, config: dict) -> None: pruner=self.pruner, is_arch_attention_enabled=config.get("is_arch_attention_enabled", False), regularizer=self.regularizer, + lambda_regularizer=self.lambda_regularizer, use_auxiliary_skip_connection=config.get( "use_auxiliary_skip_connection", False ), diff --git a/src/confopt/train/search_space_handler.py b/src/confopt/train/search_space_handler.py index e720589e..d96de7ab 100644 --- a/src/confopt/train/search_space_handler.py +++ b/src/confopt/train/search_space_handler.py @@ -12,12 +12,16 @@ from confopt.oneshot.weightentangler import WeightEntangler from confopt.searchspace import DARTSSearchSpace from confopt.searchspace.common import ( + LambdaReg, LoRALayer, OperationBlock, OperationChoices, SearchSpace, ) -from confopt.searchspace.common.base_search import ArchAttentionSupport +from confopt.searchspace.common.base_search import ( + ArchAttentionSupport, + LambdaDARTSSupport, +) class SearchSpaceHandler: @@ -34,6 +38,7 @@ def __init__( lora_toggler: LoRAToggler | None = None, is_arch_attention_enabled: bool = False, regularizer: Regularizer | None = None, + lambda_regularizer: LambdaReg | None = None, use_auxiliary_skip_connection: bool = False, ) -> None: self.sampler = sampler @@ -46,6 +51,7 @@ def __init__( self.pruner = pruner self.lora_toggler = lora_toggler self.regularizer = regularizer + self.lambda_regularizer = lambda_regularizer self.is_argmax_sampler = False if isinstance(self.sampler, GDASSampler): @@ -87,6 +93,14 @@ def adapt_search_space(self, search_space: SearchSpace) -> None: ): search_space.set_arch_attention(True) + if ( + isinstance(search_space, LambdaDARTSSupport) + and isinstance(self.lambda_regularizer, LambdaReg) + and self.lambda_regularizer.enabled + ): + search_space.enable_lambda_darts() + search_space.set_lambda_darts_params(self.lambda_regularizer) + def perturb_parameter(self, search_space: SearchSpace) -> None: if self.perturbation is not None: self.perturbation._perturb_and_update_alphas() diff --git a/tests/test_lambda_darts.py b/tests/test_lambda_darts.py new file mode 100644 index 00000000..5b06ba11 --- /dev/null +++ b/tests/test_lambda_darts.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import unittest +import torch + +from confopt.searchspace.common import LambdaReg +from confopt.searchspace import ( + DARTSSearchSpace, + NASBench201SearchSpace, + TransNASBench101SearchSpace, + SearchSpace, +) + +DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + +class TestLambda(unittest.TestCase): + + def _forward_and_backward_pass(self, model: SearchSpace) -> SearchSpace: + x = torch.randn(2, 3, 32, 32).to(DEVICE) + y = torch.randint(10, (2,)).to(DEVICE) + criterion = torch.nn.CrossEntropyLoss() + + _, preds = model(x) + + loss = criterion(preds, y) + loss.backward() + + return model + + def test_shapes_darts(self) -> None: + search_space = DARTSSearchSpace().to(DEVICE) + search_space = self._forward_and_backward_pass(search_space) + self._test_shapes(search_space, (6, 2), (14, 8)) + + def test_shapes_nb201(self) -> None: + search_space = NASBench201SearchSpace().to(DEVICE) + search_space = self._forward_and_backward_pass(search_space) + self._test_shapes(search_space, (15, 0), (6, 5)) + + def test_shapes_tnb101(self) -> None: + search_space = TransNASBench101SearchSpace().to(DEVICE) + search_space = self._forward_and_backward_pass(search_space) + self._test_shapes(search_space, (10, 0), (6, 4)) + + def test_perturbations_disabled_darts(self) -> None: + search_space = DARTSSearchSpace().to(DEVICE) + self._test_perturbations_disabled(search_space) + + def test_perturbations_disabled_nb201(self) -> None: + search_space = NASBench201SearchSpace().to(DEVICE) + self._test_perturbations_disabled(search_space) + + def test_perturbations_disabled_tnb101(self) -> None: + search_space = TransNASBench101SearchSpace().to(DEVICE) + self._test_perturbations_disabled(search_space) + + def test_perturbations_enabled_darts(self) -> None: + search_space = DARTSSearchSpace().to(DEVICE) + self._test_perturbations_enabled(search_space) + + def test_perturbations_enabled_nb201(self) -> None: + search_space = NASBench201SearchSpace().to(DEVICE) + self._test_perturbations_enabled(search_space) + + def test_perturbations_enabled_tnb101(self) -> None: + search_space = TransNASBench101SearchSpace().to(DEVICE) + self._test_perturbations_enabled(search_space) + + def _test_shapes( + self, search_space: SearchSpace, n_cells: tuple, grads_shape: tuple + ) -> None: + grads_normal, grads_reduce = search_space.model.get_arch_grads() + has_reduce = n_cells[1] > 0 + + def assert_shape_correct(grads: list, n_cells: int, grads_shape: tuple) -> None: + assert isinstance(grads, list) + assert len(grads) == n_cells + + for grad in grads: + assert grad.shape == (grads_shape[0] * grads_shape[1],) + + assert_shape_correct(grads_normal, n_cells[0], grads_shape) + if has_reduce: + assert_shape_correct(grads_reduce, n_cells[1], grads_shape) + else: + assert grads_reduce is None + + perts = search_space.get_perturbations() + + for p in perts: + assert p.shape == grads_shape + + def _get_grads(self, model: SearchSpace) -> list[torch.Tensor]: + grads = [] + for p in model.model_weight_parameters(): + if p.grad is not None: + grads.append(p.grad.clone()) + + return grads + + def _test_perturbations_disabled(self, model: SearchSpace) -> None: + x = torch.randn(2, 3, 32, 32).to(DEVICE) + y = torch.randint(10, (2,)).to(DEVICE) + + model.disable_lambda_darts() + + criterion = torch.nn.CrossEntropyLoss() + _, preds = model(x) + + loss = criterion(preds, y) + loss.backward() + + old_grads = self._get_grads(model) + model.add_lambda_regularization(x, y, criterion) + new_grads = self._get_grads(model) + + for new_grad, old_grad in zip(new_grads, old_grads): + assert (new_grad == old_grad).all() + + def _test_perturbations_enabled(self, model: SearchSpace) -> None: + x = torch.randn(2, 3, 32, 32).to(DEVICE) + y = torch.randint(10, (2,)).to(DEVICE) + + criterion = torch.nn.CrossEntropyLoss() + _, preds = model(x) + + loss = criterion(preds, y) + loss.backward() + + model.lambda_reg.enabled = True + + old_grads = self._get_grads(model) + model.add_lambda_regularization(x, y, criterion) + new_grads = self._get_grads(model) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_profiles.py b/tests/test_profiles.py index 856bae06..6d9c2145 100644 --- a/tests/test_profiles.py +++ b/tests/test_profiles.py @@ -1,12 +1,15 @@ +from dataclasses import asdict import unittest from confopt.profile import ( BaseProfile, DARTSProfile, + LambdaDARTSProfile, DRNASProfile, GDASProfile, SNASProfile, ) +from confopt.searchspace.common.base_search import LambdaReg class TestBaseProfile(unittest.TestCase): @@ -167,7 +170,19 @@ def test_sampler_post_fn(self) -> None: profile.sampler_config["arch_combine_fn"] == sampler_config["arch_combine_fn"] ) - +class TestLambdaDARTSProfile(unittest.TestCase): + def test_initialization(self) -> None: + profile = LambdaDARTSProfile( + epochs=100, + searchspace="nb201", + sampler_sample_frequency="step", + ) + + config = profile.get_config() + assert "lambda_regularizer" in config + assert config["lambda_regularizer"] == asdict(LambdaReg()) + assert config["lambda_regularizer"]["enabled"] is True + class TestDRNASProfile(unittest.TestCase): def test_initialization(self) -> None: perturb_config = {"epsilon": 0.5}