Shortcuts

Source code for mmdet.models.losses.triplet_loss

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmengine.model import BaseModule

from mmdet.registry import MODELS


[docs]@MODELS.register_module() class TripletLoss(BaseModule): """Triplet loss with hard positive/negative mining. Reference: Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. Imported from `<https://github.com/KaiyangZhou/deep-person-reid/blob/ master/torchreid/losses/hard_mine_triplet_loss.py>`_. Args: margin (float, optional): Margin for triplet loss. Defaults to 0.3. loss_weight (float, optional): Weight of the loss. Defaults to 1.0. hard_mining (bool, optional): Whether to perform hard mining. Defaults to True. """ def __init__(self, margin: float = 0.3, loss_weight: float = 1.0, hard_mining=True): super(TripletLoss, self).__init__() self.margin = margin self.ranking_loss = nn.MarginRankingLoss(margin=margin) self.loss_weight = loss_weight self.hard_mining = hard_mining
[docs] def hard_mining_triplet_loss_forward( self, inputs: torch.Tensor, targets: torch.LongTensor) -> torch.Tensor: """ Args: inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim). targets (torch.LongTensor): ground truth labels with shape (batch_size). Returns: torch.Tensor: triplet loss with hard mining. """ batch_size = inputs.size(0) # Compute Euclidean distance dist = torch.pow(inputs, 2).sum( dim=1, keepdim=True).expand(batch_size, batch_size) dist = dist + dist.t() dist.addmm_(inputs, inputs.t(), beta=1, alpha=-2) dist = dist.clamp(min=1e-12).sqrt() # for numerical stability # For each anchor, find the furthest positive sample # and nearest negative sample in the embedding space mask = targets.expand(batch_size, batch_size).eq( targets.expand(batch_size, batch_size).t()) dist_ap, dist_an = [], [] for i in range(batch_size): dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) dist_ap = torch.cat(dist_ap) dist_an = torch.cat(dist_an) # Compute ranking hinge loss y = torch.ones_like(dist_an) return self.loss_weight * self.ranking_loss(dist_an, dist_ap, y)
[docs] def forward(self, inputs: torch.Tensor, targets: torch.LongTensor) -> torch.Tensor: """ Args: inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim). targets (torch.LongTensor): ground truth labels with shape (num_classes). Returns: torch.Tensor: triplet loss. """ if self.hard_mining: return self.hard_mining_triplet_loss_forward(inputs, targets) else: raise NotImplementedError()