Skip to content

Commit

Permalink
Handle the case where either query points or observations have unspec… (
Browse files Browse the repository at this point in the history
  • Loading branch information
avullo authored Aug 13, 2024
1 parent 43e6f01 commit 4d45a29
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
21 changes: 21 additions & 0 deletions tests/unit/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import copy

import numpy as np
import numpy.testing as npt
import pytest
import tensorflow as tf
Expand Down Expand Up @@ -68,6 +69,26 @@ def test_dataset_raises_for_different_leading_shapes(
Dataset(query_points, observations)


def test_dataset_does_not_raise_with_unspecified_leading_dimension() -> None:
query_points = tf.zeros((2, 2))
observations = tf.zeros((2, 1))

query_points_var = tf.Variable(
initial_value=np.zeros((0, 2)),
shape=(None, 2),
dtype=tf.float64,
)
observations_var = tf.Variable(
initial_value=np.zeros((0, 1)),
shape=(None, 1),
dtype=tf.float64,
)

Dataset(query_points=query_points_var, observations=observations)
Dataset(query_points=query_points, observations=observations_var)
Dataset(query_points=query_points_var, observations=observations_var)


@pytest.mark.parametrize(
"query_points_shape, observations_shape",
[
Expand Down
1 change: 1 addition & 0 deletions trieste/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __post_init__(self) -> None:
self.query_points.shape[:-1] != self.observations.shape[:-1]
# can't check dynamic shapes, so trust that they're ok (if not, they'll fail later)
and None not in self.query_points.shape[:-1]
and None not in self.observations.shape[:-1]
):
raise ValueError(
f"Leading shapes of query_points and observations must match. Got shapes"
Expand Down

0 comments on commit 4d45a29

Please sign in to comment.