Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CompPoly.lean
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ import CompPoly.Multivariate.Restrict
import CompPoly.Multivariate.Unlawful
import CompPoly.Multivariate.VarsDegrees
import CompPoly.Multivariate.Wheels
import CompPoly.Univariate.Basic
import CompPoly.Univariate.Lagrange
import CompPoly.ToMathlib.Finsupp.Fin
import CompPoly.ToMathlib.MvPolynomial.Equiv
import CompPoly.ToMathlib.Polynomial.BivariateDegree
Expand All @@ -84,6 +86,10 @@ import CompPoly.Univariate.Raw.Core
import CompPoly.Univariate.Raw.Division
import CompPoly.Univariate.Raw.Ops
import CompPoly.Univariate.Raw.Proofs
import CompPoly.Univariate.NTT.Domain
import CompPoly.Univariate.NTT.Forward
import CompPoly.Univariate.NTT.Inverse
import CompPoly.Univariate.NTT.FastMul
import CompPoly.Univariate.ToPoly
import CompPoly.Univariate.ToPoly.Core
import CompPoly.Univariate.ToPoly.Degree
Expand Down
73 changes: 73 additions & 0 deletions CompPoly/Univariate/NTT/Domain.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/-
Copyright (c) 2026 CompPoly. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Salih Erdem Koçak, Doran Pamukçu
-/
import CompPoly.Univariate.Raw
import Mathlib.RingTheory.RootsOfUnity.PrimitiveRoots

/-!
# NTT Domain

This file defines the radix-2 NTT domain parameters and basic raw-polynomial
shape helpers used by forward/inverse NTT.
-/

namespace CompPoly
namespace CPolynomial
namespace NTT

variable {R : Type*} [Field R]

/-- Parameters for a radix-2 NTT domain of size `2 ^ logN`. -/
structure Domain (R : Type*) [Field R] where
logN : Nat
omega : R
primitive : IsPrimitiveRoot omega (2 ^ logN)
natCast_ne_zero : (((2 ^ logN : Nat) : R) ≠ 0)

namespace Domain

/-- Domain size. -/
@[simp] def n (D : Domain R) : Nat := 2 ^ D.logN

/-- Index type for vectors over the domain. -/
abbrev Idx (D : Domain R) := Fin D.n

/-- The `i`-th evaluation node `omega^i`. -/
@[inline] def node (D : Domain R) (i : D.Idx) : R := D.omega ^ (i : Nat)

/-- Inverse root of unity. -/
@[inline] def omegaInv (D : Domain R) : R := D.omega⁻¹

/-- Multiplicative inverse of the domain size in `R`. -/
@[inline] def nInv (D : Domain R) : R := ((D.n : Nat) : R)⁻¹

@[simp] lemma n_pos (D : Domain R) : 0 < D.n := by
simp [n]

@[simp] lemma n_ne_zero (D : Domain R) : D.n ≠ 0 := by
exact Nat.ne_of_gt D.n_pos

section RawHelpers

variable [BEq R] [LawfulBEq R]

/-- Required convolution length for multiplying `p` and `q`. -/
def requiredLength (p q : CPolynomial.Raw R) : Nat :=
p.trim.size + q.trim.size - 1

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that in Nat, 0 - 1 = 0 which can lead to undesirable behavior. worth checking that this underflow won't affect the rest

/-- Whether domain `D` is large enough for multiplying `p` and `q`. -/
def fits (D : Domain R) (p q : CPolynomial.Raw R) : Prop :=
requiredLength p q ≤ D.n

/-- Truncate a polynomial to at most `m` coefficients. -/
def truncate (m : Nat) (p : CPolynomial.Raw R) : CPolynomial.Raw R :=
p.extract 0 m

end RawHelpers
end Domain

end NTT
end CPolynomial
end CompPoly
74 changes: 74 additions & 0 deletions CompPoly/Univariate/NTT/FastMul.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/-
Copyright (c) 2026 CompPoly. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Salih Erdem Koçak, Doran Pamukçu
-/
import CompPoly.Univariate.NTT.Domain
import CompPoly.Univariate.NTT.Forward
import CompPoly.Univariate.NTT.Inverse
import CompPoly.Univariate.Raw

/-!
# Fast Multiplication via NTT

This file wires forward NTT, pointwise multiplication, and inverse NTT into a
spec/implementation pipeline.
-/

namespace CompPoly
namespace CPolynomial
namespace NTT
namespace FastMul

variable {R : Type*} [Field R]

/-- Pointwise multiplication in evaluation form. -/
@[inline] def pointwiseMul (D : Domain R) (a b : Array R) : Array R :=
Array.ofFn (fun i : D.Idx => a.getD i.1 0 * b.getD i.1 0)

