Skip to content

Commit 9155056

Browse files
authored
Merge pull request #14 from js-ish/feat-listrank
improve learn to rank
2 parents ea089c5 + 290064b commit 9155056

File tree

7 files changed

+64
-38
lines changed

7 files changed

+64
-38
lines changed

dooc/datasets.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -145,16 +145,16 @@ def __call__(
145145
Mutations(Individual Sample) and Smiles Interaction
146146
147147
MutSmiReg
148-
MutSmis{Pair/List}
149-
MutsSmi{Pair/List}
148+
MutSmis{Pair/List}wiseRank
149+
MutsSmi{Pair/List}wiseRank
150150
"""
151151

152152

153153
class MutSmiReg(_DrugcellAdamr2MutSmi):
154154
pass
155155

156156

157-
class MutSmisPairwise(_DrugcellAdamr2MutSmis):
157+
class MutSmisPairwiseRank(_DrugcellAdamr2MutSmis):
158158
def __call__(
159159
self,
160160
muts: typing.Sequence[list],
@@ -163,6 +163,10 @@ def __call__(
163163
seq_len: int = 200
164164
) -> typing.Tuple[torch.Tensor]:
165165
mut_x, smi_tgt, rout = super().__call__(muts, lsmiles, lvalues, seq_len)
166-
out = torch.zeros(rout.size(0), dtype=torch.long, device=self.device)
167-
out[(rout[:, 0] - rout[:, 1]) > 0.0] = 1
166+
out = torch.zeros(rout.size(0), dtype=rout.dtype, device=self.device)
167+
out[(rout[:, 0] - rout[:, 1]) > 0.0] = 1.0
168168
return mut_x, smi_tgt, out
169+
170+
171+
class MutSmisListwiseRank(_DrugcellAdamr2MutSmis):
172+
pass

dooc/loss.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
class ListNetLoss(nn.Module):
6+
def __init__(self, reduction: str = 'mean') -> None:
7+
super().__init__()
8+
assert reduction in ['mean', 'sum']
9+
self.reduction = reduction
10+
11+
def forward(self, predict: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
12+
out = - (target.softmax(dim=-1) * predict.log_softmax(dim=-1))
13+
return getattr(out, self.reduction)()

dooc/models.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
Mutations(Individual Sample) and Smiles Interaction
1010
1111
MutSmiReg
12-
MutSmis{Pair/List}
13-
MutsSmi{Pair/List}
12+
MutSmisRank
13+
MutsSmiRank
1414
"""
1515

1616

@@ -20,22 +20,25 @@ def __init__(self, mut_conf: drugcell.DrugcellConfig = dnets.Drugcell.DEFAULT_CO
2020
super().__init__(mut_conf, smi_conf)
2121
self.reg = heads.RegHead(self.smi_conf.d_model)
2222

23-
def forward(self, *args, **kwargs) -> torch.Tensor:
24-
return self.reg(super().forward(*args, **kwargs)) # [b, 1]
23+
def forward(
24+
self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor:
25+
return self.reg(super().forward(mut_x, smi_tgt)) # [b, 1]
2526

2627

27-
class MutSmisPairwise(dnets.DrugcellAdamr2MutSmisXattn):
28+
class MutSmisRank(dnets.DrugcellAdamr2MutSmisXattn):
2829

2930
def __init__(self, mut_conf: drugcell.DrugcellConfig = dnets.Drugcell.DEFAULT_CONFIG, smi_conf: mnets.AbsPosEncoderCausalConfig = mmodels.AdaMR2.CONFIG_LARGE) -> None:
3031
super().__init__(mut_conf, smi_conf)
31-
self.pairwise_rank = heads.PairwiseRankHead(self.smi_conf.d_model)
32+
self.reg = heads.RegHead(self.smi_conf.d_model)
3233

33-
def forward(self, *args, **kwargs) -> torch.Tensor:
34-
return self.pairwise_rank(super().forward(*args, **kwargs)) # [b, 2]
34+
def forward(
35+
self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor:
36+
return self.reg(super().forward(mut_x, smi_tgt)).squeeze(-1) # [b, n]
3537

36-
def forward_cmp(self, *args, **kwargs) -> float:
38+
def forward_cmp(self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> float:
3739
"""
3840
for infer, no batch dim
3941
"""
40-
out = self.forward(*args, **kwargs)
41-
return (out[1] - out[0]).item()
42+
assert mut_x.dim() == 1 and smi_tgt.dim() == 2
43+
out = self.forward(mut_x, smi_tgt) # [2]
44+
return (out[0] - out[1]).item()

dooc/nets/heads.py

-19
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,3 @@ def __init__(self, in_features: int):
1515

1616
def forward(self, x: torch.Tensor) -> torch.Tensor:
1717
return self.mlp(x)
18-
19-
20-
class PairwiseRankHead(nn.Module):
21-
def __init__(self, d_features: int):
22-
super().__init__()
23-
self.mlp = nn.Sequential(
24-
nn.Flatten(-2),
25-
nn.Linear(d_features * 2, d_features),
26-
nn.ReLU(),
27-
nn.Dropout(0.1),
28-
nn.Linear(d_features, 2)
29-
)
30-
31-
def forward(self, x: torch.Tensor) -> torch.Tensor:
32-
"""
33-
x: [b, 2, d_features]
34-
"""
35-
assert x.size(-2) == 2
36-
return self.mlp(x) # [b, 2] 1: x1 > x2, 0: x1 <= x2

tests/test_datasets.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ def test_MutSmiReg(smi_tkz):
7575
assert out.shape == (2, 1)
7676

7777

78-
def test_MutSmisPairwise(smi_tkz):
79-
ds = datasets.MutSmisPairwise(smi_tkz)
78+
def test_MutSmisPairwiseRank(smi_tkz):
79+
ds = datasets.MutSmisPairwiseRank(smi_tkz)
8080
lsmis = [["CC[N+]CCBr", "Cc1ccc1"], ["CCC[N+]CCBr", "CCc1ccc1"]]
8181
lvals = [[0.88, 0.89], [0.82, 0.9]]
8282
muts = [[random.choice([0, 1]) for _ in range(52)],
@@ -87,3 +87,17 @@ def test_MutSmisPairwise(smi_tkz):
8787
assert smi_tgt.shape == (2, 2, 200)
8888
assert mut_x.shape == (2, 52)
8989
assert out.shape == (2,)
90+
91+
92+
def test_MutSmisListwiseRank(smi_tkz):
93+
ds = datasets.MutSmisListwiseRank(smi_tkz)
94+
lsmis = [["CC[N+]CCBr", "Cc1ccc1", "Cc1ccc1"], ["CCC[N+]CCBr", "CCc1ccc1", "Cc1ccc1"]]
95+
lvals = [[0.88, 0.89, 0.89], [0.82, 0.9, 0.9]]
96+
muts = [[random.choice([0, 1]) for _ in range(52)],
97+
[random.choice([0, 1]) for _ in range(52)]]
98+
with pytest.raises(AssertionError):
99+
ds(muts, lsmis, lvals[:1])
100+
mut_x, smi_tgt, out = ds(muts, lsmis, lvals)
101+
assert smi_tgt.shape == (2, 3, 200)
102+
assert mut_x.shape == (2, 52)
103+
assert out.shape == (2, 3)

tests/test_loss.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import torch
2+
from dooc import loss
3+
4+
def test_ListNetLoss():
5+
predict = torch.randn(5, 3)
6+
target = torch.randn(5, 3)
7+
loss_mean = loss.ListNetLoss(reduction='mean')
8+
mean = loss_mean(predict, target)
9+
loss_sum = loss.ListNetLoss(reduction='sum')
10+
sum = loss_sum(predict, target)
11+
assert sum / 15 == mean

tests/test_pipelines.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class Pointwise(pipelines._MutSmi, pipelines._MutSmisRank):
2727
assert len(out) == 3
2828
assert out[1] == "CC[N+](C)(C)Cc1ccccc1Br"
2929

30-
model = models.MutSmisPairwise()
30+
model = models.MutSmisRank()
3131
pipeline = pipelines.MutSmisRank(smi_tokenizer=smi_tkz, model=model)
3232
out = pipeline(mutation, smiles)
3333
assert isinstance(out, list)

0 commit comments

Comments
 (0)