Skip to content

Commit

Permalink
Match ufunc.reduce behavior. (pydata#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed Feb 23, 2018
1 parent b75259d commit 302a83c
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 7 deletions.
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
'sphinx',
'sphinxcontrib-napoleon',
'sphinx_rtd_theme',
'numpydoc',
],
},
zip_safe=False
Expand Down
14 changes: 10 additions & 4 deletions sparse/coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ def nanreduce(self, method, identity=None, axis=None, keepdims=False, **kwargs):
arr = _replace_nan(self, method.identity if identity is None else identity)
return arr.reduce(method, axis, keepdims, **kwargs)

def reduce(self, method, axis=None, keepdims=False, **kwargs):
def reduce(self, method, axis=(0,), keepdims=False, **kwargs):
"""
Performs a reduction operation on this array.
Expand Down Expand Up @@ -692,10 +692,10 @@ def reduce(self, method, axis=None, keepdims=False, **kwargs):
>>> s4.dtype
dtype('float16')
By default, this reduces the array down to one number, reducing along all axes.
By default, this reduces the array by only the first axis.
>>> s.reduce(np.add)
25
<COO: shape=(5,), dtype=int64, nnz=5, sorted=True, duplicates=False>
"""
zero_reduce_result = method.reduce([_zero_of_dtype(self.dtype)], **kwargs)

Expand Down Expand Up @@ -730,10 +730,16 @@ def reduce(self, method, axis=None, keepdims=False, **kwargs):
result[missing_counts] = method(result[missing_counts],
_zero_of_dtype(self.dtype), **kwargs)
coords = a.coords[0:1, inv_idx]

# Filter out zeros
mask = result != _zero_of_dtype(result.dtype)
coords = coords[:, mask]
result = result[mask]

a = COO(coords, result, shape=(a.shape[0],),
has_duplicates=False, sorted=True)

a = a.reshape([self.shape[d] for d in neg_axis])
a = a.reshape(tuple(self.shape[d] for d in neg_axis))
result = a

if keepdims:
Expand Down
20 changes: 18 additions & 2 deletions sparse/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_reductions(reduction, axis, keepdims, kwargs, eqkwargs):
y = x.todense()
xx = getattr(x, reduction)(axis=axis, keepdims=keepdims, **kwargs)
yy = getattr(y, reduction)(axis=axis, keepdims=keepdims, **kwargs)
assert_eq(xx, yy, check_nnz=False, **eqkwargs)
assert_eq(xx, yy, **eqkwargs)


@pytest.mark.parametrize('reduction,kwargs,eqkwargs', [
Expand All @@ -42,7 +42,23 @@ def test_ufunc_reductions(reduction, axis, keepdims, kwargs, eqkwargs):
y = x.todense()
xx = reduction(x, axis=axis, keepdims=keepdims, **kwargs)
yy = reduction(y, axis=axis, keepdims=keepdims, **kwargs)
assert_eq(xx, yy, check_nnz=False, **eqkwargs)
assert_eq(xx, yy, **eqkwargs)


@pytest.mark.parametrize('reduction,kwargs', [
(np.max, {}),
(np.sum, {'axis': 0}),
(np.prod, {'keepdims': True}),
(np.add.reduce, {}),
(np.add.reduce, {'keepdims': True}),
(np.minimum.reduce, {'axis': 0}),
])
def test_ufunc_reductions_kwargs(reduction, kwargs):
x = sparse.random((2, 3, 4), density=.5)
y = x.todense()
xx = reduction(x, **kwargs)
yy = reduction(y, **kwargs)
assert_eq(xx, yy)


@pytest.mark.parametrize('reduction', [
Expand Down

0 comments on commit 302a83c

Please sign in to comment.