-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfinalplot_combined_LDsvSD.py
152 lines (124 loc) · 5.08 KB
/
finalplot_combined_LDsvSD.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# -*- coding: utf-8 -*-
"""
Created on Mon Jul 25 16:04:54 2022
@author: fm02
"""
### Author: [email protected]
### Plot results from LDvsSD individual ROIs accuracy
import numpy as np
import pandas as pd
import pickle
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import sem
sns.set(rc={"figure.dpi":300, 'savefig.dpi':300})
sns.set_theme(context="notebook",
style="white",
font="sans-serif")
sns.set_style("ticks")
kkROI = ["lATL", "rATL", "AG", "PTC", "IFG", "PVA"]
def rms(example):
"""Compute root mean square of each ROI.
Input is a dataframe of length=n_vertices."""
# first transform Series in np array of dimension n_vertics*timepoints, when the input is unstacked.
example = np.vstack(np.array(example))
# create np.array where to store info
rms_example = np.zeros(example.shape[1])
# loop over timepoints
for i in np.arange(0,example.shape[1]):
rms_example[i] = np.sqrt(np.mean(example[:,i]**2))
return rms_example
#
kkROI = ['lATL', 'rATL', 'AG', 'PTC', 'IFG', 'PVA']
with open("/imaging/hauk/users/fm02/final_dTtT/combined_ROIs/LDvsSD/scores.P" , 'rb') as f:
scores = pickle.load(f)
with open("/imaging/hauk/users/fm02/final_dTtT/combined_ROIs/LDvsSD/patterns.P" , 'rb') as f:
patterns = pickle.load(f)
# # create times array
times = np.arange(-300,900,4)
colors = sns.color_palette(['#FFBE0B',
'#FB5607',
'#FF006E',
'#8338EC',
'#3A86FF',
'#1D437F',
'#1D437F'
])
# IBM colorblind palette
# '#648fff'
# '#785ef0'
# '#dc267f'
# '#fe6100'
# '#ffb000'
# '#000000'
# '#ffffff'
for task in scores.keys():
scores[task] = np.array(scores[task])
# initialise average(scores) key
scores['avg'] = [ [] for _ in range(len(scores)) ]
# calcualte average performance for each participant, across tasks
for i in range(0, 18):
scores['avg'].loc[i] = np.array([scores['mlk'][i],
scores['frt'][i],
scores['odr'][i]]).mean(axis=0)
# for roi in scores['avg'].keys():
# scores['avg'][roi] = np.array(scores['avg'][roi])
i = 0
for task in scores.keys():
# iter to select colours
# average plot all ROIs in one plot
# plot the average score across task (= scores['avg'], and across participants)
sns.lineplot(x=times, y=np.stack(scores[task]).mean(axis=0), color=colors[i], label=task)
# plot the standard error of the mean
plt.fill_between(x=times, \
y1=(np.stack(scores[task]).mean(axis=0)) - sem(np.stack(scores[task]),0), \
y2=(np.stack(scores[task]).mean(axis=0)) + sem(np.stack(scores[task]),0), \
color=colors[i], alpha=.1)
i += 2
# plot some line that are useful for inspection
plt.axvline(0, color='k');
plt.title('LD vs average(SD) Decoding ROC AUC')
plt.axhline(.5, color='k', linestyle='--', label='chance');
plt.legend();
#plt.savefig('//cbsu/data/Imaging/hauk/users/fm02/final_dTtT/combined_ROIs/LDvsSD/Figures/average_LDvsSD_accuracy.png', format='png')
plt.show();
patterns_roi = dict.fromkeys(kkROI)
for roi in patterns_roi.keys():
patterns_roi[roi] = dict.fromkeys(['frt', 'mlk', 'odr'])
for task in patterns_roi[roi].keys():
patterns_roi[roi][task] = []
# calculate the ROOT-MEAN-SQUARE for each pattern in each task
# loop over participants
for i in range(18):
# loop over each roi
for roi in patterns_roi.keys():
# loop over each task
for task in patterns_roi[roi].keys():
patterns_roi[roi][task].append(rms(np.array(patterns[task][i].loc[roi])))
for roi in patterns_roi.keys():
patterns_roi[roi]['avg'] = []
# calculate the average of the RMS(pattern) across each task
# loop over participants
for i in range(18):
# loop over each roi
for roi in patterns_roi.keys():
patterns_roi[roi]['avg'].append(np.array([patterns_roi[roi]['mlk'][i],
patterns_roi[roi]['frt'][i],
patterns_roi[roi]['odr'][i]]).mean(axis=0))
i = 0
for roi in patterns_roi.keys():
sns.lineplot(x=times, y=np.array(patterns_roi[roi]['avg']).mean(axis=0), color=colors[i]) # this takes mean over participants
i += 1
plt.axvline(0, color='k');
plt.title('LD vs average(SD) RMS patterns')
plt.legend(patterns_roi.keys());
#plt.savefig('/imaging/hauk/users/fm02/final_dTtT/combined_ROIs/LDvsSD/Figures/average_LDvsSD_patterns.png', format='png')
plt.show();
sns.lineplot(x=times, y=np.stack(scores['avg'].mean(axis=0)), color='black')
plt.fill_between(x=times, \
y1=(np.mean(np.stack(scores['avg']),0)-sem(np.stack(scores['avg']),0)), \
y2=(np.mean(np.stack(scores['avg']),0)+sem(np.stack(scores['avg']),0)), \
color='k', alpha=.1)
plt.title(f'LD vs average(SD) Decoding ROC AUC')
plt.axvline(0, color='k')
plt.axhline(.5, color='k', linestyle='--');