diff --git a/cedar-lean/Protobuf/BParsec.lean b/cedar-lean/Protobuf/BParsec.lean index 49565087..32ee5b51 100644 --- a/cedar-lean/Protobuf/BParsec.lean +++ b/cedar-lean/Protobuf/BParsec.lean @@ -104,15 +104,21 @@ def run! [Inhabited α] (p : BParsec α) (ba : ByteArray) : α := -- Iterator wrappers -/-- Advance the iterator -/ +/-- Advance the iterator one byte, discarding it -/ @[inline] def next : BParsec Unit := λ pos => { pos := pos.next, res := .ok () } +/-- Advance the iterator `n` bytes, discarding them -/ @[inline] def forward (n : Nat) : BParsec Unit := λ pos => { pos := pos.forward n, res := .ok () } +/-- Advance the iterator one byte, returning it, or `None` if the iterator was empty -/ +@[inline] +def nextByte : BParsec (Option UInt8) := λ pos => + { pos := pos.next, res := .ok pos.data[pos.pos]? } + /-- Return some computation on the current iterator state, without changing the state -/ @[inline] def inspect (f : ByteArray.ByteIterator → α) : BParsec α := λ pos => @@ -130,6 +136,7 @@ def remaining : BParsec Nat := inspect ByteArray.ByteIterator.remaining @[inline] def empty : BParsec Bool := inspect ByteArray.ByteIterator.empty +/-- Get the current iterator position, as a `Nat` -/ @[inline] def pos : BParsec Nat := inspect ByteArray.ByteIterator.pos diff --git a/cedar-lean/Protobuf/Proofs.lean b/cedar-lean/Protobuf/Proofs.lean index 2a9986e0..2397f090 100644 --- a/cedar-lean/Protobuf/Proofs.lean +++ b/cedar-lean/Protobuf/Proofs.lean @@ -162,13 +162,13 @@ theorem foldl_iterator_progress {f : BParsec α} {g : β → α → β} {remaini unfold foldlHelper at H have H2 : ¬(ni = 0) := by omega rw [if_neg H2] at H - simp only [Bind.bind, bind, pos, inspect] at H + simp only [Bind.bind, bind, pos, inspect, throw_eq_fail] at H cases H3 : f pos₀ ; simp only [H3] at H ; rename_i pos₂ res₂ cases res₂ <;> simp only [ParseResult.mk.injEq, reduceCtorEq, and_false] at H case ok res₂ => by_cases H4 : (pos₂.pos - pos₀.pos = 0) case pos => - simp only [H4, reduceIte, throw_eq_fail, fail, ParseResult.mk.injEq, reduceCtorEq, and_false] at H + simp only [H4, reduceIte, fail, ParseResult.mk.injEq, reduceCtorEq, and_false] at H case neg => simp only [H4, reduceIte] at H let ni2 := ni - (pos₂.pos - pos₀.pos) @@ -194,65 +194,146 @@ namespace Proto instance : DecidableEq (BParsec.ParseResult (Char × Nat)) := by apply inferInstance -theorem utf8DecodeChar.sizeGt0 {pos₀ pos₁ : ByteArray.ByteIterator} {i n : Nat} {c : Char} - (H : utf8DecodeChar i pos₀ = { pos := pos₁, res := .ok ⟨c, n⟩ }) : - n > 0 +/-- Proof that `BParsec.nextByte` always progresses the iterator exactly 1 byte if it succeeds -/ +theorem BParsec.nextByte.progress {pos₀ pos₁ : ByteArray.ByteIterator} {u : UInt8} + (h₀ : BParsec.nextByte pos₀ = { pos := pos₁, res := .ok (some u) }) : + pos₁.pos = pos₀.pos + 1 := by - unfold utf8DecodeChar at H - simp only [bind, BParsec.bind, BParsec.inspect, beq_iff_eq, pure, bne_iff_ne, ne_eq, - BParsec.throw_eq_fail, gt_iff_lt, ite_not, Bool.and_eq_true, not_and, and_imp] at H - split at H - · simp only [BParsec.pure, BParsec.ParseResult.mk.injEq, Except.ok.injEq, Prod.mk.injEq] at H - omega - · split at H - · simp only [BParsec.bind, BParsec.inspect] at H - split at H - · split at H - · simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false] at H - · split at H - · simp only [BParsec.pure, BParsec.ParseResult.mk.injEq, Except.ok.injEq, - Prod.mk.injEq] at H - omega - · simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false] at H - · simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false] at H - · split at H - · simp only [BParsec.bind, BParsec.inspect] at H - split at H - · simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false] at H - · split at H - · simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false] at H - · split at H - · simp only [BParsec.pure, BParsec.ParseResult.mk.injEq, Except.ok.injEq, - Prod.mk.injEq] at H - omega - · simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false] at H - · split at H - · simp only [BParsec.bind, BParsec.inspect] at H - split at H - · simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false] at H - · split at H - · simp only [BParsec.pure, BParsec.ParseResult.mk.injEq, Except.ok.injEq, - Prod.mk.injEq] at H - omega - · simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false] at H - · simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false] at H - -/-- Uglier version of `parseStringHelper` which is functionally equivalent to -`parseStringHelper`, but has a termination proof, unlike `parseStringHelper`. --/ -private def parseStringHelper_unoptimized (remaining : Nat) (r : String) : BParsec String := do + simp only [BParsec.nextByte, ByteArray.ByteIterator.next, BParsec.ParseResult.mk.injEq, + Except.ok.injEq] at h₀ + have ⟨h₀, _⟩ := h₀ + subst pos₁ + simp only + +/-- Proof that `utf8DecodeChar` always progresses the iterator at least 1 byte if it succeeds -/ +theorem utf8DecodeChar.progress {pos₀ pos₁ : ByteArray.ByteIterator} {c : Char} + (h₀ : utf8DecodeChar pos₀ = { pos := pos₁, res := .ok c }) : + pos₁.pos > pos₀.pos +:= by + revert h₀ + unfold utf8DecodeChar + simp only [bind, BParsec.bind, BParsec.throw_eq_fail, beq_iff_eq, pure, + bne_iff_ne, ne_eq, gt_iff_lt, ite_not, Bool.and_eq_true, not_and, and_imp] + cases h₁ : BParsec.nextByte pos₀ + case mk pos₂ res => + cases res <;> simp only [BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false, false_implies] + case ok opt => + cases opt <;> simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false, + false_implies] + case some c₀ => + replace h₁ := BParsec.nextByte.progress h₁ + split + · simp only [BParsec.pure, BParsec.ParseResult.mk.injEq, Except.ok.injEq, and_imp] + intro h₀ ; subst pos₂ + simp only [h₁, Nat.lt_add_one, implies_true] + · split + · simp only [BParsec.bind] + cases h₂ : BParsec.nextByte pos₂ + case mk pos₃ res => + cases res <;> simp only [BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false, + false_implies] + case ok opt => + cases opt <;> simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, + and_false, false_implies] + case some c₁ => + replace h₂ := BParsec.nextByte.progress h₂ + split + · split + simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false, + false_implies] + split + · simp only [BParsec.pure, BParsec.ParseResult.mk.injEq, Except.ok.injEq, and_imp] + intro h₀ ; subst pos₃ + simp only [h₁, h₂] + omega + · simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false, + false_implies] + · simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false, + false_implies] + · split + · simp only [BParsec.bind] + cases h₂ : BParsec.nextByte pos₂ + case mk pos₃ res => + cases res <;> simp only [BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false, + false_implies] + case ok opt₁ => + cases h₃ : BParsec.nextByte pos₃ + case mk pos₄ res => + cases res <;> simp only [BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false, + false_implies] + case ok opt₂ => + split + · simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false, + false_implies] + · simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false, + false_implies] + · split + · simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false, + false_implies] + · replace h₂ := BParsec.nextByte.progress h₂ + replace h₃ := BParsec.nextByte.progress h₃ + split + · simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, + and_false, false_implies] + · split + · simp only [BParsec.pure, BParsec.ParseResult.mk.injEq, Except.ok.injEq, + and_imp] + intro h₀ ; subst pos₄ + omega + · simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, + and_false, false_implies] + · split + · simp only [BParsec.bind] + cases h₂ : BParsec.nextByte pos₂ + case mk pos₃ res => + cases res <;> simp only [BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false, + false_implies] + case ok opt₁ => + cases h₃ : BParsec.nextByte pos₃ + case mk pos₄ res => + cases res <;> simp only [BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false, + false_implies] + case ok opt₃ => + cases h₄ : BParsec.nextByte pos₄ + case mk pos₅ res => + cases res <;> simp only [BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false, + false_implies] + case ok opt₄ => + split + · simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, + and_false, false_implies] + · simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, + and_false, false_implies] + · simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, + and_false, false_implies] + · split + · simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, + and_false, false_implies] + · replace h₂ := BParsec.nextByte.progress h₂ + replace h₃ := BParsec.nextByte.progress h₃ + replace h₄ := BParsec.nextByte.progress h₄ + split + · simp [h₂, h₃, h₄] + intro h₀ ; subst pos₅ + omega + · simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, + and_false, false_implies] + · simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false, + false_implies] + +/-- Restating the definition of `parseStringHelper`, but now thanks to the above +theorem, we can prove termination for `parseStringHelper` -/ +def parseStringHelper' (remaining : Nat) (r : String) : BParsec String := do if remaining = 0 then pure r else let empty ← BParsec.empty if empty then throw s!"Expected more packed uints, Size Remaining: {remaining}" else - let pos ← BParsec.pos λ pos₀ => - let result := utf8DecodeChar pos pos₀ - match H : result with - | { pos := pos₁, res := .ok ⟨c, elementSize⟩ } => - have _ : elementSize > 0 := utf8DecodeChar.sizeGt0 H - (do - BParsec.forward (elementSize) - parseStringHelper_unoptimized (remaining - elementSize) (r.push c)) pos₀ - | { pos := pos₁, res := .error msg } => { pos := pos₁, res := .error msg } + match h₀ : utf8DecodeChar pos₀ with + | { pos := pos₁, res := .ok c } => + have : pos₁.pos > pos₀.pos := utf8DecodeChar.progress h₀ + let elementSize := pos₁.pos - pos₀.pos + parseStringHelper' (remaining - elementSize) (r.push c) pos₁ + | { pos := pos₁, res := .error e } => { pos := pos₁, res := .error e } +termination_by remaining end Proto diff --git a/cedar-lean/Protobuf/String.lean b/cedar-lean/Protobuf/String.lean index 1f227868..53838142 100644 --- a/cedar-lean/Protobuf/String.lean +++ b/cedar-lean/Protobuf/String.lean @@ -25,57 +25,67 @@ Decode UTF-8 encoded strings with ByteArray Parser Combinators namespace Proto --- NOTE: Will panic if there's not enough bytes to determine the next character --- NOTE: Does not progress iterator --- Returns the size of the character as well +/-- + Decodes a UTF8 Char from the iterator, advancing it appropriately, and + throwing (not panicking) if the byte sequence at the iterator's current position + is invalid UTF8 or not long enough +-/ @[inline] -def utf8DecodeChar (i : Nat) : BParsec (Char × Nat) := do - let c ← BParsec.inspect λ pos => pos.data[i]! - if c &&& 0x80 == 0 then - let char := ⟨c.toUInt32, .inl (Nat.lt_trans c.1.2 (by decide))⟩ - pure ⟨char, 1⟩ - else if c &&& 0xe0 == 0xc0 then - let c1 ← BParsec.inspect λ pos => pos.data[i+1]! - if c1 &&& 0xc0 != 0x80 then throw s!"Not a valid UTF8 Char: {c} {c1}" else - let r := ((c &&& 0x1f).toUInt32 <<< 6) ||| (c1 &&& 0x3f).toUInt32 - if 0x80 > r then throw s!"Not a valid UTF8 Char: {c} {c1}" else +def utf8DecodeChar : BParsec Char := do + let c₀ ← BParsec.nextByte + match c₀ with + | none => throw "Not enough bytes for UTF8 Char" + | some c₀ => + if c₀ &&& 0x80 == 0 then + pure ⟨c₀.toUInt32, .inl (Nat.lt_trans c₀.1.2 (by decide))⟩ + else if c₀ &&& 0xe0 == 0xc0 then + let c₁ ← BParsec.nextByte + match c₁ with + | none => throw "Not enough bytes for UTF8 Char" + | some c₁ => + if c₁ &&& 0xc0 != 0x80 then throw s!"Not a valid UTF8 Char: {c₀} {c₁}" else + let r := ((c₀ &&& 0x1f).toUInt32 <<< 6) ||| (c₁ &&& 0x3f).toUInt32 + if 0x80 > r then throw s!"Not a valid UTF8 Char: {c₀} {c₁}" else if h : r < 0xd800 then - let char := ⟨r, .inl h⟩ - pure ⟨char, 2⟩ - else throw s!"Not valid UTF8 Char: {c} {c1}" - else if c &&& 0xf0 == 0xe0 then - let c1 ← BParsec.inspect λ pos => pos.data[i+1]! - let c2 ← BParsec.inspect λ pos => pos.data[i+2]! - if ¬(c1 &&& 0xc0 == 0x80 && c2 &&& 0xc0 == 0x80) then - throw s!"Not a valid UTF8 Char: {c} {c1} {c2}" + pure ⟨r, .inl h⟩ + else throw s!"Not valid UTF8 Char: {c₀} {c₁}" + else if c₀ &&& 0xf0 == 0xe0 then + let c₁ ← BParsec.nextByte + let c₂ ← BParsec.nextByte + match c₁, c₂ with + | none, _ | _, none => throw "Not enough bytes for UTF8 Char" + | some c₁, some c₂ => + if ¬(c₁ &&& 0xc0 == 0x80 && c₂ &&& 0xc0 == 0x80) then + throw s!"Not a valid UTF8 Char: {c₀} {c₁} {c₂}" else let r := - ((c &&& 0x0f).toUInt32 <<< 12) ||| - ((c1 &&& 0x3f).toUInt32 <<< 6) ||| - (c2 &&& 0x3f).toUInt32 - if (0x800 > r) then throw s!"Not a valid UTF8 Char: {c} {c1} {c2}" else + ((c₀ &&& 0x0f).toUInt32 <<< 12) ||| + ((c₁ &&& 0x3f).toUInt32 <<< 6) ||| + (c₂ &&& 0x3f).toUInt32 + if (0x800 > r) then throw s!"Not a valid UTF8 Char: {c₀} {c₁} {c₂}" else if h : r < 0xd800 ∨ 0xdfff < r ∧ r < 0x110000 then - let char := ⟨r, h⟩ - pure ⟨char, 3⟩ - else throw s!"Not valid UTF8 Char: {c} {c1} {c2}" - else if c &&& 0xf8 == 0xf0 then - let c1 ← BParsec.inspect λ pos => pos.data[i+1]! - let c2 ← BParsec.inspect λ pos => pos.data[i+2]! - let c3 ← BParsec.inspect λ pos => pos.data[i+3]! - if ¬(c1 &&& 0xc0 == 0x80 && c2 &&& 0xc0 == 0x80 && c3 &&& 0xc0 == 0x80) then - throw s!"Not a valid UTF8 Char: {c} {c1} {c2} {c3}" + pure ⟨r, h⟩ + else throw s!"Not valid UTF8 Char: {c₀} {c₁} {c₂}" + else if c₀ &&& 0xf8 == 0xf0 then + let c₁ ← BParsec.nextByte + let c₂ ← BParsec.nextByte + let c₃ ← BParsec.nextByte + match c₁, c₂, c₃ with + | none, _, _ | _, none, _ | _, _, none => throw "Not enough bytes for UTF8 Char" + | some c₁, some c₂, some c₃ => + if ¬(c₁ &&& 0xc0 == 0x80 && c₂ &&& 0xc0 == 0x80 && c₃ &&& 0xc0 == 0x80) then + throw s!"Not a valid UTF8 Char: {c₀} {c₁} {c₂} {c₃}" else let r := - ((c &&& 0x07).toUInt32 <<< 18) ||| - ((c1 &&& 0x3f).toUInt32 <<< 12) ||| - ((c2 &&& 0x3f).toUInt32 <<< 6) ||| - (c3 &&& 0x3f).toUInt32 + ((c₀ &&& 0x07).toUInt32 <<< 18) ||| + ((c₁ &&& 0x3f).toUInt32 <<< 12) ||| + ((c₂ &&& 0x3f).toUInt32 <<< 6) ||| + (c₃ &&& 0x3f).toUInt32 if h : 0x10000 ≤ r ∧ r < 0x110000 then - let char := ⟨r, .inr ⟨Nat.lt_of_lt_of_le (by decide) h.1, h.2⟩⟩ - pure ⟨char, 4⟩ - else throw s!"Not valid UTF8 Char: {c} {c1} {c2} {c3}" + pure ⟨r, .inr ⟨Nat.lt_of_lt_of_le (by decide) h.1, h.2⟩⟩ + else throw s!"Not valid UTF8 Char: {c₀} {c₁} {c₂} {c₃}" else - throw s!"Not valid UTF8 Char: {c}" + throw s!"Not valid UTF8 Char: {c₀}" -- Progresses ByteArray.Iterator @@ -84,9 +94,10 @@ partial def parseStringHelper (remaining : Nat) (r : String) : BParsec String := if remaining = 0 then pure r else let empty ← BParsec.empty if empty then throw s!"Expected more packed uints, Size Remaining: {remaining}" else - let pos ← BParsec.pos - let ⟨c, elementSize⟩ ← utf8DecodeChar pos - BParsec.forward (elementSize) + let start_pos ← BParsec.pos + let c ← utf8DecodeChar + let end_pos ← BParsec.pos + let elementSize := end_pos - start_pos parseStringHelper (remaining - elementSize) (r.push c) @[inline] diff --git a/cedar-lean/Protobuf/Varint.lean b/cedar-lean/Protobuf/Varint.lean index f0b64fe4..51c3a4db 100644 --- a/cedar-lean/Protobuf/Varint.lean +++ b/cedar-lean/Protobuf/Varint.lean @@ -37,6 +37,11 @@ namespace Proto -- Does not progress iterator -- Has panic! indexing, should work towards adding needed proof +-- Probably the way to remove the panic! indexing would be to reorganize this module. +-- Instead of first searching to find the number of bytes we'll need to parse, +-- without progressing the iterator (`find_varint_size`), and then actually +-- parsing them and progressing the iterator (`parse_uint64` and friends), we +-- should do everything in one pass that progresses the iterator as it goes. private def find_end_of_varint_helper (n : Nat) : BParsec Nat := do let empty ← BParsec.empty if empty then throw "Expected more bytes" @@ -65,16 +70,15 @@ def find_varint_size : BParsec Nat := do pure (end_idx - start_idx) --- Note: Panic indexing used but may be able to remove with some work private def parse_uint64_helper (remaining : Nat) (p : Nat) (r : UInt64) : BParsec UInt64 := do if remaining = 0 then pure r else - let empty ← BParsec.empty - if empty then throw "Expected more bytes" else - let byte ← BParsec.inspect λ pos => pos.data[pos.pos]! - BParsec.next -- Progress iterator - have byte2 := clear_msb8 byte - have byte3 := byte2.toUInt64 <<< (7 * p.toUInt64) - parse_uint64_helper (remaining - 1) (p + 1) (r ||| byte3) + let byte ← BParsec.nextByte + match byte with + | none => throw "Expected more bytes" + | some byte => + have byte2 := clear_msb8 byte + have byte3 := byte2.toUInt64 <<< (7 * p.toUInt64) + parse_uint64_helper (remaining - 1) (p + 1) (r ||| byte3) @[inline] @@ -91,13 +95,13 @@ instance : Field UInt64 := { private def parse_uint32_helper (remaining : Nat) (p : Nat) (r : UInt32) : BParsec UInt32 := do if remaining = 0 then pure r else - let empty ← BParsec.empty -- NOTE: Might be able to remove if we add a hypotheses in the definition - if empty then throw "Expected more bytes" else - let byte ← BParsec.inspect λ pos => pos.data[pos.pos]! - BParsec.next -- Progress iterator - have byte2 := clear_msb8 byte - have byte3 := byte2.toUInt32 <<< (7 * p.toUInt32) - parse_uint32_helper (remaining - 1) (p + 1) (r ||| byte3) + let byte ← BParsec.nextByte + match byte with + | none => throw "Expected more bytes" + | some byte => + have byte2 := clear_msb8 byte + have byte3 := byte2.toUInt32 <<< (7 * p.toUInt32) + parse_uint32_helper (remaining - 1) (p + 1) (r ||| byte3) @[inline]