diff --git a/vizier/_src/algorithms/designers/gp_ucb_pe.py b/vizier/_src/algorithms/designers/gp_ucb_pe.py index 761a3f7af..35be9fd0f 100644 --- a/vizier/_src/algorithms/designers/gp_ucb_pe.py +++ b/vizier/_src/algorithms/designers/gp_ucb_pe.py @@ -322,6 +322,10 @@ def score_with_aux( is_missing = ( features.continuous.is_missing[0] | features.categorical.is_missing[0] ) + logging.info( + 'self.predictive.predictives.precomputed_solve_on_observation.ndim: %s', + self.predictive.predictives.precomputed_solve_on_observation.ndim, + ) gprm_threshold = self.predictive.predict(features) threshold = _compute_ucb_threshold( gprm_threshold, is_missing, self.ucb_coefficient @@ -457,6 +461,9 @@ class VizierGPUCBPEBandit(vza.Designer): Attributes: problem: Must be a flat study with a single metric. acquisition_optimizer: + gp_model_class: The GP model class, which must implement a `build_model` + class method that takes `ModelInput` and returns a + `StochasticProcessModel`. metadata_ns: Metadata namespace that this designer writes to. use_trust_region: Uses trust region. ard_optimizer: An optimizer object, which should return a batch of @@ -475,6 +482,10 @@ class VizierGPUCBPEBandit(vza.Designer): kw_only=True, factory=lambda: VizierGPUCBPEBandit.default_acquisition_optimizer_factory, ) + _gp_model_class: sp.ModelCoroutine[tfd.GaussianProcess] = attr.field( + kw_only=True, + factory=lambda: tuned_gp_models.VizierGaussianProcess, + ) _metadata_ns: str = attr.field( default='google_gp_ucb_pe_bandit', kw_only=True ) @@ -611,7 +622,7 @@ def _build_gp_model_and_optimize_parameters( `data.labels`. If `data.features` is empty, the returned parameters are initial values picked by the GP model. """ - coroutine = tuned_gp_models.VizierGaussianProcess.build_model( + coroutine = self._gp_model_class.build_model( # pytype: disable=attribute-error data.features ).coroutine model = sp.CoroutineWithData(coroutine, data) diff --git a/vizier/_src/jax/stochastic_process_model.py b/vizier/_src/jax/stochastic_process_model.py index fc40f454f..6530e8d96 100644 --- a/vizier/_src/jax/stochastic_process_model.py +++ b/vizier/_src/jax/stochastic_process_model.py @@ -854,7 +854,11 @@ def predict_with_aux( self.predictives.precomputed_solve_on_observation.ndim > 1 ) expand_x = len(xs.continuous.shape) == 3 and has_batched_hparams + logging.info( + 'has_batched_hparams: %s, expand_x: %s', has_batched_hparams, expand_x + ) dist = self.predictives._predict(xs, expand_batch_dim=expand_x) # pylint: disable=protected-access + logging.info('dist.batch_shape: %s', dist.batch_shape) if has_batched_hparams: return ( tfd.MixtureSameFamily(