@@ -75,8 +75,8 @@ def test_MutSmiReg(smi_tkz):
75
75
assert out .shape == (2 , 1 )
76
76
77
77
78
- def test_MutSmisPairwise (smi_tkz ):
79
- ds = datasets .MutSmisPairwise (smi_tkz )
78
+ def test_MutSmisPairwiseRank (smi_tkz ):
79
+ ds = datasets .MutSmisPairwiseRank (smi_tkz )
80
80
lsmis = [["CC[N+]CCBr" , "Cc1ccc1" ], ["CCC[N+]CCBr" , "CCc1ccc1" ]]
81
81
lvals = [[0.88 , 0.89 ], [0.82 , 0.9 ]]
82
82
muts = [[random .choice ([0 , 1 ]) for _ in range (52 )],
@@ -87,3 +87,17 @@ def test_MutSmisPairwise(smi_tkz):
87
87
assert smi_tgt .shape == (2 , 2 , 200 )
88
88
assert mut_x .shape == (2 , 52 )
89
89
assert out .shape == (2 ,)
90
+
91
+
92
+ def test_MutSmisListwiseRank (smi_tkz ):
93
+ ds = datasets .MutSmisListwiseRank (smi_tkz )
94
+ lsmis = [["CC[N+]CCBr" , "Cc1ccc1" , "Cc1ccc1" ], ["CCC[N+]CCBr" , "CCc1ccc1" , "Cc1ccc1" ]]
95
+ lvals = [[0.88 , 0.89 , 0.89 ], [0.82 , 0.9 , 0.9 ]]
96
+ muts = [[random .choice ([0 , 1 ]) for _ in range (52 )],
97
+ [random .choice ([0 , 1 ]) for _ in range (52 )]]
98
+ with pytest .raises (AssertionError ):
99
+ ds (muts , lsmis , lvals [:1 ])
100
+ mut_x , smi_tgt , out = ds (muts , lsmis , lvals )
101
+ assert smi_tgt .shape == (2 , 3 , 200 )
102
+ assert mut_x .shape == (2 , 52 )
103
+ assert out .shape == (2 , 3 )
0 commit comments