Skip to content

Commit 4f48d20

Browse files
committed
fix typing errors and add ignore stock to retro even if in stock
1 parent 8debce2 commit 4f48d20

File tree

5 files changed

+62
-39
lines changed

5 files changed

+62
-39
lines changed

aizynthfinder/aizynthfinder.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Dict
1+
from typing import Optional, Dict, Any
22
import time
33

44
from .analysis.routes import RouteCollection
@@ -22,6 +22,8 @@ def __init__(self, configfile: str) -> None:
2222
self.scorers = self.config.scorers
2323

2424
self._target_mol: Optional[Molecule] = None
25+
self.analysis: TreeAnalysis | None = None
26+
self.tree: MctsSearchTree | None = None
2527

2628
@property
2729
def target_smiles(self) -> str:
@@ -43,6 +45,8 @@ def target_mol(self, mol: Molecule) -> None:
4345
self._target_mol = mol
4446

4547
def build_routes(self, scorer: str = "state score") -> None:
48+
if self.tree is None:
49+
return None
4650
self.analysis = TreeAnalysis(self.tree, scorer=self.scorers[scorer])
4751
config_selection = RouteSelectionArguments(
4852
nmin=self.config.post_processing.min_routes,
@@ -53,7 +57,7 @@ def build_routes(self, scorer: str = "state score") -> None:
5357
config_selection)
5458

