|
347 | 347 | }, |
348 | 348 | { |
349 | 349 | "cell_type": "code", |
350 | | - "execution_count": 3, |
| 350 | + "execution_count": null, |
351 | 351 | "metadata": { |
352 | 352 | "id": "ScWILAgIGt5O" |
353 | 353 | }, |
|
359 | 359 | "from tqdm import tqdm\n", |
360 | 360 | "\n", |
361 | 361 | "import transformer_lens\n", |
362 | | - "from transformer_lens import HookedTransformer, ActivationCache\n", |
| 362 | + "from transformer_lens import ActivationCache\n", |
| 363 | + "from transformer_lens.model_bridge import TransformerBridge\n", |
363 | 364 | "from neel_plotly import line, imshow, scatter" |
364 | 365 | ] |
365 | 366 | }, |
|
479 | 480 | }, |
480 | 481 | { |
481 | 482 | "cell_type": "code", |
482 | | - "execution_count": 7, |
| 483 | + "execution_count": null, |
483 | 484 | "metadata": { |
484 | 485 | "id": "5ikyL8-S7u2Z" |
485 | 486 | }, |
|
493 | 494 | "import numpy as np\n", |
494 | 495 | "import torch\n", |
495 | 496 | "\n", |
496 | | - "from transformer_lens import HookedTransformer, ActivationCache\n", |
497 | 497 | "# from transformer_lens.utils import is_lower_triangular, is_square\n", |
498 | 498 | "\n", |
499 | 499 | "HeadName = Literal[\"previous_token_head\", \"duplicate_token_head\", \"induction_head\"]\n", |
|
515 | 515 | "\n", |
516 | 516 | "\n", |
517 | 517 | "def detect_head(\n", |
518 | | - " model: HookedTransformer,\n", |
| 518 | + " model: TransformerBridge,\n", |
519 | 519 | " seq: Union[str, List[str]],\n", |
520 | 520 | " detection_pattern: Union[torch.Tensor, HeadName],\n", |
521 | 521 | " heads: Optional[Union[List[LayerHeadTuple], LayerToHead]] = None,\n", |
|
566 | 566 | " --------\n", |
567 | 567 | " .. code-block:: python\n", |
568 | 568 | "\n", |
569 | | - " >>> from transformer_lens import HookedTransformer, utils\n", |
| 569 | + " >>> from transformer_lens import utils\n", |
| 570 | + " >>> from transformer_lens.model_bridge import TransformerBridge\n", |
570 | 571 | " >>> from transformer_lens.head_detector import detect_head\n", |
571 | 572 | " >>> import plotly.express as px\n", |
572 | 573 | "\n", |
573 | 574 | " >>> def imshow(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n", |
574 | 575 | " >>> px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale=\"RdBu\", labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n", |
575 | 576 | "\n", |
576 | | - " >>> model = HookedTransformer.from_pretrained(\"gpt2-small\")\n", |
| 577 | + " >>> model = TransformerBridge.boot_transformers(\"gpt2\")\n", |
577 | 578 | " >>> sequence = \"This is a test sequence. This is a test sequence.\"\n", |
578 | 579 | "\n", |
579 | 580 | " >>> attention_score = detect_head(model, sequence, \"previous_token_head\")\n", |
|
777 | 778 | }, |
778 | 779 | { |
779 | 780 | "cell_type": "code", |
780 | | - "execution_count": 8, |
| 781 | + "execution_count": null, |
781 | 782 | "metadata": { |
782 | 783 | "colab": { |
783 | 784 | "base_uri": "https://localhost:8080/" |
|
802 | 803 | } |
803 | 804 | ], |
804 | 805 | "source": [ |
805 | | - "model = HookedTransformer.from_pretrained(\"gpt2-small\", device=device)" |
| 806 | + "model = TransformerBridge.boot_transformers(\"gpt2\", device=device)" |
806 | 807 | ] |
807 | 808 | }, |
808 | 809 | { |
|
0 commit comments