Skip to content

Commit

Permalink
Replace _support function
Browse files Browse the repository at this point in the history
For unknbown reasons, np.sum is slow on a very large boolean array.
  • Loading branch information
dbarbier committed Jan 3, 2020
1 parent 26ec7ab commit b731fd2
Showing 1 changed file with 11 additions and 30 deletions.
41 changes: 11 additions & 30 deletions mlxtend/frequent_patterns/apriori.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,32 +121,6 @@ def apriori(df, min_support=0.5, use_colnames=False, max_len=None, verbose=0,
"""

def _support(_x, _n_rows, _is_sparse):
"""DRY private method to calculate support as the
row-wise sum of values / number of rows
Parameters
-----------
_x : matrix of bools or binary
_n_rows : numeric, number of rows in _x
_is_sparse : bool True if _x is sparse
Returns
-----------
np.array, shape = (n_rows, )
Examples
-----------
For usage examples, please see
http://rasbt.github.io/mlxtend/user_guide/frequent_patterns/apriori/
"""
out = (np.sum(_x, axis=0) / _n_rows)
return np.array(out).reshape(-1)

if min_support <= 0.:
raise ValueError('`min_support` must be a positive '
'number within the interval `(0, 1]`. '
Expand Down Expand Up @@ -180,7 +154,17 @@ def _support(_x, _n_rows, _is_sparse):
# dense DataFrame
X = df.values
is_sparse = False
support = _support(X, X.shape[0], is_sparse)
if is_sparse:
# Count nonnull entries via direct access to X indices;
# this requires X to be stored in CSC format, and to call
# X.eliminate_zeros() to remove null entries from X.
support = np.array([X.indptr[idx+1] - X.indptr[idx]
for idx in range(X.shape[1])], dtype=int)
else:
# Faster than np.count_nonzero(X, axis=0) or np.sum(X, axis=0), why?
support = np.array([np.count_nonzero(X[:, idx])
for idx in range(X.shape[1])], dtype=int)
support = support / X.shape[0]
support_dict = {1: support[support >= min_support]}
itemset_dict = {1: [(idx,) for idx in np.where(support >= min_support)[0]]}
max_itemset = 1
Expand All @@ -199,9 +183,6 @@ def _support(_x, _n_rows, _is_sparse):
processed += 1
count[:] = 0
for item in itemset:
# Count nonnull entries via direct access to X indices;
# this requires X to be stored in CSC format, and to call
# X.eliminate_zeros() to remove null entries from X.
count[X.indices[X.indptr[item]:X.indptr[item+1]]] += 1
support = np.count_nonzero(count == len(itemset)) / X.shape[0]
if support >= min_support:
Expand Down

0 comments on commit b731fd2

Please sign in to comment.