Source code for vision_explanation_methods.evaluation.pointing_game

# Copyright (c) Microsoft Corporation
# Licensed under the MIT License.

"""Defines a variety of explanation evaluation tools."""

from io import BytesIO
from typing import Any, List

import matplotlib.patches as patches
import numpy as np
import requests
import torch
import torchvision.transforms as T
from captum.attr import visualization as viz
from matplotlib import pyplot as pl
from ml_wrappers.common.constants import Device
from ml_wrappers.model.image_model_wrapper import _get_device
from PIL import Image
from torch import Tensor

from ..explanations import drise

NUM_MASKS = 100


[docs]class PointingGame: """A class for the high energy pointing game.""" def __init__(self, model: Any, device=Device.AUTO.value) -> None: """Initialize the PointingGame. Note: gt = ground truth. :param model: mlflow model :type model: Any :param device: optional parameter specifying the device to move the model to. If not specified, then cpu is the default :type device: str """ self._device = torch.device(_get_device(device)) self._model = model
[docs] def pointing_game(self, imagelocation: str, index: int, threshold: float = .8, num_masks: int = NUM_MASKS): """ Calculate the saliency scores for a given object detection prediction. The calculated value is a matrix of saliency scores. Values below the threshold are set to -1. The goal here is to filter out insignificant saliency scores, and identify highly salient pixels. That is why it is called a pointing game - we want to "point", i.e. identify, all highly salient pixels. That way we can easily determine if these highly salient pixels overlap with the gt bounding box. :param imagelocation: Path of the image location :type imagelocation: str :param index: Index of the desired object within the given image to evaluate :type index: int :param threshold: threshold between 0 and 1 to determine saliency of a pixel. If saliency score is below the threshold, then the score is set to -1 :type threshold: float :param num_masks: number of masks to run drise with :type num_masks: int :return: 2d matrix of highly salient pixels :rtype: List[Tensor] """ if not 0 <= threshold <= 1: raise ValueError('Threshold parameter must be a float \ between 0 and 1.') if num_masks < 0: raise ValueError('Number of masks parameter must be a \ positive int.') image_open_pointer = imagelocation if (imagelocation.startswith("http://") or imagelocation.startswith("https://")): response = requests.get(imagelocation) image_open_pointer = BytesIO(response.content) test_image = Image.open(image_open_pointer).convert('RGB') img_input = (T.ToTensor()(test_image) .unsqueeze(0).to(self._device)) detections = self._model.predict(img_input) saliency_scores = drise.DRISE_saliency( model=self._model, # Repeated the tensor to test batching image_tensor=img_input, target_detections=detections, # This is how many masks to run - # more is slower but gives higher quality mask. number_of_masks=num_masks, mask_padding=None, device=self._device, # This is the resolution of the random masks. # High resolutions will give finer masks, but more need to be # run. mask_res=(2, 2), verbose=True # Turns progress bar on/off. ) temp = saliency_scores[0][index]['detection'] temp[temp < threshold] = -1 return temp
[docs] def visualize_highly_salient_pixels(self, img, saliency_scores, gt_bbox: List = None): """ Create figure of highly salient pixels. :param img: PIL test image :type img: PIL.Image :param saliency_scores: 2D matrix representing the saliency scores of each pixel in an image :type saliency_scores: List[Tensor] :param gt_bbox: bounding box for ground truth prediction. if none then no ground truth bounding box is drawn :type gt_bbox: List :return: Overlay of the saliency scores on top of the image :rtype: Figure """ pl.clf() fig, ax = pl.subplots(1, 1, figsize=(10, 10)) viz.visualize_image_attr( np.transpose( saliency_scores.detach().cpu().numpy(), (1, 2, 0)), np.transpose(T.ToTensor()(img).detach().cpu().numpy(), (1, 2, 0)), method="blended_heat_map", sign="positive", show_colorbar=True, cmap=pl.cm.get_cmap("Blues"), title="Pointing Game Visualization", plt_fig_axis=(fig, ax), use_pyplot=False ) if gt_bbox is not None: x, y, width, height = gt_bbox rectangle = patches.Rectangle((x, y), width, height, linewidth=1, edgecolor='r', facecolor='none') # Add the rectangle patch to the axes ax.add_patch(rectangle) return fig
[docs] def calculate_gt_salient_pixel_overlap(self, saliency_scores: List[Tensor], gt_bbox: List): """ Calculate percent of overlap between salient pixels and gt bbox. Formula: number of salient pixels in the gt bbox / number of pixels in the gt bbox :param saliency_scores: 2D matrix representing the saliency scores of each pixel in an image :type saliency_scores: List[Tensor] :param gt_bbox: bounding box for ground truth prediction :type gt_bbox: List :return: return percent of salient pixel overlap with the ground truth :rtype: Float """ saliency_scores = torch.tensor(saliency_scores)[0] gt_bbox = torch.tensor(gt_bbox) gt_mask = torch.zeros_like(saliency_scores, dtype=torch.bool) gt_mask[gt_bbox[1]:gt_bbox[3]+1, gt_bbox[0]:gt_bbox[2]+1] = True positive_mask = torch.gt(saliency_scores, 0) positive_gt_mask = torch.logical_and(positive_mask, gt_mask) good = torch.count_nonzero(positive_gt_mask) total = torch.count_nonzero(gt_mask) return torch.div(good, total).item()