Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
ae9415a
add external stream support to context
TedThemistokleous Apr 9, 2026
a9858f8
[AIGenerated] Add tests for external streams as well as fallback modes
TedThemistokleous Apr 9, 2026
bb7e424
Fix format
TedThemistokleous Apr 9, 2026
8bc9608
Cleanup
TedThemistokleous Apr 10, 2026
297f5c0
Update context to not rebind on the same stream
TedThemistokleous Apr 11, 2026
443b4eb
Remove noop clearStream call
TedThemistokleous Apr 16, 2026
2c770da
remove clear_stream from tests
TedThemistokleous Apr 16, 2026
e468657
Update context and tests
TedThemistokleous Apr 17, 2026
043597d
fix format
TedThemistokleous Apr 17, 2026
aea1da0
Revert wait_for/finish_on semantics
TedThemistokleous May 20, 2026
3cbfae0
Add set_queue/restore_queue() semantics for async streams.
TedThemistokleous May 20, 2026
dfebac1
Update interface to use set_queue /restore_queue semantics for extern…
TedThemistokleous May 21, 2026
837ba7e
Update tests
TedThemistokleous May 21, 2026
229f289
Ensure programs use external stream semantics instead of using intern…
TedThemistokleous May 21, 2026
6f650d5
Format
TedThemistokleous May 21, 2026
93dd5e9
Merge branch 'develop' into use_external_contexts
TedThemistokleous May 21, 2026
e49551b
regen API
TedThemistokleous May 21, 2026
568a34c
Fix tidy
TedThemistokleous May 21, 2026
eedbdfe
revert finish_on/wait_for
TedThemistokleous May 21, 2026
d30fb3d
Ensure async mode doesn't use wait_for/finish_on syncs
TedThemistokleous May 21, 2026
c9a9473
Fux event record
TedThemistokleous May 21, 2026
3e7169a
Use std::optional semantics for external_stream
TedThemistokleous May 21, 2026
42f1001
Fix optional assign with value_or vs value()
TedThemistokleous May 21, 2026
2a6bc0a
Add clear external stream to handle case where we want to unbind
TedThemistokleous May 21, 2026
d3107e9
Update tests
TedThemistokleous May 21, 2026
be90720
Add change for coverage regarding context and program changes.
TedThemistokleous May 22, 2026
6f289e8
Update get_queue to mirror set_queue to preserve typing between nullp…
TedThemistokleous May 22, 2026
098f78f
fix format
TedThemistokleous May 22, 2026
70939a9
Update tests to use unsafe_get isntead of the lazy create to avoid is…
TedThemistokleous May 22, 2026
5ee9d1e
Fix license
TedThemistokleous May 22, 2026
9c4f913
Create set_raw_stream() call
TedThemistokleous May 22, 2026
adead17
Remove previous steram, cleanup comments
TedThemistokleous May 22, 2026
4fa9fde
Cleanup tests and wait() call
TedThemistokleous May 22, 2026
0270739
Remove comment
TedThemistokleous May 22, 2026
816f9e1
Update license
TedThemistokleous May 22, 2026
f0a9151
Changes to context and handle format and outstanding test changes sin…
TedThemistokleous May 22, 2026
3315f14
Update src/targets/gpu/include/migraphx/gpu/context.hpp
pfultz2 May 22, 2026
376b4d1
Merge branch 'develop' into use_external_contexts
kahmed10 May 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions src/include/migraphx/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ any_ptr get_queue_context(T&)
return {};
}

template <class T>
void set_queue_context(T&, any_ptr)
{
}

template <class T>
void restore_queue_context(T&)
{
}

template <class T>
void wait_for_context(T&, any_ptr)
{
Expand All @@ -89,6 +99,10 @@ struct MIGRAPHX_EXPORT context
// (optional)
any_ptr get_queue();
// (optional)
void set_queue(any_ptr queue);
// (optional)
void restore_queue();
// (optional)
void wait_for(any_ptr queue);
// (optional)
void finish_on(any_ptr queue);
Expand Down Expand Up @@ -142,6 +156,33 @@ struct context
return get_queue_context(private_detail_te_self);
}

template <class T>
static auto private_detail_te_default_set_queue(char, T&& private_detail_te_self, any_ptr queue)
-> decltype(private_detail_te_self.set_queue(queue))
{
private_detail_te_self.set_queue(queue);
}

template <class T>
static void
private_detail_te_default_set_queue(float, T&& private_detail_te_self, any_ptr queue)
{
set_queue_context(private_detail_te_self, queue);
}

template <class T>
static auto private_detail_te_default_restore_queue(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.restore_queue())
{
private_detail_te_self.restore_queue();
}

template <class T>
static void private_detail_te_default_restore_queue(float, T&& private_detail_te_self)
{
restore_queue_context(private_detail_te_self);
}

