From 06bb261648f7569070a4aec787a4e4f01a677f8c Mon Sep 17 00:00:00 2001 From: Zhi Tian Date: Mon, 25 May 2020 21:11:41 +0930 Subject: [PATCH] fixed fcos demo score threshold bugs --- adet/modeling/fcos/fcos_outputs.py | 8 +++++--- demo/demo.py | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/adet/modeling/fcos/fcos_outputs.py b/adet/modeling/fcos/fcos_outputs.py index a2d61987d..cdcd1fcbf 100644 --- a/adet/modeling/fcos/fcos_outputs.py +++ b/adet/modeling/fcos/fcos_outputs.py @@ -428,13 +428,15 @@ def forward_for_single_feature_map( # if self.thresh_with_ctr is True, we multiply the classification # scores with centerness scores before applying the threshold. if self.thresh_with_ctr: - box_cls = box_cls * ctrness[:, :, None] + # sqrt is used to calibrate the scores, which does not affect the COCO AP. + box_cls = torch.sqrt(box_cls * ctrness[:, :, None]) candidate_inds = box_cls > self.pre_nms_thresh pre_nms_top_n = candidate_inds.view(N, -1).sum(1) pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n) if not self.thresh_with_ctr: - box_cls = box_cls * ctrness[:, :, None] + # sqrt is used to calibrate the scores, which does not affect the COCO AP. + box_cls = torch.sqrt(box_cls * ctrness[:, :, None]) results = [] for i in range(N): @@ -473,7 +475,7 @@ def forward_for_single_feature_map( boxlist = Instances(image_sizes[i]) boxlist.pred_boxes = Boxes(detections) - boxlist.scores = torch.sqrt(per_box_cls) + boxlist.scores = per_box_cls boxlist.pred_classes = per_class boxlist.locations = per_locations if top_feat is not None: diff --git a/demo/demo.py b/demo/demo.py index 89c0aaa00..bf0fd5aa2 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -26,6 +26,7 @@ def setup_cfg(args): cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold cfg.MODEL.FCOS.INFERENCE_TH_TEST = args.confidence_threshold + cfg.MODEL.FCOS.THRESH_WITH_CTR = True cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold cfg.freeze() return cfg