From e387616fa6396d144cf80ae31c1338d18904d6ea Mon Sep 17 00:00:00 2001 From: knikolaou Date: Wed, 15 May 2024 22:20:15 +0200 Subject: [PATCH] Add shape check to BaseMeasurement class --- .../measurements/test_base_measurement.py | 14 ++++++++++++++ papyrus/measurements/base_measurement.py | 12 ++++++++++++ 2 files changed, 26 insertions(+) diff --git a/CI/unit_tests/measurements/test_base_measurement.py b/CI/unit_tests/measurements/test_base_measurement.py index abb64f8..a051042 100644 --- a/CI/unit_tests/measurements/test_base_measurement.py +++ b/CI/unit_tests/measurements/test_base_measurement.py @@ -26,6 +26,7 @@ import numpy as np import pytest +from numpy.testing import assert_raises from papyrus.measurements import BaseMeasurement @@ -105,3 +106,16 @@ def test_call(self): assert np.allclose(result, a + b + c) result = measurement(a, b=b) assert np.allclose(result, a + b) + + # Test error handling for wrong size of arguments + a = np.array([1, 2, 3]) + b = np.array([[4, 5, 6], [7, 8, 9]]) + c = np.array([[10, 11, 12], [13, 14, 15]]) + with assert_raises(ValueError): + measurement(a, b, c) + with assert_raises(ValueError): + measurement(a=a, b=b, c=c) + with assert_raises(ValueError): + measurement(a, b=b, c=c) + with assert_raises(ValueError): + measurement(a, b, c=c) diff --git a/papyrus/measurements/base_measurement.py b/papyrus/measurements/base_measurement.py index e9bc8d4..aa14f35 100644 --- a/papyrus/measurements/base_measurement.py +++ b/papyrus/measurements/base_measurement.py @@ -130,9 +130,21 @@ def __call__(self, *args: np.ndarray, **kwargs: np.ndarray) -> np.ndarray: """ # Get the number of arguments num_args = len(args) + # Get the keys and values of the keyword arguments if any keys = list(kwargs.keys()) vals = list(kwargs.values()) + + # Assert whether the length of dimension 0 of all inputs is the same + try: + inputs = args + tuple(vals) + assert all([len(i) == len(inputs[0]) for i in inputs]) + except AssertionError: + raise ValueError( + f"The first dimension of all inputs to the {self.name} measurement " + "must be the same." + ) + # Zip the arguments and values z = zip(*args, *vals)