4
4
import torch .nn .functional as F
5
5
6
6
from ...utils import box_coder_utils , common_utils , loss_utils
7
- from ..model_utils .model_nms_utils import class_agnostic_nms
7
+ from ..model_utils .model_nms_utils import class_agnostic_nms , fast_bev_nms
8
8
from .target_assigner .proposal_target_layer import ProposalTargetLayer
9
9
10
10
@@ -63,7 +63,7 @@ def proposal_layer(self, batch_dict, nms_config):
63
63
"""
64
64
batch_size = batch_dict ['batch_size' ]
65
65
batch_box_preds = batch_dict ['batch_box_preds' ]
66
- batch_cls_preds = batch_dict ['batch_cls_preds' ]
66
+ batch_cls_preds = torch . sigmoid ( batch_dict ['batch_cls_preds' ])
67
67
rois = batch_box_preds .new_zeros ((batch_size , nms_config .NMS_POST_MAXSIZE , batch_box_preds .shape [- 1 ]))
68
68
roi_scores = batch_box_preds .new_zeros ((batch_size , nms_config .NMS_POST_MAXSIZE ))
69
69
roi_labels = batch_box_preds .new_zeros ((batch_size , nms_config .NMS_POST_MAXSIZE ), dtype = torch .long )
@@ -83,9 +83,22 @@ def proposal_layer(self, batch_dict, nms_config):
83
83
if nms_config .MULTI_CLASSES_NMS :
84
84
raise NotImplementedError
85
85
else :
86
- selected , selected_scores = class_agnostic_nms (
87
- box_scores = cur_roi_scores , box_preds = box_preds , nms_config = nms_config
88
- )
86
+ if self .training :
87
+ selected , selected_scores = class_agnostic_nms (
88
+ box_scores = cur_roi_scores , box_preds = box_preds , nms_config = nms_config
89
+ )
90
+ else :
91
+ if nms_config .get ("USE_FAST_NMS" , False ):
92
+ selected , selected_scores = fast_bev_nms (
93
+ box_scores = cur_roi_scores , box_preds = box_preds , nms_config = nms_config , score_thresh = nms_config .SCORE_THRESH
94
+ )
95
+ else :
96
+ selected , selected_scores = class_agnostic_nms (
97
+ box_scores = cur_roi_scores , box_preds = box_preds , nms_config = nms_config
98
+ )
99
+ # selected, selected_scores = class_agnostic_nms(
100
+ # box_scores=cur_roi_scores, box_preds=box_preds, nms_config=nms_config
101
+ # )
89
102
90
103
rois [index , :len (selected ), :] = box_preds [selected ]
91
104
roi_scores [index , :len (selected )] = cur_roi_scores [selected ]
@@ -189,36 +202,6 @@ def get_box_reg_layer_loss(self, forward_ret_dict):
189
202
190
203
rcnn_loss_reg += loss_corner
191
204
tb_dict ['rcnn_loss_corner' ] = loss_corner .item ()
192
-
193
- if loss_cfgs .GRID_3D_IOU_LOSS and fg_sum > 0 :
194
- fg_rcnn_reg = rcnn_reg .view (rcnn_batch_size , - 1 )[fg_mask ]
195
- fg_roi_boxes3d = roi_boxes3d .view (- 1 , code_size )[fg_mask ]
196
-
197
- fg_roi_boxes3d = fg_roi_boxes3d .view (1 , - 1 , code_size )
198
- batch_anchors = fg_roi_boxes3d .clone ().detach ()
199
- roi_ry = fg_roi_boxes3d [:, :, 6 ].view (- 1 )
200
- roi_xyz = fg_roi_boxes3d [:, :, 0 :3 ].view (- 1 , 3 )
201
- batch_anchors [:, :, 0 :3 ] = 0
202
- rcnn_boxes3d = self .box_coder .decode_torch (
203
- fg_rcnn_reg .view (batch_anchors .shape [0 ], - 1 , code_size ), batch_anchors
204
- ).view (- 1 , code_size )
205
-
206
- rcnn_boxes3d = common_utils .rotate_points_along_z (
207
- rcnn_boxes3d .unsqueeze (dim = 1 ), roi_ry
208
- ).squeeze (dim = 1 )
209
- rcnn_boxes3d [:, 0 :3 ] += roi_xyz
210
-
211
- loss_iou3d = loss_utils .get_gridify_iou3d_loss (
212
- gt_of_rois_src [fg_mask ][:, :7 ],
213
- rcnn_boxes3d [:, :7 ]
214
-
215
- )
216
-
217
- loss_iou3d = loss_iou3d .mean ()
218
- loss_iou3d = loss_iou3d * loss_cfgs .LOSS_WEIGHTS ['rcnn_iou3d_weight' ]
219
-
220
- rcnn_loss_reg += loss_iou3d
221
- tb_dict ['rcnn_loss_iou3d' ] = loss_iou3d .item ()
222
205
else :
223
206
raise NotImplementedError
224
207
0 commit comments