From 61092d2d67e0246be4b86f268c372daef053db02 Mon Sep 17 00:00:00 2001 From: Alexander Falk Date: Wed, 26 Feb 2025 10:47:43 +0100 Subject: [PATCH] refactor: replaced asterisk with constraint name in get_constraints for table_config and added as_any to DeltaCheck to allow type checking in enforce_checks Signed-off-by: Alexander Falk --- crates/core/src/delta_datafusion/mod.rs | 60 ++++++++++++++++++++++- crates/core/src/kernel/mod.rs | 4 +- crates/core/src/kernel/models/schema.rs | 5 ++ crates/core/src/operations/constraints.rs | 26 ++++++++++ crates/core/src/table/config.rs | 3 +- crates/core/src/table/mod.rs | 9 ++++ 6 files changed, 104 insertions(+), 3 deletions(-) diff --git a/crates/core/src/delta_datafusion/mod.rs b/crates/core/src/delta_datafusion/mod.rs index e72112c230..e2fdcbdada 100644 --- a/crates/core/src/delta_datafusion/mod.rs +++ b/crates/core/src/delta_datafusion/mod.rs @@ -1415,9 +1415,14 @@ impl DeltaDataChecker { )); } + let field_to_select = if check.as_any().is::() { + "*" + } else { + check.get_name() + }; let sql = format!( "SELECT {} FROM `{table_name}` WHERE NOT ({}) LIMIT 1", - check.get_name(), + field_to_select, check.get_expression() ); @@ -2160,6 +2165,59 @@ mod tests { assert!(matches!(result, Err(DeltaTableError::Generic { .. }))); } + #[tokio::test] + async fn test_enforce_constraints() { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", ArrowDataType::Utf8, false), + Field::new("b", ArrowDataType::Int32, false), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["a", "b", "c", "d"])), + Arc::new(arrow::array::Int32Array::from(vec![1, 10, 10, 100])), + ], + ) + .unwrap(); + // Empty constraints is okay + let constraints: Vec = vec![]; + assert!(DeltaDataChecker::new_with_constraints(constraints) + .check_batch(&batch) + .await + .is_ok()); + + // Valid invariants return Ok(()) + let constraints = vec![ + Constraint::new("custom_a", "a is not null"), + Constraint::new("custom_b", "b < 1000"), + ]; + assert!(DeltaDataChecker::new_with_constraints(constraints) + .check_batch(&batch) + .await + .is_ok()); + + // Violated invariants returns an error with list of violations + let constraints = vec![ + Constraint::new("custom_a", "a is null"), + Constraint::new("custom_B", "b < 100"), + ]; + let result = DeltaDataChecker::new_with_constraints(constraints) + .check_batch(&batch) + .await; + assert!(result.is_err()); + assert!(matches!(result, Err(DeltaTableError::InvalidData { .. }))); + if let Err(DeltaTableError::InvalidData { violations }) = result { + assert_eq!(violations.len(), 2); + } + + // Irrelevant constraints return a different error + let constraints = vec![Constraint::new("custom_c", "c > 2000")]; + let result = DeltaDataChecker::new_with_constraints(constraints) + .check_batch(&batch) + .await; + assert!(result.is_err()); + } + #[test] fn roundtrip_test_delta_exec_plan() { let ctx = SessionContext::new(); diff --git a/crates/core/src/kernel/mod.rs b/crates/core/src/kernel/mod.rs index b2fcd71634..44a09d7745 100644 --- a/crates/core/src/kernel/mod.rs +++ b/crates/core/src/kernel/mod.rs @@ -3,7 +3,7 @@ //! The Kernel module contains all the logic for reading and processing the Delta Lake transaction log. use delta_kernel::engine::arrow_expression::ArrowExpressionHandler; -use std::sync::LazyLock; +use std::{any::Any, sync::LazyLock}; pub mod arrow; pub mod error; @@ -21,6 +21,8 @@ pub trait DataCheck { fn get_name(&self) -> &str; /// The SQL expression to use for the check fn get_expression(&self) -> &str; + + fn as_any(&self) -> &dyn Any; } static ARROW_HANDLER: LazyLock = diff --git a/crates/core/src/kernel/models/schema.rs b/crates/core/src/kernel/models/schema.rs index 862f629d6c..947b6794d1 100644 --- a/crates/core/src/kernel/models/schema.rs +++ b/crates/core/src/kernel/models/schema.rs @@ -1,5 +1,6 @@ //! Delta table schema +use std::any::Any; use std::sync::Arc; pub use delta_kernel::schema::{ @@ -44,6 +45,10 @@ impl DataCheck for Invariant { fn get_expression(&self) -> &str { &self.invariant_sql } + + fn as_any(&self) -> &dyn Any { + self + } } /// Trait to add convenience functions to struct type diff --git a/crates/core/src/operations/constraints.rs b/crates/core/src/operations/constraints.rs index 70e6ecc98e..7be5205574 100644 --- a/crates/core/src/operations/constraints.rs +++ b/crates/core/src/operations/constraints.rs @@ -272,6 +272,32 @@ mod tests { .to_owned() } + #[tokio::test] + async fn test_get_constraints_with_correct_names() -> DeltaResult<()> { + // The key of a constraint is allowed to be custom + // https://github.com/delta-io/delta/blob/master/PROTOCOL.md#check-constraints + let batch = get_record_batch(None, false); + let write = DeltaOps(create_bare_table()) + .write(vec![batch.clone()]) + .await?; + let table = DeltaOps(write); + + let constraint = table + .add_constraint() + .with_constraint("my_custom_constraint", "value < 100") + .await; + assert!(constraint.is_ok()); + let constraints = constraint + .unwrap() + .state + .unwrap() + .table_config() + .get_constraints(); + assert!(constraints.len() == 1); + assert_eq!(constraints[0].name, "my_custom_constraint"); + Ok(()) + } + #[tokio::test] async fn add_constraint_with_invalid_data() -> DeltaResult<()> { let batch = get_record_batch(None, false); diff --git a/crates/core/src/table/config.rs b/crates/core/src/table/config.rs index 0557cf285b..857352f478 100644 --- a/crates/core/src/table/config.rs +++ b/crates/core/src/table/config.rs @@ -363,7 +363,8 @@ impl TableConfig<'_> { .iter() .filter_map(|(field, value)| { if field.starts_with("delta.constraints") { - value.as_ref().map(|f| Constraint::new("*", f)) + let constraint_name = field.replace("delta.constraints.", ""); + value.as_ref().map(|f| Constraint::new(&constraint_name, f)) } else { None } diff --git a/crates/core/src/table/mod.rs b/crates/core/src/table/mod.rs index 8cc76d16a8..cbeccf893b 100644 --- a/crates/core/src/table/mod.rs +++ b/crates/core/src/table/mod.rs @@ -1,5 +1,6 @@ //! Delta Table read and write implementation +use std::any::Any; use std::cmp::{min, Ordering}; use std::collections::HashMap; use std::fmt; @@ -155,6 +156,10 @@ impl DataCheck for Constraint { fn get_expression(&self) -> &str { &self.expr } + + fn as_any(&self) -> &dyn Any { + self + } } /// A generated column @@ -195,6 +200,10 @@ impl DataCheck for GeneratedColumn { fn get_expression(&self) -> &str { &self.validation_expr } + + fn as_any(&self) -> &dyn Any { + self + } } /// Return partition fields along with their data type from the current schema.