Skip to content

Commit 8ecb49b

Browse files
kulinsethpytorchmergebot
authored andcommitted
[MPS] Add Inverse op. (pytorch#90428)
Pull Request resolved: pytorch#90428 Approved by: https://github.com/DenisVieriu97, https://github.com/malfet
1 parent 58b5a9d commit 8ecb49b

File tree

4 files changed

+107
-1
lines changed

4 files changed

+107
-1
lines changed

aten/src/ATen/native/mps/MPSGraphVenturaOps.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,7 @@
1414
- (MPSGraphTensor *)argSortWithTensor:(MPSGraphTensor *)tensor
1515
axis:(NSInteger)axis
1616
name:(NSString *)name;
17+
18+
- (MPSGraphTensor *)inverseOfTensor: (MPSGraphTensor *)tensor
19+
name:(NSString *)name;
1720
@end
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/native/mps/OperationUtils.h>
3+
#include <ATen/native/mps/MPSGraphVenturaOps.h>
4+
#include <torch/library.h>
5+
#include <c10/util/Optional.h>
6+
7+
8+
namespace at {
9+
namespace native {
10+
11+
TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info)
12+
{
13+
TORCH_CHECK(result.is_mps(), "Output tensor is not MPS");
14+
if (!is_macos_13_or_newer()) {
15+
TORCH_WARN_ONCE("torch.linalg_inv_ex.inverse is supported by MPS on MacOS 13+, please upgrade. Falling back to CPU.");
16+
auto cpu_info = at::empty({0}, kInt, c10::nullopt, kCPU, c10::nullopt, c10::nullopt);
17+
auto cpu_result = result.clone().to("cpu");
18+
at::linalg_inv_ex_out(cpu_result, cpu_info, A.to("cpu"));
19+
info.copy_(cpu_info);
20+
result.copy_(cpu_result);
21+
return;
22+
}
23+
24+
using namespace mps;
25+
MPSStream* stream = getCurrentMPSStream();
26+
info.zero_();
27+
28+
struct CachedGraph : public MPSCachedGraph
29+
{
30+
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
31+
MPSGraphTensor* inputTensor_ = nil;
32+
MPSGraphTensor* outputTensor_ = nil;
33+
};
34+
35+
Tensor output = result;
36+
bool isContiguous = true;
37+
if (!result.is_contiguous()) {
38+
output = result.contiguous();
39+
isContiguous = false;
40+
}
41+
42+
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
43+
44+
@autoreleasepool {
45+
string key = "inv_out_mps" + getTensorsStringKey({A});
46+
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
47+
if(!cachedGraph)
48+
{
49+
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
50+
51+
CachedGraph *newCachedGraph = nil;
52+
@autoreleasepool {
53+
MPSGraph* mpsGraph = make_mps_graph();
54+
newCachedGraph = new CachedGraph(mpsGraph);
55+
MPSGraphTensor* inputTensor= mpsGraphRankedPlaceHolder(mpsGraph, A);
56+
MPSGraphTensor* outputTensor = [mpsGraph inverseOfTensor: inputTensor
57+
name: nil];
58+
59+
newCachedGraph->inputTensor_ = inputTensor;
60+
newCachedGraph->outputTensor_ = outputTensor;
61+
}
62+
63+
return newCachedGraph;
64+
65+
});
66+
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
67+
}
68+
69+
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, A);
70+
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, isContiguous ? result : output);
71+
72+
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
73+
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData()
74+
};
75+
76+
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
77+
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
78+
};
79+
80+
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
81+
if (!isContiguous) {
82+
result.copy_(output);
83+
}
84+
}
85+
}
86+
}
87+
}

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12897,6 +12897,7 @@
1289712897
structured: True
1289812898
dispatch:
1289912899
CPU, CUDA: linalg_inv_ex_out
12900+
MPS: linalg_inv_ex_out_mps
1290012901

1290112902
- func: linalg_inv(Tensor A) -> Tensor
1290212903
python_module: linalg

test/test_mps.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4729,6 +4729,21 @@ def helper(shape, diag=0):
47294729
helper((2, 8, 4, 5), diag=-2)
47304730
helper((2, 8, 4, 5), diag=-3)
47314731

4732+
# Test inverse
4733+
def test_inverse(self):
4734+
def helper(n):
4735+
cpu_input = torch.randn(n, n, device='cpu')
4736+
mps_input = cpu_input.to('mps')
4737+
4738+
cpu_result = torch.linalg.inv(cpu_input)
4739+
mps_result = torch.linalg.inv(mps_input)
4740+
self.assertEqual(cpu_result, mps_result)
4741+
4742+
helper(2)
4743+
helper(6)
4744+
helper(3)
4745+
helper(8)
4746+
47324747
# Test tril
47334748
def test_tril(self):
47344749
def helper(shape, diag=0):
@@ -7796,6 +7811,7 @@ class TestConsistency(TestCase):
77967811
'diag_embed': [torch.uint8],
77977812
'diagonal_scatter': [torch.uint8],
77987813
'index_add': None,
7814+
'linalg.inv': ['f32'],
77997815
'log1p': None,
78007816
'long': None,
78017817
'nn.functional.avg_pool1d': [torch.int64],
@@ -7814,7 +7830,6 @@ class TestConsistency(TestCase):
78147830
'slice_scatter': [torch.uint8],
78157831
'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8], # moved from section below
78167832

7817-
78187833
# ALLOW_LIST doesn't know about variants
78197834
'nn.functional.padconstant': None,
78207835

0 commit comments

Comments
 (0)