Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: correct dask array handling in _calc_idxminmax #3922

Merged
merged 23 commits into from
May 13, 2020

Conversation

kmuehlbauer
Copy link
Contributor

@kmuehlbauer kmuehlbauer commented Mar 31, 2020

Fixes dask handling for implementation in #3871.

  • Closes #xxxx
  • Tests added
  • Passes isort -rc . && black . && mypy . && flake8
  • Fully documented, including whats-new.rst for all changes and api.rst for new API

@pep8speaks
Copy link

pep8speaks commented Mar 31, 2020

Hello @kmuehlbauer! Thanks for updating this PR. We checked the lines you've touched for PEP 8 issues, and found:

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2020-05-09 14:03:57 UTC

if isinstance(array.data, dask_array_type):
res = array.map_blocks(
lambda a, b: a[b], coordarray, indx, dtype=indx.dtype
).compute()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What breaks if you don't compute?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

further down, this will break:

# The dim is gone but we need to remove the corresponding coordinate.
del res.coords[dim]

# Copy attributes from argmin/argmax, if any
res.attrs = indx.attrs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My dask-knowledge lacks, so I was not able to come up with a better solution, unfortunately

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah this was quite broken. I just pushed a commit. Please see if that works. Clearly we need to add some tests with dask-backed objects.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dcherian Thanks, I'll have a look first thing next morning (CET).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for adding dask tests

@kmuehlbauer
Copy link
Contributor Author

@dcherian This seemed to work until computation:

IndexError: Unlabeled multi-dimensional array cannot be used for indexing: array_bin

where array_bin is the dimension over which idxmax is calculated. I tried to wrap my head around this to no avail. I keep trying but appreciate any hints...

Full Traceback:

/home/kai/miniconda/envs/wradlib_38_01/lib/python3.8/site-packages/xarray/core/dataarray.py in compute(self, **kwargs)
    839         """
    840         new = self.copy(deep=False)
--> 841         return new.load(**kwargs)
    842 
    843     def persist(self, **kwargs) -> "DataArray":

/home/kai/miniconda/envs/wradlib_38_01/lib/python3.8/site-packages/xarray/core/dataarray.py in load(self, **kwargs)
    813         dask.array.compute
    814         """
--> 815         ds = self._to_temp_dataset().load(**kwargs)
    816         new = self._from_temp_dataset(ds)
    817         self._variable = new._variable

/home/kai/miniconda/envs/wradlib_38_01/lib/python3.8/site-packages/xarray/core/dataset.py in load(self, **kwargs)
    654 
    655             # evaluate all the dask arrays simultaneously
--> 656             evaluated_data = da.compute(*lazy_data.values(), **kwargs)
    657 
    658             for k, data in zip(lazy_data, evaluated_data):

/home/kai/miniconda/envs/wradlib_38_01/lib/python3.8/site-packages/dask/base.py in compute(*args, **kwargs)
    435     keys = [x.__dask_keys__() for x in collections]
    436     postcomputes = [x.__dask_postcompute__() for x in collections]
--> 437     results = schedule(dsk, keys, **kwargs)
    438     return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
    439 

/home/kai/miniconda/envs/wradlib_38_01/lib/python3.8/site-packages/dask/threaded.py in get(dsk, result, cache, num_workers, pool, **kwargs)
     74                 pools[thread][num_workers] = pool
     75 
