From 4d45a29d2b8453ef74a8198b147e7a01dbfb4f47 Mon Sep 17 00:00:00 2001 From: Alessandro Vullo Date: Tue, 13 Aug 2024 15:28:41 +0100 Subject: [PATCH] =?UTF-8?q?Handle=20the=20case=20where=20either=20query=20?= =?UTF-8?q?points=20or=20observations=20have=20unspec=E2=80=A6=20(#866)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/unit/test_data.py | 21 +++++++++++++++++++++ trieste/data.py | 1 + 2 files changed, 22 insertions(+) diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py index 78c1c6df42..349d3d442c 100644 --- a/tests/unit/test_data.py +++ b/tests/unit/test_data.py @@ -15,6 +15,7 @@ import copy +import numpy as np import numpy.testing as npt import pytest import tensorflow as tf @@ -68,6 +69,26 @@ def test_dataset_raises_for_different_leading_shapes( Dataset(query_points, observations) +def test_dataset_does_not_raise_with_unspecified_leading_dimension() -> None: + query_points = tf.zeros((2, 2)) + observations = tf.zeros((2, 1)) + + query_points_var = tf.Variable( + initial_value=np.zeros((0, 2)), + shape=(None, 2), + dtype=tf.float64, + ) + observations_var = tf.Variable( + initial_value=np.zeros((0, 1)), + shape=(None, 1), + dtype=tf.float64, + ) + + Dataset(query_points=query_points_var, observations=observations) + Dataset(query_points=query_points, observations=observations_var) + Dataset(query_points=query_points_var, observations=observations_var) + + @pytest.mark.parametrize( "query_points_shape, observations_shape", [ diff --git a/trieste/data.py b/trieste/data.py index 6c979a30e0..5897efdff7 100644 --- a/trieste/data.py +++ b/trieste/data.py @@ -53,6 +53,7 @@ def __post_init__(self) -> None: self.query_points.shape[:-1] != self.observations.shape[:-1] # can't check dynamic shapes, so trust that they're ok (if not, they'll fail later) and None not in self.query_points.shape[:-1] + and None not in self.observations.shape[:-1] ): raise ValueError( f"Leading shapes of query_points and observations must match. Got shapes"