Add fp16 support for 8-bit MatMulNBits on ARM64 and fix pre-existing bugs#27692
Add fp16 support for 8-bit MatMulNBits on ARM64 and fix pre-existing bugs#27692
Conversation
Enable native fp16 GEMM (HQNBIT_CompFp16) for 8-bit quantized weights on ARM64 with NEON fp16 intrinsics. Previously only 4-bit weights had the fp16 compute path; 8-bit fp16 inputs fell back to the slow ComputeBUnpacked path with multiple precision conversions. Key changes: - Add HQ8BitGemmVariant_CompFp16 variant and wire up dispatch/availability - Implement 8N-interleaved B data packing (HQ8BitGemmPackQuantBData_CompFp16) using two Transpose8x8 operations per 16K x 8N tile - Implement NEON fp16 dequant kernel (HQ8BitBlkDequantBForHgemm_CompFp16) that loads interleaved uint8 data, widens to fp16, and applies per-column scale/zero-point via FMA - Reuse existing HQ4BitGemmKernel_CompFp16 for the GEMM step since it operates on dequantized fp16 data and is bit-width agnostic - Add MLAS-level unit tests for 8-bit prepack and dequant - Add operator-level test for Float16_8b_ARM_CompFp16
There was a problem hiding this comment.
Pull request overview
This PR extends MLAS/MatMulNBits support on ARM64 by enabling an fp16-native (HQNBIT_CompFp16) path for 8-bit quantized weights and adds HQNBIT_CompInt8 variants for fp16 inputs, alongside several platform-specific packing/workspace fixes and new test coverage.
Changes:
- Add ARM64 NEON kernels to prepack (8N interleave) and dequantize 8-bit quantized B for fp16 GEMM, reusing the existing fp16 GEMM microkernel.
- Add HQNBIT_CompInt8 support paths (4-bit and 8-bit) and update workspace sizing/alignment/variant dispatch accordingly.
- Add unit/integration tests and wire new kernel source into the build.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16_8bit.cpp | New ARM64 NEON implementation for 8-bit B prepack + fp16 dequant used by HQNBIT_CompFp16 |
| onnxruntime/core/mlas/lib/qnbitgemm.cpp | Adds HQ8Bit variants (CompFp16/CompInt8), integrates packing dispatch, and fixes bias offset for fp16 HQ GEMM |
| onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp | Extends workspace sizing/alignment for HQNBIT_CompInt8 and wires new HQ8Bit fp16 dispatch |
| onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h | Declares new HQ8Bit fp16 pack/dequant entry points |
| onnxruntime/core/mlas/lib/qnbitgemm.h | Extends dispatch struct with HQ8Bit pack/dequant pointers |
| onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc | Fixes workspace/scales/ZP packing guards and adds HQ8Bit CompInt8 fp16 handling |
| onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp | Adds targeted NEON fp16 8-bit prepack and dequant tests |
| onnxruntime/test/contrib_ops/matmul_8bits_test.cc | Adds an ARM64-only fp16 8-bit test for the HQNBIT_CompFp16 path |
| cmake/onnxruntime_mlas.cmake | Adds the new NEON source file to MLAS build for ARM64 targets |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
There was a problem hiding this comment.
Pull request overview
This PR extends MLAS MatMulNBits support on ARM64 by adding native fp16 (HQNBIT_CompFp16) handling for 8-bit quantized weights, adds fp16 “accuracy level 4” (HQNBIT_CompInt8) paths for 4-bit and 8-bit, and fixes a few platform-specific bugs in MatMulNBits prepack/compute plumbing.
Changes:
- Added ARM64 NEON kernels to pack and dequantize 8-bit block-quantized B for fp16 GEMM (reusing the existing HQ4Bit fp16 GEMM kernel for compute).
- Added HQNBIT_CompInt8 support paths (accuracy level 4) for fp16 MatMulNBits, including workspace sizing/alignment and dispatch integration.
- Added/updated unit tests and fixed pre-existing issues around bias offsets and packing/workspace setup on non-x64.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp | Adds 8-bit prepack and dequant unit tests for ARM64 fp16 HQ paths. |
| onnxruntime/test/contrib_ops/matmul_8bits_test.cc | Adds an ARM64-only fp16 HQNBIT_CompFp16 test for 8-bit MatMulNBits. |
| onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h | Declares new ARM64 NEON 8-bit HQ fp16 pack/dequant entry points. |
| onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp | Updates workspace sizing/alignment and NEON dispatch table for new compute types and 8-bit HQ fp16 hooks. |
| onnxruntime/core/mlas/lib/qnbitgemm.h | Extends dispatch struct with HQ 8-bit pack/dequant function pointers. |
| onnxruntime/core/mlas/lib/qnbitgemm.cpp | Adds new HQ 8-bit variants, implements HQ 8-bit fp16/int8 compute, and fixes bias offset bug in fp16 HQ kernels. |
| onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16_8bit.cpp | New NEON implementation for 8-bit B packing (8N interleave) and dequant-to-fp16 for HQNBIT_CompFp16. |
| onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc | Adjusts PrePack/ComputeBPacked logic for HQ int8 + fp16 workflows and fixes prepack/workspace guards on non-x64. |
| cmake/onnxruntime_mlas.cmake | Wires the new NEON fp16 8-bit kernel source into builds. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
There was a problem hiding this comment.
Pull request overview
This PR extends MLAS MatMulNBits on ARM64 by adding native fp16 (HQNBIT_CompFp16) support for 8-bit quantized weights, and adds/adjusts HQNBIT_CompInt8 plumbing for fp16 inputs (notably for 8-bit), alongside fixes for packing/workspace issues that affected non-x64 platforms and multithreaded bias handling.
Changes:
- Add ARM64 NEON kernels to pack (8N interleave) and dequantize 8-bit B data to fp16 for HQNBIT_CompFp16, reusing the existing fp16 GEMM microkernel.
- Extend QNBitGemm dispatch/variant selection and workspace logic to recognize HQ 8-bit variants (fp16 + int8 compute paths).
- Add targeted unit tests for the new kernels and an ARM64 fp16 MatMulNBits test case; adjust MatMulNBits prepack/workspace behaviors for non-x64 paths.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp | Adds fp16 NEON unit tests for 8-bit prepack and dequant, and improves fp16 compare logic. |
| onnxruntime/test/contrib_ops/matmul_8bits_test.cc | Adds an ARM64-only fp16 native GEMM (accuracy level 2 / HQNBIT_CompFp16) 8-bit MatMulNBits test. |
| onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h | Declares new HQ 8-bit fp16 pack/dequant entry points for NEON dispatch. |
| onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp | Wires new HQ 8-bit fp16 functions into the NEON QNBitGemm dispatch; updates workspace sizing for HQNBIT_CompInt8. |
| onnxruntime/core/mlas/lib/qnbitgemm.h | Extends the dispatch struct with HQ8Bit pack/dequant function pointers. |
| onnxruntime/core/mlas/lib/qnbitgemm.cpp | Adds HQ8Bit variants, availability checks, packing routes, and new HQ8Bit GEMM implementations for fp16 and int8 compute. |
| onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16_8bit.cpp | New: implements 8-bit B packing (8N interleave) and dequant-to-fp16 for HQNBIT_CompFp16 on ARM64 NEON. |
| onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc | Adjusts PrePack/ComputeBPacked logic for HQ/SQ int8 paths and non-x64 scale/zp packing behavior. |
| cmake/onnxruntime_mlas.cmake | Adds the new NEON fp16 8-bit kernel source file to MLAS build lists and compile flags. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
|
I am a little confused by this (probably my understanding of the kernels is not super good :)). |
Yes, the activations are also converted to fp32. The scales are pre-converted to fp32 during prepack. |
The OpTester destructor asserts testing_function_called_ in Debug builds (raises SIGTRAP if false). RunTest8Bits for MLFloat16 on CPU was constructing the OpTester but never calling test.RunWithConfig(), causing the SIGTRAP. Add MLFloat16 path in the CPU #else branch that checks MlasIsQNBitGemmAvailable(8, 32, HQNBIT_CompFp16) and runs the test when the fp16 GEMM kernels are available.
|
Probably related to this PR: #27251 |
Address reviewer feedback: 1. Add Float16_8b_ARM_CompInt8 test (accuracy_level=4, HQNBIT_CompInt8) for fp16 A x Int8 B on ARM64. This path was untested. 2. Expand Float16_8b_ARM_CompFp16 from 9 to 36 test cases to match the coverage of Float32_8b_AccuracyLevel4: small/large N (1-288), large K (1024, 1234), block_size 16/32/64/128, K not divisible by block_size (260/32), M=1/2/100/199. 3. Fix RunTest8Bits CPU path to select the correct compute type based on accuracy_level (HQNBIT_CompInt8 for level 4, HQNBIT_CompFp16 otherwise) to avoid SIGTRAP when adding the CompInt8 test.
fp16 accumulation error grows with K. The previous abs_error=0.055 was too tight for K=1024/1234/3072 cases. Align with existing conventions: - CompFp16: abs=0.1 (matches 4-bit fp16 tests) - CompInt8: abs=0.1*1.02 (matches fp32 int8 tests)
Description
This PR adds fp16 (half-precision) support for 8-bit MatMulNBits on ARM64 NEON and fixes several pre-existing bugs discovered during testing.
New features:
HQ8BitGemmPackQuantBData_CompFp16andHQ8BitBlkDequantBForHgemm_CompFp16NEON kernels that pack and dequantize 8-bit quantized weights for fp16 GEMM. Reuses the existingHQ4BitGemmKernel_CompFp16for the actual compute since the dequantized B matrix has the same layout.Bug fixes:
+ RangeStartNwhen initializingBiaspointer inHQ4BitGemm_CompFp16andHQ8BitGemm_CompFp16. This caused incorrect results when using multiple threads, as worker threads processing column ranges beyond the first would read bias values from the wrong offset.#ifdef MLAS_TARGET_AMD64_IX86guard around settingQuantBDataWorkspaceinComputeBPacked<MLFloat16>, so macOS ARM (which uses the fp32 fallback path) correctly sets the workspace pointer for SQNBIT_CompInt8.#ifdef MLAS_TARGET_AMD64_IX86guard around the SQNBIT_CompInt8 scale and zero-point packing in theMatMulNBits<MLFloat16>::PrePackspecialization. Addednbits_ == 8condition to match the generic template's behavior on ARM (only 8-bit needs separate scale packing on ARM, while x64 needs it for both 4-bit and 8-bit).Motivation and Context
8-bit quantized models with fp16 inputs are increasingly common on ARM devices (Windows ARM, macOS Apple Silicon). The existing MatMulNBits implementation only supported 4-bit for the HQNBIT fp16 paths. This change extends support to 8-bit, enabling faster inference for 8-bit quantized models on ARM64 without requiring fp16→fp32 conversion of the weights.
The bug fixes address issues that were either pre-existing (the
#ifdefguards were copy-paste inconsistencies from prior PRs) or introduced alongside the fp16 NEON support (the Bias offset issue). These caused crashes or incorrect output on macOS ARM and multithreaded Windows ARM configurations.Improvements
Measured on
Snapdragon X Elite - X1E78100 - Qualcomm Oryon CPUAccuracy level 4 (uses HQNBIT_CompInt8) vs Accuracy level 1 (uses HQNBIT_CompFp16)
Latest changes vs ORT 1.24.3 (both accuracy level 4)
On ORT 1.24.3: