Skip to content

Commit

Permalink
Fix data loader for QM9 dataset
Browse files Browse the repository at this point in the history
- Add vocab pregenerator for QM9, PubChem, and ZINC datasets
- Updated vocab json file
- Adds sensible defaults for LC systems
  • Loading branch information
szaman19 committed Feb 16, 2024
1 parent e19ecc0 commit 9eddb19
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 47 deletions.
26 changes: 12 additions & 14 deletions applications/FLASK/Transformer/datasets/QM9.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,53 @@
"""
The QM9 dataset, stored as pre-tokenized binary files for optimized processing.
"""

import os
import os.path
import pickle

import numpy as np
from pretokenize.SMILES_tokenizer import MolTokenizer

sequence_length = int(os.getenv('QM9_SEQUENCE_LENGTH', default='32'))
sequence_length = int(os.getenv("QM9_SEQUENCE_LENGTH", default="32"))

# ----------------------------------------------
# Setup
# ----------------------------------------------

# Load the datasets
data_dir = os.getenv(
'QM9_DATA_DIR',
'/p/vast1/lbann/datasets/FLASK/qm9')
data_dir = os.getenv("QM9_DATA_DIR", "/p/vast1/lbann/datasets/FLASK/QM9")

tokenizer = MolTokenizer("SMILES_vocab.json")
tokenizer.load_vocab_file()

dataset_train = np.load(os.path.join(data_dir, 'QM9_Pretokenize.py'))
dataset_train = np.load(os.path.join(data_dir, "QM9_Pretokenized.npy"))

_vocab_size = 46

pad_index = tokenizer.token_to_id('<pad>')
bos_index = tokenizer.token_to_id('<bos>')
eos_index = tokenizer.token_to_id('<eos>')

# ----------------------------------------------
# Sample access functions
# ----------------------------------------------


def num_train_samples():
return dataset_train.shape[0]


def get_train_sample(i):
data = dataset_train[i]
return data

return

def sample_dims():
return (2 * sequence_length + 1, )
return (2 * sequence_length + 1,)


def vocab_size():
return _vocab_size


if __name__ == '__main__':
print('Training samples:', num_train_samples())
print('Training sample 101:')
if __name__ == "__main__":
print("Training samples:", num_train_samples())
print("Training sample 101:")
print(get_train_sample(101))
100 changes: 100 additions & 0 deletions applications/FLASK/Transformer/datasets/pretokenize/GenerateVocab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from multiprocessing import Pool
from SMILES_tokenizer import MolTokenizer
from glob import glob
import argparse
import os
import numpy as np
from tqdm import tqdm


parser = argparse.ArgumentParser(
description="Generate vocab files for different datasets"
)

parser.add_argument(
"--qm9", action="store_true", help="Generate vocab file for QM9 dataset"
)
parser.add_argument(
"--zinc", action="store_true", help="Generate vocab file for ZINC dataset"
)
parser.add_argument(
"--pubchem", action="store_true", help="Generate vocab file for PubChem dataset"
)

args = parser.parse_args()


def join_vocabs(list_of_vocab_dicts):
"""
Given a list of vocab dictionaries, join them together
such that all unique tokens are present in the final vocab dictionary
"""
final_vocab = {}
counter = 0
for vocab in list_of_vocab_dicts:
for token in vocab.keys():
if token not in final_vocab.keys():
final_vocab[token] = counter
counter += 1
return final_vocab


def generate_zinc_vocab_dict(smi_file):
tokenizer = MolTokenizer()
with open(smi_file, "r") as f:
data = f.readlines()
for i in tqdm(range(1, len(data))):
line = data[i].split(" ")
_ = tokenizer._tokenize(line[0])
return tokenizer.vocab_dict


def main():

if args.qm9:
print("Generating vocab file for QM9 dataset")
tokenizer = MolTokenizer("QM9_vocab.json")
default_file = "/p/vast1/lbann/datasets/FLASK/QM9/QM9_smiles.txt"
qm9_file = os.getenv("QM9_FILE", default_file)
with open(qm9_file, "r") as smiles_data:
smiles_data = smiles_data.readlines()
for line in tqdm(smiles_data):
tokens = tokenizer.tokenize(line)
tokenizer.generate_vocab_file("QM9_vocab.json")
print("QM9 vocab file generated")

if args.zinc:
print("Generating vocab file for ZINC dataset")
default_dir = "/p/vast1/lbann/datasets/FLASK/ZINC"
zinc_dir = os.getenv("ZINC_DIR", default_dir)
zinc_files = glob(f"{zinc_dir}/*.smi")

print(len(zinc_files))

with Pool(20) as p:
zinc_vocab_dicts = p.map(generate_zinc_vocab_dict, zinc_files)

final_vocab = join_vocabs(zinc_vocab_dicts)

