Skip to content

Commit 1b0c5be

Browse files
committed
update 1 file and create 2 files: add data_postprocessing script
1 parent d032718 commit 1b0c5be

File tree

3 files changed

+324
-7
lines changed

3 files changed

+324
-7
lines changed

README.md

+15-7
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,24 @@ conda activate cam
3333
pip install -r requirements.txt
3434
```
3535

36-
3. Data Preprocessing:
37-
Please make sure you have gone through [FreeSurfer's](https://surfer.nmr.mgh.harvard.edu/fswiki/recon-all) `recon-all pipeline` to extract the cortical surface features. The surface features should be found under each subject's `surf` directory.
36+
# Data
37+
To easily demonstrate the usage of CAM, we provide a toy dataset in the `data` directory. The toy dataset contains 10 subjects from [IXI dataset](https://brain-development.org/ixi-dataset/). For each subject we will extract 4 cortical surface features using FreeSurfer (Curvature, Sulci, Thickness, Volume).
3838

39+
1. Data Preprocessing:
40+
Please make sure you have gone through [FreeSurfer's](https://surfer.nmr.mgh.harvard.edu/fswiki/recon-all) `recon-all pipeline` to extract the cortical surface features. The surface features should be found under each subject's `surf` directory. You can find the already processed data in the `data/freesurfer` directory.
3941

40-
4. Data Postprocessing:
42+
2. Data Postprocessing:
4143
Here we provide a simple script to convert the surface features to a numpy array.
4244
```bash
43-
python data_postprocessing.py --data_dir /path/to/your/freesurfer/output --output_dir /path/to/your/postprocessed/data
44-
```
45+
# training set
46+
python src/data_postprocessing.py --freesurfer_dir data/freesurfer/ --subject_list data/train_subjects.txt --output_dir data/sphere/train/ --in_ch thickness volume curv sulc --annot_file aparc --hemi lh
4547

48+
# validation set
49+
python src/data_postprocessing.py --freesurfer_dir data/freesurfer/ --subject_list data/val_subjects.txt --output_dir data/sphere/val/ --in_ch thickness volume curv sulc --annot_file aparc --hemi lh
4650

47-
# Data
48-
To easily demonstrate the usage of CAM, we provide a toy dataset in the `data` directory. The toy dataset contains 10 subjects, each with 3 cortical surface features extracted by FreeSurfer (Thickness, Sulc, Curvature).
51+
# testing set
52+
python src/data_postprocessing.py --freesurfer_dir data/freesurfer/ --subject_list data/test_subjects.txt --output_dir data/sphere/test/ --in_ch thickness volume curv sulc --annot_file aparc --hemi lh
53+
```
4954

5055

5156
# Training
@@ -73,4 +78,7 @@ If you find this repository useful for your research, please use the following.
7378

