Skip to content

fix: ensure absmax_offset is of type float32 before passing to gemm_4bit#1971

Open
kathsucurry wants to merge 2 commits into
bitsandbytes-foundation:mainfrom
kathsucurry:fix/gemm-4bit-absmax-offset-dtype
Open

fix: ensure absmax_offset is of type float32 before passing to gemm_4bit#1971
kathsucurry wants to merge 2 commits into
bitsandbytes-foundation:mainfrom
kathsucurry:fix/gemm-4bit-absmax-offset-dtype

Conversation

@kathsucurry

@kathsucurry kathsucurry commented Jun 11, 2026

Copy link
Copy Markdown

TLDR

gemm_4bit kernel functions expect a float32 absmax_offset, but the offsets from some pre-quantized models are bfloat16. The discrepancy leads to gibberish offset values in the kernel, which further affects the outputs.

Relevant package versions:

  • torch: 2.10.0+cu128
  • accelerate: 1.13.0
  • transformers: 5.3.0

Background

I was playing around with Unsloth_Puzzles.ipynb part B on RTX5070Ti (sm120) with the same model, unsloth/meta-Llama-3.1-8B-Instruct-bnb-4bit, when I noticed that the training loss values changed drastically after commit 5453368 "[CUDA] New 4bit GEMM kernels for inference (#1949)". Specifically, the loss starts off much larger and quickly becomes 0 (NaN values):

Step Training loss before 5453368 Training loss after 5453368
1 2.109375 7.015625
2 2.343750 7.828125
3 2.750000 11.593750
4 2.015625 0.000000
5 1.960938 0.000000
6 1.945312 0.000000
7 1.656250 0.000000
8 1.750000 0.000000
9 1.617188 0.000000
10 1.882812 0.000000

Note that the training loss before the mentioned commit is similar to that in the notebook.

By trying out different models, I found that the loss values show the same issue when 1) a pre-quantized base model is used and 2) gemm_4bit is called (when use_custom is True). Some of the models I used are as follows:

Model Is affected Is pre-quantized use_custom is True
unsloth/meta-Llama-3.1-8B-Instruct-bnb-4bit Yes Yes Yes
unsloth/gemma-3-270m-bnb-4bit Yes Yes Yes
unsloth/gemma-3-270m No No Yes
unsloth/Llama-3.2-3B-Instruct-bnb-4bit No Yes No

Findings

The bug appears to be caused by the type of absmax_offset. The absmax_offset values coming from the pre-quantized models I've used so far are bfloat16, while the kernel function expects a float32 tensor. Consequently, the offset value becomes gibberish and extremely large, which causes overflow.

I've made a small change along with the corresponding test. The training loss for all the models mentioned above is now more stable. The model unsloth/meta-Llama-3.1-8B-Instruct-bnb-4bit, for instance, now gives the following outputs:

Step Training loss before 5453368 Training loss after 5453368 Training loss after the bug fix
1 2.109375 7.015625 2.109375
2 2.343750 7.828125 2.335938
3 2.750000 11.593750 2.755859
4 2.015625 0.000000 2.013672
5 1.960938 0.000000 1.948242
6 1.945312 0.000000 1.929688
7 1.656250 0.000000 1.637695
8 1.750000 0.000000 1.728516
9 1.617188 0.000000 1.561523
10 1.882812 0.000000 1.808594

I've only tried Unsloth's pre-quantized models, so do let me know if there are other pre-quantized models you'd like me to try (or if there are other tests you'd like me to do).

@matthewdouglas

Copy link
Copy Markdown
Member

Thanks for catching this! Out of curiosity, can you please provide some more detail on your environment? Specifically versions of torch, accelerate, and transformers.

This seems reasonable as a defensive measure, but I don't think the root issue is in the serialized models themselves. There's a quirk in the 4bit serialization where this offset value is actually stored in a utf-8 encoded uint8 tensor as a JSON string. When deserialized, it uses the Python float type:

if "nested_absmax" in qs_dict:
offset = torch.tensor(float(qs_dict["nested_offset"])).to(device)
state2 = cls(

I can see how in some circumstance (e.g. torch.set_default_dtype() is applied in context where we're deserializing) that this may materialize as a fp16 or bf16 tensor. I think this may be the more likely cause, so I may open a separate PR here to force that to be fp32.

The other possibility is QuantState.to() which is supposed to mostly only handle device movement, not casting. It's called when a weight (Params4bit.to()) is moved. It is possible this may be casting the offset if it is called as weight.to(torch.bfloat16), so I may also address this in a PR as well.

One last, but less likely, possibility here is that you have an AVX512BF16 enabled CPU, and moved between CPU and GPU at some point after running a forward pass on CPU. This may have inadvertently cast the offset.

Since you only observe the issue with pre-quantized models, it leans heavily toward being related to the deserialization.

@github-actions

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@matthewdouglas matthewdouglas added this to the v0.50.0 milestone Jun 11, 2026
Comment thread tests/test_ops.py Outdated
@kathsucurry

Copy link
Copy Markdown
Author

Ah I should have included the versions in the PR description. Below are the versions (and I'll add them in the description as well):

  • torch: 2.10.0+cu128
  • accelerate: 1.13.0
  • transformers: 5.3.0

Thank you so much for the comments! The detailed breakdown gives really helpful context on how the offset gets stored and deserialized. I still lack context on the serialization/deserialization path or QuantState.to(), so I'll be happy to look into the possibilities you outlined if you'd like me to!

Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants