Skip to content

Commit

Permalink
torch.linalg.lstsq: forward/backward AD support (pytorch#65054)
Browse files Browse the repository at this point in the history
Summary:
As per title.

cc ezyang albanD zou3519 gqchen pearu nikitaved soulitzer Lezcano Varal7 jianyuh mruberry walterddr IvanYashchuk xwang233

Pull Request resolved: pytorch#65054

Reviewed By: zou3519

Differential Revision: D31729468

Pulled By: albanD

fbshipit-source-id: ab7df824bc80128e7f64f6444c7a4baa4786c161
  • Loading branch information
nikitaved authored and facebook-github-bot committed Oct 18, 2021
1 parent 6bde474 commit 7fad47e
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 15 deletions.
4 changes: 2 additions & 2 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -871,8 +871,8 @@
A: not_implemented("lstsq")

- name: linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)
self: not_implemented("linalg_lstsq")
b: not_implemented("linalg_lstsq")
self, b: linalg_lstsq_backward(grad, self, b, rcond, driver, grad_input_mask)
solution: linalg_lstsq_jvp(self_p, b_p, self_t, b_t)
output_differentiability: [True, False, False, False]

- name: lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm', '_conj_physical',
'scatter', 'scatter_add', 'sigmoid', 'sigmoid_backward', 'trapezoid', 'cumulative_trapezoid',
'conj_physical_', '_neg_view', '_reshape_alias', '_det_lu_based_helper', 'lu_solve', '_lu_with_info',
'linalg_pinv',
'linalg_pinv', 'linalg_lstsq',
}

GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
Expand Down
49 changes: 49 additions & 0 deletions torch/csrc/autograd/FunctionsManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2494,6 +2494,55 @@ Tensor linalg_eig_backward(const std::vector<torch::autograd::Variable> &grads,
}
}

Tensor linalg_lstsq_jvp(
const Tensor& A,
const Tensor& B,
const Tensor& dA,
const Tensor& dB
) {
auto pinvA = at::linalg_pinv(A);
auto dpinvA = pinv_jvp(A, pinvA, dA);
auto dX = dpinvA.matmul(B) + pinvA.matmul(dB);
return dX;
}

std::tuple<Tensor, Tensor> linalg_lstsq_backward(
const Tensor& grad,
const Tensor& A,
const Tensor& B,
const c10::optional<double> rcond,
const c10::optional<c10::string_view> driver,
const std::array<bool, 2>& grad_input_mask
) {
Tensor A_grad, B_grad;
if (!grad.defined()) {
return std::make_tuple(A_grad, B_grad);
}

auto A_requires_grad = grad_input_mask[0];
auto B_requires_grad = grad_input_mask[1];

Tensor pinvA;
if (A_requires_grad) {
pinvA = at::linalg_pinv(A);
auto pinvA_grad = grad.matmul(B.transpose(-1, -2).conj());
A_grad = pinv_backward(pinvA_grad, pinvA, A);
}

if (B_requires_grad) {
if (!pinvA.defined()) {
pinvA = at::linalg_pinv(A);
}
// Equivalent to
// B_grad = std::get<0>(at::linalg_lstsq(A.transpose(-1, -2).conj(), grad, rcond, driver));
// but we avoid this approach as `gelsy` is non-deterministic
B_grad = pinvA.transpose(-1, -2).conj().matmul(grad);
}

return std::make_tuple(A_grad, B_grad);
}


// jvp functions for eigenvalues and eigenvectors are separate
// because currently forward AD only works with one rule per output
Tensor eigh_jvp_eigenvalues(
Expand Down
15 changes: 15 additions & 0 deletions torch/csrc/autograd/FunctionsManual.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,12 @@ Tensor slice_backward_wrapper(
int64_t step);
Tensor linalg_eig_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
const Tensor& L, const Tensor& V);
Tensor linalg_lstsq_jvp(
const Tensor& A,
const Tensor& B,
const Tensor& dA,
const Tensor& dB
);
Tensor eigh_jvp_eigenvectors(const Tensor& input_tangent, const Tensor& eigenvalues, const Tensor& eigenvectors);
Tensor eigh_jvp_eigenvalues(const Tensor& input_tangent, const Tensor& eigenvalues, const Tensor& eigenvectors);
Tensor eigh_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
Expand Down Expand Up @@ -299,6 +305,15 @@ Tensor _det_lu_based_helper_backward(
const Tensor& pivs
);

std::tuple<Tensor, Tensor> linalg_lstsq_backward(
const Tensor& grad,
const Tensor& A,
const Tensor& B,
const c10::optional<double> rcond,
const c10::optional<c10::string_view> driver,
const std::array<bool, 2>& grad_input_mask
);

