Skip to content

Commit

Permalink
Merge branch 'main' into multifi
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT committed Nov 17, 2023
2 parents c710cb4 + 8c62f37 commit 2042f98
Show file tree
Hide file tree
Showing 7 changed files with 225 additions and 142 deletions.
20 changes: 13 additions & 7 deletions docs/components/score.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,20 @@ Some Scorer classes support using properties computed at lower levels of accurac
to improve performance.
The strategies employed by each Scorer may be different, but all have the same interface.

Use the multi-fidelity capability of a Scorer by providing multiple recipes when preprocessing
for *both* inputs and outputs for training or inference.
The recipes must be ordered from lowest- to highest-fidelity.
Use the multi-fidelity capability of a Scorer by providing
values from lower levels of fidelity when training or running inference.

.. code-block:: python
outputs = model.transform_outputs(records, [recipe_low, recipe_high])
inputs = model.transform_inputs(records, [recipe_low, recipe_high])
from examol.score.utils.multifi import collect_outputs
fidelities = [RedoxEnergy(1, 'low'), RedoxEnergy(1, 'medium'), RedoxEnergy(1, 'high')]
The outputs will, by default, contain the recipe computed at each level of fidelity
with ``np.nan`` values for missing data.
# Get the inputs and outputs, as normal
inputs = scorer.transform_inputs(records)
outputs = scorer.transform_outputs(records, fidelities[-1]) # Train using the highest level
# Pass the low-fidelity results to scoring and inference
lower_fidelities = collect_outputs(records, fidelities[:-1])
scorer.train(model_msg, inputs, outputs, lower_fidelties=lower_fidelities)
...
scorer.score(model_msg, inputs, lower_fidelties=lower_fidelities)
69 changes: 21 additions & 48 deletions examol/score/base.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,14 @@
"""Base classes for scoring functions"""
from dataclasses import dataclass
from typing import Sequence

import numpy as np

from examol.score.utils.multifi import collect_outputs
from examol.store.models import MoleculeRecord
from examol.store.recipes import PropertyRecipe


def collect_outputs(records: list[MoleculeRecord], recipes: Sequence[PropertyRecipe]) -> np.ndarray:
"""Collect the outputs for several recipe for each molecule
Args:
records: Molecule records to be summarized
recipes: List of recipes to include
Returns:
Matrix where each row is a different molecule, and each column is a different recipe
"""
return np.array([
[record.properties.get(recipe.name, {}).get(recipe.level, np.nan) for recipe in recipes]
for record in records
])


# TODO (wardlt): Make this a generic class once we move to Py3.12. https://peps.python.org/pep-0695/
@dataclass
class Scorer:
"""Base class for algorithms which quickly assign a score to a molecule, typically using a machine learning model
Expand Down Expand Up @@ -57,57 +43,28 @@ class Scorer:
update_msg = scorer.retrain(model_msg, inputs, outputs) # Run remotely
model = scorer.update(model, update_msg)
**Multi-fidelity scoring**
Multi-fidelity learning methods employ lower-fidelity estimates of a target value to improve the prediction of that value.
ExaMol supports multi-fidelity through the ability to provide more than one recipe as inputs to
:meth:`transform_inputs` and :meth:`transform_outputs`.
Implementations of Scorers must be designed to support multi-fidelity learning.
"""

_supports_multi_fidelity: bool = False
"""Whether the class supports multi-fidelity optimization"""

def transform_inputs(self, record_batch: list[MoleculeRecord], recipes: Sequence[PropertyRecipe] | None = None) -> list:
def transform_inputs(self, record_batch: list[MoleculeRecord]) -> list:
"""Form inputs for the model based on the data in a molecule record
Args:
record_batch: List of records to pre-process
recipes: List of recipes ordered from lowest to highest fidelity.
Only used in multi-fidelity scoring algorithms
Returns:
List of inputs ready for :meth:`score` or :meth:`retrain`
"""
raise NotImplementedError()

