Skip to content

[WebAssembly] Constant fold wasm.dot #149619

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

badumbatish
Copy link
Contributor

@badumbatish badumbatish commented Jul 18, 2025

Constant fold wasm.dot of constant vectors/splats.

Test case added in llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll

Related to #55933

@llvmbot llvmbot added llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels Jul 18, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 18, 2025

@llvm/pr-subscribers-llvm-analysis

Author: Jasmine Tang (badumbatish)

Changes

Constant fold wasm.dot of constant vectors/splats.

Test case added in llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll


Full diff: https://github.com/llvm/llvm-project/pull/149619.diff

2 Files Affected:

  • (modified) llvm/lib/Analysis/ConstantFolding.cpp (+31)
  • (added) llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll (+39)
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 9c1c2c6e60f02..2304c58b3f95f 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1657,6 +1657,7 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
   case Intrinsic::aarch64_sve_convert_from_svbool:
   case Intrinsic::wasm_alltrue:
   case Intrinsic::wasm_anytrue:
+  case Intrinsic::wasm_dot:
   // WebAssembly float semantics are always known
   case Intrinsic::wasm_trunc_signed:
   case Intrinsic::wasm_trunc_unsigned:
@@ -3826,6 +3827,36 @@ static Constant *ConstantFoldFixedVectorCall(
     }
     return ConstantVector::get(Result);
   }
+  case Intrinsic::wasm_dot: {
+    unsigned NumElements =
+        cast<FixedVectorType>(Operands[0]->getType())->getNumElements();
+
+    assert(NumElements == 8 && NumElements / 2 == Result.size() &&
+           "wasm dot takes i16x8 and produce i32x4");
+    assert(Ty->isIntegerTy());
+    SmallVector<APInt, 8> MulVector;
+
+    for (unsigned I = 0; I < NumElements; ++I) {
+      ConstantInt *Elt0 =
+          cast<ConstantInt>(Operands[0]->getAggregateElement(I));
+      ConstantInt *Elt1 =
+          cast<ConstantInt>(Operands[1]->getAggregateElement(I));
+
+      // sext 32 first, according to specs
+      APInt IMul = Elt0->getValue().sext(32) * Elt1->getValue().sext(32);
+
+      // TODO: imul in specs includes a modulo operation
+      // Is this performed automatically via trunc = true in APInt creation of *
+      MulVector.push_back(IMul);
+    }
+    for (unsigned I = 0; I < Result.size(); ++I) {
+      // Same case as with imul
+      APInt IAdd = MulVector[I] + MulVector[I + Result.size()];
+      Result[I] = ConstantInt::get(Ty, IAdd);
+    }
+
+    return ConstantVector::get(Result);
+  }
   default:
     break;
   }
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
new file mode 100644
index 0000000000000..02c6649becbce
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
@@ -0,0 +1,39 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+
+; RUN: opt -passes=instsimplify -S < %s | FileCheck %s
+
+; Test that intrinsics wasm dot call are constant folded
+
+target triple = "wasm32-unknown-unknown"
+
+
+define <4 x i32> @dot_zero() {
+; CHECK-LABEL: define <4 x i32> @dot_zero() {
+; CHECK-NEXT:    ret <4 x i32> zeroinitializer
+;
+  %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> zeroinitializer, <8 x i16> zeroinitializer)
+  ret <4 x i32> %res
+}
+
+; a               =   1    2    3    4    5    6    7    8
+; b               =   1    2    3    4    5    6    7    8
+; k1|k2 = a * b   =   1    4    9   16   25   36   49   64
+; k1 + k2         =   (1+25) |  (4+36) | (9+49)  | (16+64)
+; result          =    26    |   40    |   58    |   80
+define <4 x i32> @dot_nonzero() {
+; CHECK-LABEL: define <4 x i32> @dot_nonzero() {
+; CHECK-NEXT:    ret <4 x i32> <i32 26, i32 40, i32 58, i32 80>
+;
+  %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>, <8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>)
+  ret <4 x i32> %res
+}
+
+define <4 x i32> @dot_doubly_negative() {
+; CHECK-LABEL: define <4 x i32> @dot_doubly_negative() {
+; CHECK-NEXT:    ret <4 x i32> splat (i32 2)
+;
+  %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>, <8 x i16> <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>)
+  ret <4 x i32> %res
+}
+
+

@llvmbot
Copy link
Member

llvmbot commented Jul 18, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Jasmine Tang (badumbatish)

Changes

Constant fold wasm.dot of constant vectors/splats.

Test case added in llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll


Full diff: https://github.com/llvm/llvm-project/pull/149619.diff

2 Files Affected:

  • (modified) llvm/lib/Analysis/ConstantFolding.cpp (+31)
  • (added) llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll (+39)
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 9c1c2c6e60f02..2304c58b3f95f 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1657,6 +1657,7 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
   case Intrinsic::aarch64_sve_convert_from_svbool:
   case Intrinsic::wasm_alltrue:
   case Intrinsic::wasm_anytrue:
