Skip to content

Commit

Permalink
remove most of the panics and panic-indexing from Protobuf module
Browse files Browse the repository at this point in the history
Signed-off-by: Craig Disselkoen <[email protected]>
  • Loading branch information
cdisselkoen committed Nov 27, 2024
1 parent 99a9c77 commit cbe3eeb
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 118 deletions.
9 changes: 8 additions & 1 deletion cedar-lean/Protobuf/BParsec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,21 @@ def run! [Inhabited α] (p : BParsec α) (ba : ByteArray) : α :=

-- Iterator wrappers

/-- Advance the iterator -/
/-- Advance the iterator one byte, discarding it -/
@[inline]
def next : BParsec Unit := λ pos =>
{ pos := pos.next, res := .ok () }

/-- Advance the iterator `n` bytes, discarding them -/
@[inline]
def forward (n : Nat) : BParsec Unit := λ pos =>
{ pos := pos.forward n, res := .ok () }

/-- Advance the iterator one byte, returning it, or `None` if the iterator was empty -/
@[inline]
def nextByte : BParsec (Option UInt8) := λ pos =>
{ pos := pos.next, res := .ok pos.data[pos.pos]? }

/-- Return some computation on the current iterator state, without changing the state -/
@[inline]
def inspect (f : ByteArray.ByteIterator → α) : BParsec α := λ pos =>
Expand All @@ -130,6 +136,7 @@ def remaining : BParsec Nat := inspect ByteArray.ByteIterator.remaining
@[inline]
def empty : BParsec Bool := inspect ByteArray.ByteIterator.empty

/-- Get the current iterator position, as a `Nat` -/
@[inline]
def pos : BParsec Nat := inspect ByteArray.ByteIterator.pos

Expand Down
195 changes: 138 additions & 57 deletions cedar-lean/Protobuf/Proofs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,13 @@ theorem foldl_iterator_progress {f : BParsec α} {g : β → α → β} {remaini
unfold foldlHelper at H
have H2 : ¬(ni = 0) := by omega
rw [if_neg H2] at H
simp only [Bind.bind, bind, pos, inspect] at H
simp only [Bind.bind, bind, pos, inspect, throw_eq_fail] at H
cases H3 : f pos₀ ; simp only [H3] at H ; rename_i pos₂ res₂
cases res₂ <;> simp only [ParseResult.mk.injEq, reduceCtorEq, and_false] at H
case ok res₂ =>
by_cases H4 : (pos₂.pos - pos₀.pos = 0)
case pos =>
simp only [H4, reduceIte, throw_eq_fail, fail, ParseResult.mk.injEq, reduceCtorEq, and_false] at H
simp only [H4, reduceIte, fail, ParseResult.mk.injEq, reduceCtorEq, and_false] at H
case neg =>
simp only [H4, reduceIte] at H
let ni2 := ni - (pos₂.pos - pos₀.pos)
Expand All @@ -194,65 +194,146 @@ namespace Proto

instance : DecidableEq (BParsec.ParseResult (Char × Nat)) := by apply inferInstance

theorem utf8DecodeChar.sizeGt0 {pos₀ pos₁ : ByteArray.ByteIterator} {i n : Nat} {c : Char}
(H : utf8DecodeChar i pos₀ = { pos := pos₁, res := .ok ⟨c, n⟩ }) :
n > 0
/-- Proof that `BParsec.nextByte` always progresses the iterator exactly 1 byte if it succeeds -/
theorem BParsec.nextByte.progress {pos₀ pos₁ : ByteArray.ByteIterator} {u : UInt8}
(h₀ : BParsec.nextByte pos₀ = { pos := pos₁, res := .ok (some u) }) :
pos₁.pos = pos₀.pos + 1
:= by
unfold utf8DecodeChar at H
simp only [bind, BParsec.bind, BParsec.inspect, beq_iff_eq, pure, bne_iff_ne, ne_eq,
BParsec.throw_eq_fail, gt_iff_lt, ite_not, Bool.and_eq_true, not_and, and_imp] at H
split at H
· simp only [BParsec.pure, BParsec.ParseResult.mk.injEq, Except.ok.injEq, Prod.mk.injEq] at H
omega
· split at H
· simp only [BParsec.bind, BParsec.inspect] at H
split at H
· split at H
· simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false] at H
· split at H
· simp only [BParsec.pure, BParsec.ParseResult.mk.injEq, Except.ok.injEq,
Prod.mk.injEq] at H
omega
· simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false] at H
· simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false] at H
· split at H
· simp only [BParsec.bind, BParsec.inspect] at H
split at H
· simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false] at H
· split at H
· simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false] at H
· split at H
· simp only [BParsec.pure, BParsec.ParseResult.mk.injEq, Except.ok.injEq,
Prod.mk.injEq] at H
omega
· simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false] at H
· split at H
· simp only [BParsec.bind, BParsec.inspect] at H
split at H
· simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false] at H
· split at H
· simp only [BParsec.pure, BParsec.ParseResult.mk.injEq, Except.ok.injEq,
Prod.mk.injEq] at H
omega
· simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false] at H
· simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false] at H

