Skip to content

Add fp16 support for 8-bit MatMulNBits on ARM64 and fix pre-existing bugs#27692

Merged
jambayk merged 12 commits intomainfrom
jambayk/mnb-arm-16
Mar 18, 2026
Merged

Add fp16 support for 8-bit MatMulNBits on ARM64 and fix pre-existing bugs#27692
jambayk merged 12 commits intomainfrom
jambayk/mnb-arm-16

Conversation

@jambayk
Copy link
Copy Markdown
Contributor

@jambayk jambayk commented Mar 17, 2026

Description

This PR adds fp16 (half-precision) support for 8-bit MatMulNBits on ARM64 NEON and fixes several pre-existing bugs discovered during testing.

New features:

  • HQNBIT_CompFp16 for 8-bit: Added HQ8BitGemmPackQuantBData_CompFp16 and HQ8BitBlkDequantBForHgemm_CompFp16 NEON kernels that pack and dequantize 8-bit quantized weights for fp16 GEMM. Reuses the existing HQ4BitGemmKernel_CompFp16 for the actual compute since the dequantized B matrix has the same layout.
  • HQNBIT_CompInt8 for 4-bit: Added accuracy level 4 (int8 compute) support for fp16 4-bit MatMulNBits. Converts fp16 activations to fp32, then uses the existing SQ4Bit int8 kernels.
  • HQNBIT_CompInt8 for 8-bit: Added accuracy level 4 (int8 compute) support for fp16 8-bit MatMulNBits. Converts fp16 scales to fp32 for packing, then uses the existing SQ8Bit int8 kernels.

Bug fixes:

  • Bias offset bug in CompFp16 (Windows ARM multithreading): Fixed missing + RangeStartN when initializing Bias pointer in HQ4BitGemm_CompFp16 and HQ8BitGemm_CompFp16. This caused incorrect results when using multiple threads, as worker threads processing column ranges beyond the first would read bias values from the wrong offset.
  • QuantBDataWorkspace not set for MLFloat16 fallback (macOS ARM crash): Removed #ifdef MLAS_TARGET_AMD64_IX86 guard around setting QuantBDataWorkspace in ComputeBPacked<MLFloat16>, so macOS ARM (which uses the fp32 fallback path) correctly sets the workspace pointer for SQNBIT_CompInt8.
  • Scale/ZP packing skipped on non-x64 in MLFloat16 PrePack (macOS ARM gibberish): Removed #ifdef MLAS_TARGET_AMD64_IX86 guard around the SQNBIT_CompInt8 scale and zero-point packing in the MatMulNBits<MLFloat16>::PrePack specialization. Added nbits_ == 8 condition to match the generic template's behavior on ARM (only 8-bit needs separate scale packing on ARM, while x64 needs it for both 4-bit and 8-bit).

Motivation and Context

8-bit quantized models with fp16 inputs are increasingly common on ARM devices (Windows ARM, macOS Apple Silicon). The existing MatMulNBits implementation only supported 4-bit for the HQNBIT fp16 paths. This change extends support to 8-bit, enabling faster inference for 8-bit quantized models on ARM64 without requiring fp16→fp32 conversion of the weights.

