Skip to content
26 changes: 25 additions & 1 deletion cudax/include/cuda/experimental/__places/stream_pool.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,35 @@ class stream_pool
// Construct from a decorated stream, this is used to create a stream pool with a single stream.
explicit impl(decorated_stream ds)
: payload(1, mv(ds))
, externally_owned(true)
{}

// Release every stream the pool has lazily created. Externally-owned
// single-stream pools wrap user streams and must leave them alone.
~impl()
{
if (externally_owned)
{
return;
}

for (auto& ds : payload)
{
if (ds.stream != nullptr)
{
[[maybe_unused]] cudaError_t err = cudaStreamDestroy(ds.stream);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
[[maybe_unused]] cudaError_t err = cudaStreamDestroy(ds.stream);
std::ignore = cudaStreamDestroy(ds.stream);

or perhaps:

Suggested change
[[maybe_unused]] cudaError_t err = cudaStreamDestroy(ds.stream);
cudaError_t err = cudaStreamDestroy(ds.stream);
if (err)
{
fprintf(stderr, "CUDA error while destroying streams: %s (%s) at %s:%d\n", cudaGetErrorName(err), cudaGetErrorString(err), __FILE__, __LINE__);
}

ds.stream = nullptr;
}
}
}

impl(const impl&) = delete;
impl& operator=(const impl&) = delete;

mutable ::std::mutex mtx;
::std::vector<decorated_stream> payload;
size_t index = 0;
size_t index = 0;
bool externally_owned = false;
};

::std::shared_ptr<impl> pimpl;
Expand Down
17 changes: 10 additions & 7 deletions cudax/include/cuda/experimental/__stf/internal/backend_ctx.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

#include <atomic>
#include <fstream>
#include <mutex>
Comment thread
caugonnet marked this conversation as resolved.
#include <sstream>
#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -113,7 +114,7 @@ protected:
public:
friend class backend_ctx_untyped;