5559
def extract_statistics(self) -> Dict:
56-
if not self.analysis:
60+
if self.analysis is None:
5761
return {}
5862
stats = {
5963
"target":
@@ -81,19 +85,21 @@ def prepare_tree(self) -> None:
8185

8286
def tree_search(self) -> None:
8387
self.prepare_tree()
84-
self.search_stats = {"returned_first": False, "iterations": 0}
88+
assert self.tree is not None
89+
self.search_stats: Dict[str, Any] = {
90+
"returned_first": False,
91+
"iterations": 0
92+
}
8593

8694
time0 = time.time()
8795
time_past = .0
8896
i = 1
8997
while time_past < self.config.time_limit and i <= self.config.iteration_limit:
9098
self.search_stats["iterations"] += 1
9199
is_solved = self.tree.one_iteration()
92-
93100
if is_solved and "first_solution_time" not in self.search_stats:
94101
self.search_stats["first_solution_time"] = time.time() - time0
95102
self.search_stats["first_solution_iteration"] = i
96-
97103
if self.config.return_first and is_solved:
98104
self.search_stats["returned_first"] = True
99105
break

aizynthfinder/analysis/tree_analysis.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ def sort(
2828
sorted_items, sorted_scores, _ = self.scorer.sort(nodes)
2929
actions = [node.actions_to() for node in sorted_items]
3030

31-
return self._collect_top_items(sorted_items, sorted_scores, actions, selection)
31+
return self._collect_top_items(sorted_items, sorted_scores, actions,
32+
selection) # type: ignore
3233

3334
def tree_statistics(self) -> Dict:
3435
return self._tree_statistics_mcts()
@@ -47,14 +48,17 @@ def _tree_statistics_mcts(self) -> Dict:
4748
top_state = top_node.state
4849
nodes = list(self.search_tree.graph())
4950
mols_in_stock = ", ".join(
50-
mol.smiles for mol, instock in zip(top_state.mols, top_state.in_stock_list)
51+
mol.smiles
52+
for mol, instock in zip(top_state.mols, top_state.in_stock_list)
5153
if instock)
5254
mols_not_in_stock = ", ".join(
53-
mol.smiles for mol, instock in zip(top_state.mols, top_state.in_stock_list)
55+
mol.smiles
56+
for mol, instock in zip(top_state.mols, top_state.in_stock_list)
5457
if not instock)
5558

56-
policy_used_counts = self._policy_used_statistics(
57-
[node[child]["action"] for node in nodes for child in node.children])
59+
policy_used_counts = self._policy_used_statistics([
60+
node[child]["action"] for node in nodes for child in node.children
61+
])
5862

5963
return {
6064
"number_of_nodes":
@@ -66,7 +70,8 @@ def _tree_statistics_mcts(self) -> Dict:
6670
"number_of_routes":
6771
sum(1 for node in nodes if not node.children),
6872
"number_of_solved_routes":
69-
sum(1 for node in nodes if not node.children and node.state.is_solved),
73+
sum(1 for node in nodes
74+
if not node.children and node.state.is_solved),
7075
"top_score":
7176
self.scorer(top_node),
7277
"is_solved":
@@ -96,7 +101,8 @@ def _collect_top_items(
96101
reactions: Sequence[Union[Iterable[RetroReaction],
97102
Iterable[FixedRetroReaction]]],
98103
selection,
99-
) -> Tuple[Union[Sequence[MctsNode], Sequence[ReactionTree]], Sequence[float]]:
104+
) -> Tuple[Union[Sequence[MctsNode], Sequence[ReactionTree]],
105+
Sequence[float]]:
100106
if len(items) <= selection.nmin:
101107
return items, scores
102108

@@ -130,7 +136,8 @@ def _collect_top_items(
130136

131137
@staticmethod
132138
def _policy_used_statistics(
133-
reactions: Iterable[Union[RetroReaction, FixedRetroReaction]]) -> Dict:
139+
reactions: Iterable[Union[RetroReaction,
140+
FixedRetroReaction]]) -> Dict:
134141
policy_used_counts: Dict = defaultdict(int)
135142
for reaction in reactions:
136143
policy_used = reaction.metadata.get("policy_name")

aizynthfinder/context/config.py

+22-20
Original file line numberDiff line numberDiff line change
@@ -10,32 +10,33 @@
1010

1111
@dataclass
1212
class _PostprocessingConfiguration:
13-
min_routes: int = 1
14-
max_routes: int = 8
15-
all_routes: bool = False
13+
min_routes = 1
14+
max_routes = 8
15+
all_routes = False
1616
route_distance_model: Optional[str] = None
1717

1818

1919
@dataclass
2020
class Configuration:
2121

22-
C: float = 1.4
23-
cutoff_cumulative: float = 0.995
24-
cutoff_number: int = 50
25-
additive_expansion: bool = False
26-
use_rdchiral: bool = True
27-
max_transforms: int = 6
28-
default_prior: float = 0.5
29-
use_prior: bool = True
30-
iteration_limit: int = 100
31-
return_first: bool = False
32-
time_limit: int = 20
33-
filter_cutoff: float = 0.05
34-
exclude_target_from_stock: bool = True
35-
template_column: str = "retro_template"
36-
prune_cycles_in_search: bool = True
37-
search_algorithm: str = "mcts"
22+
C = 1.4
23+
cutoff_cumulative = 0.995
24+
cutoff_number = 50
25+
additive_expansion = False
26+
use_rdchiral = True
27+
max_transforms = 6
28+
default_prior = 0.5
29+
use_prior = True
30+
iteration_limit = 100
31+
return_first = False
32+
time_limit = 20
33+
filter_cutoff = 0.05
34+
exclude_target_from_stock = True
35+
template_column = "retro_template"
36+
prune_cycles_in_search = True
37+
search_algorithm = "mcts"
3838
post_processing = _PostprocessingConfiguration()
39+
ignore_stock = False
3940

4041
def __post_init__(self) -> None:
4142
self._properties: Dict = {}
@@ -54,7 +55,8 @@ def from_dict(cls, source: Dict) -> "Configuration":
5455
config_obj = Configuration()
5556
config_obj._update_from_config(source)
5657

57-
config_obj.expansion_policy.load_from_config(**source.get("policy", {}))
58+
config_obj.expansion_policy.load_from_config(
59+
**source.get("policy", {}))
5860
config_obj.filter_policy.load_from_config(**source.get("filter", {}))
5961
config_obj.stock.load_from_config(**source.get("stock", {}))
6062

aizynthfinder/mcts/state.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,13 @@ class MctsState:
1212
def __init__(self, mols: Sequence[TreeMolecule], config: Any) -> None:
1313
self.mols = mols
1414
self.stock = config.stock
15-
self.in_stock_list = [mol in self.stock for mol in self.mols]
15+
if config.ignore_stock:
16+
self.in_stock_list = [False for mol in self.mols]
17+
else:
18+
self.in_stock_list = [mol in self.stock for mol in self.mols]
1619
self.expandable_mols = [
17-
mol for mol, in_stock in zip(self.mols, self.in_stock_list) if not in_stock
20+
mol for mol, in_stock in zip(self.mols, self.in_stock_list)
21+
if not in_stock
1822
]
1923
self._stock_availability: Optional[List[str]] = None
2024
self.is_solved = all(self.in_stock_list)
@@ -39,7 +43,9 @@ def __str__(self) -> str:
3943
string = "%s\n%s\n%s\n%s\nScore: %0.3F Solved: %s" % (
4044
str([mol.smiles for mol in self.mols]),
4145
str([mol.transform for mol in self.mols]),
42-
str([mol.parent.smiles if mol.parent else "-" for mol in self.mols]),
46+
str([
47+
mol.parent.smiles if mol.parent else "-" for mol in self.mols
48+
]),
4349
str(self.in_stock_list),
4450
self.score,
4551
self.is_solved,

usage.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,20 @@
88
finder.expansion_policy.select("uspto")
99
finder.filter_policy.select("uspto")
1010

11-
finder.target_smiles = "Cc1cccc(c1N(CC(=O)Nc2ccc(cc2)c3ncon3)C(=O)C4CCS(=O)(=O)CC4)C"
11+
# Cc1cccc(c1N(CC(=O)Nc2ccc(cc2)c3ncon3)C(=O)C4CCS(=O)(=O)CC4)C
12+
finder.target_smiles = "COc1cc(C=CC(=O)O)cc(OC)c1OC"
1213
# print(psutil.Process().memory_info().rss / (1024 * 1024), "MB")
1314
finder.tree_search()
1415
finder.build_routes()
1516

16-
stats = finder.extract_statistics()
17+
# stats = finder.extract_statistics()
1718
# routes details
1819
# for d in finder.routes.dicts:
1920
# print("---" * 20)
2021
# print(d)
22+
breakpoint()
2123
plt.imshow(finder.routes.images[0])
2224
plt.show()
2325

2426
# stats
25-
print(stats)
27+
# print(stats)

0 commit comments

Comments
 (0)