Shortcuts

Source code for mmdet.engine.hooks.train_augment_det_visualization_hook

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from typing import Optional

from mmengine.hooks import Hook
from mmengine.runner import Runner
from mmengine.visualization import Visualizer

from mmdet.models.utils import mask2ndarray
from mmdet.registry import HOOKS
from mmdet.structures.bbox import BaseBoxes


[docs]@HOOKS.register_module() class TrainAugmentDetVisualizationHook(Hook): """Detection Visualization Hook. Used to visualize train augmentation. In the training phase: 1. If ``show`` is True, it means that only the prediction results are visualized without storing data, so ``vis_backends`` needs to be excluded. 2. ``vis_backends`` takes effect if the user does not specify ``show``. You can set ``vis_backends`` to WandbVisBackend or TensorboardVisBackend to store the prediction result in Wandb or Tensorboard. Args: draw (bool): whether to draw prediction results. If it is False, it means that no drawing will be done. Defaults to False. interval (int): The interval of visualization. Defaults to 50. show (bool): Whether to display the drawn image. Default to False. wait_time (float): The interval of show (s). Defaults to 0. backend_args (dict, optional): Arguments to instantiate the corresponding backend. Defaults to None. """ def __init__( # noqa: PLR0913 self, draw: bool = False, interval: int = 50, show: bool = False, wait_time: float = 0.0, backend_args: Optional[dict] = None, ): self._visualizer: Visualizer = Visualizer.get_current_instance() self.interval = interval self.show = show if self.show: # No need to think about vis backends. self._visualizer._vis_backends = {} warnings.warn( 'The show is True, it means that only ' 'the prediction results are visualized ' 'without storing data, so vis_backends ' 'needs to be excluded.', stacklevel=1, ) self.wait_time = wait_time self.backend_args = backend_args self.draw = draw
[docs] def after_train_iter(self, runner: Runner, batch_idx: int, data_batch: Optional[dict] = None, outputs: Optional[dict] = None) -> None: """Regularly check whether the loss is valid every n iterations. Args: runner (:obj:`Runner`): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. data_batch (dict, Optional): Data from dataloader. Defaults to None. outputs (dict, Optional): Outputs from model. Defaults to None. """ if self.draw is False or data_batch is None: return if self.every_n_train_iters(runner, self.interval): # There is no guarantee that the same batch of images # is visualized for each evaluation. total_curr_iter = runner.iter + batch_idx # Visualize only the first data data_sample = data_batch['data_samples'][0] img_path = data_sample.img_path img = data_batch['inputs'][0].permute(1, 2, 0).numpy() gt_instances = data_sample.gt_instances gt_bboxes = gt_instances.get('bboxes', None) if gt_bboxes is not None and isinstance(gt_bboxes, BaseBoxes): gt_instances.bboxes = gt_bboxes.tensor gt_masks = gt_instances.get('masks', None) if gt_masks is not None: masks = mask2ndarray(gt_masks) gt_instances.masks = masks.astype(bool) data_sample.gt_instances = gt_instances self._visualizer.add_datasample( osp.basename(img_path) if self.show else 'train_img', img, data_sample=data_sample, show=self.show, draw_pred=False, wait_time=self.wait_time, step=total_curr_iter, )