diff --git a/cudax/include/cuda/experimental/__places/stream_pool.cuh b/cudax/include/cuda/experimental/__places/stream_pool.cuh index 844115e3cb4..db68e79f8ba 100644 --- a/cudax/include/cuda/experimental/__places/stream_pool.cuh +++ b/cudax/include/cuda/experimental/__places/stream_pool.cuh @@ -159,11 +159,37 @@ class stream_pool // Construct from a decorated stream, this is used to create a stream pool with a single stream. explicit impl(decorated_stream ds) : payload(1, mv(ds)) + , externally_owned(true) {} + // Release every stream the pool has lazily created. Externally-owned + // single-stream pools wrap user streams and must leave them alone. + ~impl() noexcept + { + if (externally_owned) + { + return; + } + + for (auto& ds : payload) + { + if (ds.stream != nullptr) + { + // Stream destruction can fail during CUDA runtime teardown; the + // destructor has no useful way to report or recover from that. + (void) cudaStreamDestroy(ds.stream); + ds.stream = nullptr; + } + } + } + + impl(const impl&) = delete; + impl& operator=(const impl&) = delete; + mutable ::std::mutex mtx; ::std::vector payload; - size_t index = 0; + size_t index = 0; + bool externally_owned = false; }; ::std::shared_ptr pimpl; diff --git a/cudax/include/cuda/experimental/__stf/internal/backend_ctx.cuh b/cudax/include/cuda/experimental/__stf/internal/backend_ctx.cuh index 124091e143a..132108f4431 100644 --- a/cudax/include/cuda/experimental/__stf/internal/backend_ctx.cuh +++ b/cudax/include/cuda/experimental/__stf/internal/backend_ctx.cuh @@ -43,6 +43,7 @@ #include #include +#include #include #include #include @@ -113,7 +114,7 @@ protected: public: friend class backend_ctx_untyped; - impl(async_resources_handle async_resources = async_resources_handle()) + impl(async_resources_handle async_resources = async_resources_handle(), bool initialize_cuda_runtime = true) : auto_scheduler(reserved::scheduler::make(getenv("CUDASTF_SCHEDULE"))) , auto_reorderer(reserved::reorderer::make(getenv("CUDASTF_TASK_ORDER"))) // Record whether the handle was supplied by the caller *before* we @@ -122,12 +123,14 @@ protected: , user_provided_handle(bool(async_resources)) , async_resources(async_resources ? mv(async_resources) : async_resources_handle()) { - // Forces init - cudaError_t ret = cudaFree(0); - - // If we are running the task in the context of a CUDA callback, we are - // not allowed to issue any CUDA API call. - EXPECT((ret == cudaSuccess || ret == cudaErrorNotPermitted)); + if (initialize_cuda_runtime) + { + // Initialize the CUDA runtime before STF starts issuing work. + cudaError_t ret = cudaFree(0); + // If we are running the task in the context of a CUDA callback, we + // are not allowed to issue any CUDA API call. + EXPECT((ret == cudaSuccess || ret == cudaErrorNotPermitted)); + } // Enable peer memory accesses (if not done already) machine::instance().enable_peer_accesses(); diff --git a/cudax/include/cuda/experimental/__stf/stream/stream_ctx.cuh b/cudax/include/cuda/experimental/__stf/stream/stream_ctx.cuh index 5acd1a9dad0..93d0ed5d4ab 100644 --- a/cudax/include/cuda/experimental/__stf/stream/stream_ctx.cuh +++ b/cudax/include/cuda/experimental/__stf/stream/stream_ctx.cuh @@ -142,8 +142,13 @@ public: : backend_ctx(::std::make_shared(mv(handle))) {} stream_ctx(cudaStream_t user_stream, async_resources_handle handle = async_resources_handle(nullptr)) - : backend_ctx(::std::make_shared(mv(handle))) + : backend_ctx(::std::make_shared(mv(handle), !is_capturing(user_stream))) { + // A valid user stream means the CUDA runtime has already been initialized. + // If that stream is currently capturing, avoid making the backend + // constructor issue any additional runtime initialization calls that could + // be incompatible with the capture. + // When the caller supplies their own ``async_resources_handle``, its // stream pool is very likely already populated with streams that carry // residual work from previous contexts. Folding those streams into an @@ -154,9 +159,7 @@ public: // the supported in-capture configuration. if (state().user_provided_handle) { - cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone; - cuda_safe_call(cudaStreamIsCapturing(user_stream, &capture_status)); - EXPECT(capture_status == cudaStreamCaptureStatusNone, + EXPECT(!is_capturing(user_stream), "stream_ctx(user_stream, handle): user_stream is in a CUDA graph " "capture but a caller-provided async_resources_handle was " "supplied. The handle's stream pool may carry uncaptured work " @@ -174,6 +177,15 @@ public: ///@} +private: + [[nodiscard]] static bool is_capturing(cudaStream_t user_stream) + { + cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone; + cuda_safe_call(cudaStreamIsCapturing(user_stream, &capture_status)); + return capture_status != cudaStreamCaptureStatusNone; + } + +public: void set_user_stream(cudaStream_t user_stream) { // TODO first introduce the user stream in our pool @@ -594,8 +606,8 @@ private: class impl : public base::impl { public: - impl(async_resources_handle _async_resources = async_resources_handle(nullptr)) - : base::impl(mv(_async_resources)) + impl(async_resources_handle _async_resources = async_resources_handle(nullptr), bool initialize_cuda_runtime = true) + : base::impl(mv(_async_resources), initialize_cuda_runtime) { reserved::backend_ctx_setup_allocators(*this); } diff --git a/cudax/test/stf/CMakeLists.txt b/cudax/test/stf/CMakeLists.txt index 187bbb6f1e6..a3d37f1daf1 100644 --- a/cudax/test/stf/CMakeLists.txt +++ b/cudax/test/stf/CMakeLists.txt @@ -68,6 +68,7 @@ set( local_stf/stackable_nested_while.cu local_stf/stackable_nested.cu local_stf/stackable_node_pool_growth.cu + local_stf/stream_ctx_lifetime_btb.cu local_stf/stackable_read_only.cu local_stf/stackable_threads.cu local_stf/stackable_token.cu diff --git a/cudax/test/stf/local_stf/stream_ctx_lifetime_btb.cu b/cudax/test/stf/local_stf/stream_ctx_lifetime_btb.cu new file mode 100644 index 00000000000..6fd4123e158 --- /dev/null +++ b/cudax/test/stf/local_stf/stream_ctx_lifetime_btb.cu @@ -0,0 +1,201 @@ +//===----------------------------------------------------------------------===// +// +// Part of CUDASTF in CUDA C++ Core Libraries, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// + +/** + * @file + * @brief Ensure back-to-back stream_ctx instances on a caller stream are ordered. + * + * The test submits two stream_ctx instances back-to-back on the same + * caller-provided stream, without an explicit synchronization between them. + * Each context launches independent token chains on STF pool streams. The + * second context writes value 2 into the same buffer written by the first + * context, so observing any value other than 2 means the contexts were not + * chained through the caller stream correctly. + * + * The explicit-sync and shared-handle variants exercise the same shape through + * the two configurations that were already known to be safe. + */ + +#include + +#include +#include + +#include + +using namespace cuda::experimental::stf; + +namespace +{ +constexpr int N = 1 << 14; +constexpr int CHAIN_COUNT = 8; +constexpr int CHAIN_LEN = 40; +constexpr int OUTER = 5; +constexpr long long BUSY_CYCLES = 5'000'000; + +__global__ void slow_set_kernel(int* slice, int n, int value, long long cycles) +{ + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= n) + { + return; + } + const long long start = clock64(); + while (clock64() - start < cycles) + { + // busy wait + } + slice[tid] = value; +} + +void submit_token_chains(stream_ctx& ctx, int* d_arr, int value) +{ + std::vector> toks; + toks.reserve(CHAIN_COUNT); + for (int k = 0; k < CHAIN_COUNT; ++k) + { + toks.push_back(ctx.token()); + } + + const int per_chain = N / CHAIN_COUNT; + + for (int step = 0; step < CHAIN_LEN; ++step) + { + for (int k = 0; k < CHAIN_COUNT; ++k) + { + int* slice = d_arr + k * per_chain; + ctx.task(toks[k].rw())->*[=](cudaStream_t ts) { + const int threads = 128; + const int blocks = (per_chain + threads - 1) / threads; + slow_set_kernel<<>>(slice, per_chain, value, BUSY_CYCLES); + }; + } + } +} + +bool has_mismatch(const std::vector& values, int expected) +{ + return ::std::any_of(values.begin(), values.end(), [=](int x) { + return x != expected; + }); +} + +void validate_buffer(int* d_arr) +{ + std::vector h_arr(N, 0); + cuda_safe_call(cudaMemcpy(h_arr.data(), d_arr, N * sizeof(int), cudaMemcpyDeviceToHost)); + EXPECT(!has_mismatch(h_arr, 2)); +} + +void run_no_handle_no_sync_once() +{ + cudaStream_t stream{}; + cuda_safe_call(cudaStreamCreate(&stream)); + + int* d_arr = nullptr; + { + cuda_safe_call(cudaMalloc(&d_arr, N * sizeof(int))); + cuda_safe_call(cudaMemsetAsync(d_arr, 0, N * sizeof(int), stream)); + } + + { + stream_ctx ctx(stream); + submit_token_chains(ctx, d_arr, 1); + ctx.finalize(); + } + { + stream_ctx ctx(stream); + submit_token_chains(ctx, d_arr, 2); + ctx.finalize(); + } + + cuda_safe_call(cudaStreamSynchronize(stream)); + validate_buffer(d_arr); + + cuda_safe_call(cudaFree(d_arr)); + cuda_safe_call(cudaStreamDestroy(stream)); +} + +void run_no_handle_sync_once() +{ + cudaStream_t stream{}; + cuda_safe_call(cudaStreamCreate(&stream)); + + int* d_arr = nullptr; + { + cuda_safe_call(cudaMalloc(&d_arr, N * sizeof(int))); + cuda_safe_call(cudaMemsetAsync(d_arr, 0, N * sizeof(int), stream)); + } + + { + stream_ctx ctx(stream); + submit_token_chains(ctx, d_arr, 1); + ctx.finalize(); + } + cuda_safe_call(cudaStreamSynchronize(stream)); + { + stream_ctx ctx(stream); + submit_token_chains(ctx, d_arr, 2); + ctx.finalize(); + } + + cuda_safe_call(cudaStreamSynchronize(stream)); + validate_buffer(d_arr); + + cuda_safe_call(cudaFree(d_arr)); + cuda_safe_call(cudaStreamDestroy(stream)); +} + +void run_shared_handle_no_sync_once() +{ + cudaStream_t stream{}; + cuda_safe_call(cudaStreamCreate(&stream)); + + int* d_arr = nullptr; + { + cuda_safe_call(cudaMalloc(&d_arr, N * sizeof(int))); + cuda_safe_call(cudaMemsetAsync(d_arr, 0, N * sizeof(int), stream)); + } + + async_resources_handle handle; + { + stream_ctx ctx(stream, handle); + submit_token_chains(ctx, d_arr, 1); + ctx.finalize(); + } + { + stream_ctx ctx(stream, handle); + submit_token_chains(ctx, d_arr, 2); + ctx.finalize(); + } + + cuda_safe_call(cudaStreamSynchronize(stream)); + validate_buffer(d_arr); + + cuda_safe_call(cudaFree(d_arr)); + cuda_safe_call(cudaStreamDestroy(stream)); +} + +template +void repeat(Test test) +{ + for (int i = 0; i < OUTER; ++i) + { + test(); + } +} +} // namespace + +int main() +{ + repeat(run_no_handle_no_sync_once); + repeat(run_no_handle_sync_once); + repeat(run_shared_handle_no_sync_once); +}