Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/api/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,13 @@ static void set_file_format(file_options& options, const char* format) { options
static void set_default_dim_value(onnx_options& options, size_t value)
{
options.default_dim_value = value;
options.default_set = true;
}

static void set_default_dyn_dim_value(onnx_options& options, const shape::dynamic_dimension& dd)
{
options.default_dyn_dim_value = dd;
options.default_set = true;
}

static void set_default_loop_iterations(onnx_options& options, int64_t value)
Expand Down
24 changes: 24 additions & 0 deletions src/driver/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@

#include <fstream>
#include <optional>
#include <set>
#include <sstream>

namespace {
Expand Down Expand Up @@ -408,6 +409,7 @@ struct loader
{
auto v = from_json_string(convert_to_json(default_dyn_dim));
options.default_dyn_dim_value = from_value<migraphx::shape::dynamic_dimension>(v);
options.default_set = true;
}
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = true;
Expand Down Expand Up @@ -593,6 +595,25 @@ struct program_params
return map_load_args;
}

void warn_unset_inputs(const std::unordered_map<std::string, shape>& param_shapes) const
{
std::set<std::string> load_arg_names;
for(auto&& x : load_args_info)
if(not x.empty() and x[0] == '@')
load_arg_names.insert(x.substr(1));
std::set<std::string> unset;
for(const auto& param : param_shapes)
if(not contains(param.first, "#output_") and not contains(fill0, param.first) and
not contains(fill1, param.first) and not contains(load_arg_names, param.first))
unset.insert(param.first);
if(unset.empty())
return;
log::warn() << "Input(s) without explicit values: " << join_strings(std::move(unset), ", ")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all inputs require explicit values. Most dont require this. I think this warning is just too noisy.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea was to warn when users didn't set values for inputs such as the attention_mask and hopefully reduce the amount of tickets coming to us. Not sure if there's a good list out there for such inputs that require very specific values, but if not I can remove this warning.

<< ". These will be filled with random data and may cause unexpected behavior. "
"Use `--fill0 <name>`, `--fill1 <name>`, or "
"`--load-arg @<name> <file>` if the program fails to run.";
}

auto generate(const program& p,
const target& t,
bool offload,
Expand All @@ -615,6 +636,9 @@ struct program_params
m[s] = fill_argument(static_param_shapes.at(s), 0);
for(auto&& s : fill1)
m[s] = fill_argument(static_param_shapes.at(s), 1);

warn_unset_inputs(param_shapes);

fill_param_map(m, static_param_shapes, t, offload);
auto load_arg_map = program_params::parse_load_args(load_args_info, t, offload);
for(auto&& arg : load_arg_map)
Expand Down
1 change: 1 addition & 0 deletions src/include/migraphx/onnx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ struct onnx_options
/// Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value set
/// parser throws)
shape::dynamic_dimension default_dyn_dim_value = {1, 1};
bool default_set = false;
/// Explicitly specify the dims of an input
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {};
/// Explicitly specify a symbolic named parameter dimension
Expand Down
1 change: 1 addition & 0 deletions src/onnx/include/migraphx/onnx/onnx_parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ struct onnx_parser
std::unordered_map<std::string, instruction_ref> instructions;
program prog = program();
shape::dynamic_dimension default_dyn_dim_value = {1, 1};
bool default_set = false;
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, shape::dynamic_dimension> dim_params;
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
Expand Down
1 change: 1 addition & 0 deletions src/onnx/onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ static program parse_onnx_from(const onnx_options& options, Ts&&... xs)
{
parser.default_dyn_dim_value = options.default_dyn_dim_value;
}
parser.default_set = options.default_set;
if(not options.map_input_dims.empty() and not options.map_dyn_input_dims.empty())
{
MIGRAPHX_THROW("PARSE_ONNX_FROM: both map_input_dims and map_dyn_input_dims non-empty, only"
Expand Down
31 changes: 31 additions & 0 deletions src/onnx/onnx_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@
#include <migraphx/op/unknown.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/env.hpp>
#include <migraphx/logger.hpp>
#include <onnx.pb.h>
#include <iomanip>
#include <set>
#include <sstream>

namespace migraphx {
Expand Down Expand Up @@ -259,6 +261,33 @@ void onnx_parser::parse_undefined(module* mod, const std::string& name)
}
}

static void warn_unresolved_dim_params(const onnx_parser& parser, const onnx::GraphProto& graph)
{
if(parser.default_set)
return;
std::set<std::string> unresolved;
for(const auto& input : graph.input())
{
if(contains(parser.map_input_dims, input.name()) or
contains(parser.map_dyn_input_dims, input.name()))
continue;
Comment thread
eddieliao marked this conversation as resolved.
for(const auto& d : input.type().tensor_type().shape().dim())
{
// Skip batch dims and dims that are already set
if(d.has_dim_param() and not contains(parser.dim_params, d.dim_param()) and
to_lower(d.dim_param()).find("batch") == std::string::npos)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The batch dimension is not always named batch so I dont think you can rely on that.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How often would a batch dimension not contain the word "batch"? The only example I can think of right now would be just using "N", which I can add an edge case for if desired.

unresolved.insert(d.dim_param());
}
}
if(unresolved.empty())
return;
log::warn() << "Model has unbound symbolic dimension(s): "
Comment thread
eddieliao marked this conversation as resolved.
<< join_strings(std::move(unresolved), ", ") << ". These default to "
<< parser.default_dyn_dim_value << " and may cause unexpected behavior. "
<< "Try setting `--dim-param @<name> <value>` or `--input-dim @<input> <dims>` "
Comment thread
eddieliao marked this conversation as resolved.
"if program compilation fails.";
}

void onnx_parser::parse_from(std::istream& is, std::string name)
{
auto* mm = prog.get_main_module();
Expand All @@ -275,6 +304,7 @@ void onnx_parser::parse_from(std::istream& is, std::string name)

if(model.has_graph())
{
warn_unresolved_dim_params(*this, model.graph());
(void)this->parse_graph(mm, model.graph());
}
}
Expand All @@ -295,6 +325,7 @@ void onnx_parser::parse_from(const void* data, std::size_t size)

if(model.has_graph())
{
warn_unresolved_dim_params(*this, model.graph());
(void)this->parse_graph(mm, model.graph());
}
}
Expand Down
1 change: 1 addition & 0 deletions test/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR})
add_c_api_test(c_op test_c_op_construct.c ${TEST_ONNX_DIR})
add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR})
add_api_test(tf_parser test_tf_parser.cpp ${TEST_TF_DIR})
add_api_test(parser_warning test_parser_warning.cpp ${TEST_ONNX_DIR})
add_api_test(onnx_op_list test_onnx_op_list.cpp ${TEST_ONNX_DIR})
add_api_test(trace_callback test_trace_callback.cpp ${TEST_ONNX_DIR})
# GPU-based tests
Expand Down
84 changes: 84 additions & 0 deletions test/api/test_parser_warning.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include <migraphx/logger.hpp>
#include <read_onnx.hpp>
#include "test.hpp"