@[simp] theorem size_pointwiseMul (D : Domain R) (a b : Array R) :
(pointwiseMul D a b).size = D.n := by
simp [pointwiseMul]

section RawMul

variable [BEq R] [LawfulBEq R]

/-- Spec pipeline for NTT-based multiplication. -/
@[inline] def fastMulSpec (D : Domain R) (p q : CPolynomial.Raw R) : CPolynomial.Raw R :=
let pHat := Forward.forwardSpec D p
let qHat := Forward.forwardSpec D q
let cHat := pointwiseMul D pHat qHat
let c := Inverse.inverseSpec D cHat
(Domain.truncate (Domain.requiredLength p q) c).trim

/-- Implementation pipeline for NTT-based multiplication. -/
@[inline] def fastMulImpl (D : Domain R) (p q : CPolynomial.Raw R) : CPolynomial.Raw R :=
let pHat := Forward.forwardImpl D p
let qHat := Forward.forwardImpl D q
let cHat := pointwiseMul D pHat qHat
let c := Inverse.inverseImpl D cHat
(Domain.truncate (Domain.requiredLength p q) c).trim

theorem fastMulImpl_correct (D : Domain R) (p q : CPolynomial.Raw R) :
fastMulImpl D p q = fastMulSpec D p q := by
simp [fastMulImpl, fastMulSpec, Forward.forwardImpl_correct, Inverse.inverseImpl_correct]

theorem fastMulSpec_coeff (D : Domain R) (p q : CPolynomial.Raw R) (i : Nat) :
(fastMulSpec D p q).coeff i = (p * q).coeff i := by
sorry

theorem fastMulSpec_eq_mul (D : Domain R) (p q : CPolynomial.Raw R)
(hfit : Domain.fits D p q) : fastMulSpec D p q = p * q := by
sorry

theorem fastMulImpl_eq_mul (D : Domain R) (p q : CPolynomial.Raw R)
(hfit : Domain.fits D p q) : fastMulImpl D p q = p * q := by
sorry

end RawMul

end FastMul
end NTT
end CPolynomial
end CompPoly
84 changes: 84 additions & 0 deletions CompPoly/Univariate/NTT/Forward.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/-
Copyright (c) 2026 CompPoly. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Salih Erdem Koçak, Doran Pamukçu
-/
import CompPoly.Univariate.NTT.Domain
import CompPoly.Data.Nat.Bitwise

/-!
# Forward NTT

This file provides spec-level forward NTT definitions together with an
iterative radix-2 implementation.
-/

open scoped BigOperators

namespace CompPoly
namespace CPolynomial
namespace NTT
namespace Forward

variable {R : Type*} [Field R]

/-- DFT/NTT formula at one output index. -/
@[inline] def nttAt (D : Domain R) (a : Array R) (k : D.Idx) : R :=
∑ j : D.Idx, a.getD j.1 0 * D.omega ^ ((k : Nat) * (j : Nat))

/-- Full forward transform specified directly from the NTT formula. -/
@[inline] def forwardSpec (D : Domain R) (a : Array R) : Array R :=
Array.ofFn (fun k : D.Idx => nttAt D a k)

/-- Reverse the lowest `bits` bits of `i`. -/
def bitRevNat : Nat → Nat → Nat
| 0, _ => 0
| bits + 1, i => ((i &&& 1) <<< bits) ||| bitRevNat bits (i >>> 1)

/-- Apply bit-reversal permutation to an evaluation array. -/
def bitRevPermute (D : Domain R) (a : Array R) : Array R :=
Array.ofFn (fun i : D.Idx => a.getD (bitRevNat D.logN i.1) 0)

/-- One butterfly stage of the iterative radix-2 transform. -/
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like bitRevPermute, butterflyStage, and runStages are nearly identical between Forward.lean and Inverse.lean. to avoid code duplication, it's probably a good idea to factor these into a shared definition (parametric over omega/omegaInv) in a shared file. you might want to put all the shared operations in there as well (e.g. including bitRevNat)

def butterflyStage (D : Domain R) (stage : Nat) (a : Array R) : Array R := Id.run do
let blockSize : Nat := 2 ^ (stage + 1)
let half : Nat := 2 ^ stage
let wm := D.omega ^ (D.n / blockSize)
let mut acc := a
for block in [0:D.n / blockSize] do
let base := block * blockSize
let mut w : R := 1
for j in [0:half] do
let i0 := base + j
let i1 := i0 + half
let u := acc.getD i0 0
let t := w * acc.getD i1 0
acc := acc.set! i0 (u + t)
acc := acc.set! i1 (u - t)
w := w * wm
return acc

/-- Run all radix-2 butterfly stages (complexity: `O(n log n)`). -/
def runStages (D : Domain R) (a : Array R) : Array R := Id.run do
let mut acc := a
for stage in [0:D.logN] do
acc := butterflyStage D stage acc
return acc

