Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace TypeSignature::String with TypeSignature::Coercible for starts_with #14812

Merged
merged 1 commit into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1256,7 +1256,7 @@ fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
/// Coercion rules for binary (Binary/LargeBinary) to string (Utf8/LargeUtf8):
/// If one argument is binary and the other is a string then coerce to string
/// (e.g. for `like`)
fn binary_to_string_coercion(
pub fn binary_to_string_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
Expand Down
93 changes: 68 additions & 25 deletions datafusion/functions/src/string/starts_with.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,42 @@ use std::sync::Arc;
use arrow::array::ArrayRef;
use arrow::datatypes::DataType;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::type_coercion::binary::{
binary_to_string_coercion, string_coercion,
};

use crate::utils::make_scalar_function;
use datafusion_common::types::logical_string;
use datafusion_common::{internal_err, Result, ScalarValue};
use datafusion_expr::{ColumnarValue, Documentation, Expr, Like};
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
use datafusion_expr::{
cast, Coercion, ColumnarValue, Documentation, Expr, Like, ScalarFunctionArgs,
ScalarUDFImpl, Signature, TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;

/// Returns true if string starts with prefix.
/// starts_with('alphabet', 'alph') = 't'
pub fn starts_with(args: &[ArrayRef]) -> Result<ArrayRef> {
let result = arrow::compute::kernels::comparison::starts_with(&args[0], &args[1])?;
Ok(Arc::new(result) as ArrayRef)
fn starts_with(args: &[ArrayRef]) -> Result<ArrayRef> {
if let Some(coercion_data_type) =
string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| {
binary_to_string_coercion(args[0].data_type(), args[1].data_type())
})
{
let arg0 = if args[0].data_type() == &coercion_data_type {
Arc::clone(&args[0])
} else {
arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)?
};
let arg1 = if args[1].data_type() == &coercion_data_type {
Arc::clone(&args[1])
} else {
arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)?
};
let result = arrow::compute::kernels::comparison::starts_with(&arg0, &arg1)?;
Ok(Arc::new(result) as ArrayRef)
} else {
internal_err!("Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View")
}
}

