Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/heretic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ class Settings(BaseSettings):
exclude=True,
)

reproduce: str | None = Field(
Comment thread
p-e-w marked this conversation as resolved.
default=None,
description=(
"If this path or URL to a reproduce.json file is set, load reproduction information "
"from that file, and attempt to reproduce the abliterated model it originated from."
),
exclude=True,
)

dtypes: list[str] = Field(
default=[
# In practice, "auto" almost always means bfloat16.
Expand Down Expand Up @@ -161,13 +170,6 @@ class Settings(BaseSettings):
),
)

trust_remote_code: bool | None = Field(
default=None,
description="Whether to trust remote code when loading the model.",
# For security reasons, we don't store this setting.
exclude=True,
)

batch_size: int = Field(
default=0, # auto
description="Number of input sequences to process in parallel (0 = auto).",
Expand Down
34 changes: 31 additions & 3 deletions src/heretic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ def _is_help_invocation() -> bool:
from .config import QuantizationMethod
from .evaluator import Evaluator
from .model import AbliterationParameters, Model, get_model_class
from .reproduce import collect_reproducibles
from .reproduce import (
check_environment,
collect_reproducibles,
load_reproduction_information,
)
from .system import empty_cache, get_accelerator_info
from .utils import (
format_duration,
Expand Down Expand Up @@ -113,7 +117,9 @@ def obtain_merge_strategy(settings: Settings, model: Model) -> str | None:
settings.model,
device_map="meta",
torch_dtype=torch.bfloat16,
trust_remote_code=model.trusted_models.get(settings.model),
trust_remote_code=True
if settings.model in model.trusted_models
else None,
**model.revision_kwargs,
)
footprint_bytes = meta_model.get_memory_footprint()
Expand Down Expand Up @@ -175,6 +181,7 @@ def run():
len(sys.argv) > 1
# Heretic is being invoked in standard (model processing) mode.
and "--collect-reproducibles" not in sys.argv
and "--reproduce" not in sys.argv
# No model has been explicitly provided.
and "--model" not in sys.argv
# The last argument is a parameter value rather than a flag (such as "--help").
Expand All @@ -185,7 +192,9 @@ def run():

# Work around the "model" argument being required
# when Heretic is invoked in a non-processing mode.
if "--collect-reproducibles" in sys.argv and "--model" not in sys.argv:
if (
"--collect-reproducibles" in sys.argv or "--reproduce" in sys.argv
) and "--model" not in sys.argv:
sys.argv.extend(["--model", ""])

try:
Expand All @@ -208,6 +217,25 @@ def run():
collect_reproducibles(settings.collect_reproducibles)
return

if settings.reproduce is not None:
print(f"Loading reproduction information from [bold]{settings.reproduce}[/]...")
# FIXME: "Reproduction"/"reproducibility" name inconsistency!
reproduction_information = load_reproduction_information(settings.reproduce)

if reproduction_information["version"] not in ["1"]:
print(
(
f"[red]Unsupported file format version: [bold]{reproduction_information['version']}[/].[/] "
"Try loading the file with a newer version of Heretic."
)
)
return

if not check_environment(reproduction_information):
return

return

if settings.seed is None:
settings.seed = random.randint(0, 2**32 - 1)

Expand Down
23 changes: 13 additions & 10 deletions src/heretic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def __init__(self, settings: Settings):

self.tokenizer = AutoTokenizer.from_pretrained(
settings.model,
trust_remote_code=settings.trust_remote_code,
**self.revision_kwargs,
)

Expand All @@ -90,10 +89,8 @@ def __init__(self, settings: Settings):
if settings.max_memory
else None
)
self.trusted_models = {settings.model: settings.trust_remote_code}

if self.settings.evaluate_model is not None:
self.trusted_models[settings.evaluate_model] = settings.trust_remote_code
self.trusted_models = set()

for dtype in settings.dtypes:
print(f"* Trying dtype [bold]{dtype}[/]...")
Expand All @@ -112,15 +109,17 @@ def __init__(self, settings: Settings):
dtype=dtype,
device_map=settings.device_map,
max_memory=self.max_memory,
trust_remote_code=self.trusted_models.get(settings.model),
trust_remote_code=True
if settings.model in self.trusted_models
else None,
**self.revision_kwargs,
**extra_kwargs,
)

# If we reach this point and the model requires trust_remote_code,
# either the user accepted, or settings.trust_remote_code is True.
if self.trusted_models.get(settings.model) is None:
self.trusted_models[settings.model] = True
# the user must have agreed when prompted to execute remote code,
# because from_pretrained raises an exception otherwise.
self.trusted_models.add(settings.model)

# A test run can reveal dtype-related problems such as the infamous
# "RuntimeError: probability tensor contains either `inf`, `nan` or element < 0"
Expand Down Expand Up @@ -264,7 +263,9 @@ def get_merged_model(self) -> PreTrainedModel:
self.settings.model,
torch_dtype=self.model.dtype,
device_map="cpu",
trust_remote_code=self.trusted_models.get(self.settings.model),
trust_remote_code=True
if self.settings.model in self.trusted_models
else None,
**self.revision_kwargs,
)

Expand Down Expand Up @@ -326,7 +327,9 @@ def reset_model(self):
dtype=dtype,
device_map=self.settings.device_map,
max_memory=self.max_memory,
trust_remote_code=self.trusted_models.get(self.settings.model),
trust_remote_code=True
if self.settings.model in self.trusted_models
else None,
**self.revision_kwargs,
**extra_kwargs,
)
Expand Down
Loading