Skip to content

Commit

Permalink
Merge remote-tracking branch 'apache/main' into alamb/deprecation_policy
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Dec 17, 2024
2 parents 1f5bf65 + 54dccad commit 7632458
Show file tree
Hide file tree
Showing 33 changed files with 846 additions and 192 deletions.
2 changes: 1 addition & 1 deletion arrow-array/src/array/dictionary_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ pub struct DictionaryArray<K: ArrowDictionaryKeyType> {
/// map to the real values.
keys: PrimitiveArray<K>,

/// Array of dictionary values (can by any DataType).
/// Array of dictionary values (can be any DataType).
values: ArrayRef,

/// Values are ordered.
Expand Down
49 changes: 28 additions & 21 deletions arrow-cast/src/cast/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,13 @@ where
O::Native::from_decimal(adjusted)
};

Ok(match cast_options.safe {
true => array.unary_opt(f),
false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?,
Ok(if cast_options.safe {
array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
} else {
array.try_unary(|x| {
f(x).ok_or_else(|| error(x))
.and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v))
})?
})
}

Expand All @@ -137,15 +141,20 @@ where

let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());

Ok(match cast_options.safe {
true => array.unary_opt(f),
false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?,
Ok(if cast_options.safe {
array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
} else {
array.try_unary(|x| {
f(x).ok_or_else(|| error(x))
.and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v))
})?
})
}

