|
| 1 | +from torch.utils.data import Dataset |
| 2 | +from data_classes import PointCloud, Box |
| 3 | +from pyquaternion import Quaternion |
| 4 | +import numpy as np |
| 5 | +import pandas as pd |
| 6 | +import os |
| 7 | +import torch |
| 8 | +from tqdm import tqdm |
| 9 | +import kitty_utils as utils |
| 10 | +from kitty_utils import getModel |
| 11 | +from searchspace import KalmanFiltering |
| 12 | +import logging |
| 13 | +from functools import partial |
| 14 | +import copy |
| 15 | + |
| 16 | +class kittiDataset(): |
| 17 | + def __init__(self, path): |
| 18 | + self.KITTI_Folder = path |
| 19 | + self.KITTI_velo = os.path.join(self.KITTI_Folder, "velodyne") |
| 20 | + self.KITTI_label = os.path.join(self.KITTI_Folder, "label_02") |
| 21 | + |
| 22 | + def getSceneID(self, split): |
| 23 | + if "TRAIN" in split.upper(): # Training SET |
| 24 | + if "TINY" in split.upper(): |
| 25 | + sceneID = [0] |
| 26 | + else: |
| 27 | + sceneID = list(range(0, 17)) |
| 28 | + elif "VALID" in split.upper(): # Validation Set |
| 29 | + if "TINY" in split.upper(): |
| 30 | + sceneID = [18] |
| 31 | + else: |
| 32 | + sceneID = list(range(17, 19)) |
| 33 | + elif "TEST" in split.upper(): # Testing Set |
| 34 | + if "TINY" in split.upper(): |
| 35 | + sceneID = [19] |
| 36 | + else: |
| 37 | + sceneID = list(range(19, 21)) |
| 38 | + |
| 39 | + else: # Full Dataset |
| 40 | + sceneID = list(range(21)) |
| 41 | + return sceneID |
| 42 | + |
| 43 | + def getBBandPC(self, anno): |
| 44 | + calib_path = os.path.join(self.KITTI_Folder, 'calib', |
| 45 | + anno['scene'] + ".txt") |
| 46 | + calib = self.read_calib_file(calib_path) |
| 47 | + transf_mat = np.vstack((calib["Tr_velo_cam"], np.array([0, 0, 0, 1]))) |
| 48 | + PC, box = self.getPCandBBfromPandas(anno, transf_mat) |
| 49 | + return PC, box |
| 50 | + |
| 51 | + def getListOfAnno(self, sceneID, category_name="Car"): |
| 52 | + list_of_scene = [ |
| 53 | + path for path in os.listdir(self.KITTI_velo) |
| 54 | + if os.path.isdir(os.path.join(self.KITTI_velo, path)) and |
| 55 | + int(path) in sceneID |
| 56 | + ] |
| 57 | + list_of_tracklet_anno = [] |
| 58 | + for scene in list_of_scene: |
| 59 | + |
| 60 | + label_file = os.path.join(self.KITTI_label, scene + ".txt") |
| 61 | + df = pd.read_csv( |
| 62 | + label_file, |
| 63 | + sep=' ', |
| 64 | + names=[ |
| 65 | + "frame", "track_id", "type", "truncated", "occluded", |
| 66 | + "alpha", "bbox_left", "bbox_top", "bbox_right", |
| 67 | + "bbox_bottom", "height", "width", "length", "x", "y", "z", |
| 68 | + "rotation_y" |
| 69 | + ]) |
| 70 | + df = df[df["type"] == category_name] |
| 71 | + df.insert(loc=0, column="scene", value=scene) |
| 72 | + for track_id in df.track_id.unique(): |
| 73 | + df_tracklet = df[df["track_id"] == track_id] |
| 74 | + df_tracklet = df_tracklet.reset_index(drop=True) |
| 75 | + tracklet_anno = [anno for index, anno in df_tracklet.iterrows()] |
| 76 | + list_of_tracklet_anno.append(tracklet_anno) |
| 77 | + |
| 78 | + return list_of_tracklet_anno |
| 79 | + |
| 80 | + def getPCandBBfromPandas(self, box, calib): |
| 81 | + center = [box["x"], box["y"] - box["height"] / 2, box["z"]] |
| 82 | + size = [box["width"], box["length"], box["height"]] |
| 83 | + orientation = Quaternion( |
| 84 | + axis=[0, 1, 0], radians=box["rotation_y"]) * Quaternion( |
| 85 | + axis=[1, 0, 0], radians=np.pi / 2) |
| 86 | + BB = Box(center, size, orientation) |
| 87 | + |
| 88 | + try: |
| 89 | + # VELODYNE PointCloud |
| 90 | + velodyne_path = os.path.join(self.KITTI_velo, box["scene"], |
| 91 | + '{:06}.bin'.format(box["frame"])) |
| 92 | + PC = PointCloud( |
| 93 | + np.fromfile(velodyne_path, dtype=np.float32).reshape(-1, 4).T) |
| 94 | + PC.transform(calib) |
| 95 | + except : |
| 96 | + # in case the Point cloud is missing |
| 97 | + # (0001/[000177-000180].bin) |
| 98 | + PC = PointCloud(np.array([[0, 0, 0]]).T) |
| 99 | + |
| 100 | + return PC, BB |
| 101 | + |
| 102 | + def read_calib_file(self, filepath): |
| 103 | + """Read in a calibration file and parse into a dictionary.""" |
| 104 | + data = {} |
| 105 | + with open(filepath, 'r') as f: |
| 106 | + for line in f.readlines(): |
| 107 | + values = line.split() |
| 108 | + # The only non-float values in these files are dates, which |
| 109 | + # we don't care about anyway |
| 110 | + try: |
| 111 | + data[values[0]] = np.array( |
| 112 | + [float(x) for x in values[1:]]).reshape(3, 4) |
| 113 | + except ValueError: |
| 114 | + pass |
| 115 | + return data |
| 116 | + |
| 117 | + |
| 118 | +class SiameseDataset(Dataset): |
| 119 | + |
| 120 | + def __init__(self, |
| 121 | + input_size, |
| 122 | + path, |
| 123 | + split, |
| 124 | + category_name="Car", |
| 125 | + regress="GAUSSIAN", |
| 126 | + offset_BB=0, |
| 127 | + scale_BB=1.0): |
| 128 | + self.dataset = kittiDataset(path=path) |
| 129 | + |
| 130 | + self.input_size = input_size |
| 131 | + self.split = split |
| 132 | + self.sceneID = self.dataset.getSceneID(split=split) |
| 133 | + self.getBBandPC = self.dataset.getBBandPC |
| 134 | + |
| 135 | + self.category_name = category_name |
| 136 | + self.regress = regress |
| 137 | + |
| 138 | + self.list_of_tracklet_anno = self.dataset.getListOfAnno( |
| 139 | + self.sceneID, category_name) |
| 140 | + |
| 141 | + self.list_of_anno = [ |
| 142 | + anno for tracklet_anno in self.list_of_tracklet_anno |
| 143 | + for anno in tracklet_anno |
| 144 | + ] |
| 145 | + |
| 146 | + def isTiny(self): |
| 147 | + return ("TINY" in self.split.upper()) |
| 148 | + |
| 149 | + def __getitem__(self, index): |
| 150 | + return self.getitem(index) |
| 151 | + |
| 152 | + |
| 153 | +class SiameseTrain(SiameseDataset): |
| 154 | + def __init__(self, |
| 155 | + input_size, |
| 156 | + path, |
| 157 | + split="", |
| 158 | + category_name="Car", |
| 159 | + regress="GAUSSIAN", |
| 160 | + sigma_Gaussian=1, |
| 161 | + offset_BB=0, |
| 162 | + scale_BB=1.0): |
| 163 | + super(SiameseTrain,self).__init__( |
| 164 | + input_size=input_size, |
| 165 | + path=path, |
| 166 | + split=split, |
| 167 | + category_name=category_name, |
| 168 | + regress=regress, |
| 169 | + offset_BB=offset_BB, |
| 170 | + scale_BB=scale_BB) |
| 171 | + |
| 172 | + self.sigma_Gaussian = sigma_Gaussian |
| 173 | + self.offset_BB = offset_BB |
| 174 | + self.scale_BB = scale_BB |
| 175 | + |
| 176 | + self.num_candidates_perframe = 4 |
| 177 | + logging.info("preloading PC...") |
| 178 | + self.list_of_PCs = [None] * len(self.list_of_anno) |
| 179 | + self.list_of_BBs = [None] * len(self.list_of_anno) |
| 180 | + for index in tqdm(range(len(self.list_of_anno))): |
| 181 | + anno = self.list_of_anno[index] |
| 182 | + PC, box = self.getBBandPC(anno) |
| 183 | + new_PC = utils.cropPC(PC, box, offset=10) |
| 184 | + |
| 185 | + self.list_of_PCs[index] = new_PC |
| 186 | + self.list_of_BBs[index] = box |
| 187 | + logging.info("PC preloaded!") |
| 188 | + |
| 189 | + logging.info("preloading Model..") |
| 190 | + self.model_PC = [None] * len(self.list_of_tracklet_anno) |
| 191 | + for i in tqdm(range(len(self.list_of_tracklet_anno))): |
| 192 | + list_of_anno = self.list_of_tracklet_anno[i] |
| 193 | + PCs = [] |
| 194 | + BBs = [] |
| 195 | + cnt = 0 |
| 196 | + for anno in list_of_anno: |
| 197 | + this_PC, this_BB = self.getBBandPC(anno) |
| 198 | + PCs.append(this_PC) |
| 199 | + BBs.append(this_BB) |
| 200 | + anno["model_idx"] = i |
| 201 | + anno["relative_idx"] = cnt |
| 202 | + cnt += 1 |
| 203 | + |
| 204 | + self.model_PC[i] = getModel( |
| 205 | + PCs, BBs, offset=self.offset_BB, scale=self.scale_BB) |
| 206 | + logging.info("Model preloaded!") |
| 207 | + |
| 208 | + def __getitem__(self, index): |
| 209 | + return self.getitem(index) |
| 210 | + |
| 211 | + def getPCandBBfromIndex(self, anno_idx): |
| 212 | + this_PC = self.list_of_PCs[anno_idx] |
| 213 | + this_BB = self.list_of_BBs[anno_idx] |
| 214 | + return this_PC, this_BB |
| 215 | + |
| 216 | + def getitem(self, index): |
| 217 | + anno_idx = self.getAnnotationIndex(index) |
| 218 | + sample_idx = self.getSearchSpaceIndex(index) |
| 219 | + |
| 220 | + def random_box(box, center_offset, w_ratio, h_ratio, flag): |
| 221 | + if not flag: |
| 222 | + return box |
| 223 | + box = copy.deepcopy(box) |
| 224 | + box.center[0] += center_offset[0] * box.wlh[1] |
| 225 | + box.center[1] += center_offset[1] * box.wlh[0] |
| 226 | + box.wlh[0] *= w_ratio |
| 227 | + box.wlh[1] *= h_ratio |
| 228 | + return box |
| 229 | + |
| 230 | + random_box_func = partial(random_box, **dict( |
| 231 | + center_offset=[np.random.uniform(-0.4, 0.4), |
| 232 | + np.random.uniform(-0.4, 0.4)], |
| 233 | + w_ratio=np.random.uniform(0.3, 1.0), |
| 234 | + h_ratio=np.random.uniform(0.3, 1.0), |
| 235 | + flag=np.random.uniform() < 0.0 # prob |
| 236 | + )) |
| 237 | + |
| 238 | + if sample_idx == 0: |
| 239 | + sample_offsets = np.zeros(4) |
| 240 | + else: |
| 241 | + gaussian = KalmanFiltering(bnd=[1, 1, 1, 1]) |
| 242 | + sample_offsets = gaussian.sample(1)[0] |
| 243 | + sample_offsets[1] /= 2.0 |
| 244 | + sample_offsets[0] *= 2 |
| 245 | + |
| 246 | + this_anno = self.list_of_anno[anno_idx] |
| 247 | + this_PC, this_BB = self.getPCandBBfromIndex(anno_idx) |
| 248 | + # Random bbox |
| 249 | + sample_BB = utils.getOffsetBB(this_BB, sample_offsets) |
| 250 | + |
| 251 | + sample_BB = random_box_func(box=sample_BB) |
| 252 | + |
| 253 | + sample_PC, sample_label, sample_reg = utils.cropAndCenterPC_label( |
| 254 | + this_PC, sample_BB, this_BB, sample_offsets, |
| 255 | + offset=self.offset_BB, scale=self.scale_BB) |
| 256 | + |
| 257 | + if sample_PC.nbr_points() <= 10: |
| 258 | + return self.getitem(np.random.randint(0, self.__len__())) |
| 259 | + |
| 260 | + random_downsample = np.random.uniform() < 0.0 |
| 261 | + def _random_sample_pts(pc, num): |
| 262 | + p = np.array(pc.points, dtype=np.float32) |
| 263 | + if p.shape[1] < 10: |
| 264 | + return pc |
| 265 | + new_idx = np.random.randint(low=0, high=p.shape[1], size=num, dtype=np.int64) |
| 266 | + p = p[:, new_idx] |
| 267 | + pc.points = p |
| 268 | + return pc |
| 269 | + if random_downsample: |
| 270 | + random_downsample_pc_func = partial(_random_sample_pts, |
| 271 | + num=np.random.randint(min(128, sample_PC.points.shape[1] - 1), |
| 272 | + sample_PC.points.shape[1])) |
| 273 | + sample_PC = random_downsample_pc_func(sample_PC) |
| 274 | + # sample_PC = utils.regularizePC(sample_PC, self.input_size)[0] |
| 275 | + sample_PC, sample_label, sample_reg = utils.regularizePCwithlabel( |
| 276 | + sample_PC, sample_label, sample_reg, self.input_size) |
| 277 | + |
| 278 | + if this_anno["relative_idx"] == 0: |
| 279 | + prev_idx = 0 |
| 280 | + fir_idx = 0 |
| 281 | + else: |
| 282 | + prev_idx = anno_idx - 1 |
| 283 | + fir_idx = anno_idx - this_anno["relative_idx"] |
| 284 | + gt_PC_pre, gt_BB_pre = self.getPCandBBfromIndex(prev_idx) |
| 285 | + gt_PC_fir, gt_BB_fir = self.getPCandBBfromIndex(fir_idx) |
| 286 | + |
| 287 | + gt_BB_pre = random_box_func(box=gt_BB_pre) |
| 288 | + gt_BB_fir = random_box_func(box=gt_BB_fir) |
| 289 | + |
| 290 | + if sample_idx == 0: |
| 291 | + samplegt_offsets = np.zeros(4) |
| 292 | + else: |
| 293 | + samplegt_offsets = np.random.uniform(low=-0.3, high=0.3, size=4) |
| 294 | + samplegt_offsets[0] *= 2 |
| 295 | + |
| 296 | + gt_BB_pre = utils.getOffsetBB(gt_BB_pre, samplegt_offsets) |
| 297 | + gt_PC = getModel([gt_PC_pre], [gt_BB_pre], offset=self.offset_BB, scale=self.scale_BB) |
| 298 | + if random_downsample: |
| 299 | + gt_PC = random_downsample_pc_func(gt_PC) |
| 300 | + |
| 301 | + if gt_PC.nbr_points() <= 20: |
| 302 | + return self.getitem(np.random.randint(0, self.__len__())) |
| 303 | + gt_PC = utils.regularizePC(gt_PC, self.input_size) |
| 304 | + |
| 305 | + ret = { |
| 306 | + 'search' : sample_PC, |
| 307 | + 'template' : gt_PC, |
| 308 | + 'cls_label' : sample_label, # whether in box |
| 309 | + 'reg_label' : sample_reg # box |
| 310 | + } |
| 311 | + return ret # sample_PC, sample_label, sample_reg, gt_PC |
| 312 | + |
| 313 | + def __len__(self): |
| 314 | + nb_anno = len(self.list_of_anno) |
| 315 | + return nb_anno * self.num_candidates_perframe |
| 316 | + |
| 317 | + def getAnnotationIndex(self, index): |
| 318 | + return int(index / (self.num_candidates_perframe)) |
| 319 | + |
| 320 | + def getSearchSpaceIndex(self, index): |
| 321 | + return int(index % self.num_candidates_perframe) |
| 322 | + |
| 323 | + |
| 324 | +class SiameseTest(SiameseDataset): |
| 325 | + |
| 326 | + def __init__(self, |
| 327 | + input_size, |
| 328 | + path, |
| 329 | + split="", |
| 330 | + category_name="Car", |
| 331 | + regress="GAUSSIAN", |
| 332 | + offset_BB=0, |
| 333 | + scale_BB=1.0): |
| 334 | + super(SiameseTest,self).__init__( |
| 335 | + input_size=input_size, |
| 336 | + path=path, |
| 337 | + split=split, |
| 338 | + category_name=category_name, |
| 339 | + regress=regress, |
| 340 | + offset_BB=offset_BB, |
| 341 | + scale_BB=scale_BB) |
| 342 | + self.split = split |
| 343 | + self.offset_BB = offset_BB |
| 344 | + self.scale_BB = scale_BB |
| 345 | + |
| 346 | + def getitem(self, index): |
| 347 | + list_of_anno = self.list_of_tracklet_anno[index] |
| 348 | + PCs = [] |
| 349 | + BBs = [] |
| 350 | + for anno in list_of_anno: |
| 351 | + this_PC, this_BB = self.getBBandPC(anno) |
| 352 | + PCs.append(this_PC) |
| 353 | + BBs.append(this_BB) |
| 354 | + return PCs, BBs, list_of_anno |
| 355 | + |
| 356 | + def __len__(self): |
| 357 | + return len(self.list_of_tracklet_anno) |
| 358 | + |
| 359 | + |
| 360 | +if __name__ == '__main__': |
| 361 | + |
| 362 | + dataset_Training = SiameseTrain( |
| 363 | + input_size=2048, |
| 364 | + path='/data/qihaozhe/Kitty_data/training', |
| 365 | + split='Tiny_Train', |
| 366 | + category_name='Car', |
| 367 | + offset_BB=0, |
| 368 | + scale_BB=1.15) |
| 369 | + aa = dataset_Training.getitem(201) |
| 370 | + aa = dataset_Training.getitem(30) |
| 371 | + aa = dataset_Training.getitem(100) |
| 372 | + aa = dataset_Training.getitem(120) |
| 373 | + aa = dataset_Training.getitem(200) |
| 374 | + asdf = aa[2].numpy() |
0 commit comments