diff --git a/CompPoly/Data/Fin/BigOperators.lean b/CompPoly/Data/Fin/BigOperators.lean index 7aa35245..590a762f 100644 --- a/CompPoly/Data/Fin/BigOperators.lean +++ b/CompPoly/Data/Fin/BigOperators.lean @@ -165,77 +165,50 @@ This is useful for definitions that process elements in reverse order, like `fol convert motive_next_ind termination_by (r - 1 - i.val) --- The theorem statement and its proof. --- TODO: state a more generalized and reusable version of this, where f is from Fin r → M /-- -Splits a sum over `Fin (2^n)` into a sum over even indices and a sum over odd indices. -Useful for Fast Fourier Transform (FFT) type recursions. +Splits a sum over `Fin (2 * r)` into a sum over even indices and a sum over odd indices. +Generalizes `Fin.sum_univ_odd_even` from `2^n` to arbitrary `r`. -/ -theorem Fin.sum_univ_odd_even {n : ℕ} {M : Type*} [AddCommMonoid M] (f : ℕ → M) : - (∑ i : Fin (2 ^ n), f (2 * i)) + (∑ i : Fin (2 ^ n), f (2 * i + 1)) - = ∑ i: Fin (2 ^ (n+1)), f i := by +theorem Fin.sum_univ_even_odd {r : ℕ} {M : Type*} [AddCommMonoid M] (f : ℕ → M) : + (∑ i : Fin r, f (2 * i)) + (∑ i : Fin r, f (2 * i + 1)) + = ∑ i : Fin (2 * r), f i := by set f_even := fun i => f (2 * i) set f_odd := fun i => f (2 * i + 1) conv_lhs => - enter [1, 2, i] - change f_even i + enter [1, 2, i]; change f_even i conv_lhs => - enter [2, 2, i] - change f_odd i + enter [2, 2, i]; change f_odd i simp only [Fin.sum_univ_eq_sum_range] - - -- Let's define the sets of even and odd numbers. - let evens: Finset ℕ := Finset.image (fun i ↦ 2 * i) (Finset.range (2^n)) - let odds: Finset ℕ := Finset.image (fun i ↦ 2 * i + 1) (Finset.range (2^n)) - + let evens : Finset ℕ := Finset.image (fun i ↦ 2 * i) (Finset.range r) + let odds : Finset ℕ := Finset.image (fun i ↦ 2 * i + 1) (Finset.range r) conv_lhs => - enter [1]; - rw [←Finset.sum_image (g:=fun i => 2 * i) (by simp)] - + enter [1]; rw [← Finset.sum_image (g := fun i => 2 * i) (by simp)] conv_lhs => - enter [2]; - rw [← Finset.sum_image (g:=fun i => 2 * i + 1) (by simp)] - - -- First, we prove that the set on the RHS is the disjoint union of evens and odds. + enter [2]; rw [← Finset.sum_image (g := fun i => 2 * i + 1) (by simp)] have h_disjoint : Disjoint evens odds := by apply Finset.disjoint_iff_ne.mpr - -- Assume for contradiction that an element `x` is in both sets. rintro x hx y hy hxy - -- Unpack the definitions of `evens` and `odds`. rcases Finset.mem_image.mp hx with ⟨k₁, _, rfl⟩ rcases Finset.mem_image.mp hy with ⟨k₂, _, rfl⟩ omega - - have h_union : evens ∪ odds = Finset.range (2 ^ (n + 1)) := by - apply Finset.ext; intro x - simp only [Finset.mem_union, Finset.mem_range] - -- ⊢ x ∈ evens ∨ x ∈ odds ↔ x < 2 ^ (n + 1) + have h_union : evens ∪ odds = Finset.range (2 * r) := by + ext x; simp only [Finset.mem_union, Finset.mem_range] constructor - · -- First direction: `x ∈ evens ∪ odds → x < 2^(n+1)` - -- This follows from the bounds of the original range `Finset.range (2^n)`. - intro h - rcases h with (h_even | h_odd) - · rcases Finset.mem_image.mp h_even with ⟨k₁, hk₁, rfl⟩ - simp at hk₁ - omega - · rcases Finset.mem_image.mp h_odd with ⟨k₂, hk₂, rfl⟩ - simp at hk₂ - omega - · -- Second direction: `x < 2^(n+1) → x ∈ evens ∪ odds` - intro hx + · rintro (h_even | h_odd) + · rcases Finset.mem_image.mp h_even with ⟨k₁, hk₁, rfl⟩; simp at hk₁; omega + · rcases Finset.mem_image.mp h_odd with ⟨k₂, hk₂, rfl⟩; simp at hk₂; omega + · intro hx obtain (⟨k, rfl⟩ | ⟨k, rfl⟩) := Nat.even_or_odd x - · left; - unfold evens - simp only [Finset.mem_image, Finset.mem_range] - use k; - omega - · right; - unfold odds - simp only [Finset.mem_image, Finset.mem_range] - use k; - omega - -- Now, rewrite the RHS using this partition. - rw [←h_union, Finset.sum_union h_disjoint] + · left; simp only [evens, Finset.mem_image, Finset.mem_range]; use k; omega + · right; simp only [odds, Finset.mem_image, Finset.mem_range]; use k; omega + rw [← h_union, Finset.sum_union h_disjoint] + +/-- Specialized form of `Fin.sum_univ_even_odd` for `r = 2^n`. -/ +theorem Fin.sum_univ_odd_even {n : ℕ} {M : Type*} [AddCommMonoid M] (f : ℕ → M) : + (∑ i : Fin (2 ^ n), f (2 * i)) + (∑ i : Fin (2 ^ n), f (2 * i + 1)) + = ∑ i : Fin (2 ^ (n + 1)), f i := by + rw [show 2 ^ (n + 1) = 2 * 2 ^ n from by rw [pow_succ, Nat.mul_comm]] + exact Fin.sum_univ_even_odd f /-- Splits a sum over an interval `[a, c]` into two sums over `[a, b]` and `[b+1, c]`.