diff --git a/.github/scripts/build-cpu.sh b/.github/scripts/build-cpu.sh index 5daeb5ea5..5db76ecce 100644 --- a/.github/scripts/build-cpu.sh +++ b/.github/scripts/build-cpu.sh @@ -4,7 +4,11 @@ declare build_os set -xeuo pipefail -pip install cmake==3.28.3 +if [[ "${build_os}" == windows* ]]; then + pip install cmake==3.30.9 +else + pip install cmake==3.28.3 +fi if [ "${build_os:0:5}" == macos ] && [ "${build_arch}" == aarch64 ]; then cmake -DCMAKE_OSX_ARCHITECTURES=arm64 -DCOMPUTE_BACKEND=cpu . diff --git a/CMakeLists.txt b/CMakeLists.txt index a787866f6..3d420edb1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -118,6 +118,11 @@ if (BUILD_CPU) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) string(TOLOWER "${CMAKE_SYSTEM_PROCESSOR}" HOST_ARCH) + if(MSVC) + # Use the experimental OpenMP runtime for persistent thread pool support. + # Requires CMake 3.30+; silently ignored on older CMake versions. + set(OpenMP_RUNTIME_MSVC "experimental") + endif() find_package(OpenMP) endif() @@ -350,35 +355,52 @@ set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX) add_library(bitsandbytes SHARED ${SRC_FILES}) target_compile_features(bitsandbytes PUBLIC cxx_std_17) target_include_directories(bitsandbytes PUBLIC csrc) +set_target_properties(bitsandbytes PROPERTIES VISIBILITY_INLINES_HIDDEN ON) if (BUILD_CPU) + include(CheckIPOSupported) + check_ipo_supported(RESULT ipo_supported OUTPUT ipo_output) + if (ipo_supported AND NOT MSVC) + set_property(TARGET bitsandbytes PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) + endif() + if (OpenMP_CXX_FOUND) target_link_libraries(bitsandbytes PRIVATE OpenMP::OpenMP_CXX) add_definitions(-DHAS_OPENMP) endif() - if ((HOST_ARCH MATCHES "x86_64|amd64") AND (NOT MSVC)) - include(CheckCXXCompilerFlag) - check_cxx_compiler_flag(-mavx512f HAS_AVX512F_FLAG) - check_cxx_compiler_flag(-mavx512bf16 HAS_AVX512BF16_FLAG) - if (HAS_AVX512F_FLAG) - target_compile_options(bitsandbytes PRIVATE -mavx512f) - target_compile_options(bitsandbytes PRIVATE -mavx512dq) - target_compile_options(bitsandbytes PRIVATE -mavx512bw) - target_compile_options(bitsandbytes PRIVATE -mavx512vl) + if (NOT MSVC) + if (CMAKE_SYSTEM_NAME STREQUAL "Linux") + target_compile_options(bitsandbytes PRIVATE -fno-semantic-interposition) endif() - if (HAS_AVX512BF16_FLAG) - target_compile_options(bitsandbytes PRIVATE -mavx512bf16) + + if (HOST_ARCH MATCHES "x86_64|amd64") + include(CheckCXXCompilerFlag) + check_cxx_compiler_flag(-mavx512f HAS_AVX512F_FLAG) + check_cxx_compiler_flag(-mavx512bf16 HAS_AVX512BF16_FLAG) + if (HAS_AVX512F_FLAG) + target_compile_options( + bitsandbytes PRIVATE + -mavx512f + -mavx512bw + -mavx512dq + -mavx512vl + ) + endif() + if (HAS_AVX512BF16_FLAG) + target_compile_options(bitsandbytes PRIVATE -mavx512bf16) + endif() + target_compile_options( + bitsandbytes PRIVATE + -mprefer-vector-width=256 + -mfma + -mavx2 + -mf16c + -mlzcnt + -mbmi + -mbmi2 + ) endif() - target_compile_options( - bitsandbytes PRIVATE - -mprefer-vector-width=256 - -mfma - -mavx2 - -mlzcnt - -mbmi - -mbmi2 - ) endif() endif() diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index ed6803eda..597511c4b 100755 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -20,8 +20,10 @@ # However, we can overflow if we use this without AVX512_VNNI support. # This is fixed in torch 2.6+, so we set this as the minimum to be safe. # For more information: https://github.com/pytorch/pytorch/pull/136942 -# TODO(matthewdouglas): aarch64? -if torch.__version__ >= (2, 6): +# +# Without AVX-512 (including aarch64), torch._int_mm uses a scalar fallback +# that is much slower than fp32 matmul. Only use it when AVX-512 is available. +if torch.__version__ >= (2, 6) and _has_avx512: @register_kernel("bitsandbytes::int8_linear_matmul", "cpu") def _(A: torch.Tensor, B: torch.Tensor): diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 2a8912674..dfb9046ac 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -102,42 +102,28 @@ static inline void // interleaved.val[0] has elements 0-7, interleaved.val[1] has elements 8-15 uint8x16_t indices = vcombine_u8(interleaved.val[0], interleaved.val[1]); - // Use flat LUT for fast indexed access - // Store LUT as flat float array on stack (likely in L1 cache) - float flat_lut[16]; - vst1q_f32(flat_lut, lut[0]); - vst1q_f32(flat_lut + 4, lut[1]); - vst1q_f32(flat_lut + 8, lut[2]); - vst1q_f32(flat_lut + 12, lut[3]); - - // Extract indices and do lookups in groups of 4 for NEON multiply - uint8_t idx_arr[16]; - vst1q_u8(idx_arr, indices); - + // Reinterpret float LUT as 64-byte table for vqtbl4q_u8 lookup. + // Each 4-bit index i maps to bytes [i*4 .. i*4+3] in the table. + uint8x16x4_t lut_bytes = { + vreinterpretq_u8_f32(lut[0]), vreinterpretq_u8_f32(lut[1]), vreinterpretq_u8_f32(lut[2]), + vreinterpretq_u8_f32(lut[3]) + }; + // Multiply each index by 4 to get byte offset (max 15*4=60 < 64, safe) + uint8x16_t base = vshlq_n_u8(indices, 2); + // Expand each base offset to 4 consecutive bytes via zip + const uint8x16_t off = vreinterpretq_u8_u32(vdupq_n_u32(0x03020100)); + uint8x8_t lo = vget_low_u8(base), hi = vget_high_u8(base); + uint8x8x2_t z0 = vzip_u8(lo, lo); + uint8x8x2_t z1 = vzip_u8(hi, hi); + uint8x8x2_t zlo = vzip_u8(z0.val[0], z0.val[0]); + uint8x8x2_t zhi = vzip_u8(z0.val[1], z0.val[1]); + uint8x8x2_t zlo2 = vzip_u8(z1.val[0], z1.val[0]); + uint8x8x2_t zhi2 = vzip_u8(z1.val[1], z1.val[1]); float32x4_t vscale = vdupq_n_f32(scale); - - // Process 4 values at a time with NEON - load from temp buffer - float tmp_vals[16]; - tmp_vals[0] = flat_lut[idx_arr[0]]; - tmp_vals[1] = flat_lut[idx_arr[1]]; - tmp_vals[2] = flat_lut[idx_arr[2]]; - tmp_vals[3] = flat_lut[idx_arr[3]]; - tmp_vals[4] = flat_lut[idx_arr[4]]; - tmp_vals[5] = flat_lut[idx_arr[5]]; - tmp_vals[6] = flat_lut[idx_arr[6]]; - tmp_vals[7] = flat_lut[idx_arr[7]]; - tmp_vals[8] = flat_lut[idx_arr[8]]; - tmp_vals[9] = flat_lut[idx_arr[9]]; - tmp_vals[10] = flat_lut[idx_arr[10]]; - tmp_vals[11] = flat_lut[idx_arr[11]]; - tmp_vals[12] = flat_lut[idx_arr[12]]; - tmp_vals[13] = flat_lut[idx_arr[13]]; - tmp_vals[14] = flat_lut[idx_arr[14]]; - tmp_vals[15] = flat_lut[idx_arr[15]]; - float32x4_t v0 = vld1q_f32(tmp_vals); - float32x4_t v1 = vld1q_f32(tmp_vals + 4); - float32x4_t v2 = vld1q_f32(tmp_vals + 8); - float32x4_t v3 = vld1q_f32(tmp_vals + 12); + float32x4_t v0 = vreinterpretq_f32_u8(vqtbl4q_u8(lut_bytes, vaddq_u8(vcombine_u8(zlo.val[0], zlo.val[1]), off))); + float32x4_t v1 = vreinterpretq_f32_u8(vqtbl4q_u8(lut_bytes, vaddq_u8(vcombine_u8(zhi.val[0], zhi.val[1]), off))); + float32x4_t v2 = vreinterpretq_f32_u8(vqtbl4q_u8(lut_bytes, vaddq_u8(vcombine_u8(zlo2.val[0], zlo2.val[1]), off))); + float32x4_t v3 = vreinterpretq_f32_u8(vqtbl4q_u8(lut_bytes, vaddq_u8(vcombine_u8(zhi2.val[0], zhi2.val[1]), off))); vst1q_f32(out, vmulq_f32(v0, vscale)); vst1q_f32(out + 4, vmulq_f32(v1, vscale)); @@ -173,27 +159,59 @@ static inline void neon_f32_to_fp16x4(const float32x4_t src, fp16_t* dst) { vst1_u16(reinterpret_cast(dst), vreinterpret_u16_f16(half)); } -// NEON-optimized absmax computation for a block of float32 -static inline float neon_absmax_f32(const float* data, long long n) { +// NEON-optimized FP16 to float conversion (4 values at a time) +static inline float32x4_t neon_fp16x4_to_f32(const fp16_t* src) { + uint16x4_t raw = vld1_u16(reinterpret_cast(src)); + return vcvt_f32_f16(vreinterpret_f16_u16(raw)); +} + +// NEON-optimized absmax computation for a block of float32, bf16, or fp16. +template static inline float neon_absmax(const T* data, long long n) { float32x4_t vmax = vdupq_n_f32(0.0f); long long i = 0; - // Process 16 elements per iteration for better throughput for (; i + 16 <= n; i += 16) { - float32x4_t v0 = vabsq_f32(vld1q_f32(data + i)); - float32x4_t v1 = vabsq_f32(vld1q_f32(data + i + 4)); - float32x4_t v2 = vabsq_f32(vld1q_f32(data + i + 8)); - float32x4_t v3 = vabsq_f32(vld1q_f32(data + i + 12)); - vmax = vmaxq_f32(vmax, vmaxq_f32(vmaxq_f32(v0, v1), vmaxq_f32(v2, v3))); + float32x4_t v0, v1, v2, v3; + if constexpr (std::is_same::value) { + const float* p = reinterpret_cast(data + i); + v0 = vld1q_f32(p); + v1 = vld1q_f32(p + 4); + v2 = vld1q_f32(p + 8); + v3 = vld1q_f32(p + 12); + } else if constexpr (std::is_same::value) { + v0 = neon_bf16x4_to_f32(data + i); + v1 = neon_bf16x4_to_f32(data + i + 4); + v2 = neon_bf16x4_to_f32(data + i + 8); + v3 = neon_bf16x4_to_f32(data + i + 12); + } else { + v0 = neon_fp16x4_to_f32(data + i); + v1 = neon_fp16x4_to_f32(data + i + 4); + v2 = neon_fp16x4_to_f32(data + i + 8); + v3 = neon_fp16x4_to_f32(data + i + 12); + } + vmax = vmaxq_f32( + vmax, vmaxq_f32(vmaxq_f32(vabsq_f32(v0), vabsq_f32(v1)), vmaxq_f32(vabsq_f32(v2), vabsq_f32(v3))) + ); } for (; i + 4 <= n; i += 4) { - float32x4_t v = vld1q_f32(data + i); + float32x4_t v; + if constexpr (std::is_same::value) + v = vld1q_f32(reinterpret_cast(data + i)); + else if constexpr (std::is_same::value) + v = neon_bf16x4_to_f32(data + i); + else + v = neon_fp16x4_to_f32(data + i); vmax = vmaxq_f32(vmax, vabsq_f32(v)); } - // Horizontal max float result = vmaxvq_f32(vmax); - // Handle remainder for (; i < n; ++i) { - result = std::max(result, std::fabs(data[i])); + float val; + if constexpr (std::is_same::value) + val = data[i]; + else if constexpr (std::is_same::value) + val = bf16_to_float(data[i].v); + else + val = fp16_to_float(data[i].v); + result = std::max(result, std::fabs(val)); } return result; } @@ -213,7 +231,6 @@ static inline uint16x4_t neon_norm_to_lut_index_x4(float32x4_t vals) { #endif // _M_ARM64 || __aarch64__ #if defined(__AVX512F__) -#include inline __m256i cvt_fp32_to_fp16(const __m512 src) { return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); @@ -259,6 +276,29 @@ static inline __m512 set_fp4_lut() { } #endif +static constexpr float fp4_lut[16] = { + 0.0f, 0.005208333333f, 0.66666667f, 1.0f, 0.33333333f, 0.5f, 0.16666667f, 0.25f, + -0.0f, -0.005208333333f, -0.66666667f, -1.0f, -0.33333333f, -0.5f, -0.16666667f, -0.25f, +}; +static constexpr float nf4_lut[16] = { + -1.0f, + -0.6961928009986877f, + -0.5250730514526367f, + -0.39491748809814453f, + -0.28444138169288635f, + -0.18477343022823334f, + -0.09105003625154495f, + 0.0f, + 0.07958029955625534f, + 0.16093020141124725f, + 0.24611230194568634f, + 0.33791524171829224f, + 0.44070982933044434f, + 0.5626170039176941f, + 0.7229568362236023f, + 1.0f, +}; + // 4-bit (FP4 / NF4) dequantization helper extracted from the original else branch. // DATA_TYPE: 1 = FP4, 2 = NF4 template @@ -270,53 +310,45 @@ void dequantizeBlockwise4bitCpu( return; #if defined(_M_ARM64) || defined(__aarch64__) - { + // n % blocksize == 0: absmax is organized by flat element blocks; row and block + // boundaries must align or the 2D absmax indexing gives wrong scale values. + if (n % blocksize == 0) { long long dim_0 = m; long long dim_1 = n; long long input_dim_1 = dim_1 >> 1; long long absmax_dim_1 = dim_1 / blocksize; - // NEON path: process 16 output values at a time (8 packed bytes) - // Only use when blocksize evenly divides dim_1 to ensure correct scale indexing - constexpr long long VEC_LEN = 16; - if (dim_1 % VEC_LEN == 0 && blocksize >= VEC_LEN && (dim_1 % blocksize == 0)) { - float32x4_t lut[4]; - if constexpr (DATA_TYPE == 1) { - neon_fp4_lut(lut); - } else { - neon_nf4_lut(lut); - } - constexpr long long k_step = VEC_LEN / 2; // 8 bytes per iteration - BNB_OMP_PARALLEL_FOR - for (long long block_idx = 0; block_idx < dim_0; ++block_idx) { - for (long long k = 0; k < input_dim_1; k += k_step) { - long long scale_idx = k * 2 / blocksize; - float scale = absmax[block_idx * absmax_dim_1 + scale_idx]; - const uint8_t* p = &A[block_idx * input_dim_1 + k]; - - // Dequantize 16 values into a temp float buffer - float tmp_f32[16]; - neon_dequant_4bit_16values(p, scale, lut, tmp_f32); - - // Store results (convert to output type using NEON) - T* pout = &out[block_idx * dim_1 + k * 2]; - if constexpr (std::is_same()) { - // Direct copy - already float - std::memcpy(pout, tmp_f32, 16 * sizeof(float)); - } else if constexpr (std::is_same()) { - neon_f32_to_bf16x4(vld1q_f32(tmp_f32), pout); - neon_f32_to_bf16x4(vld1q_f32(tmp_f32 + 4), pout + 4); - neon_f32_to_bf16x4(vld1q_f32(tmp_f32 + 8), pout + 8); - neon_f32_to_bf16x4(vld1q_f32(tmp_f32 + 12), pout + 12); - } else if constexpr (std::is_same()) { - neon_f32_to_fp16x4(vld1q_f32(tmp_f32), pout); - neon_f32_to_fp16x4(vld1q_f32(tmp_f32 + 4), pout + 4); - neon_f32_to_fp16x4(vld1q_f32(tmp_f32 + 8), pout + 8); - neon_f32_to_fp16x4(vld1q_f32(tmp_f32 + 12), pout + 12); - } + float32x4_t neon_lut[4]; + if constexpr (DATA_TYPE == 1) { + neon_fp4_lut(neon_lut); + } else { + neon_nf4_lut(neon_lut); + } + constexpr long long k_step = 8; // 8 packed bytes = 16 output values + BNB_OMP_PARALLEL_FOR + for (long long block_idx = 0; block_idx < dim_0; ++block_idx) { + for (long long k = 0; k < input_dim_1; k += k_step) { + long long scale_idx = k * 2 / blocksize; + float scale = absmax[block_idx * absmax_dim_1 + scale_idx]; + const uint8_t* p = &A[block_idx * input_dim_1 + k]; + float tmp_f32[16]; + neon_dequant_4bit_16values(p, scale, neon_lut, tmp_f32); + T* pout = &out[block_idx * dim_1 + k * 2]; + if constexpr (std::is_same()) { + std::memcpy(pout, tmp_f32, 16 * sizeof(float)); + } else if constexpr (std::is_same()) { + neon_f32_to_bf16x4(vld1q_f32(tmp_f32), pout); + neon_f32_to_bf16x4(vld1q_f32(tmp_f32 + 4), pout + 4); + neon_f32_to_bf16x4(vld1q_f32(tmp_f32 + 8), pout + 8); + neon_f32_to_bf16x4(vld1q_f32(tmp_f32 + 12), pout + 12); + } else { + neon_f32_to_fp16x4(vld1q_f32(tmp_f32), pout); + neon_f32_to_fp16x4(vld1q_f32(tmp_f32 + 4), pout + 4); + neon_f32_to_fp16x4(vld1q_f32(tmp_f32 + 8), pout + 8); + neon_f32_to_fp16x4(vld1q_f32(tmp_f32 + 12), pout + 12); } } - return; } + return; } #endif // _M_ARM64 || __aarch64__ @@ -334,22 +366,16 @@ void dequantizeBlockwise4bitCpu( BNB_OMP_PARALLEL_FOR for (int block_idx = 0; block_idx < dim_0; ++block_idx) { for (int k = 0; k < input_dim_1; k += k_step) { - // Load 64 bits of nf4 data and a single scale data - uint8_t* p = &A[block_idx * input_dim_1 + k]; - uint64_t packed; - std::memcpy(&packed, p, sizeof(uint64_t)); + const uint8_t* p = &A[block_idx * input_dim_1 + k]; auto scale_idx = k * 2 / blocksize; auto vscales = _mm512_set1_ps((float)absmax[block_idx * absmax_dim_1 + scale_idx]); - // unpack nf4 data to 32-bit integers - uint64_t high = 0; - uint64_t low = 0; - for (int i = 0; i < 4; ++i) { - low |= ((packed >> (2 * i * 4)) & 0xf) << ((2 * i + 1) * 8); - low |= ((packed >> ((2 * i + 1) * 4)) & 0xf) << (2 * i * 8); - high |= ((packed >> (2 * i * 4 + 32)) & 0xf) << ((2 * i + 1) * 8); - high |= ((packed >> ((2 * i + 1) * 4 + 32)) & 0xf) << (2 * i * 8); - } - __m128i packed_128 = _mm_set_epi64x(high, low); + // Unpack 8 packed bytes into 16 nibble indices using SSE. + // Each byte holds two 4-bit values; high nibble is the first output element. + __m128i raw = _mm_loadl_epi64(reinterpret_cast(p)); + __m128i mask4 = _mm_set1_epi8(0x0f); + __m128i hi = _mm_and_si128(_mm_srli_epi16(raw, 4), mask4); + __m128i lo = _mm_and_si128(raw, mask4); + __m128i packed_128 = _mm_unpacklo_epi8(hi, lo); __m512i vint32 = _mm512_cvtepu8_epi32(packed_128); // Table look-up __m512 vout = _mm512_permutexvar_ps(vint32, lut); @@ -371,6 +397,7 @@ void dequantizeBlockwise4bitCpu( } #endif // Scalar fallback branch + const float* lut = DATA_TYPE == 1 ? fp4_lut : nf4_lut; long long total = m * n; BNB_OMP_PARALLEL_FOR for (long long block_idx = 0; block_idx < total; block_idx += blocksize) { @@ -381,9 +408,9 @@ void dequantizeBlockwise4bitCpu( unsigned char byte = A[byte_index]; // High nibble first (matches previous code logic) - float v0 = (DATA_TYPE == 1 ? dDequantizeFP4(byte >> 4) : dDequantizeNF4(byte >> 4)) * scale; + float v0 = lut[byte >> 4] * scale; // Low nibble second - float v1 = (DATA_TYPE == 1 ? dDequantizeFP4(byte & 0x0F) : dDequantizeNF4(byte & 0x0F)) * scale; + float v1 = lut[byte & 0x0F] * scale; if constexpr (std::is_same::value) { out[block_idx + i] = float_to_bf16(v0); @@ -418,6 +445,32 @@ void dequantizeBlockwise8bitCpu( long long valid_items = (n - block_idx >= blocksize ? blocksize : n - block_idx); long long block_end = block_idx + valid_items; float scale = absmax[block_idx / blocksize]; +#if defined(_M_ARM64) || defined(__aarch64__) + { + float32x4_t vscale = vdupq_n_f32(scale); + long long i = block_idx; + for (; i + 4 <= block_end; i += 4) { + float tmp[4] = {code[A[i]], code[A[i + 1]], code[A[i + 2]], code[A[i + 3]]}; + float32x4_t v = vmulq_f32(vld1q_f32(tmp), vscale); + if constexpr (std::is_same::value) + vst1q_f32(reinterpret_cast(out + i), v); + else if constexpr (std::is_same::value) + neon_f32_to_bf16x4(v, out + i); + else + neon_f32_to_fp16x4(v, out + i); + } + for (; i < block_end; ++i) { + float v = code[A[i]] * scale; + if constexpr (std::is_same::value) + out[i] = float_to_bf16(v); + else if constexpr (std::is_same::value) + out[i] = float_to_fp16(v); + else + out[i] = static_cast(v); + } + } +#else +#pragma omp simd for (long long i = block_idx; i < block_end; ++i) { float v = code[A[i]] * scale; if constexpr (std::is_same::value) { @@ -428,6 +481,7 @@ void dequantizeBlockwise8bitCpu( out[i] = static_cast(v); } } +#endif } } @@ -436,7 +490,7 @@ void dequantizeBlockwise8bitCpu( // which would SIGILL on non-AVX512 CPUs like Zen3. These functions are scalar C++ and don't need AVX512. #if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__)) #pragma GCC push_options -#pragma GCC target("no-avx512f") +#pragma GCC target("avx2,fma,no-avx512f") #endif // Precomputed direct lookup table: maps quantized uint16 index [0..65535] to codebook index. @@ -537,24 +591,20 @@ void quantize_cpu_impl(float* code, const T* A, float* absmax, unsigned char* ou float absmax_block = 0.0f; #if defined(_M_ARM64) || defined(__aarch64__) - if constexpr (std::is_same::value) { - // Use NEON-optimized absmax for float32 - absmax_block = neon_absmax_f32(reinterpret_cast(A + block_start), block_len); - } else -#endif - { - for (long long i = block_start; i < block_end; ++i) { - float val; - if constexpr (std::is_same::value) { - val = A[i]; - } else if constexpr (std::is_same::value) { - val = bf16_to_float(A[i].v); - } else if constexpr (std::is_same::value) { - val = fp16_to_float(A[i].v); - } - absmax_block = std::max(absmax_block, std::fabs(val)); - } + absmax_block = neon_absmax(A + block_start, block_len); +#else +#pragma omp simd reduction(max : absmax_block) + for (long long i = block_start; i < block_end; ++i) { + float val; + if constexpr (std::is_same::value) + val = A[i]; + else if constexpr (std::is_same::value) + val = bf16_to_float(A[i].v); + else + val = fp16_to_float(A[i].v); + absmax_block = std::max(absmax_block, std::fabs(val)); } +#endif absmax[b] = absmax_block; @@ -568,13 +618,18 @@ void quantize_cpu_impl(float* code, const T* A, float* absmax, unsigned char* ou const float inv_absmax = 1.0f / absmax_block; #if defined(_M_ARM64) || defined(__aarch64__) - if constexpr (std::is_same::value) { - // NEON-optimized normalize + LUT index for float32 - const float* src = A + block_start; + { long long i = 0; float32x4_t vinv = vdupq_n_f32(inv_absmax); for (; i + 4 <= block_len; i += 4) { - float32x4_t v = vmulq_f32(vld1q_f32(src + i), vinv); + float32x4_t v; + if constexpr (std::is_same::value) + v = vld1q_f32(reinterpret_cast(A + block_start + i)); + else if constexpr (std::is_same::value) + v = neon_bf16x4_to_f32(A + block_start + i); + else + v = neon_fp16x4_to_f32(A + block_start + i); + v = vmulq_f32(v, vinv); uint16x4_t indices = neon_norm_to_lut_index_x4(v); uint16_t idx_arr[4]; vst1_u16(idx_arr, indices); @@ -584,25 +639,28 @@ void quantize_cpu_impl(float* code, const T* A, float* absmax, unsigned char* ou out[block_start + i + 3] = lut[idx_arr[3]]; } for (; i < block_len; ++i) { - float normed_value = src[i] * inv_absmax; - out[block_start + i] = lut[norm_to_lut_index(normed_value)]; - } - } else -#endif - { - for (long long i = block_start; i < block_end; ++i) { float val; - if constexpr (std::is_same::value) { - val = A[i]; - } else if constexpr (std::is_same::value) { - val = bf16_to_float(A[i].v); - } else if constexpr (std::is_same::value) { - val = fp16_to_float(A[i].v); - } - float normed_value = val * inv_absmax; - out[i] = lut[norm_to_lut_index(normed_value)]; + if constexpr (std::is_same::value) + val = A[block_start + i]; + else if constexpr (std::is_same::value) + val = bf16_to_float(A[block_start + i].v); + else + val = fp16_to_float(A[block_start + i].v); + out[block_start + i] = lut[norm_to_lut_index(val * inv_absmax)]; } } +#else + for (long long i = block_start; i < block_end; ++i) { + float val; + if constexpr (std::is_same::value) + val = A[i]; + else if constexpr (std::is_same::value) + val = bf16_to_float(A[i].v); + else + val = fp16_to_float(A[i].v); + out[i] = lut[norm_to_lut_index(val * inv_absmax)]; + } +#endif } } diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 14df69921..3a1ec40a8 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -13,6 +13,10 @@ #include #endif +#if defined(__x86_64__) || defined(_M_X64) +#include +#endif + // amx-bf16 #define TILE_M 16 #define TILE_N 16 @@ -32,7 +36,7 @@ template inline int get_cache_blocks(int chunk_size) { } // forced unroll for perf critical path -#if __has_attribute(always_inline) +#if defined(__has_attribute) && __has_attribute(always_inline) #define ALWAYS_INLINE __attribute__((__always_inline__)) inline #else #define ALWAYS_INLINE inline @@ -147,6 +151,12 @@ static float bf16_to_float(uint16_t bf16) { } static inline fp16_t float_to_fp16(float x) { +#if defined(__AVX2__) + // F16C is guaranteed on all AVX2 CPUs; matches CUDA round-to-nearest-even behavior + return fp16_t{ + (uint16_t)_mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC), 0) + }; +#else uint32_t bits; std::memcpy(&bits, &x, 4); uint32_t sign = (bits >> 31) & 0x1; @@ -186,9 +196,13 @@ static inline fp16_t float_to_fp16(float x) { h = (sign << 15) | ((uint16_t)exp_h << 10) | ((uint16_t)(mant_rounded >> 13)); } return fp16_t{h}; +#endif } static inline float fp16_to_float(uint16_t h) { +#if defined(__AVX2__) + return _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(h))); +#else uint32_t sign = (h >> 15) & 0x1; uint32_t exp = (h >> 10) & 0x1F; uint32_t mant = h & 0x3FF; @@ -216,6 +230,7 @@ static inline float fp16_to_float(uint16_t h) { float f; std::memcpy(&f, &bits, sizeof(f)); return f; +#endif } inline float dDequantizeFP4(unsigned char val) { diff --git a/tests/test_autograd.py b/tests/test_autograd.py index d150f4735..7d273c853 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -134,7 +134,7 @@ def test_matmullt( @pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", [64], ids=id_formatter("dim3")) @pytest.mark.parametrize("dim4", [96], ids=id_formatter("dim4")) -@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) +@pytest.mark.parametrize("req_grad", REQ_GRAD_NO_B_WEIGHT, ids=id_formatter("req_grad")) @pytest.mark.parametrize("transpose_B", TRUE_FALSE, ids=id_formatter("transpose_B")) @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @@ -169,8 +169,8 @@ def test_matmul_4bit( for i in range(3): A = torch.randn(size=dimA, device=device, requires_grad=req_grad[0], dtype=dtype) - B = torch.randn(size=dimB, device=device, requires_grad=req_grad[1], dtype=dtype) - target = torch.randn(size=(dim2, dim4), device=device, requires_grad=req_grad[1], dtype=dtype) + B = torch.randn(size=dimB, device=device, dtype=dtype) + target = torch.randn(size=(dim2, dim4), device=device, dtype=dtype) bias = None bias2 = None if has_bias: @@ -212,9 +212,7 @@ def test_matmul_4bit( loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() loss_bnb.backward() gradA1 = A.grad - gradB1 = B.grad A.grad = None - B.grad = None if has_bias: gradBias1 = bias.grad bias.grad = None @@ -222,9 +220,7 @@ def test_matmul_4bit( loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch.backward() gradA2 = A.grad - gradB2 = B.grad A.grad = None - B.grad = None if has_bias: gradBias2 = bias.grad bias.grad = None diff --git a/tests/test_functional.py b/tests/test_functional.py index 95d8727f7..e4cd6a128 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -275,10 +275,6 @@ def test_few_bit_quant(self, device, bits, method): @pytest.mark.parametrize("device", get_available_devices()) def test_fp8_quant(self, device): - # TODO - if device == "cpu": - pytest.skip("CPU implementation segfaults") - for e_bits in range(1, 7): p_bits = 7 - e_bits code = F.create_fp8_map(True, e_bits, p_bits).to(device) @@ -896,6 +892,30 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, double_quant, kind): dim_key = "le512" if dim <= 512 else "gt512" thresholds = gemv_thresholds[dtype][dim_key] + + # On CPU with AVX512BF16, fp16/fp32 inputs are downcast to bf16 for the fused + # kernel for performance. Thresholds calibrated from 100 iterations on CPU. + cpu_bf16_cast = device == "cpu" and F.has_avx512bf16() and dtype in (torch.float16, torch.float32) + if cpu_bf16_cast: + thresholds = { + "le512": { + "err1": (2.72e-4, 9.96e-5), + "relerr1": ( + 1.88e-3 if dtype == torch.float16 else 1.64e-3, + 1.27e-2 if dtype == torch.float16 else 3.61e-3, + ), + "maxerr1": (1.22e-3, 3.80e-4), + }, + "gt512": { + "err1": (1.00e-4, 3.48e-5), + "relerr1": ( + 6.92e-4 if dtype == torch.float16 else 6.31e-4, + 9.21e-4 if dtype == torch.float16 else 4.71e-4, + ), + "maxerr1": (5.16e-4, 1.68e-4), + }, + }[dim_key] + for metric_name, metric_val in [("err1", err1), ("relerr1", relerr1), ("maxerr1", maxerr1)]: mean_val, std_val = thresholds[metric_name] limit = mean_val + N_SIGMA * std_val @@ -906,11 +926,12 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, double_quant, kind): # Ratios check that gemv_4bit and matmul_4bit produce consistent results. # These are tight bounds on internal consistency, not absolute accuracy. - if dtype == torch.float16: + # On CPU with AVX512BF16, fp16/fp32 use bf16 arithmetic so get bf16-level bounds. + if dtype == torch.float16 and not cpu_bf16_cast: assert absratio < 1.005 and absratio > 0.995 assert relratio < 1.005 and relratio > 0.992 assert maxratio < 1.005 and maxratio > 0.992 - elif dtype == torch.float32: + elif dtype == torch.float32 and not cpu_bf16_cast: assert absratio < 1.005 and absratio > 0.995 assert relratio < 1.005 and relratio > 0.995 assert maxratio < 1.005 and maxratio > 0.995 @@ -918,6 +939,10 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, double_quant, kind): assert absratio < 1.005 and absratio > 0.995 assert relratio < 1.05 and relratio > 0.96 assert maxratio < 1.05 and maxratio > 0.97 + elif cpu_bf16_cast: + assert absratio < 1.02 and absratio > 0.98 + assert relratio < 1.1 and relratio > 0.90 + assert maxratio < 1.1 and maxratio > 0.90 @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 12ed0eb27..79ede45b2 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -221,9 +221,6 @@ def test_params4bit_torch_chunk_split(device, quant_type): if device == "hpu" and not is_supported_on_hpu(quant_type, torch.float16, torch.uint8): pytest.skip("This configuration is not supported on HPU.") - if device == "cpu": - pytest.skip("CPU quantization causes segfault, skipping CPU test") - original_tensor = torch.randn(8, 4, dtype=torch.float16, device="cpu") params4bit = bnb.nn.Params4bit(data=original_tensor, quant_type=quant_type, requires_grad=False) diff --git a/tests/test_ops.py b/tests/test_ops.py index bd5217748..3550c0b6f 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -98,12 +98,8 @@ class TestInt8BlockwiseQuantOps: @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) def test_quantize_blockwise(self, device, dtype, blocksize): - if device == "cpu": - if dtype != torch.float32: - pytest.skip("CPU implementation is only available for float32") - - if blocksize != 256: - pytest.skip("CPU implementation is slow; only test blocksize=256") + if device == "cpu" and blocksize != 256: + pytest.skip("CPU implementation is slow; only test blocksize=256") code = bitsandbytes.functional.create_dynamic_map().to(device) A = torch.randn(1024, 1024, dtype=dtype, device=device) @@ -122,9 +118,6 @@ def test_quantize_blockwise(self, device, dtype, blocksize): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) def test_dequantize_blockwise(self, device, dtype, blocksize): - if device == "cpu" and dtype != torch.float32: - pytest.skip("CPU implementation is only available for float32") - A = torch.randint(0, 255, (1024, 1024), dtype=torch.uint8, device=device) code = bitsandbytes.functional.create_dynamic_map().to(device, dtype=torch.float32) diff --git a/tests/test_optim.py b/tests/test_optim.py index dbfb9d469..0a4b3d6af 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -1,9 +1,6 @@ -import os -from os.path import join -import shutil +import io import sys import time -import uuid from lion_pytorch import Lion import pytest @@ -27,16 +24,6 @@ def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0): torch.testing.assert_close(a, b, rtol=rtol, atol=atol) -def get_temp_dir(): - path = f"/tmp/autoswap/{uuid.uuid4()}" - os.makedirs(path, exist_ok=True) - return path - - -def rm_path(path): - shutil.rmtree(path) - - str2optimizers = {} ## TODO: maybe remove these three. @@ -223,13 +210,13 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name, device): assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=15) if i % (k // 5) == 0 and i > 0: - path = get_temp_dir() - torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt")) + buf = io.BytesIO() + torch.save(bnb_optimizer.state_dict(), buf) del bnb_optimizer bnb_optimizer = None bnb_optimizer = str2optimizers[optim_name][1]([p2]) - bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) - rm_path(path) + buf.seek(0) + bnb_optimizer.load_state_dict(torch.load(buf)) # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 10 errors for Lion assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=10) @@ -441,13 +428,13 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name, device): raws1cpy = bnb_optimizer.state[p2][name2].clone() qmap1 = bnb_optimizer.state[p2][qmap].clone() - path = get_temp_dir() - torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt")) + buf = io.BytesIO() + torch.save(bnb_optimizer.state_dict(), buf) del bnb_optimizer bnb_optimizer = None bnb_optimizer = str2optimizers[optim_name][1]([p2]) - bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) - rm_path(path) + buf.seek(0) + bnb_optimizer.load_state_dict(torch.load(buf)) torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2]) torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap]) @@ -577,16 +564,18 @@ def test_ademamix_state_dict_no_nan(optim_name, optim_factory, device): # Save state model_sd = {k: v.clone() for k, v in model.state_dict().items()} opt_sd = opt.state_dict() - path = get_temp_dir() - torch.save(opt_sd, join(path, "opt.pt")) - torch.save(model_sd, join(path, "model.pt")) + opt_buf = io.BytesIO() + model_buf = io.BytesIO() + torch.save(opt_sd, opt_buf) + torch.save(model_sd, model_buf) # Create fresh model and optimizer, load state model2 = nn.Linear(256, 64).to(device) - model2.load_state_dict(torch.load(join(path, "model.pt"))) + model_buf.seek(0) + model2.load_state_dict(torch.load(model_buf)) opt2 = optim_factory(model2.parameters()) - opt2.load_state_dict(torch.load(join(path, "opt.pt"))) - rm_path(path) + opt_buf.seek(0) + opt2.load_state_dict(torch.load(opt_buf)) # Verify loaded state matches original byte-for-byte orig_params = list(model.parameters())