Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 27 additions & 54 deletions CompPoly/Data/Fin/BigOperators.lean
Original file line number Diff line number Diff line change
Expand Up @@ -165,77 +165,50 @@ This is useful for definitions that process elements in reverse order, like `fol
convert motive_next_ind
termination_by (r - 1 - i.val)

-- The theorem statement and its proof.
-- TODO: state a more generalized and reusable version of this, where f is from Fin r → M
/--
Splits a sum over `Fin (2^n)` into a sum over even indices and a sum over odd indices.
Useful for Fast Fourier Transform (FFT) type recursions.
Splits a sum over `Fin (2 * r)` into a sum over even indices and a sum over odd indices.
Generalizes `Fin.sum_univ_odd_even` from `2^n` to arbitrary `r`.
-/
theorem Fin.sum_univ_odd_even {n : ℕ} {M : Type*} [AddCommMonoid M] (f : ℕ → M) :
(∑ i : Fin (2 ^ n), f (2 * i)) + (∑ i : Fin (2 ^ n), f (2 * i + 1))
= ∑ i: Fin (2 ^ (n+1)), f i := by
theorem Fin.sum_univ_even_odd {r : ℕ} {M : Type*} [AddCommMonoid M] (f : ℕ → M) :
(∑ i : Fin r, f (2 * i)) + (∑ i : Fin r, f (2 * i + 1))
= ∑ i : Fin (2 * r), f i := by
set f_even := fun i => f (2 * i)
set f_odd := fun i => f (2 * i + 1)
conv_lhs =>
enter [1, 2, i]
change f_even i
enter [1, 2, i]; change f_even i
conv_lhs =>
enter [2, 2, i]
change f_odd i
enter [2, 2, i]; change f_odd i
simp only [Fin.sum_univ_eq_sum_range]

-- Let's define the sets of even and odd numbers.
let evens: Finset ℕ := Finset.image (fun i ↦ 2 * i) (Finset.range (2^n))
let odds: Finset ℕ := Finset.image (fun i ↦ 2 * i + 1) (Finset.range (2^n))

let evens : Finset ℕ := Finset.image (fun i ↦ 2 * i) (Finset.range r)
let odds : Finset ℕ := Finset.image (fun i ↦ 2 * i + 1) (Finset.range r)
conv_lhs =>
enter [1];
rw [←Finset.sum_image (g:=fun i => 2 * i) (by simp)]

enter [1]; rw [← Finset.sum_image (g := fun i => 2 * i) (by simp)]
conv_lhs =>
enter [2];
rw [← Finset.sum_image (g:=fun i => 2 * i + 1) (by simp)]

-- First, we prove that the set on the RHS is the disjoint union of evens and odds.
enter [2]; rw [← Finset.sum_image (g := fun i => 2 * i + 1) (by simp)]
have h_disjoint : Disjoint evens odds := by
apply Finset.disjoint_iff_ne.mpr
-- Assume for contradiction that an element `x` is in both sets.
rintro x hx y hy hxy
-- Unpack the definitions of `evens` and `odds`.
rcases Finset.mem_image.mp hx with ⟨k₁, _, rfl⟩
rcases Finset.mem_image.mp hy with ⟨k₂, _, rfl⟩
omega

have h_union : evens ∪ odds = Finset.range (2 ^ (n + 1)) := by
apply Finset.ext; intro x
simp only [Finset.mem_union, Finset.mem_range]
-- ⊢ x ∈ evens ∨ x ∈ odds ↔ x < 2 ^ (n + 1)
have h_union : evens ∪ odds = Finset.range (2 * r) := by
ext x; simp only [Finset.mem_union, Finset.mem_range]
constructor
· -- First direction: `x ∈ evens ∪ odds → x < 2^(n+1)`
-- This follows from the bounds of the original range `Finset.range (2^n)`.
intro h
rcases h with (h_even | h_odd)
· rcases Finset.mem_image.mp h_even with ⟨k₁, hk₁, rfl⟩
simp at hk₁
omega
· rcases Finset.mem_image.mp h_odd with ⟨k₂, hk₂, rfl⟩
simp at hk₂
omega
· -- Second direction: `x < 2^(n+1) → x ∈ evens ∪ odds`
intro hx
· rintro (h_even | h_odd)
· rcases Finset.mem_image.mp h_even with ⟨k₁, hk₁, rfl⟩; simp at hk₁; omega
· rcases Finset.mem_image.mp h_odd with ⟨k₂, hk₂, rfl⟩; simp at hk₂; omega
· intro hx
obtain (⟨k, rfl⟩ | ⟨k, rfl⟩) := Nat.even_or_odd x
· left;
unfold evens
simp only [Finset.mem_image, Finset.mem_range]
use k;
omega
· right;
unfold odds
simp only [Finset.mem_image, Finset.mem_range]
use k;
omega
-- Now, rewrite the RHS using this partition.
rw [←h_union, Finset.sum_union h_disjoint]
· left; simp only [evens, Finset.mem_image, Finset.mem_range]; use k; omega
· right; simp only [odds, Finset.mem_image, Finset.mem_range]; use k; omega
rw [← h_union, Finset.sum_union h_disjoint]

/-- Specialized form of `Fin.sum_univ_even_odd` for `r = 2^n`. -/
theorem Fin.sum_univ_odd_even {n : ℕ} {M : Type*} [AddCommMonoid M] (f : ℕ → M) :
(∑ i : Fin (2 ^ n), f (2 * i)) + (∑ i : Fin (2 ^ n), f (2 * i + 1))
= ∑ i : Fin (2 ^ (n + 1)), f i := by
rw [show 2 ^ (n + 1) = 2 * 2 ^ n from by rw [pow_succ, Nat.mul_comm]]
exact Fin.sum_univ_even_odd f

/--
Splits a sum over an interval `[a, c]` into two sums over `[a, b]` and `[b+1, c]`.
Expand Down