Source code for mmdet.datasets.transforms.text_transformers
# Copyright (c) OpenMMLab. All rights reserved.
import json
from mmcv.transforms import BaseTransform
from mmdet.registry import TRANSFORMS
from mmdet.structures.bbox import BaseBoxes
try:
from transformers import AutoTokenizer
from transformers import BertModel as HFBertModel
except ImportError:
AutoTokenizer = None
HFBertModel = None
import random
import re
import numpy as np
def clean_name(name):
name = re.sub(r'\(.*\)', '', name)
name = re.sub(r'_', ' ', name)
name = re.sub(r' ', ' ', name)
name = name.lower()
return name
def check_for_positive_overflow(gt_bboxes, gt_labels, text, tokenizer,
max_tokens):
# Check if we have too many positive labels
# generate a caption by appending the positive labels
positive_label_list = np.unique(gt_labels).tolist()
# random shuffule so we can sample different annotations
# at different epochs
random.shuffle(positive_label_list)
kept_lables = []
length = 0
for index, label in enumerate(positive_label_list):
label_text = clean_name(text[str(label)]) + '. '
tokenized = tokenizer.tokenize(label_text)
length += len(tokenized)
if length > max_tokens:
break
else:
kept_lables.append(label)
keep_box_index = []
keep_gt_labels = []
for i in range(len(gt_labels)):
if gt_labels[i] in kept_lables:
keep_box_index.append(i)
keep_gt_labels.append(gt_labels[i])
return gt_bboxes[keep_box_index], np.array(
keep_gt_labels, dtype=int), length
def generate_senetence_given_labels(positive_label_list, negative_label_list,
text):
label_to_positions = {}
label_list = negative_label_list + positive_label_list
random.shuffle(label_list)
pheso_caption = ''
label_remap_dict = {}
for index, label in enumerate(label_list):
start_index = len(pheso_caption)
pheso_caption += clean_name(text[str(label)])
end_index = len(pheso_caption)
if label in positive_label_list:
label_to_positions[index] = [[start_index, end_index]]
label_remap_dict[int(label)] = index
# if index != len(label_list) - 1:
# pheso_caption += '. '
pheso_caption += '. '
return label_to_positions, pheso_caption, label_remap_dict
[docs]@TRANSFORMS.register_module()
class RandomSamplingNegPos(BaseTransform):
def __init__(self,
tokenizer_name,
num_sample_negative=85,
max_tokens=256,
full_sampling_prob=0.5,
label_map_file=None):
if AutoTokenizer is None:
raise RuntimeError(
'transformers is not installed, please install it by: '
'pip install transformers.')
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.num_sample_negative = num_sample_negative
self.full_sampling_prob = full_sampling_prob
self.max_tokens = max_tokens
self.label_map = None
if label_map_file:
with open(label_map_file, 'r') as file:
self.label_map = json.load(file)
[docs] def transform(self, results: dict) -> dict:
if 'phrases' in results:
return self.vg_aug(results)
else:
return self.od_aug(results)
def vg_aug(self, results):
gt_bboxes = results['gt_bboxes']
if isinstance(gt_bboxes, BaseBoxes):
gt_bboxes = gt_bboxes.tensor
gt_labels = results['gt_bboxes_labels']
text = results['text'].lower().strip()
if not text.endswith('.'):
text = text + '. '
phrases = results['phrases']
# TODO: add neg
positive_label_list = np.unique(gt_labels).tolist()
label_to_positions = {}
for label in positive_label_list:
label_to_positions[label] = phrases[label]['tokens_positive']
results['gt_bboxes'] = gt_bboxes
results['gt_bboxes_labels'] = gt_labels
results['text'] = text
results['tokens_positive'] = label_to_positions
return results
def od_aug(self, results):
gt_bboxes = results['gt_bboxes']
if isinstance(gt_bboxes, BaseBoxes):
gt_bboxes = gt_bboxes.tensor
gt_labels = results['gt_bboxes_labels']
if 'text' not in results:
assert self.label_map is not None
text = self.label_map
else:
text = results['text']
original_box_num = len(gt_labels)
# If the category name is in the format of 'a/b' (in object365),
# we randomly select one of them.
for key, value in text.items():
if '/' in value:
text[key] = random.choice(value.split('/')).strip()
gt_bboxes, gt_labels, positive_caption_length = \
check_for_positive_overflow(gt_bboxes, gt_labels,
text, self.tokenizer, self.max_tokens)
if len(gt_bboxes) < original_box_num:
print('WARNING: removed {} boxes due to positive caption overflow'.
format(original_box_num - len(gt_bboxes)))
valid_negative_indexes = list(text.keys())
positive_label_list = np.unique(gt_labels).tolist()
full_negative = self.num_sample_negative
if full_negative > len(valid_negative_indexes):
full_negative = len(valid_negative_indexes)
outer_prob = random.random()
if outer_prob < self.full_sampling_prob:
# c. probability_full: add both all positive and all negatives
num_negatives = full_negative
else:
if random.random() < 1.0:
num_negatives = np.random.choice(max(1, full_negative)) + 1
else:
num_negatives = full_negative
# Keep some negatives
negative_label_list = set()
if num_negatives != -1:
if num_negatives > len(valid_negative_indexes):
num_negatives = len(valid_negative_indexes)
for i in np.random.choice(
valid_negative_indexes, size=num_negatives, replace=False):
if int(i) not in positive_label_list:
negative_label_list.add(i)
random.shuffle(positive_label_list)
negative_label_list = list(negative_label_list)
random.shuffle(negative_label_list)
negative_max_length = self.max_tokens - positive_caption_length
screened_negative_label_list = []
for negative_label in negative_label_list:
label_text = clean_name(text[str(negative_label)]) + '. '
tokenized = self.tokenizer.tokenize(label_text)
negative_max_length -= len(tokenized)
if negative_max_length > 0:
screened_negative_label_list.append(negative_label)
else:
break
negative_label_list = screened_negative_label_list
label_to_positions, pheso_caption, label_remap_dict = \
generate_senetence_given_labels(positive_label_list,
negative_label_list, text)
# label remap
if len(gt_labels) > 0:
gt_labels = np.vectorize(lambda x: label_remap_dict[x])(gt_labels)
results['gt_bboxes'] = gt_bboxes
results['gt_bboxes_labels'] = gt_labels
results['text'] = pheso_caption
results['tokens_positive'] = label_to_positions
return results
[docs]@TRANSFORMS.register_module()
class LoadTextAnnotations(BaseTransform):
[docs] def transform(self, results: dict) -> dict:
if 'phrases' in results:
tokens_positive = [
phrase['tokens_positive']
for phrase in results['phrases'].values()
]
results['tokens_positive'] = tokens_positive
else:
text = results['text']
results['text'] = list(text.values())
return results