Skip to content

Commit

Permalink
new channel features
Browse files Browse the repository at this point in the history
  • Loading branch information
timonmerk committed Dec 11, 2023
1 parent 083810d commit 74a743d
Showing 1 changed file with 88 additions and 60 deletions.
148 changes: 88 additions & 60 deletions plot_bispectrum_thomas.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,20 @@
PATH_OUT = r"E:\scratch"
RUN_NAME = "TestBispectrum"

raw = mne.io.read_raw_brainvision(r"C:\Users\ICN_admin\Documents\Datasets\Berlin\sub-002\ses-EcogLfpMedOff01\ieeg\sub-002_ses-EcogLfpMedOff01_task-SelfpacedRotationR_acq-StimOff_run-01_ieeg.vhdr")
#raw = mne.io.read_raw_brainvision(r"C:\Users\ICN_admin\Documents\Datasets\Berlin\sub-003\ses-EcogLfpMedOff01\ieeg\sub-003_ses-EcogLfpMedOff01_task-SelfpacedRotationR_acq-StimOff_run-01_ieeg.vhdr")
#raw = mne.io.read_raw_brainvision(r"C:\Users\ICN_admin\Documents\Datasets\Berlin\sub-005\ses-EcogLfpMedOff01\ieeg\sub-005_ses-EcogLfpMedOff01_task-SelfpacedRotationR_acq-StimOff_run-01_ieeg.vhdr")
#raw = mne.io.read_raw_brainvision(r"C:\Users\ICN_admin\Documents\Datasets\Berlin\sub-005\ses-EcogLfpMedOff02\ieeg\sub-005_ses-EcogLfpMedOff02_task-SelfpacedRotationR_acq-StimOn_run-01_ieeg.vhdr")
raw = mne.io.read_raw_brainvision(
r"C:\Users\ICN_admin\Documents\Datasets\Berlin\sub-002\ses-EcogLfpMedOff01\ieeg\sub-002_ses-EcogLfpMedOff01_task-SelfpacedRotationR_acq-StimOff_run-01_ieeg.vhdr"
)
# raw = mne.io.read_raw_brainvision(r"C:\Users\ICN_admin\Documents\Datasets\Berlin\sub-003\ses-EcogLfpMedOff01\ieeg\sub-003_ses-EcogLfpMedOff01_task-SelfpacedRotationR_acq-StimOff_run-01_ieeg.vhdr")
# raw = mne.io.read_raw_brainvision(r"C:\Users\ICN_admin\Documents\Datasets\Berlin\sub-005\ses-EcogLfpMedOff01\ieeg\sub-005_ses-EcogLfpMedOff01_task-SelfpacedRotationR_acq-StimOff_run-01_ieeg.vhdr")
# raw = mne.io.read_raw_brainvision(r"C:\Users\ICN_admin\Documents\Datasets\Berlin\sub-005\ses-EcogLfpMedOff02\ieeg\sub-005_ses-EcogLfpMedOff02_task-SelfpacedRotationR_acq-StimOn_run-01_ieeg.vhdr")


raw.pick(["ECOG_L_1_SMC_AT", "SQUARED_ROTATION"])
#raw.pick(["ECOG_L_1_SMC_AT"])
#raw.pick(["ECOG_R_1_SMC_AT", "rota_squared"])
raw.pick(["ECOG_L_5_SMC_AT", "SQUARED_ROTATION"])
# raw.pick(["ECOG_L_1_SMC_AT"])
# raw.pick(["ECOG_R_1_SMC_AT", "rota_squared"])
sfreq = raw.info["sfreq"]

data = raw.get_data()#[:, :int(sfreq)*100]
data = raw.get_data() # [:, :int(sfreq)*100]


line_noise = 50
Expand All @@ -66,29 +68,29 @@
target_keywords=["SQUARED_ROTATION"],
)

nm_channels.loc[nm_channels["name"] == "ECOG_L_1_SMC_AT", "used"] = 1
nm_channels.loc[nm_channels["name"] == "ECOG_L_6_SMC_AT", "used"] = 1
nm_channels.loc[nm_channels["name"] == "SQUARED_ROTATION", "target"] = 1

# %%
# This example contains the grip force movement traces, we'll use the *MOV_RIGHT* channel as a decoding target channel.
# Let's check some of the raw feature and time series traces:

#plt.figure(figsize=(12, 4), dpi=300)
#plt.subplot(121)
#plt.plot(raw.times, data[-1, :])
#plt.xlabel("Time [s]")
#plt.ylabel("a.u.")
#plt.title("Movement label")
#plt.xlim(0, 20)
# plt.figure(figsize=(12, 4), dpi=300)
# plt.subplot(121)
# plt.plot(raw.times, data[-1, :])
# plt.xlabel("Time [s]")
# plt.ylabel("a.u.")
# plt.title("Movement label")
# plt.xlim(0, 20)

