Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions doc/start_page.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,23 +169,58 @@ SynapseNet provides functionality for training a UNet for segmentation tasks usi
In this case, you have to provide data **and** (manual) annotations for the structure(s) you want to segment.
This functionality is implemented in `synapse_net.training.supervised_training`. You can find an example script that shows how to use it [here](https://github.com/computational-cell-analytics/synapse-net/blob/main/examples/network_training.py).

We also provide a command line function to run supervised training: `synapse_net.run_supervised_training`. Run
We also provide a command line function to run supervised training: `synapse_net.run_supervised_training`.
It enables training on data and labels stored in files. Multiple file formats, such as mrc, tif, and hdf5 are supported.

For example, to train a network for vesicle segmentation from mrc files stored in separate folders for training and validation data:
```bash
synapse_net.run_supervised_training \
-n my-vesicle-model \ # The name of the model checkpoint.
--train_folder /path/to/train/tomograms \ # The path to the tomograms to use for training.
--image_file_pattern *.mrc \ # For mrc files, replace if you have a different file type.
--label_folder /path/to/train/labels \ # The path to the vesicle annotations for training.
--label_file_pattern *.mrc \ # For labels stored as mrc, replace if you have a different file type.
--val_folder /path/to/val/tomograms \ # The path to the tomograms to use for validation.
--val_label_folder /path/to/val/labels \ # The path to the vesicle annotations for training.
--patch_shape 48 256 256 \ # The patch shape in ZYX.
--batch_size 2 \ # The batch size for training.
--initial_model vesicles_3d \ # The model to use for weight initialization.
--n_iterations 25000 \ # The number of iterations to train for.
```
In this case, the model is initialized with the weight's of the 3d vesicle segmentation model due to the choice of `initial_model`. You can also choose a model for a different task here, e.g. `mitochondria` or leave out this argument to train a randomly initialized model.

Run
```bash
synapse_net.run_supervised_training -h
```
for more information and instructions on how to use it.
for more information and instructions on how to use the command.


### Domain Adaptation

SynapseNet provides functionality for (unsupervised) domain adaptation.
This functionality is implemented through a student-teacher training approach that can improve segmentation for data from a different condition (for example different sample preparation, imaging technique, or different specimen), **without requiring additional annotated structures**.
Domain adaptation is implemented in `synapse_net.training.domain_adaptation`. You can find an example script that shows how to use it [here](https://github.com/computational-cell-analytics/synapse-net/blob/main/examples/domain_adaptation.py).

We also provide a command line function to run domain adaptation: `synapse_net.run_domain_adaptation`. Run
We also provide a command line function to run domain adaptation: `synapse_net.run_domain_adaptation`.
It enables training on data in local files. Multiple file formats, such as mrc, tif, and hdf5 are supported.

For example, to adapt the network for vesicle segmentation based on mrc files:
```bash
synapse_net.run_domain_adaptation \
-n my-adapted-vesicle-model \ # The name of the model checkpoint.
--input_folder /path/to/tomograms \ # The folder with the tomograms to train on.
--file_pattern *.mrc \ # For mrc files, replace if you have a different file type.
--source_model vesicles_3d \ # To adapt the model for 3D vesicle segmentation.
--patch_shape 48 256 256 \ # The patch shape for training.
--n_iterations 10000 \ # The number of iterations to train for.
```

Run
```bash
synapse_net.run_domain_adaptation -h
```
for more information and instructions on how to use it.
for more information and instructions on how to use the command.

> Note: Domain adaptation only works if the initial model already finds some of the structures in the data from a new condition. If it does not work you will have to train a network on annotated data.

Expand Down
2 changes: 1 addition & 1 deletion environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies:
- python-elf
- pytorch
- tensorboard
- torch_em
- torch_em >=0.8.1
- torchvision
- trimesh
- zarr <3
Expand Down
2 changes: 1 addition & 1 deletion synapse_net/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.0"
__version__ = "0.4.0"
10 changes: 10 additions & 0 deletions synapse_net/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ def _get_model_registry():
return models


def get_available_models() -> List[str]:
"""Get the names of all available pretrained models.
Returns:
The list of available model names.
"""
model_registry = _get_model_registry()
return list(model_registry.urls.keys())


def get_model_path(model_type: str) -> str:
"""Get the local path to a pretrained model.
Expand Down
66 changes: 42 additions & 24 deletions synapse_net/training/domain_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
from .supervised_training import (
get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim, _derive_key_from_files
)
from ..inference.inference import get_model_path, compute_scale_from_voxel_size
from ..inference.inference import get_model_path, compute_scale_from_voxel_size, get_available_models
from ..inference.util import _Scaler


def mean_teacher_adaptation(
name: str,
unsupervised_train_paths: Tuple[str],
Expand All @@ -41,6 +42,7 @@ def mean_teacher_adaptation(
patch_sampler: Optional[callable] = None,
pseudo_label_sampler: Optional[callable] = None,
device: int = 0,
check: bool = False,
) -> None:
"""Run domain adaptation to transfer a network trained on a source domain for a supervised
segmentation task to perform this task on a different target domain.
Expand Down Expand Up @@ -85,12 +87,13 @@ def mean_teacher_adaptation(
based on the patch_shape and size of the volumes used for training.
n_samples_val: The number of val samples per epoch. By default this will be estimated
based on the patch_shape and size of the volumes used for validation.
train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training.
val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training.
val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
patch_sampler: Accept or reject patches based on a condition.
pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients.
device: GPU ID for training.
"""
pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients.
device: GPU ID for training.
check: Whether to check the training and validation loaders instead of running training.
""" # noqa
assert (supervised_train_paths is None) == (supervised_val_paths is None)
is_2d, _ = _determine_ndim(patch_shape)

Expand Down Expand Up @@ -119,23 +122,23 @@ def mean_teacher_adaptation(
pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold)
loss = self_training.DefaultSelfTrainingLoss()
loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()

unsupervised_train_loader = get_unsupervised_loader(
data_paths=unsupervised_train_paths,
raw_key=raw_key,
patch_shape=patch_shape,
batch_size=batch_size,
n_samples=n_samples_train,
sample_mask_paths=train_mask_paths,
data_paths=unsupervised_train_paths,
raw_key=raw_key,
patch_shape=patch_shape,
batch_size=batch_size,
n_samples=n_samples_train,
sample_mask_paths=train_mask_paths,
sampler=patch_sampler
)
unsupervised_val_loader = get_unsupervised_loader(
data_paths=unsupervised_val_paths,
raw_key=raw_key,
patch_shape=patch_shape,
batch_size=batch_size,
n_samples=n_samples_val,
sample_mask_paths=val_mask_paths,
data_paths=unsupervised_val_paths,
raw_key=raw_key,
patch_shape=patch_shape,
batch_size=batch_size,
n_samples=n_samples_val,
sample_mask_paths=val_mask_paths,
sampler=patch_sampler
)

Expand All @@ -153,6 +156,15 @@ def mean_teacher_adaptation(
supervised_train_loader = None
supervised_val_loader = None

if check:
from torch_em.util.debug import check_loader
check_loader(unsupervised_train_loader, n_samples=4)
check_loader(unsupervised_val_loader, n_samples=4)
if supervised_train_loader is not None:
check_loader(supervised_train_loader, n_samples=4)
check_loader(supervised_val_loader, n_samples=4)
return

device = torch.device(f"cuda:{device}") if torch.cuda.is_available() else torch.device("cpu")
trainer = self_training.MeanTeacherTrainer(
name=name,
Expand All @@ -178,8 +190,8 @@ def mean_teacher_adaptation(
sampler=pseudo_label_sampler,
)
trainer.fit(n_iterations)


# TODO patch shapes for other models
PATCH_SHAPES = {
"vesicles_3d": [48, 256, 256],
Expand Down Expand Up @@ -248,6 +260,7 @@ def _parse_patch_shape(patch_shape, model_name):
patch_shape = PATCH_SHAPES[model_name]
return patch_shape


def main():
"""@private
"""
Expand All @@ -267,11 +280,13 @@ def main():
parser.add_argument("--file_pattern", default="*",
help="The pattern for selecting files for training. For example '*.mrc' to select mrc files.")
parser.add_argument("--key", help="The internal file path for the training data. Will be derived from the file extension by default.") # noqa
available_models = get_available_models()
parser.add_argument(
"--source_model",
default="vesicles_3d",
help="The source model used for weight initialization of teacher and student model. "
"By default the model 'vesicles_3d' for vesicle segmentation in volumetric data is used."
"By default the model 'vesicles_3d' for vesicle segmentation in volumetric data is used.\n"
f"The following source models are available: {available_models}"
)
parser.add_argument(
"--resize_training_data", action="store_true",
Expand All @@ -289,14 +304,16 @@ def main():
parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa
parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa
parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.") # noqa
parser.add_argument("--save_root", help="Root path for saving the checkpoint and log dir.")

args = parser.parse_args()

source_checkpoint = get_model_path(args.source_model)
patch_shape = _parse_patch_shape(args.patch_shape, args.source_model)
with tempfile.TemporaryDirectory() as tmp_dir:
unsupervised_train_paths, unsupervised_val_paths = _get_paths(
args.input, args.pattern, args.resize_training_data, args.source_model, tmp_dir, args.val_fraction,
args.input_folder, args.file_pattern, args.resize_training_data,
args.source_model, tmp_dir, args.val_fraction,
)
unsupervised_train_paths, raw_key = _derive_key_from_files(unsupervised_train_paths, args.key)

Expand All @@ -312,4 +329,5 @@ def main():
n_samples_train=args.n_samples_train,
n_samples_val=args.n_samples_val,
check=args.check,
)
save_root=args.save_root,
)
31 changes: 26 additions & 5 deletions synapse_net/training/supervised_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from sklearn.model_selection import train_test_split
from torch_em.model import AnisotropicUNet, UNet2d

from synapse_net.inference.inference import get_model_path, get_available_models


def get_3d_model(
out_channels: int,
Expand Down Expand Up @@ -263,14 +265,13 @@ def supervised_training(
return

is_2d, _ = _determine_ndim(patch_shape)
if is_2d:
if checkpoint_path is not None:
model = torch_em.util.load_model(checkpoint=checkpoint_path)
elif is_2d:
model = get_2d_model(out_channels=out_channels, in_channels=in_channels)
else:
model = get_3d_model(out_channels=out_channels, in_channels=in_channels)

if checkpoint_path:
model = torch_em.util.load_model(checkpoint=checkpoint_path)

loss, metric = None, None
# No ignore label -> we can use default loss.
if ignore_label is None and not mask_channel:
Expand Down Expand Up @@ -368,7 +369,15 @@ def _parse_input_files(args):
return train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key


# TODO enable initialization with a pre-trained model.
def _parse_checkpoint(initial_model):
if initial_model is None:
return None
if os.path.exists(initial_model):
return initial_model
model_path = get_model_path(initial_model)
return model_path


def main():
"""@private
"""
Expand Down Expand Up @@ -404,6 +413,16 @@ def main():
parser.add_argument("--val_label_folder",
help="The input folder with the validation labels. If not given the training data will be split for validation.") # noqa

# Optional: choose a model for initializing the weights.
available_models = get_available_models()
parser.add_argument(
"--initial_model",
help="Choose a model checkpoint for weight initialization.\n"
"This may either be the path to an existing model checkpoint or the name of a pretrained model.\n"
f"The following pretrained models are available: {available_models}.\n"
"If not given, the model will be randomly initialized."
)

# More optional argument:
parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.")
parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.") # noqa
Expand All @@ -416,11 +435,13 @@ def main():

train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key =\
_parse_input_files(args)
checkpoint_path = _parse_checkpoint(args.initial_model)

supervised_training(
name=args.name, train_paths=train_image_paths, val_paths=val_image_paths,
train_label_paths=train_label_paths, val_label_paths=val_label_paths,
raw_key=raw_key, label_key=label_key, patch_shape=args.patch_shape, batch_size=args.batch_size,
n_samples_train=args.n_samples_train, n_samples_val=args.n_samples_val,
check=args.check, n_iterations=args.n_iterations, save_root=args.save_root,
checkpoint_path=checkpoint_path
)
Loading