diff --git a/src/heretic/main.py b/src/heretic/main.py index 0abd11d7..b0ea017d 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -623,7 +623,7 @@ def objective(trial: Trial) -> tuple[float, float]: for name, value in get_trial_parameters(settings, trial).items(): print(f" * {name} = [bold]{value}[/]") if settings.use_ara: - print("* Reloading model...") + print("* Resetting model...") model.reset_model() print("* Abliterating (Arbitrary-Rank Ablation)...") model.ara_abliterate(good_module_io, bad_module_io, ara_parameters) @@ -811,7 +811,7 @@ def count_completed_trials() -> int: for name, value in get_trial_parameters(settings, trial).items(): print(f" * {name} = [bold]{value}[/]") if settings.use_ara: - print("* Reloading model...") + print("* Resetting model...") model.reset_model() print("* Abliterating (Arbitrary-Rank Ablation)...") model.ara_abliterate( diff --git a/src/heretic/model.py b/src/heretic/model.py index 108e9dfd..fd03bc7b 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -12,6 +12,7 @@ import torch.nn.functional as F from peft import LoraConfig, PeftModel, get_peft_model from peft.tuners.lora.layer import Linear +from psutil import virtual_memory from torch import FloatTensor, LongTensor, Tensor from torch.nn import Module, ModuleList from torch.optim import LBFGS @@ -71,6 +72,131 @@ class ARAParameters: ModuleIO: TypeAlias = list[dict[str, dict[int, tuple[Tensor, Tensor]]]] +# Headroom (in bytes) we keep free on any device when planning a snapshot, +# so that we don't starve the LBFGS optimizer or the inference path. +_SNAPSHOT_SAFETY_MARGIN_BYTES = 2 * 1024**3 # 2 GB + + +class ARAWeightCache: + """Caches original weights of ARA-touched modules between trials. + + Holds the invariant: every entry stores the pristine, original weight + of its module. To preserve this, the caller must: + 1. Restore from cache before any new optimization (model is clean). + 2. Drop entries for modules not used on the next trial (free RAM). + 3. Snapshot newly-targeted modules BEFORE mutating them. + + Storage tier is per source device: snapshots stay on the source GPU + when there is room (intra-device copy on restore, fastest), otherwise + spill to system RAM. If neither tier can host the snapshot, + try_snapshot returns False and the caller is expected to set + needs_reload so the next reset_model() falls back to a disk reload. + """ + + def __init__(self): + self._snapshots: dict[int, Tensor] = {} + self._modules: dict[int, Module] = {} + self.needs_reload = False + + def __len__(self) -> int: + return len(self._snapshots) + + def __contains__(self, module: Module) -> bool: + return id(module) in self._snapshots + + def cached_modules(self) -> list[Module]: + return list(self._modules.values()) + + def restore_all(self) -> None: + with torch.no_grad(): + for mid, snapshot in self._snapshots.items(): + target = cast(Tensor, self._modules[mid].weight) + if snapshot.device == target.device: + target.data.copy_(snapshot) + else: + target.data.copy_( + snapshot.to(target.device, non_blocking=True) + ) + + def free(self, modules: list[Module]) -> None: + for module in modules: + mid = id(module) + self._snapshots.pop(mid, None) + self._modules.pop(mid, None) + + def clear(self) -> None: + # Required after any full reload: cached module ids are stale. + self._snapshots.clear() + self._modules.clear() + self.needs_reload = False + + @staticmethod + def _module_size_bytes(module: Module) -> int: + weight = cast(Tensor, module.weight) + return weight.numel() * weight.element_size() + + @classmethod + def estimate_total_size(cls, modules: list[Module]) -> int: + return sum(cls._module_size_bytes(m) for m in modules) + + @staticmethod + def _gpu_has_room(device: torch.device, size_bytes: int) -> bool: + if device.type != "cuda": + return False + try: + free, _ = torch.cuda.mem_get_info(device) + except Exception: + return False + return free >= size_bytes + _SNAPSHOT_SAFETY_MARGIN_BYTES + + def try_snapshot(self, modules: list[Module]) -> bool: + """Either every module is snapshotted (returns True) or none is + (returns False). Partial snapshots are never taken, to keep the + cache invariant clean.""" + if not modules: + return True + + # Tally size per source device so we make a single accept/reject + # decision per device. + size_per_device: dict[torch.device, int] = {} + for module in modules: + weight = cast(Tensor, module.weight) + size_per_device[weight.device] = ( + size_per_device.get(weight.device, 0) + + self._module_size_bytes(module) + ) + + keep_on_source_device: set[torch.device] = set() + cpu_required_bytes = 0 + for device, size in size_per_device.items(): + if device.type == "cuda" and self._gpu_has_room(device, size): + keep_on_source_device.add(device) + else: + cpu_required_bytes += size + + if cpu_required_bytes > 0: + available = virtual_memory().available + if ( + available + < cpu_required_bytes + _SNAPSHOT_SAFETY_MARGIN_BYTES + ): + return False + + with torch.no_grad(): + for module in modules: + weight = cast(Tensor, module.weight) + target_device = ( + weight.device + if weight.device in keep_on_source_device + else torch.device("cpu") + ) + snapshot = weight.detach().to(target_device, copy=True) + self._snapshots[id(module)] = snapshot + self._modules[id(module)] = module + + return True + + class Model: model: PreTrainedModel | PeftModel tokenizer: PreTrainedTokenizerBase @@ -80,6 +206,14 @@ def __init__(self, settings: Settings): self.settings = settings self.response_prefix = "" self.needs_reload = False + # Caches original weights of ARA-touched modules so reset_model() + # can restore them in O(memcpy) instead of reloading from disk. + # Only populated when settings.use_ara is True. + self._ara_cache = ARAWeightCache() + # Tracks the dtype of the currently loaded model so reset_model() + # can recover even if self.model has been nulled out by an + # interrupted previous reset (e.g. Ctrl+C during disk reload). + self._dtype: torch.dtype | None = None print() print(f"Loading model [bold]{settings.model}[/]...") @@ -161,6 +295,10 @@ def __init__(self, settings: Settings): if self.model is None: raise Exception("Failed to load model with all configured dtypes.") + # Remember the dtype that worked, so reset_model() can recover + # the disk-reload path even if self.model has been nulled out. + self._dtype = self.model.dtype + if not settings.use_ara: self._apply_lora() @@ -302,27 +440,70 @@ def reset_model(self): Resets the model to a clean state for the next trial or evaluation. Behavior: - - Fast path: If the same model is loaded and doesn't need full reload, - resets LoRA adapter weights to zero (identity transformation). - - Slow path: If switching models or after merge_and_unload(), - performs full model reload with quantization config. + - Recovery: If self.model is None (e.g. a previous reset was + interrupted between purging self.model and from_pretrained), do + a full disk reload using the last known good dtype. + - Fast path (non-ARA): If the same model is loaded and doesn't + need full reload, resets LoRA adapter weights to zero. + - Fast path (ARA): If the same model is loaded and the snapshot + cache is healthy, restores cached original weights via memcpy. + No disk I/O. + - Slow path: If switching models, after merge_and_unload(), or + when the ARA cache reports it can't fully back the next trial, + performs a full model reload with quantization config. """ + # Recovery path: a previous reset_model() may have been + # interrupted (Ctrl+C, OOM) between ``self.model = None`` and + # the from_pretrained call below, leaving us with no live model. + # Reload using the dtype we recorded at the last successful load. + if self.model is None: + assert self._dtype is not None, ( + "self.model is None and no dtype recorded — this should " + "be unreachable because __init__ always records a dtype " + "after a successful load." + ) + self._reload_from_disk(self._dtype) + return + current_model = getattr(self.model.config, "name_or_path", None) + same_model = current_model == self.settings.model + + # Fast path (non-ARA): zero out LoRA adapters. if ( - current_model == self.settings.model + same_model and not self.needs_reload and not self.settings.use_ara ): - # Reset LoRA adapters to zero (identity transformation) for name, module in self.model.named_modules(): if "lora_B" in name and hasattr(module, "weight"): torch.nn.init.zeros_(module.weight) return - dtype = self.model.dtype + # Fast path (ARA): restore from in-memory snapshot cache. + if ( + same_model + and not self.needs_reload + and self.settings.use_ara + and not self._ara_cache.needs_reload + ): + if len(self._ara_cache) > 0: + self._ara_cache.restore_all() + return + + # Slow path. + self._reload_from_disk(self.model.dtype) + + def _reload_from_disk(self, dtype: torch.dtype) -> None: + """Drop self.model and re-instantiate it from disk with the given + dtype. Used by the slow path of reset_model() and by the recovery + path when self.model has been nulled out.""" + print("* Reloading weights from disk...") # Purge existing model object from memory to make space. self.model = None # ty:ignore[invalid-assignment] + # Module IDs captured in the ARA cache reference the model object + # we just dropped, so the cache must be cleared before any reload. + self._ara_cache.clear() empty_cache() quantization_config = self._get_quantization_config(str(dtype).split(".")[-1]) @@ -344,6 +525,7 @@ def reset_model(self): if not self.settings.use_ara: self._apply_lora() + self._dtype = self.model.dtype self.needs_reload = False def get_layers(self) -> ModuleList: @@ -569,12 +751,84 @@ def abliterate( weight_A.data = lora_A.to(weight_A.dtype) weight_B.data = lora_B.to(weight_B.dtype) + def _collect_ara_target_modules( + self, parameters: ARAParameters + ) -> list[Module]: + """Enumerate every module that ara_abliterate will mutate for + these parameters. Used by the snapshot cache planner. + """ + target_modules: list[Module] = [] + for layer_index in range( + parameters.start_layer_index, + parameters.end_layer_index, + ): + for modules in self.get_layer_modules(layer_index).values(): + target_modules.extend(modules) + return target_modules + + def _prepare_ara_cache(self, target_modules: list[Module]) -> None: + """Diff the snapshot cache against this trial's targets: + + 1. Drop snapshots for modules we won't touch (free memory). + 2. Snapshot newly-targeted modules (which currently hold + originals, because reset_model() either restored them from + the cache or freshly loaded them from disk). + + On insufficient memory, prints a warning and sets needs_reload + so the next reset_model() falls back to a disk reload. + """ + cache = self._ara_cache + target_ids = {id(m) for m in target_modules} + + # Step 1: drop entries no longer needed. + to_drop = [ + m for m in cache.cached_modules() if id(m) not in target_ids + ] + if to_drop: + cache.free(to_drop) + empty_cache() + + # Step 2: snapshot newly-targeted modules. + to_add = [m for m in target_modules if m not in cache] + if not to_add: + return + + size_gb = cache.estimate_total_size(to_add) / (1024**3) + + if cache.try_snapshot(to_add): + # Only print when the cache does meaningful work, to avoid + # noise on small models where the snapshot is essentially free. + if size_gb >= 0.1: + print( + f"* Snapshotted [bold]{len(to_add)}[/] module(s) " + f"([bold]{size_gb:.2f} GB[/]) for fast reset" + ) + return + + # Could not fit the planned snapshot on any tier. Run the trial + # anyway, but flag the cache so the next reset_model() reloads + # from disk to recover the originals we're about to overwrite. + print( + f"* [yellow]Not enough RAM to snapshot all targeted layers " + f"(would need ~[bold]{size_gb:.2f} GB[/]). " + f"Weights will be reloaded from disk after this iteration...[/]" + ) + cache.needs_reload = True + def ara_abliterate( self, good_module_io: ModuleIO, bad_module_io: ModuleIO, parameters: ARAParameters, ): + # Plan/refresh the snapshot cache BEFORE any mutation, so the + # invariant 'cache holds originals' is preserved. reset_model() + # must have been called just before this, leaving every module + # pristine. + self._prepare_ara_cache( + self._collect_ara_target_modules(parameters) + ) + for layer_index in range( parameters.start_layer_index, parameters.end_layer_index,