1
+ import shutil
1
2
from pathlib import Path
2
3
from typing import cast
3
4
@@ -79,13 +80,17 @@ def cache_setup(tmp_path_factory, mock_dataset: torch.Tensor, model: PreTrainedM
79
80
hookpoint_to_sparse_encode , _ = load_hooks_sparse_coders (model , run_cfg_gemma )
80
81
# Define cache config and initialize cache
81
82
log_path = Path .cwd () / "results" / "test" / "log"
83
+ shutil .rmtree (log_path , ignore_errors = True )
82
84
log_path .mkdir (parents = True , exist_ok = True )
83
85
84
- cache = LatentCache (
85
- model ,
86
- hookpoint_to_sparse_encode ,
87
- batch_size = cache_cfg .batch_size ,
88
- log_path = log_path ,
86
+ cache , empty_cache = (
87
+ LatentCache (
88
+ model ,
89
+ hookpoint_to_sparse_encode ,
90
+ batch_size = cache_cfg .batch_size ,
91
+ log_path = log_path ,
92
+ )
93
+ for _ in range (2 )
89
94
)
90
95
91
96
# Generate mock tokens and run the cache
@@ -104,60 +109,9 @@ def cache_setup(tmp_path_factory, mock_dataset: torch.Tensor, model: PreTrainedM
104
109
)
105
110
return {
106
111
"cache" : cache ,
112
+ "empty_cache" : empty_cache ,
107
113
"tokens" : tokens ,
108
114
"cache_cfg" : cache_cfg ,
109
115
"temp_dir" : temp_dir ,
110
116
"firing_counts" : hookpoint_firing_counts ,
111
117
}
112
-
113
-
114
- def test_hookpoint_firing_counts_initialization (cache_setup ):
115
- """
116
- Ensure that hookpoint_firing_counts is initialized as an empty dictionary.
117
- """
118
- cache = cache_setup ["cache" ]
119
- assert isinstance (cache .hookpoint_firing_counts , dict )
120
- assert len (cache .hookpoint_firing_counts ) == 0 # Should be empty before run()
121
-
122
-
123
- def test_hookpoint_firing_counts_updates (cache_setup ):
124
- """
125
- Ensure that hookpoint_firing_counts is properly updated after running the cache.
126
- """
127
- cache = cache_setup ["cache" ]
128
- tokens = cache_setup ["tokens" ]
129
- cache .run (cache_setup ["cache_cfg" ].n_tokens , tokens )
130
-
131
- assert (
132
- len (cache .hookpoint_firing_counts ) > 0
133
- ), "hookpoint_firing_counts should not be empty after run()"
134
- for hookpoint , counts in cache .hookpoint_firing_counts .items ():
135
- assert isinstance (
136
- counts , torch .Tensor
137
- ), f"Counts for { hookpoint } should be a torch.Tensor"
138
- assert counts .ndim == 1 , f"Counts for { hookpoint } should be a 1D tensor"
139
- assert (counts >= 0 ).all (), f"Counts for { hookpoint } should be non-negative"
140
-
141
-
142
- def test_hookpoint_firing_counts_persistence (cache_setup ):
143
- """
144
- Ensure that hookpoint_firing_counts are correctly saved and loaded.
145
- """
146
- cache = cache_setup ["cache" ]
147
- cache .save_firing_counts ()
148
-
149
- firing_counts_path = Path .cwd () / "results" / "log" / "hookpoint_firing_counts.pt"
150
- assert firing_counts_path .exists (), "Firing counts file should exist after saving"
151
-
152
- loaded_counts = torch .load (firing_counts_path , weights_only = True )
153
- assert isinstance (
154
- loaded_counts , dict
155
- ), "Loaded firing counts should be a dictionary"
156
- assert (
157
- loaded_counts .keys () == cache .hookpoint_firing_counts .keys ()
158
- ), "Loaded firing counts keys should match saved keys"
159
-
160
- for hookpoint , counts in loaded_counts .items ():
161
- assert torch .equal (
162
- counts , cache .hookpoint_firing_counts [hookpoint ]
163
- ), f"Mismatch in firing counts for { hookpoint } "
0 commit comments