Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions CompPoly/Univariate/NTT/Domain.lean
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,31 @@ def inverse (D : Domain R) : Domain R where

section RawHelpers

variable [BEq R] [LawfulBEq R]
variable [BEq R]

/-- Required convolution length for multiplying `p` and `q`. -/
def requiredLength (p q : CPolynomial.Raw R) : Nat :=
p.trim.size + q.trim.size - 1
if p.trim.size = 0 ∨ q.trim.size = 0 then
0
else
p.trim.size + q.trim.size - 1

@[simp] theorem requiredLength_eq_zero_of_left_trim_size_zero
(p q : CPolynomial.Raw R) (hp : p.trim.size = 0) :
requiredLength p q = 0 := by
simp [requiredLength, hp]

@[simp] theorem requiredLength_eq_zero_of_right_trim_size_zero
(p q : CPolynomial.Raw R) (hq : q.trim.size = 0) :
requiredLength p q = 0 := by
simp [requiredLength, hq]

theorem requiredLength_eq_of_trim_size_pos
(p q : CPolynomial.Raw R) (hp : 0 < p.trim.size) (hq : 0 < q.trim.size) :
requiredLength p q = p.trim.size + q.trim.size - 1 := by
have hp0 : p.trim.size ≠ 0 := Nat.ne_of_gt hp
have hq0 : q.trim.size ≠ 0 := Nat.ne_of_gt hq
simp [requiredLength, hp0, hq0]

/-- Whether domain `D` is large enough for multiplying `p` and `q`. -/
def fits (D : Domain R) (p q : CPolynomial.Raw R) : Prop :=
Expand Down
130 changes: 22 additions & 108 deletions CompPoly/Univariate/NTT/FastMul.lean
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,11 @@ private theorem coeff_truncate (m : Nat) (a : CPolynomial.Raw R) (i : Nat) :
· simp [hi]

private theorem mul_coeff_eq_zero_of_requiredLength_le
(p q : CPolynomial.Raw R) {i : Nat} (hi : Domain.requiredLength p q ≤ i) :
(p q : CPolynomial.Raw R) (hppos : 0 < p.trim.size) (hqpos : 0 < q.trim.size)
{i : Nat} (hi : Domain.requiredLength p q ≤ i) :
(p * q).coeff i = 0 := by
have hreq : p.trim.size + q.trim.size - 1 ≤ i := by
simpa [Domain.requiredLength_eq_of_trim_size_pos p q hppos hqpos] using hi
rw [CPolynomial.Raw.mul_coeff]
apply Finset.sum_eq_zero
intro x hx
Expand All @@ -216,7 +219,6 @@ private theorem mul_coeff_eq_zero_of_requiredLength_le
simp [hp0]
· have hxlt : x < p.trim.size := Nat.lt_of_not_ge hpx
have hqle : q.trim.size ≤ i - x := by
simp [Domain.requiredLength] at hi
omega
have hq0 : q.coeff (i - x) = 0 := coeff_zero_of_trim_size_le q hqle
simp [hq0]
Expand All @@ -239,20 +241,6 @@ private theorem mul_coeff_eq_zero_of_right_trim_size_zero
have hq0 : q.coeff (i - x) = 0 := coeff_zero_of_trim_size_le q (by omega)
simp [hq0]

private theorem natDegree_toPoly_lt_of_trim_size_le
(D : Domain R) (a : CPolynomial.Raw R) (ha : a.trim.size ≤ D.n) :
a.toPoly.natDegree < D.n := by
by_cases hzero : a.toPoly = 0
· rw [hzero]
exact D.n_pos
· have hround := CPolynomial.Raw.toImpl_toPoly (R := R) a
have hsize : a.toPoly.toImpl.size = a.trim.size := congrArg Array.size hround
rcases CPolynomial.Raw.toImpl_elim a.toPoly with ⟨hz, _himpl⟩ | ⟨_hnz, himpl⟩
· exact (hzero hz).elim
· have himpl_size : a.toPoly.toImpl.size = a.toPoly.natDegree + 1 := by
simp [himpl]
omega

