Skip to content

Commit

Permalink
add feature options for oscillatory features
Browse files Browse the repository at this point in the history
  • Loading branch information
timonmerk committed Dec 11, 2023
1 parent eab11ce commit 80d7327
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 39 deletions.
58 changes: 26 additions & 32 deletions py_neuromodulation/nm_oscillatory.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,6 @@ def test_settings_osc(
assert isinstance(
s[osc_feature_name]["log_transform"], bool
), f"log_transform needs to be type bool, got {s[osc_feature_name]['log_transform']}"
assert isinstance(
s[osc_feature_name]["kalman_filter"], bool
), f"kalman_filter needs to be type bool, got {s[osc_feature_name]['kalman_filter']}"

if s[osc_feature_name]["kalman_filter"] is True:
nm_kalmanfilter.test_kf_settings(s, ch_names, sfreq)

assert isinstance(s["frequency_ranges_hz"], dict)

Expand Down Expand Up @@ -95,6 +89,22 @@ def update_KF(self, feature_calc: float, KF_name: str) -> float:
feature_calc = self.KF_dict[KF_name].x[0]
return feature_calc

def estimate_osc_features(self, features_compute: dict, data: np.ndarray, feature_name: np.ndarray, est_name: str):
for feature_est_name in list(self.s[est_name]["features"].keys()):
if self.s[est_name]["features"][feature_est_name] is True:
# switch case for feature_est_name
match feature_est_name:
case "mean":
features_compute[f"{feature_name}_{feature_est_name}"] = np.nanmean(data)
case "median":
features_compute[f"{feature_name}_{feature_est_name}"] = np.nanmedian(data)
case "std":
features_compute[f"{feature_name}_{feature_est_name}"] = np.nanstd(data)
case "max":
features_compute[f"{feature_name}_{feature_est_name}"] = np.nanmax(data)

return features_compute


class FFT(OscillatoryFeature):
def __init__(
Expand All @@ -104,8 +114,6 @@ def __init__(
sfreq: float,
) -> None:
super().__init__(settings, ch_names, sfreq)
if self.s["fft_settings"]["kalman_filter"]:
self.init_KF("fft")

if self.s["fft_settings"]["log_transform"]:
self.log_transform = True
Expand All @@ -132,17 +140,16 @@ def test_settings(s: dict, ch_names: Iterable[str], sfreq: int | float):
def calc_feature(self, data: np.ndarray, features_compute: dict) -> dict:
data = data[:, self.window_samples :]
Z = np.abs(fft.rfft(data))

if self.log_transform:
Z = np.log10(Z)

for ch_idx, feature_name, idx_range in self.feature_params:
Z_ch = Z[ch_idx, idx_range]
feature_calc = np.mean(Z_ch)

if self.log_transform:
feature_calc = np.log10(feature_calc)
Z_ch = Z[ch_idx, idx_range]

if self.KF_dict:
feature_calc = self.update_KF(feature_calc, feature_name)
features_compute = self.estimate_osc_features(features_compute, Z_ch, feature_name, "fft_settings")

features_compute[feature_name] = feature_calc
return features_compute


Expand All @@ -154,8 +161,6 @@ def __init__(
sfreq: float,
) -> None:
super().__init__(settings, ch_names, sfreq)
if self.s["welch_settings"]["kalman_filter"]:
self.init_KF("welch")

self.log_transform = self.s["welch_settings"]["log_transform"]

Expand Down Expand Up @@ -187,12 +192,8 @@ def calc_feature(self, data: np.ndarray, features_compute: dict) -> dict:
Z_ch = np.log10(Z_ch)

idx_range = np.where((f >= f_range[0]) & (f <= f_range[1]))[0]
feature_calc = np.mean(Z_ch[idx_range])

if self.KF_dict:
feature_calc = self.update_KF(feature_calc, feature_name)

features_compute[feature_name] = feature_calc
features_compute = self.estimate_osc_features(features_compute, Z_ch[idx_range], feature_name, "welch_settings")

return features_compute

Expand All @@ -205,8 +206,6 @@ def __init__(
sfreq: float,
) -> None:
super().__init__(settings, ch_names, sfreq)
if self.s["stft_settings"]["kalman_filter"]:
self.init_KF("stft")

self.nperseg = int(self.s["stft_settings"]["windowlength_ms"])
self.log_transform = self.s["stft_settings"]["log_transform"]
Expand All @@ -232,18 +231,13 @@ def calc_feature(self, data: np.ndarray, features_compute: dict) -> dict:
boundary="even",
)
Z = np.abs(Zxx)
if self.log_transform:
Z = np.log10(Z)
for ch_idx, feature_name, f_range in self.feature_params:
Z_ch = Z[ch_idx]
idx_range = np.where((f >= f_range[0]) & (f <= f_range[1]))[0]
feature_calc = np.mean(Z_ch[idx_range, :]) # 1. dim: f, 2. dim: t

if self.KF_dict:
feature_calc = self.update_KF(feature_calc, feature_name)

if self.log_transform:
feature_calc = np.log10(feature_calc)

features_compute[feature_name] = feature_calc
features_compute = self.estimate_osc_features(features_compute, Z_ch[idx_range, :], feature_name, "stft_settings")

return features_compute

Expand Down
21 changes: 18 additions & 3 deletions py_neuromodulation/nm_settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,32 @@
"fft_settings": {
"windowlength_ms": 1000,
"log_transform": true,
"kalman_filter": false
"features": {
"mean": true,
"median": false,
"std": false,
"max" : false
}
},
"welch_settings": {
"windowlength_ms": 1000,
"log_transform": true,
"kalman_filter": false
"features": {
"mean": true,
"median": false,
"std": false,
"max" : false
}
},
"stft_settings": {
"windowlength_ms": 500,
"log_transform": true,
"kalman_filter": false
"features": {
"mean": true,
"median": false,
"std": false,
"max" : false
}
},
"bandpass_filter_settings": {
"segment_lengths_ms": {
Expand Down
17 changes: 13 additions & 4 deletions py_neuromodulation/nm_stream_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,13 @@ def _run_offline(

def plot_raw_signal(
self,
sfreq: float,
sfreq: float = None,
data: np.array = None,
lowpass: float = None,
highpass: float = None,
picks: list = None,
plot_time: bool = True,
plot_psd: bool = True,
plot_psd: bool = False,
) -> None:
"""Use MNE-RawArray Plot to investigate PSD or raw_signal plot.
Expand All @@ -156,6 +159,9 @@ def plot_raw_signal(
if data is None and self.data is not None:
data = self.data

if sfreq is None:
sfreq = self.sfreq

if self.nm_channels is not None:
ch_names = self.nm_channels["name"].to_list()
ch_types = self.nm_channels["type"].to_list()
Expand All @@ -165,12 +171,15 @@ def plot_raw_signal(

# create mne.RawArray
info = mne.create_info(
ch_names=ch_names, sfreq=self.sfreq, ch_types=ch_types
ch_names=ch_names, sfreq=sfreq, ch_types=ch_types
)
raw = mne.io.RawArray(data, info)

if picks is not None:
raw = raw.pick(picks)
self.raw = raw
if plot_time:
raw.plot()
raw.plot(highpass=highpass, lowpass=lowpass)
if plot_psd:
raw.plot_psd()

Expand Down

0 comments on commit 80d7327

Please sign in to comment.