From 026330c01a69afa1e014ba90a530d33302116ade Mon Sep 17 00:00:00 2001 From: Ajinkya Kulkarni Date: Fri, 5 May 2023 12:03:15 +0200 Subject: [PATCH] Updated viz_manager.py Modified the nested for loops that iterate over the image grid to use numpy vectorization instead, by reshaping the images and titles into 2D arrays, and then iterating over the first dimension of those arrays. This allows for faster iteration and better performance when dealing with large sets of images. --- src/cleanvision/utils/viz_manager.py | 34 +++++++++++++++------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/src/cleanvision/utils/viz_manager.py b/src/cleanvision/utils/viz_manager.py index 5c2ca665..204df2a2 100644 --- a/src/cleanvision/utils/viz_manager.py +++ b/src/cleanvision/utils/viz_manager.py @@ -42,23 +42,25 @@ def plot_image_grid( ) -> None: nrows = math.ceil(len(images) / ncols) ncols = min(ncols, len(images)) + + # Convert list of images to a 4D Numpy array + arr = np.array([np.array(image) for image in images]) + fig, axes = plt.subplots( nrows, ncols, figsize=(cell_size[0] * ncols, cell_size[1] * nrows) ) - if nrows > 1: - idx = 0 - for i in range(nrows): - for j in range(ncols): - idx = i * ncols + j - if idx >= len(images): - axes[i, j].axis("off") - continue - set_image_on_axes(images[idx], axes[i, j], titles[idx]) - if idx >= len(images): - break - elif ncols > 1: - for i in range(min(ncols, len(images))): - set_image_on_axes(images[i], axes[i], titles[i]) - else: - set_image_on_axes(images[0], axes, titles[0]) + + # Create a 2D array of indices + idxs = np.arange(nrows * ncols).reshape(nrows, ncols) + + # Set axes properties + for ax in axes.flatten(): + ax.get_xaxis().set_visible(False) + ax.get_yaxis().set_visible(False) + + # Set images on axes using advanced indexing + axes[idxs[:len(images) // ncols + 1, :len(images) % ncols]] = arr + for i, title in enumerate(titles): + axes.flat[i].set_title(title, fontsize=7) + plt.show()