final_tokenizer = MolTokenizer("ZINC_SMILES_vocab.json")
final_tokenizer.load_vocab_dict(final_vocab)
final_tokenizer.generate_vocab_file("ZINC_SMILES_vocab.json")
print("ZINC vocab file generated")

if args.pubchem:
print("Generating vocab file for PubChem dataset")
default_file = "/p/vast1/lbann/datasets/FLASK/pubchem/CID_SMILES_CANONICAL.smi"
pubchem_file = os.getenv("PUBCHEM_FILE", default_file)
with open(pubchem_file, "r") as smiles_data:
smiles_data = smiles_data.readlines()
tokenizer = MolTokenizer("PubChem_SMILES_vocab.json")
for line in tqdm(smiles_data):
smiles = line.split(" ")[1]
tokens = tokenizer.tokenize(smiles)

tokenizer.generate_vocab_file("PubChem_SMILES_vocab.json")
print("PubChem vocab file generated")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,45 @@
PATTERN = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"


class MolTokenizer():
def __init__(self, vocab_file: str = '',
do_lower_case=False,
unk_token='<pad>',
sep_token='<eos>',
pad_token='<pad>',
cls_token='<bos>',
mask_token='<mask>',
**kwargs):
class MolTokenizer:
def __init__(
self,
vocab_file: str = "",
do_lower_case=False,
unk_token="<pad>",
sep_token="<eos>",
pad_token="<pad>",
cls_token="<bos>",
mask_token="<mask>",
**kwargs
):
self.unk_token = unk_token
self.pad_token = pad_token
self.sep_token = sep_token
self.cls_token = cls_token
self.mask_token = mask_token
self.mask_token = mask_token
self.regex_tokenizer = re.compile(PATTERN)
self.wordpiece_tokenizer = None
self.basic_tokenizer = None
self.vocab_file = vocab_file
self.vocab_dict = {}
self.vocab_dict[self.pad_token] = 0
self.vocab_dict[self.sep_token] = 1
self.vocab_dict[self.cls_token] = 2
self.vocab_dict[self.mask_token] = 3
self.counter = 4

def load_vocab_file(self):
if os.path.exists(self.vocab_file):
with open(self.vocab_file, 'r') as f:
self.vocab = json.load(f)
with open(self.vocab_file, "r") as f:
self.vocab_dict = json.load(f)
else:
raise NameError("Vocab file not douns")

def load_vocab_dict(self, vocab_dict):
self.vocab_dict = vocab_dict
self.counter = len(vocab_dict)

def _tokenize(self, text):
split_tokens = self.regex_tokenizer.findall(text)
return split_tokens
Expand All @@ -41,30 +54,61 @@ def tokenize(self, text):
split_tokens = self._tokenize(text)
output = np.zeros(len(split_tokens))
for i, each in enumerate(split_tokens):
output[i] = self.vocab[each]
if each not in self.vocab_dict.keys():
self.vocab_dict[each] = self.counter
self.counter += 1
output[i] = self.vocab_dict[each]
return output

def encode(self, token):
return self.vocab[token]
return self.vocab_dict[token]

def convert_tokens_to_string(self, tokens):
out_string = "".join(tokens).strip()
return out_string

def generate_vocab_file(self, dataset_name):
vocab_dict = {}
vocab_dict[self.pad_token] = 0
vocab_dict[self.sep_token] = 1
vocab_dict[self.cls_token] = 2
vocab_dict[self.mask_token] = 3

counter = 4
with open(dataset_name, 'r') as f:
for smiles in f:
for token in self._tokenize(smiles.strip()):
if not token in vocab_dict.keys():
vocab_dict[token] = counter
counter += 1

with open(self.vocab_file, 'w') as f:
f.write(json.dumps(vocab_dict))
return out_string

def generate_vocab_file(self, vocab_file_name):
with open(vocab_file_name, "w") as f:
f.write(json.dumps(self.vocab_dict))

def token_to_id(self, token):
return self.vocab_dict[token]


if __name__ == "__main__":
from tqdm import tqdm

tokenizer = MolTokenizer()
# with open("CID-SMILES-CANONICAL.smi", "r") as f:
# data = f.readlines()
# max_len = 0
# for i in tqdm(range(len(data))):
# line = data[i].split(" ")
# print(line[1])
# tokens = tokenizer._tokenize(line[1])
# print(tokens)
# if len(tokens) > max_len:
# max_len = len(tokens)
# break
# print(max_len)

from glob import glob
from multiprocessing import Pool

zinc_files = glob("ZINC/*.smi")
print(len(zinc_files))

def find_max_per_file(_fname):
with open(_fname, "r") as f:
data = f.readlines()
max_len = 0
for i in tqdm(range(1, len(data))):
line = data[i].split(" ")
tokens = tokenizer._tokenize(line[0])
if len(tokens) > max_len:
max_len = len(tokens)
return max_len

with Pool(20) as p:
max_len = p.map(find_max_per_file, zinc_files)
print(max(max_len))

0 comments on commit 9eddb19

Please sign in to comment.