Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove CountWildcardRule in Analyzer and move the functionality in ExprPlanner, add plan_aggregate and plan_window to planner #14689

Merged
merged 25 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions datafusion/core/src/execution/session_state_defaults.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ impl SessionStateDefaults {
feature = "unicode_expressions"
))]
Arc::new(functions::planner::UserDefinedFunctionPlanner),
Arc::new(functions_aggregate::planner::AggregateFunctionPlanner),
Arc::new(functions_window::planner::WindowFunctionPlanner),
];

expr_planners
Expand Down
34 changes: 33 additions & 1 deletion datafusion/core/tests/dataframe/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use arrow::{
array::{Int32Array, StringArray},
record_batch::RecordBatch,
};
use datafusion_functions_aggregate::count::count_all;
use std::sync::Arc;

use datafusion::error::Result;
Expand All @@ -31,7 +32,7 @@ use datafusion::prelude::*;
use datafusion::assert_batches_eq;
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_expr::expr::Alias;
use datafusion_expr::ExprSchemable;
use datafusion_expr::{table_scan, ExprSchemable, LogicalPlanBuilder};
use datafusion_functions_aggregate::expr_fn::{approx_median, approx_percentile_cont};
use datafusion_functions_nested::map::map;

Expand Down Expand Up @@ -1123,3 +1124,34 @@ async fn test_fn_map() -> Result<()> {

Ok(())
}

/// Call count wildcard from dataframe API
#[tokio::test]
async fn test_count_wildcard() -> Result<()> {
let schema = Schema::new(vec![
Field::new("a", DataType::UInt32, false),
Field::new("b", DataType::UInt32, false),
Field::new("c", DataType::UInt32, false),
]);

let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("b")], vec![count_all()])
.unwrap()
.project(vec![count_all()])
.unwrap()
.sort(vec![count_all().sort(true, false)])
.unwrap()
.build()
.unwrap();

