Skip to content

Commit

Permalink
Fix array import stuff (#1094)
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl authored Jun 26, 2024
1 parent 3997b82 commit b0ceee9
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions dask_expr/array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
from itertools import product
from typing import Union

import dask.array as da
import numpy as np
from dask import istask
from dask.array.core import slices_from_chunks
from dask.array.core import (
_should_delegate,
finalize,
graph_from_arraylike,
normalize_chunks,
slices_from_chunks,
)
from dask.base import DaskMethodsMixin, named_schedulers
from dask.core import flatten
from dask.utils import SerializableLock, cached_cumsum, cached_property, key_split
Expand All @@ -27,7 +32,7 @@ class Array(core.Expr, DaskMethodsMixin):
__dask_optimize__ = staticmethod(lambda dsk, keys, **kwargs: dsk)

def __dask_postcompute__(self):
return da.core.finalize, ()
return finalize, ()

def __dask_postpersist__(self):
state = self.lower_completely()
Expand Down Expand Up @@ -189,7 +194,7 @@ def __rfloordiv__(self, other):
def __array_ufunc__(self, numpy_ufunc, method, *inputs, **kwargs):
out = kwargs.get("out", ())
for x in inputs + out:
if da.core._should_delegate(self, x):
if _should_delegate(self, x):
return NotImplemented

if method == "__call__":
Expand Down Expand Up @@ -441,7 +446,7 @@ class FromArray(IO):

@property
def chunks(self):
return da.core.normalize_chunks(
return normalize_chunks(
self.operand("chunks"), self.array.shape, dtype=self.array.dtype
)

Expand Down Expand Up @@ -469,7 +474,7 @@ def _layer(self):
# No slicing needed
dsk = {(self._name,) + (0,) * self.array.ndim: self.array}
else:
dsk = da.core.graph_from_arraylike(
dsk = graph_from_arraylike(
self.array, chunks=self.chunks, shape=self.array.shape, name=self._name
)
return dict(dsk) # this comes as a legacy HLG for now
Expand Down Expand Up @@ -508,6 +513,9 @@ def _layer(self):


def from_array(x, chunks="auto", lock=None):
if isinstance(x, (list, tuple, memoryview) + np.ScalarType):
x = np.array(x)

return FromArray(x, chunks, lock=lock)


Expand Down

0 comments on commit b0ceee9

Please sign in to comment.