diff --git a/Lampe/Lampe.lean b/Lampe/Lampe.lean index 299d5788..46ad2877 100644 --- a/Lampe/Lampe.lean +++ b/Lampe/Lampe.lean @@ -3,7 +3,6 @@ import Lampe.Ast.Extensions import Lampe.Builtin.Arith import Lampe.Builtin.Array import Lampe.Builtin.Basic -import Lampe.Builtin.BigInt import Lampe.Builtin.Bit import Lampe.Builtin.Cast import Lampe.Builtin.Cmp diff --git a/Lampe/Lampe/Ast/Extensions.lean b/Lampe/Lampe/Ast/Extensions.lean index a5698c14..c9882012 100644 --- a/Lampe/Lampe/Ast/Extensions.lean +++ b/Lampe/Lampe/Ast/Extensions.lean @@ -1,7 +1,6 @@ import Lampe.Ast import Lampe.Builtin.Arith import Lampe.Builtin.Array -import Lampe.Builtin.BigInt import Lampe.Builtin.Bit import Lampe.Builtin.Cmp import Lampe.Builtin.Field @@ -30,13 +29,13 @@ def Expr.writeRef (ref : rep tp.ref) (val : rep tp) : Lampe.Expr rep .unit := /-- A utility function for creating a slice expression. -/ @[reducible] -def Expr.mkSlice (n : Nat) (vals : HList rep (List.replicate n tp)) +def Expr.mkSlice (n : Nat) (vals : HList rep (List.replicate n tp)) : Lampe.Expr rep (.slice tp) := Lampe.Expr.callBuiltin _ (.slice tp) .mkSlice vals /-- A utility function for creating an array expression. -/ @[reducible] -def Expr.mkArray (n : Lampe.U 32) (vals : HList rep (List.replicate n.toNat tp)) +def Expr.mkArray (n : Lampe.U 32) (vals : HList rep (List.replicate n.toNat tp)) : Lampe.Expr rep (.array tp n) := Lampe.Expr.callBuiltin _ (.array tp n) .mkArray vals @@ -52,7 +51,7 @@ def Expr.mkRepArray (n : Lampe.U 32) (val : rep tp) : Lampe.Expr rep (.array tp /-- A utility function for creating a tuple expression. -/ @[reducible] -def Expr.mkTuple (name : Option String) (args : HList rep tps) +def Expr.mkTuple (name : Option String) (args : HList rep tps) : Lampe.Expr rep (.tuple name tps) := Lampe.Expr.callBuiltin tps (.tuple name tps) .mkTuple args @@ -64,12 +63,12 @@ def Expr.modifyLens (r : rep $ .ref tp₁) (v : rep tp₂) (lens : Lampe.Lens re /-- A utility function for creating a lens read expression. -/ @[reducible] -def Expr.getLens (v : rep tp₁) (lens : Lampe.Lens rep tp₁ tp₂) +def Expr.getLens (v : rep tp₁) (lens : Lampe.Lens rep tp₁ tp₂) : Lampe.Expr rep tp₂ := Lampe.Expr.callBuiltin _ tp₂ (.getLens lens) h![v] /-- A utility function for creating a member access. -/ @[reducible] def Expr.getMember (v : rep (Tp.tuple name tps)) (member : Lampe.Builtin.Member tp tps) - : Lampe.Expr rep tp := + : Lampe.Expr rep tp := Expr.callBuiltin _ tp (Lampe.Builtin.getMember member) h![v] diff --git a/Lampe/Lampe/Builtin/Basic.lean b/Lampe/Lampe/Builtin/Basic.lean index b5930f18..4852ded0 100644 --- a/Lampe/Lampe/Builtin/Basic.lean +++ b/Lampe/Lampe/Builtin/Basic.lean @@ -70,6 +70,46 @@ end Lampe namespace Lampe.Builtin +inductive genericOmni {A : Type} + (sgn : A → List Tp × Tp) + (desc : {p : Prime} → (a : A) → (args : HList (Tp.denote p) (sgn a).fst) → (Tp.denote p (sgn a).snd) → Prop) + : Omni where + | ok {p st a args Q val} : + (h : desc a args val) + → Q (some (st, val)) + → (genericOmni sgn desc) p st (sgn a).fst (sgn a).snd args Q + | err {p st a args Q} : + (h : ∀ val, ¬ desc a args val) + → Q none + → (genericOmni sgn desc) p st (sgn a).fst (sgn a).snd args Q + +def newGenericBuiltin {A : Type} + (sgn : A → List Tp × Tp) + (desc : {p : Prime} → (a : A) → (args : HList (Tp.denote p) (sgn a).fst) → (Tp.denote p (sgn a).snd) → Prop) : Builtin := { + omni := genericOmni sgn desc + conseq := by + unfold omni_conseq + intros + cases_type genericOmni + · apply genericOmni.ok <;> tauto + · apply genericOmni.err <;> tauto + frame := by + unfold omni_frame + intros + cases_type genericOmni + · apply genericOmni.ok + · assumption + · simp + repeat apply Exists.intro + constructor + · assumption + · tauto + · apply genericOmni.err + · assumption + · simp + tauto +} + inductive genericPureOmni {A : Type} (sgn : A → List Tp × Tp) (desc : {p : Prime} @@ -165,6 +205,17 @@ def assert := newPureBuiltin (fun h![a] => ⟨a == true, fun _ => ()⟩) +/-- +Defines the static assertion builtin that takes a boolean and a message of type `tp : Tp`. +We assume the following: +- If `a == true`, it evaluates to `()`. +- Else, an exception is thrown. +-/ +def staticAssert := newGenericPureBuiltin + (fun tp => ⟨[.bool, tp], .unit⟩) + (fun _ h![a, _] => ⟨a == true, + fun _ => ()⟩) + inductive freshOmni : Omni where | mk {P st tp Q} : (∀ v, Q (some (st, v))) → freshOmni P st [] tp h![] Q diff --git a/Lampe/Lampe/Builtin/BigInt.lean b/Lampe/Lampe/Builtin/BigInt.lean deleted file mode 100644 index 55c66b95..00000000 --- a/Lampe/Lampe/Builtin/BigInt.lean +++ /dev/null @@ -1,94 +0,0 @@ -import Lampe.Builtin.Basic -namespace Lampe.Builtin - -/-- -Defines the equality comparison between two big ints. - -In Noir, this builtin corresponds to `a == b` for values `a`, `b` of type `BigInt`. --/ -def bigIntEq := newPureBuiltin - ⟨[.bi, .bi], .bool⟩ - (fun h![a, b] => ⟨True, - fun _ => a = b⟩) - -/-- -Defines the addition of two bigints `(a b : Int)`. -The builtin is assumed to return `a + b`. - -In Noir, this builtin corresponds to `a + b` for bigints `a`, `b`. --/ -def bigIntAdd := newPureBuiltin - ⟨[.bi, .bi], (.bi)⟩ - (fun h![a, b] => ⟨True, - fun _ => a + b⟩) - -/-- -Defines the subtraction of two bigints `(a b : Int)`. -The builtin is assumed to return `a - b`. - -In Noir, this builtin corresponds to `a - b` for bigints `a`, `b`. --/ -def bigIntSub := newPureBuiltin - ⟨[.bi, .bi], (.bi)⟩ - (fun h![a, b] => ⟨True, - fun _ => a - b⟩) - -/-- -Defines the multiplication of two bigints `(a b : Int)`. -The builtin is assumed to return `a * b`. - -In Noir, this builtin corresponds to `a * b` for bigints `a`, `b`. --/ -def bigIntMul := newPureBuiltin - ⟨[.bi, .bi], (.bi)⟩ - (fun h![a, b] => ⟨True, - fun _ => a * b⟩) - -/-- -Defines the division of two bigints `(a b : Int)`. We make the following assumptions: -- If `b = 0`, an exception is thrown. -- Otherwise, the builtin is assumed to return `a / b`. - -In Noir, this builtin corresponds to `a / b` for bigints `a`, `b`. --/ -def bigIntDiv := newPureBuiltin - ⟨[.bi, .bi], (.bi)⟩ - (fun h![a, b] => ⟨b ≠ 0, - fun _ => a / b⟩) - -/-- -Defines the conversion of a byte slice `bytes : List (U 8)` in little-endian encoding to a `BigInt`. -Modulus parameter is ignored. - -In Noir, this builtin corresponds to `fn from_le_bytes(bytes: [u8], modulus: [u8])` implemented for `BigInt`. - -/ -def bigIntFromLeBytes := newPureBuiltin - ⟨[.slice (.u 8), .slice (.u 8)], .bi⟩ - (fun h![bs, _] => ⟨True, - fun _ => composeFromRadix 256 (bs.map (fun u => u.toNat))⟩) - -/-- -Converts a list `l` to a vector of size `n`s. -- If `n < l.length`, then the output is truncated from the end. -- If `n > l.length`, then the higher indices are populated with `zero`. --/ -def listToVec (l : List α) (zero : α) : List.Vector α n := - ⟨l.takeD n zero, List.takeD_length _ _ _⟩ - -/-- -Defines the conversion of `a : Int` to its byte slice representation `l : Array 32 (U 8)` in little-endian encoding. -For integers that can be represented by less than 32 bytes, the higher bytes of `l` are set to zero. - -We make the following assumptions: -- If `a` cannot be represented by 32 bytes, an exception is thrown. -- Else, the builtin returns `l`. - -In Noir, this builtin corresponds to `fn to_le_bytes(self) -> [u8; 32]` implemented for `BigInt`. --/ -def bigIntToLeBytes := newPureBuiltin - ⟨[.bi], (.array (.u 8) 32)⟩ - (fun h![a] => ⟨bitsCanRepresent 256 a, fun _ => - Builtin.listToVec (decomposeToRadix 256 a.toNat (by linarith)) 0 - |>.map (fun n => BitVec.ofNat 8 n)⟩) - -end Lampe.Builtin diff --git a/Lampe/Lampe/Builtin/Field.lean b/Lampe/Lampe/Builtin/Field.lean index 341d0d0a..36d2fc61 100644 --- a/Lampe/Lampe/Builtin/Field.lean +++ b/Lampe/Lampe/Builtin/Field.lean @@ -6,9 +6,10 @@ For a prime `p`, a field element `a : Fp p`, and a bit size `w : U 32`, this builtin evaluates to `()` if and only if `a < 2^w`, i.e., it can be represented by `w` bits. Otherwise, an exception is thrown. -In Noir, this builtin corresponds to `fn __assert_max_bit_size(self, bit_size: u32)` implemented for `Field`. +In Noir, this builtin corresponds to `fn __assert_max_bit_size(self, bit_size: u32)` implemented +for `Field`. -/ -def fApplyRangeConstraint := newPureBuiltin +def applyRangeConstraint := newPureBuiltin ⟨[.field, (.u 32)], .unit⟩ (fun h![a, w] => ⟨a.val < 2^w.toNat, fun _ => ()⟩) @@ -22,8 +23,8 @@ In Noir, this builtin corresponds to `fn modulus_num_bits() -> u64` implemented -/ def fModNumBits := newPureBuiltin ⟨[.field], (.u 64)⟩ - (@fun p h![_] => ⟨numBits p.val < 2^64, - fun _ => numBits p.val⟩) + (@fun p h![_] => ⟨numBits p.natVal < 2^64, + fun _ => numBits p.natVal⟩) /-- For a prime `p`, a field element `a : Fp p`, this builtin evaluates to the bit representation of `p` in little-endian format. @@ -32,7 +33,7 @@ In Noir, this builtin corresponds to `fn modulus_le_bits() -> [u1]` implemented -/ def fModLeBits := newTotalPureBuiltin ⟨[.field], (.slice (.u 1))⟩ - (@fun p h![_] => decomposeToRadix 2 p.val (by tauto)) + (@fun p h![_] => RadixVec.of ⟨2, by linarith⟩ p.natVal |>.toDigitsBE.toList.reverse) /-- For a prime `p`, a field element `a : Fp p`, this builtin evaluates to the bit representation of `p` in big-endian format. @@ -41,7 +42,7 @@ In Noir, this builtin corresponds to `fn modulus_be_bits() -> [u1]` implemented -/ def fModBeBits := newTotalPureBuiltin ⟨[.field], (.slice (.u 1))⟩ - (@fun p h![_] => .reverse (decomposeToRadix 2 p.val (by tauto))) + (@fun p h![_] => RadixVec.toDigitsBE' 2 p.natVal |>.map fun d => BitVec.ofNatLT d.val d.prop) --.of ⟨2, by linarith⟩ p.natVal |>.toDigitsBE.toList) /-- For a prime `p`, a field element `a : Fp p`, this builtin evaluates to the byte representation of `p` in little-endian format. @@ -50,7 +51,7 @@ In Noir, this builtin corresponds to `fn modulus_le_bytes() -> [u8]` implemented -/ def fModLeBytes := newTotalPureBuiltin ⟨[.field], (.slice (.u 8))⟩ - (@fun p h![_] => decomposeToRadix 256 p.val (by linarith)) + (@fun p h![_] => RadixVec.of ⟨256, by linarith⟩ p.natVal |>.toDigitsBE.toList.reverse) /-- For a prime `p`, a field element `a : Fp p`, this builtin evaluates to the bit representation of `p` in big-endian format. @@ -59,7 +60,7 @@ In Noir, this builtin corresponds to `fn modulus_be_bytes() -> [u8]` implemented -/ def fModBeBytes := newTotalPureBuiltin ⟨[.field], (.slice (.u 8))⟩ - (@fun p h![_] => .reverse (decomposeToRadix 256 p.val (by linarith))) + (@fun p h![_] => RadixVec.of ⟨256, by linarith⟩ p.natVal |>.toDigitsBE.toList) /-- Represents the builtin that converts a field element to an unsigned integer. @@ -90,7 +91,7 @@ Represents the builtin that converts an unsigned integer into a field element. Noir's semantics for this conversion take the unsigned integer and zero-extend it up to the size of the field. We do this by taking our unsigned int as an arbitrary `i ∈ ℤ` and then convert this to a -field element by zero extending. In Noir, this builtin corresponds to `fn as_field(self) -> Field` +field element by zero extending. In Noir, this builtin corresponds to `fn as_field(self) -> Field` implemented for uints of bit size `s`. Integers are also internally represented as field elements with an additional restriction that all @@ -111,7 +112,7 @@ Represents the builtin that converts a signed integer into a field element. Noir's semantics for this conversion take the signed integer and zero-extend it up to the size of the field. We do this by taking our signed int as an arbitrary `i ∈ ℤ` and then convert this to a -field element by zero extending. In Noir, this builtin corresponds to `fn as_field(self) -> Field` +field element by zero extending. In Noir, this builtin corresponds to `fn as_field(self) -> Field` implemented for uints of bit size `s`. Integers are also internally represented as field elements with an additional restriction that all @@ -127,4 +128,86 @@ def iAsField := newGenericTotalPureBuiltin -- source of which can be viewed with `set_option trace.Meta.synthInstance`. (fun _ h![a] => a.toNat) +/-- +Represents the builtin that returns the bit representation of the modulus of a field in +little-endian format. +-/ +def modulusLeBits : Builtin := newTotalPureBuiltin + ⟨[], (.slice (.u 1))⟩ + (fun {p} h![] => RadixVec.of ⟨2, by linarith⟩ p.natVal |>.toDigitsBE.toList.reverse) + +/-- +Represents the builtin that returns the bit representation of the modulus of a field in +big-endian format. +-/ +def modulusBeBits : Builtin := newTotalPureBuiltin + ⟨[], (.slice (.u 1))⟩ + (fun {p} h![] => RadixVec.toDigitsBE' 2 p.natVal |>.map fun d => BitVec.ofNatLT d.val d.prop) + +/-- +Represents the builtin that returns the byte representation of the modulus of a field in +little-endian format. +-/ +def modulusLeBytes : Builtin := newTotalPureBuiltin + ⟨[], (.slice (.u 8))⟩ + (fun {p} h![] => RadixVec.of ⟨256, by linarith⟩ p.natVal |>.toDigitsBE.toList.reverse) + +/-- +Represents the builtin that returns the byte representation of the modulus of a field in +big-endian format. +-/ +def modulusBeBytes : Builtin := newTotalPureBuiltin + ⟨[], (.slice (.u 8))⟩ + (fun {p} h![] => RadixVec.of ⟨256, by linarith⟩ p.natVal |>.toDigitsBE.toList) + +/-- +Represents the builtin that returns the number of bits in the modulus of a field. +-/ +def modulusNumBits : Builtin := newTotalPureBuiltin + ⟨[], (.u 64)⟩ + -- Note: We could use the `log2` definition but this is easier to reason about. + (fun {p} h![] => numBits p.natVal) + +/-- +Represents the builtin that converts a field element to its bit representation in little-endian format. + +Fails if `f ≥ 2^s`. +-/ +def toLeBits : Builtin := newGenericBuiltin + (fun s => ([.field], .array (.u 1) s)) + fun _ h![f] output => + f = RadixVec.ofDigitsBE (r := 2) (output.map BitVec.toFin).reverse + +/-- +Represents the builtin that converts a field element to its bit representation in big-endian format. + +Fails if `f ≥ 2^s`. +-/ +def toBeBits : Builtin := newGenericBuiltin + (fun s => ([.field], .array (.u 1) s)) + fun _ h![f] output => + f = RadixVec.ofDigitsBE (r := 2) (output.map BitVec.toFin) + +/-- +Represents the builtin that converts a field element to its radix representation in little-endian +format. + +Fails if `r ≤ 1` or `f ≥ 2^s`. +-/ +def toLeRadix : Builtin := newGenericBuiltin + (fun s => ([.field, .u 32], .array (.u 8) s)) + fun _ h![f, r] output => + f = RadixVec.ofLimbsBE r.toNat (output.map BitVec.toNat).reverse + +/-- +Represents the builtin that converts a field element to its radix representation in big-endian +format. + +Fails if `r ≤ 1` or `f ≥ 2^s`. +-/ +def toBeRadix : Builtin := newGenericBuiltin + (fun s => ([.field, .u 32], .array (.u 8) s)) + fun _ h![f, r] output => + f = RadixVec.ofLimbsBE r.toNat (output.map BitVec.toNat) + end Lampe.Builtin diff --git a/Lampe/Lampe/Builtin/Helpers.lean b/Lampe/Lampe/Builtin/Helpers.lean index bc499b05..05023d97 100644 --- a/Lampe/Lampe/Builtin/Helpers.lean +++ b/Lampe/Lampe/Builtin/Helpers.lean @@ -1,29 +1,300 @@ -import Mathlib.Tactic.Linarith +import Mathlib.Tactic namespace Lampe -/-- Extends the given list `a` up to length `len` with the default value of `α` -/ -def extList (lst : List α) (len : Nat) (default : α) : List α - := lst ++ (List.replicate (len - lst.length) default) - -@[reducible] -def decomposeToRadix (r : Nat) (v : Nat) (h : r > 1) : List Nat := match v with -| .zero => List.nil -| v' + 1 => List.cons ((v' + 1) % r) (decomposeToRadix r ((v' + 1) / r) h) -decreasing_by - rw [Nat.succ_eq_add_one, Nat.div_lt_iff_lt_mul] - rw [Nat.lt_mul_iff_one_lt_right] - tauto - exact Nat.succ_pos v' - rw [<-Nat.ne_zero_iff_zero_lt] - aesop - -example : decomposeToRadix 10 123 (by linarith) = [3, 2, 1] := by rfl -example : decomposeToRadix 2 123 (by linarith) = [1, 1, 0, 1, 1, 1, 1] := by rfl - -def composeFromRadix (r : Nat) (l : List Nat) : Nat := (l.reverse.foldl (fun acc d => acc * r + d) 0) - -example : (composeFromRadix 10 [3, 2, 1]) = 123 := by rfl -example : (composeFromRadix 2 [1, 1, 0, 1, 1, 1, 1]) = 123 := by rfl +abbrev Radix: Type := { n : Nat // n > 1 } + +abbrev Digit (r : Radix) := Fin r.1 +abbrev RadixVec (r : Radix) (d : Nat) := Fin (r ^ d) + +namespace RadixVec + +variable {r: Radix} {d : Nat} + +def of (r : Radix) (v : Nat) : RadixVec r (Nat.log r.val v + 1) := + ⟨v, Nat.lt_pow_succ_log_self r.prop _⟩ + +def msd (v: RadixVec r (d + 1)): Digit r := + Fin.mk + (v.val / (r.1 ^ d)) + (Nat.div_lt_of_lt_mul v.prop) + +def lsds (v : RadixVec r (d + 1)) : RadixVec r d := + Fin.mk + (v.val - msd v * r ^ d) + (by + simp only [msd] + rw [Nat.div_mul_self_eq_mod_sub_self] + have := Nat.mod_le v (r ^ d) + have : v.val ≥ (v.val - v.val % r^d) := by apply Nat.sub_le + zify [*] + ring_nf + convert Int.emod_lt _ _ using 1 + · simp + · have : r.val ≠ 0 := by + intro hp + have := r.prop + linarith + simp [*] + ) + +theorem msd_lsds_decomposition {v : RadixVec r (d + 1)}: + v = ⟨v.msd.val * r^d + v.lsds.val, by + simp only [msd, lsds] + rw [Nat.div_mul_self_eq_mod_sub_self] + have := Nat.mod_le v (r ^ d) + have : v.val ≥ (v.val - v.val % r^d) := by apply Nat.sub_le + zify [*] + have := v.prop + zify at this + simp [*] + ⟩ := by + simp only [msd, lsds] + apply Fin.eq_of_val_eq + simp only + rw [Nat.div_mul_self_eq_mod_sub_self] + have := Nat.mod_le v (r ^ d) + have : v.val ≥ (v - v % r^d) := by apply Nat.sub_le + zify [*] + simp + +theorem msd_lsds_decomposition_unique {v : RadixVec r (d + 1)} {msd' : Digit r} {lsds' : RadixVec r d} {h}: + v = ⟨msd'.val * r^d + lsds'.val, h⟩ ↔ msd' = v.msd ∧ lsds' = v.lsds := by + apply Iff.intro + · rintro rfl + apply And.intro + · simp only [msd] + apply Fin.eq_of_val_eq + simp only + rw [mul_comm, Nat.mul_add_div, Nat.div_eq_of_lt, Nat.add_zero] + · exact lsds'.prop + · have := r.prop + have := Nat.one_le_pow d r.val (by linarith) + linarith + · simp only [lsds, msd] + apply Fin.eq_of_val_eq + simp only + rw [mul_comm, Nat.mul_add_div, Nat.div_eq_of_lt, mul_comm] + · simp + · exact lsds'.prop + · have := r.prop + have := Nat.one_le_pow d r.val (by linarith) + linarith + · rintro ⟨rfl, rfl⟩ + apply msd_lsds_decomposition + +def toDigitsBE {d} (v : RadixVec r d): List.Vector (Digit r) d := match d with +| 0 => List.Vector.nil +| _ + 1 => v.msd ::ᵥ toDigitsBE v.lsds + +def ofLimbsBE {d} (r : Nat) (v : List.Vector ℕ d): ℕ := match d with +| 0 => 0 +| d + 1 => v.head * r^d + ofLimbsBE r v.tail + +def ofDigitsBE {d} {r : Radix} (v : List.Vector (Digit r) d): RadixVec r d := ⟨ofLimbsBE r.val (v.map (fun d => d.val)), by + induction d with + | zero => simp [ofLimbsBE] + | succ d ih => + simp only [ofLimbsBE, List.Vector.head_map, List.Vector.tail_map] + calc + _ < v.head.val * r.val^d + r.val^d := by + have := ih v.tail + linarith + _ = (v.head.val + 1) * r.val^d := by linarith + _ ≤ r * r.val^d := by + have := Nat.succ_le_of_lt v.head.prop + apply Nat.mul_le_mul_right + assumption + _ = _ := by simp [Nat.pow_succ, Nat.mul_comm] +⟩ + + +def ofDigitsBE' (l : List (Digit r)): Nat := + (RadixVec.ofDigitsBE ⟨l, rfl⟩).val + +@[simp] +theorem ofDigitsBE'_nil : ofDigitsBE' (r:=r) [] = 0 := by rfl + +@[simp] +theorem ofDigitsBE'_cons : ofDigitsBE' (r:=r) (x :: xs) = x * r ^ xs.length + ofDigitsBE' xs := by + rfl + +@[simp] +theorem ofDigitsBE'_append : + ofDigitsBE' (r:=r) (xs ++ ys) = ofDigitsBE' xs * r ^ ys.length + ofDigitsBE' ys := by + induction xs with + | nil => simp + | cons _ _ ih => + simp only [List.cons_append, ofDigitsBE'_cons, List.length_append, ih, + Nat.pow_add, Nat.add_mul + ] + linarith + +def toDigitsBE' (r: Radix) (n : Nat): List (Digit r) := + toDigitsBE ⟨n, Nat.lt_pow_succ_log_self r.prop _⟩ |>.toList + +instance : OfNat Radix 2 where + ofNat := ⟨2, by simp⟩ + +lemma ofDigitsBE_cons {r: Radix} {d: Nat} {x: Digit r} {xs: List.Vector (Digit r) d}: + ofDigitsBE (r:=r) (x ::ᵥ xs) + = (x.val * r.val ^ d + ofDigitsBE xs) := by + rfl + +@[simp] +theorem ofDigitsBE_cons' {r: Radix} {d: Nat} {x: Digit r} {xs: List.Vector (Digit r) d}: + ofDigitsBE (r:=r) (x ::ᵥ xs) = ⟨x.val * r.val ^ d + ofDigitsBE xs, by + calc + _ < x.val * r.val ^ d + r.val ^ d := by simp + _ = (x.val + 1) * r.val ^ d := by linarith + _ ≤ r * r.val ^ d := by + apply Nat.mul_le_mul_right + have := x.prop + linarith + _ = _ := by simp [Nat.pow_succ, Nat.mul_comm] + ⟩ := by + rfl + + +theorem ofDigitsBE_lt {r:Radix} {d: Nat} {l: List.Vector (Digit r) d}: + (ofDigitsBE l).val < r.val ^ d := by + induction d with + | zero => simp + | succ d ih => + cases' l using List.Vector.casesOn with _ x xs + rw [ofDigitsBE_cons, Nat.pow_succ] + have : r.val - 1 + 1 = r.val := by omega + calc + _ ≤ (r - 1) * r.val ^ d + ofDigitsBE xs := by + rw [ + add_le_add_iff_right, + Nat.mul_le_mul_right_iff (by by_contra; simp_all) + ] + apply Nat.le_of_lt_succ + rw [Nat.succ_eq_add_one, this] + exact x.prop + _ < (r - 1) * r.val ^ d + r.val ^ d := by simp [*] + _ = (↑r - 1) * ↑r ^ d + 1 * ↑r ^ d := by simp + _ = _ := by rw [←Nat.add_mul, this, mul_comm] + +theorem ofDigitsBE'_lt {r:Radix} {l: List (Digit r)}: + ofDigitsBE' l < r ^ l.length := by + induction l with + | nil => simp + | cons x xs ih => + rw [ofDigitsBE'_cons, List.length_cons, Nat.pow_succ] + have : r.val - 1 + 1 = r.val := by omega + calc + _ ≤ (r - 1) * r.val ^ xs.length + ofDigitsBE' xs := by + rw [ + add_le_add_iff_right, + Nat.mul_le_mul_right_iff (by by_contra; simp_all) + ] + apply Nat.le_of_lt_succ + rw [Nat.succ_eq_add_one, this] + exact x.prop + _ < (r - 1) * r.val ^ xs.length + r.val ^ xs.length := by linarith + _ = (↑r - 1) * ↑r ^ xs.length + 1 * ↑r ^ xs.length := by simp + _ = _ := by rw [←Nat.add_mul, this, mul_comm] + +theorem ofDigitsBE_toDigitsBE {r: Radix} {n : RadixVec r d}: ofDigitsBE (toDigitsBE n) = n := by + induction d with + | zero => + cases' r with r hr + cases' n with n hn + have : n = 0 := by simp_all + simp [toDigitsBE, ofDigitsBE, ofLimbsBE, this] + | succ d ih => + conv_rhs => rw [msd_lsds_decomposition (v:=n)] + have := Fin.val_eq_of_eq $ ih (n := n.lsds) + simpa [ofDigitsBE, toDigitsBE, ofLimbsBE] + +theorem toDigitsBE_ofDigitsBE {r: Radix} {v : List.Vector (Digit r) d}: toDigitsBE (ofDigitsBE v) = v := by + induction' v using List.Vector.inductionOn + · rfl + · simp only [toDigitsBE, ofDigitsBE_cons'] + congr + · rw [msd] + apply Fin.eq_of_val_eq + simp only + rw [Nat.mul_comm, Nat.mul_add_div] + · rw [Nat.div_eq_of_lt] + · simp + · exact ofDigitsBE_lt + · cases r + apply Nat.lt_of_succ_le + apply Nat.one_le_pow + linarith + · rename_i h + simp only [lsds] + conv_rhs => rw [←h] + congr 1 + apply Fin.eq_of_val_eq + simp only [msd] + conv_lhs => + arg 2 + arg 1 + rw [Nat.mul_comm] + rw [Nat.mul_add_div] + · rw [Nat.div_eq_of_lt] + · simp + · exact ofDigitsBE_lt + · cases r + apply Nat.lt_of_succ_le + apply Nat.one_le_pow + linarith + +theorem ofDigitsBE'_toDigitsBE' {r: Radix} {n : Nat}: + ofDigitsBE' (toDigitsBE' r n) = n := by + simp only [toDigitsBE', ofDigitsBE'] + generalize_proofs hn hlen + conv_rhs => change (Fin.mk n hn).val; rw [←ofDigitsBE_toDigitsBE (n := ⟨n, hn⟩)] + congr <;> simp + +theorem ofDigitsBE_mono {r: Radix} {l₁ l₂: List.Vector (Digit r) d}: + l₁.toList < l₂.toList → ofDigitsBE l₁ < ofDigitsBE l₂ := by + intro hp + induction d with + | zero => + cases List.Vector.eq_nil l₁ + cases List.Vector.eq_nil l₂ + simp at hp + | succ d ih => + cases' l₁ using List.Vector.casesOn with _ h₁ + cases' l₂ using List.Vector.casesOn with _ h₂ + cases' hp + · rename_i t₁ t₂ hh + rw [Fin.lt_def] at hh + simp only [ofDigitsBE_cons', List.Vector.head, Fin.mk_lt_mk] + calc + _ < h₁.val * r.val ^ d + r.val ^ d := by simp + _ = (h₁.val + 1) * r.val ^ d := by linarith + _ ≤ h₂ * r.val ^ d := by + apply Nat.mul_le_mul_right + linarith + _ ≤ _ := by linarith + · simp only [ofDigitsBE_cons', List.Vector.head, Fin.mk_lt_mk, List.Vector.tail] + rename_i _ _ hp + have := ih hp + rw [Fin.lt_def] at this + linarith + +theorem ofDigitsBE'_mono {r: Radix} {l₁ l₂: List (Digit r)}: l₁.length = l₂.length → l₁ < l₂ → ofDigitsBE' l₁ < ofDigitsBE' l₂ := by + intro hl hlt + have := ofDigitsBE_mono (l₁ := ⟨l₁, hl⟩) (l₂ := ⟨l₂, rfl⟩) hlt + rw [Fin.lt_def] at this + simp only [ofDigitsBE'] + convert this + +theorem ofDigitsBE'_toList {r : Radix} {l : List.Vector (Digit r) d}: ofDigitsBE' l.toList = (ofDigitsBE l).val := by + simp only [ofDigitsBE'] + congr <;> simp + +theorem ofDigitsBE'_subtype_eq {r : Radix} {l : List.Vector (Digit r) d} (hlt : ofDigitsBE' l.toList < r.val^d) : + (⟨ofDigitsBE' l.toList, hlt⟩ : RadixVec r d) = ofDigitsBE l := by + ext + exact ofDigitsBE'_toList + +end RadixVec end Lampe diff --git a/Lampe/Lampe/Builtin/Runtime.lean b/Lampe/Lampe/Builtin/Runtime.lean index 0c32f4e4..062f2c21 100644 --- a/Lampe/Lampe/Builtin/Runtime.lean +++ b/Lampe/Lampe/Builtin/Runtime.lean @@ -2,6 +2,11 @@ import Lampe.Builtin.Basic namespace Lampe.Builtin +/-- +Returns whether the execution is performed in an unconstrained context. + +Note we always return false, as otherwise we would be unable to reason about the code. +-/ def isUnconstrained := newTotalPureBuiltin ([], .bool) (fun _ => false) diff --git a/Lampe/Lampe/Builtin/Stubs.lean b/Lampe/Lampe/Builtin/Stubs.lean index 3e184953..d4a30fd9 100644 --- a/Lampe/Lampe/Builtin/Stubs.lean +++ b/Lampe/Lampe/Builtin/Stubs.lean @@ -30,7 +30,6 @@ def stub : Builtin := { -- Note that many of the names here explicitly do not follow the Lean naming scheme, as they have -- to match the name in extracted code that comes from Noir. def aes128Encrypt := stub -def applyRangeConstraint := stub def arrayRefcount := stub def asWitness := stub def assertConstant := stub @@ -45,22 +44,11 @@ def embeddedCurveAdd := stub def fmtstrAsCtstring := stub def keccakf1600 := stub def mkFormatString := stub -def modulesLeBytes := stub -def modulusBeBits := stub -def modulusBeBytes := stub -def modulusLeBits := stub -def modulusLeBytes := stub -def modulusNumBits := stub def multiScalarMul := stub def poseidon2Permutation := stub def recursiveAggregation := stub def sha256Compression := stub def sliceRefcount := stub -def staticAssert := stub def strAsCtstring := stub -def toBeBits := stub -def toBeRadix := stub -def toLeBits := stub -def toLeRadix := stub end Lampe.Builtin diff --git a/Lampe/Lampe/Hoare/Builtins.lean b/Lampe/Lampe/Hoare/Builtins.lean index 35147a0f..ceb5a95e 100644 --- a/Lampe/Lampe/Hoare/Builtins.lean +++ b/Lampe/Lampe/Hoare/Builtins.lean @@ -2,7 +2,6 @@ import Lampe.Hoare.SepTotal import Lampe.Builtin.Arith import Lampe.Builtin.Array -import Lampe.Builtin.BigInt import Lampe.Builtin.Bit import Lampe.Builtin.Cast import Lampe.Builtin.Cmp @@ -72,6 +71,25 @@ def genericTotalPureBuiltin_intro {A : Type} {sgn : A → List Tp × Tp} {desc} any_goals rfl tauto +theorem genericBuiltin_intro {A : Type} {a : A} {sgn desc args} : + STHoare p Γ + ⟦⟧ + (.callBuiltin (sgn a).fst (sgn a).snd (Builtin.newGenericBuiltin sgn desc) args) + (fun v => desc a args v) := by + intro H st p + constructor + unfold mapToValHeapCondition + simp [Builtin.newGenericBuiltin, mapToValHeapCondition] + + by_cases h: ∃v, desc a args v + · cases' h with v h + apply Builtin.genericOmni.ok + · exact h + · exists ∅, st, by simp, by simp, by simp [SLP.lift, h], st, ∅ + simp_all + · apply Builtin.genericOmni.err + · simp_all + · simp -- Arithmetics @@ -199,46 +217,6 @@ theorem asSlice_intro : STHoarePureBuiltin p Γ Builtin.asSlice (by tauto) h![ar apply pureBuiltin_intro_consequence <;> try tauto tauto --- BigInt - -theorem bigIntEq_intro : STHoarePureBuiltin p Γ Builtin.bigIntEq (by tauto) h![a, b] (a := ()) := by - simp only [STHoarePureBuiltin, SLP.exists_pure] - apply pureBuiltin_intro_consequence <;> try tauto - tauto - -theorem bigIntAdd_intro : STHoarePureBuiltin p Γ Builtin.bigIntAdd (by tauto) h![a, b] (a := ()) := by - simp only [STHoarePureBuiltin, SLP.exists_pure] - apply pureBuiltin_intro_consequence <;> try tauto - tauto - -theorem bigIntSub_intro : STHoarePureBuiltin p Γ Builtin.bigIntSub (by tauto) h![a, b] (a := ()) := by - simp only [STHoarePureBuiltin, SLP.exists_pure] - apply pureBuiltin_intro_consequence <;> try tauto - tauto - -theorem bigIntMul_intro : STHoarePureBuiltin p Γ Builtin.bigIntMul (by tauto) h![a, b] (a := ()) := by - simp only [STHoarePureBuiltin, SLP.exists_pure] - apply pureBuiltin_intro_consequence <;> try tauto - tauto - -theorem bigIntDiv_intro : STHoarePureBuiltin p Γ Builtin.bigIntDiv (by tauto) h![a, b] (a := ()) := by - simp only [STHoarePureBuiltin, SLP.exists_pure] - apply pureBuiltin_intro_consequence <;> try tauto - tauto - -theorem bigIntFromLeBytes_intro : STHoarePureBuiltin p Γ Builtin.bigIntFromLeBytes (by tauto) h![bs, mbs] (a := ()) := by - simp only [STHoarePureBuiltin, SLP.exists_pure] - apply pureBuiltin_intro_consequence <;> try tauto - tauto - -theorem bigIntToLeBytes_intro : STHoarePureBuiltin p Γ Builtin.bigIntToLeBytes (by tauto) h![a] (a := ()) := by - simp only [STHoarePureBuiltin, SLP.exists_pure] - apply pureBuiltin_intro_consequence <;> try rfl - . dsimp only - intro h - use h - exact () - -- Bitwise def bNot_intro := genericTotalPureBuiltin_intro Builtin.bNot rfl () @@ -256,14 +234,14 @@ def iAnd_intro := genericTotalPureBuiltin_intro Builtin.iAnd rfl def iOr_intro := genericTotalPureBuiltin_intro Builtin.iOr rfl def iXor_intro := genericTotalPureBuiltin_intro Builtin.iXor rfl -theorem iShl_intro {p Γ W} +theorem iShl_intro {p Γ W} {a b : Tp.denote p (.i W)} : STHoarePureBuiltin p Γ Builtin.iShl (by tauto) h![a, b] (a := W) := by simp only [STHoarePureBuiltin, SLP.exists_pure] apply pureBuiltin_intro_consequence <;> try tauto tauto -theorem iShr_intro {p Γ W} +theorem iShr_intro {p Γ W} {a b : Tp.denote p (.i W)} : STHoarePureBuiltin p Γ Builtin.iShr (by tauto) h![a, b] (a := W) := by simp only [STHoarePureBuiltin, SLP.exists_pure] @@ -306,6 +284,16 @@ theorem uLt_intro : STHoarePureBuiltin p Γ Builtin.uLt (by tauto) h![a, b] := b apply pureBuiltin_intro_consequence <;> try tauto tauto +theorem uLeq_intro : STHoarePureBuiltin p Γ Builtin.uLeq (by tauto) h![a, b] := by + simp only [STHoarePureBuiltin, SLP.exists_pure] + apply pureBuiltin_intro_consequence <;> try tauto + tauto + +theorem uNeq_intro : STHoarePureBuiltin p Γ Builtin.uNeq (by tauto) h![a, b] := by + simp only [STHoarePureBuiltin, SLP.exists_pure] + apply pureBuiltin_intro_consequence <;> try tauto + tauto + theorem iLt_intro : STHoarePureBuiltin p Γ Builtin.iLt (by tauto) h![a, b] := by simp only [STHoarePureBuiltin, SLP.exists_pure] apply pureBuiltin_intro_consequence <;> try tauto @@ -323,8 +311,8 @@ theorem iGt_intro : STHoarePureBuiltin p Γ Builtin.iGt (by tauto) h![a, b] := b -- Field misc -theorem fApplyRangeConstraint_intro : - STHoarePureBuiltin p Γ Builtin.fApplyRangeConstraint (by tauto) h![f, c] (a := ()) := by +theorem applyRangeConstraint_intro : + STHoarePureBuiltin p Γ Builtin.applyRangeConstraint (by tauto) h![f, c] (a := ()) := by simp only [STHoarePureBuiltin, SLP.exists_pure] apply pureBuiltin_intro_consequence <;> try tauto tauto @@ -592,6 +580,51 @@ theorem getLens_intro {lens : Lens (Tp.denote p) tp₁ tp₂} : apply SLP.ent_star_top at h simp_all +-- Field + +theorem toLeBits_intro {f : Tp.denote p Tp.field} : + STHoare p Γ ⟦⟧ (.callBuiltin [Tp.field] ((Tp.u 1).array N) Builtin.toLeBits h![f]) + fun output => f = RadixVec.ofDigitsBE (r := 2) (output.map BitVec.toFin).reverse := by + apply STHoare.consequence + case h_hoare => + apply genericBuiltin_intro (sgn := fun s => ([.field], .array (.u 1) s)) + · apply SLP.entails_self + · intro + apply SLP.entails_self + +theorem toBeBits_intro {f : Tp.denote p Tp.field} : + STHoare p Γ ⟦⟧ (.callBuiltin [Tp.field] ((Tp.u 1).array s) Builtin.toBeBits h![f]) + fun output => f = RadixVec.ofDigitsBE (r := 2) (output.map BitVec.toFin) := by + apply STHoare.consequence + case h_hoare => + apply genericBuiltin_intro (sgn := fun s => ([.field], .array (.u 1) s)) + · apply SLP.entails_self + · intro + apply SLP.entails_self + +theorem toLeRadix_intro {f : Tp.denote p Tp.field} {r : Tp.denote p (Tp.u 32)} : + STHoare p Γ ⟦⟧ (.callBuiltin [Tp.field, Tp.u 32] ((Tp.u 8).array s) Builtin.toLeRadix h![f, r]) + fun output => f = RadixVec.ofLimbsBE r.toNat (output.map BitVec.toNat).reverse := by + apply STHoare.consequence + case h_hoare => + apply genericBuiltin_intro (sgn := fun s => ([.field, .u 32], .array (.u 8) s)) + · apply SLP.entails_self + · simp only + intro + apply SLP.entails_self + + +theorem toBeRadix_intro {f : Tp.denote p Tp.field} {r : Tp.denote p (Tp.u 32)} : + STHoare p Γ ⟦⟧ (.callBuiltin [Tp.field, Tp.u 32] ((Tp.u 8).array s) Builtin.toBeRadix h![f, r]) + fun output => f = RadixVec.ofLimbsBE r.toNat (output.map BitVec.toNat) := by + apply STHoare.consequence + case h_hoare => + apply genericBuiltin_intro (sgn := fun s => ([.field, .u 32], .array (.u 8) s)) + · apply SLP.entails_self + · simp only + intro + apply SLP.entails_self + -- Misc theorem assert_intro : STHoarePureBuiltin p Γ Builtin.assert (by tauto) h![a] (a := ()) := by @@ -599,6 +632,11 @@ theorem assert_intro : STHoarePureBuiltin p Γ Builtin.assert (by tauto) h![a] ( apply pureBuiltin_intro_consequence <;> try tauto tauto +theorem staticAssert_intro : STHoarePureBuiltin p Γ Builtin.staticAssert (by tauto) (a := tp) h![c, b] := by + simp only [STHoarePureBuiltin, SLP.exists_pure] + apply pureBuiltin_intro_consequence <;> try tauto + tauto + theorem cast_intro [Builtin.CastTp tp tp'] : STHoare p Γ ⟦⟧ (.callBuiltin [tp] tp' .cast h![v]) (fun v' => v' = @Builtin.CastTp.cast tp tp' _ p v) := by unfold STHoare THoare diff --git a/Lampe/Lampe/Syntax/Builders.lean b/Lampe/Lampe/Syntax/Builders.lean index 5f0251e1..3d49cf1c 100644 --- a/Lampe/Lampe/Syntax/Builders.lean +++ b/Lampe/Lampe/Syntax/Builders.lean @@ -5,7 +5,6 @@ import Lampe.Ast import Lampe.Ast.Extensions import Lampe.Builtin.Arith import Lampe.Builtin.Array -import Lampe.Builtin.BigInt import Lampe.Builtin.Bit import Lampe.Builtin.Cmp import Lampe.Builtin.Field @@ -579,7 +578,7 @@ def makeTraitDef [MonadUtil m] : Syntax → m (List $ TSyntax `command) -- def makeStructDef [MonadUtil m] (name : TSyntax `noir_ident): Syntax → m (TSyntax `term) -/-- +/-- Extracts any deprecation message that may have been attached to a definition. Returns `some ""` if the entity is deprecated but has no message, and returns `none` if the entity @@ -587,6 +586,5 @@ is not deprecated at all. -/ def parseDeprecatedMessage [MonadUtil m] : (stx : Syntax) → m (Option String) | `(noir_depr?|[[deprecated]]) => pure $ some "" -| `(noir_depr?|[[deprecated $msg:str]]) => pure msg.getString +| `(noir_depr?|[[deprecated $msg:str]]) => pure msg.getString | _ => pure none - diff --git a/Lampe/Lampe/Syntax/Delab.lean b/Lampe/Lampe/Syntax/Delab.lean index a2da715b..2b95d547 100644 --- a/Lampe/Lampe/Syntax/Delab.lean +++ b/Lampe/Lampe/Syntax/Delab.lean @@ -37,7 +37,6 @@ partial def ppTp (expr : Lean.Expr) : DelabM <| TSyntax `noir_type := do | Tp.i n => let i := mkIdent <| .mkSimple s!"i{← ppExpr n}" return ← `(noir_type|$(⟨i⟩):noir_type) - | Tp.bi => return ⟨mkIdent `bi⟩ | Tp.field => return ⟨mkIdent `Field⟩ | Tp.bool => return ⟨mkIdent `bool⟩ | Tp.unit => return ⟨mkIdent `Unit⟩ @@ -324,7 +323,7 @@ def delabLam : Delab := whenDelabExprOption getExpr >>= fun expr => pure (args, body) | _ => throwError "unable to parse args of Lambda" - let funArgs ← args.getElems.zip argTps.toArray |>.mapM fun (arg, tp) => do + let funArgs ← args.getElems.zip argTps.toArray |>.mapM fun (arg, tp) => do `(noir_lam_param|$(⟨arg⟩):noir_pat : $(← ppTp tp)) return ← ``(⸨fn($funArgs,*) : $(←ppTp outTp) := $(⟨extractInnerLampeExpr body⟩)⸩) @@ -458,7 +457,7 @@ def delabLampeConstU : Delab := whenDelabExprOption getExpr >>= fun expr => @[app_delab Lampe.Expr.litStr] def delabLampeLitStr : Delab := whenDelabExprOption getExpr >>= fun expr => whenFullyApplied expr do let args := expr.getAppArgs - let Expr.lit (Literal.strVal noirStr) := args[2]!.getAppArgs[0]! + let Expr.lit (Literal.strVal noirStr) := args[2]!.getAppArgs[0]! | throwError "Expected string literal as argument but none found" return ←``(⸨ $(⟨Syntax.mkStrLit noirStr⟩) ⸩) diff --git a/Lampe/Lampe/Tactic/Steps.lean b/Lampe/Lampe/Tactic/Steps.lean index 65b42c09..c295c2de 100644 --- a/Lampe/Lampe/Tactic/Steps.lean +++ b/Lampe/Lampe/Tactic/Steps.lean @@ -81,8 +81,9 @@ def getClosingTerm (val : Lean.Expr) : TacticM (Option (TSyntax `term)) := withT match n with | ``Lampe.Builtin.fresh => return some (←``(fresh_intro)) | ``Lampe.Builtin.assert => return some (←``(assert_intro)) + | ``Lampe.Builtin.staticAssert => return some (←``(staticAssert_intro)) - | ``Lampe.Builtin.bNot => return some (←``(genericTotalPureBuiltin_intro Builtin.bNot rfl)) + | ``Lampe.Builtin.bNot => return some (←``(genericTotalPureBuiltin_intro Builtin.bNot (a := ()) rfl)) | ``Lampe.Builtin.bAnd => return some (←``(genericTotalPureBuiltin_intro Builtin.bAnd rfl)) | ``Lampe.Builtin.bXor => return some (←``(genericTotalPureBuiltin_intro Builtin.bXor rfl)) | ``Lampe.Builtin.bOr => return some (←``(genericTotalPureBuiltin_intro Builtin.bOr rfl)) @@ -128,6 +129,7 @@ def getClosingTerm (val : Lean.Expr) : TacticM (Option (TSyntax `term)) := withT | ``Lampe.Builtin.uDiv => return some (←``(uDiv_intro)) | ``Lampe.Builtin.uSub => return some (←``(uSub_intro)) | ``Lampe.Builtin.uRem => return some (←``(uRem_intro)) + | ``Lampe.Builtin.uLeq => return some (←``(uLeq_intro)) | ``Lampe.Builtin.iAdd => return some (←``(iAdd_intro)) | ``Lampe.Builtin.iMul => return some (←``(iMul_intro)) @@ -136,10 +138,16 @@ def getClosingTerm (val : Lean.Expr) : TacticM (Option (TSyntax `term)) := withT | ``Lampe.Builtin.iRem => return some (←``(iRem_intro)) | ``Lampe.Builtin.iNeg => return some (←``(iNeg_intro)) + | ``Lampe.Builtin.modulusLeBits => return some (←``(genericTotalPureBuiltin_intro Builtin.modulusLeBits (a := ()) rfl)) + | ``Lampe.Builtin.modulusBeBits => return some (←``(genericTotalPureBuiltin_intro Builtin.modulusBeBits (a := ()) rfl)) + | ``Lampe.Builtin.modulusLeBytes => return some (←``(genericTotalPureBuiltin_intro Builtin.modulusLeBytes (a := ()) rfl)) + | ``Lampe.Builtin.modulusBeBytes => return some (←``(genericTotalPureBuiltin_intro Builtin.modulusBeBytes (a := ()) rfl)) + | ``Lampe.Builtin.modulusNumBits => return some (←``(genericTotalPureBuiltin_intro Builtin.modulusNumBits (a := ()) rfl)) + | ``Lampe.Builtin.strAsBytes => return some (←``(strAsBytes_intro)) | ``Lampe.Builtin.arrayAsStrUnchecked => return some (←``(arrayAsStrUnchecked_intro)) - | ``Lampe.Builtin.isUnconstrained => + | ``Lampe.Builtin.isUnconstrained => return some (←``(genericTotalPureBuiltin_intro Builtin.isUnconstrained rfl)) -- Array builtins @@ -173,8 +181,8 @@ def getClosingTerm (val : Lean.Expr) : TacticM (Option (TSyntax `term)) := withT | ``Lampe.Builtin.ref => return some (←``(ref_intro)) | ``Lampe.Builtin.readRef => return some (←``(readRef_intro)) - | ``Lampe.Builtin.fApplyRangeConstraint => return some (←``(fApplyRangeConstraint_intro)) - | ``Lampe.Builtin.fModBeBits => return some (←``(genericTotalPureBuiltin_intro Builtin.fModBeBits rfl)) + -- Field builtins + | ``Lampe.Builtin.applyRangeConstraint => return some (←``(applyRangeConstraint_intro)) | ``Lampe.Builtin.fModBeBits => return some (←``(genericTotalPureBuiltin_intro Builtin.fModBeBits rfl)) | ``Lampe.Builtin.fModBeBytes => return some (←``(genericTotalPureBuiltin_intro Builtin.fModBeBytes rfl)) | ``Lampe.Builtin.fModLeBits => return some (←``(genericTotalPureBuiltin_intro Builtin.fModLeBits rfl)) | ``Lampe.Builtin.fModLeBytes => return some (←``(genericTotalPureBuiltin_intro Builtin.fModLeBytes rfl)) @@ -183,6 +191,10 @@ def getClosingTerm (val : Lean.Expr) : TacticM (Option (TSyntax `term)) := withT | ``Lampe.Builtin.iFromField => return some (←``(genericTotalPureBuiltin_intro Builtin.iFromField rfl)) | ``Lampe.Builtin.uAsField => return some (←``(genericTotalPureBuiltin_intro Builtin.uAsField rfl)) | ``Lampe.Builtin.uFromField => return some (←``(genericTotalPureBuiltin_intro Builtin.uFromField rfl)) + | ``Lampe.Builtin.toLeBits => return some (←``(toLeBits_intro)) + | ``Lampe.Builtin.toBeBits => return some (←``(toBeBits_intro)) + | ``Lampe.Builtin.toLeRadix => return some (←``(toLeRadix_intro)) + | ``Lampe.Builtin.toBeRadix => return some (←``(toBeRadix_intro)) -- Tuple/struct builtins | ``Lampe.Builtin.makeData => return some (← ``(genericTotalPureBuiltin_intro (a := (_, _)) Builtin.makeData rfl)) @@ -322,8 +334,8 @@ partial def doPlucks (goal : MVarId) (pre : SLTerm) : TacticM (MVarId × List MV | .star _ (.lift _) r => let plucker ← mkConstWithFreshMVarLevels ``pluck_pures_destructively let goal :: impls₁ ← goal.apply plucker | throwError "unexpected goals in pluck_pures_destructively" - let (goal, impls₂) ← doPlucks goal r let goal ← introPure goal + let (goal, impls₂) ← doPlucks goal r pure (goal, impls₁ ++ impls₂) | .lift _ => let plucker ← mkConstWithFreshMVarLevels ``pluck_final_pure_destructively diff --git a/Lampe/Lampe/Tp.lean b/Lampe/Lampe/Tp.lean index b38b43be..2c3c498d 100644 --- a/Lampe/Lampe/Tp.lean +++ b/Lampe/Lampe/Tp.lean @@ -22,7 +22,6 @@ variable (p : Prime) inductive Tp where | u (size : Nat) | i (size : Nat) -| bi -- BigInt | bool | unit | str (size: U 32) @@ -162,7 +161,6 @@ def Tp.denoteArgs : List Tp → Type def Tp.denote : Tp → Type | .u n => U n | .i n => I n -| .bi => Int | .bool => Bool | .unit => Unit | .str n => NoirStr n.toNat @@ -202,7 +200,6 @@ def delabTpDenote : Delab := whenDelabTp getExpr >>= fun expr => whenFullyApplie | Tp.field => mkAppM `Lampe.Fp #[p] | Tp.u n => mkAppM `Lampe.U #[n] | Tp.i n => mkAppM `Lampe.I #[n] - | Tp.bi => mkAppM `Int #[] | Tp.bool => mkAppM `Bool #[] | Tp.unit => mkAppM `Unit #[] | Tp.str n => @@ -238,7 +235,7 @@ def Tp.zeroArgs (args : List Tp) : HList (Tp.denote p) args := def Tp.zero (tp : Tp) : Tp.denote p tp := match tp with -| .u _ | .i _ | .bi | .field => 0 +| .u _ | .i _ | .field => 0 | .bool => False | .unit => () | .str n => List.Vector.replicate n.toNat 0 diff --git a/src/lean/emit/context.rs b/src/lean/emit/context.rs index 440e35e1..25a07686 100644 --- a/src/lean/emit/context.rs +++ b/src/lean/emit/context.rs @@ -123,13 +123,9 @@ impl EmitContext { /// [`LEAN_QUOTE_END`] if it is necessary. #[must_use] pub fn quote_name_if_needed(text: &str) -> String { - if !text.starts_with(LEAN_QUOTE_START) || !text.ends_with(LEAN_QUOTE_END) { - if text.contains("::") { - let text = text.replace(LEAN_QUOTE_START, "").replace(LEAN_QUOTE_END, ""); - format!("{LEAN_QUOTE_START}{text}{LEAN_QUOTE_END}") - } else { - text.to_string() - } + if text.contains("::") { + let text = text.replace(LEAN_QUOTE_START, "").replace(LEAN_QUOTE_END, ""); + format!("{LEAN_QUOTE_START}{text}{LEAN_QUOTE_END}") } else { text.to_string() } @@ -159,6 +155,10 @@ mod test { EmitContext::quote_name_if_needed("foo::bar42"), format!("{LEAN_QUOTE_START}foo::bar42{LEAN_QUOTE_END}") ); + assert_eq!( + EmitContext::quote_name_if_needed("«std-1.0.0-beta.12»::slice::«all»"), + format!("{LEAN_QUOTE_START}std-1.0.0-beta.12::slice::all{LEAN_QUOTE_END}") + ); } #[test] diff --git a/src/lean/mod.rs b/src/lean/mod.rs index 60658d1a..4a67c7eb 100644 --- a/src/lean/mod.rs +++ b/src/lean/mod.rs @@ -14,7 +14,7 @@ pub const LEAN_QUOTE_START: &str = "«"; pub const LEAN_QUOTE_END: &str = "»"; /// Keywords that are built into Lean's syntax, so we need to quote them -pub const LEAN_KEYWORDS: &[&str] = &["from", "meta"]; +pub const LEAN_KEYWORDS: &[&str] = &["all", "from", "meta"]; fn conflicts_with_lean_keyword(text: &str) -> bool { LEAN_KEYWORDS.contains(&text) diff --git a/stdlib/lampe/Stdlib/Cmp.lean b/stdlib/lampe/Stdlib/Cmp.lean index 226d5d35..f157dcca 100644 --- a/stdlib/lampe/Stdlib/Cmp.lean +++ b/stdlib/lampe/Stdlib/Cmp.lean @@ -1068,7 +1068,6 @@ theorem slice_ord_pure_spec {p T a b} Ordering.then_of_ne_eq] steps [equal_spec, Eq.ordering_eq_spec] - · trivial apply STHoare.ite_intro · intro @@ -1102,7 +1101,6 @@ theorem tuple2_ord_pure_spec {p A B self other} resolve_trait steps [A_ord_f, equal_spec, Eq.ordering_eq_spec] - · exact () apply STHoare.ite_intro · intro diff --git a/stdlib/lampe/Stdlib/Ext.lean b/stdlib/lampe/Stdlib/Ext.lean new file mode 100644 index 00000000..5effa10b --- /dev/null +++ b/stdlib/lampe/Stdlib/Ext.lean @@ -0,0 +1,4 @@ +import Stdlib.Ext.List +import Stdlib.Ext.BitVec +import Stdlib.Ext.Vector +import Stdlib.Ext.Nat diff --git a/stdlib/lampe/Stdlib/Ext/BitVec.lean b/stdlib/lampe/Stdlib/Ext/BitVec.lean new file mode 100644 index 00000000..f7867360 --- /dev/null +++ b/stdlib/lampe/Stdlib/Ext/BitVec.lean @@ -0,0 +1,41 @@ +import Lampe + +/-! +# BitVec Extensions + +Mathlib-style extensions for BitVec operations. +-/ + +open Lampe + +instance : Std.Total (fun (x1 : U s) x2 => ¬x1 < x2) := { total := by simp [BitVec.le_total] } + +instance : Std.Antisymm (fun (x1 : U s) x2 => ¬x1 < x2) where + antisymm _ _ _ _ := by + simp_all only [BitVec.not_lt] + apply BitVec.le_antisymm <;> assumption + +instance : Std.Irrefl (fun (x1 : U s) x2 => x1 < x2) where + irrefl _ := BitVec.lt_irrefl _ + +lemma U.cases_one (i : U 1) : i = 0 ∨ i = 1 := by fin_cases i <;> simp + +@[simp] +theorem BitVec.toFin_ofFin_comp (n : ℕ) : + (fun (i : BitVec n) => i.toFin) ∘ BitVec.ofFin = id := by + funext x + simp [BitVec.toFin_ofFin] + +@[simp] +theorem BitVec.ofFin_toFin_comp (n : ℕ) : + BitVec.ofFin ∘ (fun (i : BitVec n) => i.toFin) = id := by + funext x + rfl + +lemma U32.index_toNat (len i : ℕ) (hlen : len < 2^32) (hi : i < 2^32) (hi_lt : i < len) : + (({ toFin := ⟨len, hlen⟩ } : U 32) - 1 - (BitVec.ofNatLT i hi)).toNat = len - 1 - i := by + have h1 : ({ toFin := ⟨len, hlen⟩ } : U 32).toNat = len := by simp + have h2 : (BitVec.ofNatLT i hi : U 32).toNat = i := by simp + have h3 : (1 : U 32).toNat = 1 := by decide + simp only [BitVec.toNat_sub, h1, h2, h3, Nat.reducePow] + omega diff --git a/stdlib/lampe/Stdlib/Ext/List.lean b/stdlib/lampe/Stdlib/Ext/List.lean new file mode 100644 index 00000000..4def899b --- /dev/null +++ b/stdlib/lampe/Stdlib/Ext/List.lean @@ -0,0 +1,64 @@ +import Mathlib.Tactic + +/-! +# List Extensions + +Mathlib-style extensions for List operations, primarily for lexicographic comparisons. +-/ + +theorem List.lt_append_of_lt [DecidableEq α] [LT α] [DecidableLT α] + (l₁ l₂ l₃ l₄ : List α) : + l₁.length = l₂.length → l₁ < l₂ → l₁ ++ l₃ < l₂ ++ l₄ := by + intro hl hlt + rw [List.lt_iff_exists] at hlt + simp only [hl, List.take_length, lt_self_iff_false, and_false, exists_idem, false_or] at hlt + rcases hlt with ⟨i, h, _⟩ + rw [List.lt_iff_exists] + right + exists + i, + (by simp only [List.length_append]; linarith), + (by simp only [List.length_append]; linarith) + apply And.intro + · intro j hj + have : j < l₁.length := by linarith + have : j < l₂.length := by linarith + simp_all + · simp_all + +theorem List.take_succ_lt_of_take_lt [DecidableEq α] [LT α] [DecidableLT α] {l₁ l₂ : List α} + (hi₁ : i < l₁.length) (hi₂ : i < l₂.length) (hlt : l₁.take i < l₂.take i) : + l₁.take (i + 1) < l₂.take (i + 1) := by + rw [List.take_succ_eq_append_getElem hi₁, List.take_succ_eq_append_getElem hi₂] + apply List.lt_append_of_lt + · simp [ + Nat.min_eq_left (Nat.le_of_lt hi₁), Nat.min_eq_left (Nat.le_of_lt hi₂) + ] + · exact hlt + +theorem List.take_succ_lt_of_getElem_lt [DecidableEq α] [LT α] [DecidableLT α] + {l₁ l₂ : List α} + (hi₁ : i < l₁.length) (hi₂ : i < l₂.length) + (heq : l₁.take i = l₂.take i) (hlt : l₁[i] < l₂[i]) : + l₁.take (i + 1) < l₂.take (i + 1) := by + rw [ + List.take_succ_eq_append_getElem hi₁, + List.take_succ_eq_append_getElem hi₂, + heq + ] + exact List.append_left_lt (List.cons_lt_cons_iff.mpr (Or.inl hlt)) + +theorem List.lt_of_take_lt [DecidableEq α] [LT α] [DecidableLT α] {l₁ l₂ : List α} {n : Nat} + (hlen₁ : l₁.length = n) (hlen₂ : l₂.length = n) + (hlt : l₁.take n < l₂.take n) : l₁ < l₂ := by + rw [←List.take_length (l := l₁), ←List.take_length (l := l₂), hlen₁, hlen₂] + exact hlt + +theorem List.do_pure_eq_map {α β : Type} (l : List α) (f : α → β) : + (do let a ← l; pure (f a)) = List.map f l := by + induction l with + | nil => rfl + | cons x xs ih => + show List.flatMap _ (x :: xs) = _ + simp only [List.flatMap_cons, Pure.pure, List.singleton_append, List.map_cons] + congr 1 diff --git a/stdlib/lampe/Stdlib/Ext/Nat.lean b/stdlib/lampe/Stdlib/Ext/Nat.lean new file mode 100644 index 00000000..fe0e1754 --- /dev/null +++ b/stdlib/lampe/Stdlib/Ext/Nat.lean @@ -0,0 +1,32 @@ +import Mathlib.Tactic + +/-! +# Nat Extensions + +Mathlib-style extensions for Nat operations. +-/ + +namespace Nat + +theorem mod_sub_add_eq (n i k : Nat) (hi : i ≤ k) (hk : k < n) : + (n - i + k) % n = k - i := by + have hin : i ≤ n := le_trans hi (le_of_lt hk) + have hcalc : n - i + k = n + (k - i) := by + calc + n - i + k = n + k - i := by + symm + exact Nat.sub_add_comm hin + _ = n + (k - i) := by + exact Nat.add_sub_assoc hi n + calc + (n - i + k) % n = (n + (k - i)) % n := by + simp [hcalc] + _ = ((n % n) + (k - i) % n) % n := by + simp [Nat.add_mod] + _ = (k - i) % n := by simp + _ = k - i := by + apply Nat.mod_eq_of_lt + have : k - i ≤ k := Nat.sub_le _ _ + exact lt_of_le_of_lt this hk + +end Nat diff --git a/stdlib/lampe/Stdlib/Ext/README.md b/stdlib/lampe/Stdlib/Ext/README.md new file mode 100644 index 00000000..d760b94c --- /dev/null +++ b/stdlib/lampe/Stdlib/Ext/README.md @@ -0,0 +1,3 @@ +# Ext + +Here is where we should shove lemmas which are essentially just an extension of Mathlib. diff --git a/stdlib/lampe/Stdlib/Ext/Vector.lean b/stdlib/lampe/Stdlib/Ext/Vector.lean new file mode 100644 index 00000000..3268984d --- /dev/null +++ b/stdlib/lampe/Stdlib/Ext/Vector.lean @@ -0,0 +1,12 @@ +import Lampe + +/-! +# List.Vector Extensions + +Mathlib-style extensions for List.Vector operations. +-/ + +theorem List.Vector.reverse_map {α β : Type} {d : ℕ} (v : List.Vector α d) (f : α → β) : + (v.map f).reverse = v.reverse.map f := by + apply List.Vector.eq + simp [List.Vector.toList_reverse] diff --git a/stdlib/lampe/Stdlib/Field/Basic.lean b/stdlib/lampe/Stdlib/Field/Basic.lean new file mode 100644 index 00000000..d2030262 --- /dev/null +++ b/stdlib/lampe/Stdlib/Field/Basic.lean @@ -0,0 +1,19 @@ +import «std-1.0.0-beta.12».Extracted +import Lampe + +namespace Lampe.Stdlib.Field + +open «std-1.0.0-beta.12» (env) + +-- idk where to put this one cuz i don't want to create circular dependency +-- but bn254 depends on it as does mod +theorem assert_max_bit_size_intro : + STHoare p env ⟦⟧ + («std-1.0.0-beta.12::field::assert_max_bit_size».call h![BIT_SIZE] h![f]) + (fun r => ∃∃ h : f.val < 2 ^ BIT_SIZE.toNat, r = ()) := by + enter_decl + steps + all_goals + simp_all + +end Lampe.Stdlib.Field diff --git a/stdlib/lampe/Stdlib/Field/Bn254.lean b/stdlib/lampe/Stdlib/Field/Bn254.lean index f5b60e7b..efd8aded 100644 --- a/stdlib/lampe/Stdlib/Field/Bn254.lean +++ b/stdlib/lampe/Stdlib/Field/Bn254.lean @@ -1,6 +1,403 @@ import «std-1.0.0-beta.12».Extracted import Lampe +import Stdlib.Field.Basic namespace Lampe.Stdlib.Field.Bn254 -open «std-1.0.0-beta.12» +open Lampe +open «std-1.0.0-beta.12» (env) + +abbrev PLO := «std-1.0.0-beta.12::field::bn254::PLO» +abbrev PHI := «std-1.0.0-beta.12::field::bn254::PHI» +abbrev TWO_POW_128 := «std-1.0.0-beta.12::field::bn254::TWO_POW_128» + +def ploNat : Nat := 53438638232309528389504892708671455233 + +def phiNat : Nat := 64323764613183177041862057485226039389 + +def pow128 : Nat := 2 ^ 128 + +lemma pow128_lt_prime {p} [Prime.BitsGT p 129] : pow128 < p.natVal := by + simpa [pow128] using (Prime.BitsGT.lt_prime (prime := p) (bits := 128)) + +lemma pow128_val {p} [Prime.BitsGT p 129] : ((pow128 : Nat) : Fp p).val = pow128 := by + simpa [pow128] using (ZMod.val_natCast_of_lt (pow128_lt_prime (p := p))) + +lemma val_add_one_of_lt {p} [Prime.BitsGT p 129] {x : Fp p} (hx : x.val < pow128) : + (x + 1).val = x.val + 1 := by + have h1_val : (1 : Fp p).val = 1 := by + have h1_lt : (1 : Nat) < p.natVal := by + linarith [pow128_lt_prime (p := p)] + simpa using (ZMod.val_natCast_of_lt h1_lt) + have hx1_lt : x.val + 1 < p.natVal := by + linarith [hx, pow128_lt_prime (p := p)] + have hx1_lt' : x.val + (1 : Fp p).val < p.natVal := by + simpa [h1_val] using hx1_lt + simpa [h1_val] using (ZMod.val_add_of_lt hx1_lt') + +lemma limbs_gt_of_hi_gt {a_lo a_hi b_lo b_hi : Nat} + (hb_lo : b_lo < pow128) (hhi : b_hi < a_hi) : + a_lo + pow128 * a_hi > b_lo + pow128 * b_hi := by + have h_rhs : b_lo + pow128 * b_hi < pow128 * (b_hi + 1) := by + have h1 : b_lo + pow128 * b_hi < pow128 + pow128 * b_hi := + Nat.add_lt_add_right hb_lo _ + simpa [Nat.mul_add, Nat.add_comm, Nat.add_left_comm, Nat.add_assoc] using h1 + have hpow : pow128 * (b_hi + 1) ≤ pow128 * a_hi := by + exact Nat.mul_le_mul_left _ (Nat.succ_le_of_lt hhi) + have h_lhs_ge : pow128 * (b_hi + 1) ≤ a_lo + pow128 * a_hi := by + exact le_trans hpow (Nat.le_add_left _ _) + exact lt_of_lt_of_le h_rhs h_lhs_ge + +lemma prime_sub_pow128_gt {p} [Prime.BitsGT p 129] : p.natVal - pow128 > pow128 := by + have hp : (2 ^ 129 : Nat) < p.natVal := by + simpa using (Prime.BitsGT.lt_prime (prime := p) (bits := 129)) + have hsum : pow128 + pow128 < p.natVal := by + have : (2 ^ 129 : Nat) = pow128 + pow128 := by + simp [pow128, Nat.pow_succ, two_mul] + simpa [this] using hp + exact (Nat.lt_sub_iff_add_lt).2 hsum + +lemma sub_val_gt_pow128_of_lt {p} [Prime.BitsGT p 129] {a b : Fp p} + (ha : a.val < pow128) (hb : b.val ≤ pow128) (h : a.val < b.val) : + (a - b).val > pow128 := by + have hb_lt : b.val < p.natVal := lt_of_le_of_lt hb (pow128_lt_prime (p := p)) + have hbne : b ≠ 0 := by + intro hbz + subst hbz + simpa using h + haveI : NeZero b := ⟨hbne⟩ + have hneg : (-b).val = p.natVal - b.val := by + simpa using (ZMod.val_neg_of_ne_zero b) + have hsum_lt : a.val + (-b).val < p.natVal := by + have hb_le : b.val ≤ p.natVal := le_of_lt hb_lt + have hsum_lt' : + a.val + (p.natVal - b.val) < b.val + (p.natVal - b.val) := + Nat.add_lt_add_right h _ + simpa [hneg, Nat.add_sub_of_le hb_le] using hsum_lt' + have hval : (a - b).val = a.val + (-b).val := by + simpa [sub_eq_add_neg] using (ZMod.val_add_of_lt hsum_lt) + have hsum_eq : a.val + (-b).val = p.natVal - (b.val - a.val) := by + have ha_le : a.val ≤ b.val := le_of_lt h + have hb_le : b.val ≤ p.natVal := le_of_lt hb_lt + have hd_le : b.val - a.val ≤ p.natVal := by + exact le_trans (Nat.sub_le _ _) hb_le + have hsum : + a.val + (p.natVal - b.val) + (b.val - a.val) = p.natVal := by + calc + a.val + (p.natVal - b.val) + (b.val - a.val) + = (a.val + (b.val - a.val)) + (p.natVal - b.val) := by + simp [Nat.add_assoc, Nat.add_left_comm, Nat.add_comm] + _ = b.val + (p.natVal - b.val) := by + simp [Nat.add_sub_of_le ha_le] + _ = p.natVal := by + simp [Nat.add_sub_of_le hb_le] + have hsum' : + a.val + (p.natVal - b.val) + (b.val - a.val) = + p.natVal - (b.val - a.val) + (b.val - a.val) := by + calc + a.val + (p.natVal - b.val) + (b.val - a.val) = p.natVal := hsum + _ = p.natVal - (b.val - a.val) + (b.val - a.val) := by + symm + exact Nat.sub_add_cancel hd_le + have hsum_eq' : a.val + (p.natVal - b.val) = p.natVal - (b.val - a.val) := by + exact Nat.add_right_cancel hsum' + simpa [hneg] using hsum_eq' + have hdiff : b.val - a.val ≤ pow128 := by + exact le_trans (Nat.sub_le _ _) hb + have hgt : p.natVal - (b.val - a.val) > pow128 := by + have hge : p.natVal - (b.val - a.val) ≥ p.natVal - pow128 := by + exact Nat.sub_le_sub_left hdiff _ + linarith [prime_sub_pow128_gt (p := p), hge] + calc + (a - b).val = a.val + (-b).val := hval + _ = p.natVal - (b.val - a.val) := hsum_eq + _ > pow128 := hgt + +lemma sub_val_lt_pow128_of_le {p} [Prime.BitsGT p 129] {a b : Fp p} + (ha : a.val < pow128) (hle : b.val ≤ a.val) : + (a - b).val < pow128 := by + haveI : NeZero p.natVal := by infer_instance + have hval : (a - b).val = a.val - b.val := by + exact ZMod.val_sub (a := a) (b := b) hle + have : a.val - b.val < pow128 := by + exact lt_of_le_of_lt (Nat.sub_le _ _) ha + simpa [hval] using this + +theorem plo_spec {p} : + STHoare p env ⟦⟧ + (PLO.call h![] h![]) + (fun r => r = (ploNat : Fp p)) := by + enter_decl + steps + rename_i hplo + simpa [ploNat] using hplo + +theorem phi_spec {p} : + STHoare p env ⟦⟧ + (PHI.call h![] h![]) + (fun r => r = (phiNat : Fp p)) := by + enter_decl + steps + rename_i hphi + simpa [phiNat] using hphi + +theorem two_pow_128_spec {p} : + STHoare p env ⟦⟧ + (TWO_POW_128.call h![] h![]) + (fun r => r = (pow128 : Fp p)) := by + enter_decl + steps + rename_i hpow + simpa [pow128] using hpow + +-- FIXME: steps requires this even tho it's an empty postcondition +theorem lte_hint_intro {p a b} : + STHoare p env ⟦⟧ + («std-1.0.0-beta.12::field::bn254::lte_hint».call h![] h![a, b]) + (fun _ => ⟦⟧) := by + enter_decl + steps + +-- FIXME: steps requires this even tho it's an empty postcondition +theorem decompose_hint_intro {p x} : + STHoare p env ⟦⟧ + («std-1.0.0-beta.12::field::bn254::decompose_hint».call h![] h![x]) + (fun _ => ⟦⟧) := by + enter_decl + steps + +theorem assert_gt_limbs_intro {p a b} [Prime.BitsGT p 129] : + STHoare p env ⟦⟧ + («std-1.0.0-beta.12::field::bn254::assert_gt_limbs».call h![] h![a, b]) + (fun _ => ⟦ + (a.1.val < pow128 ∧ a.2.1.val < pow128 ∧ b.1.val < pow128 ∧ b.2.1.val < pow128) → + a.1.val + pow128 * a.2.1.val > b.1.val + pow128 * b.2.1.val + ⟧) := by + enter_decl + steps [Lampe.Stdlib.Field.assert_max_bit_size_intro, two_pow_128_spec (p := p), lte_hint_intro] + intro hbounds + rcases hbounds with ⟨ha_lo, ha_hi, hb_lo, hb_hi⟩ + rename_i h_rlo_lt h_rhi_lt h_out h_out_eq + clear h_out h_out_eq + simp [Builtin.indexTpl] at * + subst alo + subst ahi + subst blo + subst bhi + cases hborrow : borrow <;> simp [hborrow] at * + · have hrlo_def : rlo = a.1 - b.1 - 1 := by + assumption + have hrhi_def : rhi = a.2.1 - b.2.1 := by + assumption + have hhi_ge : b.2.1.val ≤ a.2.1.val := by + by_contra hgt + have hlt : a.2.1.val < b.2.1.val := lt_of_not_ge hgt + have hgt' := sub_val_gt_pow128_of_lt (p := p) ha_hi (le_of_lt hb_hi) hlt + have hrhi_lt' : (a.2.1 - b.2.1).val < pow128 := by + simpa [hrhi_def] using h_rhi_lt + linarith [hrhi_lt', hgt'] + by_cases hhi_eq : b.2.1.val = a.2.1.val + · have hblo1_val : (b.1 + 1).val = b.1.val + 1 := + val_add_one_of_lt (p := p) hb_lo + have hblo1_le : (b.1 + 1).val ≤ pow128 := by + have : b.1.val + 1 ≤ pow128 := Nat.succ_le_of_lt hb_lo + simpa [hblo1_val] using this + have hrlo_lt' : (a.1 - (b.1 + 1)).val < pow128 := by + have : (a.1 - b.1 - 1).val < pow128 := by + simpa [hrlo_def] using h_rlo_lt + simpa [sub_eq_add_neg, add_assoc, add_left_comm, add_comm] using this + have hlo : b.1.val < a.1.val := by + by_contra hge + have hle : a.1.val ≤ b.1.val := le_of_not_gt hge + have hlt : a.1.val < (b.1 + 1).val := by + have : a.1.val < b.1.val + 1 := Nat.lt_succ_of_le hle + simpa [hblo1_val] using this + have hgt := sub_val_gt_pow128_of_lt (p := p) ha_lo hblo1_le hlt + linarith + have hgt : b.1.val + pow128 * b.2.1.val < a.1.val + pow128 * b.2.1.val := + Nat.add_lt_add_right hlo _ + simpa [hhi_eq, Nat.add_comm, Nat.add_left_comm, Nat.add_assoc] using hgt + · have hhi_lt : b.2.1.val < a.2.1.val := lt_of_le_of_ne hhi_ge hhi_eq + exact limbs_gt_of_hi_gt hb_lo hhi_lt + · have hrhi_def : rhi = a.2.1 - b.2.1 - 1 := by + assumption + have hbhi1_val : (b.2.1 + 1).val = b.2.1.val + 1 := + val_add_one_of_lt (p := p) hb_hi + have hbhi1_le : (b.2.1 + 1).val ≤ pow128 := by + have : b.2.1.val + 1 ≤ pow128 := Nat.succ_le_of_lt hb_hi + simpa [hbhi1_val] using this + have hrhi_lt' : (a.2.1 - (b.2.1 + 1)).val < pow128 := by + have : (a.2.1 - b.2.1 - 1).val < pow128 := by + simpa [hrhi_def] using h_rhi_lt + simpa [sub_eq_add_neg, add_assoc, add_left_comm, add_comm] using this + have hhi : b.2.1.val < a.2.1.val := by + by_contra hge + have hle : a.2.1.val ≤ b.2.1.val := le_of_not_gt hge + have hlt : a.2.1.val < (b.2.1 + 1).val := by + have : a.2.1.val < b.2.1.val + 1 := Nat.lt_succ_of_le hle + simpa [hbhi1_val] using this + have hgt := sub_val_gt_pow128_of_lt (p := p) ha_hi hbhi1_le hlt + linarith + exact limbs_gt_of_hi_gt hb_lo hhi + +theorem decompose_intro {p x} [Prime.BitsGT p 129] + (hmod : p.natVal = ploNat + pow128 * phiNat) : + STHoare p env ⟦⟧ + («std-1.0.0-beta.12::field::bn254::decompose».call h![] h![x]) + (fun r => + ∃∃ xlo xhi, + r = (xlo, xhi, ()) ∧ + xlo.val < pow128 ∧ + xhi.val < pow128 ∧ + x.val = xlo.val + pow128 * xhi.val) := by + enter_decl + steps + · exact () + apply STHoare.iteFalse_intro + steps [Lampe.Stdlib.Field.assert_max_bit_size_intro, assert_gt_limbs_intro (p := p), + decompose_hint_intro (p := p), plo_spec (p := p), phi_spec (p := p), two_pow_128_spec (p := p)] + simp [SLP.exists_pure, beq_true, decide_eq_true_eq] at * + sl + rename_i r hret hxlo hxhi hxeq hlimbs + have hret' : r = (xlo, xhi, ()) := by + simpa using hret + have hxlo' : xlo.val < pow128 := by + simpa [pow128] using hxlo + have hxhi' : xhi.val < pow128 := by + simpa [pow128] using hxhi + have hxeq' : x = xlo + (pow128 : Fp p) * xhi := by + simpa using hxeq + have hplo_val : (ploNat : Fp p).val = ploNat := by + have hplo_lt : ploNat < p.natVal := by + have hplo_lt' : ploNat < pow128 := by decide + linarith [hplo_lt', pow128_lt_prime (p := p)] + simpa using (ZMod.val_natCast_of_lt hplo_lt) + have hphi_val : (phiNat : Fp p).val = phiNat := by + have hphi_lt : phiNat < p.natVal := by + have hphi_lt' : phiNat < pow128 := by decide + linarith [hphi_lt', pow128_lt_prime (p := p)] + simpa using (ZMod.val_natCast_of_lt hphi_lt) + have hplo : (ploNat : Fp p).val < pow128 := by + have hplo_nat : ploNat < pow128 := by decide + simpa [hplo_val] using hplo_nat + have hphi : (phiNat : Fp p).val < pow128 := by + have hphi_nat : phiNat < pow128 := by decide + simpa [hphi_val] using hphi_nat + have hlimbs' : (ploNat : Fp p).val < pow128 ∧ (phiNat : Fp p).val < pow128 ∧ + xlo.val < pow128 ∧ xhi.val < pow128 := by + exact ⟨hplo, hphi, hxlo', hxhi'⟩ + have hlimbs_imp : + (ploNat : Fp p).val < pow128 ∧ (phiNat : Fp p).val < pow128 ∧ + xlo.val < pow128 ∧ xhi.val < pow128 → + (ploNat : Fp p).val + pow128 * (phiNat : Fp p).val > + xlo.val + pow128 * xhi.val := by + assumption + have hgt : (ploNat : Fp p).val + pow128 * (phiNat : Fp p).val > + xlo.val + pow128 * xhi.val := by + exact hlimbs_imp hlimbs' + have hgt' : ploNat + pow128 * phiNat > xlo.val + pow128 * xhi.val := by + simpa [hplo_val, hphi_val] using hgt + have hsum_lt : xlo.val + pow128 * xhi.val < p.natVal := by + simpa [hmod] using hgt' + have hmul_lt : pow128 * xhi.val < p.natVal := by + exact lt_of_le_of_lt (Nat.le_add_left _ _) hsum_lt + have hmul_lt' : (pow128 : Fp p).val * xhi.val < p.natVal := by + simpa [pow128_val (p := p)] using hmul_lt + have hsum_val : (xlo + (pow128 : Fp p) * xhi).val = xlo.val + pow128 * xhi.val := by + have hmul_val' : ((pow128 : Fp p) * xhi).val = (pow128 : Fp p).val * xhi.val := by + simpa using (ZMod.val_mul_of_lt hmul_lt') + have hsum_lt' : xlo.val + ((pow128 : Fp p) * xhi).val < p.natVal := by + simpa [hmul_val', pow128_val (p := p)] using hsum_lt + simpa [hmul_val', pow128_val (p := p)] using + (ZMod.val_add_of_lt (a := xlo) (b := (pow128 : Fp p) * xhi) hsum_lt') + have hval_eq : x.val = xlo.val + pow128 * xhi.val := by + have : x.val = (xlo + (pow128 : Fp p) * xhi).val := by + simpa using congrArg ZMod.val hxeq' + simpa [hsum_val] using this + refine ⟨xhi, ?_⟩ + exact ⟨hret', hxlo', hxhi', hval_eq⟩ + +theorem assert_gt_intro {p a b} [Prime.BitsGT p 129] + (hmod : p.natVal = ploNat + pow128 * phiNat) : + STHoare p env ⟦⟧ + («std-1.0.0-beta.12::field::bn254::assert_gt».call h![] h![a, b]) + (fun _ => a.val > b.val) := by + enter_decl + steps + · exact () + apply STHoare.iteFalse_intro + steps [decompose_intro (p := p) (hmod := hmod), assert_gt_limbs_intro (p := p)] + simp [SLP.exists_pure] at * + rename_i _ a_lo a_hi ha_raw b_lo b_hi hb_raw _ hlimbs + rcases ha_raw with ⟨ha_eq, ha_lo_lt, ha_hi_lt, ha_val⟩ + rcases hb_raw with ⟨hb_eq, hb_lo_lt, hb_hi_lt, hb_val⟩ + have hlimbs' : + a_lo.val < pow128 → + a_hi.val < pow128 → + b_lo.val < pow128 → + b_hi.val < pow128 → + b_lo.val + pow128 * b_hi.val < a_lo.val + pow128 * a_hi.val := by + simpa [ha_eq, hb_eq] using hlimbs + have hgt : b_lo.val + pow128 * b_hi.val < a_lo.val + pow128 * a_hi.val := by + exact hlimbs' ha_lo_lt ha_hi_lt hb_lo_lt hb_hi_lt + linarith [ha_val, hb_val, hgt] + +theorem assert_lt_intro {p a b} [Prime.BitsGT p 129] + (hmod : p.natVal = ploNat + pow128 * phiNat) : + STHoare p env ⟦⟧ + («std-1.0.0-beta.12::field::bn254::assert_lt».call h![] h![a, b]) + (fun _ => a.val < b.val) := by + enter_decl + steps [assert_gt_intro (p := p) (hmod := hmod)] + omega + +theorem field_less_than_intro {p x y} : + STHoare p env ⟦⟧ + («std-1.0.0-beta.12::field::field_less_than».call h![] h![x, y]) + (fun _ => ⟦⟧) := by + enter_decl + steps + +theorem gt_intro {p a b} [Prime.BitsGT p 129] + (hmod : p.natVal = ploNat + pow128 * phiNat) : + STHoare p env ⟦⟧ + («std-1.0.0-beta.12::field::bn254::gt».call h![] h![a, b]) + (fun r => r = decide (a.val > b.val)) := by + enter_decl + steps + · exact () + apply STHoare.iteFalse_intro + steps + apply STHoare.ite_intro + · intro h_eq + steps + have h_eq' : a = b := by + simpa [decide_eq_true_eq] using h_eq + simp_all [h_eq'] + · intro h_eq + steps [field_less_than_intro] + apply STHoare.ite_intro + · intro hlt + steps [assert_gt_intro (p := p) (hmod := hmod)] + rename_i r hret + have hgt_ba : b.val > a.val := by + assumption + have hnot : ¬ a.val > b.val := by + exact Nat.not_lt.mpr (le_of_lt hgt_ba) + simpa [hnot, decide_eq_false_iff_not] using hret + · intro hlt + steps [assert_gt_intro (p := p) (hmod := hmod)] + rename_i r hret + have hgt_ab : a.val > b.val := by + assumption + simpa [hgt_ab, decide_eq_true_eq] using hret + +theorem lt_intro {p a b} [Prime.BitsGT p 129] + (hmod : p.natVal = ploNat + pow128 * phiNat) : + STHoare p env ⟦⟧ + («std-1.0.0-beta.12::field::bn254::lt».call h![] h![a, b]) + (fun r => r = decide (a.val < b.val)) := by + enter_decl + steps [gt_intro (p := p) (hmod := hmod)] + rename_i r hret + simpa using (hret : r = decide (b.val > a.val)) diff --git a/stdlib/lampe/Stdlib/Field/Mod.lean b/stdlib/lampe/Stdlib/Field/Mod.lean index 14fd25ff..1ffe13cf 100644 --- a/stdlib/lampe/Stdlib/Field/Mod.lean +++ b/stdlib/lampe/Stdlib/Field/Mod.lean @@ -1,6 +1,994 @@ import «std-1.0.0-beta.12».Extracted import Lampe +import Lampe.Builtin.Helpers +import Stdlib.Field.Basic +import Stdlib.Field.Bn254 +import Stdlib.Compat +import Stdlib.Ext namespace Lampe.Stdlib.Field -open «std-1.0.0-beta.12» +open «std-1.0.0-beta.12» (env) + +lemma bits_lt_of_lex_lt {data pdata : List (BitVec 1)} + (hlen : data.length = pdata.length) + (hlt : data < pdata) + (hpdata : pdata = List.map (fun (d : Digit 2) => BitVec.ofNatLT d.val d.prop) + (RadixVec.toDigitsBE' 2 p)) : + RadixVec.ofDigitsBE' (data.map (fun i => (i.toFin : Digit 2))) < p := by + rw [←RadixVec.ofDigitsBE'_toDigitsBE' (r := 2) (n := p)] + apply RadixVec.ofDigitsBE'_mono + · simp [hlen, hpdata, List.length_map] + · have hself : RadixVec.toDigitsBE' 2 p = + List.map (fun (i : BitVec 1) => (i.toFin : Digit 2)) + (List.map (fun (d : Digit 2) => BitVec.ofNatLT d.val d.prop) + (RadixVec.toDigitsBE' 2 p)) := by + rw [List.map_map, eq_comm] + convert List.map_id _ + rw [hself] + apply List.map_lt + · intro x y h + rw [BitVec.lt_def] at h + rw [Fin.lt_def] + exact h + · rw [hpdata] at hlt + exact hlt + +lemma bytes_lt_of_lex_lt {data pdata : List (BitVec 8)} + (hlen : data.length = pdata.length) + (hlt : data < pdata) + (hpdata : pdata = List.map (fun (d : Digit ⟨256, by decide⟩) => BitVec.ofNatLT d.val d.prop) + (RadixVec.toDigitsBE' ⟨256, by decide⟩ p)) : + RadixVec.ofDigitsBE' (data.map (fun i => (i.toFin : Digit ⟨256, by decide⟩))) < p := by + rw [←RadixVec.ofDigitsBE'_toDigitsBE' (r := ⟨256, by decide⟩) (n := p)] + apply RadixVec.ofDigitsBE'_mono + · simp [RadixVec.toDigitsBE', hlen, hpdata, List.length_map] + · have hself : RadixVec.toDigitsBE' ⟨256, by decide⟩ p = + List.map (fun (i : BitVec 8) => i.toFin) + (List.map (fun (d : Digit ⟨256, by decide⟩) => BitVec.ofNatLT d.val d.prop) + (RadixVec.toDigitsBE' ⟨256, by decide⟩ p)) := by + simp only [List.map_map] + rw [eq_comm] + convert List.map_id _ + rw [hself] + apply List.map_lt + · intro x y h + rw [BitVec.lt_def] at h + rw [Fin.lt_def] + exact h + · rw [hpdata] at hlt + exact hlt + +lemma ofDigitsBE'_lt_of_shorter_than_modulus {r : Radix} {data : List (Digit r)} {P : Prime} + (hlen : data.length < (RadixVec.toDigitsBE' r P.natVal).length) : + RadixVec.ofDigitsBE' data < P.natVal := by + have hr : 1 < r.val := r.prop + calc RadixVec.ofDigitsBE' data + < r.val ^ data.length := RadixVec.ofDigitsBE'_lt + _ ≤ r.val ^ ((RadixVec.toDigitsBE' r P.natVal).length - 1) := by + apply Nat.pow_le_pow_right (Nat.le_of_lt hr) + apply Nat.le_pred_of_lt hlen + _ = r.val ^ Nat.log r.val P.natVal := by simp [RadixVec.toDigitsBE'] + _ ≤ P.natVal := by + apply Nat.pow_log_le_self + simp [Prime.natVal] + +theorem to_be_radix_intro : + STHoare p env ⟦⟧ + («std-1.0.0-beta.12::field::to_be_radix».call h![N] h![f, 256]) + fun o => + ∃∃ (v : List.Vector (Digit ⟨256, by decide⟩) N.toNat), + o = v.map BitVec.ofFin ⋆ + f = RadixVec.ofDigitsBE v := by + enter_decl + steps + · exact () + apply STHoare.letIn_intro + apply STHoare.iteTrue_intro + · steps + apply STHoare.skip_intro + intro _ + steps + case v => + rename_i v _ + exact v.map BitVec.toFin + · apply List.Vector.eq + simp + rw [eq_comm] + exact List.map_id _ + · subst_vars + rw [RadixVec.ofDigitsBE] + congr 2 + apply List.Vector.eq + simp + +theorem to_le_radix_intro : + STHoare p env ⟦⟧ + («std-1.0.0-beta.12::field::to_le_radix».call h![N] h![f, 256]) + fun o => + ∃∃ (v : List.Vector (Digit ⟨256, by decide⟩) N.toNat), + o = v.reverse.map BitVec.ofFin ⋆ + f = RadixVec.ofDigitsBE v := by + enter_decl + steps + · exact () + apply STHoare.letIn_intro + apply STHoare.iteTrue_intro + · steps + apply STHoare.skip_intro + intro _ + steps + case v => + rename_i v _ + exact v.reverse.map BitVec.toFin + · apply List.Vector.eq + simp [List.Vector.toList_reverse, Function.comp_def] + · subst_vars + simp only [BitVec.ofNat_eq_ofNat] + congr 2 + apply List.Vector.eq + simp [List.Vector.toList_reverse, BitVec.toFin, Fin.val_mk] + +theorem to_be_bits_intro : + STHoare p env ⟦⟧ + («std-1.0.0-beta.12::field::to_be_bits».call h![N] h![f]) + fun r => ∃∃(lt : f.val < (2 ^ N.toNat)), + r = (RadixVec.toDigitsBE (d := N.toNat) (r := 2) + ⟨f.val, by simp_all [OfNat.ofNat]⟩ |>.map BitVec.ofFin) := by + rcases N with ⟨⟨N,_⟩⟩ + enter_decl + steps + · exact () + step_as (⟦⟧) + (fun _ => RadixVec.ofDigitsBE' (bits.toList.map (fun i => (i.toFin : Digit 2))) < p.natVal) + · apply STHoare.iteTrue_intro + steps + rename' p => pbits + by_cases h: bits.length = pbits.length + · cases' bits with bits bitsLen + simp only [BitVec.toNat_ofFin] at bitsLen + cases bitsLen + loop_inv nat fun i _ _ => + (bits.take i ≤ pbits.take i) ⋆ [ok ↦ ⟨_, decide <| bits.take i < (pbits.take i)⟩] + · simp + · simp only [h] + simp [BitVec.ofNatLT_eq_ofNat] + · simp + · intro i _ _ + steps + by_cases h: bits.take i < pbits.take i + · simp only [h] + apply STHoare.iteFalse_intro + have : bits.take (i + 1) < pbits.take (i + 1) := + List.take_succ_lt_of_take_lt (by simp_all) (by simp_all) h + steps + · exact List.le_of_lt this + · simp_all + · simp only [h] + apply STHoare.iteTrue_intro + rename bits.take i ≤ pbits.take i => hle + have : bits.take i = pbits.take i := by + rw [List.le_iff_lt_or_eq] at hle + tauto + steps + by_cases hi : bits[i] = pbits[i] + · convert STHoare.iteFalse_intro _ + · simp [List.Vector.get, hi] + · rw [List.take_succ_eq_append_getElem (by assumption)] + rw [List.take_succ_eq_append_getElem (by assumption)] + rw [this, hi] + steps + · apply List.le_refl + · congr + simp [List.le_refl] + · convert STHoare.iteTrue_intro _ + · simp [List.Vector.get, hi] + · steps 7 + have hpbit : pbits[i] = 1 := by simp_all [Int.cast, IntCast.intCast] + have hbit : bits[i] = 0 := by have := U.cases_one bits[i]; simp_all + have bitle : bits[i] < pbits[i] := by simp [hpbit, hbit] + have : bits.take (i + 1) < pbits.take (i + 1) := + List.take_succ_lt_of_getElem_lt (by assumption) (by assumption) this bitle + steps + · exact List.le_of_lt this + · congr + simp [this] + steps + rename decide _ = true => hlt + have : bits.length = pbits.length := by simp_all + simp only [BitVec.toNat_ofFin, List.take_length, beq_true, decide_eq_true_eq] at hlt + simp only [this, List.take_length] at hlt + apply bits_lt_of_lex_lt this hlt + subst pbits + rfl + · loop_inv nat fun _ _ _ => [ok ↦ ⟨_, true⟩] + · congr + simp only [ + BitVec.toNat_ne, BitVec.natCast_eq_ofNat, BitVec.ofNat_toNat, + List.Vector.length, + ] + simp_all + · simp + · intro _ _ _ + steps + apply STHoare.iteFalse_intro + steps + steps + have hlen_lt : bits.length < pbits.length := by + apply lt_of_le_of_ne + · simp only [ + BitVec.natCast_eq_ofNat, BitVec.ofNat_toNat, BitVec.setWidth_eq, + List.Vector.length, + ] at * + simp_all + · assumption + apply ofDigitsBE'_lt_of_shorter_than_modulus (P := p) + subst pbits + simp_all + steps + rotate_left + · rename_i v _ + subst_vars + simp + rw [ZMod.val_natCast] + apply lt_of_le_of_lt (Nat.mod_le _ _) + apply RadixVec.ofDigitsBE_lt + · rename_i h v _ + subst_vars + simp only [←List.Vector.toList_map, RadixVec.ofDigitsBE'_toList] at h + conv_rhs => + enter [2, 1, 1] + rw [ZMod.val_natCast] + rw [Nat.mod_eq_of_lt h] + apply List.Vector.eq + rw [eq_comm] + simp only [ + BitVec.toNat_ofFin, Fin.eta, RadixVec.toDigitsBE_ofDigitsBE, + List.Vector.toList_map, List.map_map + ] + convert List.map_id _ + +set_option maxHeartbeats 300000 +theorem to_le_bits_intro : + STHoare p env ⟦⟧ + («std-1.0.0-beta.12::field::to_le_bits».call h![N] h![f]) + fun r => ∃∃(lt : f.val < (2 ^ N.toNat)), + r = (RadixVec.toDigitsBE (d := N.toNat) (r := 2) + ⟨f.val, by simp_all [OfNat.ofNat]⟩ |>.map BitVec.ofFin |>.reverse) := by + rcases N with ⟨⟨N,_⟩⟩ + enter_decl + steps + · exact () + step_as (⟦⟧) (fun _ => + RadixVec.ofDigitsBE' (bits.toList.reverse.map (fun i => (i.toFin : Digit 2))) < p.natVal) + · apply STHoare.iteTrue_intro + steps + rename' p => pbits + by_cases h: bits.length = pbits.length + · cases' bits with bits bitsLen + simp only [BitVec.toNat_ofFin] at bitsLen + cases bitsLen + loop_inv nat fun i _ _ => + (bits.reverse.take i ≤ pbits.reverse.take i) ⋆ + [ok ↦ ⟨_, decide <| bits.reverse.take i < (pbits.reverse.take i)⟩] + · simp + · simp only [h] + simp [BitVec.ofNatLT_eq_ofNat] + · simp + · intro i _ _ + steps + by_cases hlt: bits.reverse.take i < pbits.reverse.take i + · simp only [hlt] + apply STHoare.iteFalse_intro + have : bits.reverse.take (i + 1) < pbits.reverse.take (i + 1) := + List.take_succ_lt_of_take_lt (by simp_all) (by simp_all) hlt + steps + · exact List.le_of_lt this + · simp_all + · simp only [hlt] + apply STHoare.iteTrue_intro + rename bits.reverse.take i ≤ pbits.reverse.take i => hle + have heq : bits.reverse.take i = pbits.reverse.take i := by + rw [List.le_iff_lt_or_eq] at hle + tauto + have hi_lt_bits : i < bits.reverse.length := by simp_all [List.length_reverse] + have hi_lt_pbits : i < pbits.reverse.length := by simp_all [List.length_reverse] + steps + have hi_lt : i < bits.length := by simp_all [List.length_reverse] + have hlen_eq : bits.length = pbits.length := by simp_all + have hlen32 : bits.length < 2^32 := by simp_all + have hi32 : i < 2^32 := Nat.lt_trans hi_lt hlen32 + have hidx := U32.index_toNat bits.length i hlen32 hi32 hi_lt + by_cases hi : bits.reverse[i]'hi_lt_bits = pbits.reverse[i]'hi_lt_pbits + · convert STHoare.iteFalse_intro _ + · simp [List.Vector.get, h, hlen_eq] at hi ⊢ + simp_all [List.get_eq_getElem] + · rw [ + List.take_succ_eq_append_getElem hi_lt_bits, + List.take_succ_eq_append_getElem hi_lt_pbits, + heq, hi + ] + steps + · apply List.le_refl + · congr + simp [List.le_refl] + · convert STHoare.iteTrue_intro _ + · simp [List.Vector.get, List.length_reverse, h, hlen_eq] at hi ⊢ + simp_all [List.get_eq_getElem] + · steps 9 + rename_i hassert + have hpbit : pbits[pbits.length - 1 - i] = 1 := by + simp only [beq_true, decide_eq_true_eq, List.get_eq_getElem] at hassert + convert hassert using 2 + rw [←hlen_eq] + exact hidx.symm + have hbit : bits[bits.length - 1 - i] = 0 := by + have := U.cases_one bits[bits.length - 1 - i] + simp_all + have hbit_lt : bits.reverse[i]'hi_lt_bits < pbits.reverse[i]'hi_lt_pbits := by + simp [hpbit, hbit] + have : bits.reverse.take (i + 1) < pbits.reverse.take (i + 1) := + List.take_succ_lt_of_getElem_lt hi_lt_bits hi_lt_pbits heq hbit_lt + steps + · exact List.le_of_lt this + · simp [this] + steps + rename decide _ = true => hlt_final + have hlen : bits.length = pbits.length := by simp_all + simp [ + BitVec.toNat_ofFin, List.take_length, List.length_reverse + ] at hlt_final + simp only [hlen, List.take_length, List.length_reverse] at hlt_final + have hlt_full : bits.reverse < pbits.reverse := + List.lt_of_take_lt (by simp [hlen]) (by simp) hlt_final + have hpbits_rev : pbits.reverse = + List.map (fun (d : Digit 2) => BitVec.ofNatLT d.val d.prop) + (RadixVec.toDigitsBE' 2 p.natVal) := by + subst pbits + simp only [ + RadixVec.toDigitsBE', RadixVec.of, + List.do_pure_eq_map, List.map_map, + List.map_reverse, List.reverse_reverse + ] + congr 1 + funext x + simp [BitVec.ofNatLT, BitVec.ofFin] + have hlen_rev : bits.reverse.length = pbits.reverse.length := by + simp [List.length_reverse, hlen] + apply bits_lt_of_lex_lt hlen_rev (hpbits_rev ▸ hlt_full) hpbits_rev + · + loop_inv nat fun _ _ _ => [ok ↦ ⟨_, true⟩] + · congr + simp only [ + BitVec.toNat_ne, BitVec.natCast_eq_ofNat, BitVec.ofNat_toNat + ] + simp_all + · simp + · intro _ _ _ + steps + apply STHoare.iteFalse_intro + steps + steps + have hlen_lt : bits.length < pbits.length := by + apply lt_of_le_of_ne + · simp only [ + BitVec.natCast_eq_ofNat, BitVec.ofNat_toNat, BitVec.setWidth_eq + ] at * + simp_all + · assumption + apply ofDigitsBE'_lt_of_shorter_than_modulus (P := p) + subst pbits + simp [RadixVec.toDigitsBE'] at hlen_lt ⊢ + exact hlen_lt + steps + rotate_left + · rename_i v _ + subst_vars + simp + rw [ZMod.val_natCast] + apply lt_of_le_of_lt (Nat.mod_le _ _) + apply RadixVec.ofDigitsBE_lt + · + rename_i h v _ + subst_vars + simp [ + ←List.Vector.toList_map, List.Vector.toList_reverse, + ←RadixVec.ofDigitsBE'_toList, + ] at h + conv_rhs => + enter [1, 2, 1, 1] + rw [ZMod.val_natCast] + simp [ + ←List.Vector.toList_map, List.Vector.toList_reverse, + ←RadixVec.ofDigitsBE'_toList + ] + rw [Nat.mod_eq_of_lt h] + conv_rhs => + enter [1, 2, 1, 1] + rw [← List.Vector.toList_reverse] + conv_rhs => + rw [ + RadixVec.ofDigitsBE'_subtype_eq, RadixVec.toDigitsBE_ofDigitsBE, + List.Vector.reverse_map, List.Vector.reverse_reverse, + ] + apply List.Vector.eq + simp only [ + List.Vector.toList_map, List.map_map, List.map_id, + BitVec.ofFin_toFin_comp + ] + +theorem to_be_bytes_intro : + STHoare p env ⟦⟧ + («std-1.0.0-beta.12::field::to_be_bytes».call h![N] h![f]) + fun o => + ∃∃(lt : f.val < (256 ^ N.toNat)), o = (RadixVec.toDigitsBE + (d := N.toNat) + (r := ⟨256, by decide⟩) + ⟨f.val, by simp_all [OfNat.ofNat]⟩ |>.map BitVec.ofFin) := by + rcases N with ⟨⟨N, _⟩⟩ + enter_decl + steps [to_be_radix_intro] + · exact () + step_as (⟦⟧) (fun _ => + RadixVec.ofDigitsBE' (bytes.toList.map (fun i => (i.toFin : Digit ⟨256, by decide⟩))) + < p.natVal) + · apply STHoare.iteTrue_intro + steps + rename' p => pbytes -- pbytes is the modulus bytes slice + by_cases h: bytes.length = pbytes.length + · cases' bytes with bytes bytesLen + simp only [BitVec.toNat_ofFin] at bytesLen + cases bytesLen + loop_inv nat fun i _ _ => + (bytes.take i ≤ pbytes.take i) ⋆ + [ok ↦ ⟨_, decide <| bytes.take i < (pbytes.take i)⟩] + · simp + · simp only [h] + simp [BitVec.ofNatLT_eq_ofNat] + · simp + · intro i _ _ + steps + by_cases h: bytes.take i < pbytes.take i + · simp only [h] + apply STHoare.iteFalse_intro + have : bytes.take (i + 1) < pbytes.take (i + 1) := + List.take_succ_lt_of_take_lt (by simp_all) (by simp_all) h + steps + · exact List.le_of_lt this + · simp_all + · simp only [h] + apply STHoare.iteTrue_intro + rename bytes.take i ≤ pbytes.take i => hle + have heq : bytes.take i = pbytes.take i := by + rw [List.le_iff_lt_or_eq] at hle + tauto + steps + by_cases hi : bytes[i] = pbytes[i] + · convert STHoare.iteFalse_intro _ + · simp [List.Vector.get, hi] + · rw [List.take_succ_eq_append_getElem (by assumption)] + rw [List.take_succ_eq_append_getElem (by assumption)] + rw [heq, hi] + steps + · apply List.le_refl + · congr + simp [List.le_refl] + · convert STHoare.iteTrue_intro _ + · simp [List.Vector.get, hi] + · steps 7 + rename_i hlt_byte + have hbyte_lt : bytes[i] < pbytes[i] := by + simp only [beq_true, decide_eq_true_eq, BitVec.lt_def] at hlt_byte ⊢ + convert hlt_byte using 2 + have : bytes.take (i + 1) < pbytes.take (i + 1) := + List.take_succ_lt_of_getElem_lt (by assumption) (by assumption) heq hbyte_lt + steps + · exact List.le_of_lt this + · congr + simp [this] + steps + rename decide _ = true => hlt + have hlen : bytes.length = pbytes.length := by simp_all + simp [BitVec.toNat_ofFin] at hlt + simp [hlen] at hlt + have hpbytes_eq : pbytes = + List.map (fun (d : Digit ⟨256, by decide⟩) => BitVec.ofNatLT d.val d.prop) + (RadixVec.toDigitsBE' ⟨256, by decide⟩ p.natVal) := by + subst pbytes + simp only [RadixVec.toDigitsBE', List.do_pure_eq_map] + congr 1 + funext x + simp [BitVec.ofNatLT, BitVec.ofFin] + apply bytes_lt_of_lex_lt hlen hlt hpbytes_eq + · loop_inv nat fun _ _ _ => [ok ↦ ⟨_, true⟩] + · congr + simp only [BitVec.toNat_ne, BitVec.natCast_eq_ofNat, BitVec.ofNat_toNat] + simp_all + · simp + · intro _ _ _ + steps + apply STHoare.iteFalse_intro + steps + steps + have hlen_lt : bytes.length < pbytes.length := by + apply lt_of_le_of_ne + · simp_all + · assumption + have hpbytes_len : + pbytes.length = (RadixVec.toDigitsBE' ⟨256, by decide⟩ p.natVal).length := by + subst pbytes + simp [RadixVec.toDigitsBE'] + apply ofDigitsBE'_lt_of_shorter_than_modulus (P := p) + simp [List.Vector.toList_length, hpbytes_len] at hlen_lt ⊢ + exact hlen_lt + steps + rotate_left + · rename_i v _ + subst_vars + simp + rw [ZMod.val_natCast] + apply lt_of_le_of_lt (Nat.mod_le _ _) + apply RadixVec.ofDigitsBE_lt + · + subst_vars + rename_i _ h + simp [ + List.Vector.toList_map, List.Vector.toList_reverse, + ←RadixVec.ofDigitsBE'_toList, List.map_map, List.map_id, + BitVec.toFin_ofFin_comp 8, BitVec.toFin_ofFin, Function.comp, + ] at h + conv_rhs => + enter [2, 1, 1] + rw [ZMod.val_natCast] + simp [ + List.Vector.toList_map, List.Vector.toList_reverse, + ←RadixVec.ofDigitsBE'_toList, List.map_map, List.map_id, + BitVec.toFin_ofFin_comp 8, BitVec.toFin_ofFin, Function.comp + ] + rw [Nat.mod_eq_of_lt h] + apply List.Vector.eq + conv_rhs => + enter [1, 2] + rw [RadixVec.ofDigitsBE'_subtype_eq, RadixVec.toDigitsBE_ofDigitsBE] + + +set_option maxHeartbeats 500000 +theorem to_le_bytes_intro : + STHoare p env ⟦⟧ + («std-1.0.0-beta.12::field::to_le_bytes».call h![N] h![f]) + fun o => + ∃∃(lt : f.val < (256 ^ N.toNat)), o = (RadixVec.toDigitsBE + (d := N.toNat) + (r := ⟨256, by decide⟩) + ⟨f.val, by simp_all [OfNat.ofNat]⟩ |>.map BitVec.ofFin |>.reverse) := by + rcases N with ⟨⟨N, _⟩⟩ + enter_decl + steps [to_le_radix_intro] + · exact () + step_as (⟦⟧) (fun _ => + RadixVec.ofDigitsBE' + (bytes.toList.reverse.map (fun i => (i.toFin : Digit ⟨256, by decide⟩))) < p.natVal) + · apply STHoare.iteTrue_intro + steps + rename' p => pbytes + by_cases h: bytes.length = pbytes.length + · cases' bytes with bytes bytesLen + simp only [BitVec.toNat_ofFin] at bytesLen + cases bytesLen + loop_inv nat fun i _ _ => + (bytes.reverse.take i ≤ pbytes.reverse.take i) ⋆ + [ok ↦ ⟨_, decide <| bytes.reverse.take i < (pbytes.reverse.take i)⟩] + · simp + · simp only [h] + simp [BitVec.ofNatLT_eq_ofNat] + · simp + · intro i _ _ + steps + by_cases hlt: bytes.reverse.take i < pbytes.reverse.take i + · simp only [hlt] + apply STHoare.iteFalse_intro + have : bytes.reverse.take (i + 1) < pbytes.reverse.take (i + 1) := + List.take_succ_lt_of_take_lt (by simp_all) (by simp_all) hlt + steps + · exact List.le_of_lt this + · simp_all + · simp only [hlt] + apply STHoare.iteTrue_intro + rename bytes.reverse.take i ≤ pbytes.reverse.take i => hle + have heq : bytes.reverse.take i = pbytes.reverse.take i := by + rw [List.le_iff_lt_or_eq] at hle + tauto + have hi_lt_bytes : i < bytes.reverse.length := by simp_all [List.length_reverse] + have hi_lt_pbytes : i < pbytes.reverse.length := by simp_all [List.length_reverse] + steps + have hi_lt : i < bytes.length := by simp_all [List.length_reverse] + have hlen_eq : bytes.length = pbytes.length := by simp_all + have hlen32 : bytes.length < 2^32 := by simp_all + have hi32 : i < 2^32 := Nat.lt_trans hi_lt hlen32 + have hidx := U32.index_toNat bytes.length i hlen32 hi32 hi_lt + by_cases hi : bytes.reverse[i]'hi_lt_bytes = pbytes.reverse[i]'hi_lt_pbytes + · convert STHoare.iteFalse_intro _ + · simp [ + List.Vector.get, List.getElem_reverse, List.length_reverse, + h, hlen_eq + ] at hi ⊢ + simp_all [List.get_eq_getElem] + · rw [ + List.take_succ_eq_append_getElem hi_lt_bytes, + List.take_succ_eq_append_getElem hi_lt_pbytes, heq, hi + ] + steps + · apply List.le_refl + · congr + simp [List.le_refl] + · convert STHoare.iteTrue_intro _ + · simp only [ + List.Vector.get, List.getElem_reverse, List.length_reverse, h, hlen_eq + ] at hi ⊢ + simp_all [List.get_eq_getElem] + · steps 14 + rename_i hassert_lt + have hbyte_lt : bytes.reverse[i]'hi_lt_bytes < pbytes.reverse[i]'hi_lt_pbytes := by + simp only [List.getElem_reverse, h, List.length_reverse, hlen_eq] + simp only [List.Vector.get, List.get_eq_getElem] at hassert_lt + convert hassert_lt using 2 + simp_all + have : bytes.reverse.take (i + 1) < pbytes.reverse.take (i + 1) := + List.take_succ_lt_of_getElem_lt hi_lt_bytes hi_lt_pbytes heq hbyte_lt + steps + · exact List.le_of_lt this + · simp [this] + steps + rename decide _ = true => hlt_final + have hlen : bytes.length = pbytes.length := by simp_all + simp [hlen] at hlt_final + have hlt_full : bytes.reverse < pbytes.reverse := + List.lt_of_take_lt (by simp [hlen]) (by simp) hlt_final + have hpbytes_rev : pbytes.reverse = + List.map (fun (d : Digit ⟨256, by decide⟩) => BitVec.ofNatLT d.val d.prop) + (RadixVec.toDigitsBE' ⟨256, by decide⟩ p.natVal) := by + subst pbytes + simp only [ + RadixVec.toDigitsBE', RadixVec.of, List.do_pure_eq_map, List.map_map, + List.map_reverse, List.reverse_reverse + ] + congr 1 + funext x + simp [BitVec.ofNatLT, BitVec.ofFin] + have hlen_rev : bytes.reverse.length = pbytes.reverse.length := by + simp [List.length_reverse, hlen] + apply bytes_lt_of_lex_lt hlen_rev (hpbytes_rev ▸ hlt_full) hpbytes_rev + · + loop_inv nat fun _ _ _ => [ok ↦ ⟨_, true⟩] + · congr + simp only [BitVec.toNat_ne, BitVec.natCast_eq_ofNat, BitVec.ofNat_toNat] + simp_all + · simp + · intro _ _ _ + steps + apply STHoare.iteFalse_intro + steps + steps + have hlen_lt : bytes.length < pbytes.length := by + apply lt_of_le_of_ne + · simp_all + · assumption + have hpbytes_len : + pbytes.length = (RadixVec.toDigitsBE' ⟨256, by decide⟩ p.natVal).length := by + subst pbytes + simp [RadixVec.toDigitsBE', RadixVec.of, List.do_pure_eq_map] + apply ofDigitsBE'_lt_of_shorter_than_modulus (P := p) + simp only [ + List.length_map, List.length_reverse, List.Vector.toList_length, + hpbytes_len + ] at hlen_lt ⊢ + exact hlen_lt + steps + rotate_left + · rename_i v _ + subst_vars + simp + rw [ZMod.val_natCast] + apply lt_of_le_of_lt (Nat.mod_le _ _) + apply RadixVec.ofDigitsBE_lt + · rename_i hbound vDigits hvDigits + rename_i v hbytes hf + rw [hbytes] at hbound + have hbound' : RadixVec.ofDigitsBE v < p.natVal := by + simp [List.Vector.toList_reverse, BitVec.toFin_ofFin_comp 8] at hbound + rw [RadixVec.ofDigitsBE'_toList] at hbound + exact hbound + subst hvDigits hbytes hf + have hval_eq : + ZMod.val (↑↑(RadixVec.ofDigitsBE v) : ZMod p.natVal) = (RadixVec.ofDigitsBE v).val := by + rw [ZMod.val_natCast, Nat.mod_eq_of_lt hbound'] + have hlt256N : ZMod.val (↑↑(RadixVec.ofDigitsBE v) : ZMod p.natVal) < 256^N := by + rw [hval_eq] + exact (RadixVec.ofDigitsBE v).isLt + have hSubtype : + (⟨ZMod.val (↑↑(RadixVec.ofDigitsBE v) : ZMod p.natVal), hlt256N⟩ : + RadixVec ⟨256, by decide⟩ N) = RadixVec.ofDigitsBE v := by + ext + simp only [hval_eq] + simp only [hSubtype, RadixVec.toDigitsBE_ofDigitsBE, List.Vector.reverse_map] + +set_option maxHeartbeats 2000000 +theorem pow_32_intro {p self exponent} : + STHoare p env ⟦⟧ + («std-1.0.0-beta.12::field::pow_32».call h![] h![self, exponent]) + (fun r => ∃∃ hlt : exponent.val < 2^32, r = self ^ exponent.val) := by + enter_decl + steps [to_le_bits_intro] + simp [SLP.exists_pure] at * + rename_i hlt hb + set digits := + RadixVec.toDigitsBE (d := 32) (r := 2) ⟨exponent.val, hlt⟩ with hdigits + have hb_bits : b = (digits.map BitVec.ofFin).reverse := by + simpa [digits] using hb + have hb_digits : + b.reverse.map (fun i => (i.toFin : Digit 2)) = digits := by + apply List.Vector.eq + have hcomp : + ((fun (i : BitVec 1) => i.toFin) ∘ (BitVec.ofFin (w := 1))) = + (fun x : Fin (2^1) => x) := by + funext x + simp [Function.comp, BitVec.toFin_ofFin] + simp [ + hb_bits, List.Vector.reverse_map, List.Vector.toList_reverse, List.map_map, + hcomp + ] + have hb_digits_list : + b.toList.reverse.map (fun i => (i.toFin : Digit 2)) = digits.toList := by + simpa [List.Vector.toList_reverse] using + congrArg List.Vector.toList hb_digits + loop_inv nat fun i _ _ => + [r ↦ ⟨.field, self ^ (RadixVec.ofDigitsBE' (digits.toList.take (i - 1)))⟩] + · simp + · intro i hi_lo hhi + steps + · congr 1 + have hi_lo' : 1 ≤ i := by simpa using hi_lo + have hhi : i < 33 := by simpa using hhi + have hi_lt32 : i - 1 < 32 := by omega + have hi_le : i ≤ 32 := by exact Nat.lt_succ_iff.mp hhi + have hi_lt : i - 1 < digits.toList.length := by + simp [digits, List.Vector.toList_length, hi_lt32] + have hi_lt_rev : i - 1 < b.toList.reverse.length := by + simp [List.length_reverse, List.Vector.toList_length, hi_lt32] + have hmap : + (b.toList.reverse[i - 1]'hi_lt_rev).toFin = + (b.toList.reverse.map (fun i => (i.toFin : Digit 2)))[i - 1]'(by + simpa [List.length_map] using hi_lt_rev) := by + simp [ + (List.getElem_map_rev (f := fun i => (i.toFin : Digit 2)) + (l := b.toList.reverse) (n := i - 1) (h := hi_lt_rev)) + ] + have hidx : + (b.toList.reverse[i - 1]'hi_lt_rev).toFin = + digits.toList[i - 1]'hi_lt := by + simpa [hb_digits_list] using hmap + set a := RadixVec.ofDigitsBE' (digits.toList.take (i - 1)) with ha + have hindex_lt32 : 32 - i < 32 := by omega + have hindex_lt : 32 - i < b.toList.length := by + simpa [List.Vector.toList_length] using hindex_lt32 + have hmod : (4294967296 - i + 32) % 4294967296 = 32 - i := by + exact Nat.mod_sub_add_eq 4294967296 i 32 hi_le (by decide) + have hindex_fin : + (⟨(4294967296 - i + 32) % 4294967296, by + simp [hmod, hindex_lt32]⟩ : Fin 32) = + ⟨32 - i, hindex_lt32⟩ := by + apply Fin.ext + simp [hmod] + have htake : + digits.toList.take i = + digits.toList.take (i - 1) ++ [digits.toList[i - 1]'hi_lt] := by + have hi_eq : i = i - 1 + 1 := by + exact (Nat.sub_add_cancel hi_lo').symm + have htake' : + digits.toList.take (i - 1 + 1) = + digits.toList.take (i - 1) ++ [digits.toList[i - 1]'hi_lt] := by + exact List.take_succ_eq_append_getElem hi_lt + conv_lhs => rw [hi_eq] + exact htake' + have hdigits_take : + RadixVec.ofDigitsBE' (digits.toList.take i) = + RadixVec.ofDigitsBE' (digits.toList.take (i - 1)) * 2 + + digits.toList[i - 1]'hi_lt := by + rw [htake, RadixVec.ofDigitsBE'_append, RadixVec.ofDigitsBE'_cons] + have hradix : (↑(2 : Radix) : Nat) = 2 := rfl + simp [hradix] + have hpow2 : + self ^ (a * 2) = self ^ a * self ^ a := by + simpa [ha, pow_two] using (pow_mul self a 2) + have hbit_info (bit : U 1) + (hbit_rev : b.toList.reverse[i - 1]'hi_lt_rev = bit) : + (↑(BitVec.toNat (List.Vector.get b ⟨32 - i, hindex_lt32⟩)) : Fp p) = + (↑(BitVec.toNat bit) : Fp p) ∧ + digits.toList[i - 1]'hi_lt = (bit.toFin : Digit 2) := by + have hsub : 32 - 1 - (i - 1) = 32 - i := by + omega + have hbit_index : b.toList[32 - i]'hindex_lt = bit := by + simpa [ + List.getElem_reverse, List.length_reverse, List.Vector.toList_length, + hsub + ] using hbit_rev + have hbit_nat : + (↑(BitVec.toNat (List.Vector.get b ⟨32 - i, hindex_lt32⟩)) : Fp p) = + (↑(BitVec.toNat bit) : Fp p) := by + simpa [List.Vector.get, List.get_eq_getElem] using + (congrArg (fun x => (↑x.toNat : Fp p)) hbit_index) + have hbit_digit : digits.toList[i - 1]'hi_lt = (bit.toFin : Digit 2) := by + have hbit_fin : + (b.toList.reverse[i - 1]'hi_lt_rev).toFin = bit.toFin := by + simpa using congrArg (fun x => (x.toFin : Digit 2)) hbit_rev + exact hidx.symm.trans hbit_fin + exact ⟨hbit_nat, hbit_digit⟩ + by_cases hbit : b.toList.reverse[i - 1]'hi_lt_rev = 0 + · rcases hbit_info 0 hbit with ⟨hbit_nat, hbit_digit⟩ + have hdigits_zero : + RadixVec.ofDigitsBE' (digits.toList.take i) = + RadixVec.ofDigitsBE' (digits.toList.take (i - 1)) * 2 := by + simp [hdigits_take, hbit_digit] + calc + ↑(BitVec.toNat (List.Vector.get b ⟨(4294967296 - i + 32) % 4294967296, by + simpa [hmod, List.Vector.toList_length] using hindex_lt⟩)) * + (self ^ a * self ^ a * self) + + (1 - ↑(BitVec.toNat (List.Vector.get b ⟨(4294967296 - i + 32) % 4294967296, by + simpa [hmod, List.Vector.toList_length] using hindex_lt⟩))) * + (self ^ a * self ^ a) + = self ^ a * self ^ a := by + simp [hindex_fin, hbit_nat] + _ = self ^ (a * 2) := by + symm + exact hpow2 + _ = self ^ RadixVec.ofDigitsBE' (digits.toList.take i) := by + simp [ha, hdigits_zero] + · have hbit_rev : b.toList.reverse[i - 1]'hi_lt_rev = 1 := by + have := U.cases_one (b.toList.reverse[i - 1]'hi_lt_rev) + tauto + rcases hbit_info 1 hbit_rev with ⟨hbit_nat, hbit_digit⟩ + have hdigits_one : + RadixVec.ofDigitsBE' (digits.toList.take i) = + RadixVec.ofDigitsBE' (digits.toList.take (i - 1)) * 2 + 1 := by + have hmod1 : (1 % (2 : Nat)) = 1 := by + exact Nat.mod_eq_of_lt (by decide) + have hradix : (↑(2 : Radix) : Nat) = 2 := rfl + simp [hdigits_take, hbit_digit, hmod1, hradix] + calc + ↑(BitVec.toNat (List.Vector.get b ⟨(4294967296 - i + 32) % 4294967296, by + simpa [hmod, List.Vector.toList_length] using hindex_lt⟩)) * + (self ^ a * self ^ a * self) + + (1 - ↑(BitVec.toNat (List.Vector.get b ⟨(4294967296 - i + 32) % 4294967296, by + simpa [hmod, List.Vector.toList_length] using hindex_lt⟩))) * + (self ^ a * self ^ a) + = self ^ a * self ^ a * self := by + simp [hindex_fin, hbit_nat] + _ = self ^ (a * 2) * self := by + simp [hpow2] + _ = self ^ (a * 2 + 1) := by + simp [pow_add, pow_one] + _ = self ^ RadixVec.ofDigitsBE' (digits.toList.take i) := by + simp [ha, hdigits_one] + · + have htake32 : List.take 32 digits.toList = digits.toList := by + simp [List.Vector.toList_length, List.take_length (l := digits.toList)] + have hdigits_val : RadixVec.ofDigitsBE' digits.toList = exponent.val := by + have hdigits_eq : RadixVec.ofDigitsBE digits = ⟨exponent.val, hlt⟩ := by + simpa [hdigits] using (RadixVec.ofDigitsBE_toDigitsBE (n := ⟨exponent.val, hlt⟩)) + have := RadixVec.ofDigitsBE'_toList (l := digits) + simp [hdigits_eq, this] + have hpow_val : + self ^ RadixVec.ofDigitsBE' (List.take 32 digits.toList) = self ^ exponent.val := by + simp [htake32, hdigits_val] + have hlt32 : ZMod.val exponent < 4294967296 := by simpa [hlt] + refine STHoare.consequence + (H₁ := + [r ↦ ⟨Tp.field, self ^ RadixVec.ofDigitsBE' (List.take 32 digits.toList)⟩]) + (Q₁ := fun v => + ⟦v = self ^ RadixVec.ofDigitsBE' (List.take 32 digits.toList)⟧ ⋆ + [r ↦ ⟨Tp.field, self ^ RadixVec.ofDigitsBE' (List.take 32 digits.toList)⟩]) + ?_ ?_ ?_ + · exact SLP.entails_self + · intro v + simp [SLP.star_assoc] + apply SLP.pure_left + intro hv + apply SLP.pure_right + · refine And.intro hlt32 ?_ + calc + v = self ^ RadixVec.ofDigitsBE' (List.take 32 digits.toList) := hv + _ = self ^ exponent.val := hpow_val + · exact SLP.entails_top + · simpa using (STHoare.readRef_intro (p := p) (Γ := env) (r := r) + (tp := Tp.field) + (v := self ^ RadixVec.ofDigitsBE' (List.take 32 digits.toList))) + +theorem lt_intro {p self another} [Prime.BitsGT p 129] + (hmod : p.natVal = Lampe.Stdlib.Field.Bn254.ploNat + + Lampe.Stdlib.Field.Bn254.pow128 * Lampe.Stdlib.Field.Bn254.phiNat) : + STHoare p env ⟦⟧ + («std-1.0.0-beta.12::field::lt».call h![] h![self, another]) + (fun r => r = decide (self.val < another.val)) := by + enter_decl + steps [Lampe.Stdlib.Compat.is_bn254_spec] + apply STHoare.iteTrue_intro + steps [Lampe.Stdlib.Field.Bn254.lt_intro (p := p) (hmod := hmod)] + rename_i hlt + simp [hlt] + +theorem from_le_bytes_intro : + STHoare p env ⟦⟧ + («std-1.0.0-beta.12::field::from_le_bytes».call h![N] h![bytes]) + fun output => output = Fp.ofBytesLE (P := p) bytes.toList := by + rcases N with ⟨⟨N, hN⟩⟩ + enter_decl + steps + loop_inv nat fun i _ _ => + [v ↦ ⟨.field, (256 ^ i : Fp p)⟩] ⋆ + [result ↦ ⟨.field, Fp.ofBytesLE (P := p) (bytes.toList.take i)⟩] + · simp + · intro i _ hhi + steps + · congr 1 + conv at hhi => rhs; whnf + simp only [ + Lens.modify, BitVec.ofNat_eq_ofNat, BitVec.reduceToNat, Builtin.instCastTpUField, + Builtin.instCastTpU, BitVec.natCast_eq_ofNat, List.take_succ, getElem?, decidableGetElem?, + List.Vector.toList_length + ] + simp only [hhi, Fp.ofBytesLE, List.map_append, ofBaseLE_append] + have hi_le : i ≤ N := by linarith + have hi_mod : i % 4294967296 = i := by + apply Nat.mod_eq_of_lt + linarith [hi_le, hN] + simp [*, List.Vector.get, ofBaseLE] + rw [mul_comm] + rfl + steps + simp_all + rw [List.take_of_length_le] + · simp + +theorem from_be_bytes_intro : + STHoare p env ⟦⟧ + («std-1.0.0-beta.12::field::from_be_bytes».call h![N] h![bytes]) + fun output => output = Fp.ofBytesLE (P := p) bytes.toList.reverse := by + rcases N with ⟨⟨N, hN⟩⟩ + enter_decl + steps + loop_inv nat fun i _ _ => + [v ↦ ⟨.field, (256 ^ i : Fp p)⟩] ⋆ + [result ↦ ⟨.field, Fp.ofBytesLE (P := p) (bytes.toList.reverse.take i)⟩] + · simp + · intro i _ hhi + steps + · congr 1 + conv at hhi => rhs; whnf + simp only [ + Lens.modify, BitVec.ofNat_eq_ofNat, Builtin.instCastTpUField, Builtin.instCastTpU, + BitVec.natCast_eq_ofNat, List.take_succ, getElem?, decidableGetElem?, + List.Vector.toList_length, List.length_reverse + ] + simp only [hhi, Fp.ofBytesLE, List.map_append, ofBaseLE_append] + have hi_le : i ≤ N := by linarith + have hi_mod : i % 4294967296 = i := by + apply Nat.mod_eq_of_lt + linarith [hi_le, hN] + have hlen32 : N < 2^32 := by simp [hN] + have hi32 : i < 2^32 := Nat.lt_trans hhi hlen32 + have hidx := U32.index_toNat N i hlen32 hi32 hhi + simp_all [List.Vector.get, ofBaseLE, List.getElem_reverse] + rw [mul_comm] + rfl + steps + simp_all + rw [List.take_of_length_le] + · simp [List.length_reverse] + +theorem sgn0_intro : + STHoare p env ⟦⟧ + («std-1.0.0-beta.12::field::sgn0».call h![] h![f]) + (fun r => r = @Builtin.CastTp.cast Tp.field (Tp.u 1) _ p f) := by + enter_decl + simpa using + (Lampe.STHoare.cast_intro (p := p) (Γ := env) (tp := Tp.field) (tp' := Tp.u 1) (v := f)) diff --git a/stdlib/lampe/Stdlib/Ops/Bit.lean b/stdlib/lampe/Stdlib/Ops/Bit.lean index 731f2058..6830f4c2 100644 --- a/stdlib/lampe/Stdlib/Ops/Bit.lean +++ b/stdlib/lampe/Stdlib/Ops/Bit.lean @@ -29,7 +29,6 @@ theorem bool_not_spec {p a} resolve_trait steps simp_all - exact () set_option maxRecDepth 1300 in theorem u128_not_spec {p a} @@ -665,4 +664,3 @@ theorem i64_shr_spec {p a b} resolve_trait steps simp_all - diff --git a/stdlib/lampe/Stdlib/Option.lean b/stdlib/lampe/Stdlib/Option.lean index 742001c4..fdde8dc6 100644 --- a/stdlib/lampe/Stdlib/Option.lean +++ b/stdlib/lampe/Stdlib/Option.lean @@ -83,7 +83,6 @@ theorem is_none_spec {p T v} : STHoare p env ⟦⟧ steps subst_vars simp - exact () theorem is_some_spec {p T v} : STHoare p env ⟦⟧ («std-1.0.0-beta.12::option::Option::is_some».call h![T] h![v]) diff --git a/stdlib/lampe/Stdlib/Slice.lean b/stdlib/lampe/Stdlib/Slice.lean index 3eb647a8..f9249037 100644 --- a/stdlib/lampe/Stdlib/Slice.lean +++ b/stdlib/lampe/Stdlib/Slice.lean @@ -6,7 +6,7 @@ import Stdlib.TraitMethods namespace Lampe.Stdlib.Slice open «std-1.0.0-beta.12» -set_option maxRecDepth 10000 +set_option maxRecDepth 1000 set_option Lampe.pp.Expr true set_option Lampe.pp.STHoare true @@ -601,4 +601,3 @@ theorem any_pure_spec {p T Env l f fb fEmb} simp_all only [List.any_eq_true, Bool.decide_or, Bool.decide_eq_true, List.any_append, List.any_cons, List.any_nil, Bool.or_false] rw [←List.any_eq] - diff --git a/stdlib/lampe/Stdlib/Stdlib.lean b/stdlib/lampe/Stdlib/Stdlib.lean index 90d3665e..7d457152 100644 --- a/stdlib/lampe/Stdlib/Stdlib.lean +++ b/stdlib/lampe/Stdlib/Stdlib.lean @@ -1,5 +1,6 @@ import Stdlib.Append import Stdlib.Array.CheckShuffle +import Stdlib.Ext import Stdlib.Array.Mod import Stdlib.Cmp import Stdlib.Collections.BoundedVec diff --git a/stdlib/lampe/std-1.0.0-beta.12/Extracted/Array/Mod.lean b/stdlib/lampe/std-1.0.0-beta.12/Extracted/Array/Mod.lean index c6977d49..7dbd84f7 100644 --- a/stdlib/lampe/std-1.0.0-beta.12/Extracted/Array/Mod.lean +++ b/stdlib/lampe/std-1.0.0-beta.12/Extracted/Array/Mod.lean @@ -65,7 +65,7 @@ noir_def «std-1.0.0-beta.12»::array::reduce(self: accumulator } -noir_def «std-1.0.0-beta.12»::array::all(self: Array, predicate: λ(T) -> bool) -> bool := { +noir_def «std-1.0.0-beta.12»::array::«all»(self: Array, predicate: λ(T) -> bool) -> bool := { let mut ret = #_true; { let ζi0 = self; diff --git a/stdlib/lampe/std-1.0.0-beta.12/Extracted/Slice.lean b/stdlib/lampe/std-1.0.0-beta.12/Extracted/Slice.lean index 7126c583..d492d858 100644 --- a/stdlib/lampe/std-1.0.0-beta.12/Extracted/Slice.lean +++ b/stdlib/lampe/std-1.0.0-beta.12/Extracted/Slice.lean @@ -147,7 +147,7 @@ noir_def «std-1.0.0-beta.12»::slice::join(self: Slice, separator: ret } -noir_def «std-1.0.0-beta.12»::slice::all(self: Slice, predicate: λ(T) -> bool) -> bool := { +noir_def «std-1.0.0-beta.12»::slice::«all»(self: Slice, predicate: λ(T) -> bool) -> bool := { let mut ret = #_true; { let ζi0 = self;