diff --git a/encodings/sequence/src/array.rs b/encodings/sequence/src/array.rs index 56e62764ce4..cb9c25c10b5 100644 --- a/encodings/sequence/src/array.rs +++ b/encodings/sequence/src/array.rs @@ -28,7 +28,6 @@ use vortex_array::dtype::PType; use vortex_array::expr::stats::Precision as StatPrecision; use vortex_array::expr::stats::Stat; use vortex_array::match_each_integer_ptype; -use vortex_array::match_each_native_ptype; use vortex_array::match_each_pvalue; use vortex_array::scalar::PValue; use vortex_array::scalar::Scalar; @@ -99,7 +98,9 @@ impl SequenceData { ) } - /// Constructs a sequence array using two integer values (with the same ptype). + /// Constructs a sequence array using two integer values. + /// + /// Arithmetic uses an inferred i64/u64 ptype based on base and multiplier. pub(crate) fn try_new( base: PValue, multiplier: PValue, @@ -109,7 +110,7 @@ impl SequenceData { ) -> VortexResult { let dtype = DType::Primitive(ptype, nullability); Self::validate(base, multiplier, &dtype, length)?; - let (base, multiplier) = Self::normalize(base, multiplier, ptype)?; + let (base, multiplier) = Self::normalize(base, multiplier)?; Ok(unsafe { Self::new_unchecked(base, multiplier) }) } @@ -125,20 +126,60 @@ impl SequenceData { }; if !ptype.is_int() { - vortex_bail!("only integer ptype are supported in SequenceArray currently") + vortex_bail!("only integer ptypes are supported in SequenceArray currently") } vortex_ensure!(length > 0, "SequenceArray length must be greater than zero"); - Self::try_last(base, multiplier, *ptype, length).map_err(|e| { + let last = Self::try_last(base, multiplier, length).map_err(|e| { e.with_context(format!( "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ", )) })?; + match_each_integer_ptype!(*ptype, |P| { + base.cast::

()?; + last.cast::

