Skip to content

Commit e7698ff

Browse files
kurtamohlerpytorchmergebot
authored andcommitted
[MPS] Move abs op to Metal (pytorch#155474)
Pull Request resolved: pytorch#155474 Approved by: https://github.com/Skylion007, https://github.com/malfet
1 parent 7a48cc6 commit e7698ff

File tree

5 files changed

+24
-34
lines changed

5 files changed

+24
-34
lines changed

aten/src/ATen/native/mps/kernels/UnaryKernel.metal

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,17 @@ struct sigmoid_functor {
4343
}
4444
};
4545

46+
struct abs_functor {
47+
template <typename T, enable_if_t<!is_complex_v<T>, bool> = true>
48+
inline T operator()(const T x) {
49+
return static_cast<T>(precise::abs(x));
50+
}
51+
template <typename T, enable_if_t<is_complex_v<T>, bool> = true>
52+
inline T operator()(const T x) {
53+
return T(::precise::sqrt(dot(x, x)), 0);
54+
}
55+
};
56+
4657
struct sin_functor {
4758
template <typename T>
4859
inline enable_if_t<is_scalar_floating_point_v<T>, T> operator()(const T x) {
@@ -321,6 +332,14 @@ REGISTER_UNARY_OP(bitwise_not, char, char);
321332
REGISTER_UNARY_OP(bitwise_not, uchar, uchar);
322333
REGISTER_UNARY_OP(bitwise_not, bool, bool);
323334

335+
REGISTER_UNARY_OP(abs, int, int);
336+
REGISTER_UNARY_OP(abs, long, long);
337+
REGISTER_UNARY_OP(abs, short, short);
338+
REGISTER_UNARY_OP(abs, char, char);
339+
REGISTER_UNARY_OP(abs, uchar, uchar);
340+
REGISTER_UNARY_OP(abs, float, float);
341+
REGISTER_UNARY_OP(abs, half, half);
342+
324343
#define INSTANTIATE_UNARY_KERNELS2(DTYPE0, DTYPE1) \
325344
REGISTER_UNARY_OP(erf, DTYPE1, DTYPE0); \
326345
REGISTER_UNARY_OP(erfc, DTYPE1, DTYPE0); \
@@ -343,6 +362,7 @@ REGISTER_UNARY_OP(bitwise_not, bool, bool);
343362
#if __METAL_VERSION__ >= 310
344363
INSTANTIATE_UNARY_KERNELS2(bfloat, bfloat);
345364
REGISTER_UNARY_OP(neg, bfloat, bfloat);
365+
REGISTER_UNARY_OP(abs, bfloat, bfloat);
346366
#endif
347367
INSTANTIATE_UNARY_KERNELS2(half, half);
348368
INSTANTIATE_UNARY_KERNELS2(float, float);
@@ -357,6 +377,7 @@ INSTANTIATE_UNARY_KERNELS2(float, long);
357377
REGISTER_UNARY_OP(neg, DTYPE##2, DTYPE##2); \
358378
REGISTER_UNARY_OP(exp, DTYPE##2, DTYPE##2); \
359379
REGISTER_UNARY_OP(sigmoid, DTYPE##2, DTYPE##2); \
380+
REGISTER_UNARY_OP(abs, DTYPE##2, DTYPE##2); \
360381
REGISTER_UNARY_OP(exp2, DTYPE##2, DTYPE##2); \
361382
REGISTER_UNARY_OP(log, DTYPE##2, DTYPE##2); \
362383
REGISTER_UNARY_OP(log10, DTYPE##2, DTYPE##2); \

aten/src/ATen/native/mps/operations/UnaryKernel.mm

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ static void round_decimals_kernel(TensorIteratorBase& iter, int64_t decimals) {
3030
REGISTER_UNARY_TI_DISPATCH(erfinv);
3131
REGISTER_UNARY_TI_DISPATCH(sinc);
3232
REGISTER_UNARY_TI_DISPATCH(tanh);
33+
REGISTER_UNARY_TI_DISPATCH(abs);
3334
REGISTER_UNARY_TI_DISPATCH(sin);
3435
REGISTER_UNARY_TI_DISPATCH(cos);
3536
REGISTER_UNARY_TI_DISPATCH(tan);

aten/src/ATen/native/mps/operations/UnaryOps.mm

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include <ATen/NativeFunctions.h>
1313
#else
1414
#include <ATen/ops/_copy_from_and_resize.h>
15-
#include <ATen/ops/abs_native.h>
1615
#include <ATen/ops/acos_native.h>
1716
#include <ATen/ops/acosh_native.h>
1817
#include <ATen/ops/angle_native.h>
@@ -206,36 +205,6 @@ static void unary_op(const Tensor& self,
206205
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(acosh_out_mps, acosh)
207206
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(atanh_out_mps, atanh)
208207

209-
Tensor& abs_out_mps(const Tensor& self, Tensor& output) {
210-
using namespace mps;
211-
212-
if (!output.is_same_size(self)) {
213-
output.resize_(self.sizes());
214-
}
215-
216-
if (self.numel() == 0) {
217-
return output;
218-
}
219-
220-
if (supportsComplex() || !self.is_complex()) {
221-
unary_op_noresize(self, output, "abs_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
222-
auto rc = [mpsGraph absoluteWithTensor:inputTensor name:nil];
223-
if (self.is_complex()) {
224-
rc = [mpsGraph realPartOfTensor:rc name:nil];
225-
}
226-
return rc;
227-
});
228-
} else {
229-
Tensor realInput = at::view_as_real(self);
230-
unary_op_noresize(
231-
realInput, output, "abs_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
232-
auto rc = lengthOfComplexAsReal(mpsGraph, inputTensor);
233-
return [mpsGraph reshapeTensor:rc withShape:getMPSShape(output) name:nil];
234-
});
235-
}
236-
return output;
237-
}
238-
239208
Tensor& logical_not_out_mps(const Tensor& self, Tensor& output) {
240209
auto bool_self = self.to(ScalarType::Bool);
241210
mps::unary_op(bool_self, output, "logical_not_out_mps", [](MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,8 +357,7 @@
357357
- func: abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
358358
device_check: NoCheck # TensorIterator
359359
dispatch:
360-
CPU, CUDA: abs_out
361-
MPS: abs_out_mps
360+
CPU, CUDA, MPS: abs_out
362361
SparseCPU, SparseCUDA: abs_sparse_out
363362
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: abs_sparse_csr_out
364363
tags: pointwise

torch/testing/_internal/common_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3852,7 +3852,7 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad
38523852
unittest.expectedFailure,
38533853
'TestModule',
38543854
'test_memory_format',
3855-
active_if=operator.itemgetter('training'),
3855+
active_if=operator.itemgetter('training') and not _macos15_or_newer,
38563856
device_type='mps',
38573857
),)
38583858
),

0 commit comments

Comments
 (0)