From dfe0b55617fa6d1be2480c8fd86a70c6c99935e5 Mon Sep 17 00:00:00 2001 From: Muhammad Aniq Date: Tue, 16 Sep 2025 11:53:32 +0100 Subject: [PATCH 1/4] feat: add HuberEWCPlugin and PseudoAnnotationPlugin with tests and examples --- .../avalanche/training/plugins/__init__.py | 2 + .../avalanche/training/plugins/incdet_ewc.py | 97 +++++++++++++ .../training/plugins/pseudo_annotation.py | 95 +++++++++++++ .../tests/plugins/test_incdet_ewc_plugin.py | 134 ++++++++++++++++++ .../plugins/test_pseudo_annotation_plugin.py | 39 +++++ .../strategies/test_incdet_ewc_integration.py | 57 ++++++++ examples/incdet_ewc_cifar.py | 87 ++++++++++++ examples/incdet_ewc_splitmnist.py | 61 ++++++++ run_incdet_ewc.py | 106 ++++++++++++++ 9 files changed, 678 insertions(+) create mode 100644 avalanche/avalanche/training/plugins/__init__.py create mode 100644 avalanche/avalanche/training/plugins/incdet_ewc.py create mode 100644 avalanche/avalanche/training/plugins/pseudo_annotation.py create mode 100644 avalanche/tests/plugins/test_incdet_ewc_plugin.py create mode 100644 avalanche/tests/plugins/test_pseudo_annotation_plugin.py create mode 100644 avalanche/tests/strategies/test_incdet_ewc_integration.py create mode 100644 examples/incdet_ewc_cifar.py create mode 100644 examples/incdet_ewc_splitmnist.py create mode 100644 run_incdet_ewc.py diff --git a/avalanche/avalanche/training/plugins/__init__.py b/avalanche/avalanche/training/plugins/__init__.py new file mode 100644 index 000000000..18a5dd49f --- /dev/null +++ b/avalanche/avalanche/training/plugins/__init__.py @@ -0,0 +1,2 @@ +from .incdet_ewc import HuberEWCPlugin +from .pseudo_annotation import PseudoAnnotationPlugin diff --git a/avalanche/avalanche/training/plugins/incdet_ewc.py b/avalanche/avalanche/training/plugins/incdet_ewc.py new file mode 100644 index 000000000..6a33afee9 --- /dev/null +++ b/avalanche/avalanche/training/plugins/incdet_ewc.py @@ -0,0 +1,97 @@ +################################################################################ +# Copyright (c) 2025. +# Copyrights licensed under the MIT License. +# See the accompanying LICENSE file for terms. +# +# Date: 2025-09-13 +# Author(s): Muhammad Aniq, GPT-5 assistant +# Website: avalanche.continualai.org +################################################################################ +from typing import Optional + +import torch +import torch.nn.functional as F # noqa: F401 (kept for future extensions) + +from avalanche.training.plugins.ewc import EWCPlugin + + +class HuberEWCPlugin(EWCPlugin): + """ + Elastic Weight Consolidation with Huber loss regularization. + + This plugin extends the standard :class:`EWCPlugin` by replacing the + quadratic penalty with a Huber penalty to improve stability when the + regularization strength is large. The importance matrix computation and + state handling are inherited from :class:`EWCPlugin`. + + The Huber loss is applied to the scaled parameter difference + z = sqrt(importance) * (theta - theta_old) with threshold ``beta``. + """ + + def __init__( + self, + ewc_lambda: float, + *, + beta: float = 1.0, + mode: str = "separate", + decay_factor: Optional[float] = None, + keep_importance_data: bool = False, + ): + super().__init__( + ewc_lambda=ewc_lambda, + mode=mode, + decay_factor=decay_factor, + keep_importance_data=keep_importance_data, + ) + self.beta = float(beta) + + @staticmethod + def _huber_sum(x: torch.Tensor, beta: float) -> torch.Tensor: + """Element-wise Huber with threshold beta, summed over all elements.""" + abs_x = x.abs() + quad = 0.5 * (x**2) + lin = beta * (abs_x - 0.5 * beta) + return torch.where(abs_x <= beta, quad, lin).sum() + + def before_backward(self, strategy, **kwargs): + """ + Compute Huber-based EWC penalty and add it to strategy.loss. + """ + exp_counter = strategy.clock.train_exp_counter + if exp_counter == 0: + return + + device = strategy.device + penalty = torch.tensor(0.0, device=device) + + if self.mode == "separate": + for experience in range(exp_counter): + for k, cur_param in strategy.model.named_parameters(): + if k not in self.saved_params[experience]: + continue + saved_param = self.saved_params[experience][k] + imp = self.importances[experience][k] + new_shape = cur_param.shape + delta = cur_param - saved_param.expand(new_shape) + # Scale by sqrt(importance) + scaled_delta = imp.expand(new_shape).sqrt() * delta + # Use custom huber to avoid version mismatches + penalty = penalty + self._huber_sum(scaled_delta, self.beta) + elif self.mode == "online": + prev_exp = exp_counter - 1 + for k, cur_param in strategy.model.named_parameters(): + if k not in self.saved_params[prev_exp]: + continue + saved_param = self.saved_params[prev_exp][k] + imp = self.importances[prev_exp][k] + new_shape = cur_param.shape + delta = cur_param - saved_param.expand(new_shape) + scaled_delta = imp.expand(new_shape).sqrt() * delta + penalty = penalty + self._huber_sum(scaled_delta, self.beta) + else: + raise ValueError("Wrong EWC mode.") + + strategy.loss += self.ewc_lambda * penalty + + +__all__ = ["HuberEWCPlugin"] diff --git a/avalanche/avalanche/training/plugins/pseudo_annotation.py b/avalanche/avalanche/training/plugins/pseudo_annotation.py new file mode 100644 index 000000000..89da6dba2 --- /dev/null +++ b/avalanche/avalanche/training/plugins/pseudo_annotation.py @@ -0,0 +1,95 @@ +################################################################################ +# Copyright (c) 2025. +# Copyrights licensed under the MIT License. +# See the accompanying LICENSE file for terms. +# +# Date: 2025-09-13 +# Author(s): Muhammad Aniq, GPT-5 assistant +# Website: avalanche.continualai.org +################################################################################ +from copy import deepcopy +from typing import List + +import torch +from torch.utils.data import TensorDataset # noqa: F401 (kept for reference) + +from avalanche.benchmarks.utils.classification_dataset import ( + _make_taskaware_tensor_classification_dataset, +) +from avalanche.training.plugins.strategy_plugin import SupervisedPlugin + + +class PseudoAnnotationPlugin(SupervisedPlugin): + """ + Simple pseudo-annotation plugin for classification benchmarks. + + Before training on a new experience, use the previous model to generate + high-confidence predictions on current data for classes seen so far and + append those pseudo-labeled samples to the adapted dataset. + + Note: This is a minimal classification-oriented implementation (not object + detection). It assumes inputs are tensors after transforms and that the + benchmark exposes class timelines. + """ + + def __init__(self, confidence_thresh: float = 0.9): + super().__init__() + self.confidence_thresh = float(confidence_thresh) + self._prev_model = None + self._seen_classes: set[int] = set() + + def after_training_exp(self, strategy, **kwargs): + self._prev_model = deepcopy(strategy.model) + self._prev_model.to(strategy.device) + self._prev_model.eval() + + if hasattr(strategy.experience, "classes_in_this_experience"): + self._seen_classes.update( + map(int, strategy.experience.classes_in_this_experience) + ) + + @torch.no_grad() + def after_train_dataset_adaptation(self, strategy, **kwargs): + if self._prev_model is None: + return + + if not hasattr(strategy.experience, "classes_in_this_experience"): + return + + new_classes = set(map(int, strategy.experience.classes_in_this_experience)) + old_classes = sorted(list(self._seen_classes - new_classes)) + if len(old_classes) == 0: + return + + # Use the adapted dataset (already set to train transforms) + from avalanche.training.utils import load_all_dataset + + assert strategy.adapted_dataset is not None + x, *rest = load_all_dataset(strategy.adapted_dataset) + x = x.to(strategy.device) + + logits = self._prev_model(x) + probs = torch.softmax(logits, dim=1) + + xs: List[torch.Tensor] = [] + ys: List[int] = [] + for idx in range(probs.shape[0]): + p = probs[idx] + for c in old_classes: + if c < p.numel() and float(p[c]) >= self.confidence_thresh: + xs.append(x[idx].detach().cpu()) + ys.append(int(c)) + + if len(ys) == 0: + return + + x_t = torch.stack(xs, dim=0) + y_t = torch.tensor(ys, dtype=torch.long) + pseudo_ds = _make_taskaware_tensor_classification_dataset( + x_t, y_t, task_labels=0 + ) + + strategy.adapted_dataset = strategy.adapted_dataset.concat(pseudo_ds) + + +__all__ = ["PseudoAnnotationPlugin"] diff --git a/avalanche/tests/plugins/test_incdet_ewc_plugin.py b/avalanche/tests/plugins/test_incdet_ewc_plugin.py new file mode 100644 index 000000000..b440cf1cb --- /dev/null +++ b/avalanche/tests/plugins/test_incdet_ewc_plugin.py @@ -0,0 +1,134 @@ +import unittest +import os +import importlib.util +import sys +import types +import torch +from torch.nn import Linear + + +def _inject_pkg_stubs(): + for name in ["avalanche", "avalanche.training", "avalanche.training.plugins"]: + if name not in sys.modules: + mod = types.ModuleType(name) + setattr(mod, "__path__", []) + sys.modules[name] = mod + # stub strategy_plugin with minimal SupervisedPlugin base + sp_name = "avalanche.training.plugins.strategy_plugin" + if sp_name not in sys.modules: + sp_mod = types.ModuleType(sp_name) + class _SupervisedPlugin: + def __init__(self): + pass + sp_mod.SupervisedPlugin = _SupervisedPlugin + sys.modules[sp_name] = sp_mod + # stub models.utils to avoid pulling full models + mu_name = "avalanche.models.utils" + if mu_name not in sys.modules: + mu_mod = types.ModuleType(mu_name) + def _forward(m, x, t=None): + return m(x) + mu_mod.avalanche_forward = _forward + sys.modules[mu_name] = mu_mod + # stub training.utils to avoid importing benchmarks + tu_name = "avalanche.training.utils" + if tu_name not in sys.modules: + tu_mod = types.ModuleType(tu_name) + class ParamData: + def __init__(self, name, shape=None, init_tensor=None, device="cpu"): + self.name = name + self.shape = shape + self.device = device + self._data = init_tensor if init_tensor is not None else torch.zeros(()) + @property + def data(self): + return self._data + def expand(self, new_shape, padding_fn=torch.zeros): + if self._data.numel() == 0: + self._data = padding_fn(new_shape) + return self._data + def copy_params_dict(model, copy_grad=False): + out = {} + for k, p in model.named_parameters(): + tensor = p.grad.data.clone() if copy_grad and p.grad is not None else p.data.clone() + out[k] = ParamData(k, p.shape, init_tensor=tensor, device=p.device) + return out + def zerolike_params_dict(model): + out = {} + for k, p in model.named_parameters(): + out[k] = ParamData(k, p.shape, init_tensor=torch.zeros_like(p.data), device=p.device) + return out + tu_mod.ParamData = ParamData + tu_mod.copy_params_dict = copy_params_dict + tu_mod.zerolike_params_dict = zerolike_params_dict + sys.modules[tu_name] = tu_mod + + +def _preload_ewc_module(): + here = os.path.dirname(__file__) + ewc_path = os.path.abspath( + os.path.join(here, "..", "..", "avalanche", "training", "plugins", "ewc.py") + ) + spec = importlib.util.spec_from_file_location("avalanche.training.plugins.ewc", ewc_path) + module = importlib.util.module_from_spec(spec) + assert spec is not None and spec.loader is not None + spec.loader.exec_module(module) + sys.modules["avalanche.training.plugins.ewc"] = module + + +def _inject_ex_model_stub(): + mod = types.ModuleType("avalanche.benchmarks.classic.ex_model") + class _Dummy: + pass + mod.LeNet5 = _Dummy + mod.SlimResNet18 = _Dummy + sys.modules["avalanche.benchmarks.classic.ex_model"] = mod + + +def _load_huber_plugin(): + _inject_pkg_stubs() + _inject_ex_model_stub() + _preload_ewc_module() + here = os.path.dirname(__file__) + plugin_path = os.path.abspath( + os.path.join(here, "..", "..", "avalanche", "training", "plugins", "incdet_ewc.py") + ) + spec = importlib.util.spec_from_file_location("incdet_ewc_local", plugin_path) + module = importlib.util.module_from_spec(spec) + assert spec is not None and spec.loader is not None + spec.loader.exec_module(module) + return module.HuberEWCPlugin + + +class HuberEWCPluginTest(unittest.TestCase): + def test_penalty_nonzero_after_prev_exp(self): + torch.manual_seed(0) + model = Linear(5, 3) + HuberEWCPlugin = _load_huber_plugin() + plugin = HuberEWCPlugin(ewc_lambda=1.0, beta=0.5) + + class _Clock: + def __init__(self): + self.train_exp_counter = 0 + + class _Strategy: + def __init__(self, model: torch.nn.Module): + self.model = model + self.clock = _Clock() + self.device = torch.device("cpu") + self.loss = torch.tensor(0.0) + + strategy = _Strategy(model) + + plugin.saved_params[0] = {k: torch.zeros_like(p.data) for k, p in model.named_parameters()} + plugin.importances[0] = {k: torch.ones_like(p.data) for k, p in model.named_parameters()} + + strategy.clock.train_exp_counter = 1 + + plugin.before_backward(strategy) + + self.assertGreater(float(strategy.loss.item()), 0.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/avalanche/tests/plugins/test_pseudo_annotation_plugin.py b/avalanche/tests/plugins/test_pseudo_annotation_plugin.py new file mode 100644 index 000000000..003cc937f --- /dev/null +++ b/avalanche/tests/plugins/test_pseudo_annotation_plugin.py @@ -0,0 +1,39 @@ +import unittest +import torch +from torch.nn import Linear + +from avalanche.training.plugins.pseudo_annotation import PseudoAnnotationPlugin + + +class PseudoAnnotationPluginTest(unittest.TestCase): + def test_no_prev_model_noop(self): + plugin = PseudoAnnotationPlugin(confidence_thresh=0.99) + + class _Strategy: + def __init__(self): + self.device = torch.device("cpu") + self.adapted_dataset = None + self.experience = type("E", (), {"classes_in_this_experience": [0]}) + + strat = _Strategy() + # Should not raise + plugin.after_train_dataset_adaptation(strat) + + def test_after_training_sets_prev_model(self): + plugin = PseudoAnnotationPlugin(confidence_thresh=0.99) + + class _Strategy: + def __init__(self): + self.model = Linear(4, 2) + self.device = torch.device("cpu") + self.experience = type("E", (), {"classes_in_this_experience": [0, 1]}) + + strat = _Strategy() + plugin.after_training_exp(strat) + self.assertIsNotNone(plugin._prev_model) + + +if __name__ == "__main__": + unittest.main() + + diff --git a/avalanche/tests/strategies/test_incdet_ewc_integration.py b/avalanche/tests/strategies/test_incdet_ewc_integration.py new file mode 100644 index 000000000..a7c108728 --- /dev/null +++ b/avalanche/tests/strategies/test_incdet_ewc_integration.py @@ -0,0 +1,57 @@ +import unittest +import torch +from torch.optim import SGD +from torch.nn import CrossEntropyLoss, Linear +from torch.utils.data import TensorDataset + +from avalanche.training.supervised.strategy_wrappers import Naive +from avalanche.training.plugins import HuberEWCPlugin +from avalanche.training.plugins.pseudo_annotation import PseudoAnnotationPlugin +from avalanche.benchmarks.utils.classification_dataset import _make_taskaware_tensor_classification_dataset +from avalanche.benchmarks.scenarios.deprecated.generators import dataset_benchmark + + +class IncDetEWCIntegrationTest(unittest.TestCase): + def test_tiny_synthetic_two_exps(self): + device = torch.device("cpu") + # two tiny experiences of different classes + x0 = torch.randn(20, 10) + y0 = torch.zeros(20, dtype=torch.long) + x1 = torch.randn(20, 10) + y1 = torch.ones(20, dtype=torch.long) + + ds0 = _make_taskaware_tensor_classification_dataset(x0, y0, task_labels=0) + ds1 = _make_taskaware_tensor_classification_dataset(x1, y1, task_labels=0) + + bench = dataset_benchmark([ds0, ds1], [ds0, ds1]) + + model = Linear(10, 2) + optimizer = SGD(model.parameters(), lr=0.01) + criterion = CrossEntropyLoss() + + huber_plugin = HuberEWCPlugin(ewc_lambda=1.0, beta=1.0) + pseudo_plugin = PseudoAnnotationPlugin(confidence_thresh=0.99) + + strategy = Naive( + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=8, + train_epochs=1, + eval_mb_size=8, + device=device, + plugins=[huber_plugin, pseudo_plugin], + ) + + for i, exp in enumerate(bench.train_stream): + strategy.train(exp, num_workers=0, pin_memory=False) + if i == 1: + break + + res = strategy.eval(bench.test_stream[:1], num_workers=0, pin_memory=False) + self.assertTrue(isinstance(res, dict)) + self.assertTrue(len(res) > 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/incdet_ewc_cifar.py b/examples/incdet_ewc_cifar.py new file mode 100644 index 000000000..4e8e0124e --- /dev/null +++ b/examples/incdet_ewc_cifar.py @@ -0,0 +1,87 @@ +import argparse +import torch +from torch.optim import SGD +from torch.nn import CrossEntropyLoss + +from avalanche.benchmarks.classic.ccifar10 import SplitCIFAR10 +from avalanche.benchmarks.classic.ccifar100 import SplitCIFAR100 +from avalanche.models import SimpleCNN +from avalanche.training.supervised.strategy_wrappers import Naive +from avalanche.training.plugins import EvaluationPlugin +from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics +from avalanche.logging import InteractiveLogger + +from avalanche.training.plugins import HuberEWCPlugin, EWCPlugin + + +def _resolve_device(): + if torch.backends.mps.is_available(): + return torch.device("mps") + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def run(benchmark_name: str, use_huber: bool, device, mb_size: int, epochs: int): + if benchmark_name == "cifar10": + bench = SplitCIFAR10(n_experiences=5, seed=42) + num_classes = 10 + elif benchmark_name == "cifar100": + bench = SplitCIFAR100(n_experiences=10, seed=42) + num_classes = 100 + else: + raise ValueError("benchmark_name must be 'cifar10' or 'cifar100'") + + # SimpleCNN has built-ins for CIFAR + model = SimpleCNN(num_classes=num_classes) + optimizer = SGD(model.parameters(), lr=0.05, momentum=0.9) + criterion = CrossEntropyLoss() + + eval_plugin = EvaluationPlugin( + accuracy_metrics(epoch=True, experience=True, stream=True), + loss_metrics(epoch=True, experience=True, stream=True), + loggers=[InteractiveLogger()], + ) + + plugin = ( + HuberEWCPlugin(ewc_lambda=50.0, beta=0.5) + if use_huber + else EWCPlugin(ewc_lambda=50.0) + ) + + strategy = Naive( + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=mb_size, + train_epochs=epochs, + eval_mb_size=mb_size, + device=device, + plugins=[plugin], + evaluator=eval_plugin, + ) + + for exp in bench.train_stream: + print(f"Start of experience {exp.current_experience}") + strategy.train(exp, num_workers=2, pin_memory=False) + print("End of experience", exp.current_experience) + print("Evaluation on test stream") + strategy.eval(bench.test_stream, num_workers=2, pin_memory=False) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--benchmark", type=str, default="cifar10", choices=["cifar10", "cifar100"] + ) + parser.add_argument("--use_huber", type=int, default=1) + parser.add_argument("--train_mb_size", type=int, default=128) + parser.add_argument("--train_epochs", type=int, default=2) + args = parser.parse_args() + + device = _resolve_device() + run( + args.benchmark, + bool(args.use_huber), + device, + args.train_mb_size, + args.train_epochs, + ) diff --git a/examples/incdet_ewc_splitmnist.py b/examples/incdet_ewc_splitmnist.py new file mode 100644 index 000000000..9e5d69a40 --- /dev/null +++ b/examples/incdet_ewc_splitmnist.py @@ -0,0 +1,61 @@ +import torch +from torch.optim import SGD +from torch.nn import CrossEntropyLoss + +from avalanche.benchmarks.classic.cmnist import SplitMNIST +from avalanche.models import SimpleMLP +from avalanche.training.supervised.strategy_wrappers import Naive +from avalanche.training.plugins import EvaluationPlugin +from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics +from avalanche.logging import InteractiveLogger + +from avalanche.training.plugins import HuberEWCPlugin +from avalanche.training.plugins.pseudo_annotation import PseudoAnnotationPlugin + + +def _resolve_device() -> torch.device: + if torch.backends.mps.is_available(): + return torch.device("mps") + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def main(): + device = _resolve_device() + benchmark = SplitMNIST(n_experiences=2, seed=42) + + model = SimpleMLP(num_classes=benchmark.n_classes, input_size=28 * 28) + optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9) + criterion = CrossEntropyLoss() + + eval_plugin = EvaluationPlugin( + accuracy_metrics(epoch=True, experience=True, stream=True), + loss_metrics(epoch=True, experience=True, stream=True), + loggers=[InteractiveLogger()], + ) + + huber_plugin = HuberEWCPlugin(ewc_lambda=1000.0, beta=0.5) + pseudo_plugin = PseudoAnnotationPlugin(confidence_thresh=0.95) + + strategy = Naive( + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=64, + train_epochs=1, + eval_mb_size=64, + device=device, + plugins=[huber_plugin, pseudo_plugin], + evaluator=eval_plugin, + ) + + for experience in benchmark.train_stream: + print(f"Start of experience {experience.current_experience}") + print("Current classes:", experience.classes_in_this_experience) + strategy.train(experience, num_workers=0, pin_memory=False) + print("End of experience", experience.current_experience) + print("Evaluation on test stream") + strategy.eval(benchmark.test_stream, num_workers=0, pin_memory=False) + + +if __name__ == "__main__": + main() diff --git a/run_incdet_ewc.py b/run_incdet_ewc.py new file mode 100644 index 000000000..12946e9f3 --- /dev/null +++ b/run_incdet_ewc.py @@ -0,0 +1,106 @@ +import argparse +import torch +from torch.optim import SGD +from torch.nn import CrossEntropyLoss + +from avalanche.benchmarks.classic.cmnist import SplitMNIST +from avalanche.models import SimpleMLP +from avalanche.training.supervised.strategy_wrappers import Naive +from avalanche.training.plugins import EvaluationPlugin +from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics +from avalanche.logging import InteractiveLogger + +from avalanche.training.plugins import HuberEWCPlugin +from avalanche.training.plugins.pseudo_annotation import PseudoAnnotationPlugin + + +def _resolve_device(preference: str) -> torch.device: + pref = preference.lower() + if pref == "auto": + if torch.backends.mps.is_available(): + return torch.device("mps") + if torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + if pref == "mps": + return ( + torch.device("mps") + if torch.backends.mps.is_available() + else torch.device("cpu") + ) + if pref == "cuda": + return ( + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + ) + return torch.device("cpu") + + +def main(args: argparse.Namespace): + if args.seed is not None: + torch.manual_seed(int(args.seed)) + + device = _resolve_device(args.device) + + benchmark = SplitMNIST(n_experiences=args.n_experiences, seed=args.seed or 42) + + model = SimpleMLP(num_classes=benchmark.n_classes, input_size=28 * 28) + optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum) + criterion = CrossEntropyLoss() + + eval_plugin = EvaluationPlugin( + accuracy_metrics(epoch=True, experience=True, stream=True), + loss_metrics(epoch=True, experience=True, stream=True), + loggers=[InteractiveLogger()], + ) + + huber_plugin = HuberEWCPlugin(ewc_lambda=args.ewc_lambda, beta=args.beta) + pseudo_plugin = PseudoAnnotationPlugin(confidence_thresh=args.confidence_thresh) + + strategy = Naive( + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=args.train_mb_size, + train_epochs=args.train_epochs, + eval_mb_size=args.eval_mb_size, + device=device, + plugins=[huber_plugin, pseudo_plugin], + evaluator=eval_plugin, + ) + + for experience in benchmark.train_stream: + print(f"Start of experience {experience.current_experience}") + print("Current classes:", experience.classes_in_this_experience) + strategy.train( + experience, + num_workers=args.num_workers, + pin_memory=bool(args.pin_memory), + ) + print("End of experience", experience.current_experience) + print("Evaluation on test stream") + strategy.eval( + benchmark.test_stream, + num_workers=args.num_workers, + pin_memory=bool(args.pin_memory), + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--device", type=str, default="auto", choices=["auto", "mps", "cuda", "cpu"] + ) + parser.add_argument("--train_mb_size", type=int, default=64) + parser.add_argument("--eval_mb_size", type=int, default=64) + parser.add_argument("--train_epochs", type=int, default=1) + parser.add_argument("--lr", type=float, default=0.01) + parser.add_argument("--momentum", type=float, default=0.9) + parser.add_argument("--ewc_lambda", type=float, default=1000.0) + parser.add_argument("--beta", type=float, default=0.5) + parser.add_argument("--confidence_thresh", type=float, default=0.95) + parser.add_argument("--n_experiences", type=int, default=2) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--num_workers", type=int, default=0) + parser.add_argument("--pin_memory", type=int, default=0) + args = parser.parse_args() + main(args) From 6c278fb170fa9269630c802ccfb05fa709e42f59 Mon Sep 17 00:00:00 2001 From: Muhammad Aniq Date: Tue, 16 Sep 2025 11:55:15 +0100 Subject: [PATCH 2/4] chore: move plugins and tests to correct paths --- avalanche/avalanche/training/plugins/__init__.py | 2 -- avalanche/training/plugins/__init__.py | 2 ++ avalanche/{avalanche => }/training/plugins/incdet_ewc.py | 0 avalanche/{avalanche => }/training/plugins/pseudo_annotation.py | 0 {avalanche/tests => tests}/plugins/test_incdet_ewc_plugin.py | 0 .../tests => tests}/plugins/test_pseudo_annotation_plugin.py | 0 .../tests => tests}/strategies/test_incdet_ewc_integration.py | 0 7 files changed, 2 insertions(+), 2 deletions(-) delete mode 100644 avalanche/avalanche/training/plugins/__init__.py rename avalanche/{avalanche => }/training/plugins/incdet_ewc.py (100%) rename avalanche/{avalanche => }/training/plugins/pseudo_annotation.py (100%) rename {avalanche/tests => tests}/plugins/test_incdet_ewc_plugin.py (100%) rename {avalanche/tests => tests}/plugins/test_pseudo_annotation_plugin.py (100%) rename {avalanche/tests => tests}/strategies/test_incdet_ewc_integration.py (100%) diff --git a/avalanche/avalanche/training/plugins/__init__.py b/avalanche/avalanche/training/plugins/__init__.py deleted file mode 100644 index 18a5dd49f..000000000 --- a/avalanche/avalanche/training/plugins/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .incdet_ewc import HuberEWCPlugin -from .pseudo_annotation import PseudoAnnotationPlugin diff --git a/avalanche/training/plugins/__init__.py b/avalanche/training/plugins/__init__.py index d912fbf2d..045f139e6 100644 --- a/avalanche/training/plugins/__init__.py +++ b/avalanche/training/plugins/__init__.py @@ -27,3 +27,5 @@ from .update_fecam import * from .feature_distillation import * from .il2m import IL2MPlugin +from .incdet_ewc import HuberEWCPlugin +from .pseudo_annotation import PseudoAnnotationPlugin diff --git a/avalanche/avalanche/training/plugins/incdet_ewc.py b/avalanche/training/plugins/incdet_ewc.py similarity index 100% rename from avalanche/avalanche/training/plugins/incdet_ewc.py rename to avalanche/training/plugins/incdet_ewc.py diff --git a/avalanche/avalanche/training/plugins/pseudo_annotation.py b/avalanche/training/plugins/pseudo_annotation.py similarity index 100% rename from avalanche/avalanche/training/plugins/pseudo_annotation.py rename to avalanche/training/plugins/pseudo_annotation.py diff --git a/avalanche/tests/plugins/test_incdet_ewc_plugin.py b/tests/plugins/test_incdet_ewc_plugin.py similarity index 100% rename from avalanche/tests/plugins/test_incdet_ewc_plugin.py rename to tests/plugins/test_incdet_ewc_plugin.py diff --git a/avalanche/tests/plugins/test_pseudo_annotation_plugin.py b/tests/plugins/test_pseudo_annotation_plugin.py similarity index 100% rename from avalanche/tests/plugins/test_pseudo_annotation_plugin.py rename to tests/plugins/test_pseudo_annotation_plugin.py diff --git a/avalanche/tests/strategies/test_incdet_ewc_integration.py b/tests/strategies/test_incdet_ewc_integration.py similarity index 100% rename from avalanche/tests/strategies/test_incdet_ewc_integration.py rename to tests/strategies/test_incdet_ewc_integration.py From 65554bcae65405278a499595c6a5054666efa735 Mon Sep 17 00:00:00 2001 From: Muhammad Aniq Date: Tue, 16 Sep 2025 11:58:22 +0100 Subject: [PATCH 3/4] docs: add IncDet-EWC how-to and link in SUMMARY --- docs/gitbook/SUMMARY.md | 1 + docs/gitbook/how-tos/incdet_ewc.md | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+) create mode 100644 docs/gitbook/how-tos/incdet_ewc.md diff --git a/docs/gitbook/SUMMARY.md b/docs/gitbook/SUMMARY.md index cb9e0d37f..8da09f025 100644 --- a/docs/gitbook/SUMMARY.md +++ b/docs/gitbook/SUMMARY.md @@ -22,6 +22,7 @@ * [Contribute to Avalanche](from-zero-to-hero-tutorial/09_contribute-to-avalanche.md) ## How-Tos +* [IncDet‑EWC (Huber EWC + Pseudo‑Annotation)](how-tos/incdet_ewc.md) * [AvalancheDataset](how-tos/avalanchedataset/README.md) * [avalanche-datasets](how-tos/avalanchedataset/avalanche-datasets.md) diff --git a/docs/gitbook/how-tos/incdet_ewc.md b/docs/gitbook/how-tos/incdet_ewc.md new file mode 100644 index 000000000..c1ea2ce8e --- /dev/null +++ b/docs/gitbook/how-tos/incdet_ewc.md @@ -0,0 +1,24 @@ +# IncDet‑EWC (Huber EWC + Pseudo‑Annotation) + +Use `HuberEWCPlugin` with `PseudoAnnotationPlugin` to mitigate forgetting and leverage confident predictions on past classes. + +## Quickstart (SplitMNIST) +```bash +python run_incdet_ewc.py --device auto --train_mb_size 64 --eval_mb_size 64 \ + --train_epochs 1 --lr 0.01 --momentum 0.9 --ewc_lambda 1000.0 --beta 0.5 \ + --confidence_thresh 0.95 --n_experiences 2 --seed 42 +``` + +## CIFAR +```bash +python examples/incdet_ewc_cifar.py --benchmark cifar10 --use_huber 1 --train_mb_size 128 --train_epochs 2 +``` + +## Minimal API +```python +from avalanche.training.plugins import HuberEWCPlugin +from avalanche.training.plugins.pseudo_annotation import PseudoAnnotationPlugin +plugins = [HuberEWCPlugin(ewc_lambda=1000.0, beta=0.5), + PseudoAnnotationPlugin(confidence_thresh=0.95)] +# pass plugins=plugins into your Naive(...) strategy +``` From 02978bb43312131fcf22bc8c2a4057de393a3e46 Mon Sep 17 00:00:00 2001 From: Muhammad Aniq Date: Tue, 16 Sep 2025 13:43:14 +0100 Subject: [PATCH 4/4] docs: format paper summary, fix lists/code fences, add references --- docs/gitbook/how-tos/incdet_ewc.md | 32 ++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/docs/gitbook/how-tos/incdet_ewc.md b/docs/gitbook/how-tos/incdet_ewc.md index c1ea2ce8e..39ee2d1e1 100644 --- a/docs/gitbook/how-tos/incdet_ewc.md +++ b/docs/gitbook/how-tos/incdet_ewc.md @@ -1,5 +1,23 @@ # IncDet‑EWC (Huber EWC + Pseudo‑Annotation) +## Paper summary +This paper presents IncDet, a framework that successfully adapts Elastic Weight Consolidation (EWC) to the task of incremental object detection. While EWC is effective in general incremental learning, it has previously been shown to fail when directly applied to object detection. + +The authors identify two core issues responsible for this failure through controlled experiments: +1. Missing Old Class Annotations: When training on a new set of classes, images may contain objects from old classes that are not annotated. This causes the model to incorrectly learn to classify these old-class objects as background, leading to catastrophic forgetting. +2. Unstable Training: The quadratic regularisation loss used in EWC can cause gradient explosion when trying to balance performance between old and new classes, leading to unstable training. + +To address these problems, the paper proposes two corresponding solutions: +1. Pseudo Annotation: To compensate for missing labels, the old model is used to predict bounding boxes for old-class objects in the new training images. These "pseudo" annotations are then combined with the ground-truth annotations for the new classes, preventing the model from misclassifying old objects as background. +2. Huber Regularization: A novel Huber regularization loss is introduced to replace EWC's original quadratic loss. This method adaptively clips the gradient for each parameter based on its importance to the old tasks, which prevents gradient explosion and allows for stable training and a better trade-off between remembering old classes and learning new ones. + +These solutions are integrated into the IncDet framework, a general and flexible pipeline for incremental object detection. The process involves: +- Initial Training: A base model is trained on an initial set of classes. +- Predict & Aggregate: The trained model generates pseudo-annotations for old classes on new images, which are then aggregated with the manual annotations for the new classes. +- Incremental Fine-tuning: The model is fine-tuned using the combined annotations and the Huber regularization to learn the new classes while retaining knowledge of the old ones. This cycle can be executed recursively as more classes are added. + +The framework was implemented using both Fast R-CNN and Faster R-CNN, demonstrating its versatility. Experiments on the PASCAL VOC and COCO datasets show that IncDet achieves new state-of-the-art results, surpassing previous methods in both final performance and in minimizing the performance gap compared to joint training on all classes. The proposed method is also more computationally and memory-efficient during training compared to prior auxiliary-based approaches. + Use `HuberEWCPlugin` with `PseudoAnnotationPlugin` to mitigate forgetting and leverage confident predictions on past classes. ## Quickstart (SplitMNIST) @@ -16,9 +34,15 @@ python examples/incdet_ewc_cifar.py --benchmark cifar10 --use_huber 1 --train_mb ## Minimal API ```python -from avalanche.training.plugins import HuberEWCPlugin -from avalanche.training.plugins.pseudo_annotation import PseudoAnnotationPlugin -plugins = [HuberEWCPlugin(ewc_lambda=1000.0, beta=0.5), - PseudoAnnotationPlugin(confidence_thresh=0.95)] +from avalanche.training.plugins import HuberEWCPlugin, PseudoAnnotationPlugin + +plugins = [ + HuberEWCPlugin(ewc_lambda=1000.0, beta=0.5), + PseudoAnnotationPlugin(confidence_thresh=0.95), +] # pass plugins=plugins into your Naive(...) strategy ``` + +## References +- IncDet: https://ieeexplore.ieee.org/document/9127478 — DOI: https://doi.org/10.1109/TNNLS.2020.3002583 +- EWC: https://www.pnas.org/doi/10.1073/pnas.1611835114 \ No newline at end of file