Skip to content

Commit e4830bd

Browse files
committed
nice fmri fig
1 parent 0c3e0d4 commit e4830bd

21 files changed

+119
-392
lines changed
File renamed without changes.

fmri/analyze_fmri.ipynb

Lines changed: 91 additions & 253 deletions
Large diffs are not rendered by default.

fmri/matplotlibrc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
figure.autolayout : True
2+
3+
axes.titlesize : 20
4+
axes.spines.top: False
5+
axes.spines.right: False
6+
axes.labelsize : 15 ## fontsize of the x any y labels
7+
8+
font.size : 15
9+
10+
xtick.labelsize : 15 ## fontsize of the tick labels
11+
ytick.labelsize : 15 ## fontsize of the tick labels

fmri/run.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,13 @@ def get_roi_and_idx(run):
6868
runs = [int(sys.argv[-1])]
6969
else:
7070
runs = list(range(300)) # this number determines which neuron we will pick
71-
print('runs', runs)
71+
print('\nruns', runs)
7272

7373
# fit linear models
7474
use_sigmas = False
7575
use_small = False
7676
out_dir = '/scratch/users/vision/data/gallant/vim_2_crcns'
77-
save_dir = oj(out_dir, 'dec13_baselines')
77+
save_dir = oj(out_dir, 'dec14_baselines_ard')
7878
suffix = '_feats' # _feats, '' for pixels
7979
norm = '_norm' # ''
8080
print('saving to', save_dir)
@@ -140,14 +140,17 @@ def get_roi_and_idx(run):
140140

141141
# only fit voxels with no missing vals
142142
if not (n_train == y_train.size and num_test == y_test.size):
143+
print('\tskipping this voxel!')
143144
continue
144145

145146
# reg values to try
146147
reg_params = np.logspace(3, 6, 20).round().astype(int)
147148

148149
# fit ard + mdl-rs
149150
baselines = {}
150-
for model_type, model_name in zip([ARDRegression, RidgeULNML], ['ard', 'mdl-rs']):
151+
for model_type, model_name in zip([ARDRegression], ['ard']):
152+
# for model_type, model_name in zip([ARDRegression, RidgeULNML], ['ard', 'mdl-rs']):
153+
print('\tfitting', model_name)
151154
model = model_type()
152155
model.fit(X_train, y_train)
153156
preds_train = model.predict(X_train)
@@ -159,6 +162,7 @@ def get_roi_and_idx(run):
159162
baselines[f'{model_name}_corr'] = np.corrcoef(y_test, preds)[0, 1]
160163

161164
# fit ridge cv
165+
print('\tfitting ridgecv...')
162166
m = RidgeCV(alphas=reg_params, store_cv_values=True)
163167
m.fit(X_train, y_train)
164168
preds_train = m.predict(X_train)
@@ -168,7 +172,7 @@ def get_roi_and_idx(run):
168172
mse = metrics.mean_squared_error(y_test, preds)
169173
r2 = metrics.r2_score(y_test, preds)
170174
corr = np.corrcoef(y_test, preds)[0, 1]
171-
print('RidgeCV corr', corr)
175+
print('\tRidgeCV corr', corr)
172176

173177

174178
# fit mdl comp
@@ -240,4 +244,5 @@ def get_roi_and_idx(run):
240244
**r,
241245
**baselines,
242246
}
243-
pkl.dump(results, open(oj(save_dir, f'ridge_{i}.pkl'), 'wb'))
247+
pkl.dump(results, open(oj(save_dir, f'ridge_{i}.pkl'), 'wb'))
248+
print('\tdone!')
-17 KB
Binary file not shown.
-11.6 KB
Binary file not shown.
-14.4 KB
Binary file not shown.

lib/pymdlrs/src/ulnml/least_square_regression.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from numpy import ndarray
44
import scipy.linalg as linalg
55
from typing import Union, Tuple
6+
from tqdm import tqdm
67
import warnings
78

89
from sklearn.base import BaseEstimator
@@ -51,7 +52,7 @@ def fit(self, X: ndarray, y: ndarray) -> 'RidgeULNML':
5152

5253
self.call_before_fit(X, y)
5354

54-
for i in range(self.n_iter):
55+
for i in tqdm(range(self.n_iter)):
5556
self.beta_ = self.fit_beta(C, b, self.lam_)
5657
self.sigma2_ = self.fit_sigma2(X, y, self.beta_, self.lam_)
5758
self.lam_ = self.fit_lam(

readme.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ Official code for using / reproducing MDL-COMP from the paper "Revisiting comple
44

55
# Reproducing the results in the paper
66
- most of the results can be produced by simply running the notebooks
7-
- the experiments with real-data are more in depth and require running the `submit_real_data_jobs.py` file (which is a script that calls `fit.py` with the appropriate hyperparameters) before running the notebook to view the analysis
7+
- the experiments with real-data are more in depth and require running `scripts/submit_real_data_jobs.py` (which is a script that calls `src/fit.py` with the appropriate hyperparameters) before running the notebook to view the analysis
88

9-
![](https://csinva.github.io/mdl-complexity/results/fig_iid_mse.svg)
9+
![](https://csinva.github.io/mdl-complexity/reports/fig_iid_mse.svg)
1010

1111

1212
## Calculating MDL-COMP
File renamed without changes.

fmri/submit_fmri.py renamed to scripts/submit_fmri.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
# run
13-
s = Slurm("fmri", {"partition": partition, "time": "1-0"})
13+
s = Slurm("fmri", {"partition": partition, "time": "2-0"})
1414
ks = sorted(params_to_vary.keys())
1515
vals = [params_to_vary[k] for k in ks]
1616
param_combinations = list(itertools.product(*vals)) # list of tuples
@@ -19,7 +19,7 @@
1919

2020
# iterate
2121
for i in range(len(param_combinations)):
22-
param_str = 'module load python; python3 run.py '
22+
param_str = 'module load python; python3 ../fmri/run.py '
2323
for j, key in enumerate(ks):
2424
param_str += key + ' ' + str(param_combinations[i][j]) + ' '
2525
print(param_str)

submit_real_data_jobs.py renamed to scripts/submit_real_data_jobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@
137137

138138
# iterate
139139
for i in range(len(params_full)):
140-
param_str = 'python3 fit.py '
140+
param_str = 'python3 ../src/fit.py '
141141
for j, key in enumerate(ks):
142142
param_str += key + ' ' + str(params_full[i][j]) + ' '
143143
subprocess.call(param_str, shell=True)

config.py renamed to src/config.py

File renamed without changes.

data.py renamed to src/data.py

File renamed without changes.

fit.py renamed to src/fit.py

File renamed without changes.

params.py renamed to src/params.py

File renamed without changes.
File renamed without changes.
File renamed without changes.

style.py renamed to src/style.py

File renamed without changes.

submit_real_data_jobs_slurm.py

Lines changed: 0 additions & 128 deletions
This file was deleted.

0 commit comments

Comments
 (0)