private theorem natDegree_toPoly_lt_trim_size_of_pos
(a : CPolynomial.Raw R) (ha : 0 < a.trim.size) :
a.toPoly.natDegree < a.trim.size := by
Expand Down Expand Up @@ -295,10 +283,7 @@ private theorem raw_eval_mul (x : R) (p q : CPolynomial.Raw R) :
rw [← CPolynomial.Raw.eval_toPoly_eq_eval x (p * q)]
rw [← CPolynomial.Raw.eval_toPoly_eq_eval x p]
rw [← CPolynomial.Raw.eval_toPoly_eq_eval x q]
have hpoly : (p * q).toPoly = p.toPoly * q.toPoly := by
ext i
exact CPolynomial.Raw.toPoly_mul_coeff p q i
rw [hpoly]
rw [CPolynomial.Raw.toPoly_mul p q]
simp

private theorem pointwise_forwardSpec_eq_forwardSpec_mul_of_natDegree_lt
Expand All @@ -314,7 +299,8 @@ private theorem pointwise_forwardSpec_eq_forwardSpec_mul_of_natDegree_lt
let k : D.Idx := ⟨i, hiD⟩
have hpsize : i < (Forward.forwardSpec D p).size := by simpa [Forward.forwardSpec] using hiD
have hqsize : i < (Forward.forwardSpec D q).size := by simpa [Forward.forwardSpec] using hiD
have hpqsize : i < (Forward.forwardSpec D (p * q)).size := by simpa [Forward.forwardSpec] using hiD
have hpqsize : i < (Forward.forwardSpec D (p * q)).size := by
simpa [Forward.forwardSpec] using hiD
have hpget : (Forward.forwardSpec D p)[i]'hpsize = p.eval (D.node k) := by
simpa [k] using forwardSpec_get_eq_eval_of_natDegree_lt D p hpdeg k
have hqget : (Forward.forwardSpec D q)[i]'hqsize = q.eval (D.node k) := by
Expand All @@ -324,71 +310,6 @@ private theorem pointwise_forwardSpec_eq_forwardSpec_mul_of_natDegree_lt
simp [pointwiseMul]
rw [hpget, hqget, hpqget, raw_eval_mul]

private theorem forwardSpec_getD_eq_zero_of_trim_size_zero
(D : Domain R) (p : CPolynomial.Raw R) (hp : p.trim.size = 0) (i : Nat) :
(Forward.forwardSpec D p).getD i 0 = 0 := by
by_cases hi : i < (Forward.forwardSpec D p).size
· rw [Array.getD_eq_getD_getElem?, Array.getElem?_eq_getElem hi]
simp only [Option.getD_some]
have hiD : i < D.n := by simpa [Forward.forwardSpec] using hi
let k : D.Idx := ⟨i, hiD⟩
change (Forward.forwardSpec D p)[k.1] = 0
simp [Forward.forwardSpec, Forward.nttAt]
apply Finset.sum_eq_zero
intro j _
have hp0 : p.coeff j.1 = 0 := coeff_zero_of_trim_size_le p (by omega)
simp [CPolynomial.Raw.coeff] at hp0
simp [hp0]
· rw [Array.getD_eq_getD_getElem?, Array.getElem?_eq_none (Nat.le_of_not_lt hi)]
simp

omit [BEq R] [LawfulBEq R] in
private theorem inverseSpec_getD_eq_zero_of_getD_zero
(D : Domain R) (a : Array R) (ha : ∀ i : Nat, a.getD i 0 = 0) (i : Nat) :
(Inverse.inverseSpec D a).getD i 0 = 0 := by
by_cases hi : i < (Inverse.inverseSpec D a).size
· rw [Array.getD_eq_getD_getElem?, Array.getElem?_eq_getElem hi]
simp only [Option.getD_some]
simp only [Inverse.inverseSpec, Inverse.inttAt, Array.getElem_ofFn]
have hsum : (∑ j : D.Idx, a.getD j.1 0 * D.omegaInv ^ (i * j.1)) = 0 := by
apply Finset.sum_eq_zero
intro j _
simp [ha j.1]
rw [hsum]
simp
· rw [Array.getD_eq_getD_getElem?, Array.getElem?_eq_none (Nat.le_of_not_lt hi)]
simp

