|
2 | 2 | "cells": [
|
3 | 3 | {
|
4 | 4 | "cell_type": "code",
|
5 |
| - "execution_count": 6, |
| 5 | + "execution_count": 1, |
6 | 6 | "metadata": {},
|
7 | 7 | "outputs": [],
|
8 | 8 | "source": [
|
|
20 | 20 | "import h5py\n",
|
21 | 21 | "from copy import deepcopy\n",
|
22 | 22 | "from skimage.filters import gabor_kernel\n",
|
23 |
| - "import gabor_feats\n", |
| 23 | + "# import gabor_feats\n", |
24 | 24 | "from sklearn.linear_model import RidgeCV\n",
|
25 | 25 | "import seaborn as sns\n",
|
26 | 26 | "from scipy.io import loadmat\n",
|
|
166 | 166 | },
|
167 | 167 | {
|
168 | 168 | "cell_type": "code",
|
169 |
| - "execution_count": 168, |
| 169 | + "execution_count": null, |
170 | 170 | "metadata": {},
|
171 |
| - "outputs": [ |
172 |
| - { |
173 |
| - "name": "stdout", |
174 |
| - "output_type": "stream", |
175 |
| - "text": [ |
176 |
| - "[\"/roi/FFAlh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/FFArh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/IPlh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/IPrh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/MTlh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/MTplh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/MTprh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/MTrh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/OBJlh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/OBJrh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/PPAlh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/PPArh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/RSCrh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/STSrh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/VOlh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/VOrh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/latocclh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/latoccrh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v1lh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v1rh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v2lh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v2rh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v3alh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v3arh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v3blh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v3brh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v3lh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v3rh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v4lh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v4rh (EArray(18, 64, 64), zlib(3)) ''\"]\n" |
177 |
| - ] |
178 |
| - } |
179 |
| - ], |
| 171 | + "outputs": [], |
180 | 172 | "source": [
|
181 | 173 | "f = tables.open_file(oj(out_dir, 'VoxelResponses_subject1.mat'))\n",
|
182 | 174 | "xs = []\n",
|
|
514 | 506 | "cell_type": "markdown",
|
515 | 507 | "metadata": {},
|
516 | 508 | "source": [
|
517 |
| - "# visualize features / decompositions\n", |
| 509 | + "# kernel features" |
| 510 | + ] |
| 511 | + }, |
| 512 | + { |
| 513 | + "cell_type": "code", |
| 514 | + "execution_count": 30, |
| 515 | + "metadata": {}, |
| 516 | + "outputs": [], |
| 517 | + "source": [ |
| 518 | + "X = np.array(loadmat(oj(out_dir, 'mot_energy_feats_st.mat'))['S_fin'])\n", |
| 519 | + "X_test = np.array(loadmat(oj(out_dir, 'mot_energy_feats_sv.mat'))['S_fin'])" |
| 520 | + ] |
| 521 | + }, |
| 522 | + { |
| 523 | + "cell_type": "code", |
| 524 | + "execution_count": 31, |
| 525 | + "metadata": {}, |
| 526 | + "outputs": [], |
| 527 | + "source": [ |
| 528 | + "from jax import random\n", |
| 529 | + "from jax.experimental import stax\n", |
| 530 | + "from jax import random\n", |
| 531 | + "from neural_tangents import stax\n", |
| 532 | + "\n", |
| 533 | + "# kernel function\n", |
| 534 | + "init_fn, apply_fn, kernel_fn = stax.serial(\n", |
| 535 | + " stax.Dense(512), stax.Relu(),\n", |
| 536 | + " stax.Dense(512), stax.Relu(),\n", |
| 537 | + " stax.Dense(1)\n", |
| 538 | + ")" |
| 539 | + ] |
| 540 | + }, |
| 541 | + { |
| 542 | + "cell_type": "code", |
| 543 | + "execution_count": 22, |
| 544 | + "metadata": {}, |
| 545 | + "outputs": [], |
| 546 | + "source": [ |
| 547 | + "# training kernel mat\n", |
| 548 | + "kernel = kernel_fn(X, X, 'ntk')\n", |
| 549 | + "fname = oj(out_dir, f'mot_energy_feats_kernel_mat_ntk.pkl')\n", |
| 550 | + "if not os.path.exists(fname):\n", |
| 551 | + " save_pkl(kernel, fname)\n", |
| 552 | + " \n", |
| 553 | + "# training kernel mat\n", |
| 554 | + "kernel_test = kernel_fn(X_test, X, 'ntk')\n", |
| 555 | + "fname = oj(out_dir, f'mot_energy_feats_kernel_test_with_train_ntk.pkl')\n", |
| 556 | + "if not os.path.exists(fname):\n", |
| 557 | + " save_pkl(kernel_test, fname)\n", |
| 558 | + " \n", |
| 559 | + "# save out eigenvals\n", |
| 560 | + "fname = oj(out_dir, f'eigenvals_eigenvecs_mot_energy_kernel_ntk.pkl')\n", |
| 561 | + "if not os.path.exists(fname):\n", |
| 562 | + " kernel = load_pkl(oj(out_dir, f'mot_energy_feats_kernel_mat_ntk.pkl'))\n", |
| 563 | + " eigenvals, eigenvecs = npl.eig(kernel)\n", |
| 564 | + " save_pkl((eigenvals, eigenvecs), fname)" |
| 565 | + ] |
| 566 | + }, |
| 567 | + { |
| 568 | + "cell_type": "code", |
| 569 | + "execution_count": null, |
| 570 | + "metadata": {}, |
| 571 | + "outputs": [ |
| 572 | + { |
| 573 | + "name": "stderr", |
| 574 | + "output_type": "stream", |
| 575 | + "text": [ |
| 576 | + " 30%|███ | 6/20 [46:02<1:56:09, 497.85s/it]" |
| 577 | + ] |
| 578 | + } |
| 579 | + ], |
| 580 | + "source": [ |
| 581 | + "# save kernel pinvs\n", |
| 582 | + "reg_params = np.logspace(3, 6, 20).round().astype(int)\n", |
| 583 | + "kernel = load_pkl(oj(out_dir, f'mot_energy_feats_kernel_mat_ntk.pkl'))\n", |
| 584 | + "for reg_param in tqdm(reg_params):\n", |
| 585 | + " fname = oj(out_dir, f'pinv_mot_energy_kernel_ntk_{reg_param}.pkl')\n", |
| 586 | + " if not os.path.exists(fname):\n", |
| 587 | + " inv = npl.pinv(kernel + reg_param * np.eye(kernel.shape[0]))\n", |
| 588 | + " save_pkl(inv, fname)" |
| 589 | + ] |
| 590 | + }, |
| 591 | + { |
| 592 | + "cell_type": "code", |
| 593 | + "execution_count": null, |
| 594 | + "metadata": {}, |
| 595 | + "outputs": [], |
| 596 | + "source": [ |
| 597 | + "# need to save kernel matrix\n", |
| 598 | + "# need to save test-time kernel mat\n", |
| 599 | + "\n", |
| 600 | + "\n", |
| 601 | + "# save eigenvalues\n", |
| 602 | + "\n", |
| 603 | + "# make new script\n", |
| 604 | + "# need to switch to use Kernel ridge + eigenvalues" |
| 605 | + ] |
| 606 | + }, |
| 607 | + { |
| 608 | + "cell_type": "markdown", |
| 609 | + "metadata": {}, |
| 610 | + "source": [ |
| 611 | + "# visualize preprocessed features\n", |
518 | 612 | "**load and look at features**"
|
519 | 613 | ]
|
520 | 614 | },
|
521 | 615 | {
|
522 | 616 | "cell_type": "code",
|
523 |
| - "execution_count": 4, |
| 617 | + "execution_count": null, |
524 | 618 | "metadata": {},
|
525 | 619 | "outputs": [],
|
526 | 620 | "source": [
|
|
579 | 673 | "name": "python",
|
580 | 674 | "nbconvert_exporter": "python",
|
581 | 675 | "pygments_lexer": "ipython3",
|
582 |
| - "version": "3.7.5" |
| 676 | + "version": "3.8.3" |
583 | 677 | }
|
584 | 678 | },
|
585 | 679 | "nbformat": 4,
|
|
0 commit comments