diff --git a/pyttb/cp_apr.py b/pyttb/cp_apr.py index 89020a0c..e9afa110 100644 --- a/pyttb/cp_apr.py +++ b/pyttb/cp_apr.py @@ -1171,8 +1171,10 @@ def tt_calcpi_prowsubprob( Pi *= Model[i][Data.subs[sparse_indices, i], :] else: Pi = ttb.khatrirao( - Model.factor_matrices[:factorIndex] - + Model.factor_matrices[factorIndex + 1 : ndims + 1], + *( + Model.factor_matrices[:factorIndex] + + Model.factor_matrices[factorIndex + 1 : ndims + 1] + ), reverse=True, ) @@ -1660,8 +1662,10 @@ def calculatePi(Data, Model, rank, factorIndex, ndims): Pi *= Model[i][Data.subs[:, i], :] else: Pi = ttb.khatrirao( - Model.factor_matrices[:factorIndex] - + Model.factor_matrices[factorIndex + 1 :], + *( + Model.factor_matrices[:factorIndex] + + Model.factor_matrices[factorIndex + 1 :] + ), reverse=True, ) diff --git a/pyttb/khatrirao.py b/pyttb/khatrirao.py index aded4497..083b1814 100644 --- a/pyttb/khatrirao.py +++ b/pyttb/khatrirao.py @@ -1,11 +1,11 @@ # Copyright 2022 National Technology & Engineering Solutions of Sandia, # LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the # U.S. Government retains certain rights in this software. - +"""Khatri-Rao Product Implementation""" import numpy as np -def khatrirao(*listOfMatrices, reverse=False): +def khatrirao(*matrices: np.ndarray, reverse: bool = False) -> np.ndarray: """ KHATRIRAO Khatri-Rao product of matrices. @@ -16,69 +16,45 @@ def khatrirao(*listOfMatrices, reverse=False): Parameters ---------- - Matrices: [:class:`numpy.ndarray`] or :class:`numpy.ndarray`,:class:`numpy.ndarray`... + matrices: Collection of matrices to take the product of reverse: bool Set to true to calculate product in reverse - Returns - ------- - product: float - Examples -------- >>> A = np.random.normal(size=(5,2)) >>> B = np.random.normal(size=(5,2)) >>> _ = khatrirao(A,B) #<-- Khatri-Rao of A and B >>> _ = khatrirao(B,A,reverse=True) #<-- same thing as above - >>> _ = khatrirao([A,A,B]) #<-- passing a list - >>> _ = khatrirao([B,A,A],reverse = True) #<-- same as above + >>> _ = khatrirao(A,A,B) #<-- passing multiple items + >>> _ = khatrirao(B,A,A,reverse = True) #<-- same as above + >>> _ = khatrirao(*[A,A,B]) #<-- passing a list via unpacking items """ # Determine if list of matrices of multiple matrix arguments - if isinstance(listOfMatrices[0], list): - if len(listOfMatrices) == 1: - listOfMatrices = listOfMatrices[0] - else: - assert ( - False - ), "Khatri Rao Acts on multiple Array arguments or a list of Arrays" + if len(matrices) == 1 and isinstance(matrices[0], list): + raise ValueError( + "Khatrirao interface has changed. Instead of " + " `khatrirao([matrix_a, matrix_b])` please update to use argument " + "unpacking `khatrirao(*[matrix_a, matrix_b])`. This reduces ambiguity " + "in usage moving forward. " + ) + + if not isinstance(reverse, bool): + raise ValueError(f"Expected a bool for reverse but received {reverse}") # Error checking on input and set matrix order - if reverse == True: - listOfMatrices = list(reversed(listOfMatrices)) - ndimsA = [len(matrix.shape) == 2 for matrix in listOfMatrices] - if not np.all(ndimsA): + if reverse is True: + matrices = tuple(reversed(matrices)) + if not all(len(matrix.shape) == 2 for matrix in matrices): assert False, "Each argument must be a matrix" - ncolFirst = listOfMatrices[0].shape[1] - ncols = [matrix.shape[1] == ncolFirst for matrix in listOfMatrices] - if not np.all(ncols): + ncolFirst = matrices[0].shape[1] + if not all(matrix.shape[1] == ncolFirst for matrix in matrices): assert False, "All matrices must have the same number of columns." # Computation - # print(f'A =\n {listOfMatrices}') - P = listOfMatrices[0] - # print(f'size_P = \n{P.shape}') - # print(f'P = \n{P}') - if ncolFirst == 1: - for i in listOfMatrices[1:]: - # print(f'size_Ai = \n{i.shape}') - # print(f'size_reshape_Ai = \n{np.reshape(i, newshape=(-1, ncolFirst)).shape}') - # print(f'size_P = \n{P.shape}') - # print(f'size_reshape_P = \n{np.reshape(P, newshape=(ncolFirst, -1)).shape}') - P = np.reshape(i, newshape=(-1, ncolFirst)) * np.reshape( - P, newshape=(ncolFirst, -1), order="F" - ) - # print(f'size_P = \n{P.shape}') - # print(f'P = \n{P}') - else: - for i in listOfMatrices[1:]: - # print(f'size_Ai = \n{i.shape}') - # print(f'size_reshape_Ai = \n{np.reshape(i, newshape=(-1, 1, ncolFirst)).shape}') - # print(f'size_P = \n{P.shape}') - # print(f'size_reshape_P = \n{np.reshape(P, newshape=(1, -1, ncolFirst)).shape}') - P = np.reshape(i, newshape=(-1, 1, ncolFirst)) * np.reshape( - P, newshape=(1, -1, ncolFirst), order="F" - ) - # print(f'size_P = \n{P.shape}') - # print(f'P = \n{P}') - + P = matrices[0] + for i in matrices[1:]: + P = np.reshape(i, newshape=(-1, 1, ncolFirst)) * np.reshape( + P, newshape=(1, -1, ncolFirst), order="F" + ) return np.reshape(P, newshape=(-1, ncolFirst), order="F") diff --git a/pyttb/ktensor.py b/pyttb/ktensor.py index 58b068d3..95e82293 100644 --- a/pyttb/ktensor.py +++ b/pyttb/ktensor.py @@ -990,7 +990,7 @@ def full(self): [63. 85.]] """ - data = self.weights @ ttb.khatrirao(self.factor_matrices, reverse=True).T + data = self.weights @ ttb.khatrirao(*self.factor_matrices, reverse=True).T return ttb.tensor.from_data(data, self.shape) def innerprod(self, other): diff --git a/pyttb/sptensor.py b/pyttb/sptensor.py index 9aeb83a5..59bc60d9 100644 --- a/pyttb/sptensor.py +++ b/pyttb/sptensor.py @@ -348,7 +348,7 @@ def allsubs(self) -> np.ndarray: for n in range(0, self.ndims): i = o.copy() i[n] = np.expand_dims(np.arange(0, self.shape[n]), axis=1) - s[:, n] = np.squeeze(ttb.khatrirao(i)) + s[:, n] = np.squeeze(ttb.khatrirao(*i)) return s.astype(int) @@ -1723,7 +1723,7 @@ def _set_subtensor(self, key, value): i[n] = np.array(keyCopy[n])[:, None] else: i[n] = np.array(keyCopy[n], ndmin=2) - addsubs[:, n] = ttb.khatrirao(i).transpose()[:] + addsubs[:, n] = ttb.khatrirao(*i).transpose()[:] if self.subs.size > 0: # Replace existing values diff --git a/pyttb/tensor.py b/pyttb/tensor.py index b415824d..2658afd2 100644 --- a/pyttb/tensor.py +++ b/pyttb/tensor.py @@ -702,16 +702,18 @@ def mttkrp(self, U: Union[ttb.ktensor, List[np.ndarray]], n: int) -> np.ndarray: szn = self.shape[n] if n == 0: - Ur = ttb.khatrirao(U[1 : self.ndims], reverse=True) + Ur = ttb.khatrirao(*U[1 : self.ndims], reverse=True) Y = np.reshape(self.data, (szn, szr), order="F") return Y @ Ur if n == self.ndims - 1: # pylint: disable=no-else-return - Ul = ttb.khatrirao(U[0 : self.ndims - 1], reverse=True) + Ul = ttb.khatrirao(*U[0 : self.ndims - 1], reverse=True) Y = np.reshape(self.data, (szl, szn), order="F") return Y.T @ Ul else: - Ul = ttb.khatrirao(U[n + 1 :], reverse=True) - Ur = np.reshape(ttb.khatrirao(U[0:n], reverse=True), (szl, 1, R), order="F") + Ul = ttb.khatrirao(*U[n + 1 :], reverse=True) + Ur = np.reshape( + ttb.khatrirao(*U[0:n], reverse=True), (szl, 1, R), order="F" + ) Y = np.reshape(self.data, (-1, szr), order="F") Y = Y @ Ul Y = np.reshape(Y, (szl, szn, R), order="F") diff --git a/tests/test_khatrirao.py b/tests/test_khatrirao.py index 9dd00ceb..47bdbd37 100644 --- a/tests/test_khatrirao.py +++ b/tests/test_khatrirao.py @@ -24,8 +24,8 @@ def test_khatrirao(): [64, 125, 216], ] ) - assert (ttb.khatrirao([A, A, A]) == answer).all() - assert (ttb.khatrirao([A, A, A], reverse=True) == answer).all() + assert (ttb.khatrirao(*[A, A, A]) == answer).all() + assert (ttb.khatrirao(*[A, A, A], reverse=True) == answer).all() assert (ttb.khatrirao(A, A, A) == answer).all() # Test case where inputs are column vectors @@ -40,15 +40,9 @@ def test_khatrirao(): a_2[3, 0] * np.ones((16, 1)), ) ) - assert (ttb.khatrirao([a_2, a_1, a_1]) == result).all() + assert (ttb.khatrirao(*[a_2, a_1, a_1]) == result).all() assert (ttb.khatrirao(a_2, a_1, a_1) == result).all() - with pytest.raises(AssertionError) as excinfo: - ttb.khatrirao([a_2, a_1, a_1], a_2) - assert "Khatri Rao Acts on multiple Array arguments or a list of Arrays" in str( - excinfo - ) - with pytest.raises(AssertionError) as excinfo: ttb.khatrirao(a_2, a_1, np.ones((2, 2, 2))) assert "Each argument must be a matrix" in str(excinfo) @@ -56,3 +50,10 @@ def test_khatrirao(): with pytest.raises(AssertionError) as excinfo: ttb.khatrirao(a_2, a_1, a_3) assert "All matrices must have the same number of columns." in str(excinfo) + + # Check old interface error + with pytest.raises(ValueError): + ttb.khatrirao([a_1, a_1, a_1]) + + with pytest.raises(ValueError): + ttb.khatrirao(a_1, a_1, reverse="cat") diff --git a/tests/test_package.py b/tests/test_package.py index fe87ea42..5011603a 100644 --- a/tests/test_package.py +++ b/tests/test_package.py @@ -27,6 +27,7 @@ def test_linting(): os.path.join(os.path.dirname(ttb.__file__), f"{ttb.tensor.__name__}.py"), os.path.join(os.path.dirname(ttb.__file__), f"{ttb.sptensor.__name__}.py"), ttb.pyttb_utils.__file__, + os.path.join(os.path.dirname(ttb.__file__), f"{ttb.khatrirao.__name__}.py"), ] # TODO pylint fails to import pyttb in tests # add mypy check