Source code for mmdet.models.detectors.trident_faster_rcnn
# Copyright (c) OpenMMLab. All rights reserved.
from torch import Tensor
from mmdet.registry import MODELS
from mmdet.structures import SampleList
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
from .faster_rcnn import FasterRCNN
[docs]@MODELS.register_module()
class TridentFasterRCNN(FasterRCNN):
"""Implementation of `TridentNet <https://arxiv.org/abs/1901.01892>`_"""
def __init__(self,
backbone: ConfigType,
rpn_head: ConfigType,
roi_head: ConfigType,
train_cfg: ConfigType,
test_cfg: ConfigType,
neck: OptConfigType = None,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None) -> None:
super().__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)
assert self.backbone.num_branch == self.roi_head.num_branch
assert self.backbone.test_branch_idx == self.roi_head.test_branch_idx
self.num_branch = self.backbone.num_branch
self.test_branch_idx = self.backbone.test_branch_idx
def _forward(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> tuple:
"""Copy the ``batch_data_samples`` to fit multi-branch."""
num_branch = self.num_branch \
if self.training or self.test_branch_idx == -1 else 1
trident_data_samples = batch_data_samples * num_branch
return super()._forward(
batch_inputs=batch_inputs, batch_data_samples=trident_data_samples)
[docs] def loss(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> dict:
"""Copy the ``batch_data_samples`` to fit multi-branch."""
num_branch = self.num_branch \
if self.training or self.test_branch_idx == -1 else 1
trident_data_samples = batch_data_samples * num_branch
return super().loss(
batch_inputs=batch_inputs, batch_data_samples=trident_data_samples)
[docs] def predict(self,
batch_inputs: Tensor,
batch_data_samples: SampleList,
rescale: bool = True) -> SampleList:
"""Copy the ``batch_data_samples`` to fit multi-branch."""
num_branch = self.num_branch \
if self.training or self.test_branch_idx == -1 else 1
trident_data_samples = batch_data_samples * num_branch
return super().predict(
batch_inputs=batch_inputs,
batch_data_samples=trident_data_samples,
rescale=rescale)
# TODO need to refactor
[docs] def aug_test(self, imgs, img_metas, rescale=False):
"""Test with augmentations.
If rescale is False, then returned bboxes and masks will fit the scale
of imgs[0].
"""
x = self.extract_feats(imgs)
num_branch = (self.num_branch if self.test_branch_idx == -1 else 1)
trident_img_metas = [img_metas * num_branch for img_metas in img_metas]
proposal_list = self.rpn_head.aug_test_rpn(x, trident_img_metas)
return self.roi_head.aug_test(
x, proposal_list, img_metas, rescale=rescale)