namespace {
struct warning_sink
{
std::vector<std::string> messages;
std::size_t id;

warning_sink()
{
id = migraphx::log::add_sink(
[this](migraphx::log::severity, std::string_view msg, migraphx::source_location) {
messages.emplace_back(msg);
},
migraphx::log::severity::warn);
}

~warning_sink() { migraphx::log::remove_sink(id); }

bool any_unbound_dim_warning() const
{
return std::any_of(messages.begin(), messages.end(), [](const std::string& m) {
return m.find("unbound symbolic dimension") != std::string::npos;
});
}
};
} // namespace

TEST_CASE(set_default_dim_value_suppresses_unbound_dim_warning)
{
warning_sink sink;
migraphx::onnx_options opts;
opts.set_default_dim_value(4);
(void)read_onnx("dim_param_test.onnx", opts);

EXPECT(not sink.any_unbound_dim_warning());
}

TEST_CASE(set_default_dyn_dim_value_suppresses_unbound_dim_warning)
{
warning_sink sink;
migraphx::onnx_options opts;
opts.set_default_dyn_dim_value(migraphx::dynamic_dimension{1, 1});
(void)read_onnx("dim_param_test.onnx", opts);

EXPECT(not sink.any_unbound_dim_warning());
}

TEST_CASE(unbound_dim_param_emits_warning)
{
warning_sink sink;
(void)read_onnx("dim_param_test.onnx");

EXPECT(sink.any_unbound_dim_warning());
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }
2 changes: 2 additions & 0 deletions tools/api/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,13 @@ static void set_file_format(file_options& options, const char* format) { options
static void set_default_dim_value(onnx_options& options, size_t value)
{
options.default_dim_value = value;
options.default_set = true;
}

static void set_default_dyn_dim_value(onnx_options& options, const shape::dynamic_dimension& dd)
{
options.default_dyn_dim_value = dd;
options.default_set = true;
}

static void set_default_loop_iterations(onnx_options& options, int64_t value)
Expand Down
Loading