diff --git a/vortex-array/src/arrays/bool/compute/cast.rs b/vortex-array/src/arrays/bool/compute/cast.rs index fe9332346ca..3b47ce62c8f 100644 --- a/vortex-array/src/arrays/bool/compute/cast.rs +++ b/vortex-array/src/arrays/bool/compute/cast.rs @@ -1,6 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use num_traits::One; +use num_traits::Zero; +use vortex_buffer::BufferMut; use vortex_error::VortexResult; use crate::ArrayRef; @@ -9,8 +12,10 @@ use crate::IntoArray; use crate::array::ArrayView; use crate::arrays::Bool; use crate::arrays::BoolArray; +use crate::arrays::PrimitiveArray; use crate::arrays::bool::BoolArrayExt; use crate::dtype::DType; +use crate::match_each_native_ptype; use crate::scalar_fn::fns::cast::CastKernel; use crate::scalar_fn::fns::cast::CastReduce; @@ -38,17 +43,34 @@ impl CastKernel for Bool { dtype: &DType, ctx: &mut ExecutionCtx, ) -> VortexResult> { - if !dtype.is_boolean() { - return Ok(None); + if dtype.is_boolean() { + let new_validity = + array + .validity()? + .cast_nullability(dtype.nullability(), array.len(), ctx)?; + return Ok(Some( + BoolArray::new(array.to_bit_buffer(), new_validity).into_array(), + )); } + let DType::Primitive(new_ptype, new_nullability) = dtype else { + return Ok(None); + }; + let new_validity = array .validity()? - .cast_nullability(dtype.nullability(), array.len(), ctx)?; - Ok(Some( - BoolArray::new(array.to_bit_buffer(), new_validity).into_array(), - )) + .cast_nullability(*new_nullability, array.len(), ctx)?; + + let bits = array.to_bit_buffer(); + let len = bits.len(); + + Ok(Some(match_each_native_ptype!(*new_ptype, |T| { + let (one, zero) = (::one(), ::zero()); + let mut buffer = BufferMut::::with_capacity(len); + buffer.extend(bits.iter().map(|v| if v { one } else { zero })); + PrimitiveArray::new(buffer.freeze(), new_validity).into_array() + }))) } } @@ -67,6 +89,7 @@ mod tests { use crate::compute::conformance::cast::test_cast_conformance; use crate::dtype::DType; use crate::dtype::Nullability; + use crate::dtype::PType; static SESSION: LazyLock = LazyLock::new(crate::array_session); @@ -102,4 +125,22 @@ mod tests { fn test_cast_bool_conformance(#[case] array: BoolArray) { test_cast_conformance(&array.into_array()); } + + #[rstest] + #[case(PType::I8)] + #[case(PType::I32)] + #[case(PType::I64)] + #[case(PType::U8)] + #[case(PType::U64)] + #[case(PType::F32)] + #[case(PType::F64)] + fn cast_bool_to_primitive(#[case] target: PType) { + let mut ctx = SESSION.create_execution_ctx(); + let arr = BoolArray::from_iter(vec![true, false, true]).into_array(); + let out = arr + .cast(DType::Primitive(target, Nullability::NonNullable)) + .unwrap(); + let out = out.execute::(&mut ctx).unwrap().into_array(); + assert_eq!(out.len(), 3); + } }