/-- Uglier version of `parseStringHelper` which is functionally equivalent to
`parseStringHelper`, but has a termination proof, unlike `parseStringHelper`.
-/
private def parseStringHelper_unoptimized (remaining : Nat) (r : String) : BParsec String := do
simp only [BParsec.nextByte, ByteArray.ByteIterator.next, BParsec.ParseResult.mk.injEq,
Except.ok.injEq] at h₀
have ⟨h₀, _⟩ := h₀
subst pos₁
simp only

/-- Proof that `utf8DecodeChar` always progresses the iterator at least 1 byte if it succeeds -/
theorem utf8DecodeChar.progress {pos₀ pos₁ : ByteArray.ByteIterator} {c : Char}
(h₀ : utf8DecodeChar pos₀ = { pos := pos₁, res := .ok c }) :
pos₁.pos > pos₀.pos
:= by
revert h₀
unfold utf8DecodeChar
simp only [bind, BParsec.bind, BParsec.throw_eq_fail, beq_iff_eq, pure,
bne_iff_ne, ne_eq, gt_iff_lt, ite_not, Bool.and_eq_true, not_and, and_imp]
cases h₁ : BParsec.nextByte pos₀
case mk pos₂ res =>
cases res <;> simp only [BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false, false_implies]
case ok opt =>
cases opt <;> simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false,
false_implies]
case some c₀ =>
replace h₁ := BParsec.nextByte.progress h₁
split
· simp only [BParsec.pure, BParsec.ParseResult.mk.injEq, Except.ok.injEq, and_imp]
intro h₀ ; subst pos₂
simp only [h₁, Nat.lt_add_one, implies_true]
· split
· simp only [BParsec.bind]
cases h₂ : BParsec.nextByte pos₂
case mk pos₃ res =>
cases res <;> simp only [BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false,
false_implies]
case ok opt =>
cases opt <;> simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq,
and_false, false_implies]
case some c₁ =>
replace h₂ := BParsec.nextByte.progress h₂
split
· split
simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false,
false_implies]
split
· simp only [BParsec.pure, BParsec.ParseResult.mk.injEq, Except.ok.injEq, and_imp]
intro h₀ ; subst pos₃
simp only [h₁, h₂]
omega
· simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false,
false_implies]
· simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false,
false_implies]
· split
· simp only [BParsec.bind]
cases h₂ : BParsec.nextByte pos₂
case mk pos₃ res =>
cases res <;> simp only [BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false,
false_implies]
case ok opt₁ =>
cases h₃ : BParsec.nextByte pos₃
case mk pos₄ res =>
cases res <;> simp only [BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false,
false_implies]
case ok opt₂ =>
split
· simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false,
false_implies]
· simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false,
false_implies]
· split
· simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false,
false_implies]
· replace h₂ := BParsec.nextByte.progress h₂
replace h₃ := BParsec.nextByte.progress h₃
split
· simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq,
and_false, false_implies]
· split
· simp only [BParsec.pure, BParsec.ParseResult.mk.injEq, Except.ok.injEq,
and_imp]
intro h₀ ; subst pos₄
omega
· simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq,
and_false, false_implies]
· split
· simp only [BParsec.bind]
cases h₂ : BParsec.nextByte pos₂
case mk pos₃ res =>
cases res <;> simp only [BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false,
false_implies]
case ok opt₁ =>
cases h₃ : BParsec.nextByte pos₃
case mk pos₄ res =>
cases res <;> simp only [BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false,
false_implies]
case ok opt₃ =>
cases h₄ : BParsec.nextByte pos₄
case mk pos₅ res =>
cases res <;> simp only [BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false,
false_implies]
case ok opt₄ =>
split
· simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq,
and_false, false_implies]
· simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq,
and_false, false_implies]
· simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq,
and_false, false_implies]
· split
· simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq,
and_false, false_implies]
· replace h₂ := BParsec.nextByte.progress h₂
replace h₃ := BParsec.nextByte.progress h₃
replace h₄ := BParsec.nextByte.progress h₄
split
· simp [h₂, h₃, h₄]
intro h₀ ; subst pos₅
omega
· simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq,
and_false, false_implies]
· simp only [BParsec.fail, BParsec.ParseResult.mk.injEq, reduceCtorEq, and_false,
false_implies]

