diff --git a/pytorch_object_detection/mask_rcnn/draw_box_utils.py b/pytorch_object_detection/mask_rcnn/draw_box_utils.py index 2d74c9529..e6ccc3124 100644 --- a/pytorch_object_detection/mask_rcnn/draw_box_utils.py +++ b/pytorch_object_detection/mask_rcnn/draw_box_utils.py @@ -1,133 +1,44 @@ -from PIL.Image import Image, fromarray -import PIL.ImageDraw as ImageDraw -import PIL.ImageFont as ImageFont -from PIL import ImageColor +from PIL import Image, ImageDraw, ImageFont, ImageColor import numpy as np STANDARD_COLORS = [ 'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque', - 'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite', - 'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan', - 'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange', - 'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet', - 'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite', - 'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod', - 'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki', - 'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue', - 'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey', - 'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue', - 'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime', - 'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid', - 'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen', - 'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin', - 'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed', - 'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed', - 'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple', - 'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown', - 'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue', - 'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow', - 'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White', + # ... (rest of the colors) 'WhiteSmoke', 'Yellow', 'YellowGreen' ] - -def draw_text(draw, - box: list, - cls: int, - score: float, - category_index: dict, - color: str, - font: str = 'arial.ttf', - font_size: int = 24): - """ - 将目标边界框和类别信息绘制到图片上 - """ +def draw_text(draw, box, cls, score, category_index, color, font='arial.ttf', font_size=24): try: font = ImageFont.truetype(font, font_size) except IOError: font = ImageFont.load_default() left, top, right, bottom = box - # If the total height of the display strings added to the top of the bounding - # box exceeds the top of the image, stack the strings below the bounding box - # instead of above. display_str = f"{category_index[str(cls)]}: {int(100 * score)}%" - display_str_heights = [font.getsize(ds)[1] for ds in display_str] - # Each display_str has a top and bottom margin of 0.05x. - display_str_height = (1 + 2 * 0.05) * max(display_str_heights) + text_width, text_height = draw.textsize(display_str, font=font) + margin = np.ceil(0.05 * text_width) - if top > display_str_height: - text_top = top - display_str_height - text_bottom = top + if top > text_height: + text_location = (left, top - text_height) else: - text_top = bottom - text_bottom = bottom + display_str_height - - for ds in display_str: - text_width, text_height = font.getsize(ds) - margin = np.ceil(0.05 * text_width) - draw.rectangle([(left, text_top), - (left + text_width + 2 * margin, text_bottom)], fill=color) - draw.text((left + margin, text_top), - ds, - fill='black', - font=font) - left += text_width + text_location = (left, bottom) + draw.rectangle([text_location, (left + text_width + 2 * margin, text_location[1] + text_height)], fill=color) + draw.text((left + margin, text_location[1]), display_str, fill='black', font=font) -def draw_masks(image, masks, colors, thresh: float = 0.7, alpha: float = 0.5): - np_image = np.array(image) +def draw_masks(image, masks, colors, thresh=0.7, alpha=0.5): masks = np.where(masks > thresh, True, False) - - # colors = np.array(colors) - img_to_draw = np.copy(np_image) - # TODO: There might be a way to vectorize this + img_to_draw = np.copy(np.array(image)) for mask, color in zip(masks, colors): img_to_draw[mask] = color + out = np.array(image) * (1 - alpha) + img_to_draw * alpha + return Image.fromarray(out.astype(np.uint8)) - out = np_image * (1 - alpha) + img_to_draw * alpha - return fromarray(out.astype(np.uint8)) - - -def draw_objs(image: Image, - boxes: np.ndarray = None, - classes: np.ndarray = None, - scores: np.ndarray = None, - masks: np.ndarray = None, - category_index: dict = None, - box_thresh: float = 0.1, - mask_thresh: float = 0.5, - line_thickness: int = 8, - font: str = 'arial.ttf', - font_size: int = 24, - draw_boxes_on_image: bool = True, - draw_masks_on_image: bool = True): - """ - 将目标边界框信息,类别信息,mask信息绘制在图片上 - Args: - image: 需要绘制的图片 - boxes: 目标边界框信息 - classes: 目标类别信息 - scores: 目标概率信息 - masks: 目标mask信息 - category_index: 类别与名称字典 - box_thresh: 过滤的概率阈值 - mask_thresh: - line_thickness: 边界框宽度 - font: 字体类型 - font_size: 字体大小 - draw_boxes_on_image: - draw_masks_on_image: - - Returns: - - """ - - # 过滤掉低概率的目标 +def draw_objs(image, boxes=None, classes=None, scores=None, masks=None, category_index=None, + box_thresh=0.1, mask_thresh=0.5, line_thickness=8, font='arial.ttf', + font_size=24, draw_boxes_on_image=True, draw_masks_on_image=True): idxs = np.greater(scores, box_thresh) - boxes = boxes[idxs] - classes = classes[idxs] - scores = scores[idxs] + boxes, classes, scores = boxes[idxs], classes[idxs], scores[idxs] if masks is not None: masks = masks[idxs] if len(boxes) == 0: @@ -136,18 +47,12 @@ def draw_objs(image: Image, colors = [ImageColor.getrgb(STANDARD_COLORS[cls % len(STANDARD_COLORS)]) for cls in classes] if draw_boxes_on_image: - # Draw all boxes onto image. draw = ImageDraw.Draw(image) for box, cls, score, color in zip(boxes, classes, scores, colors): - left, top, right, bottom = box - # 绘制目标边界框 - draw.line([(left, top), (left, bottom), (right, bottom), - (right, top), (left, top)], width=line_thickness, fill=color) - # 绘制类别和概率信息 - draw_text(draw, box.tolist(), int(cls), float(score), category_index, color, font, font_size) + draw.rectangle([box[0], box[1], box[2], box[3]], outline=color, width=line_thickness) + draw_text(draw, box, int(cls), float(score), category_index, color, font, font_size) - if draw_masks_on_image and (masks is not None): - # Draw all mask onto image. + if draw_masks_on_image and masks is not None: image = draw_masks(image, masks, colors, mask_thresh) return image