Skip to content

Commit

Permalink
Add regularizer to covariance computation
Browse files Browse the repository at this point in the history
Avoid errors when we have zero variance
  • Loading branch information
WardLT committed Dec 7, 2023
1 parent 1bf8858 commit 4ed36c5
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
13 changes: 10 additions & 3 deletions examol/select/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,18 @@
from examol.store.db.base import MoleculeStore


class _EnsembleCovarianceModel(Model):
class EnsembleCovarianceModel(Model):
"""Model which generates a multivariate Gaussian distribution given samples from an ensemble of models"""

def __init__(self, num_outputs: int):
def __init__(self, num_outputs: int, c: float = 1e-6):
"""
Args:
num_outputs: Number of outputs for the model
c: Regularization parameter used to ensure covariance matrix is positive definite
"""
super().__init__()
self._num_outputs = num_outputs
self.regularization = torch.eye(num_outputs) * c

@property
def num_outputs(self) -> int:
Expand Down Expand Up @@ -60,6 +66,7 @@ def posterior(
combined_task_and_samples = torch.flatten(centered, start_dim=1, end_dim=2) # b x (q x o) x s
d = combined_task_and_samples.shape[-1]
cov = 1 / (d - 1) * combined_task_and_samples @ combined_task_and_samples.transpose(-1, -2)
cov += self.regularization

# Make the multivariate normal as an output
posterior = GPyTorchPosterior(
Expand Down Expand Up @@ -118,7 +125,7 @@ def update(self, database: MoleculeStore, recipes: Sequence[PropertyRecipe]):
outputs = _extract_observations(database, recipes)
self.acq_options = self.acq_options_updater(self, outputs)

self.acq_function = self.acq_function_type(model=_EnsembleCovarianceModel(len(recipes)), **self.acq_options)
self.acq_function = self.acq_function_type(model=EnsembleCovarianceModel(len(recipes)), **self.acq_options)

def _assign_score(self, samples: np.ndarray) -> np.ndarray:
# Shape the tensor in the form expected by BOtorch's GPyTorch
Expand Down
18 changes: 17 additions & 1 deletion tests/select/test_botorch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
"""Tests for the BOTorch-based acquisition functions"""
from botorch.acquisition import ExpectedImprovement
import numpy as np
import torch

from examol.select.botorch import BOTorchSequentialSelector, EHVISelector
from examol.select.botorch import BOTorchSequentialSelector, EHVISelector, EnsembleCovarianceModel


def test_ensemble_covar():
# Normal case: non-zero variance for all properties
x = torch.tensor([[[1.0, 1.1, 2.0, 2.1]]])
model = EnsembleCovarianceModel(num_outputs=2)
covar = model.posterior(x)
assert torch.isclose(covar.distribution.mean, torch.tensor([1.05, 2.05])).all()
assert torch.isclose(covar.distribution.variance, torch.tensor([[0.0050, 0.0050]]), atol=1e-4).all()

# Bad case: zero variance
x = torch.tensor([[[1.0, 1.0, 2.1, 2.1]]])
covar = model.posterior(x)
assert torch.isclose(covar.distribution.mean, torch.tensor([1.0, 2.1])).all()
assert torch.isclose(covar.distribution.variance, torch.tensor([[0.0, 0.0]]), atol=1e-4).all()


def test_sequential(test_data, recipe):
Expand Down

0 comments on commit 4ed36c5

Please sign in to comment.