/-- Restating the definition of `parseStringHelper`, but now thanks to the above
theorem, we can prove termination for `parseStringHelper` -/
def parseStringHelper' (remaining : Nat) (r : String) : BParsec String := do
if remaining = 0 then pure r else
let empty ← BParsec.empty
if empty then throw s!"Expected more packed uints, Size Remaining: {remaining}" else
let pos ← BParsec.pos
λ pos₀ =>
let result := utf8DecodeChar pos pos₀
match H : result with
| { pos := pos₁, res := .ok ⟨c, elementSize⟩ } =>
have _ : elementSize > 0 := utf8DecodeChar.sizeGt0 H
(do
BParsec.forward (elementSize)
parseStringHelper_unoptimized (remaining - elementSize) (r.push c)) pos₀
| { pos := pos₁, res := .error msg } => { pos := pos₁, res := .error msg }
match h₀ : utf8DecodeChar pos₀ with
| { pos := pos₁, res := .ok c } =>
have : pos₁.pos > pos₀.pos := utf8DecodeChar.progress h₀
let elementSize := pos₁.pos - pos₀.pos
parseStringHelper' (remaining - elementSize) (r.push c) pos₁
| { pos := pos₁, res := .error e } => { pos := pos₁, res := .error e }
termination_by remaining

end Proto
101 changes: 56 additions & 45 deletions cedar-lean/Protobuf/String.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,57 +25,67 @@ Decode UTF-8 encoded strings with ByteArray Parser Combinators

namespace Proto

