From 8e7fb5336f9dbea00bc4bc1c7cb9bc8dfe74d3eb Mon Sep 17 00:00:00 2001 From: John Omotani Date: Sun, 5 Apr 2020 14:07:11 +0100 Subject: [PATCH 01/44] DataArray.indices_min() and DataArray.indices_max() methods These return dicts of the indices of the minimum or maximum of a DataArray over several dimensions. --- xarray/core/dataarray.py | 220 ++++++++++ xarray/tests/test_dataarray.py | 777 +++++++++++++++++++++++++++++++++ 2 files changed, 997 insertions(+) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 63cba53b689..da6e9a6f815 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3726,6 +3726,226 @@ def idxmax( # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names str = property(StringAccessor) + def _unravel_argminmax( + self, + argminmax: Hashable, + dim: Union[Hashable, Sequence[Hashable], None], + keep_attrs: bool, + skipna: Optional[bool], + ) -> Dict[Hashable, "DataArray"]: + """Apply argmin or argmax over one or more dimensions, returning the result as a + dict of DataArray that can be passed directly to isel. + """ + if dim is None: + dim = self.dims + if isinstance(dim, Hashable): + dim = (dim,) + + # Get a name for the new dimension that does not conflict with any existing + # dimension + newdimname = "_unravel_argminmax_dim_0" + count = 1 + while newdimname in self.dims: + newdimname = "_unravel_argminmax_dim_{}".format(count) + count += 1 + + stacked = self.stack({newdimname: dim}) + + result_dims = stacked.dims[:-1] + reduce_shape = tuple(self.sizes[d] for d in dim) + + result_flat_indices = getattr(stacked, str(argminmax))(axis=-1, skipna=skipna) + + result_unravelled_indices = np.unravel_index(result_flat_indices, reduce_shape) + + result = { + d: DataArray(i, dims=result_dims) + for d, i in zip(dim, result_unravelled_indices) + } + + if keep_attrs: + for da in result.values(): + da.attrs = self.attrs + + return result + + def indices_min( + self, + dim: Union[Hashable, Sequence[Hashable]] = None, + keep_attrs: bool = False, + skipna: bool = None, + ) -> Dict[Hashable, "DataArray"]: + """Indices of the minimum of the DataArray over one or more dimensions. Result + returned as dict of DataArrays, which can be passed directly to isel(). + + If there are multiple minima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : hashable or sequence of hashable, optional + The dimensions over which to find the minimum. By default, finds minimum over + all dimensions. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : dict of DataArray + + See also + -------- + DataArray.argmin, DataArray.idxmin + + Examples + -------- + >>> array = xr.DataArray([0, 2, -1, 3], dims="x") + >>> array.min() + + array(-1) + >>> array.argmin() + + array(2) + >>> array.indices_min() + {'x': + array(2)} + >>> array.isel(array.indices_min()) + array(-1) + + >>> array = xr.DataArray([[[3, 2, 1], [3, 1, 2], [2, 1, 3]], + ... [[1, 3, 2], [2, -5, 1], [2, 3, 1]]], + ... dims=("x", "y", "z")) + >>> array.min(dim="x") + + array([[ 1, 2, 1], + [ 2, -5, 1], + [ 2, 1, 1]]) + Dimensions without coordinates: y, z + >>> array.argmin(dim="x") + + array([[1, 0, 0], + [1, 1, 1], + [0, 0, 1]]) + Dimensions without coordinates: y, z + >>> array.indices_min(dim="x") + {'x': + array([[1, 0, 0], + [1, 1, 1], + [0, 0, 1]]) + Dimensions without coordinates: y, z} + >>> array.min(dim=("x", "z")) + + array([ 1, -5, 1]) + Dimensions without coordinates: y + >>> array.indices_min(dim=("x", "z")) + {'x': + array([0, 1, 0]) + Dimensions without coordinates: y, 'z': + array([2, 1, 1]) + Dimensions without coordinates: y} + >>> array.isel(array.indices_min(dim=("x", "z"))) + + array([ 1, -5, 1]) + Dimensions without coordinates: y + """ + return self._unravel_argminmax("argmin", dim, keep_attrs, skipna) + + def indices_max( + self, + dim: Union[Hashable, Sequence[Hashable]] = None, + keep_attrs: bool = False, + skipna: bool = None, + ) -> Dict[Hashable, "DataArray"]: + """Indices of the maximum of the DataArray over one or more dimensions. Result + returned as dict of DataArrays, which can be passed directly to isel(). + + If there are multiple maxima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : hashable or sequence of hashable, optional + The dimensions over which to find the maximum. By default, finds maximum over + all dimensions. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : dict of DataArray + + See also + -------- + DataArray.argmax, DataArray.idxmax + + Examples + -------- + >>> array = xr.DataArray([0, 2, -1, 3], dims="x") + >>> array.max() + + array(3) + >>> array.argmax() + + array(3) + >>> array.indices_max() + {'x': + array(3)} + >>> array.isel(array.indices_max()) + + array(3) + + >>> array = xr.DataArray([[[3, 2, 1], [3, 1, 2], [2, 1, 3]], + ... [[1, 3, 2], [2, 5, 1], [2, 3, 1]]], + ... dims=("x", "y", "z")) + >>> array.max(dim="x") + + array([[3, 3, 2], + [3, 5, 2], + [2, 3, 3]]) + Dimensions without coordinates: y, z + >>> array.argmax(dim="x") + + array([[0, 1, 1], + [0, 1, 0], + [0, 1, 0]]) + Dimensions without coordinates: y, z + >>> array.indices_max(dim="x") + {'x': + array([[0, 1, 1], + [0, 1, 0], + [0, 1, 0]]) + Dimensions without coordinates: y, z} + >>> array.max(dim=("x", "z")) + + array([3, 5, 3]) + Dimensions without coordinates: y + >>> array.indices_max(dim=("x", "z")) + {'x': + array([0, 1, 0]) + Dimensions without coordinates: y, 'z': + array([0, 1, 2]) + Dimensions without coordinates: y} + >>> array.isel(array.indices_max(dim=("x", "z"))) + + array([3, 5, 3]) + Dimensions without coordinates: y + """ + return self._unravel_argminmax("argmax", dim, keep_attrs, skipna) + # priority most be higher than Variable to properly work with binary ufuncs ops.inject_all_ops_and_reduce_methods(DataArray, priority=60) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index cf31182ed30..e8b73a93c86 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -4714,6 +4714,72 @@ def test_idxmax(self, x, minindex, maxindex, nanindex): result7 = ar0.idxmax(fill_value=-1j) assert_identical(result7, expected7) + def test_indices_min(self, x, minindex, maxindex, nanindex): + ar = xr.DataArray( + x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs + ) + indarr = xr.DataArray(np.arange(x.size, dtype=np.intp), dims=["x"]) + + if np.isnan(minindex): + with pytest.raises(ValueError): + ar.argmin() + return + + expected0 = {"x": indarr[minindex]} + result0 = ar.indices_min() + for key in expected0: + assert_identical(result0[key], expected0[key]) + + result1 = ar.indices_min(keep_attrs=True) + expected1 = deepcopy(expected0) + for da in expected1.values(): + da.attrs = self.attrs + for key in expected1: + assert_identical(result1[key], expected1[key]) + + result2 = ar.indices_min(skipna=False) + if nanindex is not None and ar.dtype.kind != "O": + expected2 = {"x": indarr.isel(x=nanindex, drop=True)} + expected2["x"].attrs = {} + else: + expected2 = expected0 + + for key in expected2: + assert_identical(result2[key], expected2[key]) + + def test_indices_max(self, x, minindex, maxindex, nanindex): + ar = xr.DataArray( + x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs + ) + indarr = xr.DataArray(np.arange(x.size, dtype=np.intp), dims=["x"]) + + if np.isnan(minindex): + with pytest.raises(ValueError): + ar.argmin() + return + + expected0 = {"x": indarr[maxindex]} + result0 = ar.indices_max() + for key in expected0: + assert_identical(result0[key], expected0[key]) + + result1 = ar.indices_max(keep_attrs=True) + expected1 = deepcopy(expected0) + for da in expected1.values(): + da.attrs = self.attrs + for key in expected1: + assert_identical(result1[key], expected1[key]) + + result2 = ar.indices_max(skipna=False) + if nanindex is not None and ar.dtype.kind != "O": + expected2 = {"x": indarr.isel(x=nanindex, drop=True)} + expected2["x"].attrs = {} + else: + expected2 = expected0 + + for key in expected2: + assert_identical(result2[key], expected2[key]) + @pytest.mark.parametrize( "x, minindex, maxindex, nanindex", @@ -5157,6 +5223,717 @@ def test_idxmax(self, x, minindex, maxindex, nanindex): result7 = ar0.idxmax(dim="x", fill_value=-5j) assert_identical(result7, expected7) + def test_indices_min(self, x, minindex, maxindex, nanindex): + ar = xr.DataArray( + x, + dims=["y", "x"], + coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, + attrs=self.attrs, + ) + indarr = np.tile(np.arange(x.shape[1], dtype=np.intp), [x.shape[0], 1]) + indarr = xr.DataArray(indarr, dims=ar.dims, coords=ar.coords) + + if np.isnan(minindex).any(): + with pytest.raises(ValueError): + ar.indices_min(dim="x") + return + + expected0 = [ + indarr.isel(y=yi, drop=True).isel(x=indi, drop=True) + for yi, indi in enumerate(minindex) + ] + expected0 = {"x": xr.concat(expected0, dim="y")} + + result0 = ar.indices_min(dim="x") + for key in expected0: + assert_identical(result0[key], expected0[key]) + + result1 = ar.indices_min(dim="x", keep_attrs=True) + expected1 = deepcopy(expected0) + expected1["x"].attrs = self.attrs + for key in expected1: + assert_identical(result1[key], expected1[key]) + + minindex = [ + x if y is None or ar.dtype.kind == "O" else y + for x, y in zip(minindex, nanindex) + ] + expected2 = [ + indarr.isel(y=yi, drop=True).isel(x=indi, drop=True) + for yi, indi in enumerate(minindex) + ] + expected2 = {"x": xr.concat(expected2, dim="y")} + expected2["x"].attrs = {} + + result2 = ar.indices_min(dim="x", skipna=False) + + for key in expected2: + assert_identical(result2[key], expected2[key]) + + result3 = ar.indices_min() + min_xind = ar.isel(expected0).argmin() + expected3 = { + "y": DataArray(min_xind), + "x": DataArray(minindex[min_xind.item()]), + } + + for key in expected3: + assert_identical(result3[key], expected3[key]) + + def test_indices_max(self, x, minindex, maxindex, nanindex): + ar = xr.DataArray( + x, + dims=["y", "x"], + coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, + attrs=self.attrs, + ) + indarr = np.tile(np.arange(x.shape[1], dtype=np.intp), [x.shape[0], 1]) + indarr = xr.DataArray(indarr, dims=ar.dims, coords=ar.coords) + + if np.isnan(maxindex).any(): + with pytest.raises(ValueError): + ar.indices_max(dim="x") + return + + expected0 = [ + indarr.isel(y=yi, drop=True).isel(x=indi, drop=True) + for yi, indi in enumerate(maxindex) + ] + expected0 = {"x": xr.concat(expected0, dim="y")} + + result0 = ar.indices_max(dim="x") + for key in expected0: + assert_identical(result0[key], expected0[key]) + + result1 = ar.indices_max(dim="x", keep_attrs=True) + expected1 = deepcopy(expected0) + expected1["x"].attrs = self.attrs + for key in expected1: + assert_identical(result1[key], expected1[key]) + + maxindex = [ + x if y is None or ar.dtype.kind == "O" else y + for x, y in zip(maxindex, nanindex) + ] + expected2 = [ + indarr.isel(y=yi, drop=True).isel(x=indi, drop=True) + for yi, indi in enumerate(maxindex) + ] + expected2 = {"x": xr.concat(expected2, dim="y")} + expected2["x"].attrs = {} + + result2 = ar.indices_max(dim="x", skipna=False) + + for key in expected2: + assert_identical(result2[key], expected2[key]) + + result3 = ar.indices_max() + max_xind = ar.isel(expected0).argmax() + expected3 = { + "y": DataArray(max_xind), + "x": DataArray(maxindex[max_xind.item()]), + } + + for key in expected3: + assert_identical(result3[key], expected3[key]) + + +@pytest.mark.parametrize( + "x, minindices_x, minindices_y, minindices_z, minindices_xy, " + "minindices_xz, minindices_yz, minindices_xyz, maxindices_x, " + "maxindices_y, maxindices_z, maxindices_xy, maxindices_xz, maxindices_yz, " + "maxindices_xyz, nanindices_x, nanindices_y, nanindices_z, nanindices_xy, " + "nanindices_xz, nanindices_yz, nanindices_xyz", + [ + ( + np.array( + [ + [[0, 1, 2, 0], [-2, -4, 2, 0]], + [[1, 1, 1, 1], [1, 1, 1, 1]], + [[0, 0, -10, 5], [20, 0, 0, 0]], + ] + ), + {"x": np.array([[0, 2, 2, 0], [0, 0, 2, 0]])}, + {"y": np.array([[1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 1]])}, + {"z": np.array([[0, 1], [0, 0], [2, 1]])}, + {"x": np.array([0, 0, 2, 0]), "y": np.array([1, 1, 0, 0])}, + {"x": np.array([2, 0]), "z": np.array([2, 1])}, + {"y": np.array([1, 0, 0]), "z": np.array([1, 0, 2])}, + {"x": np.array(2), "y": np.array(0), "z": np.array(2)}, + {"x": np.array([[1, 0, 0, 2], [2, 1, 0, 1]])}, + {"y": np.array([[0, 0, 0, 0], [0, 0, 0, 0], [1, 0, 1, 0]])}, + {"z": np.array([[2, 2], [0, 0], [3, 0]])}, + {"x": np.array([2, 0, 0, 2]), "y": np.array([1, 0, 0, 0])}, + {"x": np.array([2, 2]), "z": np.array([3, 0])}, + {"y": np.array([0, 0, 1]), "z": np.array([2, 0, 0])}, + {"x": np.array(2), "y": np.array(1), "z": np.array(0)}, + {"x": np.array([[None, None, None, None], [None, None, None, None]])}, + { + "y": np.array( + [ + [None, None, None, None], + [None, None, None, None], + [None, None, None, None], + ] + ) + }, + {"z": np.array([[None, None], [None, None], [None, None]])}, + { + "x": np.array([None, None, None, None]), + "y": np.array([None, None, None, None]), + }, + {"x": np.array([None, None]), "z": np.array([None, None])}, + {"y": np.array([None, None, None]), "z": np.array([None, None, None])}, + {"x": np.array(None), "y": np.array(None), "z": np.array(None)}, + ), + ( + np.array( + [ + [[2.0, 1.0, 2.0, 0.0], [-2.0, -4.0, 2.0, 0.0]], + [[-4.0, np.NaN, 2.0, np.NaN], [-2.0, -4.0, 2.0, 0.0]], + [[np.NaN] * 4, [np.NaN] * 4], + ] + ), + {"x": np.array([[1, 0, 0, 0], [0, 0, 0, 0]])}, + { + "y": np.array( + [[1, 1, 0, 0], [0, 1, 0, 1], [np.NaN, np.NaN, np.NaN, np.NaN]] + ) + }, + {"z": np.array([[3, 1], [0, 1], [np.NaN, np.NaN]])}, + {"x": np.array([1, 0, 0, 0]), "y": np.array([0, 1, 0, 0])}, + {"x": np.array([1, 0]), "z": np.array([0, 1])}, + {"y": np.array([1, 0, np.NaN]), "z": np.array([1, 0, np.NaN])}, + {"x": np.array(0), "y": np.array(1), "z": np.array(1)}, + {"x": np.array([[0, 0, 0, 0], [0, 0, 0, 0]])}, + { + "y": np.array( + [[0, 0, 0, 0], [1, 1, 0, 1], [np.NaN, np.NaN, np.NaN, np.NaN]] + ) + }, + {"z": np.array([[0, 2], [2, 2], [np.NaN, np.NaN]])}, + {"x": np.array([0, 0, 0, 0]), "y": np.array([0, 0, 0, 0])}, + {"x": np.array([0, 0]), "z": np.array([2, 2])}, + {"y": np.array([0, 0, np.NaN]), "z": np.array([0, 2, np.NaN])}, + {"x": np.array(0), "y": np.array(0), "z": np.array(0)}, + {"x": np.array([[2, 1, 2, 1], [2, 2, 2, 2]])}, + { + "y": np.array( + [[None, None, None, None], [None, 0, None, 0], [0, 0, 0, 0]] + ) + }, + {"z": np.array([[None, None], [1, None], [0, 0]])}, + {"x": np.array([2, 1, 2, 1]), "y": np.array([0, 0, 0, 0])}, + {"x": np.array([1, 2]), "z": np.array([1, 0])}, + {"y": np.array([None, 0, 0]), "z": np.array([None, 1, 0])}, + {"x": np.array(1), "y": np.array(0), "z": np.array(1)}, + ), + ( + np.array( + [ + [[2.0, 1.0, 2.0, 0.0], [-2.0, -4.0, 2.0, 0.0]], + [[-4.0, np.NaN, 2.0, np.NaN], [-2.0, -4.0, 2.0, 0.0]], + [[np.NaN] * 4, [np.NaN] * 4], + ] + ).astype("object"), + {"x": np.array([[1, 0, 0, 0], [0, 0, 0, 0]])}, + { + "y": np.array( + [[1, 1, 0, 0], [0, 1, 0, 1], [np.NaN, np.NaN, np.NaN, np.NaN]] + ) + }, + {"z": np.array([[3, 1], [0, 1], [np.NaN, np.NaN]])}, + {"x": np.array([1, 0, 0, 0]), "y": np.array([0, 1, 0, 0])}, + {"x": np.array([1, 0]), "z": np.array([0, 1])}, + {"y": np.array([1, 0, np.NaN]), "z": np.array([1, 0, np.NaN])}, + {"x": np.array(0), "y": np.array(1), "z": np.array(1)}, + {"x": np.array([[0, 0, 0, 0], [0, 0, 0, 0]])}, + { + "y": np.array( + [[0, 0, 0, 0], [1, 1, 0, 1], [np.NaN, np.NaN, np.NaN, np.NaN]] + ) + }, + {"z": np.array([[0, 2], [2, 2], [np.NaN, np.NaN]])}, + {"x": np.array([0, 0, 0, 0]), "y": np.array([0, 0, 0, 0])}, + {"x": np.array([0, 0]), "z": np.array([2, 2])}, + {"y": np.array([0, 0, np.NaN]), "z": np.array([0, 2, np.NaN])}, + {"x": np.array(0), "y": np.array(0), "z": np.array(0)}, + {"x": np.array([[2, 1, 2, 1], [2, 2, 2, 2]])}, + { + "y": np.array( + [[None, None, None, None], [None, 0, None, 0], [0, 0, 0, 0]] + ) + }, + {"z": np.array([[None, None], [1, None], [0, 0]])}, + {"x": np.array([2, 1, 2, 1]), "y": np.array([0, 0, 0, 0])}, + {"x": np.array([1, 2]), "z": np.array([1, 0])}, + {"y": np.array([None, 0, 0]), "z": np.array([None, 1, 0])}, + {"x": np.array(1), "y": np.array(0), "z": np.array(1)}, + ), + ( + np.array( + [ + [["2015-12-31", "2020-01-02"], ["2020-01-01", "2016-01-01"]], + [["2020-01-02", "2020-01-02"], ["2020-01-02", "2020-01-02"]], + [["1900-01-01", "1-02-03"], ["1900-01-02", "1-02-03"]], + ], + dtype="datetime64[ns]", + ), + {"x": np.array([[2, 2], [2, 2]])}, + {"y": np.array([[0, 1], [0, 0], [0, 0]])}, + {"z": np.array([[0, 1], [0, 0], [1, 1]])}, + {"x": np.array([2, 2]), "y": np.array([0, 0])}, + {"x": np.array([2, 2]), "z": np.array([1, 1])}, + {"y": np.array([0, 0, 0]), "z": np.array([0, 0, 1])}, + {"x": np.array(2), "y": np.array(0), "z": np.array(1)}, + {"x": np.array([[2, 2], [2, 2]])}, + {"y": np.array([[0, 1], [0, 0], [0, 0]])}, + {"z": np.array([[1, 0], [0, 0], [1, 1]])}, + {"x": np.array([2, 2]), "y": np.array([0, 0])}, + {"x": np.array([2, 1]), "z": np.array([2, 1])}, + {"y": np.array([0, 0, 0]), "z": np.array([0, 0, 1])}, + {"x": np.array(2), "y": np.array(0), "z": np.array(1)}, + {"x": np.array([[None, None], [None, None]])}, + {"y": np.array([[None, None], [None, None], [None, None]])}, + {"z": np.array([[None, None], [None, None], [None, None]])}, + {"x": np.array([None, None]), "y": np.array([None, None])}, + {"x": np.array([None, None]), "z": np.array([None, None])}, + {"y": np.array([None, None, None]), "z": np.array([None, None, None])}, + {"x": np.array(None), "y": np.array(None), "z": np.array(None)}, + ), + ], +) +class TestReduce3D(TestReduce): + def test_indices_min( + self, + x, + minindices_x, + minindices_y, + minindices_z, + minindices_xy, + minindices_xz, + minindices_yz, + minindices_xyz, + maxindices_x, + maxindices_y, + maxindices_z, + maxindices_xy, + maxindices_xz, + maxindices_yz, + maxindices_xyz, + nanindices_x, + nanindices_y, + nanindices_z, + nanindices_xy, + nanindices_xz, + nanindices_yz, + nanindices_xyz, + ): + + ar = xr.DataArray( + x, + dims=["x", "y", "z"], + coords={ + "x": np.arange(x.shape[0]) * 4, + "y": 1 - np.arange(x.shape[1]), + "z": 2 + 3 * np.arange(x.shape[2]), + }, + attrs=self.attrs, + ) + xindarr = np.tile( + np.arange(x.shape[0], dtype=np.intp)[:, np.newaxis, np.newaxis], + [1, x.shape[1], x.shape[2]], + ) + xindarr = xr.DataArray(xindarr, dims=ar.dims, coords=ar.coords) + yindarr = np.tile( + np.arange(x.shape[1], dtype=np.intp)[np.newaxis, :, np.newaxis], + [x.shape[0], 1, x.shape[2]], + ) + yindarr = xr.DataArray(yindarr, dims=ar.dims, coords=ar.coords) + zindarr = np.tile( + np.arange(x.shape[2], dtype=np.intp)[np.newaxis, np.newaxis, :], + [x.shape[0], x.shape[1], 1], + ) + zindarr = xr.DataArray(zindarr, dims=ar.dims, coords=ar.coords) + + for inds in [ + minindices_x, + minindices_y, + minindices_z, + minindices_xy, + minindices_xz, + minindices_yz, + minindices_xyz, + ]: + if np.array([np.isnan(i) for i in inds.values()]).any(): + with pytest.raises(ValueError): + ar.indices_min(dim=[d for d in inds]) + return + + result0 = ar.indices_min(dim="x") + expected0 = { + key: xr.DataArray(value, dims=("y", "z")) + for key, value in minindices_x.items() + } + for key in expected0: + assert_identical(result0[key], expected0[key]) + + result1 = ar.indices_min(dim="y") + expected1 = { + key: xr.DataArray(value, dims=("x", "z")) + for key, value in minindices_y.items() + } + for key in expected1: + assert_identical(result1[key], expected1[key]) + + result2 = ar.indices_min(dim="z") + expected2 = { + key: xr.DataArray(value, dims=("x", "y")) + for key, value in minindices_z.items() + } + for key in expected2: + assert_identical(result2[key], expected2[key]) + + result3 = ar.indices_min(dim=("x", "y")) + expected3 = { + key: xr.DataArray(value, dims=("z")) for key, value in minindices_xy.items() + } + for key in expected3: + assert_identical(result3[key], expected3[key]) + + result4 = ar.indices_min(dim=("x", "z")) + expected4 = { + key: xr.DataArray(value, dims=("y")) for key, value in minindices_xz.items() + } + for key in expected4: + assert_identical(result4[key], expected4[key]) + + result5 = ar.indices_min(dim=("y", "z")) + expected5 = { + key: xr.DataArray(value, dims=("x")) for key, value in minindices_yz.items() + } + for key in expected5: + assert_identical(result5[key], expected5[key]) + + result6 = ar.indices_min() + expected6 = {key: xr.DataArray(value) for key, value in minindices_xyz.items()} + for key in expected6: + assert_identical(result6[key], expected6[key]) + + minindices_x = { + key: xr.where( + nanindices_x[key] == None, minindices_x[key], nanindices_x[key], + ) # noqa: E711 + for key in minindices_x + } + expected7 = { + key: xr.DataArray(value, dims=("y", "z")) + for key, value in minindices_x.items() + } + + result7 = ar.indices_min(dim="x", skipna=False) + for key in expected7: + assert_identical(result7[key], expected7[key]) + + minindices_y = { + key: xr.where( + nanindices_y[key] == None, minindices_y[key], nanindices_y[key], + ) # noqa: E711 + for key in minindices_y + } + expected8 = { + key: xr.DataArray(value, dims=("x", "z")) + for key, value in minindices_y.items() + } + + result8 = ar.indices_min(dim="y", skipna=False) + for key in expected8: + assert_identical(result8[key], expected8[key]) + + minindices_z = { + key: xr.where( + nanindices_z[key] == None, minindices_z[key], nanindices_z[key], + ) # noqa: E711 + for key in minindices_z + } + expected9 = { + key: xr.DataArray(value, dims=("x", "y")) + for key, value in minindices_z.items() + } + + result9 = ar.indices_min(dim="z", skipna=False) + for key in expected9: + assert_identical(result9[key], expected9[key]) + + minindices_xy = { + key: xr.where( + nanindices_xy[key] == None, minindices_xy[key], nanindices_xy[key], + ) # noqa: E711 + for key in minindices_xy + } + expected10 = { + key: xr.DataArray(value, dims="z") for key, value in minindices_xy.items() + } + + result10 = ar.indices_min(dim=("x", "y"), skipna=False) + for key in expected10: + assert_identical(result10[key], expected10[key]) + + minindices_xz = { + key: xr.where( + nanindices_xz[key] == None, minindices_xz[key], nanindices_xz[key], + ) # noqa: E711 + for key in minindices_xz + } + expected11 = { + key: xr.DataArray(value, dims="y") for key, value in minindices_xz.items() + } + + result11 = ar.indices_min(dim=("x", "z"), skipna=False) + for key in expected11: + assert_identical(result11[key], expected11[key]) + + minindices_yz = { + key: xr.where( + nanindices_yz[key] == None, minindices_yz[key], nanindices_yz[key], + ) # noqa: E711 + for key in minindices_yz + } + expected12 = { + key: xr.DataArray(value, dims="x") for key, value in minindices_yz.items() + } + + result12 = ar.indices_min(dim=("y", "z"), skipna=False) + for key in expected12: + assert_identical(result12[key], expected12[key]) + + minindices_xyz = { + key: xr.where( + nanindices_xyz[key] == None, minindices_xyz[key], nanindices_xyz[key], + ) # noqa: E711 + for key in minindices_xyz + } + expected13 = {key: xr.DataArray(value) for key, value in minindices_xyz.items()} + + result13 = ar.indices_min(skipna=False) + for key in expected13: + assert_identical(result13[key], expected13[key]) + + def test_indices_max( + self, + x, + minindices_x, + minindices_y, + minindices_z, + minindices_xy, + minindices_xz, + minindices_yz, + minindices_xyz, + maxindices_x, + maxindices_y, + maxindices_z, + maxindices_xy, + maxindices_xz, + maxindices_yz, + maxindices_xyz, + nanindices_x, + nanindices_y, + nanindices_z, + nanindices_xy, + nanindices_xz, + nanindices_yz, + nanindices_xyz, + ): + + ar = xr.DataArray( + x, + dims=["x", "y", "z"], + coords={ + "x": np.arange(x.shape[0]) * 4, + "y": 1 - np.arange(x.shape[1]), + "z": 2 + 3 * np.arange(x.shape[2]), + }, + attrs=self.attrs, + ) + xindarr = np.tile( + np.arange(x.shape[0], dtype=np.intp)[:, np.newaxis, np.newaxis], + [1, x.shape[1], x.shape[2]], + ) + xindarr = xr.DataArray(xindarr, dims=ar.dims, coords=ar.coords) + yindarr = np.tile( + np.arange(x.shape[1], dtype=np.intp)[np.newaxis, :, np.newaxis], + [x.shape[0], 1, x.shape[2]], + ) + yindarr = xr.DataArray(yindarr, dims=ar.dims, coords=ar.coords) + zindarr = np.tile( + np.arange(x.shape[2], dtype=np.intp)[np.newaxis, np.newaxis, :], + [x.shape[0], x.shape[1], 1], + ) + zindarr = xr.DataArray(zindarr, dims=ar.dims, coords=ar.coords) + + for inds in [ + maxindices_x, + maxindices_y, + maxindices_z, + maxindices_xy, + maxindices_xz, + maxindices_yz, + maxindices_xyz, + ]: + if np.array([np.isnan(i) for i in inds.values()]).any(): + with pytest.raises(ValueError): + ar.indices_max(dim=[d for d in inds]) + return + + result0 = ar.indices_max(dim="x") + expected0 = { + key: xr.DataArray(value, dims=("y", "z")) + for key, value in maxindices_x.items() + } + for key in expected0: + assert_identical(result0[key], expected0[key]) + + result1 = ar.indices_max(dim="y") + expected1 = { + key: xr.DataArray(value, dims=("x", "z")) + for key, value in maxindices_y.items() + } + for key in expected1: + assert_identical(result1[key], expected1[key]) + + result2 = ar.indices_max(dim="z") + expected2 = { + key: xr.DataArray(value, dims=("x", "y")) + for key, value in maxindices_z.items() + } + for key in expected2: + assert_identical(result2[key], expected2[key]) + + result3 = ar.indices_max(dim=("x", "y")) + expected3 = { + key: xr.DataArray(value, dims=("z")) for key, value in maxindices_xy.items() + } + for key in expected3: + assert_identical(result3[key], expected3[key]) + + result4 = ar.indices_max(dim=("x", "z")) + expected4 = { + key: xr.DataArray(value, dims=("y")) for key, value in maxindices_xz.items() + } + for key in expected4: + assert_identical(result4[key], expected4[key]) + + result5 = ar.indices_max(dim=("y", "z")) + expected5 = { + key: xr.DataArray(value, dims=("x")) for key, value in maxindices_yz.items() + } + for key in expected5: + assert_identical(result5[key], expected5[key]) + + result6 = ar.indices_max() + expected6 = {key: xr.DataArray(value) for key, value in maxindices_xyz.items()} + for key in expected6: + assert_identical(result6[key], expected6[key]) + + maxindices_x = { + key: xr.where( + nanindices_x[key] == None, maxindices_x[key], nanindices_x[key], + ) # noqa: E711 + for key in maxindices_x + } + expected7 = { + key: xr.DataArray(value, dims=("y", "z")) + for key, value in maxindices_x.items() + } + + result7 = ar.indices_max(dim="x", skipna=False) + for key in expected7: + assert_identical(result7[key], expected7[key]) + + maxindices_y = { + key: xr.where( + nanindices_y[key] == None, maxindices_y[key], nanindices_y[key], + ) # noqa: E711 + for key in maxindices_y + } + expected8 = { + key: xr.DataArray(value, dims=("x", "z")) + for key, value in maxindices_y.items() + } + + result8 = ar.indices_max(dim="y", skipna=False) + for key in expected8: + assert_identical(result8[key], expected8[key]) + + maxindices_z = { + key: xr.where( + nanindices_z[key] == None, maxindices_z[key], nanindices_z[key], + ) # noqa: E711 + for key in maxindices_z + } + expected9 = { + key: xr.DataArray(value, dims=("x", "y")) + for key, value in maxindices_z.items() + } + + result9 = ar.indices_max(dim="z", skipna=False) + for key in expected9: + assert_identical(result9[key], expected9[key]) + + maxindices_xy = { + key: xr.where( + nanindices_xy[key] == None, maxindices_xy[key], nanindices_xy[key], + ) # noqa: E711 + for key in maxindices_xy + } + expected10 = { + key: xr.DataArray(value, dims="z") for key, value in maxindices_xy.items() + } + + result10 = ar.indices_max(dim=("x", "y"), skipna=False) + for key in expected10: + assert_identical(result10[key], expected10[key]) + + maxindices_xz = { + key: xr.where( + nanindices_xz[key] == None, maxindices_xz[key], nanindices_xz[key], + ) # noqa: E711 + for key in maxindices_xz + } + expected11 = { + key: xr.DataArray(value, dims="y") for key, value in maxindices_xz.items() + } + + result11 = ar.indices_max(dim=("x", "z"), skipna=False) + for key in expected11: + assert_identical(result11[key], expected11[key]) + + maxindices_yz = { + key: xr.where( + nanindices_yz[key] == None, maxindices_yz[key], nanindices_yz[key], + ) # noqa: E711 + for key in maxindices_yz + } + expected12 = { + key: xr.DataArray(value, dims="x") for key, value in maxindices_yz.items() + } + + result12 = ar.indices_max(dim=("y", "z"), skipna=False) + for key in expected12: + assert_identical(result12[key], expected12[key]) + + maxindices_xyz = { + key: xr.where( + nanindices_xyz[key] == None, maxindices_xyz[key], nanindices_xyz[key], + ) # noqa: E711 + for key in maxindices_xyz + } + expected13 = {key: xr.DataArray(value) for key, value in maxindices_xyz.items()} + + result13 = ar.indices_max(skipna=False) + for key in expected13: + assert_identical(result13[key], expected13[key]) + @pytest.fixture(params=[1]) def da(request): From 2b06811783c1ef3f479e65d458a445ea4f52c491 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Sun, 5 Apr 2020 19:56:40 +0100 Subject: [PATCH 02/44] Update whats-new.rst and api.rst with indices_min(), indices_max() --- doc/api.rst | 2 ++ doc/whats-new.rst | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/doc/api.rst b/doc/api.rst index b37c84e7a81..c96a5079694 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -366,6 +366,8 @@ Computation :py:attr:`~DataArray.argmin` :py:attr:`~DataArray.idxmax` :py:attr:`~DataArray.idxmin` +:py:attr:`~DataArray.indices_max` +:py:attr:`~DataArray.indices_min` :py:attr:`~DataArray.max` :py:attr:`~DataArray.mean` :py:attr:`~DataArray.median` diff --git a/doc/whats-new.rst b/doc/whats-new.rst index dac63600679..3ef2b9e40c8 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -29,6 +29,11 @@ Breaking changes New Features ~~~~~~~~~~~~ +- Added :py:meth:`DataArray.indices_min` and :py:meth:`DataArray.indices_max` + to get a dict of the indices for each dimension of the minimum or maximum of + a DataArray. (:pull:`3936`) + By `John Omotani `_, thanks to `Keisuke Fujii + `_ for work in :pull:`1469`. - Added :py:meth:`DataArray.polyfit` and :py:func:`xarray.polyval` for fitting polynomials. (:issue:`3349`) By `Pascal Bourgault `_. - Control over attributes of result in :py:func:`merge`, :py:func:`concat`, From f6a966c4c4c669ee1fb24d10dff508fb6dd9d9f8 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Sun, 5 Apr 2020 20:43:34 +0100 Subject: [PATCH 03/44] Fix type checking in DataArray._unravel_argminmax() --- xarray/core/dataarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index da6e9a6f815..2671df6609d 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3738,7 +3738,7 @@ def _unravel_argminmax( """ if dim is None: dim = self.dims - if isinstance(dim, Hashable): + if not isinstance(dim, Sequence) or isinstance(dim, str): dim = (dim,) # Get a name for the new dimension that does not conflict with any existing From 4395e7a6220c131cefe0950f111b9633ebb53d68 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Sun, 5 Apr 2020 20:44:05 +0100 Subject: [PATCH 04/44] Fix expected results for TestReduce3D.test_indices_max() --- xarray/tests/test_dataarray.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index e8b73a93c86..4da8819b600 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -5486,13 +5486,13 @@ def test_indices_max(self, x, minindex, maxindex, nanindex): {"x": np.array([2, 2]), "z": np.array([1, 1])}, {"y": np.array([0, 0, 0]), "z": np.array([0, 0, 1])}, {"x": np.array(2), "y": np.array(0), "z": np.array(1)}, - {"x": np.array([[2, 2], [2, 2]])}, - {"y": np.array([[0, 1], [0, 0], [0, 0]])}, - {"z": np.array([[1, 0], [0, 0], [1, 1]])}, - {"x": np.array([2, 2]), "y": np.array([0, 0])}, - {"x": np.array([2, 1]), "z": np.array([2, 1])}, - {"y": np.array([0, 0, 0]), "z": np.array([0, 0, 1])}, - {"x": np.array(2), "y": np.array(0), "z": np.array(1)}, + {"x": np.array([[1, 0], [1, 1]])}, + {"y": np.array([[1, 0], [0, 0], [1, 0]])}, + {"z": np.array([[1, 0], [0, 0], [0, 0]])}, + {"x": np.array([1, 0]), "y": np.array([0, 0])}, + {"x": np.array([0, 1]), "z": np.array([1, 0])}, + {"y": np.array([0, 0, 1]), "z": np.array([1, 0, 0])}, + {"x": np.array(0), "y": np.array(0), "z": np.array(1)}, {"x": np.array([[None, None], [None, None]])}, {"y": np.array([[None, None], [None, None], [None, None]])}, {"z": np.array([[None, None], [None, None], [None, None]])}, From deee3f8765b923fa04ca4603fa3258a377dccbbb Mon Sep 17 00:00:00 2001 From: John Omotani Date: Mon, 6 Apr 2020 22:07:56 +0100 Subject: [PATCH 05/44] Respect global default for keep_attrs --- xarray/core/dataarray.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 2671df6609d..ee3dc749a20 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -55,7 +55,7 @@ from .indexes import Indexes, default_indexes, propagate_indexes from .indexing import is_fancy_indexer from .merge import PANDAS_TYPES, _extract_indexes_from_coords -from .options import OPTIONS +from .options import OPTIONS, _get_keep_attrs from .utils import Default, ReprObject, _check_inplace, _default, either_dict_or_kwargs from .variable import ( IndexVariable, @@ -3763,6 +3763,8 @@ def _unravel_argminmax( for d, i in zip(dim, result_unravelled_indices) } + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) if keep_attrs: for da in result.values(): da.attrs = self.attrs @@ -3772,7 +3774,7 @@ def _unravel_argminmax( def indices_min( self, dim: Union[Hashable, Sequence[Hashable]] = None, - keep_attrs: bool = False, + keep_attrs: bool = None, skipna: bool = None, ) -> Dict[Hashable, "DataArray"]: """Indices of the minimum of the DataArray over one or more dimensions. Result @@ -3860,7 +3862,7 @@ def indices_min( def indices_max( self, dim: Union[Hashable, Sequence[Hashable]] = None, - keep_attrs: bool = False, + keep_attrs: bool = None, skipna: bool = None, ) -> Dict[Hashable, "DataArray"]: """Indices of the maximum of the DataArray over one or more dimensions. Result From be8b26cae665c7229134b482c184b503eb79b5ee Mon Sep 17 00:00:00 2001 From: John Omotani Date: Mon, 6 Apr 2020 22:17:43 +0100 Subject: [PATCH 06/44] Merge behaviour of indices_min/indices_max into argmin/argmax When argmin or argmax are called with a sequence for 'dim', they now return a dict with the indices for each dimension in dim. --- doc/api.rst | 2 - doc/whats-new.rst | 8 ++- xarray/core/dataarray.py | 63 ++++++++++++++------ xarray/core/duck_array_ops.py | 4 +- xarray/core/ops.py | 4 +- xarray/core/rolling.py | 4 +- xarray/tests/test_dataarray.py | 104 ++++++++++++++++----------------- 7 files changed, 108 insertions(+), 81 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index c96a5079694..b37c84e7a81 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -366,8 +366,6 @@ Computation :py:attr:`~DataArray.argmin` :py:attr:`~DataArray.idxmax` :py:attr:`~DataArray.idxmin` -:py:attr:`~DataArray.indices_max` -:py:attr:`~DataArray.indices_min` :py:attr:`~DataArray.max` :py:attr:`~DataArray.mean` :py:attr:`~DataArray.median` diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 3ef2b9e40c8..25e635e7809 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -29,9 +29,11 @@ Breaking changes New Features ~~~~~~~~~~~~ -- Added :py:meth:`DataArray.indices_min` and :py:meth:`DataArray.indices_max` - to get a dict of the indices for each dimension of the minimum or maximum of - a DataArray. (:pull:`3936`) +- :py:meth:`DataArray.argmin` and :py:meth:`DataArray.argmax` now support + sequences of 'dim' arguments, and if a sequence is passed return a dict + (which can be passed to :py:meth:`isel` to get the value of the minimum) of + the indices for each dimension of the minimum or maximum of a DataArray. + (:pull:`3936`) By `John Omotani `_, thanks to `Keisuke Fujii `_ for work in :pull:`1469`. - Added :py:meth:`DataArray.polyfit` and :py:func:`xarray.polyval` for fitting polynomials. (:issue:`3349`) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index ee3dc749a20..1507c4505e7 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3730,16 +3730,35 @@ def _unravel_argminmax( self, argminmax: Hashable, dim: Union[Hashable, Sequence[Hashable], None], - keep_attrs: bool, + axis: Union[int, None], + keep_attrs: Optional[bool], skipna: Optional[bool], ) -> Dict[Hashable, "DataArray"]: """Apply argmin or argmax over one or more dimensions, returning the result as a dict of DataArray that can be passed directly to isel. """ - if dim is None: + if dim is None and axis is None: + warnings.warn( + "Behaviour of argmin/argmax with neither dim nor axis argument will " + "change to return a dict of indices of each dimension. To get a " + "single, flat index, please use np.argmin(da) or np.argmax(da) instead " + "of da.argmin() or da.argmax().", + DeprecationWarning, + ) + if dim is ...: + # In future, should do this also when (dim is None and axis is None) dim = self.dims - if not isinstance(dim, Sequence) or isinstance(dim, str): - dim = (dim,) + if ( + dim is None + or axis is not None + or not isinstance(dim, Sequence) + or isinstance(dim, str) + ): + # Return int index if single dimension is passed, and is not part of a + # sequence + return getattr(self, str(argminmax))( + dim=dim, axis=axis, keep_attrs=keep_attrs, skipna=skipna + ) # Get a name for the new dimension that does not conflict with any existing # dimension @@ -3771,9 +3790,10 @@ def _unravel_argminmax( return result - def indices_min( + def argmin( self, dim: Union[Hashable, Sequence[Hashable]] = None, + axis: Union[int, None] = None, keep_attrs: bool = None, skipna: bool = None, ) -> Dict[Hashable, "DataArray"]: @@ -3788,6 +3808,9 @@ def indices_min( dim : hashable or sequence of hashable, optional The dimensions over which to find the minimum. By default, finds minimum over all dimensions. + axis : int, optional + Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments + can be supplied. keep_attrs : bool, optional If True, the attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be @@ -3815,10 +3838,10 @@ def indices_min( >>> array.argmin() array(2) - >>> array.indices_min() + >>> array.argmin(...) {'x': array(2)} - >>> array.isel(array.indices_min()) + >>> array.isel(array.argmin(...)) array(-1) >>> array = xr.DataArray([[[3, 2, 1], [3, 1, 2], [2, 1, 3]], @@ -3836,7 +3859,7 @@ def indices_min( [1, 1, 1], [0, 0, 1]]) Dimensions without coordinates: y, z - >>> array.indices_min(dim="x") + >>> array.argmin(dim=["x"]) {'x': array([[1, 0, 0], [1, 1, 1], @@ -3846,22 +3869,23 @@ def indices_min( array([ 1, -5, 1]) Dimensions without coordinates: y - >>> array.indices_min(dim=("x", "z")) + >>> array.argmin(dim=["x", "z"]) {'x': array([0, 1, 0]) Dimensions without coordinates: y, 'z': array([2, 1, 1]) Dimensions without coordinates: y} - >>> array.isel(array.indices_min(dim=("x", "z"))) + >>> array.isel(array.argmin(dim=["x", "z"])) array([ 1, -5, 1]) Dimensions without coordinates: y """ - return self._unravel_argminmax("argmin", dim, keep_attrs, skipna) + return self._unravel_argminmax("_argmin_base", dim, axis, keep_attrs, skipna) - def indices_max( + def argmax( self, dim: Union[Hashable, Sequence[Hashable]] = None, + axis: Union[int, None] = None, keep_attrs: bool = None, skipna: bool = None, ) -> Dict[Hashable, "DataArray"]: @@ -3876,6 +3900,9 @@ def indices_max( dim : hashable or sequence of hashable, optional The dimensions over which to find the maximum. By default, finds maximum over all dimensions. + axis : int, optional + Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments + can be supplied. keep_attrs : bool, optional If True, the attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be @@ -3903,10 +3930,10 @@ def indices_max( >>> array.argmax() array(3) - >>> array.indices_max() + >>> array.argmax(...) {'x': array(3)} - >>> array.isel(array.indices_max()) + >>> array.isel(array.argmax(...)) array(3) @@ -3925,7 +3952,7 @@ def indices_max( [0, 1, 0], [0, 1, 0]]) Dimensions without coordinates: y, z - >>> array.indices_max(dim="x") + >>> array.argmax(dim=["x"]) {'x': array([[0, 1, 1], [0, 1, 0], @@ -3935,18 +3962,18 @@ def indices_max( array([3, 5, 3]) Dimensions without coordinates: y - >>> array.indices_max(dim=("x", "z")) + >>> array.argmax(dim=["x", "z"]) {'x': array([0, 1, 0]) Dimensions without coordinates: y, 'z': array([0, 1, 2]) Dimensions without coordinates: y} - >>> array.isel(array.indices_max(dim=("x", "z"))) + >>> array.isel(array.argmax(dim=["x", "z"])) array([3, 5, 3]) Dimensions without coordinates: y """ - return self._unravel_argminmax("argmax", dim, keep_attrs, skipna) + return self._unravel_argminmax("_argmax_base", dim, axis, keep_attrs, skipna) # priority most be higher than Variable to properly work with binary ufuncs diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 1340b456cf2..19dd180cb33 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -319,8 +319,8 @@ def f(values, axis=None, skipna=None, **kwargs): # Attributes `numeric_only`, `available_min_count` is used for docs. # See ops.inject_reduce_methods -argmax = _create_nan_agg_method("argmax", coerce_strings=True) -argmin = _create_nan_agg_method("argmin", coerce_strings=True) +_argmax_base = _create_nan_agg_method("argmax", coerce_strings=True) +_argmin_base = _create_nan_agg_method("argmin", coerce_strings=True) max = _create_nan_agg_method("max", coerce_strings=True) min = _create_nan_agg_method("min", coerce_strings=True) sum = _create_nan_agg_method("sum") diff --git a/xarray/core/ops.py b/xarray/core/ops.py index b789f93b4f1..d192f0216d2 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -47,8 +47,8 @@ # methods which remove an axis REDUCE_METHODS = ["all", "any"] NAN_REDUCE_METHODS = [ - "argmax", - "argmin", + "_argmax_base", + "_argmin_base", "max", "min", "mean", diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index ecba5307680..dec9fe513a9 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -130,8 +130,8 @@ def method(self, **kwargs): method.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name=name) return method - argmax = _reduce_method("argmax") - argmin = _reduce_method("argmin") + _argmax_base = _reduce_method("_argmax_base") + _argmin_base = _reduce_method("_argmin_base") max = _reduce_method("max") min = _reduce_method("min") mean = _reduce_method("mean") diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 4da8819b600..7dc6cf4885b 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -4714,7 +4714,7 @@ def test_idxmax(self, x, minindex, maxindex, nanindex): result7 = ar0.idxmax(fill_value=-1j) assert_identical(result7, expected7) - def test_indices_min(self, x, minindex, maxindex, nanindex): + def test_argmin_dim(self, x, minindex, maxindex, nanindex): ar = xr.DataArray( x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs ) @@ -4726,18 +4726,18 @@ def test_indices_min(self, x, minindex, maxindex, nanindex): return expected0 = {"x": indarr[minindex]} - result0 = ar.indices_min() + result0 = ar.argmin(...) for key in expected0: assert_identical(result0[key], expected0[key]) - result1 = ar.indices_min(keep_attrs=True) + result1 = ar.argmin(..., keep_attrs=True) expected1 = deepcopy(expected0) for da in expected1.values(): da.attrs = self.attrs for key in expected1: assert_identical(result1[key], expected1[key]) - result2 = ar.indices_min(skipna=False) + result2 = ar.argmin(..., skipna=False) if nanindex is not None and ar.dtype.kind != "O": expected2 = {"x": indarr.isel(x=nanindex, drop=True)} expected2["x"].attrs = {} @@ -4747,7 +4747,7 @@ def test_indices_min(self, x, minindex, maxindex, nanindex): for key in expected2: assert_identical(result2[key], expected2[key]) - def test_indices_max(self, x, minindex, maxindex, nanindex): + def test_argmax_dim(self, x, minindex, maxindex, nanindex): ar = xr.DataArray( x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs ) @@ -4759,18 +4759,18 @@ def test_indices_max(self, x, minindex, maxindex, nanindex): return expected0 = {"x": indarr[maxindex]} - result0 = ar.indices_max() + result0 = ar.argmax(...) for key in expected0: assert_identical(result0[key], expected0[key]) - result1 = ar.indices_max(keep_attrs=True) + result1 = ar.argmax(..., keep_attrs=True) expected1 = deepcopy(expected0) for da in expected1.values(): da.attrs = self.attrs for key in expected1: assert_identical(result1[key], expected1[key]) - result2 = ar.indices_max(skipna=False) + result2 = ar.argmax(..., skipna=False) if nanindex is not None and ar.dtype.kind != "O": expected2 = {"x": indarr.isel(x=nanindex, drop=True)} expected2["x"].attrs = {} @@ -5223,7 +5223,7 @@ def test_idxmax(self, x, minindex, maxindex, nanindex): result7 = ar0.idxmax(dim="x", fill_value=-5j) assert_identical(result7, expected7) - def test_indices_min(self, x, minindex, maxindex, nanindex): + def test_argmin_dim(self, x, minindex, maxindex, nanindex): ar = xr.DataArray( x, dims=["y", "x"], @@ -5235,7 +5235,7 @@ def test_indices_min(self, x, minindex, maxindex, nanindex): if np.isnan(minindex).any(): with pytest.raises(ValueError): - ar.indices_min(dim="x") + ar.argmin(dim="x") return expected0 = [ @@ -5244,11 +5244,11 @@ def test_indices_min(self, x, minindex, maxindex, nanindex): ] expected0 = {"x": xr.concat(expected0, dim="y")} - result0 = ar.indices_min(dim="x") + result0 = ar.argmin(dim=["x"]) for key in expected0: assert_identical(result0[key], expected0[key]) - result1 = ar.indices_min(dim="x", keep_attrs=True) + result1 = ar.argmin(dim=["x"], keep_attrs=True) expected1 = deepcopy(expected0) expected1["x"].attrs = self.attrs for key in expected1: @@ -5265,12 +5265,12 @@ def test_indices_min(self, x, minindex, maxindex, nanindex): expected2 = {"x": xr.concat(expected2, dim="y")} expected2["x"].attrs = {} - result2 = ar.indices_min(dim="x", skipna=False) + result2 = ar.argmin(dim=["x"], skipna=False) for key in expected2: assert_identical(result2[key], expected2[key]) - result3 = ar.indices_min() + result3 = ar.argmin(...) min_xind = ar.isel(expected0).argmin() expected3 = { "y": DataArray(min_xind), @@ -5280,7 +5280,7 @@ def test_indices_min(self, x, minindex, maxindex, nanindex): for key in expected3: assert_identical(result3[key], expected3[key]) - def test_indices_max(self, x, minindex, maxindex, nanindex): + def test_argmax_dim(self, x, minindex, maxindex, nanindex): ar = xr.DataArray( x, dims=["y", "x"], @@ -5292,7 +5292,7 @@ def test_indices_max(self, x, minindex, maxindex, nanindex): if np.isnan(maxindex).any(): with pytest.raises(ValueError): - ar.indices_max(dim="x") + ar.argmax(dim="x") return expected0 = [ @@ -5301,11 +5301,11 @@ def test_indices_max(self, x, minindex, maxindex, nanindex): ] expected0 = {"x": xr.concat(expected0, dim="y")} - result0 = ar.indices_max(dim="x") + result0 = ar.argmax(dim=["x"]) for key in expected0: assert_identical(result0[key], expected0[key]) - result1 = ar.indices_max(dim="x", keep_attrs=True) + result1 = ar.argmax(dim=["x"], keep_attrs=True) expected1 = deepcopy(expected0) expected1["x"].attrs = self.attrs for key in expected1: @@ -5322,12 +5322,12 @@ def test_indices_max(self, x, minindex, maxindex, nanindex): expected2 = {"x": xr.concat(expected2, dim="y")} expected2["x"].attrs = {} - result2 = ar.indices_max(dim="x", skipna=False) + result2 = ar.argmax(dim=["x"], skipna=False) for key in expected2: assert_identical(result2[key], expected2[key]) - result3 = ar.indices_max() + result3 = ar.argmax(...) max_xind = ar.isel(expected0).argmax() expected3 = { "y": DataArray(max_xind), @@ -5504,7 +5504,7 @@ def test_indices_max(self, x, minindex, maxindex, nanindex): ], ) class TestReduce3D(TestReduce): - def test_indices_min( + def test_argmin_dim( self, x, minindices_x, @@ -5567,10 +5567,10 @@ def test_indices_min( ]: if np.array([np.isnan(i) for i in inds.values()]).any(): with pytest.raises(ValueError): - ar.indices_min(dim=[d for d in inds]) + ar.argmin(dim=[d for d in inds]) return - result0 = ar.indices_min(dim="x") + result0 = ar.argmin(dim=["x"]) expected0 = { key: xr.DataArray(value, dims=("y", "z")) for key, value in minindices_x.items() @@ -5578,7 +5578,7 @@ def test_indices_min( for key in expected0: assert_identical(result0[key], expected0[key]) - result1 = ar.indices_min(dim="y") + result1 = ar.argmin(dim=["y"]) expected1 = { key: xr.DataArray(value, dims=("x", "z")) for key, value in minindices_y.items() @@ -5586,7 +5586,7 @@ def test_indices_min( for key in expected1: assert_identical(result1[key], expected1[key]) - result2 = ar.indices_min(dim="z") + result2 = ar.argmin(dim=["z"]) expected2 = { key: xr.DataArray(value, dims=("x", "y")) for key, value in minindices_z.items() @@ -5594,28 +5594,28 @@ def test_indices_min( for key in expected2: assert_identical(result2[key], expected2[key]) - result3 = ar.indices_min(dim=("x", "y")) + result3 = ar.argmin(dim=("x", "y")) expected3 = { key: xr.DataArray(value, dims=("z")) for key, value in minindices_xy.items() } for key in expected3: assert_identical(result3[key], expected3[key]) - result4 = ar.indices_min(dim=("x", "z")) + result4 = ar.argmin(dim=("x", "z")) expected4 = { key: xr.DataArray(value, dims=("y")) for key, value in minindices_xz.items() } for key in expected4: assert_identical(result4[key], expected4[key]) - result5 = ar.indices_min(dim=("y", "z")) + result5 = ar.argmin(dim=("y", "z")) expected5 = { key: xr.DataArray(value, dims=("x")) for key, value in minindices_yz.items() } for key in expected5: assert_identical(result5[key], expected5[key]) - result6 = ar.indices_min() + result6 = ar.argmin(...) expected6 = {key: xr.DataArray(value) for key, value in minindices_xyz.items()} for key in expected6: assert_identical(result6[key], expected6[key]) @@ -5631,7 +5631,7 @@ def test_indices_min( for key, value in minindices_x.items() } - result7 = ar.indices_min(dim="x", skipna=False) + result7 = ar.argmin(dim=["x"], skipna=False) for key in expected7: assert_identical(result7[key], expected7[key]) @@ -5646,7 +5646,7 @@ def test_indices_min( for key, value in minindices_y.items() } - result8 = ar.indices_min(dim="y", skipna=False) + result8 = ar.argmin(dim=["y"], skipna=False) for key in expected8: assert_identical(result8[key], expected8[key]) @@ -5661,7 +5661,7 @@ def test_indices_min( for key, value in minindices_z.items() } - result9 = ar.indices_min(dim="z", skipna=False) + result9 = ar.argmin(dim=["z"], skipna=False) for key in expected9: assert_identical(result9[key], expected9[key]) @@ -5675,7 +5675,7 @@ def test_indices_min( key: xr.DataArray(value, dims="z") for key, value in minindices_xy.items() } - result10 = ar.indices_min(dim=("x", "y"), skipna=False) + result10 = ar.argmin(dim=("x", "y"), skipna=False) for key in expected10: assert_identical(result10[key], expected10[key]) @@ -5689,7 +5689,7 @@ def test_indices_min( key: xr.DataArray(value, dims="y") for key, value in minindices_xz.items() } - result11 = ar.indices_min(dim=("x", "z"), skipna=False) + result11 = ar.argmin(dim=("x", "z"), skipna=False) for key in expected11: assert_identical(result11[key], expected11[key]) @@ -5703,7 +5703,7 @@ def test_indices_min( key: xr.DataArray(value, dims="x") for key, value in minindices_yz.items() } - result12 = ar.indices_min(dim=("y", "z"), skipna=False) + result12 = ar.argmin(dim=("y", "z"), skipna=False) for key in expected12: assert_identical(result12[key], expected12[key]) @@ -5715,11 +5715,11 @@ def test_indices_min( } expected13 = {key: xr.DataArray(value) for key, value in minindices_xyz.items()} - result13 = ar.indices_min(skipna=False) + result13 = ar.argmin(..., skipna=False) for key in expected13: assert_identical(result13[key], expected13[key]) - def test_indices_max( + def test_argmax_dim( self, x, minindices_x, @@ -5782,10 +5782,10 @@ def test_indices_max( ]: if np.array([np.isnan(i) for i in inds.values()]).any(): with pytest.raises(ValueError): - ar.indices_max(dim=[d for d in inds]) + ar.argmax(dim=[d for d in inds]) return - result0 = ar.indices_max(dim="x") + result0 = ar.argmax(dim=["x"]) expected0 = { key: xr.DataArray(value, dims=("y", "z")) for key, value in maxindices_x.items() @@ -5793,7 +5793,7 @@ def test_indices_max( for key in expected0: assert_identical(result0[key], expected0[key]) - result1 = ar.indices_max(dim="y") + result1 = ar.argmax(dim=["y"]) expected1 = { key: xr.DataArray(value, dims=("x", "z")) for key, value in maxindices_y.items() @@ -5801,7 +5801,7 @@ def test_indices_max( for key in expected1: assert_identical(result1[key], expected1[key]) - result2 = ar.indices_max(dim="z") + result2 = ar.argmax(dim=["z"]) expected2 = { key: xr.DataArray(value, dims=("x", "y")) for key, value in maxindices_z.items() @@ -5809,28 +5809,28 @@ def test_indices_max( for key in expected2: assert_identical(result2[key], expected2[key]) - result3 = ar.indices_max(dim=("x", "y")) + result3 = ar.argmax(dim=("x", "y")) expected3 = { key: xr.DataArray(value, dims=("z")) for key, value in maxindices_xy.items() } for key in expected3: assert_identical(result3[key], expected3[key]) - result4 = ar.indices_max(dim=("x", "z")) + result4 = ar.argmax(dim=("x", "z")) expected4 = { key: xr.DataArray(value, dims=("y")) for key, value in maxindices_xz.items() } for key in expected4: assert_identical(result4[key], expected4[key]) - result5 = ar.indices_max(dim=("y", "z")) + result5 = ar.argmax(dim=("y", "z")) expected5 = { key: xr.DataArray(value, dims=("x")) for key, value in maxindices_yz.items() } for key in expected5: assert_identical(result5[key], expected5[key]) - result6 = ar.indices_max() + result6 = ar.argmax(...) expected6 = {key: xr.DataArray(value) for key, value in maxindices_xyz.items()} for key in expected6: assert_identical(result6[key], expected6[key]) @@ -5846,7 +5846,7 @@ def test_indices_max( for key, value in maxindices_x.items() } - result7 = ar.indices_max(dim="x", skipna=False) + result7 = ar.argmax(dim=["x"], skipna=False) for key in expected7: assert_identical(result7[key], expected7[key]) @@ -5861,7 +5861,7 @@ def test_indices_max( for key, value in maxindices_y.items() } - result8 = ar.indices_max(dim="y", skipna=False) + result8 = ar.argmax(dim=["y"], skipna=False) for key in expected8: assert_identical(result8[key], expected8[key]) @@ -5876,7 +5876,7 @@ def test_indices_max( for key, value in maxindices_z.items() } - result9 = ar.indices_max(dim="z", skipna=False) + result9 = ar.argmax(dim=["z"], skipna=False) for key in expected9: assert_identical(result9[key], expected9[key]) @@ -5890,7 +5890,7 @@ def test_indices_max( key: xr.DataArray(value, dims="z") for key, value in maxindices_xy.items() } - result10 = ar.indices_max(dim=("x", "y"), skipna=False) + result10 = ar.argmax(dim=("x", "y"), skipna=False) for key in expected10: assert_identical(result10[key], expected10[key]) @@ -5904,7 +5904,7 @@ def test_indices_max( key: xr.DataArray(value, dims="y") for key, value in maxindices_xz.items() } - result11 = ar.indices_max(dim=("x", "z"), skipna=False) + result11 = ar.argmax(dim=("x", "z"), skipna=False) for key in expected11: assert_identical(result11[key], expected11[key]) @@ -5918,7 +5918,7 @@ def test_indices_max( key: xr.DataArray(value, dims="x") for key, value in maxindices_yz.items() } - result12 = ar.indices_max(dim=("y", "z"), skipna=False) + result12 = ar.argmax(dim=("y", "z"), skipna=False) for key in expected12: assert_identical(result12[key], expected12[key]) @@ -5930,7 +5930,7 @@ def test_indices_max( } expected13 = {key: xr.DataArray(value) for key, value in maxindices_xyz.items()} - result13 = ar.indices_max(skipna=False) + result13 = ar.argmax(..., skipna=False) for key in expected13: assert_identical(result13[key], expected13[key]) From 6d9d509b30ea083c3eb627b2c111a3f33a2db456 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Mon, 6 Apr 2020 22:43:08 +0100 Subject: [PATCH 07/44] Basic overload of argmin() and argmax() for Dataset If single dim is passed to Dataset.argmin() or Dataset.argmax(), then pass through to _argmin_base or _argmax_base. If a sequence is passed for dim, raise an exception, because the result for each DataArray would be a dict, which cannot be stored in a Dataset. --- xarray/core/dataset.py | 50 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 97b3caf2b6e..96c91db9ffd 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6294,5 +6294,55 @@ def idxmax( ), ) + def argmin(self, dim=None, axis=None, **kwargs): + if dim is None and axis is None: + warnings.warn( + "Behaviour of DataArray.argmin() with neither dim nor axis argument " + "will change to return a dict of indices of each dimension, and then it " + "will be an error to call Dataset.argmin() with no argument. To get a " + "single, flat index, please use np.argmin(ds) instead of ds.argmin().", + DeprecationWarning, + ) + if ( + dim is None + or axis is not None + or not isinstance(dim, Sequence) + or isinstance(dim, str) + ): + # Return int index if single dimension is passed, and is not part of a + # sequence + return getattr(self, "_argmin_base")(dim=dim, axis=axis, **kwargs) + else: + raise ValueError( + "When dim is a sequence, DataArray.argmin() returns a " + "dict. dicts cannot be contained in a Dataset, so cannot " + "call Dataset.argmin() with a sequence for dim" + ) + + def argmax(self, dim=None, axis=None, **kwargs): + if dim is None and axis is None: + warnings.warn( + "Behaviour of DataArray.argmin() with neither dim nor axis argument " + "will change to return a dict of indices of each dimension, and then it " + "will be an error to call Dataset.argmin() with no argument. To get a " + "single, flat index, please use np.argmin(ds) instead of ds.argmin().", + DeprecationWarning, + ) + if ( + dim is None + or axis is not None + or not isinstance(dim, Sequence) + or isinstance(dim, str) + ): + # Return int index if single dimension is passed, and is not part of a + # sequence + return getattr(self, "_argmax_base")(dim=dim, axis=axis, **kwargs) + else: + raise ValueError( + "When dim is a sequence, DataArray.argmax() returns a " + "dict. dicts cannot be contained in a Dataset, so cannot " + "call Dataset.argmax() with a sequence for dim" + ) + ops.inject_all_ops_and_reduce_methods(Dataset, array_only=False) From 70aaa9dc3390e48e8ec3538716b744a16902cadf Mon Sep 17 00:00:00 2001 From: John Omotani Date: Mon, 6 Apr 2020 22:52:08 +0100 Subject: [PATCH 08/44] Update Variable and dask tests with _argmin_base, _argmax_base The basic numpy-style argmin() and argmax() methods were renamed when adding support for handling multiple dimensions in DataArray.argmin() and DataArray.argmax(). Variable.argmin() and Variable.argmax() are therefore renamed as Variable._argmin_base() and Variable._argmax_base(). --- xarray/tests/test_dask.py | 8 ++++---- xarray/tests/test_variable.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 538dbbfb58b..cdb14ecfa63 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -209,11 +209,11 @@ def test_reduce(self): self.assertLazyAndAllClose(u.mean(), v.mean()) self.assertLazyAndAllClose(u.std(), v.std()) with raise_if_dask_computes(): - actual = v.argmax(dim="x") - self.assertLazyAndAllClose(u.argmax(dim="x"), actual) + actual = v._argmax_base(dim="x") + self.assertLazyAndAllClose(u._argmax_base(dim="x"), actual) with raise_if_dask_computes(): - actual = v.argmin(dim="x") - self.assertLazyAndAllClose(u.argmin(dim="x"), actual) + actual = v._argmin_base(dim="x") + self.assertLazyAndAllClose(u._argmin_base(dim="x"), actual) self.assertLazyAndAllClose((u > 1).any(), (v > 1).any()) self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x")) with raises_regex(NotImplementedError, "only works along an axis"): diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 78e3848b8fb..d28a7a95846 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1657,7 +1657,7 @@ def test_reduce_funcs(self): assert_identical(v.all(dim="x"), Variable([], False)) v = Variable("t", pd.date_range("2000-01-01", periods=3)) - assert v.argmax(skipna=True) == 2 + assert v._argmax_base(skipna=True) == 2 assert_identical(v.max(), Variable([], pd.Timestamp("2000-01-03"))) From f8952a8c9efb7a9a155f9de8ca77e1c2b727cacf Mon Sep 17 00:00:00 2001 From: John Omotani Date: Mon, 6 Apr 2020 23:35:47 +0100 Subject: [PATCH 09/44] Update api-hidden.rst with _argmin_base and _argmax_base --- doc/api-hidden.rst | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index cc9517a98ba..26ab98c711f 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -41,8 +41,8 @@ core.rolling.DatasetCoarsen.all core.rolling.DatasetCoarsen.any - core.rolling.DatasetCoarsen.argmax - core.rolling.DatasetCoarsen.argmin + core.rolling.DatasetCoarsen._argmax_base + core.rolling.DatasetCoarsen._argmin_base core.rolling.DatasetCoarsen.count core.rolling.DatasetCoarsen.max core.rolling.DatasetCoarsen.mean @@ -68,8 +68,8 @@ core.groupby.DatasetGroupBy.where core.groupby.DatasetGroupBy.all core.groupby.DatasetGroupBy.any - core.groupby.DatasetGroupBy.argmax - core.groupby.DatasetGroupBy.argmin + core.groupby.DatasetGroupBy._argmax_base + core.groupby.DatasetGroupBy._argmin_base core.groupby.DatasetGroupBy.count core.groupby.DatasetGroupBy.max core.groupby.DatasetGroupBy.mean @@ -85,8 +85,8 @@ core.resample.DatasetResample.all core.resample.DatasetResample.any core.resample.DatasetResample.apply - core.resample.DatasetResample.argmax - core.resample.DatasetResample.argmin + core.resample.DatasetResample._argmax_base + core.resample.DatasetResample._argmin_base core.resample.DatasetResample.assign core.resample.DatasetResample.assign_coords core.resample.DatasetResample.bfill @@ -110,8 +110,8 @@ core.resample.DatasetResample.dims core.resample.DatasetResample.groups - core.rolling.DatasetRolling.argmax - core.rolling.DatasetRolling.argmin + core.rolling.DatasetRolling._argmax_base + core.rolling.DatasetRolling._argmin_base core.rolling.DatasetRolling.count core.rolling.DatasetRolling.max core.rolling.DatasetRolling.mean @@ -183,8 +183,8 @@ core.rolling.DataArrayCoarsen.all core.rolling.DataArrayCoarsen.any - core.rolling.DataArrayCoarsen.argmax - core.rolling.DataArrayCoarsen.argmin + core.rolling.DataArrayCoarsen._argmax_base + core.rolling.DataArrayCoarsen._argmin_base core.rolling.DataArrayCoarsen.count core.rolling.DataArrayCoarsen.max core.rolling.DataArrayCoarsen.mean @@ -209,8 +209,8 @@ core.groupby.DataArrayGroupBy.where core.groupby.DataArrayGroupBy.all core.groupby.DataArrayGroupBy.any - core.groupby.DataArrayGroupBy.argmax - core.groupby.DataArrayGroupBy.argmin + core.groupby.DataArrayGroupBy._argmax_base + core.groupby.DataArrayGroupBy._argmin_base core.groupby.DataArrayGroupBy.count core.groupby.DataArrayGroupBy.max core.groupby.DataArrayGroupBy.mean @@ -226,8 +226,8 @@ core.resample.DataArrayResample.all core.resample.DataArrayResample.any core.resample.DataArrayResample.apply - core.resample.DataArrayResample.argmax - core.resample.DataArrayResample.argmin + core.resample.DataArrayResample._argmax_base + core.resample.DataArrayResample._argmin_base core.resample.DataArrayResample.assign_coords core.resample.DataArrayResample.bfill core.resample.DataArrayResample.count @@ -250,8 +250,8 @@ core.resample.DataArrayResample.dims core.resample.DataArrayResample.groups - core.rolling.DataArrayRolling.argmax - core.rolling.DataArrayRolling.argmin + core.rolling.DataArrayRolling._argmax_base + core.rolling.DataArrayRolling._argmin_base core.rolling.DataArrayRolling.count core.rolling.DataArrayRolling.max core.rolling.DataArrayRolling.mean @@ -349,8 +349,8 @@ Variable.all Variable.any - Variable.argmax - Variable.argmin + Variable._argmax_base + Variable._argmin_base Variable.argsort Variable.astype Variable.broadcast_equals @@ -421,8 +421,8 @@ IndexVariable.all IndexVariable.any - IndexVariable.argmax - IndexVariable.argmin + IndexVariable._argmax_base + IndexVariable._argmin_base IndexVariable.argsort IndexVariable.astype IndexVariable.broadcast_equals From 8caf2b8d07c14a2956a26b50ee08d83323c36058 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 7 Apr 2020 10:24:26 +0100 Subject: [PATCH 10/44] Explicitly defined class methods override injected methods If a method (such as 'argmin') has been explicitly defined on a class (so that hasattr(cls, "argmin")==True), then do not inject that method, as it would override the explicitly defined one. Instead inject a private method, prefixed by "_injected_" (such as '_injected_argmin'), so that the injected method is available to the explicitly defined one. Do not perform the hasattr check on binary ops, because this breaks some operations (e.g. addition between DataArray and int in test_dask.py). --- doc/api-hidden.rst | 12 ++++++++++++ xarray/core/ops.py | 19 ++++++++++++++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 26ab98c711f..34ff5a32eb9 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -18,6 +18,8 @@ Dataset.any Dataset.argmax Dataset.argmin + Dataset._injected_argmax + Dataset._injected_argmin Dataset.max Dataset.min Dataset.mean @@ -94,6 +96,8 @@ core.resample.DatasetResample.ffill core.resample.DatasetResample.fillna core.resample.DatasetResample.first + core.resample.DatasetResample._injected_argmax + core.resample.DatasetResample._injected_argmin core.resample.DatasetResample.last core.resample.DatasetResample.map core.resample.DatasetResample.max @@ -160,6 +164,8 @@ DataArray.any DataArray.argmax DataArray.argmin + DataArray._injected_argmax + DataArray._injected_argmin DataArray.max DataArray.min DataArray.mean @@ -234,6 +240,8 @@ core.resample.DataArrayResample.ffill core.resample.DataArrayResample.fillna core.resample.DataArrayResample.first + core.resample.DataArrayResample._injected_argmax + core.resample.DataArrayResample._injected_argmin core.resample.DataArrayResample.last core.resample.DataArrayResample.map core.resample.DataArrayResample.max @@ -369,6 +377,8 @@ Variable.fillna Variable.get_axis_num Variable.identical + Variable._injected_argmax + Variable._injected_argmin Variable.isel Variable.isnull Variable.item @@ -442,6 +452,8 @@ IndexVariable.get_axis_num IndexVariable.get_level_variable IndexVariable.identical + IndexVariable._injected_argmax + IndexVariable._injected_argmin IndexVariable.isel IndexVariable.isnull IndexVariable.item diff --git a/xarray/core/ops.py b/xarray/core/ops.py index d192f0216d2..0cd4162cdad 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -260,6 +260,9 @@ def inject_reduce_methods(cls): + [("count", duck_array_ops.count, False)] ) for name, f, include_skipna in methods: + if hasattr(cls, name): + name = "_injected_" + name + numeric_only = getattr(f, "numeric_only", False) available_min_count = getattr(f, "available_min_count", False) min_count_docs = _MINCOUNT_DOCSTRING if available_min_count else "" @@ -278,6 +281,9 @@ def inject_reduce_methods(cls): def inject_cum_methods(cls): methods = [(name, getattr(duck_array_ops, name), True) for name in NAN_CUM_METHODS] for name, f, include_skipna in methods: + if hasattr(cls, name): + name = "_injected_" + name + numeric_only = getattr(f, "numeric_only", False) func = cls._reduce_method(f, include_skipna, numeric_only) func.__name__ = name @@ -325,24 +331,35 @@ def inject_all_ops_and_reduce_methods(cls, priority=50, array_only=True): # patch in standard special operations for name in UNARY_OPS: + if hasattr(cls, op_str(name)): + name = "_injected_" + name setattr(cls, op_str(name), cls._unary_op(get_op(name))) inject_binary_ops(cls, inplace=True) # patch in numpy/pandas methods for name in NUMPY_UNARY_METHODS: + if hasattr(cls, op_str(name)): + name = "_injected_" + name setattr(cls, name, cls._unary_op(_method_wrapper(name))) for name in PANDAS_UNARY_FUNCTIONS: f = _func_slash_method_wrapper(getattr(duck_array_ops, name), name=name) + if hasattr(cls, op_str(name)): + name = "_injected_" + name setattr(cls, name, cls._unary_op(f)) f = _func_slash_method_wrapper(duck_array_ops.around, name="round") - setattr(cls, "round", cls._unary_op(f)) + if hasattr(cls, "round"): + setattr(cls, "_injected_round", cls._unary_op(f)) + else: + setattr(cls, "round", cls._unary_op(f)) if array_only: # these methods don't return arrays of the same shape as the input, so # don't try to patch these in for Dataset objects for name in NUMPY_SAME_METHODS: + if hasattr(cls, op_str(name)): + name = "_injected_" + name setattr(cls, name, _values_method_wrapper(name)) inject_reduce_methods(cls) From 4778cfd527f9c0c312136dcba41894509f49ddf7 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 7 Apr 2020 10:29:18 +0100 Subject: [PATCH 11/44] Move StringAccessor back to bottom of DataArray class definition --- xarray/core/dataarray.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 1507c4505e7..16265349798 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3722,10 +3722,6 @@ def idxmax( keep_attrs=keep_attrs, ) - # this needs to be at the end, or mypy will confuse with `str` - # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names - str = property(StringAccessor) - def _unravel_argminmax( self, argminmax: Hashable, @@ -3975,6 +3971,10 @@ def argmax( """ return self._unravel_argminmax("_argmax_base", dim, axis, keep_attrs, skipna) + # this needs to be at the end, or mypy will confuse with `str` + # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names + str = property(StringAccessor) + # priority most be higher than Variable to properly work with binary ufuncs ops.inject_all_ops_and_reduce_methods(DataArray, priority=60) From 66cf085aad038553ccafd35be9fc5756436ee96e Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 7 Apr 2020 10:30:08 +0100 Subject: [PATCH 12/44] Revert use of _argmin_base and _argmax_base Now not needed because of change to injection in ops.py. --- doc/api-hidden.rst | 40 +++++++++++++++++------------------ xarray/core/dataarray.py | 8 +++---- xarray/core/dataset.py | 4 ++-- xarray/core/duck_array_ops.py | 4 ++-- xarray/core/ops.py | 4 ++-- xarray/core/rolling.py | 4 ++-- xarray/tests/test_dask.py | 8 +++---- xarray/tests/test_variable.py | 2 +- 8 files changed, 37 insertions(+), 37 deletions(-) diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 34ff5a32eb9..e6edfe68e4b 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -43,8 +43,8 @@ core.rolling.DatasetCoarsen.all core.rolling.DatasetCoarsen.any - core.rolling.DatasetCoarsen._argmax_base - core.rolling.DatasetCoarsen._argmin_base + core.rolling.DatasetCoarsen.argmax + core.rolling.DatasetCoarsen.argmin core.rolling.DatasetCoarsen.count core.rolling.DatasetCoarsen.max core.rolling.DatasetCoarsen.mean @@ -70,8 +70,8 @@ core.groupby.DatasetGroupBy.where core.groupby.DatasetGroupBy.all core.groupby.DatasetGroupBy.any - core.groupby.DatasetGroupBy._argmax_base - core.groupby.DatasetGroupBy._argmin_base + core.groupby.DatasetGroupBy.argmax + core.groupby.DatasetGroupBy.argmin core.groupby.DatasetGroupBy.count core.groupby.DatasetGroupBy.max core.groupby.DatasetGroupBy.mean @@ -87,8 +87,8 @@ core.resample.DatasetResample.all core.resample.DatasetResample.any core.resample.DatasetResample.apply - core.resample.DatasetResample._argmax_base - core.resample.DatasetResample._argmin_base + core.resample.DatasetResample.argmax + core.resample.DatasetResample.argmin core.resample.DatasetResample.assign core.resample.DatasetResample.assign_coords core.resample.DatasetResample.bfill @@ -114,8 +114,8 @@ core.resample.DatasetResample.dims core.resample.DatasetResample.groups - core.rolling.DatasetRolling._argmax_base - core.rolling.DatasetRolling._argmin_base + core.rolling.DatasetRolling.argmax + core.rolling.DatasetRolling.argmin core.rolling.DatasetRolling.count core.rolling.DatasetRolling.max core.rolling.DatasetRolling.mean @@ -189,8 +189,8 @@ core.rolling.DataArrayCoarsen.all core.rolling.DataArrayCoarsen.any - core.rolling.DataArrayCoarsen._argmax_base - core.rolling.DataArrayCoarsen._argmin_base + core.rolling.DataArrayCoarsen.argmax + core.rolling.DataArrayCoarsen.argmin core.rolling.DataArrayCoarsen.count core.rolling.DataArrayCoarsen.max core.rolling.DataArrayCoarsen.mean @@ -215,8 +215,8 @@ core.groupby.DataArrayGroupBy.where core.groupby.DataArrayGroupBy.all core.groupby.DataArrayGroupBy.any - core.groupby.DataArrayGroupBy._argmax_base - core.groupby.DataArrayGroupBy._argmin_base + core.groupby.DataArrayGroupBy.argmax + core.groupby.DataArrayGroupBy.argmin core.groupby.DataArrayGroupBy.count core.groupby.DataArrayGroupBy.max core.groupby.DataArrayGroupBy.mean @@ -232,8 +232,8 @@ core.resample.DataArrayResample.all core.resample.DataArrayResample.any core.resample.DataArrayResample.apply - core.resample.DataArrayResample._argmax_base - core.resample.DataArrayResample._argmin_base + core.resample.DataArrayResample.argmax + core.resample.DataArrayResample.argmin core.resample.DataArrayResample.assign_coords core.resample.DataArrayResample.bfill core.resample.DataArrayResample.count @@ -258,8 +258,8 @@ core.resample.DataArrayResample.dims core.resample.DataArrayResample.groups - core.rolling.DataArrayRolling._argmax_base - core.rolling.DataArrayRolling._argmin_base + core.rolling.DataArrayRolling.argmax + core.rolling.DataArrayRolling.argmin core.rolling.DataArrayRolling.count core.rolling.DataArrayRolling.max core.rolling.DataArrayRolling.mean @@ -357,8 +357,8 @@ Variable.all Variable.any - Variable._argmax_base - Variable._argmin_base + Variable.argmax + Variable.argmin Variable.argsort Variable.astype Variable.broadcast_equals @@ -431,8 +431,8 @@ IndexVariable.all IndexVariable.any - IndexVariable._argmax_base - IndexVariable._argmin_base + IndexVariable.argmax + IndexVariable.argmin IndexVariable.argsort IndexVariable.astype IndexVariable.broadcast_equals diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 16265349798..4b75fd331f2 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3752,7 +3752,7 @@ def _unravel_argminmax( ): # Return int index if single dimension is passed, and is not part of a # sequence - return getattr(self, str(argminmax))( + return getattr(self, str("_injected_" + argminmax))( dim=dim, axis=axis, keep_attrs=keep_attrs, skipna=skipna ) @@ -3769,7 +3769,7 @@ def _unravel_argminmax( result_dims = stacked.dims[:-1] reduce_shape = tuple(self.sizes[d] for d in dim) - result_flat_indices = getattr(stacked, str(argminmax))(axis=-1, skipna=skipna) + result_flat_indices = getattr(stacked, str("_injected_" + argminmax))(axis=-1, skipna=skipna) result_unravelled_indices = np.unravel_index(result_flat_indices, reduce_shape) @@ -3876,7 +3876,7 @@ def argmin( array([ 1, -5, 1]) Dimensions without coordinates: y """ - return self._unravel_argminmax("_argmin_base", dim, axis, keep_attrs, skipna) + return self._unravel_argminmax("argmin", dim, axis, keep_attrs, skipna) def argmax( self, @@ -3969,7 +3969,7 @@ def argmax( array([3, 5, 3]) Dimensions without coordinates: y """ - return self._unravel_argminmax("_argmax_base", dim, axis, keep_attrs, skipna) + return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna) # this needs to be at the end, or mypy will confuse with `str` # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 96c91db9ffd..8bf6ff2554c 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6311,7 +6311,7 @@ def argmin(self, dim=None, axis=None, **kwargs): ): # Return int index if single dimension is passed, and is not part of a # sequence - return getattr(self, "_argmin_base")(dim=dim, axis=axis, **kwargs) + return getattr(self, "_injected_argmin")(dim=dim, axis=axis, **kwargs) else: raise ValueError( "When dim is a sequence, DataArray.argmin() returns a " @@ -6336,7 +6336,7 @@ def argmax(self, dim=None, axis=None, **kwargs): ): # Return int index if single dimension is passed, and is not part of a # sequence - return getattr(self, "_argmax_base")(dim=dim, axis=axis, **kwargs) + return getattr(self, "_injected_argmax")(dim=dim, axis=axis, **kwargs) else: raise ValueError( "When dim is a sequence, DataArray.argmax() returns a " diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 19dd180cb33..1340b456cf2 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -319,8 +319,8 @@ def f(values, axis=None, skipna=None, **kwargs): # Attributes `numeric_only`, `available_min_count` is used for docs. # See ops.inject_reduce_methods -_argmax_base = _create_nan_agg_method("argmax", coerce_strings=True) -_argmin_base = _create_nan_agg_method("argmin", coerce_strings=True) +argmax = _create_nan_agg_method("argmax", coerce_strings=True) +argmin = _create_nan_agg_method("argmin", coerce_strings=True) max = _create_nan_agg_method("max", coerce_strings=True) min = _create_nan_agg_method("min", coerce_strings=True) sum = _create_nan_agg_method("sum") diff --git a/xarray/core/ops.py b/xarray/core/ops.py index 0cd4162cdad..57a86bab7b1 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -47,8 +47,8 @@ # methods which remove an axis REDUCE_METHODS = ["all", "any"] NAN_REDUCE_METHODS = [ - "_argmax_base", - "_argmin_base", + "argmax", + "argmin", "max", "min", "mean", diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index dec9fe513a9..ecba5307680 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -130,8 +130,8 @@ def method(self, **kwargs): method.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name=name) return method - _argmax_base = _reduce_method("_argmax_base") - _argmin_base = _reduce_method("_argmin_base") + argmax = _reduce_method("argmax") + argmin = _reduce_method("argmin") max = _reduce_method("max") min = _reduce_method("min") mean = _reduce_method("mean") diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index cdb14ecfa63..538dbbfb58b 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -209,11 +209,11 @@ def test_reduce(self): self.assertLazyAndAllClose(u.mean(), v.mean()) self.assertLazyAndAllClose(u.std(), v.std()) with raise_if_dask_computes(): - actual = v._argmax_base(dim="x") - self.assertLazyAndAllClose(u._argmax_base(dim="x"), actual) + actual = v.argmax(dim="x") + self.assertLazyAndAllClose(u.argmax(dim="x"), actual) with raise_if_dask_computes(): - actual = v._argmin_base(dim="x") - self.assertLazyAndAllClose(u._argmin_base(dim="x"), actual) + actual = v.argmin(dim="x") + self.assertLazyAndAllClose(u.argmin(dim="x"), actual) self.assertLazyAndAllClose((u > 1).any(), (v > 1).any()) self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x")) with raises_regex(NotImplementedError, "only works along an axis"): diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index d28a7a95846..78e3848b8fb 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1657,7 +1657,7 @@ def test_reduce_funcs(self): assert_identical(v.all(dim="x"), Variable([], False)) v = Variable("t", pd.date_range("2000-01-01", periods=3)) - assert v._argmax_base(skipna=True) == 2 + assert v.argmax(skipna=True) == 2 assert_identical(v.max(), Variable([], pd.Timestamp("2000-01-03"))) From c78c1fedfe3c40a167ff21803a39846691f6e2f9 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Wed, 8 Apr 2020 10:12:12 +0100 Subject: [PATCH 13/44] Move implementation of argmin, argmax from DataArray to Variable Makes use of argmin and argmax more general (they are available for Variable) and is straightforward for DataArray to wrap the Variable version. --- xarray/core/dataarray.py | 86 ++++----------------- xarray/core/variable.py | 160 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 174 insertions(+), 72 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 4b75fd331f2..94770c3a1ff 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -55,7 +55,7 @@ from .indexes import Indexes, default_indexes, propagate_indexes from .indexing import is_fancy_indexer from .merge import PANDAS_TYPES, _extract_indexes_from_coords -from .options import OPTIONS, _get_keep_attrs +from .options import OPTIONS from .utils import Default, ReprObject, _check_inplace, _default, either_dict_or_kwargs from .variable import ( IndexVariable, @@ -3722,77 +3722,13 @@ def idxmax( keep_attrs=keep_attrs, ) - def _unravel_argminmax( - self, - argminmax: Hashable, - dim: Union[Hashable, Sequence[Hashable], None], - axis: Union[int, None], - keep_attrs: Optional[bool], - skipna: Optional[bool], - ) -> Dict[Hashable, "DataArray"]: - """Apply argmin or argmax over one or more dimensions, returning the result as a - dict of DataArray that can be passed directly to isel. - """ - if dim is None and axis is None: - warnings.warn( - "Behaviour of argmin/argmax with neither dim nor axis argument will " - "change to return a dict of indices of each dimension. To get a " - "single, flat index, please use np.argmin(da) or np.argmax(da) instead " - "of da.argmin() or da.argmax().", - DeprecationWarning, - ) - if dim is ...: - # In future, should do this also when (dim is None and axis is None) - dim = self.dims - if ( - dim is None - or axis is not None - or not isinstance(dim, Sequence) - or isinstance(dim, str) - ): - # Return int index if single dimension is passed, and is not part of a - # sequence - return getattr(self, str("_injected_" + argminmax))( - dim=dim, axis=axis, keep_attrs=keep_attrs, skipna=skipna - ) - - # Get a name for the new dimension that does not conflict with any existing - # dimension - newdimname = "_unravel_argminmax_dim_0" - count = 1 - while newdimname in self.dims: - newdimname = "_unravel_argminmax_dim_{}".format(count) - count += 1 - - stacked = self.stack({newdimname: dim}) - - result_dims = stacked.dims[:-1] - reduce_shape = tuple(self.sizes[d] for d in dim) - - result_flat_indices = getattr(stacked, str("_injected_" + argminmax))(axis=-1, skipna=skipna) - - result_unravelled_indices = np.unravel_index(result_flat_indices, reduce_shape) - - result = { - d: DataArray(i, dims=result_dims) - for d, i in zip(dim, result_unravelled_indices) - } - - if keep_attrs is None: - keep_attrs = _get_keep_attrs(default=False) - if keep_attrs: - for da in result.values(): - da.attrs = self.attrs - - return result - def argmin( self, dim: Union[Hashable, Sequence[Hashable]] = None, axis: Union[int, None] = None, keep_attrs: bool = None, skipna: bool = None, - ) -> Dict[Hashable, "DataArray"]: + ) -> Union["DataArray", Dict[Hashable, "DataArray"]]: """Indices of the minimum of the DataArray over one or more dimensions. Result returned as dict of DataArrays, which can be passed directly to isel(). @@ -3823,7 +3759,7 @@ def argmin( See also -------- - DataArray.argmin, DataArray.idxmin + Variable.argmin, DataArray.idxmin Examples -------- @@ -3876,7 +3812,11 @@ def argmin( array([ 1, -5, 1]) Dimensions without coordinates: y """ - return self._unravel_argminmax("argmin", dim, axis, keep_attrs, skipna) + result = self.variable.argmin(dim, axis, keep_attrs, skipna) + if isinstance(result, dict): + return {k: self._replace_maybe_drop_dims(v) for k, v in result.items()} + else: + return self._replace_maybe_drop_dims(result) def argmax( self, @@ -3884,7 +3824,7 @@ def argmax( axis: Union[int, None] = None, keep_attrs: bool = None, skipna: bool = None, - ) -> Dict[Hashable, "DataArray"]: + ) -> Union["DataArray", Dict[Hashable, "DataArray"]]: """Indices of the maximum of the DataArray over one or more dimensions. Result returned as dict of DataArrays, which can be passed directly to isel(). @@ -3915,7 +3855,7 @@ def argmax( See also -------- - DataArray.argmax, DataArray.idxmax + Variable.argmax, DataArray.idxmax Examples -------- @@ -3969,7 +3909,11 @@ def argmax( array([3, 5, 3]) Dimensions without coordinates: y """ - return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna) + result = self.variable.argmax(dim, axis, keep_attrs, skipna) + if isinstance(result, dict): + return {k: self._replace_maybe_drop_dims(v) for k, v in result.items()} + else: + return self._replace_maybe_drop_dims(result) # this needs to be at the end, or mypy will confuse with `str` # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 68e823ca426..99ac6cd9fb7 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -6,7 +6,17 @@ from collections import defaultdict from datetime import timedelta from distutils.version import LooseVersion -from typing import Any, Dict, Hashable, Mapping, Tuple, TypeVar, Union +from typing import ( + Any, + Dict, + Hashable, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) import numpy as np import pandas as pd @@ -2069,6 +2079,154 @@ def _to_numeric(self, offset=None, datetime_unit=None, dtype=float): ) return type(self)(self.dims, numeric_array, self._attrs) + def _unravel_argminmax( + self, + argminmax: Hashable, + dim: Union[Hashable, Sequence[Hashable], None], + axis: Union[int, None], + keep_attrs: Optional[bool], + skipna: Optional[bool], + ) -> Union["Variable", Dict[Hashable, "Variable"]]: + """Apply argmin or argmax over one or more dimensions, returning the result as a + dict of DataArray that can be passed directly to isel. + """ + if dim is None and axis is None: + warnings.warn( + "Behaviour of argmin/argmax with neither dim nor axis argument will " + "change to return a dict of indices of each dimension. To get a " + "single, flat index, please use np.argmin(da) or np.argmax(da) instead " + "of da.argmin() or da.argmax().", + DeprecationWarning, + ) + if dim is ...: + # In future, should do this also when (dim is None and axis is None) + dim = self.dims + if ( + dim is None + or axis is not None + or not isinstance(dim, Sequence) + or isinstance(dim, str) + ): + # Return int index if single dimension is passed, and is not part of a + # sequence + return getattr(self, "_injected_" + str(argminmax))( + dim=dim, axis=axis, keep_attrs=keep_attrs, skipna=skipna + ) + + # Get a name for the new dimension that does not conflict with any existing + # dimension + newdimname = "_unravel_argminmax_dim_0" + count = 1 + while newdimname in self.dims: + newdimname = "_unravel_argminmax_dim_{}".format(count) + count += 1 + + stacked = self.stack({newdimname: dim}) + + result_dims = stacked.dims[:-1] + reduce_shape = tuple(self.sizes[d] for d in dim) + + result_flat_indices = getattr(stacked, "_injected_" + str(argminmax))( + axis=-1, skipna=skipna + ) + + result_unravelled_indices = np.unravel_index(result_flat_indices, reduce_shape) + + result = { + d: Variable(dims=result_dims, data=i) + for d, i in zip(dim, result_unravelled_indices) + } + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + if keep_attrs: + for v in result.values(): + v.attrs = self.attrs + + return result + + def argmin( + self, + dim: Union[Hashable, Sequence[Hashable]] = None, + axis: Union[int, None] = None, + keep_attrs: bool = None, + skipna: bool = None, + ) -> Union["Variable", Dict[Hashable, "Variable"]]: + """Indices of the minimum of the DataArray over one or more dimensions. Result + returned as dict of DataArrays, which can be passed directly to isel(). + + If there are multiple minima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : hashable or sequence of hashable, optional + The dimensions over which to find the minimum. By default, finds minimum over + all dimensions. + axis : int, optional + Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments + can be supplied. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : dict of DataArray + + See also + -------- + DataArray.argmin, DataArray.idxmin + """ + return self._unravel_argminmax("argmin", dim, axis, keep_attrs, skipna) + + def argmax( + self, + dim: Union[Hashable, Sequence[Hashable]] = None, + axis: Union[int, None] = None, + keep_attrs: bool = None, + skipna: bool = None, + ) -> Union["Variable", Dict[Hashable, "Variable"]]: + """Indices of the maximum of the DataArray over one or more dimensions. Result + returned as dict of DataArrays, which can be passed directly to isel(). + + If there are multiple maxima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : hashable or sequence of hashable, optional + The dimensions over which to find the maximum. By default, finds maximum over + all dimensions. + axis : int, optional + Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments + can be supplied. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : dict of DataArray + + See also + -------- + DataArray.argmax, DataArray.idxmax + """ + return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna) + ops.inject_all_ops_and_reduce_methods(Variable) From cb6742d81e461631ad518d2c9e97adcbbbf0a0e2 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Wed, 8 Apr 2020 10:13:49 +0100 Subject: [PATCH 14/44] Update tests for change to coordinates on result of argmin, argmax --- xarray/tests/test_dataarray.py | 56 +++++++++++++++++----------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 7dc6cf4885b..0d5624afdc9 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -5239,7 +5239,7 @@ def test_argmin_dim(self, x, minindex, maxindex, nanindex): return expected0 = [ - indarr.isel(y=yi, drop=True).isel(x=indi, drop=True) + indarr.isel(y=yi).isel(x=indi, drop=True) for yi, indi in enumerate(minindex) ] expected0 = {"x": xr.concat(expected0, dim="y")} @@ -5259,7 +5259,7 @@ def test_argmin_dim(self, x, minindex, maxindex, nanindex): for x, y in zip(minindex, nanindex) ] expected2 = [ - indarr.isel(y=yi, drop=True).isel(x=indi, drop=True) + indarr.isel(y=yi).isel(x=indi, drop=True) for yi, indi in enumerate(minindex) ] expected2 = {"x": xr.concat(expected2, dim="y")} @@ -5296,7 +5296,7 @@ def test_argmax_dim(self, x, minindex, maxindex, nanindex): return expected0 = [ - indarr.isel(y=yi, drop=True).isel(x=indi, drop=True) + indarr.isel(y=yi).isel(x=indi, drop=True) for yi, indi in enumerate(maxindex) ] expected0 = {"x": xr.concat(expected0, dim="y")} @@ -5316,7 +5316,7 @@ def test_argmax_dim(self, x, minindex, maxindex, nanindex): for x, y in zip(maxindex, nanindex) ] expected2 = [ - indarr.isel(y=yi, drop=True).isel(x=indi, drop=True) + indarr.isel(y=yi).isel(x=indi, drop=True) for yi, indi in enumerate(maxindex) ] expected2 = {"x": xr.concat(expected2, dim="y")} @@ -5576,7 +5576,7 @@ def test_argmin_dim( for key, value in minindices_x.items() } for key in expected0: - assert_identical(result0[key], expected0[key]) + assert_identical(result0[key].drop_vars(["y", "z"]), expected0[key]) result1 = ar.argmin(dim=["y"]) expected1 = { @@ -5584,7 +5584,7 @@ def test_argmin_dim( for key, value in minindices_y.items() } for key in expected1: - assert_identical(result1[key], expected1[key]) + assert_identical(result1[key].drop_vars(["x", "z"]), expected1[key]) result2 = ar.argmin(dim=["z"]) expected2 = { @@ -5592,28 +5592,28 @@ def test_argmin_dim( for key, value in minindices_z.items() } for key in expected2: - assert_identical(result2[key], expected2[key]) + assert_identical(result2[key].drop_vars(["x", "y"]), expected2[key]) result3 = ar.argmin(dim=("x", "y")) expected3 = { key: xr.DataArray(value, dims=("z")) for key, value in minindices_xy.items() } for key in expected3: - assert_identical(result3[key], expected3[key]) + assert_identical(result3[key].drop_vars("z"), expected3[key]) result4 = ar.argmin(dim=("x", "z")) expected4 = { key: xr.DataArray(value, dims=("y")) for key, value in minindices_xz.items() } for key in expected4: - assert_identical(result4[key], expected4[key]) + assert_identical(result4[key].drop_vars("y"), expected4[key]) result5 = ar.argmin(dim=("y", "z")) expected5 = { key: xr.DataArray(value, dims=("x")) for key, value in minindices_yz.items() } for key in expected5: - assert_identical(result5[key], expected5[key]) + assert_identical(result5[key].drop_vars("x"), expected5[key]) result6 = ar.argmin(...) expected6 = {key: xr.DataArray(value) for key, value in minindices_xyz.items()} @@ -5633,7 +5633,7 @@ def test_argmin_dim( result7 = ar.argmin(dim=["x"], skipna=False) for key in expected7: - assert_identical(result7[key], expected7[key]) + assert_identical(result7[key].drop_vars(["y", "z"]), expected7[key]) minindices_y = { key: xr.where( @@ -5648,7 +5648,7 @@ def test_argmin_dim( result8 = ar.argmin(dim=["y"], skipna=False) for key in expected8: - assert_identical(result8[key], expected8[key]) + assert_identical(result8[key].drop_vars(["x", "z"]), expected8[key]) minindices_z = { key: xr.where( @@ -5663,7 +5663,7 @@ def test_argmin_dim( result9 = ar.argmin(dim=["z"], skipna=False) for key in expected9: - assert_identical(result9[key], expected9[key]) + assert_identical(result9[key].drop_vars(["x", "y"]), expected9[key]) minindices_xy = { key: xr.where( @@ -5677,7 +5677,7 @@ def test_argmin_dim( result10 = ar.argmin(dim=("x", "y"), skipna=False) for key in expected10: - assert_identical(result10[key], expected10[key]) + assert_identical(result10[key].drop_vars("z"), expected10[key]) minindices_xz = { key: xr.where( @@ -5691,7 +5691,7 @@ def test_argmin_dim( result11 = ar.argmin(dim=("x", "z"), skipna=False) for key in expected11: - assert_identical(result11[key], expected11[key]) + assert_identical(result11[key].drop_vars("y"), expected11[key]) minindices_yz = { key: xr.where( @@ -5705,7 +5705,7 @@ def test_argmin_dim( result12 = ar.argmin(dim=("y", "z"), skipna=False) for key in expected12: - assert_identical(result12[key], expected12[key]) + assert_identical(result12[key].drop_vars("x"), expected12[key]) minindices_xyz = { key: xr.where( @@ -5791,7 +5791,7 @@ def test_argmax_dim( for key, value in maxindices_x.items() } for key in expected0: - assert_identical(result0[key], expected0[key]) + assert_identical(result0[key].drop_vars(["y", "z"]), expected0[key]) result1 = ar.argmax(dim=["y"]) expected1 = { @@ -5799,7 +5799,7 @@ def test_argmax_dim( for key, value in maxindices_y.items() } for key in expected1: - assert_identical(result1[key], expected1[key]) + assert_identical(result1[key].drop_vars(["x", "z"]), expected1[key]) result2 = ar.argmax(dim=["z"]) expected2 = { @@ -5807,28 +5807,28 @@ def test_argmax_dim( for key, value in maxindices_z.items() } for key in expected2: - assert_identical(result2[key], expected2[key]) + assert_identical(result2[key].drop_vars(["x", "y"]), expected2[key]) result3 = ar.argmax(dim=("x", "y")) expected3 = { key: xr.DataArray(value, dims=("z")) for key, value in maxindices_xy.items() } for key in expected3: - assert_identical(result3[key], expected3[key]) + assert_identical(result3[key].drop_vars("z"), expected3[key]) result4 = ar.argmax(dim=("x", "z")) expected4 = { key: xr.DataArray(value, dims=("y")) for key, value in maxindices_xz.items() } for key in expected4: - assert_identical(result4[key], expected4[key]) + assert_identical(result4[key].drop_vars("y"), expected4[key]) result5 = ar.argmax(dim=("y", "z")) expected5 = { key: xr.DataArray(value, dims=("x")) for key, value in maxindices_yz.items() } for key in expected5: - assert_identical(result5[key], expected5[key]) + assert_identical(result5[key].drop_vars("x"), expected5[key]) result6 = ar.argmax(...) expected6 = {key: xr.DataArray(value) for key, value in maxindices_xyz.items()} @@ -5848,7 +5848,7 @@ def test_argmax_dim( result7 = ar.argmax(dim=["x"], skipna=False) for key in expected7: - assert_identical(result7[key], expected7[key]) + assert_identical(result7[key].drop_vars(["y", "z"]), expected7[key]) maxindices_y = { key: xr.where( @@ -5863,7 +5863,7 @@ def test_argmax_dim( result8 = ar.argmax(dim=["y"], skipna=False) for key in expected8: - assert_identical(result8[key], expected8[key]) + assert_identical(result8[key].drop_vars(["x", "z"]), expected8[key]) maxindices_z = { key: xr.where( @@ -5878,7 +5878,7 @@ def test_argmax_dim( result9 = ar.argmax(dim=["z"], skipna=False) for key in expected9: - assert_identical(result9[key], expected9[key]) + assert_identical(result9[key].drop_vars(["x", "y"]), expected9[key]) maxindices_xy = { key: xr.where( @@ -5892,7 +5892,7 @@ def test_argmax_dim( result10 = ar.argmax(dim=("x", "y"), skipna=False) for key in expected10: - assert_identical(result10[key], expected10[key]) + assert_identical(result10[key].drop_vars("z"), expected10[key]) maxindices_xz = { key: xr.where( @@ -5906,7 +5906,7 @@ def test_argmax_dim( result11 = ar.argmax(dim=("x", "z"), skipna=False) for key in expected11: - assert_identical(result11[key], expected11[key]) + assert_identical(result11[key].drop_vars("y"), expected11[key]) maxindices_yz = { key: xr.where( @@ -5920,7 +5920,7 @@ def test_argmax_dim( result12 = ar.argmax(dim=("y", "z"), skipna=False) for key in expected12: - assert_identical(result12[key], expected12[key]) + assert_identical(result12[key].drop_vars("x"), expected12[key]) maxindices_xyz = { key: xr.where( From ab480b5c88a059264086260e5090eb38b98aa7fa Mon Sep 17 00:00:00 2001 From: John Omotani Date: Fri, 10 Apr 2020 16:17:00 +0100 Subject: [PATCH 15/44] Add 'out' keyword to argmin/argmax methods - allow numpy call signature When np.argmin(da) is called, numpy passes an 'out' keyword argument to argmin/argmax. Need to allow this argument to avoid errors (but an exception is thrown if out is not None). --- xarray/core/dataarray.py | 12 ++++++++++-- xarray/core/variable.py | 17 +++++++++++++---- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 94770c3a1ff..c14b3565cb6 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3728,6 +3728,7 @@ def argmin( axis: Union[int, None] = None, keep_attrs: bool = None, skipna: bool = None, + out=None, ) -> Union["DataArray", Dict[Hashable, "DataArray"]]: """Indices of the minimum of the DataArray over one or more dimensions. Result returned as dict of DataArrays, which can be passed directly to isel(). @@ -3752,6 +3753,9 @@ def argmin( skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been implemented (object, datetime64 or timedelta64). + out : None + 'out' should not be passed - provided for compatibility with numpy function + signature Returns ------- @@ -3812,7 +3816,7 @@ def argmin( array([ 1, -5, 1]) Dimensions without coordinates: y """ - result = self.variable.argmin(dim, axis, keep_attrs, skipna) + result = self.variable.argmin(dim, axis, keep_attrs, skipna, out) if isinstance(result, dict): return {k: self._replace_maybe_drop_dims(v) for k, v in result.items()} else: @@ -3824,6 +3828,7 @@ def argmax( axis: Union[int, None] = None, keep_attrs: bool = None, skipna: bool = None, + out=None, ) -> Union["DataArray", Dict[Hashable, "DataArray"]]: """Indices of the maximum of the DataArray over one or more dimensions. Result returned as dict of DataArrays, which can be passed directly to isel(). @@ -3848,6 +3853,9 @@ def argmax( skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been implemented (object, datetime64 or timedelta64). + out : None + 'out' should not be passed - provided for compatibility with numpy function + signature Returns ------- @@ -3909,7 +3917,7 @@ def argmax( array([3, 5, 3]) Dimensions without coordinates: y """ - result = self.variable.argmax(dim, axis, keep_attrs, skipna) + result = self.variable.argmax(dim, axis, keep_attrs, skipna, out) if isinstance(result, dict): return {k: self._replace_maybe_drop_dims(v) for k, v in result.items()} else: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 99ac6cd9fb7..fce7376b790 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2086,6 +2086,7 @@ def _unravel_argminmax( axis: Union[int, None], keep_attrs: Optional[bool], skipna: Optional[bool], + out, ) -> Union["Variable", Dict[Hashable, "Variable"]]: """Apply argmin or argmax over one or more dimensions, returning the result as a dict of DataArray that can be passed directly to isel. @@ -2110,7 +2111,7 @@ def _unravel_argminmax( # Return int index if single dimension is passed, and is not part of a # sequence return getattr(self, "_injected_" + str(argminmax))( - dim=dim, axis=axis, keep_attrs=keep_attrs, skipna=skipna + dim=dim, axis=axis, keep_attrs=keep_attrs, skipna=skipna, out=out ) # Get a name for the new dimension that does not conflict with any existing @@ -2127,7 +2128,7 @@ def _unravel_argminmax( reduce_shape = tuple(self.sizes[d] for d in dim) result_flat_indices = getattr(stacked, "_injected_" + str(argminmax))( - axis=-1, skipna=skipna + axis=-1, skipna=skipna, out=out ) result_unravelled_indices = np.unravel_index(result_flat_indices, reduce_shape) @@ -2151,6 +2152,7 @@ def argmin( axis: Union[int, None] = None, keep_attrs: bool = None, skipna: bool = None, + out=None, ) -> Union["Variable", Dict[Hashable, "Variable"]]: """Indices of the minimum of the DataArray over one or more dimensions. Result returned as dict of DataArrays, which can be passed directly to isel(). @@ -2175,6 +2177,9 @@ def argmin( skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been implemented (object, datetime64 or timedelta64). + out : None + 'out' should not be passed - provided for compatibility with numpy function + signature Returns ------- @@ -2184,7 +2189,7 @@ def argmin( -------- DataArray.argmin, DataArray.idxmin """ - return self._unravel_argminmax("argmin", dim, axis, keep_attrs, skipna) + return self._unravel_argminmax("argmin", dim, axis, keep_attrs, skipna, out) def argmax( self, @@ -2192,6 +2197,7 @@ def argmax( axis: Union[int, None] = None, keep_attrs: bool = None, skipna: bool = None, + out=None, ) -> Union["Variable", Dict[Hashable, "Variable"]]: """Indices of the maximum of the DataArray over one or more dimensions. Result returned as dict of DataArrays, which can be passed directly to isel(). @@ -2216,6 +2222,9 @@ def argmax( skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been implemented (object, datetime64 or timedelta64). + out : None + 'out' should not be passed - provided for compatibility with numpy function + signature Returns ------- @@ -2225,7 +2234,7 @@ def argmax( -------- DataArray.argmax, DataArray.idxmax """ - return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna) + return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna, out) ops.inject_all_ops_and_reduce_methods(Variable) From dca8e457d3192754e4686d06cf5682a5d22bb5c4 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Fri, 10 Apr 2020 17:49:01 +0100 Subject: [PATCH 16/44] Update and correct docstrings for argmin and argmax --- xarray/core/dataarray.py | 28 ++++++++++++++++++---------- xarray/core/variable.py | 28 ++++++++++++++++++---------- 2 files changed, 36 insertions(+), 20 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index c14b3565cb6..259bd72ac97 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3730,17 +3730,21 @@ def argmin( skipna: bool = None, out=None, ) -> Union["DataArray", Dict[Hashable, "DataArray"]]: - """Indices of the minimum of the DataArray over one or more dimensions. Result - returned as dict of DataArrays, which can be passed directly to isel(). + """Index or indices of the minimum of the DataArray over one or more dimensions. + If a sequence is passed to 'dim', then result returned as dict of DataArrays, + which can be passed directly to isel(). If a single str is passed to 'dim' then + returns an int. If there are multiple minima, the indices of the first one found will be returned. Parameters ---------- - dim : hashable or sequence of hashable, optional + dim : hashable, sequence of hashable or ..., optional The dimensions over which to find the minimum. By default, finds minimum over - all dimensions. + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will return a dict with indices for all + dimensions; to return a dict with all dimensions now, pass '...'. axis : int, optional Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments can be supplied. @@ -3759,7 +3763,7 @@ def argmin( Returns ------- - result : dict of DataArray + result : DataArray or dict of DataArray See also -------- @@ -3830,17 +3834,21 @@ def argmax( skipna: bool = None, out=None, ) -> Union["DataArray", Dict[Hashable, "DataArray"]]: - """Indices of the maximum of the DataArray over one or more dimensions. Result - returned as dict of DataArrays, which can be passed directly to isel(). + """Index or indices of the maximum of the DataArray over one or more dimensions. + If a sequence is passed to 'dim', then result returned as dict of DataArrays, + which can be passed directly to isel(). If a single str is passed to 'dim' then + returns an int. If there are multiple maxima, the indices of the first one found will be returned. Parameters ---------- - dim : hashable or sequence of hashable, optional + dim : hashable, sequence of hashable or ..., optional The dimensions over which to find the maximum. By default, finds maximum over - all dimensions. + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will return a dict with indices for all + dimensions; to return a dict with all dimensions now, pass '...'. axis : int, optional Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments can be supplied. @@ -3859,7 +3867,7 @@ def argmax( Returns ------- - result : dict of DataArray + result : DataArray or dict of DataArray See also -------- diff --git a/xarray/core/variable.py b/xarray/core/variable.py index fce7376b790..544b59d106e 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2154,17 +2154,21 @@ def argmin( skipna: bool = None, out=None, ) -> Union["Variable", Dict[Hashable, "Variable"]]: - """Indices of the minimum of the DataArray over one or more dimensions. Result - returned as dict of DataArrays, which can be passed directly to isel(). + """Index or indices of the minimum of the Variable over one or more dimensions. + If a sequence is passed to 'dim', then result returned as dict of Variables, + which can be passed directly to isel(). If a single str is passed to 'dim' then + returns an int. If there are multiple minima, the indices of the first one found will be returned. Parameters ---------- - dim : hashable or sequence of hashable, optional + dim : hashable, sequence of hashable or ..., optional The dimensions over which to find the minimum. By default, finds minimum over - all dimensions. + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will return a dict with indices for all + dimensions; to return a dict with all dimensions now, pass '...'. axis : int, optional Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments can be supplied. @@ -2183,7 +2187,7 @@ def argmin( Returns ------- - result : dict of DataArray + result : Variable or dict of Variable See also -------- @@ -2199,17 +2203,21 @@ def argmax( skipna: bool = None, out=None, ) -> Union["Variable", Dict[Hashable, "Variable"]]: - """Indices of the maximum of the DataArray over one or more dimensions. Result - returned as dict of DataArrays, which can be passed directly to isel(). + """Index or indices of the maximum of the Variable over one or more dimensions. + If a sequence is passed to 'dim', then result returned as dict of Variables, + which can be passed directly to isel(). If a single str is passed to 'dim' then + returns an int. If there are multiple maxima, the indices of the first one found will be returned. Parameters ---------- - dim : hashable or sequence of hashable, optional + dim : hashable, sequence of hashable or ..., optional The dimensions over which to find the maximum. By default, finds maximum over - all dimensions. + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will return a dict with indices for all + dimensions; to return a dict with all dimensions now, pass '...'. axis : int, optional Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments can be supplied. @@ -2228,7 +2236,7 @@ def argmax( Returns ------- - result : dict of DataArray + result : Variable or dict of Variable See also -------- From 52554b6a7c24702663d3c2298a7deefba59f14cc Mon Sep 17 00:00:00 2001 From: John Omotani Date: Fri, 10 Apr 2020 17:55:42 +0100 Subject: [PATCH 17/44] Correct suggested replacement for da.argmin() and da.argmax() --- xarray/core/variable.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 544b59d106e..2aa7da2846a 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2095,8 +2095,8 @@ def _unravel_argminmax( warnings.warn( "Behaviour of argmin/argmax with neither dim nor axis argument will " "change to return a dict of indices of each dimension. To get a " - "single, flat index, please use np.argmin(da) or np.argmax(da) instead " - "of da.argmin() or da.argmax().", + "single, flat index, please use np.argmin(da.data) or " + "np.argmax(da.data) instead of da.argmin() or da.argmax().", DeprecationWarning, ) if dim is ...: From ef826f69bdeb63e1a21790a48c3ff3aafe3035e3 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 21 Apr 2020 00:44:17 -0700 Subject: [PATCH 18/44] Remove use of _injected_ methods in argmin/argmax --- xarray/core/variable.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 2aa7da2846a..0aa67dda1ea 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2099,6 +2099,9 @@ def _unravel_argminmax( "np.argmax(da.data) instead of da.argmin() or da.argmax().", DeprecationWarning, ) + + argminax_func = getattr(duck_array_ops, argminmax) + if dim is ...: # In future, should do this also when (dim is None and axis is None) dim = self.dims @@ -2110,8 +2113,8 @@ def _unravel_argminmax( ): # Return int index if single dimension is passed, and is not part of a # sequence - return getattr(self, "_injected_" + str(argminmax))( - dim=dim, axis=axis, keep_attrs=keep_attrs, skipna=skipna, out=out + return self.reduce( + argminax_func, dim=dim, axis=axis, keep_attrs=keep_attrs, skipna=skipna ) # Get a name for the new dimension that does not conflict with any existing @@ -2127,11 +2130,11 @@ def _unravel_argminmax( result_dims = stacked.dims[:-1] reduce_shape = tuple(self.sizes[d] for d in dim) - result_flat_indices = getattr(stacked, "_injected_" + str(argminmax))( - axis=-1, skipna=skipna, out=out - ) + result_flat_indices = stacked.reduce(argminax_func, axis=-1, skipna=skipna) - result_unravelled_indices = np.unravel_index(result_flat_indices, reduce_shape) + result_unravelled_indices = np.unravel_index( + result_flat_indices.data, reduce_shape + ) result = { d: Variable(dims=result_dims, data=i) From 8a7c7adc6369b08ca64d2927c46c8bc32092cc7a Mon Sep 17 00:00:00 2001 From: johnomotani Date: Tue, 21 Apr 2020 18:32:17 +0100 Subject: [PATCH 19/44] Fix typo in name of argminmax_func Co-Authored-By: keewis --- xarray/core/variable.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 0aa67dda1ea..9c20c547748 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2100,7 +2100,7 @@ def _unravel_argminmax( DeprecationWarning, ) - argminax_func = getattr(duck_array_ops, argminmax) + argminmax_func = getattr(duck_array_ops, argminmax) if dim is ...: # In future, should do this also when (dim is None and axis is None) @@ -2114,7 +2114,7 @@ def _unravel_argminmax( # Return int index if single dimension is passed, and is not part of a # sequence return self.reduce( - argminax_func, dim=dim, axis=axis, keep_attrs=keep_attrs, skipna=skipna + argminmax_func, dim=dim, axis=axis, keep_attrs=keep_attrs, skipna=skipna ) # Get a name for the new dimension that does not conflict with any existing @@ -2130,7 +2130,7 @@ def _unravel_argminmax( result_dims = stacked.dims[:-1] reduce_shape = tuple(self.sizes[d] for d in dim) - result_flat_indices = stacked.reduce(argminax_func, axis=-1, skipna=skipna) + result_flat_indices = stacked.reduce(argminmax_func, axis=-1, skipna=skipna) result_unravelled_indices = np.unravel_index( result_flat_indices.data, reduce_shape From e56e2e7f175ea17e202e911da0dbcc8142acbdcb Mon Sep 17 00:00:00 2001 From: johnomotani Date: Tue, 21 Apr 2020 21:09:12 +0100 Subject: [PATCH 20/44] Mark argminmax argument to _unravel_argminmax as a string Co-Authored-By: keewis --- xarray/core/variable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 9c20c547748..92dcdb8ea6a 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2081,7 +2081,7 @@ def _to_numeric(self, offset=None, datetime_unit=None, dtype=float): def _unravel_argminmax( self, - argminmax: Hashable, + argminmax: str, dim: Union[Hashable, Sequence[Hashable], None], axis: Union[int, None], keep_attrs: Optional[bool], From a99697ab883a38f6f59acccc6a12585ad733bacc Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 21 Apr 2020 18:09:09 +0100 Subject: [PATCH 21/44] Hidden internal methods don't need to appear in docs --- doc/api-hidden.rst | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index e6edfe68e4b..cc9517a98ba 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -18,8 +18,6 @@ Dataset.any Dataset.argmax Dataset.argmin - Dataset._injected_argmax - Dataset._injected_argmin Dataset.max Dataset.min Dataset.mean @@ -96,8 +94,6 @@ core.resample.DatasetResample.ffill core.resample.DatasetResample.fillna core.resample.DatasetResample.first - core.resample.DatasetResample._injected_argmax - core.resample.DatasetResample._injected_argmin core.resample.DatasetResample.last core.resample.DatasetResample.map core.resample.DatasetResample.max @@ -164,8 +160,6 @@ DataArray.any DataArray.argmax DataArray.argmin - DataArray._injected_argmax - DataArray._injected_argmin DataArray.max DataArray.min DataArray.mean @@ -240,8 +234,6 @@ core.resample.DataArrayResample.ffill core.resample.DataArrayResample.fillna core.resample.DataArrayResample.first - core.resample.DataArrayResample._injected_argmax - core.resample.DataArrayResample._injected_argmin core.resample.DataArrayResample.last core.resample.DataArrayResample.map core.resample.DataArrayResample.max @@ -377,8 +369,6 @@ Variable.fillna Variable.get_axis_num Variable.identical - Variable._injected_argmax - Variable._injected_argmin Variable.isel Variable.isnull Variable.item @@ -452,8 +442,6 @@ IndexVariable.get_axis_num IndexVariable.get_level_variable IndexVariable.identical - IndexVariable._injected_argmax - IndexVariable._injected_argmin IndexVariable.isel IndexVariable.isnull IndexVariable.item From a785c3498cf70dec8399d5cc3d6d3b77a437a7fa Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 21 Apr 2020 18:13:37 +0100 Subject: [PATCH 22/44] Basic docstrings for Dataset.argmin() and Dataset.argmax() --- xarray/core/dataset.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 8bf6ff2554c..b0e2aab7ea2 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6295,6 +6295,8 @@ def idxmax( ) def argmin(self, dim=None, axis=None, **kwargs): + """Apply argmin() to each variable in the Dataset + """ if dim is None and axis is None: warnings.warn( "Behaviour of DataArray.argmin() with neither dim nor axis argument " @@ -6320,6 +6322,8 @@ def argmin(self, dim=None, axis=None, **kwargs): ) def argmax(self, dim=None, axis=None, **kwargs): + """Apply argmax() to each variable in the Dataset + """ if dim is None and axis is None: warnings.warn( "Behaviour of DataArray.argmin() with neither dim nor axis argument " From ac897d4167902bb537890d005fe395c0939f1890 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 21 Apr 2020 18:24:41 +0100 Subject: [PATCH 23/44] Set stacklevel for DeprecationWarning in argmin/argmax methods --- xarray/core/dataset.py | 2 ++ xarray/core/variable.py | 1 + 2 files changed, 3 insertions(+) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index b0e2aab7ea2..6325599c38b 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6304,6 +6304,7 @@ def argmin(self, dim=None, axis=None, **kwargs): "will be an error to call Dataset.argmin() with no argument. To get a " "single, flat index, please use np.argmin(ds) instead of ds.argmin().", DeprecationWarning, + stacklevel=2, ) if ( dim is None @@ -6331,6 +6332,7 @@ def argmax(self, dim=None, axis=None, **kwargs): "will be an error to call Dataset.argmin() with no argument. To get a " "single, flat index, please use np.argmin(ds) instead of ds.argmin().", DeprecationWarning, + stacklevel=2, ) if ( dim is None diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 92dcdb8ea6a..7ec257d4701 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2098,6 +2098,7 @@ def _unravel_argminmax( "single, flat index, please use np.argmin(da.data) or " "np.argmax(da.data) instead of da.argmin() or da.argmax().", DeprecationWarning, + stacklevel=3, ) argminmax_func = getattr(duck_array_ops, argminmax) From 752518e788bca317907f0721a2518955bcc86ea2 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 21 Apr 2020 19:44:55 +0100 Subject: [PATCH 24/44] Revert "Explicitly defined class methods override injected methods" This reverts commit 8caf2b8d07c14a2956a26b50ee08d83323c36058. --- xarray/core/ops.py | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/xarray/core/ops.py b/xarray/core/ops.py index 57a86bab7b1..b789f93b4f1 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -260,9 +260,6 @@ def inject_reduce_methods(cls): + [("count", duck_array_ops.count, False)] ) for name, f, include_skipna in methods: - if hasattr(cls, name): - name = "_injected_" + name - numeric_only = getattr(f, "numeric_only", False) available_min_count = getattr(f, "available_min_count", False) min_count_docs = _MINCOUNT_DOCSTRING if available_min_count else "" @@ -281,9 +278,6 @@ def inject_reduce_methods(cls): def inject_cum_methods(cls): methods = [(name, getattr(duck_array_ops, name), True) for name in NAN_CUM_METHODS] for name, f, include_skipna in methods: - if hasattr(cls, name): - name = "_injected_" + name - numeric_only = getattr(f, "numeric_only", False) func = cls._reduce_method(f, include_skipna, numeric_only) func.__name__ = name @@ -331,35 +325,24 @@ def inject_all_ops_and_reduce_methods(cls, priority=50, array_only=True): # patch in standard special operations for name in UNARY_OPS: - if hasattr(cls, op_str(name)): - name = "_injected_" + name setattr(cls, op_str(name), cls._unary_op(get_op(name))) inject_binary_ops(cls, inplace=True) # patch in numpy/pandas methods for name in NUMPY_UNARY_METHODS: - if hasattr(cls, op_str(name)): - name = "_injected_" + name setattr(cls, name, cls._unary_op(_method_wrapper(name))) for name in PANDAS_UNARY_FUNCTIONS: f = _func_slash_method_wrapper(getattr(duck_array_ops, name), name=name) - if hasattr(cls, op_str(name)): - name = "_injected_" + name setattr(cls, name, cls._unary_op(f)) f = _func_slash_method_wrapper(duck_array_ops.around, name="round") - if hasattr(cls, "round"): - setattr(cls, "_injected_round", cls._unary_op(f)) - else: - setattr(cls, "round", cls._unary_op(f)) + setattr(cls, "round", cls._unary_op(f)) if array_only: # these methods don't return arrays of the same shape as the input, so # don't try to patch these in for Dataset objects for name in NUMPY_SAME_METHODS: - if hasattr(cls, op_str(name)): - name = "_injected_" + name setattr(cls, name, _values_method_wrapper(name)) inject_reduce_methods(cls) From 8b7365ba25c43f432b85e6e0dbdf66d846d698f2 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 21 Apr 2020 19:48:44 +0100 Subject: [PATCH 25/44] Revert "Add 'out' keyword to argmin/argmax methods - allow numpy call signature" This reverts commit ab480b5c88a059264086260e5090eb38b98aa7fa. --- xarray/core/dataarray.py | 12 ++---------- xarray/core/variable.py | 13 ++----------- 2 files changed, 4 insertions(+), 21 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 259bd72ac97..fc6e6831ad7 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3728,7 +3728,6 @@ def argmin( axis: Union[int, None] = None, keep_attrs: bool = None, skipna: bool = None, - out=None, ) -> Union["DataArray", Dict[Hashable, "DataArray"]]: """Index or indices of the minimum of the DataArray over one or more dimensions. If a sequence is passed to 'dim', then result returned as dict of DataArrays, @@ -3757,9 +3756,6 @@ def argmin( skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been implemented (object, datetime64 or timedelta64). - out : None - 'out' should not be passed - provided for compatibility with numpy function - signature Returns ------- @@ -3820,7 +3816,7 @@ def argmin( array([ 1, -5, 1]) Dimensions without coordinates: y """ - result = self.variable.argmin(dim, axis, keep_attrs, skipna, out) + result = self.variable.argmin(dim, axis, keep_attrs, skipna) if isinstance(result, dict): return {k: self._replace_maybe_drop_dims(v) for k, v in result.items()} else: @@ -3832,7 +3828,6 @@ def argmax( axis: Union[int, None] = None, keep_attrs: bool = None, skipna: bool = None, - out=None, ) -> Union["DataArray", Dict[Hashable, "DataArray"]]: """Index or indices of the maximum of the DataArray over one or more dimensions. If a sequence is passed to 'dim', then result returned as dict of DataArrays, @@ -3861,9 +3856,6 @@ def argmax( skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been implemented (object, datetime64 or timedelta64). - out : None - 'out' should not be passed - provided for compatibility with numpy function - signature Returns ------- @@ -3925,7 +3917,7 @@ def argmax( array([3, 5, 3]) Dimensions without coordinates: y """ - result = self.variable.argmax(dim, axis, keep_attrs, skipna, out) + result = self.variable.argmax(dim, axis, keep_attrs, skipna) if isinstance(result, dict): return {k: self._replace_maybe_drop_dims(v) for k, v in result.items()} else: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 7ec257d4701..77310612403 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2086,7 +2086,6 @@ def _unravel_argminmax( axis: Union[int, None], keep_attrs: Optional[bool], skipna: Optional[bool], - out, ) -> Union["Variable", Dict[Hashable, "Variable"]]: """Apply argmin or argmax over one or more dimensions, returning the result as a dict of DataArray that can be passed directly to isel. @@ -2156,7 +2155,6 @@ def argmin( axis: Union[int, None] = None, keep_attrs: bool = None, skipna: bool = None, - out=None, ) -> Union["Variable", Dict[Hashable, "Variable"]]: """Index or indices of the minimum of the Variable over one or more dimensions. If a sequence is passed to 'dim', then result returned as dict of Variables, @@ -2185,9 +2183,6 @@ def argmin( skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been implemented (object, datetime64 or timedelta64). - out : None - 'out' should not be passed - provided for compatibility with numpy function - signature Returns ------- @@ -2197,7 +2192,7 @@ def argmin( -------- DataArray.argmin, DataArray.idxmin """ - return self._unravel_argminmax("argmin", dim, axis, keep_attrs, skipna, out) + return self._unravel_argminmax("argmin", dim, axis, keep_attrs, skipna) def argmax( self, @@ -2205,7 +2200,6 @@ def argmax( axis: Union[int, None] = None, keep_attrs: bool = None, skipna: bool = None, - out=None, ) -> Union["Variable", Dict[Hashable, "Variable"]]: """Index or indices of the maximum of the Variable over one or more dimensions. If a sequence is passed to 'dim', then result returned as dict of Variables, @@ -2234,9 +2228,6 @@ def argmax( skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been implemented (object, datetime64 or timedelta64). - out : None - 'out' should not be passed - provided for compatibility with numpy function - signature Returns ------- @@ -2246,7 +2237,7 @@ def argmax( -------- DataArray.argmax, DataArray.idxmax """ - return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna, out) + return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna) ops.inject_all_ops_and_reduce_methods(Variable) From 46b04a66582344eda65cc6e33a8783d65b7951e2 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 21 Apr 2020 20:29:01 +0100 Subject: [PATCH 26/44] Remove argmin and argmax from ops.py --- xarray/core/ops.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/xarray/core/ops.py b/xarray/core/ops.py index b789f93b4f1..d4aeea37aad 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -47,8 +47,6 @@ # methods which remove an axis REDUCE_METHODS = ["all", "any"] NAN_REDUCE_METHODS = [ - "argmax", - "argmin", "max", "min", "mean", From 1ef3c9720c0a5c91363dbc1652fa88665e9d55e8 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 21 Apr 2020 20:29:17 +0100 Subject: [PATCH 27/44] Use self.reduce() in Dataset.argmin() and Dataset.argmax() Replaces need for "_injected_argmin" and "_injected_argmax". --- xarray/core/dataset.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 6325599c38b..1bbf79f479c 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6314,7 +6314,8 @@ def argmin(self, dim=None, axis=None, **kwargs): ): # Return int index if single dimension is passed, and is not part of a # sequence - return getattr(self, "_injected_argmin")(dim=dim, axis=axis, **kwargs) + argmin_func = getattr(duck_array_ops, "argmin") + return self.reduce(argmin_func, dim=dim, axis=axis, **kwargs) else: raise ValueError( "When dim is a sequence, DataArray.argmin() returns a " @@ -6342,7 +6343,8 @@ def argmax(self, dim=None, axis=None, **kwargs): ): # Return int index if single dimension is passed, and is not part of a # sequence - return getattr(self, "_injected_argmax")(dim=dim, axis=axis, **kwargs) + argmax_func = getattr(duck_array_ops, "argmax") + return self.reduce(argmax_func, dim=dim, axis=axis, **kwargs) else: raise ValueError( "When dim is a sequence, DataArray.argmax() returns a " From 65ca2adb31fb346ddc692224c2ae6d5c704da9e5 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 21 Apr 2020 20:50:05 +0100 Subject: [PATCH 28/44] Whitespace after 'title' lines in docstrings --- xarray/core/dataarray.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index fc6e6831ad7..3862bc33689 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3730,6 +3730,7 @@ def argmin( skipna: bool = None, ) -> Union["DataArray", Dict[Hashable, "DataArray"]]: """Index or indices of the minimum of the DataArray over one or more dimensions. + If a sequence is passed to 'dim', then result returned as dict of DataArrays, which can be passed directly to isel(). If a single str is passed to 'dim' then returns an int. @@ -3830,6 +3831,7 @@ def argmax( skipna: bool = None, ) -> Union["DataArray", Dict[Hashable, "DataArray"]]: """Index or indices of the maximum of the DataArray over one or more dimensions. + If a sequence is passed to 'dim', then result returned as dict of DataArrays, which can be passed directly to isel(). If a single str is passed to 'dim' then returns an int. From 1736abf114173e1e36e2a8ad297b32d4017492e8 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 21 Apr 2020 20:50:23 +0100 Subject: [PATCH 29/44] Remove tests of np.argmax() and np.argmin() functions from test_units.py Applying numpy functions to xarray objects is not necessarily expected to work, and the wrapping of argmin() and argmax() is broken by xarray-specific interface of argmin() and argmax() methods of Variable, DataArray and Dataset. --- xarray/tests/test_units.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 2826dc2479c..3d4d7a49c1d 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -2232,8 +2232,20 @@ def test_repr(self, func, variant, dtype): function("any"), marks=pytest.mark.xfail(reason="not implemented by pint yet"), ), - function("argmax"), - function("argmin"), + pytest.param( + function("argmax"), + marks=pytest.mark.skip( + reason="calling np.argmax as a function on xarray objects is not " + "supported" + ) + ), + pytest.param( + function("argmin"), + marks=pytest.mark.skip( + reason="calling np.argmin as a function on xarray objects is not " + "supported" + ) + ), function("max"), function("mean"), pytest.param( @@ -3733,8 +3745,20 @@ def test_repr(self, func, variant, dtype): function("any"), marks=pytest.mark.xfail(reason="not implemented by pint"), ), - function("argmax"), - function("argmin"), + pytest.param( + function("argmax"), + marks=pytest.mark.skip( + reason="calling np.argmax as a function on xarray objects is not " + "supported" + ) + ), + pytest.param( + function("argmin"), + marks=pytest.mark.skip( + reason="calling np.argmin as a function on xarray objects is not " + "supported" + ) + ), function("max"), function("min"), function("mean"), From d9b55ee782de8a43663f0e4789e4874e34d21e01 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 21 Apr 2020 18:13:57 +0100 Subject: [PATCH 30/44] Clearer deprecation warnings in Dataset.argmin() and Dataset.argmax() Also, previously suggested workaround was not correct. Remove suggestion as there is no workaround (but the removed behaviour is unlikely to be useful). --- xarray/core/dataset.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 1bbf79f479c..5a64b4aeb8a 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6299,10 +6299,11 @@ def argmin(self, dim=None, axis=None, **kwargs): """ if dim is None and axis is None: warnings.warn( - "Behaviour of DataArray.argmin() with neither dim nor axis argument " - "will change to return a dict of indices of each dimension, and then it " - "will be an error to call Dataset.argmin() with no argument. To get a " - "single, flat index, please use np.argmin(ds) instead of ds.argmin().", + "Once the behaviour of DataArray.argmin() and Variable.argmin() with " + "neither dim nor axis argument changes to return a dict of indices of " + "each dimension, for consistency it will be an error to call " + "Dataset.argmin() with no argument, since we don't return a dict of " + "Datasets.", DeprecationWarning, stacklevel=2, ) @@ -6328,10 +6329,11 @@ def argmax(self, dim=None, axis=None, **kwargs): """ if dim is None and axis is None: warnings.warn( - "Behaviour of DataArray.argmin() with neither dim nor axis argument " - "will change to return a dict of indices of each dimension, and then it " - "will be an error to call Dataset.argmin() with no argument. To get a " - "single, flat index, please use np.argmin(ds) instead of ds.argmin().", + "Once the behaviour of DataArray.argmax() and Variable.argmax() with " + "neither dim nor axis argument changes to return a dict of indices of " + "each dimension, for consistency it will be an error to call " + "Dataset.argmax() with no argument, since we don't return a dict of " + "Datasets.", DeprecationWarning, stacklevel=2, ) From 432dfbb0a97aa367e774dce2a0274521b7dc3a47 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 21 Apr 2020 21:12:42 +0100 Subject: [PATCH 31/44] Add unravel_index to duck_array_ops, use in Variable._unravel_argminmax --- xarray/core/duck_array_ops.py | 1 + xarray/core/variable.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 1340b456cf2..9c4020a15b8 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -339,6 +339,7 @@ def f(values, axis=None, skipna=None, **kwargs): cumprod_1d.numeric_only = True cumsum_1d = _create_nan_agg_method("cumsum") cumsum_1d.numeric_only = True +unravel_index = _dask_or_eager_func("unravel_index") _mean = _create_nan_agg_method("mean") diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 77310612403..807b2ebd682 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2132,7 +2132,7 @@ def _unravel_argminmax( result_flat_indices = stacked.reduce(argminmax_func, axis=-1, skipna=skipna) - result_unravelled_indices = np.unravel_index( + result_unravelled_indices = duck_array_ops.unravel_index( result_flat_indices.data, reduce_shape ) From 20b448ae7989c12e986c30d4e0650edb477dddf4 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 21 Apr 2020 22:25:50 +0100 Subject: [PATCH 32/44] Filter argmin/argmax DeprecationWarnings in tests --- xarray/tests/test_dataarray.py | 18 ++++++++++++++++++ xarray/tests/test_dataset.py | 6 ++++++ xarray/tests/test_units.py | 28 ++++++++++++++++++++-------- xarray/tests/test_variable.py | 3 +++ 4 files changed, 47 insertions(+), 8 deletions(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 0d5624afdc9..82b816ce70a 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -4464,6 +4464,9 @@ def test_max(self, x, minindex, maxindex, nanindex): assert_identical(result2, expected2) + @pytest.mark.filterwarnings( + "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" + ) def test_argmin(self, x, minindex, maxindex, nanindex): ar = xr.DataArray( x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs @@ -4493,6 +4496,9 @@ def test_argmin(self, x, minindex, maxindex, nanindex): assert_identical(result2, expected2) + @pytest.mark.filterwarnings( + "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" + ) def test_argmax(self, x, minindex, maxindex, nanindex): ar = xr.DataArray( x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs @@ -4714,6 +4720,9 @@ def test_idxmax(self, x, minindex, maxindex, nanindex): result7 = ar0.idxmax(fill_value=-1j) assert_identical(result7, expected7) + @pytest.mark.filterwarnings( + "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" + ) def test_argmin_dim(self, x, minindex, maxindex, nanindex): ar = xr.DataArray( x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs @@ -4747,6 +4756,9 @@ def test_argmin_dim(self, x, minindex, maxindex, nanindex): for key in expected2: assert_identical(result2[key], expected2[key]) + @pytest.mark.filterwarnings( + "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" + ) def test_argmax_dim(self, x, minindex, maxindex, nanindex): ar = xr.DataArray( x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs @@ -5223,6 +5235,9 @@ def test_idxmax(self, x, minindex, maxindex, nanindex): result7 = ar0.idxmax(dim="x", fill_value=-5j) assert_identical(result7, expected7) + @pytest.mark.filterwarnings( + "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" + ) def test_argmin_dim(self, x, minindex, maxindex, nanindex): ar = xr.DataArray( x, @@ -5280,6 +5295,9 @@ def test_argmin_dim(self, x, minindex, maxindex, nanindex): for key in expected3: assert_identical(result3[key], expected3[key]) + @pytest.mark.filterwarnings( + "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" + ) def test_argmax_dim(self, x, minindex, maxindex, nanindex): ar = xr.DataArray( x, diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index a1cb7361e77..499a5fd662a 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4596,6 +4596,9 @@ def test_reduce_non_numeric(self): assert_equal(data1.mean(), data2.mean()) assert_equal(data1.mean(dim="dim1"), data2.mean(dim="dim1")) + @pytest.mark.filterwarnings( + "ignore:Once the behaviour of DataArray:DeprecationWarning" + ) def test_reduce_strings(self): expected = Dataset({"x": "a"}) ds = Dataset({"x": ("y", ["a", "b"])}) @@ -4667,6 +4670,9 @@ def test_reduce_keep_attrs(self): for k, v in ds.data_vars.items(): assert v.attrs == data[k].attrs + @pytest.mark.filterwarnings( + "ignore:Once the behaviour of DataArray:DeprecationWarning" + ) def test_reduce_argmin(self): # regression test for #205 ds = Dataset({"a": ("x", [0, 1])}) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 3d4d7a49c1d..b3b836d365d 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1437,6 +1437,9 @@ def example_1d_objects(self): ), ids=repr, ) + @pytest.mark.filterwarnings( + "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" + ) def test_aggregation(self, func, dtype): array = np.linspace(0, 1, 10).astype(dtype) * ( unit_registry.m if func.name != "cumprod" else unit_registry.dimensionless @@ -2236,15 +2239,15 @@ def test_repr(self, func, variant, dtype): function("argmax"), marks=pytest.mark.skip( reason="calling np.argmax as a function on xarray objects is not " - "supported" - ) + "supported" + ), ), pytest.param( function("argmin"), marks=pytest.mark.skip( reason="calling np.argmin as a function on xarray objects is not " - "supported" - ) + "supported" + ), ), function("max"), function("mean"), @@ -2296,6 +2299,9 @@ def test_repr(self, func, variant, dtype): ), ids=repr, ) + @pytest.mark.filterwarnings( + "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" + ) def test_aggregation(self, func, dtype): array = np.arange(10).astype(dtype) * ( unit_registry.m if func.name != "cumprod" else unit_registry.dimensionless @@ -3749,15 +3755,15 @@ def test_repr(self, func, variant, dtype): function("argmax"), marks=pytest.mark.skip( reason="calling np.argmax as a function on xarray objects is not " - "supported" - ) + "supported" + ), ), pytest.param( function("argmin"), marks=pytest.mark.skip( reason="calling np.argmin as a function on xarray objects is not " - "supported" - ) + "supported" + ), ), function("max"), function("min"), @@ -3806,6 +3812,12 @@ def test_repr(self, func, variant, dtype): ), ids=repr, ) + @pytest.mark.filterwarnings( + "ignore:Once the behaviour of DataArray:DeprecationWarning" + ) + @pytest.mark.filterwarnings( + "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" + ) def test_aggregation(self, func, dtype): unit_a = ( unit_registry.Pa if func.name != "cumprod" else unit_registry.dimensionless diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 78e3848b8fb..35aab69c955 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1639,6 +1639,9 @@ def test_big_endian_reduce(self): expected = Variable([], 5) assert_identical(expected, v.sum()) + @pytest.mark.filterwarnings( + "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" + ) def test_reduce_funcs(self): v = Variable("x", np.array([1, np.nan, 2, 3])) assert_identical(v.mean(), Variable([], 2)) From 95845f9ad44822673c63a06177434a20fb7a2142 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 21 Apr 2020 22:26:14 +0100 Subject: [PATCH 33/44] Correct test for exception for nan in test_argmax --- xarray/tests/test_dataarray.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 82b816ce70a..f5811616eee 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -4765,9 +4765,9 @@ def test_argmax_dim(self, x, minindex, maxindex, nanindex): ) indarr = xr.DataArray(np.arange(x.size, dtype=np.intp), dims=["x"]) - if np.isnan(minindex): + if np.isnan(maxindex): with pytest.raises(ValueError): - ar.argmin() + ar.argmax() return expected0 = {"x": indarr[maxindex]} From 0ee5146521bffaab438c32644cf170e6a7846d6c Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 21 Apr 2020 23:55:44 +0100 Subject: [PATCH 34/44] Remove injected argmin and argmax methods from api-hidden.rst --- doc/api-hidden.rst | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index cc9517a98ba..e7de00b4081 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -41,8 +41,6 @@ core.rolling.DatasetCoarsen.all core.rolling.DatasetCoarsen.any - core.rolling.DatasetCoarsen.argmax - core.rolling.DatasetCoarsen.argmin core.rolling.DatasetCoarsen.count core.rolling.DatasetCoarsen.max core.rolling.DatasetCoarsen.mean @@ -68,8 +66,6 @@ core.groupby.DatasetGroupBy.where core.groupby.DatasetGroupBy.all core.groupby.DatasetGroupBy.any - core.groupby.DatasetGroupBy.argmax - core.groupby.DatasetGroupBy.argmin core.groupby.DatasetGroupBy.count core.groupby.DatasetGroupBy.max core.groupby.DatasetGroupBy.mean @@ -85,8 +81,6 @@ core.resample.DatasetResample.all core.resample.DatasetResample.any core.resample.DatasetResample.apply - core.resample.DatasetResample.argmax - core.resample.DatasetResample.argmin core.resample.DatasetResample.assign core.resample.DatasetResample.assign_coords core.resample.DatasetResample.bfill @@ -110,8 +104,6 @@ core.resample.DatasetResample.dims core.resample.DatasetResample.groups - core.rolling.DatasetRolling.argmax - core.rolling.DatasetRolling.argmin core.rolling.DatasetRolling.count core.rolling.DatasetRolling.max core.rolling.DatasetRolling.mean @@ -183,8 +175,6 @@ core.rolling.DataArrayCoarsen.all core.rolling.DataArrayCoarsen.any - core.rolling.DataArrayCoarsen.argmax - core.rolling.DataArrayCoarsen.argmin core.rolling.DataArrayCoarsen.count core.rolling.DataArrayCoarsen.max core.rolling.DataArrayCoarsen.mean @@ -209,8 +199,6 @@ core.groupby.DataArrayGroupBy.where core.groupby.DataArrayGroupBy.all core.groupby.DataArrayGroupBy.any - core.groupby.DataArrayGroupBy.argmax - core.groupby.DataArrayGroupBy.argmin core.groupby.DataArrayGroupBy.count core.groupby.DataArrayGroupBy.max core.groupby.DataArrayGroupBy.mean @@ -226,8 +214,6 @@ core.resample.DataArrayResample.all core.resample.DataArrayResample.any core.resample.DataArrayResample.apply - core.resample.DataArrayResample.argmax - core.resample.DataArrayResample.argmin core.resample.DataArrayResample.assign_coords core.resample.DataArrayResample.bfill core.resample.DataArrayResample.count @@ -250,8 +236,6 @@ core.resample.DataArrayResample.dims core.resample.DataArrayResample.groups - core.rolling.DataArrayRolling.argmax - core.rolling.DataArrayRolling.argmin core.rolling.DataArrayRolling.count core.rolling.DataArrayRolling.max core.rolling.DataArrayRolling.mean @@ -421,8 +405,6 @@ IndexVariable.all IndexVariable.any - IndexVariable.argmax - IndexVariable.argmin IndexVariable.argsort IndexVariable.astype IndexVariable.broadcast_equals @@ -562,8 +544,6 @@ CFTimeIndex.all CFTimeIndex.any CFTimeIndex.append - CFTimeIndex.argmax - CFTimeIndex.argmin CFTimeIndex.argsort CFTimeIndex.asof CFTimeIndex.asof_locs From d02918389fb28e9995dfa6f1281e883c82c036a0 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Wed, 24 Jun 2020 21:07:11 +0100 Subject: [PATCH 35/44] flake8 fixes --- xarray/tests/test_dataarray.py | 84 ++++++++++++++++++++++------------ 1 file changed, 56 insertions(+), 28 deletions(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index e24ac45dfb3..acba66f4e80 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -5739,8 +5739,10 @@ def test_argmin_dim( minindices_x = { key: xr.where( - nanindices_x[key] == None, minindices_x[key], nanindices_x[key], - ) # noqa: E711 + nanindices_x[key] == None, # noqa: E711 + minindices_x[key], + nanindices_x[key], + ) for key in minindices_x } expected7 = { @@ -5754,8 +5756,10 @@ def test_argmin_dim( minindices_y = { key: xr.where( - nanindices_y[key] == None, minindices_y[key], nanindices_y[key], - ) # noqa: E711 + nanindices_y[key] == None, # noqa: E711 + minindices_y[key], + nanindices_y[key], + ) for key in minindices_y } expected8 = { @@ -5769,8 +5773,10 @@ def test_argmin_dim( minindices_z = { key: xr.where( - nanindices_z[key] == None, minindices_z[key], nanindices_z[key], - ) # noqa: E711 + nanindices_z[key] == None, # noqa: E711 + minindices_z[key], + nanindices_z[key], + ) for key in minindices_z } expected9 = { @@ -5784,8 +5790,10 @@ def test_argmin_dim( minindices_xy = { key: xr.where( - nanindices_xy[key] == None, minindices_xy[key], nanindices_xy[key], - ) # noqa: E711 + nanindices_xy[key] == None, # noqa: E711 + minindices_xy[key], + nanindices_xy[key], + ) for key in minindices_xy } expected10 = { @@ -5798,8 +5806,10 @@ def test_argmin_dim( minindices_xz = { key: xr.where( - nanindices_xz[key] == None, minindices_xz[key], nanindices_xz[key], - ) # noqa: E711 + nanindices_xz[key] == None, # noqa: E711 + minindices_xz[key], + nanindices_xz[key], + ) for key in minindices_xz } expected11 = { @@ -5812,8 +5822,10 @@ def test_argmin_dim( minindices_yz = { key: xr.where( - nanindices_yz[key] == None, minindices_yz[key], nanindices_yz[key], - ) # noqa: E711 + nanindices_yz[key] == None, # noqa: E711 + minindices_yz[key], + nanindices_yz[key], + ) for key in minindices_yz } expected12 = { @@ -5826,8 +5838,10 @@ def test_argmin_dim( minindices_xyz = { key: xr.where( - nanindices_xyz[key] == None, minindices_xyz[key], nanindices_xyz[key], - ) # noqa: E711 + nanindices_xyz[key] == None, # noqa: E711 + minindices_xyz[key], + nanindices_xyz[key], + ) for key in minindices_xyz } expected13 = {key: xr.DataArray(value) for key, value in minindices_xyz.items()} @@ -5954,8 +5968,10 @@ def test_argmax_dim( maxindices_x = { key: xr.where( - nanindices_x[key] == None, maxindices_x[key], nanindices_x[key], - ) # noqa: E711 + nanindices_x[key] == None, # noqa: E711 + maxindices_x[key], + nanindices_x[key], + ) for key in maxindices_x } expected7 = { @@ -5969,8 +5985,10 @@ def test_argmax_dim( maxindices_y = { key: xr.where( - nanindices_y[key] == None, maxindices_y[key], nanindices_y[key], - ) # noqa: E711 + nanindices_y[key] == None, # noqa: E711 + maxindices_y[key], + nanindices_y[key], + ) for key in maxindices_y } expected8 = { @@ -5984,8 +6002,10 @@ def test_argmax_dim( maxindices_z = { key: xr.where( - nanindices_z[key] == None, maxindices_z[key], nanindices_z[key], - ) # noqa: E711 + nanindices_z[key] == None, # noqa: E711 + maxindices_z[key], + nanindices_z[key], + ) for key in maxindices_z } expected9 = { @@ -5999,8 +6019,10 @@ def test_argmax_dim( maxindices_xy = { key: xr.where( - nanindices_xy[key] == None, maxindices_xy[key], nanindices_xy[key], - ) # noqa: E711 + nanindices_xy[key] == None, # noqa: E711 + maxindices_xy[key], + nanindices_xy[key], + ) for key in maxindices_xy } expected10 = { @@ -6013,8 +6035,10 @@ def test_argmax_dim( maxindices_xz = { key: xr.where( - nanindices_xz[key] == None, maxindices_xz[key], nanindices_xz[key], - ) # noqa: E711 + nanindices_xz[key] == None, # noqa: E711 + maxindices_xz[key], + nanindices_xz[key], + ) for key in maxindices_xz } expected11 = { @@ -6027,8 +6051,10 @@ def test_argmax_dim( maxindices_yz = { key: xr.where( - nanindices_yz[key] == None, maxindices_yz[key], nanindices_yz[key], - ) # noqa: E711 + nanindices_yz[key] == None, # noqa: E711 + maxindices_yz[key], + nanindices_yz[key], + ) for key in maxindices_yz } expected12 = { @@ -6041,8 +6067,10 @@ def test_argmax_dim( maxindices_xyz = { key: xr.where( - nanindices_xyz[key] == None, maxindices_xyz[key], nanindices_xyz[key], - ) # noqa: E711 + nanindices_xyz[key] == None, # noqa: E711 + maxindices_xyz[key], + nanindices_xyz[key], + ) for key in maxindices_xyz } expected13 = {key: xr.DataArray(value) for key, value in maxindices_xyz.items()} From a758b0f3b5c344b9a25470072468677a3b5e5bf7 Mon Sep 17 00:00:00 2001 From: johnomotani Date: Fri, 26 Jun 2020 11:43:02 +0100 Subject: [PATCH 36/44] Tidy up argmin/argmax following code review Co-authored-by: Deepak Cherian --- xarray/core/dataarray.py | 8 ++++---- xarray/core/variable.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index c8b02360227..0ce76a5e23a 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3822,7 +3822,7 @@ def idxmax( def argmin( self, dim: Union[Hashable, Sequence[Hashable]] = None, - axis: Union[int, None] = None, + axis: int = None, keep_attrs: bool = None, skipna: bool = None, ) -> Union["DataArray", Dict[Hashable, "DataArray"]]: @@ -3830,7 +3830,7 @@ def argmin( If a sequence is passed to 'dim', then result returned as dict of DataArrays, which can be passed directly to isel(). If a single str is passed to 'dim' then - returns an int. + returns a DataArray with dtype int. If there are multiple minima, the indices of the first one found will be returned. @@ -3923,7 +3923,7 @@ def argmin( def argmax( self, dim: Union[Hashable, Sequence[Hashable]] = None, - axis: Union[int, None] = None, + axis: int = None, keep_attrs: bool = None, skipna: bool = None, ) -> Union["DataArray", Dict[Hashable, "DataArray"]]: @@ -3931,7 +3931,7 @@ def argmax( If a sequence is passed to 'dim', then result returned as dict of DataArrays, which can be passed directly to isel(). If a single str is passed to 'dim' then - returns an int. + returns a DataArray with dtype int. If there are multiple maxima, the indices of the first one found will be returned. diff --git a/xarray/core/variable.py b/xarray/core/variable.py index bb702141196..c505c749557 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2152,14 +2152,14 @@ def _unravel_argminmax( def argmin( self, dim: Union[Hashable, Sequence[Hashable]] = None, - axis: Union[int, None] = None, + axis: int = None, keep_attrs: bool = None, skipna: bool = None, ) -> Union["Variable", Dict[Hashable, "Variable"]]: """Index or indices of the minimum of the Variable over one or more dimensions. If a sequence is passed to 'dim', then result returned as dict of Variables, which can be passed directly to isel(). If a single str is passed to 'dim' then - returns an int. + returns a Variable with dtype int. If there are multiple minima, the indices of the first one found will be returned. @@ -2197,14 +2197,14 @@ def argmin( def argmax( self, dim: Union[Hashable, Sequence[Hashable]] = None, - axis: Union[int, None] = None, + axis: int = None, keep_attrs: bool = None, skipna: bool = None, ) -> Union["Variable", Dict[Hashable, "Variable"]]: """Index or indices of the maximum of the Variable over one or more dimensions. If a sequence is passed to 'dim', then result returned as dict of Variables, which can be passed directly to isel(). If a single str is passed to 'dim' then - returns an int. + returns a Variable with dtype int. If there are multiple maxima, the indices of the first one found will be returned. From 9a54e0ca753f312ba1c8908fcde96d943161035e Mon Sep 17 00:00:00 2001 From: John Omotani Date: Fri, 26 Jun 2020 11:59:41 +0100 Subject: [PATCH 37/44] Remove filters for warnings from argmin/argmax from tests Pass an explicit axis or dim argument instead to avoid the warning. --- xarray/tests/test_units.py | 24 ++++++------------------ xarray/tests/test_variable.py | 5 +---- 2 files changed, 7 insertions(+), 22 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 881fae48e43..fbcfa38ecf7 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1408,8 +1408,8 @@ def test_real_and_imag(self): ( method("all"), method("any"), - method("argmax"), - method("argmin"), + method("argmax", axis=0), + method("argmin", axis=0), method("argsort"), method("cumprod"), method("cumsum"), @@ -1427,9 +1427,6 @@ def test_real_and_imag(self): ), ids=repr, ) - @pytest.mark.filterwarnings( - "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" - ) def test_aggregation(self, func, dtype): array = np.linspace(0, 1, 10).astype(dtype) * ( unit_registry.m if func.name != "cumprod" else unit_registry.dimensionless @@ -2280,8 +2277,8 @@ def test_repr(self, func, variant, dtype): function("cumprod"), method("all"), method("any"), - method("argmax"), - method("argmin"), + method("argmax", axis=0), + method("argmin", axis=0), method("max"), method("mean"), method("median"), @@ -2298,9 +2295,6 @@ def test_repr(self, func, variant, dtype): ), ids=repr, ) - @pytest.mark.filterwarnings( - "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" - ) def test_aggregation(self, func, dtype): array = np.arange(10).astype(dtype) * ( unit_registry.m if func.name != "cumprod" else unit_registry.dimensionless @@ -3853,8 +3847,8 @@ def test_repr(self, func, variant, dtype): function("cumprod"), method("all"), method("any"), - method("argmax"), - method("argmin"), + method("argmax", axis=0), + method("argmin", axis=0), method("max"), method("min"), method("mean"), @@ -3871,12 +3865,6 @@ def test_repr(self, func, variant, dtype): ), ids=repr, ) - @pytest.mark.filterwarnings( - "ignore:Once the behaviour of DataArray:DeprecationWarning" - ) - @pytest.mark.filterwarnings( - "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" - ) def test_aggregation(self, func, dtype): unit_a, unit_b = ( (unit_registry.Pa, unit_registry.degK) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index aad55af25bc..d79d40d67c0 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1639,9 +1639,6 @@ def test_big_endian_reduce(self): expected = Variable([], 5) assert_identical(expected, v.sum()) - @pytest.mark.filterwarnings( - "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" - ) def test_reduce_funcs(self): v = Variable("x", np.array([1, np.nan, 2, 3])) assert_identical(v.mean(), Variable([], 2)) @@ -1660,7 +1657,7 @@ def test_reduce_funcs(self): assert_identical(v.all(dim="x"), Variable([], False)) v = Variable("t", pd.date_range("2000-01-01", periods=3)) - assert v.argmax(skipna=True) == 2 + assert v.argmax(skipna=True, dim="t") == 2 assert_identical(v.max(), Variable([], pd.Timestamp("2000-01-03"))) From a07ce295b0397095fe597a55f88c649e03d38504 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Fri, 26 Jun 2020 12:05:48 +0100 Subject: [PATCH 38/44] Swap order of reduce_dims checks in Dataset.reduce() Prefer to pass reduce_dims=None when possible, including for variables with only one dimension. Avoids an error if an 'axis' keyword was passed. --- xarray/core/dataset.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 26a2ec58ec7..8f60d04cc97 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4295,15 +4295,15 @@ def reduce( or np.issubdtype(var.dtype, np.number) or (var.dtype == np.bool_) ): - if len(reduce_dims) == 1: - # unpack dimensions for the benefit of functions - # like np.argmin which can't handle tuple arguments - (reduce_dims,) = reduce_dims - elif len(reduce_dims) == var.ndim: + if len(reduce_dims) == var.ndim: # prefer to aggregate over axis=None rather than # axis=(0, 1) if they will be equivalent, because # the former is often more efficient reduce_dims = None # type: ignore + elif len(reduce_dims) == 1: + # unpack dimensions for the benefit of functions + # like np.argmin which can't handle tuple arguments + (reduce_dims,) = reduce_dims variables[name] = var.reduce( func, dim=reduce_dims, From d77fe1157b88c54b1fa7199e6d7f0804a2139543 Mon Sep 17 00:00:00 2001 From: Keewis Date: Fri, 26 Jun 2020 16:05:15 +0200 Subject: [PATCH 39/44] revert the changes to Dataset.reduce --- xarray/core/dataset.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 8f60d04cc97..26a2ec58ec7 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4295,15 +4295,15 @@ def reduce( or np.issubdtype(var.dtype, np.number) or (var.dtype == np.bool_) ): - if len(reduce_dims) == var.ndim: + if len(reduce_dims) == 1: + # unpack dimensions for the benefit of functions + # like np.argmin which can't handle tuple arguments + (reduce_dims,) = reduce_dims + elif len(reduce_dims) == var.ndim: # prefer to aggregate over axis=None rather than # axis=(0, 1) if they will be equivalent, because # the former is often more efficient reduce_dims = None # type: ignore - elif len(reduce_dims) == 1: - # unpack dimensions for the benefit of functions - # like np.argmin which can't handle tuple arguments - (reduce_dims,) = reduce_dims variables[name] = var.reduce( func, dim=reduce_dims, From 308bb239b0b4649951476a2e7989d62ddb51c155 Mon Sep 17 00:00:00 2001 From: Keewis Date: Fri, 26 Jun 2020 16:27:37 +0200 Subject: [PATCH 40/44] use dim instead of axis --- xarray/tests/test_units.py | 51 +++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 14 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index fbcfa38ecf7..b9e764a4969 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -297,19 +297,29 @@ def __call__(self, obj, *args, **kwargs): all_args = merge_args(self.args, args) all_kwargs = {**self.kwargs, **kwargs} + xarray_classes = ( + xr.Variable, + xr.DataArray, + xr.Dataset, + xr.core.groupby.GroupBy, + ) + + if not isinstance(obj, xarray_classes): + # remove typical xarray args like "dim" + exclude_kwargs = ("dim", "dims") + all_kwargs = { + key: value + for key, value in all_kwargs.items() + if key not in exclude_kwargs + } + func = getattr(obj, self.name, None) + if func is None or not isinstance(func, Callable): # fall back to module level numpy functions if not a xarray object if not isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)): numpy_func = getattr(np, self.name) func = partial(numpy_func, obj) - # remove typical xarray args like "dim" - exclude_kwargs = ("dim", "dims") - all_kwargs = { - key: value - for key, value in all_kwargs.items() - if key not in exclude_kwargs - } else: raise AttributeError(f"{obj} has no method named '{self.name}'") @@ -1408,8 +1418,8 @@ def test_real_and_imag(self): ( method("all"), method("any"), - method("argmax", axis=0), - method("argmin", axis=0), + method("argmax", dim="x"), + method("argmin", dim="x"), method("argsort"), method("cumprod"), method("cumsum"), @@ -1433,7 +1443,11 @@ def test_aggregation(self, func, dtype): ) variable = xr.Variable("x", array) - units = extract_units(func(array)) + numpy_kwargs = func.kwargs.copy() + if "dim" in func.kwargs: + numpy_kwargs["axis"] = variable.get_axis_num(numpy_kwargs.pop("dim")) + + units = extract_units(func(array, **numpy_kwargs)) expected = attach_units(func(strip_units(variable)), units) actual = func(variable) @@ -2277,8 +2291,8 @@ def test_repr(self, func, variant, dtype): function("cumprod"), method("all"), method("any"), - method("argmax", axis=0), - method("argmin", axis=0), + method("argmax", dim="x"), + method("argmin", dim="x"), method("max"), method("mean"), method("median"), @@ -2301,6 +2315,10 @@ def test_aggregation(self, func, dtype): ) data_array = xr.DataArray(data=array, dims="x") + numpy_kwargs = func.kwargs.copy() + if "dim" in numpy_kwargs: + numpy_kwargs["axis"] = data_array.get_axis_num(numpy_kwargs.pop("dim")) + # units differ based on the applied function, so we need to # first compute the units units = extract_units(func(array)) @@ -3847,8 +3865,8 @@ def test_repr(self, func, variant, dtype): function("cumprod"), method("all"), method("any"), - method("argmax", axis=0), - method("argmin", axis=0), + method("argmax", dim=...), + method("argmin", dim=...), method("max"), method("min"), method("mean"), @@ -3877,6 +3895,11 @@ def test_aggregation(self, func, dtype): ds = xr.Dataset({"a": ("x", a), "b": ("x", b)}) + numpy_kwargs = func.kwargs.copy() + if "dim" in numpy_kwargs: + # can't translate dimension to axis: no guaranteed order + numpy_kwargs["axis"] = None + units_a = array_extract_units(func(a)) units_b = array_extract_units(func(b)) units = {"a": units_a, "b": units_b} From 1b53f49a54d0269cc5eba0ca96defe8ec7de9af2 Mon Sep 17 00:00:00 2001 From: Keewis Date: Fri, 26 Jun 2020 19:12:51 +0200 Subject: [PATCH 41/44] use dimension instead of Ellipsis --- xarray/tests/test_units.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index b9e764a4969..20a5f0e8613 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -3865,8 +3865,8 @@ def test_repr(self, func, variant, dtype): function("cumprod"), method("all"), method("any"), - method("argmax", dim=...), - method("argmin", dim=...), + method("argmax", dim="x"), + method("argmin", dim="x"), method("max"), method("min"), method("mean"), @@ -3895,13 +3895,23 @@ def test_aggregation(self, func, dtype): ds = xr.Dataset({"a": ("x", a), "b": ("x", b)}) - numpy_kwargs = func.kwargs.copy() - if "dim" in numpy_kwargs: - # can't translate dimension to axis: no guaranteed order - numpy_kwargs["axis"] = None + if "dim" in func.kwargs: + numpy_kwargs = func.kwargs.copy() + dim = numpy_kwargs.pop("dim") - units_a = array_extract_units(func(a)) - units_b = array_extract_units(func(b)) + axis_a = ds.a.get_axis_num(dim) + axis_b = ds.b.get_axis_num(dim) + + numpy_kwargs_a = numpy_kwargs.copy() + numpy_kwargs_a["axis"] = axis_a + numpy_kwargs_b = numpy_kwargs.copy() + numpy_kwargs_b["axis"] = axis_b + else: + numpy_kwargs_a = {} + numpy_kwargs_b = {} + + units_a = array_extract_units(func(a, **numpy_kwargs_a)) + units_b = array_extract_units(func(b, **numpy_kwargs_b)) units = {"a": units_a, "b": units_b} actual = func(ds) From 5f80205d617f00db81829cddd3ccc3883ce74af7 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Fri, 26 Jun 2020 17:32:48 +0100 Subject: [PATCH 42/44] Make passing 'dim=...' to Dataset.argmin() or Dataset.argmax() an error --- xarray/core/dataset.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 26a2ec58ec7..cf0d8b444f5 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6384,7 +6384,7 @@ def argmin(self, dim=None, axis=None, **kwargs): if ( dim is None or axis is not None - or not isinstance(dim, Sequence) + or (not isinstance(dim, Sequence) and dim is not ...) or isinstance(dim, str) ): # Return int index if single dimension is passed, and is not part of a @@ -6393,9 +6393,9 @@ def argmin(self, dim=None, axis=None, **kwargs): return self.reduce(argmin_func, dim=dim, axis=axis, **kwargs) else: raise ValueError( - "When dim is a sequence, DataArray.argmin() returns a " - "dict. dicts cannot be contained in a Dataset, so cannot " - "call Dataset.argmin() with a sequence for dim" + "When dim is a sequence or ..., DataArray.argmin() returns a dict. " + "dicts cannot be contained in a Dataset, so cannot call " + "Dataset.argmin() with a sequence or ... for dim" ) def argmax(self, dim=None, axis=None, **kwargs): @@ -6414,7 +6414,7 @@ def argmax(self, dim=None, axis=None, **kwargs): if ( dim is None or axis is not None - or not isinstance(dim, Sequence) + or (not isinstance(dim, Sequence) and dim is not ...) or isinstance(dim, str) ): # Return int index if single dimension is passed, and is not part of a @@ -6423,9 +6423,9 @@ def argmax(self, dim=None, axis=None, **kwargs): return self.reduce(argmax_func, dim=dim, axis=axis, **kwargs) else: raise ValueError( - "When dim is a sequence, DataArray.argmax() returns a " - "dict. dicts cannot be contained in a Dataset, so cannot " - "call Dataset.argmax() with a sequence for dim" + "When dim is a sequence or ..., DataArray.argmin() returns a dict. " + "dicts cannot be contained in a Dataset, so cannot call " + "Dataset.argmin() with a sequence or ... for dim" ) From 540c281aa0912e92b5f6ed46e560de7e6e0f0f9c Mon Sep 17 00:00:00 2001 From: John Omotani Date: Fri, 26 Jun 2020 17:38:01 +0100 Subject: [PATCH 43/44] Better docstrings for Dataset.argmin() and Dataset.argmax() --- xarray/core/dataset.py | 74 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 70 insertions(+), 4 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index cf0d8b444f5..b46b1d6dce0 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6369,8 +6369,41 @@ def idxmax( ) def argmin(self, dim=None, axis=None, **kwargs): - """Apply argmin() to each variable in the Dataset - """ + """Indices of the minima of the member variables. + + If there are multiple minima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : str, optional + The dimension over which to find the minimum. By default, finds minimum over + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will be an error, since DataArray.argmin will + return a dict with indices for all dimensions, which does not make sense for + a Dataset. + axis : int, optional + Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments + can be supplied. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : Dataset + + See also + -------- + DataArray.argmin + + """ if dim is None and axis is None: warnings.warn( "Once the behaviour of DataArray.argmin() and Variable.argmin() with " @@ -6399,8 +6432,41 @@ def argmin(self, dim=None, axis=None, **kwargs): ) def argmax(self, dim=None, axis=None, **kwargs): - """Apply argmax() to each variable in the Dataset - """ + """Indices of the maxima of the member variables. + + If there are multiple maxima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : str, optional + The dimension over which to find the maximum. By default, finds maximum over + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will be an error, since DataArray.argmax will + return a dict with indices for all dimensions, which does not make sense for + a Dataset. + axis : int, optional + Axis over which to apply `argmax`. Only one of the 'dim' and 'axis' arguments + can be supplied. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : Dataset + + See also + -------- + DataArray.argmax + + """ if dim is None and axis is None: warnings.warn( "Once the behaviour of DataArray.argmax() and Variable.argmax() with " From 4aca9d9fcf5d2fbd4cecd951f5a99791ad5313ea Mon Sep 17 00:00:00 2001 From: johnomotani Date: Sat, 27 Jun 2020 18:03:40 +0100 Subject: [PATCH 44/44] Update doc/whats-new.rst Co-authored-by: keewis --- doc/whats-new.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 8f1aeaebc9a..27384c43a80 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -59,7 +59,6 @@ New Features (:pull:`3936`) By `John Omotani `_, thanks to `Keisuke Fujii `_ for work in :pull:`1469`. -- Added :py:meth:`DataArray.polyfit` and :py:func:`xarray.polyval` for fitting polynomials. (:issue:`3349`) - Added :py:meth:`xarray.infer_freq` for extending frequency inferring to CFTime indexes and data (:pull:`4033`). By `Pascal Bourgault `_. - ``chunks='auto'`` is now supported in the ``chunks`` argument of