# TODO (wardlt): I'm not super-happy with multi-fidelity being inferred from input types. What if we want multi-objective learning
def transform_outputs(self, records: list[MoleculeRecord], recipe: PropertyRecipe | Sequence[PropertyRecipe]) -> np.ndarray:
def transform_outputs(self, records: list[MoleculeRecord], recipe: PropertyRecipe) -> np.ndarray:
"""Gather the target outputs of the model
Args:
records: List of records from which to extract outputs
recipe: Target recipe for the scorer for single-fidelity learning
or a list of recipes ordered from lowest to highest fidelity
for multi-objective learning.
Returns:
Outputs ready for model training
"""
# Determine if we are doing single or multi-fidelity learning
is_single = False
if isinstance(recipe, PropertyRecipe):
is_single = True
recipes = [recipe]
else:
if not self._supports_multi_fidelity: # pragma: no-coverage
raise ValueError(f'{self.__class__.__name__} does not support multi-fidelity training')
recipes = recipe

# Gather the outputs
output = collect_outputs(records, recipes)
if is_single:
return output[:, 0]
return output
return collect_outputs(records, [recipe])[:, -1]

def prepare_message(self, model: object, training: bool = False) -> object:
"""Get the model state as a serializable object
Expand Down Expand Up @@ -153,3 +110,19 @@ def update(self, model: object, update_msg: object) -> object:
Updated model
"""
raise NotImplementedError()


class MultiFidelityScorer(Scorer):
"""Base class for scorers which support multi-fidelity learning
All subclasses support a "lower_fidelities" keyword argument to the
:meth:`score` and :meth:`retrain` functions that takes any lower-fidelity information available.
Subclasses should train a multi-fidelity model if provided lower-fidelity data during
training and use the lower-fidelity data to enhance prediction accuracy during scoring.
"""

def score(self, model_msg: object, inputs: list, lower_fidelities: np.ndarray | None = None, **kwargs) -> np.ndarray:
raise NotImplementedError()

def retrain(self, model_msg: object, inputs: list, outputs: list, lower_fidelities: np.ndarray | None = None, **kwargs) -> object:
raise NotImplementedError()
56 changes: 24 additions & 32 deletions examol/score/nfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

from examol.store.models import MoleculeRecord
from examol.utils.conversions import convert_string_to_nx
from examol.store.recipes import PropertyRecipe
from .base import Scorer, collect_outputs
from .base import MultiFidelityScorer
from .utils.multifi import compute_deltas
from .utils.tf import LRLogger, TimeLimitCallback, EpochTimeLogger


Expand Down Expand Up @@ -288,7 +288,7 @@ def generator():
return loader


