diff --git a/doc/start_page.md b/doc/start_page.md index 0986056d..d34d3106 100644 --- a/doc/start_page.md +++ b/doc/start_page.md @@ -169,11 +169,32 @@ SynapseNet provides functionality for training a UNet for segmentation tasks usi In this case, you have to provide data **and** (manual) annotations for the structure(s) you want to segment. This functionality is implemented in `synapse_net.training.supervised_training`. You can find an example script that shows how to use it [here](https://github.com/computational-cell-analytics/synapse-net/blob/main/examples/network_training.py). -We also provide a command line function to run supervised training: `synapse_net.run_supervised_training`. Run +We also provide a command line function to run supervised training: `synapse_net.run_supervised_training`. +It enables training on data and labels stored in files. Multiple file formats, such as mrc, tif, and hdf5 are supported. + +For example, to train a network for vesicle segmentation from mrc files stored in separate folders for training and validation data: +```bash +synapse_net.run_supervised_training \ + -n my-vesicle-model \ # The name of the model checkpoint. + --train_folder /path/to/train/tomograms \ # The path to the tomograms to use for training. + --image_file_pattern *.mrc \ # For mrc files, replace if you have a different file type. + --label_folder /path/to/train/labels \ # The path to the vesicle annotations for training. + --label_file_pattern *.mrc \ # For labels stored as mrc, replace if you have a different file type. + --val_folder /path/to/val/tomograms \ # The path to the tomograms to use for validation. + --val_label_folder /path/to/val/labels \ # The path to the vesicle annotations for training. + --patch_shape 48 256 256 \ # The patch shape in ZYX. + --batch_size 2 \ # The batch size for training. + --initial_model vesicles_3d \ # The model to use for weight initialization. + --n_iterations 25000 \ # The number of iterations to train for. +``` +In this case, the model is initialized with the weight's of the 3d vesicle segmentation model due to the choice of `initial_model`. You can also choose a model for a different task here, e.g. `mitochondria` or leave out this argument to train a randomly initialized model. + +Run ```bash synapse_net.run_supervised_training -h ``` -for more information and instructions on how to use it. +for more information and instructions on how to use the command. + ### Domain Adaptation @@ -181,11 +202,25 @@ SynapseNet provides functionality for (unsupervised) domain adaptation. This functionality is implemented through a student-teacher training approach that can improve segmentation for data from a different condition (for example different sample preparation, imaging technique, or different specimen), **without requiring additional annotated structures**. Domain adaptation is implemented in `synapse_net.training.domain_adaptation`. You can find an example script that shows how to use it [here](https://github.com/computational-cell-analytics/synapse-net/blob/main/examples/domain_adaptation.py). -We also provide a command line function to run domain adaptation: `synapse_net.run_domain_adaptation`. Run +We also provide a command line function to run domain adaptation: `synapse_net.run_domain_adaptation`. +It enables training on data in local files. Multiple file formats, such as mrc, tif, and hdf5 are supported. + +For example, to adapt the network for vesicle segmentation based on mrc files: +```bash +synapse_net.run_domain_adaptation \ + -n my-adapted-vesicle-model \ # The name of the model checkpoint. + --input_folder /path/to/tomograms \ # The folder with the tomograms to train on. + --file_pattern *.mrc \ # For mrc files, replace if you have a different file type. + --source_model vesicles_3d \ # To adapt the model for 3D vesicle segmentation. + --patch_shape 48 256 256 \ # The patch shape for training. + --n_iterations 10000 \ # The number of iterations to train for. +``` + +Run ```bash synapse_net.run_domain_adaptation -h ``` -for more information and instructions on how to use it. +for more information and instructions on how to use the command. > Note: Domain adaptation only works if the initial model already finds some of the structures in the data from a new condition. If it does not work you will have to train a network on annotated data. diff --git a/environment.yaml b/environment.yaml index e85fc3c7..04792a74 100644 --- a/environment.yaml +++ b/environment.yaml @@ -13,7 +13,7 @@ dependencies: - python-elf - pytorch - tensorboard - - torch_em + - torch_em >=0.8.1 - torchvision - trimesh - zarr <3 diff --git a/synapse_net/__version__.py b/synapse_net/__version__.py index 493f7415..6a9beea8 100644 --- a/synapse_net/__version__.py +++ b/synapse_net/__version__.py @@ -1 +1 @@ -__version__ = "0.3.0" +__version__ = "0.4.0" diff --git a/synapse_net/inference/inference.py b/synapse_net/inference/inference.py index 718f17e0..29871479 100644 --- a/synapse_net/inference/inference.py +++ b/synapse_net/inference/inference.py @@ -63,6 +63,16 @@ def _get_model_registry(): return models +def get_available_models() -> List[str]: + """Get the names of all available pretrained models. + + Returns: + The list of available model names. + """ + model_registry = _get_model_registry() + return list(model_registry.urls.keys()) + + def get_model_path(model_type: str) -> str: """Get the local path to a pretrained model. diff --git a/synapse_net/training/domain_adaptation.py b/synapse_net/training/domain_adaptation.py index c57c8bf2..46927db3 100644 --- a/synapse_net/training/domain_adaptation.py +++ b/synapse_net/training/domain_adaptation.py @@ -15,9 +15,10 @@ from .supervised_training import ( get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim, _derive_key_from_files ) -from ..inference.inference import get_model_path, compute_scale_from_voxel_size +from ..inference.inference import get_model_path, compute_scale_from_voxel_size, get_available_models from ..inference.util import _Scaler + def mean_teacher_adaptation( name: str, unsupervised_train_paths: Tuple[str], @@ -41,6 +42,7 @@ def mean_teacher_adaptation( patch_sampler: Optional[callable] = None, pseudo_label_sampler: Optional[callable] = None, device: int = 0, + check: bool = False, ) -> None: """Run domain adaptation to transfer a network trained on a source domain for a supervised segmentation task to perform this task on a different target domain. @@ -85,12 +87,13 @@ def mean_teacher_adaptation( based on the patch_shape and size of the volumes used for training. n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation. - train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training. - val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation. + train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training. + val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation. patch_sampler: Accept or reject patches based on a condition. - pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients. - device: GPU ID for training. - """ + pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients. + device: GPU ID for training. + check: Whether to check the training and validation loaders instead of running training. + """ # noqa assert (supervised_train_paths is None) == (supervised_val_paths is None) is_2d, _ = _determine_ndim(patch_shape) @@ -119,23 +122,23 @@ def mean_teacher_adaptation( pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold) loss = self_training.DefaultSelfTrainingLoss() loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() - + unsupervised_train_loader = get_unsupervised_loader( - data_paths=unsupervised_train_paths, - raw_key=raw_key, - patch_shape=patch_shape, - batch_size=batch_size, - n_samples=n_samples_train, - sample_mask_paths=train_mask_paths, + data_paths=unsupervised_train_paths, + raw_key=raw_key, + patch_shape=patch_shape, + batch_size=batch_size, + n_samples=n_samples_train, + sample_mask_paths=train_mask_paths, sampler=patch_sampler ) unsupervised_val_loader = get_unsupervised_loader( - data_paths=unsupervised_val_paths, - raw_key=raw_key, - patch_shape=patch_shape, - batch_size=batch_size, - n_samples=n_samples_val, - sample_mask_paths=val_mask_paths, + data_paths=unsupervised_val_paths, + raw_key=raw_key, + patch_shape=patch_shape, + batch_size=batch_size, + n_samples=n_samples_val, + sample_mask_paths=val_mask_paths, sampler=patch_sampler ) @@ -153,6 +156,15 @@ def mean_teacher_adaptation( supervised_train_loader = None supervised_val_loader = None + if check: + from torch_em.util.debug import check_loader + check_loader(unsupervised_train_loader, n_samples=4) + check_loader(unsupervised_val_loader, n_samples=4) + if supervised_train_loader is not None: + check_loader(supervised_train_loader, n_samples=4) + check_loader(supervised_val_loader, n_samples=4) + return + device = torch.device(f"cuda:{device}") if torch.cuda.is_available() else torch.device("cpu") trainer = self_training.MeanTeacherTrainer( name=name, @@ -178,8 +190,8 @@ def mean_teacher_adaptation( sampler=pseudo_label_sampler, ) trainer.fit(n_iterations) - - + + # TODO patch shapes for other models PATCH_SHAPES = { "vesicles_3d": [48, 256, 256], @@ -248,6 +260,7 @@ def _parse_patch_shape(patch_shape, model_name): patch_shape = PATCH_SHAPES[model_name] return patch_shape + def main(): """@private """ @@ -267,11 +280,13 @@ def main(): parser.add_argument("--file_pattern", default="*", help="The pattern for selecting files for training. For example '*.mrc' to select mrc files.") parser.add_argument("--key", help="The internal file path for the training data. Will be derived from the file extension by default.") # noqa + available_models = get_available_models() parser.add_argument( "--source_model", default="vesicles_3d", help="The source model used for weight initialization of teacher and student model. " - "By default the model 'vesicles_3d' for vesicle segmentation in volumetric data is used." + "By default the model 'vesicles_3d' for vesicle segmentation in volumetric data is used.\n" + f"The following source models are available: {available_models}" ) parser.add_argument( "--resize_training_data", action="store_true", @@ -289,6 +304,7 @@ def main(): parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.") # noqa + parser.add_argument("--save_root", help="Root path for saving the checkpoint and log dir.") args = parser.parse_args() @@ -296,7 +312,8 @@ def main(): patch_shape = _parse_patch_shape(args.patch_shape, args.source_model) with tempfile.TemporaryDirectory() as tmp_dir: unsupervised_train_paths, unsupervised_val_paths = _get_paths( - args.input, args.pattern, args.resize_training_data, args.source_model, tmp_dir, args.val_fraction, + args.input_folder, args.file_pattern, args.resize_training_data, + args.source_model, tmp_dir, args.val_fraction, ) unsupervised_train_paths, raw_key = _derive_key_from_files(unsupervised_train_paths, args.key) @@ -312,4 +329,5 @@ def main(): n_samples_train=args.n_samples_train, n_samples_val=args.n_samples_val, check=args.check, - ) \ No newline at end of file + save_root=args.save_root, + ) diff --git a/synapse_net/training/supervised_training.py b/synapse_net/training/supervised_training.py index 3d5f3244..680deb82 100644 --- a/synapse_net/training/supervised_training.py +++ b/synapse_net/training/supervised_training.py @@ -7,6 +7,8 @@ from sklearn.model_selection import train_test_split from torch_em.model import AnisotropicUNet, UNet2d +from synapse_net.inference.inference import get_model_path, get_available_models + def get_3d_model( out_channels: int, @@ -263,14 +265,13 @@ def supervised_training( return is_2d, _ = _determine_ndim(patch_shape) - if is_2d: + if checkpoint_path is not None: + model = torch_em.util.load_model(checkpoint=checkpoint_path) + elif is_2d: model = get_2d_model(out_channels=out_channels, in_channels=in_channels) else: model = get_3d_model(out_channels=out_channels, in_channels=in_channels) - if checkpoint_path: - model = torch_em.util.load_model(checkpoint=checkpoint_path) - loss, metric = None, None # No ignore label -> we can use default loss. if ignore_label is None and not mask_channel: @@ -368,7 +369,15 @@ def _parse_input_files(args): return train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key -# TODO enable initialization with a pre-trained model. +def _parse_checkpoint(initial_model): + if initial_model is None: + return None + if os.path.exists(initial_model): + return initial_model + model_path = get_model_path(initial_model) + return model_path + + def main(): """@private """ @@ -404,6 +413,16 @@ def main(): parser.add_argument("--val_label_folder", help="The input folder with the validation labels. If not given the training data will be split for validation.") # noqa + # Optional: choose a model for initializing the weights. + available_models = get_available_models() + parser.add_argument( + "--initial_model", + help="Choose a model checkpoint for weight initialization.\n" + "This may either be the path to an existing model checkpoint or the name of a pretrained model.\n" + f"The following pretrained models are available: {available_models}.\n" + "If not given, the model will be randomly initialized." + ) + # More optional argument: parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.") parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.") # noqa @@ -416,6 +435,7 @@ def main(): train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key =\ _parse_input_files(args) + checkpoint_path = _parse_checkpoint(args.initial_model) supervised_training( name=args.name, train_paths=train_image_paths, val_paths=val_image_paths, @@ -423,4 +443,5 @@ def main(): raw_key=raw_key, label_key=label_key, patch_shape=args.patch_shape, batch_size=args.batch_size, n_samples_train=args.n_samples_train, n_samples_val=args.n_samples_val, check=args.check, n_iterations=args.n_iterations, save_root=args.save_root, + checkpoint_path=checkpoint_path ) diff --git a/test/training/test_training_cli.py b/test/training/test_training_cli.py index dd814710..8b41fe8d 100644 --- a/test/training/test_training_cli.py +++ b/test/training/test_training_cli.py @@ -7,7 +7,7 @@ from skimage.measure import label -class TestTrainignCLI(unittest.TestCase): +class TestTrainingCLI(unittest.TestCase): tmp_folder = "./tmp_data" # Create test folder and sample data. @@ -37,9 +37,10 @@ def _test_supervised_training( self, train_image_folder, train_label_folder, - val_image_folder, - val_label_folder, file_pattern, + val_image_folder=None, + val_label_folder=None, + initial_model=None, ): name = "test-model" cmd = [ @@ -49,8 +50,6 @@ def _test_supervised_training( "--image_file_pattern", file_pattern, "--label_folder", train_label_folder, "--label_file_pattern", file_pattern, - "--val_folder", val_image_folder, - "--val_label_folder", val_label_folder, "--patch_shape", "64", "64", "64", "--batch_size", "1", "--n_samples_train", "5", @@ -58,39 +57,93 @@ def _test_supervised_training( "--n_iterations", "6", "--save_root", self.tmp_folder, ] + if val_image_folder is not None: + assert val_label_folder is not None + cmd.extend([ + "--val_folder", val_image_folder, + "--val_label_folder", val_label_folder, + ]) + if initial_model is not None: + cmd.extend(["--initial_model", initial_model]) run(cmd) # Check that the checkpoint exists. ckpt_path = os.path.join(self.tmp_folder, "checkpoints", name, "latest.pt") self.assertTrue(os.path.exists(ckpt_path)) - def test_supervised_training_mrc(self): + def _write_mrc_data(self, data, out_root, labels=None): import mrcfile - # Create MRC train and val data. - def write_data(data, labels, out_root): - data_out, label_out = os.path.join(out_root, "volumes"), os.path.join(out_root, "labels") - os.makedirs(data_out, exist_ok=True) + data_out = os.path.join(out_root, "volumes") + os.makedirs(data_out, exist_ok=True) + + if labels is None: + labels = [None] * len(data) + label_out = None + else: + label_out = os.path.join(out_root, "labels") os.makedirs(label_out, exist_ok=True) - for i, (data, labels) in enumerate(zip(data, labels)): - fname = f"tomo-{i}.mrc" - with mrcfile.new(os.path.join(data_out, fname), overwrite=True) as f: - f.set_data(data) - with mrcfile.new(os.path.join(label_out, fname), overwrite=True) as f: - f.set_data(labels) - return data_out, label_out - - train_image_folder, train_label_folder = write_data( - self.train_data, self.train_labels, os.path.join(self.tmp_folder, "train") + + for i, (datum, lab) in enumerate(zip(data, labels)): + fname = f"tomo-{i}.mrc" + with mrcfile.new(os.path.join(data_out, fname), overwrite=True) as f: + f.set_data(datum) + + if lab is None: + continue + with mrcfile.new(os.path.join(label_out, fname), overwrite=True) as f: + f.set_data(lab) + + return data_out, label_out + + def test_supervised_training_with_val_data(self): + train_image_folder, train_label_folder = self._write_mrc_data( + self.train_data, os.path.join(self.tmp_folder, "train"), labels=self.train_labels, + ) + val_image_folder, val_label_folder = self._write_mrc_data( + self.val_data, os.path.join(self.tmp_folder, "val"), labels=self.val_labels ) - val_image_folder, val_label_folder = write_data( - self.val_data, self.val_labels, os.path.join(self.tmp_folder, "val") + self._test_supervised_training( + train_image_folder, train_label_folder, file_pattern="*.mrc", + val_image_folder=val_image_folder, val_label_folder=val_label_folder, ) + def test_supervised_training_without_val_data(self): + train_image_folder, train_label_folder = self._write_mrc_data( + self.train_data, os.path.join(self.tmp_folder, "train"), labels=self.train_labels, + ) + self._test_supervised_training(train_image_folder, train_label_folder, file_pattern="*.mrc") + + def test_supervised_training_with_initialization(self): + train_image_folder, train_label_folder = self._write_mrc_data( + self.train_data, os.path.join(self.tmp_folder, "train"), labels=self.train_labels, + ) self._test_supervised_training( - train_image_folder, train_label_folder, val_image_folder, val_label_folder, file_pattern="*.mrc", + train_image_folder, train_label_folder, file_pattern="*.mrc", initial_model="vesicles_3d" ) + def test_domain_adaptation(self): + train_image_folder, _ = self._write_mrc_data(self.train_data, os.path.join(self.tmp_folder, "train")) + name = "test-da-model" + cmd = [ + "synapse_net.run_domain_adaptation", + "-n", name, + "--input_folder", train_image_folder, + "--file_pattern", "*.mrc", + "--source_model", "vesicles_3d", + "--patch_shape", "64", "64", "64", + "--batch_size", "1", + "--n_samples_train", "5", + "--n_samples_val", "1", + "--n_iterations", "6", + "--save_root", self.tmp_folder, + ] + run(cmd) + + # Check that the checkpoint exists. + ckpt_path = os.path.join(self.tmp_folder, "checkpoints", name, "latest.pt") + self.assertTrue(os.path.exists(ckpt_path)) + if __name__ == "__main__": unittest.main()