-
Notifications
You must be signed in to change notification settings - Fork 4
/
get_reward_ranges.py
56 lines (46 loc) · 1.71 KB
/
get_reward_ranges.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import numpy as np
import json
import os
from argparse import ArgumentParser
from rdkit.Chem import MolFromSmiles
from tensorflow.python.keras.models import load_model
from tqdm import tqdm
from molgym.envs.rewards.mpnn import MPNNReward
from molgym.envs.rewards.rdkit import LogP, QEDReward, SAScore, CycleLength
from molgym.mpnn.layers import custom_objects
# Make all of the reward functions
from molgym.utils.conversions import convert_smiles_to_nx
mpnn_dir = os.path.join('notebooks', 'mpnn-training')
model = load_model(os.path.join(mpnn_dir, 'best_model.h5'), custom_objects=custom_objects)
with open(os.path.join(mpnn_dir, 'atom_types.json')) as fp:
atom_types = json.load(fp)
with open(os.path.join(mpnn_dir, 'bond_types.json')) as fp:
bond_types = json.load(fp)
rewards = {
'logP': LogP(),
'ic50': MPNNReward(model, atom_types, bond_types, maximize=True),
'QED': QEDReward(),
'SA': SAScore(),
'cycles': CycleLength()
}
if __name__ == "__main__":
# Parse the inputs
parser = ArgumentParser()
parser.add_argument("smiles_file")
args = parser.parse_args()
# Load in the molecules
with open(args.smiles_file) as fp:
mols = [x.strip() for x in fp]
# Get only the molecules that parse with RDKit
mols = [x for x in mols if MolFromSmiles(x) is not None]
# Compute the reward function statistics for all the rewards
stats = {}
for name, reward in rewards.items():
data = [reward(convert_smiles_to_nx(mol)) for mol in tqdm(mols, desc=name)]
stats[name] = {
'mean': np.mean(data),
'scale': np.std(data)
}
# Save as a json file
with open('reward_ranges.json', 'w') as fp:
json.dump(stats, fp, indent=2)