Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ttensor implementation #51

Merged
merged 10 commits into from
Feb 21, 2023
35 changes: 20 additions & 15 deletions pyttb/sptensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1712,7 +1712,7 @@ def __mul__(self, other):
-------
:class:`pyttb.sptensor`
"""
if isinstance(other, (float,int)):
if isinstance(other, (float, int, np.number)):
return ttb.sptensor.from_data(self.subs, self.vals*other, self.shape)

if isinstance(other, (ttb.sptensor,ttb.tensor,ttb.ktensor)) and self.shape != other.shape:
Expand Down Expand Up @@ -1754,7 +1754,7 @@ def __rmul__(self, other):
-------
:class:`pyttb.sptensor`
"""
if isinstance(other, (float,int)):
if isinstance(other, (float, int, np.number)):
return self.__mul__(other)
else:
assert False, "This object cannot be multiplied by sptensor"
Expand Down Expand Up @@ -2173,15 +2173,14 @@ def __repr__(self): # pragma: no cover

__str__ = __repr__

def ttm(self, matrices, mode, dims=None, transpose=False):
def ttm(self, matrices, dims=None, transpose=False):
"""
Sparse tensor times matrix.

Parameters
----------
matrices: A matrix or list of matrices
mode:
dims:
dims: :class:`Numpy.ndarray`, int
transpose: Transpose matrices to be multiplied

Returns
Expand All @@ -2190,10 +2189,15 @@ def ttm(self, matrices, mode, dims=None, transpose=False):
"""
if dims is None:
dims = np.arange(self.ndims)
elif isinstance(dims, list):
dims = np.array(dims)
elif np.isscalar(dims) or isinstance(dims, list):
dims = np.array([dims])

# Handle list of matrices
if isinstance(matrices, list):
# Check dimensions are valid
[dims, vidx] = tt_dimscheck(mode, self.ndims, len(matrices))
[dims, vidx] = tt_dimscheck(dims, self.ndims, len(matrices))
# Calculate individual products
Y = self.ttm(matrices[vidx[0]], dims[0], transpose=transpose)
for i in range(1, dims.size):
Expand All @@ -2208,33 +2212,34 @@ def ttm(self, matrices, mode, dims=None, transpose=False):
if transpose:
matrices = matrices.transpose()

# Check mode
if not np.isscalar(mode) or mode < 0 or mode > self.ndims-1:
assert False, "Mode must be in [0, ndims)"
# Ensure this is the terminal single dimension case
if not (dims.size == 1 and np.isin(dims, np.arange(self.ndims))):
assert False, "dims must contain values in [0,self.dims)"
dims = dims[0]

# Compute the product

# Check that sizes match
if self.shape[mode] != matrices.shape[1]:
if self.shape[dims] != matrices.shape[1]:
assert False, "Matrix shape doesn't match tensor shape"

# Compute the new size
siz = np.array(self.shape)
siz[mode] = matrices.shape[0]
siz[dims] = matrices.shape[0]

# Compute self[mode]'
Xnt = ttb.tt_to_sparse_matrix(self, mode, True)
Xnt = ttb.tt_to_sparse_matrix(self, dims, True)

# Reshape puts the reshaped things after the unchanged modes, transpose then puts it in front
idx = 0

# Convert to sparse matrix and do multiplication; generally result is sparse
Z = Xnt.dot(matrices.transpose())

# Rearrange back into sparse tensor of original shape
Ynt = ttb.tt_from_sparse_matrix(Z, self.shape, mode, idx)
# Rearrange back into sparse tensor of correct shape
Ynt = ttb.tt_from_sparse_matrix(Z, siz, dims, idx)

if Z.nnz <= 0.5 * np.prod(siz):
if not isinstance(Z, np.ndarray) and Z.nnz <= 0.5 * np.prod(siz):
return Ynt
else:
# TODO evaluate performance loss by casting into sptensor then tensor. I assume minimal since we are already
Expand Down
15 changes: 3 additions & 12 deletions pyttb/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ def ttm(self, matrix, dims=None, transpose=False):
assert False, "matrix must be of type numpy.ndarray"

if not (dims.size == 1 and np.isin(dims, np.arange(self.ndims))):
assert False, "dims must contain values in [0,self.dims]"
assert False, "dims must contain values in [0,self.dims)"

# old version (ver=0)
shape = np.array(self.shape)
Expand Down Expand Up @@ -1271,7 +1271,6 @@ def __getitem__(self, item):
kpdims = [] # dimensions to keep
rmdims = [] # dimensions to remove

# Determine the new size and what dimensions to keep
# Determine the new size and what dimensions to keep
for i in range(0, len(region)):
if isinstance(region[i], slice):
Expand All @@ -1289,19 +1288,11 @@ def __getitem__(self, item):

# If the size is zero, then the result is returned as a scalar
# otherwise, we convert the result to a tensor

if newsiz.size == 0:
a = newdata
else:
if rmdims.size == 0:
a = ttb.tensor.from_data(newdata)
else:
# If extracted data is a vector then no need to tranpose it
if len(newdata.shape) == 1:
a = ttb.tensor.from_data(newdata)
else:
a = ttb.tensor.from_data(np.transpose(newdata, np.concatenate((kpdims, rmdims))))
return ttb.tt_subsubsref(a, item)
a = ttb.tensor.from_data(newdata)
return a

# *** CASE 2a: Subscript indexing ***
if len(item) > 1 and isinstance(item[-1], str) and item[-1] == 'extract':
Expand Down
Loading