淘先锋技术网

首页 1 2 3 4 5 6 7

点云中的数据增强方法

本文基于相机坐标展示(kitti中的标签是相机坐标系)

一 旋转(相机坐标系沿y轴旋转)

import numpy as np

def rotation_points_single_angle(points, angle, axis=0):
    # points: [N, 3]
    rot_sin = np.sin(angle)
    rot_cos = np.cos(angle)
    if axis == 1:
        rot_mat_T = np.array(
            [[rot_cos, 0, -rot_sin], [0, 1, 0], [rot_sin, 0, rot_cos]],
            dtype=points.dtype)
    elif axis == 2 or axis == -1:
        rot_mat_T = np.array(
            [[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]],
            dtype=points.dtype)
    elif axis == 0:
        rot_mat_T = np.array(
            [[1, 0, 0], [0, rot_cos, -rot_sin], [0, rot_sin, rot_cos]],
            dtype=points.dtype)
    else:
        raise ValueError("axis should in range")

    return points @ rot_mat_T



def global_rotation(gt_boxes, points):
    noise_rotation = np.random.uniform(-90, 90)
    points[:, :3] = rotation_points_single_angle(
        points[:, :3], noise_rotation, axis=1)
    gt_boxes[:, :3] = rotation_points_single_angle(
        gt_boxes[:, :3], noise_rotation, axis=1)
    gt_boxes[:, 6] += noise_rotation
    return gt_boxes, points

 输出:

二 镜像

x轴镜像和z轴镜像

import numpy as np
def random_flip(gt_boxes, points, probability=0.5):
    x = np.random.choice(
         [False, True], replace=False, p=[1 - probability, probability])
    z =  np.random.choice(
         [False, True], replace=False, p=[1 - probability, probability])
    if z:
         gt_boxes[:, 2] = -gt_boxes[:, 2]
         gt_boxes[:, 6] = -gt_boxes[:, 6] + np.pi
         points[:, 2] = -points[:, 2]
    if x:
        gt_boxes[:, 0] = -gt_boxes[:, 0]
        gt_boxes[:, 6] = -gt_boxes[:, 6] + np.pi
        points[:, 0] = -points[:, 0]
    return gt_boxes, points

x轴镜像:

三 真值提取


if __name__=='__main__':
    import time

    pt   #相机坐标系下的点云
    dets # 对应标签的八个定点
    
    
    t = time.time()
    x = pt[:, 0]
    y= pt[:, 1]
    z = pt[:, 2]
    dets = det.copy()[0]
    xmin = min(dets[:,0])
    xmax= max(dets[:, 0])
    ymin = min(dets[:, 1])
    ymax = max(dets[:, 1])
    zmin = min(dets[:, 2])
    zmax = max(dets[:, 2])
    print(xmin,xmax,ymin,ymax,zmin,zmax)
    x_filt = np.logical_and((x > xmin), (x < xmax))
    y_filt = np.logical_and((y > ymin), (y < ymax))
    z_filt = np.logical_and((z > zmin), (z < zmax))
    filter = np.logical_and(x_filt, y_filt)
    filter = np.logical_and(filter, z_filt)
    indices = np.argwhere(filter).flatten()
    nn= pt[indices]
    print(time.time()-t)
    # show(pt)
转换时间(是其他转换方法的100分之一左右)
0.0018105506896972656