-- NOTE: Will panic if there's not enough bytes to determine the next character
-- NOTE: Does not progress iterator
-- Returns the size of the character as well
/--
Decodes a UTF8 Char from the iterator, advancing it appropriately, and
throwing (not panicking) if the byte sequence at the iterator's current position
is invalid UTF8 or not long enough
-/
@[inline]
def utf8DecodeChar (i : Nat) : BParsec (Char × Nat) := do
let c ← BParsec.inspect λ pos => pos.data[i]!
if c &&& 0x80 == 0 then
let char := ⟨c.toUInt32, .inl (Nat.lt_trans c.1.2 (by decide))⟩
pure ⟨char, 1
else if c &&& 0xe0 == 0xc0 then
let c1 ← BParsec.inspect λ pos => pos.data[i+1]!
if c1 &&& 0xc0 != 0x80 then throw s!"Not a valid UTF8 Char: {c} {c1}" else
let r := ((c &&& 0x1f).toUInt32 <<< 6) ||| (c1 &&& 0x3f).toUInt32
if 0x80 > r then throw s!"Not a valid UTF8 Char: {c} {c1}" else
def utf8DecodeChar : BParsec Char := do
let c₀ ← BParsec.nextByte
match c₀ with
| none => throw "Not enough bytes for UTF8 Char"
| some c₀ =>
if c₀ &&& 0x80 == 0 then
pure ⟨c₀.toUInt32, .inl (Nat.lt_trans c₀.1.2 (by decide))⟩
else if c₀ &&& 0xe0 == 0xc0 then
let c₁ ← BParsec.nextByte
match c₁ with
| none => throw "Not enough bytes for UTF8 Char"
| some c₁ =>
if c₁ &&& 0xc0 != 0x80 then throw s!"Not a valid UTF8 Char: {c₀} {c₁}" else
let r := ((c₀ &&& 0x1f).toUInt32 <<< 6) ||| (c₁ &&& 0x3f).toUInt32
if 0x80 > r then throw s!"Not a valid UTF8 Char: {c₀} {c₁}" else
if h : r < 0xd800 then
let char := ⟨r, .inl h⟩
pure ⟨char, 2
else throw s!"Not valid UTF8 Char: {c} {c1}"
else if c &&& 0xf0 == 0xe0 then
let c1 ← BParsec.inspect λ pos => pos.data[i+1]!
let c2 ← BParsec.inspect λ pos => pos.data[i+2]!
if ¬(c1 &&& 0xc0 == 0x80 && c2 &&& 0xc0 == 0x80) then
throw s!"Not a valid UTF8 Char: {c} {c1} {c2}"
pure ⟨r, .inl h⟩
else throw s!"Not valid UTF8 Char: {c₀} {c₁}"
else if c₀ &&& 0xf0 == 0xe0 then
let c₁ ← BParsec.nextByte
let c₂ ← BParsec.nextByte
match c₁, c₂ with
| none, _ | _, none => throw "Not enough bytes for UTF8 Char"
| some c₁, some c₂ =>
if ¬(c₁ &&& 0xc0 == 0x80 && c₂ &&& 0xc0 == 0x80) then
throw s!"Not a valid UTF8 Char: {c₀} {c₁} {c₂}"
else
let r :=
((c &&& 0x0f).toUInt32 <<< 12) |||
((c1 &&& 0x3f).toUInt32 <<< 6) |||
(c2 &&& 0x3f).toUInt32
if (0x800 > r) then throw s!"Not a valid UTF8 Char: {c} {c1} {c2}" else
((c &&& 0x0f).toUInt32 <<< 12) |||
((c₁ &&& 0x3f).toUInt32 <<< 6) |||
(c₂ &&& 0x3f).toUInt32
if (0x800 > r) then throw s!"Not a valid UTF8 Char: {c} {c₁} {c₂}" else
if h : r < 0xd8000xdfff < r ∧ r < 0x110000 then
let char := ⟨r, h⟩
pure ⟨char, 3
else throw s!"Not valid UTF8 Char: {c} {c1} {c2}"
else if c &&& 0xf8 == 0xf0 then
let c1 ← BParsec.inspect λ pos => pos.data[i+1]!
let c2 ← BParsec.inspect λ pos => pos.data[i+2]!
let c3 ← BParsec.inspect λ pos => pos.data[i+3]!
if ¬(c1 &&& 0xc0 == 0x80 && c2 &&& 0xc0 == 0x80 && c3 &&& 0xc0 == 0x80) then
throw s!"Not a valid UTF8 Char: {c} {c1} {c2} {c3}"
pure ⟨r, h⟩
else throw s!"Not valid UTF8 Char: {c₀} {c₁} {c₂}"
else if c₀ &&& 0xf8 == 0xf0 then
let c₁ ← BParsec.nextByte
let c₂ ← BParsec.nextByte
let c₃ ← BParsec.nextByte
match c₁, c₂, c₃ with
| none, _, _ | _, none, _ | _, _, none => throw "Not enough bytes for UTF8 Char"
| some c₁, some c₂, some c₃ =>
if ¬(c₁ &&& 0xc0 == 0x80 && c₂ &&& 0xc0 == 0x80 && c₃ &&& 0xc0 == 0x80) then
throw s!"Not a valid UTF8 Char: {c₀} {c₁} {c₂} {c₃}"
else
let r :=
((c &&& 0x07).toUInt32 <<< 18) |||
((c1 &&& 0x3f).toUInt32 <<< 12) |||
((c2 &&& 0x3f).toUInt32 <<< 6) |||
(c3 &&& 0x3f).toUInt32
((c &&& 0x07).toUInt32 <<< 18) |||
((c₁ &&& 0x3f).toUInt32 <<< 12) |||
((c₂ &&& 0x3f).toUInt32 <<< 6) |||
(c₃ &&& 0x3f).toUInt32
if h : 0x10000 ≤ r ∧ r < 0x110000 then
let char := ⟨r, .inr ⟨Nat.lt_of_lt_of_le (by decide) h.1, h.2⟩⟩
pure ⟨char, 4
else throw s!"Not valid UTF8 Char: {c} {c1} {c2} {c3}"
pure ⟨r, .inr ⟨Nat.lt_of_lt_of_le (by decide) h.1, h.2⟩⟩
else throw s!"Not valid UTF8 Char: {c₀} {c₁} {c₂} {c₃}"
else
throw s!"Not valid UTF8 Char: {c}"
throw s!"Not valid UTF8 Char: {c}"


-- Progresses ByteArray.Iterator
Expand All @@ -84,9 +94,10 @@ partial def parseStringHelper (remaining : Nat) (r : String) : BParsec String :=
if remaining = 0 then pure r else
let empty ← BParsec.empty
if empty then throw s!"Expected more packed uints, Size Remaining: {remaining}" else
let pos ← BParsec.pos
let ⟨c, elementSize⟩ ← utf8DecodeChar pos
BParsec.forward (elementSize)
let start_pos ← BParsec.pos
let c ← utf8DecodeChar
let end_pos ← BParsec.pos
let elementSize := end_pos - start_pos
parseStringHelper (remaining - elementSize) (r.push c)

@[inline]
Expand Down
Loading

0 comments on commit cbe3eeb

Please sign in to comment.