From 812c27b0ac315c7934de2037d5d83d4fe893a158 Mon Sep 17 00:00:00 2001 From: Uros Petkovic Date: Tue, 19 May 2026 14:48:22 +0200 Subject: [PATCH 1/5] Fix simplify reshapes error for thm models --- src/simplify_reshapes.cpp | 15 +++++++++++++++ test/simplify_reshapes_test.cpp | 20 ++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 884d387e3f3..68c13d96fbd 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -570,6 +570,21 @@ struct find_slice_shape_transforms return; } + // Bail to the safe path if the rebase changed the output shape on a + // sliced axis (e.g. a broadcast over the sliced axis was absorbed) + if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { + auto dst_axis = new_desc.get_dst_axes_from_src(axis).front(); + return new_desc.lens()[dst_axis] * slice->get_shape().lens()[axis] != + ins->get_shape().lens()[dst_axis] * + slice->inputs().front()->get_shape().lens()[axis]; + })) + { + auto opt_ops = desc.generate(); + auto y = insert_ops(m, ins, opt_ops, slice); + m.replace_instruction(ins, y); + return; + } + // Map slice axes using the rebased descriptor to correctly track // where dimensions end up after rebase reorders them std::vector new_axes; diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index dd022ba28d7..4781321de48 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -5073,4 +5073,24 @@ TEST_CASE(slice_reshape_multibroadcast_rebase_axis) EXPECT(m1.get_output_shapes() == m2.get_output_shapes()); } +TEST_CASE(slice_multibroadcast_over_sliced_axis) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 4, 6}}); + auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 24}}}), x); + auto sl = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), rsp); + auto rsp2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 4, 6}}}), sl); + auto mb = m1.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {3, 4, 6}}}), rsp2); + auto bias = m1.add_parameter("bias", {migraphx::shape::float_type, {3, 4, 6}}); + auto sum = m1.add_instruction(migraphx::make_op("add"), mb, bias); + m1.add_return({sum}); + } + auto m2 = m1; + run_pass(m1); + EXPECT(m1.get_output_shapes() == m2.get_output_shapes()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From b07ea78fe055eea849dae754aab4b31d11788708 Mon Sep 17 00:00:00 2001 From: Uros Petkovic <127323899+urpetkov-amd@users.noreply.github.com> Date: Tue, 19 May 2026 15:13:02 +0200 Subject: [PATCH 2/5] Update test/simplify_reshapes_test.cpp Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/simplify_reshapes_test.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 4781321de48..5246f4a3396 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -5077,9 +5077,9 @@ TEST_CASE(slice_multibroadcast_over_sliced_axis) { migraphx::module m1; { - auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 4, 6}}); - auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 24}}}), x); - auto sl = m1.add_instruction( + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 4, 6}}); + auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 24}}}), x); + auto sl = m1.add_instruction( migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), rsp); auto rsp2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 4, 6}}}), sl); auto mb = m1.add_instruction( From 5cd47d3bdd4bef500d78226313a0a6758b814377 Mon Sep 17 00:00:00 2001 From: Uros Petkovic <127323899+urpetkov-amd@users.noreply.github.com> Date: Tue, 19 May 2026 15:13:46 +0200 Subject: [PATCH 3/5] Update test/simplify_reshapes_test.cpp Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/simplify_reshapes_test.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 5246f4a3396..3fda093b898 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -5081,11 +5081,11 @@ TEST_CASE(slice_multibroadcast_over_sliced_axis) auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 24}}}), x); auto sl = m1.add_instruction( migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), rsp); - auto rsp2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 4, 6}}}), sl); - auto mb = m1.add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {3, 4, 6}}}), rsp2); - auto bias = m1.add_parameter("bias", {migraphx::shape::float_type, {3, 4, 6}}); - auto sum = m1.add_instruction(migraphx::make_op("add"), mb, bias); + auto rsp2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 4, 6}}}), sl); + auto mb = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4, 6}}}), + rsp2); + auto bias = m1.add_parameter("bias", {migraphx::shape::float_type, {3, 4, 6}}); + auto sum = m1.add_instruction(migraphx::make_op("add"), mb, bias); m1.add_return({sum}); } auto m2 = m1; From bb4202e741607f7b6618cbc5b641b700271d7141 Mon Sep 17 00:00:00 2001 From: Uros Petkovic Date: Tue, 19 May 2026 15:14:25 +0200 Subject: [PATCH 4/5] Simplify solution --- src/simplify_reshapes.cpp | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 68c13d96fbd..dd564ca1efc 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -559,20 +559,12 @@ struct find_slice_shape_transforms return; new_desc.simplify(); - // Optimizes shape transforms if the slice cant be optimized - if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { - return new_desc.get_dst_axes_from_src(axis).size() != 1; - })) - { - auto opt_ops = desc.generate(); - auto y = insert_ops(m, ins, opt_ops, slice); - m.replace_instruction(ins, y); - return; - } - - // Bail to the safe path if the rebase changed the output shape on a - // sliced axis (e.g. a broadcast over the sliced axis was absorbed) + // Bail to the safe path if rebasing onto the slice input is unsafe: + // either the sliced axis splits into multiple dst axes, or the rebase + // changed the output shape on a sliced axis if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { + if(new_desc.get_dst_axes_from_src(axis).size() != 1) + return true; auto dst_axis = new_desc.get_dst_axes_from_src(axis).front(); return new_desc.lens()[dst_axis] * slice->get_shape().lens()[axis] != ins->get_shape().lens()[dst_axis] * From c4fb7e37c70a89e6b1e0e4f9e5f418311b9fc2b5 Mon Sep 17 00:00:00 2001 From: Uros Petkovic Date: Tue, 19 May 2026 16:09:01 +0200 Subject: [PATCH 5/5] Adding TODO comment --- test/simplify_reshapes_test.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 9dada0dbda0..ae032ce4f35 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -5077,7 +5077,9 @@ TEST_CASE(slice_multibroadcast_over_sliced_axis) { migraphx::module m1; { - auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 4, 6}}); + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 4, 6}}); + // TODO: this reshape+slice+reshape chain could be simplified to + // reshape[2, 4, 6] -> slice[axes=0, starts=1, ends=2]. auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 24}}}), x); auto sl = m1.add_instruction( migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), rsp); @@ -5091,6 +5093,8 @@ TEST_CASE(slice_multibroadcast_over_sliced_axis) auto m2 = m1; run_pass(m1); EXPECT(m1.get_output_shapes() == m2.get_output_shapes()); +} + TEST_CASE(broadcast_nop_reduce_mean) { migraphx::module m1;