3
3
# Copyright (C) 2023-now, RPL, KTH Royal Institute of Technology
4
4
# Author: Qingwen Zhang (https://kin-zhang.github.io/)
5
5
#
6
- # This work is licensed under the terms of the MIT license.
7
- # For a copy, see <https://opensource.org/licenses/MIT>.
8
-
6
+ # Description: Define the loss function for training.
9
7
"""
10
8
import torch
11
9
@@ -64,70 +62,4 @@ def ff3dLoss(res_dict):
64
62
is_foreground_class = (classes > 0 ) # 0 is background, ref: FOREGROUND_BACKGROUND_BREAKDOWN
65
63
background_scalar = is_foreground_class .float () * 0.9 + 0.1
66
64
error = error * background_scalar
67
- return error .mean ()
68
-
69
- # ==========================> From AV2.0 Eval Official Scripts.
70
- from typing import Dict , Final
71
- import os , sys
72
- BASE_DIR = os .path .abspath (os .path .join ( os .path .dirname ( __file__ ), '../..' ))
73
- sys .path .append (BASE_DIR )
74
- from scripts .utils .av2_eval import compute_metrics , FOREGROUND_BACKGROUND_BREAKDOWN
75
- import numpy as np
76
-
77
- CLOSE_DISTANCE_THRESHOLD : Final = 35.0
78
- EPS : Final = 1e-6
79
- def compute_epe (res_dict , indices , eps = 1e-8 ):
80
- epe_sum = 0
81
- count_sum = 0
82
- for index in indices :
83
- count = res_dict ['Count' ][index ]
84
- if count != 0 :
85
- epe_sum += res_dict ['EPE' ][index ] * count
86
- count_sum += count
87
- return epe_sum / (count_sum + eps ) if count_sum != 0 else 0.0
88
-
89
- # after ground mask already, not origin N, 3 but without ground points
90
- def evaluate_leaderboard (est_flow , rigid_flow , pc0 , gt_flow , is_valid , pts_ids ):
91
-
92
- # gt_is_dynamic = (gt_flow - rigid_flow).norm(dim=1, p=2) > 0.05
93
- gt_is_dynamic = torch .linalg .vector_norm (gt_flow - rigid_flow , dim = - 1 ) >= 0.05
94
- mask_ = ~ est_flow .isnan ().any (dim = 1 ) & ~ rigid_flow .isnan ().any (dim = 1 ) & ~ pc0 [:, :3 ].isnan ().any (dim = 1 ) & ~ gt_flow .isnan ().any (dim = 1 )
95
- mask_no_nan = mask_ & ~ gt_is_dynamic .isnan () & ~ is_valid .isnan () & ~ pts_ids .isnan ()
96
- est_flow = est_flow [mask_no_nan , :]
97
- rigid_flow = rigid_flow [mask_no_nan , :]
98
- pc0 = pc0 [mask_no_nan , :]
99
- gt_flow = gt_flow [mask_no_nan , :]
100
- gt_is_dynamic = gt_is_dynamic [mask_no_nan ]
101
- is_valid = is_valid [mask_no_nan ]
102
- pts_ids = pts_ids [mask_no_nan ]
103
-
104
- est_is_dynamic = torch .linalg .vector_norm (est_flow - rigid_flow , dim = - 1 ) >= 0.05
105
- is_close = torch .all (torch .abs (pc0 [:, :2 ]) <= CLOSE_DISTANCE_THRESHOLD , dim = 1 )
106
- res_dict = compute_metrics (
107
- est_flow .detach ().cpu ().numpy ().astype (float ),
108
- est_is_dynamic .detach ().cpu ().numpy ().astype (bool ),
109
- gt_flow .detach ().cpu ().numpy ().astype (float ),
110
- pts_ids .detach ().cpu ().numpy ().astype (np .uint8 ),
111
- gt_is_dynamic .detach ().cpu ().numpy ().astype (bool ),
112
- is_close .detach ().cpu ().numpy ().astype (bool ),
113
- is_valid .detach ().cpu ().numpy ().astype (bool ),
114
- FOREGROUND_BACKGROUND_BREAKDOWN ,
115
- )
116
-
117
- # reference: eval.py L503
118
- # we need Dynamic IoU and EPE 3-Way Average to calculate loss!
119
- # weighted: (x[metric_type.value] * x.Count).sum() / total
120
- # 'Class': ['Background', 'Background', 'Background', 'Background', 'Foreground', 'Foreground', 'Foreground', 'Foreground']
121
- # 'Motion': ['Dynamic', 'Dynamic', 'Static', 'Static', 'Dynamic', 'Dynamic', 'Static', 'Static']
122
- # 'Distance': ['Close', 'Far', 'Close', 'Far', 'Close', 'Far', 'Close', 'Far']
123
-
124
- EPE_Background_Static = compute_epe (res_dict , [2 , 3 ])
125
- EPE_Dynamic = compute_epe (res_dict , [4 , 5 ])
126
- EPE_Foreground_Static = compute_epe (res_dict , [6 , 7 ])
127
-
128
- Dynamic_IoU = sum (res_dict ['TP' ]) / (sum (res_dict ['TP' ]) + sum (res_dict ['FP' ]) + sum (res_dict ['FN' ])+ EPS )
129
- # EPE_Dynamic is nan?
130
- if np .isnan (EPE_Dynamic ) or np .isnan (Dynamic_IoU ) or np .isnan (EPE_Background_Static ) or np .isnan (EPE_Foreground_Static ):
131
- print (res_dict )
132
-
133
- return EPE_Background_Static , EPE_Dynamic , EPE_Foreground_Static , Dynamic_IoU
65
+ return error .mean ()
0 commit comments