From fefc50398125358edff252013b836a048e8df27a Mon Sep 17 00:00:00 2001 From: zjregee Date: Thu, 27 Feb 2025 03:17:29 +0000 Subject: [PATCH] replace type signature for starts_with --- .../expr-common/src/type_coercion/binary.rs | 2 +- .../functions/src/string/starts_with.rs | 93 ++++++++++++++----- .../test_files/string/string_view.slt | 6 +- 3 files changed, 72 insertions(+), 29 deletions(-) diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 64c26192ae0f..682cc885cd6b 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -1256,7 +1256,7 @@ fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { /// Coercion rules for binary (Binary/LargeBinary) to string (Utf8/LargeUtf8): /// If one argument is binary and the other is a string then coerce to string /// (e.g. for `like`) -fn binary_to_string_coercion( +pub fn binary_to_string_coercion( lhs_type: &DataType, rhs_type: &DataType, ) -> Option { diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index f1344780eb4c..71df83352f96 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -21,18 +21,42 @@ use std::sync::Arc; use arrow::array::ArrayRef; use arrow::datatypes::DataType; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::type_coercion::binary::{ + binary_to_string_coercion, string_coercion, +}; use crate::utils::make_scalar_function; +use datafusion_common::types::logical_string; use datafusion_common::{internal_err, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, Documentation, Expr, Like}; -use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + cast, Coercion, ColumnarValue, Documentation, Expr, Like, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, +}; use datafusion_macros::user_doc; /// Returns true if string starts with prefix. /// starts_with('alphabet', 'alph') = 't' -pub fn starts_with(args: &[ArrayRef]) -> Result { - let result = arrow::compute::kernels::comparison::starts_with(&args[0], &args[1])?; - Ok(Arc::new(result) as ArrayRef) +fn starts_with(args: &[ArrayRef]) -> Result { + if let Some(coercion_data_type) = + string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| { + binary_to_string_coercion(args[0].data_type(), args[1].data_type()) + }) + { + let arg0 = if args[0].data_type() == &coercion_data_type { + Arc::clone(&args[0]) + } else { + arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)? + }; + let arg1 = if args[1].data_type() == &coercion_data_type { + Arc::clone(&args[1]) + } else { + arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)? + }; + let result = arrow::compute::kernels::comparison::starts_with(&arg0, &arg1)?; + Ok(Arc::new(result) as ArrayRef) + } else { + internal_err!("Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View") + } } #[user_doc( @@ -64,7 +88,13 @@ impl Default for StartsWithFunc { impl StartsWithFunc { pub fn new() -> Self { Self { - signature: Signature::string(2, Volatility::Immutable), + signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ], + Volatility::Immutable, + ), } } } @@ -98,7 +128,7 @@ impl ScalarUDFImpl for StartsWithFunc { fn simplify( &self, args: Vec, - _info: &dyn SimplifyInfo, + info: &dyn SimplifyInfo, ) -> Result { if let Expr::Literal(scalar_value) = &args[1] { // Convert starts_with(col, 'prefix') to col LIKE 'prefix%' with proper escaping @@ -107,31 +137,44 @@ impl ScalarUDFImpl for StartsWithFunc { // 2. 'ja\%' (escape special char '%') // 3. 'ja\%%' (add suffix for starts_with) let like_expr = match scalar_value { - ScalarValue::Utf8(Some(pattern)) => { + ScalarValue::Utf8(Some(pattern)) + | ScalarValue::LargeUtf8(Some(pattern)) + | ScalarValue::Utf8View(Some(pattern)) => { let escaped_pattern = pattern.replace("%", "\\%"); let like_pattern = format!("{}%", escaped_pattern); Expr::Literal(ScalarValue::Utf8(Some(like_pattern))) } - ScalarValue::LargeUtf8(Some(pattern)) => { - let escaped_pattern = pattern.replace("%", "\\%"); - let like_pattern = format!("{}%", escaped_pattern); - Expr::Literal(ScalarValue::LargeUtf8(Some(like_pattern))) - } - ScalarValue::Utf8View(Some(pattern)) => { - let escaped_pattern = pattern.replace("%", "\\%"); - let like_pattern = format!("{}%", escaped_pattern); - Expr::Literal(ScalarValue::Utf8View(Some(like_pattern))) - } _ => return Ok(ExprSimplifyResult::Original(args)), }; - return Ok(ExprSimplifyResult::Simplified(Expr::Like(Like { - negated: false, - expr: Box::new(args[0].clone()), - pattern: Box::new(like_expr), - escape_char: None, - case_insensitive: false, - }))); + let expr_data_type = info.get_data_type(&args[0])?; + let pattern_data_type = info.get_data_type(&like_expr)?; + + if let Some(coercion_data_type) = + string_coercion(&expr_data_type, &pattern_data_type).or_else(|| { + binary_to_string_coercion(&expr_data_type, &pattern_data_type) + }) + { + let expr = if expr_data_type == coercion_data_type { + args[0].clone() + } else { + cast(args[0].clone(), coercion_data_type.clone()) + }; + + let pattern = if pattern_data_type == coercion_data_type { + like_expr + } else { + cast(like_expr, coercion_data_type) + }; + + return Ok(ExprSimplifyResult::Simplified(Expr::Like(Like { + negated: false, + expr: Box::new(expr), + pattern: Box::new(pattern), + escape_char: None, + case_insensitive: false, + }))); + } } Ok(ExprSimplifyResult::Original(args)) diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index 69cdd58b794d..754937e18f14 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -299,7 +299,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: starts_with(test.column1_utf8view, test.column2_utf8view) AS c1, starts_with(test.column1_utf8view, CAST(test.column2_utf8 AS Utf8View)) AS c2, starts_with(test.column1_utf8view, CAST(test.column2_large_utf8 AS Utf8View)) AS c3 +01)Projection: starts_with(test.column1_utf8view, test.column2_utf8view) AS c1, starts_with(test.column1_utf8view, test.column2_utf8) AS c2, starts_with(test.column1_utf8view, test.column2_large_utf8) AS c3 02)--TableScan: test projection=[column2_utf8, column2_large_utf8, column1_utf8view, column2_utf8view] query BBB @@ -326,7 +326,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: starts_with(CAST(test.column1_utf8 AS Utf8View), test.column2_utf8view) AS c1, starts_with(test.column1_utf8, test.column2_utf8) AS c3, starts_with(CAST(test.column1_utf8 AS LargeUtf8), test.column2_large_utf8) AS c4 +01)Projection: starts_with(test.column1_utf8, test.column2_utf8view) AS c1, starts_with(test.column1_utf8, test.column2_utf8) AS c3, starts_with(test.column1_utf8, test.column2_large_utf8) AS c4 02)--TableScan: test projection=[column1_utf8, column2_utf8, column2_large_utf8, column2_utf8view] query BBB @@ -382,7 +382,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: starts_with(CAST(test.column1_utf8 AS Utf8View), substr(test.column1_utf8, Int64(1), Int64(2))) AS c1, starts_with(CAST(test.column1_large_utf8 AS Utf8View), substr(test.column1_large_utf8, Int64(1), Int64(2))) AS c2, starts_with(test.column1_utf8view, substr(test.column1_utf8view, Int64(1), Int64(2))) AS c3 +01)Projection: starts_with(test.column1_utf8, substr(test.column1_utf8, Int64(1), Int64(2))) AS c1, starts_with(test.column1_large_utf8, substr(test.column1_large_utf8, Int64(1), Int64(2))) AS c2, starts_with(test.column1_utf8view, substr(test.column1_utf8view, Int64(1), Int64(2))) AS c3 02)--TableScan: test projection=[column1_utf8, column1_large_utf8, column1_utf8view] query BBB