/-- Intended fast implementation entry point for NTT. -/
@[inline] def forwardImpl (D : Domain R) (p : CPolynomial.Raw R) : Array R :=
runStages D (bitRevPermute D p)

@[simp] theorem size_forwardSpec (D : Domain R) (a : Array R) :
(forwardSpec D a).size = D.n := by
simp [forwardSpec]

theorem forwardImpl_correct (D : Domain R) (p : CPolynomial.Raw R) :
forwardImpl D p = forwardSpec D p := by
-- TODO: Prove the iterative radix-2 implementation matches the direct NTT formula.
sorry

end Forward
end NTT
end CPolynomial
end CompPoly
86 changes: 86 additions & 0 deletions CompPoly/Univariate/NTT/Inverse.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/-
Copyright (c) 2026 CompPoly. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Salih Erdem Koçak, Doran Pamukçu
-/
import CompPoly.Univariate.NTT.Domain
import CompPoly.Univariate.NTT.Forward

/-!
# Inverse NTT

This file provides inverse NTT APIs and correctness statement.
-/

open scoped BigOperators

namespace CompPoly
namespace CPolynomial
namespace NTT
namespace Inverse

variable {R : Type*} [Field R]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it might be worth trying to weaken the field assumption, if possible. we can always generalize results once they're proved though so don't worry too much about this right away


/-- Inverse NTT formula at one output index. -/
def inttAt (D : Domain R) (v : Array R) (k : D.Idx) : R :=
D.nInv * ∑ j : D.Idx, v.getD j.1 0 * D.omegaInv ^ ((k : Nat) * (j : Nat))

/-- Full inverse transform on arrays, specified from `inttAt`. -/
def inverseSpec (D : Domain R) (v : Array R) : Array R :=
Array.ofFn (fun k : D.Idx => inttAt D v k)

/-- Apply bit-reversal permutation to an evaluation array. -/
def bitRevPermute (D : Domain R) (a : Array R) : Array R :=
Array.ofFn (fun i : D.Idx => a.getD (Forward.bitRevNat D.logN i.1) 0)

/-- One butterfly stage of the iterative radix-2 inverse transform. -/
def butterflyStage (D : Domain R) (stage : Nat) (a : Array R) : Array R := Id.run do
let blockSize : Nat := 2 ^ (stage + 1)
let half : Nat := 2 ^ stage
let wm := D.omegaInv ^ (D.n / blockSize)
let mut acc := a
for block in [0:D.n / blockSize] do
let base := block * blockSize
let mut w : R := 1
for j in [0:half] do
let i0 := base + j
let i1 := i0 + half
let u := acc.getD i0 0
let t := w * acc.getD i1 0
acc := acc.set! i0 (u + t)
acc := acc.set! i1 (u - t)
w := w * wm
return acc

/-- Run all radix-2 inverse butterfly stages. -/
def runStages (D : Domain R) (a : Array R) : Array R := Id.run do
let mut acc := a
for stage in [0:D.logN] do
acc := butterflyStage D stage acc
return acc

/-- Apply the final `n⁻¹` normalization for the inverse transform. -/
def normalize (D : Domain R) (a : Array R) : Array R :=
Array.ofFn (fun i : D.Idx => D.nInv * a.getD i.1 0)

@[simp] theorem size_inverseSpec (D : Domain R) (v : Array R) :
(inverseSpec D v).size = D.n := by
simp [inverseSpec]

@[simp] theorem size_normalize (D : Domain R) (a : Array R) :
(normalize D a).size = D.n := by
simp [normalize]

/-- Intended fast implementation entry point for inverse NTT. -/
def inverseImpl (D : Domain R) (v : Array R) : CPolynomial.Raw R :=
normalize D (runStages D (bitRevPermute D v))

theorem inverseImpl_correct (D : Domain R) (v : Array R) :
inverseImpl D v = inverseSpec D v := by
-- TODO: Prove the iterative radix-2 inverse implementation matches the direct formula.
sorry

end Inverse
end NTT
end CPolynomial
end CompPoly
3 changes: 3 additions & 0 deletions tests/CompPolyTests.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ import CompPolyTests.Univariate.Raw
import CompPolyTests.Univariate.Basic
import CompPolyTests.Univariate.Linear
import CompPolyTests.Univariate.ToPoly
import CompPolyTests.Univariate.NTT.Forward
import CompPolyTests.Univariate.NTT.Inverse
import CompPolyTests.Univariate.NTT.FastMul
import CompPolyTests.Bivariate.Basic
import CompPolyTests.Bivariate.Degree
import CompPolyTests.Bivariate.WeightedDegree
Expand Down
Loading
Loading