Shortcuts

Source code for mmdet.datasets.odvg

# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp
from typing import List, Optional

from mmengine.fileio import get_local_path

from mmdet.registry import DATASETS
from .base_det_dataset import BaseDetDataset


[docs]@DATASETS.register_module() class ODVGDataset(BaseDetDataset): """Object detection and visual grounding dataset.""" def __init__(self, *args, data_root: str = '', label_map_file: Optional[str] = None, need_text: bool = True, **kwargs) -> None: self.dataset_mode = 'VG' self.need_text = need_text if label_map_file: label_map_file = osp.join(data_root, label_map_file) with open(label_map_file, 'r') as file: self.label_map = json.load(file) self.dataset_mode = 'OD' super().__init__(*args, data_root=data_root, **kwargs) assert self.return_classes is True
[docs] def load_data_list(self) -> List[dict]: with get_local_path( self.ann_file, backend_args=self.backend_args) as local_path: with open(local_path, 'r') as f: data_list = [json.loads(line) for line in f] out_data_list = [] for data in data_list: data_info = {} img_path = osp.join(self.data_prefix['img'], data['filename']) data_info['img_path'] = img_path data_info['height'] = data['height'] data_info['width'] = data['width'] if self.dataset_mode == 'OD': if self.need_text: data_info['text'] = self.label_map anno = data.get('detection', {}) instances = [obj for obj in anno.get('instances', [])] bboxes = [obj['bbox'] for obj in instances] bbox_labels = [str(obj['label']) for obj in instances] instances = [] for bbox, label in zip(bboxes, bbox_labels): instance = {} x1, y1, x2, y2 = bbox inter_w = max(0, min(x2, data['width']) - max(x1, 0)) inter_h = max(0, min(y2, data['height']) - max(y1, 0)) if inter_w * inter_h == 0: continue if (x2 - x1) < 1 or (y2 - y1) < 1: continue instance['ignore_flag'] = 0 instance['bbox'] = bbox instance['bbox_label'] = int(label) instances.append(instance) data_info['instances'] = instances data_info['dataset_mode'] = self.dataset_mode out_data_list.append(data_info) else: anno = data['grounding'] data_info['text'] = anno['caption'] regions = anno['regions'] instances = [] phrases = {} for i, region in enumerate(regions): bbox = region['bbox'] phrase = region['phrase'] tokens_positive = region['tokens_positive'] if not isinstance(bbox[0], list): bbox = [bbox] for box in bbox: instance = {} x1, y1, x2, y2 = box inter_w = max(0, min(x2, data['width']) - max(x1, 0)) inter_h = max(0, min(y2, data['height']) - max(y1, 0)) if inter_w * inter_h == 0: continue if (x2 - x1) < 1 or (y2 - y1) < 1: continue instance['ignore_flag'] = 0 instance['bbox'] = box instance['bbox_label'] = i phrases[i] = { 'phrase': phrase, 'tokens_positive': tokens_positive } instances.append(instance) data_info['instances'] = instances data_info['phrases'] = phrases data_info['dataset_mode'] = self.dataset_mode out_data_list.append(data_info) del data_list return out_data_list