Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 63 additions & 81 deletions src/ifma_avx512.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,57 +13,6 @@
#include <cstdint>
#include <x86intrin.h>

// Precomputed shuffle masks for K = 1 to 15
static const uint8_t shuffle_masks[15][16] = {
// K = 1: [15, 0x80, 0x80, ...]
{15, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80},
// K = 2: [14, 0x80, 15, 0x80, ...]
{14, 0x80, 15, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80},
// K = 3: [13, 0x80, 14, 15, 0x80, ...]
{13, 0x80, 14, 15, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80},
// K = 4: [12, 0x80, 13, 14, 15, 0x80, ...]
{12, 0x80, 13, 14, 15, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80},
// K = 5: [11, 0x80, 12, 13, 14, 15, 0x80, ...]
{11, 0x80, 12, 13, 14, 15, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80},
// K = 6: [10, 0x80, 11, 12, 13, 14, 15, 0x80, ...]
{10, 0x80, 11, 12, 13, 14, 15, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80},
// K = 7: [9, 0x80, 10, 11, 12, 13, 14, 15, 0x80, ...]
{9, 0x80, 10, 11, 12, 13, 14, 15, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80},
// K = 8: [8, 0x80, 9, 10, 11, 12, 13, 14, 15, 0x80, ...]
{8, 0x80, 9, 10, 11, 12, 13, 14, 15, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80},
// K = 9: [7, 0x80, 8, 9, 10, 11, 12, 13, 14, 15, 0x80, ...]
{7, 0x80, 8, 9, 10, 11, 12, 13, 14, 15, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80},
// K = 10: [6, 0x80, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0x80, ...]
{6, 0x80, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0x80, 0x80, 0x80, 0x80, 0x80},
// K = 11: [5, 0x80, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0x80, ...]
{5, 0x80, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0x80, 0x80, 0x80, 0x80},
// K = 12: [4, 0x80, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0x80, ...]
{4, 0x80, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0x80, 0x80, 0x80},
// K = 13: [3, 0x80, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0x80, ...]
{3, 0x80, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0x80, 0x80},
// K = 14: [2, 0x80, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0x80]
{2, 0x80, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0x80},
// K = 15: [1, 0x80, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
{1, 0x80, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
};

// K should be between 1 and 15
inline __m128i shift_and_insert_dot(__m128i input, int K) {
// Prepare a vector with '.' (0x2E) at index 1 and zeros elsewhere
__m128i dot_vector = _mm_setr_epi8(0, 0x2E, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);

// Load the precomputed shuffle mask for K (index K-1)
__m128i mask = _mm_loadu_si128((__m128i*)shuffle_masks[K - 1]);

// Perform the shuffle to reposition the K bytes
__m128i shuffled = _mm_shuffle_epi8(input, mask);

// Blend with dot_vector to insert '.' at index 1
__m128i result = _mm_or_si128(shuffled, dot_vector);

return result;
}

/*
The IFMA decimal print method:

Expand All @@ -80,45 +29,73 @@ n = 84736251
84736251 = n mod 10^8

From this paper
https://arxiv.org/abs/1902.01961
page 8:

uint32_t d = ...; // your divisor > 0
// c = ceil ( (1 < <64) / d ) ; we take L = N
uint64_t c = UINT64_C (0xFFFFFFFFFFFFFFFF ) / d + 1;
// fastmod computes (n mod d) given precomputed c
uint32_t fastmod ( uint32_t n, uint64_t c, uint32_t d) {
uint64_t lowbits = c * n;
return (( __uint128_t ) lowbits * d) >> 64;
}
Lemire, D., Bartlett, C., & Kaser, O. (2021). Integer division by constants: optimal bounds. Heliyon, 7(6).
https://arxiv.org/abs/2012.12369

Theorem 4 (page 3)

It says that ( (c * n + c) % m ) * d / m gives n mod d, for 0 <= n <= N as long as

(1 - 1/(N+1))*1/d ≤ c/m < 1/d

or

N m ≤c d (N+1) < m (N+1)


As long as d does not divide m, we can set c = floor (m / d) and c/m < 1/d is satisfied.

It remains to verify the left identity.

N m ≤c d (N+1)

We want m to be 2^52 and N = 10^8 - 1, so we need to verify that

Fastmod fits well for this AVX512FMA instruction pair:
VPMADD52LUQ => lowbits = c * n + 0
VPMADD52HUQ => highbits = lowbits * 10 + asciiZero
just uses 52b and 104b numbers instead of 64 and 128, and highbits use 10 instead of d, and produces 8 decimal digits for 0 <= n <= 99999999.
(10^8 - 1) * 2^52 ≤ floor(2^52 / d) * d * 10^8

The only problem is that in the 8th digit case the VPMADD52HUQ overflows, if we use the original 0x2af31dd ( = (2^53 - 1)/(10^8) + 1) constant as c in VPMADD52LUQ:
where d = 10, ... , 10^8

In Python, we can check:

for k in range(1,9):
d = 10**k
lhs = (10**8-1) * 2**52
rhs = (2**52//d)*d*10**8
assert lhs <= rhs

This fits well with the IFMA instruction pair, which computes (c * n + c) mod 2^52 and then multiplies the result by 10 and adds '0' to get the ASCII code of the digit.

We set call 'c' ifma_const, set m = 2^52, and we compute

(c n + c) % 2^52 as

lowbits_l = _mm512_madd52lo_epu64(ifma_const, bcstq_l, ifma_const)

and then we compute

((c n + c) % 2^52) * 10 + '0' as

_mm512_madd52hi_epu64(asciiZero, zmmTen, lowbits_l)

where asciiZero is the vector of '0' characters and zmmTen is the vector of 10s.

0x2af31dd * 99999999 = 0x10000001a50b23

Solution: we use 0x2af31dc = 0x2af31dd - 1 as c, and use 0x1A1A400 bias instead of 0. 0x1A1A400 is the smallest bias, which does not underflows in case of the smallest 8-digit number:

0x2af31dc * 10000000 = 0x19999996FD600 = 450359960000000
(0x19999996FD600 + 0x1A1A400) * 10 = 0x1000000EAEC400
*/

// (0xFFFFFFFFFFFFFFFF ) / d + 1 for d = 10^8, 10^7, ..., 10^1
static const __m512i ifma_const = _mm512_setr_epi64(
0x00000000002af31dc, 0x0000000001ad7f29b, 0x0000000010c6f7a0c, 0x00000000a7c5ac472,
0x000000068db8bac72, 0x0000004189374bc6b, 0x0000028f5c28f5c29, 0x0000199999999999a
);

champagne_lemire_really_inline __m512i to_string_avx512ifma_8digits(uint64_t n) {
__m512i bcstq_l = _mm512_set1_epi64(n);
constexpr uint64_t twoto52 = 0x10000000000000ULL; // 2^52
__m512i ifma_const = _mm512_setr_epi64(
twoto52 / 100000000, twoto52 / 10000000, twoto52 / 1000000, twoto52 / 100000,
twoto52 / 10000, twoto52 / 1000, twoto52 / 100, twoto52 / 10
);
__m512i zmmzero = _mm512_castsi128_si512(_mm_cvtsi64_si128(0x01A1A400));
__m512i zmmTen = _mm512_set1_epi64(10);
__m512i asciiZero = _mm512_set1_epi64('0');
__m512i lowbits_l = _mm512_madd52lo_epu64(zmmzero, bcstq_l, ifma_const);
__m512i lowbits_l = _mm512_madd52lo_epu64(ifma_const, bcstq_l, ifma_const); // ifma_const * bcstq_l + ifma_const
__m512i highbits_l = _mm512_madd52hi_epu64(asciiZero, zmmTen, lowbits_l);
return highbits_l;
}
Expand All @@ -133,17 +110,22 @@ champagne_lemire_really_inline __m128i to_string_avx512ifma(uint64_t n) {
uint64_t n_07_00 = n % 100000000;
__m512i bcstq_h = _mm512_set1_epi64(n_15_08);
__m512i bcstq_l = _mm512_set1_epi64(n_07_00);
__m512i zmmzero = _mm512_castsi128_si512(_mm_cvtsi64_si128(0x01A1A400));
constexpr uint64_t twoto52 = 0x10000000000000ULL; // 2^52
__m512i ifma_const = _mm512_setr_epi64(
twoto52 / 100000000, twoto52 / 10000000, twoto52 / 1000000, twoto52 / 100000,
twoto52 / 10000, twoto52 / 1000, twoto52 / 100, twoto52 / 10
);

__m512i zmmTen = _mm512_set1_epi64(10);
__m512i asciiZero = _mm512_set1_epi64('0');

__m512i permb_const = _mm512_castsi128_si512(
_mm_set_epi8(0x78, 0x70, 0x68, 0x60, 0x58, 0x50, 0x48, 0x40,
0x38, 0x30, 0x28, 0x20, 0x18, 0x10, 0x08, 0x00));
__m512i lowbits_h = _mm512_madd52lo_epu64(zmmzero, bcstq_h, ifma_const);
__m512i lowbits_l = _mm512_madd52lo_epu64(zmmzero, bcstq_l, ifma_const);
__m512i highbits_h = _mm512_madd52hi_epu64(asciiZero, zmmTen, lowbits_h);
__m512i highbits_l = _mm512_madd52hi_epu64(asciiZero, zmmTen, lowbits_l);
__m512i lowbits_h = _mm512_madd52lo_epu64(ifma_const, bcstq_h, ifma_const); // lowbits_h = ifma_const * bcstq_h + ifma_const
__m512i lowbits_l = _mm512_madd52lo_epu64(ifma_const, bcstq_l, ifma_const); // lowbits_l = ifma_const * bcstq_l + ifma_const
__m512i highbits_h = _mm512_madd52hi_epu64(asciiZero, zmmTen, lowbits_h); // highbits_h = lowbits_h * 10 + asciiZero
__m512i highbits_l = _mm512_madd52hi_epu64(asciiZero, zmmTen, lowbits_l); // highbits_l = lowbits_l * 10 + asciiZero

// idx & 0x40 ? highbits_h[idx & 0x3F] : highbits_l[idx & 0x3F]
__m512i perm = _mm512_permutex2var_epi8(highbits_h, permb_const, highbits_l);
Expand Down