Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Khatrirao cleanup #127

Merged
merged 3 commits into from
Jun 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions pyttb/cp_apr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,8 +1171,10 @@ def tt_calcpi_prowsubprob(
Pi *= Model[i][Data.subs[sparse_indices, i], :]
else:
Pi = ttb.khatrirao(
Model.factor_matrices[:factorIndex]
+ Model.factor_matrices[factorIndex + 1 : ndims + 1],
*(
Model.factor_matrices[:factorIndex]
+ Model.factor_matrices[factorIndex + 1 : ndims + 1]
),
reverse=True,
)

Expand Down Expand Up @@ -1660,8 +1662,10 @@ def calculatePi(Data, Model, rank, factorIndex, ndims):
Pi *= Model[i][Data.subs[:, i], :]
else:
Pi = ttb.khatrirao(
Model.factor_matrices[:factorIndex]
+ Model.factor_matrices[factorIndex + 1 :],
*(
Model.factor_matrices[:factorIndex]
+ Model.factor_matrices[factorIndex + 1 :]
),
reverse=True,
)

Expand Down
76 changes: 26 additions & 50 deletions pyttb/khatrirao.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright 2022 National Technology & Engineering Solutions of Sandia,
# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the
# U.S. Government retains certain rights in this software.

"""Khatri-Rao Product Implementation"""
import numpy as np


def khatrirao(*listOfMatrices, reverse=False):
def khatrirao(*matrices: np.ndarray, reverse: bool = False) -> np.ndarray:
"""
KHATRIRAO Khatri-Rao product of matrices.

Expand All @@ -16,69 +16,45 @@ def khatrirao(*listOfMatrices, reverse=False):

Parameters
----------
Matrices: [:class:`numpy.ndarray`] or :class:`numpy.ndarray`,:class:`numpy.ndarray`...
matrices: Collection of matrices to take the product of
reverse: bool Set to true to calculate product in reverse

Returns
-------
product: float

Examples
--------
>>> A = np.random.normal(size=(5,2))
>>> B = np.random.normal(size=(5,2))
>>> _ = khatrirao(A,B) #<-- Khatri-Rao of A and B
>>> _ = khatrirao(B,A,reverse=True) #<-- same thing as above
>>> _ = khatrirao([A,A,B]) #<-- passing a list
>>> _ = khatrirao([B,A,A],reverse = True) #<-- same as above
>>> _ = khatrirao(A,A,B) #<-- passing multiple items
>>> _ = khatrirao(B,A,A,reverse = True) #<-- same as above
>>> _ = khatrirao(*[A,A,B]) #<-- passing a list via unpacking items
"""
# Determine if list of matrices of multiple matrix arguments
if isinstance(listOfMatrices[0], list):
if len(listOfMatrices) == 1:
listOfMatrices = listOfMatrices[0]
else:
assert (
False
), "Khatri Rao Acts on multiple Array arguments or a list of Arrays"
if len(matrices) == 1 and isinstance(matrices[0], list):
raise ValueError(
"Khatrirao interface has changed. Instead of "
" `khatrirao([matrix_a, matrix_b])` please update to use argument "
"unpacking `khatrirao(*[matrix_a, matrix_b])`. This reduces ambiguity "
"in usage moving forward. "
)

if not isinstance(reverse, bool):
raise ValueError(f"Expected a bool for reverse but received {reverse}")

# Error checking on input and set matrix order
if reverse == True:
listOfMatrices = list(reversed(listOfMatrices))
ndimsA = [len(matrix.shape) == 2 for matrix in listOfMatrices]
if not np.all(ndimsA):
if reverse is True:
matrices = tuple(reversed(matrices))
if not all(len(matrix.shape) == 2 for matrix in matrices):
assert False, "Each argument must be a matrix"

ncolFirst = listOfMatrices[0].shape[1]
ncols = [matrix.shape[1] == ncolFirst for matrix in listOfMatrices]
if not np.all(ncols):
ncolFirst = matrices[0].shape[1]
if not all(matrix.shape[1] == ncolFirst for matrix in matrices):
assert False, "All matrices must have the same number of columns."

# Computation
# print(f'A =\n {listOfMatrices}')
P = listOfMatrices[0]
# print(f'size_P = \n{P.shape}')
# print(f'P = \n{P}')
if ncolFirst == 1:
for i in listOfMatrices[1:]:
# print(f'size_Ai = \n{i.shape}')
# print(f'size_reshape_Ai = \n{np.reshape(i, newshape=(-1, ncolFirst)).shape}')
# print(f'size_P = \n{P.shape}')
# print(f'size_reshape_P = \n{np.reshape(P, newshape=(ncolFirst, -1)).shape}')
P = np.reshape(i, newshape=(-1, ncolFirst)) * np.reshape(
P, newshape=(ncolFirst, -1), order="F"
)
# print(f'size_P = \n{P.shape}')
# print(f'P = \n{P}')
else:
for i in listOfMatrices[1:]:
# print(f'size_Ai = \n{i.shape}')
# print(f'size_reshape_Ai = \n{np.reshape(i, newshape=(-1, 1, ncolFirst)).shape}')
# print(f'size_P = \n{P.shape}')
# print(f'size_reshape_P = \n{np.reshape(P, newshape=(1, -1, ncolFirst)).shape}')
P = np.reshape(i, newshape=(-1, 1, ncolFirst)) * np.reshape(
P, newshape=(1, -1, ncolFirst), order="F"
)
# print(f'size_P = \n{P.shape}')
# print(f'P = \n{P}')

P = matrices[0]
for i in matrices[1:]:
P = np.reshape(i, newshape=(-1, 1, ncolFirst)) * np.reshape(
P, newshape=(1, -1, ncolFirst), order="F"
)
return np.reshape(P, newshape=(-1, ncolFirst), order="F")
2 changes: 1 addition & 1 deletion pyttb/ktensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@ def full(self):
[63. 85.]]
<BLANKLINE>
"""
data = self.weights @ ttb.khatrirao(self.factor_matrices, reverse=True).T
data = self.weights @ ttb.khatrirao(*self.factor_matrices, reverse=True).T
return ttb.tensor.from_data(data, self.shape)

def innerprod(self, other):
Expand Down
4 changes: 2 additions & 2 deletions pyttb/sptensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def allsubs(self) -> np.ndarray:
for n in range(0, self.ndims):
i = o.copy()
i[n] = np.expand_dims(np.arange(0, self.shape[n]), axis=1)
s[:, n] = np.squeeze(ttb.khatrirao(i))
s[:, n] = np.squeeze(ttb.khatrirao(*i))

return s.astype(int)

Expand Down Expand Up @@ -1723,7 +1723,7 @@ def _set_subtensor(self, key, value):
i[n] = np.array(keyCopy[n])[:, None]
else:
i[n] = np.array(keyCopy[n], ndmin=2)
addsubs[:, n] = ttb.khatrirao(i).transpose()[:]
addsubs[:, n] = ttb.khatrirao(*i).transpose()[:]

if self.subs.size > 0:
# Replace existing values
Expand Down
10 changes: 6 additions & 4 deletions pyttb/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,16 +702,18 @@ def mttkrp(self, U: Union[ttb.ktensor, List[np.ndarray]], n: int) -> np.ndarray:
szn = self.shape[n]

if n == 0:
Ur = ttb.khatrirao(U[1 : self.ndims], reverse=True)
Ur = ttb.khatrirao(*U[1 : self.ndims], reverse=True)
Y = np.reshape(self.data, (szn, szr), order="F")
return Y @ Ur
if n == self.ndims - 1: # pylint: disable=no-else-return
Ul = ttb.khatrirao(U[0 : self.ndims - 1], reverse=True)
Ul = ttb.khatrirao(*U[0 : self.ndims - 1], reverse=True)
Y = np.reshape(self.data, (szl, szn), order="F")
return Y.T @ Ul
else:
Ul = ttb.khatrirao(U[n + 1 :], reverse=True)
Ur = np.reshape(ttb.khatrirao(U[0:n], reverse=True), (szl, 1, R), order="F")
Ul = ttb.khatrirao(*U[n + 1 :], reverse=True)
Ur = np.reshape(
ttb.khatrirao(*U[0:n], reverse=True), (szl, 1, R), order="F"
)
Y = np.reshape(self.data, (-1, szr), order="F")
Y = Y @ Ul
Y = np.reshape(Y, (szl, szn, R), order="F")
Expand Down
19 changes: 10 additions & 9 deletions tests/test_khatrirao.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def test_khatrirao():
[64, 125, 216],
]
)
assert (ttb.khatrirao([A, A, A]) == answer).all()
assert (ttb.khatrirao([A, A, A], reverse=True) == answer).all()
assert (ttb.khatrirao(*[A, A, A]) == answer).all()
assert (ttb.khatrirao(*[A, A, A], reverse=True) == answer).all()
assert (ttb.khatrirao(A, A, A) == answer).all()

# Test case where inputs are column vectors
Expand All @@ -40,19 +40,20 @@ def test_khatrirao():
a_2[3, 0] * np.ones((16, 1)),
)
)
assert (ttb.khatrirao([a_2, a_1, a_1]) == result).all()
assert (ttb.khatrirao(*[a_2, a_1, a_1]) == result).all()
assert (ttb.khatrirao(a_2, a_1, a_1) == result).all()

with pytest.raises(AssertionError) as excinfo:
ttb.khatrirao([a_2, a_1, a_1], a_2)
assert "Khatri Rao Acts on multiple Array arguments or a list of Arrays" in str(
excinfo
)

with pytest.raises(AssertionError) as excinfo:
ttb.khatrirao(a_2, a_1, np.ones((2, 2, 2)))
assert "Each argument must be a matrix" in str(excinfo)

with pytest.raises(AssertionError) as excinfo:
ttb.khatrirao(a_2, a_1, a_3)
assert "All matrices must have the same number of columns." in str(excinfo)

# Check old interface error
with pytest.raises(ValueError):
ttb.khatrirao([a_1, a_1, a_1])

with pytest.raises(ValueError):
ttb.khatrirao(a_1, a_1, reverse="cat")
1 change: 1 addition & 0 deletions tests/test_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_linting():
os.path.join(os.path.dirname(ttb.__file__), f"{ttb.tensor.__name__}.py"),
os.path.join(os.path.dirname(ttb.__file__), f"{ttb.sptensor.__name__}.py"),
ttb.pyttb_utils.__file__,
os.path.join(os.path.dirname(ttb.__file__), f"{ttb.khatrirao.__name__}.py"),
]
# TODO pylint fails to import pyttb in tests
# add mypy check
Expand Down