From a00876624eebdbb37bf01dd4f2710b970c560a4d Mon Sep 17 00:00:00 2001 From: oflatt Date: Wed, 11 Sep 2024 12:49:58 -0700 Subject: [PATCH 1/5] entity loader file Signed-off-by: oflatt --- cedar-policy-validator/src/entity_loader.rs | 7 +++++++ cedar-policy-validator/src/lib.rs | 2 ++ 2 files changed, 9 insertions(+) create mode 100644 cedar-policy-validator/src/entity_loader.rs diff --git a/cedar-policy-validator/src/entity_loader.rs b/cedar-policy-validator/src/entity_loader.rs new file mode 100644 index 000000000..2634847be --- /dev/null +++ b/cedar-policy-validator/src/entity_loader.rs @@ -0,0 +1,7 @@ + + + + +trait EntityLoader { + fn load_entities(&mut self, entities: Vec<(EntityUID, +} \ No newline at end of file diff --git a/cedar-policy-validator/src/lib.rs b/cedar-policy-validator/src/lib.rs index d23ff2d42..88c8ee09c 100644 --- a/cedar-policy-validator/src/lib.rs +++ b/cedar-policy-validator/src/lib.rs @@ -43,6 +43,8 @@ mod entity_manifest_analysis; mod entity_manifest_type_annotations; #[cfg(feature = "entity-manifest")] pub mod entity_slicing; +#[cfg(feature = "entity-manifest")] +pub mod entity_loader; mod err; pub use err::*; mod coreschema; From e39124c61a2efce17d12adc49d7d6be351d5d360 Mon Sep 17 00:00:00 2001 From: oflatt Date: Thu, 12 Sep 2024 17:05:43 -0700 Subject: [PATCH 2/5] implement entity loader api, untested Signed-off-by: oflatt --- cedar-policy-core/src/ast/entity.rs | 7 + cedar-policy-validator/src/entity_loader.rs | 330 ++++++++++++++++++- cedar-policy-validator/src/entity_slicing.rs | 37 +++ 3 files changed, 371 insertions(+), 3 deletions(-) diff --git a/cedar-policy-core/src/ast/entity.rs b/cedar-policy-core/src/ast/entity.rs index 38be60363..474f29d10 100644 --- a/cedar-policy-core/src/ast/entity.rs +++ b/cedar-policy-core/src/ast/entity.rs @@ -492,6 +492,13 @@ impl Entity { pub(crate) fn add_ancestor(&mut self, uid: EntityUID) { self.ancestors.insert(uid); } + + /// Add a set of ancestors to this `Entity`. + /// TODO why is `add_ancestor` pub(crate) instead of pub, and should this be too? + pub fn add_ancestors(&mut self, ancestors: HashSet) { + self.ancestors.extend(ancestors); + } + /// Mark the given `UID` as an ancestor of this `Entity` #[cfg(fuzzing)] pub fn add_ancestor(&mut self, uid: EntityUID) { diff --git a/cedar-policy-validator/src/entity_loader.rs b/cedar-policy-validator/src/entity_loader.rs index 2634847be..b6f8c63ae 100644 --- a/cedar-policy-validator/src/entity_loader.rs +++ b/cedar-policy-validator/src/entity_loader.rs @@ -1,7 +1,331 @@ +/* + * Copyright Cedar Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +//! Entity Loader API implementation +//! Loads entities based on the entity manifest. +use std::{ + collections::{BTreeMap, HashMap, HashSet}, + sync::Arc, +}; +use cedar_policy_core::{ + ast::{Context, Entity, EntityUID, Literal, PartialValue, Request, Value, ValueKind, Var}, + entities::{Entities, NoEntitiesSchema, TCComputation}, + extensions::Extensions, +}; +use smol_str::SmolStr; -trait EntityLoader { - fn load_entities(&mut self, entities: Vec<(EntityUID, -} \ No newline at end of file +use crate::{ + entity_manifest::{ + AccessTrie, EntityManifest, EntityRoot, PartialRequestError, RootAccessTrie, + }, + entity_slicing::{ + EntitySliceError, PartialContextError, PartialEntityError, WrongNumberOfEntitiesError, + }, +}; + +/// A request that an entity be loaded. +/// Optionally, instead of loading the full entity the `access_trie` +/// may be used to load only some fields of the entity. +#[derive(Debug)] +pub struct EntityRequest<'a> { + /// The id of the entity requested + entity_id: EntityUID, + /// The fieds of the entity requested + access_trie: &'a AccessTrie, +} + +/// A request that the ancestors of an entity be loaded. +/// Optionally, the `ancestors` set may be used to just load ancestors in the set. +#[derive(Debug)] +pub struct AncestorsRequest { + /// The id of the entity whose ancestors are requested + entity_id: EntityUID, + /// The ancestors that are requested, if present + ancestors: HashSet, +} + +/// Implement [`EntityLoader`] to easily load entities using their ids +/// into a Cedar [`Entities`] store. +/// The most basic implementation loads full entities (including all ancestors) in the `load_entities` method and loads the context in the `load_context` method. +/// More advanced implementations make use of the [`AccessTrie`]s provided to load partial entities and context, as well as the `load_ancestors` method to load particular ancestors. +pub trait EntityLoader { + /// Loads the concrete context based on the request. + /// Only context attributes mentioned in the `access_trie` are required. + fn load_context(&mut self, access_trie: AccessTrie) -> Context; + + /// `load_entities` is called multiple times to load entities based on their ids. + /// For each entity request in the `to_load` vector, expects one loaded entity in the resulting vector. + /// Each [`EntityRequest`] comes with an [`AccessTrie`], which can optionally be used. + /// Only fields mentioned in the entity's [`AccessTrie`] are needed, but it is sound to provide other fields as well. + /// Note that the same entity may be requested multiple times, with different [`AccessTrie`]s. + /// + /// Either `load_entities` must load all the ancestors of each entity, unless `load_ancestors` is implemented. + fn load_entities(&mut self, to_load: &[EntityRequest<'_>]) -> Vec; + + /// Optionally, `load_entities` can forgo loading ancestors in the entity hierarchy. + /// Instead, `load_ancestors` implements loading them. + /// For each entity, `load_ancestors` produces a set of ancestors entities in the resulting vector. + /// + /// Each [`AncestorsRequest`] should result in one set of ancestors in the resulting vector. + /// Only ancestors in the request are required, but it is sound to provide other ancestors as well. + fn load_ancestors(&mut self, entities: &Vec) -> Vec>; +} + +fn initial_entities_to_load<'a>( + root_access_trie: &'a RootAccessTrie, + context: &Context, + request: &Request, +) -> Result>, EntitySliceError> { + let Context::Value(context_value) = &context else { + return Err(PartialContextError {}.into()); + }; + + let mut to_load = match root_access_trie.trie.get(&EntityRoot::Var(Var::Context)) { + Some(access_trie) => find_remaining_entities_context(context_value, access_trie)?, + _ => vec![], + }; + + for (key, access_trie) in &root_access_trie.trie { + to_load.push(EntityRequest { + entity_id: match key { + EntityRoot::Var(Var::Principal) => request + .principal() + .uid() + .ok_or(PartialRequestError {})? + .clone(), + EntityRoot::Var(Var::Action) => request + .action() + .uid() + .ok_or(PartialRequestError {})? + .clone(), + EntityRoot::Var(Var::Resource) => request + .resource() + .uid() + .ok_or(PartialRequestError {})? + .clone(), + EntityRoot::Literal(lit) => lit.clone(), + EntityRoot::Var(Var::Context) => continue, + }, + access_trie, + }); + } + + Ok(to_load) +} + +/// Loads entities based on the entity manifest, request, and +/// the implemented [`EntityLoader`]. +/// Returns both the new entity store and the loaded context. +pub fn load_entities( + manifest: &EntityManifest, + request: &Request, + loader: &mut dyn EntityLoader, +) -> Result<(Context, Entities), EntitySliceError> { + let Some(root_access_trie) = manifest + .per_action + .get(&request.to_request_type().ok_or(PartialRequestError {})?) + else { + match Entities::from_entities( + vec![], + None::<&NoEntitiesSchema>, + TCComputation::AssumeAlreadyComputed, + Extensions::all_available(), + ) { + Ok(entities) => return Ok((Context::empty(), entities)), + Err(err) => return Err(err.into()), + }; + }; + + let context = match root_access_trie.trie.get(&EntityRoot::Var(Var::Context)) { + Some(access_trie) => loader.load_context(access_trie.clone()), + _ => Context::empty(), + }; + + let mut entities: HashMap = Default::default(); + // entity requests in progress + let mut to_load: Vec> = + initial_entities_to_load(&root_access_trie, &context, &request)?; + // later, find the ancestors of these entities using their ancestor tries + let mut to_find_ancestors = vec![]; + + // Main loop of loading entities, one batch at a time + while !to_load.is_empty() { + // first, record the entities in `to_find_ancestors` + for entity_request in &to_load { + to_find_ancestors.push(( + entity_request.entity_id.clone(), + &entity_request.access_trie.ancestors_trie, + )); + } + + let new_entities = loader.load_entities(&to_load); + if new_entities.len() != to_load.len() { + return Err(WrongNumberOfEntitiesError { + expected: to_load.len(), + got: new_entities.len(), + } + .into()); + } + + let mut next_to_load = vec![]; + for (entity_request, loaded) in to_load.drain(..).zip(new_entities) { + next_to_load.extend(find_remaining_entities( + &loaded, + entity_request.access_trie, + )?); + entities.insert(entity_request.entity_id, loaded); + } + + to_load = next_to_load; + } + + // now that all the entities are loaded + // we need to load their ancestors + let mut ancestors_requests = vec![]; + for (entity_id, ancestors_trie) in to_find_ancestors { + ancestors_requests.push(compute_ancestors_request( + entity_id, + ancestors_trie, + &entities, + &context, + request, + )?); + } + + let loaded_ancestors = loader.load_ancestors(&ancestors_requests); + for (request, ancestors) in ancestors_requests.into_iter().zip(loaded_ancestors) { + // PANIC SAFETY: ancestor requests are only created for entities already loaded in the entities map + #[allow(clippy::unwrap_used)] + entities + .get_mut(&request.entity_id) + .unwrap() + .add_ancestors(ancestors); + } + + // finally, convert the loaded entities into a Cedar Entities store + + match Entities::from_entities( + entities.values().cloned(), + None::<&NoEntitiesSchema>, + TCComputation::AssumeAlreadyComputed, + Extensions::all_available(), + ) { + Ok(entities) => Ok((context, entities)), + Err(e) => Err(e.into()), + } +} + +fn find_remaining_entities_context<'a>( + context_value: &Arc>, + fields: &'a AccessTrie, +) -> Result>, EntitySliceError> { + let mut remaining = vec![]; + for (field, slice) in &fields.children { + if let Some(value) = context_value.get(field) { + find_remaining_entities_value(&mut remaining, value, slice)?; + } + // the attribute may not be present, since the schema can define + // attributes that are optional + } + Ok(remaining) +} + +/// This helper function finds all entity references that need to be +/// loaded given an already-loaded [`Entity`] and corresponding [`Fields`]. +/// Returns pairs of entity and slices that need to be loaded. +fn find_remaining_entities<'a>( + entity: &Entity, + fields: &'a AccessTrie, +) -> Result>, EntitySliceError> { + let mut remaining = vec![]; + for (field, slice) in &fields.children { + if let Some(pvalue) = entity.get(field) { + let PartialValue::Value(value) = pvalue else { + return Err(PartialEntityError {}.into()); + }; + find_remaining_entities_value(&mut remaining, value, slice)?; + } + // the attribute may not be present, since the schema can define + // attributes that are optional + } + + Ok(remaining) +} + +fn find_remaining_entities_value<'a>( + remaining: &mut Vec>, + value: &Value, + trie: &'a AccessTrie, +) -> Result<(), EntitySliceError> { + match value.value_kind() { + ValueKind::Lit(literal) => { + if let Literal::EntityUID(entity_id) = literal { + remaining.push(EntityRequest { + entity_id: (**entity_id).clone(), + access_trie: trie, + }); + } + } + ValueKind::Set(_) => (), + ValueKind::ExtensionValue(_) => (), + ValueKind::Record(record) => { + for (field, child_slice) in &trie.children { + // only need to slice if field is present + if let Some(value) = record.get(field) { + find_remaining_entities_value(remaining, value, child_slice)?; + } + } + } + }; + Ok(()) +} + +/// Traverse the already-loaded entities using the ancestors trie +/// to find the entity ids that are required. +fn compute_ancestors_request( + entity_id: EntityUID, + ancestors_trie: &RootAccessTrie, + entities: &HashMap, + context: &Context, + request: &Request, +) -> Result { + // similar to load_entities, we traverse the access trie + // this time using the already-loaded entities and looking for + // is_ancestor tags. + let mut ancestors = HashSet::new(); + + let mut to_visit = initial_entities_to_load(ancestors_trie, context, request)?; + + while !to_visit.is_empty() { + let mut next_to_visit = vec![]; + for entity_request in to_visit.drain(..) { + if entity_request.access_trie.is_ancestor { + ancestors.insert(entity_request.entity_id.clone()); + } + if let Some(entity) = entities.get(&entity_request.entity_id) { + next_to_visit.extend(find_remaining_entities(entity, entity_request.access_trie)?); + } + } + to_visit = next_to_visit; + } + + Ok(AncestorsRequest { + ancestors, + entity_id, + }) +} diff --git a/cedar-policy-validator/src/entity_slicing.rs b/cedar-policy-validator/src/entity_slicing.rs index 2b31387c9..bbbbdc984 100644 --- a/cedar-policy-validator/src/entity_slicing.rs +++ b/cedar-policy-validator/src/entity_slicing.rs @@ -63,6 +63,36 @@ pub struct PartialEntityError {} impl Diagnostic for PartialEntityError {} +/// Error when an entity loader returns the wrong number of entities. +// CAUTION: this type is publicly exported in `cedar-policy`. +// Don't make fields `pub`, don't make breaking changes, and use caution +// when adding public methods. +#[derive(Debug, Clone, Error, Eq, PartialEq)] +#[error("entity loader returned the wrong number of entities. Expected {expected} but got {got} entities")] +pub struct WrongNumberOfEntitiesError { + pub(crate) expected: usize, + pub(crate) got: usize, +} + +/// Error when an entity loader returns a value missing an attribute. +// CAUTION: this type is publicly exported in `cedar-policy`. +// Don't make fields `pub`, don't make breaking changes, and use caution +// when adding public methods. +#[derive(Debug, Clone, Error, Eq, PartialEq)] +#[error("entity loader produced entity with value {value}. Expected value to be a record with attribute {attribute}")] +pub struct NonRecordValueError { + pub(crate) value: Value, + pub(crate) attribute: SmolStr, +} + +/// Context was partial during entity loading +// CAUTION: this type is publicly exported in `cedar-policy`. +// Don't make fields `pub`, don't make breaking changes, and use caution +// when adding public methods. +#[derive(Debug, Clone, Error, Eq, PartialEq)] +#[error("entity loader produced a partial context. Expected a concrete value")] +pub struct PartialContextError {} + /// An error generated by entity slicing. /// TODO make public API wrapper #[derive(Debug, Error, Diagnostic)] @@ -87,6 +117,13 @@ pub enum EntitySliceError { /// Found a partial entity during entity loading. #[error(transparent)] PartialEntity(#[from] PartialEntityError), + + #[error(transparent)] + PartialContext(#[from] PartialContextError), + + /// The entity loader produced the wrong number of entities. + #[error(transparent)] + WrongNumberOfEntities(#[from] WrongNumberOfEntitiesError), } impl EntityManifest { From 0407a104da814ca9819d6ca544670c248f9875a1 Mon Sep 17 00:00:00 2001 From: oflatt Date: Mon, 16 Sep 2024 12:22:57 -0700 Subject: [PATCH 3/5] tests passing with new entity loader, ancestors a bit ugly Signed-off-by: oflatt --- cedar-policy-validator/src/entity_loader.rs | 237 +++++++++++++++---- cedar-policy-validator/src/entity_slicing.rs | 230 +++++------------- cedar-policy-validator/src/types.rs | 10 + 3 files changed, 257 insertions(+), 220 deletions(-) diff --git a/cedar-policy-validator/src/entity_loader.rs b/cedar-policy-validator/src/entity_loader.rs index b6f8c63ae..f9fbb1885 100644 --- a/cedar-policy-validator/src/entity_loader.rs +++ b/cedar-policy-validator/src/entity_loader.rs @@ -42,32 +42,48 @@ use crate::{ /// Optionally, instead of loading the full entity the `access_trie` /// may be used to load only some fields of the entity. #[derive(Debug)] -pub struct EntityRequest<'a> { +pub(crate) struct EntityRequest { /// The id of the entity requested - entity_id: EntityUID, + pub(crate) entity_id: EntityUID, /// The fieds of the entity requested + pub(crate) access_trie: AccessTrie, +} + +/// An entity request may be an entity or `None` when +/// the entity is not present. +pub(crate) type EntityAnswer = Option; + +/// The entity request before sub-entitity tries have been +/// pruned using `prune_child_entity_dereferences`. +pub(crate) struct EntityRequestRef<'a> { + entity_id: EntityUID, access_trie: &'a AccessTrie, } +impl<'a> EntityRequestRef<'a> { + fn to_request(&self) -> EntityRequest { + EntityRequest { + entity_id: self.entity_id.clone(), + access_trie: self.access_trie.prune_child_entity_dereferences(), + } + } +} + /// A request that the ancestors of an entity be loaded. /// Optionally, the `ancestors` set may be used to just load ancestors in the set. #[derive(Debug)] -pub struct AncestorsRequest { +pub(crate) struct AncestorsRequest { /// The id of the entity whose ancestors are requested - entity_id: EntityUID, + pub(crate) entity_id: EntityUID, /// The ancestors that are requested, if present - ancestors: HashSet, + pub(crate) ancestors: HashSet, } /// Implement [`EntityLoader`] to easily load entities using their ids /// into a Cedar [`Entities`] store. /// The most basic implementation loads full entities (including all ancestors) in the `load_entities` method and loads the context in the `load_context` method. /// More advanced implementations make use of the [`AccessTrie`]s provided to load partial entities and context, as well as the `load_ancestors` method to load particular ancestors. -pub trait EntityLoader { - /// Loads the concrete context based on the request. - /// Only context attributes mentioned in the `access_trie` are required. - fn load_context(&mut self, access_trie: AccessTrie) -> Context; - +pub(crate) trait EntityLoader { /// `load_entities` is called multiple times to load entities based on their ids. /// For each entity request in the `to_load` vector, expects one loaded entity in the resulting vector. /// Each [`EntityRequest`] comes with an [`AccessTrie`], which can optionally be used. @@ -75,7 +91,10 @@ pub trait EntityLoader { /// Note that the same entity may be requested multiple times, with different [`AccessTrie`]s. /// /// Either `load_entities` must load all the ancestors of each entity, unless `load_ancestors` is implemented. - fn load_entities(&mut self, to_load: &[EntityRequest<'_>]) -> Vec; + fn load_entities( + &mut self, + to_load: &[EntityRequest], + ) -> Result, EntitySliceError>; /// Optionally, `load_entities` can forgo loading ancestors in the entity hierarchy. /// Instead, `load_ancestors` implements loading them. @@ -83,25 +102,31 @@ pub trait EntityLoader { /// /// Each [`AncestorsRequest`] should result in one set of ancestors in the resulting vector. /// Only ancestors in the request are required, but it is sound to provide other ancestors as well. - fn load_ancestors(&mut self, entities: &Vec) -> Vec>; + fn load_ancestors( + &mut self, + entities: &[AncestorsRequest], + ) -> Result>, EntitySliceError>; } fn initial_entities_to_load<'a>( root_access_trie: &'a RootAccessTrie, context: &Context, request: &Request, -) -> Result>, EntitySliceError> { + required_ancestors: &mut HashSet, +) -> Result>, EntitySliceError> { let Context::Value(context_value) = &context else { return Err(PartialContextError {}.into()); }; let mut to_load = match root_access_trie.trie.get(&EntityRoot::Var(Var::Context)) { - Some(access_trie) => find_remaining_entities_context(context_value, access_trie)?, + Some(access_trie) => { + find_remaining_entities_context(context_value, access_trie, required_ancestors)? + } _ => vec![], }; for (key, access_trie) in &root_access_trie.trie { - to_load.push(EntityRequest { + to_load.push(EntityRequestRef { entity_id: match key { EntityRoot::Var(Var::Principal) => request .principal() @@ -128,14 +153,54 @@ fn initial_entities_to_load<'a>( Ok(to_load) } +impl AccessTrie { + /// Removes any entity dereferences in the children of this trie, + /// recursively. + /// These can be included in [`EntityRequest`]s, which don't include + /// referenced entities. + pub(crate) fn prune_child_entity_dereferences(&self) -> AccessTrie { + let children = self + .children + .iter() + .map(|(k, v)| (k.clone(), Box::new(v.prune_entity_dereferences()))) + .collect(); + + AccessTrie { + children, + ancestors_trie: self.ancestors_trie.clone(), + is_ancestor: self.is_ancestor, + node_type: self.node_type.clone(), + } + } + + pub(crate) fn prune_entity_dereferences(&self) -> AccessTrie { + // PANIC SAFETY: Node types should always be present on entity manifests after creation. + #[allow(clippy::unwrap_used)] + let children = if self.node_type.as_ref().unwrap().is_entity_type() { + HashMap::new() + } else { + self.children + .iter() + .map(|(k, v)| (k.clone(), Box::new(v.prune_entity_dereferences()))) + .collect() + }; + + AccessTrie { + children, + ancestors_trie: self.ancestors_trie.clone(), + is_ancestor: self.is_ancestor, + node_type: self.node_type.clone(), + } + } +} + /// Loads entities based on the entity manifest, request, and /// the implemented [`EntityLoader`]. -/// Returns both the new entity store and the loaded context. -pub fn load_entities( +pub(crate) fn load_entities( manifest: &EntityManifest, request: &Request, loader: &mut dyn EntityLoader, -) -> Result<(Context, Entities), EntitySliceError> { +) -> Result { let Some(root_access_trie) = manifest .per_action .get(&request.to_request_type().ok_or(PartialRequestError {})?) @@ -146,20 +211,17 @@ pub fn load_entities( TCComputation::AssumeAlreadyComputed, Extensions::all_available(), ) { - Ok(entities) => return Ok((Context::empty(), entities)), + Ok(entities) => return Ok(entities), Err(err) => return Err(err.into()), }; }; - let context = match root_access_trie.trie.get(&EntityRoot::Var(Var::Context)) { - Some(access_trie) => loader.load_context(access_trie.clone()), - _ => Context::empty(), - }; + let context = request.context().ok_or(PartialRequestError {})?; let mut entities: HashMap = Default::default(); // entity requests in progress - let mut to_load: Vec> = - initial_entities_to_load(&root_access_trie, &context, &request)?; + let mut to_load: Vec> = + initial_entities_to_load(root_access_trie, context, request, &mut Default::default())?; // later, find the ancestors of these entities using their ancestor tries let mut to_find_ancestors = vec![]; @@ -173,7 +235,12 @@ pub fn load_entities( )); } - let new_entities = loader.load_entities(&to_load); + let new_entities = loader.load_entities( + &to_load + .iter() + .map(|entity_ref| entity_ref.to_request()) + .collect::>(), + )?; if new_entities.len() != to_load.len() { return Err(WrongNumberOfEntitiesError { expected: to_load.len(), @@ -183,12 +250,15 @@ pub fn load_entities( } let mut next_to_load = vec![]; - for (entity_request, loaded) in to_load.drain(..).zip(new_entities) { - next_to_load.extend(find_remaining_entities( - &loaded, - entity_request.access_trie, - )?); - entities.insert(entity_request.entity_id, loaded); + for (entity_request, loaded_maybe) in to_load.drain(..).zip(new_entities) { + if let Some(loaded) = loaded_maybe { + next_to_load.extend(find_remaining_entities( + &loaded, + entity_request.access_trie, + &mut Default::default(), + )?); + entities.insert(entity_request.entity_id, loaded); + } } to_load = next_to_load; @@ -202,42 +272,42 @@ pub fn load_entities( entity_id, ancestors_trie, &entities, - &context, + context, request, )?); } - let loaded_ancestors = loader.load_ancestors(&ancestors_requests); + let loaded_ancestors = loader.load_ancestors(&ancestors_requests)?; for (request, ancestors) in ancestors_requests.into_iter().zip(loaded_ancestors) { - // PANIC SAFETY: ancestor requests are only created for entities already loaded in the entities map - #[allow(clippy::unwrap_used)] - entities - .get_mut(&request.entity_id) - .unwrap() - .add_ancestors(ancestors); + if let Some(entity) = entities.get_mut(&request.entity_id) { + entity.add_ancestors(ancestors); + } } // finally, convert the loaded entities into a Cedar Entities store - match Entities::from_entities( entities.values().cloned(), None::<&NoEntitiesSchema>, TCComputation::AssumeAlreadyComputed, Extensions::all_available(), ) { - Ok(entities) => Ok((context, entities)), + Ok(entities) => Ok(entities), Err(e) => Err(e.into()), } } +/// Given a context value and an access trie, find all of the remaining +/// entities in the context. +/// Also keep track of required ancestors when encountering the `is_ancestor` flag. fn find_remaining_entities_context<'a>( context_value: &Arc>, fields: &'a AccessTrie, -) -> Result>, EntitySliceError> { + required_ancestors: &mut HashSet, +) -> Result>, EntitySliceError> { let mut remaining = vec![]; for (field, slice) in &fields.children { if let Some(value) = context_value.get(field) { - find_remaining_entities_value(&mut remaining, value, slice)?; + find_remaining_entities_value(&mut remaining, value, slice, required_ancestors)?; } // the attribute may not be present, since the schema can define // attributes that are optional @@ -248,17 +318,27 @@ fn find_remaining_entities_context<'a>( /// This helper function finds all entity references that need to be /// loaded given an already-loaded [`Entity`] and corresponding [`Fields`]. /// Returns pairs of entity and slices that need to be loaded. +/// Also, finds ancestors that are required whenever the `is_ancestor` +/// flag is found on a node. fn find_remaining_entities<'a>( entity: &Entity, fields: &'a AccessTrie, -) -> Result>, EntitySliceError> { + required_ancestors: &mut HashSet, +) -> Result>, EntitySliceError> { + // first, check if we need to add to `required_ancestors` + // most cases are handled by `find_remaining_entities_value`, but + // cedar variables require this logic + if fields.is_ancestor { + required_ancestors.insert(entity.uid().clone()); + } + let mut remaining = vec![]; for (field, slice) in &fields.children { if let Some(pvalue) = entity.get(field) { let PartialValue::Value(value) = pvalue else { return Err(PartialEntityError {}.into()); }; - find_remaining_entities_value(&mut remaining, value, slice)?; + find_remaining_entities_value(&mut remaining, value, slice, required_ancestors)?; } // the attribute may not be present, since the schema can define // attributes that are optional @@ -268,26 +348,74 @@ fn find_remaining_entities<'a>( } fn find_remaining_entities_value<'a>( - remaining: &mut Vec>, + remaining: &mut Vec>, value: &Value, trie: &'a AccessTrie, + required_ancestors: &mut HashSet, ) -> Result<(), EntitySliceError> { + // unless this is an entity id, ancestors should not be required + assert!( + trie.ancestors_trie == Default::default() + || matches!(value.value_kind(), ValueKind::Lit(Literal::EntityUID(_))) + ); + + // unless this is an entity id or set, it should not be an + // ancestor + assert!( + !trie.is_ancestor + || matches!( + value.value_kind(), + ValueKind::Lit(Literal::EntityUID(_)) | ValueKind::Set(_) + ) + ); + match value.value_kind() { ValueKind::Lit(literal) => { if let Literal::EntityUID(entity_id) = literal { - remaining.push(EntityRequest { + // when ancestors are required, add this to the set + if trie.is_ancestor { + required_ancestors.insert((**entity_id).clone()); + } + + remaining.push(EntityRequestRef { entity_id: (**entity_id).clone(), access_trie: trie, }); } } - ValueKind::Set(_) => (), + ValueKind::Set(set) => { + // when ancestors are required, request all of them + // when this is an ancestor, request all of the entities + // in this set + if trie.is_ancestor { + for val in set.iter() { + match val.value_kind() { + ValueKind::Lit(Literal::EntityUID(id)) => { + required_ancestors.insert((**id).clone()); + } + // PANIC SAFETY: see above panic- set must contain entities + #[allow(clippy::panic)] + _ => { + panic!( + "Found is_ancestor on set of non-entity-type {}", + val.value_kind() + ); + } + } + } + } + } ValueKind::ExtensionValue(_) => (), ValueKind::Record(record) => { for (field, child_slice) in &trie.children { // only need to slice if field is present if let Some(value) = record.get(field) { - find_remaining_entities_value(remaining, value, child_slice)?; + find_remaining_entities_value( + remaining, + value, + child_slice, + required_ancestors, + )?; } } } @@ -309,7 +437,7 @@ fn compute_ancestors_request( // is_ancestor tags. let mut ancestors = HashSet::new(); - let mut to_visit = initial_entities_to_load(ancestors_trie, context, request)?; + let mut to_visit = initial_entities_to_load(ancestors_trie, context, request, &mut ancestors)?; while !to_visit.is_empty() { let mut next_to_visit = vec![]; @@ -317,8 +445,13 @@ fn compute_ancestors_request( if entity_request.access_trie.is_ancestor { ancestors.insert(entity_request.entity_id.clone()); } + if let Some(entity) = entities.get(&entity_request.entity_id) { - next_to_visit.extend(find_remaining_entities(entity, entity_request.access_trie)?); + next_to_visit.extend(find_remaining_entities( + entity, + entity_request.access_trie, + &mut ancestors, + )?); } } to_visit = next_to_visit; diff --git a/cedar-policy-validator/src/entity_slicing.rs b/cedar-policy-validator/src/entity_slicing.rs index bbbbdc984..e7160a4e5 100644 --- a/cedar-policy-validator/src/entity_slicing.rs +++ b/cedar-policy-validator/src/entity_slicing.rs @@ -4,19 +4,19 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt::Display; use cedar_policy_core::entities::err::EntitiesError; -use cedar_policy_core::entities::{Dereference, NoEntitiesSchema, TCComputation}; -use cedar_policy_core::extensions::Extensions; +use cedar_policy_core::entities::Dereference; use cedar_policy_core::{ - ast::{Entity, EntityUID, Literal, PartialValue, Request, Value, ValueKind, Var}, + ast::{Entity, EntityUID, Literal, PartialValue, Request, Value, ValueKind}, entities::Entities, }; use miette::Diagnostic; use smol_str::SmolStr; use thiserror::Error; -use crate::entity_manifest::{ - AccessTrie, EntityManifest, EntityRoot, PartialRequestError, RootAccessTrie, +use crate::entity_loader::{ + load_entities, AncestorsRequest, EntityAnswer, EntityLoader, EntityRequest, }; +use crate::entity_manifest::{AccessTrie, EntityManifest, PartialRequestError}; /// Error when expressions are partial during entity /// slicing. @@ -118,6 +118,7 @@ pub enum EntitySliceError { #[error(transparent)] PartialEntity(#[from] PartialEntityError), + /// The entity loader returned a partial context. #[error(transparent)] PartialContext(#[from] PartialContextError), @@ -134,99 +135,64 @@ impl EntityManifest { entities: &Entities, request: &Request, ) -> Result { - let request_type = request.to_request_type().ok_or(PartialRequestError {})?; - self.per_action - .get(&request_type) - .map(|primary| primary.slice_entities(entities, request)) - .unwrap_or(Ok(Entities::default())) + let mut slicer = EntitySlicer { entities }; + load_entities(self, request, &mut slicer) } } -impl RootAccessTrie { - /// Given entities and a request, return a new entitity store - /// which is a slice of the old one. - fn slice_entities( - &self, - entities: &Entities, - request: &Request, - ) -> Result { - self.slice_entities_internal(entities, request) - .map(|res| res.0) +struct EntitySlicer<'a> { + entities: &'a Entities, +} + +impl<'a> EntityLoader for EntitySlicer<'a> { + fn load_entities( + &mut self, + to_load: &[EntityRequest], + ) -> Result, EntitySliceError> { + let mut res = vec![]; + for request in to_load { + if let Dereference::Data(entity) = self.entities.entity(&request.entity_id) { + // filter down the entity fields to those requested + res.push(Some(request.access_trie.slice_entity(entity)?)); + } else { + res.push(None); + } + } + + Ok(res) } - /// Returns a new entity store and also the ancestor entities found - /// along the way. - fn slice_entities_internal( - &self, - entities: &Entities, - request: &Request, - ) -> Result<(Entities, HashSet), EntitySliceError> { - let mut res = HashMap::::new(); - let mut ancestors = HashSet::new(); - for (root, slice) in &self.trie { - match root { - EntityRoot::Literal(lit) => { - slice.slice_entity(entities, request, lit, &mut res, &mut ancestors)?; - } - EntityRoot::Var(Var::Action) => { - let entity_id = request.action().uid().ok_or(PartialRequestError {})?; - slice.slice_entity(entities, request, entity_id, &mut res, &mut ancestors)?; - } - EntityRoot::Var(Var::Principal) => { - let entity_id = request.principal().uid().ok_or(PartialRequestError {})?; - slice.slice_entity(entities, request, entity_id, &mut res, &mut ancestors)?; - } - EntityRoot::Var(Var::Resource) => { - let resource_id = request.resource().uid().ok_or(PartialRequestError {})?; - slice.slice_entity(entities, request, resource_id, &mut res, &mut ancestors)?; - } - EntityRoot::Var(Var::Context) => { - if slice.children.is_empty() { - // no data loading needed - } else { - let partial_val: PartialValue = PartialValue::from( - request.context().ok_or(PartialRequestError {})?.clone(), - ); - let PartialValue::Value(val) = partial_val else { - return Err(PartialRequestError {}.into()); - }; - slice.slice_val(entities, request, &val, &mut res, &mut ancestors)?; + fn load_ancestors( + &mut self, + entities: &[AncestorsRequest], + ) -> Result>, EntitySliceError> { + let mut res = vec![]; + + for request in entities { + if let Dereference::Data(entity) = self.entities.entity(&request.entity_id) { + let mut ancestors = HashSet::new(); + + for required_ancestor in &request.ancestors { + if entity.is_descendant_of(required_ancestor) { + ancestors.insert(required_ancestor.clone()); } } + + res.push(ancestors); + } else { + // if the entity isn't there, we don't need any ancestors + res.push(HashSet::new()); } } - Ok(( - Entities::from_entities( - res.into_values(), - None::<&NoEntitiesSchema>, - TCComputation::AssumeAlreadyComputed, - Extensions::all_available(), - )?, - ancestors, - )) + + Ok(res) } } impl AccessTrie { /// Given an entities store, an entity id, and a resulting store /// Slice the entities and put them in the resulting store. - fn slice_entity( - &self, - entities: &Entities, - request: &Request, - lit: &EntityUID, - res: &mut HashMap, - res_ancestors: &mut HashSet, - ) -> Result<(), EntitySliceError> { - // add to the res_ancestors set if this is a relavent ancestor - if self.is_ancestor { - res_ancestors.insert(lit.clone()); - } - - // If the entity is not present, no need to slice - let Dereference::Data(entity) = entities.entity(lit) else { - return Ok(()); - }; + fn slice_entity(&self, entity: &Entity) -> Result { let mut new_entity = HashMap::::new(); for (field, slice) in &self.children { // only slice when field is available @@ -234,68 +200,24 @@ impl AccessTrie { let PartialValue::Value(val) = pval else { return Err(PartialEntityError {}.into()); }; - let sliced = slice.slice_val(entities, request, &val, res, res_ancestors)?; + let sliced = slice.slice_val(&val)?; new_entity.insert(field.clone(), PartialValue::Value(sliced)); } } - let new_ancestors = if self.ancestors_trie != Default::default() { - let relavent_ancestors = self - .ancestors_trie - .slice_entities_internal(entities, request)? - .1; - relavent_ancestors - .into_iter() - .filter(|ancestor| entity.is_descendant_of(ancestor)) - .collect() - } else { - HashSet::new() - }; - - let new_entity = - Entity::new_with_attr_partial_value(lit.clone(), new_entity, new_ancestors); - - // PANIC SAFETY: Entities in the entity store with the same ID should be compatible to union together. - #[allow(clippy::expect_used)] - if let Some(existing) = res.get_mut(lit) { - // Here we union the new entity with any existing one - *existing = existing - .union(&new_entity) - .expect("Incompatible values found in entity store"); - } else { - res.insert(lit.clone(), new_entity); - } - Ok(()) + Ok(Entity::new_with_attr_partial_value( + entity.uid().clone(), + new_entity, + Default::default(), + )) } - fn slice_val( - &self, - entities: &Entities, - request: &Request, - val: &Value, - res: &mut HashMap, - res_ancestors: &mut HashSet, - ) -> Result { - // unless this is an entity id, parents should not be required - assert!( - self.ancestors_trie == Default::default() - || matches!(val.value_kind(), ValueKind::Lit(Literal::EntityUID(_))) - ); - - // unless this is an entity id or set, it should not be an - // ancestor - assert!( - !self.is_ancestor - || matches!( - val.value_kind(), - ValueKind::Lit(Literal::EntityUID(_)) | ValueKind::Set(_) - ) - ); - + fn slice_val(&self, val: &Value) -> Result { Ok(match val.value_kind() { - ValueKind::Lit(Literal::EntityUID(id)) => { - self.slice_entity(entities, request, id, res, res_ancestors)?; + ValueKind::Lit(Literal::EntityUID(_)) => { + // entities shouldn't need to be dereferenced + assert!(self.children.is_empty()); val.clone() } ValueKind::Set(_) | ValueKind::ExtensionValue(_) | ValueKind::Lit(_) => { @@ -306,32 +228,6 @@ impl AccessTrie { .into()); } - // when this is an ancestor, request all of the entities - // in this set - if self.is_ancestor { - // PANIC SAFETY: is_ancestor is only called on the rhs of an `is`, which the typechecker ensures is an entity or set of entity type. - #[allow(clippy::panic)] - let ValueKind::Set(set) = val.value_kind() else { - panic!("Found is_ancestor on non-entity type {}", val.value_kind()) - }; - - for val in set.iter() { - match val.value_kind() { - ValueKind::Lit(Literal::EntityUID(id)) => { - res_ancestors.insert((**id).clone()); - } - // PANIC SAFETY: see above panic- set must contain entities - #[allow(clippy::panic)] - _ => { - panic!( - "Found is_ancestor on set of non-entity-type {}", - val.value_kind() - ); - } - } - } - } - val.clone() } ValueKind::Record(record) => { @@ -339,10 +235,7 @@ impl AccessTrie { for (field, slice) in &self.children { // only slice when field is available if let Some(v) = record.get(field) { - new_map.insert( - field.clone(), - slice.slice_val(entities, request, v, res, res_ancestors)?, - ); + new_map.insert(field.clone(), slice.slice_val(v)?); } } @@ -356,7 +249,8 @@ impl AccessTrie { mod entity_slice_tests { use cedar_policy_core::{ ast::{Context, PolicyID, PolicySet}, - entities::EntityJsonParser, + entities::{EntityJsonParser, TCComputation}, + extensions::Extensions, parser::parse_policy, }; diff --git a/cedar-policy-validator/src/types.rs b/cedar-policy-validator/src/types.rs index b9d535790..a733fd6a9 100644 --- a/cedar-policy-validator/src/types.rs +++ b/cedar-policy-validator/src/types.rs @@ -640,6 +640,16 @@ impl Type { }, } } + + /// Returns `true` when the type is a type of an entity + pub(crate) fn is_entity_type(&self) -> bool { + match self { + Type::EntityOrRecord(EntityRecordKind::Entity(_)) => true, + Type::EntityOrRecord(EntityRecordKind::AnyEntity) => true, + Type::EntityOrRecord(EntityRecordKind::ActionEntity { .. }) => true, + _ => false, + } + } } impl Display for Type { From 9b018bf4098e577760aaa1c64d02b28482b7bebb Mon Sep 17 00:00:00 2001 From: oflatt Date: Mon, 16 Sep 2024 12:33:57 -0700 Subject: [PATCH 4/5] clean up ancestor loading somewhat Signed-off-by: oflatt --- cedar-policy-validator/src/entity_loader.rs | 23 +++++++-------------- cedar-policy-validator/src/types.rs | 13 ++++++------ 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/cedar-policy-validator/src/entity_loader.rs b/cedar-policy-validator/src/entity_loader.rs index f9fbb1885..87a3bcdb4 100644 --- a/cedar-policy-validator/src/entity_loader.rs +++ b/cedar-policy-validator/src/entity_loader.rs @@ -318,20 +318,12 @@ fn find_remaining_entities_context<'a>( /// This helper function finds all entity references that need to be /// loaded given an already-loaded [`Entity`] and corresponding [`Fields`]. /// Returns pairs of entity and slices that need to be loaded. -/// Also, finds ancestors that are required whenever the `is_ancestor` -/// flag is found on a node. +/// Also, any sets marked `is_ancestor` are added to the `required_ancestors` set. fn find_remaining_entities<'a>( entity: &Entity, fields: &'a AccessTrie, required_ancestors: &mut HashSet, ) -> Result>, EntitySliceError> { - // first, check if we need to add to `required_ancestors` - // most cases are handled by `find_remaining_entities_value`, but - // cedar variables require this logic - if fields.is_ancestor { - required_ancestors.insert(entity.uid().clone()); - } - let mut remaining = vec![]; for (field, slice) in &fields.children { if let Some(pvalue) = entity.get(field) { @@ -347,6 +339,8 @@ fn find_remaining_entities<'a>( Ok(remaining) } +/// Like `find_remaining_entities`, but for values. +/// Any sets that are marked `is_ancestor` are added to the `required_ancestors` set. fn find_remaining_entities_value<'a>( remaining: &mut Vec>, value: &Value, @@ -372,10 +366,8 @@ fn find_remaining_entities_value<'a>( match value.value_kind() { ValueKind::Lit(literal) => { if let Literal::EntityUID(entity_id) = literal { - // when ancestors are required, add this to the set - if trie.is_ancestor { - required_ancestors.insert((**entity_id).clone()); - } + // no need to add to ancestors set here because + // we are creating an entity request. remaining.push(EntityRequestRef { entity_id: (**entity_id).clone(), @@ -384,7 +376,6 @@ fn find_remaining_entities_value<'a>( } } ValueKind::Set(set) => { - // when ancestors are required, request all of them // when this is an ancestor, request all of the entities // in this set if trie.is_ancestor { @@ -393,7 +384,7 @@ fn find_remaining_entities_value<'a>( ValueKind::Lit(Literal::EntityUID(id)) => { required_ancestors.insert((**id).clone()); } - // PANIC SAFETY: see above panic- set must contain entities + // PANIC SAFETY: see assert above- ancestor annotation is only valid on sets of entities or entities #[allow(clippy::panic)] _ => { panic!( @@ -442,6 +433,8 @@ fn compute_ancestors_request( while !to_visit.is_empty() { let mut next_to_visit = vec![]; for entity_request in to_visit.drain(..) { + // check the is_ancestor flag for entities + // the is_ancestor flag on sets of entities is handled by find_remaining_entities if entity_request.access_trie.is_ancestor { ancestors.insert(entity_request.entity_id.clone()); } diff --git a/cedar-policy-validator/src/types.rs b/cedar-policy-validator/src/types.rs index a733fd6a9..df2163865 100644 --- a/cedar-policy-validator/src/types.rs +++ b/cedar-policy-validator/src/types.rs @@ -642,13 +642,14 @@ impl Type { } /// Returns `true` when the type is a type of an entity + #[cfg(feature = "entity-manifest")] pub(crate) fn is_entity_type(&self) -> bool { - match self { - Type::EntityOrRecord(EntityRecordKind::Entity(_)) => true, - Type::EntityOrRecord(EntityRecordKind::AnyEntity) => true, - Type::EntityOrRecord(EntityRecordKind::ActionEntity { .. }) => true, - _ => false, - } + matches!( + self, + Type::EntityOrRecord(EntityRecordKind::Entity(_)) + | Type::EntityOrRecord(EntityRecordKind::AnyEntity) + | Type::EntityOrRecord(EntityRecordKind::ActionEntity { .. }) + ) } } From b3804884294a7a6126e8e56cd5e971538f0cc931 Mon Sep 17 00:00:00 2001 From: oflatt Date: Mon, 16 Sep 2024 12:36:02 -0700 Subject: [PATCH 5/5] database consistency warning Signed-off-by: oflatt --- cedar-policy-validator/src/entity_loader.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cedar-policy-validator/src/entity_loader.rs b/cedar-policy-validator/src/entity_loader.rs index 87a3bcdb4..d56319ca7 100644 --- a/cedar-policy-validator/src/entity_loader.rs +++ b/cedar-policy-validator/src/entity_loader.rs @@ -83,6 +83,9 @@ pub(crate) struct AncestorsRequest { /// into a Cedar [`Entities`] store. /// The most basic implementation loads full entities (including all ancestors) in the `load_entities` method and loads the context in the `load_context` method. /// More advanced implementations make use of the [`AccessTrie`]s provided to load partial entities and context, as well as the `load_ancestors` method to load particular ancestors. +/// +/// Warning: `load_entities` is called multiple times. If database +/// consistency is required, this API should not be used. Instead, use the entity manifest directly. pub(crate) trait EntityLoader { /// `load_entities` is called multiple times to load entities based on their ids. /// For each entity request in the `to_load` vector, expects one loaded entity in the resulting vector.