Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions src/fuse_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,18 @@ struct find_attention
auto expand = fix([&](auto self, auto ins) {
for(auto input : ins->inputs())
{
if(not contains(attn_inss, input) and input->can_eval())
{
attn_inss.insert(input);
self(input);
}
if(contains(attn_inss, input) or not input->can_eval())
continue;
// A captured @literal is lowered as `migraphx.literal`, whose
// dense_elements_attr requires standard (row-major) strides.
// Leave non-standard literals (e.g. constant-folded transposed
// bias) outside the group so they enter as a parameter
// instead, where mlir_contiguous + adjust_param_shapes can
// normalise the layout.
Comment on lines +178 to +180
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just remove this comment. There so many things incorrect in this comment it will just cause more confusion.

if(input->name() == "@literal" and not input->get_shape().standard())
continue;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesnt seem like the correct fix. We need to skip literals that are not iota literals, but you are not checking if its an iota literal.

attn_inss.insert(input);
self(input);
}
});
auto starts = attn_inss;
Expand Down
156 changes: 156 additions & 0 deletions test/fuse_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2229,6 +2229,162 @@
EXPECT(migraphx::ceil_mul_of(2049, 32) == 2080); // 2049 -> 2080 (padding = 31)
}

// A bias literal that has been constant-folded into a non-standard (transposed)
// layout must NOT be inlined as `@literal` inside the attention group,
// otherwise it lowers to a `migraphx.literal` op with non-standard strides
// which rocMLIR rejects. It should instead remain in the main module and
// enter the group as a regular input (parameter).
TEST_CASE(gemm_softmax_gemm_with_transposed_literal_bias)
{
migraphx::shape s1{migraphx::shape::half_type, {1, 4, 8, 8}};
// Non-standard but packed strides: the same byte layout you'd get from
// constant-folding `transpose([0,2,3,1], @literal{1,8,8,4})` into a
// single literal viewed as {1, 4, 8, 8}.
migraphx::shape s_bias{migraphx::shape::half_type, {1, 4, 8, 8}, {4, 1, 32, 4}};
auto s1_elements = s1.elements();

migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto a = mm->add_parameter("1", s1);
auto b = mm->add_parameter("2", s1);
auto b1 = mm->add_parameter("3", s1);
std::vector<float> bias_vals(s1_elements, 0.5f);
auto bias = mm->add_literal(migraphx::literal{s_bias, bias_vals});
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}),
b1);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto add = mm->add_instruction(migraphx::make_op("add"), gemm1, bias);
auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), add);
rmax = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}),
rmax);
auto sub = mm->add_instruction(migraphx::make_op("sub"), add, rmax);
auto exp = mm->add_instruction(migraphx::make_op("exp"), sub);
auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp);
rsum = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}),
rsum);
auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum);
auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1);
mm->add_return({gemm2});

// Sanity: the literal really is non-standard (otherwise this test
// wouldn't exercise the fix).
EXPECT(not bias->get_shape().standard());
EXPECT(bias->get_shape().packed());
}
run_pass(p1, {.attn_enabled = true});

migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto a = mm->add_parameter("1", s1);
auto b = mm->add_parameter("2", s1);
auto b1 = mm->add_parameter("3", s1);
std::vector<float> bias_vals(s1_elements, 0.5f);
// Bias literal stays in the main module and enters the attention
// group as an outer input rather than an inlined @literal.
auto bias = mm->add_literal(migraphx::literal{s_bias, bias_vals});
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}),
b1);
auto group = add_group(
p2, "attn0", "attention", {a, b, bias, b1}, [=](auto* gm, const auto& inputs) {
auto gemm1 = gm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]);
auto add = gm->add_instruction(migraphx::make_op("add"), gemm1, inputs[2]);
auto rmax =
gm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), add);
rmax = gm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rmax);
auto sub = gm->add_instruction(migraphx::make_op("sub"), add, rmax);
auto exp = gm->add_instruction(migraphx::make_op("exp"), sub);
auto rsum =
gm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp);
rsum = gm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rsum);
auto div = gm->add_instruction(migraphx::make_op("div"), exp, rsum);
auto gemm2 = gm->add_instruction(migraphx::make_op("dot"), div, inputs[3]);
return std::vector<migraphx::instruction_ref>{gemm2};
});
mm->add_return({group});
}
EXPECT(p1.sort() == p2.sort());
}

// Companion to the test above: when the captured bias literal is in standard
// (row-major) layout the existing inlining behaviour MUST be preserved -

Check warning on line 2315 in test/fuse_attention.cpp

View workflow job for this annotation

GitHub Actions / misspell

[misspell] test/fuse_attention.cpp#L2315

"behaviour" is a misspelling of "behavior"
Raw output
./test/fuse_attention.cpp:2315:44: "behaviour" is a misspelling of "behavior"
// otherwise we'd lose the optimisation that the original
// "Pull in evaluable constants so MLIR can detect causal masks" code path
// is intentionally enabling.
TEST_CASE(gemm_softmax_gemm_with_standard_literal_bias)
{
migraphx::shape s1{migraphx::shape::half_type, {1, 4, 8, 8}};
auto s1_elements = s1.elements();

migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto a = mm->add_parameter("1", s1);
auto b = mm->add_parameter("2", s1);
auto b1 = mm->add_parameter("3", s1);
std::vector<float> bias_vals(s1_elements, 0.25f);
auto bias = mm->add_literal(migraphx::literal{s1, bias_vals});
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}),
b1);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto add = mm->add_instruction(migraphx::make_op("add"), gemm1, bias);
auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), add);
rmax = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}),
rmax);
auto sub = mm->add_instruction(migraphx::make_op("sub"), add, rmax);
auto exp = mm->add_instruction(migraphx::make_op("exp"), sub);
auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp);
rsum = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}),
rsum);
auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum);
auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1);
mm->add_return({gemm2});

EXPECT(bias->get_shape().standard());
}
run_pass(p1, {.attn_enabled = true});

migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto a = mm->add_parameter("1", s1);
auto b = mm->add_parameter("2", s1);
auto b1 = mm->add_parameter("3", s1);
std::vector<float> bias_vals(s1_elements, 0.25f);
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}),
b1);
// Standard-strided bias is inlined inside the attention group.
auto group =
add_group(p2, "attn0", "attention", {a, b, b1}, [=](auto* gm, const auto& inputs) {
auto bias_lit = gm->add_literal(migraphx::literal{s1, bias_vals});
auto gemm1 = gm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]);
auto add = gm->add_instruction(migraphx::make_op("add"), gemm1, bias_lit);
auto rmax =
gm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), add);
rmax = gm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rmax);
auto sub = gm->add_instruction(migraphx::make_op("sub"), add, rmax);
auto exp = gm->add_instruction(migraphx::make_op("exp"), sub);
auto rsum =
gm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp);
rsum = gm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rsum);
auto div = gm->add_instruction(migraphx::make_op("div"), exp, rsum);
auto gemm2 = gm->add_instruction(migraphx::make_op("dot"), div, inputs[2]);
return std::vector<migraphx::instruction_ref>{gemm2};
});
mm->add_return({group});
}
EXPECT(p1.sort() == p2.sort());
}

int main(int argc, const char* argv[])
{
test::run(argc, argv);
Expand Down
Loading