-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
154 lines (121 loc) · 7.39 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
import utils
class YOLOLoss(torch.nn.Module):
def __init__(self, grid_size, bbox_pred_amount, class_amount, lambda_coord=5, lambda_obj=1, lambda_noobj=0.5, lambda_classification=1, reduction='sum', eps=1e-5):
super(YOLOLoss, self).__init__()
self.grid_size = grid_size
self.bbox_pred_amount = bbox_pred_amount
self.class_amount = class_amount
self.lambda_coord = lambda_coord
self.lambda_obj = lambda_obj
self.lambda_noobj = lambda_noobj
self.lambda_classification = lambda_classification
self.reduction = reduction
self.eps = eps
self.mse = torch.nn.MSELoss(reduction=self.reduction)
def forward(self, predictions, targets):
"""
Calculates YOLOLoss for given predictions and targets
:param predictions: tensor with shape (batch_size, grid_size, grid_size, (class_amount + 5 * bbox_pred_amount))
:param targets: tensor with shape (batch_size, grid_size, grid_size, (class_amount + 5))
:return: loss value
"""
# Select target bboxes
bbox_target_start_index = self.class_amount + 1
bbox_target_end_index = bbox_target_start_index + 4
bbox_targets = targets[..., bbox_target_start_index:bbox_target_end_index].clone() # (batch_size, grid_size, grid_size, 4)
# Calculate iou between gt and bbox_pred_i
bboxes_ious = []
for bbox_pred_number in range(self.bbox_pred_amount):
bbox_pred_start_index = self.class_amount + 5 * bbox_pred_number + 1
bbox_pred_end_index = bbox_pred_start_index + 4
bbox_pred = predictions[..., bbox_pred_start_index:bbox_pred_end_index] # (batch_size, grid_size, grid_size, 4)
iou = utils.bboxes_iou(bbox_pred, bbox_targets) # (batch_size, grid_size, grid_size, 1)
bboxes_ious.append(iou)
# Get numbers of responsible bboxes
bboxes_ious = torch.stack(bboxes_ious) # (bbox_pred_amount, batch_size, grid_size, grid_size, 1)
iou_maxes, bbox_responsible_number = torch.max(bboxes_ious, dim=0) # (batch_size, grid_size, grid_size, 1) both
# Find attributes of responsible bboxes (bbox coords and obj prob)
bbox_pred_responsible = torch.zeros(*predictions.shape[:-1], 4).to(targets.device) # (batch_size, grid_size, grid_size, 4)
obj_presented_pred_responsible = torch.zeros(*predictions.shape[:-1], 1).to(targets.device) # (batch_size, grid_size, grid_size, 1)
for bbox_pred_number in range(self.bbox_pred_amount):
# Select pred bboxes
bbox_pred_start_index = self.class_amount + 5 * bbox_pred_number + 1
bbox_pred_end_index = bbox_pred_start_index + 4
bbox_pred = predictions[..., bbox_pred_start_index:bbox_pred_end_index].clone() # (batch_size, grid_size, grid_size, 4)
# Select pred obj prob
obj_presented_pred_index = self.class_amount + 5 * bbox_pred_number
obj_presented_pred = predictions[..., obj_presented_pred_index:obj_presented_pred_index + 1].clone() # (batch_size, grid_size, grid_size, 1)
# Update responsible attributes
bbox_responsible_mask = (bbox_responsible_number == bbox_pred_number) # (batch_size, grid_size, grid_size, 1)
bbox_pred_responsible = bbox_pred_responsible + bbox_responsible_mask * bbox_pred
obj_presented_pred_responsible += bbox_responsible_mask * obj_presented_pred
# Fix bbox predictions to avoid numerical errors
gradient_sign = torch.sign(bbox_pred_responsible[..., 2:4].clone())
bbox_pred_responsible[..., 2:4] = torch.abs(bbox_pred_responsible[..., 2:4].clone()) # since at the beginning NN may predict negative values for width and height of bbox
bbox_pred_responsible[..., 2:4] = torch.sqrt(bbox_pred_responsible[..., 2:4].clone() + self.eps) # eps to avoid sqrt(0)
bbox_pred_responsible[..., 2:4] = bbox_pred_responsible[..., 2:4].clone() * gradient_sign
bbox_targets[..., 2:4] = torch.sqrt(bbox_targets[..., 2:4].clone())
# Select gt obj presented in cell mask
obj_presented_target_index = self.class_amount
obj_presented_target = targets[..., obj_presented_target_index:obj_presented_target_index + 1] # (batch_size, grid_size, grid_size, 1)
# Loss coordinates
bbox_pred_responsible = bbox_pred_responsible * obj_presented_target # Loss calculated only over cells where object exist
bbox_targets = bbox_targets * obj_presented_target # Loss calculated only over cells where object exist
loss_coords = self.mse(
torch.flatten(bbox_pred_responsible),
torch.flatten(bbox_targets),
)
# Loss obj presented
obj_presented_pred_responsible = obj_presented_pred_responsible * obj_presented_target # Loss calculated only over cells where object exist
loss_obj_presented = self.mse(
torch.flatten(obj_presented_pred_responsible),
torch.flatten(obj_presented_target),
)
# Loss no obj presented
loss_no_obj_presented = 0
for bbox_pred_number in range(self.bbox_pred_amount):
# Select pred obj prob
obj_presented_pred_index = self.class_amount + 5 * bbox_pred_number
obj_presented_pred = predictions[..., obj_presented_pred_index:obj_presented_pred_index + 1] # (batch_size, grid_size, grid_size, 1)
loss_no_obj_presented = loss_no_obj_presented + self.mse(
torch.flatten((1 - obj_presented_target) * obj_presented_pred),
torch.flatten((1 - obj_presented_target) * obj_presented_target),
)
# Loss classification
class_pred = obj_presented_target * predictions[..., :self.class_amount] # Loss calculated only over cells where object exist
class_target = obj_presented_target * targets[..., :self.class_amount] # Loss calculated only over cells where object exist
loss_classification = self.mse(
torch.flatten(class_pred),
torch.flatten(class_target),
)
# Complete loss
loss = self.lambda_coord * loss_coords + self.lambda_obj * loss_obj_presented + self.lambda_noobj * loss_no_obj_presented + self.lambda_classification * loss_classification
return loss
if __name__ == "__main__":
import __main__
print("Run of", __main__.__file__)
# Reproducibility
torch.manual_seed(0)
import random
random.seed(0)
import numpy as np
np.random.seed(0)
# Loss settings
grid_size = 7
bbox_pred_amount = 2
class_amount = 20
criterion = YOLOLoss(grid_size, bbox_pred_amount, class_amount, reduction="sum")
# Fake input generation
batch_size = 8
predictions = torch.randn(batch_size, grid_size * grid_size * (class_amount + 5 * bbox_pred_amount))
predictions = predictions.reshape(-1, grid_size, grid_size, class_amount + bbox_pred_amount * 5) # (batch_size, grid_size, grid_size, (class_amount + 5 * bbox_pred_amount))
print("predictions", predictions.shape)
targets = torch.randn(batch_size, grid_size, grid_size, class_amount + 5 * 1)
obj_presented_index = class_amount
targets[..., obj_presented_index] = (targets[..., obj_presented_index] > 0.5) * 1
targets = torch.abs(targets)
print("targets", targets.shape)
print("targets[0, 0, 0, :]", targets[0, 0, 0, :])
loss = criterion(predictions, targets)
print("loss", loss)