Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions delphi/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from simple_parsing import ArgumentParser
from torch import Tensor
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
PreTrainedModel,
Expand All @@ -27,7 +27,12 @@
from delphi.latents.neighbours import NeighbourCalculator
from delphi.log.result_analysis import log_results
from delphi.pipeline import Pipe, Pipeline, process_wrapper
from delphi.scorers import DetectionScorer, FuzzingScorer, OpenAISimulator
from delphi.scorers import (
DetectionScorer,
FuzzingScorer,
OpenAISimulator,
SurprisalInterventionScorer,
)
from delphi.sparse_coders import load_hooks_sparse_coders, load_sparse_coders
from delphi.utils import assert_type, load_tokenized_data

Expand All @@ -40,7 +45,7 @@ def load_artifacts(run_cfg: RunConfig):
else:
dtype = "auto"

model = AutoModel.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
run_cfg.model,
device_map={"": "cuda"},
quantization_config=(
Expand Down Expand Up @@ -118,6 +123,8 @@ async def process_cache(
hookpoints: list[str],
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
latent_range: Tensor | None,
model,
hookpoint_to_sparse_encode,
):
"""
Converts SAE latent activations in on-disk cache in the `latents_path` directory
Expand Down Expand Up @@ -219,6 +226,12 @@ def none_postprocessor(result):
)
)

def custom_serializer(obj):
"""A custom serializer for orjson to handle specific types."""
if isinstance(obj, Tensor):
return obj.tolist()
raise TypeError

# Builds the record from result returned by the pipeline
def scorer_preprocess(result):
if isinstance(result, list):
Expand All @@ -230,11 +243,18 @@ def scorer_preprocess(result):
return record

# Saves the score to a file
def scorer_postprocess(result, score_dir):
# In your __main__.py file

def scorer_postprocess(result, score_dir, scorer_name=None):
if isinstance(result, list):
if not result:
return
result = result[0]

safe_latent_name = str(result.record.latent).replace("/", "--")

with open(score_dir / f"{safe_latent_name}.txt", "wb") as f:
f.write(orjson.dumps(result.score))
f.write(orjson.dumps(result.score, default=custom_serializer))

scorers = []
for scorer_name in run_cfg.scorers:
Expand All @@ -257,6 +277,16 @@ def scorer_postprocess(result, score_dir):
verbose=run_cfg.verbose,
log_prob=run_cfg.log_probs,
)

elif scorer_name == "surprisal_intervention":
scorer = SurprisalInterventionScorer(
model,
hookpoint_to_sparse_encode,
hookpoints=run_cfg.hookpoints,
n_examples_shown=run_cfg.num_examples_per_scorer_prompt,
verbose=run_cfg.verbose,
log_prob=run_cfg.log_probs,
)
else:
raise ValueError(f"Scorer {scorer_name} not supported")

Expand Down Expand Up @@ -396,6 +426,8 @@ async def run(
hookpoints, hookpoint_to_sparse_encode, model, transcode = load_artifacts(run_cfg)
tokenizer = AutoTokenizer.from_pretrained(run_cfg.model, token=run_cfg.hf_token)

model.tokenizer = tokenizer

nrh = assert_type(
dict,
non_redundant_hookpoints(
Expand All @@ -412,7 +444,6 @@ async def run(
transcode,
)

del model, hookpoint_to_sparse_encode
if run_cfg.constructor_cfg.non_activating_source == "neighbours":
nrh = assert_type(
list,
Expand Down Expand Up @@ -445,8 +476,12 @@ async def run(
nrh,
tokenizer,
latent_range,
model,
hookpoint_to_sparse_encode,
)

del model, hookpoint_to_sparse_encode

if run_cfg.verbose:
log_results(scores_path, visualize_path, run_cfg.hookpoints, run_cfg.scorers)

Expand Down
10 changes: 3 additions & 7 deletions delphi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,14 @@ class RunConfig(Serializable):
the default single token explainer, and 'none' for no explanation generation."""

scorers: list[str] = list_field(
choices=[
"fuzz",
"detection",
"simulation",
],
choices=["fuzz", "detection", "simulation", "surprisal_intervention"],
default=[
"fuzz",
"detection",
],
)
"""Scorer methods to score latent explanations. Options are 'fuzz', 'detection', and
'simulation'."""
"""Scorer methods to score latent explanations. Options are 'fuzz', 'detection',
'simulation' and 'surprisal_intervention'."""

name: str = ""
"""The name of the run. Results are saved in a directory with this name."""
Expand Down
7 changes: 7 additions & 0 deletions delphi/latents/latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ class LatentRecord:
"""Frequency of the latent. Number of activations in a context per total
number of contexts."""

@property
def feature_id(self) -> int:
"""
Returns the unique feature index for this latent.
"""
return self.latent.latent_index

@property
def max_activation(self) -> float:
"""
Expand Down
Loading