Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,14 +867,17 @@ def _(
else:
raise RuntimeError(f"unsupported dtype {A.dtype}")

# Offset is expected to be a float32 tensor.
absmax_offset_f32 = absmax_offset.to(dtype=torch.float32) if absmax_offset is not None else None

with _cuda_device_of(A):
fn(
A.data_ptr(),
B.data_ptr(),
absmax.data_ptr(),
absmax_8bit.data_ptr() if absmax_8bit is not None else None,
absmax_code.data_ptr() if absmax_code is not None else None,
absmax_offset.data_ptr() if absmax_offset is not None else None,
absmax_offset_f32.data_ptr() if absmax_offset_f32 is not None else None,
out.data_ptr(),
bias.data_ptr() if bias is not None else None,
M,
Expand Down
41 changes: 41 additions & 0 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,47 @@ def test_gemm_4bit(self, device, dtype, quant_type, compress_statistics, has_bia
kwargs={"bias": bias},
)

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("offset_dtype", [torch.float16, torch.bfloat16], ids=describe_dtype)
def test_gemm_4bit_non_float32_offset(self, device, dtype, offset_dtype):
"""Regression test: offset tensors not in float32 must still produce correct results."""
N, K, blocksize = 64, 64, 64
A = torch.randn(4, K, dtype=dtype, device=device)
B = torch.randn(N, K, dtype=dtype, device=device)
B_q, qs = bitsandbytes.functional.quantize_4bit(
B, blocksize=blocksize, quant_type="nf4", compress_statistics=True
)

# Simulate a pre-quantized model where offset may not be float32.
offset_non_f32 = qs.offset.to(dtype=offset_dtype)

# Reference: explicitly use the rounded float32 value.
offset_as_f32 = offset_non_f32.to(dtype=torch.float32)
ref = torch.ops.bitsandbytes.gemm_4bit.default(
A,
B_q,
list(B.shape),
qs.state2.absmax,
blocksize,
"nf4",
absmax_8bit=qs.absmax,
absmax_code=qs.state2.code,
absmax_offset=offset_as_f32,
)
out = torch.ops.bitsandbytes.gemm_4bit.default(
A,
B_q,
list(B.shape),
qs.state2.absmax,
blocksize,
"nf4",
absmax_8bit=qs.absmax,
absmax_code=qs.state2.code,
absmax_offset=offset_non_f32,
)
torch.testing.assert_close(out, ref)


class TestNonContiguousInputs:
"""Regression tests for #1342 and #1690: quantization must handle non-contiguous tensors correctly."""
Expand Down
Loading