Skip to content

Commit

Permalink
replace type signature for starts_with
Browse files Browse the repository at this point in the history
  • Loading branch information
zjregee committed Feb 25, 2025
1 parent 7299d0e commit 4df8fbe
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 29 deletions.
2 changes: 1 addition & 1 deletion datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1256,7 +1256,7 @@ fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
/// Coercion rules for binary (Binary/LargeBinary) to string (Utf8/LargeUtf8):
/// If one argument is binary and the other is a string then coerce to string
/// (e.g. for `like`)
fn binary_to_string_coercion(
pub fn binary_to_string_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ pub use datafusion_expr_common::columnar_value::ColumnarValue;
pub use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};
pub use datafusion_expr_common::operator::Operator;
pub use datafusion_expr_common::signature::{
ArrayFunctionArgument, ArrayFunctionSignature, Signature, TypeSignature,
ArrayFunctionArgument, ArrayFunctionSignature, Coercion, Signature, TypeSignature,
TypeSignatureClass, Volatility, TIMEZONE_WILDCARD,
};
pub use datafusion_expr_common::type_coercion::binary;
Expand Down
89 changes: 65 additions & 24 deletions datafusion/functions/src/string/starts_with.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,41 @@ use std::sync::Arc;

use arrow::array::ArrayRef;
use arrow::datatypes::DataType;
use datafusion_expr::binary::{binary_to_string_coercion, string_coercion};
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};

use crate::utils::make_scalar_function;
use datafusion_common::types::logical_string;
use datafusion_common::{internal_err, Result, ScalarValue};
use datafusion_expr::{ColumnarValue, Documentation, Expr, Like};
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
use datafusion_expr::{
cast, Coercion, ColumnarValue, Documentation, Expr, Like, ScalarFunctionArgs,
ScalarUDFImpl, Signature, TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;

/// Returns true if string starts with prefix.
/// starts_with('alphabet', 'alph') = 't'
pub fn starts_with(args: &[ArrayRef]) -> Result<ArrayRef> {
let result = arrow::compute::kernels::comparison::starts_with(&args[0], &args[1])?;
Ok(Arc::new(result) as ArrayRef)
if let Some(coercion_data_type) =
string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| {
binary_to_string_coercion(args[0].data_type(), args[1].data_type())
})
{
let arg0 = if args[0].data_type() == &coercion_data_type {
args[0].clone()
} else {
arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)?
};
let arg1 = if args[1].data_type() == &coercion_data_type {
args[1].clone()
} else {
arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)?
};
let result = arrow::compute::kernels::comparison::starts_with(&arg0, &arg1)?;
Ok(Arc::new(result) as ArrayRef)
} else {
internal_err!("Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View")
}
}

