Skip to content

Commit

Permalink
Update internal code for torch.geqrf (#56250)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #56250

Moved `apply_geqrf` to `BatchLinearAlgebraKernel.cpp`. Added
`geqrf_stub` dispatch.

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D27907362

Pulled By: mruberry

fbshipit-source-id: 6719464aef29dcf3bbbde060edf79f1e32fc8ad6
  • Loading branch information
IvanYashchuk authored and facebook-github-bot committed Apr 25, 2021
1 parent d5ff432 commit e97c17a
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 80 deletions.
80 changes: 3 additions & 77 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,6 @@ void lapackCholeskySolve(char uplo, int n, int nrhs, scalar_t *a, int lda, scala
template<class scalar_t>
void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info);

template<class scalar_t>
void lapackGeqrf(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);

template<class scalar_t, class value_t=scalar_t>
void lapackSymeig(char jobz, char uplo, int n, scalar_t *a, int lda, value_t *w, scalar_t *work, int lwork, value_t *rwork, int *info);

Expand Down Expand Up @@ -1604,70 +1601,7 @@ std::tuple<Tensor&, Tensor&> triangular_solve_out(const Tensor& self, const Tens

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ qr ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

/*
The geqrf function computes QR decomposition of matrices stored in `self`.
However, rather than producing a Q matrix directly, it produces a sequence of
elementary reflectors which may later be composed to construct Q - for example
with the orgqr or ormqr functions.
Args:
* `self` - [in] Input tensor for QR decomposition
[out] QR decomposition result which contains:
i) The elements of R, on and above the diagonal.
ii) Directions of the reflectors implicitly defining Q.
Tensor with the directions of the elementary reflectors below the diagonal,
it will be overwritten with the result
* `tau` - [out] Tensor which will contain the magnitudes of the reflectors
implicitly defining Q.
* `m` - The number of rows of `self` to consider
* `n` - The number of columns of `self` to consider (actual sizes of `self` could be larger)
For further details, please see the LAPACK documentation for GEQRF.
*/
template <typename scalar_t>
static void apply_geqrf(const Tensor& self, const Tensor& tau, int64_t m, int64_t n) {
#ifndef USE_LAPACK
TORCH_CHECK(
false,
"Calling torch.geqrf on a CPU tensor requires compiling ",
"PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
#else
using value_t = typename c10::scalar_value_type<scalar_t>::type;
auto self_data = self.data_ptr<scalar_t>();
auto tau_data = tau.data_ptr<scalar_t>();
auto self_matrix_stride = matrixStride(self);
auto tau_stride = tau.size(-1);
auto batch_size = batchCount(self);
auto lda = std::max<int>(1, m);

int info;
// Run once, first to get the optimum work size.
// Since we deal with batches of matrices with the same dimensions, doing this outside
// the loop saves (batch_size - 1) workspace queries which would provide the same result
// and (batch_size - 1) calls to allocate and deallocate workspace using at::empty()
int lwork = -1;
scalar_t wkopt;
lapackGeqrf<scalar_t>(m, n, self_data, lda, tau_data, &wkopt, lwork, &info);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);

// if lwork is less than 'n' then a warning is printed:
// Intel MKL ERROR: Parameter 7 was incorrect on entry to SGEQRF.
lwork = std::max<int>(std::max<int>(1, n), real_impl<scalar_t, value_t>(wkopt));
Tensor work = at::empty({lwork}, self.options());

for (const auto i : c10::irange(batch_size)) {
scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
scalar_t* tau_working_ptr = &tau_data[i * tau_stride];

// now compute the actual QR and tau
lapackGeqrf<scalar_t>(m, n, self_working_ptr, lda, tau_working_ptr, work.data_ptr<scalar_t>(), lwork, &info);

// info from lapackGeqrf only reports if the i-th parameter is wrong
// so we don't need to check it all the time
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
}
#endif
}
DEFINE_DISPATCH(geqrf_stub);

static void geqrf_out_helper(const Tensor& input, const Tensor& QR, const Tensor& tau) {
TORCH_INTERNAL_ASSERT(input.dim() >= 2);
Expand Down Expand Up @@ -1700,13 +1634,7 @@ static void geqrf_out_helper(const Tensor& input, const Tensor& QR, const Tensor

// geqrf_stub (apply_geqrf) performs calculations in-place and 'QR' must be a copy of input
QR.copy_(input);

// TODO: implement geqrf_stub
// DEFINE_DISPATCH(geqrf_stub);
// geqrf_stub(input.device().type(), QR, tau);
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "geqrf_cpu", [&]{
apply_geqrf<scalar_t>(QR, tau, input.size(-2), input.size(-1));
});
geqrf_stub(input.device().type(), QR, tau, input.size(-2), input.size(-1));
}

std::tuple<Tensor&, Tensor&> geqrf_out(const Tensor& input, Tensor& QR, Tensor& tau) {
Expand Down Expand Up @@ -1801,9 +1729,7 @@ std::tuple<Tensor, Tensor> _linalg_qr_helper_cpu(const Tensor& self, std::string
q_working_copy = at::empty_strided(q_sizes, q_strides, self.options());
q_working_copy.narrow(-1, 0, n).copy_(self);

AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "qr_cpu", [&]{
apply_geqrf<scalar_t>(q_working_copy, tau_working_copy, m, n);
});
geqrf_stub(q_working_copy.device().type(), q_working_copy, tau_working_copy, m, n);

R = q_working_copy.slice(-2, 0, n_columns_q).slice(-1, 0, n).triu();
if (!compute_q) {
Expand Down
12 changes: 9 additions & 3 deletions aten/src/ATen/native/BatchLinearAlgebra.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ namespace at { namespace native {
// Define per-batch functions to be used in the implementation of batched
// linear algebra operations

template<class scalar_t>
template <class scalar_t>
void lapackCholeskyInverse(char uplo, int n, scalar_t *a, int lda, int *info);

template<class scalar_t, class value_t=scalar_t>
template <class scalar_t, class value_t=scalar_t>
void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info);

template<class scalar_t>
template <class scalar_t>
void lapackGeqrf(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);

template <class scalar_t>
void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);

template <class scalar_t, class value_t = scalar_t>
Expand All @@ -43,6 +46,9 @@ using linalg_eig_fn = void (*)(Tensor& /*eigenvalues*/, Tensor& /*eigenvectors*/

DECLARE_DISPATCH(linalg_eig_fn, linalg_eig_stub);

using geqrf_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/, int64_t /*m*/, int64_t /*n*/);
DECLARE_DISPATCH(geqrf_fn, geqrf_stub);

using orgqr_fn = Tensor& (*)(Tensor& /*result*/, const Tensor& /*tau*/, int64_t /*n_columns*/);
DECLARE_DISPATCH(orgqr_fn, orgqr_stub);

Expand Down
77 changes: 77 additions & 0 deletions aten/src/ATen/native/BatchLinearAlgebraKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,78 @@ void linalg_eigh_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos
});
}

/*
The geqrf function computes the QR decomposition of matrices stored in `input`.
However, rather than producing a Q matrix directly, it produces a sequence of
elementary reflectors which may later be composed to construct Q - for example
with the orgqr or ormqr functions.
Args:
* `input` - [in] Input tensor for QR decomposition
[out] QR decomposition result which contains:
i) The elements of R, on and above the diagonal.
ii) Directions of the reflectors implicitly defining Q.
Tensor with the directions of the elementary reflectors below the diagonal,
it will be overwritten with the result
* `tau` - [out] Tensor which will contain the magnitudes of the reflectors
implicitly defining Q.
* `m` - The number of rows of `input` to consider
* `n` - The number of columns of `input` to consider (actual sizes of `input` could be larger)
For further details, please see the LAPACK documentation for GEQRF.
*/
template <typename scalar_t>
static void apply_geqrf(const Tensor& input, const Tensor& tau, int64_t m, int64_t n) {
#ifndef USE_LAPACK
TORCH_CHECK(
false,
"Calling torch.geqrf on a CPU tensor requires compiling ",
"PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
#else
using value_t = typename c10::scalar_value_type<scalar_t>::type;
auto input_data = input.data_ptr<scalar_t>();
auto tau_data = tau.data_ptr<scalar_t>();
auto input_matrix_stride = matrixStride(input);
auto tau_stride = tau.size(-1);
auto batch_size = batchCount(input);
auto lda = std::max<int>(1, m);

int info;
// Run once, first to get the optimum work size.
// Since we deal with batches of matrices with the same dimensions, doing this outside
// the loop saves (batch_size - 1) workspace queries which would provide the same result
// and (batch_size - 1) calls to allocate and deallocate workspace using at::empty()
int lwork = -1;
scalar_t wkopt;
lapackGeqrf<scalar_t>(m, n, input_data, lda, tau_data, &wkopt, lwork, &info);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);

// if lwork is less than 'n' then a warning is printed:
// Intel MKL ERROR: Parameter 7 was incorrect on entry to SGEQRF.
lwork = std::max<int>(std::max<int>(1, n), real_impl<scalar_t, value_t>(wkopt));
Tensor work = at::empty({lwork}, input.options());

for (const auto i : c10::irange(batch_size)) {
scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
scalar_t* tau_working_ptr = &tau_data[i * tau_stride];

// now compute the actual QR and tau
lapackGeqrf<scalar_t>(m, n, input_working_ptr, lda, tau_working_ptr, work.data_ptr<scalar_t>(), lwork, &info);

// info from lapackGeqrf only reports if the i-th parameter is wrong
// so we don't need to check it all the time
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
}
#endif
}

// This is a type dispatching helper function for 'apply_geqrf'
void geqrf_kernel(const Tensor& input, const Tensor& tau, int64_t m, int64_t n) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "geqrf_cpu", [&]{
apply_geqrf<scalar_t>(input, tau, m, n);
});
}

/*
The orgqr function allows reconstruction of an orthogonal (or unitary) matrix Q,
from a sequence of elementary reflectors, such as produced by the geqrf function.
Expand Down Expand Up @@ -481,6 +553,11 @@ REGISTER_AVX_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
REGISTER_AVX2_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
REGISTER_VSX_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);

REGISTER_ARCH_DISPATCH(geqrf_stub, DEFAULT, &geqrf_kernel);
REGISTER_AVX_DISPATCH(geqrf_stub, &geqrf_kernel);
REGISTER_AVX2_DISPATCH(geqrf_stub, &geqrf_kernel);
REGISTER_VSX_DISPATCH(geqrf_stub, &geqrf_kernel);

REGISTER_ARCH_DISPATCH(orgqr_stub, DEFAULT, &orgqr_kernel_impl);
REGISTER_AVX_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
REGISTER_AVX2_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
Expand Down

0 comments on commit e97c17a

Please sign in to comment.