Skip to content

Commit

Permalink
Merge pull request #274 from neuromodulation/fix_data_selection
Browse files Browse the repository at this point in the history
fix ch_used data selection
  • Loading branch information
timonmerk authored Dec 14, 2023
2 parents b544190 + 9025865 commit a5389b2
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 80 deletions.
1 change: 1 addition & 0 deletions examples/plot_1_example_BIDS.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
settings = nm_settings.get_default_settings()
settings = nm_settings.set_settings_fast_compute(settings)

settings["features"]["welch"] = True
settings["features"]["fft"] = True
settings["features"]["bursts"] = True
settings["features"]["sharpwave_analysis"] = True
Expand Down
18 changes: 10 additions & 8 deletions py_neuromodulation/nm_rereference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@


class ReReferencer:

ref_matrix: np.ndarray


def __init__(
self,
sfreq: int | float,
Expand All @@ -28,11 +26,15 @@ def __init__(
ValueError: rereferencing using undefined channel
ValueError: rereferencing to same channel
"""
(channels_used,) = np.where((nm_channels.used == 1))
nm_channels = nm_channels[nm_channels["used"] == 1].reset_index(
drop=True
)
# (channels_used,) = np.where((nm_channels.used == 1))

ch_names = nm_channels["name"].tolist()

if len(ch_names) == 1:
# no re-referencing is being performed when there is a single channel present only
if nm_channels.shape[0] in (0, 1):
self.ref_matrix = None
return

Expand All @@ -48,8 +50,8 @@ def __init__(
ref_matrix = np.zeros((len(nm_channels), len(nm_channels)))
for ind in range(len(nm_channels)):
ref_matrix[ind, ind] = 1
if ind not in channels_used:
continue
# if ind not in channels_used:
# continue
ref = refs[ind]
if ref.lower() == "none" or pd.isnull(ref):
ref_idx = None
Expand Down Expand Up @@ -84,10 +86,10 @@ def process(self, data: np.ndarray) -> np.ndarray:
shape(n_channels, n_samples) - data to be rereferenced.
Returns:
reref_data (numpy ndarray):
reref_data (numpy ndarray):
shape(n_channels, n_samples) - rereferenced data
"""
if self.ref_matrix is not None:
return self.ref_matrix @ data
else:
return data
return data
11 changes: 8 additions & 3 deletions py_neuromodulation/nm_run_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def process(self, data: np.ndarray) -> pd.Series:

nan_channels = np.isnan(data).any(axis=1)

data = np.nan_to_num(data) # [self.feature_idx, :]needs to be before preprocessing
data = np.nan_to_num(data)[self.feature_idx, :]

for processor in self.preprocessors:
data = processor.process(data)
Expand All @@ -341,8 +341,13 @@ def process(self, data: np.ndarray) -> pd.Series:

# normalize features
if self.settings["postprocessing"]["feature_normalization"]:
normed_features = self.feature_normalizer.process(np.fromiter(features_dict.values(), dtype="float"))
features_dict = {key: normed_features[idx] for idx, key in enumerate(features_dict.keys())}
normed_features = self.feature_normalizer.process(
np.fromiter(features_dict.values(), dtype="float")
)
features_dict = {
key: normed_features[idx]
for idx, key in enumerate(features_dict.keys())
}

features_current = pd.Series(
data=list(features_dict.values()),
Expand Down
19 changes: 11 additions & 8 deletions tests/test_initalization_offline_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@


def test_stream_init():
"""Test if stream initialization with passed data will setup nm_channels correctly
"""
"""Test if stream initialization with passed data will setup nm_channels correctly"""
np.random.seed(0)
data = np.random.random((10, 1000))
sfreq = 100
Expand All @@ -15,17 +14,17 @@ def test_stream_init():
assert stream.nm_channels.shape[0] == 10
assert stream.settings["sampling_rate_features_hz"] == 11


def test_stream_init_no_sfreq():
"""Check if stream initialization without sfreq will raise an error
"""
"""Check if stream initialization without sfreq will raise an error"""
np.random.seed(0)
data = np.random.random((10, 1000))
with pytest.raises(Exception):
nm.Stream(data=data, sampling_rate_features_hz=11)


def test_init_warning_no_used_channel():
"""Check if a warning is raised when a stream is initialized with nm_channels, but no row has used == 1 and target == 0
"""
"""Check if a warning is raised when a stream is initialized with nm_channels, but no row has used == 1 and target == 0"""
np.random.seed(0)
data = np.random.random((10, 1000))
sfreq = 1000
Expand All @@ -34,5 +33,9 @@ def test_init_warning_no_used_channel():
channels["used"] = 0

with pytest.raises(Exception):
nm.Stream(sfreq=sfreq, data=data, nm_channels=channels, sampling_rate_features_hz=11)

nm.Stream(
sfreq=sfreq,
data=data,
nm_channels=channels,
sampling_rate_features_hz=11,
)
87 changes: 47 additions & 40 deletions tests/test_rereference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
nm_define_nmchannels,
)


@pytest.fixture
def setup():
"""This test function sets a data batch and automatic initialized M1 datafram
Expand All @@ -27,7 +28,13 @@ def setup():
fs (float): example sampling frequency
"""

RUN_NAME, PATH_RUN, PATH_BIDS, PATH_OUT, datatype = nm_IO.get_paths_example_data()
(
RUN_NAME,
PATH_RUN,
PATH_BIDS,
PATH_OUT,
datatype,
) = nm_IO.get_paths_example_data()

(
raw,
Expand All @@ -41,24 +48,17 @@ def setup():
)

settings = nm_settings.get_default_settings()
settings = nm_settings.set_settings_fast_compute(
settings
)
settings = nm_settings.set_settings_fast_compute(settings)

generator = nm_generator.raw_data_generator(
data, settings, math.floor(sfreq)
)
data_batch = next(generator, None)

return [
raw.ch_names,
raw.get_channel_types(),
raw.info["bads"],
data_batch
]

def test_rereference_not_used_channels_no_reref(setup):
return [raw.ch_names, raw.get_channel_types(), raw.info["bads"], data_batch]


def test_rereference_not_used_channels_no_reref(setup):
ch_names, ch_types, bads, data_batch = setup

nm_channels = nm_define_nmchannels.set_channels(
Expand All @@ -70,20 +70,21 @@ def test_rereference_not_used_channels_no_reref(setup):
used_types=("ecog", "dbs", "seeg"),
target_keywords=("MOV_RIGHT",),
)

re_referencer = ReReferencer(1, nm_channels)
ref_dat = re_referencer.process(data_batch)

# select here data that will is selected, this operation takes place in the nm_run_analysis
data_used = data_batch[nm_channels["used"] == 1]

ref_dat = re_referencer.process(data_used)

for no_ref_idx in np.where(
(nm_channels.rereference == "None") & nm_channels.used
== 1
(nm_channels.rereference == "None") & nm_channels.used == 1
)[0]:
assert_allclose(
ref_dat[no_ref_idx, :], data_batch[no_ref_idx, :]
)
assert_allclose(ref_dat[no_ref_idx, :], data_batch[no_ref_idx, :])

def test_rereference_car(setup):

def test_rereference_car(setup):
ch_names, ch_types, bads, data_batch = setup

nm_channels = nm_define_nmchannels.set_channels(
Expand All @@ -95,13 +96,15 @@ def test_rereference_car(setup):
used_types=("ecog", "dbs", "seeg"),
target_keywords=("MOV_RIGHT",),
)

re_referencer = ReReferencer(1, nm_channels)
ref_dat = re_referencer.process(data_batch)

data_used = data_batch[nm_channels["used"] == 1]

ref_dat = re_referencer.process(data_used)

for ecog_ch_idx in np.where(
(nm_channels["type"] == "ecog")
& (nm_channels.rereference == "average")
(nm_channels["type"] == "ecog") & (nm_channels.rereference == "average")
)[0]:
assert_allclose(
ref_dat[ecog_ch_idx, :],
Expand All @@ -112,8 +115,8 @@ def test_rereference_car(setup):
].mean(axis=0),
)

def test_rereference_bp(setup):

def test_rereference_bp(setup):
ch_names, ch_types, bads, data_batch = setup

nm_channels = nm_define_nmchannels.set_channels(
Expand All @@ -125,9 +128,12 @@ def test_rereference_bp(setup):
used_types=("ecog", "dbs", "seeg"),
target_keywords=("MOV_RIGHT",),
)

re_referencer = ReReferencer(1, nm_channels)
ref_dat = re_referencer.process(data_batch)

data_used = data_batch[nm_channels["used"] == 1]

ref_dat = re_referencer.process(data_used)

for bp_reref_idx in [
ch_idx
Expand All @@ -137,15 +143,14 @@ def test_rereference_bp(setup):
# bp_reref_idx is the channel index of the rereference anode
# referenced_bp_channel is the channel index which is the rereference cathode
referenced_bp_channel = np.where(
nm_channels.iloc[bp_reref_idx]["rereference"]
== nm_channels.name
nm_channels.iloc[bp_reref_idx]["rereference"] == nm_channels.name
)[0][0]
assert_allclose(
ref_dat[bp_reref_idx, :],
data_batch[bp_reref_idx, :]
- data_batch[referenced_bp_channel, :],
data_batch[bp_reref_idx, :] - data_batch[referenced_bp_channel, :],
)


def test_rereference_wrong_rererference_column_name(setup):
ch_names, ch_types, bads, data_batch = setup

Expand All @@ -158,13 +163,13 @@ def test_rereference_wrong_rererference_column_name(setup):
used_types=("ecog", "dbs", "seeg"),
target_keywords=("SQUARED_ROTATION",),
)

nm_channels.loc[0, "rereference"] = "hallo"
with pytest.raises(Exception) as e_info:
re_referencer = ReReferencer(1, nm_channels)

def test_rereference_muliple_channels(setup):

def test_rereference_muliple_channels(setup):
ch_names, ch_types, bads, data_batch = setup

nm_channels = nm_define_nmchannels.set_channels(
Expand All @@ -176,19 +181,22 @@ def test_rereference_muliple_channels(setup):
used_types=("ecog", "dbs", "seeg"),
target_keywords=("MOV_RIGHT",),
)

nm_channels.loc[0, "rereference"] = "LFP_RIGHT_1&LFP_RIGHT_2"

re_referencer = ReReferencer(1, nm_channels)
ref_dat = re_referencer.process(data_batch)

data_used = data_batch[nm_channels["used"] == 1]

ref_dat = re_referencer.process(data_used)

assert_allclose(
ref_dat[0, :],
data_batch[0, :] - (data_batch[1, :] + data_batch[2, :])/2
ref_dat[0, :],
data_batch[0, :] - (data_batch[1, :] + data_batch[2, :]) / 2,
)

def test_rereference_same_channel(setup):

def test_rereference_same_channel(setup):
ch_names, ch_types, bads, data_batch = setup

nm_channels = nm_define_nmchannels.set_channels(
Expand All @@ -202,7 +210,6 @@ def test_rereference_same_channel(setup):
)

nm_channels.loc[0, "rereference"] = nm_channels.loc[0, "name"]

with pytest.raises(Exception):
re_referencer = ReReferencer(1, nm_channels)

Loading

0 comments on commit a5389b2

Please sign in to comment.