Skip to content

Commit f2f1c46

Browse files
committed
updating loading in head detector demo to use transformer bridge
1 parent 4ece3c4 commit f2f1c46

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

.github/workflows/checks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ jobs:
151151
- "BERT"
152152
- "Exploratory_Analysis_Demo"
153153
# - "Grokking_Demo"
154-
# - "Head_Detector_Demo"
154+
- "Head_Detector_Demo"
155155
- "Interactive_Neuroscope"
156156
# - "LLaMA"
157157
# - "LLaMA2_GPU_Quantized"

demos/Head_Detector_Demo.ipynb

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@
347347
},
348348
{
349349
"cell_type": "code",
350-
"execution_count": 3,
350+
"execution_count": null,
351351
"metadata": {
352352
"id": "ScWILAgIGt5O"
353353
},
@@ -359,7 +359,8 @@
359359
"from tqdm import tqdm\n",
360360
"\n",
361361
"import transformer_lens\n",
362-
"from transformer_lens import HookedTransformer, ActivationCache\n",
362+
"from transformer_lens import ActivationCache\n",
363+
"from transformer_lens.model_bridge import TransformerBridge\n",
363364
"from neel_plotly import line, imshow, scatter"
364365
]
365366
},
@@ -479,7 +480,7 @@
479480
},
480481
{
481482
"cell_type": "code",
482-
"execution_count": 7,
483+
"execution_count": null,
483484
"metadata": {
484485
"id": "5ikyL8-S7u2Z"
485486
},
@@ -493,7 +494,6 @@
493494
"import numpy as np\n",
494495
"import torch\n",
495496
"\n",
496-
"from transformer_lens import HookedTransformer, ActivationCache\n",
497497
"# from transformer_lens.utils import is_lower_triangular, is_square\n",
498498
"\n",
499499
"HeadName = Literal[\"previous_token_head\", \"duplicate_token_head\", \"induction_head\"]\n",
@@ -515,7 +515,7 @@
515515
"\n",
516516
"\n",
517517
"def detect_head(\n",
518-
" model: HookedTransformer,\n",
518+
" model: TransformerBridge,\n",
519519
" seq: Union[str, List[str]],\n",
520520
" detection_pattern: Union[torch.Tensor, HeadName],\n",
521521
" heads: Optional[Union[List[LayerHeadTuple], LayerToHead]] = None,\n",
@@ -566,14 +566,15 @@
566566
" --------\n",
567567
" .. code-block:: python\n",
568568
"\n",
569-
" >>> from transformer_lens import HookedTransformer, utils\n",
569+
" >>> from transformer_lens import utils\n",
570+
" >>> from transformer_lens.model_bridge import TransformerBridge\n",
570571
" >>> from transformer_lens.head_detector import detect_head\n",
571572
" >>> import plotly.express as px\n",
572573
"\n",
573574
" >>> def imshow(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n",
574575
" >>> px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale=\"RdBu\", labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n",
575576
"\n",
576-
" >>> model = HookedTransformer.from_pretrained(\"gpt2-small\")\n",
577+
" >>> model = TransformerBridge.boot_transformers(\"gpt2\")\n",
577578
" >>> sequence = \"This is a test sequence. This is a test sequence.\"\n",
578579
"\n",
579580
" >>> attention_score = detect_head(model, sequence, \"previous_token_head\")\n",
@@ -777,7 +778,7 @@
777778
},
778779
{
779780
"cell_type": "code",
780-
"execution_count": 8,
781+
"execution_count": null,
781782
"metadata": {
782783
"colab": {
783784
"base_uri": "https://localhost:8080/"
@@ -802,7 +803,7 @@
802803
}
803804
],
804805
"source": [
805-
"model = HookedTransformer.from_pretrained(\"gpt2-small\", device=device)"
806+
"model = TransformerBridge.boot_transformers(\"gpt2\", device=device)"
806807
]
807808
},
808809
{

0 commit comments

Comments
 (0)