Skip to content

[Good first issue] Add Q4K support to MLX backend #20172

@metascroy

Description

@metascroy

🚀 The feature, motivation and pitch

Summary

Add native fused Q4_K Metal kernels for the MLX backend, matching the Q6_K
support added in [#20004]. Today Q4_K linear/embedding are lowered by repacking
the GGUF blob into MLX's native affine 4-bit qparams at export time and calling
MLX's built-in quantized matmul / gather. We want Q4_K to instead read the raw
GGUF block_q4_K directly in fused custom Metal kernels, the same way Q6_K does.

Background

ExportableGGUFTensor (extension/llm/export/gguf.py) lowers a quantized
linear/embedding to a torchao::dequantize_gguf -> linear/embedding subgraph.
The MLX pattern handlers in
backends/mlx/custom_kernel_ops/gguf/patterns.py match that subgraph and lower
it without materializing the dequantized weight.

The two formats are handled very differently today:

  • Q6_K → fused custom Metal kernels in
    backends/mlx/custom_kernel_ops/gguf/q6k/. A block_q6_K struct plus
    dequant helpers live in q6k/common.py (_Q6K_HEADER), and q6k/linear.py
    emits two kernels ported from llama.cpp:

    • M == 1 (decode): a fused mat-vec kernel (kernel_mul_mv_q6_K_f32_impl).
    • M > 1 (prefill): a tiled simdgroup mat-mat kernel (kernel_mul_mm).
    • dynamic M: both are emitted and selected at runtime via an IfNode.
      These read the GGUF bytes directly and never repack.
  • Q4_Kbackends/mlx/custom_kernel_ops/gguf/q4k/. Instead of custom
    kernels, q4k/common.py::_repack_mlx unpacks the GGUF blob and repacks it
    into MLX affine qparams (S*Q + B, group_size 32, 4-bit), and q4k/linear.py
    / q4k/embedding.py just emit a generic MLX QuantizedMatmulNode /
    quantized gather. This works but is a "rewrite to MLX quantized linear"
    rather than a true GGUF kernel: it requires an export-time repack and stores
    MLX-format constants instead of the original GGUF bytes.

Task

Implement fused Q4_K Metal kernels analogous to Q6_K so that Q4_K consumes the
raw block_q4_K directly, removing the dependency on the export-time
repack-to-MLX-qparams path.

Concretely:

  1. q4k/common.py — add a _Q4K_HEADER Metal header with the block_q4_K
    struct and dequant helpers (per-element for embedding, vectorized for
    matmul), plus QK_K / Q4K_BLOCK_BYTES constants. Port from llama.cpp
    dequantize_q4_K (ggml-common.h / ggml-metal.metal). Note Q4_K's layout
    differs from Q6_K — it carries both a super-block scale d and min dmin
    (affine), with 6-bit packed sub-block scales/mins:

    #define QK_K 256
    #define K_SCALE_SIZE 12
    typedef struct {
        half     d;                    // super-block scale for the quantized scales
        half     dmin;                 // super-block scale for the quantized mins
        uint8_t  scales[K_SCALE_SIZE]; // 6-bit packed scales + mins
        uint8_t  qs[QK_K/2];           // 4-bit quants
    } block_q4_K;                      // 144 bytes
  2. q4k/linear.py — replace the _repack_mlx + QuantizedMatmulNode path
    with mat-vec (decode), mat-mat (prefill), and dynamic-M IfNode emission,
    mirroring q6k/linear.py (kernel_mul_mv_q4_K_f32_impl and the Q4_K
    kernel_mul_mm variant).

  3. q4k/embedding.py — replace the MLX quantized gather with a per-element
    Q4_K dequant gather, mirroring q6k/embedding.py.

  4. patterns.py — update the module docstrings/comments that currently say
    "Q4_K → MLX's native 4-bit affine ops" once the kernels land. (Dispatch is
    already keyed on ggml_type, so the handler wiring should need little
    change.)

  5. Remove the now-unused _repack_mlx helper if nothing else depends on it.

Testing

Tests already exist and exercise Q4_K — see
backends/mlx/custom_kernel_ops/gguf/test/test_linear.py (and
test_embedding.py). There is already a make_q4_k_blob fixture and Q4_K
configs in GGUFLinearTest.get_test_configs. The current reference
(_fp32_linear_reference) special-cases Q4_K to reconstruct the repacked MLX
qparams; once kernels read the raw blob, switch the Q4_K reference to the
gguf-exact dequant (weight.dequantize(torch.float32)), same as Q6_K.

Run on an Apple-silicon machine:

python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear run -v
python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_embedding run -v

Pointers

  • Reference implementation (Q6_K): backends/mlx/custom_kernel_ops/gguf/q6k/{common,linear,embedding}.py
  • Code to replace (Q4_K): backends/mlx/custom_kernel_ops/gguf/q4k/{common,linear,embedding}.py
  • Pattern handlers: backends/mlx/custom_kernel_ops/gguf/patterns.py
  • Tests: backends/mlx/custom_kernel_ops/gguf/test/test_linear.py, test_embedding.py
  • Reference PR: [MLX][Gemma4] Introduce Q6K kernels #20004 (Q6_K support)

Attribution

The Q4_K block layout and Metal dequant helpers should be ported from llama.cpp
(ggml-common.h / ggml-metal.metal: block_q4_K, dequantize_q4_K,
kernel_mul_mv_q4_K_f32_impl, kernel_mul_mm), which is MIT-licensed
(Copyright (c) 2023-2024 The ggml authors). Keep inline ported from ... notes
as in the Q6_K kernels.

Alternatives

No response

Additional context

No response

RFC (Optional)

No response

Metadata

Metadata

Assignees

Labels

Type

No type
No fields configured for issues without a type.

Projects

Status
No status

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions