Pancreatic Ductal Adenocarcinoma Segmentaion using a 3D U-Net and Utilizing Secondary Tumour-Indicative Features
Resources shared as part the papers:
- Improved Pancreatic Tumor Detection by Utilizing Clinically-Relevant Secondary Features - MICCAI Cancer Prevention through early detection, Conference
- Computer-Aided Detection for Pancreatic Cancer Diagnosis: Radiological Challenges and Future Directions - Journal of Clincal Medicine
- Clinical Segmentation for Improved Pancreatic Ductal Adenocarcinoma Detection and Segmentation - SPIE Medical Imaging, Conference
- Improved Pancreatic Cancer Detection and Localization on CT scans: A Computer-Aided Detection model utilizing Secondary Features - Cancers, journal
For details refer to the papers above.
Downsample the CT scans to 2x the preferred voxel resolution. The coarse pancreatic segmentation model can then be trained with:
train3dunet --config resources\3DUnet_pancreas\train\train_config_Pancreas.yamlRunning
predict3dunet --config resources\3DUnet_pancreas\test\pancreas_coarse.ymlwill enable coarse segmentations in the full CT scan. This is used to localize the pancreas efficiently.
The fine pancreatic segmentation model can then be trained with:
train3dunet --config resources\3DUnet_pancreas\train\train_config_Pancreas.yamlWe train this model on crops around the pancreas at the disired high resolution, speeding up training and eventually inference. Running
predict3dunet --config resources\3DUnet_pancreas\test\pancreas_fine.ymlwill enable fine segmentations in the cropped CT scan. This is used as input to the tumor segmentation model later on.
The models trained to segment a single secondary feature can be trained with the following
train3dunet --config resources\3DUnet_pancreas\train\train_config_CBD.yaml
train3dunet --config resources\3DUnet_pancreas\train\train_config_PD.yamlThey are trained on the high resolution crops around the pancreas. To predict these ducts, run
predict3dunet --config resources\3DUnet_pancreas\test\cbd_fine.yml
predict3dunet --config resources\3DUnet_pancreas\test\pd_fine.ymlThe output of these models are utilized by the tumor segmentaion model.
The model trained to segment a both secondary features simultaneously can be trained with the following
train3dunet --config resources\3DUnet_pancreas\train\train_config_PD_CBD.yamlIs is trained on the high resolution crops around the pancreas. To predict these ducts, run
predict3dunet --config resources\3DUnet_pancreas\test\cbd_pd_fine.ymlThe output of this model is utilized by the tumor segmentaion model.
The PDAC segmentaion model is trained on crops around the pancreas, concatinated with segmentation maps of the pancreas, common bile duct and pancreatic duct. It is trained using the following script
train3dunet --config resources\3DUnet_pancreas\train\train_config_Tumor.yamlAt test time, run the following:
predict3dunet --config resources\3DUnet_pancreas\test\tumor.ymlWe use the Medical Decathlon dataset - Task 07 Pancreas & Tumour. A few cases were supplimented with additional annotations of the pancreatic duct, common bile duct, pancreas and pancreatic tumour for this work. The new annotations and corresponding CT volumes (nifti) can be downloaded here.
The approach is based on https://github.com/wolny/pytorch-3dunet. For any further details feel free to reach out.
PyTorch implementation of 3D U-Net and its variants:
-
UNet3DStandard 3D U-Net based on 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation -
ResidualUNet3DResidual 3D U-Net based on Superhuman Accuracy on the SNEMI3D Connectomics Challenge -
ResidualUNetSE3DSimilar toResidualUNet3Dwith the addition of Squeeze and Excitation blocks based on Deep Learning Semantic Segmentation for High-Resolution Medical Volumes. Original squeeze and excite paper: Squeeze-and-Excitation Networks
The code allows for training the U-Net for both: semantic segmentation (binary and multi-class) and regression problems (e.g. de-noising, learning deconvolutions).
2D U-Net is also supported, see 2DUnet_confocal or 2DUnet_dsb2018 for example configuration.
Just make sure to keep the singleton z-dimension in your H5 dataset (i.e. (1, Y, X) instead of (Y, X)) , because data loading / data augmentation requires tensors of rank 3.
The 2D U-Net itself uses the standard 2D convolutional layers instead of 3D convolutions with kernel size (1, 3, 3) for performance reasons.
The input data should be stored in HDF5 files. The HDF5 files for training should contain two datasets: raw and label. Optionally, when training with PixelWiseCrossEntropyLoss one should provide weight dataset.
The raw dataset should contain the input data, while the label dataset the ground truth labels. The optional weight dataset should contain the values for weighting the loss function in different regions of the input and should be of the same size as label dataset.
The format of the raw/label datasets depends on whether the problem is 2D or 3D and whether the data is single-channel or multi-channel, see the table below:
| 2D | 3D | |
|---|---|---|
| single-channel | (1, Y, X) | (Z, Y, X) |
| multi-channel | (C, 1, Y, X) | (C, Z, Y, X) |
- NVIDIA GPU
- CUDA CuDNN
pytorch-3dunet is a cross-platform package and runs on Windows and OS X as well.
- The easiest way to install
pytorch-3dunetpackage is via conda/mamba:
conda install -c conda-forge mamba
mamba create -n pytorch-3dunet -c pytorch -c nvidia -c conda-forge pytorch pytorch-cuda=12.1 pytorch-3dunet
conda activate pytorch-3dunet
After installation the following commands are accessible within the conda environment:
train3dunet for training the network and predict3dunet for prediction (see below).
- One can also install directly from source:
python setup.py install
Make sure that the installed pytorch is compatible with your CUDA version, otherwise the training/prediction will fail to run on GPU.
Given that pytorch-3dunet package was installed via conda as described above, one can train the network by simply invoking:
train3dunet --config <CONFIG>
where CONFIG is the path to a YAML configuration file, which specifies all aspects of the training procedure.
In order to train on your own data just provide the paths to your HDF5 training and validation datasets in the config.
- sample config for 3D semantic segmentation (cell boundary segmentation): train_config_segmentation.yaml
- sample config for 3D regression task (denoising): train_config_regression.yaml
- more configs can be found in resources directory
One can monitor the training progress with Tensorboard tensorboard --logdir <checkpoint_dir>/logs/ (you need tensorflow installed in your conda env), where checkpoint_dir is the path to the checkpoint directory specified in the config.
- When training with binary-based losses, i.e.:
BCEWithLogitsLoss,DiceLoss,BCEDiceLoss,GeneralizedDiceLoss: The target data has to be 4D (one target binary mask per channel). When training withWeightedCrossEntropyLoss,CrossEntropyLoss,PixelWiseCrossEntropyLossthe target dataset has to be 3D, see also pytorch documentation for CE loss: https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html final_sigmoidin themodelconfig section applies only to the inference time (validation, test):- When training with
BCEWithLogitsLoss,DiceLoss,BCEDiceLoss,GeneralizedDiceLosssetfinal_sigmoid=True - When training with cross entropy based losses (
WeightedCrossEntropyLoss,CrossEntropyLoss,PixelWiseCrossEntropyLoss) setfinal_sigmoid=Falseso thatSoftmaxnormalization is applied to the output.
- When training with
Given that pytorch-3dunet package was installed via conda as described above, one can run the prediction via:
predict3dunet --config <CONFIG>
In order to predict on your own data, just provide the path to your model as well as paths to HDF5 test files (see example test_config_segmentation.yaml).
- If you're running prediction for a large dataset, consider using
LazyHDF5DatasetandLazyPredictorin the config. This will save memory by loading data on the fly at the cost of slower prediction time. See test_config_lazy for an example config. - If your model predicts multiple classes (see e.g. train_config_multiclass), consider saving only the final segmentation instead of the probability maps which can be time and space consuming.
To do so, set
save_segmentation: truein thepredictorsection of the config (see test_config_multiclass).
By default, if multiple GPUs are available training/prediction will be run on all the GPUs using DataParallel.
If training/prediction on all available GPUs is not desirable, restrict the number of GPUs using CUDA_VISIBLE_DEVICES, e.g.
CUDA_VISIBLE_DEVICES=0,1 train3dunet --config <CONFIG>or
CUDA_VISIBLE_DEVICES=0,1 predict3dunet --config <CONFIG>BCEWithLogitsLoss(binary cross-entropy)DiceLoss(standardDiceLossdefined as1 - DiceCoefficientused for binary semantic segmentation; when more than 2 classes are present in the ground truth, it computes theDiceLossper channel and averages the values)BCEDiceLoss(Linear combination of BCE and Dice losses, i.e.alpha * BCE + beta * Dice,alpha, betacan be specified in thelosssection of the config)CrossEntropyLoss(one can specify class weights via theweight: [w_1, ..., w_k]in thelosssection of the config)PixelWiseCrossEntropyLoss(one can specify per-pixel weights in order to give more gradient to the important/under-represented regions in the ground truth;weightdataset has to be provided in the H5 files for training and validation; see sample config in train_config.ymlWeightedCrossEntropyLoss(see 'Weighted cross-entropy (WCE)' in the below paper for a detailed explanation)GeneralizedDiceLoss(see 'Generalized Dice Loss (GDL)' in the below paper for a detailed explanation) Note: use this loss function only if the labels in the training dataset are very imbalanced e.g. one class having at least 3 orders of magnitude more voxels than the others. Otherwise, use standardDiceLoss.
For a detailed explanation of some of the supported loss functions see: Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations.
MSELoss(mean squared error loss)L1Loss(mean absolute error loss)SmoothL1Loss(less sensitive to outliers than MSELoss)WeightedSmoothL1Loss(extension of theSmoothL1Losswhich allows to weight the voxel values above/below a given threshold differently)
MeanIoU(mean intersection over union)DiceCoefficient(computes per channel Dice Coefficient and returns the average) If a 3D U-Net was trained to predict cell boundaries, one can use the following semantic instance segmentation metrics (the metrics below are computed by running connected components on threshold boundary map and comparing the resulted instances to the ground truth instance segmentation):BoundaryAveragePrecision(Average Precision applied to the boundary probability maps: thresholds the output from the network, runs connected components to get the segmentation and computes AP between the resulting segmentation and the ground truth)AdaptedRandError(see http://brainiac2.mit.edu/SNEMI3D/evaluation for a detailed explanation)AveragePrecision(see https://www.kaggle.com/stkbailey/step-by-step-explanation-of-scoring-metric)
If not specified MeanIoU will be used by default.
PSNR(peak signal to noise ratio)MSE(mean squared error)