Skip to content

Commit

Permalink
Refactor types in partial evaluation (#379)
Browse files Browse the repository at this point in the history
Signed-off-by: Craig Disselkoen <[email protected]>
  • Loading branch information
cdisselkoen authored Jul 16, 2024
1 parent c6bc61c commit 23edc4e
Show file tree
Hide file tree
Showing 30 changed files with 1,058 additions and 1,604 deletions.
2 changes: 1 addition & 1 deletion cedar-lean/Cedar/Data/Map.lean
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace Cedar.Data

inductive Map (α : Type u) (β : Type v) where
| mk : List (α × β) -> Map α β
deriving Repr, DecidableEq, Repr, Inhabited
deriving Repr, DecidableEq, Inhabited

namespace Map

Expand Down
1 change: 0 additions & 1 deletion cedar-lean/Cedar/Partial.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import Cedar.Partial.Authorizer
import Cedar.Partial.Entities
import Cedar.Partial.Evaluator
import Cedar.Partial.Expr
import Cedar.Partial.Request
import Cedar.Partial.Response
import Cedar.Partial.Value
4 changes: 1 addition & 3 deletions cedar-lean/Cedar/Partial/Authorizer.lean
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,8 @@ def isAuthorized (req : Partial.Request) (entities : Partial.Entities) (policies
{
residuals := policies.filterMap λ policy => match Partial.evaluate policy.toExpr req entities with
| .ok (.value (.prim (.bool false))) => none
| .ok (.value v) => some (.residual policy.id policy.effect v.asPartialExpr)
| .ok (.residual r) => some (.residual policy.id policy.effect r)
| .ok pv => some (.residual policy.id policy.effect pv)
| .error e => some (.error policy.id e)
req,
entities,
}

Expand Down
151 changes: 117 additions & 34 deletions cedar-lean/Cedar/Partial/Evaluator.lean
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
-/

import Cedar.Partial.Entities
import Cedar.Partial.Expr
import Cedar.Partial.Request
import Cedar.Partial.Value
import Cedar.Spec.Evaluator
Expand All @@ -27,16 +26,19 @@ import Cedar.Spec.Value
namespace Cedar.Partial

open Cedar.Data
open Cedar.Spec (Attr BinaryOp EntityUID ExtFun Result UnaryOp Var intOrErr)
open Cedar.Spec (Attr BinaryOp EntityUID Expr ExtFun Result UnaryOp Var intOrErr)
open Cedar.Spec.Error

/-- Analogous to Spec.apply₁ but for partial values -/
def apply₁ (op₁ : UnaryOp) (pv : Partial.Value) : Result Partial.Value :=
match pv with
/--
Partial-evaluate `op₁ pv₁`. No analogue in Spec.Evaluator; this logic (that
sits between `Partial.evaluate` and `Spec.apply₁`) is not needed in the
equivalent Spec.Evaluator position
-/
def evaluateUnaryApp (op₁ : UnaryOp) : Partial.Value → Result Partial.Value
| .value v₁ => do
let val ← Spec.apply₁ op₁ v₁
.ok (.value val)
| .residual r => .ok (.residual (Partial.Expr.unaryApp op₁ r))
| pv => .ok (.residual (.unaryApp op₁ pv))

/-- Analogous to Spec.inₑ but for partial entities -/
def inₑ (uid₁ uid₂ : EntityUID) (es : Partial.Entities) : Bool :=
Expand All @@ -53,14 +55,14 @@ def apply₂ (op₂ : BinaryOp) (v₁ v₂ : Spec.Value) (es : Partial.Entities)
| .eq, _, _ => .ok (.value (v₁ == v₂))
| .less, .prim (.int i), .prim (.int j) => .ok (.value ((i < j): Bool))
| .lessEq, .prim (.int i), .prim (.int j) => .ok (.value ((i ≤ j): Bool))
| .add, .prim (.int i), .prim (.int j) => intOrErr (i.add? j) >>= λ x => .ok (.value x)
| .sub, .prim (.int i), .prim (.int j) => intOrErr (i.sub? j) >>= λ x => .ok (.value x)
| .mul, .prim (.int i), .prim (.int j) => intOrErr (i.mul? j) >>= λ x => .ok (.value x)
| .add, .prim (.int i), .prim (.int j) => do .ok (.value (← intOrErr (i.add? j)))
| .sub, .prim (.int i), .prim (.int j) => do .ok (.value (← intOrErr (i.sub? j)))
| .mul, .prim (.int i), .prim (.int j) => do .ok (.value (← intOrErr (i.mul? j)))
| .contains, .set vs₁, _ => .ok (.value (vs₁.contains v₂))
| .containsAll, .set vs₁, .set vs₂ => .ok (.value (vs₂.subset vs₁))
| .containsAny, .set vs₁, .set vs₂ => .ok (.value (vs₁.intersects vs₂))
| .mem, .prim (.entityUID uid₁), .prim (.entityUID uid₂) => .ok (.value (Partial.inₑ uid₁ uid₂ es))
| .mem, .prim (.entityUID uid₁), .set (vs) => Partial.inₛ uid₁ vs es >>= λ x => .ok (.value x)
| .mem, .prim (.entityUID uid₁), .set (vs) => do .ok (.value (← Partial.inₛ uid₁ vs es))
| _, _, _ => .error .typeError

/--
Expand All @@ -71,9 +73,7 @@ def apply₂ (op₂ : BinaryOp) (v₁ v₂ : Spec.Value) (es : Partial.Entities)
def evaluateBinaryApp (op₂ : BinaryOp) (pv₁ pv₂ : Partial.Value) (es : Partial.Entities) : Result Partial.Value :=
match (pv₁, pv₂) with
| (.value v₁, .value v₂) => Partial.apply₂ op₂ v₁ v₂ es
| (.value v₁, .residual r₂) => .ok (.residual (Partial.Expr.binaryApp op₂ v₁.asPartialExpr r₂))
| (.residual r₁, .value v₂) => .ok (.residual (Partial.Expr.binaryApp op₂ r₁ v₂.asPartialExpr))
| (.residual r₁, .residual r₂) => .ok (.residual (Partial.Expr.binaryApp op₂ r₁ r₂))
| (pv₁, pv₂) => .ok (.residual (.binaryApp op₂ pv₁ pv₂))

/-- Analogous to Spec.attrsOf but for lookup functions that return partial values -/
def attrsOf (v : Spec.Value) (lookup : EntityUID → Result (Map Attr Partial.Value)) : Result (Map Attr Partial.Value) :=
Expand All @@ -84,7 +84,7 @@ def attrsOf (v : Spec.Value) (lookup : EntityUID → Result (Map Attr Partial.Va

/-- Analogous to Spec.hasAttr but for partial entities -/
def hasAttr (v : Spec.Value) (a : Attr) (es : Partial.Entities) : Result Spec.Value := do
let r ← Partial.attrsOf v (fun uid => .ok (es.attrsOrEmpty uid))
let r ← Partial.attrsOf v (λ uid => .ok (es.attrsOrEmpty uid))
.ok (r.contains a)

/--
Expand All @@ -97,7 +97,7 @@ def evaluateHasAttr (pv : Partial.Value) (a : Attr) (es : Partial.Entities) : Re
| .value v₁ => do
let val ← Partial.hasAttr v₁ a es
.ok (.value val)
| .residual r => .ok (.residual (Partial.Expr.hasAttr r a)) -- TODO more precise: even though pv is a residual we may know concretely whether it contains the particular attr we care about
| .residual r => .ok (.residual (.hasAttr (.residual r) a)) -- could be more precise; see cedar-spec#395

/-- Analogous to Spec.getAttr but for partial entities -/
def getAttr (v : Spec.Value) (a : Attr) (es : Partial.Entities) : Result Partial.Value := do
Expand All @@ -110,9 +110,9 @@ def getAttr (v : Spec.Value) (a : Attr) (es : Partial.Entities) : Result Partial
Spec.Evaluator position
-/
def evaluateGetAttr (pv : Partial.Value) (a : Attr) (es : Partial.Entities) : Result Partial.Value := do
match pv with
| .value v₁ => Partial.getAttr v₁ a es
| .residual r => .ok (.residual (Partial.Expr.getAttr r a)) -- TODO more precise: pv will be a .residual if it contains any unknowns, but we might have a concrete value for the particular attr we care about
match pv with
| .value v₁ => Partial.getAttr v₁ a es
| .residual r => .ok (.residual (.getAttr (.residual r) a)) -- could be more precise; see cedar-spec#395

/-- Analogous to Spec.bindAttr but for partial values -/
def bindAttr (a : Attr) (res : Result Partial.Value) : Result (Attr × Partial.Value) := do
Expand All @@ -127,18 +127,18 @@ def evaluateVar (v : Var) (req : Partial.Request) : Result Partial.Value :=
| .resource => .ok req.resource
| .context => match req.context.mapMOnValues λ v => match v with | .value v => some v | .residual _ => none with
| some m => .ok (.value m)
| none => .ok (.residual (Partial.Expr.record (req.context.mapOnValues Partial.Value.asPartialExpr).kvs))
| none => .ok (.residual (.record req.context.kvs))

/-- Call an extension function with partial values as arguments -/
def evaluateCall (xfn : ExtFun) (args : List Partial.Value) : Result Partial.Value :=
match args.mapM (λ pval => match pval with | .value v => some v | .residual _ => none) with
| some vs => do
let val ← Spec.call xfn vs
.ok (.value val)
| none => .ok (.residual (Partial.Expr.call xfn (args.map Partial.Value.asPartialExpr)))
| none => .ok (.residual (.call xfn args))

/-- Analogous to Spec.evaluate but performs partial evaluation on partial expr/request/entities -/
def evaluate (x : Partial.Expr) (req : Partial.Request) (es : Partial.Entities) : Result Partial.Value :=
/-- Analogous to Spec.evaluate but performs partial evaluation given partial request/entities -/
def evaluate (x : Expr) (req : Partial.Request) (es : Partial.Entities) : Result Partial.Value :=
match x with
| .lit l => .ok (.value l)
| .var v => evaluateVar v req
Expand All @@ -148,7 +148,7 @@ def evaluate (x : Partial.Expr) (req : Partial.Request) (es : Partial.Entities)
| .value v => do
let b ← v.asBool
if b then Partial.evaluate x₂ req es else Partial.evaluate x₃ req es
| .residual r => .ok (.residual (Partial.Expr.ite r x₂ x₃))
| .residual r => .ok (.residual (.ite (.residual r) (x₂.substToPartialValue req) (x₃.substToPartialValue req)))
| .and x₁ x₂ => do
let pval ← Partial.evaluate x₁ req es
match pval with
Expand All @@ -161,7 +161,7 @@ def evaluate (x : Partial.Expr) (req : Partial.Request) (es : Partial.Entities)
let b ← v.asBool
.ok (.value b)
| .residual r => .ok (.residual r)
| .residual r => .ok (.residual (Partial.Expr.and r x₂))
| .residual r => .ok (.residual (.and (.residual r) (x₂.substToPartialValue req)))
| .or x₁ x₂ => do
let pval ← Partial.evaluate x₁ req es
match pval with
Expand All @@ -174,10 +174,10 @@ def evaluate (x : Partial.Expr) (req : Partial.Request) (es : Partial.Entities)
let b ← v.asBool
.ok (.value b)
| .residual r => .ok (.residual r)
| .residual r => .ok (.residual (Partial.Expr.or r x₂))
| .residual r => .ok (.residual (.or (.residual r) (x₂.substToPartialValue req)))
| .unaryApp op₁ x₁ => do
let pval ← Partial.evaluate x₁ req es
Partial.apply₁ op₁ pval
evaluateUnaryApp op₁ pval
| .binaryApp op₂ x₁ x₂ => do
let pval₁ ← Partial.evaluate x₁ req es
let pval₂ ← Partial.evaluate x₂ req es
Expand All @@ -189,16 +189,99 @@ def evaluate (x : Partial.Expr) (req : Partial.Request) (es : Partial.Entities)
let pval₁ ← Partial.evaluate x₁ req es
evaluateGetAttr pval₁ a es
| .set xs => do
let vs ← xs.mapM₁ (fun ⟨x₁, _⟩ => Partial.evaluate x₁ req es)
match vs.mapM (fun pval => match pval with | .value v => some v | .residual _ => none) with
let pvs ← xs.mapM₁ (λ ⟨x₁, _⟩ => Partial.evaluate x₁ req es)
match pvs.mapM (λ pval => match pval with | .value v => some v | .residual _ => none) with
| some vs => .ok (.value (Set.make vs))
| none => .ok (.residual (Partial.Expr.set (vs.map Partial.Value.asPartialExpr)))
| none => .ok (.residual (.set pvs))
| .record axs => do
let avs ← axs.mapM₂ (fun ⟨(a₁, x₁), _⟩ => Partial.bindAttr a₁ (Partial.evaluate x₁ req es))
match avs.mapM (fun (a, pval) => match pval with | .value v => some (a, v) | .residual _ => none) with
let apvs ← axs.mapM₂ (λ ⟨(a₁, x₁), _⟩ => Partial.bindAttr a₁ (Partial.evaluate x₁ req es))
match apvs.mapM (λ (a, pval) => match pval with | .value v => some (a, v) | .residual _ => none) with
| some avs => .ok (.value (Map.make avs))
| none => .ok (.residual (Partial.Expr.record (avs.map fun (a, v) => (a, v.asPartialExpr))))
| none => .ok (.residual (.record apvs))
| .call xfn xs => do
let pvs ← xs.mapM₁ (fun ⟨x₁, _⟩ => Partial.evaluate x₁ req es)
let pvs ← xs.mapM₁ (λ ⟨x₁, _⟩ => Partial.evaluate x₁ req es)
evaluateCall xfn pvs
| .unknown u => .ok (.residual (Partial.Expr.unknown u))

mutual

/--
Evaluate a `Partial.Value`, possibly reducing it. For instance, `3 + 5` will
evaluate to `8`. This can be relevant if a substitution was recently made on
the `Partial.Value`.
-/
def evaluateValue (pv : Partial.Value) (es : Partial.Entities) : Result Partial.Value :=
match pv with
| .value v => .ok (.value v)
| .residual r => evaluateResidual r es

/--
Evaluate a `ResidualExpr`, possibly reducing it. For instance, `3 + 5` will
evaluate to `8`. This can be relevant if a substitution was recently made on
the `ResidualExpr`.
-/
def evaluateResidual (x : Partial.ResidualExpr) (es : Partial.Entities) : Result Partial.Value :=
match x with
| .unknown u => .ok u
| .ite pv₁ pv₂ pv₃ => do
let pv₁' ← Partial.evaluateValue pv₁ es
match pv₁' with
| .value v₁' => do
let b ← v₁'.asBool
if b then Partial.evaluateValue pv₂ es else Partial.evaluateValue pv₃ es
| .residual r₁' => .ok (.residual (.ite (.residual r₁') pv₂ pv₃))
| .and pv₁ pv₂ => do
let pv₁' ← Partial.evaluateValue pv₁ es
match pv₁' with
| .value v₁' => do
let b ← v₁'.asBool
if !b then .ok (.value b) else do
let pv₂' ← Partial.evaluateValue pv₂ es
match pv₂' with
| .value v₂' => do
let b ← v₂'.asBool
.ok (.value b)
| .residual r₂' => .ok (.residual r₂')
| .residual r₁' => .ok (.residual (.and (.residual r₁') pv₂))
| .or pv₁ pv₂ => do
let pv₁' ← Partial.evaluateValue pv₁ es
match pv₁' with
| .value v₁' => do
let b ← v₁'.asBool
if b then .ok (.value b) else do
let pv₂' ← Partial.evaluateValue pv₂ es
match pv₂' with
| .value v₂' => do
let b ← v₂'.asBool
.ok (.value b)
| .residual r₂' => .ok (.residual r₂')
| .residual r₁' => .ok (.residual (.or (.residual r₁') pv₂))
| .unaryApp op₁ pv₁ => do
let pv₁' ← Partial.evaluateValue pv₁ es
evaluateUnaryApp op₁ pv₁'
| .binaryApp op₂ pv₁ pv₂ => do
let pv₁' ← Partial.evaluateValue pv₁ es
let pv₂' ← Partial.evaluateValue pv₂ es
evaluateBinaryApp op₂ pv₁' pv₂' es
| .hasAttr pv₁ a => do
let pv₁' ← Partial.evaluateValue pv₁ es
evaluateHasAttr pv₁' a es
| .getAttr pv₁ a => do
let pv₁' ← Partial.evaluateValue pv₁ es
evaluateGetAttr pv₁' a es
| .set pvs => do
let pvs' ← pvs.mapM₁ (λ ⟨pv, _⟩ => Partial.evaluateValue pv es)
match pvs'.mapM (λ pv => match pv with | .value v => some v | .residual _ => none) with
| some vs => .ok (.value (Set.make vs))
| none => .ok (.residual (.set pvs'))
| .record apvs => do
let apvs' ← apvs.mapM₂ (λ ⟨(a, pv), _⟩ => Partial.bindAttr a (Partial.evaluateValue pv es))
match apvs'.mapM (λ (a, pv) => match pv with | .value v => some (a, v) | .residual _ => none) with
| some avs => .ok (.value (Map.make avs))
| none => .ok (.residual (.record apvs'))
| .call xfn pvs => do
let pvs' ← pvs.mapM₁ (λ ⟨pv, _⟩ => Partial.evaluateValue pv es)
evaluateCall xfn pvs'

end

end Cedar.Partial
Loading

0 comments on commit 23edc4e

Please sign in to comment.