diff --git a/chabud/callbacks.py b/chabud/callbacks.py index 0d16e49..c7a2804 100644 --- a/chabud/callbacks.py +++ b/chabud/callbacks.py @@ -42,7 +42,8 @@ def on_validation_end( for i in range(batch_size): log_image = wandb.Image( - post_img[i].permute(1, 2, 0).detach().numpy() / 6000, + post_img[i][[3, 2, 1], ...].detach().numpy().transpose(1, 2, 0) + / 3000, masks={ "prediction": { "mask_data": mask[i].detach().cpu().numpy(), diff --git a/chabud/datapipe.py b/chabud/datapipe.py index 1ecb322..c79e612 100644 --- a/chabud/datapipe.py +++ b/chabud/datapipe.py @@ -4,6 +4,7 @@ import os from typing import Iterator +import albumentations as A import datatree import lightning as L import numpy as np @@ -13,7 +14,6 @@ import xarray as xr -# %% def _path_fn(urlpath: str) -> str: """ Get the filename from a urlpath and prepend it with 'data' so that it is @@ -66,8 +66,10 @@ def _train_val_fold(chip: xr.Dataset) -> int: Fold 0 is used for validation, Fold 1 and above is for training. See https://huggingface.co/datasets/chabud-team/chabud-ecml-pkdd2023/discussions/3 """ - if "fold" not in chip.attrs: # no 'fold' attribute, use for training too - return 1 # Training set + if ( + "fold" not in chip.attrs + ): # no 'fold' attribute, split between train,val with 70/30 split + return np.random.rand() > 0.3 if chip.attrs["fold"] == 0: return 0 # Validation set elif chip.attrs["fold"] >= 1: @@ -76,27 +78,68 @@ def _train_val_fold(chip: xr.Dataset) -> int: def _pre_post_mask_tuple( dataset: xr.Dataset, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]: +) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]: """ From a single xarray.Dataset, split it into a tuple containing the pre/post/target tensors and a dictionary object containing metadata. Returns ------- - data_tuple : tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict] + data_tuple : tuple[np.ndarray, np.ndarray, np.ndarray, dict] A tuple with 4 objects, the pre-event image, the post-event image, the mask image, and a Python dict containing metadata (e.g. filename, UUID, fold, comments). """ - # return just the RGB bands for now - pre = dataset.pre_fire.data[[3, 2, 1], ...].astype(dtype="float32") - post = dataset.post_fire.data[[3, 2, 1], ...].astype(dtype="float32") + # # return just the RGB bands for now + # pre = dataset.pre_fire.data[[3, 2, 1], ...].astype(dtype="float32") + # post = dataset.post_fire.data[[3, 2, 1], ...].astype(dtype="float32") + pre = dataset.pre_fire.data.astype(dtype="float32") + post = dataset.post_fire.data.astype(dtype="float32") mask = dataset.mask.data.astype(dtype="uint8") + # pre_g = dataset.pre_fire.data[2, ...].astype(dtype="float32") + # pre_r = dataset.pre_fire.data[3, ...].astype(dtype="float32") + # pre_nir = dataset.pre_fire.data[7, ...].astype(dtype="float32") + # pre_swir = dataset.pre_fire.data[11, ...].astype(dtype="float32") + + # post_g = dataset.post_fire.data[2, ...].astype(dtype="float32") + # post_r = dataset.post_fire.data[3, ...].astype(dtype="float32") + # post_nir = dataset.post_fire.data[7, ...].astype(dtype="float32") + # post_swir = dataset.post_fire.data[11, ...].astype(dtype="float32") + + # # NDVI: nir - r / nir + r + # pre_ndvi = np.nan_to_num( + # (pre_nir - pre_r) / (pre_nir + pre_r), nan=0, posinf=0, neginf=0 + # ) + # # repeat the same for all normalized index + # post_ndvi = np.nan_to_num( + # (post_nir - post_r) / (post_nir + post_r), nan=0, posinf=0, neginf=0 + # ) + + # # NDWI: g - nir / g + nir + # pre_ndwi = np.nan_to_num( + # (pre_g - pre_nir) / (pre_g + pre_nir), nan=0, posinf=0, neginf=0 + # ) + # post_ndwi = np.nan_to_num( + # (post_g - post_nir) / (post_g + post_nir), nan=0, posinf=0, neginf=0 + # ) + + # # NBR: nir - swir / nir + swir + # pre_nbr = np.nan_to_num( + # (pre_nir - pre_swir) / (pre_nir + pre_swir), nan=0, posinf=0, neginf=0 + # ) + # post_nbr = np.nan_to_num( + # (post_nir - post_swir) / (post_nir + post_swir), nan=0, posinf=0, neginf=0 + # ) + + # # combine ndvi, ndwi, nbr into a 3-channel array + # pre = np.stack([pre_ndvi, pre_ndwi, pre_nbr], axis=0) + # post = np.stack([post_ndvi, post_ndwi, post_nbr], axis=0) + return ( - torch.as_tensor(data=pre), - torch.as_tensor(data=post), - torch.as_tensor(data=mask), + pre, + post, + mask, { "filename": os.path.basename(dataset.encoding["source"]), **dataset.attrs, @@ -104,6 +147,34 @@ def _pre_post_mask_tuple( ) +def _apply_augmentation( + sample: tuple[np.ndarray, np.ndarray, np.ndarray, dict] +) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]: + """ + Apply augmentations to a single sample. + """ + aug = A.Compose( + [ + A.HorizontalFlip(p=0.5), + A.VerticalFlip(p=0.5), + A.ShiftScaleRotate( + p=0.5, shift_limit=0.05, scale_limit=0.05, rotate_limit=10 + ), + ], + additional_targets={"post": "image"}, + ) + pre, post, mask, metadata = sample + + # Apply augmentations - albumenations expects channel last + auged = aug(image=pre.transpose(1, 2, 0), post=post.transpose(1, 2, 0), mask=mask) + pre, post, mask = ( + auged["image"].transpose(2, 0, 1), + auged["post"].transpose(2, 0, 1), + auged["mask"], + ) + return (pre, post, mask, metadata) + + def _stack_tensor_collate_fn( samples: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]], ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[dict]]: @@ -111,9 +182,15 @@ def _stack_tensor_collate_fn( Stack a list of torch.Tensor objects into a single torch.Tensor, and combine metadata attributes into a list of dicts. """ - pre_tensor: torch.Tensor = torch.stack(tensors=[sample[0] for sample in samples]) - post_tensor: torch.Tensor = torch.stack(tensors=[sample[1] for sample in samples]) - mask_tensor: torch.Tensor = torch.stack(tensors=[sample[2] for sample in samples]) + pre_tensor: torch.Tensor = torch.stack( + tensors=[torch.as_tensor(sample[0]) for sample in samples] + ) + post_tensor: torch.Tensor = torch.stack( + tensors=[torch.as_tensor(sample[1]) for sample in samples] + ) + mask_tensor: torch.Tensor = torch.stack( + tensors=[torch.as_tensor(sample[2]) for sample in samples] + ) metadata: list[dict] = [sample[3] for sample in samples] return pre_tensor, post_tensor, mask_tensor, metadata @@ -139,8 +216,8 @@ def __init__( # From https://huggingface.co/datasets/chabud-team/chabud-ecml-pkdd2023/tree/main "https://huggingface.co/datasets/chabud-team/chabud-ecml-pkdd2023/resolve/main/train_eval.hdf5", # From https://huggingface.co/datasets/chabud-team/chabud-extra/tree/main - # "https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_0.hdf5", - # "https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_1.hdf5", + "https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_0.hdf5", + "https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_1.hdf5", # "https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_2.hdf5", # "https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_3.hdf5", # "https://huggingface.co/datasets/chabud-team/chabud-extra/resolve/main/california_4.hdf5", @@ -222,8 +299,9 @@ def setup( # Step 4 - Convert from xarray.Dataset to tuple of torch.Tensor objects # Also do shuffling (for train set only), batching, and tensor stacking self.datapipe_train = ( - dp_train.shuffle(buffer_size=100) + dp_train.shuffle(buffer_size=2000) .map(fn=_pre_post_mask_tuple) + .map(fn=_apply_augmentation) .batch(batch_size=self.batch_size) .collate(collate_fn=_stack_tensor_collate_fn) ) diff --git a/chabud/dataset.py b/chabud/dataset.py new file mode 100644 index 0000000..80ebec0 --- /dev/null +++ b/chabud/dataset.py @@ -0,0 +1,93 @@ +import os +from pathlib import Path + +import albumentations as A +from albumentations.pytorch import ToTensorV2 +import lightning as L +import numpy as np +from torch.utils.data import Dataset, DataLoader + + +class ChaBuDDataset(Dataset): + def __init__(self, data_dir: Path, transform=None): + self.data_dir = data_dir + self.uuids = list(data_dir.glob("*.npz")) + self.transform = transform + + def __getitem__(self, idx): + uuid = self.uuids[idx] + event = np.load(uuid) + pre, post, mask = ( + event["pre"].astype(np.float32), + event["post"].astype(np.float32), + event["mask"].astype(np.uint8), + ) + + if self.transform: + tfmed = self.transform( + image=pre.transpose(1, 2, 0), post=post.transpose(1, 2, 0), mask=mask + ) + pre, post, mask = tfmed["image"], tfmed["post"], tfmed["mask"] + + return (pre, post, mask, uuid.stem) + + def __len__(self): + return len(self.uuids) + + +class ChaBuDDataModule(L.LightningDataModule): + def __init__( + self, + data_dir: Path, + batch_size: int = 16, + num_workers: int = 4, + ): + super().__init__() + self.data_dir = data_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.trn_tfm = A.Compose( + [ + A.HorizontalFlip(p=0.5), + A.VerticalFlip(p=0.5), + A.ShiftScaleRotate( + p=0.5, shift_limit=0.05, scale_limit=0.05, rotate_limit=10 + ), + ToTensorV2(), + ], + additional_targets={"post": "image"}, + ) + self.val_tfm = A.Compose([ToTensorV2()], additional_targets={"post": "image"}) + self.tst_tfm = A.Compose([ToTensorV2()], additional_targets={"post": "image"}) + + def setup(self, stage: str | None = None) -> None: + self.trn_ds = ChaBuDDataset(self.data_dir / "trn", transform=self.trn_tfm) + self.val_ds = ChaBuDDataset(self.data_dir / "val", transform=self.val_tfm) + self.tst_ds = ChaBuDDataset(self.data_dir / "val_orig", transform=self.tst_tfm) + + def train_dataloader(self): + return DataLoader( + self.trn_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + pin_memory=True, + ) + + def val_dataloader(self): + return DataLoader( + self.val_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + pin_memory=True, + ) + + def test_dataloader(self): + return DataLoader( + self.tst_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + pin_memory=True, + ) diff --git a/chabud/model.py b/chabud/model.py index 0bc493f..6083cb9 100644 --- a/chabud/model.py +++ b/chabud/model.py @@ -20,7 +20,7 @@ ) from chabud.tinycd_model import ChangeClassifier -from chabud.unet_model import UnetChangeClassifier +from chabud.unet_model import DeepLabChangeClassifier class ChaBuDNet(L.LightningModule): @@ -36,6 +36,7 @@ def __init__( lr: float = 1e-3, model_name="tinycd", submission_filepath: str = "submission.csv", + batch_size: int = 8, ): """ Define layers of the ChaBuDNet model. @@ -77,10 +78,10 @@ def __init__( self.model = self._init_model(model_name) # Loss functions - self.criterion = torch.nn.BCEWithLogitsLoss( - pos_weight=torch.tensor(5.0), reduction="mean" - ) - # self.criterion = DiceLoss(mode="binary", from_logits=True, smooth=0.1) + # self.criterion = torch.nn.BCEWithLogitsLoss( + # pos_weight=torch.tensor(5.0), reduction="mean" + # ) + self.criterion = DiceLoss(mode="binary", from_logits=True, smooth=0.1) # self.criterion = FocalLoss(mode="binary", alpha=0.25, gamma=2.0) # Evaluation metrics to know how good the segmentation results are @@ -90,12 +91,13 @@ def _init_model(self, name): if name == "tinycd": return ChangeClassifier( bkbn_name="efficientnet_b4", - pretrained=True, + pretrained=False, output_layer_bkbn="3", + in_channels=12, freeze_backbone=False, ) elif name == "unet": - return UnetChangeClassifier() + return DeepLabChangeClassifier() else: return NotImplementedError(f"model {name} is not available") @@ -116,7 +118,7 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: def shared_step( self, - batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict], + batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor, str], batch_idx: int, phase: str, log: bool = True, @@ -140,7 +142,7 @@ def shared_step( def training_step( self, - batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict], + batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor, str], batch_idx: int, ) -> torch.Tensor: """ @@ -150,7 +152,7 @@ def training_step( def validation_step( self, - batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict], + batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor, str], batch_idx: int, ) -> torch.Tensor: """ @@ -160,7 +162,7 @@ def validation_step( def test_step( self, - batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict], + batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor, str], batch_idx: int, ) -> torch.Tensor: """ @@ -179,11 +181,18 @@ def test_step( # Pass the image through neural network model to get predicted images logits: torch.Tensor = self(x1=pre_img, x2=post_img).squeeze() y_pred: torch.Tensor = F.sigmoid(logits).detach() + # Log loss and metric + loss: torch.Tensor = self.criterion(logits, mask.float()) + # IOU expects y_pred between 0 & 1 + metric: torch.Tensor = self.iou(y_pred, mask) + self._log(loss, metric, "test") + + y_pred = (y_pred > 0.5).cpu().numpy() # Format predicted mask as binary run length encoding vector result: list = [] - for pred_mask, uuid in zip(y_pred, map(lambda x: x["uuid"], metadata)): - flat_binary_mask: np.ndarray = (y_pred > 0.5).cpu().flatten().numpy() + for pred_mask, uuid in zip(y_pred, metadata): + flat_binary_mask: np.ndarray = pred_mask.flatten() brle: np.ndarray = trimesh.voxel.runlength.dense_to_brle( dense_data=flat_binary_mask ) @@ -203,11 +212,6 @@ def test_step( header=True if batch_idx == 0 else False, ) - # Log loss and metric - loss: torch.Tensor = self.criterion(logits, mask.float()) - metric: torch.Tensor = self.iou(y_pred, mask) - self._log(loss, metric, "test") - return metric def configure_optimizers(self) -> torch.optim.Optimizer: @@ -223,7 +227,16 @@ def configure_optimizers(self) -> torch.optim.Optimizer: Documentation at: https://lightning.ai/docs/pytorch/2.0.2/common/lightning_module.html#configure-optimizers """ - return torch.optim.Adam(params=self.parameters(), lr=self.hparams.lr) + optimizer = torch.optim.Adam(params=self.parameters(), lr=self.hparams.lr) + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[15, 25], + gamma=0.1, + ) + return { + "optimizer": optimizer, + "lr_scheduler": scheduler, + } def _log(self, loss, metric, phase): on_step = True if phase == "train" else False @@ -235,6 +248,7 @@ def _log(self, loss, metric, phase): on_epoch=True, prog_bar=True, logger=True, + batch_size=self.hparams.batch_size, ) self.log( f"{phase}/iou", @@ -243,4 +257,5 @@ def _log(self, loss, metric, phase): on_epoch=True, prog_bar=True, logger=True, + batch_size=self.hparams.batch_size, ) diff --git a/chabud/tinycd_model.py b/chabud/tinycd_model.py index a873bcb..0ede381 100644 --- a/chabud/tinycd_model.py +++ b/chabud/tinycd_model.py @@ -12,7 +12,7 @@ from typing import List import torchvision -from torch import Tensor +from torch import Tensor, nn from torch.nn import Module, ModuleList, Sigmoid, BatchNorm2d from chabud.layers import MixingBlock, MixingMaskAttentionBlock, PixelwiseLinear, UpMask @@ -22,22 +22,28 @@ class ChangeClassifier(Module): def __init__( self, bkbn_name="efficientnet_b4", - pretrained=True, + pretrained=False, output_layer_bkbn="3", + in_channels=3, freeze_backbone=False, ): super().__init__() # Load the pretrained backbone according to parameters: self._backbone = _get_backbone( - bkbn_name, pretrained, output_layer_bkbn, freeze_backbone + bkbn_name, in_channels, pretrained, output_layer_bkbn, freeze_backbone ) # Normalize the input: - self._normalize = BatchNorm2d(3) # 3 number of bands + self._normalize = BatchNorm2d(in_channels) # 3 number of bands # Initialize mixing blocks: - self._first_mix = MixingMaskAttentionBlock(6, 3, [3, 10, 5], [10, 5, 1]) + self._first_mix = MixingMaskAttentionBlock( + in_channels * 2, + in_channels, + [in_channels, in_channels * 4, in_channels * 2], + [in_channels * 4, in_channels * 2, 1], + ) self._mixing_mask = ModuleList( [ MixingMaskAttentionBlock(48, 24, [24, 12, 6], [12, 6, 1]), @@ -81,13 +87,24 @@ def _decode(self, features) -> Tensor: def _get_backbone( - bkbn_name, pretrained, output_layer_bkbn, freeze_backbone + bkbn_name, in_channels, pretrained, output_layer_bkbn, freeze_backbone ) -> ModuleList: # The whole model: entire_model = getattr(torchvision.models, bkbn_name)( pretrained=pretrained ).features + # Change the number of input channels to the backbone + first_conv = entire_model[0][0] + entire_model[0][0] = nn.Conv2d( + in_channels=in_channels, + out_channels=first_conv.out_channels, + kernel_size=first_conv.kernel_size, + stride=first_conv.stride, + padding=first_conv.padding, + bias=first_conv.bias, + ) + # Slicing it: derived_model = ModuleList([]) for name, layer in entire_model.named_children(): diff --git a/chabud/unet_model.py b/chabud/unet_model.py index b022699..c75f3d2 100644 --- a/chabud/unet_model.py +++ b/chabud/unet_model.py @@ -3,11 +3,11 @@ import segmentation_models_pytorch as smp -class UnetChangeClassifier(nn.Module): - def __init__(self, in_channels=6, out_channels=1): +class DeepLabChangeClassifier(nn.Module): + def __init__(self, in_channels=12, out_channels=1): super().__init__() - self.model = smp.Unet( - encoder_name="timm-efficientnet-b0", + self.model = smp.DeepLabV3Plus( + encoder_name="timm-efficientnet-b4", encoder_weights="imagenet", in_channels=in_channels, classes=out_channels, @@ -16,6 +16,6 @@ def __init__(self, in_channels=6, out_channels=1): self.normalize = nn.BatchNorm2d(in_channels) def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - x = torch.cat([x1, x2], dim=1) + x = x2 - x1 # post - pre: change detection x = self.normalize(x) return self.model(x) diff --git a/conda-lock.yml b/conda-lock.yml index 094c6b5..6643e99 100644 --- a/conda-lock.yml +++ b/conda-lock.yml @@ -13,7 +13,7 @@ version: 1 metadata: content_hash: - linux-64: 281f4a2f0866ceef71be4179e520faf6280af391c866991f532a011d9bb34e1d + linux-64: ed9a4ab1a0482b65fed226d9bab17fcda53d92563b0366c089a644b01c350850 channels: - url: conda-forge used_env_vars: [] @@ -111,25 +111,25 @@ package: category: main optional: false - name: libgfortran5 - version: 12.2.0 + version: 13.1.0 manager: conda platform: linux-64 dependencies: {} - url: https://conda.anaconda.org/conda-forge/linux-64/libgfortran5-12.2.0-h337968e_19.tar.bz2 + url: https://conda.anaconda.org/conda-forge/linux-64/libgfortran5-13.1.0-h15d22d2_0.conda hash: - md5: 164b4b1acaedc47ee7e658ae6b308ca3 - sha256: 03ea784edd12037dc3a7a0078ff3f9c3383feabb34d5ba910bb2fd7a21a2d961 + md5: afb656a334c409dd9805508af1c89c7a + sha256: a06235f4c4b85b463d9b8a73c9e10c1b5b4105f8a0ea8ac1f2f5f64edac3dfe7 category: main optional: false - name: libstdcxx-ng - version: 12.2.0 + version: 13.1.0 manager: conda platform: linux-64 dependencies: {} - url: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-12.2.0-h46fd767_19.tar.bz2 + url: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-13.1.0-hfd8a6a1_0.conda hash: - md5: 1030b1f38c129f2634eae026f704fe60 - sha256: 0289e6a7b9a5249161a3967909e12dcfb4ab4475cdede984635d3fb65c606f08 + md5: 067bcc23164642f4c226da631f2a2e1d + sha256: 6f9eb2d7a96687938c0001166a3b308460a8eb02b10e9d0dd9e251f0219ea05c category: main optional: false - name: python_abi @@ -170,15 +170,15 @@ package: category: main optional: false - name: libgfortran-ng - version: 12.2.0 + version: 13.1.0 manager: conda platform: linux-64 dependencies: - libgfortran5: 12.2.0 - url: https://conda.anaconda.org/conda-forge/linux-64/libgfortran-ng-12.2.0-h69a702a_19.tar.bz2 + libgfortran5: 13.1.0 + url: https://conda.anaconda.org/conda-forge/linux-64/libgfortran-ng-13.1.0-h69a702a_0.conda hash: - md5: cd7a806282c16e1f2d39a7e80d3a3e0d - sha256: c7d061f323e80fbc09564179073d8af303bf69b953b0caddcf79b47e352c746f + md5: 506dc07710dd5b0ba63cbf134897fc10 + sha256: 429e1d8a3e70b632df5b876e3fc322a56f769756693daa07114c46fa5098684e category: main optional: false - name: fonts-conda-ecosystem @@ -207,16 +207,16 @@ package: category: main optional: false - name: libgcc-ng - version: 12.2.0 + version: 13.1.0 manager: conda platform: linux-64 dependencies: _libgcc_mutex: '0.1' _openmp_mutex: '>=4.5' - url: https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-12.2.0-h65d4601_19.tar.bz2 + url: https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-13.1.0-he5830b7_0.conda hash: - md5: e4c94f80aef025c17ab0828cd85ef535 - sha256: f3899c26824cee023f1e360bd0859b0e149e2b3e8b1668bc6dd04bfc70dcd659 + md5: cd93f779ff018dd85c7544c015c9db3c + sha256: fba897a02f35b2b5e6edc43a746d1fa6970a77b422f258246316110af8966911 category: main optional: false - name: aws-c-common @@ -479,17 +479,17 @@ package: category: main optional: false - name: libopenblas - version: 0.3.21 + version: 0.3.23 manager: conda platform: linux-64 dependencies: libgcc-ng: '>=12' libgfortran-ng: '' - libgfortran5: '>=10.4.0' - url: https://conda.anaconda.org/conda-forge/linux-64/libopenblas-0.3.21-pthreads_h78a6416_3.tar.bz2 + libgfortran5: '>=11.3.0' + url: https://conda.anaconda.org/conda-forge/linux-64/libopenblas-0.3.23-pthreads_h80387f5_0.conda hash: - md5: 8c5963a49b6035c40646a763293fbb35 - sha256: 018372af663987265cb3ca8f37ac8c22b5f39219f65a0c162b056a30af11bba0 + md5: 9c5ea51ccb8ffae7d06c645869d24ce6 + sha256: 00aee12d04979d024c7f9cabccff5f5db2852c934397ec863a4abde3e09d5a79 category: main optional: false - name: libsodium @@ -786,10 +786,10 @@ package: libgcc-ng: '>=12' libstdcxx-ng: '>=12' libzlib: '>=1.2.13,<1.3.0a0' - url: https://conda.anaconda.org/conda-forge/linux-64/cudnn-8.8.0.121-h0800d71_0.conda + url: https://conda.anaconda.org/conda-forge/linux-64/cudnn-8.8.0.121-h0800d71_1.conda hash: - md5: cbc302fc8abc25e58dcdd13aaf6d3a41 - sha256: 89eea37b95c1ee1ad441ce269493b038172d164d2d3933a75377fd3380633dd2 + md5: 6a64eb12e4d74deba7dff951bfcb0d57 + sha256: 1992e3b655f6c410a478c8252e735bedd9a28e4540cf5617f554d52c7d2b3e0d category: main optional: false - name: expat @@ -810,11 +810,11 @@ package: manager: conda platform: linux-64 dependencies: - libopenblas: '>=0.3.21,<1.0a0' - url: https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-16_linux64_openblas.tar.bz2 + libopenblas: '>=0.3.23,<1.0a0' + url: https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-17_linux64_openblas.conda hash: - md5: d9b7a8639171f6c6fa0a983edabcfe2b - sha256: 4e4c60d3fe0b95ffb25911dace509e3532979f5deef4364141c533c5ca82dd39 + md5: 57fb44770b1bc832fb2dbefa1bd502de + sha256: 5a9dfeb9ede4b7ac136ac8c0b589309f8aba5ce79d14ca64ad8bffb3876eb04b category: main optional: false - name: libbrotlidec @@ -914,17 +914,17 @@ package: category: main optional: false - name: libssh2 - version: 1.10.0 + version: 1.11.0 manager: conda platform: linux-64 dependencies: libgcc-ng: '>=12' - libzlib: '>=1.2.12,<1.3.0a0' - openssl: '>=3.0.5,<4.0a0' - url: https://conda.anaconda.org/conda-forge/linux-64/libssh2-1.10.0-hf14f497_3.tar.bz2 + libzlib: '>=1.2.13,<1.3.0a0' + openssl: '>=3.1.1,<4.0a0' + url: https://conda.anaconda.org/conda-forge/linux-64/libssh2-1.11.0-h0841786_0.conda hash: - md5: d85acad4b47dff4e3def14a769a97906 - sha256: 9a9a01f35d2d50326eb8ca7c0a92d0c45b2d0f77d9ea117680c70094ff480c0c + md5: 1f5a58e686b13bcfde88b93f547d23fe + sha256: 50e47fd9c4f7bf841a11647ae7486f65220cfc988ec422a4475fe8d5a823824d category: main optional: false - name: libxcb @@ -972,7 +972,7 @@ package: category: main optional: false - name: nccl - version: 2.15.5.1 + version: 2.17.1.1 manager: conda platform: linux-64 dependencies: @@ -980,10 +980,10 @@ package: cudatoolkit: '>=11.2,<12' libgcc-ng: '>=12' libstdcxx-ng: '>=12' - url: https://conda.anaconda.org/conda-forge/linux-64/nccl-2.15.5.1-h0800d71_0.conda + url: https://conda.anaconda.org/conda-forge/linux-64/nccl-2.17.1.1-h0800d71_0.conda hash: - md5: 502bc5485bb57d716ce13be9ed869938 - sha256: 7499b4a5640b5e88c491bb098ba8953945acffffb577702fb0ec43b92b4514fd + md5: e39703efa65c16c8a4c5bc647cfd593a + sha256: f4e536136b0e24042e335083338cb4a1e246d1eed5d71c318a3ec18db8d5b2a0 category: main optional: false - name: pandoc @@ -1171,10 +1171,10 @@ package: platform: linux-64 dependencies: libblas: 3.9.0 - url: https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.9.0-16_linux64_openblas.tar.bz2 + url: https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.9.0-17_linux64_openblas.conda hash: - md5: 20bae26d0a1db73f758fc3754cab4719 - sha256: e4ceab90a49cb3ac1af20177016dc92066aa278eded19646bb928d261b98367f + md5: 7ef0969b00fe3d6eef56a8151d3afb29 + sha256: 535bc0a6bc7641090b1bdd00a001bb6c4ac43bce2a11f238bc6676252f53eb3f category: main optional: false - name: libglib @@ -1218,10 +1218,10 @@ package: platform: linux-64 dependencies: libblas: 3.9.0 - url: https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.9.0-16_linux64_openblas.tar.bz2 + url: https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.9.0-17_linux64_openblas.conda hash: - md5: 955d993f41f9354bf753d29864ea20ad - sha256: f5f30b8049dfa368599e5a08a4f35cb1966af0abc539d1fd1f50d93db76a74e6 + md5: a2103882c46492e26500fcb56c03de8b + sha256: 45128394d2f4d4caf949c1b02bff1cace3ef2e33762dbe8f0edec7701a16aaa9 category: main optional: false - name: libtiff @@ -1245,16 +1245,16 @@ package: category: main optional: false - name: llvm-openmp - version: 16.0.4 + version: 16.0.5 manager: conda platform: linux-64 dependencies: libzlib: '>=1.2.13,<1.3.0a0' zstd: '>=1.5.2,<1.6.0a0' - url: https://conda.anaconda.org/conda-forge/linux-64/llvm-openmp-16.0.4-h4dfa4b3_0.conda + url: https://conda.anaconda.org/conda-forge/linux-64/llvm-openmp-16.0.5-h4dfa4b3_0.conda hash: - md5: 68ffdf82a717033ead1c5edbfeff9f54 - sha256: 48df036610d3ee357b408256f222589ea24909572b20e0bf73f9a1a3c42fe255 + md5: 9441a97b74c692d969ff465ac6c0ccea + sha256: 0dacf609b831fc518835fd82d3781c0413dbe171d45f5f21b2f9fb38439e9286 category: main optional: false - name: mpc @@ -1927,10 +1927,10 @@ package: libgcc-ng: '>=12' liblapack: '>=3.9.0,<4.0a0' libstdcxx-ng: '>=12' - url: https://conda.anaconda.org/conda-forge/linux-64/libmagma-2.7.1-hc72dce7_1.conda + url: https://conda.anaconda.org/conda-forge/linux-64/libmagma-2.7.1-hc72dce7_2.conda hash: - md5: 027b2c000121fe224f44fa717f03eca6 - sha256: 022d34448b37bd4ab46143742fa8726e44676729ab574e097f8524a578d40cb8 + md5: fdac3892aa3639be60555c195f5eb5f2 + sha256: ca158b78a54fe82e7d42c7e22ba524f4dfdc0ac37c182f20d6055382b4f47495 category: main optional: false - name: libwebp @@ -1963,17 +1963,17 @@ package: category: main optional: false - name: markupsafe - version: 2.1.2 + version: 2.1.3 manager: conda platform: linux-64 dependencies: libgcc-ng: '>=12' python: '>=3.11,<3.12.0a0' python_abi: 3.11.* - url: https://conda.anaconda.org/conda-forge/linux-64/markupsafe-2.1.2-py311h2582759_0.conda + url: https://conda.anaconda.org/conda-forge/linux-64/markupsafe-2.1.3-py311h459d7ec_0.conda hash: - md5: adb20bd57069614552adac60a020c36d - sha256: 48ee4934fc6e6ef4b0c66bb6698538beee5a5d94576c1d6b87c21009b84d55bf + md5: 9904dc4adb5d547cb21e136f98cb24b0 + sha256: 747b00706156b61d48565710f38cdb382e22f7db03e5b429532a2d5d5917c313 category: main optional: false - name: mdurl @@ -2122,17 +2122,17 @@ package: category: main optional: false - name: orjson - version: 3.8.14 + version: 3.9.0 manager: conda platform: linux-64 dependencies: libgcc-ng: '>=12' python: '>=3.11,<3.12.0a0' python_abi: 3.11.* - url: https://conda.anaconda.org/conda-forge/linux-64/orjson-3.8.14-py311h34b1e23_0.conda + url: https://conda.anaconda.org/conda-forge/linux-64/orjson-3.9.0-py311h34b1e23_0.conda hash: - md5: f468a1be3875ac2859079f6e528a8826 - sha256: d33717ecc920f77c55661624b9ed4887819f9740820c85a4b74b5a59b690eda8 + md5: 5c9b49e7917c1f865cddd67c1410b60a + sha256: e28bcf97d4f1dcd1f24824aadf69e935888055d0a2db4b90c6406759dc27d226 category: main optional: false - name: packaging @@ -2747,15 +2747,15 @@ package: category: main optional: false - name: typing_extensions - version: 4.6.2 + version: 4.6.3 manager: conda platform: linux-64 dependencies: python: '>=3.7' - url: https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.6.2-pyha770c72_0.conda + url: https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.6.3-pyha770c72_0.conda hash: - md5: 5a4a270e5a3f93846d6bade2f71fa440 - sha256: 8af96d7b665daabe3e60fa9c7457986237db1ad54469b01af3f4736bc18be284 + md5: 4a3014a4d107d15475d106b751c4e352 + sha256: 90a8d56c8015af1575d504d5f77d95a806cd999fc178a06ab51a349f1f744672 category: main optional: false - name: typing_utils @@ -3627,15 +3627,15 @@ package: category: main optional: false - name: typing-extensions - version: 4.6.2 + version: 4.6.3 manager: conda platform: linux-64 dependencies: - typing_extensions: 4.6.2 - url: https://conda.anaconda.org/conda-forge/noarch/typing-extensions-4.6.2-hd8ed1ab_0.conda + typing_extensions: 4.6.3 + url: https://conda.anaconda.org/conda-forge/noarch/typing-extensions-4.6.3-hd8ed1ab_0.conda hash: - md5: f676553904bb8f7c1dfe71c9db0d9ba7 - sha256: 5c6dcf5ff0d6be8a15d6bf5297867d9cb0154b6b946e8c87f69becf8a356e71b + md5: 3876f650ed7d0f95d70fa4b647621909 + sha256: d2334dab270e13182403cc3a394e3da8e7acb409e94059a6d9223d2ac053f90a category: main optional: false - name: validators @@ -3755,7 +3755,7 @@ package: category: main optional: false - name: cryptography - version: 41.0.0 + version: 41.0.1 manager: conda platform: linux-64 dependencies: @@ -3764,10 +3764,10 @@ package: openssl: '>=3.1.1,<4.0a0' python: '>=3.11,<3.12.0a0' python_abi: 3.11.* - url: https://conda.anaconda.org/conda-forge/linux-64/cryptography-41.0.0-py311h63ff55d_0.conda + url: https://conda.anaconda.org/conda-forge/linux-64/cryptography-41.0.1-py311h63ff55d_0.conda hash: - md5: 6bc48185b9486a8590a599e768cefb61 - sha256: 02ce04c837132d2cf6e08f03d34bbfa586bdeebfa897966f9c6da4638413aa9a + md5: 69ad01f66b8efff535d341ba5b283c2c + sha256: 7e8c5469d96b08ef1acd4cc2fc41d40ae76c30e08372cfe9f7a5a236bec2eb65 category: main optional: false - name: dateutils @@ -4537,7 +4537,7 @@ package: category: main optional: false - name: botocore - version: 1.29.144 + version: 1.29.146 manager: conda platform: linux-64 dependencies: @@ -4545,10 +4545,10 @@ package: python: '>=3.7' python-dateutil: '>=2.1,<3.0.0' urllib3: '>=1.25.4,<1.27' - url: https://conda.anaconda.org/conda-forge/noarch/botocore-1.29.144-pyhd8ed1ab_0.conda + url: https://conda.anaconda.org/conda-forge/noarch/botocore-1.29.146-pyhd8ed1ab_0.conda hash: - md5: 1292fc54c547ab84ea7ded4162c206b7 - sha256: c2169156a4c52fc8a410376afba76f6349f539a5519c18c63de89babacb3c348 + md5: f4b1bce43c105800c8dc104aff45df15 + sha256: 46dc3ded104e3023d3d8acda1ca1a2c6e3d0e21d8fac33443a9dfd31dbc119d2 category: main optional: false - name: dulwich @@ -4602,7 +4602,7 @@ package: category: main optional: false - name: ipython - version: 8.13.2 + version: 8.14.0 manager: conda platform: linux-64 dependencies: @@ -4619,10 +4619,10 @@ package: stack_data: '' traitlets: '>=5' typing_extensions: '' - url: https://conda.anaconda.org/conda-forge/noarch/ipython-8.13.2-pyh41d4057_0.conda + url: https://conda.anaconda.org/conda-forge/noarch/ipython-8.14.0-pyh41d4057_0.conda hash: - md5: e8563c13eee80a5f1c7bdfc2a1b20077 - sha256: 4cdf7c6b93594023db5a7d02bd7e4e295bfede9d81571e7a16d01951a3ff5816 + md5: 0a0b0d8177c4a209017b356439292db8 + sha256: 25f1e5d78f7f063b3e32b939fed1e4d797df12f384c949902a325e6539315f96 category: main optional: false - name: nbclient @@ -4884,11 +4884,11 @@ package: category: main optional: false - name: awscli - version: 1.27.144 + version: 1.27.146 manager: conda platform: linux-64 dependencies: - botocore: 1.29.144 + botocore: 1.29.146 colorama: '>=0.2.5,<0.4.5' docutils: '>=0.10,<0.17' python: '>=3.11,<3.12.0a0' @@ -4896,10 +4896,10 @@ package: pyyaml: '>=3.10,<5.5' rsa: '>=3.1.2,<4.8' s3transfer: '>=0.6.0,<0.7.0' - url: https://conda.anaconda.org/conda-forge/linux-64/awscli-1.27.144-py311h38be061_0.conda + url: https://conda.anaconda.org/conda-forge/linux-64/awscli-1.27.146-py311h38be061_0.conda hash: - md5: c9f25bf53357f2b7036e3053b79d4e2c - sha256: 31fd0cea4e831f69a7023d31851a182124eaf24a557a0bf1f41956115f361bd8 + md5: b615cce69c1e453ad77feb9bdd954525 + sha256: 1822f1716fed5b7f6f426547faf80dd36514571103da988db9b4074f93244f3d category: main optional: false - name: cachecontrol-with-filecache @@ -5320,6 +5320,16 @@ package: sha256: 106caf6167c4597556b31a8d9175a3fdc0356fdcd70ab19973c3b0d4c893c461 category: main optional: false +- name: joblib + version: 1.2.0 + manager: pip + platform: linux-64 + dependencies: {} + url: https://files.pythonhosted.org/packages/91/d4/3b4c8e5a30604df4c7518c562d4bf0502f2fa29221459226e140cf846512/joblib-1.2.0-py3-none-any.whl + hash: + sha256: 091138ed78f800342968c523bdde947e7a305b8594b910a0fea2ab83c3c6d385 + category: main + optional: false - name: kiwisolver version: 1.4.4 manager: pip @@ -5330,6 +5340,16 @@ package: sha256: 78d6601aed50c74e0ef02f4204da1816147a6d3fbdc8b3872d263338a9052c51 category: main optional: false +- name: lazy-loader + version: '0.2' + manager: pip + platform: linux-64 + dependencies: {} + url: https://files.pythonhosted.org/packages/a1/a8/c41f46b47a381bd60a40c0ef00d2fd1722b743b178f9c1cec0da949043de/lazy_loader-0.2-py3-none-any.whl + hash: + sha256: c35875f815c340f823ce3271ed645045397213f961b40ad0c0d395c3f5218eeb + category: main + optional: false - name: lit version: 16.0.5 manager: pip @@ -5480,6 +5500,16 @@ package: sha256: 997a2cc14023713f423e6d16536d55cb16a3d72850f142e05f82f0d4c76d383b category: main optional: false +- name: threadpoolctl + version: 3.1.0 + manager: pip + platform: linux-64 + dependencies: {} + url: https://files.pythonhosted.org/packages/61/cf/6e354304bcb9c6413c4e02a747b600061c21d38ba51e7e544ac7bc66aecc/threadpoolctl-3.1.0-py3-none-any.whl + hash: + sha256: 8b99adda265feb6773280df41eece7b2e6561b772d21ffd52e372f999024907b + category: main + optional: false - name: contourpy version: 1.0.7 manager: pip @@ -5519,6 +5549,18 @@ package: sha256: 9fc619170d800ff3793ad37c9757c255c8783051e1b5b00501205eb43ccc4f27 category: main optional: false +- name: imageio + version: 2.31.0 + manager: pip + platform: linux-64 + dependencies: + numpy: '*' + pillow: '>=8.3.2' + url: https://files.pythonhosted.org/packages/f7/9d/47d0a9d0f267e9155963db8608ffbc448f2b5d4e5414d8e608309f422094/imageio-2.31.0-py3-none-any.whl + hash: + sha256: 141bbd97910fad105c179a6b344ae4e7fef0dd85411303c63cd925b4c6163bee + category: main + optional: false - name: munch version: 3.0.0 manager: pip @@ -5530,6 +5572,39 @@ package: sha256: 0e4108418cfea898dcad01ff9569c30ff58f01d6f699331d04364f51623627c0 category: main optional: false +- name: opencv-python-headless + version: 4.7.0.72 + manager: pip + platform: linux-64 + dependencies: + numpy: '>=1.17.3' + url: https://files.pythonhosted.org/packages/3f/45/21fc904365f9cea3559e0192349bfe3ea2dce52672c1d9127c3b59711804/opencv_python_headless-4.7.0.72-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + hash: + sha256: 18dac3147863d2f4beef6b06b784ee115799a7842e2883adc4ae750c432613f9 + category: main + optional: false +- name: pywavelets + version: 1.4.1 + manager: pip + platform: linux-64 + dependencies: + numpy: '>=1.17.3' + url: https://files.pythonhosted.org/packages/de/a1/cd8a30e061f858f219364554b19d4318276c677a51d956c55fb0b134e8b2/PyWavelets-1.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + hash: + sha256: 875d4d620eee655346e3589a16a73790cf9f8917abba062234439b594e706784 + category: main + optional: false +- name: scipy + version: 1.10.1 + manager: pip + platform: linux-64 + dependencies: + numpy: '>=1.19.5,<1.27.0' + url: https://files.pythonhosted.org/packages/21/cd/fe2d4af234b80dc08c911ce63fdaee5badcdde3e9bcd9a68884580652ef0/scipy-1.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + hash: + sha256: 15a35c4242ec5f292c3dd364a7c71a61be87a3d4ddcc693372813c0b73c9af1d + category: main + optional: false - name: setuptools-scm version: 7.1.0 manager: pip @@ -5542,6 +5617,17 @@ package: sha256: 73988b6d848709e2af142aa48c986ea29592bbcfca5375678064708205253d8e category: main optional: false +- name: tifffile + version: 2023.4.12 + manager: pip + platform: linux-64 + dependencies: + numpy: '*' + url: https://files.pythonhosted.org/packages/93/86/2ed10947a1891ceb86b084153fac06877fdec38a5ed69bd9286eefab3d44/tifffile-2023.4.12-py3-none-any.whl + hash: + sha256: 3161954746fe32c4f4244d0fb2eb0a272f3a3760b78882a42faa83ac5e6e0b74 + category: main + optional: false - name: torchvision version: 0.15.1 manager: pip @@ -5615,6 +5701,39 @@ package: sha256: 7e77ead4619a3e11ab3c41982c8ad5b86edffe37c87fd2a37ec3c2cc6470b98a category: main optional: false +- name: scikit-image + version: 0.21.0 + manager: pip + platform: linux-64 + dependencies: + numpy: '>=1.21.1' + scipy: '>=1.8' + networkx: '>=2.8' + pillow: '>=9.0.1' + imageio: '>=2.27' + tifffile: '>=2022.8.12' + pywavelets: '>=1.1.1' + packaging: '>=21' + lazy-loader: '>=0.2' + url: https://files.pythonhosted.org/packages/22/c3/c5f3c351d6337a18d07c3fb04475626c106cd3dc3d59b85ec50d07656db0/scikit_image-0.21.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + hash: + sha256: ff5719c7eb99596a39c3e1d9b564025bae78ecf1da3ee6842d34f6965b5f1474 + category: main + optional: false +- name: scikit-learn + version: 1.2.2 + manager: pip + platform: linux-64 + dependencies: + numpy: '>=1.17.3' + scipy: '>=1.3.2' + joblib: '>=1.1.1' + threadpoolctl: '>=2.0.0' + url: https://files.pythonhosted.org/packages/4c/64/a1e6e92b850b39200c82e3bc54d556b2c634b3904c39ac5cdb10b1c5765f/scikit_learn-1.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + hash: + sha256: bf036ea7ef66115e0d49655f16febfa547886deba20149555a41d28f56fd6d3c + category: main + optional: false - name: timm version: 0.9.2 manager: pip @@ -5630,6 +5749,20 @@ package: sha256: 8da40cc58ed32b0622bf87d8714f9b7023398ba4cfa8fa678578d2aefde4a909 category: main optional: false +- name: qudida + version: 0.0.4 + manager: pip + platform: linux-64 + dependencies: + numpy: '>=0.18.0' + scikit-learn: '>=0.19.1' + typing-extensions: '*' + opencv-python-headless: '>=4.0.1' + url: https://files.pythonhosted.org/packages/f0/a1/a5f4bebaa31d109003909809d88aeb0d4b201463a9ea29308d9e4f9e7655/qudida-0.0.4-py3-none-any.whl + hash: + sha256: 4519714c40cd0f2e6c51e1735edae8f8b19f4efe1f33be13e9d644ca5f736dd6 + category: main + optional: false - name: segmentation-models-pytorch version: 0.3.3 manager: pip @@ -5646,3 +5779,19 @@ package: sha256: b4317d6f72cb1caf4b7e1d384096970e202600275f54deb8e774fc04d6c8b82e category: main optional: false +- name: albumentations + version: 1.3.0 + manager: pip + platform: linux-64 + dependencies: + numpy: '>=1.11.1' + scipy: '*' + scikit-image: '>=0.16.1' + pyyaml: '*' + qudida: '>=0.0.4' + opencv-python-headless: '>=4.1.1' + url: https://files.pythonhosted.org/packages/4f/55/3c2ce84c108fc1d422afd6de153e4b0a3e6f96ecec4cb9afcf0284ce3538/albumentations-1.3.0-py3-none-any.whl + hash: + sha256: 294165d87d03bc8323e484927f0a5c1a3c64b0e7b9c32a979582a6c93c363bdf + category: main + optional: false diff --git a/environment.yml b/environment.yml index a17f54f..4f17fad 100644 --- a/environment.yml +++ b/environment.yml @@ -21,5 +21,6 @@ dependencies: - pip: - typeshed-client==2.3.0 - segmentation-models-pytorch==0.3.3 + - albumentations==1.3.0 platforms: - linux-64 diff --git a/trainer.py b/trainer.py index fde1827..0df9a92 100644 --- a/trainer.py +++ b/trainer.py @@ -21,11 +21,12 @@ from lightning.pytorch.cli import ArgsType, LightningCLI from chabud.datapipe import ChaBuDDataPipeModule +from chabud.dataset import ChaBuDDataModule from chabud.model import ChaBuDNet from chabud.callbacks import LogIntermediatePredictions -def main(): +def main(stage: str = "train", ckpt_path: str = None): cwd = os.getcwd() (Path(cwd) / "logs").mkdir(exist_ok=True) @@ -49,7 +50,8 @@ def main(): ckpt_cb = ModelCheckpoint( monitor="val/iou", mode="max", - save_top_k=2, + save_top_k=1, + save_last=True, verbose=True, filename="epoch:{epoch}-step:{step}-loss:{val/loss:.3f}-iou:{val/iou:.3f}", auto_insert_metric_name=False, @@ -57,12 +59,20 @@ def main(): log_preds_cb = LogIntermediatePredictions(logger=wandb_logger) # DATAMODULE - dm = ChaBuDDataPipeModule(batch_size=20) + batch_size = 16 + # dm = ChaBuDDataPipeModule(batch_size=batch_size) + dm = ChaBuDDataModule( + data_dir=Path("./data"), + batch_size=batch_size, + ) dm.setup() # MODEL model = ChaBuDNet( - lr=1e-3, model_name="tinycd", submission_filepath="submission.csv" + lr=1e-3, + model_name="tinycd", + submission_filepath=f"{name}-submission.csv", + batch_size=batch_size, ) debug = False @@ -74,28 +84,32 @@ def main(): devices=1, accelerator="gpu", precision="16-mixed", - max_epochs=2 if debug else 20, + max_epochs=2 if debug else 30, accumulate_grad_batches=1, logger=[ csv_logger, wandb_logger, ], - callbacks=[ckpt_cb, log_preds_cb], + callbacks=[lr_cb, ckpt_cb, log_preds_cb], log_every_n_steps=1, ) - # TRAIN - print("TRAIN") - trainer.fit( - model, - train_dataloaders=dm.train_dataloader(), - val_dataloaders=dm.val_dataloader(), - ) + if stage == "train": + # TRAIN + print("TRAIN") + trainer.fit( + model, + train_dataloaders=dm.train_dataloader(), + val_dataloaders=dm.val_dataloader(), + ckpt_path="last", + ) # EVAL device = "cuda" print("EVAL") - model = ChaBuDNet.load_from_checkpoint(ckpt_cb.best_model_path).to(device) + model = ChaBuDNet.load_from_checkpoint( + ckpt_cb.best_model_path if ckpt_path is None else ckpt_path + ).to(device) model.eval() model.freeze() trainer.test(model, dataloaders=dm.test_dataloader()) @@ -122,5 +136,9 @@ def cli_main( if __name__ == "__main__": # cli_main() - main() + # main( + # stage="eval", + # ckpt_path="logs/csv_logger/12channel/version_0/checkpoints/epoch:23-step:432-loss:0.595-iou:0.632.ckpt", + # ) + main(stage="train") print("Done!")