// Only support one type of decimal cast operations
pub(crate) fn cast_decimal_to_decimal_same_type<T>(
array: &PrimitiveArray<T>,
input_precision: u8,
input_scale: i8,
output_precision: u8,
output_scale: i8,
Expand All @@ -155,29 +164,27 @@ where
T: DecimalType,
T::Native: DecimalCast + ArrowNativeTypeOp,
{
let array: PrimitiveArray<T> = match input_scale.cmp(&output_scale) {
Ordering::Equal => {
// the scale doesn't change, the native value don't need to be changed
let array: PrimitiveArray<T> =
if input_scale == output_scale && input_precision <= output_precision {
array.clone()
}
Ordering::Greater => convert_to_smaller_scale_decimal::<T, T>(
array,
input_scale,
output_precision,
output_scale,
cast_options,
)?,
Ordering::Less => {
// input_scale < output_scale
} else if input_scale < output_scale {
// the scale doesn't change, but precision may change and cause overflow
convert_to_bigger_or_equal_scale_decimal::<T, T>(
array,
input_scale,
output_precision,
output_scale,
cast_options,
)?
}
};
} else {
convert_to_smaller_scale_decimal::<T, T>(
array,
input_scale,
output_precision,
output_scale,
cast_options,
)?
};

Ok(Arc::new(array.with_precision_and_scale(
output_precision,
Expand Down
122 changes: 104 additions & 18 deletions arrow-cast/src/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -830,18 +830,20 @@ pub fn cast_with_options(
(Map(_, ordered1), Map(_, ordered2)) if ordered1 == ordered2 => {
cast_map_values(array.as_map(), to_type, cast_options, ordered1.to_owned())
}
(Decimal128(_, s1), Decimal128(p2, s2)) => {
(Decimal128(p1, s1), Decimal128(p2, s2)) => {
cast_decimal_to_decimal_same_type::<Decimal128Type>(
array.as_primitive(),
*p1,
*s1,
*p2,
*s2,
cast_options,
)
}
(Decimal256(_, s1), Decimal256(p2, s2)) => {
(Decimal256(p1, s1), Decimal256(p2, s2)) => {
cast_decimal_to_decimal_same_type::<Decimal256Type>(
array.as_primitive(),
*p1,
*s1,
*p2,
*s2,
Expand Down Expand Up @@ -1472,7 +1474,7 @@ pub fn cast_with_options(
(BinaryView, _) => Err(ArrowError::CastError(format!(
"Casting from {from_type:?} to {to_type:?} not supported",
))),
(from_type, Utf8View) if from_type.is_numeric() => {
(from_type, Utf8View) if from_type.is_primitive() => {
value_to_string_view(array, cast_options)
}
(from_type, LargeUtf8) if from_type.is_primitive() => {
Expand Down Expand Up @@ -2694,13 +2696,16 @@ mod tests {
// negative test
let array = vec![Some(123456), None];
let array = create_decimal_array(array, 10, 0).unwrap();
let result = cast(&array, &DataType::Decimal128(2, 2));
assert!(result.is_ok());
let array = result.unwrap();
let array: &Decimal128Array = array.as_primitive();
let err = array.validate_decimal_precision(2);
let result_safe = cast(&array, &DataType::Decimal128(2, 2));
assert!(result_safe.is_ok());
let options = CastOptions {
safe: false,
..Default::default()
};

let result_unsafe = cast_with_options(&array, &DataType::Decimal128(2, 2), &options);
assert_eq!("Invalid argument error: 12345600 is too large to store in a Decimal128 of precision 2. Max is 99",
err.unwrap_err().to_string());
result_unsafe.unwrap_err().to_string());
}

#[test]
Expand Down Expand Up @@ -5258,9 +5263,6 @@ mod tests {
assert_eq!("2018-12-25T00:00:00", c.value(1));
}

// Cast Timestamp to Utf8View is not supported yet
// TODO: Implement casting from Timestamp to Utf8View
// https://github.com/apache/arrow-rs/issues/6734
macro_rules! assert_cast_timestamp_to_string {
($array:expr, $datatype:expr, $output_array_type: ty, $expected:expr) => {{
let out = cast(&$array, &$datatype).unwrap();
Expand Down Expand Up @@ -5295,7 +5297,7 @@ mod tests {
None,
];

// assert_cast_timestamp_to_string!(array, DataType::Utf8View, StringViewArray, expected);
assert_cast_timestamp_to_string!(array, DataType::Utf8View, StringViewArray, expected);
assert_cast_timestamp_to_string!(array, DataType::Utf8, StringArray, expected);
assert_cast_timestamp_to_string!(array, DataType::LargeUtf8, LargeStringArray, expected);
}
Expand All @@ -5319,7 +5321,13 @@ mod tests {
Some("2018-12-25 00:00:02.001000"),
None,
];
// assert_cast_timestamp_to_string!(array_without_tz, DataType::Utf8View, StringViewArray, cast_options, expected);
assert_cast_timestamp_to_string!(
array_without_tz,
DataType::Utf8View,
StringViewArray,
cast_options,
expected
);
assert_cast_timestamp_to_string!(
array_without_tz,
DataType::Utf8,
Expand All @@ -5343,7 +5351,13 @@ mod tests {
Some("2018-12-25 05:45:02.001000"),
None,
];
// assert_cast_timestamp_to_string!(array_with_tz, DataType::Utf8View, StringViewArray, cast_options, expected);
assert_cast_timestamp_to_string!(
array_with_tz,
DataType::Utf8View,
StringViewArray,
cast_options,
expected
);
assert_cast_timestamp_to_string!(
array_with_tz,
DataType::Utf8,
Expand Down Expand Up @@ -8451,7 +8465,7 @@ mod tests {
let input_type = DataType::Decimal128(10, 3);
let output_type = DataType::Decimal256(10, 5);
assert!(can_cast_types(&input_type, &output_type));
let array = vec![Some(i128::MAX), Some(i128::MIN)];
let array = vec![Some(123456), Some(-123456)];
let input_decimal_array = create_decimal_array(array, 10, 3).unwrap();
let array = Arc::new(input_decimal_array) as ArrayRef;

Expand All @@ -8461,8 +8475,8 @@ mod tests {
Decimal256Array,
&output_type,
vec![
Some(i256::from_i128(i128::MAX).mul_wrapping(hundred)),
Some(i256::from_i128(i128::MIN).mul_wrapping(hundred))
Some(i256::from_i128(123456).mul_wrapping(hundred)),
Some(i256::from_i128(-123456).mul_wrapping(hundred))
]
);
}
Expand Down Expand Up @@ -9926,4 +9940,76 @@ mod tests {
"Cast non-nullable to non-nullable struct field returning null should fail",
);
}

#[test]
fn test_decimal_to_decimal_throw_error_on_precision_overflow_same_scale() {
let array = vec![Some(123456789)];
let array = create_decimal_array(array, 24, 2).unwrap();
println!("{:?}", array);
let input_type = DataType::Decimal128(24, 2);
let output_type = DataType::Decimal128(6, 2);
assert!(can_cast_types(&input_type, &output_type));

let options = CastOptions {
safe: false,
..Default::default()
};
let result = cast_with_options(&array, &output_type, &options);
assert_eq!(result.unwrap_err().to_string(),
"Invalid argument error: 123456790 is too large to store in a Decimal128 of precision 6. Max is 999999");
}

#[test]
fn test_decimal_to_decimal_throw_error_on_precision_overflow_lower_scale() {
let array = vec![Some(123456789)];
let array = create_decimal_array(array, 24, 2).unwrap();
println!("{:?}", array);
let input_type = DataType::Decimal128(24, 4);
let output_type = DataType::Decimal128(6, 2);
assert!(can_cast_types(&input_type, &output_type));

let options = CastOptions {
safe: false,
..Default::default()
};
let result = cast_with_options(&array, &output_type, &options);
assert_eq!(result.unwrap_err().to_string(),
"Invalid argument error: 123456790 is too large to store in a Decimal128 of precision 6. Max is 999999");
}

#[test]
fn test_decimal_to_decimal_throw_error_on_precision_overflow_greater_scale() {
let array = vec![Some(123456789)];
let array = create_decimal_array(array, 24, 2).unwrap();
println!("{:?}", array);
let input_type = DataType::Decimal128(24, 2);
let output_type = DataType::Decimal128(6, 3);
assert!(can_cast_types(&input_type, &output_type));

let options = CastOptions {
safe: false,
..Default::default()
};
let result = cast_with_options(&array, &output_type, &options);
assert_eq!(result.unwrap_err().to_string(),
"Invalid argument error: 1234567890 is too large to store in a Decimal128 of precision 6. Max is 999999");
}

#[test]
fn test_decimal_to_decimal_throw_error_on_precision_overflow_diff_type() {
let array = vec![Some(123456789)];
let array = create_decimal_array(array, 24, 2).unwrap();
println!("{:?}", array);
let input_type = DataType::Decimal128(24, 2);
let output_type = DataType::Decimal256(6, 2);
assert!(can_cast_types(&input_type, &output_type));

let options = CastOptions {
safe: false,
..Default::default()
};
let result = cast_with_options(&array, &output_type, &options);
assert_eq!(result.unwrap_err().to_string(),
"Invalid argument error: 123456789 is too large to store in a Decimal256 of precision 6. Max is 999999");
}
}
9 changes: 9 additions & 0 deletions arrow-flight/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,10 @@ fn prepare_field_for_flight(
)
.with_metadata(field.metadata().clone())
} else {
#[allow(deprecated)]
let dict_id = dictionary_tracker.set_dict_id(field.as_ref());

#[allow(deprecated)]
Field::new_dict(
field.name(),
field.data_type().clone(),
Expand Down Expand Up @@ -583,7 +585,9 @@ fn prepare_schema_for_flight(
)
.with_metadata(field.metadata().clone())
} else {
#[allow(deprecated)]
let dict_id = dictionary_tracker.set_dict_id(field.as_ref());
#[allow(deprecated)]
Field::new_dict(
field.name(),
field.data_type().clone(),
Expand Down Expand Up @@ -650,10 +654,12 @@ struct FlightIpcEncoder {

impl FlightIpcEncoder {
fn new(options: IpcWriteOptions, error_on_replacement: bool) -> Self {
#[allow(deprecated)]
let preserve_dict_id = options.preserve_dict_id();
Self {
options,
data_gen: IpcDataGenerator::default(),
#[allow(deprecated)]
dictionary_tracker: DictionaryTracker::new_with_preserve_dict_id(
error_on_replacement,
preserve_dict_id,
Expand Down Expand Up @@ -1541,6 +1547,7 @@ mod tests {
async fn verify_flight_round_trip(mut batches: Vec<RecordBatch>) {
let expected_schema = batches.first().unwrap().schema();

#[allow(deprecated)]
let encoder = FlightDataEncoderBuilder::default()
.with_options(IpcWriteOptions::default().with_preserve_dict_id(false))
.with_dictionary_handling(DictionaryHandling::Resend)
Expand Down Expand Up @@ -1568,6 +1575,7 @@ mod tests {
HashMap::from([("some_key".to_owned(), "some_value".to_owned())]),
);

#[allow(deprecated)]
let mut dictionary_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true);

let got = prepare_schema_for_flight(&schema, &mut dictionary_tracker, false);
Expand Down Expand Up @@ -1598,6 +1606,7 @@ mod tests {
options: &IpcWriteOptions,
) -> (Vec<FlightData>, FlightData) {
let data_gen = IpcDataGenerator::default();
#[allow(deprecated)]
let mut dictionary_tracker =
DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id());

Expand Down
1 change: 1 addition & 0 deletions arrow-flight/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ pub struct IpcMessage(pub Bytes);

fn flight_schema_as_encoded_data(arrow_schema: &Schema, options: &IpcWriteOptions) -> EncodedData {
let data_gen = writer::IpcDataGenerator::default();
#[allow(deprecated)]
let mut dict_tracker =
writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id());
data_gen.schema_to_bytes_with_dictionary_tracker(arrow_schema, &mut dict_tracker, options)
Expand Down
Loading

0 comments on commit 7632458

Please sign in to comment.