-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtree_loss.py
71 lines (63 loc) · 2.82 KB
/
tree_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from numpy.core.fromnumeric import shape
import torch
import torch.nn as nn
class TreeLoss(nn.Module):
def __init__(self, hierarchy, total_nodes, levels, device):
super(TreeLoss, self).__init__()
self.stateSpace = self.generateStateSpace(hierarchy, total_nodes, levels).to(device)
def forward(self, fs, labels, device):
index = torch.mm(self.stateSpace, fs.T)
joint = torch.exp(index)
z = torch.sum(joint, dim=0)
loss = torch.zeros(fs.shape[0], dtype=torch.float64).to(device)
for i in range(len(labels)):
marginal = torch.sum(torch.index_select(joint[:, i], 0, torch.where(self.stateSpace[:, labels[i]] > 0)[0]))
# if labels[i] > 50: # We want to emphasize the importance of species adjustment
# loss[i] = -torch.log(marginal / z[i]) * 3
loss[i] = -torch.log(marginal / z[i])
return torch.mean(loss)
def inference(self, fs, device):
with torch.no_grad():
index = torch.mm(self.stateSpace, fs.T)
joint = torch.exp(index)
z = torch.sum(joint, dim=0)
pMargin = torch.zeros((fs.shape[0], fs.shape[1]), dtype=torch.float64).to(device)
for i in range(fs.shape[0]):
for j in range(fs.shape[1]):
pMargin[i, j] = torch.sum(torch.index_select(joint[:, i], 0, torch.where(self.stateSpace[:, j] > 0)[0]))
return pMargin
def generateStateSpace(self, hierarchy, total_nodes, levels):
# Generate StateSpace of given hierarchy which only contains parent-child and mutual-exclusive relations
stateSpace = torch.zeros(total_nodes + 1, total_nodes)
recorded = torch.zeros(total_nodes)
i = 1
if levels == 2:
for path in hierarchy:
# path = [species, families]
if recorded[path[1]] == 0:
stateSpace[i, path[1]] = 1
recorded[path[1]] = 1
i += 1
stateSpace[i, path[1]] = 1
stateSpace[i, path[0]] = 1
i += 1
elif levels == 3:
for path in hierarchy:
# path = [species, orders, families]
if recorded[path[1]] == 0:
stateSpace[i, path[1]] = 1
recorded[path[1]] = 1
i += 1
if recorded[path[2]] == 0:
stateSpace[i, path[1]] = 1
stateSpace[i, path[2]] = 1
recorded[path[2]] = 1
i += 1
stateSpace[i, path[1]] = 1
stateSpace[i, path[2]] = 1
stateSpace[i, path[0]] = 1
i += 1
if i == total_nodes + 1:
return stateSpace
else:
print('Invalid StateSpace!!!')