VoxelNext, a fully sparse 3D object detection network

GitHub - dvlab-research/VoxelNeXt: VoxelNeXt: Fully Sparse VoxelNet for 3D Object Detection and Tracking (CVPR 2023)

https://arxiv.org/abs/2303.11301

Summary

The current 3D target detection model follows the 2D method in the detection part. On the dense feature map, the 3D frame is predicted through the preset anchor or center. The innovation of this article is to use the sparse characteristics of the point cloud to pass After spconv extracts features, it does not convert to dense feature maps and predicts 3D boxes directly on sparse features. It has been verified that good results have been achieved on commonly used public data sets.

1 Introduction

Take the commonly used centerpoint model as an example, including sparse to dense. Although it can work effectively, it brings the following problems: waste of computing resources, complex processes, and the need for nms post-processing.

 The method proposed in this article eliminates the steps of center anchor, sparse to dense, rpn, nms and other steps, and predicts directly and only at sparse feature positions.

Optimization of VoxelNext and Centerpoint, flops. 

VoxelNext method, relative to centerpoint and FSD, compares the latency under different detection ranges. VoxelNext is very friendly to long-distance target detection.

2. Related work

       Lidar Detectors

        Current 3D detectors usually refer to 2D detectors, such as the rcnn series, such as the centerpoint series. Although the 3D point cloud is sparse compared to the 2D data itself, the current detectors still operate on dense feature maps. predicted. This article makes a change and performs target prediction directly on sparse features.

         Sparse Detectors

           I analyzed some sparse detectors, such as waymo's RSN, which first extracts the foreground points on the range image segmentation, and then performs target detection on the sparse foreground points; SWFormer and FSD are some attempts at sparse detection, but the processes are complicated. This article uses commonly used sparse convolution to simplify the process as much as possible.
pillarnet

RSN

        Sparse Convlution Network

          Because of the efficiency of sparse convolution, it is now the mainstream method of 3D network backbone. But it is generally not used directly in the detection head. There are currently some attempts at optimization, such as using a transformer to increase the receptive field, but this article uses additional downsampling to increase the receptive field.

        3D Object Tracking        

          It is common to use kalman filter to track the results, and there are also direct prediction speeds like centertrack. This article also uses voxel query to perform correlation, which effectively predicts the deviation of the object center.

3. Fully Sparse Voxel-based Network

        Schematic diagram of voxelnext network structure:

3.1 backbone adaptation

additional down sampling

Based on the original downsampling, {1, 2, 4, 8}, {F 1, F 2, F 3, F 4}, continue to downsample {16, 32}, {F5, F6}, and then F4 , align the spatial resolution of F5 and F6 to F4, and then generate Fc.

 F is a sparse feature, and P is a 3D coordinate value. Fc is the feature superposition of F4, F5, and F6. At the same time, the sizes of P5, P6 to P4 are updated.

x_conv5 = self.conv5(x_conv4)
x_conv6 = self.conv6(x_conv5)

x_conv5.indices[:, 1:] *= 2
x_conv6.indices[:, 1:] *= 4
x_conv4 = x_conv4.replace_feature(torch.cat([x_conv4.features, x_conv5.features, x_conv6.features]))
x_conv4.indices = torch.cat([x_conv4.indices, x_conv5.indices, x_conv6.indices])

sparse height compression

The conventional approach is to change sparsity to dense, and then add the z dimension to the channel dimension.

Here, the sparse features are placed directly on the bev plane, and then the sum is added. Very efficient.

def bev_out(self, x_conv):
        features_cat = x_conv.features
        indices_cat = x_conv.indices[:, [0, 2, 3]]
        spatial_shape = x_conv.spatial_shape[1:]

        indices_unique, _inv = torch.unique(indices_cat, dim=0, return_inverse=True)
        features_unique = features_cat.new_zeros((indices_unique.shape[0], features_cat.shape[1]))
        features_unique.index_add_(0, _inv, features_cat)

        x_out = spconv.SparseConvTensor(
            features=features_unique,
            indices=indices_unique,
            spatial_shape=spatial_shape,
            batch_size=x_conv.batch_size
        )
        return x_out

spatially voxel prunning

During the downsampling process, unimportant background features are prune. It can not only highlight the prospects, but also improve computing efficiency.

3.2 sparse head

        1. class head

Prediction, NxF => NxK

The target, the voxel closest to the center of the gt box, is the positive sample.

loss, focal loss

inference, use sparse max pooling. voxel itself is sparse enough and only operates on non-empty locations. What if the object itself is very close?

 Experiments have found that query voxel is not necessarily in the center of the box, or even within the box.

        2. regression head

Positive voxel filtering, N->n

