diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index a790196e149..85750e6aee1 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -2764,6 +2764,10 @@ impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::Extension pub fn vortex_array::arrays::Extension::take(vortex_array::ArrayView<'_, vortex_array::arrays::Extension>, &vortex_array::ArrayRef, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::Filter + +pub fn vortex_array::arrays::Filter::take(vortex_array::ArrayView<'_, vortex_array::arrays::Filter>, &vortex_array::ArrayRef, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::FixedSizeList pub fn vortex_array::arrays::FixedSizeList::take(vortex_array::ArrayView<'_, vortex_array::arrays::FixedSizeList>, &vortex_array::ArrayRef, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> @@ -2980,6 +2984,10 @@ impl vortex_array::ValidityVTable for vortex_array pub fn vortex_array::arrays::Filter::validity(vortex_array::ArrayView<'_, vortex_array::arrays::Filter>) -> vortex_error::VortexResult +impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::Filter + +pub fn vortex_array::arrays::Filter::take(vortex_array::ArrayView<'_, vortex_array::arrays::Filter>, &vortex_array::ArrayRef, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + pub struct vortex_array::arrays::filter::FilterData impl vortex_array::arrays::filter::FilterData @@ -6248,6 +6256,10 @@ impl vortex_array::ValidityVTable for vortex_array pub fn vortex_array::arrays::Filter::validity(vortex_array::ArrayView<'_, vortex_array::arrays::Filter>) -> vortex_error::VortexResult +impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::Filter + +pub fn vortex_array::arrays::Filter::take(vortex_array::ArrayView<'_, vortex_array::arrays::Filter>, &vortex_array::ArrayRef, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + pub struct vortex_array::arrays::FixedSizeList impl core::clone::Clone for vortex_array::arrays::FixedSizeList diff --git a/vortex-array/src/arrays/filter/mod.rs b/vortex-array/src/arrays/filter/mod.rs index 39859bc5b25..15a987e8f20 100644 --- a/vortex-array/src/arrays/filter/mod.rs +++ b/vortex-array/src/arrays/filter/mod.rs @@ -16,6 +16,7 @@ pub use kernel::FilterReduce; pub use kernel::FilterReduceAdaptor; mod rules; +mod take; mod vtable; pub use vtable::Filter; diff --git a/vortex-array/src/arrays/filter/take.rs b/vortex-array/src/arrays/filter/take.rs new file mode 100644 index 00000000000..d589781119e --- /dev/null +++ b/vortex-array/src/arrays/filter/take.rs @@ -0,0 +1,1130 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_buffer::BitBuffer; +use vortex_buffer::Buffer; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_mask::AllOr; +use vortex_mask::Mask; + +use super::Filter; +use crate::ArrayRef; +use crate::IntoArray; +use crate::array::ArrayView; +use crate::arrays::ConstantArray; +use crate::arrays::DecimalArray; +use crate::arrays::PrimitiveArray; +use crate::arrays::decimal::DecimalArrayExt; +use crate::arrays::dict::TakeExecute; +use crate::arrays::dict::TakeExecuteAdaptor; +use crate::arrays::filter::FilterArrayExt; +use crate::builtins::ArrayBuiltins; +use crate::dtype::DType; +use crate::dtype::IntegerPType; +use crate::dtype::NativeDecimalType; +use crate::dtype::NativePType; +use crate::executor::ExecutionCtx; +use crate::kernel::ParentKernelSet; +use crate::match_each_decimal_value_type; +use crate::match_each_integer_ptype; +use crate::match_each_native_ptype; +use crate::scalar::Scalar; +use crate::validity::Validity; + +pub(super) const PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(Filter))]); + +const BIG_TAKE_FALLBACK_LEN: usize = 4096; + +fn take_impl( + array: ArrayView<'_, Filter>, + indices: &PrimitiveArray, + ctx: &mut ExecutionCtx, +) -> VortexResult { + if array.child().dtype().is_primitive() + && let Some(taken) = take_primitive(array, indices, ctx)? + { + return Ok(taken); + } + + if array.child().dtype().is_decimal() + && let Some(taken) = take_decimal(array, indices, ctx)? + { + return Ok(taken); + } + + let indices_validity = indices.validity()?.execute_mask(indices.len(), ctx)?; + if indices_validity.all_true() + && let Some((start, end)) = + contiguous_sequential_take_range_indices(array.filter_mask(), indices)? + { + return array.child().slice(start..end); + } + + if indices_validity.all_true() + && let Some(take_len) = sequential_take_len(indices, array.len())? + { + if take_len == 0 { + return array.child().slice(0..0); + } + let rank_mask = Mask::from_slices(array.len(), vec![(0, take_len)]); + let mask = array.filter_mask().intersect_by_rank(&rank_mask); + return array.child().filter(mask); + } + + match indices_validity.bit_buffer() { + AllOr::All => { + let translated = translate_indices(array.filter_mask(), indices)?; + let translated_indices = + PrimitiveArray::new(translated, indices.validity()?).into_array(); + + array.child().take(translated_indices) + } + AllOr::None => Ok(ConstantArray::new( + Scalar::null(array.dtype().as_nullable()), + indices.len(), + ) + .into_array()), + AllOr::Some(b) => { + let translated = translate_nullable_indices(array.filter_mask(), indices, b)?; + let translated_indices = + PrimitiveArray::new(translated, indices.validity()?).into_array(); + + array.child().take(translated_indices) + } + } +} + +fn translate_nullable_indices( + filter: &Mask, + indices: &PrimitiveArray, + indices_validity: &BitBuffer, +) -> VortexResult> { + match_each_integer_ptype!(indices.ptype(), |P| { + translate_nullable_ranks(filter, indices.as_slice::

(), indices_validity) + }) +} + +fn translate_nullable_ranks( + filter: &Mask, + ranks: &[P], + indices_validity: &BitBuffer, +) -> VortexResult> { + let filtered_len = filter.true_count(); + if let Some(start) = contiguous_filter_start(filter) { + return translate_nullable_ranks_with_offset(ranks, indices_validity, filtered_len, start); + } + + match filter.indices() { + AllOr::All => translate_nullable_ranks_identity(ranks, indices_validity, filtered_len), + AllOr::None => unreachable!("empty filters are handled by take preconditions"), + AllOr::Some(filter_indices) => translate_nullable_ranks_with_indices( + ranks, + indices_validity, + filtered_len, + filter_indices, + ), + } +} + +fn translate_nullable_ranks_with_offset( + ranks: &[P], + indices_validity: &BitBuffer, + filtered_len: usize, + start: usize, +) -> VortexResult> { + let mut translated = BufferMut::::with_capacity(ranks.len()); + let translated_ptr = translated.spare_capacity_mut().as_mut_ptr().cast::(); + + for (idx, rank) in ranks.iter().enumerate() { + let translated_rank = if indices_validity.value(idx) { + let rank = validate_rank(*rank, filtered_len)?; + u64::try_from(start + rank)? + } else { + 0 + }; + + // SAFETY: `translated` has capacity for all ranks and this loop initializes each + // output slot once. + unsafe { translated_ptr.add(idx).write(translated_rank) }; + } + + // SAFETY: The loop writes exactly `ranks.len()` initialized values. + unsafe { translated.set_len(ranks.len()) }; + Ok(translated.freeze()) +} + +fn translate_nullable_ranks_identity( + ranks: &[P], + indices_validity: &BitBuffer, + filtered_len: usize, +) -> VortexResult> { + let mut translated = BufferMut::::with_capacity(ranks.len()); + let translated_ptr = translated.spare_capacity_mut().as_mut_ptr().cast::(); + + for (idx, rank) in ranks.iter().enumerate() { + let translated_rank = if indices_validity.value(idx) { + u64::try_from(validate_rank(*rank, filtered_len)?)? + } else { + 0 + }; + + // SAFETY: `translated` has capacity for all ranks and this loop initializes each + // output slot once. + unsafe { translated_ptr.add(idx).write(translated_rank) }; + } + + // SAFETY: The loop writes exactly `ranks.len()` initialized values. + unsafe { translated.set_len(ranks.len()) }; + Ok(translated.freeze()) +} + +fn translate_nullable_ranks_with_indices( + ranks: &[P], + indices_validity: &BitBuffer, + filtered_len: usize, + filter_indices: &[usize], +) -> VortexResult> { + let mut translated = BufferMut::::with_capacity(ranks.len()); + let translated_ptr = translated.spare_capacity_mut().as_mut_ptr().cast::(); + + for (idx, rank) in ranks.iter().enumerate() { + let translated_rank = if indices_validity.value(idx) { + let rank = validate_rank(*rank, filtered_len)?; + // SAFETY: `rank` was checked against the filtered length, so it is in bounds for + // `filter_indices`; filter indices are valid child positions by construction. + unsafe { u64::try_from(*filter_indices.get_unchecked(rank))? } + } else { + 0 + }; + + // SAFETY: `translated` has capacity for all ranks and this loop initializes each + // output slot once. + unsafe { translated_ptr.add(idx).write(translated_rank) }; + } + + // SAFETY: The loop writes exactly `ranks.len()` initialized values. + unsafe { translated.set_len(ranks.len()) }; + Ok(translated.freeze()) +} + +fn translate_indices(filter: &Mask, indices: &PrimitiveArray) -> VortexResult> { + match_each_integer_ptype!(indices.ptype(), |P| { + translate_ranks(filter, indices.as_slice::

()) + }) +} + +fn contiguous_sequential_take_range_indices( + filter: &Mask, + indices: &PrimitiveArray, +) -> VortexResult> { + match_each_integer_ptype!(indices.ptype(), |P| { + contiguous_sequential_take_range(filter, indices.as_slice::

()) + }) +} + +fn sequential_take_len( + indices: &PrimitiveArray, + filtered_len: usize, +) -> VortexResult> { + match_each_integer_ptype!(indices.ptype(), |P| { + sequential_take_len_typed(indices.as_slice::

(), filtered_len) + }) +} + +fn sequential_take_len_typed( + ranks: &[P], + filtered_len: usize, +) -> VortexResult> { + for (idx, rank) in ranks.iter().enumerate() { + if rank.as_() != idx { + return Ok(None); + } + } + + if ranks.len() > filtered_len { + vortex_bail!(OutOfBounds: ranks.len() - 1, 0, filtered_len); + } + + Ok(Some(ranks.len())) +} + +fn translate_ranks(filter: &Mask, ranks: &[P]) -> VortexResult> { + let mut translated = BufferMut::::with_capacity(ranks.len()); + let translated_ptr = translated.spare_capacity_mut().as_mut_ptr().cast::(); + let filtered_len = filter.true_count(); + + if let Some(start) = contiguous_filter_start(filter) { + for (idx, rank) in ranks.iter().enumerate() { + let rank = validate_rank(*rank, filtered_len)?; + // SAFETY: `translated` has capacity for all ranks and this loop initializes each + // output slot once. + unsafe { translated_ptr.add(idx).write(u64::try_from(start + rank)?) }; + } + } else { + let filter_indices = match filter.indices() { + AllOr::All => { + for (idx, rank) in ranks.iter().enumerate() { + let rank = validate_rank(*rank, filtered_len)?; + // SAFETY: `translated` has capacity for all ranks and this loop initializes + // each output slot once. + unsafe { translated_ptr.add(idx).write(u64::try_from(rank)?) }; + } + + // SAFETY: The loop writes exactly `ranks.len()` initialized values. + unsafe { translated.set_len(ranks.len()) }; + return Ok(translated.freeze()); + } + AllOr::None => unreachable!("empty filters are handled by take preconditions"), + AllOr::Some(filter_indices) => filter_indices, + }; + + for (idx, rank) in ranks.iter().enumerate() { + let rank = validate_rank(*rank, filtered_len)?; + // SAFETY: `translated` has capacity for all ranks. `rank` was checked against the + // filtered length, and filter indices are valid child positions by construction. + unsafe { + translated_ptr + .add(idx) + .write(u64::try_from(*filter_indices.get_unchecked(rank))?) + }; + } + } + + // SAFETY: Each loop path writes exactly `ranks.len()` initialized values. + unsafe { translated.set_len(ranks.len()) }; + Ok(translated.freeze()) +} + +fn validate_rank(rank: P, filtered_len: usize) -> VortexResult { + let rank: usize = rank.as_(); + if rank >= filtered_len { + vortex_bail!(OutOfBounds: rank, 0, filtered_len); + } + Ok(rank) +} + +fn take_primitive( + array: ArrayView<'_, Filter>, + indices: &PrimitiveArray, + ctx: &mut ExecutionCtx, +) -> VortexResult> { + let child = array.child().clone().execute::(ctx)?; + match_each_native_ptype!(child.ptype(), |T| { + match_each_integer_ptype!(indices.ptype(), |P| { + take_primitive_typed::(child, array.filter_mask(), indices, ctx).map(Some) + }) + }) +} + +fn take_decimal( + array: ArrayView<'_, Filter>, + indices: &PrimitiveArray, + ctx: &mut ExecutionCtx, +) -> VortexResult> { + let child = array.child().clone().execute::(ctx)?; + match_each_decimal_value_type!(child.values_type(), |T| { + match_each_integer_ptype!(indices.ptype(), |P| { + take_decimal_typed::(child, array.filter_mask(), indices, ctx).map(Some) + }) + }) +} + +fn take_decimal_typed( + child: DecimalArray, + filter: &Mask, + indices: &PrimitiveArray, + ctx: &mut ExecutionCtx, +) -> VortexResult +where + T: NativeDecimalType, + P: IntegerPType, +{ + let ranks = indices.as_slice::

(); + let decimal_dtype = child.decimal_dtype(); + let child_validity = child.validity()?; + let indices_validity = indices.validity()?.execute_mask(indices.len(), ctx)?; + + let taken = if indices_validity.all_true() { + if let Some((start, end)) = contiguous_sequential_take_range(filter, ranks)? { + let values = child.buffer_handle().slice_typed::(start..end); + let output_validity = contiguous_output_validity(&child_validity, indices, start..end)?; + // SAFETY: The values are sliced from an existing valid decimal array, and the output + // validity was built for exactly the sliced take length. + return Ok(unsafe { + DecimalArray::new_unchecked_handle( + values, + T::DECIMAL_TYPE, + decimal_dtype, + output_validity, + ) + } + .into_array()); + } + + take_filtered_values::(child.buffer::().as_slice(), filter, ranks)? + } else { + take_filtered_values_nullable::( + child.buffer::().as_slice(), + filter, + ranks, + &indices_validity, + )? + }; + let output_validity = + take_output_validity(&child_validity, filter, indices, &indices_validity)?; + // SAFETY: Valid ranks copy existing decimal values, null ranks write default placeholders that + // are hidden by output validity, and the output validity was built for the take length. + Ok(unsafe { DecimalArray::new_unchecked(taken, decimal_dtype, output_validity) }.into_array()) +} + +fn take_primitive_typed( + child: PrimitiveArray, + filter: &Mask, + indices: &PrimitiveArray, + ctx: &mut ExecutionCtx, +) -> VortexResult +where + T: NativePType, + P: IntegerPType, +{ + let ranks = indices.as_slice::

(); + let child_validity = child.validity()?; + + let indices_validity = indices.validity()?.execute_mask(indices.len(), ctx)?; + if indices_validity.all_true() { + return take_primitive_all_valid::(child, filter, indices, ranks, &child_validity); + } + + let taken = take_filtered_values_nullable::( + child.as_slice::(), + filter, + ranks, + &indices_validity, + )?; + let output_validity = + take_output_validity(&child_validity, filter, indices, &indices_validity)?; + + Ok(PrimitiveArray::new(taken, output_validity).into_array()) +} + +fn take_primitive_all_valid( + child: PrimitiveArray, + filter: &Mask, + indices: &PrimitiveArray, + ranks: &[P], + child_validity: &Validity, +) -> VortexResult +where + T: NativePType, + P: IntegerPType, +{ + if let Some((start, end)) = contiguous_sequential_take_range(filter, ranks)? { + let output_validity = contiguous_output_validity(child_validity, indices, start..end)?; + return Ok(PrimitiveArray::from_buffer_handle( + child.buffer_handle().slice_typed::(start..end), + T::PTYPE, + output_validity, + ) + .into_array()); + } + + let taken = take_filtered_values::(child.as_slice::(), filter, ranks)?; + let output_validity = take_output_validity( + child_validity, + filter, + indices, + &Mask::new_true(indices.len()), + )?; + Ok(PrimitiveArray::new(taken, output_validity).into_array()) +} + +fn contiguous_output_validity( + child_validity: &Validity, + indices: &PrimitiveArray, + range: std::ops::Range, +) -> VortexResult { + if child_validity.no_nulls() { + return indices.validity(); + } + + child_validity.slice(range) +} + +fn take_output_validity( + child_validity: &Validity, + filter: &Mask, + indices: &PrimitiveArray, + indices_validity: &Mask, +) -> VortexResult { + if child_validity.no_nulls() { + return indices.validity(); + } + + let translated_indices = match indices_validity.bit_buffer() { + AllOr::All => PrimitiveArray::new(translate_indices(filter, indices)?, indices.validity()?) + .into_array(), + AllOr::None => return Ok(Validity::AllInvalid), + AllOr::Some(b) => PrimitiveArray::new( + translate_nullable_indices(filter, indices, b)?, + indices.validity()?, + ) + .into_array(), + }; + + child_validity.take(&translated_indices) +} + +fn take_filtered_values_nullable( + values: &[T], + filter: &Mask, + ranks: &[P], + indices_validity: &Mask, +) -> VortexResult> +where + T: Copy + Default, + P: IntegerPType, +{ + if indices_validity.all_false() { + return Ok(Buffer::zeroed(ranks.len())); + } + + if let Some(start) = contiguous_filter_start(filter) { + let mut out = BufferMut::::with_capacity(ranks.len()); + let out_ptr = out.spare_capacity_mut().as_mut_ptr().cast::(); + for (idx, rank) in ranks.iter().enumerate() { + let value = if indices_validity.value(idx) { + let rank = validate_rank(*rank, filter.true_count())?; + // SAFETY: `rank` was checked against the contiguous filtered length. + unsafe { *values.get_unchecked(start + rank) } + } else { + T::default() + }; + + // SAFETY: `out` has capacity for all ranks and this loop initializes each output slot + // once. + unsafe { out_ptr.add(idx).write(value) }; + } + + // SAFETY: The loop writes exactly `ranks.len()` initialized values. + unsafe { out.set_len(ranks.len()) }; + return Ok(out.freeze()); + } + + let indices = match filter.indices() { + AllOr::All => { + return take_values_by_rank_nullable( + values, + ranks, + indices_validity, + filter.true_count(), + ); + } + AllOr::None => unreachable!("empty filters are handled by take preconditions"), + AllOr::Some(indices) => indices, + }; + + let mut out = BufferMut::::with_capacity(ranks.len()); + let out_ptr = out.spare_capacity_mut().as_mut_ptr().cast::(); + for (idx, rank) in ranks.iter().enumerate() { + let value = if indices_validity.value(idx) { + let rank = validate_rank(*rank, filter.true_count())?; + // SAFETY: `rank` was bounds-checked against `indices`, whose values are valid + // positions in `values`. + unsafe { *values.get_unchecked(*indices.get_unchecked(rank)) } + } else { + T::default() + }; + + // SAFETY: `out` has capacity for all ranks and this loop initializes each output slot + // once. + unsafe { out_ptr.add(idx).write(value) }; + } + + // SAFETY: The loop writes exactly `ranks.len()` initialized values. + unsafe { out.set_len(ranks.len()) }; + Ok(out.freeze()) +} + +fn take_values_by_rank_nullable( + values: &[T], + ranks: &[P], + indices_validity: &Mask, + filtered_len: usize, +) -> VortexResult> +where + T: Copy + Default, + P: IntegerPType, +{ + let mut out = BufferMut::::with_capacity(ranks.len()); + let out_ptr = out.spare_capacity_mut().as_mut_ptr().cast::(); + for (idx, rank) in ranks.iter().enumerate() { + let value = if indices_validity.value(idx) { + let rank = validate_rank(*rank, filtered_len)?; + // SAFETY: `rank` was bounds-checked. + unsafe { *values.get_unchecked(rank) } + } else { + T::default() + }; + + // SAFETY: `out` has capacity for all ranks and this loop initializes each output slot + // once. + unsafe { out_ptr.add(idx).write(value) }; + } + + // SAFETY: The loop writes exactly `ranks.len()` initialized values. + unsafe { out.set_len(ranks.len()) }; + Ok(out.freeze()) +} + +fn contiguous_sequential_take_range( + filter: &Mask, + ranks: &[P], +) -> VortexResult> { + let Some(start) = contiguous_filter_start(filter) else { + return Ok(None); + }; + + for (idx, rank) in ranks.iter().enumerate() { + if rank.as_() != idx { + return Ok(None); + } + } + + let filtered_len = filter.true_count(); + if ranks.len() > filtered_len { + vortex_bail!(OutOfBounds: ranks.len() - 1, 0, filtered_len); + } + + Ok(Some((start, start + ranks.len()))) +} + +fn take_filtered_values(values: &[T], filter: &Mask, ranks: &[P]) -> VortexResult> +where + T: Copy + Default, + P: IntegerPType, +{ + let filtered_len = filter.true_count(); + + if let Some(start) = contiguous_filter_start(filter) { + let mut out = BufferMut::::with_capacity(ranks.len()); + let out_ptr = out.spare_capacity_mut().as_mut_ptr().cast::(); + for (idx, rank) in ranks.iter().enumerate() { + let rank = validate_rank(*rank, filtered_len)?; + // SAFETY: `out` has capacity for all ranks. The filter is contiguous with + // `filtered_len` values starting at `start`, and `rank` was checked above. + unsafe { out_ptr.add(idx).write(*values.get_unchecked(start + rank)) }; + } + + // SAFETY: The loop writes exactly `ranks.len()` initialized values. + unsafe { out.set_len(ranks.len()) }; + return Ok(out.freeze()); + } + + let indices = match filter.indices() { + AllOr::All => return take_values_by_rank(values, ranks), + AllOr::None => unreachable!("empty filters are handled by take preconditions"), + AllOr::Some(indices) => indices, + }; + + let mut out = BufferMut::::with_capacity(ranks.len()); + let out_ptr = out.spare_capacity_mut().as_mut_ptr().cast::(); + for (idx, rank) in ranks.iter().enumerate() { + let rank = validate_rank(*rank, filtered_len)?; + // SAFETY: `out` has capacity for all ranks. `rank` was bounds-checked against + // `indices`, whose values are valid positions in `values`. + unsafe { + out_ptr + .add(idx) + .write(*values.get_unchecked(*indices.get_unchecked(rank))) + }; + } + + // SAFETY: The loop writes exactly `ranks.len()` initialized values. + unsafe { out.set_len(ranks.len()) }; + Ok(out.freeze()) +} + +fn take_values_by_rank(values: &[T], ranks: &[P]) -> VortexResult> +where + T: Copy + Default, + P: IntegerPType, +{ + let mut out = BufferMut::::with_capacity(ranks.len()); + let out_ptr = out.spare_capacity_mut().as_mut_ptr().cast::(); + for (idx, rank) in ranks.iter().enumerate() { + let rank = validate_rank(*rank, values.len())?; + + // SAFETY: `out` has capacity for all ranks and `rank` was bounds-checked. + unsafe { out_ptr.add(idx).write(*values.get_unchecked(rank)) }; + } + + // SAFETY: The loop writes exactly `ranks.len()` initialized values. + unsafe { out.set_len(ranks.len()) }; + Ok(out.freeze()) +} + +fn contiguous_filter_start(filter: &Mask) -> Option { + let start = filter.first()?; + let end = filter.last()?.checked_add(1)?; + (end - start == filter.true_count()).then_some(start) +} + +fn should_materialize_big_take(array: ArrayView<'_, Filter>, indices: &ArrayRef) -> bool { + indices.len() >= array.len() && array.len() >= BIG_TAKE_FALLBACK_LEN +} + +impl TakeExecute for Filter { + fn take( + array: ArrayView<'_, Filter>, + indices: &ArrayRef, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + // Bool filtering is already very cheap. Translating take indices through the filter adds + // overhead without improving the downstream bool take, so leave bool children on the + // regular filter path. + if array.child().dtype().is_boolean() { + return Ok(None); + } + + let DType::Primitive(ptype, nullability) = indices.dtype() else { + vortex_bail!("Invalid indices dtype: {}", indices.dtype()) + }; + + if should_materialize_big_take(array, indices) { + return Ok(None); + } + + let unsigned_indices = if ptype.is_unsigned_int() { + indices.clone().execute::(ctx)? + } else { + indices + .clone() + .cast(DType::Primitive(ptype.to_unsigned(), *nullability))? + .execute::(ctx)? + }; + + take_impl(array, &unsigned_indices, ctx).map(Some) + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_error::VortexResult; + use vortex_mask::Mask; + use vortex_session::VortexSession; + + use crate::IntoArray; + use crate::RecursiveCanonical; + use crate::arrays::BoolArray; + use crate::arrays::DecimalArray; + use crate::arrays::DictArray; + use crate::arrays::FilterArray; + use crate::arrays::FixedSizeListArray; + use crate::arrays::ListArray; + use crate::arrays::Primitive; + use crate::arrays::PrimitiveArray; + use crate::arrays::StructArray; + use crate::arrays::VarBinViewArray; + use crate::assert_arrays_eq; + use crate::dtype::DecimalDType; + use crate::dtype::FieldNames; + use crate::executor::ExecutionCtx; + use crate::validity::Validity; + + #[test] + fn test_take_execute_kernel_maps_indices_through_filter() -> VortexResult<()> { + let filter = FilterArray::new( + PrimitiveArray::from_option_iter([Some(10i32), Some(20), Some(30), Some(40), None]) + .into_array(), + Mask::from_iter([true, false, true, true, false]), + ) + .into_array(); + let parent = DictArray::try_new( + PrimitiveArray::new( + buffer![2u64, 100, 0], + Validity::Array(BoolArray::from_iter([true, false, true]).into_array()), + ) + .into_array(), + filter.clone(), + )? + .into_array(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = filter + .execute_parent(&parent, 1, &mut ctx)? + .expect("filter child should execute its take parent"); + + assert_arrays_eq!( + result.execute::(&mut ctx)?.0, + PrimitiveArray::from_option_iter([Some(40i32), None, Some(10)]).into_array() + ); + Ok(()) + } + + #[test] + fn test_take_execute_kernel_nullable_fast_path_maps_indices_through_filter() -> VortexResult<()> + { + let filter = FilterArray::new( + buffer![10i32, 20, 30, 40, 50].into_array(), + Mask::from_slices(5, vec![(1, 4)]), + ) + .into_array(); + let parent = DictArray::try_new( + PrimitiveArray::new( + buffer![2u64, 100, 0], + Validity::Array(BoolArray::from_iter([true, false, true]).into_array()), + ) + .into_array(), + filter.clone(), + )? + .into_array(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = filter + .execute_parent(&parent, 1, &mut ctx)? + .expect("filter child should execute its take parent"); + + assert!(result.as_opt::().is_some()); + assert_arrays_eq!( + result.execute::(&mut ctx)?.0, + PrimitiveArray::from_option_iter([Some(40i32), None, Some(20)]).into_array() + ); + Ok(()) + } + + #[test] + fn test_take_execute_kernel_fast_path_maps_indices_through_filter() -> VortexResult<()> { + let filter = FilterArray::new( + buffer![10i32, 20, 30, 40, 50, 60].into_array(), + Mask::from_indices(6, vec![1, 3, 4, 5]), + ) + .into_array(); + let parent = + DictArray::try_new(buffer![2u64, 0, 3].into_array(), filter.clone())?.into_array(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = filter + .execute_parent(&parent, 1, &mut ctx)? + .expect("filter child should execute its take parent"); + + assert!(result.as_opt::().is_some()); + assert_arrays_eq!( + result.execute::(&mut ctx)?.0, + PrimitiveArray::from_iter([50i32, 20, 60]).into_array() + ); + Ok(()) + } + + fn assert_take_execute_rejects_out_of_bounds_rank( + child: crate::ArrayRef, + filter_mask: Mask, + codes: crate::ArrayRef, + ) -> VortexResult<()> { + let filter = FilterArray::new(child, filter_mask).into_array(); + let parent = DictArray::try_new(codes, filter.clone())?.into_array(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + if let Err(err) = filter.execute_parent(&parent, 1, &mut ctx) { + assert!( + err.to_string().contains("out of bounds"), + "unexpected error: {err}" + ); + return Ok(()); + } + + panic!("out-of-bounds rank should fail"); + } + + #[test] + fn test_take_execute_kernel_rejects_contiguous_sequential_rank_past_filter_len() + -> VortexResult<()> { + assert_take_execute_rejects_out_of_bounds_rank( + buffer![10i32, 20, 30, 40, 50].into_array(), + Mask::from_slices(5, vec![(1, 4)]), + buffer![0u64, 1, 2, 3].into_array(), + ) + } + + #[test] + fn test_take_execute_kernel_rejects_random_mask_rank_past_filter_len() -> VortexResult<()> { + assert_take_execute_rejects_out_of_bounds_rank( + buffer![10i32, 20, 30, 40, 50].into_array(), + Mask::from_indices(5, vec![1, 3, 4]), + buffer![2u64, 3].into_array(), + ) + } + + #[test] + fn test_take_execute_kernel_rejects_translated_rank_past_filter_len() -> VortexResult<()> { + assert_take_execute_rejects_out_of_bounds_rank( + ListArray::try_new( + buffer![10u32, 11, 20, 30, 31, 32, 40, 50, 51].into_array(), + buffer![0u32, 2, 3, 6, 7, 9].into_array(), + Validity::NonNullable, + )? + .into_array(), + Mask::from_indices(5, vec![0, 2, 4]), + buffer![0u64, 3].into_array(), + ) + } + + #[test] + fn test_take_execute_kernel_handles_empty_sequential_take() -> VortexResult<()> { + let filter = FilterArray::new( + ListArray::try_new( + buffer![10u32, 11, 20, 30, 31, 32, 40, 50, 51].into_array(), + buffer![0u32, 2, 3, 6, 7, 9].into_array(), + Validity::NonNullable, + )? + .into_array(), + Mask::from_indices(5, vec![0, 2, 4]), + ) + .into_array(); + let parent = DictArray::try_new( + PrimitiveArray::from_iter(std::iter::empty::()).into_array(), + filter.clone(), + )? + .into_array(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = filter + .execute_parent(&parent, 1, &mut ctx)? + .expect("filter child should execute its take parent"); + + assert_arrays_eq!( + result.execute::(&mut ctx)?.0, + ListArray::try_new( + PrimitiveArray::from_iter(std::iter::empty::()).into_array(), + buffer![0u32].into_array(), + Validity::NonNullable, + )? + .into_array() + ); + Ok(()) + } + + fn assert_take_execute_maps_child_dtype( + child: crate::ArrayRef, + expected: crate::ArrayRef, + ) -> VortexResult<()> { + let filter = + FilterArray::new(child, Mask::from_iter([true, false, true, true, false])).into_array(); + let parent = + DictArray::try_new(buffer![2u64, 0, 1].into_array(), filter.clone())?.into_array(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = filter + .execute_parent(&parent, 1, &mut ctx)? + .expect("filter child should execute its take parent"); + + assert_arrays_eq!(result.execute::(&mut ctx)?.0, expected); + Ok(()) + } + + #[test] + fn test_take_execute_kernel_skips_bool_filter_child() -> VortexResult<()> { + let filter = FilterArray::new( + BoolArray::from_iter([true, false, true, true, false]).into_array(), + Mask::from_iter([true, false, true, true, false]), + ) + .into_array(); + let parent = + DictArray::try_new(buffer![2u64, 0, 1].into_array(), filter.clone())?.into_array(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = filter.execute_parent(&parent, 1, &mut ctx)?; + + assert!(result.is_none()); + Ok(()) + } + + fn execute_large_primitive_full_take( + filter_mask: Mask, + ) -> VortexResult> { + let child_len = filter_mask.len(); + let child_len_u32 = u32::try_from(child_len)?; + let filter = FilterArray::new( + PrimitiveArray::from_iter(0..child_len_u32).into_array(), + filter_mask, + ) + .into_array(); + let indices = PrimitiveArray::from_iter((0..filter.len()).map(|idx| idx as u64)); + let parent = DictArray::try_new(indices.into_array(), filter.clone())?.into_array(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + filter.execute_parent(&parent, 1, &mut ctx) + } + + #[test] + fn test_take_execute_kernel_materializes_large_full_take() -> VortexResult<()> { + let filtered_len = super::BIG_TAKE_FALLBACK_LEN; + let result = execute_large_primitive_full_take(Mask::from_indices( + filtered_len * 2, + (0..filtered_len).map(|idx| idx * 2), + ))?; + + assert!(result.is_none()); + Ok(()) + } + + #[test] + fn test_take_execute_kernel_materializes_large_contiguous_full_take() -> VortexResult<()> { + let filtered_len = super::BIG_TAKE_FALLBACK_LEN; + let result = execute_large_primitive_full_take(Mask::from_slices( + filtered_len * 2, + vec![(filtered_len / 2, filtered_len / 2 + filtered_len)], + ))?; + + assert!(result.is_none()); + Ok(()) + } + + #[test] + fn test_take_execute_kernel_handles_nullable_primitive_filter_child() -> VortexResult<()> { + let filter = FilterArray::new( + PrimitiveArray::from_option_iter([Some(10i32), Some(20), None, Some(40), Some(50)]) + .into_array(), + Mask::from_iter([true, false, true, true, false]), + ) + .into_array(); + let parent = + DictArray::try_new(buffer![2u64, 0, 1].into_array(), filter.clone())?.into_array(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = filter.execute_parent(&parent, 1, &mut ctx)?; + + assert_arrays_eq!( + result + .expect("filter child should execute its take parent") + .execute::(&mut ctx)? + .0, + PrimitiveArray::from_option_iter([Some(40i32), Some(10), None]).into_array() + ); + Ok(()) + } + + #[test] + fn test_take_execute_kernel_handles_nullable_decimal_filter_child() -> VortexResult<()> { + let decimal_dtype = DecimalDType::new(19, 2); + let filter = FilterArray::new( + DecimalArray::from_option_iter( + [Some(100i128), Some(200), None, Some(400), Some(500)], + decimal_dtype, + ) + .into_array(), + Mask::from_iter([true, false, true, true, false]), + ) + .into_array(); + let parent = + DictArray::try_new(buffer![2u64, 0, 1].into_array(), filter.clone())?.into_array(); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let result = filter.execute_parent(&parent, 1, &mut ctx)?; + + assert_arrays_eq!( + result + .expect("filter child should execute its take parent") + .execute::(&mut ctx)? + .0, + DecimalArray::from_option_iter([Some(400i128), Some(100), None], decimal_dtype) + .into_array() + ); + Ok(()) + } + + #[test] + fn test_take_execute_kernel_handles_decimal_filter_child() -> VortexResult<()> { + let decimal_dtype = DecimalDType::new(19, 2); + + assert_take_execute_maps_child_dtype( + DecimalArray::new( + buffer![100i128, 200, 300, 400, 500], + decimal_dtype, + Validity::NonNullable, + ) + .into_array(), + DecimalArray::new( + buffer![400i128, 100, 300], + decimal_dtype, + Validity::NonNullable, + ) + .into_array(), + ) + } + + #[test] + fn test_take_execute_kernel_handles_fixed_size_list_filter_child() -> VortexResult<()> { + assert_take_execute_maps_child_dtype( + FixedSizeListArray::new( + buffer![10u32, 11, 20, 21, 30, 31, 40, 41, 50, 51].into_array(), + 2, + Validity::NonNullable, + 5, + ) + .into_array(), + FixedSizeListArray::new( + buffer![40u32, 41, 10, 11, 30, 31].into_array(), + 2, + Validity::NonNullable, + 3, + ) + .into_array(), + ) + } + + #[test] + fn test_take_execute_kernel_handles_list_filter_child() -> VortexResult<()> { + assert_take_execute_maps_child_dtype( + ListArray::try_new( + buffer![10u32, 11, 20, 30, 31, 32, 40, 50, 51].into_array(), + buffer![0u32, 2, 3, 6, 7, 9].into_array(), + Validity::NonNullable, + )? + .into_array(), + ListArray::try_new( + buffer![40u32, 10, 11, 30, 31, 32].into_array(), + buffer![0u32, 1, 3, 6].into_array(), + Validity::NonNullable, + )? + .into_array(), + ) + } + + #[test] + fn test_take_execute_kernel_handles_string_filter_child() -> VortexResult<()> { + assert_take_execute_maps_child_dtype( + VarBinViewArray::from_iter_str(["a", "b", "c", "d", "e"]).into_array(), + VarBinViewArray::from_iter_str(["d", "a", "c"]).into_array(), + ) + } + + #[test] + fn test_take_execute_kernel_handles_struct_filter_child() -> VortexResult<()> { + assert_take_execute_maps_child_dtype( + StructArray::try_new( + FieldNames::from(["id", "value"]), + vec![ + buffer![10u32, 20, 30, 40, 50].into_array(), + buffer![100u64, 200, 300, 400, 500].into_array(), + ], + 5, + Validity::NonNullable, + )? + .into_array(), + StructArray::try_new( + FieldNames::from(["id", "value"]), + vec![ + buffer![40u32, 10, 30].into_array(), + buffer![400u64, 100, 300].into_array(), + ], + 3, + Validity::NonNullable, + )? + .into_array(), + ) + } +} diff --git a/vortex-array/src/arrays/filter/vtable.rs b/vortex-array/src/arrays/filter/vtable.rs index 9f70e6e2614..312ed536f77 100644 --- a/vortex-array/src/arrays/filter/vtable.rs +++ b/vortex-array/src/arrays/filter/vtable.rs @@ -34,6 +34,7 @@ use crate::arrays::filter::execute::execute_filter; use crate::arrays::filter::execute::execute_filter_fast_paths; use crate::arrays::filter::rules::PARENT_RULES; use crate::arrays::filter::rules::RULES; +use crate::arrays::filter::take::PARENT_KERNELS; use crate::buffer::BufferHandle; use crate::dtype::DType; use crate::executor::ExecutionCtx; @@ -170,6 +171,15 @@ impl VTable for Filter { PARENT_RULES.evaluate(array, parent, child_idx) } + fn execute_parent( + array: ArrayView<'_, Self>, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } + fn reduce(array: ArrayView<'_, Self>) -> VortexResult> { RULES.evaluate(array) } diff --git a/vortex-buffer/public-api.lock b/vortex-buffer/public-api.lock index 4e0b1190a7a..55dec017580 100644 --- a/vortex-buffer/public-api.lock +++ b/vortex-buffer/public-api.lock @@ -280,6 +280,8 @@ pub fn vortex_buffer::BitBuffer::new_with_offset(vortex_buffer::ByteBuffer, usiz pub fn vortex_buffer::BitBuffer::offset(&self) -> usize +pub fn vortex_buffer::BitBuffer::select(&self, usize) -> core::option::Option + pub fn vortex_buffer::BitBuffer::set_indices(&self) -> arrow_buffer::util::bit_iterator::BitIndexIterator<'_> pub fn vortex_buffer::BitBuffer::set_slices(&self) -> arrow_buffer::util::bit_iterator::BitSliceIterator<'_> diff --git a/vortex-buffer/src/bit/buf.rs b/vortex-buffer/src/bit/buf.rs index d9c30ea1917..9b98d1cd5d6 100644 --- a/vortex-buffer/src/bit/buf.rs +++ b/vortex-buffer/src/bit/buf.rs @@ -25,6 +25,7 @@ use crate::bit::count_ones::count_ones; use crate::bit::get_bit_unchecked; use crate::bit::ops::bitwise_binary_op; use crate::bit::ops::bitwise_unary_op; +use crate::bit::select::bit_select; use crate::buffer; /// An immutable bitset stored as a packed byte buffer. @@ -319,6 +320,16 @@ impl BitBuffer { count_ones(self.buffer.as_slice(), self.offset, self.len) } + /// Returns the position of the `nth` set bit (0-indexed). + /// + /// This is the "select" operation on a bitmap: given a rank `nth`, find + /// which logical bit position holds that rank. + /// + /// Returns `None` if `nth` is greater than or equal to the number of set bits. + pub fn select(&self, nth: usize) -> Option { + bit_select(self.buffer.as_slice(), self.offset, self.len, nth) + } + /// Get the number of unset bits in the buffer. pub fn false_count(&self) -> usize { self.len - self.true_count() diff --git a/vortex-buffer/src/bit/count_ones.rs b/vortex-buffer/src/bit/count_ones.rs index 6d70d47cfa7..df5844a2914 100644 --- a/vortex-buffer/src/bit/count_ones.rs +++ b/vortex-buffer/src/bit/count_ones.rs @@ -22,7 +22,11 @@ pub fn count_ones(bytes: &[u8], offset: usize, len: usize) -> usize { } #[inline] -fn align_offset_len(bytes: &[u8], offset: usize, len: usize) -> (Option, &[u8], Option) { +pub(super) fn align_offset_len( + bytes: &[u8], + offset: usize, + len: usize, +) -> (Option, &[u8], Option) { let start_byte = offset / 8; let start_bit = offset % 8; let end_bit = offset + len; diff --git a/vortex-buffer/src/bit/mod.rs b/vortex-buffer/src/bit/mod.rs index 034be84a18c..37930d788b7 100644 --- a/vortex-buffer/src/bit/mod.rs +++ b/vortex-buffer/src/bit/mod.rs @@ -13,6 +13,7 @@ mod buf_mut; mod count_ones; mod macros; mod ops; +mod select; pub use arrow_buffer::bit_chunk_iterator::BitChunkIterator; pub use arrow_buffer::bit_chunk_iterator::BitChunks; diff --git a/vortex-buffer/src/bit/select.rs b/vortex-buffer/src/bit/select.rs new file mode 100644 index 00000000000..f04983e6ad8 --- /dev/null +++ b/vortex-buffer/src/bit/select.rs @@ -0,0 +1,428 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use super::count_ones::align_offset_len; + +/// Returns the position of the `nth` set bit (0-indexed) within the logical range +/// `[offset, offset + len)` of the given byte slice. +/// +/// The returned position is relative to the logical start (i.e., 0-indexed from `offset`). +/// Returns `None` if `nth` is out of bounds. +/// +/// Uses architecture-specific optimizations: +/// - **aarch64**: NEON `vcnt`-based popcount for the word-level scan. +/// - **x86_64 + BMI2**: `pdep` + `tzcnt` for the final in-word select. +/// - **Scalar fallback**: 4× unrolled word scan with `count_ones`, byte-level narrowing. +#[inline] +pub fn bit_select(bytes: &[u8], offset: usize, len: usize, nth: usize) -> Option { + let (head, middle, tail) = align_offset_len(bytes, offset, len); + let mut remaining = nth; + let mut pos = 0usize; + + // ── partial first byte ────────────────────────────────────────────── + if let Some(head) = head { + let count = head.count_ones() as usize; + if remaining < count { + return Some(select_in_byte(head, remaining)); + } + remaining -= count; + let start_bit = offset % 8; + pos = (8 - start_bit).min(len); + } + + // ── aligned middle bytes ──────────────────────────────────────────── + if !middle.is_empty() { + let (words, tail_bytes) = middle.as_chunks::<8>(); + + let (rem, new_pos, word_idx) = scan_words(words, remaining, pos); + remaining = rem; + pos = new_pos; + + if word_idx < words.len() { + let word = u64::from_le_bytes(words[word_idx]); + return Some(pos + select_in_word(word, remaining)); + } + + // Remaining aligned bytes that don't fill a full u64. + for &byte in tail_bytes { + let count = byte.count_ones() as usize; + if remaining < count { + return Some(pos + select_in_byte(byte, remaining)); + } + remaining -= count; + pos += 8; + } + } + + // ── partial last byte ─────────────────────────────────────────────── + if let Some(tail) = tail + && remaining < tail.count_ones() as usize + { + return Some(pos + select_in_byte(tail, remaining)); + } + + None +} + +// ── Word-level scan ───────────────────────────────────────────────────── + +/// Scan `words` accumulating popcounts. Returns `(remaining, position, word_index)`. +/// +/// If `word_index < words.len()`, the target bit is inside that word and `remaining` +/// is the rank *within* that word. Otherwise all words were consumed. +#[inline] +fn scan_words(words: &[[u8; 8]], remaining: usize, pos: usize) -> (usize, usize, usize) { + scan_words_impl(words, remaining, pos) +} + +// ── aarch64 NEON scan ─────────────────────────────────────────────────── + +#[cfg(target_arch = "aarch64")] +#[allow(clippy::cast_possible_truncation)] // u64 → usize is lossless on aarch64 (64-bit) +#[inline] +fn scan_words_impl( + words: &[[u8; 8]], + mut remaining: usize, + mut pos: usize, +) -> (usize, usize, usize) { + use std::arch::aarch64::vcntq_u8; + use std::arch::aarch64::vgetq_lane_u64; + use std::arch::aarch64::vld1q_u8; + use std::arch::aarch64::vpaddlq_u8; + use std::arch::aarch64::vpaddlq_u16; + use std::arch::aarch64::vpaddlq_u32; + + let mut idx = 0; + + // Process 4 u64 words at a time using two 128-bit NEON registers. + while idx + 4 <= words.len() { + let ptr = words[idx].as_ptr(); + // SAFETY: idx + 4 <= words.len() guarantees 32 contiguous bytes from ptr. + // NEON vld1q_u8 supports unaligned access. + let (count_0, count_1, count_2, count_3) = unsafe { + let pop_lo = vcntq_u8(vld1q_u8(ptr)); + let pop_hi = vcntq_u8(vld1q_u8(ptr.add(16))); + let sums_lo = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(pop_lo))); + let sums_hi = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(pop_hi))); + ( + vgetq_lane_u64::<0>(sums_lo) as usize, + vgetq_lane_u64::<1>(sums_lo) as usize, + vgetq_lane_u64::<0>(sums_hi) as usize, + vgetq_lane_u64::<1>(sums_hi) as usize, + ) + }; + + let total = count_0 + count_1 + count_2 + count_3; + if remaining >= total { + remaining -= total; + pos += 256; + idx += 4; + continue; + } + + // Narrow down to the exact word. + if remaining < count_0 { + return (remaining, pos, idx); + } + remaining -= count_0; + pos += 64; + if remaining < count_1 { + return (remaining, pos, idx + 1); + } + remaining -= count_1; + pos += 64; + if remaining < count_2 { + return (remaining, pos, idx + 2); + } + remaining -= count_2; + pos += 64; + return (remaining, pos, idx + 3); + } + + // Process pairs. + while idx + 2 <= words.len() { + let ptr = words[idx].as_ptr(); + // SAFETY: idx + 2 <= words.len() guarantees 16 contiguous bytes. + let (count_0, count_1) = unsafe { + let pop = vcntq_u8(vld1q_u8(ptr)); + let sums = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(pop))); + ( + vgetq_lane_u64::<0>(sums) as usize, + vgetq_lane_u64::<1>(sums) as usize, + ) + }; + let total = count_0 + count_1; + if remaining < total { + if remaining < count_0 { + return (remaining, pos, idx); + } + return (remaining - count_0, pos + 64, idx + 1); + } + remaining -= total; + pos += 128; + idx += 2; + } + + // Single trailing word. + if idx < words.len() { + let word = u64::from_le_bytes(words[idx]); + let count = word.count_ones() as usize; + if remaining < count { + return (remaining, pos, idx); + } + remaining -= count; + pos += 64; + idx += 1; + } + + (remaining, pos, idx) +} + +// ── Scalar scan (x86_64 / generic) ───────────────────────────────────── + +#[cfg(not(target_arch = "aarch64"))] +#[inline] +fn scan_words_impl( + words: &[[u8; 8]], + mut remaining: usize, + mut pos: usize, +) -> (usize, usize, usize) { + let mut idx = 0; + + // 4× unrolled: the four independent `count_ones` calls pipeline well. + while idx + 4 <= words.len() { + let count_0 = u64::from_le_bytes(words[idx]).count_ones() as usize; + let count_1 = u64::from_le_bytes(words[idx + 1]).count_ones() as usize; + let count_2 = u64::from_le_bytes(words[idx + 2]).count_ones() as usize; + let count_3 = u64::from_le_bytes(words[idx + 3]).count_ones() as usize; + let total = count_0 + count_1 + count_2 + count_3; + + if remaining >= total { + remaining -= total; + pos += 256; + idx += 4; + continue; + } + + if remaining < count_0 { + return (remaining, pos, idx); + } + remaining -= count_0; + pos += 64; + if remaining < count_1 { + return (remaining, pos, idx + 1); + } + remaining -= count_1; + pos += 64; + if remaining < count_2 { + return (remaining, pos, idx + 2); + } + remaining -= count_2; + pos += 64; + return (remaining, pos, idx + 3); + } + + while idx < words.len() { + let word = u64::from_le_bytes(words[idx]); + let count = word.count_ones() as usize; + if remaining < count { + return (remaining, pos, idx); + } + remaining -= count; + pos += 64; + idx += 1; + } + + (remaining, pos, idx) +} + +// ── In-word select ────────────────────────────────────────────────────── + +/// Position of the `nth` set bit inside a u64 (0-indexed, little-endian bit order). +#[inline] +fn select_in_word(word: u64, nth: usize) -> usize { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("bmi2") { + // SAFETY: runtime detection guarantees the required target feature. + return unsafe { select_in_word_bmi2(word, nth) }; + } + } + select_in_word_scalar(word, nth) +} + +/// BMI2: deposit a single bit at the nth set-bit position, then count trailing zeros. +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "bmi2")] +unsafe fn select_in_word_bmi2(word: u64, nth: usize) -> usize { + use std::arch::x86_64::_pdep_u64; + use std::arch::x86_64::_tzcnt_u64; + + use vortex_error::VortexExpect; + + usize::try_from(unsafe { _tzcnt_u64(_pdep_u64(1u64 << nth, word)) }) + .vortex_expect("safe to convert tzcnt result to usize") +} + +/// Scalar: narrow to the correct byte, then clear `nth` lowest set bits and trailing-zeros. +#[inline] +fn select_in_word_scalar(word: u64, mut nth: usize) -> usize { + let bytes = word.to_le_bytes(); + let mut bit_offset = 0usize; + for &byte in &bytes { + let count = byte.count_ones() as usize; + if nth < count { + return bit_offset + select_in_byte(byte, nth); + } + nth -= count; + bit_offset += 8; + } + unreachable!("select_in_word: nth exceeds popcount") +} + +// ── In-byte select ────────────────────────────────────────────────────── + +/// Position of the `nth` set bit inside a byte (0-indexed, LSB-first). +/// +/// Clears the lowest `nth` set bits, then uses `trailing_zeros`. +#[inline] +fn select_in_byte(byte: u8, nth: usize) -> usize { + debug_assert!(nth < byte.count_ones() as usize); + let mut bits = u32::from(byte); + for _ in 0..nth { + bits &= bits - 1; // clear lowest set bit + } + bits.trailing_zeros() as usize +} + +#[cfg(test)] +mod tests { + #![allow(clippy::cast_possible_truncation)] + + use rstest::rstest; + + use super::*; + + #[test] + fn test_select_all_set() { + // Every bit is set — select(n) == n. + let buf = [0xFFu8; 16]; // 128 bits, all set + for nth in 0..128 { + assert_eq!(bit_select(&buf, 0, 128, nth), Some(nth), "nth={nth}"); + } + } + + #[test] + fn test_select_every_other() { + // 0b01010101 repeated: bits 0,2,4,6 of each byte are set. + let buf = [0x55u8; 16]; // 128 bits, 64 set + for nth in 0..64 { + assert_eq!(bit_select(&buf, 0, 128, nth), Some(nth * 2), "nth={nth}"); + } + } + + #[test] + fn test_select_single_bit() { + // Only bit 42 is set. + let mut buf = [0u8; 16]; + buf[42 / 8] |= 1 << (42 % 8); + assert_eq!(bit_select(&buf, 0, 128, 0), Some(42)); + } + + #[test] + fn test_select_out_of_bounds_returns_none() { + let buf = [0b0001_0100u8]; + assert_eq!(bit_select(&buf, 0, 8, 0), Some(2)); + assert_eq!(bit_select(&buf, 0, 8, 1), Some(4)); + assert_eq!(bit_select(&buf, 0, 8, 2), None); + } + + #[rstest] + #[case(0, 128)] + #[case(3, 100)] + #[case(7, 50)] + #[case(1, 7)] + #[case(5, 5)] + #[case(0, 1)] + #[case(0, 64)] + #[case(1, 64)] + #[case(0, 65)] + #[case(3, 256)] + fn test_select_agrees_with_naive(#[case] offset: usize, #[case] len: usize) { + let total_bits = offset + len; + let total_bytes = total_bits.div_ceil(8); + // Deterministic pattern with moderate density. + let buf: Vec = (0..total_bytes) + .map(|i| ((i.wrapping_mul(0x9E) ^ 0xA5) & 0xFF) as u8) + .collect(); + + // Collect set-bit positions naively. + let expected: Vec = (0..len) + .filter(|&i| { + let phys = offset + i; + (buf[phys / 8] >> (phys % 8)) & 1 == 1 + }) + .collect(); + + for (nth, &expected_pos) in expected.iter().enumerate() { + assert_eq!( + bit_select(&buf, offset, len, nth), + Some(expected_pos), + "offset={offset} len={len} nth={nth}" + ); + } + } + + #[test] + fn test_select_large_buffer() { + // ~64 KB buffer, ~50% density. + let len = 65_536 * 8; + let buf: Vec = (0u32..65_536) + .map(|i| ((i.wrapping_mul(0x37) ^ 0xBC) & 0xFF) as u8) + .collect(); + + let true_count = buf.iter().map(|b| b.count_ones() as usize).sum::(); + + // Spot-check a few positions. + let first = bit_select(&buf, 0, len, 0); + let last = bit_select(&buf, 0, len, true_count - 1); + let first = first.expect("buffer has at least one set bit"); + let last = last.expect("true_count - 1 is in bounds"); + assert!(first < len); + assert!(last < len); + assert!(first <= last); + + // Verify the found positions are actually set. + assert_ne!(buf[first / 8] & (1 << (first % 8)), 0); + assert_ne!(buf[last / 8] & (1 << (last % 8)), 0); + } + + #[test] + fn test_select_in_word_basic() { + // 0b1010_1010 = 0xAA — bits 1,3,5,7 are set. + let word = 0x00000000_000000AAu64; + assert_eq!(select_in_word(word, 0), 1); + assert_eq!(select_in_word(word, 1), 3); + assert_eq!(select_in_word(word, 2), 5); + assert_eq!(select_in_word(word, 3), 7); + } + + #[test] + fn test_select_in_word_all_set() { + let word = u64::MAX; + for nth in 0..64 { + assert_eq!(select_in_word(word, nth), nth, "nth={nth}"); + } + } + + #[test] + fn test_select_in_byte_basic() { + assert_eq!(select_in_byte(0b1010_1010, 0), 1); + assert_eq!(select_in_byte(0b1010_1010, 1), 3); + assert_eq!(select_in_byte(0b1010_1010, 2), 5); + assert_eq!(select_in_byte(0b1010_1010, 3), 7); + assert_eq!(select_in_byte(0b0000_0001, 0), 0); + assert_eq!(select_in_byte(0b1000_0000, 0), 7); + assert_eq!(select_in_byte(0xFF, 7), 7); + } +} diff --git a/vortex-mask/src/lib.rs b/vortex-mask/src/lib.rs index 339f9a12d74..30c05f8f410 100644 --- a/vortex-mask/src/lib.rs +++ b/vortex-mask/src/lib.rs @@ -457,7 +457,23 @@ impl Mask { if let Some(slices) = values.slices.get() { return slices.last().map(|(_, end)| end - 1); } - values.buffer.set_slices().last().map(|(_, end)| end - 1) + + if values.true_count == 0 { + return None; + } + + Some( + values + .buffer + .select(values.true_count - 1) + .unwrap_or_else(|| { + vortex_panic!( + "Rank {} out of bounds for mask with true count {}", + values.true_count - 1, + values.true_count + ) + }), + ) } } } @@ -473,8 +489,19 @@ impl Mask { match &self { Self::AllTrue(_) => n, Self::AllFalse(_) => unreachable!("no true values in all-false mask"), - // TODO(joe): optimize this function - Self::Values(values) => values.indices()[n], + Self::Values(values) => { + if let Some(indices) = values.indices.get() { + return indices[n]; + } + + values.buffer.select(n).unwrap_or_else(|| { + vortex_panic!( + "Rank {} out of bounds for mask with true count {}", + values.true_count - 1, + values.true_count + ) + }) + } } }