diff --git a/src/include/migraphx/context.hpp b/src/include/migraphx/context.hpp index 634ed55b7c2..e9c4d5b185e 100644 --- a/src/include/migraphx/context.hpp +++ b/src/include/migraphx/context.hpp @@ -67,6 +67,16 @@ any_ptr get_queue_context(T&) return {}; } +template +void set_queue_context(T&, any_ptr) +{ +} + +template +void restore_queue_context(T&) +{ +} + template void wait_for_context(T&, any_ptr) { @@ -89,6 +99,10 @@ struct MIGRAPHX_EXPORT context // (optional) any_ptr get_queue(); // (optional) + void set_queue(any_ptr queue); + // (optional) + void restore_queue(); + // (optional) void wait_for(any_ptr queue); // (optional) void finish_on(any_ptr queue); @@ -142,6 +156,33 @@ struct context return get_queue_context(private_detail_te_self); } + template + static auto private_detail_te_default_set_queue(char, T&& private_detail_te_self, any_ptr queue) + -> decltype(private_detail_te_self.set_queue(queue)) + { + private_detail_te_self.set_queue(queue); + } + + template + static void + private_detail_te_default_set_queue(float, T&& private_detail_te_self, any_ptr queue) + { + set_queue_context(private_detail_te_self, queue); + } + + template + static auto private_detail_te_default_restore_queue(char, T&& private_detail_te_self) + -> decltype(private_detail_te_self.restore_queue()) + { + private_detail_te_self.restore_queue(); + } + + template + static void private_detail_te_default_restore_queue(float, T&& private_detail_te_self) + { + restore_queue_context(private_detail_te_self); + } + template static auto private_detail_te_default_wait_for(char, T&& private_detail_te_self, any_ptr queue) -> decltype(private_detail_te_self.wait_for(queue)) @@ -192,6 +233,10 @@ struct context std::declval()), private_detail_te_default_get_queue(char(0), std::declval()), + private_detail_te_default_set_queue( + char(0), std::declval(), std::declval()), + private_detail_te_default_restore_queue(char(0), + std::declval()), private_detail_te_default_wait_for( char(0), std::declval(), std::declval()), private_detail_te_default_finish_on( @@ -289,6 +334,18 @@ struct context return (*this).private_detail_te_get_handle().get_queue(); } + void set_queue(any_ptr queue) + { + assert((*this).private_detail_te_handle_mem_var); + (*this).private_detail_te_get_handle().set_queue(queue); + } + + void restore_queue() + { + assert((*this).private_detail_te_handle_mem_var); + (*this).private_detail_te_get_handle().restore_queue(); + } + void wait_for(any_ptr queue) { assert((*this).private_detail_te_handle_mem_var); @@ -323,6 +380,8 @@ struct context virtual value to_value() const = 0; virtual void from_value(const value& v) = 0; virtual any_ptr get_queue() = 0; + virtual void set_queue(any_ptr queue) = 0; + virtual void restore_queue() = 0; virtual void wait_for(any_ptr queue) = 0; virtual void finish_on(any_ptr queue) = 0; virtual void finish() const = 0; @@ -373,6 +432,18 @@ struct context return private_detail_te_default_get_queue(char(0), private_detail_te_value); } + void set_queue(any_ptr queue) override + { + + private_detail_te_default_set_queue(char(0), private_detail_te_value, queue); + } + + void restore_queue() override + { + + private_detail_te_default_restore_queue(char(0), private_detail_te_value); + } + void wait_for(any_ptr queue) override { diff --git a/src/program.cpp b/src/program.cpp index ecc54c4fd1d..2fd7662b122 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -644,7 +644,7 @@ std::vector program::eval(const parameter_map& params, if(exec_env.async) { assert(contexts.size() == 1); - contexts.front().wait_for(exec_env.queue); + contexts.front().set_queue(exec_env.queue); } // When MIGRAPHX_TRACE_EVAL is set, overwrite any user-provided trace callback with our trace @@ -697,7 +697,7 @@ std::vector program::eval(const parameter_map& params, if(exec_env.async) { assert(contexts.size() == 1); - contexts.front().finish_on(exec_env.queue); + contexts.front().restore_queue(); } return ret; diff --git a/src/targets/gpu/include/migraphx/gpu/context.hpp b/src/targets/gpu/include/migraphx/gpu/context.hpp index dd7f0c4a7b4..c02016c7394 100644 --- a/src/targets/gpu/include/migraphx/gpu/context.hpp +++ b/src/targets/gpu/include/migraphx/gpu/context.hpp @@ -42,6 +42,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -90,6 +91,8 @@ struct hip_device hipStream_t get() { + if(external_stream.has_value()) + return external_stream.value(); if(not enabled(MIGRAPHX_ENABLE_NULL_STREAM{})) { setup(); @@ -144,12 +147,39 @@ struct hip_device } #endif + void set_raw_stream(hipStream_t raw_stream) + { +#if MIGRAPHX_USE_MIOPEN + if(mihandle != nullptr) + miopenSetStream(mihandle.get(), raw_stream); +#endif +#if MIGRAPHX_USE_ROCBLAS + if(rbhandle != nullptr) + rocblas_set_stream(rbhandle.get(), raw_stream); +#endif + } + + bool has_external_stream() const { return external_stream.has_value(); } + + void set_queue(hipStream_t q) + { + external_stream = q; + set_raw_stream(q); + } + + void restore_queue() + { + external_stream.reset(); + set_raw_stream(s.get()); + } + void wait() const { - if(s == nullptr) + hipStream_t cur = external_stream.value_or(s.get()); + if(cur == nullptr) return; setup(); - auto status = hipStreamSynchronize(s.get()); + auto status = hipStreamSynchronize(cur); if(status != hipSuccess) MIGRAPHX_THROW("Failed to wait: " + hip_error(status)); } @@ -173,6 +203,8 @@ struct hip_device private: std::size_t id = 0; shared s = nullptr; + std::optional external_stream{}; + #if MIGRAPHX_USE_MIOPEN shared mihandle = nullptr; #endif @@ -334,25 +366,45 @@ struct context this->current_device = std::make_shared(device, n_streams); } + // Pure event-based synchronization point. Records an event on the + // caller's queue and makes the context's current stream wait on it. void wait_for(any_ptr queue) { auto status = hipEventRecord(begin_event.get(), queue.get()); if(status != hipSuccess) MIGRAPHX_THROW("Failed to record: " + hip_error(status)); - get_stream().wait(begin_event.get()); } + // Symmetric counterpart of wait_for(). Records an event on the context's + // current stream and makes the caller's queue wait on it. void finish_on(any_ptr queue) { get_stream().record(finish_event.get()); - auto status = hipStreamWaitEvent(queue.get(), finish_event.get(), 0); if(status != hipSuccess) MIGRAPHX_THROW("Failed to wait on event: " + hip_error(status)); } - any_ptr get_queue() { return get_stream().get(); } + any_ptr get_queue() + { + auto* s = get_stream().get(); + return s == nullptr ? any_ptr{} : any_ptr{s}; + } + + // Bind a caller-provided queue for subsequent submissions. + // Passing an empty / null any_ptr is equivalent to binding the HIP + // default stream (nullptr), which is a distinct, valid operation from + // restore_queue() -- the two must not be conflated. We bypass the typed + // any_ptr accessor when the pointer is null because a default-constructed + // any_ptr carries no type name and would otherwise throw on get<>(). + void set_queue(any_ptr queue) + { + hipStream_t s = queue.unsafe_get() == nullptr ? nullptr : queue.get(); + get_stream().set_queue(s); + } + + void restore_queue() { get_stream().restore_queue(); } std::pair get_perf_events() const { diff --git a/test/eval_test.cpp b/test/eval_test.cpp index 57b5c085efb..1261f0265b8 100644 --- a/test/eval_test.cpp +++ b/test/eval_test.cpp @@ -28,6 +28,8 @@ #include #include #include +#include +#include #include #include "test.hpp" #include @@ -151,6 +153,24 @@ struct double_invert_target migraphx::context get_context() const { return {}; } }; +// Minimal context that implements the optional set_queue/restore_queue members +// of the context concept +// Each call bumps a counter so a test can verify the dispatch routed through. +struct tracked_ctx +{ + int set_calls = 0; + int restore_calls = 0; + migraphx::any_ptr last_queue{}; + + void finish() const {} + void set_queue(migraphx::any_ptr q) + { + ++set_calls; + last_queue = q; + } + void restore_queue() { ++restore_calls; } +}; + TEST_CASE(literal_test1) { migraphx::program p; @@ -645,4 +665,57 @@ TEST_CASE(eval_trace_with_target_test) EXPECT(not fired_ops.empty()); } +TEST_CASE(async_eval_on_cpu_target_invokes_set_and_restore_queue) +{ + // The async branches of program::eval() call contexts.front().set_queue() + // before generic_eval and contexts.front().restore_queue() after. id_target + // wraps an id_target::context that has no set_queue/restore_queue members, + // so this exercises the type-erased facade end-to-end on a non-GPU build: + // - program::eval async prologue + epilogue (set_queue / restore_queue) + // - context::set_queue / context::restore_queue facade bodies + // - the float-overload (no-member) dispatchers + // - the default set_queue_context / restore_queue_context free-function + // fallbacks + migraphx::program p; + auto* mm = p.get_main_module(); + auto one = mm->add_literal(1); + auto two = mm->add_literal(2); + mm->add_instruction(migraphx::make_op("add"), one, two); + p.compile(id_target{}); + + int dummy = 0; + migraphx::execution_environment exec_env; + exec_env.queue = migraphx::any_ptr{&dummy}; + exec_env.async = true; + + auto result = p.eval({}, exec_env).back(); + EXPECT(result == migraphx::literal{3}); + + // A default-constructed any_ptr is a legal exec_env.queue and must also + // round-trip through set_queue / restore_queue without throwing. + migraphx::execution_environment exec_env_null; + exec_env_null.async = true; + auto result2 = p.eval({}, exec_env_null).back(); + EXPECT(result2 == migraphx::literal{3}); +} + +TEST_CASE(context_facade_dispatches_to_member_set_and_restore_queue) +{ + // Sister test of the one above: tracked_ctx *does* implement set_queue and + // restore_queue. + migraphx::context ctx{tracked_ctx{}}; + + int dummy = 0; + migraphx::any_ptr q{&dummy}; + ctx.set_queue(q); + ctx.set_queue(q); + ctx.restore_queue(); + + auto* held = ctx.any_cast(); + EXPECT(held != nullptr); + EXPECT(held->set_calls == 2); + EXPECT(held->restore_calls == 1); + EXPECT(held->last_queue.unsafe_get() == &dummy); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/gpu/context_serialize.cpp b/test/gpu/context_serialize.cpp index 845f594a8f1..b5f1d71f094 100644 --- a/test/gpu/context_serialize.cpp +++ b/test/gpu/context_serialize.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -51,7 +51,10 @@ TEST_CASE(gpu_context_serialize) TEST_CASE(context_queue) { migraphx::context ctx = migraphx::gpu::context{0, 3}; - EXPECT(ctx.get_queue().get() != nullptr); + // unsafe_get() avoids a type-mismatch throw if MIGRAPHX_ENABLE_NULL_STREAM + // is set: context::get_queue() returns an untyped any_ptr when the bound + // stream is nullptr, so get() would not yield "false". + EXPECT(ctx.get_queue().unsafe_get() != nullptr); } int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/gpu/external_stream.cpp b/test/gpu/external_stream.cpp new file mode 100644 index 00000000000..c5c758d9ad9 --- /dev/null +++ b/test/gpu/external_stream.cpp @@ -0,0 +1,595 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "test.hpp" + +using hip_stream_ptr = MIGRAPHX_MANAGE_PTR(hipStream_t, hipStreamDestroy); + +static hip_stream_ptr create_external_stream() +{ + hipStream_t stream; + auto status = hipStreamCreate(&stream); + if(status != hipSuccess) + MIGRAPHX_THROW("Failed to create stream"); + return hip_stream_ptr{stream}; +} + +static void verify_data(const migraphx::argument& result, const migraphx::shape& s, float expected) +{ + std::vector expected_data(s.elements(), expected); + auto expected_arg = migraphx::argument{s, expected_data.data()}; + EXPECT(result == expected_arg); +} + +TEST_CASE(test_stream_override_get) +{ + migraphx::gpu::context ctx{}; + auto& stream = ctx.get_stream(); + + hipStream_t internal = stream.get(); + EXPECT(internal != nullptr); + // A freshly-constructed context has no external binding. + EXPECT(not stream.has_external_stream()); + + auto ext = create_external_stream(); + stream.set_queue(ext.get()); + + EXPECT(stream.get() == ext.get()); + EXPECT(stream.get() != internal); + EXPECT(stream.has_external_stream()); + + // Under std::optional semantics, set_queue(nullptr) does NOT + // clear the binding: it rebinds to the HIP default stream (which is a + // legal stream value). has_external_stream() therefore stays true, and + // get() now returns nullptr (= the default stream). + stream.set_queue(nullptr); + + EXPECT(stream.get() == nullptr); + EXPECT(stream.get() != internal); + EXPECT(stream.has_external_stream()); +} + +TEST_CASE(test_stream_override_get_queue) +{ + migraphx::gpu::context ctx{}; + auto ext = create_external_stream(); + + hipStream_t original_queue = ctx.get_queue().get(); + EXPECT(original_queue != nullptr); + + ctx.get_stream().set_queue(ext.get()); + EXPECT(ctx.get_queue().get() == ext.get()); + + // Rebinding to nullptr means "the HIP default stream", not "no binding". + // The active queue therefore changes value, but the external binding + // remains in effect. Use unsafe_get() to compare against nullptr: + // context::get_queue() returns a default-constructed (untyped) any_ptr + // when the bound stream is nullptr, so get() would throw a + // type-mismatch exception on the empty name string. + ctx.get_stream().set_queue(nullptr); + + EXPECT(ctx.get_queue().unsafe_get() == nullptr); + EXPECT(ctx.get_queue().unsafe_get() != original_queue); + EXPECT(ctx.get_stream().has_external_stream()); +} + +TEST_CASE(test_context_set_and_restore_queue) +{ + migraphx::gpu::context ctx{}; + auto ext = create_external_stream(); + + migraphx::any_ptr queue(ext.get()); + + hipStream_t before = ctx.get_queue().get(); + EXPECT(before != nullptr); + + // set_queue() (not wait_for) is what redirects the active binding. + ctx.set_queue(queue); + EXPECT(ctx.get_queue().get() == ext.get()); + EXPECT(ctx.get_queue().get() != before); + EXPECT(ctx.get_stream().has_external_stream()); + + // restore_queue() puts the original binding back; wait_for/finish_on + // intentionally do not. + ctx.restore_queue(); + EXPECT(ctx.get_queue().get() == before); + EXPECT(not ctx.get_stream().has_external_stream()); +} + +TEST_CASE(test_context_wait_for_finish_on_do_not_rebind) +{ + migraphx::gpu::context ctx{}; + auto ext = create_external_stream(); + + migraphx::any_ptr queue(ext.get()); + + hipStream_t before = ctx.get_queue().get(); + + // wait_for() is pure event sync; it must NOT mutate the active binding. + ctx.wait_for(queue); + EXPECT(ctx.get_queue().get() == before); + EXPECT(not ctx.get_stream().has_external_stream()); + + ctx.finish_on(queue); + EXPECT(ctx.get_queue().get() == before); + EXPECT(not ctx.get_stream().has_external_stream()); +} + +TEST_CASE(test_context_restore_queue_is_noop_when_unsaved) +{ + migraphx::gpu::context ctx{}; + hipStream_t before = ctx.get_queue().get(); + + // Safe to call without a prior set_queue() -- the async epilogue in + // program::eval relies on this. + ctx.restore_queue(); + EXPECT(ctx.get_queue().get() == before); + EXPECT(not ctx.get_stream().has_external_stream()); +} + +TEST_CASE(test_context_set_queue_with_null_then_restore) +{ + migraphx::gpu::context ctx{}; + auto ext = create_external_stream(); + + // Pre-bind an external stream so the "original" binding under test is + // not the internal stream. + ctx.get_stream().set_queue(ext.get()); + EXPECT(ctx.get_queue().get() == ext.get()); + + // nullptr is a *valid* queue value -- it binds the HIP default stream. + // Under std::optional semantics the binding is still + // active (has_external_stream() == true); the value is just nullptr. + // set_queue(null) must NOT be conflated with restore. Read back via + // unsafe_get() because context::get_queue() returns an untyped any_ptr + // when the bound stream is nullptr. + ctx.set_queue(migraphx::any_ptr{}); + EXPECT(ctx.get_queue().unsafe_get() == nullptr); + EXPECT(ctx.get_queue().unsafe_get() != ext.get()); + EXPECT(ctx.get_stream().has_external_stream()); + + // restore_queue() unconditionally unbinds the external stream and routes + // submissions back to the internal stream -- it does NOT replay the + // previously-bound `ext` value. Callers that need to re-establish a + // prior external binding must call set_queue() themselves. + ctx.restore_queue(); + EXPECT(not ctx.get_stream().has_external_stream()); + EXPECT(ctx.get_queue().get() != ext.get()); +} + +TEST_CASE(test_external_stream_eval_uses_caller_stream) +{ + const unsigned int m = 64; + const unsigned int k = 128; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {m, k}}); + auto y = mm->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {k, m}})); + mm->add_instruction(migraphx::make_op("dot"), x, y); + + p.compile(migraphx::make_target("gpu")); + + migraphx::shape input_shape{migraphx::shape::float_type, {m, k}}; + migraphx::shape output_shape{migraphx::shape::float_type, {m, m}}; + auto input = migraphx::fill_argument(input_shape, 1); + auto ginput = migraphx::gpu::to_gpu(input); + + auto output = migraphx::fill_argument(output_shape, 0); + auto goutput = migraphx::gpu::to_gpu(output); + + auto ext = create_external_stream(); + + auto results = p.eval({{"x", ginput}, {"main:#output_0", goutput}}, {ext.get(), true}); + + EXPECT(not results.empty()); + + EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); + auto host_output = migraphx::gpu::from_gpu(goutput); + EXPECT(host_output != output); +} + +TEST_CASE(test_external_stream_serialized_on_caller_stream) +{ + const unsigned int n = 256; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 1.0f); + std::vector ydata(n, 2.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + auto ext = create_external_stream(); + + auto results = p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {ext.get(), true}); + + EXPECT(not results.empty()); + + EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); + auto host_result = migraphx::gpu::from_gpu(gout); + verify_data(host_result, out_shape, 3.0f); +} + +TEST_CASE(test_multiple_async_evals_same_stream) +{ + const unsigned int n = 128; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 1.0f); + std::vector ydata(n, 2.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + auto ext = create_external_stream(); + + for(int iter = 0; iter < 5; ++iter) + { + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {ext.get(), true}); + } + + EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); + auto host_result = migraphx::gpu::from_gpu(gout); + verify_data(host_result, out_shape, 3.0f); +} + +TEST_CASE(test_external_stream_cleared_after_eval) +{ + // When the context had NO external binding prior to async eval, the + // epilogue's restore_queue() must put the context back into the + // "no binding" state -- the caller's transient stream must not leak. + // This requires previous_stream to distinguish "no save" from + // "saved (nothing bound)". + const unsigned int n = 64; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 1.0f); + std::vector ydata(n, 2.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + auto ext = create_external_stream(); + + migraphx::context& ctx_ref = p.get_context(); + auto* gpu_ctx = ctx_ref.any_cast(); + EXPECT(gpu_ctx != nullptr); + EXPECT(not gpu_ctx->get_stream().has_external_stream()); + + hipStream_t internal_stream = gpu_ctx->get_queue().get(); + + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {ext.get(), true}); + + EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); + + EXPECT(not gpu_ctx->get_stream().has_external_stream()); + EXPECT(gpu_ctx->get_queue().get() == internal_stream); +} + +TEST_CASE(test_external_stream_eval_unbinds_prior_binding) +{ + // program::eval's async epilogue calls restore_queue(), which under the + // current contract unconditionally unbinds whatever external stream was + // active at the start of eval -- including a binding the caller had + // installed *before* eval. Callers that want to keep a prior binding + // alive across async evals must re-install it themselves; eval is not + // responsible for preserving it. + const unsigned int n = 64; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 1.0f); + std::vector ydata(n, 2.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + auto prior = create_external_stream(); + auto ext = create_external_stream(); + + migraphx::context& ctx_ref = p.get_context(); + auto* gpu_ctx = ctx_ref.any_cast(); + EXPECT(gpu_ctx != nullptr); + + gpu_ctx->get_stream().set_queue(prior.get()); + EXPECT(gpu_ctx->get_queue().get() == prior.get()); + + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {ext.get(), true}); + + EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); + + // Async eval's restore_queue() unconditionally drops the external + // binding. Neither the caller's `ext` nor the previously-bound + // `prior` remains in effect. + EXPECT(not gpu_ctx->get_stream().has_external_stream()); + EXPECT(gpu_ctx->get_queue().get() != ext.get()); + EXPECT(gpu_ctx->get_queue().get() != prior.get()); +} + +TEST_CASE(test_wait_for_finish_on_require_typed_queue) +{ + // The earlier "null-stream event fallback" was intentionally removed: + // wait_for()/finish_on() are now pure event-sync primitives that assume + // the caller has provided a typed (hipStream_t) any_ptr. Passing a + // default-constructed any_ptr is a programmer error and surfaces as a + // type-mismatch exception, rather than silently no-op'ing. + // + // The async eval path in program::eval no longer calls wait_for() / + // finish_on() at all -- it relies on set_queue()/restore_queue() for + // queue rebinding -- so this is purely a direct-API contract test. + migraphx::gpu::context ctx{}; + + hipStream_t internal_before = ctx.get_queue().get(); + + bool threw_on_wait_for = false; + try + { + ctx.wait_for(migraphx::any_ptr{}); + } + catch(const migraphx::exception&) + { + threw_on_wait_for = true; + } + EXPECT(threw_on_wait_for); + + bool threw_on_finish_on = false; + try + { + ctx.finish_on(migraphx::any_ptr{}); + } + catch(const migraphx::exception&) + { + threw_on_finish_on = true; + } + EXPECT(threw_on_finish_on); + + // The active binding is untouched on the error path. + EXPECT(not ctx.get_stream().has_external_stream()); + EXPECT(ctx.get_queue().get() == internal_before); +} + +TEST_CASE(test_async_eval_with_null_queue_uses_default_stream) +{ + // A default-constructed any_ptr is treated as "bind the HIP default + // stream (nullptr)" by context::set_queue() -- not as a request for an + // event fallback (there isn't one anymore). The eval must dispatch on + // the default stream and produce correct results once that stream is + // synchronized. + const unsigned int n = 128; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 5.0f); + std::vector ydata(n, 7.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + auto results = + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {migraphx::any_ptr{}, true}); + + EXPECT(not results.empty()); + + EXPECT(hipDeviceSynchronize() == hipSuccess); + auto host_result = migraphx::gpu::from_gpu(gout); + verify_data(host_result, out_shape, 12.0f); +} + +TEST_CASE(test_non_async_eval_uses_internal_stream) +{ + const unsigned int n = 128; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 4.0f); + std::vector ydata(n, 6.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + migraphx::context& ctx_ref = p.get_context(); + auto* gpu_ctx = ctx_ref.any_cast(); + EXPECT(gpu_ctx != nullptr); + + auto results = p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}); + + EXPECT(not results.empty()); + EXPECT(not gpu_ctx->get_stream().has_external_stream()); + + p.finish(); + auto host_result = migraphx::gpu::from_gpu(gout); + verify_data(host_result, out_shape, 10.0f); +} + +TEST_CASE(test_mixed_async_and_sync_evals) +{ + // Interleave async (caller-supplied stream) and sync (default) evals on + // the same program and verify each produces correct results. Async + // eval's restore_queue() unconditionally unbinds the external stream, + // so any caller-installed prior binding must be re-installed between + // async cycles. The sync block exercises the (uncommon but legal) + // sync-eval-with-pre-bound-external-stream path, which depends on + // wait() syncing the bound external stream rather than the internal one. + const unsigned int n = 128; + + migraphx::program p; + auto* mm = p.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {n}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {n}}); + mm->add_instruction(migraphx::make_op("add"), x, y); + + p.compile(migraphx::make_target("gpu")); + + std::vector xdata(n, 1.0f); + std::vector ydata(n, 2.0f); + auto xarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, xdata.data()}; + auto yarg = migraphx::argument{migraphx::shape{migraphx::shape::float_type, {n}}, ydata.data()}; + + auto gx = migraphx::gpu::to_gpu(xarg); + auto gy = migraphx::gpu::to_gpu(yarg); + + migraphx::shape out_shape{migraphx::shape::float_type, {n}}; + auto out = migraphx::fill_argument(out_shape, 0); + auto gout = migraphx::gpu::to_gpu(out); + + migraphx::context& ctx_ref = p.get_context(); + auto* gpu_ctx = ctx_ref.any_cast(); + EXPECT(gpu_ctx != nullptr); + + auto prior = create_external_stream(); + auto ext = create_external_stream(); + + gpu_ctx->get_stream().set_queue(prior.get()); + EXPECT(gpu_ctx->get_queue().get() == prior.get()); + + // Async eval with caller-supplied stream. The async epilogue unbinds + // whatever was bound -- including `prior` -- after eval. + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout}}, {ext.get(), true}); + EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); + EXPECT(not gpu_ctx->get_stream().has_external_stream()); + + auto host_result = migraphx::gpu::from_gpu(gout); + verify_data(host_result, out_shape, 3.0f); + + // Re-bind `prior` so the sync eval below runs on it. Sync eval does + // not touch the queue binding, so kernels submit to `prior` and + // p.finish() syncs `prior` (via wait()'s external_stream preference). + gpu_ctx->get_stream().set_queue(prior.get()); + auto gout2 = migraphx::gpu::to_gpu(out); + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout2}}); + EXPECT(gpu_ctx->get_queue().get() == prior.get()); + p.finish(); + + auto host_result2 = migraphx::gpu::from_gpu(gout2); + verify_data(host_result2, out_shape, 3.0f); + + // Another async eval; again the caller's stream is unbound on exit + // and so is `prior` (which was still bound coming in). + auto gout3 = migraphx::gpu::to_gpu(out); + p.eval({{"x", gx}, {"y", gy}, {"main:#output_0", gout3}}, {ext.get(), true}); + EXPECT(hipStreamSynchronize(ext.get()) == hipSuccess); + EXPECT(not gpu_ctx->get_stream().has_external_stream()); + + auto host_result3 = migraphx::gpu::from_gpu(gout3); + verify_data(host_result3, out_shape, 3.0f); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/program_test.cpp b/test/program_test.cpp index 78e83e49db1..e1612b685dd 100644 --- a/test/program_test.cpp +++ b/test/program_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -196,4 +196,20 @@ TEST_CASE(program_copy) } } +TEST_CASE(program_file_version_accessor_matches_serialized_value) +{ + // The new program::get_program_file_version() accessor is also what + // to_value() writes into the "version" field of a serialized program and + // what from_value() compares against on load. Guard against silent drift + // between the accessor and the serialized form. + migraphx::program p; + auto* mm = p.get_main_module(); + mm->add_literal(1); + p.compile(migraphx::make_target("ref")); + + auto v = p.get_program_file_version(); + EXPECT(v > 0); + EXPECT(p.to_value().at("version").to() == v); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/tools/include/context.hpp b/tools/include/context.hpp index ab700b3f8fd..93f16f6ab88 100644 --- a/tools/include/context.hpp +++ b/tools/include/context.hpp @@ -67,6 +67,16 @@ any_ptr get_queue_context(T&) return {}; } +template +void set_queue_context(T&, any_ptr) +{ +} + +template +void restore_queue_context(T&) +{ +} + template void wait_for_context(T&, any_ptr) { @@ -76,13 +86,17 @@ template void finish_on_context(T&, any_ptr){} <% - interface('context', - virtual('to_value', returns = 'value', const = True, default = 'to_value_context'), - virtual('from_value', v = 'const value&', default = 'from_value_context'), - virtual('get_queue', returns = 'any_ptr', default = 'get_queue_context'), - virtual('wait_for', queue = 'any_ptr', returns = 'void', default = 'wait_for_context'), - virtual('finish_on', queue = 'any_ptr', returns = 'void', default = 'finish_on_context'), - virtual('finish', returns = 'void', const = True)) %> + interface( + 'context', + virtual('to_value', returns = 'value', const = True, default = 'to_value_context'), + virtual('from_value', v = 'const value&', default = 'from_value_context'), + virtual('get_queue', returns = 'any_ptr', default = 'get_queue_context'), + virtual('set_queue', queue = 'any_ptr', returns = 'void', default = 'set_queue_context'), + virtual('restore_queue', returns = 'void', default = 'restore_queue_context'), + virtual('wait_for', queue = 'any_ptr', returns = 'void', default = 'wait_for_context'), + virtual('finish_on', queue = 'any_ptr', returns = 'void', default = 'finish_on_context'), + virtual('finish', returns = 'void', const = True)) +%> inline void migraphx_to_value(value& v, const context& ctx) {