-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cleanup logical optimizer rules. #7919
Changes from 13 commits
c8dc3a4
b2509fd
921bd5c
7d8a911
80092c5
1648b1b
a9e4439
c939460
cecb82f
70d0a26
7273cc8
cb187bf
a8cd920
0741b39
fdee7c3
ee4b81e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -231,13 +231,13 @@ async fn group_by_dictionary() { | |
.expect("ran plan correctly"); | ||
|
||
let expected = [ | ||
"+-------+------------------------+", | ||
"| t.val | COUNT(DISTINCT t.dict) |", | ||
"+-------+------------------------+", | ||
"| 1 | 2 |", | ||
"| 2 | 2 |", | ||
"| 4 | 1 |", | ||
"+-------+------------------------+", | ||
"+-----+------------------------+", | ||
"| val | COUNT(DISTINCT t.dict) |", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
"+-----+------------------------+", | ||
"| 1 | 2 |", | ||
"| 2 | 2 |", | ||
"| 4 | 1 |", | ||
"+-----+------------------------+", | ||
]; | ||
assert_batches_sorted_eq!(expected, &results); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,21 +17,26 @@ | |
|
||
//! Built-in functions module contains all the built-in functions definitions. | ||
|
||
use std::cmp::Ordering; | ||
use std::collections::HashMap; | ||
use std::fmt; | ||
use std::str::FromStr; | ||
use std::sync::{Arc, OnceLock}; | ||
|
||
use crate::nullif::SUPPORTED_NULLIF_TYPES; | ||
use crate::signature::TIMEZONE_WILDCARD; | ||
use crate::type_coercion::functions::data_types; | ||
use crate::{ | ||
conditional_expressions, struct_expressions, utils, FuncMonotonicity, Signature, | ||
TypeSignature, Volatility, | ||
}; | ||
|
||
use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit}; | ||
use datafusion_common::{ | ||
internal_err, plan_datafusion_err, plan_err, DataFusionError, Result, | ||
exec_datafusion_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, | ||
Result, | ||
}; | ||
use std::collections::HashMap; | ||
use std::fmt; | ||
use std::str::FromStr; | ||
use std::sync::{Arc, OnceLock}; | ||
|
||
use strum::IntoEnumIterator; | ||
use strum_macros::EnumIter; | ||
|
||
|
@@ -315,6 +320,72 @@ fn function_to_name() -> &'static HashMap<BuiltinScalarFunction, &'static str> { | |
}) | ||
} | ||
|
||
/// Returns the wider type among lhs and rhs. | ||
/// Wider type is the type that can safely represent the other type without information loss. | ||
/// Returns Error if types are incompatible. | ||
fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result<DataType> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it make sense to move this function into the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This makes sense, I moved the function under |
||
Ok(match (lhs, rhs) { | ||
(lhs, rhs) if lhs == rhs => lhs.clone(), | ||
(DataType::Null, _) => rhs.clone(), | ||
(_, DataType::Null) => lhs.clone(), | ||
// Right UInt is larger than left UInt | ||
(DataType::UInt8, DataType::UInt16 | DataType::UInt32 | DataType::UInt64) => { | ||
rhs.clone() | ||
} | ||
(DataType::UInt16, DataType::UInt32 | DataType::UInt64) => rhs.clone(), | ||
(DataType::UInt32, DataType::UInt64) => rhs.clone(), | ||
// Left UInt is larger than right UInt. | ||
(DataType::UInt16 | DataType::UInt32 | DataType::UInt64, DataType::UInt8) => { | ||
lhs.clone() | ||
} | ||
(DataType::UInt32 | DataType::UInt64, DataType::UInt16) => lhs.clone(), | ||
(DataType::UInt64, DataType::UInt32) => lhs.clone(), | ||
// Right Int is larger than left Int | ||
(DataType::Int8, DataType::Int16 | DataType::Int32 | DataType::Int64) => { | ||
rhs.clone() | ||
} | ||
(DataType::Int16, DataType::Int32 | DataType::Int64) => rhs.clone(), | ||
(DataType::Int32, DataType::Int64) => rhs.clone(), | ||
// Left Int is larger than right Int. | ||
(DataType::Int16 | DataType::Int32 | DataType::Int64, DataType::Int8) => { | ||
lhs.clone() | ||
} | ||
(DataType::Int32 | DataType::Int64, DataType::Int16) => lhs.clone(), | ||
(DataType::Int64, DataType::Int32) => lhs.clone(), | ||
// Right Float is larger than left Float | ||
(DataType::Float16, DataType::Float32 | DataType::Float64) => rhs.clone(), | ||
(DataType::Float32, DataType::Float64) => rhs.clone(), | ||
// Left Float is larger than right Float. | ||
(DataType::Float32 | DataType::Float64, DataType::Float16) => lhs.clone(), | ||
(DataType::Float64, DataType::Float32) => lhs.clone(), | ||
// String | ||
(DataType::Utf8, DataType::LargeUtf8) => rhs.clone(), | ||
(DataType::LargeUtf8, DataType::Utf8) => lhs.clone(), | ||
(DataType::List(lhs_field), DataType::List(rhs_field)) => { | ||
let field_type = | ||
get_wider_type(lhs_field.data_type(), rhs_field.data_type())?; | ||
if lhs_field.name() != rhs_field.name() { | ||
return Err(exec_datafusion_err!( | ||
"There is no wider type that can represent both lhs: {:?}, rhs:{:?}", | ||
lhs, | ||
rhs | ||
)); | ||
} | ||
assert_eq!(lhs_field.name(), rhs_field.name()); | ||
let field_name = lhs_field.name(); | ||
let nullable = lhs_field.is_nullable() | rhs_field.is_nullable(); | ||
DataType::List(Arc::new(Field::new(field_name, field_type, nullable))) | ||
} | ||
(_, _) => { | ||
return Err(exec_datafusion_err!( | ||
"There is no wider type that can represent both lhs: {:?}, rhs:{:?}", | ||
lhs, | ||
rhs | ||
)); | ||
} | ||
}) | ||
} | ||
|
||
impl BuiltinScalarFunction { | ||
/// an allowlist of functions to take zero arguments, so that they will get special treatment | ||
/// while executing. | ||
|
@@ -468,18 +539,14 @@ impl BuiltinScalarFunction { | |
/// * `List(Int64)` has dimension 2 | ||
/// * `List(List(Int64))` has dimension 3 | ||
/// * etc. | ||
fn return_dimension(self, input_expr_type: DataType) -> u64 { | ||
let mut res: u64 = 1; | ||
fn return_dimension(self, input_expr_type: &DataType) -> u64 { | ||
let mut result: u64 = 1; | ||
let mut current_data_type = input_expr_type; | ||
loop { | ||
match current_data_type { | ||
DataType::List(field) => { | ||
current_data_type = field.data_type().clone(); | ||
res += 1; | ||
} | ||
_ => return res, | ||
} | ||
while let DataType::List(field) = current_data_type { | ||
current_data_type = field.data_type(); | ||
result += 1; | ||
} | ||
result | ||
} | ||
|
||
/// Returns the output [`DataType`] of this function | ||
|
@@ -538,11 +605,17 @@ impl BuiltinScalarFunction { | |
match input_expr_type { | ||
List(field) => { | ||
if !field.data_type().equals_datatype(&Null) { | ||
let dims = self.return_dimension(input_expr_type.clone()); | ||
if max_dims < dims { | ||
max_dims = dims; | ||
expr_type = input_expr_type.clone(); | ||
} | ||
let dims = self.return_dimension(input_expr_type); | ||
expr_type = match max_dims.cmp(&dims) { | ||
Ordering::Greater => expr_type, | ||
Ordering::Equal => { | ||
get_wider_type(&expr_type, input_expr_type)? | ||
} | ||
Ordering::Less => { | ||
max_dims = dims; | ||
input_expr_type.clone() | ||
} | ||
}; | ||
} | ||
} | ||
_ => { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
During schema check, we were missing out these cases as equal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI @viirya this may be of interest to you