Skip to content

Commit 06f9a99

Browse files
committed
Memory monitoring
1 parent 20e7188 commit 06f9a99

File tree

3 files changed

+1553
-9
lines changed

3 files changed

+1553
-9
lines changed

cell_segmentation/inference/cell_detection.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
from preprocessing.encoding.datasets.patched_wsi_inference import PatchedWSIInference
6767
from utils.file_handling import load_wsi_files_from_csv
6868
from utils.logger import Logger
69-
from utils.tools import unflatten_dict
69+
from utils.tools import unflatten_dict, get_size_of_dict
7070

7171
warnings.filterwarnings("ignore", category=ShapelyDeprecationWarning)
7272
pandarallel.initialize(progress_bar=False, nb_workers=12)
@@ -89,7 +89,6 @@
8989
5: "Epithelial",
9090
}
9191

92-
9392
class CellSegmentationInference:
9493
def __init__(
9594
self,
@@ -300,11 +299,15 @@ def process_wsi(
300299
"metadata": {"wsi_metadata": wsi.metadata, "nuclei_types": nuclei_types},
301300
}
302301
processed_patches = []
302+
303+
memory_usage = 0
304+
cell_count = 0
303305

304306
with torch.no_grad():
305-
for batch in tqdm.tqdm(
306-
wsi_inference_dataloader, total=len(wsi_inference_dataloader)
307-
):
307+
308+
pbar = tqdm.tqdm(wsi_inference_dataloader, total=len(wsi_inference_dataset))
309+
310+
for batch in wsi_inference_dataloader:
308311
patches = batch[0].to(self.device)
309312

310313
metadata = batch[1]
@@ -323,6 +326,7 @@ def process_wsi(
323326
for idx, (patch_instance_types, patch_metadata) in enumerate(
324327
zip(instance_types, metadata)
325328
):
329+
pbar.update(1)
326330
# add global patch metadata
327331
patch_cell_detection = {}
328332
patch_cell_detection["patch_metadata"] = patch_metadata
@@ -368,10 +372,6 @@ def process_wsi(
368372
cell["bbox"], 1024, 64
369373
),
370374
"offset_global": offset_global.tolist()
371-
# optional: Local positional information
372-
# "bbox_local": cell["bbox"].tolist(),
373-
# "centroid_local": cell["centroid"].tolist(),
374-
# "contour_local": cell["contour"].tolist(),
375375
}
376376
cell_detection = {
377377
"bbox": bbox_global.tolist(),
@@ -413,6 +413,14 @@ def process_wsi(
413413
graph_data["positions"].append(torch.Tensor(centroid_global))
414414
graph_data["contours"].append(torch.Tensor(contour_global))
415415

416+
cell_count = cell_count + 1
417+
# dict sizes
418+
memory_usage = memory_usage + get_size_of_dict(cell_dict)/(1024*1024) + get_size_of_dict(cell_detection)/(1024*1024) # + sys.getsizeof(cell_token)/(1024*1024)
419+
# pytorch
420+
memory_usage = memory_usage + (cell_token.nelement() * cell_token.element_size())/(1024*1024) + centroid_global.nbytes/(1024*1024) + contour_global.nbytes/(1024*1024)
421+
422+
pbar.set_postfix(Cells=cell_count, Memory=f"{memory_usage:.2f} MB")
423+
416424
# post processing
417425
self.logger.info(f"Detected cells before cleaning: {len(cell_dict_wsi)}")
418426
keep_idx = self.post_process_edge_cells(cell_list=cell_dict_wsi)

0 commit comments

Comments
 (0)