Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: improve ndarray handling of objects, nested arrays, and mixed types #554

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions autograd/numpy/numpy_boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ class ArrayBox(Box):
def __getitem__(A, idx):
return A[idx]

def item(self):
return self[(0,) * len(self.shape)]

# Constants w.r.t float data just pass though
shape = property(lambda self: self._value.shape)
ndim = property(lambda self: self._value.ndim)
Expand Down
15 changes: 14 additions & 1 deletion autograd/numpy/numpy_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,23 @@ def column_stack(tup):
return concatenate(arrays, 1)


def _maybe_unwrap(a):
"""Unwrap scalar arrays that do not contain sequences."""
from autograd.numpy.numpy_boxes import ArrayBox

if not a.shape: # it is a scalar array
if isinstance(a, ArrayBox):
if not isinstance(a._value, (list, tuple)):
return a.item()
else:
return a.item()
return a


def array(A, *args, **kwargs):
t = builtins.type(A)
if t in (list, tuple):
return array_from_args(args, kwargs, *map(array, A))
return array_from_args(args, kwargs, *map(_maybe_unwrap, map(array, A)))
else:
return _array_from_scalar_or_array(args, kwargs, A)

Expand Down
60 changes: 60 additions & 0 deletions tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,63 @@ def test_flatten_complex():
val = 1 + 1j
flat, unflatten = flatten(val)
assert np.all(val == unflatten(flat))


### Some tests for retrieval of objects from object-like arrays


def test_object_array():
x = object()
a = np.array([x])
assert a.item() is x


# Nested lists of objects and object arrays
def test_object_array_nested():
x = object()
y = object()
a = np.array([[x], [y]])
ab = np.array([[x, y], [y, x]])

assert a[0, 0] is x
assert a[1, 0] is y

assert ab[0, 0] is x
assert ab[0, 1] is y
assert ab[1, 0] is y
assert ab[1, 1] is x

# Test mixed nesting; we use object arrays
# for inhomogeneous shapes
b = np.array([x, [y]], dtype=object)
assert b[0] is x
assert b[1][0] is y


def test_zero_dim_arrays():
# 1. numeric scalar array
x = np.array(5)
arr = np.array([x])
assert arr[0] == 5

# 2. boolean scalar array
y = np.array(True)
arr2 = np.array([y])
assert arr2[0] == True # noqa: E712 because np.True_ is not a bool


def test_mixed_object_arrays():
x = object()
y = "string"
z = 42
arr = np.array([x, y, z])
assert arr[0] is x
assert arr[1] == "string"
assert arr[2] == 42


def test_object_array_empty():
a = np.array([])
assert a.shape == (0,)
b = np.array([[]])
assert b.shape == (1, 0)
Loading