From 926f90af4996b5f0b61d7ec3ad80805b5a52a517 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Thu, 9 Apr 2026 16:07:41 +0100 Subject: [PATCH 1/8] TakeExecute for FilterArray Signed-off-by: Robert Kruszewski --- vortex-array/public-api.lock | 12 + vortex-array/src/arrays/filter/mod.rs | 1 + vortex-array/src/arrays/filter/take.rs | 1119 ++++++++++++++++++++++ vortex-array/src/arrays/filter/vtable.rs | 10 + vortex-buffer/public-api.lock | 4 + vortex-buffer/src/bit/buf.rs | 33 + vortex-buffer/src/bit/count_ones.rs | 6 +- vortex-buffer/src/bit/mod.rs | 1 + vortex-buffer/src/bit/select.rs | 690 +++++++++++++ vortex-mask/Cargo.toml | 4 - vortex-mask/public-api.lock | 2 + vortex-mask/src/lib.rs | 48 +- vortex-mask/src/tests.rs | 16 + 13 files changed, 1939 insertions(+), 7 deletions(-) create mode 100644 vortex-array/src/arrays/filter/take.rs create mode 100644 vortex-buffer/src/bit/select.rs diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index a790196e149..85750e6aee1 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -2764,6 +2764,10 @@ impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::Extension pub fn vortex_array::arrays::Extension::take(vortex_array::ArrayView<'_, vortex_array::arrays::Extension>, &vortex_array::ArrayRef, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::Filter + +pub fn vortex_array::arrays::Filter::take(vortex_array::ArrayView<'_, vortex_array::arrays::Filter>, &vortex_array::ArrayRef, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::FixedSizeList pub fn vortex_array::arrays::FixedSizeList::take(vortex_array::ArrayView<'_, vortex_array::arrays::FixedSizeList>, &vortex_array::ArrayRef, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> @@ -2980,6 +2984,10 @@ impl vortex_array::ValidityVTable for vortex_array pub fn vortex_array::arrays::Filter::validity(vortex_array::ArrayView<'_, vortex_array::arrays::Filter>) -> vortex_error::VortexResult +impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::Filter + +pub fn vortex_array::arrays::Filter::take(vortex_array::ArrayView<'_, vortex_array::arrays::Filter>, &vortex_array::ArrayRef, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + pub struct vortex_array::arrays::filter::FilterData impl vortex_array::arrays::filter::FilterData @@ -6248,6 +6256,10 @@ impl vortex_array::ValidityVTable for vortex_array pub fn vortex_array::arrays::Filter::validity(vortex_array::ArrayView<'_, vortex_array::arrays::Filter>) -> vortex_error::VortexResult +impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::Filter + +pub fn vortex_array::arrays::Filter::take(vortex_array::ArrayView<'_, vortex_array::arrays::Filter>, &vortex_array::ArrayRef, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + pub struct vortex_array::arrays::FixedSizeList impl core::clone::Clone for vortex_array::arrays::FixedSizeList diff --git a/vortex-array/src/arrays/filter/mod.rs b/vortex-array/src/arrays/filter/mod.rs index 39859bc5b25..15a987e8f20 100644 --- a/vortex-array/src/arrays/filter/mod.rs +++ b/vortex-array/src/arrays/filter/mod.rs @@ -16,6 +16,7 @@ pub use kernel::FilterReduce; pub use kernel::FilterReduceAdaptor; mod rules; +mod take; mod vtable; pub use vtable::Filter; diff --git a/vortex-array/src/arrays/filter/take.rs b/vortex-array/src/arrays/filter/take.rs new file mode 100644 index 00000000000..2733e93b65f --- /dev/null +++ b/vortex-array/src/arrays/filter/take.rs @@ -0,0 +1,1119 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_buffer::Buffer; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_mask::AllOr; +use vortex_mask::Mask; + +use super::Filter; +use crate::ArrayRef; +use crate::IntoArray; +use crate::array::ArrayView; +use crate::arrays::Decimal; +use crate::arrays::DecimalArray; +use crate::arrays::Primitive; +use crate::arrays::PrimitiveArray; +use crate::arrays::decimal::DecimalArrayExt; +use crate::arrays::dict::TakeExecute; +use crate::arrays::dict::TakeExecuteAdaptor; +use crate::arrays::filter::FilterArrayExt; +use crate::builtins::ArrayBuiltins; +use crate::dtype::DType; +use crate::dtype::IntegerPType; +use crate::dtype::NativeDecimalType; +use crate::dtype::NativePType; +use crate::executor::ExecutionCtx; +use crate::kernel::ParentKernelSet; +use crate::match_each_decimal_value_type; +use crate::match_each_integer_ptype; +use crate::match_each_native_ptype; +use crate::validity::Validity; + +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(Filter))]); + +const NULLABLE_FIXED_WIDTH_FULL_TAKE_FALLBACK_LEN: usize = 4096; + +fn take_impl( + array: ArrayView<'_, Filter>, + indices: &PrimitiveArray, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let indices_validity = indices.validity()?.execute_mask(indices.len(), ctx)?; + if let Some(taken) = take_primitive_fast_path(array, indices, &indices_validity)? { + return Ok(taken); + } + if let Some(taken) = take_decimal_fast_path(array, indices, &indices_validity)? { + return Ok(taken); + } + + if indices_validity.all_true() { + let translated = translate_indices_fast(array.filter_mask(), indices, array.len())?; + let translated_indices = PrimitiveArray::new( + translated, + Validity::from_mask(indices_validity, indices.dtype().nullability()), + ) + .into_array(); + + return array.child().take(translated_indices); + } + + let translated = translate_nullable_indices_fast( + array.filter_mask(), + indices, + &indices_validity, + array.len(), + )?; + let translated_indices = PrimitiveArray::new( + translated, + Validity::from_mask(indices_validity, indices.dtype().nullability()), + ) + .into_array(); + + array.child().take(translated_indices) +} + +fn should_fallback_nullable_fixed_width_full_take( + array: ArrayView<'_, Filter>, + indices: &ArrayRef, +) -> VortexResult { + if indices.len() < NULLABLE_FIXED_WIDTH_FULL_TAKE_FALLBACK_LEN + || indices.len() < array.len() + || !indices.dtype().is_nullable() + { + return Ok(false); + } + + // For large nullable full-or-larger fixed-width takes with nullable children, materializing the + // filter first and then taking the child beats translating every nullable rank through the + // parent. + if array.child().dtype().is_decimal() || array.child().dtype().is_primitive() { + return Ok(!array.child().validity()?.no_nulls()); + } + + Ok(false) +} + +fn translate_nullable_indices_fast( + filter: &Mask, + indices: &PrimitiveArray, + indices_validity: &Mask, + filtered_len: usize, +) -> VortexResult> { + match_each_integer_ptype!(indices.ptype(), |P| { + translate_nullable_ranks_fast( + filter, + indices.as_slice::

(), + indices_validity, + filtered_len, + ) + }) +} + +fn translate_nullable_ranks_fast( + filter: &Mask, + ranks: &[P], + indices_validity: &Mask, + filtered_len: usize, +) -> VortexResult> { + if indices_validity.all_true() { + return translate_ranks_fast(filter, ranks, filtered_len); + } + if indices_validity.all_false() { + return Ok(Buffer::zeroed(ranks.len())); + } + + if let Some(start) = contiguous_filter_start(filter, filtered_len) { + return translate_nullable_ranks_with_offset(ranks, indices_validity, filtered_len, start); + } + + match filter.indices() { + AllOr::All => translate_nullable_ranks_identity(ranks, indices_validity, filtered_len), + AllOr::None => unreachable!("empty filters are handled by take preconditions"), + AllOr::Some(filter_indices) => translate_nullable_ranks_with_indices( + ranks, + indices_validity, + filtered_len, + filter_indices, + ), + } +} + +fn translate_nullable_ranks_with_offset( + ranks: &[P], + indices_validity: &Mask, + filtered_len: usize, + start: usize, +) -> VortexResult> { + let mut translated = BufferMut::::with_capacity(ranks.len()); + let translated_ptr = translated.spare_capacity_mut().as_mut_ptr().cast::(); + + for (idx, rank) in ranks.iter().enumerate() { + let translated_rank = if indices_validity.value(idx) { + let rank = validate_rank(*rank, filtered_len)?; + u64::try_from(start + rank)? + } else { + 0 + }; + + // SAFETY: `translated` has capacity for all ranks and this loop initializes each + // output slot once. + unsafe { translated_ptr.add(idx).write(translated_rank) }; + } + + // SAFETY: The loop writes exactly `ranks.len()` initialized values. + unsafe { translated.set_len(ranks.len()) }; + Ok(translated.freeze()) +} + +fn translate_nullable_ranks_identity( + ranks: &[P], + indices_validity: &Mask, + filtered_len: usize, +) -> VortexResult> { + let mut translated = BufferMut::::with_capacity(ranks.len()); + let translated_ptr = translated.spare_capacity_mut().as_mut_ptr().cast::(); + + for (idx, rank) in ranks.iter().enumerate() { + let translated_rank = if indices_validity.value(idx) { + u64::try_from(validate_rank(*rank, filtered_len)?)? + } else { + 0 + }; + + // SAFETY: `translated` has capacity for all ranks and this loop initializes each + // output slot once. + unsafe { translated_ptr.add(idx).write(translated_rank) }; + } + + // SAFETY: The loop writes exactly `ranks.len()` initialized values. + unsafe { translated.set_len(ranks.len()) }; + Ok(translated.freeze()) +} + +fn translate_nullable_ranks_with_indices( + ranks: &[P], + indices_validity: &Mask, + filtered_len: usize, + filter_indices: &[usize], +) -> VortexResult> { + let mut translated = BufferMut::::with_capacity(ranks.len()); + let translated_ptr = translated.spare_capacity_mut().as_mut_ptr().cast::(); + + for (idx, rank) in ranks.iter().enumerate() { + let translated_rank = if indices_validity.value(idx) { + let rank = validate_rank(*rank, filtered_len)?; + // SAFETY: `rank` was checked against the filtered length, so it is in bounds for + // `filter_indices`; filter indices are valid child positions by construction. + unsafe { u64::try_from(*filter_indices.get_unchecked(rank))? } + } else { + 0 + }; + + // SAFETY: `translated` has capacity for all ranks and this loop initializes each + // output slot once. + unsafe { translated_ptr.add(idx).write(translated_rank) }; + } + + // SAFETY: The loop writes exactly `ranks.len()` initialized values. + unsafe { translated.set_len(ranks.len()) }; + Ok(translated.freeze()) +} + +fn translate_indices_fast( + filter: &Mask, + indices: &PrimitiveArray, + filtered_len: usize, +) -> VortexResult> { + match_each_integer_ptype!(indices.ptype(), |P| { + translate_ranks_fast(filter, indices.as_slice::

(), filtered_len) + }) +} + +fn translate_ranks_fast( + filter: &Mask, + ranks: &[P], + filtered_len: usize, +) -> VortexResult> { + let mut translated = BufferMut::::with_capacity(ranks.len()); + let translated_ptr = translated.spare_capacity_mut().as_mut_ptr().cast::(); + + if let Some(start) = contiguous_filter_start(filter, filtered_len) { + for (idx, rank) in ranks.iter().enumerate() { + let Some(rank) = rank.to_usize() else { + vortex_bail!(OutOfBounds: 0, 0, filtered_len); + }; + if rank >= filtered_len { + vortex_bail!(OutOfBounds: rank, 0, filtered_len); + } + + // SAFETY: `translated` has capacity for all ranks and this loop initializes each + // output slot once. + unsafe { translated_ptr.add(idx).write(u64::try_from(start + rank)?) }; + } + } else { + let filter_indices = match filter.indices() { + AllOr::All => { + for (idx, rank) in ranks.iter().enumerate() { + let Some(rank) = rank.to_usize() else { + vortex_bail!(OutOfBounds: 0, 0, filtered_len); + }; + if rank >= filtered_len { + vortex_bail!(OutOfBounds: rank, 0, filtered_len); + } + + // SAFETY: `translated` has capacity for all ranks and this loop initializes + // each output slot once. + unsafe { translated_ptr.add(idx).write(u64::try_from(rank)?) }; + } + + // SAFETY: The loop writes exactly `ranks.len()` initialized values. + unsafe { translated.set_len(ranks.len()) }; + return Ok(translated.freeze()); + } + AllOr::None => unreachable!("empty filters are handled by take preconditions"), + AllOr::Some(filter_indices) => filter_indices, + }; + + for (idx, rank) in ranks.iter().enumerate() { + let Some(rank) = rank.to_usize() else { + vortex_bail!(OutOfBounds: 0, 0, filtered_len); + }; + if rank >= filtered_len { + vortex_bail!(OutOfBounds: rank, 0, filtered_len); + } + + // SAFETY: `translated` has capacity for all ranks. `rank` was checked against the + // filtered length, and filter indices are valid child positions by construction. + unsafe { + translated_ptr + .add(idx) + .write(u64::try_from(*filter_indices.get_unchecked(rank))?) + }; + } + } + + // SAFETY: Each loop path writes exactly `ranks.len()` initialized values. + unsafe { translated.set_len(ranks.len()) }; + Ok(translated.freeze()) +} + +fn validate_rank(rank: P, filtered_len: usize) -> VortexResult { + let Some(rank) = rank.to_usize() else { + vortex_bail!(OutOfBounds: 0, 0, filtered_len); + }; + if rank >= filtered_len { + vortex_bail!(OutOfBounds: rank, 0, filtered_len); + } + Ok(rank) +} + +fn take_primitive_fast_path( + array: ArrayView<'_, Filter>, + indices: &PrimitiveArray, + indices_validity: &Mask, +) -> VortexResult> { + let Some(child) = array.child().as_opt::() else { + return Ok(None); + }; + + let child_validity = child.validity()?; + if !child_validity.no_nulls() { + return Ok(None); + } + + let output_validity = + Validity::from_mask(indices_validity.clone(), indices.dtype().nullability()); + match_each_native_ptype!(child.ptype(), |T| { + match_each_integer_ptype!(indices.ptype(), |P| { + take_primitive_fast_path_typed::( + child, + array.filter_mask(), + array.len(), + indices, + indices_validity, + output_validity, + ) + .map(Some) + }) + }) +} + +fn take_decimal_fast_path( + array: ArrayView<'_, Filter>, + indices: &PrimitiveArray, + indices_validity: &Mask, +) -> VortexResult> { + let Some(child) = array.child().as_opt::() else { + return Ok(None); + }; + + let child_validity = child.validity()?; + if !child_validity.no_nulls() { + return Ok(None); + } + + let output_validity = + Validity::from_mask(indices_validity.clone(), indices.dtype().nullability()); + match_each_decimal_value_type!(child.values_type(), |T| { + match_each_integer_ptype!(indices.ptype(), |P| { + take_decimal_fast_path_typed::( + child, + array.filter_mask(), + array.len(), + indices, + indices_validity, + output_validity, + ) + .map(Some) + }) + }) +} + +fn take_decimal_fast_path_typed( + child: ArrayView<'_, Decimal>, + filter: &Mask, + filtered_len: usize, + indices: &PrimitiveArray, + indices_validity: &Mask, + output_validity: Validity, +) -> VortexResult +where + T: NativeDecimalType, + P: IntegerPType, +{ + let ranks = indices.as_slice::

(); + let decimal_dtype = child.decimal_dtype(); + + if indices_validity.all_true() { + if let Some((start, end)) = contiguous_sequential_take_range(filter, ranks, filtered_len)? { + let values = child.buffer_handle().slice_typed::(start..end); + // SAFETY: The values are sliced from an existing valid decimal array, and the output + // validity was built for exactly the sliced take length. + return Ok(unsafe { + DecimalArray::new_unchecked_handle( + values, + T::DECIMAL_TYPE, + decimal_dtype, + output_validity, + ) + } + .into_array()); + } + + let taken = take_filtered_values::( + child.buffer::().as_slice(), + filter, + ranks, + filtered_len, + )?; + // SAFETY: Taking existing decimal values preserves the decimal dtype invariants, and the + // output validity was built for the take length. + return Ok( + unsafe { DecimalArray::new_unchecked(taken, decimal_dtype, output_validity) } + .into_array(), + ); + } + + let taken = take_filtered_values_nullable::( + child.buffer::().as_slice(), + filter, + ranks, + indices_validity, + filtered_len, + )?; + // SAFETY: Valid ranks copy existing decimal values, null ranks write default placeholders that + // are hidden by output validity, and the output validity was built for the take length. + Ok(unsafe { DecimalArray::new_unchecked(taken, decimal_dtype, output_validity) }.into_array()) +} + +fn take_primitive_fast_path_typed( + child: ArrayView<'_, Primitive>, + filter: &Mask, + filtered_len: usize, + indices: &PrimitiveArray, + indices_validity: &Mask, + output_validity: Validity, +) -> VortexResult +where + T: NativePType, + P: IntegerPType, +{ + let ranks = indices.as_slice::

(); + + if indices_validity.all_true() { + return take_primitive_fast_path_all_valid::( + child, + filter, + filtered_len, + ranks, + output_validity, + ); + } + + take_primitive_fast_path_nullable::( + child.as_slice::(), + filter, + filtered_len, + ranks, + indices_validity, + output_validity, + ) +} + +fn take_primitive_fast_path_all_valid( + child: ArrayView<'_, Primitive>, + filter: &Mask, + filtered_len: usize, + ranks: &[P], + output_validity: Validity, +) -> VortexResult +where + T: NativePType, + P: IntegerPType, +{ + if let Some((start, end)) = contiguous_sequential_take_range(filter, ranks, filtered_len)? { + return Ok(PrimitiveArray::from_buffer_handle( + child.buffer_handle().slice_typed::(start..end), + T::PTYPE, + output_validity, + ) + .into_array()); + } + + let taken = take_filtered_values::(child.as_slice::(), filter, ranks, filtered_len)?; + Ok(PrimitiveArray::new(taken, output_validity).into_array()) +} + +fn take_primitive_fast_path_nullable( + values: &[T], + filter: &Mask, + filtered_len: usize, + ranks: &[P], + indices_validity: &Mask, + output_validity: Validity, +) -> VortexResult +where + T: NativePType, + P: IntegerPType, +{ + let taken = take_filtered_values_nullable::( + values, + filter, + ranks, + indices_validity, + filtered_len, + )?; + Ok(PrimitiveArray::new(taken, output_validity).into_array()) +} + +fn take_filtered_values_nullable( + values: &[T], + filter: &Mask, + ranks: &[P], + indices_validity: &Mask, + filtered_len: usize, +) -> VortexResult> +where + T: Copy + Default, + P: IntegerPType, +{ + if indices_validity.all_false() { + return Ok(Buffer::zeroed(ranks.len())); + } + + if let Some(start) = contiguous_filter_start(filter, filtered_len) { + let mut out = BufferMut::::with_capacity(ranks.len()); + let out_ptr = out.spare_capacity_mut().as_mut_ptr().cast::(); + for (idx, rank) in ranks.iter().enumerate() { + let value = if indices_validity.value(idx) { + let rank = validate_rank(*rank, filtered_len)?; + // SAFETY: `rank` was checked against the contiguous filtered length. + unsafe { *values.get_unchecked(start + rank) } + } else { + T::default() + }; + + // SAFETY: `out` has capacity for all ranks and this loop initializes each output slot + // once. + unsafe { out_ptr.add(idx).write(value) }; + } + + // SAFETY: The loop writes exactly `ranks.len()` initialized values. + unsafe { out.set_len(ranks.len()) }; + return Ok(out.freeze()); + } + + let indices = match filter.indices() { + AllOr::All => { + return take_values_by_rank_nullable(values, ranks, indices_validity, filtered_len); + } + AllOr::None => unreachable!("empty filters are handled by take preconditions"), + AllOr::Some(indices) => indices, + }; + + let mut out = BufferMut::::with_capacity(ranks.len()); + let out_ptr = out.spare_capacity_mut().as_mut_ptr().cast::(); + for (idx, rank) in ranks.iter().enumerate() { + let value = if indices_validity.value(idx) { + let rank = validate_rank(*rank, filtered_len)?; + // SAFETY: `rank` was bounds-checked against `indices`, whose values are valid + // positions in `values`. + unsafe { *values.get_unchecked(*indices.get_unchecked(rank)) } + } else { + T::default() + }; + + // SAFETY: `out` has capacity for all ranks and this loop initializes each output slot + // once. + unsafe { out_ptr.add(idx).write(value) }; + } + + // SAFETY: The loop writes exactly `ranks.len()` initialized values. + unsafe { out.set_len(ranks.len()) }; + Ok(out.freeze()) +} + +fn take_values_by_rank_nullable( + values: &[T], + ranks: &[P], + indices_validity: &Mask, + filtered_len: usize, +) -> VortexResult> +where + T: Copy + Default, + P: IntegerPType, +{ + let mut out = BufferMut::::with_capacity(ranks.len()); + let out_ptr = out.spare_capacity_mut().as_mut_ptr().cast::(); + for (idx, rank) in ranks.iter().enumerate() { + let value = if indices_validity.value(idx) { + let rank = validate_rank(*rank, filtered_len)?; + // SAFETY: `rank` was bounds-checked. + unsafe { *values.get_unchecked(rank) } + } else { + T::default() + }; + + // SAFETY: `out` has capacity for all ranks and this loop initializes each output slot + // once. + unsafe { out_ptr.add(idx).write(value) }; + } + + // SAFETY: The loop writes exactly `ranks.len()` initialized values. + unsafe { out.set_len(ranks.len()) }; + Ok(out.freeze()) +} + +fn contiguous_sequential_take_range( + filter: &Mask, + ranks: &[P], + filtered_len: usize, +) -> VortexResult> { + let Some(start) = contiguous_filter_start(filter, filtered_len) else { + return Ok(None); + }; + + for (idx, rank) in ranks.iter().enumerate() { + let Some(rank) = rank.to_usize() else { + vortex_bail!(OutOfBounds: 0, 0, filtered_len); + }; + if rank >= filtered_len { + vortex_bail!(OutOfBounds: rank, 0, filtered_len); + } + if rank != idx { + return Ok(None); + } + } + + Ok(Some((start, start + ranks.len()))) +} + +fn take_filtered_values( + values: &[T], + filter: &Mask, + ranks: &[P], + filtered_len: usize, +) -> VortexResult> +where + T: Copy + Default, + P: IntegerPType, +{ + if let Some(start) = contiguous_filter_start(filter, filtered_len) { + let mut out = BufferMut::::with_capacity(ranks.len()); + let out_ptr = out.spare_capacity_mut().as_mut_ptr().cast::(); + for (idx, rank) in ranks.iter().enumerate() { + let Some(rank) = rank.to_usize() else { + vortex_bail!(OutOfBounds: 0, 0, filtered_len); + }; + if rank >= filtered_len { + vortex_bail!(OutOfBounds: rank, 0, filtered_len); + } + + // SAFETY: `out` has capacity for all ranks. The filter is contiguous with + // `filtered_len` values starting at `start`, and `rank` was checked above. + unsafe { out_ptr.add(idx).write(*values.get_unchecked(start + rank)) }; + } + + // SAFETY: The loop writes exactly `ranks.len()` initialized values. + unsafe { out.set_len(ranks.len()) }; + return Ok(out.freeze()); + } + + let indices = match filter.indices() { + AllOr::All => return take_values_by_rank(values, ranks, filtered_len), + AllOr::None => unreachable!("empty filters are handled by take preconditions"), + AllOr::Some(indices) => indices, + }; + + if ranks.len() == filtered_len && !first_rank_is_zero(ranks, filtered_len)? { + let filtered = gather_values_by_indices(values, indices); + return take_values_by_rank(filtered.as_slice(), ranks, filtered_len); + } + + let mut out = BufferMut::::with_capacity(ranks.len()); + let out_ptr = out.spare_capacity_mut().as_mut_ptr().cast::(); + for (idx, rank) in ranks.iter().enumerate() { + let Some(rank) = rank.to_usize() else { + vortex_bail!(OutOfBounds: 0, 0, filtered_len); + }; + if rank >= filtered_len { + vortex_bail!(OutOfBounds: rank, 0, filtered_len); + } + + // SAFETY: `out` has capacity for all ranks. `rank` was bounds-checked against + // `indices`, whose values are valid positions in `values`. + unsafe { + out_ptr + .add(idx) + .write(*values.get_unchecked(*indices.get_unchecked(rank))) + }; + } + + // SAFETY: The loop writes exactly `ranks.len()` initialized values. + unsafe { out.set_len(ranks.len()) }; + Ok(out.freeze()) +} + +fn gather_values_by_indices(values: &[T], indices: &[usize]) -> Buffer +where + T: Copy + Default, +{ + let mut out = BufferMut::::with_capacity(indices.len()); + let out_ptr = out.spare_capacity_mut().as_mut_ptr().cast::(); + + for (idx, &value_idx) in indices.iter().enumerate() { + // SAFETY: `out` has capacity for all indices and mask indices are valid positions in the + // child values buffer by construction. + unsafe { out_ptr.add(idx).write(*values.get_unchecked(value_idx)) }; + } + + // SAFETY: The loop writes exactly `indices.len()` initialized values. + unsafe { out.set_len(indices.len()) }; + out.freeze() +} + +fn first_rank_is_zero(ranks: &[P], filtered_len: usize) -> VortexResult { + let Some(first) = ranks.first() else { + return Ok(false); + }; + let Some(first) = first.to_usize() else { + vortex_bail!(OutOfBounds: 0, 0, filtered_len); + }; + if first >= filtered_len { + vortex_bail!(OutOfBounds: first, 0, filtered_len); + } + Ok(first == 0) +} + +fn take_values_by_rank( + values: &[T], + ranks: &[P], + filtered_len: usize, +) -> VortexResult> +where + T: Copy + Default, + P: IntegerPType, +{ + let mut out = BufferMut::::with_capacity(ranks.len()); + let out_ptr = out.spare_capacity_mut().as_mut_ptr().cast::(); + for (idx, rank) in ranks.iter().enumerate() { + let Some(rank) = rank.to_usize() else { + vortex_bail!(OutOfBounds: 0, 0, filtered_len); + }; + if rank >= filtered_len { + vortex_bail!(OutOfBounds: rank, 0, filtered_len); + } + + // SAFETY: `out` has capacity for all ranks and `rank` was bounds-checked. + unsafe { out_ptr.add(idx).write(*values.get_unchecked(rank)) }; + } + + // SAFETY: The loop writes exactly `ranks.len()` initialized values. + unsafe { out.set_len(ranks.len()) }; + Ok(out.freeze()) +} + +fn contiguous_filter_start(filter: &Mask, filtered_len: usize) -> Option { + let start = filter.first()?; + let end = filter.last()?.checked_add(1)?; + (end - start == filtered_len).then_some(start) +} + +impl TakeExecute for Filter { + fn take( + array: ArrayView<'_, Filter>, + indices: &ArrayRef, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + // Bool filtering is already very cheap. Translating take indices through the filter adds + // overhead without improving the downstream bool take, so leave bool children on the + // regular filter path. + if array.child().dtype().is_boolean() { + return Ok(None); + } + + let DType::Primitive(ptype, nullability) = indices.dtype() else { + vortex_bail!("Invalid indices dtype: {}", indices.dtype()) + }; + + if should_fallback_nullable_fixed_width_full_take(array, indices)? { + return Ok(None); + } + + let unsigned_indices = if ptype.is_unsigned_int() { + indices.clone().execute::(ctx)? + } else { + indices + .clone() + .cast(DType::Primitive(ptype.to_unsigned(), *nullability))? + .execute::(ctx)? + }; + + take_impl(array, &unsigned_indices, ctx).map(Some) + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_error::VortexResult; + use vortex_mask::Mask; + use vortex_session::VortexSession; + + use crate::IntoArray; + use crate::RecursiveCanonical; + use crate::arrays::BoolArray; + use crate::arrays::DecimalArray; + use crate::arrays::Dict; + use crate::arrays::DictArray; + use crate::arrays::FilterArray; + use crate::arrays::FixedSizeListArray; + use crate::arrays::ListArray; + use crate::arrays::Primitive; + use crate::arrays::PrimitiveArray; + use crate::arrays::StructArray; + use crate::arrays::VarBinViewArray; + use crate::assert_arrays_eq; + use crate::dtype::DecimalDType; + use crate::dtype::FieldNames; + use crate::executor::ExecutionCtx; + use crate::validity::Validity; + + #[test] + fn test_take_execute_kernel_maps_indices_through_filter() -> VortexResult<()> { + let filter = FilterArray::new( + PrimitiveArray::from_option_iter([Some(10i32), Some(20), Some(30), Some(40), None]) + .into_array(), + Mask::from_iter([true, false, true, true, false]), + ) + .into_array(); + let parent = DictArray::try_new( + PrimitiveArray::new( + buffer![2u64, 100, 0], + Validity::Array(BoolArray::from_iter([true, false, true]).into_array()), + ) + .into_array(), + filter.clone(), + )? + .into_array(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = filter + .execute_parent(&parent, 1, &mut ctx)? + .expect("filter child should execute its take parent"); + + assert!(result.as_opt::().is_some()); + assert_arrays_eq!( + result.execute::(&mut ctx)?.0, + PrimitiveArray::from_option_iter([Some(40i32), None, Some(10)]).into_array() + ); + Ok(()) + } + + #[test] + fn test_take_execute_kernel_nullable_fast_path_maps_indices_through_filter() -> VortexResult<()> + { + let filter = FilterArray::new( + buffer![10i32, 20, 30, 40, 50].into_array(), + Mask::from_slices(5, vec![(1, 4)]), + ) + .into_array(); + let parent = DictArray::try_new( + PrimitiveArray::new( + buffer![2u64, 100, 0], + Validity::Array(BoolArray::from_iter([true, false, true]).into_array()), + ) + .into_array(), + filter.clone(), + )? + .into_array(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = filter + .execute_parent(&parent, 1, &mut ctx)? + .expect("filter child should execute its take parent"); + + assert!(result.as_opt::().is_some()); + assert_arrays_eq!( + result.execute::(&mut ctx)?.0, + PrimitiveArray::from_option_iter([Some(40i32), None, Some(20)]).into_array() + ); + Ok(()) + } + + #[test] + fn test_take_execute_kernel_fast_path_maps_indices_through_filter() -> VortexResult<()> { + let filter = FilterArray::new( + buffer![10i32, 20, 30, 40, 50, 60].into_array(), + Mask::from_indices(6, vec![1, 3, 4, 5]), + ) + .into_array(); + let parent = + DictArray::try_new(buffer![2u64, 0, 3].into_array(), filter.clone())?.into_array(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = filter + .execute_parent(&parent, 1, &mut ctx)? + .expect("filter child should execute its take parent"); + + assert!(result.as_opt::().is_some()); + assert_arrays_eq!( + result.execute::(&mut ctx)?.0, + PrimitiveArray::from_iter([50i32, 20, 60]).into_array() + ); + Ok(()) + } + + fn assert_take_execute_maps_child_dtype( + child: crate::ArrayRef, + expected: crate::ArrayRef, + ) -> VortexResult<()> { + let filter = + FilterArray::new(child, Mask::from_iter([true, false, true, true, false])).into_array(); + let parent = + DictArray::try_new(buffer![2u64, 0, 1].into_array(), filter.clone())?.into_array(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = filter + .execute_parent(&parent, 1, &mut ctx)? + .expect("filter child should execute its take parent"); + + assert_arrays_eq!(result.execute::(&mut ctx)?.0, expected); + Ok(()) + } + + #[test] + fn test_take_execute_kernel_skips_bool_filter_child() -> VortexResult<()> { + let filter = FilterArray::new( + BoolArray::from_iter([true, false, true, true, false]).into_array(), + Mask::from_iter([true, false, true, true, false]), + ) + .into_array(); + let parent = + DictArray::try_new(buffer![2u64, 0, 1].into_array(), filter.clone())?.into_array(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = filter.execute_parent(&parent, 1, &mut ctx)?; + + assert!(result.is_none()); + Ok(()) + } + + fn execute_large_nullable_fixed_width_take( + child: crate::ArrayRef, + ) -> VortexResult> { + let filter = + FilterArray::new(child, Mask::from_iter([true, false, true, true, false])).into_array(); + let indices = PrimitiveArray::from_option_iter( + (0..=super::NULLABLE_FIXED_WIDTH_FULL_TAKE_FALLBACK_LEN) + .map(|idx| Some((idx % 3) as u64)), + ) + .into_array(); + let parent = DictArray::try_new(indices, filter.clone())?.into_array(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + filter.execute_parent(&parent, 1, &mut ctx) + } + + #[test] + fn test_take_execute_kernel_handles_large_nullable_primitive_take_without_child_nulls() + -> VortexResult<()> { + let result = + execute_large_nullable_fixed_width_take(buffer![10i32, 20, 30, 40, 50].into_array())?; + + assert_eq!( + result + .expect("non-null fixed-width child should stay on the fast path") + .len(), + super::NULLABLE_FIXED_WIDTH_FULL_TAKE_FALLBACK_LEN + 1 + ); + Ok(()) + } + + #[test] + fn test_take_execute_kernel_handles_large_nullable_decimal_take_without_child_nulls() + -> VortexResult<()> { + let decimal_dtype = DecimalDType::new(19, 2); + + let result = execute_large_nullable_fixed_width_take( + DecimalArray::new( + buffer![100i128, 200, 300, 400, 500], + decimal_dtype, + Validity::NonNullable, + ) + .into_array(), + )?; + + assert_eq!( + result + .expect("non-null fixed-width child should stay on the fast path") + .len(), + super::NULLABLE_FIXED_WIDTH_FULL_TAKE_FALLBACK_LEN + 1 + ); + Ok(()) + } + + #[test] + fn test_take_execute_kernel_falls_back_for_large_nullable_primitive_take_with_child_nulls() + -> VortexResult<()> { + let result = execute_large_nullable_fixed_width_take( + PrimitiveArray::from_option_iter([Some(10i32), Some(20), None, Some(40), Some(50)]) + .into_array(), + )?; + + assert!(result.is_none()); + Ok(()) + } + + #[test] + fn test_take_execute_kernel_falls_back_for_large_nullable_decimal_take_with_child_nulls() + -> VortexResult<()> { + let decimal_dtype = DecimalDType::new(19, 2); + let result = execute_large_nullable_fixed_width_take( + DecimalArray::from_option_iter( + [Some(100i128), Some(200), None, Some(400), Some(500)], + decimal_dtype, + ) + .into_array(), + )?; + + assert!(result.is_none()); + Ok(()) + } + + #[test] + fn test_take_execute_kernel_handles_decimal_filter_child() -> VortexResult<()> { + let decimal_dtype = DecimalDType::new(19, 2); + + assert_take_execute_maps_child_dtype( + DecimalArray::new( + buffer![100i128, 200, 300, 400, 500], + decimal_dtype, + Validity::NonNullable, + ) + .into_array(), + DecimalArray::new( + buffer![400i128, 100, 300], + decimal_dtype, + Validity::NonNullable, + ) + .into_array(), + ) + } + + #[test] + fn test_take_execute_kernel_handles_fixed_size_list_filter_child() -> VortexResult<()> { + assert_take_execute_maps_child_dtype( + FixedSizeListArray::new( + buffer![10u32, 11, 20, 21, 30, 31, 40, 41, 50, 51].into_array(), + 2, + Validity::NonNullable, + 5, + ) + .into_array(), + FixedSizeListArray::new( + buffer![40u32, 41, 10, 11, 30, 31].into_array(), + 2, + Validity::NonNullable, + 3, + ) + .into_array(), + ) + } + + #[test] + fn test_take_execute_kernel_handles_list_filter_child() -> VortexResult<()> { + assert_take_execute_maps_child_dtype( + ListArray::try_new( + buffer![10u32, 11, 20, 30, 31, 32, 40, 50, 51].into_array(), + buffer![0u32, 2, 3, 6, 7, 9].into_array(), + Validity::NonNullable, + )? + .into_array(), + ListArray::try_new( + buffer![40u32, 10, 11, 30, 31, 32].into_array(), + buffer![0u32, 1, 3, 6].into_array(), + Validity::NonNullable, + )? + .into_array(), + ) + } + + #[test] + fn test_take_execute_kernel_handles_string_filter_child() -> VortexResult<()> { + assert_take_execute_maps_child_dtype( + VarBinViewArray::from_iter_str(["a", "b", "c", "d", "e"]).into_array(), + VarBinViewArray::from_iter_str(["d", "a", "c"]).into_array(), + ) + } + + #[test] + fn test_take_execute_kernel_handles_struct_filter_child() -> VortexResult<()> { + assert_take_execute_maps_child_dtype( + StructArray::try_new( + FieldNames::from(["id", "value"]), + vec![ + buffer![10u32, 20, 30, 40, 50].into_array(), + buffer![100u64, 200, 300, 400, 500].into_array(), + ], + 5, + Validity::NonNullable, + )? + .into_array(), + StructArray::try_new( + FieldNames::from(["id", "value"]), + vec![ + buffer![40u32, 10, 30].into_array(), + buffer![400u64, 100, 300].into_array(), + ], + 3, + Validity::NonNullable, + )? + .into_array(), + ) + } +} diff --git a/vortex-array/src/arrays/filter/vtable.rs b/vortex-array/src/arrays/filter/vtable.rs index 9f70e6e2614..312ed536f77 100644 --- a/vortex-array/src/arrays/filter/vtable.rs +++ b/vortex-array/src/arrays/filter/vtable.rs @@ -34,6 +34,7 @@ use crate::arrays::filter::execute::execute_filter; use crate::arrays::filter::execute::execute_filter_fast_paths; use crate::arrays::filter::rules::PARENT_RULES; use crate::arrays::filter::rules::RULES; +use crate::arrays::filter::take::PARENT_KERNELS; use crate::buffer::BufferHandle; use crate::dtype::DType; use crate::executor::ExecutionCtx; @@ -170,6 +171,15 @@ impl VTable for Filter { PARENT_RULES.evaluate(array, parent, child_idx) } + fn execute_parent( + array: ArrayView<'_, Self>, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } + fn reduce(array: ArrayView<'_, Self>) -> VortexResult> { RULES.evaluate(array) } diff --git a/vortex-buffer/public-api.lock b/vortex-buffer/public-api.lock index 4e0b1190a7a..ec0570ac4a3 100644 --- a/vortex-buffer/public-api.lock +++ b/vortex-buffer/public-api.lock @@ -280,6 +280,10 @@ pub fn vortex_buffer::BitBuffer::new_with_offset(vortex_buffer::ByteBuffer, usiz pub fn vortex_buffer::BitBuffer::offset(&self) -> usize +pub fn vortex_buffer::BitBuffer::select(&self, usize) -> usize + +pub fn vortex_buffer::BitBuffer::select_sorted_batch(&self, &[usize]) -> alloc::vec::Vec + pub fn vortex_buffer::BitBuffer::set_indices(&self) -> arrow_buffer::util::bit_iterator::BitIndexIterator<'_> pub fn vortex_buffer::BitBuffer::set_slices(&self) -> arrow_buffer::util::bit_iterator::BitSliceIterator<'_> diff --git a/vortex-buffer/src/bit/buf.rs b/vortex-buffer/src/bit/buf.rs index d9c30ea1917..108c99d0657 100644 --- a/vortex-buffer/src/bit/buf.rs +++ b/vortex-buffer/src/bit/buf.rs @@ -25,6 +25,8 @@ use crate::bit::count_ones::count_ones; use crate::bit::get_bit_unchecked; use crate::bit::ops::bitwise_binary_op; use crate::bit::ops::bitwise_unary_op; +use crate::bit::select::bit_select; +use crate::bit::select::bit_select_sorted_batch; use crate::buffer; /// An immutable bitset stored as a packed byte buffer. @@ -319,6 +321,37 @@ impl BitBuffer { count_ones(self.buffer.as_slice(), self.offset, self.len) } + /// Returns the position of the `nth` set bit (0-indexed). + /// + /// This is the "select" operation on a bitmap: given a rank `nth`, find + /// which logical bit position holds that rank. + /// + /// # Panics + /// + /// Panics (debug) or produces undefined results (release) if `nth` is + /// greater than or equal to [`true_count`](Self::true_count). + pub fn select(&self, nth: usize) -> usize { + bit_select(self.buffer.as_slice(), self.offset, self.len, nth) + } + + /// Select positions for multiple ranks in a single pass over the bitmap. + /// + /// `sorted_ranks` must be sorted in non-decreasing order, with each value + /// less than [`true_count`](Self::true_count). This is O(L/64 + N) where + /// L = bitmap length and N = number of ranks, compared to O(N × L/64) for + /// individual [`select`](Self::select) calls. + pub fn select_sorted_batch(&self, sorted_ranks: &[usize]) -> Vec { + let mut out = vec![0; sorted_ranks.len()]; + bit_select_sorted_batch( + self.buffer.as_slice(), + self.offset, + self.len, + sorted_ranks, + &mut out, + ); + out + } + /// Get the number of unset bits in the buffer. pub fn false_count(&self) -> usize { self.len - self.true_count() diff --git a/vortex-buffer/src/bit/count_ones.rs b/vortex-buffer/src/bit/count_ones.rs index 6d70d47cfa7..df5844a2914 100644 --- a/vortex-buffer/src/bit/count_ones.rs +++ b/vortex-buffer/src/bit/count_ones.rs @@ -22,7 +22,11 @@ pub fn count_ones(bytes: &[u8], offset: usize, len: usize) -> usize { } #[inline] -fn align_offset_len(bytes: &[u8], offset: usize, len: usize) -> (Option, &[u8], Option) { +pub(super) fn align_offset_len( + bytes: &[u8], + offset: usize, + len: usize, +) -> (Option, &[u8], Option) { let start_byte = offset / 8; let start_bit = offset % 8; let end_bit = offset + len; diff --git a/vortex-buffer/src/bit/mod.rs b/vortex-buffer/src/bit/mod.rs index 034be84a18c..37930d788b7 100644 --- a/vortex-buffer/src/bit/mod.rs +++ b/vortex-buffer/src/bit/mod.rs @@ -13,6 +13,7 @@ mod buf_mut; mod count_ones; mod macros; mod ops; +mod select; pub use arrow_buffer::bit_chunk_iterator::BitChunkIterator; pub use arrow_buffer::bit_chunk_iterator::BitChunks; diff --git a/vortex-buffer/src/bit/select.rs b/vortex-buffer/src/bit/select.rs new file mode 100644 index 00000000000..1ac6336971d --- /dev/null +++ b/vortex-buffer/src/bit/select.rs @@ -0,0 +1,690 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use super::count_ones::align_offset_len; + +/// Returns the position of the `nth` set bit (0-indexed) within the logical range +/// `[offset, offset + len)` of the given byte slice. +/// +/// The returned position is relative to the logical start (i.e., 0-indexed from `offset`). +/// +/// Uses architecture-specific optimizations: +/// - **aarch64**: NEON `vcnt`-based popcount for the word-level scan. +/// - **x86_64 + BMI2**: `pdep` + `tzcnt` for the final in-word select. +/// - **Scalar fallback**: 4× unrolled word scan with `count_ones`, byte-level narrowing. +#[inline] +pub fn bit_select(bytes: &[u8], offset: usize, len: usize, nth: usize) -> usize { + let (head, middle, tail) = align_offset_len(bytes, offset, len); + let mut remaining = nth; + let mut pos = 0usize; + + // ── partial first byte ────────────────────────────────────────────── + if let Some(head) = head { + let count = head.count_ones() as usize; + if remaining < count { + return select_in_byte(head, remaining); + } + remaining -= count; + let start_bit = offset % 8; + pos = (8 - start_bit).min(len); + } + + // ── aligned middle bytes ──────────────────────────────────────────── + if !middle.is_empty() { + let (words, tail_bytes) = middle.as_chunks::<8>(); + + let (rem, new_pos, word_idx) = scan_words(words, remaining, pos); + remaining = rem; + pos = new_pos; + + if word_idx < words.len() { + let word = u64::from_le_bytes(words[word_idx]); + return pos + select_in_word(word, remaining); + } + + // Remaining aligned bytes that don't fill a full u64. + for &byte in tail_bytes { + let count = byte.count_ones() as usize; + if remaining < count { + return pos + select_in_byte(byte, remaining); + } + remaining -= count; + pos += 8; + } + } + + // ── partial last byte ─────────────────────────────────────────────── + if let Some(tail) = tail { + debug_assert!( + remaining < tail.count_ones() as usize, + "bit_select: nth={nth} out of bounds" + ); + return pos + select_in_byte(tail, remaining); + } + + unreachable!("bit_select: nth={nth} exceeds set bit count") +} + +// ── Batch select (sorted ranks, single pass) ─────────────────────────── + +/// Select positions for multiple ranks in a single pass over the bitmap. +/// +/// `sorted_ranks` must be sorted in non-decreasing order; each value must be +/// less than the total number of set bits in `[offset, offset+len)`. +/// Results are written to `out[0..sorted_ranks.len()]`. +pub fn bit_select_sorted_batch( + bytes: &[u8], + offset: usize, + len: usize, + sorted_ranks: &[usize], + out: &mut [usize], +) { + debug_assert!(out.len() >= sorted_ranks.len()); + if sorted_ranks.is_empty() { + return; + } + + let (head, middle, tail) = align_offset_len(bytes, offset, len); + let mut cumul = 0usize; + let mut pos = 0usize; + let mut ri = 0usize; // index into sorted_ranks / out + + // ── head byte ─────────────────────────────────────────────────── + if let Some(head) = head { + let count = head.count_ones() as usize; + while ri < sorted_ranks.len() && sorted_ranks[ri] < cumul + count { + out[ri] = select_in_byte(head, sorted_ranks[ri] - cumul); + ri += 1; + } + cumul += count; + let start_bit = offset % 8; + pos = (8 - start_bit).min(len); + if ri >= sorted_ranks.len() { + return; + } + } + + // ── middle bytes ──────────────────────────────────────────────── + if !middle.is_empty() { + let (words, tail_bytes) = middle.as_chunks::<8>(); + + scan_words_batch_impl(words, sorted_ranks, out, &mut cumul, &mut pos, &mut ri); + if ri >= sorted_ranks.len() { + return; + } + + for &byte in tail_bytes { + let count = byte.count_ones() as usize; + while ri < sorted_ranks.len() && sorted_ranks[ri] < cumul + count { + out[ri] = pos + select_in_byte(byte, sorted_ranks[ri] - cumul); + ri += 1; + } + cumul += count; + pos += 8; + if ri >= sorted_ranks.len() { + return; + } + } + } + + // ── tail byte ─────────────────────────────────────────────────── + if let Some(tail) = tail { + let count = tail.count_ones() as usize; + while ri < sorted_ranks.len() && sorted_ranks[ri] < cumul + count { + out[ri] = pos + select_in_byte(tail, sorted_ranks[ri] - cumul); + ri += 1; + } + } +} + +// ── aarch64 NEON batch scan ───────────────────────────────────────────── + +#[cfg(target_arch = "aarch64")] +#[allow(clippy::cast_possible_truncation)] +fn scan_words_batch_impl( + words: &[[u8; 8]], + sorted_ranks: &[usize], + out: &mut [usize], + cumul: &mut usize, + pos: &mut usize, + ri: &mut usize, +) { + use std::arch::aarch64::vcntq_u8; + use std::arch::aarch64::vgetq_lane_u64; + use std::arch::aarch64::vld1q_u8; + use std::arch::aarch64::vpaddlq_u8; + use std::arch::aarch64::vpaddlq_u16; + use std::arch::aarch64::vpaddlq_u32; + + let mut wi = 0; + + // 4-word NEON blocks — skip entire blocks when no targets inside. + while wi + 4 <= words.len() && *ri < sorted_ranks.len() { + let ptr = words[wi].as_ptr(); + // SAFETY: wi + 4 <= words.len() guarantees 32 contiguous bytes. + let (c0, c1, c2, c3) = unsafe { + let pop_lo = vcntq_u8(vld1q_u8(ptr)); + let pop_hi = vcntq_u8(vld1q_u8(ptr.add(16))); + let sums_lo = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(pop_lo))); + let sums_hi = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(pop_hi))); + ( + vgetq_lane_u64::<0>(sums_lo) as usize, + vgetq_lane_u64::<1>(sums_lo) as usize, + vgetq_lane_u64::<0>(sums_hi) as usize, + vgetq_lane_u64::<1>(sums_hi) as usize, + ) + }; + let total = c0 + c1 + c2 + c3; + + if sorted_ranks[*ri] >= *cumul + total { + *cumul += total; + *pos += 256; + wi += 4; + continue; + } + + // At least one target in this block — emit per word. + let counts = [c0, c1, c2, c3]; + for (j, &count) in counts.iter().enumerate() { + while *ri < sorted_ranks.len() && sorted_ranks[*ri] < *cumul + count { + let word = u64::from_le_bytes(words[wi + j]); + out[*ri] = *pos + select_in_word(word, sorted_ranks[*ri] - *cumul); + *ri += 1; + } + *cumul += count; + *pos += 64; + } + wi += 4; + } + + // Remaining words, scalar. + while wi < words.len() && *ri < sorted_ranks.len() { + let word = u64::from_le_bytes(words[wi]); + let count = word.count_ones() as usize; + while *ri < sorted_ranks.len() && sorted_ranks[*ri] < *cumul + count { + out[*ri] = *pos + select_in_word(word, sorted_ranks[*ri] - *cumul); + *ri += 1; + } + *cumul += count; + *pos += 64; + wi += 1; + } +} + +// ── Scalar batch scan ─────────────────────────────────────────────────── + +#[cfg(not(target_arch = "aarch64"))] +fn scan_words_batch_impl( + words: &[[u8; 8]], + sorted_ranks: &[usize], + out: &mut [usize], + cumul: &mut usize, + pos: &mut usize, + ri: &mut usize, +) { + let mut wi = 0; + + while wi + 4 <= words.len() && *ri < sorted_ranks.len() { + let c0 = u64::from_le_bytes(words[wi]).count_ones() as usize; + let c1 = u64::from_le_bytes(words[wi + 1]).count_ones() as usize; + let c2 = u64::from_le_bytes(words[wi + 2]).count_ones() as usize; + let c3 = u64::from_le_bytes(words[wi + 3]).count_ones() as usize; + let total = c0 + c1 + c2 + c3; + + if sorted_ranks[*ri] >= *cumul + total { + *cumul += total; + *pos += 256; + wi += 4; + continue; + } + + let counts = [c0, c1, c2, c3]; + for (j, &count) in counts.iter().enumerate() { + while *ri < sorted_ranks.len() && sorted_ranks[*ri] < *cumul + count { + let word = u64::from_le_bytes(words[wi + j]); + out[*ri] = *pos + select_in_word(word, sorted_ranks[*ri] - *cumul); + *ri += 1; + } + *cumul += count; + *pos += 64; + } + wi += 4; + } + + while wi < words.len() && *ri < sorted_ranks.len() { + let word = u64::from_le_bytes(words[wi]); + let count = word.count_ones() as usize; + while *ri < sorted_ranks.len() && sorted_ranks[*ri] < *cumul + count { + out[*ri] = *pos + select_in_word(word, sorted_ranks[*ri] - *cumul); + *ri += 1; + } + *cumul += count; + *pos += 64; + wi += 1; + } +} + +// ── Word-level scan ───────────────────────────────────────────────────── + +/// Scan `words` accumulating popcounts. Returns `(remaining, position, word_index)`. +/// +/// If `word_index < words.len()`, the target bit is inside that word and `remaining` +/// is the rank *within* that word. Otherwise all words were consumed. +#[inline] +fn scan_words(words: &[[u8; 8]], remaining: usize, pos: usize) -> (usize, usize, usize) { + scan_words_impl(words, remaining, pos) +} + +// ── aarch64 NEON scan ─────────────────────────────────────────────────── + +#[cfg(target_arch = "aarch64")] +#[allow(clippy::cast_possible_truncation)] // u64 → usize is lossless on aarch64 (64-bit) +#[inline] +fn scan_words_impl( + words: &[[u8; 8]], + mut remaining: usize, + mut pos: usize, +) -> (usize, usize, usize) { + use std::arch::aarch64::vcntq_u8; + use std::arch::aarch64::vgetq_lane_u64; + use std::arch::aarch64::vld1q_u8; + use std::arch::aarch64::vpaddlq_u8; + use std::arch::aarch64::vpaddlq_u16; + use std::arch::aarch64::vpaddlq_u32; + + let mut idx = 0; + + // Process 4 u64 words at a time using two 128-bit NEON registers. + while idx + 4 <= words.len() { + let ptr = words[idx].as_ptr(); + // SAFETY: idx + 4 <= words.len() guarantees 32 contiguous bytes from ptr. + // NEON vld1q_u8 supports unaligned access. + let (count_0, count_1, count_2, count_3) = unsafe { + let pop_lo = vcntq_u8(vld1q_u8(ptr)); + let pop_hi = vcntq_u8(vld1q_u8(ptr.add(16))); + let sums_lo = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(pop_lo))); + let sums_hi = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(pop_hi))); + ( + vgetq_lane_u64::<0>(sums_lo) as usize, + vgetq_lane_u64::<1>(sums_lo) as usize, + vgetq_lane_u64::<0>(sums_hi) as usize, + vgetq_lane_u64::<1>(sums_hi) as usize, + ) + }; + + let total = count_0 + count_1 + count_2 + count_3; + if remaining >= total { + remaining -= total; + pos += 256; + idx += 4; + continue; + } + + // Narrow down to the exact word. + if remaining < count_0 { + return (remaining, pos, idx); + } + remaining -= count_0; + pos += 64; + if remaining < count_1 { + return (remaining, pos, idx + 1); + } + remaining -= count_1; + pos += 64; + if remaining < count_2 { + return (remaining, pos, idx + 2); + } + remaining -= count_2; + pos += 64; + return (remaining, pos, idx + 3); + } + + // Process pairs. + while idx + 2 <= words.len() { + let ptr = words[idx].as_ptr(); + // SAFETY: idx + 2 <= words.len() guarantees 16 contiguous bytes. + let (count_0, count_1) = unsafe { + let pop = vcntq_u8(vld1q_u8(ptr)); + let sums = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(pop))); + ( + vgetq_lane_u64::<0>(sums) as usize, + vgetq_lane_u64::<1>(sums) as usize, + ) + }; + let total = count_0 + count_1; + if remaining < total { + if remaining < count_0 { + return (remaining, pos, idx); + } + return (remaining - count_0, pos + 64, idx + 1); + } + remaining -= total; + pos += 128; + idx += 2; + } + + // Single trailing word. + if idx < words.len() { + let word = u64::from_le_bytes(words[idx]); + let count = word.count_ones() as usize; + if remaining < count { + return (remaining, pos, idx); + } + remaining -= count; + pos += 64; + idx += 1; + } + + (remaining, pos, idx) +} + +// ── Scalar scan (x86_64 / generic) ───────────────────────────────────── + +#[cfg(not(target_arch = "aarch64"))] +#[inline] +fn scan_words_impl( + words: &[[u8; 8]], + mut remaining: usize, + mut pos: usize, +) -> (usize, usize, usize) { + let mut idx = 0; + + // 4× unrolled: the four independent `count_ones` calls pipeline well. + while idx + 4 <= words.len() { + let count_0 = u64::from_le_bytes(words[idx]).count_ones() as usize; + let count_1 = u64::from_le_bytes(words[idx + 1]).count_ones() as usize; + let count_2 = u64::from_le_bytes(words[idx + 2]).count_ones() as usize; + let count_3 = u64::from_le_bytes(words[idx + 3]).count_ones() as usize; + let total = count_0 + count_1 + count_2 + count_3; + + if remaining >= total { + remaining -= total; + pos += 256; + idx += 4; + continue; + } + + if remaining < count_0 { + return (remaining, pos, idx); + } + remaining -= count_0; + pos += 64; + if remaining < count_1 { + return (remaining, pos, idx + 1); + } + remaining -= count_1; + pos += 64; + if remaining < count_2 { + return (remaining, pos, idx + 2); + } + remaining -= count_2; + pos += 64; + return (remaining, pos, idx + 3); + } + + while idx < words.len() { + let word = u64::from_le_bytes(words[idx]); + let count = word.count_ones() as usize; + if remaining < count { + return (remaining, pos, idx); + } + remaining -= count; + pos += 64; + idx += 1; + } + + (remaining, pos, idx) +} + +// ── In-word select ────────────────────────────────────────────────────── + +/// Position of the `nth` set bit inside a u64 (0-indexed, little-endian bit order). +#[inline] +fn select_in_word(word: u64, nth: usize) -> usize { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("bmi2") { + // SAFETY: runtime detection guarantees the required target feature. + return unsafe { select_in_word_bmi2(word, nth) }; + } + } + select_in_word_scalar(word, nth) +} + +/// BMI2: deposit a single bit at the nth set-bit position, then count trailing zeros. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "bmi2")] +unsafe fn select_in_word_bmi2(word: u64, nth: usize) -> usize { + use std::arch::x86_64::_pdep_u64; + use std::arch::x86_64::_tzcnt_u64; + + use vortex_error::VortexExpect; + + usize::try_from(unsafe { _tzcnt_u64(_pdep_u64(1u64 << nth, word)) }) + .vortex_expect("safe to convert tzcnt result to usize") +} + +/// Scalar: narrow to the correct byte, then clear `nth` lowest set bits and trailing-zeros. +#[inline] +fn select_in_word_scalar(word: u64, mut nth: usize) -> usize { + let bytes = word.to_le_bytes(); + let mut bit_offset = 0usize; + for &byte in &bytes { + let count = byte.count_ones() as usize; + if nth < count { + return bit_offset + select_in_byte(byte, nth); + } + nth -= count; + bit_offset += 8; + } + unreachable!("select_in_word: nth exceeds popcount") +} + +// ── In-byte select ────────────────────────────────────────────────────── + +/// Position of the `nth` set bit inside a byte (0-indexed, LSB-first). +/// +/// Clears the lowest `nth` set bits, then uses `trailing_zeros`. +#[inline] +fn select_in_byte(byte: u8, nth: usize) -> usize { + debug_assert!(nth < byte.count_ones() as usize); + let mut bits = u32::from(byte); + for _ in 0..nth { + bits &= bits - 1; // clear lowest set bit + } + bits.trailing_zeros() as usize +} + +#[cfg(test)] +mod tests { + #![allow(clippy::cast_possible_truncation)] + + use rstest::rstest; + + use super::*; + + #[test] + fn test_select_all_set() { + // Every bit is set — select(n) == n. + let buf = [0xFFu8; 16]; // 128 bits, all set + for nth in 0..128 { + assert_eq!(bit_select(&buf, 0, 128, nth), nth, "nth={nth}"); + } + } + + #[test] + fn test_select_every_other() { + // 0b01010101 repeated: bits 0,2,4,6 of each byte are set. + let buf = [0x55u8; 16]; // 128 bits, 64 set + for nth in 0..64 { + assert_eq!(bit_select(&buf, 0, 128, nth), nth * 2, "nth={nth}"); + } + } + + #[test] + fn test_select_single_bit() { + // Only bit 42 is set. + let mut buf = [0u8; 16]; + buf[42 / 8] |= 1 << (42 % 8); + assert_eq!(bit_select(&buf, 0, 128, 0), 42); + } + + #[rstest] + #[case(0, 128)] + #[case(3, 100)] + #[case(7, 50)] + #[case(1, 7)] + #[case(5, 5)] + #[case(0, 1)] + #[case(0, 64)] + #[case(1, 64)] + #[case(0, 65)] + #[case(3, 256)] + fn test_select_agrees_with_naive(#[case] offset: usize, #[case] len: usize) { + let total_bits = offset + len; + let total_bytes = total_bits.div_ceil(8); + // Deterministic pattern with moderate density. + let buf: Vec = (0..total_bytes) + .map(|i| ((i.wrapping_mul(0x9E) ^ 0xA5) & 0xFF) as u8) + .collect(); + + // Collect set-bit positions naively. + let expected: Vec = (0..len) + .filter(|&i| { + let phys = offset + i; + (buf[phys / 8] >> (phys % 8)) & 1 == 1 + }) + .collect(); + + for (nth, &expected_pos) in expected.iter().enumerate() { + assert_eq!( + bit_select(&buf, offset, len, nth), + expected_pos, + "offset={offset} len={len} nth={nth}" + ); + } + } + + #[test] + fn test_select_large_buffer() { + // ~64 KB buffer, ~50% density. + let len = 65_536 * 8; + let buf: Vec = (0u32..65_536) + .map(|i| ((i.wrapping_mul(0x37) ^ 0xBC) & 0xFF) as u8) + .collect(); + + let true_count = buf.iter().map(|b| b.count_ones() as usize).sum::(); + + // Spot-check a few positions. + let first = bit_select(&buf, 0, len, 0); + let last = bit_select(&buf, 0, len, true_count - 1); + assert!(first < len); + assert!(last < len); + assert!(first <= last); + + // Verify the found positions are actually set. + assert_ne!(buf[first / 8] & (1 << (first % 8)), 0); + assert_ne!(buf[last / 8] & (1 << (last % 8)), 0); + } + + #[test] + fn test_select_in_word_basic() { + // 0b1010_1010 = 0xAA — bits 1,3,5,7 are set. + let word = 0x00000000_000000AAu64; + assert_eq!(select_in_word(word, 0), 1); + assert_eq!(select_in_word(word, 1), 3); + assert_eq!(select_in_word(word, 2), 5); + assert_eq!(select_in_word(word, 3), 7); + } + + #[test] + fn test_select_in_word_all_set() { + let word = u64::MAX; + for nth in 0..64 { + assert_eq!(select_in_word(word, nth), nth, "nth={nth}"); + } + } + + #[test] + fn test_select_in_byte_basic() { + assert_eq!(select_in_byte(0b1010_1010, 0), 1); + assert_eq!(select_in_byte(0b1010_1010, 1), 3); + assert_eq!(select_in_byte(0b1010_1010, 2), 5); + assert_eq!(select_in_byte(0b1010_1010, 3), 7); + assert_eq!(select_in_byte(0b0000_0001, 0), 0); + assert_eq!(select_in_byte(0b1000_0000, 0), 7); + assert_eq!(select_in_byte(0xFF, 7), 7); + } + + // ── batch select tests ────────────────────────────────────────── + + #[test] + fn test_batch_select_all_set() { + let buf = [0xFFu8; 16]; // 128 bits, all set + let ranks: Vec = (0..128).collect(); + let mut out = vec![0usize; 128]; + bit_select_sorted_batch(&buf, 0, 128, &ranks, &mut out); + for (nth, &pos) in out.iter().enumerate() { + assert_eq!(pos, nth, "nth={nth}"); + } + } + + #[test] + fn test_batch_select_every_other() { + let buf = [0x55u8; 16]; // 128 bits, 64 set + let ranks: Vec = (0..64).collect(); + let mut out = vec![0usize; 64]; + bit_select_sorted_batch(&buf, 0, 128, &ranks, &mut out); + for (nth, &pos) in out.iter().enumerate() { + assert_eq!(pos, nth * 2, "nth={nth}"); + } + } + + #[test] + fn test_batch_select_sparse_ranks() { + let buf = [0xFFu8; 16]; // 128 set + let ranks = [0, 10, 50, 100, 127]; + let mut out = [0usize; 5]; + bit_select_sorted_batch(&buf, 0, 128, &ranks, &mut out); + assert_eq!(out, [0, 10, 50, 100, 127]); + } + + #[test] + fn test_batch_select_empty() { + let buf = [0xFFu8; 4]; + let mut out = []; + bit_select_sorted_batch(&buf, 0, 32, &[], &mut out); + } + + #[rstest] + #[case(0, 128)] + #[case(3, 100)] + #[case(7, 50)] + #[case(0, 65)] + #[case(3, 256)] + fn test_batch_select_agrees_with_individual(#[case] offset: usize, #[case] len: usize) { + let total_bytes = (offset + len).div_ceil(8); + let buf: Vec = (0..total_bytes) + .map(|i| ((i.wrapping_mul(0x9E) ^ 0xA5) & 0xFF) as u8) + .collect(); + + // Get individual results. + let true_count = (0..len) + .filter(|&i| { + let phys = offset + i; + (buf[phys / 8] >> (phys % 8)) & 1 == 1 + }) + .count(); + + let all_ranks: Vec = (0..true_count).collect(); + let individual: Vec = all_ranks + .iter() + .map(|&r| bit_select(&buf, offset, len, r)) + .collect(); + + let mut batch = vec![0usize; true_count]; + bit_select_sorted_batch(&buf, offset, len, &all_ranks, &mut batch); + + assert_eq!(batch, individual, "offset={offset} len={len}"); + } +} diff --git a/vortex-mask/Cargo.toml b/vortex-mask/Cargo.toml index a2af3d27cb0..037ffdb0d9d 100644 --- a/vortex-mask/Cargo.toml +++ b/vortex-mask/Cargo.toml @@ -35,9 +35,5 @@ rstest = { workspace = true } name = "intersect_by_rank" harness = false -[[bench]] -name = "rank" -harness = false - [lints] workspace = true diff --git a/vortex-mask/public-api.lock b/vortex-mask/public-api.lock index 006bdafaa64..75eb5da0187 100644 --- a/vortex-mask/public-api.lock +++ b/vortex-mask/public-api.lock @@ -82,6 +82,8 @@ pub fn vortex_mask::Mask::new_true(usize) -> Self pub fn vortex_mask::Mask::rank(&self, usize) -> usize +pub fn vortex_mask::Mask::rank_batch(&self, &[usize]) -> alloc::vec::Vec + pub fn vortex_mask::Mask::slice(&self, impl core::ops::range::RangeBounds) -> Self pub fn vortex_mask::Mask::slices(&self) -> vortex_mask::AllOr<&[(usize, usize)]> diff --git a/vortex-mask/src/lib.rs b/vortex-mask/src/lib.rs index 339f9a12d74..bbcb5772616 100644 --- a/vortex-mask/src/lib.rs +++ b/vortex-mask/src/lib.rs @@ -473,8 +473,52 @@ impl Mask { match &self { Self::AllTrue(_) => n, Self::AllFalse(_) => unreachable!("no true values in all-false mask"), - // TODO(joe): optimize this function - Self::Values(values) => values.indices()[n], + Self::Values(values) => values.buffer.select(n), + } + } + + /// Translate multiple positions through the mask in batch. + /// + /// For each `ranks[i]`, computes the position of the `ranks[i]`-th set bit, + /// equivalent to calling [`rank`](Self::rank) for each element. + /// + /// This is O(N log N + L/64) vs O(N × L/64) for individual calls, where + /// N = `ranks.len()` and L = mask length. + pub fn rank_batch(&self, ranks: &[usize]) -> Vec { + if ranks.is_empty() { + return vec![]; + } + match &self { + Self::AllTrue(_) => ranks.to_vec(), + Self::AllFalse(_) => unreachable!("no true values in all-false mask"), + Self::Values(values) => { + if let Some(indices) = values.indices.get() { + return ranks.iter().map(|&rank| indices[rank]).collect(); + } + + if ranks.is_sorted() { + return values.buffer.select_sorted_batch(ranks); + } + + if ranks.len() >= values.true_count().div_ceil(2) { + let indices = values.indices(); + return ranks.iter().map(|&rank| indices[rank]).collect(); + } + + // Sort an index permutation by rank value. + let mut perm: Vec = (0..ranks.len()).collect(); + perm.sort_unstable_by_key(|&i| ranks[i]); + + let sorted_ranks: Vec = perm.iter().map(|&i| ranks[i]).collect(); + let sorted_results = values.buffer.select_sorted_batch(&sorted_ranks); + + // Scatter back to original order. + let mut results = vec![0usize; ranks.len()]; + for (perm_idx, &orig_idx) in perm.iter().enumerate() { + results[orig_idx] = sorted_results[perm_idx]; + } + results + } } } diff --git a/vortex-mask/src/tests.rs b/vortex-mask/src/tests.rs index d1496fcb30f..e74bcd52e82 100644 --- a/vortex-mask/src/tests.rs +++ b/vortex-mask/src/tests.rs @@ -88,6 +88,22 @@ fn test_mask_value() { assert!(values.value(4)); } +#[test] +fn test_rank_batch_sorted_ranks() { + let mask = Mask::from_buffer(BitBuffer::from_iter([ + false, true, false, true, true, false, false, true, + ])); + + assert_eq!(mask.rank_batch(&[0, 1, 2, 3]), vec![1, 3, 4, 7]); +} + +#[test] +fn test_rank_batch_unsorted_ranks_with_cached_indices() { + let mask = Mask::from_indices(8, vec![1, 3, 4, 7]); + + assert_eq!(mask.rank_batch(&[3, 0, 2, 1]), vec![7, 1, 4, 3]); +} + #[test] fn test_mask_first() { assert_eq!(Mask::new_true(5).first(), Some(0)); From ada7ed463e366193b3e3183fe302c277e0054f59 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Thu, 14 May 2026 11:58:21 +0100 Subject: [PATCH 2/8] fixes Signed-off-by: Robert Kruszewski --- vortex-mask/Cargo.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vortex-mask/Cargo.toml b/vortex-mask/Cargo.toml index 037ffdb0d9d..a2af3d27cb0 100644 --- a/vortex-mask/Cargo.toml +++ b/vortex-mask/Cargo.toml @@ -35,5 +35,9 @@ rstest = { workspace = true } name = "intersect_by_rank" harness = false +[[bench]] +name = "rank" +harness = false + [lints] workspace = true From 9f224fc715da241a8906d68a590a81f2deb194d0 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Thu, 14 May 2026 12:05:56 +0100 Subject: [PATCH 3/8] less Signed-off-by: Robert Kruszewski --- vortex-mask/src/lib.rs | 45 ---------------------------------------- vortex-mask/src/tests.rs | 16 -------------- 2 files changed, 61 deletions(-) diff --git a/vortex-mask/src/lib.rs b/vortex-mask/src/lib.rs index bbcb5772616..dfc0d745995 100644 --- a/vortex-mask/src/lib.rs +++ b/vortex-mask/src/lib.rs @@ -477,51 +477,6 @@ impl Mask { } } - /// Translate multiple positions through the mask in batch. - /// - /// For each `ranks[i]`, computes the position of the `ranks[i]`-th set bit, - /// equivalent to calling [`rank`](Self::rank) for each element. - /// - /// This is O(N log N + L/64) vs O(N × L/64) for individual calls, where - /// N = `ranks.len()` and L = mask length. - pub fn rank_batch(&self, ranks: &[usize]) -> Vec { - if ranks.is_empty() { - return vec![]; - } - match &self { - Self::AllTrue(_) => ranks.to_vec(), - Self::AllFalse(_) => unreachable!("no true values in all-false mask"), - Self::Values(values) => { - if let Some(indices) = values.indices.get() { - return ranks.iter().map(|&rank| indices[rank]).collect(); - } - - if ranks.is_sorted() { - return values.buffer.select_sorted_batch(ranks); - } - - if ranks.len() >= values.true_count().div_ceil(2) { - let indices = values.indices(); - return ranks.iter().map(|&rank| indices[rank]).collect(); - } - - // Sort an index permutation by rank value. - let mut perm: Vec = (0..ranks.len()).collect(); - perm.sort_unstable_by_key(|&i| ranks[i]); - - let sorted_ranks: Vec = perm.iter().map(|&i| ranks[i]).collect(); - let sorted_results = values.buffer.select_sorted_batch(&sorted_ranks); - - // Scatter back to original order. - let mut results = vec![0usize; ranks.len()]; - for (perm_idx, &orig_idx) in perm.iter().enumerate() { - results[orig_idx] = sorted_results[perm_idx]; - } - results - } - } - } - /// Slice the mask. pub fn slice(&self, range: impl RangeBounds) -> Self { let start = match range.start_bound() { diff --git a/vortex-mask/src/tests.rs b/vortex-mask/src/tests.rs index e74bcd52e82..d1496fcb30f 100644 --- a/vortex-mask/src/tests.rs +++ b/vortex-mask/src/tests.rs @@ -88,22 +88,6 @@ fn test_mask_value() { assert!(values.value(4)); } -#[test] -fn test_rank_batch_sorted_ranks() { - let mask = Mask::from_buffer(BitBuffer::from_iter([ - false, true, false, true, true, false, false, true, - ])); - - assert_eq!(mask.rank_batch(&[0, 1, 2, 3]), vec![1, 3, 4, 7]); -} - -#[test] -fn test_rank_batch_unsorted_ranks_with_cached_indices() { - let mask = Mask::from_indices(8, vec![1, 3, 4, 7]); - - assert_eq!(mask.rank_batch(&[3, 0, 2, 1]), vec![7, 1, 4, 3]); -} - #[test] fn test_mask_first() { assert_eq!(Mask::new_true(5).first(), Some(0)); From 2b94a171bb5126d00241d2a39c7337d8216697ab Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Thu, 14 May 2026 12:09:08 +0100 Subject: [PATCH 4/8] less Signed-off-by: Robert Kruszewski --- vortex-buffer/src/bit/buf.rs | 19 --- vortex-buffer/src/bit/select.rs | 271 -------------------------------- 2 files changed, 290 deletions(-) diff --git a/vortex-buffer/src/bit/buf.rs b/vortex-buffer/src/bit/buf.rs index 108c99d0657..3d49cba0716 100644 --- a/vortex-buffer/src/bit/buf.rs +++ b/vortex-buffer/src/bit/buf.rs @@ -26,7 +26,6 @@ use crate::bit::get_bit_unchecked; use crate::bit::ops::bitwise_binary_op; use crate::bit::ops::bitwise_unary_op; use crate::bit::select::bit_select; -use crate::bit::select::bit_select_sorted_batch; use crate::buffer; /// An immutable bitset stored as a packed byte buffer. @@ -334,24 +333,6 @@ impl BitBuffer { bit_select(self.buffer.as_slice(), self.offset, self.len, nth) } - /// Select positions for multiple ranks in a single pass over the bitmap. - /// - /// `sorted_ranks` must be sorted in non-decreasing order, with each value - /// less than [`true_count`](Self::true_count). This is O(L/64 + N) where - /// L = bitmap length and N = number of ranks, compared to O(N × L/64) for - /// individual [`select`](Self::select) calls. - pub fn select_sorted_batch(&self, sorted_ranks: &[usize]) -> Vec { - let mut out = vec![0; sorted_ranks.len()]; - bit_select_sorted_batch( - self.buffer.as_slice(), - self.offset, - self.len, - sorted_ranks, - &mut out, - ); - out - } - /// Get the number of unset bits in the buffer. pub fn false_count(&self) -> usize { self.len - self.true_count() diff --git a/vortex-buffer/src/bit/select.rs b/vortex-buffer/src/bit/select.rs index 1ac6336971d..2647c080c22 100644 --- a/vortex-buffer/src/bit/select.rs +++ b/vortex-buffer/src/bit/select.rs @@ -65,205 +65,6 @@ pub fn bit_select(bytes: &[u8], offset: usize, len: usize, nth: usize) -> usize unreachable!("bit_select: nth={nth} exceeds set bit count") } -// ── Batch select (sorted ranks, single pass) ─────────────────────────── - -/// Select positions for multiple ranks in a single pass over the bitmap. -/// -/// `sorted_ranks` must be sorted in non-decreasing order; each value must be -/// less than the total number of set bits in `[offset, offset+len)`. -/// Results are written to `out[0..sorted_ranks.len()]`. -pub fn bit_select_sorted_batch( - bytes: &[u8], - offset: usize, - len: usize, - sorted_ranks: &[usize], - out: &mut [usize], -) { - debug_assert!(out.len() >= sorted_ranks.len()); - if sorted_ranks.is_empty() { - return; - } - - let (head, middle, tail) = align_offset_len(bytes, offset, len); - let mut cumul = 0usize; - let mut pos = 0usize; - let mut ri = 0usize; // index into sorted_ranks / out - - // ── head byte ─────────────────────────────────────────────────── - if let Some(head) = head { - let count = head.count_ones() as usize; - while ri < sorted_ranks.len() && sorted_ranks[ri] < cumul + count { - out[ri] = select_in_byte(head, sorted_ranks[ri] - cumul); - ri += 1; - } - cumul += count; - let start_bit = offset % 8; - pos = (8 - start_bit).min(len); - if ri >= sorted_ranks.len() { - return; - } - } - - // ── middle bytes ──────────────────────────────────────────────── - if !middle.is_empty() { - let (words, tail_bytes) = middle.as_chunks::<8>(); - - scan_words_batch_impl(words, sorted_ranks, out, &mut cumul, &mut pos, &mut ri); - if ri >= sorted_ranks.len() { - return; - } - - for &byte in tail_bytes { - let count = byte.count_ones() as usize; - while ri < sorted_ranks.len() && sorted_ranks[ri] < cumul + count { - out[ri] = pos + select_in_byte(byte, sorted_ranks[ri] - cumul); - ri += 1; - } - cumul += count; - pos += 8; - if ri >= sorted_ranks.len() { - return; - } - } - } - - // ── tail byte ─────────────────────────────────────────────────── - if let Some(tail) = tail { - let count = tail.count_ones() as usize; - while ri < sorted_ranks.len() && sorted_ranks[ri] < cumul + count { - out[ri] = pos + select_in_byte(tail, sorted_ranks[ri] - cumul); - ri += 1; - } - } -} - -// ── aarch64 NEON batch scan ───────────────────────────────────────────── - -#[cfg(target_arch = "aarch64")] -#[allow(clippy::cast_possible_truncation)] -fn scan_words_batch_impl( - words: &[[u8; 8]], - sorted_ranks: &[usize], - out: &mut [usize], - cumul: &mut usize, - pos: &mut usize, - ri: &mut usize, -) { - use std::arch::aarch64::vcntq_u8; - use std::arch::aarch64::vgetq_lane_u64; - use std::arch::aarch64::vld1q_u8; - use std::arch::aarch64::vpaddlq_u8; - use std::arch::aarch64::vpaddlq_u16; - use std::arch::aarch64::vpaddlq_u32; - - let mut wi = 0; - - // 4-word NEON blocks — skip entire blocks when no targets inside. - while wi + 4 <= words.len() && *ri < sorted_ranks.len() { - let ptr = words[wi].as_ptr(); - // SAFETY: wi + 4 <= words.len() guarantees 32 contiguous bytes. - let (c0, c1, c2, c3) = unsafe { - let pop_lo = vcntq_u8(vld1q_u8(ptr)); - let pop_hi = vcntq_u8(vld1q_u8(ptr.add(16))); - let sums_lo = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(pop_lo))); - let sums_hi = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(pop_hi))); - ( - vgetq_lane_u64::<0>(sums_lo) as usize, - vgetq_lane_u64::<1>(sums_lo) as usize, - vgetq_lane_u64::<0>(sums_hi) as usize, - vgetq_lane_u64::<1>(sums_hi) as usize, - ) - }; - let total = c0 + c1 + c2 + c3; - - if sorted_ranks[*ri] >= *cumul + total { - *cumul += total; - *pos += 256; - wi += 4; - continue; - } - - // At least one target in this block — emit per word. - let counts = [c0, c1, c2, c3]; - for (j, &count) in counts.iter().enumerate() { - while *ri < sorted_ranks.len() && sorted_ranks[*ri] < *cumul + count { - let word = u64::from_le_bytes(words[wi + j]); - out[*ri] = *pos + select_in_word(word, sorted_ranks[*ri] - *cumul); - *ri += 1; - } - *cumul += count; - *pos += 64; - } - wi += 4; - } - - // Remaining words, scalar. - while wi < words.len() && *ri < sorted_ranks.len() { - let word = u64::from_le_bytes(words[wi]); - let count = word.count_ones() as usize; - while *ri < sorted_ranks.len() && sorted_ranks[*ri] < *cumul + count { - out[*ri] = *pos + select_in_word(word, sorted_ranks[*ri] - *cumul); - *ri += 1; - } - *cumul += count; - *pos += 64; - wi += 1; - } -} - -// ── Scalar batch scan ─────────────────────────────────────────────────── - -#[cfg(not(target_arch = "aarch64"))] -fn scan_words_batch_impl( - words: &[[u8; 8]], - sorted_ranks: &[usize], - out: &mut [usize], - cumul: &mut usize, - pos: &mut usize, - ri: &mut usize, -) { - let mut wi = 0; - - while wi + 4 <= words.len() && *ri < sorted_ranks.len() { - let c0 = u64::from_le_bytes(words[wi]).count_ones() as usize; - let c1 = u64::from_le_bytes(words[wi + 1]).count_ones() as usize; - let c2 = u64::from_le_bytes(words[wi + 2]).count_ones() as usize; - let c3 = u64::from_le_bytes(words[wi + 3]).count_ones() as usize; - let total = c0 + c1 + c2 + c3; - - if sorted_ranks[*ri] >= *cumul + total { - *cumul += total; - *pos += 256; - wi += 4; - continue; - } - - let counts = [c0, c1, c2, c3]; - for (j, &count) in counts.iter().enumerate() { - while *ri < sorted_ranks.len() && sorted_ranks[*ri] < *cumul + count { - let word = u64::from_le_bytes(words[wi + j]); - out[*ri] = *pos + select_in_word(word, sorted_ranks[*ri] - *cumul); - *ri += 1; - } - *cumul += count; - *pos += 64; - } - wi += 4; - } - - while wi < words.len() && *ri < sorted_ranks.len() { - let word = u64::from_le_bytes(words[wi]); - let count = word.count_ones() as usize; - while *ri < sorted_ranks.len() && sorted_ranks[*ri] < *cumul + count { - out[*ri] = *pos + select_in_word(word, sorted_ranks[*ri] - *cumul); - *ri += 1; - } - *cumul += count; - *pos += 64; - wi += 1; - } -} - // ── Word-level scan ───────────────────────────────────────────────────── /// Scan `words` accumulating popcounts. Returns `(remaining, position, word_index)`. @@ -615,76 +416,4 @@ mod tests { assert_eq!(select_in_byte(0b1000_0000, 0), 7); assert_eq!(select_in_byte(0xFF, 7), 7); } - - // ── batch select tests ────────────────────────────────────────── - - #[test] - fn test_batch_select_all_set() { - let buf = [0xFFu8; 16]; // 128 bits, all set - let ranks: Vec = (0..128).collect(); - let mut out = vec![0usize; 128]; - bit_select_sorted_batch(&buf, 0, 128, &ranks, &mut out); - for (nth, &pos) in out.iter().enumerate() { - assert_eq!(pos, nth, "nth={nth}"); - } - } - - #[test] - fn test_batch_select_every_other() { - let buf = [0x55u8; 16]; // 128 bits, 64 set - let ranks: Vec = (0..64).collect(); - let mut out = vec![0usize; 64]; - bit_select_sorted_batch(&buf, 0, 128, &ranks, &mut out); - for (nth, &pos) in out.iter().enumerate() { - assert_eq!(pos, nth * 2, "nth={nth}"); - } - } - - #[test] - fn test_batch_select_sparse_ranks() { - let buf = [0xFFu8; 16]; // 128 set - let ranks = [0, 10, 50, 100, 127]; - let mut out = [0usize; 5]; - bit_select_sorted_batch(&buf, 0, 128, &ranks, &mut out); - assert_eq!(out, [0, 10, 50, 100, 127]); - } - - #[test] - fn test_batch_select_empty() { - let buf = [0xFFu8; 4]; - let mut out = []; - bit_select_sorted_batch(&buf, 0, 32, &[], &mut out); - } - - #[rstest] - #[case(0, 128)] - #[case(3, 100)] - #[case(7, 50)] - #[case(0, 65)] - #[case(3, 256)] - fn test_batch_select_agrees_with_individual(#[case] offset: usize, #[case] len: usize) { - let total_bytes = (offset + len).div_ceil(8); - let buf: Vec = (0..total_bytes) - .map(|i| ((i.wrapping_mul(0x9E) ^ 0xA5) & 0xFF) as u8) - .collect(); - - // Get individual results. - let true_count = (0..len) - .filter(|&i| { - let phys = offset + i; - (buf[phys / 8] >> (phys % 8)) & 1 == 1 - }) - .count(); - - let all_ranks: Vec = (0..true_count).collect(); - let individual: Vec = all_ranks - .iter() - .map(|&r| bit_select(&buf, offset, len, r)) - .collect(); - - let mut batch = vec![0usize; true_count]; - bit_select_sorted_batch(&buf, offset, len, &all_ranks, &mut batch); - - assert_eq!(batch, individual, "offset={offset} len={len}"); - } } From ef9decdfe19be48b2e592394be9a3d265e1082b4 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Thu, 14 May 2026 12:10:50 +0100 Subject: [PATCH 5/8] try Signed-off-by: Robert Kruszewski --- vortex-mask/src/lib.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vortex-mask/src/lib.rs b/vortex-mask/src/lib.rs index dfc0d745995..e88c67d22a0 100644 --- a/vortex-mask/src/lib.rs +++ b/vortex-mask/src/lib.rs @@ -473,7 +473,13 @@ impl Mask { match &self { Self::AllTrue(_) => n, Self::AllFalse(_) => unreachable!("no true values in all-false mask"), - Self::Values(values) => values.buffer.select(n), + Self::Values(values) => { + if let Some(indices) = values.indices.get() { + return indices[n]; + } + + values.buffer.select(n) + } } } From 6d1c9a82959c558987ce58b54524ab728fe8bb95 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Thu, 14 May 2026 22:48:06 +0100 Subject: [PATCH 6/8] api Signed-off-by: Robert Kruszewski --- vortex-buffer/public-api.lock | 2 -- vortex-mask/public-api.lock | 2 -- 2 files changed, 4 deletions(-) diff --git a/vortex-buffer/public-api.lock b/vortex-buffer/public-api.lock index ec0570ac4a3..3080ee8116e 100644 --- a/vortex-buffer/public-api.lock +++ b/vortex-buffer/public-api.lock @@ -282,8 +282,6 @@ pub fn vortex_buffer::BitBuffer::offset(&self) -> usize pub fn vortex_buffer::BitBuffer::select(&self, usize) -> usize -pub fn vortex_buffer::BitBuffer::select_sorted_batch(&self, &[usize]) -> alloc::vec::Vec - pub fn vortex_buffer::BitBuffer::set_indices(&self) -> arrow_buffer::util::bit_iterator::BitIndexIterator<'_> pub fn vortex_buffer::BitBuffer::set_slices(&self) -> arrow_buffer::util::bit_iterator::BitSliceIterator<'_> diff --git a/vortex-mask/public-api.lock b/vortex-mask/public-api.lock index 75eb5da0187..006bdafaa64 100644 --- a/vortex-mask/public-api.lock +++ b/vortex-mask/public-api.lock @@ -82,8 +82,6 @@ pub fn vortex_mask::Mask::new_true(usize) -> Self pub fn vortex_mask::Mask::rank(&self, usize) -> usize -pub fn vortex_mask::Mask::rank_batch(&self, &[usize]) -> alloc::vec::Vec - pub fn vortex_mask::Mask::slice(&self, impl core::ops::range::RangeBounds) -> Self pub fn vortex_mask::Mask::slices(&self) -> vortex_mask::AllOr<&[(usize, usize)]> From 0af4aada0a15784aeb3a183d7e74aaeca308f2fc Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Fri, 15 May 2026 00:11:17 +0100 Subject: [PATCH 7/8] simplify Signed-off-by: Robert Kruszewski --- vortex-array/src/arrays/filter/take.rs | 310 ++++++++----------------- vortex-mask/src/lib.rs | 3 +- 2 files changed, 96 insertions(+), 217 deletions(-) diff --git a/vortex-array/src/arrays/filter/take.rs b/vortex-array/src/arrays/filter/take.rs index 2733e93b65f..53ad8b90b54 100644 --- a/vortex-array/src/arrays/filter/take.rs +++ b/vortex-array/src/arrays/filter/take.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use vortex_buffer::BitBuffer; use vortex_buffer::Buffer; use vortex_buffer::BufferMut; use vortex_error::VortexResult; @@ -12,9 +13,8 @@ use super::Filter; use crate::ArrayRef; use crate::IntoArray; use crate::array::ArrayView; -use crate::arrays::Decimal; +use crate::arrays::ConstantArray; use crate::arrays::DecimalArray; -use crate::arrays::Primitive; use crate::arrays::PrimitiveArray; use crate::arrays::decimal::DecimalArrayExt; use crate::arrays::dict::TakeExecute; @@ -30,6 +30,7 @@ use crate::kernel::ParentKernelSet; use crate::match_each_decimal_value_type; use crate::match_each_integer_ptype; use crate::match_each_native_ptype; +use crate::scalar::Scalar; use crate::validity::Validity; pub(super) const PARENT_KERNELS: ParentKernelSet = @@ -43,37 +44,41 @@ fn take_impl( ctx: &mut ExecutionCtx, ) -> VortexResult { let indices_validity = indices.validity()?.execute_mask(indices.len(), ctx)?; - if let Some(taken) = take_primitive_fast_path(array, indices, &indices_validity)? { + if array.child().dtype().is_primitive() + && let Some(taken) = take_primitive(array, indices, &indices_validity, ctx)? + { return Ok(taken); } - if let Some(taken) = take_decimal_fast_path(array, indices, &indices_validity)? { + if array.child().dtype().is_decimal() + && let Some(taken) = take_decimal(array, indices, &indices_validity, ctx)? + { return Ok(taken); } - if indices_validity.all_true() { - let translated = translate_indices_fast(array.filter_mask(), indices, array.len())?; - let translated_indices = PrimitiveArray::new( - translated, - Validity::from_mask(indices_validity, indices.dtype().nullability()), - ) - .into_array(); - - return array.child().take(translated_indices); - } + match indices_validity.bit_buffer() { + AllOr::All => { + let translated = translate_indices(array.filter_mask(), indices)?; + let translated_indices = + PrimitiveArray::new(translated, indices.validity()?).into_array(); - let translated = translate_nullable_indices_fast( - array.filter_mask(), - indices, - &indices_validity, - array.len(), - )?; - let translated_indices = PrimitiveArray::new( - translated, - Validity::from_mask(indices_validity, indices.dtype().nullability()), - ) - .into_array(); + return array.child().take(translated_indices); + } + AllOr::None => { + return Ok(ConstantArray::new( + Scalar::null(array.dtype().as_nullable()), + indices.len(), + ) + .into_array()); + } + AllOr::Some(b) => { + let translated = + translate_nullable_indices(array.filter_mask(), indices, b, array.len())?; + let translated_indices = + PrimitiveArray::new(translated, indices.validity()?).into_array(); - array.child().take(translated_indices) + array.child().take(translated_indices) + } + } } fn should_fallback_nullable_fixed_width_full_take( @@ -97,14 +102,14 @@ fn should_fallback_nullable_fixed_width_full_take( Ok(false) } -fn translate_nullable_indices_fast( +fn translate_nullable_indices( filter: &Mask, indices: &PrimitiveArray, - indices_validity: &Mask, + indices_validity: &BitBuffer, filtered_len: usize, ) -> VortexResult> { match_each_integer_ptype!(indices.ptype(), |P| { - translate_nullable_ranks_fast( + translate_nullable_ranks( filter, indices.as_slice::

(), indices_validity, @@ -113,20 +118,13 @@ fn translate_nullable_indices_fast( }) } -fn translate_nullable_ranks_fast( +fn translate_nullable_ranks( filter: &Mask, ranks: &[P], - indices_validity: &Mask, + indices_validity: &BitBuffer, filtered_len: usize, ) -> VortexResult> { - if indices_validity.all_true() { - return translate_ranks_fast(filter, ranks, filtered_len); - } - if indices_validity.all_false() { - return Ok(Buffer::zeroed(ranks.len())); - } - - if let Some(start) = contiguous_filter_start(filter, filtered_len) { + if let Some(start) = contiguous_filter_start(filter) { return translate_nullable_ranks_with_offset(ranks, indices_validity, filtered_len, start); } @@ -144,7 +142,7 @@ fn translate_nullable_ranks_fast( fn translate_nullable_ranks_with_offset( ranks: &[P], - indices_validity: &Mask, + indices_validity: &BitBuffer, filtered_len: usize, start: usize, ) -> VortexResult> { @@ -171,7 +169,7 @@ fn translate_nullable_ranks_with_offset( fn translate_nullable_ranks_identity( ranks: &[P], - indices_validity: &Mask, + indices_validity: &BitBuffer, filtered_len: usize, ) -> VortexResult> { let mut translated = BufferMut::::with_capacity(ranks.len()); @@ -196,7 +194,7 @@ fn translate_nullable_ranks_identity( fn translate_nullable_ranks_with_indices( ranks: &[P], - indices_validity: &Mask, + indices_validity: &BitBuffer, filtered_len: usize, filter_indices: &[usize], ) -> VortexResult> { @@ -223,51 +221,33 @@ fn translate_nullable_ranks_with_indices( Ok(translated.freeze()) } -fn translate_indices_fast( - filter: &Mask, - indices: &PrimitiveArray, - filtered_len: usize, -) -> VortexResult> { +fn translate_indices(filter: &Mask, indices: &PrimitiveArray) -> VortexResult> { match_each_integer_ptype!(indices.ptype(), |P| { - translate_ranks_fast(filter, indices.as_slice::

(), filtered_len) + translate_ranks(filter, indices.as_slice::

()) }) } -fn translate_ranks_fast( - filter: &Mask, - ranks: &[P], - filtered_len: usize, -) -> VortexResult> { +fn translate_ranks(filter: &Mask, ranks: &[P]) -> VortexResult> { let mut translated = BufferMut::::with_capacity(ranks.len()); let translated_ptr = translated.spare_capacity_mut().as_mut_ptr().cast::(); - if let Some(start) = contiguous_filter_start(filter, filtered_len) { + if let Some(start) = contiguous_filter_start(filter) { for (idx, rank) in ranks.iter().enumerate() { - let Some(rank) = rank.to_usize() else { - vortex_bail!(OutOfBounds: 0, 0, filtered_len); - }; - if rank >= filtered_len { - vortex_bail!(OutOfBounds: rank, 0, filtered_len); - } - // SAFETY: `translated` has capacity for all ranks and this loop initializes each // output slot once. - unsafe { translated_ptr.add(idx).write(u64::try_from(start + rank)?) }; + unsafe { + translated_ptr + .add(idx) + .write(u64::try_from(start + rank.as_())?) + }; } } else { let filter_indices = match filter.indices() { AllOr::All => { for (idx, rank) in ranks.iter().enumerate() { - let Some(rank) = rank.to_usize() else { - vortex_bail!(OutOfBounds: 0, 0, filtered_len); - }; - if rank >= filtered_len { - vortex_bail!(OutOfBounds: rank, 0, filtered_len); - } - // SAFETY: `translated` has capacity for all ranks and this loop initializes // each output slot once. - unsafe { translated_ptr.add(idx).write(u64::try_from(rank)?) }; + unsafe { translated_ptr.add(idx).write(u64::try_from(rank.as_())?) }; } // SAFETY: The loop writes exactly `ranks.len()` initialized values. @@ -279,19 +259,12 @@ fn translate_ranks_fast( }; for (idx, rank) in ranks.iter().enumerate() { - let Some(rank) = rank.to_usize() else { - vortex_bail!(OutOfBounds: 0, 0, filtered_len); - }; - if rank >= filtered_len { - vortex_bail!(OutOfBounds: rank, 0, filtered_len); - } - // SAFETY: `translated` has capacity for all ranks. `rank` was checked against the // filtered length, and filter indices are valid child positions by construction. unsafe { translated_ptr .add(idx) - .write(u64::try_from(*filter_indices.get_unchecked(rank))?) + .write(u64::try_from(*filter_indices.get_unchecked(rank.as_()))?) }; } } @@ -302,23 +275,20 @@ fn translate_ranks_fast( } fn validate_rank(rank: P, filtered_len: usize) -> VortexResult { - let Some(rank) = rank.to_usize() else { - vortex_bail!(OutOfBounds: 0, 0, filtered_len); - }; + let rank: usize = rank.as_(); if rank >= filtered_len { vortex_bail!(OutOfBounds: rank, 0, filtered_len); } Ok(rank) } -fn take_primitive_fast_path( +fn take_primitive( array: ArrayView<'_, Filter>, indices: &PrimitiveArray, indices_validity: &Mask, + ctx: &mut ExecutionCtx, ) -> VortexResult> { - let Some(child) = array.child().as_opt::() else { - return Ok(None); - }; + let child = array.child().clone().execute::(ctx)?; let child_validity = child.validity()?; if !child_validity.no_nulls() { @@ -329,10 +299,9 @@ fn take_primitive_fast_path( Validity::from_mask(indices_validity.clone(), indices.dtype().nullability()); match_each_native_ptype!(child.ptype(), |T| { match_each_integer_ptype!(indices.ptype(), |P| { - take_primitive_fast_path_typed::( + take_primitive_typed::( child, array.filter_mask(), - array.len(), indices, indices_validity, output_validity, @@ -342,14 +311,13 @@ fn take_primitive_fast_path( }) } -fn take_decimal_fast_path( +fn take_decimal( array: ArrayView<'_, Filter>, indices: &PrimitiveArray, indices_validity: &Mask, + ctx: &mut ExecutionCtx, ) -> VortexResult> { - let Some(child) = array.child().as_opt::() else { - return Ok(None); - }; + let child = array.child().clone().execute::(ctx)?; let child_validity = child.validity()?; if !child_validity.no_nulls() { @@ -360,10 +328,9 @@ fn take_decimal_fast_path( Validity::from_mask(indices_validity.clone(), indices.dtype().nullability()); match_each_decimal_value_type!(child.values_type(), |T| { match_each_integer_ptype!(indices.ptype(), |P| { - take_decimal_fast_path_typed::( + take_decimal_typed::( child, array.filter_mask(), - array.len(), indices, indices_validity, output_validity, @@ -373,10 +340,9 @@ fn take_decimal_fast_path( }) } -fn take_decimal_fast_path_typed( - child: ArrayView<'_, Decimal>, +fn take_decimal_typed( + child: DecimalArray, filter: &Mask, - filtered_len: usize, indices: &PrimitiveArray, indices_validity: &Mask, output_validity: Validity, @@ -389,7 +355,7 @@ where let decimal_dtype = child.decimal_dtype(); if indices_validity.all_true() { - if let Some((start, end)) = contiguous_sequential_take_range(filter, ranks, filtered_len)? { + if let Some((start, end)) = contiguous_sequential_take_range(filter, ranks)? { let values = child.buffer_handle().slice_typed::(start..end); // SAFETY: The values are sliced from an existing valid decimal array, and the output // validity was built for exactly the sliced take length. @@ -404,12 +370,7 @@ where .into_array()); } - let taken = take_filtered_values::( - child.buffer::().as_slice(), - filter, - ranks, - filtered_len, - )?; + let taken = take_filtered_values::(child.buffer::().as_slice(), filter, ranks)?; // SAFETY: Taking existing decimal values preserves the decimal dtype invariants, and the // output validity was built for the take length. return Ok( @@ -423,17 +384,15 @@ where filter, ranks, indices_validity, - filtered_len, )?; // SAFETY: Valid ranks copy existing decimal values, null ranks write default placeholders that // are hidden by output validity, and the output validity was built for the take length. Ok(unsafe { DecimalArray::new_unchecked(taken, decimal_dtype, output_validity) }.into_array()) } -fn take_primitive_fast_path_typed( - child: ArrayView<'_, Primitive>, +fn take_primitive_typed( + child: PrimitiveArray, filter: &Mask, - filtered_len: usize, indices: &PrimitiveArray, indices_validity: &Mask, output_validity: Validity, @@ -445,29 +404,21 @@ where let ranks = indices.as_slice::

(); if indices_validity.all_true() { - return take_primitive_fast_path_all_valid::( - child, - filter, - filtered_len, - ranks, - output_validity, - ); + return take_primitive_all_valid::(child, filter, ranks, output_validity); } - take_primitive_fast_path_nullable::( + take_primitive_nullable::( child.as_slice::(), filter, - filtered_len, ranks, indices_validity, output_validity, ) } -fn take_primitive_fast_path_all_valid( - child: ArrayView<'_, Primitive>, +fn take_primitive_all_valid( + child: PrimitiveArray, filter: &Mask, - filtered_len: usize, ranks: &[P], output_validity: Validity, ) -> VortexResult @@ -475,7 +426,7 @@ where T: NativePType, P: IntegerPType, { - if let Some((start, end)) = contiguous_sequential_take_range(filter, ranks, filtered_len)? { + if let Some((start, end)) = contiguous_sequential_take_range(filter, ranks)? { return Ok(PrimitiveArray::from_buffer_handle( child.buffer_handle().slice_typed::(start..end), T::PTYPE, @@ -484,14 +435,13 @@ where .into_array()); } - let taken = take_filtered_values::(child.as_slice::(), filter, ranks, filtered_len)?; + let taken = take_filtered_values::(child.as_slice::(), filter, ranks)?; Ok(PrimitiveArray::new(taken, output_validity).into_array()) } -fn take_primitive_fast_path_nullable( +fn take_primitive_nullable( values: &[T], filter: &Mask, - filtered_len: usize, ranks: &[P], indices_validity: &Mask, output_validity: Validity, @@ -500,13 +450,7 @@ where T: NativePType, P: IntegerPType, { - let taken = take_filtered_values_nullable::( - values, - filter, - ranks, - indices_validity, - filtered_len, - )?; + let taken = take_filtered_values_nullable::(values, filter, ranks, indices_validity)?; Ok(PrimitiveArray::new(taken, output_validity).into_array()) } @@ -515,7 +459,6 @@ fn take_filtered_values_nullable( filter: &Mask, ranks: &[P], indices_validity: &Mask, - filtered_len: usize, ) -> VortexResult> where T: Copy + Default, @@ -525,12 +468,12 @@ where return Ok(Buffer::zeroed(ranks.len())); } - if let Some(start) = contiguous_filter_start(filter, filtered_len) { + if let Some(start) = contiguous_filter_start(filter) { let mut out = BufferMut::::with_capacity(ranks.len()); let out_ptr = out.spare_capacity_mut().as_mut_ptr().cast::(); for (idx, rank) in ranks.iter().enumerate() { let value = if indices_validity.value(idx) { - let rank = validate_rank(*rank, filtered_len)?; + let rank = validate_rank(*rank, filter.true_count())?; // SAFETY: `rank` was checked against the contiguous filtered length. unsafe { *values.get_unchecked(start + rank) } } else { @@ -549,7 +492,12 @@ where let indices = match filter.indices() { AllOr::All => { - return take_values_by_rank_nullable(values, ranks, indices_validity, filtered_len); + return take_values_by_rank_nullable( + values, + ranks, + indices_validity, + filter.true_count(), + ); } AllOr::None => unreachable!("empty filters are handled by take preconditions"), AllOr::Some(indices) => indices, @@ -559,7 +507,7 @@ where let out_ptr = out.spare_capacity_mut().as_mut_ptr().cast::(); for (idx, rank) in ranks.iter().enumerate() { let value = if indices_validity.value(idx) { - let rank = validate_rank(*rank, filtered_len)?; + let rank = validate_rank(*rank, filter.true_count())?; // SAFETY: `rank` was bounds-checked against `indices`, whose values are valid // positions in `values`. unsafe { *values.get_unchecked(*indices.get_unchecked(rank)) } @@ -611,20 +559,13 @@ where fn contiguous_sequential_take_range( filter: &Mask, ranks: &[P], - filtered_len: usize, ) -> VortexResult> { - let Some(start) = contiguous_filter_start(filter, filtered_len) else { + let Some(start) = contiguous_filter_start(filter) else { return Ok(None); }; for (idx, rank) in ranks.iter().enumerate() { - let Some(rank) = rank.to_usize() else { - vortex_bail!(OutOfBounds: 0, 0, filtered_len); - }; - if rank >= filtered_len { - vortex_bail!(OutOfBounds: rank, 0, filtered_len); - } - if rank != idx { + if rank.as_() != idx { return Ok(None); } } @@ -632,27 +573,16 @@ fn contiguous_sequential_take_range( Ok(Some((start, start + ranks.len()))) } -fn take_filtered_values( - values: &[T], - filter: &Mask, - ranks: &[P], - filtered_len: usize, -) -> VortexResult> +fn take_filtered_values(values: &[T], filter: &Mask, ranks: &[P]) -> VortexResult> where T: Copy + Default, P: IntegerPType, { - if let Some(start) = contiguous_filter_start(filter, filtered_len) { + if let Some(start) = contiguous_filter_start(filter) { let mut out = BufferMut::::with_capacity(ranks.len()); let out_ptr = out.spare_capacity_mut().as_mut_ptr().cast::(); for (idx, rank) in ranks.iter().enumerate() { - let Some(rank) = rank.to_usize() else { - vortex_bail!(OutOfBounds: 0, 0, filtered_len); - }; - if rank >= filtered_len { - vortex_bail!(OutOfBounds: rank, 0, filtered_len); - } - + let rank = validate_rank(*rank, filter.true_count())?; // SAFETY: `out` has capacity for all ranks. The filter is contiguous with // `filtered_len` values starting at `start`, and `rank` was checked above. unsafe { out_ptr.add(idx).write(*values.get_unchecked(start + rank)) }; @@ -664,32 +594,20 @@ where } let indices = match filter.indices() { - AllOr::All => return take_values_by_rank(values, ranks, filtered_len), + AllOr::All => return take_values_by_rank(values, ranks), AllOr::None => unreachable!("empty filters are handled by take preconditions"), AllOr::Some(indices) => indices, }; - if ranks.len() == filtered_len && !first_rank_is_zero(ranks, filtered_len)? { - let filtered = gather_values_by_indices(values, indices); - return take_values_by_rank(filtered.as_slice(), ranks, filtered_len); - } - let mut out = BufferMut::::with_capacity(ranks.len()); let out_ptr = out.spare_capacity_mut().as_mut_ptr().cast::(); for (idx, rank) in ranks.iter().enumerate() { - let Some(rank) = rank.to_usize() else { - vortex_bail!(OutOfBounds: 0, 0, filtered_len); - }; - if rank >= filtered_len { - vortex_bail!(OutOfBounds: rank, 0, filtered_len); - } - // SAFETY: `out` has capacity for all ranks. `rank` was bounds-checked against // `indices`, whose values are valid positions in `values`. unsafe { out_ptr .add(idx) - .write(*values.get_unchecked(*indices.get_unchecked(rank))) + .write(*values.get_unchecked(*indices.get_unchecked(rank.as_()))) }; } @@ -698,42 +616,7 @@ where Ok(out.freeze()) } -fn gather_values_by_indices(values: &[T], indices: &[usize]) -> Buffer -where - T: Copy + Default, -{ - let mut out = BufferMut::::with_capacity(indices.len()); - let out_ptr = out.spare_capacity_mut().as_mut_ptr().cast::(); - - for (idx, &value_idx) in indices.iter().enumerate() { - // SAFETY: `out` has capacity for all indices and mask indices are valid positions in the - // child values buffer by construction. - unsafe { out_ptr.add(idx).write(*values.get_unchecked(value_idx)) }; - } - - // SAFETY: The loop writes exactly `indices.len()` initialized values. - unsafe { out.set_len(indices.len()) }; - out.freeze() -} - -fn first_rank_is_zero(ranks: &[P], filtered_len: usize) -> VortexResult { - let Some(first) = ranks.first() else { - return Ok(false); - }; - let Some(first) = first.to_usize() else { - vortex_bail!(OutOfBounds: 0, 0, filtered_len); - }; - if first >= filtered_len { - vortex_bail!(OutOfBounds: first, 0, filtered_len); - } - Ok(first == 0) -} - -fn take_values_by_rank( - values: &[T], - ranks: &[P], - filtered_len: usize, -) -> VortexResult> +fn take_values_by_rank(values: &[T], ranks: &[P]) -> VortexResult> where T: Copy + Default, P: IntegerPType, @@ -741,12 +624,7 @@ where let mut out = BufferMut::::with_capacity(ranks.len()); let out_ptr = out.spare_capacity_mut().as_mut_ptr().cast::(); for (idx, rank) in ranks.iter().enumerate() { - let Some(rank) = rank.to_usize() else { - vortex_bail!(OutOfBounds: 0, 0, filtered_len); - }; - if rank >= filtered_len { - vortex_bail!(OutOfBounds: rank, 0, filtered_len); - } + let rank = validate_rank(*rank, values.len())?; // SAFETY: `out` has capacity for all ranks and `rank` was bounds-checked. unsafe { out_ptr.add(idx).write(*values.get_unchecked(rank)) }; @@ -757,10 +635,10 @@ where Ok(out.freeze()) } -fn contiguous_filter_start(filter: &Mask, filtered_len: usize) -> Option { +fn contiguous_filter_start(filter: &Mask) -> Option { let start = filter.first()?; let end = filter.last()?.checked_add(1)?; - (end - start == filtered_len).then_some(start) + (end - start == filter.true_count()).then_some(start) } impl TakeExecute for Filter { diff --git a/vortex-mask/src/lib.rs b/vortex-mask/src/lib.rs index e88c67d22a0..47c4a97281e 100644 --- a/vortex-mask/src/lib.rs +++ b/vortex-mask/src/lib.rs @@ -457,7 +457,8 @@ impl Mask { if let Some(slices) = values.slices.get() { return slices.last().map(|(_, end)| end - 1); } - values.buffer.set_slices().last().map(|(_, end)| end - 1) + + (values.true_count != 0).then(|| values.buffer.select(values.true_count - 1)) } } } From 866cecb9d456ef5c3adfbe340c5f0b1957c6eff7 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Fri, 15 May 2026 14:23:35 +0100 Subject: [PATCH 8/8] refactor Signed-off-by: Robert Kruszewski --- vortex-array/src/arrays/filter/take.rs | 477 ++++++++++++++++--------- vortex-buffer/public-api.lock | 2 +- vortex-buffer/src/bit/buf.rs | 7 +- vortex-buffer/src/bit/select.rs | 39 +- vortex-mask/src/lib.rs | 25 +- 5 files changed, 355 insertions(+), 195 deletions(-) diff --git a/vortex-array/src/arrays/filter/take.rs b/vortex-array/src/arrays/filter/take.rs index 53ad8b90b54..d589781119e 100644 --- a/vortex-array/src/arrays/filter/take.rs +++ b/vortex-array/src/arrays/filter/take.rs @@ -36,43 +36,59 @@ use crate::validity::Validity; pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(Filter))]); -const NULLABLE_FIXED_WIDTH_FULL_TAKE_FALLBACK_LEN: usize = 4096; +const BIG_TAKE_FALLBACK_LEN: usize = 4096; fn take_impl( array: ArrayView<'_, Filter>, indices: &PrimitiveArray, ctx: &mut ExecutionCtx, ) -> VortexResult { - let indices_validity = indices.validity()?.execute_mask(indices.len(), ctx)?; if array.child().dtype().is_primitive() - && let Some(taken) = take_primitive(array, indices, &indices_validity, ctx)? + && let Some(taken) = take_primitive(array, indices, ctx)? { return Ok(taken); } + if array.child().dtype().is_decimal() - && let Some(taken) = take_decimal(array, indices, &indices_validity, ctx)? + && let Some(taken) = take_decimal(array, indices, ctx)? { return Ok(taken); } + let indices_validity = indices.validity()?.execute_mask(indices.len(), ctx)?; + if indices_validity.all_true() + && let Some((start, end)) = + contiguous_sequential_take_range_indices(array.filter_mask(), indices)? + { + return array.child().slice(start..end); + } + + if indices_validity.all_true() + && let Some(take_len) = sequential_take_len(indices, array.len())? + { + if take_len == 0 { + return array.child().slice(0..0); + } + let rank_mask = Mask::from_slices(array.len(), vec![(0, take_len)]); + let mask = array.filter_mask().intersect_by_rank(&rank_mask); + return array.child().filter(mask); + } + match indices_validity.bit_buffer() { AllOr::All => { let translated = translate_indices(array.filter_mask(), indices)?; let translated_indices = PrimitiveArray::new(translated, indices.validity()?).into_array(); - return array.child().take(translated_indices); - } - AllOr::None => { - return Ok(ConstantArray::new( - Scalar::null(array.dtype().as_nullable()), - indices.len(), - ) - .into_array()); + array.child().take(translated_indices) } + AllOr::None => Ok(ConstantArray::new( + Scalar::null(array.dtype().as_nullable()), + indices.len(), + ) + .into_array()), AllOr::Some(b) => { - let translated = - translate_nullable_indices(array.filter_mask(), indices, b, array.len())?; + let translated = translate_nullable_indices(array.filter_mask(), indices, b)?; let translated_indices = PrimitiveArray::new(translated, indices.validity()?).into_array(); @@ -81,40 +97,13 @@ fn take_impl( } } -fn should_fallback_nullable_fixed_width_full_take( - array: ArrayView<'_, Filter>, - indices: &ArrayRef, -) -> VortexResult { - if indices.len() < NULLABLE_FIXED_WIDTH_FULL_TAKE_FALLBACK_LEN - || indices.len() < array.len() - || !indices.dtype().is_nullable() - { - return Ok(false); - } - - // For large nullable full-or-larger fixed-width takes with nullable children, materializing the - // filter first and then taking the child beats translating every nullable rank through the - // parent. - if array.child().dtype().is_decimal() || array.child().dtype().is_primitive() { - return Ok(!array.child().validity()?.no_nulls()); - } - - Ok(false) -} - fn translate_nullable_indices( filter: &Mask, indices: &PrimitiveArray, indices_validity: &BitBuffer, - filtered_len: usize, ) -> VortexResult> { match_each_integer_ptype!(indices.ptype(), |P| { - translate_nullable_ranks( - filter, - indices.as_slice::

(), - indices_validity, - filtered_len, - ) + translate_nullable_ranks(filter, indices.as_slice::

(), indices_validity) }) } @@ -122,8 +111,8 @@ fn translate_nullable_ranks( filter: &Mask, ranks: &[P], indices_validity: &BitBuffer, - filtered_len: usize, ) -> VortexResult> { + let filtered_len = filter.true_count(); if let Some(start) = contiguous_filter_start(filter) { return translate_nullable_ranks_with_offset(ranks, indices_validity, filtered_len, start); } @@ -227,27 +216,61 @@ fn translate_indices(filter: &Mask, indices: &PrimitiveArray) -> VortexResult VortexResult> { + match_each_integer_ptype!(indices.ptype(), |P| { + contiguous_sequential_take_range(filter, indices.as_slice::

()) + }) +} + +fn sequential_take_len( + indices: &PrimitiveArray, + filtered_len: usize, +) -> VortexResult> { + match_each_integer_ptype!(indices.ptype(), |P| { + sequential_take_len_typed(indices.as_slice::

(), filtered_len) + }) +} + +fn sequential_take_len_typed( + ranks: &[P], + filtered_len: usize, +) -> VortexResult> { + for (idx, rank) in ranks.iter().enumerate() { + if rank.as_() != idx { + return Ok(None); + } + } + + if ranks.len() > filtered_len { + vortex_bail!(OutOfBounds: ranks.len() - 1, 0, filtered_len); + } + + Ok(Some(ranks.len())) +} + fn translate_ranks(filter: &Mask, ranks: &[P]) -> VortexResult> { let mut translated = BufferMut::::with_capacity(ranks.len()); let translated_ptr = translated.spare_capacity_mut().as_mut_ptr().cast::(); + let filtered_len = filter.true_count(); if let Some(start) = contiguous_filter_start(filter) { for (idx, rank) in ranks.iter().enumerate() { + let rank = validate_rank(*rank, filtered_len)?; // SAFETY: `translated` has capacity for all ranks and this loop initializes each // output slot once. - unsafe { - translated_ptr - .add(idx) - .write(u64::try_from(start + rank.as_())?) - }; + unsafe { translated_ptr.add(idx).write(u64::try_from(start + rank)?) }; } } else { let filter_indices = match filter.indices() { AllOr::All => { for (idx, rank) in ranks.iter().enumerate() { + let rank = validate_rank(*rank, filtered_len)?; // SAFETY: `translated` has capacity for all ranks and this loop initializes // each output slot once. - unsafe { translated_ptr.add(idx).write(u64::try_from(rank.as_())?) }; + unsafe { translated_ptr.add(idx).write(u64::try_from(rank)?) }; } // SAFETY: The loop writes exactly `ranks.len()` initialized values. @@ -259,12 +282,13 @@ fn translate_ranks(filter: &Mask, ranks: &[P]) -> VortexResult< }; for (idx, rank) in ranks.iter().enumerate() { + let rank = validate_rank(*rank, filtered_len)?; // SAFETY: `translated` has capacity for all ranks. `rank` was checked against the // filtered length, and filter indices are valid child positions by construction. unsafe { translated_ptr .add(idx) - .write(u64::try_from(*filter_indices.get_unchecked(rank.as_()))?) + .write(u64::try_from(*filter_indices.get_unchecked(rank))?) }; } } @@ -285,28 +309,12 @@ fn validate_rank(rank: P, filtered_len: usize) -> VortexResult< fn take_primitive( array: ArrayView<'_, Filter>, indices: &PrimitiveArray, - indices_validity: &Mask, ctx: &mut ExecutionCtx, ) -> VortexResult> { let child = array.child().clone().execute::(ctx)?; - - let child_validity = child.validity()?; - if !child_validity.no_nulls() { - return Ok(None); - } - - let output_validity = - Validity::from_mask(indices_validity.clone(), indices.dtype().nullability()); match_each_native_ptype!(child.ptype(), |T| { match_each_integer_ptype!(indices.ptype(), |P| { - take_primitive_typed::( - child, - array.filter_mask(), - indices, - indices_validity, - output_validity, - ) - .map(Some) + take_primitive_typed::(child, array.filter_mask(), indices, ctx).map(Some) }) }) } @@ -314,28 +322,12 @@ fn take_primitive( fn take_decimal( array: ArrayView<'_, Filter>, indices: &PrimitiveArray, - indices_validity: &Mask, ctx: &mut ExecutionCtx, ) -> VortexResult> { let child = array.child().clone().execute::(ctx)?; - - let child_validity = child.validity()?; - if !child_validity.no_nulls() { - return Ok(None); - } - - let output_validity = - Validity::from_mask(indices_validity.clone(), indices.dtype().nullability()); match_each_decimal_value_type!(child.values_type(), |T| { match_each_integer_ptype!(indices.ptype(), |P| { - take_decimal_typed::( - child, - array.filter_mask(), - indices, - indices_validity, - output_validity, - ) - .map(Some) + take_decimal_typed::(child, array.filter_mask(), indices, ctx).map(Some) }) }) } @@ -344,8 +336,7 @@ fn take_decimal_typed( child: DecimalArray, filter: &Mask, indices: &PrimitiveArray, - indices_validity: &Mask, - output_validity: Validity, + ctx: &mut ExecutionCtx, ) -> VortexResult where T: NativeDecimalType, @@ -353,10 +344,13 @@ where { let ranks = indices.as_slice::

(); let decimal_dtype = child.decimal_dtype(); + let child_validity = child.validity()?; + let indices_validity = indices.validity()?.execute_mask(indices.len(), ctx)?; - if indices_validity.all_true() { + let taken = if indices_validity.all_true() { if let Some((start, end)) = contiguous_sequential_take_range(filter, ranks)? { let values = child.buffer_handle().slice_typed::(start..end); + let output_validity = contiguous_output_validity(&child_validity, indices, start..end)?; // SAFETY: The values are sliced from an existing valid decimal array, and the output // validity was built for exactly the sliced take length. return Ok(unsafe { @@ -370,21 +364,17 @@ where .into_array()); } - let taken = take_filtered_values::(child.buffer::().as_slice(), filter, ranks)?; - // SAFETY: Taking existing decimal values preserves the decimal dtype invariants, and the - // output validity was built for the take length. - return Ok( - unsafe { DecimalArray::new_unchecked(taken, decimal_dtype, output_validity) } - .into_array(), - ); - } - - let taken = take_filtered_values_nullable::( - child.buffer::().as_slice(), - filter, - ranks, - indices_validity, - )?; + take_filtered_values::(child.buffer::().as_slice(), filter, ranks)? + } else { + take_filtered_values_nullable::( + child.buffer::().as_slice(), + filter, + ranks, + &indices_validity, + )? + }; + let output_validity = + take_output_validity(&child_validity, filter, indices, &indices_validity)?; // SAFETY: Valid ranks copy existing decimal values, null ranks write default placeholders that // are hidden by output validity, and the output validity was built for the take length. Ok(unsafe { DecimalArray::new_unchecked(taken, decimal_dtype, output_validity) }.into_array()) @@ -394,39 +384,45 @@ fn take_primitive_typed( child: PrimitiveArray, filter: &Mask, indices: &PrimitiveArray, - indices_validity: &Mask, - output_validity: Validity, + ctx: &mut ExecutionCtx, ) -> VortexResult where T: NativePType, P: IntegerPType, { let ranks = indices.as_slice::

(); + let child_validity = child.validity()?; + let indices_validity = indices.validity()?.execute_mask(indices.len(), ctx)?; if indices_validity.all_true() { - return take_primitive_all_valid::(child, filter, ranks, output_validity); + return take_primitive_all_valid::(child, filter, indices, ranks, &child_validity); } - take_primitive_nullable::( + let taken = take_filtered_values_nullable::( child.as_slice::(), filter, ranks, - indices_validity, - output_validity, - ) + &indices_validity, + )?; + let output_validity = + take_output_validity(&child_validity, filter, indices, &indices_validity)?; + + Ok(PrimitiveArray::new(taken, output_validity).into_array()) } fn take_primitive_all_valid( child: PrimitiveArray, filter: &Mask, + indices: &PrimitiveArray, ranks: &[P], - output_validity: Validity, + child_validity: &Validity, ) -> VortexResult where T: NativePType, P: IntegerPType, { if let Some((start, end)) = contiguous_sequential_take_range(filter, ranks)? { + let output_validity = contiguous_output_validity(child_validity, indices, start..end)?; return Ok(PrimitiveArray::from_buffer_handle( child.buffer_handle().slice_typed::(start..end), T::PTYPE, @@ -436,22 +432,49 @@ where } let taken = take_filtered_values::(child.as_slice::(), filter, ranks)?; + let output_validity = take_output_validity( + child_validity, + filter, + indices, + &Mask::new_true(indices.len()), + )?; Ok(PrimitiveArray::new(taken, output_validity).into_array()) } -fn take_primitive_nullable( - values: &[T], +fn contiguous_output_validity( + child_validity: &Validity, + indices: &PrimitiveArray, + range: std::ops::Range, +) -> VortexResult { + if child_validity.no_nulls() { + return indices.validity(); + } + + child_validity.slice(range) +} + +fn take_output_validity( + child_validity: &Validity, filter: &Mask, - ranks: &[P], + indices: &PrimitiveArray, indices_validity: &Mask, - output_validity: Validity, -) -> VortexResult -where - T: NativePType, - P: IntegerPType, -{ - let taken = take_filtered_values_nullable::(values, filter, ranks, indices_validity)?; - Ok(PrimitiveArray::new(taken, output_validity).into_array()) +) -> VortexResult { + if child_validity.no_nulls() { + return indices.validity(); + } + + let translated_indices = match indices_validity.bit_buffer() { + AllOr::All => PrimitiveArray::new(translate_indices(filter, indices)?, indices.validity()?) + .into_array(), + AllOr::None => return Ok(Validity::AllInvalid), + AllOr::Some(b) => PrimitiveArray::new( + translate_nullable_indices(filter, indices, b)?, + indices.validity()?, + ) + .into_array(), + }; + + child_validity.take(&translated_indices) } fn take_filtered_values_nullable( @@ -570,6 +593,11 @@ fn contiguous_sequential_take_range( } } + let filtered_len = filter.true_count(); + if ranks.len() > filtered_len { + vortex_bail!(OutOfBounds: ranks.len() - 1, 0, filtered_len); + } + Ok(Some((start, start + ranks.len()))) } @@ -578,11 +606,13 @@ where T: Copy + Default, P: IntegerPType, { + let filtered_len = filter.true_count(); + if let Some(start) = contiguous_filter_start(filter) { let mut out = BufferMut::::with_capacity(ranks.len()); let out_ptr = out.spare_capacity_mut().as_mut_ptr().cast::(); for (idx, rank) in ranks.iter().enumerate() { - let rank = validate_rank(*rank, filter.true_count())?; + let rank = validate_rank(*rank, filtered_len)?; // SAFETY: `out` has capacity for all ranks. The filter is contiguous with // `filtered_len` values starting at `start`, and `rank` was checked above. unsafe { out_ptr.add(idx).write(*values.get_unchecked(start + rank)) }; @@ -602,12 +632,13 @@ where let mut out = BufferMut::::with_capacity(ranks.len()); let out_ptr = out.spare_capacity_mut().as_mut_ptr().cast::(); for (idx, rank) in ranks.iter().enumerate() { + let rank = validate_rank(*rank, filtered_len)?; // SAFETY: `out` has capacity for all ranks. `rank` was bounds-checked against // `indices`, whose values are valid positions in `values`. unsafe { out_ptr .add(idx) - .write(*values.get_unchecked(*indices.get_unchecked(rank.as_()))) + .write(*values.get_unchecked(*indices.get_unchecked(rank))) }; } @@ -641,6 +672,10 @@ fn contiguous_filter_start(filter: &Mask) -> Option { (end - start == filter.true_count()).then_some(start) } +fn should_materialize_big_take(array: ArrayView<'_, Filter>, indices: &ArrayRef) -> bool { + indices.len() >= array.len() && array.len() >= BIG_TAKE_FALLBACK_LEN +} + impl TakeExecute for Filter { fn take( array: ArrayView<'_, Filter>, @@ -658,7 +693,7 @@ impl TakeExecute for Filter { vortex_bail!("Invalid indices dtype: {}", indices.dtype()) }; - if should_fallback_nullable_fixed_width_full_take(array, indices)? { + if should_materialize_big_take(array, indices) { return Ok(None); } @@ -686,7 +721,6 @@ mod tests { use crate::RecursiveCanonical; use crate::arrays::BoolArray; use crate::arrays::DecimalArray; - use crate::arrays::Dict; use crate::arrays::DictArray; use crate::arrays::FilterArray; use crate::arrays::FixedSizeListArray; @@ -724,7 +758,6 @@ mod tests { .execute_parent(&parent, 1, &mut ctx)? .expect("filter child should execute its take parent"); - assert!(result.as_opt::().is_some()); assert_arrays_eq!( result.execute::(&mut ctx)?.0, PrimitiveArray::from_option_iter([Some(40i32), None, Some(10)]).into_array() @@ -786,6 +819,94 @@ mod tests { Ok(()) } + fn assert_take_execute_rejects_out_of_bounds_rank( + child: crate::ArrayRef, + filter_mask: Mask, + codes: crate::ArrayRef, + ) -> VortexResult<()> { + let filter = FilterArray::new(child, filter_mask).into_array(); + let parent = DictArray::try_new(codes, filter.clone())?.into_array(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + if let Err(err) = filter.execute_parent(&parent, 1, &mut ctx) { + assert!( + err.to_string().contains("out of bounds"), + "unexpected error: {err}" + ); + return Ok(()); + } + + panic!("out-of-bounds rank should fail"); + } + + #[test] + fn test_take_execute_kernel_rejects_contiguous_sequential_rank_past_filter_len() + -> VortexResult<()> { + assert_take_execute_rejects_out_of_bounds_rank( + buffer![10i32, 20, 30, 40, 50].into_array(), + Mask::from_slices(5, vec![(1, 4)]), + buffer![0u64, 1, 2, 3].into_array(), + ) + } + + #[test] + fn test_take_execute_kernel_rejects_random_mask_rank_past_filter_len() -> VortexResult<()> { + assert_take_execute_rejects_out_of_bounds_rank( + buffer![10i32, 20, 30, 40, 50].into_array(), + Mask::from_indices(5, vec![1, 3, 4]), + buffer![2u64, 3].into_array(), + ) + } + + #[test] + fn test_take_execute_kernel_rejects_translated_rank_past_filter_len() -> VortexResult<()> { + assert_take_execute_rejects_out_of_bounds_rank( + ListArray::try_new( + buffer![10u32, 11, 20, 30, 31, 32, 40, 50, 51].into_array(), + buffer![0u32, 2, 3, 6, 7, 9].into_array(), + Validity::NonNullable, + )? + .into_array(), + Mask::from_indices(5, vec![0, 2, 4]), + buffer![0u64, 3].into_array(), + ) + } + + #[test] + fn test_take_execute_kernel_handles_empty_sequential_take() -> VortexResult<()> { + let filter = FilterArray::new( + ListArray::try_new( + buffer![10u32, 11, 20, 30, 31, 32, 40, 50, 51].into_array(), + buffer![0u32, 2, 3, 6, 7, 9].into_array(), + Validity::NonNullable, + )? + .into_array(), + Mask::from_indices(5, vec![0, 2, 4]), + ) + .into_array(); + let parent = DictArray::try_new( + PrimitiveArray::from_iter(std::iter::empty::()).into_array(), + filter.clone(), + )? + .into_array(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = filter + .execute_parent(&parent, 1, &mut ctx)? + .expect("filter child should execute its take parent"); + + assert_arrays_eq!( + result.execute::(&mut ctx)?.0, + ListArray::try_new( + PrimitiveArray::from_iter(std::iter::empty::()).into_array(), + buffer![0u32].into_array(), + Validity::NonNullable, + )? + .into_array() + ); + Ok(()) + } + fn assert_take_execute_maps_child_dtype( child: crate::ArrayRef, expected: crate::ArrayRef, @@ -821,85 +942,97 @@ mod tests { Ok(()) } - fn execute_large_nullable_fixed_width_take( - child: crate::ArrayRef, + fn execute_large_primitive_full_take( + filter_mask: Mask, ) -> VortexResult> { - let filter = - FilterArray::new(child, Mask::from_iter([true, false, true, true, false])).into_array(); - let indices = PrimitiveArray::from_option_iter( - (0..=super::NULLABLE_FIXED_WIDTH_FULL_TAKE_FALLBACK_LEN) - .map(|idx| Some((idx % 3) as u64)), + let child_len = filter_mask.len(); + let child_len_u32 = u32::try_from(child_len)?; + let filter = FilterArray::new( + PrimitiveArray::from_iter(0..child_len_u32).into_array(), + filter_mask, ) .into_array(); - let parent = DictArray::try_new(indices, filter.clone())?.into_array(); + let indices = PrimitiveArray::from_iter((0..filter.len()).map(|idx| idx as u64)); + let parent = DictArray::try_new(indices.into_array(), filter.clone())?.into_array(); let mut ctx = ExecutionCtx::new(VortexSession::empty()); filter.execute_parent(&parent, 1, &mut ctx) } #[test] - fn test_take_execute_kernel_handles_large_nullable_primitive_take_without_child_nulls() - -> VortexResult<()> { - let result = - execute_large_nullable_fixed_width_take(buffer![10i32, 20, 30, 40, 50].into_array())?; + fn test_take_execute_kernel_materializes_large_full_take() -> VortexResult<()> { + let filtered_len = super::BIG_TAKE_FALLBACK_LEN; + let result = execute_large_primitive_full_take(Mask::from_indices( + filtered_len * 2, + (0..filtered_len).map(|idx| idx * 2), + ))?; - assert_eq!( - result - .expect("non-null fixed-width child should stay on the fast path") - .len(), - super::NULLABLE_FIXED_WIDTH_FULL_TAKE_FALLBACK_LEN + 1 - ); + assert!(result.is_none()); Ok(()) } #[test] - fn test_take_execute_kernel_handles_large_nullable_decimal_take_without_child_nulls() - -> VortexResult<()> { - let decimal_dtype = DecimalDType::new(19, 2); + fn test_take_execute_kernel_materializes_large_contiguous_full_take() -> VortexResult<()> { + let filtered_len = super::BIG_TAKE_FALLBACK_LEN; + let result = execute_large_primitive_full_take(Mask::from_slices( + filtered_len * 2, + vec![(filtered_len / 2, filtered_len / 2 + filtered_len)], + ))?; - let result = execute_large_nullable_fixed_width_take( - DecimalArray::new( - buffer![100i128, 200, 300, 400, 500], - decimal_dtype, - Validity::NonNullable, - ) - .into_array(), - )?; - - assert_eq!( - result - .expect("non-null fixed-width child should stay on the fast path") - .len(), - super::NULLABLE_FIXED_WIDTH_FULL_TAKE_FALLBACK_LEN + 1 - ); + assert!(result.is_none()); Ok(()) } #[test] - fn test_take_execute_kernel_falls_back_for_large_nullable_primitive_take_with_child_nulls() - -> VortexResult<()> { - let result = execute_large_nullable_fixed_width_take( + fn test_take_execute_kernel_handles_nullable_primitive_filter_child() -> VortexResult<()> { + let filter = FilterArray::new( PrimitiveArray::from_option_iter([Some(10i32), Some(20), None, Some(40), Some(50)]) .into_array(), - )?; + Mask::from_iter([true, false, true, true, false]), + ) + .into_array(); + let parent = + DictArray::try_new(buffer![2u64, 0, 1].into_array(), filter.clone())?.into_array(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); - assert!(result.is_none()); + let result = filter.execute_parent(&parent, 1, &mut ctx)?; + + assert_arrays_eq!( + result + .expect("filter child should execute its take parent") + .execute::(&mut ctx)? + .0, + PrimitiveArray::from_option_iter([Some(40i32), Some(10), None]).into_array() + ); Ok(()) } #[test] - fn test_take_execute_kernel_falls_back_for_large_nullable_decimal_take_with_child_nulls() - -> VortexResult<()> { + fn test_take_execute_kernel_handles_nullable_decimal_filter_child() -> VortexResult<()> { let decimal_dtype = DecimalDType::new(19, 2); - let result = execute_large_nullable_fixed_width_take( + let filter = FilterArray::new( DecimalArray::from_option_iter( [Some(100i128), Some(200), None, Some(400), Some(500)], decimal_dtype, ) .into_array(), - )?; + Mask::from_iter([true, false, true, true, false]), + ) + .into_array(); + let parent = + DictArray::try_new(buffer![2u64, 0, 1].into_array(), filter.clone())?.into_array(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); - assert!(result.is_none()); + let result = filter.execute_parent(&parent, 1, &mut ctx)?; + + assert_arrays_eq!( + result + .expect("filter child should execute its take parent") + .execute::(&mut ctx)? + .0, + DecimalArray::from_option_iter([Some(400i128), Some(100), None], decimal_dtype) + .into_array() + ); Ok(()) } diff --git a/vortex-buffer/public-api.lock b/vortex-buffer/public-api.lock index 3080ee8116e..55dec017580 100644 --- a/vortex-buffer/public-api.lock +++ b/vortex-buffer/public-api.lock @@ -280,7 +280,7 @@ pub fn vortex_buffer::BitBuffer::new_with_offset(vortex_buffer::ByteBuffer, usiz pub fn vortex_buffer::BitBuffer::offset(&self) -> usize -pub fn vortex_buffer::BitBuffer::select(&self, usize) -> usize +pub fn vortex_buffer::BitBuffer::select(&self, usize) -> core::option::Option pub fn vortex_buffer::BitBuffer::set_indices(&self) -> arrow_buffer::util::bit_iterator::BitIndexIterator<'_> diff --git a/vortex-buffer/src/bit/buf.rs b/vortex-buffer/src/bit/buf.rs index 3d49cba0716..9b98d1cd5d6 100644 --- a/vortex-buffer/src/bit/buf.rs +++ b/vortex-buffer/src/bit/buf.rs @@ -325,11 +325,8 @@ impl BitBuffer { /// This is the "select" operation on a bitmap: given a rank `nth`, find /// which logical bit position holds that rank. /// - /// # Panics - /// - /// Panics (debug) or produces undefined results (release) if `nth` is - /// greater than or equal to [`true_count`](Self::true_count). - pub fn select(&self, nth: usize) -> usize { + /// Returns `None` if `nth` is greater than or equal to the number of set bits. + pub fn select(&self, nth: usize) -> Option { bit_select(self.buffer.as_slice(), self.offset, self.len, nth) } diff --git a/vortex-buffer/src/bit/select.rs b/vortex-buffer/src/bit/select.rs index 2647c080c22..f04983e6ad8 100644 --- a/vortex-buffer/src/bit/select.rs +++ b/vortex-buffer/src/bit/select.rs @@ -7,13 +7,14 @@ use super::count_ones::align_offset_len; /// `[offset, offset + len)` of the given byte slice. /// /// The returned position is relative to the logical start (i.e., 0-indexed from `offset`). +/// Returns `None` if `nth` is out of bounds. /// /// Uses architecture-specific optimizations: /// - **aarch64**: NEON `vcnt`-based popcount for the word-level scan. /// - **x86_64 + BMI2**: `pdep` + `tzcnt` for the final in-word select. /// - **Scalar fallback**: 4× unrolled word scan with `count_ones`, byte-level narrowing. #[inline] -pub fn bit_select(bytes: &[u8], offset: usize, len: usize, nth: usize) -> usize { +pub fn bit_select(bytes: &[u8], offset: usize, len: usize, nth: usize) -> Option { let (head, middle, tail) = align_offset_len(bytes, offset, len); let mut remaining = nth; let mut pos = 0usize; @@ -22,7 +23,7 @@ pub fn bit_select(bytes: &[u8], offset: usize, len: usize, nth: usize) -> usize if let Some(head) = head { let count = head.count_ones() as usize; if remaining < count { - return select_in_byte(head, remaining); + return Some(select_in_byte(head, remaining)); } remaining -= count; let start_bit = offset % 8; @@ -39,14 +40,14 @@ pub fn bit_select(bytes: &[u8], offset: usize, len: usize, nth: usize) -> usize if word_idx < words.len() { let word = u64::from_le_bytes(words[word_idx]); - return pos + select_in_word(word, remaining); + return Some(pos + select_in_word(word, remaining)); } // Remaining aligned bytes that don't fill a full u64. for &byte in tail_bytes { let count = byte.count_ones() as usize; if remaining < count { - return pos + select_in_byte(byte, remaining); + return Some(pos + select_in_byte(byte, remaining)); } remaining -= count; pos += 8; @@ -54,15 +55,13 @@ pub fn bit_select(bytes: &[u8], offset: usize, len: usize, nth: usize) -> usize } // ── partial last byte ─────────────────────────────────────────────── - if let Some(tail) = tail { - debug_assert!( - remaining < tail.count_ones() as usize, - "bit_select: nth={nth} out of bounds" - ); - return pos + select_in_byte(tail, remaining); + if let Some(tail) = tail + && remaining < tail.count_ones() as usize + { + return Some(pos + select_in_byte(tail, remaining)); } - unreachable!("bit_select: nth={nth} exceeds set bit count") + None } // ── Word-level scan ───────────────────────────────────────────────────── @@ -309,7 +308,7 @@ mod tests { // Every bit is set — select(n) == n. let buf = [0xFFu8; 16]; // 128 bits, all set for nth in 0..128 { - assert_eq!(bit_select(&buf, 0, 128, nth), nth, "nth={nth}"); + assert_eq!(bit_select(&buf, 0, 128, nth), Some(nth), "nth={nth}"); } } @@ -318,7 +317,7 @@ mod tests { // 0b01010101 repeated: bits 0,2,4,6 of each byte are set. let buf = [0x55u8; 16]; // 128 bits, 64 set for nth in 0..64 { - assert_eq!(bit_select(&buf, 0, 128, nth), nth * 2, "nth={nth}"); + assert_eq!(bit_select(&buf, 0, 128, nth), Some(nth * 2), "nth={nth}"); } } @@ -327,7 +326,15 @@ mod tests { // Only bit 42 is set. let mut buf = [0u8; 16]; buf[42 / 8] |= 1 << (42 % 8); - assert_eq!(bit_select(&buf, 0, 128, 0), 42); + assert_eq!(bit_select(&buf, 0, 128, 0), Some(42)); + } + + #[test] + fn test_select_out_of_bounds_returns_none() { + let buf = [0b0001_0100u8]; + assert_eq!(bit_select(&buf, 0, 8, 0), Some(2)); + assert_eq!(bit_select(&buf, 0, 8, 1), Some(4)); + assert_eq!(bit_select(&buf, 0, 8, 2), None); } #[rstest] @@ -360,7 +367,7 @@ mod tests { for (nth, &expected_pos) in expected.iter().enumerate() { assert_eq!( bit_select(&buf, offset, len, nth), - expected_pos, + Some(expected_pos), "offset={offset} len={len} nth={nth}" ); } @@ -379,6 +386,8 @@ mod tests { // Spot-check a few positions. let first = bit_select(&buf, 0, len, 0); let last = bit_select(&buf, 0, len, true_count - 1); + let first = first.expect("buffer has at least one set bit"); + let last = last.expect("true_count - 1 is in bounds"); assert!(first < len); assert!(last < len); assert!(first <= last); diff --git a/vortex-mask/src/lib.rs b/vortex-mask/src/lib.rs index 47c4a97281e..30c05f8f410 100644 --- a/vortex-mask/src/lib.rs +++ b/vortex-mask/src/lib.rs @@ -458,7 +458,22 @@ impl Mask { return slices.last().map(|(_, end)| end - 1); } - (values.true_count != 0).then(|| values.buffer.select(values.true_count - 1)) + if values.true_count == 0 { + return None; + } + + Some( + values + .buffer + .select(values.true_count - 1) + .unwrap_or_else(|| { + vortex_panic!( + "Rank {} out of bounds for mask with true count {}", + values.true_count - 1, + values.true_count + ) + }), + ) } } } @@ -479,7 +494,13 @@ impl Mask { return indices[n]; } - values.buffer.select(n) + values.buffer.select(n).unwrap_or_else(|| { + vortex_panic!( + "Rank {} out of bounds for mask with true count {}", + values.true_count - 1, + values.true_count + ) + }) } } }