class NFPScorer(Scorer):
class NFPScorer(MultiFidelityScorer):
"""Train message-passing neural networks based on the `NFP <https://github.com/NREL/nfp>`_ library.
NFP uses Keras to define message-passing networks, which is backed by Tensorflow for executing the networks on different hardware.
Expand Down Expand Up @@ -316,51 +316,44 @@ def prepare_message(self, model: tf.keras.models.Model, training: bool = False)
else:
return NFPMessage(model)

def transform_inputs(self, record_batch: list[MoleculeRecord], recipes: list[PropertyRecipe] | None = None) -> list[dict | tuple[dict, np.ndarray]]:
mol_dicts = [convert_string_to_dict(record.identifier.inchi) for record in record_batch]
def transform_inputs(self, record_batch: list[MoleculeRecord]) -> list[dict]:
return [convert_string_to_dict(record.identifier.inchi) for record in record_batch]

# Return only the molecular dicts for single-fidelity runs
if recipes is None:
return mol_dicts

# Return both the molecular dictionary and known properties for multi-fidelity
return list(zip(mol_dicts, collect_outputs(record_batch, recipes)))

def score(self, model_msg: NFPMessage, inputs: list[dict | tuple[dict, np.ndarray]], batch_size: int = 64, **kwargs) -> np.ndarray:
def score(self,
model_msg: NFPMessage,
inputs: list[dict | tuple[dict, np.ndarray]],
batch_size: int = 64,
lower_fidelities: np.ndarray | None = None,
**kwargs) -> np.ndarray:
"""Assign a score to molecules
Args:
model_msg: Model in a transmittable format
inputs: Batch of inputs ready for the model (in dictionary format)
batch_size: Number of molecules to evaluate at each time
lower_fidelities: Properties of the molecule at lower levels, if known
Returns:
The scores to a set of records
"""
model = model_msg.get_model() # Unpack the model

# Unpack the known values if running multiobjective learning
is_single = isinstance(inputs[0], dict)
known_outputs = None
if not is_single:
inputs, known_outputs = zip(*inputs)
known_outputs = np.array(known_outputs)
known_outputs[:, 1:] = np.diff(known_outputs)

# Run inference
loader = make_data_loader(inputs, batch_size=batch_size)
ml_outputs = model.predict(loader, verbose=False)
if is_single:
ml_outputs = np.squeeze(model.predict(loader, verbose=False))
if ml_outputs.ndim == 1: # Single-fidelity learning
return ml_outputs

# For multi-objective, add in the use the known outputs in place of the NN outputs
best_outputs = np.where(np.isnan(known_outputs), ml_outputs, known_outputs)
best_outputs = best_outputs.cumsum(axis=1) # The outputs of the networks are deltas
return best_outputs[:, -1] # Return only the highest level of fidelity
# For multi-objective, add in the use the known outputs in place of the NN outputs if we know them
if lower_fidelities is not None:
known_deltas = compute_deltas(lower_fidelities)
ml_outputs[:, :-1] = np.where(np.isnan(known_deltas), ml_outputs[:, :-1], known_deltas)
return ml_outputs.sum(axis=1) # The outputs of the networks are deltas

def retrain(self,
model_msg: dict | NFPMessage,
inputs: list,
outputs: np.ndarray,
lower_fidelities: None | np.ndarray = None,
num_epochs: int = 4,
batch_size: int = 32,
validation_split: float = 0.1,
Expand All @@ -376,6 +369,7 @@ def retrain(self,
model_msg: Model to be retrained
inputs: Training set inputs, as generated by :meth:`transform_inputs`
outputs: Training Set outputs, as generated by :meth:`transform_outputs`
lower_fidelities: Lower-fidelity data, if available
num_epochs: Maximum number of epochs to run
batch_size: Number of molecules per training batch
validation_split: Fraction of molecules used for the training/validation split
Expand All @@ -398,7 +392,7 @@ def retrain(self,
raise NotImplementedError(f'Unrecognized message type: {type(model_msg)}')

# Prepare data for single- vs multi-objective
is_single = isinstance(inputs[0], dict)
is_single = lower_fidelities is None
if is_single:
# Nothing special: Use a standard loss function, no preprocessing required
loss = 'mean_squared_error'
Expand All @@ -410,11 +404,9 @@ def loss(y_true, y_pred):
is_known = tf.math.is_finite(y_true)
return tf.keras.losses.mean_squared_error(y_true[is_known], y_pred[is_known])

inputs, _ = zip(*inputs) # We do not need the input values for training

# Prepare the outputs
outputs = outputs.copy()
outputs[:, 1:] = np.diff(outputs) # Compute the deltas between successive stages
outputs = np.concatenate([lower_fidelities, outputs[:, None]], axis=1)
outputs = compute_deltas(outputs)
value_spec = tf.TensorSpec((outputs.shape[1],), dtype=tf.float32)

# Split off a validation set
Expand Down
104 changes: 79 additions & 25 deletions examol/score/rdkit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Scorers that rely on RDKit and sklearn"""
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass
from functools import partial
from typing import Callable, Union

import numpy as np
from sklearn import clone
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KNeighborsRegressor
Expand All @@ -12,48 +14,100 @@
from sklearn.pipeline import Pipeline
from sklearn.decomposition import PCA

from examol.score.base import Scorer
from examol.score.base import MultiFidelityScorer
from examol.score.utils.multifi import compute_deltas
from examol.score.rdkit.descriptors import compute_morgan_fingerprints, compute_doan_2020_fingerprints
from examol.store.models import MoleculeRecord

ModelType = Pipeline | list[Pipeline]
"""Model is a single for training and a list of models after training"""
InputType = list[str]
"""Model inputs are the SMILES string of the molecule"""

class RDKitScorer(Scorer):

