Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: fuse index runend decoding - take_from #2527

Merged
merged 17 commits into from
Feb 28, 2025
3 changes: 0 additions & 3 deletions encodings/dict/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ impl ScalarAtFn<&DictArray> for DictEncoding {

impl TakeFn<&DictArray> for DictEncoding {
fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
// Dict
// codes: 0 0 1
// dict: a b c d e f g h
let codes = take(array.codes(), indices)?;
DictArray::try_new(codes, array.values().clone()).map(|a| a.into_array())
}
Expand Down
22 changes: 22 additions & 0 deletions encodings/runend/benches/run_end_compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use itertools::repeat_n;
use num_traits::PrimInt;
use vortex_array::Array;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::compute::take;
use vortex_array::validity::Validity;
use vortex_buffer::Buffer;
use vortex_dtype::NativePType;
Expand Down Expand Up @@ -62,3 +63,24 @@ fn decompress<T: NativePType + PrimInt>(bencher: Bencher, (length, run_step): (u
.with_inputs(|| runend_array.to_array())
.bench_values(|array| array.to_canonical().unwrap());
}

#[divan::bench(args = BENCH_ARGS)]
#[allow(clippy::cast_possible_truncation)]
fn take_indices(bencher: Bencher, (length, run_step): (usize, usize)) {
let values = PrimitiveArray::new(
(0..length)
.step_by(run_step)
.enumerate()
.flat_map(|(idx, x)| repeat_n(idx as u64, x))
.collect::<Buffer<_>>(),
Validity::NonNullable,
);

let source_array = PrimitiveArray::from_iter(0..(length as i32)).into_array();
let (ends, values) = runend_encode(&values).unwrap();
let runend_array = RunEndArray::try_new(ends.into_array(), values).unwrap();

bencher
.with_inputs(|| (source_array.clone(), runend_array.to_array()))
.bench_refs(|(array, indices)| take(array, indices).unwrap());
}
6 changes: 6 additions & 0 deletions encodings/runend/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ mod invert;
mod scalar_at;
mod slice;
pub(crate) mod take;
mod take_from;

use vortex_array::Array;
use vortex_array::compute::{
BinaryNumericFn, CompareFn, FillNullFn, FilterFn, InvertFn, ScalarAtFn, SliceFn, TakeFn,
TakeFromFn,
};
use vortex_array::vtable::ComputeVTable;

Expand Down Expand Up @@ -47,6 +49,10 @@ impl ComputeVTable for RunEndEncoding {
fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
Some(self)
}

fn take_from_fn(&self) -> Option<&dyn TakeFromFn<&dyn Array>> {
Some(self)
}
}

#[cfg(test)]
Expand Down
1 change: 1 addition & 0 deletions encodings/runend/src/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ impl TakeFn<&RunEndArray> for RunEndEncoding {
})
.collect::<VortexResult<Vec<_>>>()?
});

take_indices_unchecked(array, &checked_indices)
}
}
Expand Down
45 changes: 45 additions & 0 deletions encodings/runend/src/compute/take_from.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use vortex_array::compute::{TakeFromFn, take};
use vortex_array::{Array, ArrayRef};
use vortex_dtype::DType;
use vortex_error::VortexResult;

use crate::{RunEndArray, RunEndEncoding};

impl TakeFromFn<&RunEndArray> for RunEndEncoding {
/// Takes values from the source array using run-end encoded indices.
///
/// # Arguments
///
/// * `indices` - Run-end encoded indices
/// * `source` - Array to take values from
///
/// # Returns
///
/// * `Ok(Some(source))` - If successful
/// * `Ok(None)` - If the source array has an unsupported dtype
///
fn take_from(
&self,
indices: &RunEndArray,
source: &dyn Array,
) -> VortexResult<Option<ArrayRef>> {
// Only `Primitive` and `Bool` are valid run-end value types. - TODO: Support additional DTypes
if !matches!(source.dtype(), DType::Primitive(_, _) | DType::Bool(_)) {
return Ok(None);
}

// Transform the run-end encoding from storing indices to storing values
// by taking values from `source` at positions specified by `indices.values()`.
let values = take(source, indices.values())?;

// Create a new run-end array containing values as values, instead of indices as values.
let ree_array = RunEndArray::with_offset_and_length(
indices.ends().clone(),
values,
indices.offset(),
indices.len(),
)?;

Ok(Some(ree_array.into_array()))
}
}
2 changes: 2 additions & 0 deletions vortex-array/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub use search_sorted::*;
pub use slice::{SliceFn, slice};
pub use sum::*;
pub use take::{TakeFn, take, take_into};
pub use take_from::TakeFromFn;
pub use to_arrow::*;

mod between;
Expand All @@ -49,6 +50,7 @@ mod search_sorted;
mod slice;
mod sum;
mod take;
mod take_from;
mod to_arrow;

#[cfg(feature = "test-harness")]
Expand Down
7 changes: 7 additions & 0 deletions vortex-array/src/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,13 @@ fn derive_take_stats(arr: &dyn Array) -> StatsSet {
}

fn take_impl(array: &dyn Array, indices: &dyn Array) -> VortexResult<ArrayRef> {
// First look for a TakeFrom specialized on the indices.
if let Some(take_from_fn) = indices.vtable().take_from_fn() {
if let Some(arr) = take_from_fn.take_from(indices, array)? {
return Ok(arr);
}
}

// If TakeFn defined for the encoding, delegate to TakeFn.
// If we know from stats that indices are all valid, we can avoid all bounds checks.
if let Some(take_fn) = array.vtable().take_fn() {
Expand Down
22 changes: 22 additions & 0 deletions vortex-array/src/compute/take_from.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use vortex_error::{VortexExpect, VortexResult};

use crate::encoding::Encoding;
use crate::{Array, ArrayRef};

pub trait TakeFromFn<A> {
fn take_from(&self, indices: A, array: &dyn Array) -> VortexResult<Option<ArrayRef>>;
}

impl<E: Encoding> TakeFromFn<&dyn Array> for E
where
E: for<'a> TakeFromFn<&'a E::Array>,
{
fn take_from(&self, indices: &dyn Array, array: &dyn Array) -> VortexResult<Option<ArrayRef>> {
let indices = indices
.as_any()
.downcast_ref::<E::Array>()
.vortex_expect("Failed to downcast array");

TakeFromFn::take_from(self, indices, array)
}
}
6 changes: 5 additions & 1 deletion vortex-array/src/vtable/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::Array;
use crate::compute::{
BetweenFn, BinaryBooleanFn, BinaryNumericFn, CastFn, CompareFn, FillForwardFn, FillNullFn,
FilterFn, InvertFn, IsConstantFn, LikeFn, MaskFn, MinMaxFn, ScalarAtFn, SearchSortedFn,
SearchSortedUsizeFn, SliceFn, SumFn, TakeFn, ToArrowFn,
SearchSortedUsizeFn, SliceFn, SumFn, TakeFn, TakeFromFn, ToArrowFn,
};

/// VTable for dispatching compute functions to Vortex encodings.
Expand Down Expand Up @@ -133,6 +133,10 @@ pub trait ComputeVTable {
None
}

fn take_from_fn(&self) -> Option<&dyn TakeFromFn<&dyn Array>> {
None
}

/// Convert the array to an Arrow array of the given type.
///
/// See: [ToArrowFn].
Expand Down
Loading