From 4e1f8391d4c2736ed92e5acba16dab51fb053aed Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 13 Nov 2024 22:26:01 +0800 Subject: [PATCH] Introduce `TypePlanner` for customizing type planning (#13294) * introduce `plan_data_type` for ExprPlanner * implement TypePlanner trait instead of extending ExprPlanner * enhance the document --- datafusion/core/src/execution/context/mod.rs | 52 ++++++++++++++++- .../core/src/execution/session_state.rs | 30 +++++++++- datafusion/expr/src/planner.rs | 18 +++++- datafusion/sql/src/planner.rs | 8 +++ datafusion/sql/tests/common/mod.rs | 57 +++++++++++++++++-- datafusion/sql/tests/sql_integration.rs | 56 +++++++++++++++++- 6 files changed, 211 insertions(+), 10 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index a2093c39fc7b..45dfe835880f 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1788,15 +1788,15 @@ impl<'n, 'a> TreeNodeVisitor<'n> for BadPlanVisitor<'a> { #[cfg(test)] mod tests { - use std::env; - use std::path::PathBuf; - use super::{super::options::CsvReadOptions, *}; use crate::assert_batches_eq; use crate::execution::memory_pool::MemoryConsumer; use crate::execution::runtime_env::RuntimeEnvBuilder; use crate::test; use crate::test_util::{plan_and_collect, populate_csv_partitions}; + use arrow_schema::{DataType, TimeUnit}; + use std::env; + use std::path::PathBuf; use datafusion_common_runtime::SpawnedTask; @@ -1804,6 +1804,8 @@ mod tests { use crate::execution::session_state::SessionStateBuilder; use crate::physical_planner::PhysicalPlanner; use async_trait::async_trait; + use datafusion_expr::planner::TypePlanner; + use sqlparser::ast; use tempfile::TempDir; #[tokio::test] @@ -2200,6 +2202,29 @@ mod tests { Ok(()) } + #[tokio::test] + async fn custom_type_planner() -> Result<()> { + let state = SessionStateBuilder::new() + .with_default_features() + .with_type_planner(Arc::new(MyTypePlanner {})) + .build(); + let ctx = SessionContext::new_with_state(state); + let result = ctx + .sql("SELECT DATETIME '2021-01-01 00:00:00'") + .await? + .collect() + .await?; + let expected = [ + "+-----------------------------+", + "| Utf8(\"2021-01-01 00:00:00\") |", + "+-----------------------------+", + "| 2021-01-01T00:00:00 |", + "+-----------------------------+", + ]; + assert_batches_eq!(expected, &result); + Ok(()) + } + struct MyPhysicalPlanner {} #[async_trait] @@ -2260,4 +2285,25 @@ mod tests { Ok(ctx) } + + #[derive(Debug)] + struct MyTypePlanner {} + + impl TypePlanner for MyTypePlanner { + fn plan_type(&self, sql_type: &ast::DataType) -> Result> { + match sql_type { + ast::DataType::Datetime(precision) => { + let precision = match precision { + Some(0) => TimeUnit::Second, + Some(3) => TimeUnit::Millisecond, + Some(6) => TimeUnit::Microsecond, + None | Some(9) => TimeUnit::Nanosecond, + _ => unreachable!(), + }; + Ok(Some(DataType::Timestamp(precision, None))) + } + _ => Ok(None), + } + } + } } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 6172783ab832..9fc081dd5329 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -48,7 +48,7 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::FunctionRewrite; -use datafusion_expr::planner::ExprPlanner; +use datafusion_expr::planner::{ExprPlanner, TypePlanner}; use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry}; use datafusion_expr::simplify::SimplifyInfo; use datafusion_expr::var_provider::{is_system_variables, VarType}; @@ -128,6 +128,8 @@ pub struct SessionState { analyzer: Analyzer, /// Provides support for customising the SQL planner, e.g. to add support for custom operators like `->>` or `?` expr_planners: Vec>, + /// Provides support for customising the SQL type planning + type_planner: Option>, /// Responsible for optimizing a logical plan optimizer: Optimizer, /// Responsible for optimizing a physical execution plan @@ -192,6 +194,7 @@ impl Debug for SessionState { .field("table_factories", &self.table_factories) .field("function_factory", &self.function_factory) .field("expr_planners", &self.expr_planners) + .field("type_planner", &self.type_planner) .field("query_planners", &self.query_planner) .field("analyzer", &self.analyzer) .field("optimizer", &self.optimizer) @@ -955,6 +958,7 @@ pub struct SessionStateBuilder { session_id: Option, analyzer: Option, expr_planners: Option>>, + type_planner: Option>, optimizer: Option, physical_optimizers: Option, query_planner: Option>, @@ -984,6 +988,7 @@ impl SessionStateBuilder { session_id: None, analyzer: None, expr_planners: None, + type_planner: None, optimizer: None, physical_optimizers: None, query_planner: None, @@ -1031,6 +1036,7 @@ impl SessionStateBuilder { session_id: None, analyzer: Some(existing.analyzer), expr_planners: Some(existing.expr_planners), + type_planner: existing.type_planner, optimizer: Some(existing.optimizer), physical_optimizers: Some(existing.physical_optimizers), query_planner: Some(existing.query_planner), @@ -1125,6 +1131,12 @@ impl SessionStateBuilder { self } + /// Set the [`TypePlanner`] used to customize the behavior of the SQL planner. + pub fn with_type_planner(mut self, type_planner: Arc) -> Self { + self.type_planner = Some(type_planner); + self + } + /// Set the [`PhysicalOptimizerRule`]s used to optimize plans. pub fn with_physical_optimizer_rules( mut self, @@ -1318,6 +1330,7 @@ impl SessionStateBuilder { session_id, analyzer, expr_planners, + type_planner, optimizer, physical_optimizers, query_planner, @@ -1346,6 +1359,7 @@ impl SessionStateBuilder { session_id: session_id.unwrap_or(Uuid::new_v4().to_string()), analyzer: analyzer.unwrap_or_default(), expr_planners: expr_planners.unwrap_or_default(), + type_planner, optimizer: optimizer.unwrap_or_default(), physical_optimizers: physical_optimizers.unwrap_or_default(), query_planner: query_planner.unwrap_or(Arc::new(DefaultQueryPlanner {})), @@ -1456,6 +1470,11 @@ impl SessionStateBuilder { &mut self.expr_planners } + /// Returns the current type_planner value + pub fn type_planner(&mut self) -> &mut Option> { + &mut self.type_planner + } + /// Returns the current optimizer value pub fn optimizer(&mut self) -> &mut Option { &mut self.optimizer @@ -1578,6 +1597,7 @@ impl Debug for SessionStateBuilder { .field("table_factories", &self.table_factories) .field("function_factory", &self.function_factory) .field("expr_planners", &self.expr_planners) + .field("type_planner", &self.type_planner) .field("query_planners", &self.query_planner) .field("analyzer_rules", &self.analyzer_rules) .field("analyzer", &self.analyzer) @@ -1619,6 +1639,14 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { &self.state.expr_planners } + fn get_type_planner(&self) -> Option> { + if let Some(type_planner) = &self.state.type_planner { + Some(Arc::clone(type_planner)) + } else { + None + } + } + fn get_table_source( &self, name: TableReference, diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 7dd7360e478f..42047e8e6caa 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -25,6 +25,7 @@ use datafusion_common::{ config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema, Result, TableReference, }; +use sqlparser::ast; use crate::{AggregateUDF, Expr, GetFieldAccess, ScalarUDF, TableSource, WindowUDF}; @@ -66,6 +67,11 @@ pub trait ContextProvider { &[] } + /// Getter for the data type planner + fn get_type_planner(&self) -> Option> { + None + } + /// Getter for a UDF description fn get_function_meta(&self, name: &str) -> Option>; /// Getter for a UDAF description @@ -216,7 +222,7 @@ pub trait ExprPlanner: Debug + Send + Sync { /// custom expressions. #[derive(Debug, Clone)] pub struct RawBinaryExpr { - pub op: sqlparser::ast::BinaryOperator, + pub op: ast::BinaryOperator, pub left: Expr, pub right: Expr, } @@ -249,3 +255,13 @@ pub enum PlannerResult { /// The raw expression could not be planned, and is returned unmodified Original(T), } + +/// This trait allows users to customize the behavior of the data type planning +pub trait TypePlanner: Debug + Send + Sync { + /// Plan SQL type to DataFusion data type + /// + /// Returns None if not possible + fn plan_type(&self, _sql_type: &ast::DataType) -> Result> { + Ok(None) + } +} diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 4d44d5ff2584..ccb2ccf7126f 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -401,6 +401,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } pub(crate) fn convert_data_type(&self, sql_type: &SQLDataType) -> Result { + // First check if any of the registered type_planner can handle this type + if let Some(type_planner) = self.context_provider.get_type_planner() { + if let Some(data_type) = type_planner.plan_type(sql_type)? { + return Ok(data_type); + } + } + + // If no type_planner can handle this type, use the default conversion match sql_type { SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_sql_type)) => { // Arrays may be multi-dimensional. diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index b0fa17031849..63c296dfbc2f 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -18,15 +18,16 @@ use std::any::Any; #[cfg(test)] use std::collections::HashMap; -use std::fmt::Display; +use std::fmt::{Debug, Display}; use std::{sync::Arc, vec}; use arrow_schema::*; use datafusion_common::config::ConfigOptions; use datafusion_common::file_options::file_type::FileType; -use datafusion_common::{plan_err, GetExt, Result, TableReference}; -use datafusion_expr::planner::ExprPlanner; -use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; +use datafusion_common::{plan_err, DFSchema, GetExt, Result, TableReference}; +use datafusion_expr::planner::{ExprPlanner, PlannerResult, TypePlanner}; +use datafusion_expr::{AggregateUDF, Expr, ScalarUDF, TableSource, WindowUDF}; +use datafusion_functions_nested::expr_fn::make_array; use datafusion_sql::planner::ContextProvider; struct MockCsvType {} @@ -54,6 +55,7 @@ pub(crate) struct MockSessionState { scalar_functions: HashMap>, aggregate_functions: HashMap>, expr_planners: Vec>, + type_planner: Option>, window_functions: HashMap>, pub config_options: ConfigOptions, } @@ -64,6 +66,11 @@ impl MockSessionState { self } + pub fn with_type_planner(mut self, type_planner: Arc) -> Self { + self.type_planner = Some(type_planner); + self + } + pub fn with_scalar_function(mut self, scalar_function: Arc) -> Self { self.scalar_functions .insert(scalar_function.name().to_string(), scalar_function); @@ -259,6 +266,14 @@ impl ContextProvider for MockContextProvider { fn get_expr_planners(&self) -> &[Arc] { &self.state.expr_planners } + + fn get_type_planner(&self) -> Option> { + if let Some(type_planner) = &self.state.type_planner { + Some(Arc::clone(type_planner)) + } else { + None + } + } } struct EmptyTable { @@ -280,3 +295,37 @@ impl TableSource for EmptyTable { Arc::clone(&self.table_schema) } } + +#[derive(Debug)] +pub struct CustomTypePlanner {} + +impl TypePlanner for CustomTypePlanner { + fn plan_type(&self, sql_type: &sqlparser::ast::DataType) -> Result> { + match sql_type { + sqlparser::ast::DataType::Datetime(precision) => { + let precision = match precision { + Some(0) => TimeUnit::Second, + Some(3) => TimeUnit::Millisecond, + Some(6) => TimeUnit::Microsecond, + None | Some(9) => TimeUnit::Nanosecond, + _ => unreachable!(), + }; + Ok(Some(DataType::Timestamp(precision, None))) + } + _ => Ok(None), + } + } +} + +#[derive(Debug)] +pub struct CustomExprPlanner {} + +impl ExprPlanner for CustomExprPlanner { + fn plan_array_literal( + &self, + exprs: Vec, + _schema: &DFSchema, + ) -> Result>> { + Ok(PlannerResult::Planned(make_array(exprs))) + } +} diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index b2f128778a1c..ab7e6c8d0bb7 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -41,13 +41,14 @@ use datafusion_sql::{ planner::{ParserOptions, SqlToRel}, }; -use crate::common::MockSessionState; +use crate::common::{CustomExprPlanner, CustomTypePlanner, MockSessionState}; use datafusion_functions::core::planner::CoreFunctionPlanner; use datafusion_functions_aggregate::{ approx_median::approx_median_udaf, count::count_udaf, min_max::max_udaf, min_max::min_udaf, }; use datafusion_functions_aggregate::{average::avg_udaf, grouping::grouping_udaf}; +use datafusion_functions_nested::make_array::make_array_udf; use datafusion_functions_window::rank::rank_udwf; use rstest::rstest; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; @@ -4497,3 +4498,56 @@ fn test_no_functions_registered() { "Internal error: No functions registered with this context." ); } + +#[test] +fn test_custom_type_plan() -> Result<()> { + let sql = "SELECT DATETIME '2001-01-01 18:00:00'"; + + // test the default behavior + let options = ParserOptions::default(); + let dialect = &GenericDialect {}; + let state = MockSessionState::default(); + let context = MockContextProvider { state }; + let planner = SqlToRel::new_with_options(&context, options); + let result = DFParser::parse_sql_with_dialect(sql, dialect); + let mut ast = result.unwrap(); + let err = planner.statement_to_plan(ast.pop_front().unwrap()); + assert_contains!( + err.unwrap_err().to_string(), + "This feature is not implemented: Unsupported SQL type Datetime(None)" + ); + + fn plan_sql(sql: &str) -> LogicalPlan { + let options = ParserOptions::default(); + let dialect = &GenericDialect {}; + let state = MockSessionState::default() + .with_scalar_function(make_array_udf()) + .with_expr_planner(Arc::new(CustomExprPlanner {})) + .with_type_planner(Arc::new(CustomTypePlanner {})); + let context = MockContextProvider { state }; + let planner = SqlToRel::new_with_options(&context, options); + let result = DFParser::parse_sql_with_dialect(sql, dialect); + let mut ast = result.unwrap(); + planner.statement_to_plan(ast.pop_front().unwrap()).unwrap() + } + + let plan = plan_sql(sql); + let expected = + "Projection: CAST(Utf8(\"2001-01-01 18:00:00\") AS Timestamp(Nanosecond, None))\ + \n EmptyRelation"; + assert_eq!(plan.to_string(), expected); + + let plan = plan_sql("SELECT CAST(TIMESTAMP '2001-01-01 18:00:00' AS DATETIME)"); + let expected = "Projection: CAST(CAST(Utf8(\"2001-01-01 18:00:00\") AS Timestamp(Nanosecond, None)) AS Timestamp(Nanosecond, None))\ + \n EmptyRelation"; + assert_eq!(plan.to_string(), expected); + + let plan = plan_sql( + "SELECT ARRAY[DATETIME '2001-01-01 18:00:00', DATETIME '2001-01-02 18:00:00']", + ); + let expected = "Projection: make_array(CAST(Utf8(\"2001-01-01 18:00:00\") AS Timestamp(Nanosecond, None)), CAST(Utf8(\"2001-01-02 18:00:00\") AS Timestamp(Nanosecond, None)))\ + \n EmptyRelation"; + assert_eq!(plan.to_string(), expected); + + Ok(()) +}