Skip to content

Commit 0879445

Browse files
committed
add kernel fmri experiments
1 parent e4830bd commit 0879445

7 files changed

+788
-95
lines changed

fmri/analyze_fmri.ipynb

Lines changed: 52 additions & 65 deletions
Large diffs are not rendered by default.

fmri/analyze_fmri_kernel.ipynb

Lines changed: 415 additions & 0 deletions
Large diffs are not rendered by default.

fmri/preprocess_fmri.ipynb

Lines changed: 109 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 6,
5+
"execution_count": 1,
66
"metadata": {},
77
"outputs": [],
88
"source": [
@@ -20,7 +20,7 @@
2020
"import h5py\n",
2121
"from copy import deepcopy\n",
2222
"from skimage.filters import gabor_kernel\n",
23-
"import gabor_feats\n",
23+
"# import gabor_feats\n",
2424
"from sklearn.linear_model import RidgeCV\n",
2525
"import seaborn as sns\n",
2626
"from scipy.io import loadmat\n",
@@ -166,17 +166,9 @@
166166
},
167167
{
168168
"cell_type": "code",
169-
"execution_count": 168,
169+
"execution_count": null,
170170
"metadata": {},
171-
"outputs": [
172-
{
173-
"name": "stdout",
174-
"output_type": "stream",
175-
"text": [
176-
"[\"/roi/FFAlh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/FFArh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/IPlh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/IPrh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/MTlh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/MTplh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/MTprh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/MTrh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/OBJlh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/OBJrh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/PPAlh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/PPArh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/RSCrh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/STSrh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/VOlh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/VOrh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/latocclh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/latoccrh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v1lh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v1rh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v2lh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v2rh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v3alh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v3arh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v3blh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v3brh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v3lh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v3rh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v4lh (EArray(18, 64, 64), zlib(3)) ''\", \"/roi/v4rh (EArray(18, 64, 64), zlib(3)) ''\"]\n"
177-
]
178-
}
179-
],
171+
"outputs": [],
180172
"source": [
181173
"f = tables.open_file(oj(out_dir, 'VoxelResponses_subject1.mat'))\n",
182174
"xs = []\n",
@@ -514,13 +506,115 @@
514506
"cell_type": "markdown",
515507
"metadata": {},
516508
"source": [
517-
"# visualize features / decompositions\n",
509+
"# kernel features"
510+
]
511+
},
512+
{
513+
"cell_type": "code",
514+
"execution_count": 30,
515+
"metadata": {},
516+
"outputs": [],
517+
"source": [
518+
"X = np.array(loadmat(oj(out_dir, 'mot_energy_feats_st.mat'))['S_fin'])\n",
519+
"X_test = np.array(loadmat(oj(out_dir, 'mot_energy_feats_sv.mat'))['S_fin'])"
520+
]
521+
},
522+
{
523+
"cell_type": "code",
524+
"execution_count": 31,
525+
"metadata": {},
526+
"outputs": [],
527+
"source": [
528+
"from jax import random\n",
529+
"from jax.experimental import stax\n",
530+
"from jax import random\n",
531+
"from neural_tangents import stax\n",
532+
"\n",
533+
"# kernel function\n",
534+
"init_fn, apply_fn, kernel_fn = stax.serial(\n",
535+
" stax.Dense(512), stax.Relu(),\n",
536+
" stax.Dense(512), stax.Relu(),\n",
537+
" stax.Dense(1)\n",
538+
")"
539+
]
540+
},
541+
{
542+
"cell_type": "code",
543+
"execution_count": 22,
544+
"metadata": {},
545+
"outputs": [],
546+
"source": [
547+
"# training kernel mat\n",
548+
"kernel = kernel_fn(X, X, 'ntk')\n",
549+
"fname = oj(out_dir, f'mot_energy_feats_kernel_mat_ntk.pkl')\n",
550+
"if not os.path.exists(fname):\n",
551+
" save_pkl(kernel, fname)\n",
552+
" \n",
553+
"# training kernel mat\n",
554+
"kernel_test = kernel_fn(X_test, X, 'ntk')\n",
555+
"fname = oj(out_dir, f'mot_energy_feats_kernel_test_with_train_ntk.pkl')\n",
556+
"if not os.path.exists(fname):\n",
557+
" save_pkl(kernel_test, fname)\n",
558+
" \n",
559+
"# save out eigenvals\n",
560+
"fname = oj(out_dir, f'eigenvals_eigenvecs_mot_energy_kernel_ntk.pkl')\n",
561+
"if not os.path.exists(fname):\n",
562+
" kernel = load_pkl(oj(out_dir, f'mot_energy_feats_kernel_mat_ntk.pkl'))\n",
563+
" eigenvals, eigenvecs = npl.eig(kernel)\n",
564+
" save_pkl((eigenvals, eigenvecs), fname)"
565+
]
566+
},
567+
{
568+
"cell_type": "code",
569+
"execution_count": null,
570+
"metadata": {},
571+
"outputs": [
572+
{
573+
"name": "stderr",
574+
"output_type": "stream",
575+
"text": [
576+
" 30%|███ | 6/20 [46:02<1:56:09, 497.85s/it]"
577+
]
578+
}
579+
],
580+
"source": [
581+
"# save kernel pinvs\n",
582+
"reg_params = np.logspace(3, 6, 20).round().astype(int)\n",
583+
"kernel = load_pkl(oj(out_dir, f'mot_energy_feats_kernel_mat_ntk.pkl'))\n",
584+
"for reg_param in tqdm(reg_params):\n",
585+
" fname = oj(out_dir, f'pinv_mot_energy_kernel_ntk_{reg_param}.pkl')\n",
586+
" if not os.path.exists(fname):\n",
587+
" inv = npl.pinv(kernel + reg_param * np.eye(kernel.shape[0]))\n",
588+
" save_pkl(inv, fname)"
589+
]
590+
},
591+
{
592+
"cell_type": "code",
593+
"execution_count": null,
594+
"metadata": {},
595+
"outputs": [],
596+
"source": [
597+
"# need to save kernel matrix\n",
598+
"# need to save test-time kernel mat\n",
599+
"\n",
600+
"\n",
601+
"# save eigenvalues\n",
602+
"\n",
603+
"# make new script\n",
604+
"# need to switch to use Kernel ridge + eigenvalues"
605+
]
606+
},
607+
{
608+
"cell_type": "markdown",
609+
"metadata": {},
610+
"source": [
611+
"# visualize preprocessed features\n",
518612
"**load and look at features**"
519613
]
520614
},
521615
{
522616
"cell_type": "code",
523-
"execution_count": 4,
617+
"execution_count": null,
524618
"metadata": {},
525619
"outputs": [],
526620
"source": [
@@ -579,7 +673,7 @@
579673
"name": "python",
580674
"nbconvert_exporter": "python",
581675
"pygments_lexer": "ipython3",
582-
"version": "3.7.5"
676+
"version": "3.8.3"
583677
}
584678
},
585679
"nbformat": 4,

fmri/run.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,24 @@ def load_h5(fname):
3939
f.close()
4040
return data
4141

42+
def load_pkl(fname):
43+
return pkl.load(open(fname, "rb" ))
44+
4245
def save_pkl(d, fname):
4346
if os.path.exists(fname):
4447
os.remove(fname)
4548
with open(fname, 'wb') as f:
4649
pkl.dump(d, f)
4750

48-
def get_roi_and_idx(run):
51+
def get_roi_and_idx(run, out_dir, sigmas):
4952
# select roi + i (which is the roi_idx)
5053
rois = ['v1lh', 'v2lh', 'v4lh', 'v1rh', 'v2rh', 'v4rh']
5154
roi = rois[run % len(rois)]
5255

5356
f = tables.open_file(oj(out_dir, 'VoxelResponses_subject1.mat'), 'r')
5457
roi_idxs_all = f.get_node(f'/roi/{roi}')[:].flatten().nonzero()[0] # structure containing volume matrices (64x64x18) with indices corresponding to each roi in each hemisphere
55-
roi_idxs = np.array([roi_idx for roi_idx in roi_idxs_all if ~np.isnan(sigmas[roi_idx])])
58+
roi_idxs = np.array([roi_idx for roi_idx in roi_idxs_all
59+
if ~np.isnan(sigmas[roi_idx])])
5660

5761
i = roi_idxs[run // len(rois)] # i is the roi idx
5862
return roi, i
@@ -77,6 +81,7 @@ def get_roi_and_idx(run):
7781
save_dir = oj(out_dir, 'dec14_baselines_ard')
7882
suffix = '_feats' # _feats, '' for pixels
7983
norm = '_norm' # ''
84+
reg_params = np.logspace(3, 6, 20).round().astype(int) # reg values to try (must match preprocess_fmri)
8085
print('saving to', save_dir)
8186

8287

@@ -97,6 +102,14 @@ def get_roi_and_idx(run):
97102
else:
98103
X_train = np.array(loadmat(oj(out_dir, 'mot_energy_feats_st.mat'))['S_fin'])
99104
X_test = np.array(loadmat(oj(out_dir, 'mot_energy_feats_sv.mat'))['S_fin'])
105+
if use_small:
106+
# (U, alphas, _) = pkl.load(open(oj(out_dir, f'decomp_mot_energy_small.pkl'), 'rb'))
107+
(eigenvals, eigenvecs) = pkl.load(open(oj(out_dir, f'eigenvals_eigenvecs_mot_energy_small.pkl'), 'rb'))
108+
Y_train = Y_train[:, :720]
109+
else:
110+
# (U, alphas, _) = pkl.load(open(oj(out_dir, f'decomp_mot_energy.pkl'), 'rb'))
111+
(eigenvals, eigenvecs) = pkl.load(open(oj(out_dir, f'eigenvals_eigenvecs_mot_energy.pkl'), 'rb'))
112+
100113

101114
'''
102115
# load the raw responses
@@ -108,18 +121,12 @@ def get_roi_and_idx(run):
108121
Y_train = load_h5(oj(out_dir, 'rt_norm.h5')) # training responses: 73728 (voxels) x 7200 (timepoints)
109122
Y_test = load_h5(oj(out_dir, 'rv_norm.h5') )
110123
sigmas = load_h5(oj(out_dir, f'out_rva_sigmas_norm.h5')) # stddev across repeats
111-
if use_small:
112-
# (U, alphas, _) = pkl.load(open(oj(out_dir, f'decomp_mot_energy_small.pkl'), 'rb'))
113-
(eigenvals, eigenvecs) = pkl.load(open(oj(out_dir, f'eigenvals_eigenvecs_mot_energy_small.pkl'), 'rb'))
114-
Y_train = Y_train[:, :720]
115-
else:
116-
# (U, alphas, _) = pkl.load(open(oj(out_dir, f'decomp_mot_energy.pkl'), 'rb'))
117-
(eigenvals, eigenvecs) = pkl.load(open(oj(out_dir, f'eigenvals_eigenvecs_mot_energy.pkl'), 'rb'))
124+
118125

119126

120127
# loop over individual neurons
121128
for run in runs:
122-
roi, i = get_roi_and_idx(run)
129+
roi, i = get_roi_and_idx(run, out_dir, sigmas)
123130
results = {}
124131
os.makedirs(save_dir, exist_ok=True)
125132
print('fitting', roi, 'idx', i)
@@ -143,9 +150,6 @@ def get_roi_and_idx(run):
143150
print('\tskipping this voxel!')
144151
continue
145152

146-
# reg values to try
147-
reg_params = np.logspace(3, 6, 20).round().astype(int)
148-
149153
# fit ard + mdl-rs
150154
baselines = {}
151155
for model_type, model_name in zip([ARDRegression], ['ard']):
@@ -245,4 +249,4 @@ def get_roi_and_idx(run):
245249
**baselines,
246250
}
247251
pkl.dump(results, open(oj(save_dir, f'ridge_{i}.pkl'), 'wb'))
248-
print('\tdone!')
252+
print(f'\tsuccesfully finished run {run}!')

0 commit comments

Comments
 (0)