diff --git a/autograd/numpy/numpy_boxes.py b/autograd/numpy/numpy_boxes.py index b9c73963..cc913449 100644 --- a/autograd/numpy/numpy_boxes.py +++ b/autograd/numpy/numpy_boxes.py @@ -18,6 +18,9 @@ class ArrayBox(Box): def __getitem__(A, idx): return A[idx] + def item(self): + return self[(0,) * len(self.shape)] + # Constants w.r.t float data just pass though shape = property(lambda self: self._value.shape) ndim = property(lambda self: self._value.ndim) diff --git a/autograd/numpy/numpy_wrapper.py b/autograd/numpy/numpy_wrapper.py index baa0aed3..5f37bdeb 100644 --- a/autograd/numpy/numpy_wrapper.py +++ b/autograd/numpy/numpy_wrapper.py @@ -65,10 +65,23 @@ def column_stack(tup): return concatenate(arrays, 1) +def _maybe_unwrap(a): + """Unwrap scalar arrays that do not contain sequences.""" + from autograd.numpy.numpy_boxes import ArrayBox + + if not a.shape: # it is a scalar array + if isinstance(a, ArrayBox): + if not isinstance(a._value, (list, tuple)): + return a.item() + else: + return a.item() + return a + + def array(A, *args, **kwargs): t = builtins.type(A) if t in (list, tuple): - return array_from_args(args, kwargs, *map(array, A)) + return array_from_args(args, kwargs, *map(_maybe_unwrap, map(array, A))) else: return _array_from_scalar_or_array(args, kwargs, A) diff --git a/tests/test_misc.py b/tests/test_misc.py index 5cffd3e5..f2945dd3 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -106,3 +106,63 @@ def test_flatten_complex(): val = 1 + 1j flat, unflatten = flatten(val) assert np.all(val == unflatten(flat)) + + +### Some tests for retrieval of objects from object-like arrays + + +def test_object_array(): + x = object() + a = np.array([x]) + assert a.item() is x + + +# Nested lists of objects and object arrays +def test_object_array_nested(): + x = object() + y = object() + a = np.array([[x], [y]]) + ab = np.array([[x, y], [y, x]]) + + assert a[0, 0] is x + assert a[1, 0] is y + + assert ab[0, 0] is x + assert ab[0, 1] is y + assert ab[1, 0] is y + assert ab[1, 1] is x + + # Test mixed nesting; we use object arrays + # for inhomogeneous shapes + b = np.array([x, [y]], dtype=object) + assert b[0] is x + assert b[1][0] is y + + +def test_zero_dim_arrays(): + # 1. numeric scalar array + x = np.array(5) + arr = np.array([x]) + assert arr[0] == 5 + + # 2. boolean scalar array + y = np.array(True) + arr2 = np.array([y]) + assert arr2[0] == True # noqa: E712 because np.True_ is not a bool + + +def test_mixed_object_arrays(): + x = object() + y = "string" + z = 42 + arr = np.array([x, y, z]) + assert arr[0] is x + assert arr[1] == "string" + assert arr[2] == 42 + + +def test_object_array_empty(): + a = np.array([]) + assert a.shape == (0,) + b = np.array([[]]) + assert b.shape == (1, 0)