From f3c5de8e78c2e79c5675ca72e881f698782ee500 Mon Sep 17 00:00:00 2001 From: ghecko Date: Sun, 26 Apr 2026 12:31:57 +0200 Subject: [PATCH 1/2] perf(ara): cache original weights to skip disk reload between trials Each Optuna trial in ARA mode previously fell into the slow path of reset_model() and reloaded the full model from disk, because ara_abliterate mutates module.weight in-place via LBFGS and bypasses the LoRA adapter mechanism. For large models this dominated trial time (the 'Loading weights' progress bar visible in trial traces). Add ARAWeightCache, an in-memory snapshot of the original weights of ARA-targeted modules: - reset_model() restores from cache via memcpy when the cache is healthy, reducing per-trial reset cost from minutes to ms. - ara_abliterate() plans the cache before mutating: drop snapshots no longer needed, snapshot newly-targeted modules. Diff is computed from start_layer_index/end_layer_index, so RAM only holds what's actually in scope. - Storage tier is per source device: snapshots stay on the source GPU when there is room (intra-device copy is fastest), else spill to CPU. A 2 GB safety margin is reserved on each tier to avoid starving the LBFGS optimizer. - If neither tier can fit the planned snapshot, prints a warning and flags the cache so the next reset_model() falls back to a disk reload, recovering the originals of the modules mutated without backup. No partial snapshots are ever taken, to keep the cache invariant clean (every entry holds the pristine original weight). The non-ARA LoRA fast path is unchanged. --- src/heretic/main.py | 4 +- src/heretic/model.py | 229 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 225 insertions(+), 8 deletions(-) diff --git a/src/heretic/main.py b/src/heretic/main.py index 0abd11d7..b0ea017d 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -623,7 +623,7 @@ def objective(trial: Trial) -> tuple[float, float]: for name, value in get_trial_parameters(settings, trial).items(): print(f" * {name} = [bold]{value}[/]") if settings.use_ara: - print("* Reloading model...") + print("* Resetting model...") model.reset_model() print("* Abliterating (Arbitrary-Rank Ablation)...") model.ara_abliterate(good_module_io, bad_module_io, ara_parameters) @@ -811,7 +811,7 @@ def count_completed_trials() -> int: for name, value in get_trial_parameters(settings, trial).items(): print(f" * {name} = [bold]{value}[/]") if settings.use_ara: - print("* Reloading model...") + print("* Resetting model...") model.reset_model() print("* Abliterating (Arbitrary-Rank Ablation)...") model.ara_abliterate( diff --git a/src/heretic/model.py b/src/heretic/model.py index 108e9dfd..3149bc2f 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -12,6 +12,7 @@ import torch.nn.functional as F from peft import LoraConfig, PeftModel, get_peft_model from peft.tuners.lora.layer import Linear +from psutil import virtual_memory from torch import FloatTensor, LongTensor, Tensor from torch.nn import Module, ModuleList from torch.optim import LBFGS @@ -71,6 +72,127 @@ class ARAParameters: ModuleIO: TypeAlias = list[dict[str, dict[int, tuple[Tensor, Tensor]]]] +# Headroom (in bytes) we keep free on any device when planning a snapshot, +# so that we don't starve the LBFGS optimizer or the inference path. +_SNAPSHOT_SAFETY_MARGIN_BYTES = 2 * 1024**3 # 2 GB + + +class ARAWeightCache: + """Caches original weights of ARA-touched modules between trials. + + Holds the invariant: every entry stores the pristine, original weight of + its module. To preserve this, the caller must: + 1. Restore from cache before any new optimization (model becomes clean). + 2. Drop entries for modules not used on the next trial (free memory). + 3. Snapshot newly-targeted modules BEFORE mutating them. + + Storage tier is per source device: snapshots stay on the source GPU when + there is room (intra-device copy on restore, fastest), otherwise they + spill to system RAM. If neither tier can host the snapshot, try_snapshot + returns False and the caller is expected to set ``needs_reload`` so that + the next reset_model() falls back to a full disk reload. + """ + + def __init__(self): + self._snapshots: dict[int, Tensor] = {} + self._modules: dict[int, Module] = {} + self.needs_reload = False + + def __len__(self) -> int: + return len(self._snapshots) + + def __contains__(self, module: Module) -> bool: + return id(module) in self._snapshots + + def cached_modules(self) -> list[Module]: + return list(self._modules.values()) + + def restore_all(self) -> None: + with torch.no_grad(): + for mid, snapshot in self._snapshots.items(): + target = cast(Tensor, self._modules[mid].weight) + if snapshot.device == target.device: + target.data.copy_(snapshot) + else: + target.data.copy_(snapshot.to(target.device, non_blocking=True)) + + def free(self, modules: list[Module]) -> None: + for module in modules: + mid = id(module) + self._snapshots.pop(mid, None) + self._modules.pop(mid, None) + + def clear(self) -> None: + # Required after any full reload: cached module ``id``s are stale. + self._snapshots.clear() + self._modules.clear() + self.needs_reload = False + + @staticmethod + def _module_size_bytes(module: Module) -> int: + weight = cast(Tensor, module.weight) + return weight.numel() * weight.element_size() + + @classmethod + def estimate_total_size(cls, modules: list[Module]) -> int: + return sum(cls._module_size_bytes(m) for m in modules) + + @staticmethod + def _gpu_has_room(device: torch.device, size_bytes: int) -> bool: + if device.type != "cuda": + return False + try: + free, _ = torch.cuda.mem_get_info(device) + except Exception: + return False + return free >= size_bytes + _SNAPSHOT_SAFETY_MARGIN_BYTES + + def try_snapshot(self, modules: list[Module]) -> bool: + """Either every module is snapshotted (returns True) or none is + (returns False). Partial snapshots are never taken, to keep the + cache invariant clean.""" + if not modules: + return True + + # Tally size per source device so we make a single accept/reject + # decision per device (multiple weights on the same GPU could each + # individually "fit" while collectively overflowing). + size_per_device: dict[torch.device, int] = {} + for module in modules: + weight = cast(Tensor, module.weight) + size_per_device[weight.device] = ( + size_per_device.get(weight.device, 0) + + self._module_size_bytes(module) + ) + + keep_on_source_device: set[torch.device] = set() + cpu_required_bytes = 0 + for device, size in size_per_device.items(): + if device.type == "cuda" and self._gpu_has_room(device, size): + keep_on_source_device.add(device) + else: + cpu_required_bytes += size + + if cpu_required_bytes > 0: + available = virtual_memory().available + if available < cpu_required_bytes + _SNAPSHOT_SAFETY_MARGIN_BYTES: + return False + + with torch.no_grad(): + for module in modules: + weight = cast(Tensor, module.weight) + target_device = ( + weight.device + if weight.device in keep_on_source_device + else torch.device("cpu") + ) + snapshot = weight.detach().to(target_device, copy=True) + self._snapshots[id(module)] = snapshot + self._modules[id(module)] = module + + return True + + class Model: model: PreTrainedModel | PeftModel tokenizer: PreTrainedTokenizerBase @@ -80,6 +202,10 @@ def __init__(self, settings: Settings): self.settings = settings self.response_prefix = "" self.needs_reload = False + # Caches original weights of ARA-touched modules so reset_model() + # can restore them in O(memcpy) instead of reloading from disk. + # Only populated when settings.use_ara is True. + self._ara_cache = ARAWeightCache() print() print(f"Loading model [bold]{settings.model}[/]...") @@ -302,27 +428,50 @@ def reset_model(self): Resets the model to a clean state for the next trial or evaluation. Behavior: - - Fast path: If the same model is loaded and doesn't need full reload, - resets LoRA adapter weights to zero (identity transformation). - - Slow path: If switching models or after merge_and_unload(), - performs full model reload with quantization config. + - Fast path (non-ARA): If the same model is loaded and doesn't need + full reload, resets LoRA adapter weights to zero (identity). + - Fast path (ARA): If the same model is loaded and the snapshot cache + is healthy, restores cached original weights via memcpy. No disk I/O. + - Slow path: If switching models, after merge_and_unload(), or when + the ARA cache reports it can't fully back the next trial, performs + a full model reload with quantization config. """ current_model = getattr(self.model.config, "name_or_path", None) + same_model = current_model == self.settings.model + + # Fast path (non-ARA): zero out LoRA adapters. if ( - current_model == self.settings.model + same_model and not self.needs_reload and not self.settings.use_ara ): - # Reset LoRA adapters to zero (identity transformation) for name, module in self.model.named_modules(): if "lora_B" in name and hasattr(module, "weight"): torch.nn.init.zeros_(module.weight) return + # Fast path (ARA): restore from in-memory snapshot cache. + if ( + same_model + and not self.needs_reload + and self.settings.use_ara + and not self._ara_cache.needs_reload + ): + if len(self._ara_cache) > 0: + self._ara_cache.restore_all() + return + + # Slow path: full reload from disk. This is the only path that costs + # real time (tens of seconds to minutes for large models). + print("* Reloading weights from disk...") + dtype = self.model.dtype # Purge existing model object from memory to make space. self.model = None # ty:ignore[invalid-assignment] + # Module IDs captured in the ARA cache reference the model object we + # just dropped, so the cache must be cleared before any reload. + self._ara_cache.clear() empty_cache() quantization_config = self._get_quantization_config(str(dtype).split(".")[-1]) @@ -569,12 +718,80 @@ def abliterate( weight_A.data = lora_A.to(weight_A.dtype) weight_B.data = lora_B.to(weight_B.dtype) + def _collect_ara_target_modules( + self, parameters: ARAParameters + ) -> list[Module]: + """Enumerate every module that ara_abliterate will mutate for these + parameters. Used by the snapshot cache planner.""" + target_modules: list[Module] = [] + for layer_index in range( + parameters.start_layer_index, + parameters.end_layer_index, + ): + for modules in self.get_layer_modules(layer_index).values(): + target_modules.extend(modules) + return target_modules + + def _prepare_ara_cache(self, target_modules: list[Module]) -> None: + """Diff the snapshot cache against this trial's targets: + + 1. Drop snapshots for modules we won't touch (free memory). + 2. Snapshot newly-targeted modules (which currently hold originals, + because reset_model() either restored them from the cache or + freshly loaded them from disk). + + On insufficient memory, prints a warning and sets ``needs_reload`` + so the next reset_model() falls back to a disk reload. + """ + cache = self._ara_cache + target_ids = {id(m) for m in target_modules} + + # Step 1: drop entries no longer needed. + to_drop = [ + m for m in cache.cached_modules() if id(m) not in target_ids + ] + if to_drop: + cache.free(to_drop) + empty_cache() + + # Step 2: snapshot newly-targeted modules. + to_add = [m for m in target_modules if m not in cache] + if not to_add: + return + + size_gb = cache.estimate_total_size(to_add) / (1024**3) + + if cache.try_snapshot(to_add): + # Only mention the cache when it does meaningful work, to avoid + # noise on small models where the snapshot is essentially free. + if size_gb >= 0.1: + print( + f"* Snapshotted [bold]{len(to_add)}[/] module(s) " + f"([bold]{size_gb:.2f} GB[/]) for fast reset" + ) + return + + # Could not fit the planned snapshot on any tier. Run the trial + # anyway, but flag the cache so the next reset_model() reloads from + # disk to recover the originals we're about to overwrite. + print( + f"* [yellow]Not enough RAM to snapshot all targeted layers " + f"(would need ~[bold]{size_gb:.2f} GB[/]). " + f"Weights will be reloaded from disk after this iteration...[/]" + ) + cache.needs_reload = True + def ara_abliterate( self, good_module_io: ModuleIO, bad_module_io: ModuleIO, parameters: ARAParameters, ): + # Plan/refresh the snapshot cache BEFORE any mutation, so the + # invariant "cache holds originals" is preserved. reset_model() must + # have been called just before this, leaving every module pristine. + self._prepare_ara_cache(self._collect_ara_target_modules(parameters)) + for layer_index in range( parameters.start_layer_index, parameters.end_layer_index, From df8fed9e51f2cb24e5e9386b4254ba3776742c14 Mon Sep 17 00:00:00 2001 From: ghecko Date: Sun, 26 Apr 2026 13:40:51 +0200 Subject: [PATCH 2/2] Also fix a pre-existing AttributeError when reset_model() is called after an interrupted previous reset (e.g. Ctrl+C during from_pretrained). Track the active dtype on self._dtype and recover via _reload_from_disk() when self.model is None. --- src/heretic/model.py | 115 ++++++++++++++++++++++++++++--------------- 1 file changed, 76 insertions(+), 39 deletions(-) diff --git a/src/heretic/model.py b/src/heretic/model.py index 3149bc2f..fd03bc7b 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -80,17 +80,17 @@ class ARAParameters: class ARAWeightCache: """Caches original weights of ARA-touched modules between trials. - Holds the invariant: every entry stores the pristine, original weight of - its module. To preserve this, the caller must: - 1. Restore from cache before any new optimization (model becomes clean). - 2. Drop entries for modules not used on the next trial (free memory). + Holds the invariant: every entry stores the pristine, original weight + of its module. To preserve this, the caller must: + 1. Restore from cache before any new optimization (model is clean). + 2. Drop entries for modules not used on the next trial (free RAM). 3. Snapshot newly-targeted modules BEFORE mutating them. - Storage tier is per source device: snapshots stay on the source GPU when - there is room (intra-device copy on restore, fastest), otherwise they - spill to system RAM. If neither tier can host the snapshot, try_snapshot - returns False and the caller is expected to set ``needs_reload`` so that - the next reset_model() falls back to a full disk reload. + Storage tier is per source device: snapshots stay on the source GPU + when there is room (intra-device copy on restore, fastest), otherwise + spill to system RAM. If neither tier can host the snapshot, + try_snapshot returns False and the caller is expected to set + needs_reload so the next reset_model() falls back to a disk reload. """ def __init__(self): @@ -114,7 +114,9 @@ def restore_all(self) -> None: if snapshot.device == target.device: target.data.copy_(snapshot) else: - target.data.copy_(snapshot.to(target.device, non_blocking=True)) + target.data.copy_( + snapshot.to(target.device, non_blocking=True) + ) def free(self, modules: list[Module]) -> None: for module in modules: @@ -123,7 +125,7 @@ def free(self, modules: list[Module]) -> None: self._modules.pop(mid, None) def clear(self) -> None: - # Required after any full reload: cached module ``id``s are stale. + # Required after any full reload: cached module ids are stale. self._snapshots.clear() self._modules.clear() self.needs_reload = False @@ -155,8 +157,7 @@ def try_snapshot(self, modules: list[Module]) -> bool: return True # Tally size per source device so we make a single accept/reject - # decision per device (multiple weights on the same GPU could each - # individually "fit" while collectively overflowing). + # decision per device. size_per_device: dict[torch.device, int] = {} for module in modules: weight = cast(Tensor, module.weight) @@ -175,7 +176,10 @@ def try_snapshot(self, modules: list[Module]) -> bool: if cpu_required_bytes > 0: available = virtual_memory().available - if available < cpu_required_bytes + _SNAPSHOT_SAFETY_MARGIN_BYTES: + if ( + available + < cpu_required_bytes + _SNAPSHOT_SAFETY_MARGIN_BYTES + ): return False with torch.no_grad(): @@ -206,6 +210,10 @@ def __init__(self, settings: Settings): # can restore them in O(memcpy) instead of reloading from disk. # Only populated when settings.use_ara is True. self._ara_cache = ARAWeightCache() + # Tracks the dtype of the currently loaded model so reset_model() + # can recover even if self.model has been nulled out by an + # interrupted previous reset (e.g. Ctrl+C during disk reload). + self._dtype: torch.dtype | None = None print() print(f"Loading model [bold]{settings.model}[/]...") @@ -287,6 +295,10 @@ def __init__(self, settings: Settings): if self.model is None: raise Exception("Failed to load model with all configured dtypes.") + # Remember the dtype that worked, so reset_model() can recover + # the disk-reload path even if self.model has been nulled out. + self._dtype = self.model.dtype + if not settings.use_ara: self._apply_lora() @@ -428,14 +440,31 @@ def reset_model(self): Resets the model to a clean state for the next trial or evaluation. Behavior: - - Fast path (non-ARA): If the same model is loaded and doesn't need - full reload, resets LoRA adapter weights to zero (identity). - - Fast path (ARA): If the same model is loaded and the snapshot cache - is healthy, restores cached original weights via memcpy. No disk I/O. - - Slow path: If switching models, after merge_and_unload(), or when - the ARA cache reports it can't fully back the next trial, performs - a full model reload with quantization config. + - Recovery: If self.model is None (e.g. a previous reset was + interrupted between purging self.model and from_pretrained), do + a full disk reload using the last known good dtype. + - Fast path (non-ARA): If the same model is loaded and doesn't + need full reload, resets LoRA adapter weights to zero. + - Fast path (ARA): If the same model is loaded and the snapshot + cache is healthy, restores cached original weights via memcpy. + No disk I/O. + - Slow path: If switching models, after merge_and_unload(), or + when the ARA cache reports it can't fully back the next trial, + performs a full model reload with quantization config. """ + # Recovery path: a previous reset_model() may have been + # interrupted (Ctrl+C, OOM) between ``self.model = None`` and + # the from_pretrained call below, leaving us with no live model. + # Reload using the dtype we recorded at the last successful load. + if self.model is None: + assert self._dtype is not None, ( + "self.model is None and no dtype recorded — this should " + "be unreachable because __init__ always records a dtype " + "after a successful load." + ) + self._reload_from_disk(self._dtype) + return + current_model = getattr(self.model.config, "name_or_path", None) same_model = current_model == self.settings.model @@ -461,16 +490,19 @@ def reset_model(self): self._ara_cache.restore_all() return - # Slow path: full reload from disk. This is the only path that costs - # real time (tens of seconds to minutes for large models). - print("* Reloading weights from disk...") + # Slow path. + self._reload_from_disk(self.model.dtype) - dtype = self.model.dtype + def _reload_from_disk(self, dtype: torch.dtype) -> None: + """Drop self.model and re-instantiate it from disk with the given + dtype. Used by the slow path of reset_model() and by the recovery + path when self.model has been nulled out.""" + print("* Reloading weights from disk...") # Purge existing model object from memory to make space. self.model = None # ty:ignore[invalid-assignment] - # Module IDs captured in the ARA cache reference the model object we - # just dropped, so the cache must be cleared before any reload. + # Module IDs captured in the ARA cache reference the model object + # we just dropped, so the cache must be cleared before any reload. self._ara_cache.clear() empty_cache() @@ -493,6 +525,7 @@ def reset_model(self): if not self.settings.use_ara: self._apply_lora() + self._dtype = self.model.dtype self.needs_reload = False def get_layers(self) -> ModuleList: @@ -721,8 +754,9 @@ def abliterate( def _collect_ara_target_modules( self, parameters: ARAParameters ) -> list[Module]: - """Enumerate every module that ara_abliterate will mutate for these - parameters. Used by the snapshot cache planner.""" + """Enumerate every module that ara_abliterate will mutate for + these parameters. Used by the snapshot cache planner. + """ target_modules: list[Module] = [] for layer_index in range( parameters.start_layer_index, @@ -736,11 +770,11 @@ def _prepare_ara_cache(self, target_modules: list[Module]) -> None: """Diff the snapshot cache against this trial's targets: 1. Drop snapshots for modules we won't touch (free memory). - 2. Snapshot newly-targeted modules (which currently hold originals, - because reset_model() either restored them from the cache or - freshly loaded them from disk). + 2. Snapshot newly-targeted modules (which currently hold + originals, because reset_model() either restored them from + the cache or freshly loaded them from disk). - On insufficient memory, prints a warning and sets ``needs_reload`` + On insufficient memory, prints a warning and sets needs_reload so the next reset_model() falls back to a disk reload. """ cache = self._ara_cache @@ -762,7 +796,7 @@ def _prepare_ara_cache(self, target_modules: list[Module]) -> None: size_gb = cache.estimate_total_size(to_add) / (1024**3) if cache.try_snapshot(to_add): - # Only mention the cache when it does meaningful work, to avoid + # Only print when the cache does meaningful work, to avoid # noise on small models where the snapshot is essentially free. if size_gb >= 0.1: print( @@ -772,8 +806,8 @@ def _prepare_ara_cache(self, target_modules: list[Module]) -> None: return # Could not fit the planned snapshot on any tier. Run the trial - # anyway, but flag the cache so the next reset_model() reloads from - # disk to recover the originals we're about to overwrite. + # anyway, but flag the cache so the next reset_model() reloads + # from disk to recover the originals we're about to overwrite. print( f"* [yellow]Not enough RAM to snapshot all targeted layers " f"(would need ~[bold]{size_gb:.2f} GB[/]). " @@ -788,9 +822,12 @@ def ara_abliterate( parameters: ARAParameters, ): # Plan/refresh the snapshot cache BEFORE any mutation, so the - # invariant "cache holds originals" is preserved. reset_model() must - # have been called just before this, leaving every module pristine. - self._prepare_ara_cache(self._collect_ara_target_modules(parameters)) + # invariant 'cache holds originals' is preserved. reset_model() + # must have been called just before this, leaving every module + # pristine. + self._prepare_ara_cache( + self._collect_ara_target_modules(parameters) + ) for layer_index in range( parameters.start_layer_index,