Skip to content

Commit 4fc9ba6

Browse files
authored
Enable config tests, fix new dataset format, add tests for it (#145)
fix: update token loader due to new datasets version.
1 parent 1af73f4 commit 4fc9ba6

File tree

3 files changed

+114
-57
lines changed

3 files changed

+114
-57
lines changed

delphi/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing import Any, TypeVar, cast
22

3+
import datasets
34
import numpy as np
45
import torch
6+
from datasets.table import table_iter
57
from torch import Tensor
68
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
79

@@ -14,10 +16,23 @@ def load_tokenized_data(
1416
dataset_name: str = "",
1517
column_name: str = "text",
1618
seed: int = 22,
19+
convert_to_tensor_chunk_size: int = 2**18,
1720
):
1821
"""
1922
Load a huggingface dataset, tokenize it, and shuffle.
2023
Using this function ensures we are using the same tokens everywhere.
24+
25+
Args:
26+
ctx_len: The context length of the tokens.
27+
tokenizer: The tokenizer to use.
28+
dataset_repo: The repository of the dataset.
29+
dataset_split: The split of the dataset.
30+
dataset_name: The name of the dataset.
31+
column_name: The name of the column to tokenize.
32+
seed: The seed to use for shuffling the dataset.
33+
convert_to_tensor_chunk_size: The chunk size to use when converting the dataset
34+
from Huggingface's Table format to a tensor. Values around 2**17-2**18 seem to
35+
be the fastest.
2136
"""
2237
from datasets import load_dataset
2338
from sparsify.data import chunk_and_tokenize
@@ -33,6 +48,16 @@ def load_tokenized_data(
3348

3449
tokens = tokens_ds["input_ids"]
3550

51+
if isinstance(tokens, datasets.Column):
52+
tokens = torch.cat(
53+
[
54+
torch.from_numpy(np.stack(table_chunk["input_ids"].to_numpy(), axis=0))
55+
for table_chunk in table_iter(
56+
tokens.source._data, convert_to_tensor_chunk_size
57+
)
58+
]
59+
)
60+
3661
return tokens
3762

3863

tests/conftest.py

Lines changed: 11 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import shutil
12
from pathlib import Path
23
from typing import cast
34

@@ -79,13 +80,17 @@ def cache_setup(tmp_path_factory, mock_dataset: torch.Tensor, model: PreTrainedM
7980
hookpoint_to_sparse_encode, _ = load_hooks_sparse_coders(model, run_cfg_gemma)
8081
# Define cache config and initialize cache
8182
log_path = Path.cwd() / "results" / "test" / "log"
83+
shutil.rmtree(log_path, ignore_errors=True)
8284
log_path.mkdir(parents=True, exist_ok=True)
8385

84-
cache = LatentCache(
85-
model,
86-
hookpoint_to_sparse_encode,
87-
batch_size=cache_cfg.batch_size,
88-
log_path=log_path,
86+
cache, empty_cache = (
87+
LatentCache(
88+
model,
89+
hookpoint_to_sparse_encode,
90+
batch_size=cache_cfg.batch_size,
91+
log_path=log_path,
92+
)
93+
for _ in range(2)
8994
)
9095

9196
# Generate mock tokens and run the cache
@@ -104,60 +109,9 @@ def cache_setup(tmp_path_factory, mock_dataset: torch.Tensor, model: PreTrainedM
104109
)
105110
return {
106111
"cache": cache,
112+
"empty_cache": empty_cache,
107113
"tokens": tokens,
108114
"cache_cfg": cache_cfg,
109115
"temp_dir": temp_dir,
110116
"firing_counts": hookpoint_firing_counts,
111117
}
112-
113-
114-
def test_hookpoint_firing_counts_initialization(cache_setup):
115-
"""
116-
Ensure that hookpoint_firing_counts is initialized as an empty dictionary.
117-
"""
118-
cache = cache_setup["cache"]
119-
assert isinstance(cache.hookpoint_firing_counts, dict)
120-
assert len(cache.hookpoint_firing_counts) == 0 # Should be empty before run()
121-
122-
123-
def test_hookpoint_firing_counts_updates(cache_setup):
124-
"""
125-
Ensure that hookpoint_firing_counts is properly updated after running the cache.
126-
"""
127-
cache = cache_setup["cache"]
128-
tokens = cache_setup["tokens"]
129-
cache.run(cache_setup["cache_cfg"].n_tokens, tokens)
130-
131-
assert (
132-
len(cache.hookpoint_firing_counts) > 0
133-
), "hookpoint_firing_counts should not be empty after run()"
134-
for hookpoint, counts in cache.hookpoint_firing_counts.items():
135-
assert isinstance(
136-
counts, torch.Tensor
137-
), f"Counts for {hookpoint} should be a torch.Tensor"
138-
assert counts.ndim == 1, f"Counts for {hookpoint} should be a 1D tensor"
139-
assert (counts >= 0).all(), f"Counts for {hookpoint} should be non-negative"
140-
141-
142-
def test_hookpoint_firing_counts_persistence(cache_setup):
143-
"""
144-
Ensure that hookpoint_firing_counts are correctly saved and loaded.
145-
"""
146-
cache = cache_setup["cache"]
147-
cache.save_firing_counts()
148-
149-
firing_counts_path = Path.cwd() / "results" / "log" / "hookpoint_firing_counts.pt"
150-
assert firing_counts_path.exists(), "Firing counts file should exist after saving"
151-
152-
loaded_counts = torch.load(firing_counts_path, weights_only=True)
153-
assert isinstance(
154-
loaded_counts, dict
155-
), "Loaded firing counts should be a dictionary"
156-
assert (
157-
loaded_counts.keys() == cache.hookpoint_firing_counts.keys()
158-
), "Loaded firing counts keys should match saved keys"
159-
160-
for hookpoint, counts in loaded_counts.items():
161-
assert torch.equal(
162-
counts, cache.hookpoint_firing_counts[hookpoint]
163-
), f"Mismatch in firing counts for {hookpoint}"

tests/test_latents/test_config.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from pathlib import Path
2+
3+
import torch
4+
from transformers import AutoTokenizer
5+
6+
from delphi.utils import load_tokenized_data
7+
8+
9+
def test_dataset_is_array():
10+
tokens = load_tokenized_data(
11+
ctx_len=16,
12+
tokenizer=AutoTokenizer.from_pretrained("EleutherAI/pythia-70m"),
13+
dataset_repo="NeelNanda/pile-10k",
14+
dataset_split="train",
15+
dataset_name="",
16+
column_name="text",
17+
seed=42,
18+
)
19+
assert isinstance(tokens, torch.Tensor)
20+
assert tokens.ndim == 2
21+
assert tokens.shape[1] == 16
22+
assert tokens.dtype in (torch.int64, torch.int32)
23+
assert tokens.min() >= 0
24+
assert tokens.max() < 50304
25+
26+
27+
def test_hookpoint_firing_counts_initialization(cache_setup):
28+
"""
29+
Ensure that hookpoint_firing_counts is initialized as an empty dictionary.
30+
"""
31+
cache = cache_setup["empty_cache"]
32+
assert isinstance(cache.hookpoint_firing_counts, dict)
33+
assert len(cache.hookpoint_firing_counts) == 0 # Should be empty before run()
34+
35+
36+
def test_hookpoint_firing_counts_updates(cache_setup):
37+
"""
38+
Ensure that hookpoint_firing_counts is properly updated after running the cache.
39+
"""
40+
cache = cache_setup["empty_cache"]
41+
tokens = cache_setup["tokens"]
42+
cache.run(cache_setup["cache_cfg"].n_tokens, tokens)
43+
44+
assert (
45+
len(cache.hookpoint_firing_counts) > 0
46+
), "hookpoint_firing_counts should not be empty after run()"
47+
for hookpoint, counts in cache.hookpoint_firing_counts.items():
48+
assert isinstance(
49+
counts, torch.Tensor
50+
), f"Counts for {hookpoint} should be a torch.Tensor"
51+
assert counts.ndim == 1, f"Counts for {hookpoint} should be a 1D tensor"
52+
assert (counts >= 0).all(), f"Counts for {hookpoint} should be non-negative"
53+
54+
55+
def test_hookpoint_firing_counts_persistence(cache_setup):
56+
"""
57+
Ensure that hookpoint_firing_counts are correctly saved and loaded.
58+
"""
59+
cache = cache_setup["empty_cache"]
60+
cache.save_firing_counts()
61+
62+
firing_counts_path = (
63+
Path.cwd() / "results" / "test" / "log" / "hookpoint_firing_counts.pt"
64+
)
65+
assert firing_counts_path.exists(), "Firing counts file should exist after saving"
66+
67+
loaded_counts = torch.load(firing_counts_path, weights_only=True)
68+
assert isinstance(
69+
loaded_counts, dict
70+
), "Loaded firing counts should be a dictionary"
71+
assert (
72+
loaded_counts.keys() == cache.hookpoint_firing_counts.keys()
73+
), "Loaded firing counts keys should match saved keys"
74+
75+
for hookpoint, counts in loaded_counts.items():
76+
assert torch.equal(
77+
counts, cache.hookpoint_firing_counts[hookpoint]
78+
), f"Mismatch in firing counts for {hookpoint}"

0 commit comments

Comments
 (0)