#[user_doc(
Expand Down Expand Up @@ -64,7 +88,13 @@ impl Default for StartsWithFunc {
impl StartsWithFunc {
pub fn new() -> Self {
Self {
signature: Signature::string(2, Volatility::Immutable),
signature: Signature::coercible(
vec![
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
],
Volatility::Immutable,
),
}
}
}
Expand Down Expand Up @@ -98,7 +128,7 @@ impl ScalarUDFImpl for StartsWithFunc {
fn simplify(
&self,
args: Vec<Expr>,
_info: &dyn SimplifyInfo,
info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
if let Expr::Literal(scalar_value) = &args[1] {
// Convert starts_with(col, 'prefix') to col LIKE 'prefix%' with proper escaping
Expand All @@ -107,31 +137,44 @@ impl ScalarUDFImpl for StartsWithFunc {
// 2. 'ja\%' (escape special char '%')
// 3. 'ja\%%' (add suffix for starts_with)
let like_expr = match scalar_value {
ScalarValue::Utf8(Some(pattern)) => {
ScalarValue::Utf8(Some(pattern))
| ScalarValue::LargeUtf8(Some(pattern))
| ScalarValue::Utf8View(Some(pattern)) => {
let escaped_pattern = pattern.replace("%", "\\%");
let like_pattern = format!("{}%", escaped_pattern);
Expr::Literal(ScalarValue::Utf8(Some(like_pattern)))
}
ScalarValue::LargeUtf8(Some(pattern)) => {
let escaped_pattern = pattern.replace("%", "\\%");
let like_pattern = format!("{}%", escaped_pattern);
Expr::Literal(ScalarValue::LargeUtf8(Some(like_pattern)))
}
ScalarValue::Utf8View(Some(pattern)) => {
let escaped_pattern = pattern.replace("%", "\\%");
let like_pattern = format!("{}%", escaped_pattern);
Expr::Literal(ScalarValue::Utf8View(Some(like_pattern)))
}
_ => return Ok(ExprSimplifyResult::Original(args)),
};

return Ok(ExprSimplifyResult::Simplified(Expr::Like(Like {
negated: false,
expr: Box::new(args[0].clone()),
pattern: Box::new(like_expr),
escape_char: None,
case_insensitive: false,
})));
let expr_data_type = info.get_data_type(&args[0])?;
let pattern_data_type = info.get_data_type(&like_expr)?;

if let Some(coercion_data_type) =
string_coercion(&expr_data_type, &pattern_data_type).or_else(|| {
binary_to_string_coercion(&expr_data_type, &pattern_data_type)
})
{
let expr = if expr_data_type == coercion_data_type {
args[0].clone()
} else {
cast(args[0].clone(), coercion_data_type.clone())
};

let pattern = if pattern_data_type == coercion_data_type {
like_expr
} else {
cast(like_expr, coercion_data_type)
};

return Ok(ExprSimplifyResult::Simplified(Expr::Like(Like {
negated: false,
expr: Box::new(expr),
pattern: Box::new(pattern),
escape_char: None,
case_insensitive: false,
})));
}
}

Ok(ExprSimplifyResult::Original(args))
Expand Down
6 changes: 3 additions & 3 deletions datafusion/sqllogictest/test_files/string/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ EXPLAIN SELECT
FROM test;
----
logical_plan
01)Projection: starts_with(test.column1_utf8view, test.column2_utf8view) AS c1, starts_with(test.column1_utf8view, CAST(test.column2_utf8 AS Utf8View)) AS c2, starts_with(test.column1_utf8view, CAST(test.column2_large_utf8 AS Utf8View)) AS c3
01)Projection: starts_with(test.column1_utf8view, test.column2_utf8view) AS c1, starts_with(test.column1_utf8view, test.column2_utf8) AS c2, starts_with(test.column1_utf8view, test.column2_large_utf8) AS c3
02)--TableScan: test projection=[column2_utf8, column2_large_utf8, column1_utf8view, column2_utf8view]

query BBB
Expand All @@ -326,7 +326,7 @@ EXPLAIN SELECT
FROM test;
----
logical_plan
01)Projection: starts_with(CAST(test.column1_utf8 AS Utf8View), test.column2_utf8view) AS c1, starts_with(test.column1_utf8, test.column2_utf8) AS c3, starts_with(CAST(test.column1_utf8 AS LargeUtf8), test.column2_large_utf8) AS c4
01)Projection: starts_with(test.column1_utf8, test.column2_utf8view) AS c1, starts_with(test.column1_utf8, test.column2_utf8) AS c3, starts_with(test.column1_utf8, test.column2_large_utf8) AS c4
02)--TableScan: test projection=[column1_utf8, column2_utf8, column2_large_utf8, column2_utf8view]

query BBB
Expand Down Expand Up @@ -382,7 +382,7 @@ EXPLAIN SELECT
FROM test;
----
logical_plan
01)Projection: starts_with(CAST(test.column1_utf8 AS Utf8View), substr(test.column1_utf8, Int64(1), Int64(2))) AS c1, starts_with(CAST(test.column1_large_utf8 AS Utf8View), substr(test.column1_large_utf8, Int64(1), Int64(2))) AS c2, starts_with(test.column1_utf8view, substr(test.column1_utf8view, Int64(1), Int64(2))) AS c3
01)Projection: starts_with(test.column1_utf8, substr(test.column1_utf8, Int64(1), Int64(2))) AS c1, starts_with(test.column1_large_utf8, substr(test.column1_large_utf8, Int64(1), Int64(2))) AS c2, starts_with(test.column1_utf8view, substr(test.column1_utf8view, Int64(1), Int64(2))) AS c3
02)--TableScan: test projection=[column1_utf8, column1_large_utf8, column1_utf8view]

query BBB
Expand Down