diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 7825d6585..09d2e2244 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -867,6 +867,9 @@ def _( else: raise RuntimeError(f"unsupported dtype {A.dtype}") + # Offset is expected to be a float32 tensor. + absmax_offset_f32 = absmax_offset.to(dtype=torch.float32) if absmax_offset is not None else None + with _cuda_device_of(A): fn( A.data_ptr(), @@ -874,7 +877,7 @@ def _( absmax.data_ptr(), absmax_8bit.data_ptr() if absmax_8bit is not None else None, absmax_code.data_ptr() if absmax_code is not None else None, - absmax_offset.data_ptr() if absmax_offset is not None else None, + absmax_offset_f32.data_ptr() if absmax_offset_f32 is not None else None, out.data_ptr(), bias.data_ptr() if bias is not None else None, M, diff --git a/tests/test_ops.py b/tests/test_ops.py index 3550c0b6f..69589dcc0 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -333,6 +333,47 @@ def test_gemm_4bit(self, device, dtype, quant_type, compress_statistics, has_bia kwargs={"bias": bias}, ) + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=describe_dtype) + @pytest.mark.parametrize("offset_dtype", [torch.float16, torch.bfloat16], ids=describe_dtype) + def test_gemm_4bit_non_float32_offset(self, device, dtype, offset_dtype): + """Regression test: offset tensors not in float32 must still produce correct results.""" + N, K, blocksize = 64, 64, 64 + A = torch.randn(4, K, dtype=dtype, device=device) + B = torch.randn(N, K, dtype=dtype, device=device) + B_q, qs = bitsandbytes.functional.quantize_4bit( + B, blocksize=blocksize, quant_type="nf4", compress_statistics=True + ) + + # Simulate a pre-quantized model where offset may not be float32. + offset_non_f32 = qs.offset.to(dtype=offset_dtype) + + # Reference: explicitly use the rounded float32 value. + offset_as_f32 = offset_non_f32.to(dtype=torch.float32) + ref = torch.ops.bitsandbytes.gemm_4bit.default( + A, + B_q, + list(B.shape), + qs.state2.absmax, + blocksize, + "nf4", + absmax_8bit=qs.absmax, + absmax_code=qs.state2.code, + absmax_offset=offset_as_f32, + ) + out = torch.ops.bitsandbytes.gemm_4bit.default( + A, + B_q, + list(B.shape), + qs.state2.absmax, + blocksize, + "nf4", + absmax_8bit=qs.absmax, + absmax_code=qs.state2.code, + absmax_offset=offset_non_f32, + ) + torch.testing.assert_close(out, ref) + class TestNonContiguousInputs: """Regression tests for #1342 and #1690: quantization must handle non-contiguous tensors correctly."""