#[user_doc(
Expand Down Expand Up @@ -64,7 +86,13 @@ impl Default for StartsWithFunc {
impl StartsWithFunc {
pub fn new() -> Self {
Self {
signature: Signature::string(2, Volatility::Immutable),
signature: Signature::coercible(
vec![
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
],
Volatility::Immutable,
),
}
}
}
Expand Down Expand Up @@ -98,7 +126,7 @@ impl ScalarUDFImpl for StartsWithFunc {
fn simplify(
&self,
args: Vec<Expr>,
_info: &dyn SimplifyInfo,
info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
if let Expr::Literal(scalar_value) = &args[1] {
// Convert starts_with(col, 'prefix') to col LIKE 'prefix%' with proper escaping
Expand All @@ -107,31 +135,44 @@ impl ScalarUDFImpl for StartsWithFunc {
// 2. 'ja\%' (escape special char '%')
// 3. 'ja\%%' (add suffix for starts_with)
let like_expr = match scalar_value {
ScalarValue::Utf8(Some(pattern)) => {
ScalarValue::Utf8(Some(pattern))
| ScalarValue::LargeUtf8(Some(pattern))
| ScalarValue::Utf8View(Some(pattern)) => {
let escaped_pattern = pattern.replace("%", "\\%");
let like_pattern = format!("{}%", escaped_pattern);
Expr::Literal(ScalarValue::Utf8(Some(like_pattern)))
}
ScalarValue::LargeUtf8(Some(pattern)) => {
let escaped_pattern = pattern.replace("%", "\\%");
let like_pattern = format!("{}%", escaped_pattern);
Expr::Literal(ScalarValue::LargeUtf8(Some(like_pattern)))
}
ScalarValue::Utf8View(Some(pattern)) => {
let escaped_pattern = pattern.replace("%", "\\%");
let like_pattern = format!("{}%", escaped_pattern);
Expr::Literal(ScalarValue::Utf8View(Some(like_pattern)))
}
_ => return Ok(ExprSimplifyResult::Original(args)),
};

return Ok(ExprSimplifyResult::Simplified(Expr::Like(Like {
negated: false,
expr: Box::new(args[0].clone()),
pattern: Box::new(like_expr),
escape_char: None,
case_insensitive: false,
})));
let expr_data_type = info.get_data_type(&args[0])?;
let pattern_data_type = info.get_data_type(&like_expr)?;

if let Some(coercion_data_type) =
string_coercion(&expr_data_type, &pattern_data_type).or_else(|| {
binary_to_string_coercion(&expr_data_type, &pattern_data_type)
})
{
let expr = if expr_data_type == coercion_data_type {
args[0].clone()
} else {
cast(args[0].clone(), coercion_data_type.clone())
};

let pattern = if pattern_data_type == coercion_data_type {
like_expr
} else {
cast(like_expr, coercion_data_type)
};

return Ok(ExprSimplifyResult::Simplified(Expr::Like(Like {
negated: false,
expr: Box::new(expr),
pattern: Box::new(pattern),
escape_char: None,
case_insensitive: false,
})));
}
}

Ok(ExprSimplifyResult::Original(args))
Expand Down
6 changes: 3 additions & 3 deletions datafusion/sqllogictest/test_files/string/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ EXPLAIN SELECT
FROM test;
----
logical_plan
01)Projection: starts_with(test.column1_utf8view, test.column2_utf8view) AS c1, starts_with(test.column1_utf8view, CAST(test.column2_utf8 AS Utf8View)) AS c2, starts_with(test.column1_utf8view, CAST(test.column2_large_utf8 AS Utf8View)) AS c3
01)Projection: starts_with(test.column1_utf8view, test.column2_utf8view) AS c1, starts_with(test.column1_utf8view, test.column2_utf8) AS c2, starts_with(test.column1_utf8view, test.column2_large_utf8) AS c3
02)--TableScan: test projection=[column2_utf8, column2_large_utf8, column1_utf8view, column2_utf8view]

query BBB
Expand All @@ -326,7 +326,7 @@ EXPLAIN SELECT
FROM test;
----
logical_plan
01)Projection: starts_with(CAST(test.column1_utf8 AS Utf8View), test.column2_utf8view) AS c1, starts_with(test.column1_utf8, test.column2_utf8) AS c3, starts_with(CAST(test.column1_utf8 AS LargeUtf8), test.column2_large_utf8) AS c4
01)Projection: starts_with(test.column1_utf8, test.column2_utf8view) AS c1, starts_with(test.column1_utf8, test.column2_utf8) AS c3, starts_with(test.column1_utf8, test.column2_large_utf8) AS c4
02)--TableScan: test projection=[column1_utf8, column2_utf8, column2_large_utf8, column2_utf8view]

query BBB
Expand Down Expand Up @@ -382,7 +382,7 @@ EXPLAIN SELECT
FROM test;
----
logical_plan
01)Projection: starts_with(CAST(test.column1_utf8 AS Utf8View), substr(test.column1_utf8, Int64(1), Int64(2))) AS c1, starts_with(CAST(test.column1_large_utf8 AS Utf8View), substr(test.column1_large_utf8, Int64(1), Int64(2))) AS c2, starts_with(test.column1_utf8view, substr(test.column1_utf8view, Int64(1), Int64(2))) AS c3
01)Projection: starts_with(test.column1_utf8, substr(test.column1_utf8, Int64(1), Int64(2))) AS c1, starts_with(test.column1_large_utf8, substr(test.column1_large_utf8, Int64(1), Int64(2))) AS c2, starts_with(test.column1_utf8view, substr(test.column1_utf8view, Int64(1), Int64(2))) AS c3
02)--TableScan: test projection=[column1_utf8, column1_large_utf8, column1_utf8view]

query BBB
Expand Down

0 comments on commit 4df8fbe

Please sign in to comment.