Skip to content

Commit 13ebe6e

Browse files
committed
allow custom cache dir when initializing EEGDashDataset
1 parent a98275d commit 13ebe6e

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

src/eegdash/main.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -280,15 +280,17 @@ def collection(self):
280280
return self.__collection
281281

282282
class EEGDashDataset(BaseConcatDataset):
283-
CACHE_DIR = '.eegdash_cache'
283+
# CACHE_DIR = '.eegdash_cache'
284284
def __init__(
285285
self,
286286
query:dict=None,
287287
data_dir:str | list =None,
288288
dataset:str | list =None,
289289
description_fields: list[str]=['subject', 'session', 'run', 'task', 'age', 'gender', 'sex'],
290+
cache_dir:str='.eegdash_cache',
290291
**kwargs
291292
):
293+
self.cache_dir = cache_dir
292294
if query:
293295
datasets = self.find_datasets(query, description_fields, **kwargs)
294296
elif data_dir:
@@ -301,6 +303,7 @@ def __init__(
301303
datasets.extend(self.load_bids_dataset(dataset[i], data_dir[i], description_fields))
302304
# convert to list using get_item on each element
303305
super().__init__(datasets)
306+
304307

305308
def find_key_in_nested_dict(self, data, target_key):
306309
if isinstance(data, dict):
@@ -321,7 +324,7 @@ def find_datasets(self, query:dict, description_fields:list[str], **kwargs):
321324
value = self.find_key_in_nested_dict(record, field)
322325
if value:
323326
description[field] = value
324-
datasets.append(EEGDashBaseDataset(record, self.CACHE_DIR, description=description, **kwargs))
327+
datasets.append(EEGDashBaseDataset(record, self.cache_dir, description=description, **kwargs))
325328
return datasets
326329

327330
def load_bids_dataset(self, dataset, data_dir, description_fields: list[str],raw_format='eeglab', **kwargs):
@@ -334,7 +337,7 @@ def get_base_dataset_from_bids_file(bids_dataset, bids_file):
334337
value = self.find_key_in_nested_dict(record, field)
335338
if value:
336339
description[field] = value
337-
return EEGDashBaseDataset(record, self.CACHE_DIR, description=description, **kwargs)
340+
return EEGDashBaseDataset(record, self.cache_dir, description=description, **kwargs)
338341

339342
bids_dataset = EEGBIDSDataset(
340343
data_dir=data_dir,

0 commit comments

Comments
 (0)