-
Notifications
You must be signed in to change notification settings - Fork 14
NTT-based univariate multiplication #174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
94c2902
7426cc0
7bbb5bb
fcc44e3
04a4665
cc742e0
ba21fde
5136d13
e4885b1
d728353
8a10675
eda9f7d
d79e0d0
b41dcc2
65ffb15
d19099a
e43b06a
a56b9a6
51db7dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| /- | ||
| Copyright (c) 2026 CompPoly. All rights reserved. | ||
| Released under Apache 2.0 license as described in the file LICENSE. | ||
| Authors: Salih Erdem Koçak, Doran Pamukçu | ||
| -/ | ||
| import CompPoly.Univariate.Raw | ||
| import Mathlib.RingTheory.RootsOfUnity.PrimitiveRoots | ||
|
|
||
| /-! | ||
| # NTT Domain | ||
|
|
||
| This file defines the radix-2 NTT domain parameters and basic raw-polynomial | ||
| shape helpers used by forward/inverse NTT. | ||
| -/ | ||
|
|
||
| namespace CompPoly | ||
| namespace CPolynomial | ||
| namespace NTT | ||
|
|
||
| variable {R : Type*} [Field R] | ||
|
|
||
| /-- Parameters for a radix-2 NTT domain of size `2 ^ logN`. -/ | ||
| structure Domain (R : Type*) [Field R] where | ||
| logN : Nat | ||
| omega : R | ||
| primitive : IsPrimitiveRoot omega (2 ^ logN) | ||
| natCast_ne_zero : (((2 ^ logN : Nat) : R) ≠ 0) | ||
|
|
||
| namespace Domain | ||
|
|
||
| /-- Domain size. -/ | ||
| @[simp] def n (D : Domain R) : Nat := 2 ^ D.logN | ||
|
|
||
| /-- Index type for vectors over the domain. -/ | ||
| abbrev Idx (D : Domain R) := Fin D.n | ||
|
|
||
| /-- The `i`-th evaluation node `omega^i`. -/ | ||
| @[inline] def node (D : Domain R) (i : D.Idx) : R := D.omega ^ (i : Nat) | ||
|
|
||
| /-- Inverse root of unity. -/ | ||
| @[inline] def omegaInv (D : Domain R) : R := D.omega⁻¹ | ||
|
|
||
| /-- Multiplicative inverse of the domain size in `R`. -/ | ||
| @[inline] def nInv (D : Domain R) : R := ((D.n : Nat) : R)⁻¹ | ||
|
|
||
| @[simp] lemma n_pos (D : Domain R) : 0 < D.n := by | ||
| simp [n] | ||
|
|
||
| @[simp] lemma n_ne_zero (D : Domain R) : D.n ≠ 0 := by | ||
| exact Nat.ne_of_gt D.n_pos | ||
|
|
||
| section RawHelpers | ||
|
|
||
| variable [BEq R] [LawfulBEq R] | ||
|
|
||
| /-- Required convolution length for multiplying `p` and `q`. -/ | ||
| def requiredLength (p q : CPolynomial.Raw R) : Nat := | ||
| p.trim.size + q.trim.size - 1 | ||
|
|
||
| /-- Whether domain `D` is large enough for multiplying `p` and `q`. -/ | ||
| def fits (D : Domain R) (p q : CPolynomial.Raw R) : Prop := | ||
| requiredLength p q ≤ D.n | ||
|
|
||
| /-- Truncate a polynomial to at most `m` coefficients. -/ | ||
| def truncate (m : Nat) (p : CPolynomial.Raw R) : CPolynomial.Raw R := | ||
| p.extract 0 m | ||
|
|
||
| end RawHelpers | ||
| end Domain | ||
|
|
||
| end NTT | ||
| end CPolynomial | ||
| end CompPoly | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| /- | ||
| Copyright (c) 2026 CompPoly. All rights reserved. | ||
| Released under Apache 2.0 license as described in the file LICENSE. | ||
| Authors: Salih Erdem Koçak, Doran Pamukçu | ||
| -/ | ||
| import CompPoly.Univariate.NTT.Domain | ||
| import CompPoly.Univariate.NTT.Forward | ||
| import CompPoly.Univariate.NTT.Inverse | ||
| import CompPoly.Univariate.Raw | ||
|
|
||
| /-! | ||
| # Fast Multiplication via NTT | ||
|
|
||
| This file wires forward NTT, pointwise multiplication, and inverse NTT into a | ||
| spec/implementation pipeline. | ||
| -/ | ||
|
|
||
| namespace CompPoly | ||
| namespace CPolynomial | ||
| namespace NTT | ||
| namespace FastMul | ||
|
|
||
| variable {R : Type*} [Field R] | ||
|
|
||
| /-- Pointwise multiplication in evaluation form. -/ | ||
| @[inline] def pointwiseMul (D : Domain R) (a b : Array R) : Array R := | ||
| Array.ofFn (fun i : D.Idx => a.getD i.1 0 * b.getD i.1 0) | ||
|
|
||
| @[simp] theorem size_pointwiseMul (D : Domain R) (a b : Array R) : | ||
| (pointwiseMul D a b).size = D.n := by | ||
| simp [pointwiseMul] | ||
|
|
||
| section RawMul | ||
|
|
||
| variable [BEq R] [LawfulBEq R] | ||
|
|
||
| /-- Spec pipeline for NTT-based multiplication. -/ | ||
| @[inline] def fastMulSpec (D : Domain R) (p q : CPolynomial.Raw R) : CPolynomial.Raw R := | ||
| let pHat := Forward.forwardSpec D p | ||
| let qHat := Forward.forwardSpec D q | ||
| let cHat := pointwiseMul D pHat qHat | ||
| let c := Inverse.inverseSpec D cHat | ||
| (Domain.truncate (Domain.requiredLength p q) c).trim | ||
|
|
||
| /-- Implementation pipeline for NTT-based multiplication. -/ | ||
| @[inline] def fastMulImpl (D : Domain R) (p q : CPolynomial.Raw R) : CPolynomial.Raw R := | ||
| let pHat := Forward.forwardImpl D p | ||
| let qHat := Forward.forwardImpl D q | ||
| let cHat := pointwiseMul D pHat qHat | ||
| let c := Inverse.inverseImpl D cHat | ||
| (Domain.truncate (Domain.requiredLength p q) c).trim | ||
|
|
||
| theorem fastMulImpl_correct (D : Domain R) (p q : CPolynomial.Raw R) : | ||
| fastMulImpl D p q = fastMulSpec D p q := by | ||
| simp [fastMulImpl, fastMulSpec, Forward.forwardImpl_correct, Inverse.inverseImpl_correct] | ||
|
|
||
| theorem fastMulSpec_coeff (D : Domain R) (p q : CPolynomial.Raw R) (i : Nat) : | ||
| (fastMulSpec D p q).coeff i = (p * q).coeff i := by | ||
| sorry | ||
|
|
||
| theorem fastMulSpec_eq_mul (D : Domain R) (p q : CPolynomial.Raw R) | ||
| (hfit : Domain.fits D p q) : fastMulSpec D p q = p * q := by | ||
| sorry | ||
|
|
||
| theorem fastMulImpl_eq_mul (D : Domain R) (p q : CPolynomial.Raw R) | ||
| (hfit : Domain.fits D p q) : fastMulImpl D p q = p * q := by | ||
| sorry | ||
|
|
||
| end RawMul | ||
|
|
||
| end FastMul | ||
| end NTT | ||
| end CPolynomial | ||
| end CompPoly |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| /- | ||
| Copyright (c) 2026 CompPoly. All rights reserved. | ||
| Released under Apache 2.0 license as described in the file LICENSE. | ||
| Authors: Salih Erdem Koçak, Doran Pamukçu | ||
| -/ | ||
| import CompPoly.Univariate.NTT.Domain | ||
| import CompPoly.Data.Nat.Bitwise | ||
|
|
||
| /-! | ||
| # Forward NTT | ||
|
|
||
| This file provides spec-level forward NTT definitions together with an | ||
| iterative radix-2 implementation. | ||
| -/ | ||
|
|
||
| open scoped BigOperators | ||
|
|
||
| namespace CompPoly | ||
| namespace CPolynomial | ||
| namespace NTT | ||
| namespace Forward | ||
|
|
||
| variable {R : Type*} [Field R] | ||
|
|
||
| /-- DFT/NTT formula at one output index. -/ | ||
| @[inline] def nttAt (D : Domain R) (a : Array R) (k : D.Idx) : R := | ||
| ∑ j : D.Idx, a.getD j.1 0 * D.omega ^ ((k : Nat) * (j : Nat)) | ||
|
|
||
| /-- Full forward transform specified directly from the NTT formula. -/ | ||
| @[inline] def forwardSpec (D : Domain R) (a : Array R) : Array R := | ||
| Array.ofFn (fun k : D.Idx => nttAt D a k) | ||
|
|
||
| /-- Reverse the lowest `bits` bits of `i`. -/ | ||
| def bitRevNat : Nat → Nat → Nat | ||
| | 0, _ => 0 | ||
| | bits + 1, i => ((i &&& 1) <<< bits) ||| bitRevNat bits (i >>> 1) | ||
|
|
||
| /-- Apply bit-reversal permutation to an evaluation array. -/ | ||
| def bitRevPermute (D : Domain R) (a : Array R) : Array R := | ||
| Array.ofFn (fun i : D.Idx => a.getD (bitRevNat D.logN i.1) 0) | ||
|
|
||
| /-- One butterfly stage of the iterative radix-2 transform. -/ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it looks like |
||
| def butterflyStage (D : Domain R) (stage : Nat) (a : Array R) : Array R := Id.run do | ||
| let blockSize : Nat := 2 ^ (stage + 1) | ||
| let half : Nat := 2 ^ stage | ||
| let wm := D.omega ^ (D.n / blockSize) | ||
| let mut acc := a | ||
| for block in [0:D.n / blockSize] do | ||
| let base := block * blockSize | ||
| let mut w : R := 1 | ||
| for j in [0:half] do | ||
| let i0 := base + j | ||
| let i1 := i0 + half | ||
| let u := acc.getD i0 0 | ||
| let t := w * acc.getD i1 0 | ||
| acc := acc.set! i0 (u + t) | ||
| acc := acc.set! i1 (u - t) | ||
| w := w * wm | ||
| return acc | ||
|
|
||
| /-- Run all radix-2 butterfly stages (complexity: `O(n log n)`). -/ | ||
| def runStages (D : Domain R) (a : Array R) : Array R := Id.run do | ||
| let mut acc := a | ||
| for stage in [0:D.logN] do | ||
| acc := butterflyStage D stage acc | ||
| return acc | ||
|
|
||
| /-- Intended fast implementation entry point for NTT. -/ | ||
| @[inline] def forwardImpl (D : Domain R) (p : CPolynomial.Raw R) : Array R := | ||
| runStages D (bitRevPermute D p) | ||
|
|
||
| @[simp] theorem size_forwardSpec (D : Domain R) (a : Array R) : | ||
| (forwardSpec D a).size = D.n := by | ||
| simp [forwardSpec] | ||
|
|
||
| theorem forwardImpl_correct (D : Domain R) (p : CPolynomial.Raw R) : | ||
| forwardImpl D p = forwardSpec D p := by | ||
| -- TODO: Prove the iterative radix-2 implementation matches the direct NTT formula. | ||
| sorry | ||
|
|
||
| end Forward | ||
| end NTT | ||
| end CPolynomial | ||
| end CompPoly | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| /- | ||
| Copyright (c) 2026 CompPoly. All rights reserved. | ||
| Released under Apache 2.0 license as described in the file LICENSE. | ||
| Authors: Salih Erdem Koçak, Doran Pamukçu | ||
| -/ | ||
| import CompPoly.Univariate.NTT.Domain | ||
| import CompPoly.Univariate.NTT.Forward | ||
|
|
||
| /-! | ||
| # Inverse NTT | ||
|
|
||
| This file provides inverse NTT APIs and correctness statement. | ||
| -/ | ||
|
|
||
| open scoped BigOperators | ||
|
|
||
| namespace CompPoly | ||
| namespace CPolynomial | ||
| namespace NTT | ||
| namespace Inverse | ||
|
|
||
| variable {R : Type*} [Field R] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it might be worth trying to weaken the field assumption, if possible. we can always generalize results once they're proved though so don't worry too much about this right away |
||
|
|
||
| /-- Inverse NTT formula at one output index. -/ | ||
| def inttAt (D : Domain R) (v : Array R) (k : D.Idx) : R := | ||
| D.nInv * ∑ j : D.Idx, v.getD j.1 0 * D.omegaInv ^ ((k : Nat) * (j : Nat)) | ||
|
|
||
| /-- Full inverse transform on arrays, specified from `inttAt`. -/ | ||
| def inverseSpec (D : Domain R) (v : Array R) : Array R := | ||
| Array.ofFn (fun k : D.Idx => inttAt D v k) | ||
|
|
||
| /-- Apply bit-reversal permutation to an evaluation array. -/ | ||
| def bitRevPermute (D : Domain R) (a : Array R) : Array R := | ||
| Array.ofFn (fun i : D.Idx => a.getD (Forward.bitRevNat D.logN i.1) 0) | ||
|
|
||
| /-- One butterfly stage of the iterative radix-2 inverse transform. -/ | ||
| def butterflyStage (D : Domain R) (stage : Nat) (a : Array R) : Array R := Id.run do | ||
| let blockSize : Nat := 2 ^ (stage + 1) | ||
| let half : Nat := 2 ^ stage | ||
| let wm := D.omegaInv ^ (D.n / blockSize) | ||
| let mut acc := a | ||
| for block in [0:D.n / blockSize] do | ||
| let base := block * blockSize | ||
| let mut w : R := 1 | ||
| for j in [0:half] do | ||
| let i0 := base + j | ||
| let i1 := i0 + half | ||
| let u := acc.getD i0 0 | ||
| let t := w * acc.getD i1 0 | ||
| acc := acc.set! i0 (u + t) | ||
| acc := acc.set! i1 (u - t) | ||
| w := w * wm | ||
| return acc | ||
|
|
||
| /-- Run all radix-2 inverse butterfly stages. -/ | ||
| def runStages (D : Domain R) (a : Array R) : Array R := Id.run do | ||
| let mut acc := a | ||
| for stage in [0:D.logN] do | ||
| acc := butterflyStage D stage acc | ||
| return acc | ||
|
|
||
| /-- Apply the final `n⁻¹` normalization for the inverse transform. -/ | ||
| def normalize (D : Domain R) (a : Array R) : Array R := | ||
| Array.ofFn (fun i : D.Idx => D.nInv * a.getD i.1 0) | ||
|
|
||
| @[simp] theorem size_inverseSpec (D : Domain R) (v : Array R) : | ||
| (inverseSpec D v).size = D.n := by | ||
| simp [inverseSpec] | ||
|
|
||
| @[simp] theorem size_normalize (D : Domain R) (a : Array R) : | ||
| (normalize D a).size = D.n := by | ||
| simp [normalize] | ||
|
|
||
| /-- Intended fast implementation entry point for inverse NTT. -/ | ||
| def inverseImpl (D : Domain R) (v : Array R) : CPolynomial.Raw R := | ||
| normalize D (runStages D (bitRevPermute D v)) | ||
|
|
||
| theorem inverseImpl_correct (D : Domain R) (v : Array R) : | ||
| inverseImpl D v = inverseSpec D v := by | ||
| -- TODO: Prove the iterative radix-2 inverse implementation matches the direct formula. | ||
| sorry | ||
|
|
||
| end Inverse | ||
| end NTT | ||
| end CPolynomial | ||
| end CompPoly | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that in Nat,
0 - 1 = 0which can lead to undesirable behavior. worth checking that this underflow won't affect the rest