()?; + VortexResult::Ok(()) + })?; + Ok(()) } - fn normalize(base: PValue, multiplier: PValue, ptype: PType) -> VortexResult<(PValue, PValue)> { + fn infer_calculation_ptype(base: PValue, multiplier: PValue) -> VortexResult { + if !base.ptype().is_int() || !multiplier.ptype().is_int() { + vortex_bail!("only integer ptypes are supported in SequenceArray currently") + } + + if base.ptype().is_signed_int() || multiplier.ptype().is_signed_int() { + Ok(PType::I64) + } else { + Ok(PType::U64) + } + } + + fn infer_calculation_ptype_from_proto( + base: &vortex_proto::scalar::ScalarValue, + multiplier: &vortex_proto::scalar::ScalarValue, + ) -> VortexResult { + use vortex_proto::scalar::scalar_value::Kind; + let base_kind = base + .kind + .as_ref() + .ok_or_else(|| vortex_err!("base value missing kind"))?; + let multiplier_kind = multiplier + .kind + .as_ref() + .ok_or_else(|| vortex_err!("multiplier value missing kind"))?; + + match (base_kind, multiplier_kind) { + (Kind::Int64Value(_), _) | (_, Kind::Int64Value(_)) => Ok(PType::I64), + (Kind::Uint64Value(_), Kind::Uint64Value(_)) => Ok(PType::U64), + _ => vortex_bail!("only integer ptypes are supported in SequenceArray currently"), + } + } + + fn normalize(base: PValue, multiplier: PValue) -> VortexResult<(PValue, PValue)> { + let ptype = Self::infer_calculation_ptype(base, multiplier)?; match_each_integer_ptype!(ptype, |P| { Ok(( PValue::from(base.cast::

()?), @@ -158,7 +199,7 @@ impl SequenceData { Self { base, multiplier } } - pub fn ptype(&self) -> PType { + pub(crate) fn calculation_ptype(&self) -> PType { self.base.ptype() } @@ -170,6 +211,10 @@ impl SequenceData { self.multiplier } + pub(crate) fn cast_value(value: PValue, ptype: PType) -> VortexResult { + match_each_integer_ptype!(ptype, |O| { Ok(PValue::from(value.cast::()?)) }) + } + pub fn into_parts(self) -> SequenceDataParts { SequenceDataParts { base: self.base, @@ -181,9 +226,9 @@ impl SequenceData { pub(crate) fn try_last( base: PValue, multiplier: PValue, - ptype: PType, length: usize, ) -> VortexResult { + let ptype = Self::infer_calculation_ptype(base, multiplier)?; match_each_integer_ptype!(ptype, |P| { let len_t =

::from_usize(length - 1) .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?; @@ -199,7 +244,7 @@ impl SequenceData { } pub(crate) fn index_value(&self, idx: usize) -> PValue { - match_each_native_ptype!(self.ptype(), |P| { + match_each_integer_ptype!(self.calculation_ptype(), |P| { let base = self.base.cast::

().vortex_expect("must be able to cast"); let multiplier = self .multiplier @@ -291,14 +336,22 @@ impl VTable for Sequence { ); let metadata = SequenceMetadata::decode(metadata)?; - let ptype = dtype.as_ptype(); + let base_metadata = metadata + .base + .as_ref() + .ok_or_else(|| vortex_err!("base required"))?; + + let multiplier_metadata = metadata + .multiplier + .as_ref() + .ok_or_else(|| vortex_err!("multiplier required"))?; + + let ptype = + SequenceData::infer_calculation_ptype_from_proto(base_metadata, multiplier_metadata)?; // We go via Scalar to validate that the value is valid for the ptype. let base = Scalar::from_proto_value( - metadata - .base - .as_ref() - .ok_or_else(|| vortex_err!("base required"))?, + base_metadata, &DType::Primitive(ptype, NonNullable), session, )? @@ -307,10 +360,7 @@ impl VTable for Sequence { .vortex_expect("sequence array base should be a non-nullable primitive"); let multiplier = Scalar::from_proto_value( - metadata - .multiplier - .as_ref() - .ok_or_else(|| vortex_err!("multiplier required"))?, + multiplier_metadata, &DType::Primitive(ptype, NonNullable), session, )? @@ -345,10 +395,8 @@ impl OperationsVTable for Sequence { index: usize, _ctx: &mut ExecutionCtx, ) -> VortexResult { - Scalar::try_new( - array.dtype().clone(), - Some(ScalarValue::Primitive(array.index_value(index))), - ) + let value = SequenceData::cast_value(array.index_value(index), array.dtype().as_ptype())?; + Scalar::try_new(array.dtype().clone(), Some(ScalarValue::Primitive(value))) } } @@ -397,8 +445,9 @@ impl Sequence { length: usize, ) -> SequenceArray { let dtype = DType::Primitive(ptype, nullability); - let (base, multiplier) = SequenceData::normalize(base, multiplier, ptype) - .vortex_expect("SequenceArray parts must be normalized to the target ptype"); + let (base, multiplier) = SequenceData::normalize(base, multiplier).vortex_expect( + "SequenceArray parts must be normalized to the inferred arithmetic ptype", + ); let stats = Self::stats(multiplier); let data = unsafe { SequenceData::new_unchecked(base, multiplier) }; unsafe { Array::from_parts_unchecked(ArrayParts::new(Sequence, dtype, length, data)) } diff --git a/encodings/sequence/src/compress.rs b/encodings/sequence/src/compress.rs index b20980cf132..7b6de5bff26 100644 --- a/encodings/sequence/src/compress.rs +++ b/encodings/sequence/src/compress.rs @@ -5,20 +5,22 @@ use std::ops::Add; use num_traits::CheckedAdd; use num_traits::CheckedSub; +use num_traits::cast::NumCast; use vortex_array::ArrayRef; use vortex_array::ArrayView; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::Primitive; use vortex_array::arrays::PrimitiveArray; +use vortex_array::dtype::IntegerPType; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; use vortex_array::match_each_integer_ptype; -use vortex_array::match_each_native_ptype; use vortex_array::scalar::PValue; use vortex_array::validity::Validity; use vortex_buffer::BufferMut; use vortex_buffer::trusted_len::TrustedLen; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use crate::Sequence; @@ -59,24 +61,32 @@ unsafe impl> TrustedLen for SequenceIter {} /// Decompresses a [`SequenceArray`] into a [`PrimitiveArray`]. #[inline] pub fn sequence_decompress(array: &SequenceArray) -> VortexResult { - fn decompress_inner( - base: P, - multiplier: P, + fn decompress_inner( + base: C, + multiplier: C, len: usize, nullability: Nullability, ) -> PrimitiveArray { - let values = BufferMut::from_trusted_len_iter(SequenceIter { - acc: base, - step: multiplier, - remaining: len, - }); + let values = BufferMut::from_trusted_len_iter( + SequenceIter { + acc: base, + step: multiplier, + remaining: len, + } + .map(|value| { + ::from(value) + .vortex_expect("validated sequence values must fit output ptype") + }), + ); PrimitiveArray::new(values, Validity::from(nullability)) } - let prim = match_each_native_ptype!(array.ptype(), |P| { - let base = array.base().cast::

()?; - let multiplier = array.multiplier().cast::

()?; - decompress_inner(base, multiplier, array.len(), array.dtype().nullability()) + let prim = match_each_integer_ptype!(array.calculation_ptype(), |C| { + let base = array.base().cast::()?; + let multiplier = array.multiplier().cast::()?; + match_each_integer_ptype!(array.dtype().as_ptype(), |O| { + decompress_inner::(base, multiplier, array.len(), array.dtype().nullability()) + }) }); Ok(prim.into_array()) } @@ -131,7 +141,7 @@ fn encode_primitive_array + CheckedAdd + CheckedSu return Ok(None); } - if SequenceData::try_last(base.into(), multiplier.into(), P::PTYPE, slice.len()).is_err() { + if SequenceData::try_last(base.into(), multiplier.into(), slice.len()).is_err() { // If the last value is out of range, we cannot encode return Ok(None); } diff --git a/encodings/sequence/src/compute/cast.rs b/encodings/sequence/src/compute/cast.rs index e0a63d0c922..41f00981238 100644 --- a/encodings/sequence/src/compute/cast.rs +++ b/encodings/sequence/src/compute/cast.rs @@ -5,12 +5,8 @@ use vortex_array::ArrayRef; use vortex_array::ArrayView; use vortex_array::IntoArray; use vortex_array::dtype::DType; -use vortex_array::dtype::Nullability; -use vortex_array::scalar::Scalar; -use vortex_array::scalar::ScalarValue; use vortex_array::scalar_fn::fns::cast::CastReduce; use vortex_error::VortexResult; -use vortex_error::vortex_err; use crate::Sequence; impl CastReduce for Sequence { @@ -26,62 +22,21 @@ impl CastReduce for Sequence { return Ok(None); } - // Check if this is just a nullability change - if array.ptype() == *target_ptype && array.dtype().nullability() != *target_nullability { - // For SequenceArray, we can just create a new one with the same parameters - // but different nullability - return Ok(Some( - Sequence::try_new( - array.base(), - array.multiplier(), - *target_ptype, - *target_nullability, - array.len(), - )? - .into_array(), - )); - } - - // For type changes, we need to cast the base and multiplier - if array.ptype() != *target_ptype { - // Create scalars from PValues and cast them - let base_scalar = Scalar::try_new( - DType::Primitive(array.ptype(), Nullability::NonNullable), - Some(ScalarValue::Primitive(array.base())), - )?; - let multiplier_scalar = Scalar::try_new( - DType::Primitive(array.ptype(), Nullability::NonNullable), - Some(ScalarValue::Primitive(array.multiplier())), - )?; - - let new_base_scalar = - base_scalar.cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?; - let new_multiplier_scalar = multiplier_scalar - .cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?; - - // Extract PValues from the casted scalars - let new_base = new_base_scalar - .as_primitive() - .pvalue() - .ok_or_else(|| vortex_err!("Cast resulted in null base value"))?; - let new_multiplier = new_multiplier_scalar - .as_primitive() - .pvalue() - .ok_or_else(|| vortex_err!("Cast resulted in null multiplier value"))?; - - return Ok(Some( - Sequence::try_new( - new_base, - new_multiplier, - *target_ptype, - *target_nullability, - array.len(), - )? - .into_array(), - )); + if array.dtype() == dtype { + return Ok(None); } - Ok(None) + // try_new also validates that the produced values fit the target ptype. + Ok(Some( + Sequence::try_new( + array.base(), + array.multiplier(), + *target_ptype, + *target_nullability, + array.len(), + )? + .into_array(), + )) } } @@ -99,6 +54,9 @@ mod tests { use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; + use vortex_array::scalar::Scalar; + use vortex_array::scalar::ScalarValue; + use vortex_error::VortexResult; use vortex_session::VortexSession; use crate::Sequence; @@ -197,14 +155,40 @@ mod tests { ); } + #[test] + fn test_cast_sequence_keeps_arithmetic_ptype_but_scalar_uses_output_dtype() -> VortexResult<()> + { + // Cast the public dtype to u8 + let casted = Sequence::try_new_typed(100i32, -10i32, Nullability::NonNullable, 5)? + .into_array() + .cast(DType::Primitive(PType::U8, Nullability::NonNullable))?; + + let sequence = casted + .as_typed::() + .expect("integer sequence cast should preserve SequenceArray"); + assert_eq!(sequence.calculation_ptype(), PType::I64); + assert_eq!( + sequence.dtype(), + &DType::Primitive(PType::U8, Nullability::NonNullable) + ); + + let scalar = casted.execute_scalar(1, &mut SESSION.create_execution_ctx())?; + assert_eq!( + scalar, + Scalar::try_new( + DType::Primitive(PType::U8, Nullability::NonNullable), + Some(ScalarValue::from(90u8)), + )? + ); + + Ok(()) + } + #[rstest] #[case::i32(Sequence::try_new_typed(0i32, 1i32, Nullability::NonNullable, 5).unwrap())] #[case::u64(Sequence::try_new_typed(1000u64, 100u64, Nullability::NonNullable, 4).unwrap())] - // TODO(DK): SequenceArray does not actually conform. You cannot cast this array to u8 even - // though all its values are representable therein. - // - // #[case::negative_step(Sequence::try_new_typed(100i32, -10i32, Nullability::NonNullable, - // 5).unwrap())] + #[case::negative_step(Sequence::try_new_typed(100i32, -10i32, Nullability::NonNullable, + 5).unwrap())] #[case::single(Sequence::try_new_typed(42i64, 0i64, Nullability::NonNullable, 1).unwrap())] #[case::constant(Sequence::try_new_typed( 100i32, diff --git a/encodings/sequence/src/compute/filter.rs b/encodings/sequence/src/compute/filter.rs index 3ef24188fbf..bb15114f8b5 100644 --- a/encodings/sequence/src/compute/filter.rs +++ b/encodings/sequence/src/compute/filter.rs @@ -1,14 +1,15 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use num_traits::NumCast; use vortex_array::ArrayRef; use vortex_array::ArrayView; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::filter::FilterKernel; -use vortex_array::dtype::NativePType; -use vortex_array::match_each_native_ptype; +use vortex_array::dtype::IntegerPType; +use vortex_array::match_each_integer_ptype; use vortex_array::validity::Validity; use vortex_buffer::BufferMut; use vortex_error::VortexExpect; @@ -24,22 +25,30 @@ impl FilterKernel for Sequence { _ctx: &mut ExecutionCtx, ) -> VortexResult> { let validity = Validity::from(array.dtype().nullability()); - match_each_native_ptype!(array.ptype(), |P| { - let mul = array.multiplier().cast::

()?; - let base = array.base().cast::

()?; - Ok(Some(filter_impl(mul, base, mask, validity))) + match_each_integer_ptype!(array.calculation_ptype(), |C| { + let mul = array.multiplier().cast::()?; + let base = array.base().cast::()?; + match_each_integer_ptype!(array.dtype().as_ptype(), |O| { + Ok(Some(filter_impl::(mul, base, mask, validity))) + }) }) } } -fn filter_impl(mul: T, base: T, mask: &Mask, validity: Validity) -> ArrayRef { +fn filter_impl( + mul: C, + base: C, + mask: &Mask, + validity: Validity, +) -> ArrayRef { let mask_values = mask .values() .vortex_expect("FilterKernel precondition: mask is Mask::Values"); - let mut buffer = BufferMut::::with_capacity(mask_values.true_count()); + let mut buffer = BufferMut::::with_capacity(mask_values.true_count()); buffer.extend(mask_values.indices().iter().map(|&idx| { - let i = T::from_usize(idx).vortex_expect("all valid indices fit"); - base + i * mul + let i = C::from_usize(idx).vortex_expect("all valid indices fit"); + ::from(base + i * mul) + .vortex_expect("valid sequence values must fit output ptype") })); PrimitiveArray::new(buffer.freeze(), validity).into_array() } diff --git a/encodings/sequence/src/compute/min_max.rs b/encodings/sequence/src/compute/min_max.rs index 2ff8b2abd58..d17ab8b2e6d 100644 --- a/encodings/sequence/src/compute/min_max.rs +++ b/encodings/sequence/src/compute/min_max.rs @@ -48,7 +48,7 @@ impl DynAggregateKernel for SequenceMinMaxKernel { } let base = seq.base(); - let last = SequenceData::try_last(base, seq.multiplier(), seq.ptype(), seq.len())?; + let last = SequenceData::try_last(base, seq.multiplier(), seq.len())?; // Determine min and max based on multiplier direction. // For unsigned types, multiplier is always >= 0. @@ -65,7 +65,11 @@ impl DynAggregateKernel for SequenceMinMaxKernel { float: |_v| { unreachable!("float multiplier not supported for SequenceArray") } ); - let non_nullable_dtype = DType::Primitive(seq.ptype(), Nullability::NonNullable); + let output_ptype = seq.dtype().as_ptype(); + let min_pvalue = SequenceData::cast_value(min_pvalue, output_ptype)?; + let max_pvalue = SequenceData::cast_value(max_pvalue, output_ptype)?; + + let non_nullable_dtype = DType::Primitive(output_ptype, Nullability::NonNullable); let min_scalar = Scalar::try_new( non_nullable_dtype.clone(), Some(ScalarValue::Primitive(min_pvalue)), diff --git a/encodings/sequence/src/compute/slice.rs b/encodings/sequence/src/compute/slice.rs index c2b64b68ef4..8d38c4076b8 100644 --- a/encodings/sequence/src/compute/slice.rs +++ b/encodings/sequence/src/compute/slice.rs @@ -19,7 +19,7 @@ impl SliceReduce for Sequence { Sequence::new_unchecked( array.index_value(range.start), array.multiplier(), - array.ptype(), + array.dtype().as_ptype(), array.dtype().nullability(), range.len(), ) diff --git a/encodings/sequence/src/compute/take.rs b/encodings/sequence/src/compute/take.rs index 4b056d0ae7f..d13bd269bd0 100644 --- a/encodings/sequence/src/compute/take.rs +++ b/encodings/sequence/src/compute/take.rs @@ -11,10 +11,8 @@ use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::dict::TakeExecute; use vortex_array::dtype::DType; use vortex_array::dtype::IntegerPType; -use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; use vortex_array::match_each_integer_ptype; -use vortex_array::match_each_native_ptype; use vortex_array::scalar::Scalar; use vortex_array::validity::Validity; use vortex_buffer::Buffer; @@ -26,9 +24,9 @@ use vortex_mask::Mask; use crate::Sequence; -fn take_inner( - mul: S, - base: S, +fn take_inner( + mul: C, + base: C, indices: &[T], indices_mask: Mask, result_nullability: Nullability, @@ -40,14 +38,15 @@ fn take_inner( if i.as_() >= len { vortex_panic!(OutOfBounds: i.as_(), 0, len); } - let i = ::from::(*i).vortex_expect("all indices fit"); - base + i * mul + let i = ::from::(*i).vortex_expect("all indices fit"); + ::from(base + i * mul) + .vortex_expect("validated sequence values must fit output ptype") })), Validity::from(result_nullability), ) .into_array(), AllOr::None => ConstantArray::new( - Scalar::null(DType::Primitive(S::PTYPE, Nullability::Nullable)), + Scalar::null(DType::Primitive(O::PTYPE, Nullability::Nullable)), indices.len(), ) .into_array(), @@ -60,10 +59,11 @@ fn take_inner( } let i = - ::from::(*i).vortex_expect("all valid indices fit"); - base + i * mul + ::from::(*i).vortex_expect("all valid indices fit"); + ::from(base + i * mul) + .vortex_expect("validated sequence values must fit output ptype") } else { - S::zero() + O::zero() } })); PrimitiveArray::new(buffer, Validity::from(b.clone())).into_array() @@ -71,6 +71,45 @@ fn take_inner( } } +fn take_with_typed_indices( + array: ArrayView<'_, Sequence>, + indices: &[T], + indices_mask: Mask, + result_nullability: Nullability, +) -> VortexResult { + match_each_integer_ptype!(array.calculation_ptype(), |C| { + let mul = array.multiplier().cast::()?; + let base = array.base().cast::()?; + + match_each_integer_ptype!(array.dtype().as_ptype(), |O| { + Ok(take_inner::( + mul, + base, + indices, + indices_mask, + result_nullability, + array.len(), + )) + }) + }) +} + +fn take_sequence( + array: ArrayView<'_, Sequence>, + indices: &PrimitiveArray, + indices_mask: Mask, + result_nullability: Nullability, +) -> VortexResult { + match_each_integer_ptype!(indices.ptype(), |T| { + take_with_typed_indices::( + array, + indices.as_slice::(), + indices_mask, + result_nullability, + ) + }) +} + impl TakeExecute for Sequence { fn take( array: ArrayView<'_, Self>, @@ -81,21 +120,7 @@ impl TakeExecute for Sequence { let indices = indices.clone().execute::(ctx)?; let result_nullability = array.dtype().nullability() | indices.dtype().nullability(); - match_each_integer_ptype!(indices.ptype(), |T| { - let indices = indices.as_slice::(); - match_each_native_ptype!(array.ptype(), |S| { - let mul = array.multiplier().cast::()?; - let base = array.base().cast::()?; - Ok(Some(take_inner( - mul, - base, - indices, - mask, - result_nullability, - array.len(), - ))) - }) - }) + take_sequence(array, &indices, mask, result_nullability).map(Some) } }