-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdropper.py
57 lines (48 loc) · 1.74 KB
/
dropper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import numpy as np
import torch.nn as nn
"""
Implements Loss Truncation, stolen from: https://github.com/ddkang/loss_dropper/blob/master/loss_dropper/dropper.py
https://aclanthology.org/2020.acl-main.66.pdf
"""
class LossDropper(nn.Module):
def __init__(
self,
dropc=0.4,
min_count=10000,
recompute=10000,
verbose=True
):
super().__init__()
self.keepc = 1. - dropc
self.count = 0
self.min_count = min_count
self.recompute = recompute
self.last_computed = 0
self.percentile_val = 100000000.
self.cur_idx = 0
self.verbose = verbose
self.vals = np.zeros(self.recompute, dtype=np.float32)
def forward(self, loss):
if loss is None:
return loss
self.last_computed += loss.numel()
self.count += loss.numel()
if self.count < len(self.vals):
self.vals[self.count - loss.numel():self.count] = loss.detach().cpu().numpy().flatten()
self.cur_idx += loss.numel()
return (loss < np.inf).type(loss.dtype)
else:
for idx, item in enumerate(loss):
self.vals[self.cur_idx] = item
self.cur_idx += 1
if self.cur_idx >= len(self.vals):
self.cur_idx = 0
if self.count < self.min_count:
return (loss < np.inf).type(loss.dtype)
if self.last_computed > self.recompute:
self.percentile_val = np.percentile(self.vals, self.keepc * 100)
if self.verbose:
print('Using cutoff', self.percentile_val)
self.last_computed = 0
mask = (loss < self.percentile_val).type(loss.dtype)
return mask