diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 079511761ad..ed6ff0b3853 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -559,9 +559,16 @@ struct find_slice_shape_transforms return; new_desc.simplify(); - // Optimizes shape transforms if the slice cant be optimized + // Bail to the safe path if rebasing onto the slice input is unsafe: + // either the sliced axis splits into multiple dst axes, or the rebase + // changed the output shape on a sliced axis if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { - return new_desc.get_dst_axes_from_src(axis).size() != 1; + if(new_desc.get_dst_axes_from_src(axis).size() != 1) + return true; + auto dst_axis = new_desc.get_dst_axes_from_src(axis).front(); + return new_desc.lens()[dst_axis] * slice->get_shape().lens()[axis] != + ins->get_shape().lens()[dst_axis] * + slice->inputs().front()->get_shape().lens()[axis]; })) { auto opt_ops = desc.generate(); diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 3290129ea9c..ae032ce4f35 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -5073,6 +5073,28 @@ TEST_CASE(slice_reshape_multibroadcast_rebase_axis) EXPECT(m1.get_output_shapes() == m2.get_output_shapes()); } +TEST_CASE(slice_multibroadcast_over_sliced_axis) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 4, 6}}); + // TODO: this reshape+slice+reshape chain could be simplified to + // reshape[2, 4, 6] -> slice[axes=0, starts=1, ends=2]. + auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 24}}}), x); + auto sl = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), rsp); + auto rsp2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 4, 6}}}), sl); + auto mb = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 4, 6}}}), + rsp2); + auto bias = m1.add_parameter("bias", {migraphx::shape::float_type, {3, 4, 6}}); + auto sum = m1.add_instruction(migraphx::make_op("add"), mb, bias); + m1.add_return({sum}); + } + auto m2 = m1; + run_pass(m1); + EXPECT(m1.get_output_shapes() == m2.get_output_shapes()); +} + TEST_CASE(broadcast_nop_reduce_mean) { migraphx::module m1;