diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..bd71867 --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +[flake8] +ignore = E203, E266, E501, W503 +max-line-length = 88 +max-complexity = 18 +select = B,C,E,F,W,T4,B9 diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 0000000..04455a8 --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,7 @@ +[settings] +known_third_party =deepreg,tensorflow,torchio,tqdm +multi_line_output = 3 +include_trailing_comma = True +force_grid_wrap = 0 +use_parentheses = True +line_length = 88 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..67b4bf1 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,47 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.4.0 + hooks: + - id: check-ast # Simply check whether the files parse as valid python + - id: check-case-conflict # Check for files that would conflict in case-insensitive filesystems + - id: check-builtin-literals # Require literal syntax when initializing empty or zero Python builtin types + - id: check-docstring-first # Check a common error of defining a docstring after code + - id: check-merge-conflict # Check for files that contain merge conflict strings + - id: check-yaml # Check yaml files + - id: check-vcs-permalinks # Ensure that links to vcs websites are permalinks + - id: debug-statements # Check for debugger imports and py37+ `breakpoint()` calls in python source + - id: detect-private-key # Detect the presence of private keys + - id: end-of-file-fixer # Ensure that a file is either empty, or ends with one newline + - id: mixed-line-ending # Replace or checks mixed line ending + - id: trailing-whitespace # This hook trims trailing whitespace + - id: file-contents-sorter # Sort the lines in specified files + files: .*requirements.*\.txt$ + - repo: https://github.com/asottile/seed-isort-config + rev: v2.2.0 + hooks: + - id: seed-isort-config + - repo: https://github.com/timothycrosley/isort + rev: 5.7.0 + hooks: + - id: isort + - repo: https://github.com/psf/black + rev: 20.8b1 + hooks: + - id: black + language_version: python3.7 + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v2.2.1 + hooks: + - id: prettier + - repo: https://gitlab.com/pycqa/flake8 + rev: 3.8.4 + hooks: + - id: flake8 + - repo: https://github.com/pycqa/pydocstyle + rev: 5.1.1 # pick a git hash / tag to point to + hooks: + - id: pydocstyle + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.800 + hooks: + - id: mypy diff --git a/.prettierignore b/.prettierignore new file mode 100644 index 0000000..4467cb7 --- /dev/null +++ b/.prettierignore @@ -0,0 +1 @@ +docs/joss_paper/paper.md diff --git a/.prettierrc b/.prettierrc new file mode 100644 index 0000000..6d2f1b0 --- /dev/null +++ b/.prettierrc @@ -0,0 +1,6 @@ +{ + "printWidth": 88, + "proseWrap": "always", + "useTabs": false, + "tabWidth": 2 +} diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..7a9f360 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,336 @@ +[MASTER] + +# Specify a configuration file. +#rcfile= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Profiled execution. +profile=no + +# Add files or directories to the denylist. They should be base names, not +# paths. +ignore=CVS + +# Pickle collected data for later comparisons. +persistent=yes + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins=pylint.extensions.docparams +accept-no-param-doc=no + +[MESSAGES CONTROL] + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time. See also the "--disable" option for examples. +enable=indexing-exception,old-raise-syntax + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" +disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,not-context-manager,invalid-sequence-index,unexpected-keyword-arg,no-value-for-parameter + + +# Set the cache size for astng objects. +cache-size=500 + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Put messages in a separate file for each module / package specified on the +# command line instead of printing them on stdout. Reports (if any) will be +# written in a file name "pylint_global.[txt|html]". +files-output=no + +# Tells whether to display a full report or only the messages +reports=yes + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Add a comment according to your evaluation note. This is used by the global +# evaluation report (RP0004). +comment=no + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[TYPECHECK] + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of classes names for which member attributes should not be checked +# (useful for classes with attributes dynamically set). +ignored-classes=SQLObject + +# When zope mode is activated, add a predefined set of Zope acquired attributes +# to generated-members. +zope=no + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E0201 when accessed. Python regular +# expressions are accepted. +generated-members=REQUEST,acl_users,aq_parent + +# List of decorators that create context managers from functions, such as +# contextlib.contextmanager. +contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the beginning of the name of dummy variables +# (i.e. not used). +dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + + +[BASIC] + +# Required attributes for module, separated by a comma +required-attributes= + +# List of builtins function names that should not be used, separated by a comma +bad-functions=apply,input,reduce + + +# Disable the report(s) with the given id(s). +# All non-Google reports are disabled by default. +disable-report=R0001,R0002,R0003,R0004,R0101,R0102,R0201,R0202,R0220,R0401,R0402,R0701,R0801,R0901,R0902,R0903,R0904,R0911,R0912,R0913,R0914,R0915,R0921,R0922,R0923 + +# Regular expression which should only match correct module names +module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ + +# Regular expression which should only match correct module level names +const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression which should only match correct class names +class-rgx=^_?[A-Z][a-zA-Z0-9]*$ + +# Regular expression which should only match correct function names +function-rgx=^(?:(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ + +# Regular expression which should only match correct method names +method-rgx=^(?:(?P__[a-z0-9_]+__|next)|(?P_{0,2}[A-Z][a-zA-Z0-9]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ + +# Regular expression which should only match correct instance attribute names +attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct argument names +argument-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct variable names +variable-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct attribute names in class +# bodies +class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression which should only match correct list comprehension / +# generator expression variable names +inlinevar-rgx=^[a-z][a-z0-9_]*$ + +# Good variable names which should always be accepted, separated by a comma +good-names=main,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=(__.*__|main) + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=1 + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=88 + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=(?x) + (^\s*(import|from)\s + |\$Id:\s\/\/depot\/.+#\d+\s\$ + |^[a-zA-Z_][a-zA-Z0-9_]*\s*=\s*("[^"]\S+"|'[^']\S+') + |^\s*\#\ LINT\.ThenChange + |^[^#]*\#\ type:\ [a-zA-Z_][a-zA-Z0-9_.,[\] ]*$ + |pylint + |""" + |\# + |lambda + |(https?|ftp):) + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=y + +# List of optional constructs for which whitespace checking is disabled +no-space-check= + +# Maximum number of lines in a module +max-module-lines=1000 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes= + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub,TERMIOS,Bastion,rexec,sets + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + + +[CLASSES] + +# List of interface methods to ignore, separated by a comma. This is used for +# instance to not check methods defines in Zope's Interface base class. +ignore-iface-methods=isImplementedBy,deferred,extends,names,namesAndDescriptions,queryDescriptionFor,getBases,getDescriptionFor,getDoc,getName,getTaggedValue,getTaggedValueTags,isEqualOrExtendedBy,setTaggedValue,isImplementedByInstancesOf,adaptWith,is_implemented_by + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__,__new__,setUp + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls,class_ + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# Maximum number of arguments for function / method +max-args=5 + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore +ignored-argument-names=_.* + +# Maximum number of locals for function / method body +max-locals=15 + +# Maximum number of return / yield for function / method body +max-returns=6 + +# Maximum number of branch for function / method body +max-branches=12 + +# Maximum number of statements in function / method body +max-statements=50 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception" +overgeneral-exceptions=Exception,StandardError,BaseException + + +[AST] + +# Maximum line length for lambdas +short-func-length=1 + +# List of module members that should be marked as deprecated. +# All of the string functions are listed in 4.1.4 Deprecated string functions +# in the Python 2.4 docs. +deprecated-members=string.atof,string.atoi,string.atol,string.capitalize,string.expandtabs,string.find,string.rfind,string.index,string.rindex,string.count,string.lower,string.split,string.rsplit,string.splitfields,string.join,string.joinfields,string.lstrip,string.rstrip,string.strip,string.swapcase,string.translate,string.upper,string.ljust,string.rjust,string.center,string.zfill,string.replace,sys.exitfunc + + +[DOCSTRING] + +default-docstring-type=sphinx +# List of exceptions that do not need to be mentioned in the Raises section of +# a docstring. +ignore-exceptions=AssertionError,NotImplementedError,StopIteration,TypeError,ValueError + + + +[TOKENS] + +# Number of spaces of indent required when the last token on the preceding line +# is an open (, [, or {. +indent-after-paren=4 + + +[GOOGLE LINES] + +# Regexp for a proper copyright notice. +copyright=Copyright \d{4} The TensorFlow Authors\. +All [Rr]ights [Rr]eserved\. diff --git a/benchmark/__init__.py b/benchmark/__init__.py new file mode 100644 index 0000000..5e58a39 --- /dev/null +++ b/benchmark/__init__.py @@ -0,0 +1 @@ +"""Benchmark with other methods.""" diff --git a/benchmark/balakrishnan2019/__init__.py b/benchmark/balakrishnan2019/__init__.py new file mode 100644 index 0000000..3261ebc --- /dev/null +++ b/benchmark/balakrishnan2019/__init__.py @@ -0,0 +1 @@ +"""Reproduce https://arxiv.org/abs/1809.05231.""" diff --git a/benchmark/balakrishnan2019/config_balakrishnan_2019.yaml b/benchmark/balakrishnan2019/config_balakrishnan_2019.yaml new file mode 100644 index 0000000..781e5a3 --- /dev/null +++ b/benchmark/balakrishnan2019/config_balakrishnan_2019.yaml @@ -0,0 +1,50 @@ +dataset: + dir: + train: "/raid/candi/Yunguan/DeepReg/neuroimaging/preprocessed" # required + valid: + test: + format: "nifti" + type: "unpaired" # paired / unpaired / grouped + labeled: false # whether to use the labels if available, "true" or "false" + image_shape: [192, 224, 192] + +train: + # define neural network structure + method: "ddf" # options include "ddf", "dvf", "conditional" + backbone: + name: "vm_balakrishnan_2019" # options include "local", "unet" and "global" + num_channel_initial: 16 # number of initial channel in local net, controls the size of the network + depth: 4 + concat_skip: true + encode_num_channels: [16, 32, 32, 32, 32] + decode_num_channels: [32, 32, 32, 32, 32] + + # define the loss function for training + loss: + image: + name: "lncc" # other options include "lncc", "ssd" and "gmi", for local normalised cross correlation, + weight: 1.0 + label: + weight: 0.0 + name: "dice" # options include "dice", "cross-entropy", "mean-squared", "generalised_dice" and "jaccard" + regularization: + weight: 1.0 # weight of regularization loss + name: "gradient" # options include "bending", "gradient" + + # define the optimizer + optimizer: + name: "adam" # options include "adam", "sgd" and "rms" + adam: + learning_rate: 1.0e-4 + + # define the hyper-parameters for preprocessing + preprocess: + data_augmentation: + name: "affine" + batch_size: 2 + shuffle_buffer_num_batch: 1 # shuffle_buffer_size = batch_size * shuffle_buffer_num_batch + + # other training hyper-parameters + epochs: 2 # number of training epochs + save_period: 2 # the model will be saved every `save_period` epochs. + update_freq: 50 diff --git a/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py new file mode 100644 index 0000000..808c4a2 --- /dev/null +++ b/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py @@ -0,0 +1,233 @@ +"""This script provides an example of using custom backbone for training.""" + +import argparse +from typing import Tuple, Union + +import tensorflow as tf +import tensorflow.keras.layers as tfkl +from deepreg.model.backbone import UNet +from deepreg.registry import REGISTRY +from deepreg.train import train + + +@REGISTRY.register_backbone(name="vm_balakrishnan_2019") +class VoxelMorphBalakrishnan2019(UNet): + """Reproduce https://arxiv.org/abs/1809.05231.""" + + def __init__(self, **kwargs): + """ + Init. + + Args: + **kwargs: + """ + super().__init__(**kwargs) + + self._out_ddf_upsampling = tf.keras.layers.UpSampling3D(size=2) + self._out_ddf_conv = tfkl.Conv3D( + filters=3, + kernel_size=3, + padding="same", + activation=self.get_activation(), + ) + + def build_encode_conv_block( + self, filters: int, kernel_size: int, padding: str + ) -> Union[tf.keras.Model, tfkl.Layer]: + """ + Build a conv block for down-sampling. + + :param filters: number of channels for output + :param kernel_size: arg for conv3d + :param padding: arg for conv3d + :return: a block consists of one or multiple layers + """ + return tfkl.Conv3D( + filters=filters, + kernel_size=kernel_size, + padding=padding, + strides=2, + activation=self.get_activation(), + ) + + def build_down_sampling_block( + self, filters: int, kernel_size: int, padding: str, strides: int + ) -> Union[tf.keras.Model, tfkl.Layer]: + """ + Return identity layer. + + :param filters: number of channels for output, arg for conv3d + :param kernel_size: arg for pool3d or conv3d + :param padding: arg for pool3d or conv3d + :param strides: arg for pool3d or conv3d + :return: a block consists of one or multiple layers + """ + return tfkl.Lambda(lambda x: x) + + def build_bottom_block( + self, filters: int, kernel_size: int, padding: str + ) -> Union[tf.keras.Model, tfkl.Layer]: + """ + Return down sample layer. + + :param filters: number of channels for output + :param kernel_size: arg for conv3d + :param padding: arg for conv3d + :return: a block consists of one or multiple layers + """ + return tfkl.Conv3D( + filters=filters, + kernel_size=kernel_size, + padding=padding, + strides=2, + activation=self.get_activation(), + ) + + def build_up_sampling_block( + self, + filters: int, + output_padding: int, + kernel_size: int, + padding: str, + strides: int, + output_shape: tuple, + ) -> Union[tf.keras.Model, tfkl.Layer]: + """ + Build a block for up-sampling. + + This block changes the tensor shape (width, height, depth), + but it does not changes the number of channels. + + :param filters: number of channels for output + :param output_padding: padding for output + :param kernel_size: arg for deconv3d + :param padding: arg for deconv3d + :param strides: arg for deconv3d + :param output_shape: shape of the output tensor + :return: a block consists of one or multiple layers + """ + return tf.keras.layers.UpSampling3D(size=strides) + + def build_decode_conv_block( + self, filters: int, kernel_size: int, padding: str + ) -> Union[tf.keras.Model, tfkl.Layer]: + """ + Build a conv block for up-sampling. + + :param filters: number of channels for output + :param kernel_size: arg for conv3d + :param padding: arg for conv3d + :return: a block consists of one or multiple layers + """ + return tfkl.Conv3D( + filters=filters, + kernel_size=kernel_size, + padding=padding, + strides=1, + activation=self.get_activation(), + ) + + def build_output_block( + self, + image_size: Tuple[int], + extract_levels: Tuple[int], + out_channels: int, + out_kernel_initializer: str, + out_activation: str, + ) -> Union[tf.keras.Model, tfkl.Layer]: + """ + Build a block for output. + + The input to this block is a list of tensors. + + :param image_size: such as (dim1, dim2, dim3) + :param extract_levels: number of extraction levels. + :param out_channels: number of channels for the extractions + :param out_kernel_initializer: initializer to use for kernels. + :param out_activation: activation to use at end layer. + :return: a block consists of one or multiple layers + """ + + class OutputBlock(tf.keras.Model): + def __init__(self, num_channel_initial, activation): + super().__init__() + self.conv1 = tfkl.Conv3D( + filters=num_channel_initial, + kernel_size=3, + padding="same", + activation=activation, + ) + self.conv2 = tfkl.Conv3D( + filters=num_channel_initial, + kernel_size=3, + padding="same", + kernel_initializer=tf.keras.initializers.RandomNormal( + mean=0.0, stddev=1e-5 + ), + ) + + def call(self, inputs, training=None, mask=None): + x = inputs[0] + x = self.conv1(x) + x = self.conv2(x) + return x + + return OutputBlock( + num_channel_initial=self.num_channel_initial, + activation=self.get_activation(), + ) + + def call(self, inputs: tf.Tensor, training=None, mask=None) -> tf.Tensor: + """ + Build LocalNet graph based on built layers. + + :param inputs: image batch, shape = (batch, f_dim1, f_dim2, f_dim3, ch) + :param training: None or bool. + :param mask: None or tf.Tensor. + :return: shape = (batch, f_dim1, f_dim2, f_dim3, out_channels) + """ + output = super().call(inputs=inputs, training=training, mask=mask) + # upsample again + output = self._out_ddf_upsampling(output) + output = tf.concat([inputs, output], axis=4) + output = self._out_ddf_conv(output) + return output + + def get_activation(self) -> tf.keras.layers.Layer: + """Return activation layer.""" + return tf.keras.layers.LeakyReLU(alpha=0.2) + + +def main(args=None): + """ + Launch training. + + Args: + args: + + """ + parser = argparse.ArgumentParser() + + parser.add_argument( + "--gpu", + "-g", + help="GPU index for training." + '-g "" for using CPU' + '-g "0" for using GPU 0' + '-g "0,1" for using GPU 0 and 1.', + type=str, + required=True, + ) + args = parser.parse_args(args) + + config_path = "config_balakrishnan_2019.yaml" + train( + gpu=args.gpu, + config_path=config_path, + gpu_allow_growth=True, + ckpt_path="", + ) + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9d4ed68 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +mypy +pre-commit>=2.10.0 +torchio +tqdm diff --git a/scripts/neuro_imaging_preprocess.py b/scripts/neuro_imaging_preprocess.py new file mode 100644 index 0000000..b94ec7a --- /dev/null +++ b/scripts/neuro_imaging_preprocess.py @@ -0,0 +1,101 @@ +""" +This script is for performing skull stripping on the affine-aligned datasets. + +The data are stored in multiple folders: +- matrices, storing the affine matrices for the affine registration +- mri, storing the MR images +- gif_parcellation, storing the parcellation of the MR images +- reference, storing one image parcellation pair + +The preprocessed files will be saved under +- preprocessed/images +- preprocessed/labels +- preprocessed/reference +""" +import glob +import os + +import torchio as tio +from tqdm import tqdm + +SMALLEST_BRAIN_LABEL = 24 # from colour table +data_folder_path = "/raid/candi/Yunguan/DeepReg/neuroimaging" +output_folder_path = f"{data_folder_path}/preprocessed" + +for folder_name in ["images", "labels"]: + _path = f"{output_folder_path}/{folder_name}" + if not os.path.exists(_path): + os.makedirs(_path) + +# get file paths +image_file_paths = glob.glob(f"{data_folder_path}/mri/*.nii.gz") +label_file_paths = glob.glob(f"{data_folder_path}/gif_parcellation/*.nii.gz") +matrix_file_paths = glob.glob(f"{data_folder_path}/matrices/*.txt") + +assert len(image_file_paths) == len(label_file_paths) == len(matrix_file_paths) +num_images = len(image_file_paths) + +image_file_paths = sorted(image_file_paths) +label_file_paths = sorted(label_file_paths) +matrix_file_paths = sorted(matrix_file_paths) + +# get unique IDs +image_file_names = [ + os.path.split(x)[1].replace(".nii.gz", "") for x in image_file_paths +] +label_file_names = [ + os.path.split(x)[1].replace(".nii.gz", "") for x in label_file_paths +] +matrix_file_names = [os.path.split(x)[1].replace(".txt", "") for x in matrix_file_paths] + +# images have suffix "_t1_pre_on_mni" +# labels have suffix "_t1_pre_NeuroMorph_Parcellation" or "-T1_NeuroMorph_Parcellation" +# matrices have suffix "_t1_pre_to_mni" +# verify sorted filenames are matching +for i in range(num_images): + image_fname = image_file_names[i] + label_fname = label_file_names[i] + label_fname = label_fname.replace( + "_t1_pre_NeuroMorph_Parcellation", "_t1_pre_on_mni" + ) + label_fname = label_fname.replace("-T1_NeuroMorph_Parcellation", "_t1_pre_on_mni") + matrix_fname = matrix_file_names[i] + matrix_fname = matrix_fname.replace("_t1_pre_to_mni", "_t1_pre_on_mni") + assert image_fname == label_fname == matrix_fname + + +def preprocess(image_path: str, label_path: str, matrix_path: str): + """ + Preprocess one data sample. + + Args: + image_path: file path for image + label_path: file path for parcellation + matrix_path: file path for affine matrix + """ + name = os.path.split(image_path)[1].replace("_pre_on_mni.nii.gz", "") + out_image_path = f"{output_folder_path}/images/{name}.nii.gz" + out_label_path = f"{output_folder_path}/labels/{name}.nii.gz" + + # resample parcellation to MNI + matrix = tio.io.read_matrix(matrix_path) + parcellation = tio.LabelMap(label_path, to_mni=matrix) + resample = tio.Resample(image_path, pre_affine_name="to_mni") + parcellation_mni = resample(parcellation) + parcellation_mni.save(out_label_path) + + # get brain mask + extract_brain = tio.Lambda(lambda x: (x >= SMALLEST_BRAIN_LABEL)) + brain_mask = extract_brain(parcellation_mni) + + # skull-stripping + mri = tio.ScalarImage(image_path) + mri.data[~brain_mask.data.bool()] = 0 + mri.save(out_image_path) + + +for image_path, label_path, matrix_path in tqdm( + zip(image_file_paths, label_file_paths, matrix_file_paths), + total=num_images, +): + preprocess(image_path=image_path, label_path=label_path, matrix_path=matrix_path)