Shortcuts

Source code for mmdet.models.dense_heads.clr_head.clr_head

# Copyright (c) VBTI. All rights reserved.
import math

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.logging import MMLogger
from mmengine.structures import InstanceData

from mmdet.registry import MODELS
from .lane import Lane
from .losses.lineiou_loss import liou_loss
from .utils.dynamic_assign import assign
from .utils.roi_gather import LinearModule, ROIGather
from .utils.seg_decoder import SegDecoder

try:
    from mmcv.ops import line_nms as nms
except ImportError:
    logger: MMLogger = MMLogger.get_current_instance()
    logger.warning(
        'NMS is not compiled. Most likely when installing the package cuda '
        'was not available.\n'
        'This model cannot work without CUDA, therefore you need to recompile'
        ' the package with CUDA support.\n'
        'To make sure that the package is compiled with CUDA support use '
        'FORCE_CUDA=1 mim install -e .\n')
    nms = None


def visualize_lane_priors_tensor(prior_tensors, num_points=72):
    """Visualize lane priors on an image using PyTorch tensors.

    Args:
    - image_tensor: torch.Tensor of shape (3, H, W) representing the input
        image
    - prior_tensors: torch.Tensor of shape (num_priors, num_components) where
        - num_components = 4 + num_offsets
        - prior_tensors[:, 1]: foreground/background probability
        - prior_tensors[:, 2]: start y
        - prior_tensors[:, 3]: start x
        - prior_tensors[:, 4]: angle theta
        - prior_tensors[:, 5]: length
        - prior_tensors[:, 6:]: offsets
    - num_points: int, number of points to use for visualization (default 72)
    """
    # 2 scores, 1 start_y, 1 start_x, 1 theta, 1 length, 72 coordinates,
    # score[0] = negative prob, score[1] = positive prob
    # Convert image tensor to numpy for plotting
    image_np = np.ones((320, 800, 3))
    H, W = image_np.shape[:2]

    plt.figure(figsize=(12, 8))
    plt.imshow(image_np)
    prior_tensors = prior_tensors.detach().cpu()
    for prior in prior_tensors:
        y = prior[2].item()
        x = prior[3].item()
        theta = prior[4].item()
        theta = theta * math.pi
        # print(theta)
        # if theta < math.pi / 4:
        #     continue
        # Generate y coordinates
        y_coords = np.linspace(0, 1, num_points)
        x_coords = x * (W - 1) + ((
            (1 - y_coords - y) * H) / np.tan(theta + 1e-5))

        y_coords = (y_coords + y) * H

        # Generate x coordinates
        # Plot the lane prior
        plt.plot(x_coords, y_coords, '-', linewidth=2)
        # Plot start point
        plt.plot(x_coords[0], y_coords[0], 'ro', markersize=5)

    plt.xlim(0, W)
    plt.ylim(H, 0)
    plt.legend()
    plt.title('Lane Priors Visualization')
    plt.show()


