Skip to content

Commit

Permalink
Supports multimetrics in VizierGaussianProcess
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 708138650
  • Loading branch information
vizier-team authored and copybara-github committed Dec 20, 2024
1 parent 1985bd9 commit 0b61e28
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 68 deletions.
15 changes: 2 additions & 13 deletions vizier/_src/algorithms/designers/gp/gp_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from vizier._src.algorithms.designers.gp import transfer_learning as vtl
from vizier._src.jax import stochastic_process_model as sp
from vizier._src.jax import types
from vizier._src.jax.models import multitask_tuned_gp_models
from vizier._src.jax.models import tuned_gp_models
from vizier.jax import optimizers
from vizier.utils import profiler
Expand Down Expand Up @@ -155,19 +154,9 @@ def get_vizier_gp_coroutine(
Returns:
The model coroutine.
"""
# Construct the multi-task GP.
if data.labels.shape[-1] > 1:
gp_coroutine = multitask_tuned_gp_models.VizierMultitaskGaussianProcess(
_feature_dim=types.ContinuousAndCategorical[int](
data.features.continuous.padded_array.shape[-1],
data.features.categorical.padded_array.shape[-1],
),
_num_tasks=data.labels.shape[-1],
)
return sp.StochasticProcessModel(gp_coroutine).coroutine

return tuned_gp_models.VizierGaussianProcess.build_model(
data.features, linear_coef=linear_coef
data,
linear_coef=linear_coef,
).coroutine


Expand Down
2 changes: 1 addition & 1 deletion vizier/_src/algorithms/designers/gp_ucb_pe.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def _build_gp_model_and_optimize_parameters(
initial values picked by the GP model.
"""
coroutine = self._gp_model_class.build_model( # pytype: disable=attribute-error
data.features
data
).coroutine
model = sp.CoroutineWithData(coroutine, data)

Expand Down
45 changes: 33 additions & 12 deletions vizier/_src/jax/models/tuned_gp_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

tfb = tfp.bijectors
tfd = tfp.distributions
tfde = tfp.experimental.distributions
tfpk = tfp.math.psd_kernels
tfpke = tfp.experimental.psd_kernels

Expand Down Expand Up @@ -86,26 +87,34 @@ class VizierGaussianProcess(sp.ModelCoroutine[tfd.GaussianProcess]):
"""

_dim: types.ContinuousAndCategorical[int] = struct.field(pytree_node=False)
_num_metrics: int = struct.field(pytree_node=False)
_use_retrying_cholesky: bool = struct.field(
pytree_node=False, default=True, kw_only=True
)
_boundary_epsilon: float = struct.field(default=1e-12, kw_only=True)
_linear_coef: Optional[float] = struct.field(default=None, kw_only=True)

def __attrs_post_init__(self):
if self._num_metrics < 1:
raise ValueError(
'Number of metrics must be at least 1, got: {self._num_metrics}'
)

@classmethod
def build_model(
cls,
features: types.ModelInput,
data: types.ModelData,
*,
use_retrying_cholesky: bool = True,
linear_coef: Optional[float] = None,
) -> sp.StochasticProcessModel:
"""Returns the model and loss function."""
gp_coroutine = VizierGaussianProcess(
_dim=types.ContinuousAndCategorical[int](
features.continuous.padded_array.shape[-1],
features.categorical.padded_array.shape[-1],
data.features.continuous.padded_array.shape[-1],
data.features.categorical.padded_array.shape[-1],
),
_num_metrics=data.labels.shape[-1],
_use_retrying_cholesky=use_retrying_cholesky,
_linear_coef=linear_coef,
)
Expand Down Expand Up @@ -214,8 +223,11 @@ def __call__(
# output a shape of `[batch_shape, 1]`, ensuring that batch dimensions
# line up properly.
mean_fn_constant = yield sp.ModelParameter(
init_fn=lambda k: jax.random.normal(key=k, shape=[1]),
regularizer=lambda x: 0.5 * jnp.squeeze(x, axis=-1) ** 2,
init_fn=lambda k: jax.random.normal(
key=k,
shape=[1] if self._num_metrics == 1 else [1, self._num_metrics],
),
regularizer=lambda x: 0.5 * jnp.sum(x**2),
name='mean_fn',
)

Expand Down Expand Up @@ -256,10 +268,19 @@ def __call__(
)
cholesky_fn = lambda matrix: retrying_cholesky(matrix)[0]

return tfd.GaussianProcess(
kernel,
index_points=inputs,
observation_noise_variance=observation_noise_variance,
cholesky_fn=cholesky_fn,
mean_fn=mean_fn,
)
if self._num_metrics > 1:
return tfde.MultiTaskGaussianProcess(
tfpke.Independent(self._num_metrics, kernel),
index_points=inputs,
observation_noise_variance=observation_noise_variance,
cholesky_fn=cholesky_fn,
mean_fn=mean_fn,
)
else:
return tfd.GaussianProcess(
kernel,
index_points=inputs,
observation_noise_variance=observation_noise_variance,
cholesky_fn=cholesky_fn,
mean_fn=mean_fn,
)
113 changes: 71 additions & 42 deletions vizier/_src/jax/models/tuned_gp_models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from absl import logging
import equinox as eqx
import jax
from jax import config
import numpy as np
from tensorflow_probability.substrates import jax as tfp
from vizier._src.jax import stochastic_process_model as sp
Expand All @@ -28,13 +27,14 @@
from vizier.jax import optimizers

from absl.testing import absltest
from absl.testing import parameterized

tfb = tfp.bijectors


class VizierGpTest(absltest.TestCase):
class VizierGpTest(parameterized.TestCase):

def _generate_xys(self):
def _generate_xys(self, num_metrics: int):
x_obs = np.array(
[
[
Expand Down Expand Up @@ -120,46 +120,56 @@ def _generate_xys(self):
],
dtype=np.float64,
)
y_obs = np.array(
[
0.55552674,
-0.29054829,
-0.04703586,
0.0217839,
0.15445438,
0.46654119,
0.12255823,
-0.19540335,
-0.11772564,
-0.44447326,
],
dtype=np.float64,
)[:, np.newaxis]
y_obs = np.tile(
np.array(
[
0.55552674,
-0.29054829,
-0.04703586,
0.0217839,
0.15445438,
0.46654119,
0.12255823,
-0.19540335,
-0.11772564,
-0.44447326,
],
dtype=np.float64,
)[:, np.newaxis],
(1, num_metrics),
)
return x_obs, y_obs

# TODO: Define generic assertions for loss values/masking in
# coroutines.
def test_masking_works(self):
# Mask three dimensions and four observations.
x_obs, y_obs = self._generate_xys()
@parameterized.parameters(
# Pads two observations.
dict(num_metrics=1, num_obs=12),
# No observations are padded because multimetric GP does not support
# observation padding.
dict(num_metrics=2, num_obs=10),
)
def test_masking_works(self, num_metrics: int, num_obs: int):
x_obs, y_obs = self._generate_xys(num_metrics)
data = types.ModelData(
features=types.ModelInput(
# Pads three continuous dimensions.
continuous=types.PaddedArray.from_array(
x_obs, target_shape=(12, 9), fill_value=1.0
x_obs, target_shape=(num_obs, 9), fill_value=1.0
),
categorical=types.PaddedArray.from_array(
np.zeros((9, 0), dtype=types.INT_DTYPE),
target_shape=(12, 2),
target_shape=(num_obs, 2),
fill_value=1,
),
),
labels=types.PaddedArray.from_array(
y_obs, target_shape=(12, 1), fill_value=np.nan
y_obs, target_shape=(num_obs, num_metrics), fill_value=np.nan
),
)
model1 = sp.CoroutineWithData(
tuned_gp_models.VizierGaussianProcess(
types.ContinuousAndCategorical[int](9, 2)
types.ContinuousAndCategorical[int](9, 2), num_metrics
),
data=data,
)
Expand All @@ -173,7 +183,7 @@ def test_masking_works(self):
)
model2 = sp.CoroutineWithData(
tuned_gp_models.VizierGaussianProcess(
types.ContinuousAndCategorical[int](9, 2)
types.ContinuousAndCategorical[int](9, 2), num_metrics
),
data=modified_data,
)
Expand Down Expand Up @@ -205,37 +215,44 @@ def test_masking_works(self):
model2.loss_with_aux(optimal_params2)[0],
)

def test_good_log_likelihood(self):
@parameterized.parameters(
# Pads two observations.
dict(num_metrics=1, num_obs=12),
# No observations are padded because multimetric GP does not support
# observation padding.
dict(num_metrics=2, num_obs=10),
)
def test_good_log_likelihood(self, num_metrics: int, num_obs: int):
# We use a fixed random seed for sampling categorical data (and continuous
# data from `_generate_xys`, above) so that the same data is used for every
# test run.
rng, init_rng, cat_rng = jax.random.split(jax.random.PRNGKey(2), 3)
x_cont_obs, y_obs = self._generate_xys()
x_cont_obs, y_obs = self._generate_xys(num_metrics)
data = types.ModelData(
features=types.ModelInput(
continuous=types.PaddedArray.from_array(
x_cont_obs, target_shape=(12, 9), fill_value=np.nan
x_cont_obs, target_shape=(num_obs, 9), fill_value=np.nan
),
categorical=types.PaddedArray.from_array(
jax.random.randint(
cat_rng,
shape=(12, 3),
shape=(num_obs, 3),
minval=0,
maxval=3,
dtype=types.INT_DTYPE,
),
target_shape=(12, 5),
target_shape=(num_obs, 5),
fill_value=-1,
),
),
labels=types.PaddedArray.from_array(
y_obs, target_shape=(12, 1), fill_value=np.nan
y_obs, target_shape=(num_obs, num_metrics), fill_value=np.nan
),
)
target_loss = -0.2
model = sp.CoroutineWithData(
tuned_gp_models.VizierGaussianProcess(
types.ContinuousAndCategorical[int](9, 5)
types.ContinuousAndCategorical[int](9, 5), num_metrics
),
data=data,
)
Expand All @@ -251,37 +268,46 @@ def test_good_log_likelihood(self):
logging.info('Loss: %s', metrics['loss'])
self.assertLess(np.min(metrics['loss']), target_loss)

def test_good_log_likelihood_linear(self):
@parameterized.parameters(
# Pads two observations.
dict(num_metrics=1, num_obs=12),
# No observations are padded because multimetric GP does not support
# observation padding.
dict(num_metrics=2, num_obs=10),
)
def test_good_log_likelihood_linear(self, num_metrics: int, num_obs: int):
# We use a fixed random seed for sampling categorical data (and continuous
# data from `_generate_xys`, above) so that the same data is used for every
# test run.
rng, init_rng, cat_rng = jax.random.split(jax.random.PRNGKey(2), 3)
x_cont_obs, y_obs = self._generate_xys()
x_cont_obs, y_obs = self._generate_xys(num_metrics)
data = types.ModelData(
features=types.ModelInput(
continuous=types.PaddedArray.from_array(
x_cont_obs, target_shape=(12, 9), fill_value=np.nan
x_cont_obs, target_shape=(num_obs, 9), fill_value=np.nan
),
categorical=types.PaddedArray.from_array(
jax.random.randint(
cat_rng,
shape=(12, 3),
shape=(num_obs, 3),
minval=0,
maxval=3,
dtype=types.INT_DTYPE,
),
target_shape=(12, 5),
target_shape=(num_obs, 5),
fill_value=-1,
),
),
labels=types.PaddedArray.from_array(
y_obs, target_shape=(12, 1), fill_value=np.nan
y_obs, target_shape=(num_obs, num_metrics), fill_value=np.nan
),
)
target_loss = -0.2
model = sp.CoroutineWithData(
tuned_gp_models.VizierGaussianProcess(
types.ContinuousAndCategorical[int](9, 5), _linear_coef=1.0
types.ContinuousAndCategorical[int](9, 5),
num_metrics,
_linear_coef=1.0,
),
data=data,
)
Expand Down Expand Up @@ -327,11 +353,14 @@ def test_good_log_likelihood_linear(self):
),
)
y_pred_mean = predictive.predict(pred_features).mean()
self.assertEqual(y_pred_mean.shape, (best_n, n_pred_features))
self.assertEqual(
y_pred_mean.shape,
(best_n, n_pred_features) + ((num_metrics,) if num_metrics > 1 else ()),
)


if __name__ == '__main__':
# Jax disables float64 computations by default and will silently convert
# float64s to float32s. We must explicitly enable float64.
config.update('jax_enable_x64', True)
jax.config.update('jax_enable_x64', True)
absltest.main()

0 comments on commit 0b61e28

Please sign in to comment.