template <class T>
Comment thread
TedThemistokleous marked this conversation as resolved.
static auto private_detail_te_default_wait_for(char, T&& private_detail_te_self, any_ptr queue)
-> decltype(private_detail_te_self.wait_for(queue))
Expand Down Expand Up @@ -192,6 +233,10 @@ struct context
std::declval<const value&>()),
private_detail_te_default_get_queue(char(0),
std::declval<PrivateDetailTypeErasedT>()),
private_detail_te_default_set_queue(
char(0), std::declval<PrivateDetailTypeErasedT>(), std::declval<any_ptr>()),
private_detail_te_default_restore_queue(char(0),
std::declval<PrivateDetailTypeErasedT>()),
private_detail_te_default_wait_for(
char(0), std::declval<PrivateDetailTypeErasedT>(), std::declval<any_ptr>()),
private_detail_te_default_finish_on(
Expand Down Expand Up @@ -289,6 +334,18 @@ struct context
return (*this).private_detail_te_get_handle().get_queue();
}

void set_queue(any_ptr queue)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().set_queue(queue);
}

void restore_queue()
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().restore_queue();
}

void wait_for(any_ptr queue)
{
assert((*this).private_detail_te_handle_mem_var);
Expand Down Expand Up @@ -323,6 +380,8 @@ struct context
virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0;
virtual any_ptr get_queue() = 0;
virtual void set_queue(any_ptr queue) = 0;
virtual void restore_queue() = 0;
virtual void wait_for(any_ptr queue) = 0;
virtual void finish_on(any_ptr queue) = 0;
virtual void finish() const = 0;
Expand Down Expand Up @@ -373,6 +432,18 @@ struct context
return private_detail_te_default_get_queue(char(0), private_detail_te_value);
}

void set_queue(any_ptr queue) override
{

private_detail_te_default_set_queue(char(0), private_detail_te_value, queue);
}

void restore_queue() override
{

private_detail_te_default_restore_queue(char(0), private_detail_te_value);
}