The bug fixes address issues that were either pre-existing (the #ifdef guards were copy-paste inconsistencies from prior PRs) or introduced alongside the fp16 NEON support (the Bias offset issue). These caused crashes or incorrect output on macOS ARM and multithreaded Windows ARM configurations.

Improvements

Measured on Snapdragon X Elite - X1E78100 - Qualcomm Oryon CPU

Accuracy level 4 (uses HQNBIT_CompInt8) vs Accuracy level 1 (uses HQNBIT_CompFp16)

Model Seq 1 Seq 256 Seq 512
4-bit
Qwen 0.5B 1.19× (9.6ms) 1.36× (428ms) 1.27× (1119ms)
Qwen 1.5B 0.89× (39.8ms) 1.62× (1371ms) 1.54× (2694ms)
Qwen 3B 1.16× (46.8ms) 1.54× (2654ms) 1.43× (5427ms)
8-bit
Qwen 0.5B 0.79× (22.5ms) 2.59× (257ms) 2.16× (642ms)
Qwen 1.5B 1.14× (41.4ms) 2.50× (848ms) 2.55× (1636ms)
Qwen 3B 1.07× (52.9ms) 1.95× (2133ms) 2.29× (3799ms)

Latest changes vs ORT 1.24.3 (both accuracy level 4)

On ORT 1.24.3:

  • 4 bit uses HQNBIT_CompFp16
  • 8 bit uses naive unpacked dequantize and matmul
Model Seq 1 Seq 256 Seq 512
4-bit
Qwen 0.5B 1.13× (9.6ms) 1.35× (428ms) 1.27× (1119ms)
Qwen 1.5B 0.82× (39.8ms) 1.40× (1371ms) 1.47× (2694ms)
Qwen 3B 1.16× (46.8ms) 1.47× (2654ms) 1.51× (5427ms)
8-bit
Qwen 0.5B 35.4× (22.5ms) 5.0× (257ms) 3.2× (642ms)
Qwen 1.5B 98.0× (41.4ms) 6.8× (848ms) 4.7× (1636ms)
Qwen 3B 107.8× (52.9ms) 4.1× (2133ms) 3.1× (3799ms)

jambayk added 6 commits March 16, 2026 19:24
Enable native fp16 GEMM (HQNBIT_CompFp16) for 8-bit quantized weights on
ARM64 with NEON fp16 intrinsics. Previously only 4-bit weights had the
fp16 compute path; 8-bit fp16 inputs fell back to the slow ComputeBUnpacked
path with multiple precision conversions.

Key changes:
- Add HQ8BitGemmVariant_CompFp16 variant and wire up dispatch/availability
- Implement 8N-interleaved B data packing (HQ8BitGemmPackQuantBData_CompFp16)
  using two Transpose8x8 operations per 16K x 8N tile
- Implement NEON fp16 dequant kernel (HQ8BitBlkDequantBForHgemm_CompFp16)
  that loads interleaved uint8 data, widens to fp16, and applies per-column
  scale/zero-point via FMA
- Reuse existing HQ4BitGemmKernel_CompFp16 for the GEMM step since it
  operates on dequantized fp16 data and is bit-width agnostic
- Add MLAS-level unit tests for 8-bit prepack and dequant
- Add operator-level test for Float16_8b_ARM_CompFp16
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends MLAS/MatMulNBits support on ARM64 by enabling an fp16-native (HQNBIT_CompFp16) path for 8-bit quantized weights and adds HQNBIT_CompInt8 variants for fp16 inputs, alongside several platform-specific packing/workspace fixes and new test coverage.

Changes:

  • Add ARM64 NEON kernels to prepack (8N interleave) and dequantize 8-bit quantized B for fp16 GEMM, reusing the existing fp16 GEMM microkernel.
  • Add HQNBIT_CompInt8 support paths (4-bit and 8-bit) and update workspace sizing/alignment/variant dispatch accordingly.
  • Add unit/integration tests and wire new kernel source into the build.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16_8bit.cpp New ARM64 NEON implementation for 8-bit B prepack + fp16 dequant used by HQNBIT_CompFp16
onnxruntime/core/mlas/lib/qnbitgemm.cpp Adds HQ8Bit variants (CompFp16/CompInt8), integrates packing dispatch, and fixes bias offset for fp16 HQ GEMM
onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp Extends workspace sizing/alignment for HQNBIT_CompInt8 and wires new HQ8Bit fp16 dispatch
onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h Declares new HQ8Bit fp16 pack/dequant entry points
onnxruntime/core/mlas/lib/qnbitgemm.h Extends dispatch struct with HQ8Bit pack/dequant pointers
onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc Fixes workspace/scales/ZP packing guards and adds HQ8Bit CompInt8 fp16 handling
onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp Adds targeted NEON fp16 8-bit prepack and dequant tests
onnxruntime/test/contrib_ops/matmul_8bits_test.cc Adds an ARM64-only fp16 8-bit test for the HQNBIT_CompFp16 path
cmake/onnxruntime_mlas.cmake Adds the new NEON source file to MLAS build for ARM64 targets

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment thread onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16_8bit.cpp
Comment thread onnxruntime/core/mlas/lib/qnbitgemm.cpp
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends MLAS MatMulNBits support on ARM64 by adding native fp16 (HQNBIT_CompFp16) handling for 8-bit quantized weights, adds fp16 “accuracy level 4” (HQNBIT_CompInt8) paths for 4-bit and 8-bit, and fixes a few platform-specific bugs in MatMulNBits prepack/compute plumbing.

Changes:

  • Added ARM64 NEON kernels to pack and dequantize 8-bit block-quantized B for fp16 GEMM (reusing the existing HQ4Bit fp16 GEMM kernel for compute).
  • Added HQNBIT_CompInt8 support paths (accuracy level 4) for fp16 MatMulNBits, including workspace sizing/alignment and dispatch integration.
  • Added/updated unit tests and fixed pre-existing issues around bias offsets and packing/workspace setup on non-x64.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp Adds 8-bit prepack and dequant unit tests for ARM64 fp16 HQ paths.
onnxruntime/test/contrib_ops/matmul_8bits_test.cc Adds an ARM64-only fp16 HQNBIT_CompFp16 test for 8-bit MatMulNBits.
onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h Declares new ARM64 NEON 8-bit HQ fp16 pack/dequant entry points.
onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp Updates workspace sizing/alignment and NEON dispatch table for new compute types and 8-bit HQ fp16 hooks.
onnxruntime/core/mlas/lib/qnbitgemm.h Extends dispatch struct with HQ 8-bit pack/dequant function pointers.
onnxruntime/core/mlas/lib/qnbitgemm.cpp Adds new HQ 8-bit variants, implements HQ 8-bit fp16/int8 compute, and fixes bias offset bug in fp16 HQ kernels.
onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16_8bit.cpp New NEON implementation for 8-bit B packing (8N interleave) and dequant-to-fp16 for HQNBIT_CompFp16.
onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc Adjusts PrePack/ComputeBPacked logic for HQ int8 + fp16 workflows and fixes prepack/workspace guards on non-x64.
cmake/onnxruntime_mlas.cmake Wires the new NEON fp16 8-bit kernel source into builds.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment thread onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp
Comment thread onnxruntime/core/mlas/lib/qnbitgemm.cpp Outdated
Comment thread onnxruntime/core/mlas/lib/qnbitgemm.cpp Outdated
Comment thread onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp
Comment thread onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends MLAS MatMulNBits on ARM64 by adding native fp16 (HQNBIT_CompFp16) support for 8-bit quantized weights, and adds/adjusts HQNBIT_CompInt8 plumbing for fp16 inputs (notably for 8-bit), alongside fixes for packing/workspace issues that affected non-x64 platforms and multithreaded bias handling.

Changes:

  • Add ARM64 NEON kernels to pack (8N interleave) and dequantize 8-bit B data to fp16 for HQNBIT_CompFp16, reusing the existing fp16 GEMM microkernel.
  • Extend QNBitGemm dispatch/variant selection and workspace logic to recognize HQ 8-bit variants (fp16 + int8 compute paths).
  • Add targeted unit tests for the new kernels and an ARM64 fp16 MatMulNBits test case; adjust MatMulNBits prepack/workspace behaviors for non-x64 paths.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp Adds fp16 NEON unit tests for 8-bit prepack and dequant, and improves fp16 compare logic.
onnxruntime/test/contrib_ops/matmul_8bits_test.cc Adds an ARM64-only fp16 native GEMM (accuracy level 2 / HQNBIT_CompFp16) 8-bit MatMulNBits test.
onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h Declares new HQ 8-bit fp16 pack/dequant entry points for NEON dispatch.
onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp Wires new HQ 8-bit fp16 functions into the NEON QNBitGemm dispatch; updates workspace sizing for HQNBIT_CompInt8.
onnxruntime/core/mlas/lib/qnbitgemm.h Extends the dispatch struct with HQ8Bit pack/dequant function pointers.
onnxruntime/core/mlas/lib/qnbitgemm.cpp Adds HQ8Bit variants, availability checks, packing routes, and new HQ8Bit GEMM implementations for fp16 and int8 compute.
onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16_8bit.cpp New: implements 8-bit B packing (8N interleave) and dequant-to-fp16 for HQNBIT_CompFp16 on ARM64 NEON.
onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc Adjusts PrePack/ComputeBPacked logic for HQ/SQ int8 paths and non-x64 scale/zp packing behavior.
cmake/onnxruntime_mlas.cmake Adds the new NEON fp16 8-bit kernel source file to MLAS build lists and compile flags.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment thread onnxruntime/core/mlas/lib/qnbitgemm.h Outdated
Comment thread onnxruntime/core/mlas/lib/qnbitgemm.cpp
Comment thread onnxruntime/core/mlas/lib/qnbitgemm.cpp
Comment thread onnxruntime/core/mlas/lib/qnbitgemm.cpp Outdated
Comment thread onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
@hariharans29
Copy link
Copy Markdown
Member

Added accuracy level 4 (int8 compute) support for fp16 8-bit MatMulNBits. Converts fp16 scales to fp32 for packing, then uses the existing SQ8Bit int8 kernels.

I am a little confused by this (probably my understanding of the kernels is not super good :)).
But should this statement read - Converts fp16 activations to fp32 and then uses existing SQ8Bit int8 kernels along with scales conversion from fp16 to fp32 ?