+  case Intrinsic::wasm_dot:
   // WebAssembly float semantics are always known
   case Intrinsic::wasm_trunc_signed:
   case Intrinsic::wasm_trunc_unsigned:
@@ -3826,6 +3827,36 @@ static Constant *ConstantFoldFixedVectorCall(
     }
     return ConstantVector::get(Result);
   }
+  case Intrinsic::wasm_dot: {
+    unsigned NumElements =
+        cast<FixedVectorType>(Operands[0]->getType())->getNumElements();
+
+    assert(NumElements == 8 && NumElements / 2 == Result.size() &&
+           "wasm dot takes i16x8 and produce i32x4");
+    assert(Ty->isIntegerTy());
+    SmallVector<APInt, 8> MulVector;
+
+    for (unsigned I = 0; I < NumElements; ++I) {
+      ConstantInt *Elt0 =
+          cast<ConstantInt>(Operands[0]->getAggregateElement(I));
+      ConstantInt *Elt1 =
+          cast<ConstantInt>(Operands[1]->getAggregateElement(I));
+
+      // sext 32 first, according to specs
+      APInt IMul = Elt0->getValue().sext(32) * Elt1->getValue().sext(32);
+
+      // TODO: imul in specs includes a modulo operation
+      // Is this performed automatically via trunc = true in APInt creation of *
+      MulVector.push_back(IMul);
+    }
+    for (unsigned I = 0; I < Result.size(); ++I) {
+      // Same case as with imul
+      APInt IAdd = MulVector[I] + MulVector[I + Result.size()];
+      Result[I] = ConstantInt::get(Ty, IAdd);
+    }
+
+    return ConstantVector::get(Result);
+  }
   default:
     break;
   }
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
new file mode 100644
index 0000000000000..02c6649becbce
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
@@ -0,0 +1,39 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+
+; RUN: opt -passes=instsimplify -S < %s | FileCheck %s
+
+; Test that intrinsics wasm dot call are constant folded
+
+target triple = "wasm32-unknown-unknown"
+
+
+define <4 x i32> @dot_zero() {
+; CHECK-LABEL: define <4 x i32> @dot_zero() {
+; CHECK-NEXT:    ret <4 x i32> zeroinitializer
+;
+  %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> zeroinitializer, <8 x i16> zeroinitializer)
+  ret <4 x i32> %res
+}
+
+; a               =   1    2    3    4    5    6    7    8
+; b               =   1    2    3    4    5    6    7    8
+; k1|k2 = a * b   =   1    4    9   16   25   36   49   64
+; k1 + k2         =   (1+25) |  (4+36) | (9+49)  | (16+64)
+; result          =    26    |   40    |   58    |   80
+define <4 x i32> @dot_nonzero() {
+; CHECK-LABEL: define <4 x i32> @dot_nonzero() {
+; CHECK-NEXT:    ret <4 x i32> <i32 26, i32 40, i32 58, i32 80>
+;
+  %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>, <8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>)
+  ret <4 x i32> %res
+}
+
+define <4 x i32> @dot_doubly_negative() {
+; CHECK-LABEL: define <4 x i32> @dot_doubly_negative() {
+; CHECK-NEXT:    ret <4 x i32> splat (i32 2)
+;
+  %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>, <8 x i16> <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>)
+  ret <4 x i32> %res
+}
+
+

@badumbatish
Copy link
Contributor Author

@tlively @dschuff hi! i would love some reviews on the pr

@badumbatish badumbatish requested a review from tlively July 24, 2025 23:47
Copy link
Collaborator

@tlively tlively left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Just a few nits.

// sext 32 first, according to specs
APInt IMul = Elt0->getValue().sext(32) * Elt1->getValue().sext(32);

// i16 -> i32 bypasses specs modulo on imul
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on what "bypasses specs modulo" means? Would it be correct to say "Do not truncate the 32-bit result of the multiplication"?

Copy link
Contributor Author

@badumbatish badumbatish Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, i was thinking since we're just multiplying two i16 together, there's no way that the modulo operation with 2^32 on this multiplication will affect the result, thus the word bypass.

Should i change this to something like multilplying two originally 16bit integers doesn't need a modulo of 2^32?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just leave out the comment if the spec doesn't mention anything about truncation. No need to worry about it if the spec doesn't!

Comment on lines +3834 to +3835
assert(NumElements == 8 && Result.size() == 4 &&
"wasm dot takes i16x8 and produces i32x4");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit

Suggested change
assert(NumElements == 8 && Result.size() == 4 &&
"wasm dot takes i16x8 and produces i32x4");
assert(NumElements == 8 && Result.size() == 4 &&
"wasm_dot takes i16x8 and produces i32x4");

assert(NumElements == 8 && Result.size() == 4 &&
"wasm dot takes i16x8 and produces i32x4");
assert(Ty->isIntegerTy());
SmallVector<APInt, 8> MulVector;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, if we don't need the arbitrary precisioness of APInt or the dynamicness of a vector, we can just use a plain old int32_t array:

Suggested change
SmallVector<APInt, 8> MulVector;
int32_t MulVector[8];

And then you can use ConstantInt::getSExtValue to get the values out. It should be a little bit faster than using an APInt.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants