Skip to content

Commit

Permalink
Implement DataTree.isel and DataTree.sel (#9588)
Browse files Browse the repository at this point in the history
* Implement DataTree.isel and DataTree.sel

* add api docs

* fix CI failures

* add docstrings for DataTree.isel and DataTree.sel

* Add comments

* add another indexing test
  • Loading branch information
shoyer authored Oct 10, 2024
1 parent 4c3c22b commit c057d13
Show file tree
Hide file tree
Showing 3 changed files with 274 additions and 21 deletions.
15 changes: 8 additions & 7 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -761,16 +761,17 @@ Compare one ``DataTree`` object to another.
DataTree.equals
DataTree.identical

.. Indexing
.. --------
Indexing
--------

.. Index into all nodes in the subtree simultaneously.
Index into all nodes in the subtree simultaneously.

.. .. autosummary::
.. :toctree: generated/
.. autosummary::
:toctree: generated/

DataTree.isel
DataTree.sel

.. DataTree.isel
.. DataTree.sel
.. DataTree.drop_sel
.. DataTree.drop_isel
.. DataTree.head
Expand Down
188 changes: 186 additions & 2 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
from xarray.core.merge import dataset_update_method
from xarray.core.options import OPTIONS as XR_OPTS
from xarray.core.treenode import NamedNode, NodePath
from xarray.core.types import Self
from xarray.core.utils import (
Default,
FilteredMapping,
Frozen,
_default,
drop_dims_from_indexers,
either_dict_or_kwargs,
maybe_wrap_array,
)
Expand All @@ -54,7 +56,12 @@

from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes
from xarray.core.merge import CoercibleMapping, CoercibleValue
from xarray.core.types import ErrorOptions, NetcdfWriteModes, ZarrWriteModes
from xarray.core.types import (
ErrorOptions,
ErrorOptionsWithWarn,
NetcdfWriteModes,
ZarrWriteModes,
)

# """
# DEVELOPERS' NOTE
Expand Down Expand Up @@ -1081,7 +1088,7 @@ def from_dict(
d: Mapping[str, Dataset | DataTree | None],
/,
name: str | None = None,
) -> DataTree:
) -> Self:
"""
Create a datatree from a dictionary of data objects, organised by paths into the tree.
Expand Down Expand Up @@ -1601,3 +1608,180 @@ def to_zarr(
compute=compute,
**kwargs,
)

def _selective_indexing(
self,
func: Callable[[Dataset, Mapping[Any, Any]], Dataset],
indexers: Mapping[Any, Any],
missing_dims: ErrorOptionsWithWarn = "raise",
) -> Self:
"""Apply an indexing operation over the subtree, handling missing
dimensions and inherited coordinates gracefully by only applying
indexing at each node selectively.
"""
all_dims = set()
for node in self.subtree:
all_dims.update(node._node_dims)
indexers = drop_dims_from_indexers(indexers, all_dims, missing_dims)

result = {}
for node in self.subtree:
node_indexers = {k: v for k, v in indexers.items() if k in node.dims}
node_result = func(node.dataset, node_indexers)
# Indexing datasets corresponding to each node results in redundant
# coordinates when indexes from a parent node are inherited.
# Ideally, we would avoid creating such coordinates in the first
# place, but that would require implementing indexing operations at
# the Variable instead of the Dataset level.
for k in node_indexers:
if k not in node._node_coord_variables and k in node_result.coords:
# We remove all inherited coordinates. Coordinates
# corresponding to an index would be de-duplicated by
# _deduplicate_inherited_coordinates(), but indexing (e.g.,
# with a scalar) can also create scalar coordinates, which
# need to be explicitly removed.
del node_result.coords[k]
result[node.path] = node_result
return type(self).from_dict(result, name=self.name)

def isel(
self,
indexers: Mapping[Any, Any] | None = None,
drop: bool = False,
missing_dims: ErrorOptionsWithWarn = "raise",
**indexers_kwargs: Any,
) -> Self:
"""Returns a new data tree with each array indexed along the specified
dimension(s).
This method selects values from each array using its `__getitem__`
method, except this method does not require knowing the order of
each array's dimensions.
Parameters
----------
indexers : dict, optional
A dict with keys matching dimensions and values given
by integers, slice objects or arrays.
indexer can be a integer, slice, array-like or DataArray.
If DataArrays are passed as indexers, xarray-style indexing will be
carried out. See :ref:`indexing` for the details.
One of indexers or indexers_kwargs must be provided.
drop : bool, default: False
If ``drop=True``, drop coordinates variables indexed by integers
instead of making them scalar.
missing_dims : {"raise", "warn", "ignore"}, default: "raise"
What to do if dimensions that should be selected from are not present in the
Dataset:
- "raise": raise an exception
- "warn": raise a warning, and ignore the missing dimensions
- "ignore": ignore the missing dimensions
**indexers_kwargs : {dim: indexer, ...}, optional
The keyword arguments form of ``indexers``.
One of indexers or indexers_kwargs must be provided.
Returns
-------
obj : DataTree
A new DataTree with the same contents as this data tree, except each
array and dimension is indexed by the appropriate indexers.
If indexer DataArrays have coordinates that do not conflict with
this object, then these coordinates will be attached.
In general, each array's data will be a view of the array's data
in this dataset, unless vectorized indexing was triggered by using
an array indexer, in which case the data will be a copy.
See Also
--------
DataTree.sel
Dataset.isel
"""

def apply_indexers(dataset, node_indexers):
return dataset.isel(node_indexers, drop=drop)

indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
return self._selective_indexing(
apply_indexers, indexers, missing_dims=missing_dims
)

def sel(
self,
indexers: Mapping[Any, Any] | None = None,
method: str | None = None,
tolerance: int | float | Iterable[int | float] | None = None,
drop: bool = False,
**indexers_kwargs: Any,
) -> Self:
"""Returns a new data tree with each array indexed by tick labels
along the specified dimension(s).
In contrast to `DataTree.isel`, indexers for this method should use
labels instead of integers.
Under the hood, this method is powered by using pandas's powerful Index
objects. This makes label based indexing essentially just as fast as
using integer indexing.
It also means this method uses pandas's (well documented) logic for
indexing. This means you can use string shortcuts for datetime indexes
(e.g., '2000-01' to select all values in January 2000). It also means
that slices are treated as inclusive of both the start and stop values,
unlike normal Python indexing.
Parameters
----------
indexers : dict, optional
A dict with keys matching dimensions and values given
by scalars, slices or arrays of tick labels. For dimensions with
multi-index, the indexer may also be a dict-like object with keys
matching index level names.
If DataArrays are passed as indexers, xarray-style indexing will be
carried out. See :ref:`indexing` for the details.
One of indexers or indexers_kwargs must be provided.
method : {None, "nearest", "pad", "ffill", "backfill", "bfill"}, optional
Method to use for inexact matches:
* None (default): only exact matches
* pad / ffill: propagate last valid index value forward
* backfill / bfill: propagate next valid index value backward
* nearest: use nearest valid index value
tolerance : optional
Maximum distance between original and new labels for inexact
matches. The values of the index at the matching locations must
satisfy the equation ``abs(index[indexer] - target) <= tolerance``.
drop : bool, optional
If ``drop=True``, drop coordinates variables in `indexers` instead
of making them scalar.
**indexers_kwargs : {dim: indexer, ...}, optional
The keyword arguments form of ``indexers``.
One of indexers or indexers_kwargs must be provided.
Returns
-------
obj : DataTree
A new DataTree with the same contents as this data tree, except each
variable and dimension is indexed by the appropriate indexers.
If indexer DataArrays have coordinates that do not conflict with
this object, then these coordinates will be attached.
In general, each array's data will be a view of the array's data
in this dataset, unless vectorized indexing was triggered by using
an array indexer, in which case the data will be a copy.
See Also
--------
DataTree.isel
Dataset.sel
"""

def apply_indexers(dataset, node_indexers):
# TODO: reimplement in terms of map_index_queries(), to avoid
# redundant look-ups of integer positions from labels (via indexes)
# on child nodes.
return dataset.sel(
node_indexers, method=method, tolerance=tolerance, drop=drop
)

indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel")
return self._selective_indexing(apply_indexers, indexers)
92 changes: 80 additions & 12 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,6 @@ def test_ipython_key_completions(self, create_test_datatree):
var_keys = list(dt.variables.keys())
assert all(var_key in key_completions for var_key in var_keys)

@pytest.mark.xfail(reason="sel not implemented yet")
def test_operation_with_attrs_but_no_data(self):
# tests bug from xarray-datatree GH262
xs = xr.Dataset({"testvar": xr.DataArray(np.ones((2, 3)))})
Expand Down Expand Up @@ -1561,26 +1560,95 @@ def test_filter(self):
assert_identical(elders, expected)


class TestDSMethodInheritance:
@pytest.mark.xfail(reason="isel not implemented yet")
def test_dataset_method(self):
ds = xr.Dataset({"a": ("x", [1, 2, 3])})
dt = DataTree.from_dict(
class TestIndexing:

def test_isel_siblings(self):
tree = DataTree.from_dict(
{
"/": ds,
"/results": ds,
"/first": xr.Dataset({"a": ("x", [1, 2])}),
"/second": xr.Dataset({"b": ("x", [1, 2, 3])}),
}
)

expected = DataTree.from_dict(
{
"/": ds.isel(x=1),
"/results": ds.isel(x=1),
"/first": xr.Dataset({"a": 2}),
"/second": xr.Dataset({"b": 3}),
}
)
actual = tree.isel(x=-1)
assert_equal(actual, expected)

result = dt.isel(x=1)
assert_equal(result, expected)
expected = DataTree.from_dict(
{
"/first": xr.Dataset({"a": ("x", [1])}),
"/second": xr.Dataset({"b": ("x", [1])}),
}
)
actual = tree.isel(x=slice(1))
assert_equal(actual, expected)

actual = tree.isel(x=[0])
assert_equal(actual, expected)

actual = tree.isel(x=slice(None))
assert_equal(actual, tree)

def test_isel_inherited(self):
tree = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": [1, 2]}),
"/child": xr.Dataset({"foo": ("x", [3, 4])}),
}
)

expected = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": 2}),
"/child": xr.Dataset({"foo": 4}),
}
)
actual = tree.isel(x=-1)
assert_equal(actual, expected)

expected = DataTree.from_dict(
{
"/child": xr.Dataset({"foo": 4}),
}
)
actual = tree.isel(x=-1, drop=True)
assert_equal(actual, expected)

expected = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": [1]}),
"/child": xr.Dataset({"foo": ("x", [3])}),
}
)
actual = tree.isel(x=[0])
assert_equal(actual, expected)

actual = tree.isel(x=slice(None))
assert_equal(actual, tree)

def test_sel(self):
tree = DataTree.from_dict(
{
"/first": xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"x": [1, 2, 3]}),
"/second": xr.Dataset({"b": ("x", [4, 5])}, coords={"x": [2, 3]}),
}
)
expected = DataTree.from_dict(
{
"/first": xr.Dataset({"a": 2}, coords={"x": 2}),
"/second": xr.Dataset({"b": 4}, coords={"x": 2}),
}
)
actual = tree.sel(x=2)
assert_equal(actual, expected)


class TestDSMethodInheritance:

@pytest.mark.xfail(reason="reduce methods not implemented yet")
def test_reduce_method(self):
Expand Down

0 comments on commit c057d13

Please sign in to comment.