diff --git a/arrow-arith/src/arithmetic.rs b/arrow-arith/src/arithmetic.rs index 48f0412bf8c7..0db32d575761 100644 --- a/arrow-arith/src/arithmetic.rs +++ b/arrow-arith/src/arithmetic.rs @@ -1333,6 +1333,45 @@ where }); } +/// Perform `left % right` operation on two arrays. If either left or right value is null +/// then the result is also null. If any right hand value is zero then the result of this +/// operation will be `Err(ArrowError::DivideByZero)`. +pub fn modulus_dyn(left: &dyn Array, right: &dyn Array) -> Result { + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_math_op!( + left, + right, + |a, b| { + if b.is_zero() { + Err(ArrowError::DivideByZero) + } else { + Ok(a.mod_wrapping(b)) + } + }, + math_divide_checked_op_dict + ) + } + _ => { + downcast_primitive_array!( + (left, right) => { + math_checked_divide_op(left, right, |a, b| { + if b.is_zero() { + Err(ArrowError::DivideByZero) + } else { + Ok(a.mod_wrapping(b)) + } + }).map(|a| Arc::new(a) as ArrayRef) + } + _ => Err(ArrowError::CastError(format!( + "Unsupported data type {}, {}", + left.data_type(), right.data_type() + ))) + ) + } + } +} + /// Perform `left / right` operation on two arrays. If either left or right value is null /// then the result is also null. If any right hand value is zero then the result of this /// operation will be `Err(ArrowError::DivideByZero)`. @@ -1551,6 +1590,23 @@ where Ok(unary(array, |a| a.mod_wrapping(modulo))) } +/// Modulus every value in an array by a scalar. If any value in the array is null then the +/// result is also null. If the scalar is zero then the result of this operation will be +/// `Err(ArrowError::DivideByZero)`. +pub fn modulus_scalar_dyn( + array: &dyn Array, + modulo: T::Native, +) -> Result +where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, +{ + if modulo.is_zero() { + return Err(ArrowError::DivideByZero); + } + unary_dyn::<_, T>(array, |value| value.mod_wrapping(modulo)) +} + /// Divide every value in an array by a scalar. If any value in the array is null then the /// result is also null. If the scalar is zero then the result of this operation will be /// `Err(ArrowError::DivideByZero)`. @@ -2170,6 +2226,14 @@ mod tests { assert_eq!(0, c.value(2)); assert_eq!(1, c.value(3)); assert_eq!(0, c.value(4)); + + let c = modulus_dyn(&a, &b).unwrap(); + let c = as_primitive_array::(&c); + assert_eq!(0, c.value(0)); + assert_eq!(3, c.value(1)); + assert_eq!(0, c.value(2)); + assert_eq!(1, c.value(3)); + assert_eq!(0, c.value(4)); } #[test] @@ -2182,6 +2246,16 @@ mod tests { modulus(&a, &b).unwrap(); } + #[test] + #[should_panic( + expected = "called `Result::unwrap()` on an `Err` value: DivideByZero" + )] + fn test_int_array_modulus_dyn_divide_by_zero() { + let a = Int32Array::from(vec![1]); + let b = Int32Array::from(vec![0]); + modulus_dyn(&a, &b).unwrap(); + } + #[test] fn test_int_array_modulus_overflow_wrapping() { let a = Int32Array::from(vec![i32::MIN]); @@ -2258,6 +2332,11 @@ mod tests { let c = modulus_scalar(&a, b).unwrap(); let expected = Int32Array::from(vec![0, 2, 0, 2, 1]); assert_eq!(c, expected); + + let c = modulus_scalar_dyn::(&a, b).unwrap(); + let c = as_primitive_array::(&c); + let expected = Int32Array::from(vec![0, 2, 0, 2, 1]); + assert_eq!(c, &expected); } #[test] @@ -2268,6 +2347,11 @@ mod tests { let actual = modulus_scalar(a, 3).unwrap(); let expected = Int32Array::from(vec![None, Some(0), Some(2), None]); assert_eq!(actual, expected); + + let actual = modulus_scalar_dyn::(a, 3).unwrap(); + let actual = as_primitive_array::(&actual); + let expected = Int32Array::from(vec![None, Some(0), Some(2), None]); + assert_eq!(actual, &expected); } #[test] @@ -2283,7 +2367,11 @@ mod tests { fn test_int_array_modulus_scalar_overflow_wrapping() { let a = Int32Array::from(vec![i32::MIN]); let result = modulus_scalar(&a, -1).unwrap(); - assert_eq!(0, result.value(0)) + assert_eq!(0, result.value(0)); + + let result = modulus_scalar_dyn::(&a, -1).unwrap(); + let result = as_primitive_array::(&result); + assert_eq!(0, result.value(0)); } #[test] @@ -2566,6 +2654,14 @@ mod tests { modulus(&a, &b).unwrap(); } + #[test] + #[should_panic(expected = "DivideByZero")] + fn test_i32_array_modulus_dyn_by_zero() { + let a = Int32Array::from(vec![15]); + let b = Int32Array::from(vec![0]); + modulus_dyn(&a, &b).unwrap(); + } + #[test] #[should_panic(expected = "DivideByZero")] fn test_f32_array_modulus_by_zero() { @@ -2574,6 +2670,14 @@ mod tests { modulus(&a, &b).unwrap(); } + #[test] + #[should_panic(expected = "DivideByZero")] + fn test_f32_array_modulus_dyn_by_zero() { + let a = Float32Array::from(vec![1.5]); + let b = Float32Array::from(vec![0.0]); + modulus_dyn(&a, &b).unwrap(); + } + #[test] fn test_f64_array_divide() { let a = Float64Array::from(vec![15.0, 15.0, 8.0]);