void wait_for(any_ptr queue) override
{

Expand Down
4 changes: 2 additions & 2 deletions src/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ std::vector<argument> program::eval(const parameter_map& params,
if(exec_env.async)
{
assert(contexts.size() == 1);
contexts.front().wait_for(exec_env.queue);
contexts.front().set_queue(exec_env.queue);
}

// When MIGRAPHX_TRACE_EVAL is set, overwrite any user-provided trace callback with our trace
Expand Down Expand Up @@ -697,7 +697,7 @@ std::vector<argument> program::eval(const parameter_map& params,
if(exec_env.async)
{
assert(contexts.size() == 1);
contexts.front().finish_on(exec_env.queue);
contexts.front().restore_queue();
}

return ret;
Expand Down
62 changes: 57 additions & 5 deletions src/targets/gpu/include/migraphx/gpu/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include <migraphx/gpu/hsa_chiplet.hpp>
#include <unordered_map>
#include <memory>
#include <optional>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -90,6 +91,8 @@ struct hip_device

hipStream_t get()
{
if(external_stream.has_value())
return external_stream.value();
if(not enabled(MIGRAPHX_ENABLE_NULL_STREAM{}))
{
setup();
Expand Down Expand Up @@ -144,12 +147,39 @@ struct hip_device
}
#endif

void set_raw_stream(hipStream_t raw_stream)
{
#if MIGRAPHX_USE_MIOPEN
if(mihandle != nullptr)
miopenSetStream(mihandle.get(), raw_stream);
#endif
#if MIGRAPHX_USE_ROCBLAS
if(rbhandle != nullptr)
rocblas_set_stream(rbhandle.get(), raw_stream);
#endif
}

bool has_external_stream() const { return external_stream.has_value(); }

void set_queue(hipStream_t q)
{
external_stream = q;
set_raw_stream(q);
}

void restore_queue()
{
external_stream.reset();
set_raw_stream(s.get());
}

void wait() const
{
if(s == nullptr)
hipStream_t cur = external_stream.value_or(s.get());
if(cur == nullptr)
return;
setup();
auto status = hipStreamSynchronize(s.get());
auto status = hipStreamSynchronize(cur);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to wait: " + hip_error(status));
}
Expand All @@ -173,6 +203,8 @@ struct hip_device
private:
std::size_t id = 0;
shared<hip_stream_ptr> s = nullptr;
std::optional<hipStream_t> external_stream{};

#if MIGRAPHX_USE_MIOPEN
shared<miopen_handle> mihandle = nullptr;
#endif
Expand Down Expand Up @@ -334,25 +366,45 @@ struct context
this->current_device = std::make_shared<hip_device>(device, n_streams);
}

// Pure event-based synchronization point. Records an event on the
// caller's queue and makes the context's current stream wait on it.
void wait_for(any_ptr queue)
{
auto status = hipEventRecord(begin_event.get(), queue.get<hipStream_t>());
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to record: " + hip_error(status));

get_stream().wait(begin_event.get());
}

// Symmetric counterpart of wait_for(). Records an event on the context's
// current stream and makes the caller's queue wait on it.
void finish_on(any_ptr queue)
{
get_stream().record(finish_event.get());

auto status = hipStreamWaitEvent(queue.get<hipStream_t>(), finish_event.get(), 0);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to wait on event: " + hip_error(status));
}

any_ptr get_queue() { return get_stream().get(); }
any_ptr get_queue()
{
auto* s = get_stream().get();
return s == nullptr ? any_ptr{} : any_ptr{s};
}

// Bind a caller-provided queue for subsequent submissions.
// Passing an empty / null any_ptr is equivalent to binding the HIP
// default stream (nullptr), which is a distinct, valid operation from
// restore_queue() -- the two must not be conflated. We bypass the typed
// any_ptr accessor when the pointer is null because a default-constructed
// any_ptr carries no type name and would otherwise throw on get<>().
void set_queue(any_ptr queue)
{
hipStream_t s = queue.unsafe_get() == nullptr ? nullptr : queue.get<hipStream_t>();
get_stream().set_queue(s);
}

void restore_queue() { get_stream().restore_queue(); }

std::pair<hipEvent_t, hipEvent_t> get_perf_events() const
{
Expand Down
73 changes: 73 additions & 0 deletions test/eval_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/execution_environment.hpp>
#include <migraphx/any_ptr.hpp>
#include <sstream>
#include "test.hpp"
#include <basic_ops.hpp>
Expand Down Expand Up @@ -151,6 +153,24 @@ struct double_invert_target
migraphx::context get_context() const { return {}; }
};

// Minimal context that implements the optional set_queue/restore_queue members
// of the context concept
// Each call bumps a counter so a test can verify the dispatch routed through.
struct tracked_ctx
{
int set_calls = 0;
int restore_calls = 0;
migraphx::any_ptr last_queue{};

void finish() const {}
void set_queue(migraphx::any_ptr q)
{
++set_calls;
last_queue = q;
}
void restore_queue() { ++restore_calls; }
};

TEST_CASE(literal_test1)
{
migraphx::program p;
Expand Down Expand Up @@ -645,4 +665,57 @@ TEST_CASE(eval_trace_with_target_test)
EXPECT(not fired_ops.empty());
}

TEST_CASE(async_eval_on_cpu_target_invokes_set_and_restore_queue)
{
// The async branches of program::eval() call contexts.front().set_queue()
// before generic_eval and contexts.front().restore_queue() after. id_target
// wraps an id_target::context that has no set_queue/restore_queue members,
// so this exercises the type-erased facade end-to-end on a non-GPU build:
// - program::eval async prologue + epilogue (set_queue / restore_queue)
// - context::set_queue / context::restore_queue facade bodies
// - the float-overload (no-member) dispatchers
// - the default set_queue_context / restore_queue_context free-function
// fallbacks
migraphx::program p;
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(migraphx::make_op("add"), one, two);
p.compile(id_target{});

int dummy = 0;
migraphx::execution_environment exec_env;
exec_env.queue = migraphx::any_ptr{&dummy};
exec_env.async = true;

auto result = p.eval({}, exec_env).back();
EXPECT(result == migraphx::literal{3});

// A default-constructed any_ptr is a legal exec_env.queue and must also
// round-trip through set_queue / restore_queue without throwing.
migraphx::execution_environment exec_env_null;
exec_env_null.async = true;
auto result2 = p.eval({}, exec_env_null).back();
EXPECT(result2 == migraphx::literal{3});
}

TEST_CASE(context_facade_dispatches_to_member_set_and_restore_queue)
{
// Sister test of the one above: tracked_ctx *does* implement set_queue and
// restore_queue.
migraphx::context ctx{tracked_ctx{}};

int dummy = 0;
migraphx::any_ptr q{&dummy};
ctx.set_queue(q);
ctx.set_queue(q);
ctx.restore_queue();

auto* held = ctx.any_cast<tracked_ctx>();
EXPECT(held != nullptr);
EXPECT(held->set_calls == 2);
EXPECT(held->restore_calls == 1);
EXPECT(held->last_queue.unsafe_get() == &dummy);
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }
7 changes: 5 additions & 2 deletions test/gpu/context_serialize.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -51,7 +51,10 @@ TEST_CASE(gpu_context_serialize)
TEST_CASE(context_queue)
{
migraphx::context ctx = migraphx::gpu::context{0, 3};
EXPECT(ctx.get_queue().get<hipStream_t>() != nullptr);
// unsafe_get() avoids a type-mismatch throw if MIGRAPHX_ENABLE_NULL_STREAM
// is set: context::get_queue() returns an untyped any_ptr when the bound
// stream is nullptr, so get<hipStream_t>() would not yield "false".
EXPECT(ctx.get_queue().unsafe_get() != nullptr);
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }
Loading
Loading