diff --git a/Cargo.lock b/Cargo.lock index 37d4e38ece8..7ad5d1ad40c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10458,6 +10458,7 @@ dependencies = [ "sqllogictest", "thiserror 2.0.18", "tokio", + "tracing-subscriber", "vortex", "vortex-datafusion", "vortex-duckdb", diff --git a/vortex-array/src/arrays/bool/compute/cast.rs b/vortex-array/src/arrays/bool/compute/cast.rs index fe9332346ca..52418e1176c 100644 --- a/vortex-array/src/arrays/bool/compute/cast.rs +++ b/vortex-array/src/arrays/bool/compute/cast.rs @@ -1,6 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use num_traits::One; +use num_traits::Zero; +use vortex_buffer::BufferMut; use vortex_error::VortexResult; use crate::ArrayRef; @@ -9,8 +12,10 @@ use crate::IntoArray; use crate::array::ArrayView; use crate::arrays::Bool; use crate::arrays::BoolArray; +use crate::arrays::PrimitiveArray; use crate::arrays::bool::BoolArrayExt; use crate::dtype::DType; +use crate::match_each_native_ptype; use crate::scalar_fn::fns::cast::CastKernel; use crate::scalar_fn::fns::cast::CastReduce; @@ -38,17 +43,34 @@ impl CastKernel for Bool { dtype: &DType, ctx: &mut ExecutionCtx, ) -> VortexResult> { - if !dtype.is_boolean() { - return Ok(None); + if dtype.is_boolean() { + let new_validity = + array + .validity()? + .cast_nullability(dtype.nullability(), array.len(), ctx)?; + return Ok(Some( + BoolArray::new(array.to_bit_buffer(), new_validity).into_array(), + )); } + let DType::Primitive(new_ptype, new_nullability) = dtype else { + return Ok(None); + }; + let new_validity = array .validity()? - .cast_nullability(dtype.nullability(), array.len(), ctx)?; - Ok(Some( - BoolArray::new(array.to_bit_buffer(), new_validity).into_array(), - )) + .cast_nullability(*new_nullability, array.len(), ctx)?; + + let bits = array.to_bit_buffer(); + let len = bits.len(); + + Ok(Some(match_each_native_ptype!(*new_ptype, |T| { + let (one, zero) = (::one(), ::zero()); + let mut buffer = BufferMut::::with_capacity(len); + buffer.extend(bits.iter().map(|v| if v { one } else { zero })); + PrimitiveArray::new(buffer.freeze(), new_validity).into_array() + }))) } } @@ -102,4 +124,22 @@ mod tests { fn test_cast_bool_conformance(#[case] array: BoolArray) { test_cast_conformance(&array.into_array()); } + + #[rstest] + #[case(crate::dtype::PType::I8)] + #[case(crate::dtype::PType::I32)] + #[case(crate::dtype::PType::I64)] + #[case(crate::dtype::PType::U8)] + #[case(crate::dtype::PType::U64)] + #[case(crate::dtype::PType::F32)] + #[case(crate::dtype::PType::F64)] + fn cast_bool_to_primitive(#[case] target: crate::dtype::PType) { + let mut ctx = SESSION.create_execution_ctx(); + let arr = BoolArray::from_iter(vec![true, false, true]).into_array(); + let out = arr + .cast(DType::Primitive(target, Nullability::NonNullable)) + .unwrap(); + let out = out.execute::(&mut ctx).unwrap().into_array(); + assert_eq!(out.len(), 3); + } } diff --git a/vortex-duckdb/build.rs b/vortex-duckdb/build.rs index 7647011e366..186dba9757b 100644 --- a/vortex-duckdb/build.rs +++ b/vortex-duckdb/build.rs @@ -27,11 +27,13 @@ const DEFAULT_DUCKDB_VERSION: &str = "1.5.3"; const BUILD_ARTIFACTS: [&str; 3] = ["libduckdb.dylib", "libduckdb.so", "libduckdb_static.a"]; -const SOURCE_FILES: [&str; 7] = [ +const SOURCE_FILES: [&str; 9] = [ "cpp/vortex_duckdb.cpp", "cpp/copy_function.cpp", "cpp/expr.cpp", + "cpp/optimizer.cpp", "cpp/scalar_fn_pushdown.cpp", + "cpp/cast_pushdown.cpp", "cpp/table_filter.cpp", "cpp/table_function.cpp", "cpp/vector.cpp", diff --git a/vortex-duckdb/cpp/cast_pushdown.cpp b/vortex-duckdb/cpp/cast_pushdown.cpp new file mode 100644 index 00000000000..7121c192a94 --- /dev/null +++ b/vortex-duckdb/cpp/cast_pushdown.cpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors +#include "cast_pushdown.hpp" +#include "table_function.hpp" + +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" + +// A GET reachable through a single-child chain of filters/projections. A join +// (or any other multi-child operator) breaks the chain. +// See test/sql/copy/csv/test_insert_into_types.test in duckdb (cast not pushed past a join) +static bool ReachesPushdownGet(const LogicalOperator &op) { + const LogicalOperator *cur = &op; + while (cur->children.size() == 1) { + cur = cur->children[0].get(); + switch (cur->type) { + case LogicalOperatorType::LOGICAL_GET: + return cur->Cast().function.bind == duckdb_vx_table_function_bind; + case LogicalOperatorType::LOGICAL_FILTER: + case LogicalOperatorType::LOGICAL_PROJECTION: + continue; + default: + return false; + } + } + return false; +} + +void CastCollect::VisitOperator(LogicalOperator &op) { + /* + * Logical projection expressions are columns which reference underlying + * GETs. Don't process them, as they would add conflicts for every column + * used in projection. Example: PROJECTION(col) -> GET(col). We don't want + * to visit BoundColumnRefExpression in PROJECTION to avoid registering a + * non-existent conflict. + * + * However, CastReplace will visit them because we need to update their + * types if pushdown succeeded. + */ + if (op.type != LogicalOperatorType::LOGICAL_PROJECTION) { + return LogicalOperatorVisitor::VisitOperator(op); + } + auto &projection = op.Cast(); + + // Only push casts from a projection that forwards just column refs and + // casts and reaches a GET without a join in between. A constant or other + // expression makes the projection ineligible. + // See test/sql/copy/csv/test_csv_error_message_type.test (top-level cast + // to VARCHAR must still push) and test_large_integer_detection.test (a + // nested cast to VARCHAR must not) in duckdb. + bool clean = ReachesPushdownGet(projection); + for (const auto &e : projection.expressions) { + switch (e->GetExpressionClass()) { + case ExpressionClass::BOUND_COLUMN_REF: + case ExpressionClass::BOUND_CAST: + continue; + default: + clean = false; + break; + } + } + if (clean) { + for (const auto &e : projection.expressions) { + if (e->GetExpressionClass() == ExpressionClass::BOUND_CAST) { + top_level_casts.insert(e.get()); + } + } + } + if (projections.count(projection.table_index)) { + VisitOperatorChildren(op); + return; + } + + LogicalOperatorVisitor::VisitOperator(op); +} + +ExpressionPtr CastCollect::VisitReplace(BoundColumnRefExpression &expr, ExpressionPtr *ptr) { + if (const auto binding = Resolve(expr.binding, analyses, projections)) { + // Column is used without cast applied to it, register a conflict. + // Not emplace() as we need to update the value if it was present + binding->analysis.col_to_expr[binding->column_index] = nullptr; + } + return std::move(*ptr); +} + +ExpressionPtr CastCollect::VisitReplace(BoundCastExpression &expr, ExpressionPtr *ptr) { + if (expr.child->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + // Descend into children so e.g. fn(col, other) still sees "col" and + // registers a conflict + return nullptr; + } + const auto &bound_col = expr.child->Cast(); + const auto binding = Resolve(bound_col.binding, analyses, projections); + if (!binding) { + return nullptr; + } + auto &col_to_expr = binding->analysis.col_to_expr; + + if (auto it = col_to_expr.find(binding->column_index); it == col_to_expr.end()) { + // Only a top-level projection cast starts a candidate. + if (top_level_casts.count(&expr)) { + col_to_expr.emplace(binding->column_index, &expr); + } + } else if (it->second == nullptr || + it->second->Cast().return_type != expr.return_type) { + // Different target type, or already a conflict. + it->second = nullptr; + } + + return std::move(*ptr); +} + +ExpressionPtr CastReplace::VisitReplace(BoundColumnRefExpression &expr, ExpressionPtr *ptr) { + const auto binding = Resolve(expr.binding, analyses, projections); + if (!binding) { + return std::move(*ptr); + } + + const auto &[analysis, column_index, projection] = *binding; + if (CanPushdownColumn(analysis, column_index)) { + const idx_t storage_index = analysis.get.GetColumnIds()[column_index].GetPrimaryIndex(); + const LogicalType return_type = analysis.get.returned_types[storage_index]; + expr.return_type = return_type; + // LogicalProjection types are resolved by calling + // LogicalProjection::ResolveTypes, so we need to check whether types in + // projection have been resolved, and updated them only if needed. + if (projection != nullptr && !projection->types.empty()) { + projection->types[column_index] = return_type; + } + } + + return std::move(*ptr); +} + +ExpressionPtr CastReplace::VisitReplace(BoundCastExpression &expr, ExpressionPtr *ptr) { + if (expr.child->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + return nullptr; // Same as in ScalarFnCollect::VisitReplace + } + auto &bound_col_base = expr.child; + const auto &bound_col = bound_col_base->Cast(); + const auto binding = Resolve(bound_col.binding, analyses, projections); + if (!binding) { + return nullptr; + } + + const auto &[analysis, column_index, projection] = *binding; + if (!CanPushdownColumn(analysis, column_index)) { + return std::move(*ptr); + } + + const idx_t storage_index = analysis.get.GetColumnIds()[column_index].GetPrimaryIndex(); + const LogicalType return_type = analysis.get.returned_types[storage_index]; + bound_col_base->return_type = return_type; + // Same as in CastReplace::VisitReplace(BoundColumnRefExpression) + if (projection != nullptr && !projection->types.empty()) { + projection->types[column_index] = return_type; + } + return std::move(bound_col_base); +} + +CastCollect::CastCollect(Analyses &analyses, const Projections &projections) + : analyses(analyses), projections(projections) { +} + +CastReplace::CastReplace(Analyses &analyses, const Projections &projections) + : analyses(analyses), projections(projections) { +} diff --git a/vortex-duckdb/cpp/expr.cpp b/vortex-duckdb/cpp/expr.cpp index afe2573adc2..566b760bd23 100644 --- a/vortex-duckdb/cpp/expr.cpp +++ b/vortex-duckdb/cpp/expr.cpp @@ -2,8 +2,10 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors #include "expr.h" +#include "duckdb/common/type_visitor.hpp" #include "duckdb/function/scalar_function.hpp" #include "duckdb/planner/expression/bound_between_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression/bound_comparison_expression.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" @@ -129,3 +131,17 @@ extern "C" void duckdb_vx_expr_get_bound_function(duckdb_vx_expr ffi_expr, out->scalar_function = reinterpret_cast(&expr.function); out->bind_info = expr.bind_info.get(); } + +extern "C" duckdb_vx_expr duckdb_vx_expr_get_bound_cast_child(duckdb_vx_expr ffi_expr) { + D_ASSERT(ffi_expr); + auto &expr = reinterpret_cast(ffi_expr)->Cast(); + return reinterpret_cast(expr.child.get()); +} + +extern "C" bool duckdb_vx_logical_type_contains_128bit(duckdb_logical_type ffi_type) { + D_ASSERT(ffi_type); + auto &type = *reinterpret_cast(ffi_type); + return TypeVisitor::Contains(type, [](const LogicalType &t) { + return t.id() == LogicalTypeId::HUGEINT || t.id() == LogicalTypeId::UHUGEINT; + }); +} diff --git a/vortex-duckdb/cpp/include/cast_pushdown.hpp b/vortex-duckdb/cpp/include/cast_pushdown.hpp new file mode 100644 index 00000000000..c805a6b3615 --- /dev/null +++ b/vortex-duckdb/cpp/include/cast_pushdown.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors +#pragma once +#include "optimizer.hpp" + +#include "duckdb/common/unordered_set.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/planner/logical_operator.hpp" + +using namespace duckdb; + +/** + * Collect CAST(col) expressions. If "col" is used without CAST in "plan", + * record in "analyses.conflicts" + */ +struct CastCollect final : LogicalOperatorVisitor { + Analyses &analyses; + const Projections &projections; + // Casts that are direct outputs of a clean projection over a GET. Only these + // start a pushdown candidate; a nested cast may push down a different value. + // See test/sql/copy/csv/auto/test_large_integer_detection.test in duckdb + unordered_set top_level_casts; + + CastCollect(Analyses &analyses, const Projections &projections); + void VisitOperator(LogicalOperator &op) override; + ExpressionPtr VisitReplace(BoundColumnRefExpression &expr, ExpressionPtr *ptr) override; + ExpressionPtr VisitReplace(BoundCastExpression &expr, ExpressionPtr *ptr) override; +}; + +/* + * For "col" in columns collected by ScalarFnCollect, replace CAST(col) to "col" + * if "col" doesn't have conflicting usage. Update return types for bound + * columns and logical projections referencing this column. + */ +struct CastReplace final : LogicalOperatorVisitor { + Analyses &analyses; + const Projections &projections; + + CastReplace(Analyses &analyses, const Projections &aliases); + ExpressionPtr VisitReplace(BoundColumnRefExpression &expr, ExpressionPtr *ptr) override; + ExpressionPtr VisitReplace(BoundCastExpression &expr, ExpressionPtr *ptr) override; +}; diff --git a/vortex-duckdb/cpp/include/expr.h b/vortex-duckdb/cpp/include/expr.h index 5b7997596d6..a1d4499584e 100644 --- a/vortex-duckdb/cpp/include/expr.h +++ b/vortex-duckdb/cpp/include/expr.h @@ -264,6 +264,12 @@ typedef struct { void duckdb_vx_expr_get_bound_function(duckdb_vx_expr expr, duckdb_vx_expr_bound_function *out); +duckdb_vx_expr duckdb_vx_expr_get_bound_cast_child(duckdb_vx_expr expr); + +// Check if type or contained types i.e. List(T) contains HUGEINT/UHUGEINT +// These are not present in DType so we can't convert. +bool duckdb_vx_logical_type_contains_128bit(duckdb_logical_type type); + #ifdef __cplusplus /* End C ABI */ } #endif diff --git a/vortex-duckdb/cpp/include/optimizer.hpp b/vortex-duckdb/cpp/include/optimizer.hpp new file mode 100644 index 00000000000..09394af8a0f --- /dev/null +++ b/vortex-duckdb/cpp/include/optimizer.hpp @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors +#pragma once +#include "table_function.hpp" + +#include "duckdb/planner/expression.hpp" +#include "duckdb/planner/operator/logical_get.hpp" + +#include + +// Aliases here are for ease of migration to duckdb 2.0 where these are +// separate types + +using namespace duckdb; + +/** + * Column index in requested scan. Example: + * + * CREATE TABLE t (a1 INTEGER, a2 INTEGER, a3 INTEGER); + * SELECT a2, a3 FROM t; + * + * a2's TableColumnScanIndex is 0, a3's TableColumnScanIndex is 1, + * index is index in SELECT clause. + */ +using TableColumnScanIndex = idx_t; +using ProjectionIndex = TableColumnScanIndex; + +/** + * Column index in table's storage. Example: + * + * CREATE TABLE t (a1 INTEGER, a2 INTEGER, a3 INTEGER); + * SELECT a2, a3 FROM t; + * + * a2's TableColumnStorageIndex is 1, a3's TableColumnScanIndex is 2, + * index is index of column in table storage. + * + * for i: TableColumnScanIndex, column_ids[i].GetPrimaryIndex() is + * TableColumnStorageIndex + */ +using TableColumnStorageIndex = idx_t; + +using TableIndex = idx_t; + +using ExpressionPtr = unique_ptr; +using LogicalOperatorPtr = unique_ptr; + +struct GetAnalysis { + LogicalGet &get; + /** + * for fn(col), mapping of "col scan index" -> "expression applied to function". + * "expression" is nullptr iff column is used with a different expression + * or without expression application in the query plan (i.e. SELECT col). + */ + unordered_map col_to_expr; + + TableColumnStorageIndex StorageIndex(TableColumnScanIndex idx) const; +}; + +using Analyses = unordered_map; + +/* + * Using scalar function pushdown as a specific example, + * SELECT fn(col) FROM '*.vortex' yields a PROJECTION fn(col) -> GET (vortex) + * plan. PROJECTION's "col" table_index is 1, vortex GET's table_index is 0. + * So we want to track original table_index for GET in case column is found + * in filter we failed to push down (i.e. WHERE prefix(col, 'h')) as well as + * projection's table_index. + * + * So we keep a mapping of + * + * "projection table index" to "projection operator". + * + * to resolve this. + * For simplicity, current implementation is limited to one level i.e. + * PROJECTION -> GET (i.e. read from VIEW) is pushed down but VIEW->VIEW->GET + * or VIEW->CTE->GET is not. + * + * Storing a reference is fine because the plan outlives the optimizer pass. + */ +using Projections = unordered_map; + +void FindGetsAndProjections(LogicalOperator &op, Analyses &analyses, Projections &aliases); + +struct GetBinding { + GetAnalysis &analysis; + TableColumnScanIndex column_index; + // If column binding was part of a projection, this is non-nullptr + LogicalProjection *projection; +}; + +/* + * Given a column binding, resolve it to a GET and a GET's column scan index. + * Returns nullopt for virtual columns and columns which are neither part of + * GET nor part of PROJECTION wrapping a GET. + */ +std::optional Resolve(ColumnBinding binding, Analyses &analyses, const Projections &projections); + +// A passthrough projection only forwards its child columns, e.g. a VIEW's +// "SELECT col". +bool IsPassthrough(const LogicalProjection &projection); + +// There are no conflicting column usages in the plan +bool CanPushdownColumn(const GetAnalysis &analysis, TableColumnScanIndex idx); + +template +LogicalOperatorPtr TryPushdown(ClientContext &context, LogicalOperatorPtr plan) { + Analyses analyses; + Projections projections; + FindGetsAndProjections(*plan, analyses, projections); + if (analyses.empty()) { + return plan; + } + Collect(analyses, projections).VisitOperator(*plan); + + bool any_pushed = false; + for (auto &[_, analysis] : analyses) { + for (auto &[column_index, expr] : analysis.col_to_expr) { + if (expr == nullptr) { // Conflict for column + continue; + } + const TableColumnStorageIndex storage_index = analysis.StorageIndex(column_index); + TableFunctionProjectionExpressionInput input {analysis.get, *expr, storage_index}; + if (projection_expression_pushdown(context, input)) { + analysis.get.types[column_index] = expr->return_type; + // LOGICAL_GET doesn't initialize .types of LogicalOperator + analysis.get.returned_types[storage_index] = expr->return_type; + any_pushed = true; + } else { // failed to push down expression, can't replace it + expr = nullptr; + } + } + } + + if (any_pushed) { + Replace(analyses, projections).VisitOperator(*plan); + } + return plan; +} diff --git a/vortex-duckdb/cpp/include/scalar_fn_pushdown.hpp b/vortex-duckdb/cpp/include/scalar_fn_pushdown.hpp index ef590c96dcc..19bad361cb0 100644 --- a/vortex-duckdb/cpp/include/scalar_fn_pushdown.hpp +++ b/vortex-duckdb/cpp/include/scalar_fn_pushdown.hpp @@ -1,82 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors #pragma once -#include "duckdb.h" - -#include "duckdb/optimizer/optimizer_extension.hpp" +#include "optimizer.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/operator/logical_get.hpp" -#include using namespace duckdb; -using ExpressionPtr = unique_ptr; -using LogicalOperatorPtr = unique_ptr; - -/** - * Column index in requested scan. Example: - * - * CREATE TABLE t (a1 INTEGER, a2 INTEGER, a3 INTEGER); - * SELECT a2, a3 FROM t; - * - * a2's TableColumnScanIndex is 0, a3's TableColumnScanIndex is 1, - * index is index in SELECT clause. - */ -using TableColumnScanIndex = idx_t; - -/** - * Column index in table's storage. Example: - * - * CREATE TABLE t (a1 INTEGER, a2 INTEGER, a3 INTEGER); - * SELECT a2, a3 FROM t; - * - * a2's TableColumnStorageIndex is 1, a3's TableColumnScanIndex is 2, - * index is index of column in table storage. - * - * for i: TableColumnScanIndex, column_ids[i].GetPrimaryIndex() is - * TableColumnStorageIndex - */ -using TableColumnStorageIndex = idx_t; - -using TableIndex = idx_t; - -struct GetAnalysis { - LogicalGet &get; - /** - * for fn(col), mapping of "col scan index" -> "fn expression". - * "fn expression" is nullptr iff column is used with a different function - * or without function application in the query plan. - */ - unordered_map col_to_fn; - - TableColumnStorageIndex StorageIndex(TableColumnScanIndex idx) const; -}; - -using Analyses = unordered_map; - -/* - * SELECT fn(col) FROM '*.vortex' yields a PROJECTION fn(col) -> GET (vortex) - * plan. PROJECTION's "col" table_index is 1, vortex GET's table_index is 0. - * So we want to track original table_index for GET in case column is found - * in filter we failed to push down (i.e. WHERE prefix(col, 'h')) as well as - * projection's table_index. - * - * So we keep a mapping of - * - * "projection table index" to "projection operator". - * - * to resolve this. - * For simplicity, current implementation is limited to one level i.e. - * PROJECTION -> GET (i.e. read from VIEW) is pushed down but VIEW->VIEW->GET - * or VIEW->CTE->GET is not. - * - * Storing a reference is fine because the plan outlives the optimizer pass. - */ -using Projections = unordered_map; - -LogicalOperatorPtr TryPushdownScalarFunctions(ClientContext &context, LogicalOperatorPtr plan); - /** * Collect fn(col) expressions i.e. expressions where a single function (not * a function chain) wraps a single bound column. If "col" is used without @@ -105,19 +35,3 @@ struct ScalarFnReplace final : LogicalOperatorVisitor { ExpressionPtr VisitReplace(BoundColumnRefExpression &expr, ExpressionPtr *ptr) override; ExpressionPtr VisitReplace(BoundFunctionExpression &expr, ExpressionPtr *ptr) override; }; - -void FindGetsAndProjections(LogicalOperator &op, Analyses &analyses, Projections &aliases); - -struct GetBinding { - GetAnalysis &analysis; - TableColumnScanIndex column_index; - // If column binding was part of a projection, this is non-nullptr - LogicalProjection *projection; -}; - -/* - * Given a column binding, resolve it to a GET and a GET's column scan index. - * Returns nullopt for virtual columns and columns which are neither part of - * GET nor part of PROJECTION wrapping a GET. - */ -std::optional Resolve(ColumnBinding binding, Analyses &analyses, const Projections &projections); diff --git a/vortex-duckdb/cpp/optimizer.cpp b/vortex-duckdb/cpp/optimizer.cpp new file mode 100644 index 00000000000..84e3d1fa781 --- /dev/null +++ b/vortex-duckdb/cpp/optimizer.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors +#include "optimizer.hpp" +#include "table_function.hpp" + +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" + +void FindGetsAndProjections(LogicalOperator &op, Analyses &analyses, Projections &projections) { + using enum LogicalOperatorType; + switch (op.type) { + case LOGICAL_GET: { + if (auto &get = op.Cast(); get.function.bind == duckdb_vx_table_function_bind) { + analyses.emplace(get.table_index, GetAnalysis {get, {}}); + } + break; + } + case LOGICAL_PROJECTION: { + LogicalProjection &projection = op.Cast(); + D_ASSERT(projection.children.size() == 1); + auto &child = *projection.children[0]; + if (!IsPassthrough(projection) || child.type != LOGICAL_GET) { + break; + } + // The GET itself is recorded when recursion reaches it below. Only + // passthrough projections wrapping a vortex GET act as aliases. + if (auto &get = child.Cast(); get.function.bind == duckdb_vx_table_function_bind) { + projections.emplace(projection.table_index, projection); + } + break; + } + default: + break; + } + + for (auto &child : op.children) { + FindGetsAndProjections(*child, analyses, projections); + } +} + +TableColumnStorageIndex GetAnalysis::StorageIndex(TableColumnScanIndex idx) const { + return get.GetColumnIds()[idx].GetPrimaryIndex(); +} + +std::optional Resolve(ColumnBinding binding, Analyses &analyses, const Projections &projections) { + if (IsVirtualColumn(binding.column_index)) { + return std::nullopt; + } + if (const auto it = analyses.find(binding.table_index); it != analyses.end()) { + return {{it->second, binding.column_index, nullptr}}; + } + + const auto projection_it = projections.find(binding.table_index); + if (projection_it == projections.end()) { + return std::nullopt; + } + + LogicalProjection &projection = projection_it->second; + const ExpressionPtr &inner = projection.expressions[binding.column_index]; + if (inner->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + return std::nullopt; + } + const ColumnBinding &get_binding = inner->Cast().binding; + if (IsVirtualColumn(get_binding.column_index)) { + return std::nullopt; + } + if (const auto it = analyses.find(get_binding.table_index); it != analyses.end()) { + return {{it->second, get_binding.column_index, &projection}}; + } + return std::nullopt; +} + +bool CanPushdownColumn(const GetAnalysis &analysis, TableColumnScanIndex idx) { + const auto it = analysis.col_to_expr.find(idx); + return it != analysis.col_to_expr.end() && it->second != nullptr; +} + +bool IsPassthrough(const LogicalProjection &projection) { + if (projection.expressions.empty()) { + return false; // don't register empty projections in Projections + } + for (const auto &e : projection.expressions) { + if (e->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + return false; + } + } + return true; +} diff --git a/vortex-duckdb/cpp/scalar_fn_pushdown.cpp b/vortex-duckdb/cpp/scalar_fn_pushdown.cpp index 057920f79f7..e52a0a52f8f 100644 --- a/vortex-duckdb/cpp/scalar_fn_pushdown.cpp +++ b/vortex-duckdb/cpp/scalar_fn_pushdown.cpp @@ -1,9 +1,11 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +#include "scalar_fn_pushdown.hpp" + #include "duckdb/catalog/catalog.hpp" #include "duckdb/planner/operator/logical_projection.hpp" -#include "scalar_fn_pushdown.hpp" -#include "table_function.hpp" +#include "duckdb/planner/operator/logical_get.hpp" + #include /** @@ -12,114 +14,6 @@ * If there are any functions left, this means they were not pushed down and * may produce conflicts (e.g. WHERE prefix("str", 'h')). */ - -LogicalOperatorPtr TryPushdownScalarFunctions(ClientContext &context, LogicalOperatorPtr plan) { - Analyses analyses; - Projections projections; - FindGetsAndProjections(*plan, analyses, projections); - if (analyses.empty()) { - return plan; - } - ScalarFnCollect(analyses, projections).VisitOperator(*plan); - - bool any_pushed = false; - for (auto &[_, analysis] : analyses) { - for (auto &[column_index, expr] : analysis.col_to_fn) { - if (expr == nullptr) { // Conflict for column - continue; - } - const TableColumnStorageIndex storage_index = analysis.StorageIndex(column_index); - TableFunctionProjectionExpressionInput input {analysis.get, *expr, storage_index}; - if (projection_expression_pushdown(context, input)) { - analysis.get.types[column_index] = expr->return_type; - analysis.get.returned_types[storage_index] = expr->return_type; - any_pushed = true; - } else { // failed to push down expression, can't replace it - expr = nullptr; - } - } - } - - if (any_pushed) { - ScalarFnReplace(analyses, projections).VisitOperator(*plan); - } - return plan; -} - -// A passthrough projection only forwards its child columns, e.g. a VIEW's -// "SELECT col". -static bool is_passthrough(const LogicalProjection &projection) { - if (projection.expressions.empty()) { - return false; // don't register empty projections in Projections - } - for (const auto &e : projection.expressions) { - if (e->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { - return false; - } - } - return true; -} - -void FindGetsAndProjections(LogicalOperator &op, Analyses &analyses, Projections &projections) { - using enum LogicalOperatorType; - switch (op.type) { - case LOGICAL_GET: { - if (auto &get = op.Cast(); get.function.bind == duckdb_vx_table_function_bind) { - analyses.emplace(get.table_index, GetAnalysis {get, {}}); - } - break; - } - case LOGICAL_PROJECTION: { - LogicalProjection &projection = op.Cast(); - D_ASSERT(projection.children.size() == 1); - auto &child = *projection.children[0]; - if (!is_passthrough(projection) || child.type != LOGICAL_GET) { - break; - } - // The GET itself is recorded when recursion reaches it below. Only - // passthrough projections wrapping a vortex GET act as aliases. - if (auto &get = child.Cast(); get.function.bind == duckdb_vx_table_function_bind) { - projections.emplace(projection.table_index, projection); - } - break; - } - default: - break; - } - - for (auto &child : op.children) { - FindGetsAndProjections(*child, analyses, projections); - } -} - -std::optional Resolve(ColumnBinding binding, Analyses &analyses, const Projections &projections) { - if (IsVirtualColumn(binding.column_index)) { - return std::nullopt; - } - if (const auto it = analyses.find(binding.table_index); it != analyses.end()) { - return {{it->second, binding.column_index, nullptr}}; - } - - const auto projection_it = projections.find(binding.table_index); - if (projection_it == projections.end()) { - return std::nullopt; - } - - LogicalProjection &projection = projection_it->second; - const ExpressionPtr &inner = projection.expressions[binding.column_index]; - if (inner->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { - return std::nullopt; - } - const ColumnBinding &get_binding = inner->Cast().binding; - if (IsVirtualColumn(get_binding.column_index)) { - return std::nullopt; - } - if (const auto it = analyses.find(get_binding.table_index); it != analyses.end()) { - return {{it->second, get_binding.column_index, &projection}}; - } - return std::nullopt; -} - void ScalarFnCollect::VisitOperator(LogicalOperator &op) { /* * Logical projection expressions are columns which reference underlying @@ -143,7 +37,7 @@ ExpressionPtr ScalarFnCollect::VisitReplace(BoundColumnRefExpression &expr, Expr if (const auto binding = Resolve(expr.binding, analyses, projections)) { // Column is used without function applied to it, register a conflict. // Not emplace() as we need to update the value if it was present - binding->analysis.col_to_fn[binding->column_index] = nullptr; + binding->analysis.col_to_expr[binding->column_index] = nullptr; } return std::move(*ptr); } @@ -160,11 +54,11 @@ ExpressionPtr ScalarFnCollect::VisitReplace(BoundFunctionExpression &expr, Expre if (!binding) { return nullptr; } - auto &col_to_fn = binding->analysis.col_to_fn; + auto &col_to_expr = binding->analysis.col_to_expr; - if (auto it = col_to_fn.find(binding->column_index); it == col_to_fn.end()) { + if (auto it = col_to_expr.find(binding->column_index); it == col_to_expr.end()) { // This is the first time we see the column used by a single function. - col_to_fn.emplace(binding->column_index, &expr); + col_to_expr.emplace(binding->column_index, &expr); } else if (it->second == nullptr || !it->second->Equals(expr)) { // Either column is used with different function in "expr" or // there already is a conflict. @@ -174,11 +68,6 @@ ExpressionPtr ScalarFnCollect::VisitReplace(BoundFunctionExpression &expr, Expre return std::move(*ptr); } -static bool can_pushdown_column(const GetAnalysis &analysis, TableColumnScanIndex idx) { - const auto it = analysis.col_to_fn.find(idx); - return it != analysis.col_to_fn.end() && it->second != nullptr; -} - ExpressionPtr ScalarFnReplace::VisitReplace(BoundColumnRefExpression &expr, ExpressionPtr *ptr) { const auto binding = Resolve(expr.binding, analyses, projections); if (!binding) { @@ -186,9 +75,9 @@ ExpressionPtr ScalarFnReplace::VisitReplace(BoundColumnRefExpression &expr, Expr } const auto &[analysis, column_index, projection] = *binding; - if (can_pushdown_column(analysis, column_index)) { + if (CanPushdownColumn(analysis, column_index)) { expr.return_type = analysis.get.types[column_index]; - if (projection != nullptr) { + if (projection != nullptr && !projection->types.empty()) { projection->types[column_index] = expr.return_type; } } @@ -209,12 +98,12 @@ ExpressionPtr ScalarFnReplace::VisitReplace(BoundFunctionExpression &expr, Expre } const auto &[analysis, column_index, projection] = *binding; - if (!can_pushdown_column(analysis, column_index)) { + if (!CanPushdownColumn(analysis, column_index)) { return std::move(*ptr); } bound_col_base->return_type = analysis.get.types[column_index]; - if (projection != nullptr) { + if (projection != nullptr && !projection->types.empty()) { projection->types[column_index] = bound_col_base->return_type; } return std::move(bound_col_base); @@ -227,7 +116,3 @@ ScalarFnCollect::ScalarFnCollect(Analyses &analyses, const Projections &projecti ScalarFnReplace::ScalarFnReplace(Analyses &analyses, const Projections &projections) : analyses(analyses), projections(projections) { } - -TableColumnStorageIndex GetAnalysis::StorageIndex(TableColumnScanIndex idx) const { - return get.GetColumnIds()[idx].GetPrimaryIndex(); -} diff --git a/vortex-duckdb/cpp/vortex_duckdb.cpp b/vortex-duckdb/cpp/vortex_duckdb.cpp index 091f98a1703..27710025b44 100644 --- a/vortex-duckdb/cpp/vortex_duckdb.cpp +++ b/vortex-duckdb/cpp/vortex_duckdb.cpp @@ -4,6 +4,7 @@ #include "data.hpp" #include "error.hpp" #include "scalar_fn_pushdown.hpp" +#include "cast_pushdown.hpp" #include "vortex_duckdb.h" #include "duckdb/catalog/catalog.hpp" @@ -268,7 +269,8 @@ extern "C" duckdb_blob duckdb_vx_value_get_geometry(duckdb_value value) { } static void VortexOptimizeFunction(OptimizerExtensionInput &input, unique_ptr &plan) { - plan = TryPushdownScalarFunctions(input.context, std::move(plan)); + plan = TryPushdown(input.context, std::move(plan)); + plan = TryPushdown(input.context, std::move(plan)); } struct VortexOptimizerExtension final : OptimizerExtension { diff --git a/vortex-duckdb/src/convert/dtype.rs b/vortex-duckdb/src/convert/dtype.rs index 4238b354182..e8690cc7d93 100644 --- a/vortex-duckdb/src/convert/dtype.rs +++ b/vortex-duckdb/src/convert/dtype.rs @@ -88,8 +88,8 @@ impl FromLogicalType for DType { DUCKDB_TYPE::DUCKDB_TYPE_USMALLINT => DType::Primitive(U16, nullability), DUCKDB_TYPE::DUCKDB_TYPE_UINTEGER => DType::Primitive(U32, nullability), DUCKDB_TYPE::DUCKDB_TYPE_UBIGINT => DType::Primitive(U64, nullability), - DUCKDB_TYPE::DUCKDB_TYPE_HUGEINT => todo!(), - DUCKDB_TYPE::DUCKDB_TYPE_UHUGEINT => todo!(), + DUCKDB_TYPE::DUCKDB_TYPE_HUGEINT => vortex_bail!("I128 is not in Vortex type system"), + DUCKDB_TYPE::DUCKDB_TYPE_UHUGEINT => vortex_bail!("U128 is not in Vortex type system"), DUCKDB_TYPE::DUCKDB_TYPE_FLOAT => DType::Primitive(F32, nullability), DUCKDB_TYPE::DUCKDB_TYPE_DOUBLE => DType::Primitive(F64, nullability), DUCKDB_TYPE::DUCKDB_TYPE_VARCHAR => DType::Utf8(nullability), diff --git a/vortex-duckdb/src/convert/expr.rs b/vortex-duckdb/src/convert/expr.rs index 387b644fe30..287c9a9625f 100644 --- a/vortex-duckdb/src/convert/expr.rs +++ b/vortex-duckdb/src/convert/expr.rs @@ -38,6 +38,7 @@ use vortex::scalar_fn::fns::like::LikeOptions; use vortex::scalar_fn::fns::literal::Literal; use vortex::scalar_fn::fns::operators::Operator; +use crate::convert::dtype::FromLogicalType; use crate::cpp::DUCKDB_TYPE; use crate::cpp::DUCKDB_VX_EXPR_TYPE; use crate::duckdb; @@ -45,6 +46,7 @@ use crate::duckdb::BoundFunction; use crate::duckdb::BoundOperator; use crate::duckdb::ExpressionClass; use crate::duckdb::ExpressionClass::BoundBetween; +use crate::duckdb::ExpressionClass::BoundCast; use crate::duckdb::ExpressionClass::BoundColumnRef; use crate::duckdb::ExpressionClass::BoundComparison; use crate::duckdb::ExpressionClass::BoundConjunction; @@ -87,8 +89,7 @@ fn try_from_bound_function( let col = byte_length(col); // byte_length returns u64, strlen expects i64. // At this point we don't know column's dtype so we ultimately - // set it to be nullable. For non-nullable column the nullability - // will be AllValid so it's a marginal cost. + // set it to be nullable. let dtype = DType::Primitive(PType::I64, Nullability::Nullable); cast(col, dtype) } @@ -141,7 +142,7 @@ fn try_from_bound_function( return Ok(None); }; - // We don't know the column's nullability here, so we set it to nullable. + // We don't know the column's nullability here build_list_length(col, Nullability::Nullable) } // len/length semantics depend on the return type of underlying expr. @@ -155,7 +156,7 @@ fn try_from_bound_function( return Ok(None); }; - // Same nullability rationale as in "array_length" branch. + // We don't know the column's nullability here let list_len_expr = build_list_length(col, Nullability::Nullable); return Ok(Some(list_len_expr)); } else { @@ -202,12 +203,13 @@ fn is_supported_length_alias(func: &BoundFunction) -> bool { // Example: optional filters may fail to parse on our side (we return // Ok(None)), so we don't allow pushing these. pub fn can_push_expression(value: &duckdb::ExpressionRef) -> bool { - let Some(value) = value.as_class() else { + let Some(class) = value.as_class() else { return false; }; - match value { + match class { BoundColumnRef(_) => true, BoundConstant(_) => true, + BoundCast(c) => !value.return_type().contains_128bit() && can_push_expression(c.child), BoundRef => true, BoundComparison(comp) => can_push_expression(comp.left) && can_push_expression(comp.right), BoundBetween(between) => { @@ -255,27 +257,41 @@ pub fn try_from_projection_expression( value: &duckdb::ExpressionRef, field: &DuckdbField, ) -> VortexResult> { - let Some(value) = value.as_class() else { + let Some(class) = value.as_class() else { return Ok(None); }; - let ExpressionClass::BoundFunction(func) = value else { - return Ok(None); - }; - Ok(match func.scalar_function.name() { - "strlen" => { - let col = byte_length(get_item(field.name.as_str(), root())); - // byte_length returns u64, strlen expects i64 - let dtype = DType::Primitive(PType::I64, field.dtype.nullability()); - let col = cast(col, dtype); - Some(col) + Ok(match class { + ExpressionClass::BoundFunction(func) => { + match func.scalar_function.name() { + "strlen" => { + let col = byte_length(get_item(field.name.as_str(), root())); + // byte_length returns u64, strlen expects i64 + let dtype = DType::Primitive(PType::I64, field.dtype.nullability()); + let col = cast(col, dtype); + Some(col) + } + "array_length" => { + // Only accept array_length(expr) rather than array_length(expr, dim). + (func.children().count() == 1).then(|| list_length_on_field(field)) + } + // len/length have different semantics depending on field dtype. + "len" | "length" => { + matches!(field.dtype, DType::List(..) | DType::FixedSizeList(..)) + .then(|| list_length_on_field(field)) + } + _ => None, + } } - "array_length" => { - // Only accept array_length(expr) rather than array_length(expr, dim). - (func.children().count() == 1).then(|| list_length_on_field(field)) + BoundCast(_) => { + let target = value.return_type(); + if target.contains_128bit() { + None + } else { + let dtype = DType::from_logical_type(target, field.dtype.nullability())?; + let col = get_item(field.name.as_str(), root()); + Some(cast(col, dtype)) + } } - // len/length have different semantics depending on field dtype. - "len" | "length" => matches!(field.dtype, DType::List(..) | DType::FixedSizeList(..)) - .then(|| list_length_on_field(field)), _ => None, }) } @@ -286,14 +302,14 @@ fn try_from_expression_inner( value: &duckdb::ExpressionRef, col_sub: Option<&Expression>, ) -> VortexResult> { - let Some(value) = value.as_class() else { + let Some(class) = value.as_class() else { debug!( class_id = ?value.as_class_id(), "unknown expression class id" ); return Ok(None); }; - Ok(Some(match value { + Ok(Some(match class { BoundRef => { let Some(col) = col_sub else { vortex_bail!("BoundRef requested but no column supplied"); @@ -372,6 +388,18 @@ fn try_from_expression_inner( ExpressionClass::BoundFunction(func) => { return try_from_bound_function(&func, col_sub); } + BoundCast(c) => { + let target = value.return_type(); + if target.contains_128bit() { + return Ok(None); + } + let Some(child) = try_from_expression_inner(c.child, col_sub)? else { + return Ok(None); + }; + // We don't know the column's nullability here + let dtype = DType::from_logical_type(target, Nullability::Nullable)?; + cast(child, dtype) + } BoundConjunction(conj) => { let Some(children) = conj .children() diff --git a/vortex-duckdb/src/duckdb/expr.rs b/vortex-duckdb/src/duckdb/expr.rs index 48f255f744e..59b8f5405a9 100644 --- a/vortex-duckdb/src/duckdb/expr.rs +++ b/vortex-duckdb/src/duckdb/expr.rs @@ -44,6 +44,12 @@ impl ExpressionRef { pub fn as_class(&self) -> Option> { Some( match unsafe { cpp::duckdb_vx_expr_get_class(self.as_ptr()) } { + cpp::DUCKDB_VX_EXPR_CLASS::DUCKDB_VX_EXPR_CLASS_BOUND_CAST => { + let child = unsafe { + Expression::borrow(cpp::duckdb_vx_expr_get_bound_cast_child(self.as_ptr())) + }; + ExpressionClass::BoundCast(BoundCast { child }) + } cpp::DUCKDB_VX_EXPR_CLASS::DUCKDB_VX_EXPR_CLASS_BOUND_COLUMN_REF => { let name = unsafe { let ptr = cpp::duckdb_vx_expr_get_bound_column_ref_get_name(self.as_ptr()); @@ -165,10 +171,15 @@ pub enum ExpressionClass<'a> { BoundBetween(BoundBetween<'a>), BoundOperator(BoundOperator<'a>), BoundFunction(BoundFunction<'a>), + BoundCast(BoundCast<'a>), /// Column inside ExpressionFilter for expression pushed down to Vortex. BoundRef, } +pub struct BoundCast<'a> { + pub child: &'a ExpressionRef, +} + pub struct BoundColumnRef { pub name: DDBString, } diff --git a/vortex-duckdb/src/duckdb/logical_type.rs b/vortex-duckdb/src/duckdb/logical_type.rs index 7a4627e6953..e19ca438d0e 100644 --- a/vortex-duckdb/src/duckdb/logical_type.rs +++ b/vortex-duckdb/src/duckdb/logical_type.rs @@ -33,6 +33,7 @@ use crate::cpp::duckdb_union_type_member_count; use crate::cpp::duckdb_union_type_member_name; use crate::cpp::duckdb_union_type_member_type; use crate::cpp::duckdb_vx_create_geometry; +use crate::cpp::duckdb_vx_logical_type_contains_128bit; use crate::cpp::duckdb_vx_logical_type_copy; use crate::cpp::duckdb_vx_logical_type_stringify; use crate::cpp::idx_t; @@ -153,6 +154,14 @@ impl LogicalType { Self::new(DUCKDB_TYPE::DUCKDB_TYPE_BLOB) } + pub fn uint8() -> Self { + Self::new(DUCKDB_TYPE::DUCKDB_TYPE_UTINYINT) + } + + pub fn uint16() -> Self { + Self::new(DUCKDB_TYPE::DUCKDB_TYPE_USMALLINT) + } + pub fn uint32() -> Self { Self::new(DUCKDB_TYPE::DUCKDB_TYPE_UINTEGER) } @@ -165,6 +174,14 @@ impl LogicalType { Self::new(DUCKDB_TYPE::DUCKDB_TYPE_UHUGEINT) } + pub fn int8() -> Self { + Self::new(DUCKDB_TYPE::DUCKDB_TYPE_TINYINT) + } + + pub fn int16() -> Self { + Self::new(DUCKDB_TYPE::DUCKDB_TYPE_SMALLINT) + } + pub fn int32() -> Self { Self::new(DUCKDB_TYPE::DUCKDB_TYPE_INTEGER) } @@ -234,6 +251,11 @@ impl LogicalTypeRef { matches!(self.as_type_id(), DUCKDB_TYPE::DUCKDB_TYPE_DECIMAL) } + /// true if T is [U]HUGEINT or some child type e.g. LIST(T) contains these + pub fn contains_128bit(&self) -> bool { + unsafe { duckdb_vx_logical_type_contains_128bit(self.as_ptr()) } + } + pub fn geometry_crs(&self) -> Option { unsafe { let c_string = duckdb_geometry_type_get_crs(self.as_ptr()); diff --git a/vortex-sqllogictest/Cargo.toml b/vortex-sqllogictest/Cargo.toml index 9cc2f4c63e9..efa226cda0f 100644 --- a/vortex-sqllogictest/Cargo.toml +++ b/vortex-sqllogictest/Cargo.toml @@ -25,6 +25,7 @@ rstest = { workspace = true } sqllogictest = "0.29.1" thiserror = { workspace = true } tokio = { workspace = true, features = ["full"] } +tracing-subscriber = { workspace = true, features = ["env-filter"] } vortex = { workspace = true, features = ["tokio"] } vortex-datafusion = { workspace = true } vortex-duckdb = { workspace = true } diff --git a/vortex-sqllogictest/bin/sqllogictests-runner.rs b/vortex-sqllogictest/bin/sqllogictests-runner.rs index 8f99ff4ec49..4eb0feafc68 100644 --- a/vortex-sqllogictest/bin/sqllogictests-runner.rs +++ b/vortex-sqllogictest/bin/sqllogictests-runner.rs @@ -194,6 +194,11 @@ fn complete_files( } fn main() -> anyhow::Result { + drop( + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(), + ); let mut raw_args: Vec = std::env::args().collect(); // We remove the `--complete` flag that isn't standard before we pass the rest. let complete = { diff --git a/vortex-sqllogictest/slt/duckdb/cast_pushdown.slt b/vortex-sqllogictest/slt/duckdb/cast_pushdown.slt new file mode 100644 index 00000000000..a2e5a9cfe92 --- /dev/null +++ b/vortex-sqllogictest/slt/duckdb/cast_pushdown.slt @@ -0,0 +1,308 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright the Vortex contributors + +include ../setup.slt.no + +statement ok +COPY (SELECT '2020-01-01 02:20:30' ts) TO '$__TEST_DIR__/cast_pushdown-date.vortex'; + +# WHERE cast is pushed down to reader +query TT +EXPLAIN SELECT * FROM '$__TEST_DIR__/cast_pushdown-date.vortex' +WHERE ts::DATE = '2020-01-01'; +---- +:FILTER + +statement ok +COPY (SELECT * FROM ( + VALUES (1, 0),(1, 0),(1, 0),(1, 0),(1, 0),(1, 0),(2, 0),(3, 0),(3, 0),(3, 1) +) AS t(column00, column01)) +TO '$__TEST_DIR__/cast_pushdown.vortex'; + +# Column is used uncasted +query TT +EXPLAIN SELECT * FROM '$__TEST_DIR__/cast_pushdown.vortex'; +---- +:.*PROJECTION.* + +query I +SELECT column00 FROM '$__TEST_DIR__/cast_pushdown.vortex' ORDER BY 1; +---- +1 +1 +1 +1 +1 +1 +2 +3 +3 +3 + +# Column is used casted +query TT +EXPLAIN SELECT column00::UTINYINT FROM '$__TEST_DIR__/cast_pushdown.vortex'; +---- +:.*PROJECTION.* + +query I +SELECT column00::UTINYINT FROM '$__TEST_DIR__/cast_pushdown.vortex' ORDER BY 1; +---- +1 +1 +1 +1 +1 +1 +2 +3 +3 +3 + +# Column is used casted and uncasted +query TT +EXPLAIN SELECT column00, column00::UTINYINT FROM '$__TEST_DIR__/cast_pushdown.vortex'; +---- +:.*PROJECTION.* + +query II +SELECT column00, column00::UTINYINT FROM '$__TEST_DIR__/cast_pushdown.vortex' ORDER BY 1, 2; +---- +1 1 +1 1 +1 1 +1 1 +1 1 +1 1 +2 2 +3 3 +3 3 +3 3 + +# Column is used uncasted with filter +query TT +EXPLAIN SELECT column00 FROM '$__TEST_DIR__/cast_pushdown.vortex' +WHERE column00 > 0; +---- +:.*PROJECTION.* + +query I +SELECT column00 FROM '$__TEST_DIR__/cast_pushdown.vortex' +WHERE column00 > 0 ORDER BY 1; +---- +1 +1 +1 +1 +1 +1 +2 +3 +3 +3 + +# Column is used uncasted with filter on casted. +# Cast is pushed in WHERE separately +query TT +EXPLAIN SELECT column00 FROM '$__TEST_DIR__/cast_pushdown.vortex' +WHERE column00::UTINYINT > 0; +---- +:.*FILTER.* + +query I +SELECT column00 FROM '$__TEST_DIR__/cast_pushdown.vortex' +WHERE column00::UTINYINT > 0 ORDER BY 1; +---- +1 +1 +1 +1 +1 +1 +2 +3 +3 +3 + +# Same cast in SELECT and ORDER BY: no conflict +query TT +EXPLAIN SELECT column00::UTINYINT +FROM '$__TEST_DIR__/cast_pushdown.vortex' +ORDER BY column00::UTINYINT; +---- +:.*PROJECTION.* + +query I +SELECT column00::UTINYINT +FROM '$__TEST_DIR__/cast_pushdown.vortex' +ORDER BY column00::UTINYINT; +---- +1 +1 +1 +1 +1 +1 +2 +3 +3 +3 + +# Two different casts of the same column: conflict +query TT +EXPLAIN SELECT column00::UTINYINT, column00::USMALLINT +FROM '$__TEST_DIR__/cast_pushdown.vortex'; +---- +:.*PROJECTION.* + +query II +SELECT column00::UTINYINT, column00::USMALLINT +FROM '$__TEST_DIR__/cast_pushdown.vortex' ORDER BY 1, 2; +---- +1 1 +1 1 +1 1 +1 1 +1 1 +1 1 +2 2 +3 3 +3 3 +3 3 + +# TRY_CAST: column is only used with try_cast +query TT +EXPLAIN SELECT TRY_CAST(column00 AS UTINYINT) FROM '$__TEST_DIR__/cast_pushdown.vortex'; +---- +:.*PROJECTION.* + +query I +SELECT TRY_CAST(column00 AS UTINYINT) FROM '$__TEST_DIR__/cast_pushdown.vortex' ORDER BY 1; +---- +1 +1 +1 +1 +1 +1 +2 +3 +3 +3 + +# TRY_CAST: column used uncasted with try_cast filter. +query TT +EXPLAIN SELECT column00 FROM '$__TEST_DIR__/cast_pushdown.vortex' +WHERE TRY_CAST(column00 AS UTINYINT) > 0; +---- +:.*FILTER.* + +query I +SELECT column00 FROM '$__TEST_DIR__/cast_pushdown.vortex' +WHERE TRY_CAST(column00 AS UTINYINT) > 0 ORDER BY 1; +---- +1 +1 +1 +1 +1 +1 +2 +3 +3 +3 + +# TRY_CAST: same try_cast in SELECT and WHERE; type pushed down so filter uses column index +query TT +EXPLAIN SELECT TRY_CAST(column00 AS UTINYINT) +FROM '$__TEST_DIR__/cast_pushdown.vortex' +WHERE TRY_CAST(column00 AS UTINYINT) > 0; +---- +:.*CAST.* + +query I +SELECT TRY_CAST(column00 AS UTINYINT) +FROM '$__TEST_DIR__/cast_pushdown.vortex' +WHERE TRY_CAST(column00 AS UTINYINT) > 0 ORDER BY 1; +---- +1 +1 +1 +1 +1 +1 +2 +3 +3 +3 + +# CAST and TRY_CAST of the same target type: no conflict, both pushed down; +# PROJECTION remains only to duplicate the column, no cast expressions in the plan +query TT +EXPLAIN SELECT column00::UTINYINT, TRY_CAST(column00 AS UTINYINT) +FROM '$__TEST_DIR__/cast_pushdown.vortex'; +---- +:.*CAST.* + +query II +SELECT column00::UTINYINT, TRY_CAST(column00 AS UTINYINT) +FROM '$__TEST_DIR__/cast_pushdown.vortex' ORDER BY 1, 2; +---- +1 1 +1 1 +1 1 +1 1 +1 1 +1 1 +2 2 +3 3 +3 3 +3 3 + +# CAST and TRY_CAST of different target types: conflict, pushdown blocked +query TT +EXPLAIN SELECT column00::UTINYINT, TRY_CAST(column00 AS USMALLINT) +FROM '$__TEST_DIR__/cast_pushdown.vortex'; +---- +:.*PROJECTION.* + +query II +SELECT column00::UTINYINT, TRY_CAST(column00 AS USMALLINT) +FROM '$__TEST_DIR__/cast_pushdown.vortex' ORDER BY 1, 2; +---- +1 1 +1 1 +1 1 +1 1 +1 1 +1 1 +2 2 +3 3 +3 3 +3 3 + +# i128 and u128 casts are not allowed as Vortex doesn't support these types + +query TT +EXPLAIN SELECT column00::HUGEINT FROM '$__TEST_DIR__/cast_pushdown.vortex'; +---- +:.*PROJECTION.* + +query TT +EXPLAIN SELECT column00::UHUGEINT FROM '$__TEST_DIR__/cast_pushdown.vortex'; +---- +:.*PROJECTION.* + +query TT +EXPLAIN SELECT column00 +FROM '$__TEST_DIR__/cast_pushdown.vortex' +WHERE column01::HUGEINT >= column00; +---- +:.*FILTER.* + +query TT +EXPLAIN SELECT column01 +FROM '$__TEST_DIR__/cast_pushdown.vortex' +WHERE column00::UHUGEINT >= column01 +---- +:.*FILTER.* diff --git a/vortex-sqllogictest/slt/duckdb/projection_expression_pushdown.slt b/vortex-sqllogictest/slt/duckdb/projection_expression_pushdown.slt index 9c431097bd7..36e099004cd 100644 --- a/vortex-sqllogictest/slt/duckdb/projection_expression_pushdown.slt +++ b/vortex-sqllogictest/slt/duckdb/projection_expression_pushdown.slt @@ -351,13 +351,13 @@ ORDER BY strlen(str); 3 5 -# prefix isn't pushed down as a complex filter so WHERE is not pushed down +# || isn't pushed down as a complex filter so WHERE is not pushed down # although this is only usage of str in SELECT, we can't push strlen down # as there is usage in WHERE. # This also tests functions with multiple arguments using "str" inside query I SELECT strlen(str) FROM '${WORK_DIR}/pe-pushdown.vortex' -WHERE prefix("str", 'H') > 0 +WHERE prefix(str || 'O', 'H') > 0 ORDER BY strlen(str); ---- 2 @@ -373,13 +373,6 @@ ORDER BY strlen(str); 3 5 -# explain: prefix()/suffix() in WHERE are multi-arg uses of str, no pushdown -query TT -EXPLAIN (FORMAT JSON) -SELECT strlen(str) FROM '${WORK_DIR}/pe-pushdown.vortex' WHERE prefix("str", 'H') > 0; ----- -:SELECT projections - # conflict with concat(), no pushdown query I SELECT strlen(str) FROM '${WORK_DIR}/pe-pushdown.vortex' diff --git a/vortex-sqllogictest/src/duckdb.rs b/vortex-sqllogictest/src/duckdb.rs index 1052781d430..cbf42feae7a 100644 --- a/vortex-sqllogictest/src/duckdb.rs +++ b/vortex-sqllogictest/src/duckdb.rs @@ -63,11 +63,15 @@ impl DuckDB { fn normalize_column_type(logical_type: &LogicalTypeRef) -> DFColumnType { let type_id = logical_type.as_type_id(); - if type_id == LogicalType::int32().as_type_id() + if type_id == LogicalType::int8().as_type_id() + || type_id == LogicalType::int16().as_type_id() + || type_id == LogicalType::int32().as_type_id() || type_id == LogicalType::int64().as_type_id() + || type_id == LogicalType::int128().as_type_id() + || type_id == LogicalType::uint8().as_type_id() + || type_id == LogicalType::uint16().as_type_id() || type_id == LogicalType::uint32().as_type_id() || type_id == LogicalType::uint64().as_type_id() - || type_id == LogicalType::int128().as_type_id() || type_id == LogicalType::uint128().as_type_id() { DFColumnType::Integer