private theorem pointwise_getD_eq_zero_of_left_trim_size_zero
(D : Domain R) (p q : CPolynomial.Raw R) (hp : p.trim.size = 0) (i : Nat) :
(pointwiseMul D (Forward.forwardSpec D p) (Forward.forwardSpec D q)).getD i 0 = 0 := by
by_cases hi : i < (pointwiseMul D (Forward.forwardSpec D p) (Forward.forwardSpec D q)).size
· rw [Array.getD_eq_getD_getElem?, Array.getElem?_eq_getElem hi]
simp [pointwiseMul]
left
have hpsize : i < (Forward.forwardSpec D p).size := by
simpa [Forward.forwardSpec, pointwiseMul] using hi
have hpget := forwardSpec_getD_eq_zero_of_trim_size_zero D p hp i
rw [Array.getD_eq_getD_getElem?, Array.getElem?_eq_getElem hpsize] at hpget
simpa using hpget
· rw [Array.getD_eq_getD_getElem?, Array.getElem?_eq_none (Nat.le_of_not_lt hi)]
simp

private theorem pointwise_getD_eq_zero_of_right_trim_size_zero
(D : Domain R) (p q : CPolynomial.Raw R) (hq : q.trim.size = 0) (i : Nat) :
(pointwiseMul D (Forward.forwardSpec D p) (Forward.forwardSpec D q)).getD i 0 = 0 := by
by_cases hi : i < (pointwiseMul D (Forward.forwardSpec D p) (Forward.forwardSpec D q)).size
· rw [Array.getD_eq_getD_getElem?, Array.getElem?_eq_getElem hi]
simp [pointwiseMul]
right
have hqsize : i < (Forward.forwardSpec D q).size := by
simpa [Forward.forwardSpec, pointwiseMul] using hi
have hqget := forwardSpec_getD_eq_zero_of_trim_size_zero D q hq i
rw [Array.getD_eq_getD_getElem?, Array.getElem?_eq_getElem hqsize] at hqget
simpa using hqget
· rw [Array.getD_eq_getD_getElem?, Array.getElem?_eq_none (Nat.le_of_not_lt hi)]
simp

/-- Spec pipeline for NTT-based multiplication. -/
@[inline] def fastMulSpec (D : Domain R) (p q : CPolynomial.Raw R) : CPolynomial.Raw R :=
let pHat := Forward.forwardSpec D p
Expand All @@ -403,29 +324,24 @@ private theorem fastMulSpec_coeff_eq_zero_of_left_trim_size_zero
rw [fastMulSpec]
rw [CPolynomial.Raw.Trim.coeff_eq_coeff]
rw [coeff_truncate]
by_cases hi : i < Domain.requiredLength p q
· rw [if_pos hi]
rw [CPolynomial.Raw.coeff]
exact inverseSpec_getD_eq_zero_of_getD_zero D
(pointwiseMul D (Forward.forwardSpec D p) (Forward.forwardSpec D q))
(pointwise_getD_eq_zero_of_left_trim_size_zero D p q hp) i
· rw [if_neg hi]
rw [Domain.requiredLength_eq_zero_of_left_trim_size_zero p q hp]
simp

private theorem fastMulSpec_coeff_eq_zero_of_right_trim_size_zero
(D : Domain R) (p q : CPolynomial.Raw R) (hq : q.trim.size = 0) (i : Nat) :
(fastMulSpec D p q).coeff i = 0 := by
rw [fastMulSpec]
rw [CPolynomial.Raw.Trim.coeff_eq_coeff]
rw [coeff_truncate]
by_cases hi : i < Domain.requiredLength p q
· rw [if_pos hi]
rw [CPolynomial.Raw.coeff]
exact inverseSpec_getD_eq_zero_of_getD_zero D
(pointwiseMul D (Forward.forwardSpec D p) (Forward.forwardSpec D q))
(pointwise_getD_eq_zero_of_right_trim_size_zero D p q hq) i
· rw [if_neg hi]