let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64]\
\n Projection: count(*) [count(*):Int64]\
\n Aggregate: groupBy=[[test.b]], aggr=[[count(*)]] [b:UInt32, count(*):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

let formatted_plan = plan.display_indent_schema().to_string();
assert_eq!(formatted_plan, expected);

Ok(())
}
30 changes: 15 additions & 15 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ use arrow::datatypes::{
};
use arrow::error::ArrowError;
use arrow::util::pretty::pretty_format_batches;
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_functions_aggregate::count::{count_all, count_udaf};
use datafusion_functions_aggregate::expr_fn::{
array_agg, avg, count, count_distinct, max, median, min, sum,
};
Expand Down Expand Up @@ -72,7 +73,7 @@ use datafusion_expr::expr::{GroupingSet, Sort, WindowFunction};
use datafusion_expr::var_provider::{VarProvider, VarType};
use datafusion_expr::{
cast, col, create_udf, exists, in_subquery, lit, out_ref_col, placeholder,
scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, LogicalPlan,
scalar_subquery, when, Expr, ExprFunctionExt, ExprSchemable, LogicalPlan,
ScalarFunctionImplementation, WindowFrame, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition,
};
Expand Down Expand Up @@ -2463,8 +2464,8 @@ async fn test_count_wildcard_on_sort() -> Result<()> {
let df_results = ctx
.table("t1")
.await?
.aggregate(vec![col("b")], vec![count(wildcard())])?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.sort(vec![count(wildcard()).sort(true, false)])?
.aggregate(vec![col("b")], vec![count_all()])?
.sort(vec![count_all().sort(true, false)])?
.explain(false, false)?
.collect()
.await?;
Expand Down Expand Up @@ -2498,8 +2499,8 @@ async fn test_count_wildcard_on_where_in() -> Result<()> {
Arc::new(
ctx.table("t2")
.await?
.aggregate(vec![], vec![count(wildcard())])?
.select(vec![count(wildcard())])?
.aggregate(vec![], vec![count_all()])?
.select(vec![count_all()])?
.into_optimized_plan()?,
),
))?
Expand Down Expand Up @@ -2532,8 +2533,8 @@ async fn test_count_wildcard_on_where_exist() -> Result<()> {
.filter(exists(Arc::new(
ctx.table("t2")
.await?
.aggregate(vec![], vec![count(wildcard())])?
.select(vec![count(wildcard())])?
.aggregate(vec![], vec![count_all()])?
.select(vec![count_all()])?
.into_unoptimized_plan(),
// Usually, into_optimized_plan() should be used here, but due to
// https://github.com/apache/datafusion/issues/5771,
Expand Down Expand Up @@ -2568,7 +2569,7 @@ async fn test_count_wildcard_on_window() -> Result<()> {
.await?
.select(vec![Expr::WindowFunction(WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(count_udaf()),
vec![wildcard()],
vec![Expr::Literal(COUNT_STAR_EXPANSION)],
))
.order_by(vec![Sort::new(col("a"), false, true)])
.window_frame(WindowFrame::new_bounds(
Expand Down Expand Up @@ -2599,17 +2600,16 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> {
let sql_results = ctx
.sql("select count(*) from t1")
.await?
.select(vec![col("count(*)")])?
.explain(false, false)?
.collect()
.await?;

// add `.select(vec![count(wildcard())])?` to make sure we can analyze all node instead of just top node.
// add `.select(vec![count_wildcard()])?` to make sure we can analyze all node instead of just top node.
let df_results = ctx
.table("t1")
.await?
.aggregate(vec![], vec![count(wildcard())])?
.select(vec![count(wildcard())])?
.aggregate(vec![], vec![count_all()])?
.select(vec![count_all()])?
.explain(false, false)?
.collect()
.await?;
Expand Down Expand Up @@ -2646,8 +2646,8 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> {
ctx.table("t2")
.await?
.filter(out_ref_col(DataType::UInt32, "t1.a").eq(col("t2.a")))?
.aggregate(vec![], vec![count(wildcard())])?
.select(vec![col(count(wildcard()).to_string())])?
.aggregate(vec![], vec![count_all()])?
.select(vec![col(count_all().to_string())])?
.into_unoptimized_plan(),
))
.gt(lit(ScalarValue::UInt8(Some(0)))),
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/explain_analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ async fn explain_logical_plan_only() {
let expected = vec![
vec![
"logical_plan",
"Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]]\
"Aggregate: groupBy=[[]], aggr=[[count(*)]]\
\n SubqueryAlias: t\
\n Projection: \
\n Values: (Utf8(\"a\"), Int64(1), Int64(100)), (Utf8(\"a\"), Int64(2), Int64(150))"
Expand Down
1 change: 0 additions & 1 deletion datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2294,7 +2294,6 @@ impl Display for SchemaDisplay<'_> {
| Expr::OuterReferenceColumn(..)
| Expr::Placeholder(_)
| Expr::Wildcard { .. } => write!(f, "{}", self.0),

Expr::AggregateFunction(AggregateFunction { func, params }) => {
match func.schema_name(params) {
Ok(name) => {
Expand Down
62 changes: 53 additions & 9 deletions datafusion/expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@ use datafusion_common::{
config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema,
Result, TableReference,
};
use sqlparser::ast;
use sqlparser::ast::{self, NullTreatment};

use crate::{AggregateUDF, Expr, GetFieldAccess, ScalarUDF, TableSource, WindowUDF};
use crate::{
AggregateUDF, Expr, GetFieldAccess, ScalarUDF, SortExpr, TableSource, WindowFrame,
WindowFunctionDefinition, WindowUDF,
};

/// Provides the `SQL` query planner meta-data about tables and
/// functions referenced in SQL statements, without a direct dependency on the
Expand Down Expand Up @@ -138,7 +141,7 @@ pub trait ExprPlanner: Debug + Send + Sync {

/// Plan an array literal, such as `[1, 2, 3]`
///
/// Returns origin expression arguments if not possible
/// Returns original expression arguments if not possible
fn plan_array_literal(
&self,
exprs: Vec<Expr>,
Expand All @@ -149,14 +152,14 @@ pub trait ExprPlanner: Debug + Send + Sync {

/// Plan a `POSITION` expression, such as `POSITION(<expr> in <expr>)`
///
/// returns origin expression arguments if not possible
/// Returns original expression arguments if not possible
fn plan_position(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
Ok(PlannerResult::Original(args))
}

/// Plan a dictionary literal, such as `{ key: value, ...}`
///
/// Returns origin expression arguments if not possible
/// Returns original expression arguments if not possible
fn plan_dictionary_literal(
&self,
expr: RawDictionaryExpr,
Expand All @@ -167,14 +170,14 @@ pub trait ExprPlanner: Debug + Send + Sync {

/// Plan an extract expression, such as`EXTRACT(month FROM foo)`
///
/// Returns origin expression arguments if not possible
/// Returns original expression arguments if not possible
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good change. nit Could go in separate PR to keep PR size lower.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can keep this

fn plan_extract(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
Ok(PlannerResult::Original(args))
}

/// Plan an substring expression, such as `SUBSTRING(<expr> [FROM <expr>] [FOR <expr>])`
///
/// Returns origin expression arguments if not possible
/// Returns original expression arguments if not possible
fn plan_substring(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
Ok(PlannerResult::Original(args))
}
Expand All @@ -195,14 +198,14 @@ pub trait ExprPlanner: Debug + Send + Sync {

/// Plans an overlay expression, such as `overlay(str PLACING substr FROM pos [FOR count])`
///
/// Returns origin expression arguments if not possible
/// Returns original expression arguments if not possible
fn plan_overlay(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
Ok(PlannerResult::Original(args))
}

/// Plans a `make_map` expression, such as `make_map(key1, value1, key2, value2, ...)`
///
/// Returns origin expression arguments if not possible
/// Returns original expression arguments if not possible
fn plan_make_map(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
Ok(PlannerResult::Original(args))
}
Expand Down Expand Up @@ -230,6 +233,23 @@ pub trait ExprPlanner: Debug + Send + Sync {
fn plan_any(&self, expr: RawBinaryExpr) -> Result<PlannerResult<RawBinaryExpr>> {
Ok(PlannerResult::Original(expr))
}

/// Plans aggregate functions, such as `COUNT(<expr>)`
///
/// Returns original expression arguments if not possible
fn plan_aggregate(
&self,
expr: RawAggregateExpr,
) -> Result<PlannerResult<RawAggregateExpr>> {
Ok(PlannerResult::Original(expr))
}

/// Plans window functions, such as `COUNT(<expr>)`
///
/// Returns original expression arguments if not possible
fn plan_window(&self, expr: RawWindowExpr) -> Result<PlannerResult<RawWindowExpr>> {
Ok(PlannerResult::Original(expr))
}
}

/// An operator with two arguments to plan
Expand Down Expand Up @@ -266,6 +286,30 @@ pub struct RawDictionaryExpr {
pub values: Vec<Expr>,
}

/// This structure is used by `AggregateFunctionPlanner` to plan operators with
/// custom expressions.
#[derive(Debug, Clone)]
pub struct RawAggregateExpr {
pub func: Arc<AggregateUDF>,
pub args: Vec<Expr>,
pub distinct: bool,
pub filter: Option<Box<Expr>>,
pub order_by: Option<Vec<SortExpr>>,
pub null_treatment: Option<NullTreatment>,
}

/// This structure is used by `WindowFunctionPlanner` to plan operators with
/// custom expressions.
#[derive(Debug, Clone)]
pub struct RawWindowExpr {
pub func_def: WindowFunctionDefinition,
pub args: Vec<Expr>,
pub partition_by: Vec<Expr>,
pub order_by: Vec<SortExpr>,
pub window_frame: WindowFrame,
pub null_treatment: Option<NullTreatment>,
}

/// Result of planning a raw expr with [`ExprPlanner`]
#[derive(Debug, Clone)]
pub enum PlannerResult<T> {
Expand Down
21 changes: 13 additions & 8 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -515,27 +515,32 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
null_treatment,
} = params;

let mut schema_name = String::new();
let mut display_name = String::new();

schema_name.write_fmt(format_args!(
display_name.write_fmt(format_args!(
"{}({}{})",
self.name(),
if *distinct { "DISTINCT " } else { "" },
expr_vec_fmt!(args)
))?;

if let Some(nt) = null_treatment {
schema_name.write_fmt(format_args!(" {}", nt))?;
display_name.write_fmt(format_args!(" {}", nt))?;
}
if let Some(fe) = filter {
schema_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
}
if let Some(order_by) = order_by {
schema_name
.write_fmt(format_args!(" ORDER BY [{}]", expr_vec_fmt!(order_by)))?;
if let Some(ob) = order_by {
display_name.write_fmt(format_args!(
" ORDER BY [{}]",
ob.iter()
.map(|o| format!("{o}"))
.collect::<Vec<String>>()
.join(", ")
))?;
}

Ok(schema_name)
Ok(display_name)
}

/// Returns the user-defined display name of function, given the arguments
Expand Down
Loading