Skip to content

Commit

Permalink
use logical type for signature
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
  • Loading branch information
jayzhan211 committed Nov 4, 2024
1 parent 85f92ef commit 30be8a0
Show file tree
Hide file tree
Showing 16 changed files with 118 additions and 36 deletions.
2 changes: 2 additions & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions datafusion/common/src/types/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ impl fmt::Debug for dyn LogicalType {
}
}

impl std::fmt::Display for dyn LogicalType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}

impl PartialEq for dyn LogicalType {
fn eq(&self, other: &Self) -> bool {
self.signature().eq(&other.signature())
Expand Down
46 changes: 41 additions & 5 deletions datafusion/common/src/types/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@
// under the License.

use super::{
LogicalField, LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields,
TypeSignature,
LogicalField, LogicalFieldRef, LogicalFields, LogicalType, LogicalTypeRef,
LogicalUnionFields, TypeSignature,
};
use crate::error::{Result, _internal_err};
use arrow::compute::can_cast_types;
use arrow_schema::{
DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields,
};
use std::sync::Arc;
use std::sync::{Arc, OnceLock};

/// Representation of a type that DataFusion can handle natively. It is a subset
/// of the physical variants in Arrow's native [`DataType`].
Expand Down Expand Up @@ -348,6 +349,12 @@ impl LogicalType for NativeType {
// mapping solutions to provide backwards compatibility while transitioning from
// the purely physical system to a logical / physical system.

impl From<&DataType> for NativeType {
fn from(value: &DataType) -> Self {
value.clone().into()
}
}

impl From<DataType> for NativeType {
fn from(value: DataType) -> Self {
use NativeType::*;
Expand Down Expand Up @@ -392,8 +399,37 @@ impl From<DataType> for NativeType {
}
}

impl From<&DataType> for NativeType {
fn from(value: &DataType) -> Self {
value.clone().into()
impl NativeType {
#[inline]
pub fn is_numeric(&self) -> bool {
use NativeType::*;
matches!(
self,
UInt8
| UInt16
| UInt32
| UInt64
| Int8
| Int16
| Int32
| Int64
| Float16
| Float32
| Float64
)
}

/// This function is the NativeType version of `can_cast_types`.
/// It handles general coercion rules that are widely applicable.
/// Avoid adding specific coercion cases here.
/// Aim to keep this logic as SIMPLE as possible!
pub fn can_cast_to(&self, target_type: &Self) -> bool {
// In Postgres, most functions coerce numeric strings to numeric inputs,
// but they do not accept numeric inputs as strings.
if self.is_numeric() && target_type == &NativeType::String {
return false;
}

true
}
}
10 changes: 7 additions & 3 deletions datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//! and return types of functions in DataFusion.
use arrow::datatypes::DataType;
use datafusion_common::types::LogicalTypeRef;

/// Constant that is used as a placeholder for any valid timezone.
/// This is used where a function can accept a timestamp type with any
Expand Down Expand Up @@ -109,7 +110,7 @@ pub enum TypeSignature {
/// For example, `Coercible(vec![DataType::Float64])` accepts
/// arguments like `vec![DataType::Int32]` or `vec![DataType::Float32]`
/// since i32 and f32 can be casted to f64
Coercible(Vec<DataType>),
Coercible(Vec<LogicalTypeRef>),
/// Fixed number of arguments of arbitrary types
/// If a function takes 0 argument, its `TypeSignature` should be `Any(0)`
Any(usize),
Expand Down Expand Up @@ -201,7 +202,10 @@ impl TypeSignature {
TypeSignature::Numeric(num) => {
vec![format!("Numeric({num})")]
}
TypeSignature::Exact(types) | TypeSignature::Coercible(types) => {
TypeSignature::Coercible(types) => {
vec![Self::join_types(types, ", ")]
}
TypeSignature::Exact(types) => {
vec![Self::join_types(types, ", ")]
}
TypeSignature::Any(arg_count) => {
Expand Down Expand Up @@ -322,7 +326,7 @@ impl Signature {
}
}
/// Target coerce types in order
pub fn coercible(target_types: Vec<DataType>, volatility: Volatility) -> Self {
pub fn coercible(target_types: Vec<LogicalTypeRef>, volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::Coercible(target_types),
volatility,
Expand Down
64 changes: 48 additions & 16 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use arrow::{
};
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, plan_err,
types::{logical_string, NativeType},
utils::{coerced_fixed_size_list_to_list, list_ndims},
Result,
};
Expand Down Expand Up @@ -401,6 +402,10 @@ fn get_valid_types(
.map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::String(number) => {
// TODO: we can switch to coercible after all the string functions support utf8view since it is choosen as the default string type.
//
// let data_types = get_valid_types(&TypeSignature::Coercible(vec![logical_string(); *number]), current_types)?.swap_remove(0);

if *number < 1 {
return plan_err!(
"The signature expected at least one argument but received {}",
Expand All @@ -415,20 +420,38 @@ fn get_valid_types(
);
}

fn coercion_rule(
let mut new_types = Vec::with_capacity(current_types.len());
for data_type in current_types.iter() {
let logical_data_type: NativeType = data_type.into();

match logical_data_type {
NativeType::String => {
new_types.push(data_type.to_owned());
}
NativeType::Null => {
new_types.push(DataType::Utf8);
}
_ => {
return plan_err!(
"The signature expected NativeType::String but received {data_type}"
);
}
}
}

let data_types = new_types;

// Find the common string type for the given types
fn find_common_type(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Result<DataType> {
match (lhs_type, rhs_type) {
(DataType::Null, DataType::Null) => Ok(DataType::Utf8),
(DataType::Null, data_type) | (data_type, DataType::Null) => {
coercion_rule(data_type, &DataType::Utf8)
}
(DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
coercion_rule(lhs, rhs)
find_common_type(lhs, rhs)
}
(DataType::Dictionary(_, v), other)
| (other, DataType::Dictionary(_, v)) => coercion_rule(v, other),
| (other, DataType::Dictionary(_, v)) => find_common_type(v, other),
_ => {
if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
Ok(coerced_type)
Expand All @@ -444,15 +467,13 @@ fn get_valid_types(
}

// Length checked above, safe to unwrap
let mut coerced_type = current_types.first().unwrap().to_owned();
for t in current_types.iter().skip(1) {
coerced_type = coercion_rule(&coerced_type, t)?;
let mut coerced_type = data_types.first().unwrap().to_owned();
for t in data_types.iter().skip(1) {
coerced_type = find_common_type(&coerced_type, t)?;
}

fn base_type_or_default_type(data_type: &DataType) -> DataType {
if data_type.is_null() {
DataType::Utf8
} else if let DataType::Dictionary(_, v) = data_type {
if let DataType::Dictionary(_, v) = data_type {
base_type_or_default_type(v)
} else {
data_type.to_owned()
Expand Down Expand Up @@ -506,14 +527,25 @@ fn get_valid_types(
);
}

let mut new_types = Vec::with_capacity(current_types.len());
for (data_type, target_type) in current_types.iter().zip(target_types.iter())
{
if !can_cast_types(data_type, target_type) {
return plan_err!("{data_type} is not coercible to {target_type}");
let logical_data_type: NativeType = data_type.into();
if logical_data_type == *target_type.native() {
new_types.push(data_type.to_owned());
} else if logical_data_type.can_cast_to(target_type.native()) {
let casted_type = target_type.default_cast_for(data_type)?;
new_types.push(casted_type);
} else {
return plan_err!(
"The signature expected {:?} but received {:?}",
target_type.native(),
logical_data_type
);
}
}

vec![target_types.to_owned()]
vec![new_types]
}
TypeSignature::Uniform(number, valid_types) => valid_types
.iter()
Expand Down
3 changes: 2 additions & 1 deletion datafusion/functions-aggregate/src/stddev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use std::sync::{Arc, OnceLock};
use arrow::array::Float64Array;
use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};

use datafusion_common::types::logical_float64;
use datafusion_common::{internal_err, not_impl_err, Result};
use datafusion_common::{plan_err, ScalarValue};
use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL;
Expand Down Expand Up @@ -72,7 +73,7 @@ impl Stddev {
pub fn new() -> Self {
Self {
signature: Signature::coercible(
vec![DataType::Float64],
vec![logical_float64()],
Volatility::Immutable,
),
alias: vec!["stddev_samp".to_string()],
Expand Down
5 changes: 3 additions & 2 deletions datafusion/functions-aggregate/src/variance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ use std::sync::OnceLock;
use std::{fmt::Debug, sync::Arc};

use datafusion_common::{
downcast_value, not_impl_err, plan_err, DataFusionError, Result, ScalarValue,
downcast_value, not_impl_err, plan_err, types::logical_float64, DataFusionError,
Result, ScalarValue,
};
use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL;
use datafusion_expr::{
Expand Down Expand Up @@ -83,7 +84,7 @@ impl VarianceSample {
Self {
aliases: vec![String::from("var_sample"), String::from("var_samp")],
signature: Signature::coercible(
vec![DataType::Float64],
vec![logical_float64()],
Volatility::Immutable,
),
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/string/bit_length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ impl ScalarUDFImpl for BitLengthFunc {
ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar(
ScalarValue::Int64(v.as_ref().map(|x| (x.len() * 8) as i64)),
)),
_ => unreachable!(),
_ => unreachable!("bit length"),
},
}
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/string/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ impl ScalarUDFImpl for ConcatFunc {
}
};
}
_ => unreachable!(),
_ => unreachable!("concat"),
}
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/string/concat_ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ impl ScalarUDFImpl for ConcatWsFunc {
ColumnarValueRef::NonNullableArray(string_array)
}
}
_ => unreachable!(),
_ => unreachable!("concat ws"),
};

let mut columns = Vec::with_capacity(args.len() - 1);
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/string/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ mod tests {
let args = vec![ColumnarValue::Array(input)];
let result = match func.invoke(&args)? {
ColumnarValue::Array(result) => result,
_ => unreachable!(),
_ => unreachable!("lower"),
};
assert_eq!(&expected, &result);
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/string/octet_length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl ScalarUDFImpl for OctetLengthFunc {
ScalarValue::Utf8View(v) => Ok(ColumnarValue::Scalar(
ScalarValue::Int32(v.as_ref().map(|x| x.len() as i32)),
)),
_ => unreachable!(),
_ => unreachable!("OctetLengthFunc"),
},
}
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/string/upper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ mod tests {
let args = vec![ColumnarValue::Array(input)];
let result = match func.invoke(&args)? {
ColumnarValue::Array(result) => result,
_ => unreachable!(),
_ => unreachable!("upper"),
};
assert_eq!(&expected, &result);
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/unicode/character_length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ fn character_length(args: &[ArrayRef]) -> Result<ArrayRef> {
let string_array = args[0].as_string_view();
character_length_general::<Int32Type, _>(string_array)
}
_ => unreachable!(),
_ => unreachable!("CharacterLengthFunc"),
}
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/unicode/lpad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ pub fn lpad<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
length_array,
&args[2],
),
(_, _) => unreachable!(),
(_, _) => unreachable!("lpad"),
}
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/scalar.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1940,7 +1940,7 @@ select position('' in '')
----
1

query error DataFusion error: Error during planning: Error during planning: Int64 and Int64 are not coercible to a common string
query error DataFusion error: Error during planning: Error during planning: The signature expected NativeType::String but received Int64
select position(1 in 1)

query I
Expand Down

0 comments on commit 30be8a0

Please sign in to comment.