/-- Implementation pipeline for NTT-based multiplication. -/
rw [Domain.requiredLength_eq_zero_of_right_trim_size_zero p q hq]
simp

/--
Implementation pipeline for NTT-based multiplication.

Callers must provide a domain satisfying `Domain.fits D p q`; otherwise the
result is truncated to the domain-supported convolution length.
-/
@[inline] def fastMulImpl (D : Domain R) (p q : CPolynomial.Raw R) : CPolynomial.Raw R :=
let pHat := Forward.forwardImpl D p
let qHat := Forward.forwardImpl D q
Expand All @@ -452,18 +368,15 @@ theorem fastMulSpec_coeff (D : Domain R) (p q : CPolynomial.Raw R)
have hfit' : Domain.requiredLength p q ≤ D.n := by
simpa [Domain.fits] using hfit
have hfitLen : p.trim.size + q.trim.size - 1 ≤ D.n := by
simpa [Domain.requiredLength] using hfit'
simpa [Domain.requiredLength_eq_of_trim_size_pos p q hppos hqpos] using hfit'
have hpdeg_lt_trim := natDegree_toPoly_lt_trim_size_of_pos p hppos
have hqdeg_lt_trim := natDegree_toPoly_lt_trim_size_of_pos q hqpos
have hpdeg : p.toPoly.natDegree < D.n := by
omega
have hqdeg : q.toPoly.natDegree < D.n := by
omega
have hpqdeg : (p * q).toPoly.natDegree < D.n := by
have hpoly : (p * q).toPoly = p.toPoly * q.toPoly := by
ext j
exact CPolynomial.Raw.toPoly_mul_coeff p q j
rw [hpoly]
rw [CPolynomial.Raw.toPoly_mul p q]
refine lt_of_le_of_lt Polynomial.natDegree_mul_le ?_
omega
rw [fastMulSpec]
Expand All @@ -478,7 +391,8 @@ theorem fastMulSpec_coeff (D : Domain R) (p q : CPolynomial.Raw R)
rw [hpoint]
exact inverse_forwardSpec_coeff_of_lt D (p * q) hiD
· rw [if_neg hi]
exact (mul_coeff_eq_zero_of_requiredLength_le p q (Nat.le_of_not_lt hi)).symm
exact (mul_coeff_eq_zero_of_requiredLength_le p q hppos hqpos
(Nat.le_of_not_lt hi)).symm

theorem fastMulSpec_eq_mul (D : Domain R) (p q : CPolynomial.Raw R)
(hfit : Domain.fits D p q) : fastMulSpec D p q = p * q := by
Expand Down
21 changes: 8 additions & 13 deletions CompPoly/Univariate/NTT/Forward.lean
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ private def forwardMathPairsSpec
simp [forwardStageSpec, bitRevPermute]
| succ completed ih =>
rw [forwardStageSpec_succ]
exact size_butterflyStage D completed (forwardStageSpec D completed a) ih
rw [size_butterflyStage]
exact ih

@[simp] theorem size_forwardStagePureSpec (D : Domain R) (completed : Nat) (a : Array R) :
(forwardStagePureSpec D completed a).size = D.n := by
Expand All @@ -145,7 +146,8 @@ private def forwardMathPairsSpec
simp [forwardStagePureSpec, bitRevPermute]
| succ completed ih =>
rw [forwardStagePureSpec_succ]
exact size_butterflyStageSpec D completed (forwardStagePureSpec D completed a) ih
rw [size_butterflyStageSpec]
exact ih

