-
Notifications
You must be signed in to change notification settings - Fork 9
/
thompson_sampling.py
executable file
·226 lines (210 loc) · 10.6 KB
/
thompson_sampling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import random
from typing import List, Optional, Tuple
import functools
import math
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from tqdm.auto import tqdm
from disallow_tracker import DisallowTracker
from reagent import Reagent
from ts_logger import get_logger
from ts_utils import read_reagents
from evaluators import DBEvaluator
class ThompsonSampler:
def __init__(self, mode="maximize", log_filename: Optional[str] = None):
"""
Basic init
:param mode: maximize or minimize
:param log_filename: Optional filename to write logging to. If None, logging will be output to stdout
"""
# A list of lists of Reagents. Each component in the reaction will have one list of Reagents in this list
self.reagent_lists: List[List[Reagent]] = []
self.reaction = None
self.evaluator = None
self.num_prods = 0
self.logger = get_logger(__name__, filename=log_filename)
self._disallow_tracker = None
self.hide_progress = False
self._mode = mode
if self._mode == "maximize":
self.pick_function = np.nanargmax
self._top_func = max
elif self._mode == "minimize":
self.pick_function = np.nanargmin
self._top_func = min
elif self._mode == "maximize_boltzmann":
# See documentation for _boltzmann_reweighted_pick
self.pick_function = functools.partial(self._boltzmann_reweighted_pick)
self._top_func = max
elif self._mode == "minimize_boltzmann":
# See documentation for _boltzmann_reweighted_pick
self.pick_function = functools.partial(self._boltzmann_reweighted_pick)
self._top_func = min
else:
raise ValueError(f"{mode} is not a supported argument")
self._warmup_std = None
def _boltzmann_reweighted_pick(self, scores: np.ndarray):
"""Rather than choosing the top sampled score, use a reweighted probability.
Zhao, H., Nittinger, E. & Tyrchan, C. Enhanced Thompson Sampling by Roulette
Wheel Selection for Screening Ultra-Large Combinatorial Libraries.
bioRxiv 2024.05.16.594622 (2024) doi:10.1101/2024.05.16.594622
suggested several modifications to the Thompson Sampling procedure.
This method implements one of those, namely a Boltzmann style probability distribution
from the sampled values. The reagent is chosen based on that distribution rather than
simply the max sample.
"""
if self._mode == "minimize_boltzmann":
scores = -scores
exp_terms = np.exp(scores / self._warmup_std)
probs = exp_terms / np.nansum(exp_terms)
probs[np.isnan(probs)] = 0.0
return np.random.choice(probs.shape[0], p=probs)
def set_hide_progress(self, hide_progress: bool) -> None:
"""
Hide the progress bars
:param hide_progress: set to True to hide the progress baars
"""
self.hide_progress = hide_progress
def read_reagents(self, reagent_file_list, num_to_select: Optional[int] = None):
"""
Reads the reagents from reagent_file_list
:param reagent_file_list: List of reagent filepaths
:param num_to_select: Max number of reagents to select from the reagents file (for dev purposes only)
:return: None
"""
self.reagent_lists = read_reagents(reagent_file_list, num_to_select)
self.num_prods = math.prod([len(x) for x in self.reagent_lists])
self.logger.info(f"{self.num_prods:.2e} possible products")
self._disallow_tracker = DisallowTracker([len(x) for x in self.reagent_lists])
def get_num_prods(self) -> int:
"""
Get the total number of possible products
:return: num_prods
"""
return self.num_prods
def set_evaluator(self, evaluator):
"""
Define the evaluator
:param evaluator: evaluator class, must define an evaluate method that takes an RDKit molecule
"""
self.evaluator = evaluator
def set_reaction(self, rxn_smarts):
"""
Define the reaction
:param rxn_smarts: reaction SMARTS
"""
self.reaction = AllChem.ReactionFromSmarts(rxn_smarts)
def evaluate(self, choice_list: List[int]) -> Tuple[str, str, float]:
"""Evaluate a set of reagents
:param choice_list: list of reagent ids
:return: smiles for the reaction product, score for the reaction product
"""
selected_reagents = []
for idx, choice in enumerate(choice_list):
component_reagent_list = self.reagent_lists[idx]
selected_reagents.append(component_reagent_list[choice])
prod = self.reaction.RunReactants([reagent.mol for reagent in selected_reagents])
product_name = "_".join([reagent.reagent_name for reagent in selected_reagents])
res = np.nan
product_smiles = "FAIL"
if prod:
prod_mol = prod[0][0] # RunReactants returns Tuple[Tuple[Mol]]
Chem.SanitizeMol(prod_mol)
product_smiles = Chem.MolToSmiles(prod_mol)
if isinstance(self.evaluator, DBEvaluator):
res = self.evaluator.evaluate(product_name)
res = float(res)
else:
res = self.evaluator.evaluate(prod_mol)
if np.isfinite(res):
[reagent.add_score(res) for reagent in selected_reagents]
return product_smiles, product_name, res
def warm_up(self, num_warmup_trials=3):
"""Warm-up phase, each reagent is sampled with num_warmup_trials random partners
:param num_warmup_trials: number of times to sample each reagent
"""
# get the list of reagent indices
idx_list = list(range(0, len(self.reagent_lists)))
# get the number of reagents for each component in the reaction
reagent_count_list = [len(x) for x in self.reagent_lists]
warmup_results = []
for i in idx_list:
partner_list = [x for x in idx_list if x != i]
# The number of reagents for this component
current_max = reagent_count_list[i]
# For each reagent...
for j in tqdm(range(0, current_max), desc=f"Warmup {i + 1} of {len(idx_list)}", disable=self.hide_progress):
# For each warmup trial...
for k in range(0, num_warmup_trials):
current_list = [DisallowTracker.Empty] * len(idx_list)
current_list[i] = DisallowTracker.To_Fill
disallow_mask = self._disallow_tracker.get_disallowed_selection_mask(current_list)
if j not in disallow_mask:
## ok we can select this reagent
current_list[i] = j
# Randomly select reagents for each additional component of the reaction
for p in partner_list:
# tell the disallow tracker which site we are filling
current_list[p] = DisallowTracker.To_Fill
# get the new disallow mask
disallow_mask = self._disallow_tracker.get_disallowed_selection_mask(current_list)
selection_scores = np.random.uniform(size=reagent_count_list[p])
# null out the disallowed ones
selection_scores[list(disallow_mask)] = np.NaN
# and select a random one
current_list[p] = np.nanargmax(selection_scores).item(0)
self._disallow_tracker.update(current_list)
product_smiles, product_name, score = self.evaluate(current_list)
if np.isfinite(score):
warmup_results.append([score, product_smiles, product_name])
warmup_scores = [ws[0] for ws in warmup_results]
self.logger.info(
f"warmup score stats: "
f"cnt={len(warmup_scores)}, "
f"mean={np.mean(warmup_scores):0.4f}, "
f"std={np.std(warmup_scores):0.4f}, "
f"min={np.min(warmup_scores):0.4f}, "
f"max={np.max(warmup_scores):0.4f}")
# initialize each reagent
prior_mean = np.mean(warmup_scores)
prior_std = np.std(warmup_scores)
self._warmup_std = prior_std
for i in range(0, len(self.reagent_lists)):
for j in range(0, len(self.reagent_lists[i])):
reagent = self.reagent_lists[i][j]
try:
reagent.init_given_prior(prior_mean=prior_mean, prior_std=prior_std)
except ValueError:
self.logger.info(f"Skipping reagent {reagent.reagent_name} because there were no successful evaluations during warmup")
self._disallow_tracker.retire_one_synthon(i, j)
self.logger.info(f"Top score found during warmup: {max(warmup_scores):.3f}")
return warmup_results
def search(self, num_cycles=25):
"""Run the search
:param: num_cycles: number of search iterations
:return: a list of SMILES and scores
"""
out_list = []
rng = np.random.default_rng()
for i in tqdm(range(0, num_cycles), desc="Cycle", disable=self.hide_progress):
selected_reagents = [DisallowTracker.Empty] * len(self.reagent_lists)
for cycle_id in random.sample(range(0, len(self.reagent_lists)), len(self.reagent_lists)):
reagent_list = self.reagent_lists[cycle_id]
selected_reagents[cycle_id] = DisallowTracker.To_Fill
disallow_mask = self._disallow_tracker.get_disallowed_selection_mask(selected_reagents)
stds = np.array([r.current_std for r in reagent_list])
mu = np.array([r.current_mean for r in reagent_list])
choice_row = rng.normal(size=len(reagent_list)) * stds + mu
if disallow_mask:
choice_row[np.array(list(disallow_mask))] = np.NaN
selected_reagents[cycle_id] = self.pick_function(choice_row)
self._disallow_tracker.update(selected_reagents)
# Select a reagent for each component, according to the choice function
smiles, name, score = self.evaluate(selected_reagents)
if np.isfinite(score):
out_list.append([score, smiles, name])
if i % 100 == 0:
top_score, top_smiles, top_name = self._top_func(out_list)
self.logger.info(f"Iteration: {i} max score: {top_score:2f} smiles: {top_smiles} {top_name}")
return out_list