diff --git a/TensorToolbox/import_data.py b/TensorToolbox/import_data.py index 96fb2ce8..b0905bf0 100644 --- a/TensorToolbox/import_data.py +++ b/TensorToolbox/import_data.py @@ -24,7 +24,9 @@ def import_data(filename): if data_type == 'tensor': - assert False, f"{data_type} is not currently allowed" + shape = import_shape(fp) + data = import_array(fp, np.prod(shape)) + return ttb.tensor().from_data(data, shape) elif data_type == 'sptensor': @@ -35,26 +37,22 @@ def import_data(filename): elif data_type == 'matrix': - assert False, f"{data_type} is not currently allowed" + shape = import_shape(fp) + mat = import_array(fp, np.prod(shape)) + mat = np.reshape(mat, np.array(shape)) + return mat elif data_type == 'ktensor': shape = import_shape(fp) - #print(f"shape: {shape}") r = import_rank(fp) - #print(f"rank: {r}") - weights = np.array(fp.readline().strip().split(' '),dtype="float") - #print(f"weights: {weights}") + weights = import_array(fp, r) factor_matrices = [] for n in range(len(shape)): fac_type = fp.readline().strip() - #print(f"fac_type: {fac_type}") fac_shape = import_shape(fp) - #print(f"fac_shape: {fac_shape}") - fac = np.zeros(fac_shape, dtype="float") - for r in range(fac_shape[0]): - fac[r,:] = fp.readline().strip().split(' ') - #print(f"fac: {fac}") + fac = import_array(fp, np.prod(fac_shape)) + fac = np.reshape(fac, np.array(fac_shape)) factor_matrices.append(fac) return ttb.ktensor().from_data(weights, factor_matrices) @@ -87,6 +85,10 @@ def import_sparse_array(fp, n, nz): vals = np.zeros((nz, 1)) for k in range(nz): line = fp.readline().strip().split(' ') - subs[k,:] = line[:-1] + # 1-based indexing in file, 0-based indexing in package + subs[k,:] = [np.int64(i)-1 for i in line[:-1]] vals[k,0] = line[-1] return subs, vals + +def import_array(fp, n): + return np.fromfile(fp, count=n, sep=' ') diff --git a/tests/test_import_export_data.py b/tests/test_import_export_data.py new file mode 100644 index 00000000..db1f26b0 --- /dev/null +++ b/tests/test_import_export_data.py @@ -0,0 +1,108 @@ +# Copyright 2022 National Technology & Engineering Solutions of Sandia, +# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the +# U.S. Government retains certain rights in this software. + +import numpy as np +import pytest +import os +import TensorToolbox as ttb + +@pytest.fixture() +def sample_tensor_2way(): + data = np.array([[1., 2., 3.], [4., 5., 6.]]) + shape = (2, 3) + params = {'data':data, 'shape': shape} + tensorInstance = ttb.tensor().from_data(data, shape) + return params, tensorInstance + +@pytest.fixture() +def sample_tensor_3way(): + data = np.array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.]) + shape = (2, 3, 2) + params = {'data':np.reshape(data, np.array(shape), order='F'), 'shape': shape} + tensorInstance = ttb.tensor().from_data(data, shape) + return params, tensorInstance + +@pytest.fixture() +def sample_tensor_4way(): + data = np.arange(1, 82) + shape = (3, 3, 3, 3) + params = {'data':np.reshape(data, np.array(shape), order='F'), 'shape': shape} + tensorInstance = ttb.tensor().from_data(data, shape) + return params, tensorInstance + +@pytest.mark.indevelopment +def test_import_data_tensor(): + # truth data + T = ttb.tensor.from_data(np.ones((3,3,3)), (3,3,3)) + + # imported data + data_filename = os.path.join(os.path.dirname(__file__),'data','tensor.tns') + X = ttb.import_data(data_filename) + + assert X.shape == (3, 3, 3) + assert T.isequal(X) + +@pytest.mark.indevelopment +def test_import_data_sptensor(): + # truth data + subs = np.array([[0, 0, 0],[0, 2, 2],[1, 1, 1],[1, 2, 0],[1, 2, 1],[1, 2, 2], + [1, 3, 1],[2, 0, 0],[2, 0, 1],[2, 2, 0],[2, 2, 1],[2, 3, 0], + [2, 3, 2],[3, 0, 0],[3, 0, 1],[3, 2, 0],[4, 0, 2],[4, 3, 2]]) + vals = np.reshape(np.array(range(1,19)),(18,1)) + shape = (5, 4, 3) + S = ttb.sptensor().from_data(subs, vals, shape) + + # imported data + data_filename = os.path.join(os.path.dirname(__file__),'data','sptensor.tns') + X = ttb.import_data(data_filename) + + assert S.isequal(X) + +@pytest.mark.indevelopment +def test_import_data_ktensor(): + # truth data + weights = np.array([3, 2]) + fm0 = np.array([[1., 5.], [2., 6.], [3., 7.], [4., 8.]]) + fm1 = np.array([[ 2., 7.], [ 3., 8.], [ 4., 9.], [ 5., 10.], [ 6., 11.]]) + fm2 = np.array([[3., 6.], [4., 7.], [5., 8.]]) + factor_matrices = [fm0, fm1, fm2] + K = ttb.ktensor.from_data(weights, factor_matrices) + + # imported data + data_filename = os.path.join(os.path.dirname(__file__),'data','ktensor.tns') + X = ttb.import_data(data_filename) + + assert K.isequal(X) + +@pytest.mark.indevelopment +def test_import_data_array(): + # truth data + M = np.array([[1., 5.], [2., 6.], [3., 7.], [4., 8.]]) + print('\nM') + print(M) + + # imported data + data_filename = os.path.join(os.path.dirname(__file__),'data','matrix.tns') + X = ttb.import_data(data_filename) + print('\nX') + print(X) + + assert (M == X).all() + +@pytest.mark.indevelopment +def test_export_data_tensor(): + pass + +@pytest.mark.indevelopment +def test_export_data_sptensor(): + pass + +@pytest.mark.indevelopment +def test_export_data_ktensor(): + pass + +@pytest.mark.indevelopment +def test_export_data_array(): + pass +