66
66
from preprocessing .encoding .datasets .patched_wsi_inference import PatchedWSIInference
67
67
from utils .file_handling import load_wsi_files_from_csv
68
68
from utils .logger import Logger
69
- from utils .tools import unflatten_dict
69
+ from utils .tools import unflatten_dict , get_size_of_dict
70
70
71
71
warnings .filterwarnings ("ignore" , category = ShapelyDeprecationWarning )
72
72
pandarallel .initialize (progress_bar = False , nb_workers = 12 )
89
89
5 : "Epithelial" ,
90
90
}
91
91
92
-
93
92
class CellSegmentationInference :
94
93
def __init__ (
95
94
self ,
@@ -300,11 +299,15 @@ def process_wsi(
300
299
"metadata" : {"wsi_metadata" : wsi .metadata , "nuclei_types" : nuclei_types },
301
300
}
302
301
processed_patches = []
302
+
303
+ memory_usage = 0
304
+ cell_count = 0
303
305
304
306
with torch .no_grad ():
305
- for batch in tqdm .tqdm (
306
- wsi_inference_dataloader , total = len (wsi_inference_dataloader )
307
- ):
307
+
308
+ pbar = tqdm .tqdm (wsi_inference_dataloader , total = len (wsi_inference_dataset ))
309
+
310
+ for batch in wsi_inference_dataloader :
308
311
patches = batch [0 ].to (self .device )
309
312
310
313
metadata = batch [1 ]
@@ -323,6 +326,7 @@ def process_wsi(
323
326
for idx , (patch_instance_types , patch_metadata ) in enumerate (
324
327
zip (instance_types , metadata )
325
328
):
329
+ pbar .update (1 )
326
330
# add global patch metadata
327
331
patch_cell_detection = {}
328
332
patch_cell_detection ["patch_metadata" ] = patch_metadata
@@ -368,10 +372,6 @@ def process_wsi(
368
372
cell ["bbox" ], 1024 , 64
369
373
),
370
374
"offset_global" : offset_global .tolist ()
371
- # optional: Local positional information
372
- # "bbox_local": cell["bbox"].tolist(),
373
- # "centroid_local": cell["centroid"].tolist(),
374
- # "contour_local": cell["contour"].tolist(),
375
375
}
376
376
cell_detection = {
377
377
"bbox" : bbox_global .tolist (),
@@ -413,6 +413,14 @@ def process_wsi(
413
413
graph_data ["positions" ].append (torch .Tensor (centroid_global ))
414
414
graph_data ["contours" ].append (torch .Tensor (contour_global ))
415
415
416
+ cell_count = cell_count + 1
417
+ # dict sizes
418
+ memory_usage = memory_usage + get_size_of_dict (cell_dict )/ (1024 * 1024 ) + get_size_of_dict (cell_detection )/ (1024 * 1024 ) # + sys.getsizeof(cell_token)/(1024*1024)
419
+ # pytorch
420
+ memory_usage = memory_usage + (cell_token .nelement () * cell_token .element_size ())/ (1024 * 1024 ) + centroid_global .nbytes / (1024 * 1024 ) + contour_global .nbytes / (1024 * 1024 )
421
+
422
+ pbar .set_postfix (Cells = cell_count , Memory = f"{ memory_usage :.2f} MB" )
423
+
416
424
# post processing
417
425
self .logger .info (f"Detected cells before cleaning: { len (cell_dict_wsi )} " )
418
426
keep_idx = self .post_process_edge_cells (cell_list = cell_dict_wsi )
0 commit comments