diff --git a/synapse_net/inference/mitochondria.py b/synapse_net/inference/mitochondria.py index 40cfc4e8..8b1e5646 100644 --- a/synapse_net/inference/mitochondria.py +++ b/synapse_net/inference/mitochondria.py @@ -1,5 +1,5 @@ import time -from typing import Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import elf.parallel as parallel import numpy as np @@ -65,6 +65,7 @@ def segment_mitochondria( ws_halo: Tuple[int, ...] = (48, 48, 48), boundary_threshold: float = 0.25, area_threshold: int = 5000, + preprocess: Callable = None, ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """Segment mitochondria in an input volume. @@ -99,7 +100,10 @@ def segment_mitochondria( # Rescale the mask if it was given and run prediction. if mask is not None: mask = scaler.scale_input(mask, is_segmentation=True) - pred = get_prediction(input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose) + pred = get_prediction( + input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose, + preprocess=preprocess + ) # Run segmentation and rescale the result if necessary. foreground, boundaries = pred[:2] diff --git a/synapse_net/inference/util.py b/synapse_net/inference/util.py index 86fa3db3..6c9ab020 100644 --- a/synapse_net/inference/util.py +++ b/synapse_net/inference/util.py @@ -126,6 +126,7 @@ def get_prediction( mask: Optional[ArrayLike] = None, prediction: Optional[ArrayLike] = None, devices: Optional[List[str]] = None, + preprocess: Optional[callable] = None, ) -> ArrayLike: """Run prediction on a given volume. @@ -192,7 +193,7 @@ def get_prediction( # print(f"updated_tiling {updated_tiling}") prediction = get_prediction_torch_em( input_volume, updated_tiling, model_path, model, verbose, with_channels, - mask=mask, prediction=prediction, devices=devices, + mask=mask, prediction=prediction, devices=devices, preprocess=preprocess, ) return prediction @@ -208,6 +209,7 @@ def get_prediction_torch_em( mask: Optional[ArrayLike] = None, prediction: Optional[ArrayLike] = None, devices: Optional[List[str]] = None, + preprocess: Optional[callable] = None, ) -> np.ndarray: """Run prediction using torch-em on a given volume. @@ -258,7 +260,10 @@ def get_prediction_torch_em( print("Run prediction with mask.") mask = mask.astype("bool") - preprocess = None if isinstance(input_volume, np.ndarray) else torch_em.transform.raw.standardize + if preprocess is None: + preprocess = None if isinstance(input_volume, np.ndarray) else torch_em.transform.raw.standardize + else: + preprocess = preprocess prediction = predict_with_halo( input_volume, model, gpu_ids=devices, block_shape=block_shape, halo=halo, diff --git a/synapse_net/training/supervised_training.py b/synapse_net/training/supervised_training.py index 3a2cebc0..1c463238 100644 --- a/synapse_net/training/supervised_training.py +++ b/synapse_net/training/supervised_training.py @@ -201,6 +201,7 @@ def supervised_training( in_channels: int = 1, out_channels: int = 2, mask_channel: bool = False, + checkpoint_path: Optional[str] = None, **loader_kwargs, ): """Run supervised segmentation training. @@ -243,6 +244,7 @@ def supervised_training( out_channels: The number of output channels of the UNet. mask_channel: Whether the last channels in the labels should be used for masking the loss. This can be used to implement more complex masking operations and is not compatible with `ignore_label`. + checkpoint_path: Path to the directory where 'best.pt' resides; continue training this model. loader_kwargs: Additional keyword arguments for the dataloader. """ train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size, @@ -265,6 +267,9 @@ def supervised_training( model = get_2d_model(out_channels=out_channels, in_channels=in_channels) else: model = get_3d_model(out_channels=out_channels, in_channels=in_channels) + + if checkpoint_path: + model = torch_em.util.load_model(checkpoint=checkpoint_path) loss, metric = None, None # No ignore label -> we can use default loss.