Skip to content

Commit

Permalink
Update __array__ signatures with copy (#9529)
Browse files Browse the repository at this point in the history
* Update __array__ with copy

* Update common.py

* Update indexing.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* copy only available from np2

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Raise if copy=false

* Update groupby.py

* Update test_namedarray.py

* Update pyproject.toml

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Illviljan and pre-commit-ci[bot] authored Sep 25, 2024
1 parent 52f13d4 commit ea06c6f
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 31 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,6 @@ filterwarnings = [
"default:Using a non-tuple sequence for multidimensional indexing is deprecated:FutureWarning",
"default:Duplicate dimension names present:UserWarning:xarray.namedarray.core",
"default:::xarray.tests.test_strategies", # TODO: remove once we know how to deal with a changed signature in protocols
"ignore:__array__ implementation doesn't accept a copy keyword, so passing copy=False failed.",
]

log_cli_level = "INFO"
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def __complex__(self: Any) -> complex:
return complex(self.values)

def __array__(
self: Any, dtype: DTypeLike | None = None, copy: bool | None = None
self: Any, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
if not copy:
if np.lib.NumpyVersion(np.__version__) >= "2.0.0":
Expand Down
5 changes: 4 additions & 1 deletion xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from xarray.core.dataset import calculate_dimensions

if TYPE_CHECKING:
import numpy as np
import pandas as pd

from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes
Expand Down Expand Up @@ -737,7 +738,9 @@ def __bool__(self) -> bool:
def __iter__(self) -> Iterator[str]:
return itertools.chain(self._data_variables, self._children) # type: ignore[arg-type]

def __array__(self, dtype=None, copy=None):
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
raise TypeError(
"cannot directly convert a DataTree into a "
"numpy array. Instead, create an xarray.DataArray "
Expand Down
6 changes: 5 additions & 1 deletion xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,11 @@ def values(self) -> range:
def data(self) -> range:
return range(self.size)

def __array__(self) -> np.ndarray:
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
if copy is False:
raise NotImplementedError(f"An array copy is necessary, got {copy = }.")
return np.arange(self.size)

@property
Expand Down
43 changes: 27 additions & 16 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import numpy as np
import pandas as pd
from packaging.version import Version

from xarray.core import duck_array_ops
from xarray.core.nputils import NumpyVIndexAdapter
Expand Down Expand Up @@ -505,9 +506,14 @@ class ExplicitlyIndexed:

__slots__ = ()

def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
# Leave casting to an array up to the underlying array type.
return np.asarray(self.get_duck_array(), dtype=dtype)
if Version(np.__version__) >= Version("2.0.0"):
return np.asarray(self.get_duck_array(), dtype=dtype, copy=copy)
else:
return np.asarray(self.get_duck_array(), dtype=dtype)

def get_duck_array(self):
return self.array
Expand All @@ -520,11 +526,6 @@ def get_duck_array(self):
key = BasicIndexer((slice(None),) * self.ndim)
return self[key]

def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
# This is necessary because we apply the indexing key in self.get_duck_array()
# Note this is the base class for all lazy indexing classes
return np.asarray(self.get_duck_array(), dtype=dtype)

def _oindex_get(self, indexer: OuterIndexer):
raise NotImplementedError(
f"{self.__class__.__name__}._oindex_get method should be overridden"
Expand Down Expand Up @@ -570,8 +571,13 @@ def __init__(self, array, indexer_cls: type[ExplicitIndexer] = BasicIndexer):
self.array = as_indexable(array)
self.indexer_cls = indexer_cls

def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
return np.asarray(self.get_duck_array(), dtype=dtype)
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
if Version(np.__version__) >= Version("2.0.0"):
return np.asarray(self.get_duck_array(), dtype=dtype, copy=copy)
else:
return np.asarray(self.get_duck_array(), dtype=dtype)

def get_duck_array(self):
return self.array.get_duck_array()
Expand Down Expand Up @@ -830,9 +836,6 @@ def __init__(self, array):
def _ensure_cached(self):
self.array = as_indexable(self.array.get_duck_array())

def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
return np.asarray(self.get_duck_array(), dtype=dtype)

def get_duck_array(self):
self._ensure_cached()
return self.array.get_duck_array()
Expand Down Expand Up @@ -1674,15 +1677,21 @@ def __init__(self, array: pd.Index, dtype: DTypeLike = None):
def dtype(self) -> np.dtype:
return self._dtype

def __array__(self, dtype: DTypeLike = None) -> np.ndarray:
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
if dtype is None:
dtype = self.dtype
array = self.array
if isinstance(array, pd.PeriodIndex):
with suppress(AttributeError):
# this might not be public API
array = array.astype("object")
return np.asarray(array.values, dtype=dtype)

if Version(np.__version__) >= Version("2.0.0"):
return np.asarray(array.values, dtype=dtype, copy=copy)
else:
return np.asarray(array.values, dtype=dtype)

def get_duck_array(self) -> np.ndarray:
return np.asarray(self)
Expand Down Expand Up @@ -1831,15 +1840,17 @@ def __init__(
super().__init__(array, dtype)
self.level = level

def __array__(self, dtype: DTypeLike = None) -> np.ndarray:
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
if dtype is None:
dtype = self.dtype
if self.level is not None:
return np.asarray(
self.array.get_level_values(self.level).values, dtype=dtype
)
else:
return super().__array__(dtype)
return super().__array__(dtype, copy=copy)

def _convert_scalar(self, item):
if isinstance(item, tuple) and self.level is not None:
Expand Down
6 changes: 3 additions & 3 deletions xarray/namedarray/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ def __getitem__(

@overload
def __array__(
self, dtype: None = ..., /, *, copy: None | bool = ...
self, dtype: None = ..., /, *, copy: bool | None = ...
) -> np.ndarray[Any, _DType_co]: ...
@overload
def __array__(
self, dtype: _DType, /, *, copy: None | bool = ...
self, dtype: _DType, /, *, copy: bool | None = ...
) -> np.ndarray[Any, _DType]: ...

def __array__(
self, dtype: _DType | None = ..., /, *, copy: None | bool = ...
self, dtype: _DType | None = ..., /, *, copy: bool | None = ...
) -> np.ndarray[Any, _DType] | np.ndarray[Any, _DType_co]: ...

# TODO: Should return the same subclass but with a new dtype generic.
Expand Down
12 changes: 9 additions & 3 deletions xarray/tests/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def __init__(self, array):
def get_duck_array(self):
raise UnexpectedDataAccess("Tried accessing data")

def __array__(self, dtype: np.typing.DTypeLike = None):
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
raise UnexpectedDataAccess("Tried accessing data")

def __getitem__(self, key):
Expand All @@ -49,7 +51,9 @@ def __init__(self, array: np.ndarray):
def __getitem__(self, key):
return type(self)(self.array[key])

def __array__(self, dtype: np.typing.DTypeLike = None):
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
raise UnexpectedDataAccess("Tried accessing data")

def __array_namespace__(self):
Expand Down Expand Up @@ -140,7 +144,9 @@ def __repr__(self: Any) -> str:
def get_duck_array(self):
raise UnexpectedDataAccess("Tried accessing data")

def __array__(self, dtype: np.typing.DTypeLike = None):
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
raise UnexpectedDataAccess("Tried accessing data")

def __getitem__(self, key) -> "ConcatenatableArray":
Expand Down
6 changes: 4 additions & 2 deletions xarray/tests/test_assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,11 @@ def dims(self):
warnings.warn("warning in test", stacklevel=2)
return super().dims

def __array__(self, dtype=None, copy=None):
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
warnings.warn("warning in test", stacklevel=2)
return super().__array__()
return super().__array__(dtype, copy=copy)

a = WarningVariable("x", [1])
b = WarningVariable("x", [2])
Expand Down
4 changes: 3 additions & 1 deletion xarray/tests/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,9 @@ def test_lazy_array_wont_compute() -> None:
from xarray.core.indexing import LazilyIndexedArray

class LazilyIndexedArrayNotComputable(LazilyIndexedArray):
def __array__(self, dtype=None, copy=None):
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
raise NotImplementedError("Computing this array is not possible.")

arr = LazilyIndexedArrayNotComputable(np.array([1, 2]))
Expand Down
11 changes: 9 additions & 2 deletions xarray/tests/test_namedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import pytest
from packaging.version import Version

from xarray.core.indexing import ExplicitlyIndexed
from xarray.namedarray._typing import (
Expand Down Expand Up @@ -53,8 +54,14 @@ def shape(self) -> _Shape:
class CustomArray(
CustomArrayBase[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co]
):
def __array__(self) -> np.ndarray[Any, np.dtype[np.generic]]:
return np.array(self.array)
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray[Any, np.dtype[np.generic]]:

if Version(np.__version__) >= Version("2.0.0"):
return np.asarray(self.array, dtype=dtype, copy=copy)
else:
return np.asarray(self.array, dtype=dtype)


class CustomArrayIndexable(
Expand Down

0 comments on commit ea06c6f

Please sign in to comment.