|
| 1 | +#include <ATen/ATen.h> |
| 2 | +#include <ATen/native/mps/OperationUtils.h> |
| 3 | +#include <ATen/native/mps/MPSGraphVenturaOps.h> |
| 4 | +#include <torch/library.h> |
| 5 | +#include <c10/util/Optional.h> |
| 6 | + |
| 7 | + |
| 8 | +namespace at { |
| 9 | +namespace native { |
| 10 | + |
| 11 | +TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) |
| 12 | +{ |
| 13 | + TORCH_CHECK(result.is_mps(), "Output tensor is not MPS"); |
| 14 | + if (!is_macos_13_or_newer()) { |
| 15 | + TORCH_WARN_ONCE("torch.linalg_inv_ex.inverse is supported by MPS on MacOS 13+, please upgrade. Falling back to CPU."); |
| 16 | + auto cpu_info = at::empty({0}, kInt, c10::nullopt, kCPU, c10::nullopt, c10::nullopt); |
| 17 | + auto cpu_result = result.clone().to("cpu"); |
| 18 | + at::linalg_inv_ex_out(cpu_result, cpu_info, A.to("cpu")); |
| 19 | + info.copy_(cpu_info); |
| 20 | + result.copy_(cpu_result); |
| 21 | + return; |
| 22 | + } |
| 23 | + |
| 24 | + using namespace mps; |
| 25 | + MPSStream* stream = getCurrentMPSStream(); |
| 26 | + info.zero_(); |
| 27 | + |
| 28 | + struct CachedGraph : public MPSCachedGraph |
| 29 | + { |
| 30 | + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} |
| 31 | + MPSGraphTensor* inputTensor_ = nil; |
| 32 | + MPSGraphTensor* outputTensor_ = nil; |
| 33 | + }; |
| 34 | + |
| 35 | + Tensor output = result; |
| 36 | + bool isContiguous = true; |
| 37 | + if (!result.is_contiguous()) { |
| 38 | + output = result.contiguous(); |
| 39 | + isContiguous = false; |
| 40 | + } |
| 41 | + |
| 42 | + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); |
| 43 | + |
| 44 | + @autoreleasepool { |
| 45 | + string key = "inv_out_mps" + getTensorsStringKey({A}); |
| 46 | + CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key)); |
| 47 | + if(!cachedGraph) |
| 48 | + { |
| 49 | + MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { |
| 50 | + |
| 51 | + CachedGraph *newCachedGraph = nil; |
| 52 | + @autoreleasepool { |
| 53 | + MPSGraph* mpsGraph = make_mps_graph(); |
| 54 | + newCachedGraph = new CachedGraph(mpsGraph); |
| 55 | + MPSGraphTensor* inputTensor= mpsGraphRankedPlaceHolder(mpsGraph, A); |
| 56 | + MPSGraphTensor* outputTensor = [mpsGraph inverseOfTensor: inputTensor |
| 57 | + name: nil]; |
| 58 | + |
| 59 | + newCachedGraph->inputTensor_ = inputTensor; |
| 60 | + newCachedGraph->outputTensor_ = outputTensor; |
| 61 | + } |
| 62 | + |
| 63 | + return newCachedGraph; |
| 64 | + |
| 65 | + }); |
| 66 | + cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph); |
| 67 | + } |
| 68 | + |
| 69 | + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, A); |
| 70 | + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, isContiguous ? result : output); |
| 71 | + |
| 72 | + NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{ |
| 73 | + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData() |
| 74 | + }; |
| 75 | + |
| 76 | + NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{ |
| 77 | + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() |
| 78 | + }; |
| 79 | + |
| 80 | + runMPSGraph(stream, cachedGraph->graph(), feeds, results); |
| 81 | + if (!isContiguous) { |
| 82 | + result.copy_(output); |
| 83 | + } |
| 84 | + } |
| 85 | +} |
| 86 | +} |
| 87 | +} |
0 commit comments