diff --git a/synapse_net/training/supervised_training.py b/synapse_net/training/supervised_training.py index 1c46323..3d5f324 100644 --- a/synapse_net/training/supervised_training.py +++ b/synapse_net/training/supervised_training.py @@ -267,7 +267,7 @@ def supervised_training( model = get_2d_model(out_channels=out_channels, in_channels=in_channels) else: model = get_3d_model(out_channels=out_channels, in_channels=in_channels) - + if checkpoint_path: model = torch_em.util.load_model(checkpoint=checkpoint_path) @@ -362,8 +362,8 @@ def _parse_input_files(args): else: if args.val_label_folder is None: raise ValueError("You have passed a val_folder, but not a val_label_folder.") - val_image_paths = _parse_input_folder(args.val_image_folder, args.image_file_pattern, raw_key) - val_label_paths = _parse_input_folder(args.val_label_folder, args.label_file_pattern, label_key) + val_image_paths, _ = _parse_input_folder(args.val_folder, args.image_file_pattern, raw_key) + val_label_paths, _ = _parse_input_folder(args.val_label_folder, args.label_file_pattern, label_key) return train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key @@ -410,6 +410,8 @@ def main(): parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.") # noqa + parser.add_argument("--n_iterations", type=int, default=int(1e5), help="The maximal number of iterations to train for.") # noqa + parser.add_argument("--save_root", help="Root path for saving the checkpoint and log dir.") args = parser.parse_args() train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key =\ @@ -420,5 +422,5 @@ def main(): train_label_paths=train_label_paths, val_label_paths=val_label_paths, raw_key=raw_key, label_key=label_key, patch_shape=args.patch_shape, batch_size=args.batch_size, n_samples_train=args.n_samples_train, n_samples_val=args.n_samples_val, - check=args.check, + check=args.check, n_iterations=args.n_iterations, save_root=args.save_root, ) diff --git a/test/training/test_training_cli.py b/test/training/test_training_cli.py new file mode 100644 index 0000000..dd81471 --- /dev/null +++ b/test/training/test_training_cli.py @@ -0,0 +1,96 @@ +import os +import unittest +from shutil import rmtree +from subprocess import run + +from skimage.data import binary_blobs +from skimage.measure import label + + +class TestTrainignCLI(unittest.TestCase): + tmp_folder = "./tmp_data" + + # Create test folder and sample data. + def setUp(self): + n_train = 5 + n_val = 2 + + self.train_data = [ + binary_blobs(length=128, n_dim=3, volume_fraction=0.15).astype("uint8") for _ in range(n_train) + ] + self.val_data = [ + binary_blobs(length=128, n_dim=3, volume_fraction=0.15).astype("uint8") for _ in range(n_val) + ] + + self.train_labels = [label(data).astype("uint16") for data in self.train_data] + self.val_labels = [label(data).astype("uint16") for data in self.val_data] + + os.makedirs(self.tmp_folder, exist_ok=True) + + def tearDown(self): + try: + rmtree(self.tmp_folder) + except OSError: + pass + + def _test_supervised_training( + self, + train_image_folder, + train_label_folder, + val_image_folder, + val_label_folder, + file_pattern, + ): + name = "test-model" + cmd = [ + "synapse_net.run_supervised_training", + "-n", name, + "--train_folder", train_image_folder, + "--image_file_pattern", file_pattern, + "--label_folder", train_label_folder, + "--label_file_pattern", file_pattern, + "--val_folder", val_image_folder, + "--val_label_folder", val_label_folder, + "--patch_shape", "64", "64", "64", + "--batch_size", "1", + "--n_samples_train", "5", + "--n_samples_val", "1", + "--n_iterations", "6", + "--save_root", self.tmp_folder, + ] + run(cmd) + + # Check that the checkpoint exists. + ckpt_path = os.path.join(self.tmp_folder, "checkpoints", name, "latest.pt") + self.assertTrue(os.path.exists(ckpt_path)) + + def test_supervised_training_mrc(self): + import mrcfile + + # Create MRC train and val data. + def write_data(data, labels, out_root): + data_out, label_out = os.path.join(out_root, "volumes"), os.path.join(out_root, "labels") + os.makedirs(data_out, exist_ok=True) + os.makedirs(label_out, exist_ok=True) + for i, (data, labels) in enumerate(zip(data, labels)): + fname = f"tomo-{i}.mrc" + with mrcfile.new(os.path.join(data_out, fname), overwrite=True) as f: + f.set_data(data) + with mrcfile.new(os.path.join(label_out, fname), overwrite=True) as f: + f.set_data(labels) + return data_out, label_out + + train_image_folder, train_label_folder = write_data( + self.train_data, self.train_labels, os.path.join(self.tmp_folder, "train") + ) + val_image_folder, val_label_folder = write_data( + self.val_data, self.val_labels, os.path.join(self.tmp_folder, "val") + ) + + self._test_supervised_training( + train_image_folder, train_label_folder, val_image_folder, val_label_folder, file_pattern="*.mrc", + ) + + +if __name__ == "__main__": + unittest.main()