Skip to content

Commit

Permalink
and again update
Browse files Browse the repository at this point in the history
  • Loading branch information
Sophie committed Dec 5, 2024
1 parent 3b65af5 commit 45c593a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 20 deletions.
40 changes: 21 additions & 19 deletions timeseries_transformer/new_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,30 +481,20 @@ def __len__(self):

# dummy = {"bids": Dummy_imputation}

def npy_data(
root_dir, output_file, masking_ratio, mean_mask_length, mode, distribution, exclude_feats):
def npy_data(root_dir, output_file, masking_ratio, mean_mask_length, mode, distribution, exclude_feats):
"""
Preprocess `.npy` files, compute masks, and save the resulting dataset with metadata for use with `ImputationDataset`.
Args:
root_dir: Directory containing `.npy` files.
output_file: Path to save the preprocessed data.
masking_ratio: Fraction of data to mask.
mean_mask_length: Average length of masked segments.
mode: Masking mode ('separate' or 'block').
distribution: Mask length distribution ('geometric' or other).
exclude_feats: List of features/channels to exclude from masking.
"""
all_data = {}
feature_df = []

all_data = {}

for file_name in tqdm(os.listdir(root_dir), desc="Preprocessing .npy files"):
if file_name.endswith('.npy'):
file_path = os.path.join(root_dir, file_name)
key = os.path.splitext(file_name)[0]

array = np.load(file_path)
if array.shape[1] != 4:
raise ValueError(f"File {file_name} does not have 4 channel")
raise ValueError(f"File {file_name} does not have 4 channels")

transposed = array.T
time_points = transposed.shape[1]
Expand All @@ -516,14 +506,26 @@ def npy_data(
clip_key = f"{key}_{i}"

mask = noise_mask(clip.T, masking_ratio, mean_mask_length, mode, distribution, exclude_feats)

with open(output_file, 'wb') as f:
pickle.dump({"feature_df": pd.concat(all_data.values), "FileID": clip_key, "mask": mask}, f) #"data": all_data,
#'builtin_function_or_method' object is not iterable

all_data[clip_key] = {
"feature_df": pd.DataFrame(clip),
"mask": mask
}

with open(output_file, 'wb') as f:
pickle.dump(
{
"feature_df": pd.concat([entry["feature_df"] for entry in all_data.values()]),
"FileID": list(all_data.keys()),
"mask": [entry["mask"] for entry in all_data.values()]
},
f
)

print(f"Preprocessed data saved to {output_file}")



class newImputationDataset(Dataset):
"""
Dataset class to handle preprocessed `.npy` data with masks and metadata.
Expand All @@ -537,7 +539,7 @@ def __init__(self, preprocessed_file):
with open(preprocessed_file, 'rb') as f:
dataset = pickle.load(f)

self.data = dataset["data"]
self.data = dataset["feature_df"]
# self.feature_df = pd.concat(self.data.values(), ignore_index=True)
self.all_IDs = list(self.data.keys())

Expand Down
2 changes: 1 addition & 1 deletion timeseries_transformer/running.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# from utils import utils, analysis
from loss import l2_reg_loss
from mvts_transformer.src.datasets.dataset import ImputationDataset, TransductionDataset, ClassiregressionDataset, collate_unsuperv, collate_superv

from mvts_transformer.src.utils import utils

logger = logging.getLogger('__main__')

Expand Down

0 comments on commit 45c593a

Please sign in to comment.