#plt.subplot(122)
#for idx, ch_name in enumerate(nm_channels.query("used == 1").name):
# plt.subplot(122)
# for idx, ch_name in enumerate(nm_channels.query("used == 1").name):
# plt.plot(raw.times, data[idx, :] + idx * 300, label=ch_name)
#plt.legend(bbox_to_anchor=(1, 0.5), loc="center left")
#plt.title("ECoG + STN-LFP time series")
#plt.xlabel("Time [s]")
#plt.ylabel("Voltage a.u.")
#plt.xlim(0, 20)
# plt.legend(bbox_to_anchor=(1, 0.5), loc="center left")
# plt.title("ECoG + STN-LFP time series")
# plt.xlabel("Time [s]")
# plt.ylabel("Voltage a.u.")
# plt.xlim(0, 20)

# %%
settings = nm_settings.get_default_settings()
Expand Down Expand Up @@ -119,14 +121,27 @@
folder_name=RUN_NAME,
)

features_plt = features[np.logical_and(features["time"] > int(sfreq)*50, features["time"] < int(sfreq)*60)]
features_plt = features[
np.logical_and(
features["time"] > int(sfreq) * 50, features["time"] < int(sfreq) * 60
)
]
plt.figure()
plt.subplot(211)
plt.plot(data[0, int(sfreq)*50 : int(sfreq)*60])
plt.plot(data[0, int(sfreq) * 50 : int(sfreq) * 60])
plt.subplot(212)
plt.plot(features_plt["ECOG_L_1_SMC_AT_Bispectrum_phase_mean_whole_fband_range"], label="phase")
plt.plot(features_plt["ECOG_L_1_SMC_AT_Bispectrum_real_mean_whole_fband_range"], label="real")
plt.plot(features_plt["ECOG_L_1_SMC_AT_Bispectrum_imag_mean_whole_fband_range"], label="imag")
plt.plot(
features_plt["ECOG_L_6_SMC_AT_Bispectrum_phase_mean_whole_fband_range"],
label="phase",
)
plt.plot(
features_plt["ECOG_L_6_SMC_AT_Bispectrum_real_mean_whole_fband_range"],
label="real",
)
plt.plot(
features_plt["ECOG_L_6_SMC_AT_Bispectrum_imag_mean_whole_fband_range"],
label="imag",
)
plt.legend()
plt.show()

Expand All @@ -153,22 +168,22 @@
feature_reader._get_target_ch()

features_to_plt = [
"ECOG_L_1_SMC_AT_Bispectrum_phase_sum_whole_fband_range",
"ECOG_L_1_SMC_AT_Bispectrum_absolute_mean_theta",
"ECOG_L_1_SMC_AT_Bispectrum_absolute_mean_alpha",
"ECOG_L_1_SMC_AT_Bispectrum_absolute_mean_low beta",
"ECOG_L_1_SMC_AT_Bispectrum_real_sum_low beta",
"ECOG_L_1_SMC_AT_Bispectrum_absolute_mean_high beta",
"ECOG_L_1_SMC_AT_Bispectrum_real_mean_high beta",
"ECOG_L_1_SMC_AT_Bispectrum_real_mean_theta",
"ECOG_L_1_SMC_AT_Bispectrum_real_mean_alpha",
#"ECOG_L_1_SMC_AT_Bispectrum_real_mean_low beta",
"ECOG_L_1_SMC_AT_Bispectrum_real_var_alpha",
][::-1]
"ECOG_L_6_SMC_AT_Bispectrum_phase_sum_whole_fband_range",
"ECOG_L_6_SMC_AT_Bispectrum_absolute_mean_theta",
"ECOG_L_6_SMC_AT_Bispectrum_absolute_mean_alpha",
"ECOG_L_6_SMC_AT_Bispectrum_absolute_mean_low beta",
"ECOG_L_6_SMC_AT_Bispectrum_real_sum_low beta",
"ECOG_L_6_SMC_AT_Bispectrum_absolute_mean_high beta",
"ECOG_L_6_SMC_AT_Bispectrum_real_mean_high beta",
"ECOG_L_6_SMC_AT_Bispectrum_real_mean_theta",
"ECOG_L_6_SMC_AT_Bispectrum_real_mean_alpha",
# "ECOG_L_1_SMC_AT_Bispectrum_real_mean_low beta",
"ECOG_L_6_SMC_AT_Bispectrum_real_var_alpha",
][::-1]

# %%
feature_reader.plot_target_averaged_channel(
ch="ECOG_L_1_SMC_AT",
ch="ECOG_L_6_SMC_AT",
list_feature_keywords=None,
features_to_plt=features_to_plt,
epoch_len=6,
Expand All @@ -179,16 +194,26 @@
)


df_plt = feature_reader.feature_arr[features_to_plt+["SQUARED_ROTATION"]].astype("float").T
df_plt = (
feature_reader.feature_arr[features_to_plt + ["SQUARED_ROTATION"]]
.astype("float")
.T
)

