Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/heretic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ def objective(trial: Trial) -> tuple[float, float]:
for name, value in get_trial_parameters(settings, trial).items():
print(f" * {name} = [bold]{value}[/]")
if settings.use_ara:
print("* Reloading model...")
print("* Resetting model...")
model.reset_model()
print("* Abliterating (Arbitrary-Rank Ablation)...")
model.ara_abliterate(good_module_io, bad_module_io, ara_parameters)
Expand Down Expand Up @@ -811,7 +811,7 @@ def count_completed_trials() -> int:
for name, value in get_trial_parameters(settings, trial).items():
print(f" * {name} = [bold]{value}[/]")
if settings.use_ara:
print("* Reloading model...")
print("* Resetting model...")
model.reset_model()
print("* Abliterating (Arbitrary-Rank Ablation)...")
model.ara_abliterate(
Expand Down
268 changes: 261 additions & 7 deletions src/heretic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.nn.functional as F
from peft import LoraConfig, PeftModel, get_peft_model
from peft.tuners.lora.layer import Linear
from psutil import virtual_memory
from torch import FloatTensor, LongTensor, Tensor
from torch.nn import Module, ModuleList
from torch.optim import LBFGS
Expand Down Expand Up @@ -71,6 +72,131 @@ class ARAParameters:
ModuleIO: TypeAlias = list[dict[str, dict[int, tuple[Tensor, Tensor]]]]


# Headroom (in bytes) we keep free on any device when planning a snapshot,
# so that we don't starve the LBFGS optimizer or the inference path.
_SNAPSHOT_SAFETY_MARGIN_BYTES = 2 * 1024**3 # 2 GB


class ARAWeightCache:
"""Caches original weights of ARA-touched modules between trials.

Holds the invariant: every entry stores the pristine, original weight
of its module. To preserve this, the caller must:
1. Restore from cache before any new optimization (model is clean).
2. Drop entries for modules not used on the next trial (free RAM).
3. Snapshot newly-targeted modules BEFORE mutating them.

Storage tier is per source device: snapshots stay on the source GPU
when there is room (intra-device copy on restore, fastest), otherwise
spill to system RAM. If neither tier can host the snapshot,
try_snapshot returns False and the caller is expected to set
needs_reload so the next reset_model() falls back to a disk reload.
"""

def __init__(self):
self._snapshots: dict[int, Tensor] = {}
self._modules: dict[int, Module] = {}
self.needs_reload = False

def __len__(self) -> int:
return len(self._snapshots)

def __contains__(self, module: Module) -> bool:
return id(module) in self._snapshots

def cached_modules(self) -> list[Module]:
return list(self._modules.values())

def restore_all(self) -> None:
with torch.no_grad():
for mid, snapshot in self._snapshots.items():
target = cast(Tensor, self._modules[mid].weight)
if snapshot.device == target.device:
target.data.copy_(snapshot)
else:
target.data.copy_(
snapshot.to(target.device, non_blocking=True)
)

def free(self, modules: list[Module]) -> None:
for module in modules:
mid = id(module)
self._snapshots.pop(mid, None)
self._modules.pop(mid, None)

def clear(self) -> None:
# Required after any full reload: cached module ids are stale.
self._snapshots.clear()
self._modules.clear()
self.needs_reload = False

@staticmethod
def _module_size_bytes(module: Module) -> int:
weight = cast(Tensor, module.weight)
return weight.numel() * weight.element_size()

@classmethod
def estimate_total_size(cls, modules: list[Module]) -> int:
return sum(cls._module_size_bytes(m) for m in modules)

@staticmethod
def _gpu_has_room(device: torch.device, size_bytes: int) -> bool:
if device.type != "cuda":
return False
try:
free, _ = torch.cuda.mem_get_info(device)
except Exception:
return False
return free >= size_bytes + _SNAPSHOT_SAFETY_MARGIN_BYTES

def try_snapshot(self, modules: list[Module]) -> bool:
"""Either every module is snapshotted (returns True) or none is
(returns False). Partial snapshots are never taken, to keep the
cache invariant clean."""
if not modules:
return True

# Tally size per source device so we make a single accept/reject
# decision per device.
size_per_device: dict[torch.device, int] = {}
for module in modules:
weight = cast(Tensor, module.weight)
size_per_device[weight.device] = (
size_per_device.get(weight.device, 0)
+ self._module_size_bytes(module)
)

keep_on_source_device: set[torch.device] = set()
cpu_required_bytes = 0
for device, size in size_per_device.items():
if device.type == "cuda" and self._gpu_has_room(device, size):
keep_on_source_device.add(device)
else:
cpu_required_bytes += size

if cpu_required_bytes > 0:
available = virtual_memory().available
if (
available
< cpu_required_bytes + _SNAPSHOT_SAFETY_MARGIN_BYTES
):
return False

with torch.no_grad():
for module in modules:
weight = cast(Tensor, module.weight)
target_device = (
weight.device
if weight.device in keep_on_source_device
else torch.device("cpu")
)
snapshot = weight.detach().to(target_device, copy=True)
self._snapshots[id(module)] = snapshot
self._modules[id(module)] = module

return True


class Model:
model: PreTrainedModel | PeftModel
tokenizer: PreTrainedTokenizerBase
Expand All @@ -80,6 +206,14 @@ def __init__(self, settings: Settings):
self.settings = settings
self.response_prefix = ""
self.needs_reload = False
# Caches original weights of ARA-touched modules so reset_model()
# can restore them in O(memcpy) instead of reloading from disk.
# Only populated when settings.use_ara is True.
self._ara_cache = ARAWeightCache()
# Tracks the dtype of the currently loaded model so reset_model()
# can recover even if self.model has been nulled out by an
# interrupted previous reset (e.g. Ctrl+C during disk reload).
self._dtype: torch.dtype | None = None

print()
print(f"Loading model [bold]{settings.model}[/]...")
Expand Down Expand Up @@ -161,6 +295,10 @@ def __init__(self, settings: Settings):
if self.model is None:
raise Exception("Failed to load model with all configured dtypes.")

# Remember the dtype that worked, so reset_model() can recover
# the disk-reload path even if self.model has been nulled out.
self._dtype = self.model.dtype

if not settings.use_ara:
self._apply_lora()

Expand Down Expand Up @@ -302,27 +440,70 @@ def reset_model(self):
Resets the model to a clean state for the next trial or evaluation.

Behavior:
- Fast path: If the same model is loaded and doesn't need full reload,
resets LoRA adapter weights to zero (identity transformation).
- Slow path: If switching models or after merge_and_unload(),
performs full model reload with quantization config.
- Recovery: If self.model is None (e.g. a previous reset was
interrupted between purging self.model and from_pretrained), do
a full disk reload using the last known good dtype.
- Fast path (non-ARA): If the same model is loaded and doesn't
need full reload, resets LoRA adapter weights to zero.
- Fast path (ARA): If the same model is loaded and the snapshot
cache is healthy, restores cached original weights via memcpy.
No disk I/O.
- Slow path: If switching models, after merge_and_unload(), or
when the ARA cache reports it can't fully back the next trial,
performs a full model reload with quantization config.
"""
# Recovery path: a previous reset_model() may have been
# interrupted (Ctrl+C, OOM) between ``self.model = None`` and
# the from_pretrained call below, leaving us with no live model.
# Reload using the dtype we recorded at the last successful load.
if self.model is None:
assert self._dtype is not None, (
"self.model is None and no dtype recorded — this should "
"be unreachable because __init__ always records a dtype "
"after a successful load."
)
self._reload_from_disk(self._dtype)
return

current_model = getattr(self.model.config, "name_or_path", None)
Comment thread
ghecko marked this conversation as resolved.
same_model = current_model == self.settings.model

# Fast path (non-ARA): zero out LoRA adapters.
if (
current_model == self.settings.model
same_model
and not self.needs_reload
and not self.settings.use_ara
):
# Reset LoRA adapters to zero (identity transformation)
for name, module in self.model.named_modules():
if "lora_B" in name and hasattr(module, "weight"):
torch.nn.init.zeros_(module.weight)
return

dtype = self.model.dtype
# Fast path (ARA): restore from in-memory snapshot cache.
if (
same_model
and not self.needs_reload
and self.settings.use_ara
and not self._ara_cache.needs_reload
):
if len(self._ara_cache) > 0:
self._ara_cache.restore_all()
return

# Slow path.
self._reload_from_disk(self.model.dtype)

def _reload_from_disk(self, dtype: torch.dtype) -> None:
"""Drop self.model and re-instantiate it from disk with the given
dtype. Used by the slow path of reset_model() and by the recovery
path when self.model has been nulled out."""
print("* Reloading weights from disk...")

# Purge existing model object from memory to make space.
self.model = None # ty:ignore[invalid-assignment]
# Module IDs captured in the ARA cache reference the model object
# we just dropped, so the cache must be cleared before any reload.
self._ara_cache.clear()
empty_cache()

quantization_config = self._get_quantization_config(str(dtype).split(".")[-1])
Expand All @@ -344,6 +525,7 @@ def reset_model(self):
if not self.settings.use_ara:
self._apply_lora()

self._dtype = self.model.dtype
self.needs_reload = False

def get_layers(self) -> ModuleList:
Expand Down Expand Up @@ -569,12 +751,84 @@ def abliterate(
weight_A.data = lora_A.to(weight_A.dtype)
weight_B.data = lora_B.to(weight_B.dtype)

def _collect_ara_target_modules(
self, parameters: ARAParameters
) -> list[Module]:
"""Enumerate every module that ara_abliterate will mutate for
these parameters. Used by the snapshot cache planner.
"""
target_modules: list[Module] = []
for layer_index in range(
parameters.start_layer_index,
parameters.end_layer_index,
):
for modules in self.get_layer_modules(layer_index).values():
target_modules.extend(modules)
return target_modules

def _prepare_ara_cache(self, target_modules: list[Module]) -> None:
"""Diff the snapshot cache against this trial's targets:

1. Drop snapshots for modules we won't touch (free memory).
2. Snapshot newly-targeted modules (which currently hold
originals, because reset_model() either restored them from
the cache or freshly loaded them from disk).

On insufficient memory, prints a warning and sets needs_reload
so the next reset_model() falls back to a disk reload.
"""
cache = self._ara_cache
target_ids = {id(m) for m in target_modules}

# Step 1: drop entries no longer needed.
to_drop = [
m for m in cache.cached_modules() if id(m) not in target_ids
]
if to_drop:
cache.free(to_drop)
empty_cache()

# Step 2: snapshot newly-targeted modules.
to_add = [m for m in target_modules if m not in cache]
if not to_add:
return

size_gb = cache.estimate_total_size(to_add) / (1024**3)

if cache.try_snapshot(to_add):
# Only print when the cache does meaningful work, to avoid
# noise on small models where the snapshot is essentially free.
if size_gb >= 0.1:
print(
f"* Snapshotted [bold]{len(to_add)}[/] module(s) "
f"([bold]{size_gb:.2f} GB[/]) for fast reset"
)
return

# Could not fit the planned snapshot on any tier. Run the trial
# anyway, but flag the cache so the next reset_model() reloads
# from disk to recover the originals we're about to overwrite.
print(
f"* [yellow]Not enough RAM to snapshot all targeted layers "
f"(would need ~[bold]{size_gb:.2f} GB[/]). "
f"Weights will be reloaded from disk after this iteration...[/]"
)
cache.needs_reload = True

def ara_abliterate(
self,
good_module_io: ModuleIO,
bad_module_io: ModuleIO,
parameters: ARAParameters,
):
# Plan/refresh the snapshot cache BEFORE any mutation, so the
# invariant 'cache holds originals' is preserved. reset_model()
# must have been called just before this, leaving every module
# pristine.
self._prepare_ara_cache(
self._collect_ara_target_modules(parameters)
)

for layer_index in range(
parameters.start_layer_index,
parameters.end_layer_index,
Expand Down