Tensor lu_backward_base(
const variable_list& grads,
const Tensor& self,
Expand Down
11 changes: 7 additions & 4 deletions torch/testing/_internal/common_device_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,15 +1146,18 @@ def skipCPUIfNoMkl(fn):
def skipCUDAIfNoMagma(fn):
return skipCUDAIf('no_magma', "no MAGMA library detected")(skipCUDANonDefaultStreamIf(True)(fn))

def has_cusolver():
version = _get_torch_cuda_version()
# cuSolver is disabled on cuda < 10.1.243
return version >= (10, 2)

# Skips a test on CUDA if cuSOLVER is not available
def skipCUDAIfNoCusolver(fn):
version = _get_torch_cuda_version()
return skipCUDAIf(version < (10, 2), "cuSOLVER not available")(fn)
return skipCUDAIf(not has_cusolver(), "cuSOLVER not available")(fn)

# Skips a test if both cuSOLVER and MAGMA are not available
def skipCUDAIfNoMagmaAndNoCusolver(fn):
version = _get_torch_cuda_version()
if version >= (10, 2):
if has_cusolver():
return fn
else:
# cuSolver is disabled on cuda < 10.1.243, tests depend on MAGMA
Expand Down
54 changes: 46 additions & 8 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch.testing._internal.common_device_type import \
(onlyCUDA, onlyOnCPUAndCUDA, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
skipCUDAIfNoCusolver, skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIfRocm, precisionOverride,
toleranceOverride, tol)
toleranceOverride, tol, has_cusolver)
from torch.testing._internal.common_cuda import CUDA11OrLater, SM53OrLater, SM60OrLater
from torch.testing._internal.common_utils import \
(is_iterable_of_tensors,
Expand Down Expand Up @@ -3936,13 +3936,31 @@ def sample_inputs_linalg_cholesky_inverse(op_info, device, dtype, requires_grad=

def sample_inputs_linalg_lstsq(op_info, device, dtype, requires_grad=False, **kwargs):
from torch.testing._internal.common_utils import random_well_conditioned_matrix

device = torch.device(device)

drivers: Tuple[str, ...]
if device.type == 'cuda':
drivers = ('gels',)
else:
drivers = ('gels', 'gelsy', 'gelss', 'gelsd')

# we generate matrices of shape (..., n + delta, n)
deltas: Tuple[int, ...]
if device.type == 'cpu' or has_cusolver():
deltas = (-1, 0, +1)
# only square systems if Cusolver is not available
# becase we solve a lstsq problem with a transposed matrix in the backward
else:
deltas = (0,)

out = []
for batch in ((), (3,), (3, 3)):
shape = batch + (3, 3)
for batch, driver, delta in product(((), (3,), (3, 3)), drivers, deltas):
shape = batch + (3 + delta, 3)
a = random_well_conditioned_matrix(*shape, dtype=dtype, device=device)
a.requires_grad = requires_grad
a.requires_grad_(requires_grad)
b = make_tensor(shape, device, dtype, low=None, high=None, requires_grad=requires_grad)
out.append(SampleInput(a, args=(b,)))
out.append(SampleInput(a, args=(b,), kwargs=dict(driver=driver)))
return out

def sample_inputs_householder_product(op_info, device, dtype, requires_grad, **kwargs):
Expand Down Expand Up @@ -7721,12 +7739,32 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
decorators=[skipCUDAIfNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack]),
OpInfo('linalg.lstsq',
aten_name='linalg_lstsq',
op=torch.linalg.lstsq,
dtypes=floating_and_complex_types(),
supports_out=True,
sample_inputs_func=sample_inputs_linalg_lstsq,
supports_autograd=False,
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]),
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
skips=(
# we skip gradient checks for this suite as they are tested in
# variant_test_name='grad_oriented'
DecorateInfo(unittest.skip("Skipped!"), 'TestGradients'),
)),
OpInfo('linalg.lstsq',
aten_name='linalg_lstsq',
variant_test_name='grad_oriented',
# gradchecks for forward AD fails with multi-Tensor outputs
op=lambda a, b, driver: torch.linalg.lstsq(a, b, driver=driver)[0],
dtypes=floating_and_complex_types(),
sample_inputs_func=sample_inputs_linalg_lstsq,
supports_autograd=True,
supports_forward_ad=True,
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
skips=(
# tests do not work with passing lambda for op
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
)),
OpInfo('linalg.matrix_power',
aliases=('matrix_power',),
aten_name='linalg_matrix_power',
Expand Down

0 comments on commit 7fad47e

Please sign in to comment.