diff --git a/src/include/migraphx/match/dot_softmax_dot.hpp b/src/include/migraphx/match/dot_softmax_dot.hpp new file mode 100644 index 00000000000..d91cff27e1f --- /dev/null +++ b/src/include/migraphx/match/dot_softmax_dot.hpp @@ -0,0 +1,67 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_MATCH_DOT_SOFTMAX_DOT_HPP +#define MIGRAPHX_GUARD_MATCH_DOT_SOFTMAX_DOT_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace match { + +/// Match the (undecomposed) `dot -> softmax -> dot` attention pattern, with +/// optional `mul` (scale), `add` (bias), or `where` (mask) ops between the +/// first dot and the softmax. This is the form before `rewrite_reduce` +/// decomposes softmax into its `div(exp(sub(x, max)), sum(exp(...)))` chain. +/// +/// `gemm_pred` is applied to both dot operations; pass `match::any()` to +/// match any dot. `bias_pred` is applied to the optional `add` (bias) op. +/// +/// Bound names: "gemm1", "gemm2", "softmax", and (when the corresponding op +/// is present) "scale", "bias", "select_const", "select_cond". +template +inline auto dot_softmax_dot(GemmPred gemm_pred, BiasPred bias_pred) +{ + auto gemm1 = + match::skip(match::name("contiguous"))(match::name("dot")(gemm_pred.bind("gemm1"))); + auto mul = match::name("mul")( + match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1)); + auto where = match::name("where")(match::arg(2)(match::is_constant().bind("select_const")), + match::arg(1)(mul), + match::arg(0)(match::any().bind("select_cond"))); + auto add = match::name("add")( + bias_pred, match::nargs(2), match::either_arg(0, 1)(match::none_of(mul).bind("bias"), mul)); + auto softmax = match::name("softmax")(match::arg(0)(match::any_of(mul, add, gemm1, where))) + .bind("softmax"); + return match::name("dot")(gemm_pred.bind("gemm2"))(match::arg(0)(softmax)); +} + +inline auto dot_softmax_dot() { return dot_softmax_dot(match::any(), match::any()); } + +} // namespace match +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/quantize_8bits.hpp b/src/include/migraphx/quantize_8bits.hpp index 9345025837d..a597f8bbb41 100644 --- a/src/include/migraphx/quantize_8bits.hpp +++ b/src/include/migraphx/quantize_8bits.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -30,6 +30,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -45,6 +46,7 @@ struct MIGRAPHX_EXPORT capture_arguments_pass std::unordered_set ins_names = {"dot", "convolution"}; std::function)> f{}; std::size_t* param_index = nullptr; + std::unordered_set skip_instructions{}; std::string name() const { return "capture_arguments"; } void apply(module& m) const; }; diff --git a/src/quantization.cpp b/src/quantization.cpp index 35be7dbe6d7..8a8d262fdd8 100644 --- a/src/quantization.cpp +++ b/src/quantization.cpp @@ -23,6 +23,8 @@ */ #include #include +#include +#include #include #include #include @@ -95,6 +97,20 @@ static void quantize_8bits(program& prog, // avoid loss of precision. run_passes(prog, {rewrite_rnn{}, normalize_ops{}, optimize_module{}}, quant_tracer()); + // Skip Q/DQ insertion for instructions inside attention regions so the + // dot->softmax->dot pattern remains intact for fuse_attention + auto* mm = prog.get_main_module(); + std::unordered_set skip_instructions; + for(auto ins : iterator_for(*mm)) + { + auto r = match::match_instruction(*mm, ins, match::dot_softmax_dot()); + if(r.result == mm->end()) + continue; + auto region = + find_instructions_between(r.instructions["gemm1"], r.instructions["gemm2"], mm); + skip_instructions.insert(region.begin(), region.end()); + } + std::shared_ptr>> quant_8bit_params = std::make_shared>>(); std::shared_ptr> max_abs_vals = std::make_shared>(); @@ -127,8 +143,10 @@ static void quantize_8bits(program& prog, // pass to add capture argument op std::size_t param_num = 0; - run_passes( - prog, {capture_arguments_pass{ins_names, calc_quant_params, ¶m_num}}, quant_tracer()); + run_passes(prog, + {capture_arguments_pass{ + ins_names, calc_quant_params, ¶m_num, std::move(skip_instructions)}}, + quant_tracer()); quant_8bit_params->resize(param_num, std::pair(64.0f, 0.0f)); max_abs_vals->resize(param_num, 0.0f); @@ -193,8 +211,9 @@ void quantize_int4_weights(program& prog) void quantize_fp8(program& prog, const target& t, const std::vector& calibration) { - std::unordered_set supported_ins_names; auto* mm = prog.get_main_module(); + + std::unordered_set supported_ins_names; for(auto ins : iterator_for(*mm)) { if(ins->name() == "convert") @@ -206,6 +225,7 @@ void quantize_fp8(program& prog, const target& t, const std::vectorname()); } } + quantize_8bits(prog, t, shape::fp8e4m3fn_type, calibration, supported_ins_names); } diff --git a/src/quantize_8bits.cpp b/src/quantize_8bits.cpp index 1f06a3451b8..74be907c866 100644 --- a/src/quantize_8bits.cpp +++ b/src/quantize_8bits.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -106,6 +106,10 @@ void capture_arguments_pass::apply(module& m) const // NOLINT { continue; } + if(contains(skip_instructions, ins)) + { + continue; + } auto inputs = ins->inputs(); std::vector new_args; diff --git a/src/targets/gpu/prefuse_ops.cpp b/src/targets/gpu/prefuse_ops.cpp index 3db298930bd..c7262b20367 100644 --- a/src/targets/gpu/prefuse_ops.cpp +++ b/src/targets/gpu/prefuse_ops.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -180,23 +181,8 @@ struct find_gemm_softmax_gemm auto matcher() const { - auto gemm1 = match::skip(match::name("contiguous"))(match::name("dot")( - match::any_of(is_ck_gemm(), is_test_gemm(enable_attention)).bind("gemm1"))); - auto mul = match::name("mul")( - match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1)); - auto where = match::name("where")(match::arg(2)(match::is_constant().bind("select_const")), - match::arg(1)(mul), - match::arg(0)(match::any().bind("select_cond"))); - auto add = - match::name("add")(is_bias_supported(), - match::nargs(2), - match::either_arg(0, 1)(match::none_of(mul).bind("bias"), mul)); - auto softmax = match::name("softmax")(match::arg(0)(match::any_of(mul, add, gemm1, where))) - .bind("softmax"); - - return match::name("dot")( - match::any_of(is_ck_gemm(), is_test_gemm(enable_attention)).bind("gemm2"))( - match::arg(0)(softmax)); + return match::dot_softmax_dot(match::any_of(is_ck_gemm(), is_test_gemm(enable_attention)), + is_bias_supported()); } void apply(module_pass_manager& mpm, const match::matcher_result& r) const diff --git a/test/quantization.cpp b/test/quantization.cpp index a33c88badec..158e3f88417 100644 --- a/test/quantization.cpp +++ b/test/quantization.cpp @@ -610,6 +610,47 @@ TEST_CASE(op_capture_subgraph) } } +TEST_CASE(op_capture_skip_attention) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {1, 4, 8, 8}}; + auto x = mm->add_parameter("x", s); + auto wq = mm->add_parameter("wq", s); + auto wo = mm->add_parameter("wo", s); + auto k = mm->add_parameter("k", s); + auto v = mm->add_parameter("v", s); + + // Pre-attention projection: outside the attention region. + auto pre = mm->add_instruction(migraphx::make_op("dot"), x, wq); + + // Attention region: these dots must NOT be captured. + auto kt = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), k); + auto qk = mm->add_instruction(migraphx::make_op("dot"), pre, kt); + auto sm = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), qk); + auto attn = mm->add_instruction(migraphx::make_op("dot"), sm, v); + + // Post-attention projection: outside the attention region. + auto out = mm->add_instruction(migraphx::make_op("dot"), attn, wo); + mm->add_return({out}); + + migraphx::target t = migraphx::make_target("ref"); + std::vector cali; + migraphx::quantize_fp8(p, t, cali); + + auto has_dq_input = [](auto ins) { + return std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto in) { + return in->name() == "dequantizelinear"; + }); + }; + + EXPECT(not has_dq_input(qk)); + EXPECT(not has_dq_input(attn)); + EXPECT(has_dq_input(pre)); + EXPECT(has_dq_input(out)); +} + TEST_CASE(dot_float) { auto create_program = [] {