Shortcuts

Source code for mmdet.evaluation.metrics.grefcoco_metric

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence

import numpy as np
import torch
from mmengine.evaluator import BaseMetric
from mmengine.fileio import get_local_path
from mmengine.logging import MMLogger

from mmdet.datasets.api_wrappers import COCO
from mmdet.registry import METRICS
from ..functional import bbox_overlaps


# refer from https://github.com/henghuiding/gRefCOCO/blob/main/mdetr/datasets/refexp.py # noqa
[docs]@METRICS.register_module() class gRefCOCOMetric(BaseMetric): default_prefix: Optional[str] = 'grefcoco' def __init__(self, ann_file: Optional[str] = None, metric: str = 'bbox', iou_thrs: float = 0.5, thresh_score: float = 0.7, thresh_f1: float = 1.0, **kwargs) -> None: super().__init__(**kwargs) self.metric = metric self.iou_thrs = iou_thrs self.thresh_score = thresh_score self.thresh_f1 = thresh_f1 with get_local_path(ann_file) as local_path: self.coco = COCO(local_path)
[docs] def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: for data_sample in data_samples: result = dict() pred = data_sample['pred_instances'] result['img_id'] = data_sample['img_id'] result['bboxes'] = pred['bboxes'].cpu() result['scores'] = pred['scores'].cpu() self.results.append(result)
[docs] def compute_metrics(self, results: list) -> Dict[str, float]: logger: MMLogger = MMLogger.get_current_instance() correct_image = 0 num_image = 0 nt = {'TP': 0, 'TN': 0, 'FP': 0, 'FN': 0} for result in results: img_id = result['img_id'] TP = 0 ann_ids = self.coco.getAnnIds(imgIds=img_id) target = self.coco.loadAnns(ann_ids[0]) converted_bbox_all = [] no_target_flag = False for one_target in target: if one_target['category_id'] == -1: no_target_flag = True target_bbox = one_target['bbox'] converted_bbox = [ target_bbox[0], target_bbox[1], target_bbox[2] + target_bbox[0], target_bbox[3] + target_bbox[1], ] converted_bbox_all.append( np.array(converted_bbox).reshape(-1, 4)) gt_bbox_all = np.concatenate(converted_bbox_all, axis=0) idx = result['scores'] >= self.thresh_score filtered_boxes = result['bboxes'][idx] iou = bbox_overlaps(filtered_boxes.numpy(), gt_bbox_all) iou = torch.from_numpy(iou) num_prediction = filtered_boxes.shape[0] num_gt = gt_bbox_all.shape[0] if no_target_flag: if num_prediction >= 1: nt['FN'] += 1 else: nt['TP'] += 1 if num_prediction >= 1: f_1 = 0. else: f_1 = 1.0 else: if num_prediction >= 1: nt['TN'] += 1 else: nt['FP'] += 1 for i in range(min(num_prediction, num_gt)): top_value, top_index = torch.topk(iou.flatten(0, 1), 1) if top_value < self.iou_thrs: break else: top_index_x = top_index // num_gt top_index_y = top_index % num_gt TP += 1 iou[top_index_x[0], :] = 0.0 iou[:, top_index_y[0]] = 0.0 FP = num_prediction - TP FN = num_gt - TP f_1 = 2 * TP / (2 * TP + FP + FN) if f_1 >= self.thresh_f1: correct_image += 1 num_image += 1 score = correct_image / max(num_image, 1) results = { 'F1_score': score, 'T_acc': nt['TN'] / (nt['TN'] + nt['FP']), 'N_acc': nt['TP'] / (nt['TP'] + nt['FN']) } logger.info(results) return results