Skip to content

[AIMIGRAPHX-1017] Skip Q/DQ for Attention Ops#4900

Draft
eddieliao wants to merge 4 commits into
developfrom
skip_attention_qdq
Draft

[AIMIGRAPHX-1017] Skip Q/DQ for Attention Ops#4900
eddieliao wants to merge 4 commits into
developfrom
skip_attention_qdq

Conversation

@eddieliao
Copy link
Copy Markdown
Contributor

Motivation

Skips inserting Q/DQ pairs for attention patterns so they can be fused later on.

Technical Details

Before:

2026-05-20 17:05:48.707049 [INFO] [/code/AMDMIGraphX/src/driver/main.cpp:1134] Running [ MIGraphX Version: 2.16.0.20250912-17-427-g854b494ae ]: ./build-develop/bin/driver perf attention_fp8_test.onnx --fp8
2026-05-20 17:05:48.707179 [INFO] [/code/AMDMIGraphX/src/driver/main.cpp:448] Reading: attention_fp8_test.onnx
2026-05-20 17:05:48.707908 [INFO] [/code/AMDMIGraphX/src/driver/main.cpp:754] Quantizing to fp8 ...
2026-05-20 17:05:53.643159 [INFO] [/code/AMDMIGraphX/src/driver/main.cpp:762] Compiling ...
module: "main"
@0 = check_context::migraphx::gpu::context -> float_type, {}, {}
@1 = hip::hip_allocate_memory[shape=int8_type, {295698432}, {1},id=main:scratch] -> int8_type, {295698432}, {1}
@2 = load[offset=18874368,end=85983232](@1) -> float_type, {1, 16, 1024, 1024}, {16777216, 1048576, 1024, 1}
k = @param:k -> half_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}
q = @param:q -> half_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}
@5 = gpu::code_object[code_object=16440,symbol_name=mlir_quantizelinear_transpose_quantizelinear_quant_dot,global=131072,local=256,output_arg=2,](q,k,@2) -> float_type, {1, 16, 1024, 1024}, {16777216, 1048576, 1024, 1}
@6 = load[offset=16777216,end=18874368](@1) -> fp8e4m3fn_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}
v = @param:v -> half_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}
@8 = gpu::code_object[code_object=6504,symbol_name=quantizelinear_kernel,global=2097152,local=1024,](v,@6) -> fp8e4m3fn_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}
@9 = load[offset=0,end=16777216](@1) -> fp8e4m3fn_type, {1, 16, 1024, 1024}, {16777216, 1048576, 1024, 1}
@10 = gpu::code_object[code_object=10472,symbol_name=dequantizelinear_reduce_max_sub_exp_reduce_sum_div_quantizelinear_kernel,global=4194304,local=256,](@5,@9) -> fp8e4m3fn_type, {1, 16, 1024, 1024}, {16777216, 1048576, 1024, 1}
@11 = load[offset=287309824,end=295698432](@1) -> float_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}
@12 = load[offset=18874368,end=287309824](@1) -> uint8_type, {268435456}, {1}
@13 = gpu::hip_quant_gemm[alpha=1,beta=0,trans_batch=0,solution_idx=0](@10,@8,@12,@11) -> float_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}
main:#output_0 = @param:main:#output_0 -> half_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}
w_o = @param:w_o -> half_type, {1, 16, 128, 128}, {262144, 16384, 128, 1}
@16 = gpu::code_object[code_object=11112,symbol_name=mlir_dequantizelinear_quantizelinear_quantizelinear_quant_dot_dequantizelinear,global=65536,local=256,output_arg=2,](@13,w_o,main:#output_0) -> half_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}
@17 = @return(@16)


2026-05-20 17:05:57.213673 [INFO] [/code/AMDMIGraphX/src/driver/main.cpp:933] Allocating params ...
2026-05-20 17:05:57.238790 [INFO] [/code/AMDMIGraphX/src/driver/main.cpp:935] Running performance report ...
@0 = check_context::migraphx::gpu::context -> float_type, {}, {}: 0.00036984ms, 1%
@1 = hip::hip_allocate_memory[shape=int8_type, {295698432}, {1},id=main:scratch] -> int8_type, {295698432}, {1}: 0.00039316ms, 1%
@2 = load[offset=18874368,end=85983232](@1) -> float_type, {1, 16, 1024, 1024}, {16777216, 1048576, 1024, 1}: 0.00046054ms, 1%
k = @param:k -> half_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}: 0.00031982ms, 1%
q = @param:q -> half_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}: 0.00032ms, 1%
@5 = gpu::code_object[code_object=16440,symbol_name=mlir_quantizelinear_transpose_quantizelinear_quant_dot,global=131072,local=256,output_arg=2,](q,k,@2) -> float_type, {1, 16, 1024, 1024}, {16777216, 1048576, 1024, 1}: 0.0294877ms, 19%
@6 = load[offset=16777216,end=18874368](@1) -> fp8e4m3fn_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}: 0.00047354ms, 1%
v = @param:v -> half_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}: 0.00032ms, 1%
@8 = gpu::code_object[code_object=6504,symbol_name=quantizelinear_kernel,global=2097152,local=1024,](v,@6) -> fp8e4m3fn_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}: 0.0188608ms, 12%
@9 = load[offset=0,end=16777216](@1) -> fp8e4m3fn_type, {1, 16, 1024, 1024}, {16777216, 1048576, 1024, 1}: 0.0004709ms, 1%
@10 = gpu::code_object[code_object=10472,symbol_name=dequantizelinear_reduce_max_sub_exp_reduce_sum_div_quantizelinear_kernel,global=4194304,local=256,](@5,@9) -> fp8e4m3fn_type, {1, 16, 1024, 1024}, {16777216, 1048576, 1024, 1}: 0.056494ms, 36%
@11 = load[offset=287309824,end=295698432](@1) -> float_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}: 0.00047252ms, 1%
@12 = load[offset=18874368,end=287309824](@1) -> uint8_type, {268435456}, {1}: 0.00045208ms, 1%
@13 = gpu::hip_quant_gemm[alpha=1,beta=0,trans_batch=0,solution_idx=0](@10,@8,@12,@11) -> float_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}: 0.0317516ms, 20%
main:#output_0 = @param:main:#output_0 -> half_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}: 0.00035904ms, 1%
w_o = @param:w_o -> half_type, {1, 16, 128, 128}, {262144, 16384, 128, 1}: 0.00032ms, 1%
@16 = gpu::code_object[code_object=11112,symbol_name=mlir_dequantizelinear_quantizelinear_quantizelinear_quant_dot_dequantizelinear,global=65536,local=256,output_arg=2,](@13,w_o,main:#output_0) -> half_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}: 0.0181806ms, 12%
@17 = @return(@16)
Summary:
gpu::code_object::dequantizelinear_reduce_max_sub_exp_reduce_sum_div_quantizelinear_kernel: 0.056494ms / 1 = 0.056494ms, 36%
gpu::hip_quant_gemm: 0.0317516ms / 1 = 0.0317516ms, 20%
gpu::code_object::mlir_quantizelinear_transpose_quantizelinear_quant_dot: 0.0294877ms / 1 = 0.0294877ms, 19%
gpu::code_object::quantizelinear_kernel: 0.0188608ms / 1 = 0.0188608ms, 12%
gpu::code_object::mlir_dequantizelinear_quantizelinear_quantizelinear_quant_dot_dequantizelinear: 0.0181806ms / 1 = 0.0181806ms, 12%
load: 0.00232958ms / 5 = 0.000465916ms, 2%
@param: 0.00163886ms / 5 = 0.000327772ms, 2%
hip::hip_allocate_memory: 0.00039316ms / 1 = 0.00039316ms, 1%
check_context::migraphx::gpu::context: 0.00036984ms / 1 = 0.00036984ms, 1%

Batch size: 1
Rate: 9745.39 inferences/sec
Total time: 0.102613ms (Min: 0.09991ms, Max: 0.19147ms, Mean: 0.104161ms, Median: 0.1026ms)
Percentiles (90%, 95%, 99%): (0.10583ms, 0.10612ms, 0.15152ms)
Total instructions time: 0.159506ms
Overhead time: 0.00174ms, -0.0568935ms
Overhead: 2%, -55%
2026-05-20 17:05:57.336887 [INFO] [/code/AMDMIGraphX/src/driver/main.cpp:1143] MIGRAPHX_MLIR_USE_SPECIFIC_OPS=attention\
2026-05-20 17:05:57.336940 [INFO] [/code/AMDMIGraphX/src/driver/main.cpp:1155] [ MIGraphX Version: 2.16.0.20250912-17-427-g854b494ae ] Complete(8.62984s): ./build-develop/bin/driver perf attention_fp8_test.onnx --fp8

After:

2026-05-20 17:06:02.500580 [INFO] [/code/AMDMIGraphX/src/driver/main.cpp:1134] Running [ MIGraphX Version: 2.16.0.20250912-17-452-g2ae91710e ]: ./build/bin/driver perf attention_fp8_test.onnx --fp8
2026-05-20 17:06:02.500727 [INFO] [/code/AMDMIGraphX/src/driver/main.cpp:448] Reading: attention_fp8_test.onnx
2026-05-20 17:06:02.501537 [INFO] [/code/AMDMIGraphX/src/driver/main.cpp:754] Quantizing to fp8 ...
2026-05-20 17:06:04.842209 [INFO] [/code/AMDMIGraphX/src/driver/main.cpp:762] Compiling ...
module: "main"
@0 = check_context::migraphx::gpu::context -> float_type, {}, {}
@1 = hip::hip_allocate_memory[shape=int8_type, {2097152}, {1},id=main:scratch] -> int8_type, {2097152}, {1}
@2 = load[offset=0,end=2097152](@1) -> fp8e4m3fn_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}
v = @param:v -> half_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}
k = @param:k -> half_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}
q = @param:q -> half_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}
@6 = gpu::code_object[code_object=11328,symbol_name=mlir_transpose_dot_convert_reshape_reduce_max_reshape_sub_exp_reshape_reduce_sum_reshape_div_convert_dot_quantizelinear,global=65536,local=256,output_arg=3,](q,k,v,@2) -> fp8e4m3fn_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}
main:#output_0 = @param:main:#output_0 -> half_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}
w_o = @param:w_o -> half_type, {1, 16, 128, 128}, {262144, 16384, 128, 1}
@9 = gpu::code_object[code_object=7208,symbol_name=mlir_quantizelinear_quant_dot_dequantizelinear,global=131072,local=256,output_arg=2,](w_o,@6,main:#output_0) -> half_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}
@10 = @return(@9)


2026-05-20 17:06:07.429231 [INFO] [/code/AMDMIGraphX/src/driver/main.cpp:933] Allocating params ...
2026-05-20 17:06:07.454508 [INFO] [/code/AMDMIGraphX/src/driver/main.cpp:935] Running performance report ...
@0 = check_context::migraphx::gpu::context -> float_type, {}, {}: 0.0003777ms, 1%
@1 = hip::hip_allocate_memory[shape=int8_type, {2097152}, {1},id=main:scratch] -> int8_type, {2097152}, {1}: 0.00040294ms, 1%
@2 = load[offset=0,end=2097152](@1) -> fp8e4m3fn_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}: 0.00045ms, 1%
v = @param:v -> half_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}: 0.0003138ms, 1%
k = @param:k -> half_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}: 0.00029554ms, 1%
q = @param:q -> half_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}: 0.00031052ms, 1%
@6 = gpu::code_object[code_object=11328,symbol_name=mlir_transpose_dot_convert_reshape_reduce_max_reshape_sub_exp_reshape_reduce_sum_reshape_div_convert_dot_quantizelinear,global=65536,local=256,output_arg=3,](q,k,v,@2) -> fp8e4m3fn_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}: 0.0427579ms, 70%
main:#output_0 = @param:main:#output_0 -> half_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}: 0.0003567ms, 1%
w_o = @param:w_o -> half_type, {1, 16, 128, 128}, {262144, 16384, 128, 1}: 0.00030078ms, 1%
@9 = gpu::code_object[code_object=7208,symbol_name=mlir_quantizelinear_quant_dot_dequantizelinear,global=131072,local=256,output_arg=2,](w_o,@6,main:#output_0) -> half_type, {1, 16, 1024, 128}, {2097152, 131072, 128, 1}: 0.015936ms, 26%
@10 = @return(@9)
Summary:
gpu::code_object::mlir_transpose_dot_convert_reshape_reduce_max_reshape_sub_exp_reshape_reduce_sum_reshape_div_convert_dot_quantizelinear: 0.0427579ms / 1 = 0.0427579ms, 70%
gpu::code_object::mlir_quantizelinear_quant_dot_dequantizelinear: 0.015936ms / 1 = 0.015936ms, 26%
@param: 0.00157734ms / 5 = 0.000315468ms, 3%
load: 0.00045ms / 1 = 0.00045ms, 1%
hip::hip_allocate_memory: 0.00040294ms / 1 = 0.00040294ms, 1%
check_context::migraphx::gpu::context: 0.0003777ms / 1 = 0.0003777ms, 1%

Batch size: 1
Rate: 20371.4 inferences/sec
Total time: 0.0490884ms (Min: 0.04807ms, Max: 0.052309ms, Mean: 0.0491553ms, Median: 0.04904ms)
Percentiles (90%, 95%, 99%): (0.04982ms, 0.05042ms, 0.0521ms)
Total instructions time: 0.0615019ms
Overhead time: 0.000947ms, -0.0124135ms
Overhead: 2%, -25%
2026-05-20 17:06:07.471109 [INFO] [/code/AMDMIGraphX/src/driver/main.cpp:1143] MIGRAPHX_MLIR_USE_SPECIFIC_OPS=attention\
2026-05-20 17:06:07.471138 [INFO] [/code/AMDMIGraphX/src/driver/main.cpp:1155] [ MIGraphX Version: 2.16.0.20250912-17-452-g2ae91710e ] Complete(4.97051s): ./build/bin/driver perf attention_fp8_test.onnx --fp8

Changelog Category

Add a CHANGELOG.md entry for any option other than Not Applicable

    • Added: New functionality.
    • Changed: Changes to existing functionality.
    • Removed: Functionality or support that has been removed. (Compared to a previous release)
    • Optimized: Component performance that has been optimized or improved.
    • Resolved Issues: Known issues from a previous version that have been resolved.
    • Not Applicable: This PR is not to be included in the changelog.

@eddieliao eddieliao requested a review from pfultz2 May 20, 2026 17:16
@eddieliao eddieliao self-assigned this May 20, 2026
@eddieliao eddieliao added Matchers Updates or adds a change to compile time Matchers FP8 issues related to FP8 implemenation Perf Improve labels May 20, 2026
@eddieliao
Copy link
Copy Markdown
Contributor Author

PR currently in draft as I am looking for some feedback on the premise of this change before I go and work on/clean up the actual implementation.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the FP8 quantization pipeline to avoid inserting Q/DQ pairs into dot -> softmax -> dot attention subgraphs so those regions can be fused later (improving performance and reducing scratch usage), and factors the attention-pattern matcher into a reusable header.

Changes:

  • Extracts a reusable match::dot_softmax_dot matcher and uses it in GPU prefusion matching.
  • Adds a skip_instructions mechanism to capture_arguments_pass and wires it through FP8 quantization.
  • Detects attention regions in quantize_fp8() and skips capture/QDQ insertion for those instructions.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
src/targets/gpu/prefuse_ops.cpp Switches attention prefusion matching to the new shared dot_softmax_dot matcher.
src/quantize_8bits.cpp Skips inserting capture ops (and therefore Q/DQ) for a provided set of instructions.
src/quantization.cpp Detects attention regions and passes them into the capture pass to skip Q/DQ insertion.
src/include/migraphx/quantize_8bits.hpp Extends capture_arguments_pass API to carry a skip set.
src/include/migraphx/match/dot_softmax_dot.hpp Introduces a reusable matcher for undecomposed attention (dot -> softmax -> dot).

Comment thread src/quantization.cpp Outdated
Comment on lines 44 to 50
struct MIGRAPHX_EXPORT capture_arguments_pass
{
std::unordered_set<std::string> ins_names = {"dot", "convolution"};
std::function<void(std::size_t, std::vector<argument>)> f{};
std::size_t* param_index = nullptr;
std::unordered_set<instruction_ref> skip_instructions{};
std::string name() const { return "capture_arguments"; }
Comment on lines +34 to +40
/// Match the (undecomposed) `dot -> softmax -> dot` attention pattern, with
/// optional `mul` (scale), `add` (bias), or `where` (mask) ops between the
/// first dot and the softmax. This is the form before `rewrite_reduce`
/// decomposes softmax into its `div(exp(sub(x, max)), sum(exp(...)))` chain.
///
/// `gemm_pred` is applied to both dot operations; pass `match::any()` to
/// match any dot. `bias_pred` is applied to the optional `add` (bias) op.
Comment thread src/quantization.cpp Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented May 20, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

Additional details and impacted files
@@           Coverage Diff            @@
##           develop    #4900   +/-   ##
========================================
  Coverage    92.88%   92.88%           
========================================
  Files          587      588    +1     
  Lines        30348    30365   +17     
========================================
+ Hits         28187    28204   +17     
  Misses        2161     2161           
Files with missing lines Coverage Δ
src/include/migraphx/match/dot_softmax_dot.hpp 100.00% <100.00%> (ø)
src/include/migraphx/quantize_8bits.hpp 100.00% <ø> (ø)
src/quantization.cpp 86.42% <100.00%> (+1.09%) ⬆️
src/quantize_8bits.cpp 100.00% <100.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

FP8 issues related to FP8 implemenation Matchers Updates or adds a change to compile time Matchers Perf Improve

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants