From 2e33c08d743093a4b76ee1243e0c8b74cbab795a Mon Sep 17 00:00:00 2001 From: Bernhard Manfred Gruber Date: Wed, 13 May 2026 23:50:36 +0200 Subject: [PATCH] Refactor extraction of tuning policy selectors --- cub/cub/device/device_adjacent_difference.cuh | 38 +-- cub/cub/device/device_merge.cuh | 17 +- cub/cub/device/device_merge_sort.cuh | 255 ++++++++++-------- .../dispatch/dispatch_adjacent_difference.cuh | 19 +- cub/cub/device/dispatch/dispatch_merge.cuh | 21 +- 5 files changed, 182 insertions(+), 168 deletions(-) diff --git a/cub/cub/device/device_adjacent_difference.cuh b/cub/cub/device/device_adjacent_difference.cuh index 48aaa4f99e7..6df8eb399d9 100644 --- a/cub/cub/device/device_adjacent_difference.cuh +++ b/cub/cub/device/device_adjacent_difference.cuh @@ -635,11 +635,9 @@ struct DeviceAdjacentDifference { _CCCL_NVTX_RANGE_SCOPE("cub::DeviceAdjacentDifference::SubtractLeftCopy"); - using OffsetT = detail::choose_offset_t; - using default_policy_selector = detail::adjacent_difference::policy_selector_from_types; - - return detail::dispatch_with_env_and_tuning( - env, [&](auto policy_selector, void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t stream) { + using OffsetT = detail::choose_offset_t; + return detail::dispatch_with_env( + env, [&](auto tuning_env, void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t stream) { return detail::adjacent_difference::dispatch( d_temp_storage, temp_storage_bytes, @@ -648,7 +646,7 @@ struct DeviceAdjacentDifference static_cast(num_items), difference_op, stream, - policy_selector); + tuning_env); }); } @@ -721,11 +719,8 @@ struct DeviceAdjacentDifference _CCCL_NVTX_RANGE_SCOPE("cub::DeviceAdjacentDifference::SubtractLeft"); using OffsetT = detail::choose_offset_t; - using default_policy_selector = - detail::adjacent_difference::policy_selector_from_types; - - return detail::dispatch_with_env_and_tuning( - env, [&](auto policy_selector, void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t stream) { + return detail::dispatch_with_env( + env, [&](auto tuning_env, void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t stream) { return detail::adjacent_difference::dispatch( d_temp_storage, temp_storage_bytes, @@ -734,7 +729,7 @@ struct DeviceAdjacentDifference static_cast(num_items), difference_op, stream, - policy_selector); + tuning_env); }); } @@ -817,11 +812,9 @@ struct DeviceAdjacentDifference { _CCCL_NVTX_RANGE_SCOPE("cub::DeviceAdjacentDifference::SubtractRightCopy"); - using OffsetT = detail::choose_offset_t; - using default_policy_selector = detail::adjacent_difference::policy_selector_from_types; - - return detail::dispatch_with_env_and_tuning( - env, [&](auto policy_selector, void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t stream) { + using OffsetT = detail::choose_offset_t; + return detail::dispatch_with_env( + env, [&](auto tuning_env, void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t stream) { return detail::adjacent_difference::dispatch( d_temp_storage, temp_storage_bytes, @@ -830,7 +823,7 @@ struct DeviceAdjacentDifference static_cast(num_items), difference_op, stream, - policy_selector); + tuning_env); }); } @@ -903,11 +896,8 @@ struct DeviceAdjacentDifference _CCCL_NVTX_RANGE_SCOPE("cub::DeviceAdjacentDifference::SubtractRight"); using OffsetT = detail::choose_offset_t; - using default_policy_selector = - detail::adjacent_difference::policy_selector_from_types; - - return detail::dispatch_with_env_and_tuning( - env, [&](auto policy_selector, void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t stream) { + return detail::dispatch_with_env( + env, [&](auto tuning_env, void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t stream) { return detail::adjacent_difference::dispatch( d_temp_storage, temp_storage_bytes, @@ -916,7 +906,7 @@ struct DeviceAdjacentDifference static_cast(num_items), difference_op, stream, - policy_selector); + tuning_env); }); } }; diff --git a/cub/cub/device/device_merge.cuh b/cub/cub/device/device_merge.cuh index cab154900d0..62ac40e182c 100644 --- a/cub/cub/device/device_merge.cuh +++ b/cub/cub/device/device_merge.cuh @@ -192,10 +192,8 @@ struct DeviceMerge { _CCCL_NVTX_RANGE_SCOPE("cub::DeviceMerge::MergeKeys"); - using default_policy_selector = - detail::merge::policy_selector_from_types, NullType, int64_t>; - return detail::dispatch_with_env_and_tuning( - env, [&](auto policy_selector, void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t stream) { + return detail::dispatch_with_env( + env, [&](auto tuning_env, void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t stream) { return detail::merge::dispatch( d_temp_storage, temp_storage_bytes, @@ -209,7 +207,7 @@ struct DeviceMerge static_cast(nullptr), compare_op, stream, - policy_selector); + tuning_env); }); } @@ -416,10 +414,9 @@ struct DeviceMerge EnvT env = {}) { _CCCL_NVTX_RANGE_SCOPE("cub::DeviceMerge::MergePairs"); - using default_policy_selector = detail::merge:: - policy_selector_from_types, detail::it_value_t, int64_t>; - return detail::dispatch_with_env_and_tuning( - env, [&](auto policy_selector, void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t stream) { + + return detail::dispatch_with_env( + env, [&](auto tuning_env, void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t stream) { return detail::merge::dispatch( d_temp_storage, temp_storage_bytes, @@ -433,7 +430,7 @@ struct DeviceMerge values_out, compare_op, stream, - policy_selector); + tuning_env); }); } }; diff --git a/cub/cub/device/device_merge_sort.cuh b/cub/cub/device/device_merge_sort.cuh index a735eb1c27d..c38efa405d3 100644 --- a/cub/cub/device/device_merge_sort.cuh +++ b/cub/cub/device/device_merge_sort.cuh @@ -97,6 +97,44 @@ private: return "cub::DeviceMergeSort"; } + // TODO(bgruber): I would ideally like to have the logic of extracting the policy selector from the tuning environment + // inside the dispatch function, but this will not work with CCCL.C, which needs to pass a stateful policy selector. + // Refactor this once we have a host code JIT compiler. + template > + CUB_RUNTIME_FUNCTION static auto select_tuning_and_dispatch( + void* d_temp_storage, + size_t& temp_storage_bytes, + KeyInputIteratorT d_input_keys, + ValueInputIteratorT d_input_values, + KeyIteratorT d_output_keys, + ValueIteratorT d_output_values, + OffsetT num_items, + CompareOpT compare_op, + cudaStream_t stream, + TuningEnvT = {}) + { + using default_policy_selector_t = detail::merge_sort::policy_selector_from_types; + using policy_selector_t = ::cuda::std::execution:: + __query_result_or_t; + return detail::merge_sort::dispatch( + d_temp_storage, + temp_storage_bytes, + d_input_keys, + d_input_values, + d_output_keys, + d_output_values, + num_items, + compare_op, + stream, + policy_selector_t{}); + } + // Internal version without NVTX range template CUB_RUNTIME_FUNCTION static cudaError_t SortPairsNoNVTX( @@ -299,23 +337,20 @@ public: { _CCCL_NVTX_RANGE_SCOPE(GetName()); - using ChooseOffsetT = detail::choose_offset_t; - using default_policy_selector = detail::merge_sort::policy_selector_from_types; - - return detail::dispatch_with_env_and_tuning( - env, [&](auto policy_selector, void* storage, size_t& bytes, auto stream) { - return detail::merge_sort::dispatch( - storage, - bytes, - d_keys, - d_values, - d_keys, - d_values, - static_cast(num_items), - compare_op, - stream, - policy_selector); - }); + using ChooseOffsetT = detail::choose_offset_t; + return detail::dispatch_with_env(env, [&](auto tuning_env, void* storage, size_t& bytes, auto stream) { + return select_tuning_and_dispatch( + storage, + bytes, + d_keys, + d_values, + d_keys, + d_values, + static_cast(num_items), + compare_op, + stream, + tuning_env); + }); } /** @@ -553,23 +588,20 @@ public: { _CCCL_NVTX_RANGE_SCOPE(GetName()); - using ChooseOffsetT = detail::choose_offset_t; - using default_policy_selector = detail::merge_sort::policy_selector_from_types; - - return detail::dispatch_with_env_and_tuning( - env, [&](auto policy_selector, void* storage, size_t& bytes, auto stream) { - return detail::merge_sort::dispatch( - storage, - bytes, - d_input_keys, - d_input_values, - d_output_keys, - d_output_values, - static_cast(num_items), - compare_op, - stream, - policy_selector); - }); + using ChooseOffsetT = detail::choose_offset_t; + return detail::dispatch_with_env(env, [&](auto tuning_env, void* storage, size_t& bytes, auto stream) { + return select_tuning_and_dispatch( + storage, + bytes, + d_input_keys, + d_input_values, + d_output_keys, + d_output_values, + static_cast(num_items), + compare_op, + stream, + tuning_env); + }); } private: @@ -754,23 +786,20 @@ public: { _CCCL_NVTX_RANGE_SCOPE(GetName()); - using ChooseOffsetT = detail::choose_offset_t; - using default_policy_selector = detail::merge_sort::policy_selector_from_types; - - return detail::dispatch_with_env_and_tuning( - env, [&](auto policy_selector, void* storage, size_t& bytes, auto stream) { - return detail::merge_sort::dispatch( - storage, - bytes, - d_keys, - static_cast(nullptr), - d_keys, - static_cast(nullptr), - static_cast(num_items), - compare_op, - stream, - policy_selector); - }); + using ChooseOffsetT = detail::choose_offset_t; + return detail::dispatch_with_env(env, [&](auto tuning_env, void* storage, size_t& bytes, auto stream) { + return select_tuning_and_dispatch( + storage, + bytes, + d_keys, + static_cast(nullptr), + d_keys, + static_cast(nullptr), + static_cast(num_items), + compare_op, + stream, + tuning_env); + }); } private: @@ -984,23 +1013,20 @@ public: { _CCCL_NVTX_RANGE_SCOPE(GetName()); - using ChooseOffsetT = detail::choose_offset_t; - using default_policy_selector = detail::merge_sort::policy_selector_from_types; - - return detail::dispatch_with_env_and_tuning( - env, [&](auto policy_selector, void* storage, size_t& bytes, auto stream) { - return detail::merge_sort::dispatch( - storage, - bytes, - d_input_keys, - static_cast(nullptr), - d_output_keys, - static_cast(nullptr), - static_cast(num_items), - compare_op, - stream, - policy_selector); - }); + using ChooseOffsetT = detail::choose_offset_t; + return detail::dispatch_with_env(env, [&](auto tuning_env, void* storage, size_t& bytes, auto stream) { + return select_tuning_and_dispatch( + storage, + bytes, + d_input_keys, + static_cast(nullptr), + d_output_keys, + static_cast(nullptr), + static_cast(num_items), + compare_op, + stream, + tuning_env); + }); } /** @@ -1181,23 +1207,20 @@ public: { _CCCL_NVTX_RANGE_SCOPE(GetName()); - using ChooseOffsetT = detail::choose_offset_t; - using default_policy_selector = detail::merge_sort::policy_selector_from_types; - - return detail::dispatch_with_env_and_tuning( - env, [&](auto policy_selector, void* storage, size_t& bytes, auto stream) { - return detail::merge_sort::dispatch( - storage, - bytes, - d_keys, - d_values, - d_keys, - d_values, - static_cast(num_items), - compare_op, - stream, - policy_selector); - }); + using ChooseOffsetT = detail::choose_offset_t; + return detail::dispatch_with_env(env, [&](auto tuning_env, void* storage, size_t& bytes, auto stream) { + return select_tuning_and_dispatch( + storage, + bytes, + d_keys, + d_values, + d_keys, + d_values, + static_cast(num_items), + compare_op, + stream, + tuning_env); + }); } /** @@ -1359,23 +1382,20 @@ public: { _CCCL_NVTX_RANGE_SCOPE(GetName()); - using ChooseOffsetT = detail::choose_offset_t; - using default_policy_selector = detail::merge_sort::policy_selector_from_types; - - return detail::dispatch_with_env_and_tuning( - env, [&](auto policy_selector, void* storage, size_t& bytes, auto stream) { - return detail::merge_sort::dispatch( - storage, - bytes, - d_keys, - static_cast(nullptr), - d_keys, - static_cast(nullptr), - static_cast(num_items), - compare_op, - stream, - policy_selector); - }); + using ChooseOffsetT = detail::choose_offset_t; + return detail::dispatch_with_env(env, [&](auto tuning_env, void* storage, size_t& bytes, auto stream) { + return select_tuning_and_dispatch( + storage, + bytes, + d_keys, + static_cast(nullptr), + d_keys, + static_cast(nullptr), + static_cast(num_items), + compare_op, + stream, + tuning_env); + }); } /** @@ -1562,23 +1582,20 @@ public: { _CCCL_NVTX_RANGE_SCOPE(GetName()); - using ChooseOffsetT = detail::choose_offset_t; - using default_policy_selector = detail::merge_sort::policy_selector_from_types; - - return detail::dispatch_with_env_and_tuning( - env, [&](auto policy_selector, void* storage, size_t& bytes, auto stream) { - return detail::merge_sort::dispatch( - storage, - bytes, - d_input_keys, - static_cast(nullptr), - d_output_keys, - static_cast(nullptr), - static_cast(num_items), - compare_op, - stream, - policy_selector); - }); + using ChooseOffsetT = detail::choose_offset_t; + return detail::dispatch_with_env(env, [&](auto tuning_env, void* storage, size_t& bytes, auto stream) { + return select_tuning_and_dispatch( + storage, + bytes, + d_input_keys, + static_cast(nullptr), + d_output_keys, + static_cast(nullptr), + static_cast(num_items), + compare_op, + stream, + tuning_env); + }); } }; diff --git a/cub/cub/device/dispatch/dispatch_adjacent_difference.cuh b/cub/cub/device/dispatch/dispatch_adjacent_difference.cuh index 808a696d8f6..ca0d3fc584b 100644 --- a/cub/cub/device/dispatch/dispatch_adjacent_difference.cuh +++ b/cub/cub/device/dispatch/dispatch_adjacent_difference.cuh @@ -28,6 +28,7 @@ #include #include +#include #include #include #include @@ -320,11 +321,8 @@ template , + typename TuningEnvT = ::cuda::std::execution::env<>, typename KernelLauncherFactory = CUB_DETAIL_DEFAULT_KERNEL_LAUNCHER_FACTORY> -#if _CCCL_HAS_CONCEPTS() - requires adjacent_difference_policy_selector -#endif // _CCCL_HAS_CONCEPTS() CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE auto dispatch( void* d_temp_storage, size_t& temp_storage_bytes, @@ -333,18 +331,25 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE auto dispatch( OffsetT num_items, DifferenceOpT difference_op, cudaStream_t stream, - PolicySelector policy_selector = {}, + TuningEnvT = {}, KernelLauncherFactory launcher_factory = {}) { using InputT = detail::it_value_t; + using default_policy_selector_t = policy_selector_from_types; + using policy_selector_t = + ::cuda::std::execution::__query_result_or_t; +#if _CCCL_HAS_CONCEPTS() + static_assert(adjacent_difference_policy_selector); +#endif // _CCCL_HAS_CONCEPTS() + ::cuda::compute_capability cc{}; if (const auto error = CubDebug(launcher_factory.PtxComputeCap(cc))) { return error; } - const adjacent_difference_policy active_policy = policy_selector(cc); + const adjacent_difference_policy active_policy = policy_selector_t{}(cc); #if _CCCL_HOSTED() && defined(CUB_DEBUG_LOG) NV_IF_TARGET(NV_IS_HOST, ({ ::std::stringstream ss; @@ -429,7 +434,7 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE auto dispatch( if (const auto error = CubDebug( THRUST_NS_QUALIFIER::cuda_cub::detail::triple_chevron(num_tiles, active_policy.threads_per_block, 0, stream) - .doit(DeviceAdjacentDifferenceDifferenceKernel < PolicySelector, + .doit(DeviceAdjacentDifferenceDifferenceKernel < policy_selector_t, InputIteratorT, OutputIteratorT, DifferenceOpT, diff --git a/cub/cub/device/dispatch/dispatch_merge.cuh b/cub/cub/device/dispatch/dispatch_merge.cuh index 18f0b310105..421884fe976 100644 --- a/cub/cub/device/dispatch/dispatch_merge.cuh +++ b/cub/cub/device/dispatch/dispatch_merge.cuh @@ -24,6 +24,7 @@ #include #include +#include #include CUB_NAMESPACE_BEGIN @@ -189,11 +190,8 @@ template , it_value_t, Offset>, + typename TuningEnvT = ::cuda::std::execution::env<>, typename KernelLauncherFactory = CUB_DETAIL_DEFAULT_KERNEL_LAUNCHER_FACTORY> -#if _CCCL_HAS_CONCEPTS() - requires merge_policy_selector -#endif // _CCCL_HAS_CONCEPTS() CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t dispatch( void* d_temp_storage, size_t& temp_storage_bytes, @@ -207,16 +205,23 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t dispatch( ValueIt3 d_values_out, CompareOp compare_op, cudaStream_t stream, - PolicySelector policy_selector = {}, + TuningEnvT = {}, KernelLauncherFactory launcher_factory = {}) { + using default_policy_selector_t = policy_selector_from_types, it_value_t, Offset>; + using policy_selector_t = + ::cuda::std::execution::__query_result_or_t; +#if _CCCL_HAS_CONCEPTS() + static_assert(merge_policy_selector); +#endif // _CCCL_HAS_CONCEPTS() + ::cuda::compute_capability cc{}; if (const auto error = CubDebug(launcher_factory.PtxComputeCap(cc))) { return error; } - return dispatch_compute_cap(policy_selector, cc, [&](auto policy_getter) { + return dispatch_compute_cap(policy_selector_t{}, cc, [&](auto policy_getter) { #if _CCCL_HOSTED() && defined(CUB_DEBUG_LOG) NV_IF_TARGET(NV_IS_HOST, ({ std::stringstream ss; @@ -270,7 +275,7 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t dispatch( THRUST_NS_QUALIFIER::cuda_cub::detail::triple_chevron( partition_grid_size, threads_per_partition_block, 0, stream) .doit(device_partition_merge_path_kernel< - PolicySelector, + policy_selector_t, KeyIt1, ValueIt1, KeyIt2, @@ -301,7 +306,7 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t dispatch( THRUST_NS_QUALIFIER::cuda_cub::detail::triple_chevron( static_cast(num_tiles), static_cast(AgentT::threads_per_block), 0, stream) .doit( - device_merge_kernel, + device_merge_kernel, d_keys1, d_values1, num_items1,