Skip to content

Commit

Permalink
feat: move SIMD gather to PrimitiveArray::take (#2538)
Browse files Browse the repository at this point in the history
SIMD gather on primitive values is more generally applicable than only
in dict decoding and is therefore moved to `PrimitiveArray::take`.
  • Loading branch information
0ax1 authored Feb 27, 2025
1 parent 5a23ffe commit 2b895da
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 108 deletions.
30 changes: 2 additions & 28 deletions encodings/dict/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,10 @@ use vortex_array::{
Array, ArrayCanonicalImpl, ArrayImpl, ArrayRef, ArrayStatisticsImpl, ArrayValidityImpl,
Canonical, Encoding, EncodingId, IntoArray, RkyvMetadata, ToCanonical, encoding_ids,
};
use vortex_dtype::{
DType, PType, match_each_integer_ptype, match_each_native_simd_ptype,
match_each_unsigned_integer_ptype,
};
use vortex_dtype::{DType, match_each_integer_ptype};
use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
use vortex_mask::{AllOr, Mask};

use crate::compress::dict_decode_typed_primitive;
use crate::serde::DictMetadata;

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -101,29 +97,7 @@ impl ArrayCanonicalImpl for DictArray {
let canonical_values: ArrayRef = self.values().to_canonical()?.into_array();
take(&canonical_values, self.codes())?.to_canonical()
}
DType::Primitive(ptype, _)
// TODO(alex): handle nullable codes & values
if *ptype != PType::F16
&& self.codes().all_valid()?
&& self.values().all_valid()? =>
{
let codes = self.codes().to_primitive()?;
let values = self.values().to_primitive()?;

match_each_unsigned_integer_ptype!(codes.ptype(), |$C| {
match_each_native_simd_ptype!(values.ptype(), |$V| {
// SIMD types larger than the SIMD register size are beneficial for
// performance as this leads to better instruction level parallelism.
let decoded = dict_decode_typed_primitive::<$C, $V, 64>(
codes.as_slice(),
values.as_slice(),
self.dtype().nullability(),
);
decoded.to_canonical()
})
})
}
_ => take(self.values(), self.codes())?.to_canonical()
_ => take(self.values(), self.codes())?.to_canonical(),
}
}

Expand Down
74 changes: 0 additions & 74 deletions encodings/dict/src/compress.rs

This file was deleted.

3 changes: 0 additions & 3 deletions encodings/dict/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#![feature(portable_simd)]

//! Implementation of Dictionary encoding.
//!
//! Expose a [DictArray] which is zero-copy equivalent to Arrow's
Expand All @@ -8,7 +6,6 @@ pub use array::*;

mod array;
pub mod builders;
mod compress;
mod compute;
mod serde;
mod stats;
Expand Down
98 changes: 96 additions & 2 deletions vortex-array/src/arrays/primitive/compute/take.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
use std::simd;

use num_traits::AsPrimitive;
use vortex_buffer::Buffer;
use vortex_dtype::{NativePType, match_each_integer_ptype, match_each_native_ptype};
use simd::num::SimdUint;
use vortex_buffer::{Alignment, Buffer, BufferMut};
use vortex_dtype::{
NativePType, Nullability, PType, match_each_integer_ptype, match_each_native_ptype,
match_each_native_simd_ptype, match_each_unsigned_integer_ptype,
};
use vortex_error::{VortexResult, vortex_err};
use vortex_mask::Mask;

Expand All @@ -17,6 +23,27 @@ impl TakeFn<&PrimitiveArray> for PrimitiveEncoding {
let indices = indices.to_primitive()?;
let validity = array.validity().take(&indices)?;

if array.ptype() != PType::F16
&& indices.dtype().is_unsigned_int()
&& indices.all_valid()?
&& array.all_valid()?
{
// TODO(alex): handle nullable codes & values
match_each_unsigned_integer_ptype!(indices.ptype(), |$C| {
match_each_native_simd_ptype!(array.ptype(), |$V| {
// SIMD types larger than the SIMD register size are beneficial for
// performance as this leads to better instruction level parallelism.
let decoded = take_primitive_simd::<$C, $V, 64>(
indices.as_slice(),
array.as_slice(),
array.dtype().nullability() | indices.dtype().nullability(),
);

return Ok(decoded.into_array()) as VortexResult<ArrayRef>;
})
});
}

match_each_native_ptype!(array.ptype(), |$T| {
match_each_integer_ptype!(indices.ptype(), |$I| {
let values = take_primitive(array.as_slice::<$T>(), indices.as_slice::<$I>());
Expand Down Expand Up @@ -74,6 +101,73 @@ fn take_primitive<T: NativePType, I: NativePType + AsPrimitive<usize>>(
indices.iter().map(|idx| array[idx.as_()]).collect()
}

/// Takes elements from an array using SIMD indexing.
///
/// # Type Parameters
/// * `C` - Index type
/// * `V` - Value type
/// * `LANE_COUNT` - Number of SIMD lanes to process in parallel
///
/// # Parameters
/// * `indices` - Indices to gather values from
/// * `values` - Source values to index
/// * `nullability` - Nullability of the resulting array
///
/// # Returns
/// A `PrimitiveArray` containing the gathered values where each index has been replaced with
/// the corresponding value from the source array.
fn take_primitive_simd<I, V, const LANE_COUNT: usize>(
indices: &[I],
values: &[V],
nullability: Nullability,
) -> PrimitiveArray
where
I: simd::SimdElement + AsPrimitive<usize>,
V: simd::SimdElement + NativePType,
simd::LaneCount<LANE_COUNT>: simd::SupportedLaneCount,
simd::Simd<I, LANE_COUNT>: SimdUint<Cast<usize> = simd::Simd<usize, LANE_COUNT>>,
{
let indices_len = indices.len();

let mut buffer = BufferMut::<V>::with_capacity_aligned(
indices_len,
Alignment::of::<simd::Simd<V, LANE_COUNT>>(),
);

let buf_slice = buffer.spare_capacity_mut();

for chunk_idx in 0..(indices_len / LANE_COUNT) {
let offset = chunk_idx * LANE_COUNT;
let mask = simd::Mask::from_bitmask(u64::MAX);
let codes_chunk = simd::Simd::<I, LANE_COUNT>::from_slice(&indices[offset..]);

unsafe {
let selection = simd::Simd::gather_select_unchecked(
values,
mask,
codes_chunk.cast::<usize>(),
simd::Simd::<V, LANE_COUNT>::default(),
);

selection.store_select_ptr(buf_slice.as_mut_ptr().add(offset) as *mut V, mask.cast());
}
}

for idx in ((indices_len / LANE_COUNT) * LANE_COUNT)..indices_len {
unsafe {
buf_slice
.get_unchecked_mut(idx)
.write(values[indices[idx].as_()]);
}
}

unsafe {
buffer.set_len(indices_len);
}

PrimitiveArray::new(buffer.freeze(), nullability.into())
}

#[cfg(test)]
mod test {
use vortex_buffer::buffer;
Expand Down
3 changes: 2 additions & 1 deletion vortex-array/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![feature(once_cell_try)]
#![feature(trusted_len)]
#![feature(portable_simd)]
#![feature(substr_range)]
#![feature(trusted_len)]
//! Vortex crate containing core logic for encoding and memory representation of [arrays](ArrayRef).
//!
//! At the heart of Vortex are [arrays](ArrayRef) and [encodings](vtable::EncodingVTable).
Expand Down

0 comments on commit 2b895da

Please sign in to comment.