Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions synapse_net/training/supervised_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def supervised_training(
model = get_2d_model(out_channels=out_channels, in_channels=in_channels)
else:
model = get_3d_model(out_channels=out_channels, in_channels=in_channels)

if checkpoint_path:
model = torch_em.util.load_model(checkpoint=checkpoint_path)

Expand Down Expand Up @@ -362,8 +362,8 @@ def _parse_input_files(args):
else:
if args.val_label_folder is None:
raise ValueError("You have passed a val_folder, but not a val_label_folder.")
val_image_paths = _parse_input_folder(args.val_image_folder, args.image_file_pattern, raw_key)
val_label_paths = _parse_input_folder(args.val_label_folder, args.label_file_pattern, label_key)
val_image_paths, _ = _parse_input_folder(args.val_folder, args.image_file_pattern, raw_key)
val_label_paths, _ = _parse_input_folder(args.val_label_folder, args.label_file_pattern, label_key)

return train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key

Expand Down Expand Up @@ -410,6 +410,8 @@ def main():
parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa
parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa
parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.") # noqa
parser.add_argument("--n_iterations", type=int, default=int(1e5), help="The maximal number of iterations to train for.") # noqa
parser.add_argument("--save_root", help="Root path for saving the checkpoint and log dir.")
args = parser.parse_args()

train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key =\
Expand All @@ -420,5 +422,5 @@ def main():
train_label_paths=train_label_paths, val_label_paths=val_label_paths,
raw_key=raw_key, label_key=label_key, patch_shape=args.patch_shape, batch_size=args.batch_size,
n_samples_train=args.n_samples_train, n_samples_val=args.n_samples_val,
check=args.check,
check=args.check, n_iterations=args.n_iterations, save_root=args.save_root,
)
96 changes: 96 additions & 0 deletions test/training/test_training_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import os
import unittest
from shutil import rmtree
from subprocess import run

from skimage.data import binary_blobs
from skimage.measure import label


class TestTrainignCLI(unittest.TestCase):
tmp_folder = "./tmp_data"

# Create test folder and sample data.
def setUp(self):
n_train = 5
n_val = 2

self.train_data = [
binary_blobs(length=128, n_dim=3, volume_fraction=0.15).astype("uint8") for _ in range(n_train)
]
self.val_data = [
binary_blobs(length=128, n_dim=3, volume_fraction=0.15).astype("uint8") for _ in range(n_val)
]

self.train_labels = [label(data).astype("uint16") for data in self.train_data]
self.val_labels = [label(data).astype("uint16") for data in self.val_data]

os.makedirs(self.tmp_folder, exist_ok=True)

def tearDown(self):
try:
rmtree(self.tmp_folder)
except OSError:
pass

def _test_supervised_training(
self,
train_image_folder,
train_label_folder,
val_image_folder,
val_label_folder,
file_pattern,
):
name = "test-model"
cmd = [
"synapse_net.run_supervised_training",
"-n", name,
"--train_folder", train_image_folder,
"--image_file_pattern", file_pattern,
"--label_folder", train_label_folder,
"--label_file_pattern", file_pattern,
"--val_folder", val_image_folder,
"--val_label_folder", val_label_folder,
"--patch_shape", "64", "64", "64",
"--batch_size", "1",
"--n_samples_train", "5",
"--n_samples_val", "1",
"--n_iterations", "6",
"--save_root", self.tmp_folder,
]
run(cmd)

# Check that the checkpoint exists.
ckpt_path = os.path.join(self.tmp_folder, "checkpoints", name, "latest.pt")
self.assertTrue(os.path.exists(ckpt_path))

def test_supervised_training_mrc(self):
import mrcfile

# Create MRC train and val data.
def write_data(data, labels, out_root):
data_out, label_out = os.path.join(out_root, "volumes"), os.path.join(out_root, "labels")
os.makedirs(data_out, exist_ok=True)
os.makedirs(label_out, exist_ok=True)
for i, (data, labels) in enumerate(zip(data, labels)):
fname = f"tomo-{i}.mrc"
with mrcfile.new(os.path.join(data_out, fname), overwrite=True) as f:
f.set_data(data)
with mrcfile.new(os.path.join(label_out, fname), overwrite=True) as f:
f.set_data(labels)
return data_out, label_out

train_image_folder, train_label_folder = write_data(
self.train_data, self.train_labels, os.path.join(self.tmp_folder, "train")
)
val_image_folder, val_label_folder = write_data(
self.val_data, self.val_labels, os.path.join(self.tmp_folder, "val")
)

self._test_supervised_training(
train_image_folder, train_label_folder, val_image_folder, val_label_folder, file_pattern="*.mrc",
)


if __name__ == "__main__":
unittest.main()