Skip to content

Commit

Permalink
add schema check
Browse files Browse the repository at this point in the history
  • Loading branch information
OussamaSaoudi-db committed Nov 26, 2024
1 parent c00157e commit 2348f7e
Showing 1 changed file with 68 additions and 34 deletions.
102 changes: 68 additions & 34 deletions kernel/src/table_changes/log_replay.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub(crate) fn table_changes_action_iter(
commit_files: impl IntoIterator<Item = ParsedLogPath>,
table_schema: &SchemaRef,
predicate: Option<ExpressionRef>,
schema: SchemaRef,
) -> DeltaResult<impl Iterator<Item = DeltaResult<TableChangesScanData>>> {
let json_client = engine.get_json_handler();
let filter = DataSkippingFilter::new(engine, table_schema, predicate);
Expand All @@ -55,6 +56,7 @@ pub(crate) fn table_changes_action_iter(
json_client.clone(),
filter.clone(),
expression_evaluator.clone(),
schema.clone(),
);
scanner.visit_commit()?;
scanner.resolve_dvs();
Expand Down Expand Up @@ -87,6 +89,7 @@ struct LogReplayScanner {
filter: Option<DataSkippingFilter>,
expression_evaluator: Arc<dyn ExpressionEvaluator>,
add_paths: HashSet<String>,
schema: SchemaRef,
}

impl LogReplayScanner {
Expand All @@ -95,6 +98,7 @@ impl LogReplayScanner {
json_client: Arc<dyn JsonHandler>,
filter: Option<DataSkippingFilter>,
expression_evaluator: Arc<dyn ExpressionEvaluator>,
schema: SchemaRef,
) -> Self {
Self {
timestamp: commit_file.location.last_modified,
Expand All @@ -105,6 +109,7 @@ impl LogReplayScanner {
remove_dvs: Default::default(),
expression_evaluator,
add_paths: Default::default(),
schema,
}
}
fn visit_commit(&mut self) -> DeltaResult<()> {
Expand All @@ -116,19 +121,15 @@ impl LogReplayScanner {
for actions in action_iter {
let actions = actions?;

//let actions_arrow: &ArrowEngineData =
// actions.as_ref().any_ref().downcast_ref().unwrap();
//println!("actions: {:?}", actions_arrow.record_batch());
// apply data skipping to get back a selection vector for actions that passed skipping
// note: None implies all files passed data skipping.
// Apply data skipping to get back a selection vector for actions that passed skipping
let filter_vector = self
.filter
.as_ref()
.map(|filter| filter.apply(actions.as_ref()))
.transpose()?;

// we start our selection vector based on what was filtered. we will add to this vector
// below if a file has been removed
// We start our selection vector based on what was filtered. We will add to this vector
// below if a file has been removed. Note: None implies all files passed data skipping.
let selection_vector = match filter_vector {
Some(ref filter_vector) => filter_vector.clone(),
None => vec![true; actions.len()],
Expand All @@ -139,7 +140,10 @@ impl LogReplayScanner {
protocol.ensure_read_supported()?;
}
if let Some(schema) = visitor.schema {
// TODO: Ensure schema is compatible
require!(
self.schema.as_ref() == &schema,
Error::generic("Got unexpected schma")
);
}
}
Ok(())
Expand Down Expand Up @@ -167,6 +171,7 @@ impl LogReplayScanner {
filter,
expression_evaluator,
add_paths: _,
schema: _,
} = self;
let remove_dvs = Arc::new(remove_dvs);

Expand Down Expand Up @@ -207,7 +212,7 @@ struct Phase1Visitor<'a> {
scanner: &'a mut LogReplayScanner,
selection_vector: Vec<bool>,
protocol: Option<Protocol>,
schema: Option<String>,
schema: Option<StructType>,
}
impl<'a> Phase1Visitor<'a> {
fn new(scanner: &'a mut LogReplayScanner, selection_vector: Vec<bool>) -> Self {
Expand Down Expand Up @@ -295,8 +300,7 @@ impl<'a> RowVisitor for Phase1Visitor<'a> {
} else if let Some(timestamp) = getters[8].get_long(i, "commitInfo.timestamp")? {
self.scanner.timestamp = timestamp;
} else if let Some(schema) = getters[9].get_str(i, "metaData.schemaString")? {
self.schema = Some(schema.to_string());
// TODO: Validate that the schema is as expected
self.schema = Some(serde_json::from_str(schema)?);
} else if let Some(min_reader_version) =
getters[10].get_int(i, "protocol.min_reader_version")?
{
Expand Down Expand Up @@ -389,7 +393,7 @@ mod tests {

use crate::actions::deletion_vector::DeletionVectorDescriptor;
use crate::scan::state::DvInfo;
use crate::schema::{DataType, StructField, StructType};
use crate::schema::{self, DataType, StructField, StructType};
use itertools::Itertools;
use object_store::local::LocalFileSystem;
use object_store::ObjectStore;
Expand All @@ -400,7 +404,9 @@ mod tests {
use tempfile::TempDir;

use super::{get_add_transform_expr, LogReplayScanner, TableChangesScanData};
use crate::actions::{get_log_add_schema, Add, Cdc, CommitInfo, Metadata, Protocol, Remove};
use crate::actions::{
get_log_add_schema, get_log_schema, Add, Cdc, CommitInfo, Metadata, Protocol, Remove,
};
use crate::engine::sync::SyncEngine;
use crate::log_segment::LogSegment;
use crate::path::ParsedLogPath;
Expand Down Expand Up @@ -492,26 +498,11 @@ mod tests {
self.dir.path()
}
}
fn get_init_commit() -> Vec<Action> {
let schema = StructType::new([
fn get_schema() -> StructType {
StructType::new([
StructField::new("id", DataType::LONG, true),
StructField::new("value", DataType::STRING, true),
]);
let schema_string = serde_json::to_string(&schema).unwrap();
vec![
Metadata {
schema_string,
configuration: HashMap::from([
("enableChangeDataFeed".to_string(), "true".to_string()),
("enableDeletionVectors".to_string(), "true".to_string()),
]),
..Default::default()
}
.into(),
Protocol::try_new(3, 7, Some(["deletionVectors"]), Some(["deletionVectors"]))
.unwrap()
.into(),
]
])
}

fn get_segment(
Expand Down Expand Up @@ -542,6 +533,7 @@ mod tests {
engine.get_json_handler(),
None,
expression_evaluator.clone(),
get_schema().into(),
)
}
fn result_to_sv(iter: impl Iterator<Item = DeltaResult<TableChangesScanData>>) -> Vec<bool> {
Expand All @@ -555,7 +547,23 @@ mod tests {
async fn metadata_protocol() {
let engine = SyncEngine::new();
let mut mock_table = MockTable::new();
mock_table.commit(&get_init_commit()).await;
let schema_string = serde_json::to_string(&get_schema()).unwrap();
mock_table
.commit(&[
Metadata {
schema_string,
configuration: HashMap::from([
("enableChangeDataFeed".to_string(), "true".to_string()),
("enableDeletionVectors".to_string(), "true".to_string()),
]),
..Default::default()
}
.into(),
Protocol::try_new(3, 7, Some(["deletionVectors"]), Some(["deletionVectors"]))
.unwrap()
.into(),
])
.await;

let mut commits = get_segment(&engine, mock_table.table_root(), 0, None)
.unwrap()
Expand All @@ -576,6 +584,33 @@ mod tests {
);
}

#[tokio::test]
async fn incompatible_schema() {
let engine = SyncEngine::new();
let mut mock_table = MockTable::new();
let schema = get_schema().project(&["id"]).unwrap();
let schema_string = serde_json::to_string(&schema).unwrap();
mock_table
.commit(&[Metadata {
schema_string,
configuration: HashMap::from([
("enableChangeDataFeed".to_string(), "true".to_string()),
("enableDeletionVectors".to_string(), "true".to_string()),
]),
..Default::default()
}
.into()])
.await;

let mut commits = get_segment(&engine, mock_table.table_root(), 0, None)
.unwrap()
.into_iter();

let mut scanner = get_commit_log_scanner(&engine, commits.next().unwrap());

assert!(scanner.visit_commit().is_err());
}

#[tokio::test]
async fn table_changes_add_remove() {
let engine = SyncEngine::new();
Expand Down Expand Up @@ -631,7 +666,6 @@ mod tests {
async fn table_changes_cdc() {
let engine = SyncEngine::new();
let mut mock_table = MockTable::new();
mock_table.commit(&get_init_commit()).await;
mock_table
.commit(&[
Add {
Expand All @@ -656,7 +690,7 @@ mod tests {
.unwrap()
.into_iter();

let commit = commits.nth(1).unwrap();
let commit = commits.next().unwrap();
let mut scanner = get_commit_log_scanner(&engine, commit);

scanner.visit_commit().unwrap();
Expand Down

0 comments on commit 2348f7e

Please sign in to comment.