Shortcuts

Source code for mmdet.models.losses.varifocal_loss

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from mmdet.registry import MODELS
from .utils import weight_reduce_loss


def varifocal_loss(pred: Tensor,
                   target: Tensor,
                   weight: Optional[Tensor] = None,
                   alpha: float = 0.75,
                   gamma: float = 2.0,
                   iou_weighted: bool = True,
                   reduction: str = 'mean',
                   avg_factor: Optional[int] = None) -> Tensor:
    """`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_

    Args:
        pred (Tensor): The prediction with shape (N, C), C is the
            number of classes.
        target (Tensor): The learning target of the iou-aware
            classification score with shape (N, C), C is the number of classes.
        weight (Tensor, optional): The weight of loss for each
            prediction. Defaults to None.
        alpha (float, optional): A balance factor for the negative part of
            Varifocal Loss, which is different from the alpha of Focal Loss.
            Defaults to 0.75.
        gamma (float, optional): The gamma for calculating the modulating
            factor. Defaults to 2.0.
        iou_weighted (bool, optional): Whether to weight the loss of the
            positive example with the iou target. Defaults to True.
        reduction (str, optional): The method used to reduce the loss into
            a scalar. Defaults to 'mean'. Options are "none", "mean" and
            "sum".
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.

    Returns:
        Tensor: Loss tensor.
    """
    # pred and target should be of the same size
    assert pred.size() == target.size()
    pred_sigmoid = pred.sigmoid()
    target = target.type_as(pred)
    if iou_weighted:
        focal_weight = target * (target > 0.0).float() + \
            alpha * (pred_sigmoid - target).abs().pow(gamma) * \
            (target <= 0.0).float()
    else:
        focal_weight = (target > 0.0).float() + \
            alpha * (pred_sigmoid - target).abs().pow(gamma) * \
            (target <= 0.0).float()
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss


[docs]@MODELS.register_module() class VarifocalLoss(nn.Module): def __init__(self, use_sigmoid: bool = True, alpha: float = 0.75, gamma: float = 2.0, iou_weighted: bool = True, reduction: str = 'mean', loss_weight: float = 1.0) -> None: """`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_ Args: use_sigmoid (bool, optional): Whether the prediction is used for sigmoid or softmax. Defaults to True. alpha (float, optional): A balance factor for the negative part of Varifocal Loss, which is different from the alpha of Focal Loss. Defaults to 0.75. gamma (float, optional): The gamma for calculating the modulating factor. Defaults to 2.0. iou_weighted (bool, optional): Whether to weight the loss of the positive examples with the iou target. Defaults to True. reduction (str, optional): The method used to reduce the loss into a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum". loss_weight (float, optional): Weight of loss. Defaults to 1.0. """ super().__init__() assert use_sigmoid is True, \ 'Only sigmoid varifocal loss supported now.' assert alpha >= 0.0 self.use_sigmoid = use_sigmoid self.alpha = alpha self.gamma = gamma self.iou_weighted = iou_weighted self.reduction = reduction self.loss_weight = loss_weight
[docs] def forward(self, pred: Tensor, target: Tensor, weight: Optional[Tensor] = None, avg_factor: Optional[int] = None, reduction_override: Optional[str] = None) -> Tensor: """Forward function. Args: pred (Tensor): The prediction with shape (N, C), C is the number of classes. target (Tensor): The learning target of the iou-aware classification score with shape (N, C), C is the number of classes. weight (Tensor, optional): The weight of loss for each prediction. Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Options are "none", "mean" and "sum". Returns: Tensor: The calculated loss """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) if self.use_sigmoid: loss_cls = self.loss_weight * varifocal_loss( pred, target, weight, alpha=self.alpha, gamma=self.gamma, iou_weighted=self.iou_weighted, reduction=reduction, avg_factor=avg_factor) else: raise NotImplementedError return loss_cls