Skip to content

Commit

Permalink
Add CSR (compressed sparse row) layout for sparse tensors (pytorch#50937
Browse files Browse the repository at this point in the history
)

Summary:
Implement compressed sparse row format. Derived from the GCS implementation at pytorch#44190

Pull Request resolved: pytorch#50937

Reviewed By: mrshenli

Differential Revision: D27439865

Pulled By: ezyang

fbshipit-source-id: 3ba3dcb9679505b980ff6a5f513e913bbae2fb1d
  • Loading branch information
v0dro authored and facebook-github-bot committed Apr 12, 2021
1 parent c6d9ca0 commit 5fb1142
Show file tree
Hide file tree
Showing 52 changed files with 2,310 additions and 201 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ genrule(
"aten/src/ATen/RegisterMkldnnCPU.cpp",
"aten/src/ATen/RegisterQuantizedCPU.cpp",
"aten/src/ATen/RegisterSparseCPU.cpp",
"aten/src/ATen/RegisterSparseCsrCPU.cpp",
"aten/src/ATen/RegisterCompositeImplicitAutograd.cpp",
"aten/src/ATen/RegisterMeta.cpp",
"aten/src/ATen/RegisterCompositeExplicitAutograd.cpp",
Expand Down
153 changes: 153 additions & 0 deletions aten/src/ATen/SparseCsrTensorImpl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#include <ATen/ATen.h>
#include <ATen/InitialTensorOptions.h>
#include <ATen/SparseCsrTensorImpl.h>
#include <ATen/SparseTensorImpl.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/core/LegacyTypeDispatch.h>

namespace at {
namespace {
DeviceType SparseCsrTensorSetToDeviceType(DispatchKeySet key_set) {
if (key_set.has(DispatchKey::SparseCsrCPU)) {
return kCPU;
} else if (key_set.has(DispatchKey::SparseCsrCUDA)) {
return kCUDA;
} else {
TORCH_CHECK(false,
"Cannot construct SparseCsrTensor with non-sparse tensor type ID ",
key_set);
}
}
} // namespace

SparseCsrTensorImpl::SparseCsrTensorImpl(
at::DispatchKeySet key_set,
const caffe2::TypeMeta data_type)
: SparseCsrTensorImpl(
key_set,
data_type,
at::empty(
{0},
at::initialTensorOptions()
.device(SparseCsrTensorSetToDeviceType(key_set))
.dtype(ScalarType::Int)) // crow_indices
,
at::empty(
{0},
at::initialTensorOptions()
.device(SparseCsrTensorSetToDeviceType(key_set))
.dtype(ScalarType::Int)) // col_indices
,
at::empty(
{0},
at::initialTensorOptions()
.device(SparseCsrTensorSetToDeviceType(key_set))
.dtype(data_type)) // values
) {}

SparseCsrTensorImpl::SparseCsrTensorImpl(
at::DispatchKeySet key_set,
const caffe2::TypeMeta data_type,
at::Tensor crow_indices,
at::Tensor col_indices,
at::Tensor values)
: TensorImpl(key_set, data_type, values.device()),
crow_indices_(std::move(crow_indices)),
col_indices_(std::move(col_indices)),
values_(std::move(values)) {}

void SparseCsrTensorImpl::resize_and_clear_(
const int64_t nnz_size,
IntArrayRef size) {
// call crow_indices().options() here since the struct contructor calls the
// tensor constructor with args for device specific init.
auto empty_crow_indices = at::empty(size[0] + 1, crow_indices().options());
auto empty_col_indices = at::empty(nnz_size, col_indices().options());
auto empty_values = at::empty(nnz_size, values().options());

crow_indices_ = empty_crow_indices;
col_indices_ = empty_col_indices;
values_ = empty_values;
sizes_and_strides_.set_sizes(size);
}

void SparseCsrTensorImpl::resize_as_sparse_csr_tensor_(const Tensor& src) {
crow_indices_ = at::empty_like(
src.crow_indices(),
src.crow_indices().options(),
src.crow_indices().suggest_memory_format());
col_indices_ = at::empty_like(
src.col_indices(),
src.col_indices().options(),
src.col_indices().suggest_memory_format());
values_ = at::empty_like(
src.values(),
src.values().options(),
src.values().suggest_memory_format());
sizes_and_strides_.set_sizes(src.sizes());
}

void SparseCsrTensorImpl::set_member_tensors(
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& values) {
auto crow_indices_type = crow_indices.scalar_type();
auto col_indices_type = col_indices.scalar_type();

TORCH_CHECK(
crow_indices_type == col_indices_type,
"both crow_indices and col_indices should have the same type.");
TORCH_CHECK(
crow_indices_type == kInt || crow_indices_type == kLong,
"crow_indices and col_indices must be an int32 or int64 type, but got: ",
crow_indices_type);
TORCH_CHECK(
values.scalar_type() == typeMetaToScalarType(dtype()),
"dtype of values (",
values.scalar_type(),
") must match dtype of sparse tensor (",
typeMetaToScalarType(dtype()),
")");

TORCH_CHECK(
col_indices.layout() == kStrided,
"expected col_indices to be a strided tensor, but got indices of layout ",
col_indices.layout());
TORCH_CHECK(
crow_indices.layout() == kStrided,
"expected crow_indices to be a strided tensor, but got crow_indices of layout ",
crow_indices.layout());
TORCH_CHECK(
values.layout() == kStrided && values.is_contiguous(),
"expected values to be a strided and contiguous tensor, but got values of layout ",
values.layout());

TORCH_CHECK(
values.device().type() == device().type(),
"device type of values (",
values.device().type(),
") must match device type of device().type()",
device().type(),
")");
TORCH_CHECK(
values.is_cuda() || col_indices.get_device() == crow_indices.get_device(),
"crow_indices and col_indices devices (",
crow_indices.get_device(),
", ",
col_indices.get_device(),
") must match with the (non-cuda) device of values (",
values.get_device(),
")");

TORCH_CHECK(
col_indices.size(0) == values.size(0),
"col_indices and values must have equal sizes, but got col_indices.size(0): ",
col_indices.size(0),
", values.size(0): ",
values.size(0));

crow_indices_ = crow_indices;
col_indices_ = col_indices;
values_ = values;
}
} // namespace at
55 changes: 55 additions & 0 deletions aten/src/ATen/SparseCsrTensorImpl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#pragma once

#include <ATen/Tensor.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/Exception.h>

namespace at {

// Struct implementing a sparse CSR tensor. It uses three 1-D tensors for
// denoting the data: `crow_indices_`, `col_indices_` and `values_`.
// The `crow_indices_` tensor is a integer tensor of shape `(size(0) + 1)`
// that represents the compressed row indices of the CSR tensor. The
// `col_indices_` tensor is an integer tensor of shape `(nnz())`
// that explicitly stores the column indices of each value of the sparse
// tensor. The `values_` tensor can be of any pytorch-supported data type
// and has shape `(nnz())`.
//
// Since the main advantage of the CSR format over the COO format is speed of
// computation, care must be taken to facilitate smooth interfacing of
// these data structures with optimized libraries such as MKL and MAGMA.
// Since the MKL interface for pytorch currently uses indexing with int32
// type, it is important to make sure that the `crow_indices` and `col_indices`
// are of type int32 when calling MKL routines such as SPMM or SPMV.
//
// If not calling MKL, it should be alright to use 64 bit integer tensors
// for indexing.
struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
Tensor crow_indices_;
Tensor col_indices_;
Tensor values_;

public:
explicit SparseCsrTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta);

void resize_and_clear_(const int64_t nnz_size, IntArrayRef size);
void resize_as_sparse_csr_tensor_(const Tensor& src);
void set_member_tensors(
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& values);

const Tensor& crow_indices() const { return crow_indices_; }
const Tensor& col_indices() const { return col_indices_; }
const Tensor& values() const { return values_; }
int nnz() { return values_.size(0); }

private:
explicit SparseCsrTensorImpl(
at::DispatchKeySet key_set,
const caffe2::TypeMeta data_type,
at::Tensor crow_indices,
at::Tensor col_indices,
at::Tensor values);
};
} // namespace at
20 changes: 20 additions & 0 deletions aten/src/ATen/SparseCsrTensorUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#pragma once

