-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[WebAssembly] Constant fold wasm.dot #149619
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-llvm-analysis Author: Jasmine Tang (badumbatish) ChangesConstant fold wasm.dot of constant vectors/splats. Test case added in Full diff: https://github.com/llvm/llvm-project/pull/149619.diff 2 Files Affected:
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 9c1c2c6e60f02..2304c58b3f95f 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1657,6 +1657,7 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
case Intrinsic::aarch64_sve_convert_from_svbool:
case Intrinsic::wasm_alltrue:
case Intrinsic::wasm_anytrue:
+ case Intrinsic::wasm_dot:
// WebAssembly float semantics are always known
case Intrinsic::wasm_trunc_signed:
case Intrinsic::wasm_trunc_unsigned:
@@ -3826,6 +3827,36 @@ static Constant *ConstantFoldFixedVectorCall(
}
return ConstantVector::get(Result);
}
+ case Intrinsic::wasm_dot: {
+ unsigned NumElements =
+ cast<FixedVectorType>(Operands[0]->getType())->getNumElements();
+
+ assert(NumElements == 8 && NumElements / 2 == Result.size() &&
+ "wasm dot takes i16x8 and produce i32x4");
+ assert(Ty->isIntegerTy());
+ SmallVector<APInt, 8> MulVector;
+
+ for (unsigned I = 0; I < NumElements; ++I) {
+ ConstantInt *Elt0 =
+ cast<ConstantInt>(Operands[0]->getAggregateElement(I));
+ ConstantInt *Elt1 =
+ cast<ConstantInt>(Operands[1]->getAggregateElement(I));
+
+ // sext 32 first, according to specs
+ APInt IMul = Elt0->getValue().sext(32) * Elt1->getValue().sext(32);
+
+ // TODO: imul in specs includes a modulo operation
+ // Is this performed automatically via trunc = true in APInt creation of *
+ MulVector.push_back(IMul);
+ }
+ for (unsigned I = 0; I < Result.size(); ++I) {
+ // Same case as with imul
+ APInt IAdd = MulVector[I] + MulVector[I + Result.size()];
+ Result[I] = ConstantInt::get(Ty, IAdd);
+ }
+
+ return ConstantVector::get(Result);
+ }
default:
break;
}
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
new file mode 100644
index 0000000000000..02c6649becbce
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
@@ -0,0 +1,39 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+
+; RUN: opt -passes=instsimplify -S < %s | FileCheck %s
+
+; Test that intrinsics wasm dot call are constant folded
+
+target triple = "wasm32-unknown-unknown"
+
+
+define <4 x i32> @dot_zero() {
+; CHECK-LABEL: define <4 x i32> @dot_zero() {
+; CHECK-NEXT: ret <4 x i32> zeroinitializer
+;
+ %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> zeroinitializer, <8 x i16> zeroinitializer)
+ ret <4 x i32> %res
+}
+
+; a = 1 2 3 4 5 6 7 8
+; b = 1 2 3 4 5 6 7 8
+; k1|k2 = a * b = 1 4 9 16 25 36 49 64
+; k1 + k2 = (1+25) | (4+36) | (9+49) | (16+64)
+; result = 26 | 40 | 58 | 80
+define <4 x i32> @dot_nonzero() {
+; CHECK-LABEL: define <4 x i32> @dot_nonzero() {
+; CHECK-NEXT: ret <4 x i32> <i32 26, i32 40, i32 58, i32 80>
+;
+ %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>, <8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>)
+ ret <4 x i32> %res
+}
+
+define <4 x i32> @dot_doubly_negative() {
+; CHECK-LABEL: define <4 x i32> @dot_doubly_negative() {
+; CHECK-NEXT: ret <4 x i32> splat (i32 2)
+;
+ %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>, <8 x i16> <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>)
+ ret <4 x i32> %res
+}
+
+
|
@llvm/pr-subscribers-llvm-transforms Author: Jasmine Tang (badumbatish) ChangesConstant fold wasm.dot of constant vectors/splats. Test case added in Full diff: https://github.com/llvm/llvm-project/pull/149619.diff 2 Files Affected:
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 9c1c2c6e60f02..2304c58b3f95f 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1657,6 +1657,7 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
case Intrinsic::aarch64_sve_convert_from_svbool:
case Intrinsic::wasm_alltrue:
case Intrinsic::wasm_anytrue:
+ case Intrinsic::wasm_dot:
// WebAssembly float semantics are always known
case Intrinsic::wasm_trunc_signed:
case Intrinsic::wasm_trunc_unsigned:
@@ -3826,6 +3827,36 @@ static Constant *ConstantFoldFixedVectorCall(
}
return ConstantVector::get(Result);
}
+ case Intrinsic::wasm_dot: {
+ unsigned NumElements =
+ cast<FixedVectorType>(Operands[0]->getType())->getNumElements();
+
+ assert(NumElements == 8 && NumElements / 2 == Result.size() &&
+ "wasm dot takes i16x8 and produce i32x4");
+ assert(Ty->isIntegerTy());
+ SmallVector<APInt, 8> MulVector;
+
+ for (unsigned I = 0; I < NumElements; ++I) {
+ ConstantInt *Elt0 =
+ cast<ConstantInt>(Operands[0]->getAggregateElement(I));
+ ConstantInt *Elt1 =
+ cast<ConstantInt>(Operands[1]->getAggregateElement(I));
+
+ // sext 32 first, according to specs
+ APInt IMul = Elt0->getValue().sext(32) * Elt1->getValue().sext(32);
+
+ // TODO: imul in specs includes a modulo operation
+ // Is this performed automatically via trunc = true in APInt creation of *
+ MulVector.push_back(IMul);
+ }
+ for (unsigned I = 0; I < Result.size(); ++I) {
+ // Same case as with imul
+ APInt IAdd = MulVector[I] + MulVector[I + Result.size()];
+ Result[I] = ConstantInt::get(Ty, IAdd);
+ }
+
+ return ConstantVector::get(Result);
+ }
default:
break;
}
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
new file mode 100644
index 0000000000000..02c6649becbce
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
@@ -0,0 +1,39 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+
+; RUN: opt -passes=instsimplify -S < %s | FileCheck %s
+
+; Test that intrinsics wasm dot call are constant folded
+
+target triple = "wasm32-unknown-unknown"
+
+
+define <4 x i32> @dot_zero() {
+; CHECK-LABEL: define <4 x i32> @dot_zero() {
+; CHECK-NEXT: ret <4 x i32> zeroinitializer
+;
+ %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> zeroinitializer, <8 x i16> zeroinitializer)
+ ret <4 x i32> %res
+}
+
+; a = 1 2 3 4 5 6 7 8
+; b = 1 2 3 4 5 6 7 8
+; k1|k2 = a * b = 1 4 9 16 25 36 49 64
+; k1 + k2 = (1+25) | (4+36) | (9+49) | (16+64)
+; result = 26 | 40 | 58 | 80
+define <4 x i32> @dot_nonzero() {
+; CHECK-LABEL: define <4 x i32> @dot_nonzero() {
+; CHECK-NEXT: ret <4 x i32> <i32 26, i32 40, i32 58, i32 80>
+;
+ %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>, <8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>)
+ ret <4 x i32> %res
+}
+
+define <4 x i32> @dot_doubly_negative() {
+; CHECK-LABEL: define <4 x i32> @dot_doubly_negative() {
+; CHECK-NEXT: ret <4 x i32> splat (i32 2)
+;
+ %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>, <8 x i16> <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>)
+ ret <4 x i32> %res
+}
+
+
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Just a few nits.
// sext 32 first, according to specs | ||
APInt IMul = Elt0->getValue().sext(32) * Elt1->getValue().sext(32); | ||
|
||
// i16 -> i32 bypasses specs modulo on imul |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate on what "bypasses specs modulo" means? Would it be correct to say "Do not truncate the 32-bit result of the multiplication"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, i was thinking since we're just multiplying two i16 together, there's no way that the modulo operation with 2^32 on this multiplication will affect the result, thus the word bypass.
Should i change this to something like multilplying two originally 16bit integers doesn't need a modulo of 2^32
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would just leave out the comment if the spec doesn't mention anything about truncation. No need to worry about it if the spec doesn't!
assert(NumElements == 8 && Result.size() == 4 && | ||
"wasm dot takes i16x8 and produces i32x4"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit
assert(NumElements == 8 && Result.size() == 4 && | |
"wasm dot takes i16x8 and produces i32x4"); | |
assert(NumElements == 8 && Result.size() == 4 && | |
"wasm_dot takes i16x8 and produces i32x4"); |
assert(NumElements == 8 && Result.size() == 4 && | ||
"wasm dot takes i16x8 and produces i32x4"); | ||
assert(Ty->isIntegerTy()); | ||
SmallVector<APInt, 8> MulVector; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit, if we don't need the arbitrary precisioness of APInt or the dynamicness of a vector, we can just use a plain old int32_t array:
SmallVector<APInt, 8> MulVector; | |
int32_t MulVector[8]; |
And then you can use ConstantInt::getSExtValue
to get the values out. It should be a little bit faster than using an APInt.
Constant fold wasm.dot of constant vectors/splats.
Test case added in
llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
Related to #55933