@jambayk
Copy link
Copy Markdown
Contributor Author

jambayk commented Mar 17, 2026

Added accuracy level 4 (int8 compute) support for fp16 8-bit MatMulNBits. Converts fp16 scales to fp32 for packing, then uses the existing SQ8Bit int8 kernels.

I am a little confused by this (probably my understanding of the kernels is not super good :)). But should this statement read - Converts fp16 activations to fp32 and then uses existing SQ8Bit int8 kernels along with scales conversion from fp16 to fp32 ?

Yes, the activations are also converted to fp32. The scales are pre-converted to fp32 during prepack.

The OpTester destructor asserts testing_function_called_ in Debug
builds (raises SIGTRAP if false). RunTest8Bits for MLFloat16 on CPU
was constructing the OpTester but never calling test.RunWithConfig(),
causing the SIGTRAP.

Add MLFloat16 path in the CPU #else branch that checks
MlasIsQNBitGemmAvailable(8, 32, HQNBIT_CompFp16) and runs the test
when the fp16 GEMM kernels are available.
@hariharans29
Copy link
Copy Markdown
Member

Probably related to this PR: #27251

Comment thread onnxruntime/test/contrib_ops/matmul_8bits_test.cc
Address reviewer feedback:
1. Add Float16_8b_ARM_CompInt8 test (accuracy_level=4, HQNBIT_CompInt8)
   for fp16 A x Int8 B on ARM64. This path was untested.
2. Expand Float16_8b_ARM_CompFp16 from 9 to 36 test cases to match the
   coverage of Float32_8b_AccuracyLevel4: small/large N (1-288), large K
   (1024, 1234), block_size 16/32/64/128, K not divisible by block_size
   (260/32), M=1/2/100/199.
3. Fix RunTest8Bits CPU path to select the correct compute type based on
   accuracy_level (HQNBIT_CompInt8 for level 4, HQNBIT_CompFp16 otherwise)
   to avoid SIGTRAP when adding the CompInt8 test.
hariharans29
hariharans29 previously approved these changes Mar 17, 2026
@jambayk jambayk enabled auto-merge (squash) March 17, 2026 21:45
fp16 accumulation error grows with K. The previous abs_error=0.055
was too tight for K=1024/1234/3072 cases. Align with existing
conventions:
- CompFp16: abs=0.1 (matches 4-bit fp16 tests)
- CompInt8: abs=0.1*1.02 (matches fp32 int8 tests)
@jambayk jambayk requested a review from hariharans29 March 17, 2026 22:16
@jambayk jambayk merged commit c1f38c0 into main Mar 18, 2026
91 checks passed
@jambayk jambayk deleted the jambayk/mnb-arm-16 branch March 18, 2026 07:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants