Skip to content

Commit

Permalink
Test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705520724
  • Loading branch information
vizier-team authored and copybara-github committed Dec 12, 2024
1 parent 1fc8537 commit cdcabbb
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
13 changes: 12 additions & 1 deletion vizier/_src/algorithms/designers/gp_ucb_pe.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,10 @@ def score_with_aux(
is_missing = (
features.continuous.is_missing[0] | features.categorical.is_missing[0]
)
logging.info(
'self.predictive.predictives.precomputed_solve_on_observation.ndim: %s',
self.predictive.predictives.precomputed_solve_on_observation.ndim,
)
gprm_threshold = self.predictive.predict(features)
threshold = _compute_ucb_threshold(
gprm_threshold, is_missing, self.ucb_coefficient
Expand Down Expand Up @@ -457,6 +461,9 @@ class VizierGPUCBPEBandit(vza.Designer):
Attributes:
problem: Must be a flat study with a single metric.
acquisition_optimizer:
gp_model_class: The GP model class, which must implement a `build_model`
class method that takes `ModelInput` and returns a
`StochasticProcessModel`.
metadata_ns: Metadata namespace that this designer writes to.
use_trust_region: Uses trust region.
ard_optimizer: An optimizer object, which should return a batch of
Expand All @@ -475,6 +482,10 @@ class VizierGPUCBPEBandit(vza.Designer):
kw_only=True,
factory=lambda: VizierGPUCBPEBandit.default_acquisition_optimizer_factory,
)
_gp_model_class: sp.ModelCoroutine[tfd.GaussianProcess] = attr.field(
kw_only=True,
factory=lambda: tuned_gp_models.VizierGaussianProcess,
)
_metadata_ns: str = attr.field(
default='google_gp_ucb_pe_bandit', kw_only=True
)
Expand Down Expand Up @@ -611,7 +622,7 @@ def _build_gp_model_and_optimize_parameters(
`data.labels`. If `data.features` is empty, the returned parameters are
initial values picked by the GP model.
"""
coroutine = tuned_gp_models.VizierGaussianProcess.build_model(
coroutine = self._gp_model_class.build_model( # pytype: disable=attribute-error
data.features
).coroutine
model = sp.CoroutineWithData(coroutine, data)
Expand Down
4 changes: 4 additions & 0 deletions vizier/_src/jax/stochastic_process_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,11 @@ def predict_with_aux(
self.predictives.precomputed_solve_on_observation.ndim > 1
)
expand_x = len(xs.continuous.shape) == 3 and has_batched_hparams
logging.info(
'has_batched_hparams: %s, expand_x: %s', has_batched_hparams, expand_x
)
dist = self.predictives._predict(xs, expand_batch_dim=expand_x) # pylint: disable=protected-access
logging.info('dist.batch_shape: %s', dist.batch_shape)
if has_batched_hparams:
return (
tfd.MixtureSameFamily(
Expand Down

0 comments on commit cdcabbb

Please sign in to comment.