plt.figure(figsize=(15, 5), dpi=300)
plt.imshow(stats.zscore(df_plt, axis=1), aspect="auto")
plt.clim(-3, 3)
ytick_labelsize = 12
plt.yticks(np.arange(len(features_to_plt+["SQUARED_ROTATION"])), features_to_plt+["SQUARED_ROTATION"], size=ytick_labelsize)
plt.yticks(
np.arange(len(features_to_plt + ["SQUARED_ROTATION"])),
features_to_plt + ["SQUARED_ROTATION"],
size=ytick_labelsize,
)

tick_num = np.arange(0, df_plt.shape[1], int(df_plt.shape[1] / 100))
tick_labels = np.array(np.rint(feature_reader.feature_arr["time"].iloc[tick_num] / 1000), dtype=int)
tick_labels = np.array(
np.rint(feature_reader.feature_arr["time"].iloc[tick_num] / 1000), dtype=int
)
plt.xticks(tick_num, tick_labels)
plt.xlabel("Time [s]")
plt.ylabel("Features")
Expand All @@ -203,7 +228,7 @@
from py_neuromodulation import nm_plots

nm_plots.plot_all_features(
df=feature_reader.feature_arr[features_to_plt+["time"]],
df=feature_reader.feature_arr[features_to_plt + ["time"]],
time_limit_low_s=15,
time_limit_high_s=20,
normalize=True,
Expand All @@ -212,8 +237,8 @@
ytick_labelsize=12,
feature_file=feature_reader.feature_file,
OUT_PATH=feature_reader.feature_dir,
#clim_low=1,
#clim_high=-1,
# clim_low=1,
# clim_high=-1,
)


Expand All @@ -222,15 +247,15 @@
ytick_labelsize=10,
clim_low=-2,
clim_high=2,
ch_used="ECOG_L_1_SMC_AT",
ch_used="ECOG_L_6_SMC_AT",
time_limit_low_s=50,
time_limit_high_s=70,
normalize=True,
save=True,
)

feature_reader.plot_target_averaged_channel(
ch="ECOG_L_1_SMC_AT",
ch="ECOG_L_6_SMC_AT",
list_feature_keywords=None,
epoch_len=7,
threshold=0.5,
Expand All @@ -250,13 +275,12 @@
)



# %%
nm_plots.plot_corr_matrix(
feature=feature_reader.feature_arr.filter(regex="ECOG_L_1_SMC_AT"),
ch_name="ECOG_L_1_SMC_AT",
feature=feature_reader.feature_arr.filter(regex="ECOG_L_6_SMC_AT"),
ch_name="ECOG_L_6_SMC_AT",
feature_names=feature_reader.feature_arr.filter(
regex="ECOG_L_1_SMC_AT"
regex="ECOG_L_6_SMC_AT"
).columns,
feature_file=feature_reader.feature_file,
show_plot=True,
Expand All @@ -273,22 +297,26 @@
#
# Here, we show an example using the XGBOOST classifier. The used labels came from a continuous grip force movement target, named "MOV_RIGHT".
#
# First we initialize the :class:`~nm_decode.Decoder` class, which the specified *validation method*, here being a simple 3-fold cross validation,
# First we initialize the :class:`~nm_decode.Decoder` class, which the specified *validation method*, here being a simple 3-fold cross validation,
# the evaluation metric, used machine learning model, and the channels we want to evaluate performances for.
#
# There are many more implemented methods, but we will here limit it to the ones presented.

model = linear_model.LogisticRegression()
feature_reader.feature_arr['SQUARED_ROTATION'] = feature_reader.feature_arr['SQUARED_ROTATION'].astype(int)>0.5
feature_reader.feature_arr["SQUARED_ROTATION"] = (
feature_reader.feature_arr["SQUARED_ROTATION"].astype(int) > 0.5
)

feature_reader.decoder = nm_decode.Decoder(
features=feature_reader.feature_arr[[f for f in feature_reader.feature_arr.columns if "imag" in f]],
features=feature_reader.feature_arr[
[f for f in feature_reader.feature_arr.columns if "time" not in f]
],
label=np.array(feature_reader.label).astype(int),
label_name=feature_reader.label_name,
used_chs=feature_reader.used_chs,
model=model,
eval_method=metrics.balanced_accuracy_score,
cv_method=model_selection.KFold(n_splits=3, shuffle=True),
cv_method=model_selection.KFold(n_splits=3, shuffle=False),
)

# %%
Expand All @@ -312,10 +340,10 @@
x_col="sub",
y_col="performance_test",
hue="ch_type",
PATH_SAVE=PATH_OUT / RUN_NAME / (RUN_NAME + "_decoding_performance.png"),
figsize_tuple=(8, 5)
# PATH_SAVE=PATH_OUT / RUN_NAME / (RUN_NAME + "_decoding_performance.png"),
figsize_tuple=(8, 5),
)
ax.set_ylabel(r"$R^2$ Correlation")
ax.set_xlabel("Subject 000")
ax.set_title("Performance comparison Movement decoding")
plt.tight_layout()
plt.tight_layout()

0 comments on commit 74a743d

Please sign in to comment.