Prediction, nxF => nx2(dx,dy), nx1(z), nx3(w,h,l), nx2(cos,sin)

loss, l1 loss

Related code:

The forward network structure, the overall structure is compared with the previous centerhead, the convolution changes from 2d conv to 2d subMconv. hm is also called hm.

class SeparateHead(nn.Module):
    def __init__(self, input_channels, sep_head_dict, kernel_size, init_bias=-2.19, use_bias=False):
        super().__init__()
        self.sep_head_dict = sep_head_dict

        for cur_name in self.sep_head_dict:
            output_channels = self.sep_head_dict[cur_name]['out_channels']
            num_conv = self.sep_head_dict[cur_name]['num_conv']

            fc_list = []
            for k in range(num_conv - 1):
                fc_list.append(spconv.SparseSequential(
                    spconv.SubMConv2d(input_channels, input_channels, kernel_size, padding=int(kernel_size//2), bias=use_bias, indice_key=cur_name),
                    nn.BatchNorm1d(input_channels),
                    nn.ReLU()
                ))
            fc_list.append(spconv.SubMConv2d(input_channels, output_channels, 1, bias=True, indice_key=cur_name+'out'))
            fc = nn.Sequential(*fc_list)
            if 'hm' in cur_name:
                fc[-1].bias.data.fill_(init_bias)
            else:
                for m in fc.modules():
                    if isinstance(m, spconv.SubMConv2d):
                        kaiming_normal_(m.weight.data)
                        if hasattr(m, "bias") and m.bias is not None:
                            nn.init.constant_(m.bias, 0)

            self.__setattr__(cur_name, fc)

    def forward(self, x):
        ret_dict = {}
        for cur_name in self.sep_head_dict:
            ret_dict[cur_name] = self.__getattr__(cur_name)(x).features

        return ret_dict

Target encoding , which was previously the hm of dense, and the encoded target boxes corresponding to gt

Now it is the sparse hm and the corresponding encoded target boxes.

def assign_target_of_single_head(
            self, num_classes, gt_boxes, num_voxels, spatial_indices, spatial_shape, feature_map_stride, num_max_objs=500,
            gaussian_overlap=0.1, min_radius=2
    ):
        """
        Args:
            gt_boxes: (N, 8)
            feature_map_size: (2), [x, y]

        Returns:

        """
        heatmap = gt_boxes.new_zeros(num_classes, num_voxels)

        ret_boxes = gt_boxes.new_zeros((num_max_objs, gt_boxes.shape[-1] - 1 + 1))
        inds = gt_boxes.new_zeros(num_max_objs).long()
        mask = gt_boxes.new_zeros(num_max_objs).long()

        x, y, z = gt_boxes[:, 0], gt_boxes[:, 1], gt_boxes[:, 2]
        coord_x = (x - self.point_cloud_range[0]) / self.voxel_size[0] / feature_map_stride
        coord_y = (y - self.point_cloud_range[1]) / self.voxel_size[1] / feature_map_stride

        coord_x = torch.clamp(coord_x, min=0, max=spatial_shape[1] - 0.5)  # bugfixed: 1e-6 does not work for center.int()
        coord_y = torch.clamp(coord_y, min=0, max=spatial_shape[0] - 0.5)  #

        center = torch.cat((coord_x[:, None], coord_y[:, None]), dim=-1)
        center_int = center.int()
        center_int_float = center_int.float()

        dx, dy, dz = gt_boxes[:, 3], gt_boxes[:, 4], gt_boxes[:, 5]
        dx = dx / self.voxel_size[0] / feature_map_stride
        dy = dy / self.voxel_size[1] / feature_map_stride

        radius = centernet_utils.gaussian_radius(dx, dy, min_overlap=gaussian_overlap)
        radius = torch.clamp_min(radius.int(), min=min_radius)

        for k in range(min(num_max_objs, gt_boxes.shape[0])):
            if dx[k] <= 0 or dy[k] <= 0:
                continue

            if not (0 <= center_int[k][0] <= spatial_shape[1] and 0 <= center_int[k][1] <= spatial_shape[0]):
                continue

            cur_class_id = (gt_boxes[k, -1] - 1).long()
            
            # 距离最近的voxel选为query voxel
            # inds也更新为此voxel的顺序
            distance = self.distance(spatial_indices, center[k])
            inds[k] = distance.argmin()
            mask[k] = 1
            
            
            # 在稀疏的hm上,进行hm的绘制   
            if 'gt_center' in self.gaussian_type:
                centernet_utils.draw_gaussian_to_heatmap_voxels(heatmap[cur_class_id], distance, radius[k].item() * self.gaussian_ratio)

            if 'nearst' in self.gaussian_type:
                centernet_utils.draw_gaussian_to_heatmap_voxels(heatmap[cur_class_id], self.distance(spatial_indices, spatial_indices[inds[k]]), radius[k].item() * self.gaussian_ratio)
            
            # △x,△y,是center和代理voxel的spatial inds的offset
            ret_boxes[k, 0:2] = center[k] - spatial_indices[inds[k]][:2]
            ret_boxes[k, 2] = z[k]
            ret_boxes[k, 3:6] = gt_boxes[k, 3:6].log()
            ret_boxes[k, 6] = torch.cos(gt_boxes[k, 6])
            ret_boxes[k, 7] = torch.sin(gt_boxes[k, 6])
            if gt_boxes.shape[1] > 8:
                ret_boxes[k, 8:] = gt_boxes[k, 7:-1]

        return heatmap, ret_boxes, inds, mask

hm and box decode

def decode_bbox_from_voxels_nuscenes(batch_size, indices, obj, rot_cos, rot_sin,
                            center, center_z, dim, vel=None, iou=None, point_cloud_range=None, voxel_size=None, voxels_3d=None,
                            feature_map_stride=None, K=100, score_thresh=None, post_center_limit_range=None, add_features=None):
    batch_idx = indices[:, 0]
    spatial_indices = indices[:, 1:]
    scores, inds, class_ids = _topk_1d(None, batch_size, batch_idx, obj, K=K, nuscenes=True)

    center = gather_feat_idx(center, inds, batch_size, batch_idx)
    rot_sin = gather_feat_idx(rot_sin, inds, batch_size, batch_idx)
    rot_cos = gather_feat_idx(rot_cos, inds, batch_size, batch_idx)
    center_z = gather_feat_idx(center_z, inds, batch_size, batch_idx)
    dim = gather_feat_idx(dim, inds, batch_size, batch_idx)
    spatial_indices = gather_feat_idx(spatial_indices, inds, batch_size, batch_idx)

    if not add_features is None:
        add_features = [gather_feat_idx(add_feature, inds, batch_size, batch_idx) for add_feature in add_features]

    if not isinstance(feature_map_stride, int):
        feature_map_stride = gather_feat_idx(feature_map_stride.unsqueeze(-1), inds, batch_size, batch_idx)

    angle = torch.atan2(rot_sin, rot_cos)
    xs = (spatial_indices[:, :, -1:] + center[:, :, 0:1]) * feature_map_stride * voxel_size[0] + point_cloud_range[0]
    ys = (spatial_indices[:, :, -2:-1] + center[:, :, 1:2]) * feature_map_stride * voxel_size[1] + point_cloud_range[1]
    #zs = (spatial_indices[:, :, 0:1]) * feature_map_stride * voxel_size[2] + point_cloud_range[2] + center_z

    box_part_list = [xs, ys, center_z, dim, angle]

    if not vel is None:
        vel = gather_feat_idx(vel, inds, batch_size, batch_idx)
        box_part_list.append(vel)

    if not iou is None:
        iou = gather_feat_idx(iou, inds, batch_size, batch_idx)
        iou = torch.clamp(iou, min=0, max=1.)

    final_box_preds = torch.cat((box_part_list), dim=-1)
    final_scores = scores.view(batch_size, K)
    final_class_ids = class_ids.view(batch_size, K)
    if not add_features is None:
        add_features = [add_feature.view(batch_size, K, add_feature.shape[-1]) for add_feature in add_features]

    assert post_center_limit_range is not None
    mask = (final_box_preds[..., :3] >= post_center_limit_range[:3]).all(2)
    mask &= (final_box_preds[..., :3] <= post_center_limit_range[3:]).all(2)

    if score_thresh is not None:
        mask &= (final_scores > score_thresh)

    ret_pred_dicts = []
    for k in range(batch_size):
        cur_mask = mask[k]
        cur_boxes = final_box_preds[k, cur_mask]
        cur_scores = final_scores[k, cur_mask]
        cur_labels = final_class_ids[k, cur_mask]
        cur_add_features = [add_feature[k, cur_mask] for add_feature in add_features] if not add_features is None else None
        cur_iou = iou[k, cur_mask] if not iou is None else None

        ret_pred_dicts.append({
            'pred_boxes': cur_boxes,
            'pred_scores': cur_scores,
            'pred_labels': cur_labels,
            'pred_ious': cur_iou,
            'add_features': cur_add_features,
        })
    return ret_pred_dicts

3.3 object tracking

voxel association

   Query voxel serves as a proxy for center and uses l2 distance to associate query voxel.

        

Guess you like

Origin blog.csdn.net/huang_victor/article/details/130065986