[docs]@MODELS.register_module() class CLRHead(nn.Module): def __init__( self, num_points: int = 72, prior_feat_channels: int = 64, fc_hidden_dim: int = 64, num_priors: int = 192, num_fc: int = 2, refine_layers: int = 3, sample_points: int = 36, img_w: int = 800, img_h: int = 320, num_classes: int = 5, cut_height: int = 0, cls_loss: dict | None = None, cls_loss_weight: float = 2., xyt_loss_weight: float = 0.5, iou_loss_weight: float = 2., seg_loss_weight: float = 1., num_classes_seg: int = 5, train_cfg=None, test_cfg=None, ): super(CLRHead, self).__init__() self.img_w = img_w self.img_h = img_h self.num_classes = num_classes self.n_strips = num_points - 1 self.n_offsets = num_points self.num_priors = num_priors self.sample_points = sample_points self.refine_layers = refine_layers self.fc_hidden_dim = fc_hidden_dim self.num_classes = num_classes_seg self.ignore_label = 255 self.bg_weight = 0.4 self.cut_height = cut_height self.cls_loss_weight = cls_loss_weight self.xyt_loss_weight = xyt_loss_weight self.iou_loss_weight = iou_loss_weight self.seg_loss_weight = seg_loss_weight if cls_loss is None: cls_loss = dict(type='FocalLoss', alpha=0.25, gamma=2.0) self.cls_criterion = MODELS.build(cls_loss) self.register_buffer( name='sample_x_indexs', tensor=(torch.linspace( 0, 1, steps=self.sample_points, dtype=torch.float32) * self.n_strips).long()) self.register_buffer( name='prior_feat_ys', tensor=torch.flip( (1 - self.sample_x_indexs.float() / self.n_strips), dims=[-1])) self.register_buffer( name='prior_ys', tensor=torch.linspace( 1, 0, steps=self.n_offsets, dtype=torch.float32)) self.prior_feat_channels = prior_feat_channels self.prior_embeddings = nn.Embedding(self.num_priors, 3) self._init_prior_embeddings() init_priors, priors_on_featmap = self.generate_priors_from_embeddings( ) # None, None self.register_buffer(name='priors', tensor=init_priors) self.register_buffer( name='priors_on_featmap', tensor=priors_on_featmap) # generate xys for feature map self.seg_decoder = SegDecoder(self.img_h, self.img_w, self.num_classes, self.prior_feat_channels, self.refine_layers) reg_modules = list() cls_modules = list() for _ in range(num_fc): reg_modules += [*LinearModule(self.fc_hidden_dim)] cls_modules += [*LinearModule(self.fc_hidden_dim)] self.reg_modules = nn.ModuleList(reg_modules) self.cls_modules = nn.ModuleList(cls_modules) self.roi_gather = ROIGather(self.prior_feat_channels, self.num_priors, self.sample_points, self.fc_hidden_dim, self.refine_layers) self.reg_layers = nn.Linear( self.fc_hidden_dim, self.n_offsets + 1 + 2 + 1) # n offsets + 1 length + start_x + start_y + theta self.cls_layers = nn.Linear(self.fc_hidden_dim, 2) weights = torch.ones(self.num_classes) weights[0] = self.bg_weight self.criterion = torch.nn.NLLLoss( ignore_index=self.ignore_label, weight=weights) # init the weights here self.init_weights() # function to init layer weights def init_weights(self): # initialize heads for m in self.cls_layers.parameters(): nn.init.normal_(m, mean=0., std=1e-3) for m in self.reg_layers.parameters(): nn.init.normal_(m, mean=0., std=1e-3) def get_device(self) -> str: return next(self.parameters()).device.type
[docs] def pool_prior_features(self, batch_features, num_priors, prior_xs): """Pool prior feature from feature map. Args: batch_features (Tensor): Input feature maps, shape: (B, C, H, W) """ batch_size = batch_features.shape[0] prior_xs = prior_xs.view(batch_size, num_priors, -1, 1) prior_ys = self.prior_feat_ys.repeat(batch_size * num_priors).view( batch_size, num_priors, -1, 1) prior_xs = prior_xs * 2. - 1. prior_ys = prior_ys * 2. - 1. grid = torch.cat((prior_xs, prior_ys), dim=-1) feature = F.grid_sample( batch_features, grid, align_corners=True).permute(0, 2, 1, 3) feature = feature.reshape(batch_size * num_priors, self.prior_feat_channels, self.sample_points, 1) return feature
def generate_priors_from_embeddings(self): predictions = self.prior_embeddings.weight # (num_prop, 3) # 2 scores, 1 start_y, 1 start_x, 1 theta, 1 length, 72 coordinates, # score[0] = negative prob, score[1] = positive prob priors = predictions.new_zeros( (self.num_priors, 2 + 2 + 2 + self.n_offsets), device=predictions.device) priors[:, 2:5] = predictions.clone() priors[:, 6:] = ( priors[:, 3].unsqueeze(1).clone().repeat(1, self.n_offsets) * (self.img_w - 1) + ((1 - self.prior_ys.repeat(self.num_priors, 1) - priors[:, 2].unsqueeze(1).clone().repeat(1, self.n_offsets)) * self.img_h / torch.tan(priors[:, 4].unsqueeze(1).clone().repeat( 1, self.n_offsets) * math.pi + 1e-5))) / ( self.img_w - 1) # init priors on feature map priors_on_featmap = priors.clone()[..., 6 + self.sample_x_indexs] return priors, priors_on_featmap def _init_prior_embeddings(self): # [start_y, start_x, theta] -> all normalize bottom_priors_nums = self.num_priors * 3 // 4 left_priors_nums, _ = self.num_priors // 8, self.num_priors // 8 strip_size = 0.5 / (left_priors_nums // 2 - 1) bottom_strip_size = 1 / (bottom_priors_nums // 4 + 1) for i in range(left_priors_nums): nn.init.constant_(self.prior_embeddings.weight[i, 0], (i // 2) * strip_size) nn.init.constant_(self.prior_embeddings.weight[i, 1], 0.) nn.init.constant_(self.prior_embeddings.weight[i, 2], 0.16 if i % 2 == 0 else 0.32) for i in range(left_priors_nums, left_priors_nums + bottom_priors_nums): nn.init.constant_(self.prior_embeddings.weight[i, 0], 0.) nn.init.constant_(self.prior_embeddings.weight[i, 1], ((i - left_priors_nums) // 4 + 1) * bottom_strip_size) nn.init.constant_(self.prior_embeddings.weight[i, 2], 0.2 * (i % 4 + 1)) for i in range(left_priors_nums + bottom_priors_nums, self.num_priors): nn.init.constant_( self.prior_embeddings.weight[i, 0], ((i - left_priors_nums - bottom_priors_nums) // 2) * strip_size) nn.init.constant_(self.prior_embeddings.weight[i, 1], 1.) nn.init.constant_(self.prior_embeddings.weight[i, 2], 0.68 if i % 2 == 0 else 0.84) # forward function here
[docs] def forward(self, inputs, batch_data_samples, training: bool = False, **kwargs): """Take pyramid features as input to perform Cross Layer Refinement and finally output the prediction lanes. Each feature is a 4D tensor. Args: x: input features (list[Tensor]) Return: prediction_list: each layer's prediction result seg: segmentation result for auxiliary loss """ batch_features = list(inputs[len(inputs) - self.refine_layers:]) batch_features.reverse() batch_size = batch_features[-1].shape[0] if training: self.priors, self.priors_on_featmap = \ self.generate_priors_from_embeddings() # visualize_lane_priors_tensor(self.priors, num_points=72) priors, priors_on_featmap = self.priors.repeat( batch_size, 1, 1), self.priors_on_featmap.repeat(batch_size, 1, 1) predictions_lists = [] # iterative refine prior_features_stages = [] for stage in range(self.refine_layers): num_priors = priors_on_featmap.shape[1] prior_xs = torch.flip(priors_on_featmap, dims=[2]) batch_prior_features = self.pool_prior_features( batch_features[stage], num_priors, prior_xs) prior_features_stages.append(batch_prior_features) fc_features = self.roi_gather(prior_features_stages, batch_features[stage], stage) fc_features = fc_features.view(num_priors, batch_size, -1).reshape(batch_size * num_priors, self.fc_hidden_dim) cls_features = fc_features.clone() reg_features = fc_features.clone() for cls_layer in self.cls_modules: cls_features = cls_layer(cls_features) for reg_layer in self.reg_modules: reg_features = reg_layer(reg_features) cls_logits = self.cls_layers(cls_features) reg = self.reg_layers(reg_features) cls_logits = cls_logits.reshape( batch_size, -1, cls_logits.shape[1]) # (B, num_priors, 2) reg = reg.reshape(batch_size, -1, reg.shape[1]) predictions = priors.clone() predictions[:, :, :2] = cls_logits predictions[:, :, 2:5] += reg[:, :, :3] # also reg theta angle here predictions[:, :, 5] = reg[:, :, 3] # length def tran_tensor(t): return t.unsqueeze(2).clone().repeat(1, 1, self.n_offsets) predictions[..., 6:] = ( tran_tensor(predictions[..., 3]) * (self.img_w - 1) + ((1 - self.prior_ys.repeat(batch_size, num_priors, 1) - tran_tensor(predictions[..., 2])) * self.img_h / torch.tan(tran_tensor(predictions[..., 4]) * math.pi + 1e-5)) ) / ( self.img_w - 1) prediction_lines = predictions.clone() predictions[..., 6:] += reg[..., 4:] predictions_lists.append(predictions) if stage != self.refine_layers - 1: priors = prediction_lines.detach().clone() priors_on_featmap = priors[..., 6 + self.sample_x_indexs] if training: seg = None seg_features = torch.cat([ F.interpolate( feature, size=[ batch_features[-1].shape[2], batch_features[-1].shape[3] ], mode='bilinear', align_corners=False) for feature in batch_features ], dim=1) seg = self.seg_decoder(seg_features) output = {'predictions_lists': predictions_lists, 'seg': seg} return output, batch_data_samples return predictions_lists[-1]
[docs] def predictions_to_pred(self, predictions, scores, org_width): """Convert predictions to internal Lane structure for evaluation.""" self.prior_ys = self.prior_ys.to(predictions.device) self.prior_ys = self.prior_ys.double() lanes = [] for lane, score in zip(predictions, scores, strict=True): lane_xs = lane[6:] # normalized value start = min( max(0, int(round(lane[2].item() * self.n_strips))), self.n_strips) length = int(round(lane[5].item())) end = start + length - 1 end = min(end, len(self.prior_ys) - 1) # end = label_end # if the prediction does not start at the bottom of the image, # extend its prediction until the x is outside the image mask = ~((((lane_xs[:start] >= 0.) & (lane_xs[:start] <= 1.) ).cpu().numpy()[::-1].cumprod()[::-1]).astype(np.bool_)) lane_xs[end + 1:] = -2 lane_xs[:start][mask] = -2 lane_ys = self.prior_ys[lane_xs >= 0] lane_xs = lane_xs[lane_xs >= 0] lane_xs = lane_xs.flip(0).double() lane_ys = lane_ys.flip(0) lane_ys = (lane_ys * (org_width - self.cut_height) + self.cut_height) / org_width if len(lane_xs) <= 1: continue points = torch.stack( (lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)), dim=1).squeeze(2) lane_prediction = Lane( points=points.cpu().numpy(), metadata={ 'start_x': lane[3].cpu().item(), 'start_y': lane[2].cpu().item(), 'conf': score.cpu().item() }) lanes.append(lane_prediction) return lanes
def loss(self, x, batch, training: bool = True, **kwargs): output, batch = self.forward( x, batch_data_samples=batch, training=training) predictions_lists = output['predictions_lists'] targets = torch.stack([torch.tensor(b.lines_array) for b in batch]).clone() targets = targets.to(x[0].device) target_batch_seg = torch.stack( [b.gt_sem_seg.sem_seg.squeeze(0) for b in batch]).clone().long() target_batch_seg = target_batch_seg.to(x[0].device) cls_loss = 0 reg_xytl_loss = 0 iou_loss = 0 for stage in range(self.refine_layers): predictions_list = predictions_lists[stage] for predictions, target in zip(predictions_list, targets): target = target[target[:, 1] == 1] if len(target) == 0: # If there are no targets, all predictions have to be # negatives (i.e., 0 confidence) cls_target = predictions.new_zeros( predictions.shape[0]).long() cls_pred = predictions[:, :2] cls_loss = cls_loss + self.cls_criterion( cls_pred, cls_target).sum() continue with torch.no_grad(): matched_row_inds, matched_col_inds = assign( predictions, target, self.img_w, self.img_h) # classification targets cls_target = predictions.new_zeros(predictions.shape[0]).long() cls_target[matched_row_inds] = 1 cls_pred = predictions[:, :2] # regression targets -> [start_y, start_x, theta] # (all transformed to absolute values), only on matched pairs reg_yxtl = predictions[matched_row_inds, 2:6] reg_yxtl[:, 0] *= self.n_strips reg_yxtl[:, 1] *= (self.img_w - 1) reg_yxtl[:, 2] *= 180 reg_yxtl[:, 3] *= self.n_strips target_yxtl = target[matched_col_inds, 2:6].clone() # regression targets -> S coordinates # (all transformed to absolute values) reg_pred = predictions[matched_row_inds, 6:] reg_pred *= (self.img_w - 1) reg_targets = target[matched_col_inds, 6:].clone() with torch.no_grad(): predictions_starts = torch.clamp( (predictions[matched_row_inds, 2] * self.n_strips).round().long(), 0, self.n_strips ) # ensure the predictions starts is valid target_starts = (target[matched_col_inds, 2] * self.n_strips).round().long() target_yxtl[:, -1] -= (predictions_starts - target_starts ) # reg length # Loss calculation cls_loss = cls_loss + self.cls_criterion( cls_pred, cls_target).sum() / target.shape[0] target_yxtl[:, 0] *= self.n_strips target_yxtl[:, 2] *= 180 reg_xytl_loss = reg_xytl_loss + F.smooth_l1_loss( reg_yxtl, target_yxtl, reduction='none').mean() iou_loss = iou_loss + liou_loss( reg_pred, reg_targets, self.img_w, length=15) # extra segmentation loss seg_loss = self.criterion( F.log_softmax(output['seg'], dim=1), target_batch_seg) cls_loss /= (len(targets) * self.refine_layers) reg_xytl_loss /= (len(targets) * self.refine_layers) iou_loss /= (len(targets) * self.refine_layers) loss = cls_loss * self.cls_loss_weight \ + reg_xytl_loss * self.xyt_loss_weight \ + seg_loss * self.seg_loss_weight \ + iou_loss * self.iou_loss_weight return_value = { 'loss': loss, 'cls_loss': cls_loss * self.cls_loss_weight, 'reg_xytl_loss': reg_xytl_loss * self.xyt_loss_weight, 'seg_loss': seg_loss * self.seg_loss_weight, 'iou_loss': iou_loss * self.iou_loss_weight } return return_value
[docs] def get_lanes(self, output, org_widths, as_lanes=True, conf_threshold=0.05, nms_threshold=0.5, top_k=5): """Convert model output to lanes.""" softmax = nn.Softmax(dim=1) decoded = [] for predictions, org_width in zip(output, org_widths, strict=True): # filter out the conf lower than conf threshold scores = softmax(predictions[:, :2])[:, 1] keep_inds = scores >= conf_threshold predictions = predictions[keep_inds] scores = scores[keep_inds] if predictions.shape[0] == 0: decoded.append([]) continue if nms is not None and self.get_device() != 'cpu': nms_predictions = predictions.detach().clone() nms_predictions = torch.cat( [nms_predictions[..., :4], nms_predictions[..., 5:]], dim=-1) nms_predictions[..., 4] = nms_predictions[..., 4] * self.n_strips nms_predictions[..., 5:] = nms_predictions[..., 5:] * ( self.img_w - 1) keep, num_to_keep, _ = nms( nms_predictions, scores, overlap=nms_threshold, top_k=top_k) keep = keep[:num_to_keep] predictions = predictions[keep] scores = scores[keep] if predictions.shape[0] == 0: decoded.append([]) continue predictions[:, 5] = torch.round(predictions[:, 5] * self.n_strips) if as_lanes: pred = self.predictions_to_pred(predictions, scores, org_width) else: pred = predictions decoded.append(pred) return decoded
[docs] def predict(self, x, batched_data_samples, *args, **kwargs): """Perform inference on the input image. Args: x: input image Return: lanes: list of lanes """ output = self.forward(x, batched_data_samples) org_widths, org_heights = zip(*[(sample.org_width, sample.org_height) for sample in batched_data_samples]) batched_lane = self.get_lanes(output, org_widths, as_lanes=True) preds = [] for lanes, org_width, org_height in zip(batched_lane, org_widths, org_heights): instance = InstanceData() instance_points = [] for lane in lanes: points = lane.points points[..., 0] *= (org_width - 1) points[..., 1] *= (org_height - 1) instance_points.append(points) instance.lines = instance_points instance.line_conf = [lane.metadata['conf'] for lane in lanes] preds.append(instance) return preds