From b1b43630d3206e86b17710cdccf61a2177fc4157 Mon Sep 17 00:00:00 2001 From: Cedric AUGONNET Date: Tue, 12 May 2026 11:55:24 +0200 Subject: [PATCH 1/5] [STF] Chain back-to-back stream contexts Destroy pool-owned streams with the stream pool and initialize the CUDA runtime only once so consecutive stream_ctx instances on a caller stream serialize without explicit synchronization. --- .../experimental/__places/stream_pool.cuh | 26 ++- .../__stf/internal/backend_ctx.cuh | 24 +- cudax/test/stf/CMakeLists.txt | 1 + .../stf/local_stf/stream_ctx_lifetime_btb.cu | 205 ++++++++++++++++++ 4 files changed, 249 insertions(+), 7 deletions(-) create mode 100644 cudax/test/stf/local_stf/stream_ctx_lifetime_btb.cu diff --git a/cudax/include/cuda/experimental/__places/stream_pool.cuh b/cudax/include/cuda/experimental/__places/stream_pool.cuh index 8c17dfa6321..eb12d93bc08 100644 --- a/cudax/include/cuda/experimental/__places/stream_pool.cuh +++ b/cudax/include/cuda/experimental/__places/stream_pool.cuh @@ -163,11 +163,35 @@ class stream_pool // Construct from a decorated stream, this is used to create a stream pool with a single stream. explicit impl(decorated_stream ds) : payload(1, mv(ds)) + , externally_owned(true) {} + // Release every stream the pool has lazily created. Externally-owned + // single-stream pools wrap user streams and must leave them alone. + ~impl() + { + if (externally_owned) + { + return; + } + + for (auto& ds : payload) + { + if (ds.stream != nullptr) + { + [[maybe_unused]] cudaError_t err = cudaStreamDestroy(ds.stream); + ds.stream = nullptr; + } + } + } + + impl(const impl&) = delete; + impl& operator=(const impl&) = delete; + mutable ::std::mutex mtx; ::std::vector payload; - size_t index = 0; + size_t index = 0; + bool externally_owned = false; }; ::std::shared_ptr pimpl; diff --git a/cudax/include/cuda/experimental/__stf/internal/backend_ctx.cuh b/cudax/include/cuda/experimental/__stf/internal/backend_ctx.cuh index e64dc3b098d..5d5b2ad9ce5 100644 --- a/cudax/include/cuda/experimental/__stf/internal/backend_ctx.cuh +++ b/cudax/include/cuda/experimental/__stf/internal/backend_ctx.cuh @@ -43,6 +43,7 @@ #include #include +#include #include #include #include @@ -122,12 +123,23 @@ protected: , user_provided_handle(bool(async_resources)) , async_resources(async_resources ? mv(async_resources) : async_resources_handle()) { - // Forces init - cudaError_t ret = cudaFree(0); - - // If we are running the task in the context of a CUDA callback, we are - // not allowed to issue any CUDA API call. - EXPECT((ret == cudaSuccess || ret == cudaErrorNotPermitted)); + // Force CUDA runtime init exactly once per process. Previous versions + // called ``cudaFree(0)`` unconditionally on every context construction, + // but ``cudaFree(0)`` is not capture-safe: under + // ``cudaStreamCaptureModeThreadLocal`` / ``Global`` (what Warp's + // ``ScopedCapture`` uses) it is rejected with + // ``cudaErrorStreamCaptureUnsupported`` *and* invalidates the current + // capture, poisoning every subsequent CUDA call on that capture chain. + // Running it once, before any user code might enter a capture region, is + // sufficient: CUDA init is a process-wide state that does not need to be + // re-checked per STF context. + static ::std::once_flag cuda_init_flag; + ::std::call_once(cuda_init_flag, [] { + cudaError_t ret = cudaFree(0); + // If we are running the task in the context of a CUDA callback, we + // are not allowed to issue any CUDA API call. + EXPECT((ret == cudaSuccess || ret == cudaErrorNotPermitted)); + }); // Enable peer memory accesses (if not done already) machine::instance().enable_peer_accesses(); diff --git a/cudax/test/stf/CMakeLists.txt b/cudax/test/stf/CMakeLists.txt index 75746e23527..0fbe7f7f031 100644 --- a/cudax/test/stf/CMakeLists.txt +++ b/cudax/test/stf/CMakeLists.txt @@ -66,6 +66,7 @@ set( local_stf/stackable_nested_while.cu local_stf/stackable_nested.cu local_stf/stackable_node_pool_growth.cu + local_stf/stream_ctx_lifetime_btb.cu local_stf/stackable_read_only.cu local_stf/stackable_threads.cu local_stf/stackable_token.cu diff --git a/cudax/test/stf/local_stf/stream_ctx_lifetime_btb.cu b/cudax/test/stf/local_stf/stream_ctx_lifetime_btb.cu new file mode 100644 index 00000000000..c998e9724e1 --- /dev/null +++ b/cudax/test/stf/local_stf/stream_ctx_lifetime_btb.cu @@ -0,0 +1,205 @@ +//===----------------------------------------------------------------------===// +// +// Part of CUDASTF in CUDA C++ Core Libraries, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// + +/** + * @file + * @brief Ensure back-to-back stream_ctx instances on a caller stream are ordered. + * + * The test submits two stream_ctx instances back-to-back on the same + * caller-provided stream, without an explicit synchronization between them. + * Each context launches independent token chains on STF pool streams. The + * second context writes value 2 into the same buffer written by the first + * context, so observing any value other than 2 means the contexts were not + * chained through the caller stream correctly. + * + * The explicit-sync and shared-handle variants exercise the same shape through + * the two configurations that were already known to be safe. + */ + +#include + +#include + +#include + +using namespace cuda::experimental::stf; + +namespace +{ +constexpr int N = 1 << 14; +constexpr int CHAIN_COUNT = 8; +constexpr int CHAIN_LEN = 40; +constexpr int OUTER = 5; +constexpr long long BUSY_CYCLES = 5'000'000; + +__global__ void slow_set_kernel(int* slice, int n, int value, long long ns) +{ + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n) + { + return; + } + long long start = clock64(); + while (clock64() - start < ns) + { + // busy wait + } + slice[tid] = value; +} + +void submit_token_chains(stream_ctx& ctx, int* d_arr, int value) +{ + std::vector> toks; + toks.reserve(CHAIN_COUNT); + for (int k = 0; k < CHAIN_COUNT; ++k) + { + toks.push_back(ctx.token()); + } + + const int per_chain = N / CHAIN_COUNT; + + for (int step = 0; step < CHAIN_LEN; ++step) + { + for (int k = 0; k < CHAIN_COUNT; ++k) + { + int* slice = d_arr + k * per_chain; + ctx.task(toks[k].rw())->*[=](cudaStream_t ts) { + const int threads = 128; + const int blocks = (per_chain + threads - 1) / threads; + slow_set_kernel<<>>(slice, per_chain, value, BUSY_CYCLES); + }; + } + } +} + +bool has_mismatch(const std::vector& values, int expected) +{ + for (int value : values) + { + if (value != expected) + { + return true; + } + } + return false; +} + +void validate_buffer(int* d_arr) +{ + std::vector h_arr(N, 0); + cuda_safe_call(cudaMemcpy(h_arr.data(), d_arr, N * sizeof(int), cudaMemcpyDeviceToHost)); + EXPECT(!has_mismatch(h_arr, 2)); +} + +void run_no_handle_no_sync_once() +{ + cudaStream_t stream{}; + cuda_safe_call(cudaStreamCreate(&stream)); + + int* d_arr = nullptr; + { + cuda_safe_call(cudaMalloc(&d_arr, N * sizeof(int))); + cuda_safe_call(cudaMemsetAsync(d_arr, 0, N * sizeof(int), stream)); + } + + { + stream_ctx ctx(stream); + submit_token_chains(ctx, d_arr, 1); + ctx.finalize(); + } + { + stream_ctx ctx(stream); + submit_token_chains(ctx, d_arr, 2); + ctx.finalize(); + } + + cuda_safe_call(cudaStreamSynchronize(stream)); + validate_buffer(d_arr); + + cuda_safe_call(cudaFree(d_arr)); + cuda_safe_call(cudaStreamDestroy(stream)); +} + +void run_no_handle_sync_once() +{ + cudaStream_t stream{}; + cuda_safe_call(cudaStreamCreate(&stream)); + + int* d_arr = nullptr; + { + cuda_safe_call(cudaMalloc(&d_arr, N * sizeof(int))); + cuda_safe_call(cudaMemsetAsync(d_arr, 0, N * sizeof(int), stream)); + } + + { + stream_ctx ctx(stream); + submit_token_chains(ctx, d_arr, 1); + ctx.finalize(); + } + cuda_safe_call(cudaStreamSynchronize(stream)); + { + stream_ctx ctx(stream); + submit_token_chains(ctx, d_arr, 2); + ctx.finalize(); + } + + cuda_safe_call(cudaStreamSynchronize(stream)); + validate_buffer(d_arr); + + cuda_safe_call(cudaFree(d_arr)); + cuda_safe_call(cudaStreamDestroy(stream)); +} + +void run_shared_handle_no_sync_once() +{ + cudaStream_t stream{}; + cuda_safe_call(cudaStreamCreate(&stream)); + + int* d_arr = nullptr; + { + cuda_safe_call(cudaMalloc(&d_arr, N * sizeof(int))); + cuda_safe_call(cudaMemsetAsync(d_arr, 0, N * sizeof(int), stream)); + } + + async_resources_handle handle; + { + stream_ctx ctx(stream, handle); + submit_token_chains(ctx, d_arr, 1); + ctx.finalize(); + } + { + stream_ctx ctx(stream, handle); + submit_token_chains(ctx, d_arr, 2); + ctx.finalize(); + } + + cuda_safe_call(cudaStreamSynchronize(stream)); + validate_buffer(d_arr); + + cuda_safe_call(cudaFree(d_arr)); + cuda_safe_call(cudaStreamDestroy(stream)); +} + +template +void repeat(Test test) +{ + for (int i = 0; i < OUTER; ++i) + { + test(); + } +} +} // namespace + +int main() +{ + repeat(run_no_handle_no_sync_once); + repeat(run_no_handle_sync_once); + repeat(run_shared_handle_no_sync_once); +} From 2787106d13c1c516cb1d58cee03730f60155191f Mon Sep 17 00:00:00 2001 From: Cedric AUGONNET Date: Tue, 12 May 2026 12:15:45 +0200 Subject: [PATCH 2/5] [STF] Clarify CUDA runtime init comment Describe the runtime initialization invariant without relying on implementation history. --- .../experimental/__stf/internal/backend_ctx.cuh | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/cudax/include/cuda/experimental/__stf/internal/backend_ctx.cuh b/cudax/include/cuda/experimental/__stf/internal/backend_ctx.cuh index 5d5b2ad9ce5..74e47246f1b 100644 --- a/cudax/include/cuda/experimental/__stf/internal/backend_ctx.cuh +++ b/cudax/include/cuda/experimental/__stf/internal/backend_ctx.cuh @@ -123,16 +123,10 @@ protected: , user_provided_handle(bool(async_resources)) , async_resources(async_resources ? mv(async_resources) : async_resources_handle()) { - // Force CUDA runtime init exactly once per process. Previous versions - // called ``cudaFree(0)`` unconditionally on every context construction, - // but ``cudaFree(0)`` is not capture-safe: under - // ``cudaStreamCaptureModeThreadLocal`` / ``Global`` (what Warp's - // ``ScopedCapture`` uses) it is rejected with - // ``cudaErrorStreamCaptureUnsupported`` *and* invalidates the current - // capture, poisoning every subsequent CUDA call on that capture chain. - // Running it once, before any user code might enter a capture region, is - // sufficient: CUDA init is a process-wide state that does not need to be - // re-checked per STF context. + // Initialize the CUDA runtime before STF starts issuing work. The + // initialization call is process-wide, so doing it once is sufficient and + // avoids making capture-unsafe runtime calls while a user stream is being + // captured. static ::std::once_flag cuda_init_flag; ::std::call_once(cuda_init_flag, [] { cudaError_t ret = cudaFree(0); From 9d0db468e0e6b651bb5c5071b01b07bd6602c259 Mon Sep 17 00:00:00 2001 From: Cedric AUGONNET Date: Tue, 12 May 2026 13:36:02 +0200 Subject: [PATCH 3/5] [STF] Avoid runtime init during stream capture Skip CUDA runtime initialization when constructing a stream_ctx from an already-capturing user stream; the stream itself implies CUDA is initialized, and normal contexts still initialize before issuing work. --- .../__stf/internal/backend_ctx.cuh | 13 ++++------ .../experimental/__stf/stream/stream_ctx.cuh | 24 ++++++++++++++----- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/cudax/include/cuda/experimental/__stf/internal/backend_ctx.cuh b/cudax/include/cuda/experimental/__stf/internal/backend_ctx.cuh index 74e47246f1b..b4cc4003f46 100644 --- a/cudax/include/cuda/experimental/__stf/internal/backend_ctx.cuh +++ b/cudax/include/cuda/experimental/__stf/internal/backend_ctx.cuh @@ -114,7 +114,7 @@ protected: public: friend class backend_ctx_untyped; - impl(async_resources_handle async_resources = async_resources_handle()) + impl(async_resources_handle async_resources = async_resources_handle(), bool initialize_cuda_runtime = true) : auto_scheduler(reserved::scheduler::make(getenv("CUDASTF_SCHEDULE"))) , auto_reorderer(reserved::reorderer::make(getenv("CUDASTF_TASK_ORDER"))) // Record whether the handle was supplied by the caller *before* we @@ -123,17 +123,14 @@ protected: , user_provided_handle(bool(async_resources)) , async_resources(async_resources ? mv(async_resources) : async_resources_handle()) { - // Initialize the CUDA runtime before STF starts issuing work. The - // initialization call is process-wide, so doing it once is sufficient and - // avoids making capture-unsafe runtime calls while a user stream is being - // captured. - static ::std::once_flag cuda_init_flag; - ::std::call_once(cuda_init_flag, [] { + if (initialize_cuda_runtime) + { + // Initialize the CUDA runtime before STF starts issuing work. cudaError_t ret = cudaFree(0); // If we are running the task in the context of a CUDA callback, we // are not allowed to issue any CUDA API call. EXPECT((ret == cudaSuccess || ret == cudaErrorNotPermitted)); - }); + } // Enable peer memory accesses (if not done already) machine::instance().enable_peer_accesses(); diff --git a/cudax/include/cuda/experimental/__stf/stream/stream_ctx.cuh b/cudax/include/cuda/experimental/__stf/stream/stream_ctx.cuh index b03cc49b8f7..1a8d2de043f 100644 --- a/cudax/include/cuda/experimental/__stf/stream/stream_ctx.cuh +++ b/cudax/include/cuda/experimental/__stf/stream/stream_ctx.cuh @@ -142,8 +142,13 @@ public: : backend_ctx(::std::make_shared(mv(handle))) {} stream_ctx(cudaStream_t user_stream, async_resources_handle handle = async_resources_handle(nullptr)) - : backend_ctx(::std::make_shared(mv(handle))) + : backend_ctx(::std::make_shared(mv(handle), !is_capturing(user_stream))) { + // A valid user stream means the CUDA runtime has already been initialized. + // If that stream is currently capturing, avoid making the backend + // constructor issue any additional runtime initialization calls that could + // be incompatible with the capture. + // When the caller supplies their own ``async_resources_handle``, its // stream pool is very likely already populated with streams that carry // residual work from previous contexts. Folding those streams into an @@ -154,9 +159,7 @@ public: // the supported in-capture configuration. if (state().user_provided_handle) { - cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone; - cuda_safe_call(cudaStreamIsCapturing(user_stream, &capture_status)); - EXPECT(capture_status == cudaStreamCaptureStatusNone, + EXPECT(!is_capturing(user_stream), "stream_ctx(user_stream, handle): user_stream is in a CUDA graph " "capture but a caller-provided async_resources_handle was " "supplied. The handle's stream pool may carry uncaptured work " @@ -174,6 +177,15 @@ public: ///@} +private: + [[nodiscard]] static bool is_capturing(cudaStream_t user_stream) + { + cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone; + cuda_safe_call(cudaStreamIsCapturing(user_stream, &capture_status)); + return capture_status != cudaStreamCaptureStatusNone; + } + +public: void set_user_stream(cudaStream_t user_stream) { // TODO first introduce the user stream in our pool @@ -593,8 +605,8 @@ private: class impl : public base::impl { public: - impl(async_resources_handle _async_resources = async_resources_handle(nullptr)) - : base::impl(mv(_async_resources)) + impl(async_resources_handle _async_resources = async_resources_handle(nullptr), bool initialize_cuda_runtime = true) + : base::impl(mv(_async_resources), initialize_cuda_runtime) { reserved::backend_ctx_setup_allocators(*this); } From 46d682a85e8be70d8a82e7bf8bf5875701950992 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Augonnet?= <158148890+caugonnet@users.noreply.github.com> Date: Thu, 21 May 2026 11:21:27 +0200 Subject: [PATCH 4/5] Better algorithm for a test Co-authored-by: Andrei Alexandrescu --- cudax/test/stf/local_stf/stream_ctx_lifetime_btb.cu | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/cudax/test/stf/local_stf/stream_ctx_lifetime_btb.cu b/cudax/test/stf/local_stf/stream_ctx_lifetime_btb.cu index c998e9724e1..b4d39a64500 100644 --- a/cudax/test/stf/local_stf/stream_ctx_lifetime_btb.cu +++ b/cudax/test/stf/local_stf/stream_ctx_lifetime_btb.cu @@ -81,14 +81,7 @@ void submit_token_chains(stream_ctx& ctx, int* d_arr, int value) bool has_mismatch(const std::vector& values, int expected) { - for (int value : values) - { - if (value != expected) - { - return true; - } - } - return false; + return ::std::any_of(values.begin(), values.end(), [=](int x) { return x != expected; }); } void validate_buffer(int* d_arr) From e170423e97b25c8c2e0814afaf9938bacf2f07e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Augonnet?= <158148890+caugonnet@users.noreply.github.com> Date: Thu, 21 May 2026 11:21:45 +0200 Subject: [PATCH 5/5] add missing const Co-authored-by: Andrei Alexandrescu --- cudax/test/stf/local_stf/stream_ctx_lifetime_btb.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cudax/test/stf/local_stf/stream_ctx_lifetime_btb.cu b/cudax/test/stf/local_stf/stream_ctx_lifetime_btb.cu index b4d39a64500..d8bd51f1411 100644 --- a/cudax/test/stf/local_stf/stream_ctx_lifetime_btb.cu +++ b/cudax/test/stf/local_stf/stream_ctx_lifetime_btb.cu @@ -46,7 +46,7 @@ __global__ void slow_set_kernel(int* slice, int n, int value, long long ns) { return; } - long long start = clock64(); + const long long start = clock64(); while (clock64() - start < ns) { // busy wait