---> 76     results = get_async(
     77         pool.apply_async,
     78         len(pool._pool),

/home/kai/miniconda/envs/wradlib_38_01/lib/python3.8/site-packages/dask/local.py in get_async(apply_async, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, **kwargs)
    484                         _execute_task(task, data)  # Re-execute locally
    485                     else:
--> 486                         raise_exception(exc, tb)
    487                 res, worker_id = loads(res_info)
    488                 state["cache"][key] = res

/home/kai/miniconda/envs/wradlib_38_01/lib/python3.8/site-packages/dask/local.py in reraise(exc, tb)
    314     if exc.__traceback__ is not tb:
    315         raise exc.with_traceback(tb)
--> 316     raise exc
    317 
    318 

/home/kai/miniconda/envs/wradlib_38_01/lib/python3.8/site-packages/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    220     try:
    221         task, data = loads(task_info)
--> 222         result = _execute_task(task, data)
    223         id = get_id()
    224         result = dumps((result, id))

/home/kai/miniconda/envs/wradlib_38_01/lib/python3.8/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    119         # temporaries by their reference count and can execute certain
    120         # operations in-place.
--> 121         return func(*(_execute_task(a, cache) for a in args))
    122     elif not ishashable(arg):
    123         return arg

/home/kai/miniconda/envs/wradlib_38_01/lib/python3.8/site-packages/dask/optimization.py in __call__(self, *args)
    980         if not len(args) == len(self.inkeys):
    981             raise ValueError("Expected %d args, got %d" % (len(self.inkeys), len(args)))
--> 982         return core.get(self.dsk, self.outkey, dict(zip(self.inkeys, args)))
    983 
    984     def __reduce__(self):

/home/kai/miniconda/envs/wradlib_38_01/lib/python3.8/site-packages/dask/core.py in get(dsk, out, cache)
    149     for key in toposort(dsk):
    150         task = dsk[key]
--> 151         result = _execute_task(task, cache)
    152         cache[key] = result
    153     result = _execute_task(out, cache)

/home/kai/miniconda/envs/wradlib_38_01/lib/python3.8/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
    119         # temporaries by their reference count and can execute certain
    120         # operations in-place.
--> 121         return func(*(_execute_task(a, cache) for a in args))
    122     elif not ishashable(arg):
    123         return arg

/home/kai/miniconda/envs/wradlib_38_01/lib/python3.8/site-packages/xarray/core/computation.py in <lambda>(ind, coord)
   1387         res = indx.copy(
   1388             data=indx.data.map_blocks(
-> 1389                 lambda ind, coord: coord[(ind,)], coordarray, dtype=coordarray.dtype
   1390             )
   1391         )

/home/kai/miniconda/envs/wradlib_38_01/lib/python3.8/site-packages/xarray/core/dataarray.py in __getitem__(self, key)
    642         else:
    643             # xarray-style array indexing
--> 644             return self.isel(indexers=self._item_key_to_dict(key))
    645 
    646     def __setitem__(self, key: Any, value: Any) -> None:

/home/kai/miniconda/envs/wradlib_38_01/lib/python3.8/site-packages/xarray/core/dataarray.py in isel(self, indexers, drop, **indexers_kwargs)
   1020         indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
   1021         if any(is_fancy_indexer(idx) for idx in indexers.values()):
-> 1022             ds = self._to_temp_dataset()._isel_fancy(indexers, drop=drop)
   1023             return self._from_temp_dataset(ds)
   1024 

/home/kai/miniconda/envs/wradlib_38_01/lib/python3.8/site-packages/xarray/core/dataset.py in _isel_fancy(self, indexers, drop)
   1962         # Note: we need to preserve the original indexers variable in order to merge the
   1963         # coords below
-> 1964         indexers_list = list(self._validate_indexers(indexers))
   1965 
   1966         variables: Dict[Hashable, Variable] = {}

/home/kai/miniconda/envs/wradlib_38_01/lib/python3.8/site-packages/xarray/core/dataset.py in _validate_indexers(self, indexers)
   1805 
   1806                 if v.ndim > 1:
-> 1807                     raise IndexError(
   1808                         "Unlabeled multi-dimensional array cannot be "
   1809                         "used for indexing: {}".format(k)

IndexError: Unlabeled multi-dimensional array cannot be used for indexing: array_bin

@max-sixty
Copy link
Collaborator

@kmuehlbauer what's the code that's running to generate that traceback? I can try and help in lieu of @dcherian

Thanks for giving this a go

And CC @toddrjen if they have any insight

@kmuehlbauer
Copy link
Contributor Author

kmuehlbauer commented Apr 3, 2020

@max-sixty Thanks! I'll really appreciate your help. I've tracked the possible source down to a dimension problem. I've tried to create a minimal example as follows using the current idxmax implementation from above. I copied only the dask related lines of code:

# create dask backed 3d array
darray = da.from_array(np.random.RandomState(0).randn(10*20*30).reshape(10, 20, 30), chunks=(10, 20, 30), name='data_arr')
array = xr.DataArray(darray, dims=["x", "y", 'z'])
array = array.assign_coords({'x': (['x'], np.arange(10)),
                             'y': (['y'], np.arange(20)),
                             'z': (['z'], np.arange(30)),
                            })

func=lambda x, *args, **kwargs: x.argmax(*args, **kwargs)
indx = func(array, dim='z', axis=None, keep_attrs=True, skipna=False)
coordarray = array['z']
res = indx.copy(
    data=indx.data.map_blocks(
        lambda ind, coord: coord[(ind,)], coordarray, dtype=coordarray.dtype
    )
)
print(res)
# the following line break breaks
print(res.compute())

# using only 2dim array everything works as intended
array2d = array.sel(y=0, drop=True)
indx = func(array2d, dim='z', axis=None, keep_attrs=True, skipna=False)
coordarray = array['z']
res = indx.copy(
    data=indx.data.map_blocks(
        lambda ind, coord: coord[(ind,)], coordarray, dtype=coordarray.dtype
    )
)
print(res)
# this works for two dim data
print(res.compute())

@keewis
Copy link
Collaborator

keewis commented Apr 3, 2020

The issue is that argmax produces a ndim - 1 array of indexes. In the case of 2D input data that would be 1D, so since indexing only allows to index with 1D arrays, your code works. For ndim > 2, we'd be trying to index with arrays with ndim > 1, so the indexing fails.

To get your 3D example (and potentially every N-D example) to work, simply fall back to the wrapped array's integer indexing (using coordarray.data):

In [2]: darray = da.from_array( 
   ...:     np.random.RandomState(0).randn(10 *20 * 30).reshape(10, 20, 30), 
   ...:     chunks=(1, 20, 30),  # so we actually have multiple blocks
   ...:     name='data_arr' 
   ...: ) 
   ...: array = xr.DataArray( 
   ...:     darray, 
   ...:     dims=["x", "y", 'z'], 
   ...:     coords={"x": np.arange(10), "y": np.arange(20), "z": np.arange(30)}, 
   ...: ) 
   ...: array
Out[2]: 
<xarray.DataArray 'data_arr' (x: 10, y: 20, z: 30)>
dask.array<data_arr, shape=(10, 20, 30), dtype=float64, chunksize=(1, 20, 30), chunktype=numpy.ndarray>
Coordinates:
  * x        (x) int64 0 1 2 3 4 5 6 7 8 9
  * y        (y) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
  * z        (z) int64 0 1 2 3 4 5 6 7 8 9 10 ... 20 21 22 23 24 25 26 27 28 29

In [3]: indx = array.argmin(dim='z', keep_attrs=True, skipna=False) 
   ...: res = indx.copy( 
   ...:     data=indx.data.map_blocks( 
   ...:         lambda ind, coord: coord[(ind,)], 
   ...:         array.z.data, 
   ...:         dtype=array.z.dtype 
   ...:     ) 
   ...: )

In [4]: res.compute()
Out[4]: 
<xarray.DataArray 'data_arr' (x: 10, y: 20)>
array([[20,  3,  3, 11, 20, 17,  3, 27, 24,  1,  7,  4, 22, 14,  7, 18,
         5, 18,  7, 19],
       [10, 21, 25,  3, 15, 25, 28,  8, 10,  9, 13,  3, 24, 17, 19, 23,
        12, 19, 19, 28],
       [ 1, 26, 10,  9, 16,  8, 17,  8,  6, 24, 28, 13, 23, 22, 26, 13,
        28, 11,  6, 16],
       [ 6,  9, 26, 27,  1,  2, 21,  8, 10, 19, 14, 14, 20, 25, 24,  4,
        18, 12, 20,  2],
       [22,  5, 12, 17, 13, 23, 23,  8, 27, 22,  1, 19, 26, 16, 12, 17,
        19, 28,  8, 12],
       [20,  8, 25, 13,  4, 12, 23, 13, 27, 18, 15, 28, 10, 10,  0, 12,
         5, 14,  5, 27],
       [29,  0, 19,  7, 15,  2,  8,  8, 13,  4, 12,  1,  7, 19, 14,  0,
         3,  7, 12,  9],
       [ 9,  8,  4,  9, 17,  6,  7,  5, 29,  0, 15, 28, 22,  6, 24, 24,
        20,  0, 24, 23],
       [ 1, 19, 12, 20,  4, 26,  5, 13, 21, 26, 25, 10,  5,  1, 11, 21,
         6, 18,  4, 21],
       [15, 27, 13,  7, 25,  3, 14, 14, 17, 15, 11,  4, 16, 22, 22, 23,
         0, 16, 26, 13]])
Coordinates:
  * x        (x) int64 0 1 2 3 4 5 6 7 8 9
  * y        (y) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19

Note that in this case res == indx since z = np.arange(30)

@kmuehlbauer
Copy link
Contributor Author

@keewis Thanks a bunch for the explanation. Would we be on the safe side, if we use your proposed N-D example? It also works in the 2d-case.

@keewis
Copy link
Collaborator

keewis commented Apr 3, 2020

👍

@kmuehlbauer
Copy link
Contributor Author

@keewis I checked with my datasets, works like a charm. I'll try to add dask tests for this as @shoyer suggested. Where should these tests go? Currently the idxmax/idxmin tests are in test_dataarray and test_dataset:

def test_idxmin(self, x, minindex, maxindex, nanindex):

def test_idxmax(self, x, minindex, maxindex, nanindex):

expected = Dataset({"x": -10})
actual = ds.idxmin()
assert_identical(expected, actual)
expected = Dataset({"x": 10})
actual = ds.idxmax()
assert_identical(expected, actual)

Any pointers?

@keewis
Copy link
Collaborator

keewis commented Apr 3, 2020

I'd put the dask tests directly below the idxmin / idxmax tests, decorated with requires_dask.

Edit: that's easy for test_dataarray.py, but not as much for test_dataset.py. Since Dataset delegates to the DataArray methods we might get away with not adding dask tests to Dataset, but I'm not really sure.

@kmuehlbauer
Copy link
Contributor Author

@keewis I've started by adding the dask tests to the existing idxmin/idxmaxtests (test_dataarray). This works locally, but one tests (data is datetimes) fails for the dask test. Any immediate thought on this?

@keewis
Copy link
Collaborator

keewis commented Apr 3, 2020

it seems you can't use argmin with dask arrays if the dtype is datetime64 (M8):

In [24]: time = np.asarray(pd.date_range("2019-07-17", periods=10)) 
    ...: array = xr.DataArray( 
    ...:     time, 
    ...:     dims="x", 
    ...:     coords={"x": np.arange(time.size) * 4}, 
    ...: ).chunk({}) 
    ...: array                                                                                                                                                                                                                  
Out[24]: 
<xarray.DataArray (x: 10)>
dask.array<xarray-<this-array>, shape=(10,), dtype=datetime64[ns], chunksize=(10,), chunktype=numpy.ndarray>
Coordinates:
  * x        (x) int64 0 4 8 12 16 20 24 28 32 36

In [25]: array.compute().argmin(dim="x")                                                                                                                                                                                        
Out[25]: 
<xarray.DataArray ()>
array(0)

In [26]: array.argmin(dim="x")                                                                                                                                                                                                  
---------------------------------------------------------------------------
UFuncTypeError                            Traceback (most recent call last)
<ipython-input-26-e665d5b1b9b4> in <module>
----> 1 array.argmin(dim="x")

.../xarray/core/common.py in wrapped_func(self, dim, axis, skipna, **kwargs)
     44 
     45             def wrapped_func(self, dim=None, axis=None, skipna=None, **kwargs):
---> 46                 return self.reduce(func, dim, axis, skipna=skipna, **kwargs)
     47 
     48         else:

.../xarray/core/dataarray.py in reduce(self, func, dim, axis, keep_attrs, keepdims, **kwargs)
   2260         """
   2261 
-> 2262         var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs)
   2263         return self._replace_maybe_drop_dims(var)
   2264 

.../xarray/core/variable.py in reduce(self, func, dim, axis, keep_attrs, keepdims, allow_lazy, **kwargs)
   1573 
   1574         if axis is not None:
-> 1575             data = func(input_data, axis=axis, **kwargs)
   1576         else:
   1577             data = func(input_data, **kwargs)

.../xarray/core/duck_array_ops.py in f(values, axis, skipna, **kwargs)
    302 
    303         try:
--> 304             return func(values, axis=axis, **kwargs)
    305         except AttributeError:
    306             if not isinstance(values, dask_array_type):

.../xarray/core/duck_array_ops.py in f(*args, **kwargs)
     45             else:
     46                 wrapped = getattr(eager_module, name)
---> 47             return wrapped(*args, **kwargs)
     48 
     49     else:

~/.conda/envs/xarray/lib/python3.8/site-packages/dask/array/reductions.py in wrapped(x, axis, split_every, out)
   1002 
   1003     def wrapped(x, axis=None, split_every=None, out=None):
-> 1004         return arg_reduction(
   1005             x, chunk, combine, agg, axis, split_every=split_every, out=out
   1006         )

~/.conda/envs/xarray/lib/python3.8/site-packages/dask/array/reductions.py in arg_reduction(x, chunk, combine, agg, axis, split_every, out)
    980     tmp = Array(graph, name, chunks, dtype=x.dtype)
    981     dtype = np.argmin([1]).dtype
--> 982     result = _tree_reduce(tmp, agg, axis, False, dtype, split_every, combine)
    983     return handle_out(out, result)
    984 

~/.conda/envs/xarray/lib/python3.8/site-packages/dask/array/reductions.py in _tree_reduce(x, aggregate, axis, keepdims, dtype, split_every, combine, name, concatenate, reduced_meta)
    243     if concatenate:
    244         func = compose(func, partial(_concatenate2, axes=axis))
--> 245     return partial_reduce(
    246         func,
    247         x,

~/.conda/envs/xarray/lib/python3.8/site-packages/dask/array/reductions.py in partial_reduce(func, x, split_every, keepdims, dtype, name, reduced_meta)
    314     if is_arraylike(meta) and meta.ndim != len(out_chunks):
    315         if len(out_chunks) == 0:
--> 316             meta = meta.sum()
    317         else:
    318             meta = meta.reshape((0,) * len(out_chunks))

~/.conda/envs/xarray/lib/python3.8/site-packages/numpy/core/_methods.py in _sum(a, axis, dtype, out, keepdims, initial, where)
     36 def _sum(a, axis=None, dtype=None, out=None, keepdims=False,
     37          initial=_NoValue, where=True):
---> 38     return umr_sum(a, axis, dtype, out, keepdims, initial, where)
     39 
     40 def _prod(a, axis=None, dtype=None, out=None, keepdims=False,

UFuncTypeError: ufunc 'add' cannot use operands with types dtype('<M8[ns]') and dtype('<M8[ns]')

I guess that's a dask bug?

Edit: you can reproduce it without xarray, so definitely a bug in dask:

MWE with only numpy / dask.array
In [32]: time = np.asarray(pd.date_range("2019-07-17", periods=10)) 
    ...: np.argmin(da.from_array(time))                                                                                                                                                                                         
---------------------------------------------------------------------------
UFuncTypeError                            Traceback (most recent call last)
<ipython-input-32-190cb901ff65> in <module>
      1 time = np.asarray(pd.date_range("2019-07-17", periods=10))
----> 2 np.argmin(da.from_array(time))

<__array_function__ internals> in argmin(*args, **kwargs)

~/.conda/envs/xarray/lib/python3.8/site-packages/dask/array/core.py in __array_function__(self, func, types, args, kwargs)
   1348         if da_func is func:
   1349             return handle_nonmatching_names(func, args, kwargs)
-> 1350         return da_func(*args, **kwargs)
   1351 
   1352     @property

~/.conda/envs/xarray/lib/python3.8/site-packages/dask/array/reductions.py in wrapped(x, axis, split_every, out)
   1002 
   1003     def wrapped(x, axis=None, split_every=None, out=None):
-> 1004         return arg_reduction(
   1005             x, chunk, combine, agg, axis, split_every=split_every, out=out
   1006         )

~/.conda/envs/xarray/lib/python3.8/site-packages/dask/array/reductions.py in arg_reduction(x, chunk, combine, agg, axis, split_every, out)
    980     tmp = Array(graph, name, chunks, dtype=x.dtype)
    981     dtype = np.argmin([1]).dtype
--> 982     result = _tree_reduce(tmp, agg, axis, False, dtype, split_every, combine)
    983     return handle_out(out, result)
    984 

~/.conda/envs/xarray/lib/python3.8/site-packages/dask/array/reductions.py in _tree_reduce(x, aggregate, axis, keepdims, dtype, split_every, combine, name, concatenate, reduced_meta)
    243     if concatenate:
    244         func = compose(func, partial(_concatenate2, axes=axis))
--> 245     return partial_reduce(
    246         func,
    247         x,

~/.conda/envs/xarray/lib/python3.8/site-packages/dask/array/reductions.py in partial_reduce(func, x, split_every, keepdims, dtype, name, reduced_meta)
    314     if is_arraylike(meta) and meta.ndim != len(out_chunks):
    315         if len(out_chunks) == 0:
--> 316             meta = meta.sum()
    317         else:
    318             meta = meta.reshape((0,) * len(out_chunks))

~/.conda/envs/xarray/lib/python3.8/site-packages/numpy/core/_methods.py in _sum(a, axis, dtype, out, keepdims, initial, where)
     36 def _sum(a, axis=None, dtype=None, out=None, keepdims=False,
     37          initial=_NoValue, where=True):
---> 38     return umr_sum(a, axis, dtype, out, keepdims, initial, where)
     39 
     40 def _prod(a, axis=None, dtype=None, out=None, keepdims=False,

UFuncTypeError: ufunc 'add' cannot use operands with types dtype('<M8[ns]') and dtype('<M8[ns]')

@kmuehlbauer
Copy link
Contributor Author

@keewis OK, how should I handle this? Shall we XFAIL the these tests then?

@keewis
Copy link
Collaborator

keewis commented Apr 3, 2020

I think so? For now, xfail if use_dask and x.dtype.kind == "M"

…ray, xfail dask tests for dtype dateime64 (M)
@kmuehlbauer
Copy link
Contributor Author

Seems that everything goes well, besides the upstream-dev run.

If this is ready for merge, should I extend the idxmin/idxmax section of whats-new.rst ? And how should I distribute the credit for all contributors @dcherian, @keewis, @max-sixty?

@keewis
Copy link
Collaborator

keewis commented Apr 3, 2020

I'd say add a new entry. Also, I think we're all just reviewers so adding just your name should be fine.

@kmuehlbauer
Copy link
Contributor Author

This is ready from my end for final review.

Should be merged before #3936, IMHO.

@dcherian
Copy link
Contributor

If your tests are passing now, it's likely that they're computing things to make things work. We should add the with raise_if_dask_computes() context (imported from test_dask.py) when testing with dask variables but it looks non-trivial given how the tests are structured at the moment.

@kmuehlbauer
Copy link
Contributor Author

@dcherian Thanks for the suggestion with the dask compute context, I'll have a look the next day.

Nevertheless, I've debugged locally and the res-output of the idxmin/idxmax holds dask data.

Anyway, I'll revert to the former working status and leave a comment to this PR in the code.

@dcherian
Copy link
Contributor

I suggest updating the tests before reverting anything. This solution may work ...

@kmuehlbauer
Copy link
Contributor Author

@dcherian I've tried to apply the with raise_if_dask_computes() context, but as you already mentioned, it is quite hard to incorporate into the tests. Every once in a while we would need to introduce some if/else clause to check for use_dask and also the computation count inside idxmin/idxmax varies based on source data layout (because of usage of argmin/argmax inside).

I'm now totally unsure how to proceed from here. Any guidance very much appreciated.

@dcherian
Copy link
Contributor

@kmuehlbauer I've pushed a commit adding the decorator to just the 2D test_idxmax. The decorator does nothing for numpy so all the numpy tests pass but the dask tests fail now (because it is indeed computing things). Prior to this commit the 2D tests were passing, so we should go back to the map_blocks solution I guess.

because you can't index a NumPy array with a dask array?

I think @shoyer is right here.

@kmuehlbauer
Copy link
Contributor Author

@dcherian Thanks for explaining the decorator a bit more. So it's indeed simpler than I thought. I'll revert to the map_blocks solution. I'll not have time today, so this will have to wait a bit.

@dcherian
Copy link
Contributor

Not a problem. Thanks for working on this!

@shoyer
Copy link
Member

shoyer commented Apr 16, 2020

A simpler option than using map_blocks might be to cast the array being indexed into a Dask array if needed, using dask.array.asarray(). Then both objects would be dask arrays, so indexing should work.

@kmuehlbauer
Copy link
Contributor Author

@shoyer Thanks for the hint.

I'm currently experimenting with the different possibilities.

non-dask:

<xarray.DataArray (y: 3, x: 7)>
array([[  0,   1,   2,   0,  -2,  -4,   2],
       [  1,   1,   1,   1,   1,   1,   1],
       [  0,   0, -10,   5,  20,   0,   0]])
Coordinates:
  * x        (x) int64 0 4 8 12 16 20 24
  * y        (y) int64 1 0 -1
Attributes:
    attr1:    value1
    attr2:    2929

dask:

<xarray.DataArray (y: 3, x: 7)>
dask.array<xarray-<this-array>, shape=(3, 7), dtype=int64, chunksize=(3, 7), chunktype=numpy.ndarray>
Coordinates:
  * x        (x) int64 0 4 8 12 16 20 24
  * y        (y) int64 1 0 -1
Attributes:
    attr1:    value1
    attr2:    2929

The relevant code inside idxmin/idxmax:

# This will run argmin or argmax.
# indx will be dask if array is dask
indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna)

# separated out for debugging
# with the current test layout coords will not be dask since the array's coords are not dask
coords = array[dim]

# try to make dask of coords as per @shoyer's suggestion 
# the below fails silently, but cannot be forced even with trying to 
# do something like dask.array.asarray(), this errors out with
# "Cannot assign to the .data attribute of dimension coordinate a.k.a IndexVariable 'x'"
if isinstance(indx.data, dask_array_type):
    coords = coords.chunk({})

res = coords[(indx,)]

It seems that the map-blocks approach is the only one which seem to work throughout the tests, but one. It fails with array set as array.astype("object") in the test fixture. Reason: dask gets computed within argmin/argmax.

I'll revert to the map-blocks now and add the with raise_if_dask_computes() , context where neccessary. Any hints appreciated for the dask compute-error.

@kmuehlbauer
Copy link
Contributor Author

Error log of the compute error:
https://dev.azure.com/xarray/xarray/_build/results?buildId=2629&view=logs&j=78b48a04-306f-5a15-9ac3-dd2fdb28db5e&t=5160aa4e-6217-5012-6424-4f17180b374b&l=412

xarray/tests/test_dask.py:52: RuntimeError
_______ TestReduce2D.test_idxmax[True-x2-minindex2-maxindex2-nanindex2] ________

self = <xarray.tests.test_dataarray.TestReduce2D object at 0x7fbf71ecba58>
xarray/core/variable.py:1579: in reduce
    data = func(input_data, axis=axis, **kwargs)
xarray/core/duck_array_ops.py:304: in f
    return func(values, axis=axis, **kwargs)
xarray/core/nanops.py:104: in nanargmax
    return _nan_argminmax_object("argmax", fill_value, a, axis=axis)
xarray/core/nanops.py:57: in _nan_argminmax_object
    if (valid_count == 0).any():
/usr/share/miniconda/envs/xarray-tests/lib/python3.6/site-packages/dask/array/core.py:1375: in __bool__
    return bool(self.compute())
/usr/share/miniconda/envs/xarray-tests/lib/python3.6/site-packages/dask/base.py:175: in compute
    (result,) = compute(self, traverse=False, **kwargs)
/usr/share/miniconda/envs/xarray-tests/lib/python3.6/site-packages/dask/base.py:446: in compute
    results = schedule(dsk, keys, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <xarray.tests.test_dask.CountingScheduler object at 0x7fbf718e9160>
dsk = {('all-all-aggregate-any-getitem-invert-sum-sum-aggregate-any-aggregate-a1eb5c32ca3225c37fddde3d1aa0d2f0',): (Compose(... -4.0, 2.0],
       [-4.0, nan, 2.0, nan, -2.0, -4.0, 2.0],
       [nan, nan, nan, nan, nan, nan, nan]], dtype=object)}
keys = [[('any-aggregate-a1eb5c32ca3225c37fddde3d1aa0d2f0',)]], kwargs = {}

    def __call__(self, dsk, keys, **kwargs):
        self.total_computes += 1
        if self.total_computes > self.max_computes:
            raise RuntimeError(
                "Too many computes. Total: %d > max: %d."
>               % (self.total_computes, self.max_computes)
            )
E           RuntimeError: Too many computes. Total: 1 > max: 0.

@dcherian
Copy link
Contributor

Yeah interestingly we don't raise an error when trying to chunk IndexVariables.

I've pushed a commit where we extract the underlying numpy array, chunk that, index it and then wrap it up in a DataArray o_O.

@dcherian
Copy link
Contributor

dcherian commented Apr 18, 2020

The compute error is from here:

def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs):
""" In house nanargmin, nanargmax for object arrays. Always return integer
type
"""
valid_count = count(value, axis=axis)
value = fillna(value, fill_value)
data = _dask_or_eager_func(func)(value, axis=axis, **kwargs)
# TODO This will evaluate dask arrays and might be costly.
if (valid_count == 0).any():
raise ValueError("All-NaN slice encountered")
return data

I think we'll have to rethink the skipna conditions for dask arrays so that the compute doesn't happen.

Or figure out why we do this check in the first place. hmm...

@dcherian dcherian mentioned this pull request May 5, 2020
23 tasks
…k-issues

* upstream/master: (22 commits)
  support darkmode (pydata#4036)
  Use literal syntax instead of function calls to create the data structure (pydata#4038)
  Add template xarray object kwarg to map_blocks (pydata#3816)
  Transpose coords by default (pydata#3824)
  Remove broken test for Panel with to_pandas() (pydata#4028)
  Allow warning with cartopy in docs plotting build (pydata#4032)
  Support overriding existing variables in to_zarr() without appending (pydata#4029)
  chore: Remove unnecessary comprehension (pydata#4026)
  fix to_netcdf docstring typo (pydata#4021)
  Pint support for DataArray (pydata#3643)
  Apply blackdoc to the documentation (pydata#4012)
  ensure Variable._repr_html_ works (pydata#3973)
  Fix handling of abbreviated units like msec (pydata#3998)
  full_like: error on non-scalar fill_value (pydata#3979)
  Fix some code quality and bug-risk issues (pydata#3999)
  DOC: add pandas.DataFrame.to_xarray (pydata#3994)
  Better chunking error messages for zarr backend (pydata#3983)
  Silence sphinx warnings (pydata#3990)
  Fix distributed tests on upstream-dev (pydata#3989)
  Add multi-dimensional extrapolation example and mention different behavior of kwargs in interp (pydata#3956)
  ...
@dcherian
Copy link
Contributor

dcherian commented May 9, 2020

The test fails for object arrays because we compute eagerly in nanops._nan_argminmax_object to raise an error for all-NaN slices.

To solve this we could

  1. fix Allow for All-NaN in argmax, argmin #3884 so that nanargmin, nanargmax never raise an error for all-NaN slices.
  2. Figure out some clever way to raise the error at compute time rather than graph construction time.

For now, I bumped up max_computes to 1 for object arrays.

@dcherian dcherian requested a review from shoyer May 9, 2020 14:05
@kmuehlbauer
Copy link
Contributor Author

Thanks @dcherian for getting back to this. To my bad, this adventure went too far for my capabilities. Nevertheless I hope to catch up learning xarray inside.

@dcherian dcherian merged commit c73e958 into pydata:master May 13, 2020
dcherian added a commit that referenced this pull request May 25, 2020
* Added chunks='auto' option in dataset.py

* FIX: correct dask array handling in _calc_idxminmax (#3922)

* FIX: correct dask array handling in _calc_idxminmax

* FIX: remove unneeded import, reformat via black

* fix idxmax, idxmin with dask arrays

* FIX: use array[dim].data in `_calc_idxminmax` as per @keewis suggestion, attach dim name to result

* ADD: add dask tests to `idxmin`/`idxmax` dataarray tests

* FIX: add back fixture line removed by accident

* ADD: complete dask handling in `idxmin`/`idxmax` tests in test_dataarray, xfail dask tests for dtype dateime64 (M)

* ADD: add "support dask handling for idxmin/idxmax" in whats-new.rst

* MIN: reintroduce changes added by #3953

* MIN: change if-clause to use `and` instead of `&` as per review-comment

* MIN: change if-clause to use `and` instead of `&` as per review-comment

* WIP: remove dask handling entirely for debugging purposes

* Test for dask computes

* WIP: re-add dask handling (map_blocks-approach), add `with raise_if_dask_computes()` context to idxmin-tests

* Use dask indexing instead of map_blocks.

* Better chunk choice.

* Return -1 for _nan_argminmax_object if all NaNs along dim

* Revert "Return -1 for _nan_argminmax_object if all NaNs along dim"

This reverts commit 58901b9.

* Raise error for object arrays

* No error for object arrays. Instead expect 1 compute in tests.

Co-authored-by: dcherian <deepak@cherian.net>

* fix the failing flake8 CI (#4057)

* rename d and l to dim and length

* Fixed typo in rasterio docs (#4063)

* Added chunks='auto' option in dataset.py

Added changes to whats-new.rst

* Added chunks='auto' option in dataset.py

Added changes to whats-new.rst

* Error fix, catch chunks=None

* Minor reformatting + flake8 changes

* Added isinstance(chunks, (Number, str)) in dataset.py, passing

* format changes

* added auto-chunk test for dataarrays

* Assert chunk sizes equal in auto-chunk test

Co-authored-by: Kai Mühlbauer <kmuehlbauer@users.noreply.github.com>
Co-authored-by: dcherian <deepak@cherian.net>
Co-authored-by: keewis <keewis@users.noreply.github.com>
Co-authored-by: clausmichele <31700619+clausmichele@users.noreply.github.com>
Co-authored-by: Keewis <keewis@posteo.de>
@kmuehlbauer kmuehlbauer deleted the fix-idxminmax-dask-issues branch May 25, 2023 07:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants