From a51abb0c7b2931eb7c35924f35cbfcf7601c431f Mon Sep 17 00:00:00 2001 From: Ahsan Saghir Date: Thu, 7 May 2026 19:15:37 +0000 Subject: [PATCH] Fix attention for non-standard literal --- src/fuse_attention.cpp | 17 +++-- test/fuse_attention.cpp | 156 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 168 insertions(+), 5 deletions(-) diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index b0cc4d7cc05..92c2ef0f21a 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -170,11 +170,18 @@ struct find_attention auto expand = fix([&](auto self, auto ins) { for(auto input : ins->inputs()) { - if(not contains(attn_inss, input) and input->can_eval()) - { - attn_inss.insert(input); - self(input); - } + if(contains(attn_inss, input) or not input->can_eval()) + continue; + // A captured @literal is lowered as `migraphx.literal`, whose + // dense_elements_attr requires standard (row-major) strides. + // Leave non-standard literals (e.g. constant-folded transposed + // bias) outside the group so they enter as a parameter + // instead, where mlir_contiguous + adjust_param_shapes can + // normalise the layout. + if(input->name() == "@literal" and not input->get_shape().standard()) + continue; + attn_inss.insert(input); + self(input); } }); auto starts = attn_inss; diff --git a/test/fuse_attention.cpp b/test/fuse_attention.cpp index 64cb4ca2d06..c2d49f86a98 100644 --- a/test/fuse_attention.cpp +++ b/test/fuse_attention.cpp @@ -2229,6 +2229,162 @@ TEST_CASE(ceil_mul_of_function) EXPECT(migraphx::ceil_mul_of(2049, 32) == 2080); // 2049 -> 2080 (padding = 31) } +// A bias literal that has been constant-folded into a non-standard (transposed) +// layout must NOT be inlined as `@literal` inside the attention group, +// otherwise it lowers to a `migraphx.literal` op with non-standard strides +// which rocMLIR rejects. It should instead remain in the main module and +// enter the group as a regular input (parameter). +TEST_CASE(gemm_softmax_gemm_with_transposed_literal_bias) +{ + migraphx::shape s1{migraphx::shape::half_type, {1, 4, 8, 8}}; + // Non-standard but packed strides: the same byte layout you'd get from + // constant-folding `transpose([0,2,3,1], @literal{1,8,8,4})` into a + // single literal viewed as {1, 4, 8, 8}. + migraphx::shape s_bias{migraphx::shape::half_type, {1, 4, 8, 8}, {4, 1, 32, 4}}; + auto s1_elements = s1.elements(); + + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("1", s1); + auto b = mm->add_parameter("2", s1); + auto b1 = mm->add_parameter("3", s1); + std::vector bias_vals(s1_elements, 0.5f); + auto bias = mm->add_literal(migraphx::literal{s_bias, bias_vals}); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); + b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), + b1); + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = mm->add_instruction(migraphx::make_op("add"), gemm1, bias); + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), add); + rmax = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), + rmax); + auto sub = mm->add_instruction(migraphx::make_op("sub"), add, rmax); + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp); + rsum = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), + rsum); + auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); + auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); + mm->add_return({gemm2}); + + // Sanity: the literal really is non-standard (otherwise this test + // wouldn't exercise the fix). + EXPECT(not bias->get_shape().standard()); + EXPECT(bias->get_shape().packed()); + } + run_pass(p1, {.attn_enabled = true}); + + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("1", s1); + auto b = mm->add_parameter("2", s1); + auto b1 = mm->add_parameter("3", s1); + std::vector bias_vals(s1_elements, 0.5f); + // Bias literal stays in the main module and enters the attention + // group as an outer input rather than an inlined @literal. + auto bias = mm->add_literal(migraphx::literal{s_bias, bias_vals}); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); + b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), + b1); + auto group = add_group( + p2, "attn0", "attention", {a, b, bias, b1}, [=](auto* gm, const auto& inputs) { + auto gemm1 = gm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = gm->add_instruction(migraphx::make_op("add"), gemm1, inputs[2]); + auto rmax = + gm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), add); + rmax = gm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rmax); + auto sub = gm->add_instruction(migraphx::make_op("sub"), add, rmax); + auto exp = gm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = + gm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp); + rsum = gm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rsum); + auto div = gm->add_instruction(migraphx::make_op("div"), exp, rsum); + auto gemm2 = gm->add_instruction(migraphx::make_op("dot"), div, inputs[3]); + return std::vector{gemm2}; + }); + mm->add_return({group}); + } + EXPECT(p1.sort() == p2.sort()); +} + +// Companion to the test above: when the captured bias literal is in standard +// (row-major) layout the existing inlining behaviour MUST be preserved - +// otherwise we'd lose the optimisation that the original +// "Pull in evaluable constants so MLIR can detect causal masks" code path +// is intentionally enabling. +TEST_CASE(gemm_softmax_gemm_with_standard_literal_bias) +{ + migraphx::shape s1{migraphx::shape::half_type, {1, 4, 8, 8}}; + auto s1_elements = s1.elements(); + + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("1", s1); + auto b = mm->add_parameter("2", s1); + auto b1 = mm->add_parameter("3", s1); + std::vector bias_vals(s1_elements, 0.25f); + auto bias = mm->add_literal(migraphx::literal{s1, bias_vals}); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); + b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), + b1); + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = mm->add_instruction(migraphx::make_op("add"), gemm1, bias); + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), add); + rmax = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), + rmax); + auto sub = mm->add_instruction(migraphx::make_op("sub"), add, rmax); + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp); + rsum = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), + rsum); + auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); + auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); + mm->add_return({gemm2}); + + EXPECT(bias->get_shape().standard()); + } + run_pass(p1, {.attn_enabled = true}); + + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("1", s1); + auto b = mm->add_parameter("2", s1); + auto b1 = mm->add_parameter("3", s1); + std::vector bias_vals(s1_elements, 0.25f); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); + b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), + b1); + // Standard-strided bias is inlined inside the attention group. + auto group = + add_group(p2, "attn0", "attention", {a, b, b1}, [=](auto* gm, const auto& inputs) { + auto bias_lit = gm->add_literal(migraphx::literal{s1, bias_vals}); + auto gemm1 = gm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto add = gm->add_instruction(migraphx::make_op("add"), gemm1, bias_lit); + auto rmax = + gm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), add); + rmax = gm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rmax); + auto sub = gm->add_instruction(migraphx::make_op("sub"), add, rmax); + auto exp = gm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = + gm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp); + rsum = gm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rsum); + auto div = gm->add_instruction(migraphx::make_op("div"), exp, rsum); + auto gemm2 = gm->add_instruction(migraphx::make_op("dot"), div, inputs[2]); + return std::vector{gemm2}; + }); + mm->add_return({group}); + } + EXPECT(p1.sort() == p2.sort()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv);