Skip to content

Commit

Permalink
new load (in tb.py) & obligatory debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
sophie460 committed Oct 10, 2024
1 parent 2728e3b commit 1e674ce
Show file tree
Hide file tree
Showing 6 changed files with 1,010 additions and 14 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ wav2vec2-demo/
testing.py
*.lock
wav2vec2-pretrained-demo/
tb.py
runs/
logging_events/

Expand Down
20 changes: 11 additions & 9 deletions bids_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,21 @@
import torch
from pathlib import Path
from torch.utils.data import Dataset
from datasets import DatasetDicct, concatenate_datasets
from transformers import Wav2Vec2FeatureExtractor
import numpy

class BIDSBrainVisionDataset(Dataset):
def __init__(self, directory, channel_names, target_name, window_size=2.0, overlap=0.0, preload=True): #preload=False for goin easy on the RAM
def __init__(self, directory, channel_names, target_name, window_size=2.0, overlap=0.0, preload=True, feature_extractor=None): #preload=False for goin easy on the RAM
self.directory = Path(directory)
self.channel_names = channel_names
self.target_name = target_name
self.window_size = window_size
self.overlap = overlap
self.preload = preload
self.feature_extractor = feature_extractor

self.filepaths = list(self.directory.glob("*.vhdr"))

self.windows = []
self._prepare_dataset()

Expand All @@ -96,26 +98,26 @@ def _load_brainvision_file(self, filepath):

target, _ = raw[self.target_name, :]

x_data = torch.stack(ecogs, dim=1).unsqueeze(0) ####maybe worng stack (after permuting?)
x_data = torch.stack(ecogs, dim=0).unsqueeze(0) ####maybe worng stack (after permuting?)
x_data = (x_data - x_data.mean()) / x_data.std()
x_data = x_data.squeeze(1)
# x_data = x_data.squeeze(1)
# x_data = x_data.permute(0, 2, 1)
# x_data = x_data.mean(dim=1, keepdim=True) ####################################################

y_data = torch.tensor(target.T, dtype=torch.float32).reshape(1, 1, -1)
y_data = torch.tensor(target.T, dtype=torch.float32).reshape(1, -1)
#ydata: torch.Size([1, 1, 130001]), xdata: torch.Size([1, 6, 130001])
return x_data, y_data, raw.info['sfreq']

def _sliding_windows(self, data, window_size, overlap, sfreq):
step = int(window_size * sfreq)
overlap_step = int(overlap * sfreq)
data_length = data.shape[2]
data_length = data.shape[1]
windows = []

for x in range(0, data_length - step + 1, step - overlap_step):
stop = x + step
print(f"window from {x} to {stop}")
windows.append(data[:, :, x:stop])
# print(f"window from {x} to {stop}")
windows.append(data[:, x:stop])
# print(f"window size: {window_size}")
print(f"number of windows={len(windows)}")
return windows
Expand All @@ -136,7 +138,7 @@ def __len__(self):
return len(self.windows)

def __getitem__(self, idx):
return self.windows[idx]
x_window, y_window = self.windows[idx]

channel_names = ['ECOG_RIGHT_0', 'ECOG_RIGHT_1', 'ECOG_RIGHT_2', 'ECOG_RIGHT_3', 'ECOG_RIGHT_4', 'ECOG_RIGHT_5']
target_name = 'MOV_LEFT_CLEAN'
Expand Down
Loading

0 comments on commit 1e674ce

Please sign in to comment.