Skip to content

Commit fdd3e20

Browse files
Revert "[MPS] Fix binary ops between int32 tensor with int64 scalar (pytorch#80220)"
This reverts commit a6556ef. Reverted pytorch#80220 on behalf of https://github.com/malfet due to Did not push the final version of commit
1 parent e98e7fe commit fdd3e20

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

aten/src/ATen/native/mps/operations/BinaryOps.mm

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
3838

3939
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
4040
@autoreleasepool {
41-
string key = op_name + getTensorsStringKey({self, other, output}, /*use_scalar_value*/ false);
41+
string key = op_name + getTensorsStringKey({self, other}, /*use_scalar_value*/ false);
4242
BinaryOpCachedGraph* cachedGraph = static_cast<BinaryOpCachedGraph *>(cache_->LookUp(key));
4343

4444
if(!cachedGraph) {
@@ -62,9 +62,6 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
6262
secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, common_dtype);
6363
}
6464
newCachedGraph->outputTensor = binaryBlock(newCachedGraph, primaryCastTensor, secondaryCastTensor);
65-
if (output.scalar_type() != common_dtype) {
66-
newCachedGraph->outputTensor = castMPSTensor(mpsGraph, newCachedGraph->outputTensor, output.scalar_type());
67-
}
6865
}
6966
return newCachedGraph;
7067
});

0 commit comments

Comments
 (0)