Skip to content

Commit

Permalink
updates to demo_cross_validate.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
PatWalters committed Dec 22, 2024
1 parent c251860 commit 6350e96
Show file tree
Hide file tree
Showing 4 changed files with 727 additions and 228 deletions.
44 changes: 44 additions & 0 deletions notebooks/catboost_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/usr/bin/env python

from catboost import CatBoostRegressor
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator
import numpy as np
from sklearn.model_selection import train_test_split
import pandas as pd
import useful_rdkit_utils as uru

class CatBoostWrapper:
def __init__(self, y_col):
self.cb = CatBoostRegressor(verbose=False)
self.y_col = y_col
self.fp_name = "fp"
self.fg = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=1024)

def fit(self, train):
train['mol'] = train.SMILES.apply(Chem.MolFromSmiles)
train[self.fp_name] = train.mol.apply(self.fg.GetCountFingerprintAsNumPy)
self.cb.fit(np.stack(train.fp),train[self.y_col])

def predict(self, test):
test['mol'] = test.SMILES.apply(Chem.MolFromSmiles)
test[self.fp_name] = test.mol.apply(self.fg.GetCountFingerprintAsNumPy)
pred = self.cb.predict(np.stack(np.stack(test[self.fp_name])))
return pred

def validate(self, train, test):
self.fit(train)
pred = self.predict(test)
return pred



def main():
df = pd.read_csv("https://raw.githubusercontent.com/PatWalters/datafiles/refs/heads/main/biogen_logS.csv")
train, test = train_test_split(df)
cb_wrapper = CatBoostWrapper("logS")
pred = cb_wrapper.validate(train, test)
print(pred)

if __name__ == "__main__":
main()
Loading

0 comments on commit 6350e96

Please sign in to comment.