/--
The algorithmic stage recursion agrees with the pure stage recursion.
Expand All @@ -162,7 +164,6 @@ theorem forwardStageSpec_eq_forwardStagePureSpec (D : Domain R) (a : Array R) :
| succ completed ih =>
rw [forwardStageSpec_succ, forwardStagePureSpec_succ, ih]
rw [butterflyStage_eq_butterflyStageSpec D completed (forwardStagePureSpec D completed a)]
exact size_forwardStagePureSpec D completed a

/--
Base case of the mathematical stage invariant: before any butterflies, the
Expand Down Expand Up @@ -371,12 +372,6 @@ private theorem omega_pow_domain_half_eq_neg_one
exact IsPrimitiveRoot.pow D.n_pos D.primitive hprod
exact IsPrimitiveRoot.eq_neg_one_of_two_right hprim2

private theorem omega_pow_stage_stride_eq_neg_one
(D : Domain R) (stage : Nat) (hstage : stage < D.logN) :
D.omega ^ (2 ^ stage * 2 ^ (D.logN - (stage + 1))) = -1 := by
rw [stage_stride_half_eq_domain_half D stage hstage]
exact omega_pow_domain_half_eq_neg_one D (by omega)

private theorem forwardMathPairsSpec_get_lower_current
(D : Domain R) (stage block j : Nat) (a : Array R) (hj : j < 2 ^ stage)
(hi : block * 2 ^ (stage + 1) + j < (forwardMathPairsSpec D stage block j a).size) :
Expand Down Expand Up @@ -540,7 +535,7 @@ private theorem forwardMathValueAt_succ_upper
ring

private theorem eq_lower_or_upper_of_block_pair
(stage block j i : Nat) (_hj : j < 2 ^ stage)
(stage block j i : Nat)
(hblock : i / 2 ^ (stage + 1) = block) (hpair : i % 2 ^ stage = j) :
i = block * 2 ^ (stage + 1) + j ∨
i = block * 2 ^ (stage + 1) + j + 2 ^ stage := by
Expand Down Expand Up @@ -583,7 +578,7 @@ private theorem eq_lower_or_upper_of_block_pair
omega

private theorem forwardMathPairsSpec_get_unchanged
(D : Domain R) (stage block j : Nat) (a : Array R) (hj : j < 2 ^ stage)
(D : Domain R) (stage block j : Nat) (a : Array R)
{i : Nat}
(hiOld : i < (forwardMathPairsSpec D stage block j a).size)
(hiNew : i < (forwardMathPairsSpec D stage block (j + 1) a).size)
Expand All @@ -603,7 +598,7 @@ private theorem forwardMathPairsSpec_get_unchanged
rw [if_neg hltPair]
by_cases hltPairNext : i % 2 ^ stage < j + 1
· have hpair : i % 2 ^ stage = j := by omega
rcases eq_lower_or_upper_of_block_pair stage block j i hj hEqBlock hpair with h | h
rcases eq_lower_or_upper_of_block_pair stage block j i hEqBlock hpair with h | h
· exact (hneLower h.symm).elim
· exact (hneUpper h.symm).elim
· rw [if_neg hltPairNext]
Expand Down Expand Up @@ -708,7 +703,7 @@ private theorem butterflyInnerStep_forwardMathPairsSpec_succ
rw [forwardMathPairsSpec_get_lower_next D stage block donePairs a hdonePairs hi₂]
exact (forwardMathValueAt_succ_lower D stage block donePairs a hstage hdonePairs).symm
· rw [if_neg hLower]
exact forwardMathPairsSpec_get_unchanged D stage block donePairs a hdonePairs
exact forwardMathPairsSpec_get_unchanged D stage block donePairs a
hi₁ hi₂ hLower hUpper
· rw [pow_succ]

