Skip to content

Commit

Permalink
Add shape check to BaseMeasurement class
Browse files Browse the repository at this point in the history
  • Loading branch information
knikolaou committed May 15, 2024
1 parent 658fe50 commit e387616
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
14 changes: 14 additions & 0 deletions CI/unit_tests/measurements/test_base_measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import numpy as np
import pytest
from numpy.testing import assert_raises

from papyrus.measurements import BaseMeasurement

Expand Down Expand Up @@ -105,3 +106,16 @@ def test_call(self):
assert np.allclose(result, a + b + c)
result = measurement(a, b=b)
assert np.allclose(result, a + b)

# Test error handling for wrong size of arguments
a = np.array([1, 2, 3])
b = np.array([[4, 5, 6], [7, 8, 9]])
c = np.array([[10, 11, 12], [13, 14, 15]])
with assert_raises(ValueError):
measurement(a, b, c)
with assert_raises(ValueError):
measurement(a=a, b=b, c=c)
with assert_raises(ValueError):
measurement(a, b=b, c=c)
with assert_raises(ValueError):
measurement(a, b, c=c)
12 changes: 12 additions & 0 deletions papyrus/measurements/base_measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,21 @@ def __call__(self, *args: np.ndarray, **kwargs: np.ndarray) -> np.ndarray:
"""
# Get the number of arguments
num_args = len(args)

# Get the keys and values of the keyword arguments if any
keys = list(kwargs.keys())
vals = list(kwargs.values())

# Assert whether the length of dimension 0 of all inputs is the same
try:
inputs = args + tuple(vals)
assert all([len(i) == len(inputs[0]) for i in inputs])
except AssertionError:
raise ValueError(
f"The first dimension of all inputs to the {self.name} measurement "
"must be the same."
)

# Zip the arguments and values
z = zip(*args, *vals)

Expand Down

0 comments on commit e387616

Please sign in to comment.