Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 25 additions & 25 deletions CompPoly/Data/Nat/Bitwise.lean
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ lemma testBit_true_eq_getBit_eq_1 (k n : Nat) : n.testBit k = ((Nat.getBit k n)
simp only [one_and_eq_mod_two, mod_two_bne_zero, beq_iff_eq, and_one_is_mod]

lemma testBit_false_eq_getBit_eq_0 (k n : Nat) :
(n.testBit k = false) = ((Nat.getBit k n) = 0) := by
(n.testBit k = false) = ((Nat.getBit k n) = 0) := by
unfold getBit
rw [Nat.testBit]
simp only [one_and_eq_mod_two, mod_two_bne_zero, beq_eq_false_iff_ne, ne_eq, mod_two_not_eq_one,
Expand Down Expand Up @@ -73,7 +73,7 @@ lemma getBit_zero_eq_zero {k : Nat} : getBit k 0 = 0 := by
rw [Nat.and_one_is_mod]

lemma getBit_eq_zero_or_one {k n : Nat} :
getBit k n = 0 ∨ getBit k n = 1 := by
getBit k n = 0 ∨ getBit k n = 1 := by
unfold getBit
rw [Nat.and_one_is_mod]
simp only [Nat.mod_two_eq_zero_or_one]
Expand All @@ -96,7 +96,7 @@ lemma getLowBits_zero_eq_zero {n : ℕ} : getLowBits 0 n = 0 := by
simp only [Nat.shiftLeft_zero, Nat.sub_self, Nat.and_zero]

lemma getLowBits_eq_mod_two_pow {numLowBits : ℕ} (n : ℕ) :
getLowBits numLowBits n = n % (2 ^ numLowBits) := by
getLowBits numLowBits n = n % (2 ^ numLowBits) := by
unfold getLowBits
rw [Nat.shiftLeft_eq, one_mul]
exact Nat.and_two_pow_sub_one_eq_mod n numLowBits
Expand Down Expand Up @@ -237,8 +237,8 @@ lemma and_two_pow_eq_two_pow_of_getBit_1 {n i : ℕ} (h_getBit : getBit i n = 1)
conv_lhs => rw [Nat.and_two_pow (n:=n) (i:=i)]
simp only [h_testBit_i_eq_1, Bool.toNat_true, one_mul]

lemma and_two_pow_eq_two_pow_of_getBit_eq_one {n i : ℕ} (h_getBit : getBit i n = 1)
: n &&& (2^i) = 2^i := by
lemma and_two_pow_eq_two_pow_of_getBit_eq_one {n i : ℕ} (h_getBit : getBit i n = 1) :
n &&& (2^i) = 2^i := by
apply eq_iff_eq_all_getBits.mpr; unfold getBit
intro k
have h_getBit_two_pow := getBit_two_pow (i := i) (k := k)
Expand Down Expand Up @@ -267,13 +267,13 @@ lemma eq_zero_or_eq_one_of_lt_two {n : ℕ} (h_lt : n < 2) : n = 0 ∨ n = 1 :=
· right; rfl

lemma div_2_form {nD2 b : ℕ} (h_b : b < 2) :
(nD2 * 2 + b) / 2 = nD2 := by
(nD2 * 2 + b) / 2 = nD2 := by
rw [←add_comm, ←mul_comm]
rw [Nat.add_mul_div_left (x := b) (y := 2) (z := nD2) (H := by norm_num)]
norm_num; exact h_b;

lemma and_by_split_lowBits {n m n1 m1 bn bm : ℕ} (h_bn : bn < 2) (h_bm : bm < 2)
(h_n : n = n1 * 2 + bn) (h_m : m = m1 * 2 + bm) :
(h_n : n = n1 * 2 + bn) (h_m : m = m1 * 2 + bm) :
n &&& m = (n1 &&& m1) * 2 + (bn &&& bm) := by -- main tool : Nat.div_add_mod /2
rw [h_n, h_m]
-- ⊢ (n1 * 2 + bn) &&& (m1 * 2 + bm) = (n1 &&& m1) * 2 + (bn &&& bm)
Expand Down Expand Up @@ -302,7 +302,7 @@ lemma and_by_split_lowBits {n m n1 m1 bn bm : ℕ} (h_bn : bn < 2) (h_bm : bm <
rw [←Nat.div_add_mod ((n1 * 2 + bn) &&& (m1 * 2 + bm)) 2, h_div_eq, h_mod_eq, Nat.div_add_mod]

lemma xor_by_split_lowBits {n m n1 m1 bn bm : ℕ} (h_bn : bn < 2) (h_bm : bm < 2)
(h_n : n = n1 * 2 + bn) (h_m : m = m1 * 2 + bm) :
(h_n : n = n1 * 2 + bn) (h_m : m = m1 * 2 + bm) :
n ^^^ m = (n1 ^^^ m1) * 2 + (bn ^^^ bm) := by
rw [h_n, h_m]
-- ⊢ (n1 * 2 + bn) ^^^ (m1 * 2 + bm) = (n1 ^^^ m1) * 2 + (bn ^^^ bm)
Expand Down Expand Up @@ -333,7 +333,7 @@ lemma xor_by_split_lowBits {n m n1 m1 bn bm : ℕ} (h_bn : bn < 2) (h_bm : bm <
rw [←Nat.div_add_mod ((n1 * 2 + bn) ^^^ (m1 * 2 + bm)) 2, h_div_eq, h_mod_eq, Nat.div_add_mod]

lemma or_by_split_lowBits {n m n1 m1 bn bm : ℕ} (h_bn : bn < 2) (h_bm : bm < 2)
(h_n : n = n1 * 2 + bn) (h_m : m = m1 * 2 + bm) :
(h_n : n = n1 * 2 + bn) (h_m : m = m1 * 2 + bm) :
n ||| m = (n1 ||| m1) * 2 + (bn ||| bm) := by
rw [h_n, h_m]
-- ⊢ (n1 * 2 + bn) ||| (m1 * 2 + bm) = (n1 ||| m1) * 2 + (bn ||| bm)
Expand Down Expand Up @@ -365,10 +365,10 @@ lemma or_by_split_lowBits {n m n1 m1 bn bm : ℕ} (h_bn : bn < 2) (h_bm : bm < 2

lemma sum_eq_xor_plus_twice_and (n : Nat) : ∀ m : ℕ, n + m = (n ^^^ m) + 2 * (n &&& m) := by
induction n using Nat.binaryRec with
| z =>
| zero =>
intro m
rw [zero_add, Nat.zero_and, mul_zero, add_zero, Nat.zero_xor]
| f bn n2 ih =>
| bit bn n2 ih =>
intro m
let resDiv2M := Nat.boddDiv2 m
let bm := resDiv2M.fst
Expand Down Expand Up @@ -397,7 +397,7 @@ lemma sum_eq_xor_plus_twice_and (n : Nat) : ∀ m : ℕ, n + m = (n ^^^ m) + 2 *
rw [←h_m]
unfold mVal
simp only [h_bm, h_m2]
exact Nat.bit_decomp m
exact Nat.bit_bodd_div2 m
rw [←h_mVal_eq_m]
have h_and : nVal &&& mVal = (n2 &&& m2) * 2 + (getBitN &&& getBitM) :=
and_by_split_lowBits (h_bn := h_getBitN) (h_bm := h_getBitM) (h_n := h_n) (h_m := h_m)
Expand All @@ -409,7 +409,7 @@ lemma sum_eq_xor_plus_twice_and (n : Nat) : ∀ m : ℕ, n + m = (n ^^^ m) + 2 *
omega

lemma add_shiftRight_distrib {n m k : ℕ} (h_and_zero : n &&& m = 0) :
(n + m) >>> k = (n >>> k) + (m >>> k) := by
(n + m) >>> k = (n >>> k) + (m >>> k) := by
rw [sum_eq_xor_plus_twice_and, h_and_zero, mul_zero, add_zero]
conv =>
rhs
Expand Down Expand Up @@ -482,7 +482,7 @@ lemma xor_eq_sub_iff_submask {n m : ℕ} (h : m ≤ n) : n ^^^ m = n - m ↔ n &
rw [Nat.and_self, Nat.xor_self, mul_zero, add_zero]

lemma getBit_of_add_distrib {n m k : ℕ}
(h_n_AND_m : n &&& m = 0) : getBit k (n + m) = getBit k n + getBit k m := by
(h_n_AND_m : n &&& m = 0) : getBit k (n + m) = getBit k n + getBit k m := by
unfold getBit
rw [sum_of_and_eq_zero_is_xor h_n_AND_m]
rw [Nat.shiftRight_xor_distrib, Nat.and_xor_distrib_right]
Expand All @@ -499,7 +499,7 @@ lemma getBit_of_add_distrib {n m k : ℕ}
exact (sum_of_and_eq_zero_is_xor (n := getBitN) (m := getBitM) h_getBitN_and_getBitM).symm

lemma add_two_pow_of_getBit_eq_zero_lt_two_pow {n m i : ℕ} (h_n : n < 2 ^ m) (h_i : i < m)
(h_getBit_at_i_eq_zero : getBit i n = 0) :
(h_getBit_at_i_eq_zero : getBit i n = 0) :
n + 2^i < 2^m := by
have h_j_and: n &&& (2^i) = 0 := by
rw [and_two_pow_eq_zero_of_getBit_0 (n:=n) (i:=i)]
Expand All @@ -511,7 +511,7 @@ lemma add_two_pow_of_getBit_eq_zero_lt_two_pow {n m i : ℕ} (h_n : n < 2 ^ m) (
exact h_and_lt

lemma getBit_of_multiple_of_power_of_two {n p : ℕ} : ∀ k,
getBit (k) (2^p * n) = if k < p then 0 else getBit (k-p) n := by
getBit (k) (2^p * n) = if k < p then 0 else getBit (k-p) n := by
intro k
have h_test := Nat.testBit_two_pow_mul (i := p) (a := n) (j:=k)
simp only [Nat.testBit, Nat.and_comm 1] at h_test
Expand Down Expand Up @@ -541,14 +541,14 @@ lemma getBit_of_multiple_of_power_of_two {n p : ℕ} : ∀ k,
simp only [getBit, Nat.and_one_is_mod, h_test]

lemma getBit_of_shiftLeft {n p : ℕ} :
∀ k, getBit (k) (n <<< p) = if k < p then 0 else getBit (k - p) n := by
∀ k, getBit (k) (n <<< p) = if k < p then 0 else getBit (k - p) n := by
intro k
rw [getBit_of_multiple_of_power_of_two (n:=n) (p:=p) (k:=k).symm]
congr
rw [Nat.shiftLeft_eq, mul_comm]

lemma getBit_of_shiftRight {n p : ℕ} :
∀ k, getBit k (n >>> p) = getBit (k+p) n := by
∀ k, getBit k (n >>> p) = getBit (k+p) n := by
intro k
unfold getBit
rw [←Nat.shiftRight_add]
Expand Down Expand Up @@ -588,7 +588,7 @@ lemma getBit_of_two_pow_sub_one {i k : ℕ} : getBit k (2^i - 1) =
simp only [h_test]

lemma getBit_of_sub_two_pow_of_bit_1 {n i j : ℕ} (h_getBit_eq_1 : getBit i n = 1) :
getBit j (n - 2^i) = (if j = i then 0 else getBit j n) := by
getBit j (n - 2^i) = (if j = i then 0 else getBit j n) := by
have h_2_pow_i_lt_n: 2^i ≤ n := by
apply Nat.ge_two_pow_of_testBit
rw [Nat.testBit_true_eq_getBit_eq_1]
Expand Down Expand Up @@ -657,7 +657,7 @@ lemma getBit_eq_pred_getBit_of_div_two {n k : ℕ} (h_k : k > 0) :

-- TODO: uniqueness of this representation?
theorem getBit_repr {ℓ : Nat} : ∀ j, j < 2^ℓ →
j = ∑ k ∈ Finset.Icc 0 (ℓ-1), (getBit k j) * 2^k := by
j = ∑ k ∈ Finset.Icc 0 (ℓ-1), (getBit k j) * 2^k := by
induction ℓ with
| zero =>
-- Base case : ℓ = 0
Expand Down Expand Up @@ -782,7 +782,7 @@ theorem getBit_repr {ℓ : Nat} : ∀ j, j < 2^ℓ →
rw [←h_j_eq]

theorem getBit_repr_univ {ℓ : Nat} : ∀ j, j < 2^ℓ →
j = ∑ k ∈ Finset.univ (α:=Fin ℓ), (getBit k j) * 2^k.val := by
j = ∑ k ∈ Finset.univ (α:=Fin ℓ), (getBit k j) * 2^k.val := by
intro j h_j
have h_repr_Icc := getBit_repr (ℓ:=ℓ) (j:=j) (by omega)
rw [h_repr_Icc]
Expand Down Expand Up @@ -905,7 +905,7 @@ theorem and_highBits_lowBits_eq_zero {n : ℕ} (numLowBits : ℕ) :
rw [h_getBit_right_eq_0, Nat.and_zero]

lemma num_eq_highBits_add_lowBits {n : ℕ} (numLowBits : ℕ) :
n = getHighBits numLowBits n + getLowBits numLowBits n := by
n = getHighBits numLowBits n + getLowBits numLowBits n := by
apply eq_iff_eq_all_getBits.mpr; unfold getBit
intro k
--- use 2 getBit extractions to get the condition for getLowBits of ((n >>> numLowBits) <<<
Expand All @@ -932,7 +932,7 @@ lemma num_eq_highBits_add_lowBits {n : ℕ} (numLowBits : ℕ) :
rw [Nat.sub_add_cancel (n:=k) (m:=numLowBits) (by omega)]

lemma num_eq_highBits_xor_lowBits {n : ℕ} (numLowBits : ℕ) :
n = getHighBits numLowBits n ^^^ getLowBits numLowBits n := by
n = getHighBits numLowBits n ^^^ getLowBits numLowBits n := by
rw [←sum_of_and_eq_zero_is_xor]
· exact num_eq_highBits_add_lowBits (n := n) (numLowBits := numLowBits)
· exact and_highBits_lowBits_eq_zero (n := n) (numLowBits := numLowBits)
Expand All @@ -957,7 +957,7 @@ lemma getBit_of_highBits_no_shl {n : ℕ} (numLowBits : ℕ) :
exact getBit_of_shiftRight k

lemma getBit_of_lt_two_pow {n : ℕ} (a : Fin (2 ^ n)) (k : ℕ) :
getBit k a = if k < n then getBit k a else 0 := by
getBit k a = if k < n then getBit k a else 0 := by
if h_k: k < n then
simp only [h_k, ↓reduceIte]
else
Expand All @@ -971,7 +971,7 @@ lemma getBit_of_lt_two_pow {n : ℕ} (a : Fin (2 ^ n)) (k : ℕ) :

-- Note: maybe we can generalize this into a non-empty set of diff bits
lemma exist_bit_diff_if_diff {n : ℕ} (a : Fin (2 ^ n)) (b : Fin (2 ^ n)) (h_a_ne_b : a ≠ b) :
∃ k: Fin n, getBit k a ≠ getBit k b := by
∃ k: Fin n, getBit k a ≠ getBit k b := by
by_contra h_no_diff
push_neg at h_no_diff
have h_a_eq_b: a = b := by
Expand Down
2 changes: 1 addition & 1 deletion CompPoly/Multivariate/CMvMonomial.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import Mathlib.Algebra.Group.TypeTags.Basic
import Mathlib.Algebra.GroupWithZero.Nat
import Mathlib.Algebra.Ring.Defs
import Mathlib.Data.Nat.Lattice
import Std.Classes.Ord.Vector
import Batteries.Data.Vector.Basic

/-!
# Computable monomials
Expand Down
31 changes: 18 additions & 13 deletions CompPoly/Multivariate/Lawful.lean
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
/-
Copyright (c) 2025 CompPoly. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Frantisek Silvasi
-/

import CompPoly.Multivariate.Unlawful
import Mathlib.Analysis.Normed.Ring.Lemmas

Expand Down Expand Up @@ -80,7 +86,7 @@ def fromUnlawful (p : Unlawful n R) : Lawful n R :=

@[grind←]
protected lemma grind_fromUnlawful_congr {p₁ p₂ : Unlawful n R}
(h : p₁ = p₂) : Lawful.fromUnlawful p₁ = Lawful.fromUnlawful p₂ := by grind
(h : p₁ = p₂) : Lawful.fromUnlawful p₁ = Lawful.fromUnlawful p₂ := by grind

def C (c : R) : Lawful n R :=
⟨Unlawful.C c, by grind⟩
Expand Down Expand Up @@ -112,8 +118,7 @@ lemma cast_fromUnlawful : (fromUnlawful p.1).1 = p.1 := by
unfold fromUnlawful
rcases p with ⟨p, hp⟩
simp; ext1 x
erw [ExtTreeMap.getElem?_filter, Option.filter_irrel (by intros; specialize hp x; grind)]
rfl
grind

section

Expand All @@ -127,23 +132,23 @@ instance [Add R] : Add (Lawful n R) := ⟨add⟩

@[grind=]
protected lemma grind_add_skip [Add R] {p₁ p₂ : Lawful n R} :
p₁ + p₂ = Lawful.fromUnlawful (p₁.1.add p₂.1) := rfl
p₁ + p₂ = Lawful.fromUnlawful (p₁.1.add p₂.1) := rfl

/--
Note to self: This goes too far.
-/
@[grind=]
protected lemma grind_add_skip_aggressive [Add R] {p₁ p₂ : Lawful n R} :
p₁ + p₂ = fromUnlawful (ExtTreeMap.mergeWith (fun _ c₁ c₂ => c₁ + c₂) p₁.1 p₂.1) := rfl
p₁ + p₂ = fromUnlawful (ExtTreeMap.mergeWith (fun _ c₁ c₂ => c₁ + c₂) p₁.1 p₂.1) := rfl

def mul [Mul R] [Add R] (p₁ p₂ : Lawful n R) : Lawful n R :=
fromUnlawful <| p₁.val * p₂.val

instance [Mul R] [Add R] [Zero R] : Mul (Lawful n R) := ⟨mul⟩

def npow [NatCast R] [Add R] [Mul R] : ℕ → Lawful n R → Lawful n R
| .zero , _ => 1
| .succ n, p => (npow n p) * p
| .zero , _ => 1
| .succ n, p => (npow n p) * p

instance [NatCast R] [Add R] [Mul R] : NatPow (Lawful n R) := ⟨fun e b ↦ npow b e⟩

Expand Down Expand Up @@ -202,19 +207,19 @@ section
variable {n₁ n₂ : ℕ}

def align
(p₁ : Lawful n₁ R) (p₂ : Lawful n₂ R) :
Lawful (n₁ ⊔ n₂) R × Lawful (n₁ ⊔ n₂) R :=
(p₁ : Lawful n₁ R) (p₂ : Lawful n₂ R) :
Lawful (n₁ ⊔ n₂) R × Lawful (n₁ ⊔ n₂) R :=
letI sup := n₁ ⊔ n₂
(
cast (by congr 1; grind) (p₁.extend sup),
cast (by congr 1; grind) (p₂.extend sup)
)

def liftPoly
(f : Lawful (n₁ ⊔ n₂) R →
Lawful (n₁ ⊔ n₂) R →
Lawful (n₁ ⊔ n₂) R)
(p₁ : Lawful n₁ R) (p₂ : Lawful n₂ R) : Lawful (n₁ ⊔ n₂) R :=
(f : Lawful (n₁ ⊔ n₂) R →
Lawful (n₁ ⊔ n₂) R →
Lawful (n₁ ⊔ n₂) R)
(p₁ : Lawful n₁ R) (p₂ : Lawful n₂ R) : Lawful (n₁ ⊔ n₂) R :=
Function.uncurry f (align p₁ p₂)

section
Expand Down
4 changes: 2 additions & 2 deletions CompPoly/Multivariate/MvPolyEquiv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import CompPoly.Multivariate.CMvPolynomial
import Mathlib.Algebra.MvPolynomial.Basic
import Mathlib.Algebra.Ring.Defs
import CompPoly.Multivariate.Lawful
import Std.Classes.Ord.Vector
import Batteries.Data.Vector.Basic

/-!
# `Equiv` and `RingEquiv` between `CMvPolynomial` and `MvPolynomial`.
Expand Down Expand Up @@ -364,7 +364,7 @@ lemma map_mul (a b : CMvPolynomial n R) :
getElem?_neg
]
unfold MvPolynomial.coeff MonoidAlgebra.single
rw [Finsupp.single_eq_of_ne (by symm; exact m_in)]
rw [Finsupp.single_eq_of_ne (by symm; grind)]
split
next h contra =>
exfalso; apply m_in; symm
Expand Down
Loading