diff --git a/cub/benchmarks/bench/partition/flagged.cu b/cub/benchmarks/bench/partition/flagged.cu index 162a0a97186..436ea152d31 100644 --- a/cub/benchmarks/bench/partition/flagged.cu +++ b/cub/benchmarks/bench/partition/flagged.cu @@ -33,20 +33,18 @@ # endif // TUNE_LOAD template -struct policy_hub_t +struct policy_selector { - struct policy_t : cub::ChainedPolicy<300, policy_t, policy_t> + [[nodiscard]] _CCCL_HOST_DEVICE constexpr auto operator()(cuda::compute_capability) const + -> cub::detail::select::select_if_policy { - using SelectIfPolicyT = - cub::AgentSelectIfPolicy; - }; - - using MaxPolicy = policy_t; + return {TUNE_THREADS_PER_BLOCK, + TUNE_ITEMS_PER_THREAD, + TUNE_LOAD_ALGORITHM, + TUNE_LOAD_MODIFIER, + cub::BLOCK_SCAN_WARP_SCANS, + delay_constructor_policy}; + } }; #endif // TUNE_BASE @@ -70,31 +68,11 @@ void init_output_partition_buffer(FlagsItT, OffsetT, T* d_out, T*& d_partition_o template void flagged(nvbench::state& state, nvbench::type_list) { - using input_it_t = const T*; - using flag_it_t = const bool*; - using num_selected_it_t = OffsetT*; - using select_op_t = cub::NullType; - using equality_op_t = cub::NullType; using offset_t = OffsetT; constexpr bool use_distinct_out_partitions = UseDistinctPartitionT::value; using output_it_t = typename ::cuda::std:: conditional, T*>::type; - using dispatch_t = cub::DispatchSelectIf< - input_it_t, - flag_it_t, - output_it_t, - num_selected_it_t, - select_op_t, - equality_op_t, - offset_t, - cub::SelectImpl::Partition -#if !TUNE_BASE - , - policy_hub_t -#endif // TUNE_BASE - >; - // Retrieve axis parameters const auto elements = static_cast(state.get_int64("Elements{io}")); const bit_entropy entropy = str_to_entropy(state.get_string("Entropy")); @@ -106,9 +84,9 @@ void flagged(nvbench::state& state, nvbench::type_list num_selected(1); thrust::device_vector out(elements); - input_it_t d_in = thrust::raw_pointer_cast(in.data()); - flag_it_t d_flags = thrust::raw_pointer_cast(flags.data()); - num_selected_it_t d_num_selected = thrust::raw_pointer_cast(num_selected.data()); + const T* d_in = thrust::raw_pointer_cast(in.data()); + const bool* d_flags = thrust::raw_pointer_cast(flags.data()); + offset_t* d_num_selected = thrust::raw_pointer_cast(num_selected.data()); output_it_t d_out{}; init_output_partition_buffer(flags.cbegin(), elements, thrust::raw_pointer_cast(out.data()), d_out); @@ -118,25 +96,25 @@ void flagged(nvbench::state& state, nvbench::type_list(elements); state.add_global_memory_writes(1); - std::size_t temp_size{}; - dispatch_t::Dispatch( - nullptr, temp_size, d_in, d_flags, d_out, d_num_selected, select_op_t{}, equality_op_t{}, elements, nullptr); - - thrust::device_vector temp(temp_size); - auto* temp_storage = thrust::raw_pointer_cast(temp.data()); - + caching_allocator_t alloc; state.exec(nvbench::exec_tag::gpu | nvbench::exec_tag::no_batch, [&](nvbench::launch& launch) { - dispatch_t::Dispatch( - temp_storage, - temp_size, + auto env = cub_bench_env( + alloc, + launch +#if !TUNE_BASE + , + cuda::execution::tune(policy_selector{}) +#endif // !TUNE_BASE + ); + _CCCL_TRY_CUDA_API( + cub::DevicePartition::Flagged, + "Flagged failed", d_in, d_flags, d_out, d_num_selected, - select_op_t{}, - equality_op_t{}, - elements, - launch.get_stream()); + static_cast(elements), + env); }); } diff --git a/cub/benchmarks/bench/partition/if.cu b/cub/benchmarks/bench/partition/if.cu index 5d9c531fbfa..49dcc53d4d6 100644 --- a/cub/benchmarks/bench/partition/if.cu +++ b/cub/benchmarks/bench/partition/if.cu @@ -33,20 +33,18 @@ # endif // TUNE_LOAD template -struct policy_hub_t +struct policy_selector { - struct policy_t : cub::ChainedPolicy<300, policy_t, policy_t> + [[nodiscard]] _CCCL_HOST_DEVICE constexpr auto operator()(cuda::compute_capability) const + -> cub::detail::select::select_if_policy { - using SelectIfPolicyT = - cub::AgentSelectIfPolicy; - }; - - using MaxPolicy = policy_t; + return {TUNE_THREADS_PER_BLOCK, + TUNE_ITEMS_PER_THREAD, + TUNE_LOAD_ALGORITHM, + TUNE_LOAD_MODIFIER, + cub::BLOCK_SCAN_WARP_SCANS, + delay_constructor_policy}; + } }; #endif // !TUNE_BASE @@ -71,31 +69,12 @@ void init_output_partition_buffer(InItT, OffsetT, T* d_out, SelectOpT, T*& d_par template void partition(nvbench::state& state, nvbench::type_list) { - using input_it_t = const T*; - using flag_it_t = cub::NullType*; - using num_selected_it_t = OffsetT*; using select_op_t = less_then_t; - using equality_op_t = cub::NullType; using offset_t = OffsetT; constexpr bool use_distinct_out_partitions = UseDistinctPartitionT::value; using output_it_t = typename ::cuda::std:: conditional, T*>::type; - using dispatch_t = cub::DispatchSelectIf< - input_it_t, - flag_it_t, - output_it_t, - num_selected_it_t, - select_op_t, - equality_op_t, - offset_t, - cub::SelectImpl::Partition -#if !TUNE_BASE - , - policy_hub_t -#endif // !TUNE_BASE - >; - // Retrieve axis parameters const auto elements = static_cast(state.get_int64("Elements{io}")); const bit_entropy entropy = str_to_entropy(state.get_string("Entropy")); @@ -108,9 +87,8 @@ void partition(nvbench::state& state, nvbench::type_list out(elements); - input_it_t d_in = thrust::raw_pointer_cast(in.data()); - flag_it_t d_flags = nullptr; - num_selected_it_t d_num_selected = thrust::raw_pointer_cast(num_selected.data()); + const T* d_in = thrust::raw_pointer_cast(in.data()); + offset_t* d_num_selected = thrust::raw_pointer_cast(num_selected.data()); output_it_t d_out{}; init_output_partition_buffer(in.cbegin(), elements, thrust::raw_pointer_cast(out.data()), select_op, d_out); @@ -119,25 +97,25 @@ void partition(nvbench::state& state, nvbench::type_list(elements); state.add_global_memory_writes(1); - std::size_t temp_size{}; - dispatch_t::Dispatch( - nullptr, temp_size, d_in, d_flags, d_out, d_num_selected, select_op, equality_op_t{}, elements, nullptr); - - thrust::device_vector temp(temp_size); - auto* temp_storage = thrust::raw_pointer_cast(temp.data()); - + caching_allocator_t alloc; state.exec(nvbench::exec_tag::gpu | nvbench::exec_tag::no_batch, [&](nvbench::launch& launch) { - dispatch_t::Dispatch( - temp_storage, - temp_size, + auto env = cub_bench_env( + alloc, + launch +#if !TUNE_BASE + , + cuda::execution::tune(policy_selector{}) +#endif // !TUNE_BASE + ); + _CCCL_TRY_CUDA_API( + cub::DevicePartition::If, + "If failed", d_in, - d_flags, d_out, d_num_selected, + static_cast(elements), select_op, - equality_op_t{}, - elements, - launch.get_stream()); + env); }); } diff --git a/cub/benchmarks/bench/partition/three_way.cu b/cub/benchmarks/bench/partition/three_way.cu index f9ace6b1a95..0475868b5e1 100644 --- a/cub/benchmarks/bench/partition/three_way.cu +++ b/cub/benchmarks/bench/partition/three_way.cu @@ -1,8 +1,7 @@ // SPDX-FileCopyrightText: Copyright (c) 2011-2023, NVIDIA CORPORATION. All rights reserved. // SPDX-License-Identifier: BSD-3 -#include -#include +#include #include #include @@ -16,7 +15,7 @@ #if !TUNE_BASE template -struct tuned_policy_selector_t +struct policy_selector { [[nodiscard]] _CCCL_HOST_DEVICE constexpr auto operator()(cuda::compute_capability) const -> cub::detail::three_way_partition::three_way_partition_policy @@ -26,7 +25,7 @@ struct tuned_policy_selector_t TUNE_TRANSPOSE == 0 ? cub::BLOCK_LOAD_DIRECT : cub::BLOCK_LOAD_WARP_TRANSPOSE, cub::LOAD_DEFAULT, cub::BLOCK_SCAN_WARP_SCANS, - cub::detail::delay_constructor_policy_from_type}; + delay_constructor_policy}; } }; #endif // !TUNE_BASE @@ -34,11 +33,8 @@ struct tuned_policy_selector_t template void partition(nvbench::state& state, nvbench::type_list) { - using input_it_t = const T*; - using output_it_t = T*; - using num_selected_it_t = OffsetT*; - using select_op_t = less_then_t; - using offset_t = OffsetT; + using select_op_t = less_then_t; + using offset_t = OffsetT; // Retrieve axis parameters const auto elements = static_cast(state.get_int64("Elements{io}")); @@ -54,50 +50,44 @@ void partition(nvbench::state& state, nvbench::type_list) select_op_t select_op_2{right_border}; thrust::device_vector in = generate(elements, entropy, min_val, max_val); - thrust::device_vector num_selected(1); + thrust::device_vector num_selected(2); thrust::device_vector out_1(elements); thrust::device_vector out_2(elements); thrust::device_vector out_3(elements); - input_it_t d_in = thrust::raw_pointer_cast(in.data()); - output_it_t d_out_1 = thrust::raw_pointer_cast(out_1.data()); - output_it_t d_out_2 = thrust::raw_pointer_cast(out_2.data()); - output_it_t d_out_3 = thrust::raw_pointer_cast(out_3.data()); - num_selected_it_t d_num_selected = thrust::raw_pointer_cast(num_selected.data()); + const T* d_in = thrust::raw_pointer_cast(in.data()); + T* d_out_1 = thrust::raw_pointer_cast(out_1.data()); + T* d_out_2 = thrust::raw_pointer_cast(out_2.data()); + T* d_out_3 = thrust::raw_pointer_cast(out_3.data()); + offset_t* d_num_selected = thrust::raw_pointer_cast(num_selected.data()); state.add_element_count(elements); state.add_global_memory_reads(elements); state.add_global_memory_writes(elements); - state.add_global_memory_writes(1); + state.add_global_memory_writes(2); - std::size_t temp_size{}; - auto dispatch = [&](void* temp_storage, cudaStream_t stream) { - return cub::detail::three_way_partition::dispatch( - temp_storage, - temp_size, + caching_allocator_t alloc; + state.exec(nvbench::exec_tag::gpu | nvbench::exec_tag::no_batch, [&](nvbench::launch& launch) { + auto env = cub_bench_env( + alloc, + launch +#if !TUNE_BASE + , + cuda::execution::tune(policy_selector{}) +#endif // !TUNE_BASE + ); + _CCCL_TRY_CUDA_API( + cub::DevicePartition::If, + "If three-way failed", d_in, d_out_1, d_out_2, d_out_3, d_num_selected, + static_cast(elements), select_op_1, select_op_2, - static_cast(elements), - stream -#if !TUNE_BASE - , - policy_selector_t{} -#endif // !TUNE_BASE - ); - }; - - dispatch(nullptr, nullptr); - - thrust::device_vector temp(temp_size, thrust::no_init); - auto* temp_storage = thrust::raw_pointer_cast(temp.data()); - - state.exec(nvbench::exec_tag::gpu | nvbench::exec_tag::no_batch, [&](nvbench::launch& launch) { - dispatch(temp_storage, launch.get_stream()); + env); }); } diff --git a/cub/cub/device/device_partition.cuh b/cub/cub/device/device_partition.cuh index cfc5073383b..0a7c341c646 100644 --- a/cub/cub/device/device_partition.cuh +++ b/cub/cub/device/device_partition.cuh @@ -54,46 +54,6 @@ CUB_NAMESPACE_BEGIN //! @endrst struct DevicePartition { -private: - template - CUB_RUNTIME_FUNCTION static cudaError_t partition_impl( - void* d_temp_storage, - size_t& temp_storage_bytes, - InputIteratorT d_in, - FlagIteratorT d_flags, - OutputIteratorT d_out, - NumSelectedIteratorT d_num_selected_out, - OffsetT num_items, - SelectOpT select_op, - cudaStream_t stream) - { - using default_policy_selector = detail::select:: - policy_selector_from_types; - using policy_selector = - ::cuda::std::execution::__query_result_or_t; - - using EqualityOp = NullType; - - return detail::select::dispatch( - d_temp_storage, - temp_storage_bytes, - d_in, - d_flags, - d_out, - d_num_selected_out, - select_op, - EqualityOp{}, - static_cast(num_items), - stream, - policy_selector{}); - } - public: //! @rst //! Uses the ``d_flags`` sequence to split the corresponding items from @@ -321,21 +281,32 @@ public: { _CCCL_NVTX_RANGE_SCOPE("cub::DevicePartition::Flagged"); - using ChooseOffsetT = detail::choose_signed_offset; - using offset_t = typename ChooseOffsetT::type; + using choose_offset_t = detail::choose_signed_offset; + using offset_t = typename choose_offset_t::type; + using default_policy_selector = detail::select:: + policy_selector_from_types; // Check if the number of items exceeds the range covered by the selected signed offset type - if (const cudaError_t error = ChooseOffsetT::is_exceeding_offset_type(num_items)) + if (const auto error = choose_offset_t::is_exceeding_offset_type(num_items)) { return error; } - // Dispatch with environment - handles all boilerplate - return detail::dispatch_with_env(env, [&]([[maybe_unused]] auto tuning, void* storage, size_t& bytes, auto stream) { - using tuning_t = decltype(tuning); - return partition_impl( - storage, bytes, d_in, d_flags, d_out, d_num_selected_out, static_cast(num_items), NullType{}, stream); - }); + return detail::dispatch_with_env_and_tuning( + env, [&](auto policy_selector, void* storage, size_t& bytes, auto stream) { + return detail::select::dispatch( + storage, + bytes, + d_in, + d_flags, + d_out, + d_num_selected_out, + NullType{}, + NullType{}, + static_cast(num_items), + stream, + policy_selector); + }); } //! @rst @@ -572,92 +543,34 @@ public: { _CCCL_NVTX_RANGE_SCOPE("cub::DevicePartition::If"); - using ChooseOffsetT = detail::choose_signed_offset; - using offset_t = typename ChooseOffsetT::type; - - // Check if the number of items exceeds the range covered by the selected signed offset type - if (const cudaError_t error = ChooseOffsetT::is_exceeding_offset_type(num_items)) - { - return error; - } - - // Dispatch with environment - handles all boilerplate - return detail::dispatch_with_env(env, [&]([[maybe_unused]] auto tuning, void* storage, size_t& bytes, auto stream) { - using tuning_t = decltype(tuning); - return partition_impl( - storage, - bytes, - d_in, - static_cast(nullptr), - d_out, - d_num_selected_out, - static_cast(num_items), - select_op, - stream); - }); - } - -private: - template - friend class DispatchSegmentedSort; - - // Internal version without NVTX range - template - CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t IfNoNVTX( - void* d_temp_storage, - size_t& temp_storage_bytes, - InputIteratorT d_in, - FirstOutputIteratorT d_first_part_out, - SecondOutputIteratorT d_second_part_out, - UnselectedOutputIteratorT d_unselected_out, - NumSelectedIteratorT d_num_selected_out, - NumItemsT num_items, - SelectFirstPartOp select_first_part_op, - SelectSecondPartOp select_second_part_op, - cudaStream_t stream = nullptr) - { - using ChooseOffsetT = detail::choose_signed_offset; - using OffsetT = typename ChooseOffsetT::type; + using choose_offset_t = detail::choose_signed_offset; + using offset_t = typename choose_offset_t::type; + using default_policy_selector = detail::select:: + policy_selector_from_types; - // Signed integer type for global offsets // Check if the number of items exceeds the range covered by the selected signed offset type - if (const auto error = ChooseOffsetT::is_exceeding_offset_type(num_items)) + if (const auto error = choose_offset_t::is_exceeding_offset_type(num_items)) { return error; } - return detail::three_way_partition::dispatch( - d_temp_storage, - temp_storage_bytes, - d_in, - d_first_part_out, - d_second_part_out, - d_unselected_out, - d_num_selected_out, - select_first_part_op, - select_second_part_op, - static_cast(num_items), - stream); + return detail::dispatch_with_env_and_tuning( + env, [&](auto policy_selector, void* storage, size_t& bytes, auto stream) { + return detail::select::dispatch( + storage, + bytes, + d_in, + static_cast(nullptr), + d_out, + d_num_selected_out, + select_op, + NullType{}, + static_cast(num_items), + stream, + policy_selector); + }); } -public: //! @rst //! Uses two functors to split the corresponding items from ``d_in`` into a three partitioned sequences //! ``d_first_part_out``, ``d_second_part_out``, and ``d_unselected_out``. @@ -865,7 +778,16 @@ public: cudaStream_t stream = nullptr) { _CCCL_NVTX_RANGE_SCOPE_IF(d_temp_storage, "cub::DevicePartition::If"); - return IfNoNVTX( + using choose_offset_t = detail::choose_signed_offset; + + // Check if the number of items exceeds the range covered by the selected signed offset type + if (const auto error = choose_offset_t::is_exceeding_offset_type(num_items)) + { + return error; + } + + using offset_t = typename choose_offset_t::type; + return detail::three_way_partition::dispatch( d_temp_storage, temp_storage_bytes, d_in, @@ -873,9 +795,9 @@ public: d_second_part_out, d_unselected_out, d_num_selected_out, - num_items, select_first_part_op, select_second_part_op, + static_cast(num_items), stream); } @@ -999,20 +921,33 @@ public: { _CCCL_NVTX_RANGE_SCOPE("cub::DevicePartition::If"); - return detail::dispatch_with_env(env, [&]([[maybe_unused]] auto tuning, void* storage, size_t& bytes, auto stream) { - return IfNoNVTX( - storage, - bytes, - d_in, - d_first_part_out, - d_second_part_out, - d_unselected_out, - d_num_selected_out, - num_items, - select_first_part_op, - select_second_part_op, - stream); - }); + using choose_offset_t = detail::choose_signed_offset; + if (const auto error = choose_offset_t::is_exceeding_offset_type(num_items)) + { + return error; + } + + using offset_t = typename choose_offset_t::type; + using default_policy_selector = + detail::three_way_partition::policy_selector_from_types, + detail::three_way_partition::per_partition_offset_t>; + + return detail::dispatch_with_env_and_tuning( + env, [&](auto policy_selector, void* storage, size_t& bytes, auto stream) { + return detail::three_way_partition::dispatch( + storage, + bytes, + d_in, + d_first_part_out, + d_second_part_out, + d_unselected_out, + d_num_selected_out, + select_first_part_op, + select_second_part_op, + static_cast(num_items), + stream, + policy_selector); + }); } }; diff --git a/cub/test/catch2_test_device_partition_env.cu b/cub/test/catch2_test_device_partition_env.cu index 187c0d27eb0..b2a0c5c339d 100644 --- a/cub/test/catch2_test_device_partition_env.cu +++ b/cub/test/catch2_test_device_partition_env.cu @@ -9,8 +9,10 @@ struct stream_registry_factory_t; #include +#include #include +#include #include #include "catch2_test_device_select_common.cuh" @@ -269,3 +271,119 @@ TEST_CASE("Device partition uses custom stream", "[partition][device]") REQUIRE(cudaSuccess == cudaStreamDestroy(custom_stream)); } + +#if TEST_LAUNCH != 1 + +struct less_than_5_t +{ + __host__ __device__ bool operator()(int val) const + { + return val < 5; + } +}; + +template +struct partition_policy_selector +{ + _CCCL_API constexpr auto operator()(cuda::compute_capability) const -> cub::detail::select::select_if_policy + { + return {static_cast(BlockThreads), + 10, + cub::BLOCK_LOAD_DIRECT, + cub::LOAD_DEFAULT, + cub::BLOCK_SCAN_WARP_SCANS, + cub::detail::delay_constructor_policy{cub::detail::delay_constructor_kind::fixed_delay, 350, 450}}; + } +}; + +template +struct three_way_partition_policy_selector +{ + _CCCL_API constexpr auto operator()(cuda::compute_capability) const + -> cub::detail::three_way_partition::three_way_partition_policy + { + return {static_cast(BlockThreads), + 10, + cub::BLOCK_LOAD_DIRECT, + cub::LOAD_DEFAULT, + cub::BLOCK_SCAN_WARP_SCANS, + cub::detail::delay_constructor_policy{cub::detail::delay_constructor_kind::fixed_delay, 350, 450}}; + } +}; + +using block_sizes = + c2h::type_list, cuda::std::integral_constant>; + +C2H_TEST("DevicePartition::If can be tuned", "[partition][device]", block_sizes) +{ + constexpr unsigned int target_block_size = c2h::get<0, TestType>::value; + auto d_in = c2h::device_vector{1, 2, 3, 4, 5, 6, 7, 8}; + auto d_out = c2h::device_vector(8); + auto d_num_selected = c2h::device_vector(1); + auto d_block_size = c2h::device_vector(1); + + block_size_extracting_op select_op{thrust::raw_pointer_cast(d_block_size.data())}; + + auto env = cuda::execution::tune(partition_policy_selector{}); + + device_partition_if(d_in.begin(), d_out.begin(), d_num_selected.begin(), 8, select_op, env); + REQUIRE(d_num_selected[0] == 4); + REQUIRE(d_block_size[0] == target_block_size); +} + +C2H_TEST("DevicePartition::Flagged can be tuned", "[partition][device]", block_sizes) +{ + constexpr unsigned int target_block_size = c2h::get<0, TestType>::value; + auto d_in = c2h::device_vector{1, 2, 3, 4, 5, 6, 7, 8}; + auto d_out = c2h::device_vector(8); + auto d_num_selected = c2h::device_vector(1); + auto d_block_size = c2h::device_vector(1); + + block_size_extracting_constant_iterator flags_begin(1, thrust::raw_pointer_cast(d_block_size.data())); + + auto env = cuda::execution::tune(partition_policy_selector{}); + + device_partition_flagged(d_in.begin(), flags_begin, d_out.begin(), d_num_selected.begin(), 8, env); + REQUIRE(d_num_selected[0] == 8); + REQUIRE(d_block_size[0] == target_block_size); +} + +struct less_than_7_t +{ + __host__ __device__ bool operator()(int val) const + { + return val < 7; + } +}; + +C2H_TEST("DevicePartition::If three-way can be tuned", "[partition][device]", block_sizes) +{ + constexpr unsigned int target_block_size = c2h::get<0, TestType>::value; + auto d_in = c2h::device_vector{0, 2, 3, 9, 5, 2, 81, 8}; + auto d_small_out = c2h::device_vector(8); + auto d_large_out = c2h::device_vector(8); + auto d_unselected_out = c2h::device_vector(8); + auto d_num_selected = c2h::device_vector(2); + auto d_block_size = c2h::device_vector(1); + + block_size_extracting_op small_selector{thrust::raw_pointer_cast(d_block_size.data())}; + greater_than_t large_selector{50}; + + auto env = cuda::execution::tune(three_way_partition_policy_selector{}); + + device_partition_if( + d_in.begin(), + d_small_out.begin(), + d_large_out.begin(), + d_unselected_out.begin(), + d_num_selected.begin(), + static_cast(d_in.size()), + small_selector, + large_selector, + env); + REQUIRE(d_num_selected[0] == 5); + REQUIRE(d_num_selected[1] == 1); + REQUIRE(d_block_size[0] == target_block_size); +} + +#endif // TEST_LAUNCH != 1 diff --git a/thrust/thrust/system/cuda/detail/partition.h b/thrust/thrust/system/cuda/detail/partition.h index a3bd215277c..30c0eb011c6 100644 --- a/thrust/thrust/system/cuda/detail/partition.h +++ b/thrust/thrust/system/cuda/detail/partition.h @@ -43,97 +43,81 @@ namespace cuda_cub namespace detail { template -struct DispatchPartitionIf +[[nodiscard]] cudaError_t THRUST_RUNTIME_FUNCTION dispatch_partition( + execution_policy& policy, + void* d_temp_storage, + size_t& temp_storage_bytes, + InputIt first, + StencilIt stencil, + OutputIt output, + Predicate predicate, + OffsetT num_items, + std::size_t& num_selected) { - static cudaError_t THRUST_RUNTIME_FUNCTION dispatch( - execution_policy& policy, - void* d_temp_storage, - size_t& temp_storage_bytes, - InputIt first, - StencilIt stencil, - OutputIt output, - Predicate predicate, - OffsetT num_items, - std::size_t& num_selected) + // FIXME(bgruber): the call to the dispatch function should be replaced by a public CUB API, but there is currently no + // exposure of a partition with stencil and predicate (a `DevicePartition:FlaggedIf`) + + using equality_op_t = cub::NullType; + + cudaStream_t stream = cuda_cub::stream(policy); + + std::size_t allocation_sizes[2] = {0, sizeof(OffsetT)}; + void* allocations[2] = {nullptr, nullptr}; + + // Query algorithm memory requirements + cudaError_t status = cub::detail::select::dispatch( + nullptr, + allocation_sizes[0], + first, + stencil, + output, + static_cast(nullptr), + predicate, + equality_op_t{}, + num_items, + stream); + _CUDA_CUB_RET_IF_FAIL(status); + + status = cub::detail::alias_temporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes); + _CUDA_CUB_RET_IF_FAIL(status); + + // Return if we're only querying temporary storage requirements + if (d_temp_storage == nullptr) { - using num_selected_out_it_t = OffsetT*; - using equality_op_t = cub::NullType; - - cudaError_t status = cudaSuccess; - cudaStream_t stream = cuda_cub::stream(policy); - - std::size_t allocation_sizes[2] = {0, sizeof(OffsetT)}; - void* allocations[2] = {nullptr, nullptr}; - - // Query algorithm memory requirements - status = cub::DispatchSelectIf< - InputIt, - StencilIt, - OutputIt, - num_selected_out_it_t, - Predicate, - equality_op_t, - OffsetT, - cub::SelectImpl::Partition>::Dispatch(nullptr, - allocation_sizes[0], - first, - stencil, - output, - static_cast(nullptr), - predicate, - equality_op_t{}, - num_items, - stream); - _CUDA_CUB_RET_IF_FAIL(status); - - status = cub::detail::alias_temporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes); - _CUDA_CUB_RET_IF_FAIL(status); - - // Return if we're only querying temporary storage requirements - if (d_temp_storage == nullptr) - { - return status; - } - - // Return for empty problems - if (num_items == 0) - { - num_selected = 0; - return status; - } - - // Memory allocation for the number of selected output items - OffsetT* d_num_selected_out = thrust::detail::aligned_reinterpret_cast(allocations[1]); - - // Run algorithm - status = cub::DispatchSelectIf< - InputIt, - StencilIt, - OutputIt, - num_selected_out_it_t, - Predicate, - equality_op_t, - OffsetT, - cub::SelectImpl::Partition>::Dispatch(allocations[0], - allocation_sizes[0], - first, - stencil, - output, - d_num_selected_out, - predicate, - equality_op_t{}, - num_items, - stream); - _CUDA_CUB_RET_IF_FAIL(status); - - // Get number of selected items - status = cuda_cub::synchronize(policy); - _CUDA_CUB_RET_IF_FAIL(status); - num_selected = static_cast(get_value(policy, d_num_selected_out)); + return status; + } + // Return for empty problems + if (num_items == 0) + { + num_selected = 0; return status; } -}; + + // Memory allocation for the number of selected output items + OffsetT* d_num_selected_out = thrust::detail::aligned_reinterpret_cast(allocations[1]); + + // Run algorithm + status = cub::detail::select::dispatch( + allocations[0], + allocation_sizes[0], + first, + stencil, + output, + d_num_selected_out, + predicate, + equality_op_t{}, + num_items, + stream); + _CUDA_CUB_RET_IF_FAIL(status); + + // Get number of selected items + status = cuda_cub::synchronize(policy); + _CUDA_CUB_RET_IF_FAIL(status); + num_selected = static_cast(get_value(policy, d_num_selected_out)); + + return status; +} template THRUST_RUNTIME_FUNCTION std::size_t partition( @@ -144,24 +128,15 @@ THRUST_RUNTIME_FUNCTION std::size_t partition( OutputIt output, Predicate predicate) { - using size_type = thrust::detail::it_difference_t; - - size_type num_items = ::cuda::std::distance(first, last); + const auto num_items = ::cuda::std::distance(first, last); std::size_t num_selected{}; cudaError_t status = cudaSuccess; size_t temp_storage_bytes = 0; - // 32-bit offset-type dispatch - using dispatch32_t = DispatchPartitionIf; - - // 64-bit offset-type dispatch - using dispatch64_t = DispatchPartitionIf; - // Query temporary storage requirements - THRUST_INDEX_TYPE_DISPATCH2( + THRUST_INDEX_TYPE_DISPATCH( status, - dispatch32_t::dispatch, - dispatch64_t::dispatch, + dispatch_partition, num_items, (policy, nullptr, temp_storage_bytes, first, stencil, output, predicate, num_items_fixed, num_selected)); cuda_cub::throw_on_error(status, "partition failed on 1st step"); @@ -171,10 +146,9 @@ THRUST_RUNTIME_FUNCTION std::size_t partition( void* temp_storage = static_cast(tmp.data().get()); // Run algorithm - THRUST_INDEX_TYPE_DISPATCH2( + THRUST_INDEX_TYPE_DISPATCH( status, - dispatch32_t::dispatch, - dispatch64_t::dispatch, + dispatch_partition, num_items, (policy, temp_storage, temp_storage_bytes, first, stencil, output, predicate, num_items_fixed, num_selected)); cuda_cub::throw_on_error(status, "partition failed on 2nd step");