#include <ATen/ATen.h>
#include <ATen/SparseCsrTensorImpl.h>
#include <ATen/SparseTensorImpl.h>
#include <ATen/SparseTensorUtils.h>

namespace at {
namespace sparse_csr {

using SparseCsrTensor = Tensor;

inline SparseCsrTensorImpl* get_sparse_csr_impl(const SparseCsrTensor& self) {
AT_ASSERTM(
self.is_sparse_csr(),
"_internal_get_SparseCsrTensorImpl: not a sparse CSR tensor");
return static_cast<SparseCsrTensorImpl*>(self.unsafeGetTensorImpl());
}
} // namespace sparse
} // namespace at
4 changes: 4 additions & 0 deletions aten/src/ATen/core/DeprecatedTypeProperties.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class TORCH_API DeprecatedTypeProperties {
return layout_from_backend(backend()) == kSparse;
}

bool is_sparse_csr() const {
return layout_from_backend(backend()) == kSparseCsr;
}

DeviceType device_type() const {
return backendToDeviceType(backend_);
}
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ _(aten, is_same_size) \
_(aten, is_set_to) \
_(aten, is_signed) \
_(aten, is_sparse) \
_(aten, is_sparse_csr) \
_(aten, isclose) \
_(aten, isreal) \
_(aten, istft) \
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ Tensor& resize_as_(
!optional_memory_format.has_value(),
"Unsupported memory format for sparse tensor resize_as_ :",
optional_memory_format.value());
return native::resize_as_sparse_(self, the_template);
return at::native::resize_as_sparse_(self, the_template);
}
Tensor& result = self.resize_(the_template.sizes());
if (optional_memory_format.has_value()) {
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/TypeProperties.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ bool is_sparse(const Tensor& self) {
return self.is_sparse();
}

bool is_sparse_csr(const Tensor& self) {
return self.is_sparse_csr();
}

bool is_quantized(const Tensor& self) {
return self.is_quantized();
}
Expand Down
Loading

0 comments on commit 5fb1142

Please sign in to comment.