7479
# Acknowledgments/References
7580
1. IXI data: https://brain-development.org/ixi-dataset/
81+
2. Sphere postprocessing code borrowed from:
82+
- [surface-vision-transformers](https://github.com/metrics-lab/surface-vision-transformers)
83+
- [SPHARM-Net](https://github.com/Shape-Lab/SPHARM-Net)
7684
3. We would like to thank all participants in this study, making the work possible. This work was supported the German Research Foundation (DFG) Emmy Noether with reference 513851350 (TW), the Cluster of Excellence with reference 390727645 (TW) and the BMBF-funded de.NBI Cloud within the German Network for Bioinformatics Infrastructure (de.NBI) (031A532B, 031A533A, 031A533B, 031A534A, 031A535A, 031A537A, 031A537B, 031A537C, 031A537D, 031A538A).

src/data_postprocessing.py

+156
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
"""
2+
https://github.com/metrics-lab/surface-vision-transformers/blob/main/tools/preprocessing.py
3+
4+
triangle_indices_ico_6_sub_ico_1 -> ico6_80_561
5+
num_patches: 80
6+
num_vertices: 561
7+
8+
triangle_indices_ico_6_sub_ico_2 -> ico6_320_153
9+
num_patches: 320
10+
num_vertices: 153
11+
"""
12+
13+
# %% import
14+
import argparse
15+
16+
import joblib
17+
import pandas as pd
18+
import pyrootutils
19+
from joblib import Parallel, delayed
20+
from tqdm import tqdm
21+
22+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
23+
from src.utils.feature_extract import RunningStats, get_patch_data
24+
from src.utils.spharmnet.lib.io import read_mesh
25+
26+
# %% args
27+
# ------------------------------------------------------------------------------
28+
parser = argparse.ArgumentParser(description="Extract sphere from freesurfer")
29+
30+
# paths
31+
parser.add_argument(
32+
"--freesurfer_dir",
33+
type=str,
34+
default="data/freesurfer/",
35+
help="Path to FreeSurfer output directory",
36+
)
37+
parser.add_argument(
38+
"--subject_list",
39+
type=str,
40+
default="data/train_subjects.txt",
41+
help="List of subjects to process",
42+
)
43+
parser.add_argument(
44+
"--ico6_sphere_path",
45+
type=str,
46+
default="src/utils/ico6.vtk",
47+
help="Path to ico6 sphere",
48+
)
49+
50+
parser.add_argument(
51+
"--output_dir",
52+
type=str,
53+
default="data/sphere/train/",
54+
help="FreeSurfer sphere output directory",
55+
)
56+
57+
# features
58+
parser.add_argument(
59+
"--in_ch",
60+
type=str,
61+
default=["thickness", "volume", "curv", "sulc"],
62+
nargs="+",
63+
help="List of geometry to process",
64+
)
65+
parser.add_argument(
66+
"--annot_file",
67+
type=str,
68+
default="aparc",
69+
choices=["aparc", "aparc.a2009s"],
70+
help="Manual labels (e.g. aparc for ?h.aparc.annot)",
71+
)
72+
parser.add_argument(
73+
"--hemi",
74+
type=str,
75+
default="lh",
76+
choices=["lh", "rh"],
77+
help="Hemisphere for data generation",
78+
)
79+
parser.add_argument(
80+
"--n_jobs",
81+
type=int,
82+
default=-1,
83+
help="# of CPU n_jobs for parallel data generation",
84+
)
85+
args, unknown = parser.parse_known_args()
86+
87+
88+
# %% main
89+
# ------------------------------------------------------------------------------
90+
def main(args):
91+
# init
92+
proj_root_dir = pyrootutils.find_root()
93+
out_dir = proj_root_dir / args.output_dir
94+
out_dir.mkdir(parents=True, exist_ok=True)
95+
96+
# load subject list
97+
with open(proj_root_dir / args.subject_list, "r") as f:
98+
subjects = f.read().splitlines()
99+
subjects = [proj_root_dir / sub for sub in subjects]
100+
101+
# load ico mesh & triangle indices
102+
ico_v, _ = read_mesh(
103+
str(proj_root_dir / args.ico6_sphere_path)
104+
) # ico_v: ico vertices (40962, 3)
105+
patch_ids_path = proj_root_dir / "src/utils/ico6_320_153.csv"
106+
triangle_mesh_indices = pd.read_csv(patch_ids_path)
107+
108+
# extract feature
109+
# ------------------------------------------------------------------------------
110+
print(f"Extractiing {args.subject_list}: {args.in_ch}")
111+
sphere_data = Parallel(n_jobs=args.n_jobs)(
112+
delayed(get_patch_data)(
113+
ico_v=ico_v,
114+
triangle_mesh_indices=triangle_mesh_indices,
115+
in_ch=args.in_ch,
116+
annot_file=args.annot_file,
117+
sub=sub,
118+
hemi=args.hemi,
119+
)
120+
for sub in tqdm(subjects, desc=f"{args.hemi}")
121+
)
122+
123+
# store sphere data & phenotypic data in pkl file
124+
# ------------------------------------------------------------------------------
125+
running_stats = {channel: RunningStats() for channel in args.in_ch}
126+
for sub_folder, feat_patches, roi_anno, structure_map in sphere_data:
127+
# save to pkl file
128+
sub = sub_folder.name
129+
pkl_file = f"{out_dir}/{sub}.pkl"
130+
joblib.dump(
131+
{
132+
"feat_patches": feat_patches,
133+
"roi_anno": roi_anno,
134+
"structure_map": structure_map,
135+
},
136+
pkl_file,
137+
)
138+
139+
# udpate running stats (mean, std) for each channel
140+
for channel in args.in_ch:
141+
running_stats[channel].update(feat_patches[channel])
142+
143+
running_stats = {
144+
channel: {
145+
"mean": running_stats[channel].get_mean(),
146+
"std": running_stats[channel].get_std(),
147+
}
148+
for channel in args.in_ch
149+
}
150+
print(f"Running stats: {running_stats}")
151+
joblib.dump(running_stats, f"{out_dir}/stats.pkl")
152+
153+
154+
# %% main
155+
if __name__ == "__main__":
156+
main(args)

src/utils/feature_extract.py

+153
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
#%% import
2+
import pyrootutils
3+
import numpy as np
4+
import traceback
5+
import pandas as pd
6+
import os
7+
8+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
9+
from src.utils.spharmnet.lib.io import read_annot, read_feat, read_mesh
10+
from src.utils.spharmnet.lib.sphere import TriangleSearch
11+
12+
13+
# %% extract patch data from surface
14+
def get_patch_data(
15+
ico_v: np.array,
16+
triangle_mesh_indices: pd.DataFrame,
17+
in_ch: list[str] = ["area", "sphere", "thickness", "volume", "curv", "sulc", "inflated.H"],
18+
annot_file: str = "aparc",
19+
sub: str = "data/freesurfer/sub-IXI031",
20+
hemi: str = "lh",
21+
) -> tuple[str, dict[np.array], list]:
22+
23+
# paths
24+
surf_dir = sub / "surf"
25+
label_dir = sub / "label"
26+
27+
# load native sphere
28+
# ------------------------------------------------------------------------------
29+
try:
30+
sphere_path = os.path.join(surf_dir, hemi + "." + "sphere")
31+
native_v, native_f = read_mesh(sphere_path)
32+
except FileNotFoundError as e:
33+
print(f"\tsub: {sub} | Error: File {sphere_path} not found.")
34+
# raise e # Re-raise the exception to see the full traceback
35+
return None, None, None, None
36+
except Exception as e:
37+
print(f"\tsub: {sub} | An error occurred while reading the mesh:\n{e}")
38+
traceback.print_exc()
39+
return None, None, None, None
40+
try:
41+
tree = TriangleSearch(native_v, native_f)
42+
triangle_idx, bary_coeff = tree.query(ico_v)
43+
except Exception as e:
44+
print(f"\tsub: {sub} | An error occurred during triangle search and query:\n{e}")
45+
traceback.print_exc()
46+
return None, None, None, None
47+
48+
# extract sphere features
49+
# ------------------------------------------------------------------------------
50+
try:
51+
feat_patches = {feat_name: None for feat_name in in_ch}
52+
for feat_name in in_ch:
53+
# load surface feature
54+
feat_path = os.path.join(surf_dir, hemi + "." + feat_name)
55+
try:
56+
feat = read_feat(feat_path) # feat: features (115231, 1)
57+
except Exception as feat_read_error:
58+
print(f"\tsub: {sub} | Error reading feature '{feat_name}' from '{feat_path}': {feat_read_error}")
59+
traceback.print_exc()
60+
return None, None, None, None
61+
62+
# remesh surf feature: 115231 -> 40962
63+
try:
64+
feat_remesh = np.multiply(feat[native_f[triangle_idx]], bary_coeff).sum(
65+
-1
66+
) # feat_remesh: features (40962, 1)
67+
assert feat_remesh.shape[0] == ico_v.shape[0], f"feat_remesh.shape[0] != ico_v.shape[0]"
68+
except Exception as feat_processing_error:
69+
print(f"\tsub: {sub} | Error processing feature '{feat_name}': {feat_processing_error}")
70+
traceback.print_exc()
71+
return None, None, None, None
72+
73+
# extract triangle patches
74+
try:
75+
data = feat_remesh[triangle_mesh_indices.values].T # num_patches x num_vertices
76+
feat_patches[feat_name] = data
77+
except Exception as feat_extract_error:
78+
print(f"\tsub: {sub} | Error extracting feature '{feat_name}': {feat_extract_error}")
79+
traceback.print_exc()
80+
return None, None, None, None
81+
82+
except Exception as e:
83+
print(f"\tsub: {sub} | An error occurred during feature extraction:\n{e}")
84+
traceback.print_exc()
85+
return None, None, None, None
86+
87+
# extract labels
88+
# ------------------------------------------------------------------------------
89+
try:
90+
# laod annotation
91+
num_vert = native_v.shape[0]
92+
label_arr = np.zeros(num_vert, dtype=np.int16)
93+
annot = os.path.join(label_dir, hemi + "." + annot_file + ".annot")
94+
try:
95+
vertices, label, sturcture_ls, structureID_ls = read_annot(
96+
annot
97+
) # vertices: vertex indices (115231,), label: labels (115231,), sturcture_ls: structure names (36,), structureID_ls: structure IDs (36,)
98+
except Exception as annot_read_error:
99+
print(f"\tsub: {sub} | Error reading annotation from '{annot}': {annot_read_error}")
100+
traceback.print_exc()
101+
return None, None, None, None
102+
103+
# remesh roi label: 115231 -> 40962
104+
try:
105+
label = [structureID_ls.index(l) if l in structureID_ls else 0 for l in label]
106+
label_arr[vertices] = label
107+
label_remesh = label_arr[
108+
native_f[triangle_idx, np.argmax(bary_coeff, axis=1)]
109+
] # label_remesh: labels (40962,)
110+
assert label_remesh.shape[0] == ico_v.shape[0], "label_remesh.shape[0] != ico_v.shape[0]"
111+
except Exception as label_processing_error:
112+
print(f"\tsub: {sub} | Error processing label: {label_processing_error}")
113+
traceback.print_exc()
114+
return None, None, None, None
115+
116+
# extract triangle patches
117+
try:
118+
label_remesh = label_remesh[triangle_mesh_indices.values].T # num_patches x num_vertices
119+
except Exception as label_extract_error:
120+
print(f"\tsub: {sub} | Error extracting label: {label_extract_error}")
121+
traceback.print_exc()
122+
return None, None, None, None
123+
124+
except Exception as e:
125+
print(f"\tsub: {sub} | An error occurred during label extraction:\n{e}")
126+
traceback.print_exc()
127+
return None, None, None, None
128+
129+
# extract structure map
130+
structure_map = list(enumerate(sturcture_ls))
131+
return sub, feat_patches, label_remesh, structure_map
132+
133+
134+
#%% calcualte running stats
135+
class RunningStats:
136+
def __init__(self):
137+
self.N = 0
138+
self.mean = 0.0
139+
self.M2 = 0.0
140+
141+
def update(self, data):
142+
self.N += 1
143+
self.mean += np.mean(data)
144+
self.M2 += np.mean(data**2)
145+
146+
def get_mean(self):
147+
return self.mean / self.N
148+
149+
def get_std(self):
150+
mean = self.mean / self.N
151+
m2 = self.M2 / self.N
152+
return np.sqrt(m2 - mean**2)
153+
# %%

0 commit comments

Comments
 (0)