Skip to content

Commit

Permalink
[MPS] Add Inverse op. (pytorch#90428)
Browse files Browse the repository at this point in the history
  • Loading branch information
kulinseth authored and pytorchmergebot committed Dec 19, 2022
1 parent 58b5a9d commit 8ecb49b
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 1 deletion.
3 changes: 3 additions & 0 deletions aten/src/ATen/native/mps/MPSGraphVenturaOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,7 @@
- (MPSGraphTensor *)argSortWithTensor:(MPSGraphTensor *)tensor
axis:(NSInteger)axis
name:(NSString *)name;

- (MPSGraphTensor *)inverseOfTensor: (MPSGraphTensor *)tensor
name:(NSString *)name;
@end
87 changes: 87 additions & 0 deletions aten/src/ATen/native/mps/operations/Inverse.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#include <ATen/ATen.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/mps/MPSGraphVenturaOps.h>
#include <torch/library.h>
#include <c10/util/Optional.h>


namespace at {
namespace native {

TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info)
{
TORCH_CHECK(result.is_mps(), "Output tensor is not MPS");
if (!is_macos_13_or_newer()) {
TORCH_WARN_ONCE("torch.linalg_inv_ex.inverse is supported by MPS on MacOS 13+, please upgrade. Falling back to CPU.");
auto cpu_info = at::empty({0}, kInt, c10::nullopt, kCPU, c10::nullopt, c10::nullopt);
auto cpu_result = result.clone().to("cpu");
at::linalg_inv_ex_out(cpu_result, cpu_info, A.to("cpu"));
info.copy_(cpu_info);
result.copy_(cpu_result);
return;
}

using namespace mps;
MPSStream* stream = getCurrentMPSStream();
info.zero_();

struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};

Tensor output = result;
bool isContiguous = true;
if (!result.is_contiguous()) {
output = result.contiguous();
isContiguous = false;
}

MPSGraphCache* cache_ = MPSGraphCache::getInstance();

@autoreleasepool {
string key = "inv_out_mps" + getTensorsStringKey({A});
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph)
{
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {

CachedGraph *newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* inputTensor= mpsGraphRankedPlaceHolder(mpsGraph, A);
MPSGraphTensor* outputTensor = [mpsGraph inverseOfTensor: inputTensor
name: nil];

newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
}

return newCachedGraph;

});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}

Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, A);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, isContiguous ? result : output);

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()
};

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};

runMPSGraph(stream, cachedGraph->graph(), feeds, results);
if (!isContiguous) {
result.copy_(output);
}
}
}
}
}
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12897,6 +12897,7 @@
structured: True
dispatch:
CPU, CUDA: linalg_inv_ex_out
MPS: linalg_inv_ex_out_mps

- func: linalg_inv(Tensor A) -> Tensor
python_module: linalg
Expand Down
17 changes: 16 additions & 1 deletion test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4729,6 +4729,21 @@ def helper(shape, diag=0):
helper((2, 8, 4, 5), diag=-2)
helper((2, 8, 4, 5), diag=-3)

# Test inverse
def test_inverse(self):
def helper(n):
cpu_input = torch.randn(n, n, device='cpu')
mps_input = cpu_input.to('mps')

cpu_result = torch.linalg.inv(cpu_input)
mps_result = torch.linalg.inv(mps_input)
self.assertEqual(cpu_result, mps_result)

helper(2)
helper(6)
helper(3)
helper(8)

# Test tril
def test_tril(self):
def helper(shape, diag=0):
Expand Down Expand Up @@ -7796,6 +7811,7 @@ class TestConsistency(TestCase):
'diag_embed': [torch.uint8],
'diagonal_scatter': [torch.uint8],
'index_add': None,
'linalg.inv': ['f32'],
'log1p': None,
'long': None,
'nn.functional.avg_pool1d': [torch.int64],
Expand All @@ -7814,7 +7830,6 @@ class TestConsistency(TestCase):
'slice_scatter': [torch.uint8],
'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8], # moved from section below


# ALLOW_LIST doesn't know about variants
'nn.functional.padconstant': None,

Expand Down

0 comments on commit 8ecb49b

Please sign in to comment.