impl(async_resources_handle async_resources = async_resources_handle())
impl(async_resources_handle async_resources = async_resources_handle(), bool initialize_cuda_runtime = true)
: auto_scheduler(reserved::scheduler::make(getenv("CUDASTF_SCHEDULE")))
, auto_reorderer(reserved::reorderer::make(getenv("CUDASTF_TASK_ORDER")))
// Record whether the handle was supplied by the caller *before* we
Expand All @@ -122,12 +123,14 @@ protected:
, user_provided_handle(bool(async_resources))
, async_resources(async_resources ? mv(async_resources) : async_resources_handle())
{
// Forces init
cudaError_t ret = cudaFree(0);

// If we are running the task in the context of a CUDA callback, we are
// not allowed to issue any CUDA API call.
EXPECT((ret == cudaSuccess || ret == cudaErrorNotPermitted));
if (initialize_cuda_runtime)
{
// Initialize the CUDA runtime before STF starts issuing work.
cudaError_t ret = cudaFree(0);
// If we are running the task in the context of a CUDA callback, we
// are not allowed to issue any CUDA API call.
EXPECT((ret == cudaSuccess || ret == cudaErrorNotPermitted));
}

// Enable peer memory accesses (if not done already)
machine::instance().enable_peer_accesses();
Expand Down
24 changes: 18 additions & 6 deletions cudax/include/cuda/experimental/__stf/stream/stream_ctx.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,13 @@ public:
: backend_ctx<stream_ctx>(::std::make_shared<impl>(mv(handle)))
{}
stream_ctx(cudaStream_t user_stream, async_resources_handle handle = async_resources_handle(nullptr))
: backend_ctx<stream_ctx>(::std::make_shared<impl>(mv(handle)))
: backend_ctx<stream_ctx>(::std::make_shared<impl>(mv(handle), !is_capturing(user_stream)))
{
// A valid user stream means the CUDA runtime has already been initialized.
// If that stream is currently capturing, avoid making the backend
// constructor issue any additional runtime initialization calls that could
// be incompatible with the capture.

// When the caller supplies their own ``async_resources_handle``, its
// stream pool is very likely already populated with streams that carry
// residual work from previous contexts. Folding those streams into an
Expand All @@ -154,9 +159,7 @@ public:
// the supported in-capture configuration.
if (state().user_provided_handle)
{
cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone;
cuda_safe_call(cudaStreamIsCapturing(user_stream, &capture_status));
EXPECT(capture_status == cudaStreamCaptureStatusNone,
EXPECT(!is_capturing(user_stream),
"stream_ctx(user_stream, handle): user_stream is in a CUDA graph "
"capture but a caller-provided async_resources_handle was "
"supplied. The handle's stream pool may carry uncaptured work "
Expand All @@ -174,6 +177,15 @@ public:

///@}

private:
[[nodiscard]] static bool is_capturing(cudaStream_t user_stream)
{
cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone;
cuda_safe_call(cudaStreamIsCapturing(user_stream, &capture_status));
return capture_status != cudaStreamCaptureStatusNone;
}

public:
void set_user_stream(cudaStream_t user_stream)
{
// TODO first introduce the user stream in our pool
Expand Down Expand Up @@ -594,8 +606,8 @@ private:
class impl : public base::impl
{
public:
impl(async_resources_handle _async_resources = async_resources_handle(nullptr))
: base::impl(mv(_async_resources))
impl(async_resources_handle _async_resources = async_resources_handle(nullptr), bool initialize_cuda_runtime = true)
: base::impl(mv(_async_resources), initialize_cuda_runtime)
{
reserved::backend_ctx_setup_allocators<impl, uncached_stream_allocator>(*this);
}
Expand Down
1 change: 1 addition & 0 deletions cudax/test/stf/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ set(
local_stf/stackable_nested_while.cu
local_stf/stackable_nested.cu
local_stf/stackable_node_pool_growth.cu
local_stf/stream_ctx_lifetime_btb.cu
local_stf/stackable_read_only.cu
local_stf/stackable_threads.cu
local_stf/stackable_token.cu
Expand Down
198 changes: 198 additions & 0 deletions cudax/test/stf/local_stf/stream_ctx_lifetime_btb.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
//===----------------------------------------------------------------------===//
//
// Part of CUDASTF in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

/**
* @file
* @brief Ensure back-to-back stream_ctx instances on a caller stream are ordered.
*
* The test submits two stream_ctx instances back-to-back on the same
* caller-provided stream, without an explicit synchronization between them.
* Each context launches independent token chains on STF pool streams. The
* second context writes value 2 into the same buffer written by the first
* context, so observing any value other than 2 means the contexts were not
* chained through the caller stream correctly.
*
* The explicit-sync and shared-handle variants exercise the same shape through
* the two configurations that were already known to be safe.
*/

#include <cuda/experimental/stf.cuh>

#include <vector>

#include <cuda_runtime.h>

using namespace cuda::experimental::stf;

namespace
{
constexpr int N = 1 << 14;
constexpr int CHAIN_COUNT = 8;
constexpr int CHAIN_LEN = 40;
constexpr int OUTER = 5;
constexpr long long BUSY_CYCLES = 5'000'000;

__global__ void slow_set_kernel(int* slice, int n, int value, long long ns)
{
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid >= n)
{
return;
}
const long long start = clock64();
while (clock64() - start < ns)
{
// busy wait
}
slice[tid] = value;
}

void submit_token_chains(stream_ctx& ctx, int* d_arr, int value)
{
std::vector<logical_data<void_interface>> toks;
toks.reserve(CHAIN_COUNT);
for (int k = 0; k < CHAIN_COUNT; ++k)
{
toks.push_back(ctx.token());
}

const int per_chain = N / CHAIN_COUNT;

for (int step = 0; step < CHAIN_LEN; ++step)
{
for (int k = 0; k < CHAIN_COUNT; ++k)
{
int* slice = d_arr + k * per_chain;
ctx.task(toks[k].rw())->*[=](cudaStream_t ts) {
const int threads = 128;
const int blocks = (per_chain + threads - 1) / threads;
slow_set_kernel<<<blocks, threads, 0, ts>>>(slice, per_chain, value, BUSY_CYCLES);
};
}
}
}

bool has_mismatch(const std::vector<int>& values, int expected)
{
return ::std::any_of(values.begin(), values.end(), [=](int x) { return x != expected; });
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Verify usage site
rg -n --context 2 '::std::any_of' cudax/test/stf/local_stf/stream_ctx_lifetime_btb.cu

# Verify direct include is currently missing
rg -n '^`#include` <algorithm>$' cudax/test/stf/local_stf/stream_ctx_lifetime_btb.cu || true

Repository: NVIDIA/cccl

Length of output: 231


🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="cudax/test/stf/local_stf/stream_ctx_lifetime_btb.cu"

echo "== Includes near top =="
sed -n '1,120p' "$FILE" | nl -ba | sed -n '1,120p'

echo
echo "== Any include mentioning algorithm =="
rg -n --context 2 -i '^\s*#\s*include\s*<[^>]*algorithm[^>]*>' "$FILE" || true

Repository: NVIDIA/cccl

Length of output: 140


🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="cudax/test/stf/local_stf/stream_ctx_lifetime_btb.cu"

echo "== Includes near top =="
cat -n "$FILE" | sed -n '1,140p'

echo
echo "== Any include mentioning algorithm =="
rg -n --context 1 -i '^\s*#\s*include\s*<[^>]*algorithm[^>]*>' "$FILE" || true

echo
echo "== Any include for std/algorithm variants =="
rg -n --context 1 -i '^\s*#\s*include\s*<([^>]*std/)?algorithm[^>]*>' "$FILE" || true

Repository: NVIDIA/cccl

Length of output: 4990


important: Add #include <algorithm> since this file calls ::std::any_of(...) (line 84) but has no direct <algorithm> include; avoid relying on transitive STL includes for stricter toolchains.

}

void validate_buffer(int* d_arr)
{
std::vector<int> h_arr(N, 0);
cuda_safe_call(cudaMemcpy(h_arr.data(), d_arr, N * sizeof(int), cudaMemcpyDeviceToHost));
EXPECT(!has_mismatch(h_arr, 2));
}

void run_no_handle_no_sync_once()
{
cudaStream_t stream{};
cuda_safe_call(cudaStreamCreate(&stream));

int* d_arr = nullptr;
{
cuda_safe_call(cudaMalloc(&d_arr, N * sizeof(int)));
cuda_safe_call(cudaMemsetAsync(d_arr, 0, N * sizeof(int), stream));
}

{
stream_ctx ctx(stream);
submit_token_chains(ctx, d_arr, 1);
ctx.finalize();
}
{
stream_ctx ctx(stream);
submit_token_chains(ctx, d_arr, 2);
ctx.finalize();
}

cuda_safe_call(cudaStreamSynchronize(stream));
validate_buffer(d_arr);

cuda_safe_call(cudaFree(d_arr));
cuda_safe_call(cudaStreamDestroy(stream));
}

void run_no_handle_sync_once()
{
cudaStream_t stream{};
cuda_safe_call(cudaStreamCreate(&stream));

int* d_arr = nullptr;
{
cuda_safe_call(cudaMalloc(&d_arr, N * sizeof(int)));
cuda_safe_call(cudaMemsetAsync(d_arr, 0, N * sizeof(int), stream));
}

{
stream_ctx ctx(stream);
submit_token_chains(ctx, d_arr, 1);
ctx.finalize();
}
cuda_safe_call(cudaStreamSynchronize(stream));
{
stream_ctx ctx(stream);
submit_token_chains(ctx, d_arr, 2);
ctx.finalize();
}

cuda_safe_call(cudaStreamSynchronize(stream));
validate_buffer(d_arr);

cuda_safe_call(cudaFree(d_arr));
cuda_safe_call(cudaStreamDestroy(stream));
}

void run_shared_handle_no_sync_once()
{
cudaStream_t stream{};
cuda_safe_call(cudaStreamCreate(&stream));

int* d_arr = nullptr;
{
cuda_safe_call(cudaMalloc(&d_arr, N * sizeof(int)));
cuda_safe_call(cudaMemsetAsync(d_arr, 0, N * sizeof(int), stream));
}

async_resources_handle handle;
{
stream_ctx ctx(stream, handle);
submit_token_chains(ctx, d_arr, 1);
ctx.finalize();
}
{
stream_ctx ctx(stream, handle);
submit_token_chains(ctx, d_arr, 2);
ctx.finalize();
}

cuda_safe_call(cudaStreamSynchronize(stream));
validate_buffer(d_arr);

cuda_safe_call(cudaFree(d_arr));
cuda_safe_call(cudaStreamDestroy(stream));
}

template <typename Test>
void repeat(Test test)
{
for (int i = 0; i < OUTER; ++i)
{
test();
}
}
} // namespace

int main()
{
repeat(run_no_handle_no_sync_once);
repeat(run_no_handle_sync_once);
repeat(run_shared_handle_no_sync_once);
}