Expand Down
24 changes: 9 additions & 15 deletions CompPoly/Univariate/NTT/Transform.lean
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,8 @@ specification.
This is where the local `set!` bookkeeping for one stage belongs.
-/
theorem butterflyStage_eq_butterflyStageSpec
(D : Domain R) (stage : Nat) (a : Array R) (ha : a.size = D.n) :
(D : Domain R) (stage : Nat) (a : Array R) :
butterflyStage D stage a = butterflyStageSpec D stage a := by
have _ := ha
let blockSize : Nat := 2 ^ (stage + 1)
let half : Nat := 2 ^ stage
let wm := D.omega ^ (D.n / blockSize)
Expand Down Expand Up @@ -209,19 +208,14 @@ private theorem size_butterflyStageSpec_aux
| n + 1, acc => by
simp [size_butterflyStageSpec_aux blockSize half wm n acc]

theorem size_butterflyStageSpec (D : Domain R) (stage : Nat) (a : Array R) (ha : a.size = D.n) :
(butterflyStageSpec D stage a).size = D.n := by
let blockSize : Nat := 2 ^ (stage + 1)
let half : Nat := 2 ^ stage
let wm := D.omega ^ (D.n / blockSize)
rw [show (butterflyStageSpec D stage a).size = a.size by
simp [butterflyStageSpec, size_butterflyStageSpec_aux]]
exact ha

theorem size_butterflyStage (D : Domain R) (stage : Nat) (a : Array R) (ha : a.size = D.n) :
(butterflyStage D stage a).size = D.n := by
rw [butterflyStage_eq_butterflyStageSpec D stage a ha]
exact size_butterflyStageSpec D stage a ha
theorem size_butterflyStageSpec (D : Domain R) (stage : Nat) (a : Array R) :
(butterflyStageSpec D stage a).size = a.size := by
simp [butterflyStageSpec, size_butterflyStageSpec_aux]

theorem size_butterflyStage (D : Domain R) (stage : Nat) (a : Array R) :
(butterflyStage D stage a).size = a.size := by
rw [butterflyStage_eq_butterflyStageSpec D stage a]
exact size_butterflyStageSpec D stage a

/-- Run all radix-2 butterfly stages (complexity: `O(n log n)`). -/
def runStages (D : Domain R) (a : Array R) : Array R := Id.run do
Expand Down
6 changes: 6 additions & 0 deletions CompPoly/Univariate/ToPoly/Equiv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ lemma Raw.toPoly_mul_coeff [LawfulBEq R] (p q : CPolynomial.Raw R) (i : ℕ) :
rcases h_coeff (i - x) with ⟨_, hq⟩
simp [hp, hq]

@[grind =]
lemma Raw.toPoly_mul [LawfulBEq R] (p q : CPolynomial.Raw R) :
(p * q).toPoly = p.toPoly * q.toPoly := by
ext i
exact Raw.toPoly_mul_coeff p q i

@[grind =]
lemma toPoly_mul_coeffC [LawfulBEq R] (p q : CPolynomial R) (i : ℕ) :
(p.val * q.val).toPoly.coeff i = (p.val.toPoly * q.val.toPoly).coeff i := by
Expand Down
2 changes: 1 addition & 1 deletion tests/CompPolyTests/Univariate/NTT/FastMul.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import CompPolyTests.Univariate.NTT.Common
/-!
# Univariate NTT FastMul Tests

Concrete executable checks for the temporary spec-backed NTT multiplication path.
Concrete executable checks for the iterative butterfly NTT multiplication path.
-/

namespace CompPoly
Expand Down
2 changes: 1 addition & 1 deletion tests/CompPolyTests/Univariate/NTT/Forward.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import CompPolyTests.Univariate.NTT.Common
/-!
# Univariate NTT Forward Tests

Concrete executable checks for the temporary spec-backed forward NTT path.
Concrete executable checks for the iterative butterfly forward NTT path.
-/

namespace CompPoly
Expand Down
2 changes: 1 addition & 1 deletion tests/CompPolyTests/Univariate/NTT/Inverse.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import CompPolyTests.Univariate.NTT.Common
/-!
# Univariate NTT Inverse Tests

Concrete executable checks for the temporary spec-backed inverse NTT path.
Concrete executable checks for the iterative butterfly inverse NTT path.
-/

namespace CompPoly
Expand Down
Loading