@dataclass
class RDKitScorer(MultiFidelityScorer):
"""Score molecules based on a model defined using RDKit and Scikit-Learn
Models must take a SMILES string as input.
Use the :class:`~.FingerprintTransformer` to transform the SMILES into an RDKit Mol object if needed.
"""
def transform_inputs(self, record_batch: list[MoleculeRecord]) -> list:
return [x.identifier.smiles for x in record_batch]
**Multi Fidelity Learning**
def prepare_message(self, model: Pipeline, training: bool = True) -> Pipeline:
return model
We implement multi-fidelity learning by training separate models for each level of fidelity.
def score(self, model_msg: Pipeline, inputs: list, **kwargs) -> np.ndarray:
return model_msg.predict(inputs)
The model for the lowest level of fidelity is trained to predict the value of the property
and each subsequent model predicts the delta between it and the previous step.
def retrain(self, model_msg: Pipeline, inputs: list, outputs: np.ndarray, bootstrap: bool = True, **kwargs) -> object:
"""Retrain the scorer based on new training records
On inference, we use the known values for either the lowest level of fidelity or
the deltas in place of the predictions from the machine learning models.
"""

Args:
model_msg: Model to be retrained
inputs: Training set inputs, as generated by :meth:`transform_inputs`
outputs: Training Set outputs, as generated by :meth:`transform_outputs`
bootstrap: Whether to sample training data with replacement
Returns:
Message defining how to update the model
"""
# If desired, resample with replacement
def transform_inputs(self, record_batch: list[MoleculeRecord]) -> InputType:
return [x.identifier.smiles for x in record_batch]

def prepare_message(self, model: ModelType, training: bool = True) -> ModelType:
if training:
# Only send a single model for training
return model[0] if isinstance(model, list) else model
else:
# Send the whole list for inference
return model

def score(self, model_msg: ModelType, inputs: InputType, lower_fidelities: np.ndarray | None = None, **kwargs) -> np.ndarray:
if not isinstance(model_msg, list):
# Single objective
return model_msg.predict(inputs)
else:
# Get the known deltas then append a NaN to the end (we don't know the last delta)
if lower_fidelities is None:
deltas = np.empty((len(inputs), len(model_msg))) * np.nan
else:
known_deltas = compute_deltas(lower_fidelities)
deltas = np.concatenate((known_deltas, np.empty_like(known_deltas[:, :1]) * np.nan), axis=1)

# Run the model at each level
for my_level, my_model in enumerate(model_msg):
my_preds = my_model.predict(inputs)
is_unknown = np.isnan(deltas[:, my_level])
deltas[is_unknown, my_level] = my_preds[is_unknown]

# Sum up the deltas
return np.sum(deltas, axis=1)

def retrain(self, model_msg: Pipeline, inputs: InputType, outputs: np.ndarray,
bootstrap: bool = True,
lower_fidelities: np.ndarray | None = None) -> ModelType:
if bootstrap:
samples = np.random.random_integers(0, len(inputs) - 1, size=(len(inputs),))
inputs = [inputs[i] for i in samples]
outputs = outputs[samples]

model_msg.fit(inputs, outputs)
return model_msg

def update(self, model: object, update_msg: object) -> object:
if lower_fidelities is not None:
lower_fidelities = lower_fidelities[samples, :]

if lower_fidelities is None:
# For single level, train a single model
model_msg.fit(inputs, outputs)
return model_msg
else:
# Compute the delta and then train a different model for each delta
outputs = np.concatenate([lower_fidelities, outputs[:, None]], axis=1) # Append target level to end
deltas = compute_deltas(outputs)

models = []
for y in deltas.T:
# Remove the missing values
mask = np.isfinite(y)
my_smiles = [i for m, i in zip(mask, inputs) if m]
y = y[mask]

# Fit a fresh copy of the model
my_model: Pipeline = clone(model_msg)
my_model.fit(my_smiles, y)
models.append(my_model)
return models

def update(self, model: ModelType, update_msg: ModelType) -> ModelType:
return update_msg


Expand Down
Loading

0 comments on commit 2042f98

Please sign in to comment.