Shortcuts

Source code for mmdet.datasets.transforms.augment_wrappers

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union

import numpy as np
from mmcv.transforms import RandomChoice
from mmcv.transforms.utils import cache_randomness
from mmengine.config import ConfigDict

from mmdet.registry import TRANSFORMS

# AutoAugment uses reinforcement learning to search for
# some widely useful data augmentation strategies,
# here we provide AUTOAUG_POLICIES_V0.
# For AUTOAUG_POLICIES_V0, each tuple is an augmentation
# operation of the form (operation, probability, magnitude).
# Each element in policies is a policy that will be applied
# sequentially on the image.

# RandAugment defines a data augmentation search space, RANDAUG_SPACE,
# sampling 1~3 data augmentations each time, and
# setting the magnitude of each data augmentation randomly,
# which will be applied sequentially on the image.

_MAX_LEVEL = 10

AUTOAUG_POLICIES_V0 = [
    [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
    [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
    [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
    [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
    [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
    [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
    [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
    [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
    [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
    [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
    [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
    [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
    [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
    [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
    [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
    [('Rotate', 1.0, 7), ('TranslateY', 0.8, 9)],
    [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
    [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
    [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
    [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
    [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
    [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
    [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
    [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
    [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
]


def policies_v0():
    """Autoaugment policies that was used in AutoAugment Paper."""
    policies = list()
    for policy_args in AUTOAUG_POLICIES_V0:
        policy = list()
        for args in policy_args:
            policy.append(dict(type=args[0], prob=args[1], level=args[2]))
        policies.append(policy)
    return policies


RANDAUG_SPACE = [[dict(type='AutoContrast')], [dict(type='Equalize')],
                 [dict(type='Invert')], [dict(type='Rotate')],
                 [dict(type='Posterize')], [dict(type='Solarize')],
                 [dict(type='SolarizeAdd')], [dict(type='Color')],
                 [dict(type='Contrast')], [dict(type='Brightness')],
                 [dict(type='Sharpness')], [dict(type='ShearX')],
                 [dict(type='ShearY')], [dict(type='TranslateX')],
                 [dict(type='TranslateY')]]


def level_to_mag(level: Optional[int], min_mag: float,
                 max_mag: float) -> float:
    """Map from level to magnitude."""
    if level is None:
        return round(np.random.rand() * (max_mag - min_mag) + min_mag, 1)
    else:
        return round(level / _MAX_LEVEL * (max_mag - min_mag) + min_mag, 1)


[docs]@TRANSFORMS.register_module() class AutoAugment(RandomChoice): """Auto augmentation. This data augmentation is proposed in `AutoAugment: Learning Augmentation Policies from Data <https://arxiv.org/abs/1805.09501>`_ and in `Learning Data Augmentation Strategies for Object Detection <https://arxiv.org/pdf/1906.11172>`_. Required Keys: - img - gt_bboxes (BaseBoxes[torch.float32]) (optional) - gt_bboxes_labels (np.int64) (optional) - gt_masks (BitmapMasks | PolygonMasks) (optional) - gt_ignore_flags (bool) (optional) - gt_seg_map (np.uint8) (optional) Modified Keys: - img - img_shape - gt_bboxes - gt_bboxes_labels - gt_masks - gt_ignore_flags - gt_seg_map Added Keys: - homography_matrix Args: policies (List[List[Union[dict, ConfigDict]]]): The policies of auto augmentation.Each policy in ``policies`` is a specific augmentation policy, and is composed by several augmentations. When AutoAugment is called, a random policy in ``policies`` will be selected to augment images. Defaults to policy_v0(). prob (list[float], optional): The probabilities associated with each policy. The length should be equal to the policy number and the sum should be 1. If not given, a uniform distribution will be assumed. Defaults to None. Examples: >>> policies = [ >>> [ >>> dict(type='Sharpness', prob=0.0, level=8), >>> dict(type='ShearX', prob=0.4, level=0,) >>> ], >>> [ >>> dict(type='Rotate', prob=0.6, level=10), >>> dict(type='Color', prob=1.0, level=6) >>> ] >>> ] >>> augmentation = AutoAugment(policies) >>> img = np.ones(100, 100, 3) >>> gt_bboxes = np.ones(10, 4) >>> results = dict(img=img, gt_bboxes=gt_bboxes) >>> results = augmentation(results) """ def __init__(self, policies: List[List[Union[dict, ConfigDict]]] = policies_v0(), prob: Optional[List[float]] = None) -> None: assert isinstance(policies, list) and len(policies) > 0, \ 'Policies must be a non-empty list.' for policy in policies: assert isinstance(policy, list) and len(policy) > 0, \ 'Each policy in policies must be a non-empty list.' for augment in policy: assert isinstance(augment, dict) and 'type' in augment, \ 'Each specific augmentation must be a dict with key' \ ' "type".' super().__init__(transforms=policies, prob=prob) self.policies = policies def __repr__(self) -> str: return f'{self.__class__.__name__}(policies={self.policies}, ' \ f'prob={self.prob})'
[docs]@TRANSFORMS.register_module() class RandAugment(RandomChoice): """Rand augmentation. This data augmentation is proposed in `RandAugment: Practical automated data augmentation with a reduced search space <https://arxiv.org/abs/1909.13719>`_. Required Keys: - img - gt_bboxes (BaseBoxes[torch.float32]) (optional) - gt_bboxes_labels (np.int64) (optional) - gt_masks (BitmapMasks | PolygonMasks) (optional) - gt_ignore_flags (bool) (optional) - gt_seg_map (np.uint8) (optional) Modified Keys: - img - img_shape - gt_bboxes - gt_bboxes_labels - gt_masks - gt_ignore_flags - gt_seg_map Added Keys: - homography_matrix Args: aug_space (List[List[Union[dict, ConfigDict]]]): The augmentation space of rand augmentation. Each augmentation transform in ``aug_space`` is a specific transform, and is composed by several augmentations. When RandAugment is called, a random transform in ``aug_space`` will be selected to augment images. Defaults to aug_space. aug_num (int): Number of augmentation to apply equentially. Defaults to 2. prob (list[float], optional): The probabilities associated with each augmentation. The length should be equal to the augmentation space and the sum should be 1. If not given, a uniform distribution will be assumed. Defaults to None. Examples: >>> aug_space = [ >>> dict(type='Sharpness'), >>> dict(type='ShearX'), >>> dict(type='Color'), >>> ], >>> augmentation = RandAugment(aug_space) >>> img = np.ones(100, 100, 3) >>> gt_bboxes = np.ones(10, 4) >>> results = dict(img=img, gt_bboxes=gt_bboxes) >>> results = augmentation(results) """ def __init__(self, aug_space: List[Union[dict, ConfigDict]] = RANDAUG_SPACE, aug_num: int = 2, prob: Optional[List[float]] = None) -> None: assert isinstance(aug_space, list) and len(aug_space) > 0, \ 'Augmentation space must be a non-empty list.' for aug in aug_space: assert isinstance(aug, list) and len(aug) == 1, \ 'Each augmentation in aug_space must be a list.' for transform in aug: assert isinstance(transform, dict) and 'type' in transform, \ 'Each specific transform must be a dict with key' \ ' "type".' super().__init__(transforms=aug_space, prob=prob) self.aug_space = aug_space self.aug_num = aug_num @cache_randomness def random_pipeline_index(self): indices = np.arange(len(self.transforms)) return np.random.choice( indices, self.aug_num, p=self.prob, replace=False)
[docs] def transform(self, results: dict) -> dict: """Transform function to use RandAugment. Args: results (dict): Result dict from loading pipeline. Returns: dict: Result dict with RandAugment. """ for idx in self.random_pipeline_index(): results = self.transforms[idx](results) return results
def __repr__(self) -> str: return f'{self.__class__.__name__}(' \ f'aug_space={self.aug_space}, '\ f'aug_num={self.aug_num}, ' \ f'prob={self.prob})'