Source code for mmdet.evaluation.metrics.refexp_metric
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence
import numpy as np
from mmengine.evaluator import BaseMetric
from mmengine.fileio import get_local_path
from mmengine.logging import MMLogger
from mmdet.datasets.api_wrappers import COCO
from mmdet.registry import METRICS
from ..functional import bbox_overlaps
[docs]@METRICS.register_module()
class RefExpMetric(BaseMetric):
default_prefix: Optional[str] = 'refexp'
def __init__(self,
ann_file: Optional[str] = None,
metric: str = 'bbox',
topk=(1, 5, 10),
iou_thrs: float = 0.5,
**kwargs) -> None:
super().__init__(**kwargs)
self.metric = metric
self.topk = topk
self.iou_thrs = iou_thrs
with get_local_path(ann_file) as local_path:
self.coco = COCO(local_path)
[docs] def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
for data_sample in data_samples:
result = dict()
pred = data_sample['pred_instances']
result['img_id'] = data_sample['img_id']
result['bboxes'] = pred['bboxes'].cpu().numpy()
result['scores'] = pred['scores'].cpu().numpy()
self.results.append(result)
[docs] def compute_metrics(self, results: list) -> Dict[str, float]:
logger: MMLogger = MMLogger.get_current_instance()
dataset2score = {
'refcoco': {
k: 0.0
for k in self.topk
},
'refcoco+': {
k: 0.0
for k in self.topk
},
'refcocog': {
k: 0.0
for k in self.topk
},
}
dataset2count = {'refcoco': 0.0, 'refcoco+': 0.0, 'refcocog': 0.0}
for result in results:
img_id = result['img_id']
ann_ids = self.coco.getAnnIds(imgIds=img_id)
assert len(ann_ids) == 1
img_info = self.coco.loadImgs(img_id)[0]
target = self.coco.loadAnns(ann_ids[0])
target_bbox = target[0]['bbox']
converted_bbox = [
target_bbox[0],
target_bbox[1],
target_bbox[2] + target_bbox[0],
target_bbox[3] + target_bbox[1],
]
iou = bbox_overlaps(result['bboxes'],
np.array(converted_bbox).reshape(-1, 4))
for k in self.topk:
if max(iou[:k]) >= self.iou_thrs:
dataset2score[img_info['dataset_name']][k] += 1.0
dataset2count[img_info['dataset_name']] += 1.0
for key, value in dataset2score.items():
for k in self.topk:
try:
value[k] /= dataset2count[key]
except Exception as e:
print(e)
results = {}
mean_precision = 0.0
for key, value in dataset2score.items():
results[key] = sorted([v for k, v in value.items()])
mean_precision += sum(results[key])
logger.info(
f' Dataset: {key} - Precision @ 1, 5, 10: {results[key]}')
# `mean_precision` key is used for saving the best checkpoint
out_results = {'mean_precision': mean_precision / 9.0}
for i, k in enumerate(self.topk):
out_results[f'refcoco_precision@{k}'] = results['refcoco'][i]
for i, k in enumerate(self.topk):
out_results[f'refcoco+_precision@{k}'] = results['refcoco+'][i]
for i, k in enumerate(self.topk):
out_results[f'refcocog_precision@{k}'] = results['refcocog'][i]
return out_results