Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 72 additions & 23 deletions encodings/sequence/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ use vortex_array::dtype::PType;
use vortex_array::expr::stats::Precision as StatPrecision;
use vortex_array::expr::stats::Stat;
use vortex_array::match_each_integer_ptype;
use vortex_array::match_each_native_ptype;
use vortex_array::match_each_pvalue;
use vortex_array::scalar::PValue;
use vortex_array::scalar::Scalar;
Expand Down Expand Up @@ -99,7 +98,9 @@ impl SequenceData {
)
}

/// Constructs a sequence array using two integer values (with the same ptype).
/// Constructs a sequence array using two integer values.
///
/// Arithmetic uses an inferred i64/u64 ptype based on base and multiplier.
pub(crate) fn try_new(
base: PValue,
multiplier: PValue,
Expand All @@ -109,7 +110,7 @@ impl SequenceData {
) -> VortexResult<Self> {
let dtype = DType::Primitive(ptype, nullability);
Self::validate(base, multiplier, &dtype, length)?;
let (base, multiplier) = Self::normalize(base, multiplier, ptype)?;
let (base, multiplier) = Self::normalize(base, multiplier)?;

Ok(unsafe { Self::new_unchecked(base, multiplier) })
}
Expand All @@ -125,20 +126,60 @@ impl SequenceData {
};

if !ptype.is_int() {
vortex_bail!("only integer ptype are supported in SequenceArray currently")
vortex_bail!("only integer ptypes are supported in SequenceArray currently")
}

vortex_ensure!(length > 0, "SequenceArray length must be greater than zero");
Self::try_last(base, multiplier, *ptype, length).map_err(|e| {
let last = Self::try_last(base, multiplier, length).map_err(|e| {
e.with_context(format!(
"final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ",
))
})?;

match_each_integer_ptype!(*ptype, |P| {
base.cast::<P>()?;
last.cast::<P>()?;
VortexResult::Ok(())
})?;

Ok(())
}

fn normalize(base: PValue, multiplier: PValue, ptype: PType) -> VortexResult<(PValue, PValue)> {
fn infer_calculation_ptype(base: PValue, multiplier: PValue) -> VortexResult<PType> {
if !base.ptype().is_int() || !multiplier.ptype().is_int() {
vortex_bail!("only integer ptypes are supported in SequenceArray currently")
}

if base.ptype().is_signed_int() || multiplier.ptype().is_signed_int() {
Ok(PType::I64)
} else {
Ok(PType::U64)
}
}

fn infer_calculation_ptype_from_proto(
base: &vortex_proto::scalar::ScalarValue,
multiplier: &vortex_proto::scalar::ScalarValue,
) -> VortexResult<PType> {
use vortex_proto::scalar::scalar_value::Kind;
let base_kind = base
.kind
.as_ref()
.ok_or_else(|| vortex_err!("base value missing kind"))?;
let multiplier_kind = multiplier
.kind
.as_ref()
.ok_or_else(|| vortex_err!("multiplier value missing kind"))?;

match (base_kind, multiplier_kind) {
(Kind::Int64Value(_), _) | (_, Kind::Int64Value(_)) => Ok(PType::I64),
(Kind::Uint64Value(_), Kind::Uint64Value(_)) => Ok(PType::U64),
_ => vortex_bail!("only integer ptypes are supported in SequenceArray currently"),
}
}

fn normalize(base: PValue, multiplier: PValue) -> VortexResult<(PValue, PValue)> {
let ptype = Self::infer_calculation_ptype(base, multiplier)?;
match_each_integer_ptype!(ptype, |P| {
Ok((
PValue::from(base.cast::<P>()?),
Expand Down Expand Up @@ -170,6 +211,10 @@ impl SequenceData {
self.multiplier
}

pub(crate) fn cast_value(value: PValue, ptype: PType) -> VortexResult<PValue> {
match_each_integer_ptype!(ptype, |O| { Ok(PValue::from(value.cast::<O>()?)) })
}

pub fn into_parts(self) -> SequenceDataParts {
SequenceDataParts {
base: self.base,
Expand All @@ -181,9 +226,9 @@ impl SequenceData {
pub(crate) fn try_last(
base: PValue,
multiplier: PValue,
ptype: PType,
length: usize,
) -> VortexResult<PValue> {
let ptype = Self::infer_calculation_ptype(base, multiplier)?;
match_each_integer_ptype!(ptype, |P| {
let len_t = <P>::from_usize(length - 1)
.ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
Expand All @@ -199,7 +244,7 @@ impl SequenceData {
}

pub(crate) fn index_value(&self, idx: usize) -> PValue {
match_each_native_ptype!(self.ptype(), |P| {
match_each_integer_ptype!(self.ptype(), |P| {
let base = self.base.cast::<P>().vortex_expect("must be able to cast");
let multiplier = self
.multiplier
Expand Down Expand Up @@ -291,14 +336,22 @@ impl VTable for Sequence {
);
let metadata = SequenceMetadata::decode(metadata)?;

let ptype = dtype.as_ptype();
let base_metadata = metadata
.base
.as_ref()
.ok_or_else(|| vortex_err!("base required"))?;

let multiplier_metadata = metadata
.multiplier
.as_ref()
.ok_or_else(|| vortex_err!("multiplier required"))?;

let ptype =
SequenceData::infer_calculation_ptype_from_proto(base_metadata, multiplier_metadata)?;

// We go via Scalar to validate that the value is valid for the ptype.
let base = Scalar::from_proto_value(
metadata
.base
.as_ref()
.ok_or_else(|| vortex_err!("base required"))?,
base_metadata,
&DType::Primitive(ptype, NonNullable),
session,
)?
Expand All @@ -307,10 +360,7 @@ impl VTable for Sequence {
.vortex_expect("sequence array base should be a non-nullable primitive");

let multiplier = Scalar::from_proto_value(
metadata
.multiplier
.as_ref()
.ok_or_else(|| vortex_err!("multiplier required"))?,
multiplier_metadata,
&DType::Primitive(ptype, NonNullable),
session,
)?
Expand Down Expand Up @@ -345,10 +395,8 @@ impl OperationsVTable<Sequence> for Sequence {
index: usize,
_ctx: &mut ExecutionCtx,
) -> VortexResult<Scalar> {
Scalar::try_new(
array.dtype().clone(),
Some(ScalarValue::Primitive(array.index_value(index))),
)
let value = SequenceData::cast_value(array.index_value(index), array.dtype().as_ptype())?;
Scalar::try_new(array.dtype().clone(), Some(ScalarValue::Primitive(value)))
}
}

Expand Down Expand Up @@ -397,8 +445,9 @@ impl Sequence {
length: usize,
) -> SequenceArray {
let dtype = DType::Primitive(ptype, nullability);
let (base, multiplier) = SequenceData::normalize(base, multiplier, ptype)
.vortex_expect("SequenceArray parts must be normalized to the target ptype");
let (base, multiplier) = SequenceData::normalize(base, multiplier).vortex_expect(
"SequenceArray parts must be normalized to the inferred arithmetic ptype",
);
let stats = Self::stats(multiplier);
let data = unsafe { SequenceData::new_unchecked(base, multiplier) };
unsafe { Array::from_parts_unchecked(ArrayParts::new(Sequence, dtype, length, data)) }
Expand Down
38 changes: 24 additions & 14 deletions encodings/sequence/src/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,22 @@ use std::ops::Add;

use num_traits::CheckedAdd;
use num_traits::CheckedSub;
use num_traits::cast::NumCast;
use vortex_array::ArrayRef;
use vortex_array::ArrayView;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::arrays::Primitive;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::dtype::IntegerPType;
use vortex_array::dtype::NativePType;
use vortex_array::dtype::Nullability;
use vortex_array::match_each_integer_ptype;
use vortex_array::match_each_native_ptype;
use vortex_array::scalar::PValue;
use vortex_array::validity::Validity;
use vortex_buffer::BufferMut;
use vortex_buffer::trusted_len::TrustedLen;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;

use crate::Sequence;
Expand Down Expand Up @@ -59,24 +61,32 @@ unsafe impl<T: Copy + Add<Output = T>> TrustedLen for SequenceIter<T> {}
/// Decompresses a [`SequenceArray`] into a [`PrimitiveArray`].
#[inline]
pub fn sequence_decompress(array: &SequenceArray) -> VortexResult<ArrayRef> {
fn decompress_inner<P: NativePType>(
base: P,
multiplier: P,
fn decompress_inner<C: IntegerPType, O: IntegerPType>(
base: C,
multiplier: C,
len: usize,
nullability: Nullability,
) -> PrimitiveArray {
let values = BufferMut::from_trusted_len_iter(SequenceIter {
acc: base,
step: multiplier,
remaining: len,
});
let values = BufferMut::from_trusted_len_iter(
SequenceIter {
acc: base,
step: multiplier,
remaining: len,
}
.map(|value| {
<O as NumCast>::from(value)
.vortex_expect("validated sequence values must fit output ptype")
}),
);
PrimitiveArray::new(values, Validity::from(nullability))
}

let prim = match_each_native_ptype!(array.ptype(), |P| {
let base = array.base().cast::<P>()?;
let multiplier = array.multiplier().cast::<P>()?;
decompress_inner(base, multiplier, array.len(), array.dtype().nullability())
let prim = match_each_integer_ptype!(array.ptype(), |C| {
let base = array.base().cast::<C>()?;
let multiplier = array.multiplier().cast::<C>()?;
match_each_integer_ptype!(array.dtype().as_ptype(), |O| {
decompress_inner::<C, O>(base, multiplier, array.len(), array.dtype().nullability())
})
});
Ok(prim.into_array())
}
Expand Down Expand Up @@ -131,7 +141,7 @@ fn encode_primitive_array<P: NativePType + Into<PValue> + CheckedAdd + CheckedSu
return Ok(None);
}

if SequenceData::try_last(base.into(), multiplier.into(), P::PTYPE, slice.len()).is_err() {
if SequenceData::try_last(base.into(), multiplier.into(), slice.len()).is_err() {
// If the last value is out of range, we cannot encode
return Ok(None);
}
Expand Down
Loading
Loading