Source code for vision_explanation_methods.error_labeling.error_labeling

# Copyright (c) Microsoft Corporation
# Licensed under the MIT License.

"""Defines the Error Labeling Manager class."""

from copy import deepcopy
from enum import Enum

import numpy as np
import torchvision
from torch import Tensor

LABELS = 'labels'


[docs]class ErrorLabelType(Enum): """ Enum providing types of error labels. If none, then the detection is not an error. It is a correct prediction. """ # the gt doesn't have a corresponding detection MISSING = "missing" # the model predicted detections, but there was nothing there # this prediction must have a 0 iou score with all gt detections BACKGROUND = "background" # the predicted class is correct, bounding box does not have sufficient # overlap with ground truth (based on the iou threshold) LOCALIZATION = "localization" # the predicted class is incorrect, the bounding box is correct CLASS_NAME = "class_name" # both the predicted class and bounding box are incorrect CLASS_LOCALIZATION = "class_localization" # the predicted class is correct, the bounding box is correct, but # the iou score is lower than another detection DUPLICATE_DETECTION = "duplicate_detection" MATCH = "match"
[docs]class ErrorLabeling(): """ Defines a wrapper class of Error Labeling for vision scenario. Only supported for object detection at this point. """ def __init__(self, task_type: str, pred_y: list, true_y: list, iou_threshold: float = 0.5): """ Create an ErrorLabeling object. :param task_type: The task to run. :type task_type: str :param pred_y: predicted detections, nested list of 6 floats (class, bounding box, conf score). The bounding box will be located at indexes 1-4. :type pred_y: list :param true_y: ground truth detections, nested list of 6 floats (class, bounding box, is crowded). The bounding box will be located at indexes 1-4. :type true_y: list :param iou_threshold: required minimum for bounding box overlap :type iou_threshold: float """ self._is_run = False self._is_added = False self._task_type = task_type self._pred_y = pred_y self._true_y = true_y self._iou_threshold = iou_threshold
[docs] def compute_error_labels(self): """ Compute labels for errors in an object detection prediction. Note: if a row does not have a match, that means that there is a missing gt detection :return: 2d matrix of error labels :rtype: NDArray """ match_matrix = np.full((len(self._true_y), len(self._pred_y)), None) # save original ordering of predictions original_indices = [i for i, _ in sorted(enumerate(self._pred_y), key=lambda x: x[1][-1], reverse=True)] # sort predictions by decreasing conf score # this is to stay consistent with NMS and MAP algorithms sorted_list = sorted(self._pred_y, key=lambda x: x[-1], reverse=True) if len(self._true_y) == 0: match_matrix = [[ErrorLabelType.BACKGROUND] for _ in range(len(self._pred_y))] return match_matrix for gt_index, gt in enumerate(self._true_y): for detect_index, detect in enumerate(sorted_list): iou_score = torchvision.ops.box_iou( Tensor(detect[1:5]).unsqueeze(0).view(-1, 4), Tensor(gt[1:5]).unsqueeze(0).view(-1, 4)) if iou_score.item() == 0: # if iou is 0, then prediction is detecting the background match_matrix[gt_index][detect_index] = ( ErrorLabelType.BACKGROUND) elif self._iou_threshold <= iou_score: # the detection and ground truth bb's are overlapping if detect[0] != gt[0]: # the bboxes line up, but labels do not match_matrix[gt_index][detect_index] = ( ErrorLabelType.CLASS_NAME) elif (ErrorLabelType.MATCH in match_matrix[gt_index]): # class name and bbox correct, but there is already a # match with a higher confidence score (this is why # it was imporant to sort by descending confidence # scores as the first step) match_matrix[gt_index][detect_index] = ( ErrorLabelType.DUPLICATE_DETECTION) else: # this means bboxes overlap, class names = (1st time) match_matrix[gt_index][detect_index] = ( ErrorLabelType.MATCH) else: if detect[0] != gt[0]: # the bboxes don't line up, and labels do not match_matrix[gt_index][detect_index] = ( ErrorLabelType.CLASS_LOCALIZATION) else: # the bboxes don't line up, but the labels are correct match_matrix[gt_index][detect_index] = ( ErrorLabelType.LOCALIZATION) # resort the columns (so no longer ordered by descending conf # scores) match_matrix[gt_index] = [match_matrix[gt_index][i] for i in original_indices] return match_matrix
[docs] def compute_error_list(self): """ Determine a complete list of errors encountered during prediction. Note that it is possible to have more errors than actual objects in an image (because we account for missing detections and duplicate detections). """ match_matrix = self.compute_error_labels() error_arr = self._remove_matches(deepcopy(match_matrix)) dup_count = np.count_nonzero(match_matrix == ErrorLabelType.DUPLICATE_DETECTION) error_list = [ErrorLabelType.DUPLICATE_DETECTION for _ in range(dup_count)] if len(error_arr) == 0: return error_list diff = len(error_arr) - len(error_arr[0]) if diff > 0: for _ in range(diff): error_list.append(ErrorLabelType.MISSING) order_of_errors = [ErrorLabelType.CLASS_NAME, ErrorLabelType.LOCALIZATION, ErrorLabelType.CLASS_LOCALIZATION, ErrorLabelType.BACKGROUND] for err in order_of_errors: for gt_index, gt in enumerate(error_arr): for detect_index, detect in enumerate(gt): if detect == err: error_list.append(err) error_arr = self._remove_rows_cols(error_arr, set([gt_index]), set([detect_index])) if len(error_arr) == 0: break return error_list
def _remove_matches(self, arr: np.array): """ Remove match rows and columns from a error labeling matrix. :param arr: np 2d array :type arr: np.array :return: array with removed rows and columns :rtype: np.array """ rows_to_delete = set() cols_to_delete = set() for row, row_items in enumerate(arr): for col, value in enumerate(row_items): if value == ErrorLabelType.MATCH: rows_to_delete.add(row) cols_to_delete.add(col) modified_array = self._remove_rows_cols(arr, rows_to_delete, cols_to_delete) return modified_array def _remove_rows_cols(self, arr: np.array, rows_to_delete: set, cols_to_delete: set): """ Remove rows and columns from a given array. :param arr: np 2d array :type arr: np.array :param rows_to_delete: unique set of indexes of rows to remove :type rows_to_delete: set :param cols_to_delete: unique set of indexes of cols to remove :type cols_to_delete: set :return: array with removed rows and columns :rtype: np.array """ # Delete rows modified_array = [row for row_index, row in enumerate(arr) if row_index not in rows_to_delete] # Delete columns modified_array = [[value for col_index, value in enumerate(row) if col_index not in cols_to_delete] for row in modified_array] return modified_array