diff --git a/aphrodite_kernels/CMakeLists.txt b/aphrodite_kernels/CMakeLists.txt index 9ed334475a..43c78e3715 100644 --- a/aphrodite_kernels/CMakeLists.txt +++ b/aphrodite_kernels/CMakeLists.txt @@ -100,7 +100,8 @@ endif() # if (NOT APHRODITE_TARGET_DEVICE STREQUAL "cuda" AND NOT APHRODITE_TARGET_DEVICE STREQUAL "rocm") - if (APHRODITE_TARGET_DEVICE STREQUAL "cpu") + if (APHRODITE_TARGET_DEVICE STREQUAL "cpu" OR + APHRODITE_TARGET_DEVICE STREQUAL "android") include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake) else() return() diff --git a/aphrodite_kernels/cmake/cpu_extension.cmake b/aphrodite_kernels/cmake/cpu_extension.cmake index 55d197cffe..d314727840 100644 --- a/aphrodite_kernels/cmake/cpu_extension.cmake +++ b/aphrodite_kernels/cmake/cpu_extension.cmake @@ -337,6 +337,32 @@ if(USE_ONEDNN) ${APHRODITE_EXT_SRC}) endif() +# Mobile-optimized kernels (ARM NEON) for Android +if(APHRODITE_TARGET_DEVICE STREQUAL "android") + message(STATUS "Building mobile-optimized kernels for Android") + set(APHRODITE_MOBILE_SRC + "csrc/cpu/mobile/quant.cpp" + "csrc/cpu/mobile/blas.cpp" + "csrc/cpu/mobile/gemm.cpp" + "csrc/cpu/mobile/reduce.cpp" + "csrc/cpu/mobile/scalar.cpp" + "csrc/cpu/mobile/nn.cpp" + "csrc/cpu/mobile/attention.cpp" + "csrc/cpu/mobile/torch_bindings.cpp") + set(APHRODITE_EXT_SRC + ${APHRODITE_EXT_SRC} + ${APHRODITE_MOBILE_SRC}) + # Add ARM NEON compile flags for mobile kernels + # Note: These flags may already be set by ASIMD_FOUND detection above + # but we ensure they're present for Android builds + if(NOT ASIMD_FOUND) + list(APPEND CXX_COMPILE_FLAGS + "-march=armv8-a+neon" + "-mfpu=neon-fp-armv8") + endif() + message(STATUS "Mobile kernel source files: ${APHRODITE_MOBILE_SRC}") +endif() + message(STATUS "CPU extension source files: ${APHRODITE_EXT_SRC}") # diff --git a/aphrodite_kernels/csrc/cpu/mobile/README.md b/aphrodite_kernels/csrc/cpu/mobile/README.md new file mode 100644 index 0000000000..7583b6f087 --- /dev/null +++ b/aphrodite_kernels/csrc/cpu/mobile/README.md @@ -0,0 +1 @@ +Mobile-optimized compute kernels, based on [cactus](https://github.com/cactus-compute/cactus). \ No newline at end of file diff --git a/aphrodite_kernels/csrc/cpu/mobile/attention.cpp b/aphrodite_kernels/csrc/cpu/mobile/attention.cpp new file mode 100644 index 0000000000..dcc8d7acae --- /dev/null +++ b/aphrodite_kernels/csrc/cpu/mobile/attention.cpp @@ -0,0 +1,1268 @@ +#include "threading.hpp" +#include +#include +#include +#include +#include +#include + +namespace aphrodite::mobile { + +void attention_int8(const int8_t* queries, const int8_t* keys, + const int8_t* values, int8_t* output, size_t batch_size, + size_t seq_len, size_t kv_seq_len, size_t num_q_heads, + size_t num_kv_heads, size_t head_dim, float scale, + const int8_t* mask, float q_scale, float k_scale, + float v_scale, float output_scale, size_t position_offset, + size_t window_size, bool is_causal) { + if (scale == 0.0f) { + scale = 1.0f / sqrtf(static_cast(head_dim)); + } + + constexpr size_t VECTOR_WIDTH = 16; + constexpr size_t TILE_Q = 4; + constexpr size_t TILE_K = 8; + constexpr size_t VECTOR_UNROLL = 2; + const size_t head_dim_aligned = (head_dim / (VECTOR_WIDTH * VECTOR_UNROLL)) * + (VECTOR_WIDTH * VECTOR_UNROLL); + + const size_t group_size = num_q_heads / num_kv_heads; + + const size_t q_batch_stride = seq_len * num_q_heads * head_dim; + const size_t kv_batch_stride = kv_seq_len * num_kv_heads * head_dim; + const size_t o_batch_stride = seq_len * num_q_heads * head_dim; + const size_t q_seq_stride = num_q_heads * head_dim; + const size_t kv_seq_stride = num_kv_heads * head_dim; + const size_t o_seq_stride = num_q_heads * head_dim; + const size_t mask_batch_stride = mask ? seq_len * kv_seq_len : 0; + + aphrodite::mobile::parallel_for( + batch_size * num_q_heads * seq_len, + aphrodite::mobile::Thresholds::ATTENTION, + [=](size_t start_idx, size_t end_idx) { + for (size_t work_idx = start_idx; work_idx < end_idx; ++work_idx) { + const size_t batch_idx = work_idx / (num_q_heads * seq_len); + const size_t remainder = work_idx % (num_q_heads * seq_len); + const size_t q_head_idx = remainder / seq_len; + const size_t q_pos = remainder % seq_len; + const size_t kv_head_idx = q_head_idx / group_size; + + const int8_t* Q_base = queries + batch_idx * q_batch_stride; + const int8_t* K_base = keys + batch_idx * kv_batch_stride; + const int8_t* V_base = values + batch_idx * kv_batch_stride; + int8_t* O_base = output + batch_idx * o_batch_stride; + const int8_t* M = + mask ? (mask + batch_idx * mask_batch_stride) : nullptr; + + for (size_t q_start = q_pos; q_start <= q_pos; q_start += TILE_Q) { + const size_t q_end = std::min(q_start + TILE_Q, seq_len); + + std::vector attention_scores( + TILE_Q * kv_seq_len, -std::numeric_limits::infinity()); + + for (size_t q_offset = 0; q_offset < (q_end - q_start); + ++q_offset) { + const size_t q_pos = q_start + q_offset; + const int8_t* q_vec = + Q_base + q_pos * q_seq_stride + q_head_idx * head_dim; + + for (size_t kv_start = 0; kv_start < kv_seq_len; + kv_start += TILE_K) { + const size_t kv_end = std::min(kv_start + TILE_K, kv_seq_len); + + std::vector accumulators(TILE_K, vdupq_n_s32(0)); + + for (size_t dim_block = 0; dim_block < head_dim_aligned; + dim_block += VECTOR_WIDTH * VECTOR_UNROLL) { + int8x16_t q_vec_low = vld1q_s8(&q_vec[dim_block]); + int8x16_t q_vec_high = + vld1q_s8(&q_vec[dim_block + VECTOR_WIDTH]); + + for (size_t kv_idx = 0; kv_idx < (kv_end - kv_start); + ++kv_idx) { + const size_t kv_pos = kv_start + kv_idx; + const int8_t* k_vec = K_base + kv_pos * kv_seq_stride + + kv_head_idx * head_dim; + + int8x16_t k_vec_low = vld1q_s8(&k_vec[dim_block]); + int8x16_t k_vec_high = + vld1q_s8(&k_vec[dim_block + VECTOR_WIDTH]); + + accumulators[kv_idx] = + accum_i8mm(accumulators[kv_idx], q_vec_low, k_vec_low); + accumulators[kv_idx] = accum_i8mm(accumulators[kv_idx], + q_vec_high, k_vec_high); + } + } + + for (size_t dim_block = head_dim_aligned; dim_block < head_dim; + dim_block += VECTOR_WIDTH) { + size_t remaining = head_dim - dim_block; + + int8_t q_tmp[VECTOR_WIDTH] = {}; + if (remaining >= VECTOR_WIDTH) { + memcpy(q_tmp, &q_vec[dim_block], VECTOR_WIDTH); + } else { + memcpy(q_tmp, &q_vec[dim_block], remaining); + } + int8x16_t q_vec_remainder = vld1q_s8(q_tmp); + + for (size_t kv_idx = 0; kv_idx < (kv_end - kv_start); + ++kv_idx) { + const size_t kv_pos = kv_start + kv_idx; + const int8_t* k_vec = K_base + kv_pos * kv_seq_stride + + kv_head_idx * head_dim; + + int8_t k_tmp[VECTOR_WIDTH] = {}; + if (remaining >= VECTOR_WIDTH) { + memcpy(k_tmp, &k_vec[dim_block], VECTOR_WIDTH); + } else { + memcpy(k_tmp, &k_vec[dim_block], remaining); + } + int8x16_t k_vec_remainder = vld1q_s8(k_tmp); + + accumulators[kv_idx] = accum_i8mm( + accumulators[kv_idx], q_vec_remainder, k_vec_remainder); + } + } + + for (size_t kv_idx = 0; kv_idx < (kv_end - kv_start); + ++kv_idx) { + const size_t kv_pos = kv_start + kv_idx; + int32_t score = vaddvq_s32(accumulators[kv_idx]); + float attention_score = + static_cast(score) * q_scale * k_scale * scale; + + size_t absolute_q_pos = position_offset + q_pos; + + if (is_causal && kv_pos > absolute_q_pos) { + attention_score = -std::numeric_limits::infinity(); + } else if (window_size > 0 && kv_pos < absolute_q_pos && + (absolute_q_pos - kv_pos) > window_size) { + attention_score = -std::numeric_limits::infinity(); + } else if (M) { + const int8_t mask_val = M[q_pos * kv_seq_len + kv_pos]; + if (mask_val == 0) { + attention_score = -std::numeric_limits::infinity(); + } + } + + attention_scores[q_offset * kv_seq_len + kv_pos] = + attention_score; + } + } + } + + for (size_t q_offset = 0; q_offset < (q_end - q_start); + ++q_offset) { + const size_t q_pos = q_start + q_offset; + float* scores_row = &attention_scores[q_offset * kv_seq_len]; + + float max_score = -std::numeric_limits::infinity(); + for (size_t kv_pos = 0; kv_pos < kv_seq_len; ++kv_pos) { + max_score = std::max(max_score, scores_row[kv_pos]); + } + + float sum_exp = 0.0f; + for (size_t kv_pos = 0; kv_pos < kv_seq_len; ++kv_pos) { + if (scores_row[kv_pos] != + -std::numeric_limits::infinity()) { + scores_row[kv_pos] = expf(scores_row[kv_pos] - max_score); + sum_exp += scores_row[kv_pos]; + } else { + scores_row[kv_pos] = 0.0f; + } + } + + if (sum_exp > 0.0f) { + const float inv_sum = 1.0f / sum_exp; + for (size_t kv_pos = 0; kv_pos < kv_seq_len; ++kv_pos) { + scores_row[kv_pos] *= inv_sum; + } + } + + int8_t* o_vec = + O_base + q_pos * o_seq_stride + q_head_idx * head_dim; + std::fill(o_vec, o_vec + head_dim, 0); + + for (size_t kv_pos = 0; kv_pos < kv_seq_len; ++kv_pos) { + const float attn_weight = scores_row[kv_pos]; + if (attn_weight == 0.0f) continue; + + const int8_t* v_vec = + V_base + kv_pos * kv_seq_stride + kv_head_idx * head_dim; + + for (size_t dim = 0; dim < head_dim; ++dim) { + float weighted_val_fp32 = + attn_weight * static_cast(v_vec[dim]) * v_scale; + float current_fp32 = + static_cast(o_vec[dim]) * output_scale; + float result_fp32 = current_fp32 + weighted_val_fp32; + + int32_t quantized_result = + static_cast(result_fp32 / output_scale + + (result_fp32 >= 0 ? 0.5f : -0.5f)); + quantized_result = + std::max(-128, std::min(127, quantized_result)); + o_vec[dim] = static_cast(quantized_result); + } + } + } + } + } + }); +} + +void attention_f32(const float* queries, const float* keys, const float* values, + float* output, size_t batch_size, size_t seq_len, + size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads, + size_t head_dim, float scale, const float* mask, + size_t position_offset, size_t window_size, bool is_causal) { + if (scale == 0.0f) { + scale = 1.0f / sqrtf(static_cast(head_dim)); + } + + constexpr size_t VECTOR_WIDTH = 4; + constexpr size_t TILE_Q = 4; + constexpr size_t TILE_K = 8; + constexpr size_t VECTOR_UNROLL = 2; + const size_t head_dim_aligned = (head_dim / (VECTOR_WIDTH * VECTOR_UNROLL)) * + (VECTOR_WIDTH * VECTOR_UNROLL); + + const size_t group_size = num_q_heads / num_kv_heads; + + const size_t q_batch_stride = seq_len * num_q_heads * head_dim; + const size_t kv_batch_stride = kv_seq_len * num_kv_heads * head_dim; + const size_t o_batch_stride = seq_len * num_q_heads * head_dim; + const size_t q_seq_stride = num_q_heads * head_dim; + const size_t kv_seq_stride = num_kv_heads * head_dim; + const size_t o_seq_stride = num_q_heads * head_dim; + const size_t mask_batch_stride = mask ? seq_len * kv_seq_len : 0; + + aphrodite::mobile::parallel_for( + batch_size * num_q_heads * seq_len, + aphrodite::mobile::Thresholds::ATTENTION, + [=](size_t start_idx, size_t end_idx) { + for (size_t work_idx = start_idx; work_idx < end_idx; ++work_idx) { + const size_t batch_idx = work_idx / (num_q_heads * seq_len); + const size_t remainder = work_idx % (num_q_heads * seq_len); + const size_t q_head_idx = remainder / seq_len; + const size_t q_pos = remainder % seq_len; + const size_t kv_head_idx = q_head_idx / group_size; + + const float* Q_base = queries + batch_idx * q_batch_stride; + const float* K_base = keys + batch_idx * kv_batch_stride; + const float* V_base = values + batch_idx * kv_batch_stride; + float* O_base = output + batch_idx * o_batch_stride; + const float* M = + mask ? (mask + batch_idx * mask_batch_stride) : nullptr; + + for (size_t q_start = q_pos; q_start <= q_pos; q_start += TILE_Q) { + const size_t q_end = std::min(q_start + TILE_Q, seq_len); + + std::vector attention_scores( + TILE_Q * kv_seq_len, -std::numeric_limits::infinity()); + + for (size_t q_offset = 0; q_offset < (q_end - q_start); + ++q_offset) { + const size_t q_pos = q_start + q_offset; + const float* q_vec = + Q_base + q_pos * q_seq_stride + q_head_idx * head_dim; + + for (size_t kv_start = 0; kv_start < kv_seq_len; + kv_start += TILE_K) { + const size_t kv_end = std::min(kv_start + TILE_K, kv_seq_len); + + std::vector accumulators(TILE_K, + vdupq_n_f32(0.0f)); + + for (size_t dim_block = 0; dim_block < head_dim_aligned; + dim_block += VECTOR_WIDTH * VECTOR_UNROLL) { + float32x4_t q_vec_low = vld1q_f32(&q_vec[dim_block]); + float32x4_t q_vec_high = + vld1q_f32(&q_vec[dim_block + VECTOR_WIDTH]); + + for (size_t kv_idx = 0; kv_idx < (kv_end - kv_start); + ++kv_idx) { + const size_t kv_pos = kv_start + kv_idx; + const float* k_vec = K_base + kv_pos * kv_seq_stride + + kv_head_idx * head_dim; + + if (kv_idx + 1 < (kv_end - kv_start)) { + const float* next_k_vec = K_base + + (kv_pos + 1) * kv_seq_stride + + kv_head_idx * head_dim; + __builtin_prefetch(next_k_vec + dim_block, 0, 1); + } + + float32x4_t k_vec_low = vld1q_f32(&k_vec[dim_block]); + float32x4_t k_vec_high = + vld1q_f32(&k_vec[dim_block + VECTOR_WIDTH]); + + accumulators[kv_idx] = + vfmaq_f32(accumulators[kv_idx], q_vec_low, k_vec_low); + accumulators[kv_idx] = + vfmaq_f32(accumulators[kv_idx], q_vec_high, k_vec_high); + } + } + + for (size_t dim_block = head_dim_aligned; dim_block < head_dim; + dim_block += VECTOR_WIDTH) { + size_t remaining = head_dim - dim_block; + + float q_tmp[VECTOR_WIDTH] = {}; + if (remaining >= VECTOR_WIDTH) { + memcpy(q_tmp, &q_vec[dim_block], + VECTOR_WIDTH * sizeof(float)); + } else { + memcpy(q_tmp, &q_vec[dim_block], remaining * sizeof(float)); + } + float32x4_t q_vec_remainder = vld1q_f32(q_tmp); + + for (size_t kv_idx = 0; kv_idx < (kv_end - kv_start); + ++kv_idx) { + const size_t kv_pos = kv_start + kv_idx; + const float* k_vec = K_base + kv_pos * kv_seq_stride + + kv_head_idx * head_dim; + + float k_tmp[VECTOR_WIDTH] = {}; + if (remaining >= VECTOR_WIDTH) { + memcpy(k_tmp, &k_vec[dim_block], + VECTOR_WIDTH * sizeof(float)); + } else { + memcpy(k_tmp, &k_vec[dim_block], + remaining * sizeof(float)); + } + float32x4_t k_vec_remainder = vld1q_f32(k_tmp); + + accumulators[kv_idx] = vfmaq_f32( + accumulators[kv_idx], q_vec_remainder, k_vec_remainder); + } + } + + for (size_t kv_idx = 0; kv_idx < (kv_end - kv_start); + ++kv_idx) { + const size_t kv_pos = kv_start + kv_idx; + float score = vaddvq_f32(accumulators[kv_idx]); + float attention_score = score * scale; + + size_t absolute_q_pos = position_offset + q_pos; + + if (is_causal && kv_pos > absolute_q_pos) { + attention_score = -std::numeric_limits::infinity(); + } else if (window_size > 0 && kv_pos < absolute_q_pos && + (absolute_q_pos - kv_pos) > window_size) { + attention_score = -std::numeric_limits::infinity(); + } else if (M) { + const float mask_val = M[q_pos * kv_seq_len + kv_pos]; + if (mask_val == 0.0f) { + attention_score = -std::numeric_limits::infinity(); + } + } + + attention_scores[q_offset * kv_seq_len + kv_pos] = + attention_score; + } + } + } + + for (size_t q_offset = 0; q_offset < (q_end - q_start); + ++q_offset) { + const size_t q_pos = q_start + q_offset; + float* scores_row = &attention_scores[q_offset * kv_seq_len]; + + float max_score = -std::numeric_limits::infinity(); + for (size_t kv_pos = 0; kv_pos < kv_seq_len; ++kv_pos) { + max_score = std::max(max_score, scores_row[kv_pos]); + } + + float sum_exp = 0.0f; + for (size_t kv_pos = 0; kv_pos < kv_seq_len; ++kv_pos) { + if (scores_row[kv_pos] != + -std::numeric_limits::infinity()) { + scores_row[kv_pos] = expf(scores_row[kv_pos] - max_score); + sum_exp += scores_row[kv_pos]; + } else { + scores_row[kv_pos] = 0.0f; + } + } + + if (sum_exp > 0.0f) { + const float inv_sum = 1.0f / sum_exp; + for (size_t kv_pos = 0; kv_pos < kv_seq_len; ++kv_pos) { + scores_row[kv_pos] *= inv_sum; + } + } + + float* o_vec = + O_base + q_pos * o_seq_stride + q_head_idx * head_dim; + std::fill(o_vec, o_vec + head_dim, 0.0f); + + for (size_t kv_pos = 0; kv_pos < kv_seq_len; ++kv_pos) { + const float attn_weight = scores_row[kv_pos]; + if (attn_weight == 0.0f) continue; + + const float* v_vec = + V_base + kv_pos * kv_seq_stride + kv_head_idx * head_dim; + + if (kv_pos + 1 < kv_seq_len) { + const float* next_v_vec = V_base + + (kv_pos + 1) * kv_seq_stride + + kv_head_idx * head_dim; + __builtin_prefetch(next_v_vec, 0, 1); + } + + size_t dim_aligned = (head_dim / VECTOR_WIDTH) * VECTOR_WIDTH; + float32x4_t weight_vec = vdupq_n_f32(attn_weight); + + for (size_t dim = 0; dim < dim_aligned; dim += VECTOR_WIDTH) { + float32x4_t v_values = vld1q_f32(&v_vec[dim]); + float32x4_t o_values = vld1q_f32(&o_vec[dim]); + float32x4_t weighted = vmulq_f32(v_values, weight_vec); + float32x4_t result = vaddq_f32(o_values, weighted); + vst1q_f32(&o_vec[dim], result); + } + + for (size_t dim = dim_aligned; dim < head_dim; ++dim) { + o_vec[dim] += v_vec[dim] * attn_weight; + } + } + } + } + } + }); +} + +void attention_f16(const __fp16* queries, const __fp16* keys, + const __fp16* values, __fp16* output, size_t batch_size, + size_t seq_len, size_t kv_seq_len, size_t num_q_heads, + size_t num_kv_heads, size_t head_dim, float scale, + const __fp16* mask, size_t position_offset, + size_t window_size, bool is_causal) { + if (scale == 0.0f) { + scale = 1.0f / sqrtf(static_cast(head_dim)); + } + + constexpr size_t VECTOR_WIDTH = 8; + constexpr size_t BLOCK_SIZE = 32; + const size_t head_dim_aligned = (head_dim / VECTOR_WIDTH) * VECTOR_WIDTH; + + const size_t group_size = num_q_heads / num_kv_heads; + + const size_t q_batch_stride = seq_len * num_q_heads * head_dim; + const size_t kv_batch_stride = kv_seq_len * num_kv_heads * head_dim; + const size_t o_batch_stride = seq_len * num_q_heads * head_dim; + const size_t q_seq_stride = num_q_heads * head_dim; + const size_t kv_seq_stride = num_kv_heads * head_dim; + const size_t o_seq_stride = num_q_heads * head_dim; + const size_t mask_batch_stride = mask ? seq_len * kv_seq_len : 0; + + aphrodite::mobile::parallel_for( + batch_size * num_q_heads * seq_len, + aphrodite::mobile::Thresholds::ATTENTION, + [=](size_t start_idx, size_t end_idx) { + std::vector block_scores(BLOCK_SIZE); + std::vector output_accum_low(head_dim_aligned / + VECTOR_WIDTH * 2); + std::vector output_accum_high(head_dim_aligned / + VECTOR_WIDTH * 2); + + for (size_t work_idx = start_idx; work_idx < end_idx; ++work_idx) { + const size_t batch_idx = work_idx / (num_q_heads * seq_len); + const size_t remainder = work_idx % (num_q_heads * seq_len); + const size_t q_head_idx = remainder / seq_len; + const size_t q_pos = remainder % seq_len; + + const size_t kv_head_idx = q_head_idx / group_size; + + const __fp16* Q_base = queries + batch_idx * q_batch_stride; + const __fp16* K_base = keys + batch_idx * kv_batch_stride; + const __fp16* V_base = values + batch_idx * kv_batch_stride; + __fp16* O_base = output + batch_idx * o_batch_stride; + const __fp16* M = + mask ? (mask + batch_idx * mask_batch_stride) : nullptr; + const __fp16* q_vec = + Q_base + q_pos * q_seq_stride + q_head_idx * head_dim; + __fp16* o_vec = O_base + q_pos * o_seq_stride + q_head_idx * head_dim; + + float running_max = -std::numeric_limits::infinity(); + float running_sum = 0.0f; + + for (size_t i = 0; i < output_accum_low.size(); ++i) { + output_accum_low[i] = vdupq_n_f32(0.0f); + output_accum_high[i] = vdupq_n_f32(0.0f); + } + + const bool is_decode = (q_pos == seq_len - 1) && seq_len > 1; + const size_t absolute_q_pos = position_offset + q_pos; + + size_t kv_start = 0; + size_t kv_end = kv_seq_len; + + if (window_size > 0 && window_size < kv_seq_len) { + if (absolute_q_pos > window_size) { + kv_start = absolute_q_pos - window_size; + } + if (is_causal) { + kv_end = std::min(kv_end, absolute_q_pos + 1); + } + } else if (is_causal) { + kv_end = std::min(kv_end, absolute_q_pos + 1); + } + + for (size_t kv_block_start = kv_start; kv_block_start < kv_end; + kv_block_start += BLOCK_SIZE) { + const size_t kv_block_end = + std::min(kv_block_start + BLOCK_SIZE, kv_seq_len); + const size_t block_size = kv_block_end - kv_block_start; + + float block_max = -std::numeric_limits::infinity(); + + if (!is_decode && is_causal && kv_block_start > absolute_q_pos) { + for (size_t kv_idx = 0; kv_idx < block_size; ++kv_idx) { + block_scores[kv_idx] = -std::numeric_limits::infinity(); + } + continue; + } + + for (size_t kv_idx = 0; kv_idx < block_size; ++kv_idx) { + const size_t kv_pos = kv_block_start + kv_idx; + + if (!is_decode && is_causal && kv_pos > absolute_q_pos) { + block_scores[kv_idx] = -std::numeric_limits::infinity(); + continue; + } + + const __fp16* k_vec = + K_base + kv_pos * kv_seq_stride + kv_head_idx * head_dim; + + if (kv_idx + 1 < block_size) { + const __fp16* next_k_vec = K_base + + (kv_pos + 1) * kv_seq_stride + + kv_head_idx * head_dim; + __builtin_prefetch(next_k_vec, 0, 1); + } + + float32x4_t score_accum_low = vdupq_n_f32(0.0f); + float32x4_t score_accum_high = vdupq_n_f32(0.0f); + + for (size_t dim_block = 0; dim_block < head_dim_aligned; + dim_block += VECTOR_WIDTH) { + float16x8_t q_vec_f16 = vld1q_f16(&q_vec[dim_block]); + float16x8_t k_vec_f16 = vld1q_f16(&k_vec[dim_block]); + + float32x4_t q_low = vcvt_f32_f16(vget_low_f16(q_vec_f16)); + float32x4_t q_high = vcvt_f32_f16(vget_high_f16(q_vec_f16)); + float32x4_t k_low = vcvt_f32_f16(vget_low_f16(k_vec_f16)); + float32x4_t k_high = vcvt_f32_f16(vget_high_f16(k_vec_f16)); + + score_accum_low = vfmaq_f32(score_accum_low, q_low, k_low); + score_accum_high = vfmaq_f32(score_accum_high, q_high, k_high); + } + + float score = + vaddvq_f32(vaddq_f32(score_accum_low, score_accum_high)); + + for (size_t dim = head_dim_aligned; dim < head_dim; ++dim) { + score += static_cast(q_vec[dim]) * + static_cast(k_vec[dim]); + } + + score *= scale; + + size_t absolute_q_pos = position_offset + q_pos; + + if (is_causal && kv_pos > absolute_q_pos) { + score = -std::numeric_limits::infinity(); + } else if (window_size > 0 && kv_pos < absolute_q_pos && + (absolute_q_pos - kv_pos) > window_size) { + score = -std::numeric_limits::infinity(); + } else if (M && static_cast( + M[q_pos * kv_seq_len + kv_pos]) == 0.0f) { + score = -std::numeric_limits::infinity(); + } + + block_scores[kv_idx] = score; + block_max = std::max(block_max, score); + } + + if (block_max > -std::numeric_limits::infinity()) { + float scale_correction = expf(running_max - block_max); + running_sum *= scale_correction; + + for (size_t i = 0; i < output_accum_low.size() / 2; ++i) { + output_accum_low[i] = + vmulq_n_f32(output_accum_low[i], scale_correction); + output_accum_high[i] = + vmulq_n_f32(output_accum_high[i], scale_correction); + } + running_max = block_max; + } + + float block_sum = 0.0f; + const size_t vec_size = (block_size / 4) * 4; + + for (size_t kv_idx = 0; kv_idx < vec_size; kv_idx += 4) { + float32x4_t scores = vld1q_f32(&block_scores[kv_idx]); + uint32x4_t inf_mask = vceqq_f32( + scores, vdupq_n_f32(-std::numeric_limits::infinity())); + + float32x4_t x = vsubq_f32(scores, vdupq_n_f32(block_max)); + x = vmulq_n_f32(x, 1.442695f); + + int32x4_t xi = vcvtq_s32_f32(x); + float32x4_t xf = vsubq_f32(x, vcvtq_f32_s32(xi)); + + float32x4_t y = vfmaq_n_f32(vdupq_n_f32(1.0f), xf, 0.6931472f); + y = vfmaq_f32(y, vmulq_f32(xf, xf), vdupq_n_f32(0.2402265f)); + + xi = vaddq_s32(xi, vdupq_n_s32(127)); + xi = vshlq_n_s32(xi, 23); + y = vmulq_f32(y, vreinterpretq_f32_s32(xi)); + + y = vbslq_f32(inf_mask, vdupq_n_f32(0.0f), y); + + vst1q_f32(&block_scores[kv_idx], y); + block_sum += vaddvq_f32(y); + } + + for (size_t kv_idx = vec_size; kv_idx < block_size; ++kv_idx) { + if (block_scores[kv_idx] != + -std::numeric_limits::infinity()) { + block_scores[kv_idx] = expf(block_scores[kv_idx] - block_max); + block_sum += block_scores[kv_idx]; + } else { + block_scores[kv_idx] = 0.0f; + } + } + + for (size_t kv_idx = 0; kv_idx < block_size; ++kv_idx) { + const float attn_weight = block_scores[kv_idx]; + if (attn_weight == 0.0f) continue; + + const size_t kv_pos = kv_block_start + kv_idx; + const __fp16* v_vec = + V_base + kv_pos * kv_seq_stride + kv_head_idx * head_dim; + + const float32x4_t weight_vec = vdupq_n_f32(attn_weight); + + for (size_t dim_block = 0; dim_block < head_dim_aligned; + dim_block += VECTOR_WIDTH) { + float16x8_t v_vec_f16 = vld1q_f16(&v_vec[dim_block]); + float32x4_t v_low = vcvt_f32_f16(vget_low_f16(v_vec_f16)); + float32x4_t v_high = vcvt_f32_f16(vget_high_f16(v_vec_f16)); + + size_t idx = dim_block / VECTOR_WIDTH; + output_accum_low[idx] = + vfmaq_f32(output_accum_low[idx], v_low, weight_vec); + output_accum_high[idx] = + vfmaq_f32(output_accum_high[idx], v_high, weight_vec); + } + } + + running_sum += block_sum; + } + + if (running_sum > 0.0f) { + const float inv_sum = 1.0f / running_sum; + const float32x4_t inv_sum_vec = vdupq_n_f32(inv_sum); + + for (size_t dim_block = 0; dim_block < head_dim_aligned; + dim_block += VECTOR_WIDTH) { + size_t idx = dim_block / VECTOR_WIDTH; + float32x4_t final_low = + vmulq_f32(output_accum_low[idx], inv_sum_vec); + float32x4_t final_high = + vmulq_f32(output_accum_high[idx], inv_sum_vec); + + float16x4_t low_f16 = vcvt_f16_f32(final_low); + float16x4_t high_f16 = vcvt_f16_f32(final_high); + float16x8_t combined = vcombine_f16(low_f16, high_f16); + + vst1q_f16(&o_vec[dim_block], combined); + } + + for (size_t dim = head_dim_aligned; dim < head_dim; ++dim) { + o_vec[dim] = static_cast<__fp16>(0.0f); + } + } else { + for (size_t dim = 0; dim < head_dim; ++dim) { + o_vec[dim] = static_cast<__fp16>(0.0f); + } + } + } + }); +} + +void rms_norm_f32(const float* input, const float* weight, float* output, + size_t batch_size, size_t dims, float eps) { + constexpr size_t SIMD_WIDTH = 4; + constexpr size_t UNROLL_FACTOR = 4; + constexpr size_t TILE_SIZE = SIMD_WIDTH * UNROLL_FACTOR; + + for (size_t b = 0; b < batch_size; ++b) { + const float* input_row = input + b * dims; + float* output_row = output + b * dims; + + float32x4_t sum_squares_vec[UNROLL_FACTOR]; + for (size_t u = 0; u < UNROLL_FACTOR; u++) { + sum_squares_vec[u] = vdupq_n_f32(0.0f); + } + + size_t i = 0; + const size_t tile_end = (dims >= TILE_SIZE) ? dims - TILE_SIZE + 1 : 0; + + for (; i < tile_end; i += TILE_SIZE) { + for (size_t u = 0; u < UNROLL_FACTOR; u++) { + float32x4_t input_vec = vld1q_f32(&input_row[i + u * SIMD_WIDTH]); + sum_squares_vec[u] = + vfmaq_f32(sum_squares_vec[u], input_vec, input_vec); + } + } + + const size_t simd_end = (dims >= SIMD_WIDTH) ? dims - SIMD_WIDTH + 1 : 0; + for (; i < simd_end; i += SIMD_WIDTH) { + float32x4_t input_vec = vld1q_f32(&input_row[i]); + sum_squares_vec[0] = vfmaq_f32(sum_squares_vec[0], input_vec, input_vec); + } + + float32x4_t total_sum = sum_squares_vec[0]; + for (size_t u = 1; u < UNROLL_FACTOR; u++) { + total_sum = vaddq_f32(total_sum, sum_squares_vec[u]); + } + float sum_squares = vaddvq_f32(total_sum); + + for (; i < dims; ++i) { + float val = input_row[i]; + sum_squares += val * val; + } + + float rms = sqrtf(sum_squares / static_cast(dims) + eps); + float inv_rms = 1.0f / rms; + float32x4_t inv_rms_vec = vdupq_n_f32(inv_rms); + + i = 0; + for (; i < tile_end; i += TILE_SIZE) { + for (size_t u = 0; u < UNROLL_FACTOR; u++) { + float32x4_t input_vec = vld1q_f32(&input_row[i + u * SIMD_WIDTH]); + float32x4_t weight_vec = vld1q_f32(&weight[i + u * SIMD_WIDTH]); + float32x4_t norm_vec = + vmulq_f32(vmulq_f32(input_vec, inv_rms_vec), weight_vec); + vst1q_f32(&output_row[i + u * SIMD_WIDTH], norm_vec); + } + } + + for (; i < simd_end; i += SIMD_WIDTH) { + float32x4_t input_vec = vld1q_f32(&input_row[i]); + float32x4_t weight_vec = vld1q_f32(&weight[i]); + float32x4_t norm_vec = + vmulq_f32(vmulq_f32(input_vec, inv_rms_vec), weight_vec); + vst1q_f32(&output_row[i], norm_vec); + } + + for (; i < dims; ++i) { + output_row[i] = input_row[i] * inv_rms * weight[i]; + } + } +} + +void rms_norm_i8_f32(const int8_t* input, const float* weight, float* output, + size_t batch_size, size_t dims, float eps, + float input_scale) { + constexpr size_t SIMD_WIDTH = 4; + constexpr size_t UNROLL_FACTOR = 4; + constexpr size_t TILE_SIZE = SIMD_WIDTH * UNROLL_FACTOR; + + const float32x4_t input_scale_vec = vdupq_n_f32(input_scale); + + for (size_t b = 0; b < batch_size; ++b) { + const int8_t* input_row = input + b * dims; + float* output_row = output + b * dims; + + float32x4_t sum_squares_vec[UNROLL_FACTOR]; + for (size_t u = 0; u < UNROLL_FACTOR; u++) { + sum_squares_vec[u] = vdupq_n_f32(0.0f); + } + + size_t i = 0; + const size_t tile_end = (dims >= TILE_SIZE) ? dims - TILE_SIZE + 1 : 0; + + for (; i < tile_end; i += TILE_SIZE) { + for (size_t u = 0; u < UNROLL_FACTOR; u++) { + int8x8_t input_i8 = vld1_s8(&input_row[i + u * SIMD_WIDTH]); + int16x4_t input_i16 = vget_low_s16(vmovl_s8(input_i8)); + int32x4_t input_i32 = vmovl_s16(input_i16); + float32x4_t input_f32 = + vmulq_f32(vcvtq_f32_s32(input_i32), input_scale_vec); + sum_squares_vec[u] = + vfmaq_f32(sum_squares_vec[u], input_f32, input_f32); + } + } + + const size_t simd_end = (dims >= SIMD_WIDTH) ? dims - SIMD_WIDTH + 1 : 0; + for (; i < simd_end; i += SIMD_WIDTH) { + int8x8_t input_i8 = vld1_s8(&input_row[i]); + int16x4_t input_i16 = vget_low_s16(vmovl_s8(input_i8)); + int32x4_t input_i32 = vmovl_s16(input_i16); + float32x4_t input_f32 = + vmulq_f32(vcvtq_f32_s32(input_i32), input_scale_vec); + sum_squares_vec[0] = vfmaq_f32(sum_squares_vec[0], input_f32, input_f32); + } + + float32x4_t total_sum = sum_squares_vec[0]; + for (size_t u = 1; u < UNROLL_FACTOR; u++) { + total_sum = vaddq_f32(total_sum, sum_squares_vec[u]); + } + float sum_squares = vaddvq_f32(total_sum); + + for (; i < dims; ++i) { + float val = static_cast(input_row[i]) * input_scale; + sum_squares += val * val; + } + + float rms = sqrtf(sum_squares / static_cast(dims) + eps); + float inv_rms = 1.0f / rms; + float32x4_t inv_rms_vec = vdupq_n_f32(inv_rms); + + i = 0; + for (; i < tile_end; i += TILE_SIZE) { + for (size_t u = 0; u < UNROLL_FACTOR; u++) { + int8x8_t input_i8 = vld1_s8(&input_row[i + u * SIMD_WIDTH]); + int16x4_t input_i16 = vget_low_s16(vmovl_s8(input_i8)); + int32x4_t input_i32 = vmovl_s16(input_i16); + float32x4_t input_f32 = + vmulq_f32(vcvtq_f32_s32(input_i32), input_scale_vec); + + float32x4_t weight_vec = vld1q_f32(&weight[i + u * SIMD_WIDTH]); + float32x4_t norm_f32 = + vmulq_f32(vmulq_f32(input_f32, inv_rms_vec), weight_vec); + + vst1q_f32(&output_row[i + u * SIMD_WIDTH], norm_f32); + } + } + + for (; i < simd_end; i += SIMD_WIDTH) { + int8x8_t input_i8 = vld1_s8(&input_row[i]); + int16x4_t input_i16 = vget_low_s16(vmovl_s8(input_i8)); + int32x4_t input_i32 = vmovl_s16(input_i16); + float32x4_t input_f32 = + vmulq_f32(vcvtq_f32_s32(input_i32), input_scale_vec); + + float32x4_t weight_vec = vld1q_f32(&weight[i]); + float32x4_t norm_f32 = + vmulq_f32(vmulq_f32(input_f32, inv_rms_vec), weight_vec); + + vst1q_f32(&output_row[i], norm_f32); + } + + for (; i < dims; ++i) { + float val = static_cast(input_row[i]) * input_scale; + output_row[i] = (val * inv_rms) * weight[i]; + } + } +} + +void rms_norm_f16(const __fp16* input, const __fp16* weight, __fp16* output, + size_t batch_size, size_t dims, float eps) { + constexpr size_t SIMD_WIDTH = 8; + constexpr size_t UNROLL_FACTOR = 2; + constexpr size_t TILE_SIZE = SIMD_WIDTH * UNROLL_FACTOR; + + for (size_t b = 0; b < batch_size; ++b) { + const __fp16* input_row = input + b * dims; + __fp16* output_row = output + b * dims; + + float32x4_t sum_squares_vec[UNROLL_FACTOR * 2]; + for (size_t u = 0; u < UNROLL_FACTOR * 2; u++) { + sum_squares_vec[u] = vdupq_n_f32(0.0f); + } + + size_t i = 0; + const size_t tile_end = (dims >= TILE_SIZE) ? dims - TILE_SIZE + 1 : 0; + + for (; i < tile_end; i += TILE_SIZE) { + for (size_t u = 0; u < UNROLL_FACTOR; u++) { + float16x8_t input_vec = vld1q_f16(&input_row[i + u * SIMD_WIDTH]); + float32x4_t input_low = vcvt_f32_f16(vget_low_f16(input_vec)); + float32x4_t input_high = vcvt_f32_f16(vget_high_f16(input_vec)); + sum_squares_vec[u * 2] = + vfmaq_f32(sum_squares_vec[u * 2], input_low, input_low); + sum_squares_vec[u * 2 + 1] = + vfmaq_f32(sum_squares_vec[u * 2 + 1], input_high, input_high); + } + } + + const size_t simd_end = (dims >= SIMD_WIDTH) ? dims - SIMD_WIDTH + 1 : 0; + for (; i < simd_end; i += SIMD_WIDTH) { + float16x8_t input_vec = vld1q_f16(&input_row[i]); + float32x4_t input_low = vcvt_f32_f16(vget_low_f16(input_vec)); + float32x4_t input_high = vcvt_f32_f16(vget_high_f16(input_vec)); + sum_squares_vec[0] = vfmaq_f32(sum_squares_vec[0], input_low, input_low); + sum_squares_vec[1] = + vfmaq_f32(sum_squares_vec[1], input_high, input_high); + } + + float32x4_t total_sum = sum_squares_vec[0]; + for (size_t u = 1; u < UNROLL_FACTOR * 2; u++) { + total_sum = vaddq_f32(total_sum, sum_squares_vec[u]); + } + float sum_squares = vaddvq_f32(total_sum); + + for (; i < dims; ++i) { + float val = static_cast(input_row[i]); + sum_squares += val * val; + } + + float rms = sqrtf(sum_squares / static_cast(dims) + eps); + float inv_rms = 1.0f / rms; + float16x8_t inv_rms_vec = vdupq_n_f16(static_cast<__fp16>(inv_rms)); + + i = 0; + for (; i < tile_end; i += TILE_SIZE) { + for (size_t u = 0; u < UNROLL_FACTOR; u++) { + float16x8_t input_vec = vld1q_f16(&input_row[i + u * SIMD_WIDTH]); + float16x8_t weight_vec = vld1q_f16(&weight[i + u * SIMD_WIDTH]); + float16x8_t norm_vec = + vmulq_f16(vmulq_f16(input_vec, inv_rms_vec), weight_vec); + vst1q_f16(&output_row[i + u * SIMD_WIDTH], norm_vec); + } + } + + for (; i < simd_end; i += SIMD_WIDTH) { + float16x8_t input_vec = vld1q_f16(&input_row[i]); + float16x8_t weight_vec = vld1q_f16(&weight[i]); + float16x8_t norm_vec = + vmulq_f16(vmulq_f16(input_vec, inv_rms_vec), weight_vec); + vst1q_f16(&output_row[i], norm_vec); + } + + for (; i < dims; ++i) { + output_row[i] = + static_cast<__fp16>(static_cast(input_row[i]) * inv_rms * + static_cast(weight[i])); + } + } +} + +namespace CactusRoPE { + +struct RoPECache { + std::vector cos_table; + std::vector sin_table; + size_t max_seq_len; + size_t head_dim; + float theta; + bool initialized; + + RoPECache() : max_seq_len(0), head_dim(0), theta(0.0f), initialized(false) {} +}; + +static thread_local RoPECache rope_cache; + +void precompute_rope_tables(size_t seq_len, size_t head_dim, float theta) { + if (rope_cache.initialized && rope_cache.max_seq_len >= seq_len && + rope_cache.head_dim == head_dim && rope_cache.theta == theta) { + return; + } + + const size_t half_dim = head_dim / 2; + const size_t table_size = seq_len * half_dim; + + rope_cache.cos_table.resize(table_size); + rope_cache.sin_table.resize(table_size); + + for (size_t pos = 0; pos < seq_len; ++pos) { + const float pos_float = static_cast(pos); + for (size_t dim_idx = 0; dim_idx < half_dim; ++dim_idx) { + const float inv_freq = + 1.0f / powf(theta, 2.0f * static_cast(dim_idx) / + static_cast(head_dim)); + const float angle = pos_float * inv_freq; + const size_t cache_idx = pos * half_dim + dim_idx; + rope_cache.cos_table[cache_idx] = cosf(angle); + rope_cache.sin_table[cache_idx] = sinf(angle); + } + } + + rope_cache.max_seq_len = seq_len; + rope_cache.head_dim = head_dim; + rope_cache.theta = theta; + rope_cache.initialized = true; +} + +void kernel_rope_neon_optimized_head(const float* input_head, + float* output_head, const float* cos_cache, + const float* sin_cache, size_t seq_len, + size_t head_dim, size_t pairs_per_head) { + constexpr size_t SIMD_WIDTH = 4; + constexpr size_t UNROLL_FACTOR = 4; + constexpr size_t VECTORIZED_PAIRS = SIMD_WIDTH * UNROLL_FACTOR; + + const size_t pairs_vectorized = + (pairs_per_head / VECTORIZED_PAIRS) * VECTORIZED_PAIRS; + + for (size_t seq_idx = 0; seq_idx < seq_len; ++seq_idx) { + const float* x = input_head + seq_idx * head_dim; + float* y = output_head + seq_idx * head_dim; + const size_t cache_base = seq_idx * pairs_per_head; + + size_t pair_idx = 0; + + for (; pair_idx < pairs_vectorized; pair_idx += VECTORIZED_PAIRS) { + const size_t cache_offset = cache_base + pair_idx; + + float32x4_t cos_vec[UNROLL_FACTOR]; + float32x4_t sin_vec[UNROLL_FACTOR]; + float32x4x2_t input_pairs[UNROLL_FACTOR]; + + for (size_t u = 0; u < UNROLL_FACTOR; u++) { + cos_vec[u] = vld1q_f32(&cos_cache[cache_offset + u * SIMD_WIDTH]); + sin_vec[u] = vld1q_f32(&sin_cache[cache_offset + u * SIMD_WIDTH]); + input_pairs[u] = vld2q_f32(&x[2 * (pair_idx + u * SIMD_WIDTH)]); + } + + for (size_t u = 0; u < UNROLL_FACTOR; u++) { + float32x4_t x1_vec = input_pairs[u].val[0]; + float32x4_t x2_vec = input_pairs[u].val[1]; + + float32x4_t y1_vec = + vfmsq_f32(vmulq_f32(x1_vec, cos_vec[u]), x2_vec, sin_vec[u]); + float32x4_t y2_vec = + vfmaq_f32(vmulq_f32(x2_vec, cos_vec[u]), x1_vec, sin_vec[u]); + + float32x4x2_t output_pairs; + output_pairs.val[0] = y1_vec; + output_pairs.val[1] = y2_vec; + vst2q_f32(&y[2 * (pair_idx + u * SIMD_WIDTH)], output_pairs); + } + } + + for (; pair_idx < pairs_per_head; ++pair_idx) { + const size_t cache_offset = cache_base + pair_idx; + const float cos_val = cos_cache[cache_offset]; + const float sin_val = sin_cache[cache_offset]; + + const float x1 = x[2 * pair_idx]; + const float x2 = x[2 * pair_idx + 1]; + + y[2 * pair_idx] = x1 * cos_val - x2 * sin_val; + y[2 * pair_idx + 1] = x1 * sin_val + x2 * cos_val; + } + } +} + +} // namespace CactusRoPE + +void rope_f32(const float* input, float* output, size_t batch_size, + size_t seq_len, size_t num_heads, size_t head_dim, + size_t start_pos, float theta) { + const size_t half_dim = head_dim / 2; + + CactusRoPE::precompute_rope_tables(seq_len + start_pos, head_dim, theta); + + const float* cos_cache = + CactusRoPE::rope_cache.cos_table.data() + start_pos * half_dim; + const float* sin_cache = + CactusRoPE::rope_cache.sin_table.data() + start_pos * half_dim; + + aphrodite::mobile::parallel_for( + batch_size * seq_len, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [&](size_t start_idx, size_t end_idx) { + for (size_t idx = start_idx; idx < end_idx; ++idx) { + const size_t batch_idx = idx / seq_len; + const size_t seq_idx = idx % seq_len; + + for (size_t head_idx = 0; head_idx < num_heads; ++head_idx) { + const size_t offset = + ((batch_idx * seq_len + seq_idx) * num_heads + head_idx) * + head_dim; + const float* input_ptr = input + offset; + float* output_ptr = output + offset; + + const float* cos_ptr = cos_cache + seq_idx * half_dim; + const float* sin_ptr = sin_cache + seq_idx * half_dim; + + for (size_t i = 0; i < half_dim; ++i) { + const float cos_val = cos_ptr[i]; + const float sin_val = sin_ptr[i]; + + const float x_first_half = input_ptr[i]; + const float x_second_half = input_ptr[i + half_dim]; + + output_ptr[i] = x_first_half * cos_val - x_second_half * sin_val; + + output_ptr[i + half_dim] = + x_second_half * cos_val + x_first_half * sin_val; + } + } + } + }); +} + +void rope_i8_f32_i8(const int8_t* input, int8_t* output, size_t batch_size, + size_t seq_len, size_t num_heads, size_t head_dim, + size_t start_pos, float theta, float input_scale, + float output_scale) { + const size_t half_dim = head_dim / 2; + + CactusRoPE::precompute_rope_tables(seq_len + start_pos, head_dim, theta); + + const float* cos_cache = + CactusRoPE::rope_cache.cos_table.data() + start_pos * half_dim; + const float* sin_cache = + CactusRoPE::rope_cache.sin_table.data() + start_pos * half_dim; + + aphrodite::mobile::parallel_for( + batch_size * seq_len, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [&](size_t start_idx, size_t end_idx) { + for (size_t idx = start_idx; idx < end_idx; ++idx) { + const size_t batch_idx = idx / seq_len; + const size_t seq_idx = idx % seq_len; + + for (size_t head_idx = 0; head_idx < num_heads; ++head_idx) { + const size_t offset = + ((batch_idx * seq_len + seq_idx) * num_heads + head_idx) * + head_dim; + const int8_t* x = input + offset; + int8_t* y = output + offset; + const size_t cache_base = seq_idx * half_dim; + + for (size_t i = 0; i < half_dim; ++i) { + const float cos_val = cos_cache[cache_base + i]; + const float sin_val = sin_cache[cache_base + i]; + + const float x_first_half = static_cast(x[i]) * input_scale; + const float x_second_half = + static_cast(x[i + half_dim]) * input_scale; + + const float y0 = x_first_half * cos_val - x_second_half * sin_val; + const float y1 = x_second_half * cos_val + x_first_half * sin_val; + + const float scaled_y0 = y0 / output_scale; + const float scaled_y1 = y1 / output_scale; + + y[i] = static_cast( + std::max(-128.0f, std::min(127.0f, std::round(scaled_y0)))); + y[i + half_dim] = static_cast( + std::max(-128.0f, std::min(127.0f, std::round(scaled_y1)))); + } + } + } + }); +} + +namespace CactusRoPEF16 { + +struct RoPECacheF16 { + std::vector<__fp16> cos_table; + std::vector<__fp16> sin_table; + size_t max_seq_len; + size_t head_dim; + float theta; + bool initialized; + + RoPECacheF16() + : max_seq_len(0), head_dim(0), theta(0.0f), initialized(false) {} +}; + +static thread_local RoPECacheF16 rope_cache_f16; + +void precompute_rope_tables_f16(size_t seq_len, size_t head_dim, float theta) { + if (rope_cache_f16.initialized && rope_cache_f16.max_seq_len >= seq_len && + rope_cache_f16.head_dim == head_dim && rope_cache_f16.theta == theta) { + return; + } + + const size_t half_dim = head_dim / 2; + const size_t table_size = seq_len * half_dim; + + rope_cache_f16.cos_table.resize(table_size); + rope_cache_f16.sin_table.resize(table_size); + + for (size_t pos = 0; pos < seq_len; ++pos) { + const float pos_float = static_cast(pos); + for (size_t i = 0; i < half_dim; ++i) { + const float freq = 1.0f / powf(theta, (2.0f * i) / head_dim); + const float angle = pos_float * freq; + + const size_t idx = pos * half_dim + i; + rope_cache_f16.cos_table[idx] = static_cast<__fp16>(cosf(angle)); + rope_cache_f16.sin_table[idx] = static_cast<__fp16>(sinf(angle)); + } + } + + rope_cache_f16.max_seq_len = seq_len; + rope_cache_f16.head_dim = head_dim; + rope_cache_f16.theta = theta; + rope_cache_f16.initialized = true; +} + +} // namespace CactusRoPEF16 + +void rope_f16(const __fp16* input, __fp16* output, size_t batch_size, + size_t seq_len, size_t num_heads, size_t head_dim, + size_t start_pos, float theta) { + const size_t half_dim = head_dim / 2; + + CactusRoPEF16::precompute_rope_tables_f16(seq_len + start_pos, head_dim, + theta); + + const __fp16* cos_cache = + CactusRoPEF16::rope_cache_f16.cos_table.data() + start_pos * half_dim; + const __fp16* sin_cache = + CactusRoPEF16::rope_cache_f16.sin_table.data() + start_pos * half_dim; + + aphrodite::mobile::parallel_for( + batch_size * seq_len, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [&](size_t start_idx, size_t end_idx) { + for (size_t idx = start_idx; idx < end_idx; ++idx) { + const size_t batch_idx = idx / seq_len; + const size_t seq_idx = idx % seq_len; + + for (size_t head_idx = 0; head_idx < num_heads; ++head_idx) { + const size_t offset = + ((batch_idx * seq_len + seq_idx) * num_heads + head_idx) * + head_dim; + const __fp16* input_ptr = input + offset; + __fp16* output_ptr = output + offset; + + const __fp16* cos_ptr = cos_cache + seq_idx * half_dim; + const __fp16* sin_ptr = sin_cache + seq_idx * half_dim; + + constexpr size_t SIMD_WIDTH = 8; + const size_t vectorized_half_dim = + (half_dim / SIMD_WIDTH) * SIMD_WIDTH; + + for (size_t i = 0; i < vectorized_half_dim; i += SIMD_WIDTH) { + float16x8_t cos_vec = vld1q_f16(&cos_ptr[i]); + float16x8_t sin_vec = vld1q_f16(&sin_ptr[i]); + + float16x8_t x_first_half = vld1q_f16(&input_ptr[i]); + float16x8_t x_second_half = vld1q_f16(&input_ptr[i + half_dim]); + + float16x8_t first_result = vfmsq_f16( + vmulq_f16(x_first_half, cos_vec), x_second_half, sin_vec); + float16x8_t second_result = vfmaq_f16( + vmulq_f16(x_second_half, cos_vec), x_first_half, sin_vec); + + vst1q_f16(&output_ptr[i], first_result); + vst1q_f16(&output_ptr[i + half_dim], second_result); + } + + for (size_t i = vectorized_half_dim; i < half_dim; ++i) { + const __fp16 cos_val = cos_ptr[i]; + const __fp16 sin_val = sin_ptr[i]; + + const __fp16 x_first_half = input_ptr[i]; + const __fp16 x_second_half = input_ptr[i + half_dim]; + + output_ptr[i] = x_first_half * cos_val - x_second_half * sin_val; + + output_ptr[i + half_dim] = + x_second_half * cos_val + x_first_half * sin_val; + } + } + } + }); +} +} // namespace aphrodite::mobile diff --git a/aphrodite_kernels/csrc/cpu/mobile/blas.cpp b/aphrodite_kernels/csrc/cpu/mobile/blas.cpp new file mode 100644 index 0000000000..ac5b7c2638 --- /dev/null +++ b/aphrodite_kernels/csrc/cpu/mobile/blas.cpp @@ -0,0 +1,1044 @@ +#include "threading.hpp" +#include +#include +#include +#include +#include + +namespace aphrodite::mobile { + +static inline size_t compute_linear_index(const size_t* coords, + const size_t* strides, size_t ndim) { + size_t index = 0; + for (size_t i = 0; i < ndim; ++i) { + index += coords[i] * strides[i]; + } + return index; +} + +static inline void increment_coords(size_t* coords, const size_t* shape, + size_t ndim) { + for (int i = ndim - 1; i >= 0; --i) { + coords[i]++; + if (coords[i] < shape[i]) { + break; + } + coords[i] = 0; + } +} + +void add_int8(const int8_t* a, const int8_t* b, int8_t* output, + size_t num_elements) { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [&](size_t start_idx, size_t end_idx) { + const size_t vectorized_end = + ((end_idx - start_idx) / NEON_VECTOR_SIZE) * NEON_VECTOR_SIZE + + start_idx; + + for (size_t i = start_idx; i < vectorized_end; i += NEON_VECTOR_SIZE) { + int8x16_t a_vec = vld1q_s8(&a[i]); + int8x16_t b_vec = vld1q_s8(&b[i]); + + int16x8_t a_low = vmovl_s8(vget_low_s8(a_vec)); + int16x8_t a_high = vmovl_s8(vget_high_s8(a_vec)); + int16x8_t b_low = vmovl_s8(vget_low_s8(b_vec)); + int16x8_t b_high = vmovl_s8(vget_high_s8(b_vec)); + + int16x8_t result_low = vaddq_s16(a_low, b_low); + int16x8_t result_high = vaddq_s16(a_high, b_high); + + int8x16_t result_vec = + vcombine_s8(vqmovn_s16(result_low), vqmovn_s16(result_high)); + vst1q_s8(&output[i], result_vec); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + int32_t sum = static_cast(a[i]) + static_cast(b[i]); + output[i] = clamp_to_int8(sum); + } + }); +} + +void subtract_int8(const int8_t* a, const int8_t* b, int8_t* output, + size_t num_elements) { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [&](size_t start_idx, size_t end_idx) { + const size_t vectorized_end = + ((end_idx - start_idx) / NEON_VECTOR_SIZE) * NEON_VECTOR_SIZE + + start_idx; + + for (size_t i = start_idx; i < vectorized_end; i += NEON_VECTOR_SIZE) { + int8x16_t a_vec = vld1q_s8(&a[i]); + int8x16_t b_vec = vld1q_s8(&b[i]); + + int16x8_t a_low = vmovl_s8(vget_low_s8(a_vec)); + int16x8_t a_high = vmovl_s8(vget_high_s8(a_vec)); + int16x8_t b_low = vmovl_s8(vget_low_s8(b_vec)); + int16x8_t b_high = vmovl_s8(vget_high_s8(b_vec)); + + int16x8_t result_low = vsubq_s16(a_low, b_low); + int16x8_t result_high = vsubq_s16(a_high, b_high); + + int8x16_t result_vec = + vcombine_s8(vqmovn_s16(result_low), vqmovn_s16(result_high)); + vst1q_s8(&output[i], result_vec); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + int32_t diff = + static_cast(a[i]) - static_cast(b[i]); + output[i] = clamp_to_int8(diff); + } + }); +} + +void multiply_int8(const int8_t* a, const int8_t* b, int8_t* output, + size_t num_elements) { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [&](size_t start_idx, size_t end_idx) { + const size_t vectorized_end = + ((end_idx - start_idx) / NEON_VECTOR_SIZE) * NEON_VECTOR_SIZE + + start_idx; + + for (size_t i = start_idx; i < vectorized_end; i += NEON_VECTOR_SIZE) { + int8x16_t a_vec = vld1q_s8(&a[i]); + int8x16_t b_vec = vld1q_s8(&b[i]); + + int16x8_t a_low = vmovl_s8(vget_low_s8(a_vec)); + int16x8_t a_high = vmovl_s8(vget_high_s8(a_vec)); + int16x8_t b_low = vmovl_s8(vget_low_s8(b_vec)); + int16x8_t b_high = vmovl_s8(vget_high_s8(b_vec)); + + int16x8_t result_low = vmulq_s16(a_low, b_low); + int16x8_t result_high = vmulq_s16(a_high, b_high); + + int8x16_t result_vec = + vcombine_s8(vqmovn_s16(result_low), vqmovn_s16(result_high)); + vst1q_s8(&output[i], result_vec); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + int32_t product = + static_cast(a[i]) * static_cast(b[i]); + output[i] = clamp_to_int8(product); + } + }); +} + +void divide_int8(const int8_t* a, const int8_t* b, int8_t* output, + size_t num_elements) { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [&](size_t start_idx, size_t end_idx) { + for (size_t i = start_idx; i < end_idx; ++i) { + if (b[i] == 0) { + output[i] = (a[i] >= 0) ? 127 : -128; + } else { + float result = static_cast(a[i]) / static_cast(b[i]); + output[i] = clamp_to_int8(result); + } + } + }); +} + +void add_f32(const float* a, const float* b, float* output, + size_t num_elements) { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float32x4_t a_vec = vld1q_f32(&a[i]); + float32x4_t b_vec = vld1q_f32(&b[i]); + float32x4_t result_vec = vaddq_f32(a_vec, b_vec); + vst1q_f32(&output[i], result_vec); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + output[i] = a[i] + b[i]; + } + }); +} + +void subtract_f32(const float* a, const float* b, float* output, + size_t num_elements) { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float32x4_t a_vec = vld1q_f32(&a[i]); + float32x4_t b_vec = vld1q_f32(&b[i]); + float32x4_t result_vec = vsubq_f32(a_vec, b_vec); + vst1q_f32(&output[i], result_vec); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + output[i] = a[i] - b[i]; + } + }); +} + +void multiply_f32(const float* a, const float* b, float* output, + size_t num_elements) { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float32x4_t a_vec = vld1q_f32(&a[i]); + float32x4_t b_vec = vld1q_f32(&b[i]); + float32x4_t result_vec = vmulq_f32(a_vec, b_vec); + vst1q_f32(&output[i], result_vec); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + output[i] = a[i] * b[i]; + } + }); +} + +void divide_f32(const float* a, const float* b, float* output, + size_t num_elements) { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float32x4_t a_vec = vld1q_f32(&a[i]); + float32x4_t b_vec = vld1q_f32(&b[i]); + float32x4_t result_vec = vdivq_f32(a_vec, b_vec); + vst1q_f32(&output[i], result_vec); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + output[i] = a[i] / b[i]; + } + }); +} + +void add_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output, + const size_t* a_strides, const size_t* b_strides, + const size_t* output_shape, size_t ndim) { + size_t total_elements = 1; + for (size_t i = 0; i < ndim; ++i) { + total_elements *= output_shape[i]; + } + + std::vector coords(ndim, 0); + + for (size_t linear_idx = 0; linear_idx < total_elements; ++linear_idx) { + size_t a_idx = compute_linear_index(coords.data(), a_strides, ndim); + size_t b_idx = compute_linear_index(coords.data(), b_strides, ndim); + + output[linear_idx] = clamp_to_int8(static_cast(a[a_idx]) + + static_cast(b[b_idx])); + + increment_coords(coords.data(), output_shape, ndim); + } +} + +void subtract_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output, + const size_t* a_strides, const size_t* b_strides, + const size_t* output_shape, size_t ndim) { + size_t total_elements = 1; + for (size_t i = 0; i < ndim; ++i) { + total_elements *= output_shape[i]; + } + + std::vector coords(ndim, 0); + + for (size_t linear_idx = 0; linear_idx < total_elements; ++linear_idx) { + size_t a_idx = compute_linear_index(coords.data(), a_strides, ndim); + size_t b_idx = compute_linear_index(coords.data(), b_strides, ndim); + + output[linear_idx] = clamp_to_int8(static_cast(a[a_idx]) - + static_cast(b[b_idx])); + + increment_coords(coords.data(), output_shape, ndim); + } +} + +void multiply_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output, + const size_t* a_strides, const size_t* b_strides, + const size_t* output_shape, size_t ndim) { + size_t total_elements = 1; + for (size_t i = 0; i < ndim; ++i) { + total_elements *= output_shape[i]; + } + + std::vector coords(ndim, 0); + + for (size_t linear_idx = 0; linear_idx < total_elements; ++linear_idx) { + size_t a_idx = compute_linear_index(coords.data(), a_strides, ndim); + size_t b_idx = compute_linear_index(coords.data(), b_strides, ndim); + + output[linear_idx] = clamp_to_int8(static_cast(a[a_idx]) * + static_cast(b[b_idx])); + + increment_coords(coords.data(), output_shape, ndim); + } +} + +void divide_broadcast_int8(const int8_t* a, const int8_t* b, int8_t* output, + const size_t* a_strides, const size_t* b_strides, + const size_t* output_shape, size_t ndim) { + size_t total_elements = 1; + for (size_t i = 0; i < ndim; ++i) { + total_elements *= output_shape[i]; + } + + std::vector coords(ndim, 0); + + for (size_t linear_idx = 0; linear_idx < total_elements; ++linear_idx) { + size_t a_idx = compute_linear_index(coords.data(), a_strides, ndim); + size_t b_idx = compute_linear_index(coords.data(), b_strides, ndim); + + int32_t b_val = static_cast(b[b_idx]); + if (b_val == 0) { + output[linear_idx] = 0; + } else { + output[linear_idx] = + clamp_to_int8(static_cast(a[a_idx]) / b_val); + } + + increment_coords(coords.data(), output_shape, ndim); + } +} + +void add_broadcast_f32(const float* a, const float* b, float* output, + const size_t* a_strides, const size_t* b_strides, + const size_t* output_shape, size_t ndim) { + size_t total_elements = 1; + for (size_t i = 0; i < ndim; ++i) { + total_elements *= output_shape[i]; + } + + std::vector coords(ndim, 0); + + for (size_t linear_idx = 0; linear_idx < total_elements; ++linear_idx) { + size_t a_idx = compute_linear_index(coords.data(), a_strides, ndim); + size_t b_idx = compute_linear_index(coords.data(), b_strides, ndim); + + output[linear_idx] = a[a_idx] + b[b_idx]; + + increment_coords(coords.data(), output_shape, ndim); + } +} + +void subtract_broadcast_f32(const float* a, const float* b, float* output, + const size_t* a_strides, const size_t* b_strides, + const size_t* output_shape, size_t ndim) { + size_t total_elements = 1; + for (size_t i = 0; i < ndim; ++i) { + total_elements *= output_shape[i]; + } + + std::vector coords(ndim, 0); + + for (size_t linear_idx = 0; linear_idx < total_elements; ++linear_idx) { + size_t a_idx = compute_linear_index(coords.data(), a_strides, ndim); + size_t b_idx = compute_linear_index(coords.data(), b_strides, ndim); + + output[linear_idx] = a[a_idx] - b[b_idx]; + + increment_coords(coords.data(), output_shape, ndim); + } +} + +void multiply_broadcast_f32(const float* a, const float* b, float* output, + const size_t* a_strides, const size_t* b_strides, + const size_t* output_shape, size_t ndim) { + size_t total_elements = 1; + for (size_t i = 0; i < ndim; ++i) { + total_elements *= output_shape[i]; + } + + std::vector coords(ndim, 0); + + for (size_t linear_idx = 0; linear_idx < total_elements; ++linear_idx) { + size_t a_idx = compute_linear_index(coords.data(), a_strides, ndim); + size_t b_idx = compute_linear_index(coords.data(), b_strides, ndim); + + output[linear_idx] = a[a_idx] * b[b_idx]; + + increment_coords(coords.data(), output_shape, ndim); + } +} + +void divide_broadcast_f32(const float* a, const float* b, float* output, + const size_t* a_strides, const size_t* b_strides, + const size_t* output_shape, size_t ndim) { + size_t total_elements = 1; + for (size_t i = 0; i < ndim; ++i) { + total_elements *= output_shape[i]; + } + + std::vector coords(ndim, 0); + + for (size_t linear_idx = 0; linear_idx < total_elements; ++linear_idx) { + size_t a_idx = compute_linear_index(coords.data(), a_strides, ndim); + size_t b_idx = compute_linear_index(coords.data(), b_strides, ndim); + + output[linear_idx] = a[a_idx] / b[b_idx]; + + increment_coords(coords.data(), output_shape, ndim); + } +} + +void add_f16(const __fp16* a, const __fp16* b, __fp16* output, + size_t num_elements) { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 8; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float16x8_t a_vec = vld1q_f16(&a[i]); + float16x8_t b_vec = vld1q_f16(&b[i]); + float16x8_t result_vec = vaddq_f16(a_vec, b_vec); + vst1q_f16(&output[i], result_vec); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + output[i] = a[i] + b[i]; + } + }); +} + +void add_f16_clipped(const __fp16* a, const __fp16* b, __fp16* output, + size_t num_elements) { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 8; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + constexpr float FP16_MAX = 65500.0f; + const float32x4_t max_val = vdupq_n_f32(FP16_MAX); + const float32x4_t min_val = vdupq_n_f32(-FP16_MAX); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float16x8_t a_vec = vld1q_f16(&a[i]); + float16x8_t b_vec = vld1q_f16(&b[i]); + + float32x4_t a_low = vcvt_f32_f16(vget_low_f16(a_vec)); + float32x4_t a_high = vcvt_f32_f16(vget_high_f16(a_vec)); + float32x4_t b_low = vcvt_f32_f16(vget_low_f16(b_vec)); + float32x4_t b_high = vcvt_f32_f16(vget_high_f16(b_vec)); + + float32x4_t result_low = vaddq_f32(a_low, b_low); + float32x4_t result_high = vaddq_f32(a_high, b_high); + + result_low = vminq_f32(vmaxq_f32(result_low, min_val), max_val); + result_high = vminq_f32(vmaxq_f32(result_high, min_val), max_val); + + float16x4_t result_low_f16 = vcvt_f16_f32(result_low); + float16x4_t result_high_f16 = vcvt_f16_f32(result_high); + float16x8_t result_vec = + vcombine_f16(result_low_f16, result_high_f16); + + vst1q_f16(&output[i], result_vec); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + float result = static_cast(a[i]) + static_cast(b[i]); + result = std::fmin(std::fmax(result, -FP16_MAX), FP16_MAX); + output[i] = static_cast<__fp16>(result); + } + }); +} + +void subtract_f16(const __fp16* a, const __fp16* b, __fp16* output, + size_t num_elements) { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 8; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float16x8_t a_vec = vld1q_f16(&a[i]); + float16x8_t b_vec = vld1q_f16(&b[i]); + float16x8_t result_vec = vsubq_f16(a_vec, b_vec); + vst1q_f16(&output[i], result_vec); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + output[i] = a[i] - b[i]; + } + }); +} + +void multiply_f16(const __fp16* a, const __fp16* b, __fp16* output, + size_t num_elements) { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 8; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float16x8_t a_vec = vld1q_f16(&a[i]); + float16x8_t b_vec = vld1q_f16(&b[i]); + float16x8_t result_vec = vmulq_f16(a_vec, b_vec); + vst1q_f16(&output[i], result_vec); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + output[i] = a[i] * b[i]; + } + }); +} + +void divide_f16(const __fp16* a, const __fp16* b, __fp16* output, + size_t num_elements) { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 8; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float16x8_t a_vec = vld1q_f16(&a[i]); + float16x8_t b_vec = vld1q_f16(&b[i]); + float16x8_t result_vec = vdivq_f16(a_vec, b_vec); + vst1q_f16(&output[i], result_vec); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + output[i] = a[i] / b[i]; + } + }); +} + +void add_broadcast_f16(const __fp16* a, const __fp16* b, __fp16* output, + const size_t* a_strides, const size_t* b_strides, + const size_t* output_shape, size_t ndim) { + size_t total_elements = 1; + for (size_t i = 0; i < ndim; ++i) { + total_elements *= output_shape[i]; + } + + std::vector coords(ndim, 0); + constexpr float FP16_MAX = 65500.0f; + + for (size_t linear_idx = 0; linear_idx < total_elements; ++linear_idx) { + size_t a_idx = compute_linear_index(coords.data(), a_strides, ndim); + size_t b_idx = compute_linear_index(coords.data(), b_strides, ndim); + + float result = static_cast(a[a_idx]) + static_cast(b[b_idx]); + result = std::fmin(std::fmax(result, -FP16_MAX), FP16_MAX); + output[linear_idx] = static_cast<__fp16>(result); + + increment_coords(coords.data(), output_shape, ndim); + } +} + +void subtract_broadcast_f16(const __fp16* a, const __fp16* b, __fp16* output, + const size_t* a_strides, const size_t* b_strides, + const size_t* output_shape, size_t ndim) { + size_t total_elements = 1; + for (size_t i = 0; i < ndim; ++i) { + total_elements *= output_shape[i]; + } + + std::vector coords(ndim, 0); + + for (size_t linear_idx = 0; linear_idx < total_elements; ++linear_idx) { + size_t a_idx = compute_linear_index(coords.data(), a_strides, ndim); + size_t b_idx = compute_linear_index(coords.data(), b_strides, ndim); + + output[linear_idx] = a[a_idx] - b[b_idx]; + + increment_coords(coords.data(), output_shape, ndim); + } +} + +void multiply_broadcast_f16(const __fp16* a, const __fp16* b, __fp16* output, + const size_t* a_strides, const size_t* b_strides, + const size_t* output_shape, size_t ndim) { + size_t total_elements = 1; + for (size_t i = 0; i < ndim; ++i) { + total_elements *= output_shape[i]; + } + + std::vector coords(ndim, 0); + + for (size_t linear_idx = 0; linear_idx < total_elements; ++linear_idx) { + size_t a_idx = compute_linear_index(coords.data(), a_strides, ndim); + size_t b_idx = compute_linear_index(coords.data(), b_strides, ndim); + + output[linear_idx] = a[a_idx] * b[b_idx]; + + increment_coords(coords.data(), output_shape, ndim); + } +} + +void divide_broadcast_f16(const __fp16* a, const __fp16* b, __fp16* output, + const size_t* a_strides, const size_t* b_strides, + const size_t* output_shape, size_t ndim) { + size_t total_elements = 1; + for (size_t i = 0; i < ndim; ++i) { + total_elements *= output_shape[i]; + } + + std::vector coords(ndim, 0); + + for (size_t linear_idx = 0; linear_idx < total_elements; ++linear_idx) { + size_t a_idx = compute_linear_index(coords.data(), a_strides, ndim); + size_t b_idx = compute_linear_index(coords.data(), b_strides, ndim); + + output[linear_idx] = a[a_idx] / b[b_idx]; + + increment_coords(coords.data(), output_shape, ndim); + } +} + +void transpose_f32(const float* source, float* destination, const size_t* shape, + const size_t* permutation, size_t ndim, size_t start_idx, + size_t end_idx) { + if (ndim == 2 && permutation[0] == 1 && permutation[1] == 0) { + size_t num_rows = shape[0]; + size_t num_cols = shape[1]; + + constexpr size_t THRESHOLD = 8192; + constexpr size_t TILE_ROWS = 32; + if (num_rows * num_cols >= THRESHOLD) { + const size_t num_row_blocks = (num_rows + TILE_ROWS - 1) / TILE_ROWS; + + aphrodite::mobile::parallel_for( + num_row_blocks, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [=](size_t start_block, size_t end_block) { + for (size_t block_idx = start_block; block_idx < end_block; + ++block_idx) { + size_t start_row = block_idx * TILE_ROWS; + size_t end_row = std::min(start_row + TILE_ROWS, num_rows); + + transpose_2d_f32(source, destination, num_rows, num_cols, + start_row, end_row); + } + }); + } else { + cactus_transpose_2d_f32(source, destination, num_rows, num_cols, 0, + num_rows); + } + } else { + for (size_t idx = start_idx; idx < end_idx; ++idx) { + size_t src_idx = 0; + size_t tmp_idx = idx; + + for (size_t i = 0; i < ndim; ++i) { + size_t coord = tmp_idx % shape[permutation[ndim - 1 - i]]; + tmp_idx /= shape[permutation[ndim - 1 - i]]; + + size_t stride = 1; + for (size_t j = permutation[ndim - 1 - i] + 1; j < ndim; ++j) { + stride *= shape[j]; + } + src_idx += coord * stride; + } + + destination[idx] = source[src_idx]; + } + } +} + +void concat_f32(const float* input1, const float* input2, float* output, + const size_t* shape1, const size_t* shape2, + const size_t* output_shape, size_t ndims, int axis) { + if (axis < 0) axis += ndims; + + size_t outer_size = 1; + for (size_t i = 0; i < static_cast(axis); ++i) { + outer_size *= output_shape[i]; + } + + size_t inner_size = 1; + for (size_t i = axis + 1; i < ndims; ++i) { + inner_size *= output_shape[i]; + } + + size_t axis_size1 = shape1[axis]; + size_t axis_size2 = shape2[axis]; + + size_t input1_stride = axis_size1 * inner_size; + size_t input2_stride = axis_size2 * inner_size; + size_t output_stride = (axis_size1 + axis_size2) * inner_size; + + aphrodite::mobile::parallel_for( + outer_size, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [&](size_t start, size_t end) { + for (size_t outer = start; outer < end; ++outer) { + const float* in1_ptr = input1 + outer * input1_stride; + const float* in2_ptr = input2 + outer * input2_stride; + float* out_ptr = output + outer * output_stride; + + size_t copy_size1 = axis_size1 * inner_size; + std::memcpy(out_ptr, in1_ptr, copy_size1 * sizeof(float)); + + size_t copy_size2 = axis_size2 * inner_size; + std::memcpy(out_ptr + copy_size1, in2_ptr, + copy_size2 * sizeof(float)); + } + }); +} + +void concat_f16(const __fp16* input1, const __fp16* input2, __fp16* output, + const size_t* shape1, const size_t* shape2, + const size_t* output_shape, size_t ndims, int axis) { + if (axis < 0) axis += ndims; + + size_t outer_size = 1; + for (size_t i = 0; i < static_cast(axis); ++i) { + outer_size *= output_shape[i]; + } + + size_t inner_size = 1; + for (size_t i = axis + 1; i < ndims; ++i) { + inner_size *= output_shape[i]; + } + + size_t axis_size1 = shape1[axis]; + size_t axis_size2 = shape2[axis]; + + size_t input1_stride = axis_size1 * inner_size; + size_t input2_stride = axis_size2 * inner_size; + size_t output_stride = (axis_size1 + axis_size2) * inner_size; + + aphrodite::mobile::parallel_for( + outer_size, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [&](size_t start, size_t end) { + for (size_t outer = start; outer < end; ++outer) { + const __fp16* in1_ptr = input1 + outer * input1_stride; + const __fp16* in2_ptr = input2 + outer * input2_stride; + __fp16* out_ptr = output + outer * output_stride; + + size_t copy_size1 = axis_size1 * inner_size; + std::memcpy(out_ptr, in1_ptr, copy_size1 * sizeof(__fp16)); + + size_t copy_size2 = axis_size2 * inner_size; + std::memcpy(out_ptr + copy_size1, in2_ptr, + copy_size2 * sizeof(__fp16)); + } + }); +} + +void concat_int8(const int8_t* input1, const int8_t* input2, int8_t* output, + const size_t* shape1, const size_t* shape2, + const size_t* output_shape, size_t ndims, int axis) { + if (axis < 0) axis += ndims; + + size_t outer_size = 1; + for (size_t i = 0; i < static_cast(axis); ++i) { + outer_size *= output_shape[i]; + } + + size_t inner_size = 1; + for (size_t i = axis + 1; i < ndims; ++i) { + inner_size *= output_shape[i]; + } + + size_t axis_size1 = shape1[axis]; + size_t axis_size2 = shape2[axis]; + + size_t input1_stride = axis_size1 * inner_size; + size_t input2_stride = axis_size2 * inner_size; + size_t output_stride = (axis_size1 + axis_size2) * inner_size; + + aphrodite::mobile::parallel_for( + outer_size, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [&](size_t start, size_t end) { + for (size_t outer = start; outer < end; ++outer) { + const int8_t* in1_ptr = input1 + outer * input1_stride; + const int8_t* in2_ptr = input2 + outer * input2_stride; + int8_t* out_ptr = output + outer * output_stride; + + size_t copy_size1 = axis_size1 * inner_size; + std::memcpy(out_ptr, in1_ptr, copy_size1 * sizeof(int8_t)); + + size_t copy_size2 = axis_size2 * inner_size; + std::memcpy(out_ptr + copy_size1, in2_ptr, + copy_size2 * sizeof(int8_t)); + } + }); +} + +void transpose_2d_int8(const int8_t* source, int8_t* destination, + size_t num_rows, size_t num_cols, size_t start_row, + size_t end_row) { + constexpr size_t TILE_SIZE = 64; + constexpr size_t VECTOR_WIDTH = 16; + + for (size_t row_tile_start = start_row; row_tile_start < end_row; + row_tile_start += TILE_SIZE) { + const size_t row_tile_end = std::min(row_tile_start + TILE_SIZE, end_row); + + for (size_t col_tile_start = 0; col_tile_start < num_cols; + col_tile_start += TILE_SIZE) { + const size_t col_tile_end = + std::min(col_tile_start + TILE_SIZE, num_cols); + + for (size_t row_block = row_tile_start; row_block < row_tile_end; + row_block += VECTOR_WIDTH) { + const size_t row_block_end = + std::min(row_block + VECTOR_WIDTH, row_tile_end); + + for (size_t col_block = col_tile_start; col_block < col_tile_end; + col_block += VECTOR_WIDTH) { + const size_t col_block_end = + std::min(col_block + VECTOR_WIDTH, col_tile_end); + + if (row_block_end - row_block >= 8 && + col_block_end - col_block >= 8) { + int8x8_t rows[8]; + for (int i = 0; i < 8; i++) { + if (row_block + i < row_block_end) { + rows[i] = + vld1_s8(&source[(row_block + i) * num_cols + col_block]); + } else { + rows[i] = vdup_n_s8(0); + } + } + + int8x8x2_t r01 = vtrn_s8(rows[0], rows[1]); + int8x8x2_t r23 = vtrn_s8(rows[2], rows[3]); + int8x8x2_t r45 = vtrn_s8(rows[4], rows[5]); + int8x8x2_t r67 = vtrn_s8(rows[6], rows[7]); + + int16x4x2_t r0123_low = vtrn_s16(vreinterpret_s16_s8(r01.val[0]), + vreinterpret_s16_s8(r23.val[0])); + int16x4x2_t r0123_high = vtrn_s16(vreinterpret_s16_s8(r01.val[1]), + vreinterpret_s16_s8(r23.val[1])); + int16x4x2_t r4567_low = vtrn_s16(vreinterpret_s16_s8(r45.val[0]), + vreinterpret_s16_s8(r67.val[0])); + int16x4x2_t r4567_high = vtrn_s16(vreinterpret_s16_s8(r45.val[1]), + vreinterpret_s16_s8(r67.val[1])); + + int32x2x2_t final_0123 = + vtrn_s32(vreinterpret_s32_s16(r0123_low.val[0]), + vreinterpret_s32_s16(r4567_low.val[0])); + int32x2x2_t final_4567 = + vtrn_s32(vreinterpret_s32_s16(r0123_low.val[1]), + vreinterpret_s32_s16(r4567_low.val[1])); + int32x2x2_t final_89AB = + vtrn_s32(vreinterpret_s32_s16(r0123_high.val[0]), + vreinterpret_s32_s16(r4567_high.val[0])); + int32x2x2_t final_CDEF = + vtrn_s32(vreinterpret_s32_s16(r0123_high.val[1]), + vreinterpret_s32_s16(r4567_high.val[1])); + + int8x8_t transposed[8] = {vreinterpret_s8_s32(final_0123.val[0]), + vreinterpret_s8_s32(final_4567.val[0]), + vreinterpret_s8_s32(final_0123.val[1]), + vreinterpret_s8_s32(final_4567.val[1]), + vreinterpret_s8_s32(final_89AB.val[0]), + vreinterpret_s8_s32(final_CDEF.val[0]), + vreinterpret_s8_s32(final_89AB.val[1]), + vreinterpret_s8_s32(final_CDEF.val[1])}; + + for (int col = 0; col < 8 && col_block + col < col_block_end; + col++) { + if (col_block + col < num_cols) { + vst1_s8(&destination[(col_block + col) * num_rows + row_block], + transposed[col]); + } + } + } else { + for (size_t row = row_block; row < row_block_end; row++) { + for (size_t col = col_block; col < col_block_end; col++) { + destination[col * num_rows + row] = + source[row * num_cols + col]; + } + } + } + } + } + } + } +} + +void transpose_2d_f32(const float* source, float* destination, size_t num_rows, + size_t num_cols, size_t start_row, size_t end_row) { + constexpr size_t TILE_SIZE = 32; + constexpr size_t VECTOR_WIDTH = 4; + + for (size_t row_tile_start = start_row; row_tile_start < end_row; + row_tile_start += TILE_SIZE) { + const size_t row_tile_end = std::min(row_tile_start + TILE_SIZE, end_row); + + for (size_t col_tile_start = 0; col_tile_start < num_cols; + col_tile_start += TILE_SIZE) { + const size_t col_tile_end = + std::min(col_tile_start + TILE_SIZE, num_cols); + + for (size_t row_block = row_tile_start; row_block < row_tile_end; + row_block += VECTOR_WIDTH) { + const size_t row_block_end = + std::min(row_block + VECTOR_WIDTH, row_tile_end); + + for (size_t col_block = col_tile_start; col_block < col_tile_end; + col_block += VECTOR_WIDTH) { + const size_t col_block_end = + std::min(col_block + VECTOR_WIDTH, col_tile_end); + + if (row_block_end - row_block >= 4 && + col_block_end - col_block >= 4) { + float32x4_t rows[4]; + for (int i = 0; i < 4; i++) { + if (row_block + i < row_block_end) { + rows[i] = + vld1q_f32(&source[(row_block + i) * num_cols + col_block]); + } else { + rows[i] = vdupq_n_f32(0.0f); + } + } + + float32x4x2_t r01 = vtrnq_f32(rows[0], rows[1]); + float32x4x2_t r23 = vtrnq_f32(rows[2], rows[3]); + + float32x4_t col0 = vcombine_f32(vget_low_f32(r01.val[0]), + vget_low_f32(r23.val[0])); + float32x4_t col1 = vcombine_f32(vget_low_f32(r01.val[1]), + vget_low_f32(r23.val[1])); + float32x4_t col2 = vcombine_f32(vget_high_f32(r01.val[0]), + vget_high_f32(r23.val[0])); + float32x4_t col3 = vcombine_f32(vget_high_f32(r01.val[1]), + vget_high_f32(r23.val[1])); + + if (col_block + 0 < num_cols) { + if (row_block_end - row_block >= 4) { + vst1q_f32(&destination[(col_block + 0) * num_rows + row_block], + col0); + } else { + float temp[4]; + vst1q_f32(temp, col0); + for (size_t i = 0; i < row_block_end - row_block; ++i) { + destination[(col_block + 0) * num_rows + row_block + i] = + temp[i]; + } + } + } + if (col_block + 1 < num_cols) { + if (row_block_end - row_block >= 4) { + vst1q_f32(&destination[(col_block + 1) * num_rows + row_block], + col1); + } else { + float temp[4]; + vst1q_f32(temp, col1); + for (size_t i = 0; i < row_block_end - row_block; ++i) { + destination[(col_block + 1) * num_rows + row_block + i] = + temp[i]; + } + } + } + if (col_block + 2 < num_cols) { + if (row_block_end - row_block >= 4) { + vst1q_f32(&destination[(col_block + 2) * num_rows + row_block], + col2); + } else { + float temp[4]; + vst1q_f32(temp, col2); + for (size_t i = 0; i < row_block_end - row_block; ++i) { + destination[(col_block + 2) * num_rows + row_block + i] = + temp[i]; + } + } + } + if (col_block + 3 < num_cols) { + if (row_block_end - row_block >= 4) { + vst1q_f32(&destination[(col_block + 3) * num_rows + row_block], + col3); + } else { + float temp[4]; + vst1q_f32(temp, col3); + for (size_t i = 0; i < row_block_end - row_block; ++i) { + destination[(col_block + 3) * num_rows + row_block + i] = + temp[i]; + } + } + } + } else { + for (size_t row = row_block; row < row_block_end; row++) { + for (size_t col = col_block; col < col_block_end; col++) { + destination[col * num_rows + row] = + source[row * num_cols + col]; + } + } + } + } + } + } + } +} + +void transpose_int8(const int8_t* source, int8_t* destination, + const size_t* shape, const size_t* permutation, size_t ndim, + size_t start_idx, size_t end_idx) { + if (ndim == 2 && permutation[0] == 1 && permutation[1] == 0) { + size_t num_rows = shape[0]; + size_t num_cols = shape[1]; + + constexpr size_t THRESHOLD = 8192; + constexpr size_t TILE_ROWS = 64; + if (num_rows * num_cols >= THRESHOLD) { + const size_t num_row_blocks = (num_rows + TILE_ROWS - 1) / TILE_ROWS; + + aphrodite::mobile::parallel_for( + num_row_blocks, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [=](size_t start_block, size_t end_block) { + for (size_t block_idx = start_block; block_idx < end_block; + ++block_idx) { + size_t start_row = block_idx * TILE_ROWS; + size_t end_row = std::min(start_row + TILE_ROWS, num_rows); + + transpose_2d_int8(source, destination, num_rows, num_cols, + start_row, end_row); + } + }); + } else { + cactus_transpose_2d_int8(source, destination, num_rows, num_cols, 0, + num_rows); + } + } else { + for (size_t idx = start_idx; idx < end_idx; ++idx) { + size_t src_idx = 0; + size_t tmp_idx = idx; + + for (size_t i = 0; i < ndim; ++i) { + size_t coord = tmp_idx % shape[permutation[ndim - 1 - i]]; + tmp_idx /= shape[permutation[ndim - 1 - i]]; + + size_t stride = 1; + for (size_t j = permutation[ndim - 1 - i] + 1; j < ndim; ++j) { + stride *= shape[j]; + } + src_idx += coord * stride; + } + + destination[idx] = source[src_idx]; + } + } +} +} // namespace aphrodite::mobile diff --git a/aphrodite_kernels/csrc/cpu/mobile/gemm.cpp b/aphrodite_kernels/csrc/cpu/mobile/gemm.cpp new file mode 100644 index 0000000000..9c9362cc37 --- /dev/null +++ b/aphrodite_kernels/csrc/cpu/mobile/gemm.cpp @@ -0,0 +1,925 @@ +#include "threading.hpp" +#include +#include +#include +#include + +namespace aphrodite::mobile { + +static void matmul_int8_worker(const int8_t* a, const int8_t* b_transposed, + int8_t* c, size_t M, size_t K, size_t N, + size_t start_row, size_t end_row, float a_scale, + float b_scale, float c_scale) { + constexpr int TILE_M = 4; + constexpr int TILE_N = 4; + constexpr int VECTOR_WIDTH = 16; + const size_t K_aligned = (K / (VECTOR_WIDTH * 2)) * (VECTOR_WIDTH * 2); + + for (size_t row_block = start_row; row_block < end_row; row_block += TILE_M) { + for (size_t col_block = 0; col_block < N; col_block += TILE_N) { + int32x4_t accumulators[TILE_M][TILE_N]; + for (int m = 0; m < TILE_M; ++m) + for (int n = 0; n < TILE_N; ++n) accumulators[m][n] = vdupq_n_s32(0); + + for (size_t k_block = 0; k_block < K_aligned; + k_block += VECTOR_WIDTH * 2) { + int8x16_t a_vec_low[TILE_M], a_vec_high[TILE_M]; + int8x16_t b_vec_low[TILE_N], b_vec_high[TILE_N]; + + for (int m = 0; m < TILE_M; ++m) { + size_t row = row_block + m; + if (row < M) { + a_vec_low[m] = vld1q_s8(&a[row * K + k_block]); + a_vec_high[m] = vld1q_s8(&a[row * K + k_block + VECTOR_WIDTH]); + } else { + a_vec_low[m] = vdupq_n_s8(0); + a_vec_high[m] = vdupq_n_s8(0); + } + } + + for (int n = 0; n < TILE_N; ++n) { + size_t col = col_block + n; + if (col < N) { + b_vec_low[n] = vld1q_s8(&b_transposed[col * K + k_block]); + b_vec_high[n] = + vld1q_s8(&b_transposed[col * K + k_block + VECTOR_WIDTH]); + } else { + b_vec_low[n] = vdupq_n_s8(0); + b_vec_high[n] = vdupq_n_s8(0); + } + } + + accumulators[0][0] = + accum_i8mm(accumulators[0][0], a_vec_low[0], b_vec_low[0]); + accumulators[0][1] = + accum_i8mm(accumulators[0][1], a_vec_low[0], b_vec_low[1]); + accumulators[0][2] = + accum_i8mm(accumulators[0][2], a_vec_low[0], b_vec_low[2]); + accumulators[0][3] = + accum_i8mm(accumulators[0][3], a_vec_low[0], b_vec_low[3]); + accumulators[1][0] = + accum_i8mm(accumulators[1][0], a_vec_low[1], b_vec_low[0]); + accumulators[1][1] = + accum_i8mm(accumulators[1][1], a_vec_low[1], b_vec_low[1]); + accumulators[1][2] = + accum_i8mm(accumulators[1][2], a_vec_low[1], b_vec_low[2]); + accumulators[1][3] = + accum_i8mm(accumulators[1][3], a_vec_low[1], b_vec_low[3]); + accumulators[2][0] = + accum_i8mm(accumulators[2][0], a_vec_low[2], b_vec_low[0]); + accumulators[2][1] = + accum_i8mm(accumulators[2][1], a_vec_low[2], b_vec_low[1]); + accumulators[2][2] = + accum_i8mm(accumulators[2][2], a_vec_low[2], b_vec_low[2]); + accumulators[2][3] = + accum_i8mm(accumulators[2][3], a_vec_low[2], b_vec_low[3]); + accumulators[3][0] = + accum_i8mm(accumulators[3][0], a_vec_low[3], b_vec_low[0]); + accumulators[3][1] = + accum_i8mm(accumulators[3][1], a_vec_low[3], b_vec_low[1]); + accumulators[3][2] = + accum_i8mm(accumulators[3][2], a_vec_low[3], b_vec_low[2]); + accumulators[3][3] = + accum_i8mm(accumulators[3][3], a_vec_low[3], b_vec_low[3]); + + accumulators[0][0] = + accum_i8mm(accumulators[0][0], a_vec_high[0], b_vec_high[0]); + accumulators[0][1] = + accum_i8mm(accumulators[0][1], a_vec_high[0], b_vec_high[1]); + accumulators[0][2] = + accum_i8mm(accumulators[0][2], a_vec_high[0], b_vec_high[2]); + accumulators[0][3] = + accum_i8mm(accumulators[0][3], a_vec_high[0], b_vec_high[3]); + accumulators[1][0] = + accum_i8mm(accumulators[1][0], a_vec_high[1], b_vec_high[0]); + accumulators[1][1] = + accum_i8mm(accumulators[1][1], a_vec_high[1], b_vec_high[1]); + accumulators[1][2] = + accum_i8mm(accumulators[1][2], a_vec_high[1], b_vec_high[2]); + accumulators[1][3] = + accum_i8mm(accumulators[1][3], a_vec_high[1], b_vec_high[3]); + accumulators[2][0] = + accum_i8mm(accumulators[2][0], a_vec_high[2], b_vec_high[0]); + accumulators[2][1] = + accum_i8mm(accumulators[2][1], a_vec_high[2], b_vec_high[1]); + accumulators[2][2] = + accum_i8mm(accumulators[2][2], a_vec_high[2], b_vec_high[2]); + accumulators[2][3] = + accum_i8mm(accumulators[2][3], a_vec_high[2], b_vec_high[3]); + accumulators[3][0] = + accum_i8mm(accumulators[3][0], a_vec_high[3], b_vec_high[0]); + accumulators[3][1] = + accum_i8mm(accumulators[3][1], a_vec_high[3], b_vec_high[1]); + accumulators[3][2] = + accum_i8mm(accumulators[3][2], a_vec_high[3], b_vec_high[2]); + accumulators[3][3] = + accum_i8mm(accumulators[3][3], a_vec_high[3], b_vec_high[3]); + } + + for (size_t k_block = K_aligned; k_block < K; k_block += VECTOR_WIDTH) { + size_t remaining = K - k_block; + int8x16_t a_vec[TILE_M], b_vec[TILE_N]; + + for (int m = 0; m < TILE_M; ++m) { + size_t row = row_block + m; + if (row < M) { + if (remaining >= VECTOR_WIDTH) { + a_vec[m] = vld1q_s8(&a[row * K + k_block]); + } else { + int8_t tmp[VECTOR_WIDTH] = {}; + memcpy(tmp, &a[row * K + k_block], remaining); + a_vec[m] = vld1q_s8(tmp); + } + } else { + a_vec[m] = vdupq_n_s8(0); + } + } + + for (int n = 0; n < TILE_N; ++n) { + size_t col = col_block + n; + if (col < N) { + if (remaining >= VECTOR_WIDTH) { + b_vec[n] = vld1q_s8(&b_transposed[col * K + k_block]); + } else { + int8_t tmp[VECTOR_WIDTH] = {}; + memcpy(tmp, &b_transposed[col * K + k_block], remaining); + b_vec[n] = vld1q_s8(tmp); + } + } else { + b_vec[n] = vdupq_n_s8(0); + } + } + + for (int m = 0; m < TILE_M; ++m) + for (int n = 0; n < TILE_N; ++n) + accumulators[m][n] = + accum_i8mm(accumulators[m][n], a_vec[m], b_vec[n]); + } + const float scale_factor = (a_scale * b_scale) / c_scale; + + for (int m = 0; m < TILE_M; ++m) { + size_t row = row_block + m; + if (row >= M) continue; + for (int n = 0; n < TILE_N; ++n) { + size_t col = col_block + n; + if (col >= N) continue; + int32_t sum = vaddvq_s32(accumulators[m][n]); + + float scaled_result = static_cast(sum) * scale_factor; + int32_t quantized_result = static_cast( + scaled_result + (scaled_result >= 0 ? 0.5f : -0.5f)); + quantized_result = std::min(127, std::max(-128, quantized_result)); + + c[row * N + col] = static_cast(quantized_result); + } + } + } + } +} + +void matmul_int8(const int8_t* a, const int8_t* b_transposed, int8_t* c, + size_t M, size_t K, size_t N, float a_scale, float b_scale, + float c_scale) { + if (M == 0) return; + + constexpr size_t TILE_M = 4; + const size_t num_row_blocks = (M + TILE_M - 1) / TILE_M; + + aphrodite::mobile::parallel_for( + num_row_blocks, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [=](size_t start_block, size_t end_block) { + for (size_t block_idx = start_block; block_idx < end_block; + ++block_idx) { + size_t start_row = block_idx * TILE_M; + size_t end_row = std::min(start_row + TILE_M, M); + + matmul_int8_worker(a, b_transposed, c, M, K, N, start_row, end_row, + a_scale, b_scale, c_scale); + } + }); +} + +static void matmul_f16_worker(const __fp16* a, const __fp16* b_transposed, + __fp16* c, size_t M, size_t K, size_t N, + size_t start_row, size_t end_row) { + constexpr int TILE_M = 4; + constexpr int TILE_N = 4; + constexpr int VECTOR_WIDTH = 8; + const size_t K_aligned = (K / (VECTOR_WIDTH * 2)) * (VECTOR_WIDTH * 2); + + for (size_t row_block = start_row; row_block < end_row; row_block += TILE_M) { + for (size_t col_block = 0; col_block < N; col_block += TILE_N) { + float16x8_t accumulators[TILE_M][TILE_N]; + for (int m = 0; m < TILE_M; ++m) + for (int n = 0; n < TILE_N; ++n) accumulators[m][n] = vdupq_n_f16(0.0); + + for (size_t k_block = 0; k_block < K_aligned; + k_block += VECTOR_WIDTH * 2) { + float16x8_t a_vec_low[TILE_M], a_vec_high[TILE_M]; + float16x8_t b_vec_low[TILE_N], b_vec_high[TILE_N]; + + for (int m = 0; m < TILE_M; ++m) { + size_t row = row_block + m; + if (row < M) { + a_vec_low[m] = vld1q_f16(&a[row * K + k_block]); + a_vec_high[m] = vld1q_f16(&a[row * K + k_block + VECTOR_WIDTH]); + } else { + a_vec_low[m] = vdupq_n_f16(0.0); + a_vec_high[m] = vdupq_n_f16(0.0); + } + } + + for (int n = 0; n < TILE_N; ++n) { + size_t col = col_block + n; + if (col < N) { + b_vec_low[n] = vld1q_f16(&b_transposed[col * K + k_block]); + b_vec_high[n] = + vld1q_f16(&b_transposed[col * K + k_block + VECTOR_WIDTH]); + } else { + b_vec_low[n] = vdupq_n_f16(0.0); + b_vec_high[n] = vdupq_n_f16(0.0); + } + } + + for (int m = 0; m < TILE_M; ++m) + for (int n = 0; n < TILE_N; ++n) { + accumulators[m][n] = + accum_f16_dot(accumulators[m][n], a_vec_low[m], a_vec_high[m], + b_vec_low[n], b_vec_high[n]); + } + } + + for (size_t k_block = K_aligned; k_block < K; k_block += VECTOR_WIDTH) { + size_t remaining = K - k_block; + float16x8_t a_vec[TILE_M], b_vec[TILE_N]; + + for (int m = 0; m < TILE_M; ++m) { + size_t row = row_block + m; + if (row < M) { + if (remaining >= VECTOR_WIDTH) { + a_vec[m] = vld1q_f16(&a[row * K + k_block]); + } else { + __fp16 tmp[VECTOR_WIDTH] = {0.0}; + memcpy(tmp, &a[row * K + k_block], remaining * sizeof(__fp16)); + a_vec[m] = vld1q_f16(tmp); + } + } else { + a_vec[m] = vdupq_n_f16(0.0); + } + } + + for (int n = 0; n < TILE_N; ++n) { + size_t col = col_block + n; + if (col < N) { + if (remaining >= VECTOR_WIDTH) { + b_vec[n] = vld1q_f16(&b_transposed[col * K + k_block]); + } else { + __fp16 tmp[VECTOR_WIDTH] = {0.0}; + memcpy(tmp, &b_transposed[col * K + k_block], + remaining * sizeof(__fp16)); + b_vec[n] = vld1q_f16(tmp); + } + } else { + b_vec[n] = vdupq_n_f16(0.0); + } + } + + for (int m = 0; m < TILE_M; ++m) + for (int n = 0; n < TILE_N; ++n) + accumulators[m][n] = + vfmaq_f16(accumulators[m][n], a_vec[m], b_vec[n]); + } + + for (int m = 0; m < TILE_M; ++m) { + size_t row = row_block + m; + if (row >= M) continue; + for (int n = 0; n < TILE_N; ++n) { + size_t col = col_block + n; + if (col >= N) continue; + float16x4_t low = vget_low_f16(accumulators[m][n]); + float16x4_t high = vget_high_f16(accumulators[m][n]); + float16x4_t sum_vec = vadd_f16(low, high); + __fp16 sum = vget_lane_f16(sum_vec, 0) + vget_lane_f16(sum_vec, 1) + + vget_lane_f16(sum_vec, 2) + vget_lane_f16(sum_vec, 3); + c[row * N + col] = sum; + } + } + } + } +} + +void matmul_f16(const __fp16* a, const __fp16* b_transposed, __fp16* c, + size_t M, size_t K, size_t N) { + constexpr size_t TILE_M = 4; + const size_t num_row_blocks = (M + TILE_M - 1) / TILE_M; + + aphrodite::mobile::parallel_for( + num_row_blocks, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [=](size_t start_block, size_t end_block) { + for (size_t block_idx = start_block; block_idx < end_block; + ++block_idx) { + size_t start_row = block_idx * TILE_M; + size_t end_row = std::min(start_row + TILE_M, M); + + matmul_f16_worker(a, b_transposed, c, M, K, N, start_row, end_row); + } + }); +} + +static void matmul_f32_worker(const float* a, const float* b_transposed, + float* c, size_t M, size_t K, size_t N, + size_t start_row, size_t end_row) { + constexpr int TILE_M = 4; + constexpr int TILE_N = 4; + constexpr int VECTOR_WIDTH = 4; + const size_t K_aligned = (K / (VECTOR_WIDTH * 2)) * (VECTOR_WIDTH * 2); + + for (size_t row_block = start_row; row_block < end_row; row_block += TILE_M) { + for (size_t col_block = 0; col_block < N; col_block += TILE_N) { + float32x4_t accumulators[TILE_M][TILE_N]; + for (int m = 0; m < TILE_M; ++m) + for (int n = 0; n < TILE_N; ++n) accumulators[m][n] = vdupq_n_f32(0.0f); + + for (size_t k_block = 0; k_block < K_aligned; + k_block += VECTOR_WIDTH * 2) { + float32x4_t a_vec_low[TILE_M], a_vec_high[TILE_M]; + float32x4_t b_vec_low[TILE_N], b_vec_high[TILE_N]; + + for (int m = 0; m < TILE_M; ++m) { + size_t row = row_block + m; + if (row < M) { + a_vec_low[m] = vld1q_f32(&a[row * K + k_block]); + a_vec_high[m] = vld1q_f32(&a[row * K + k_block + VECTOR_WIDTH]); + } else { + a_vec_low[m] = vdupq_n_f32(0.0f); + a_vec_high[m] = vdupq_n_f32(0.0f); + } + } + + for (int n = 0; n < TILE_N; ++n) { + size_t col = col_block + n; + if (col < N) { + b_vec_low[n] = vld1q_f32(&b_transposed[col * K + k_block]); + b_vec_high[n] = + vld1q_f32(&b_transposed[col * K + k_block + VECTOR_WIDTH]); + } else { + b_vec_low[n] = vdupq_n_f32(0.0f); + b_vec_high[n] = vdupq_n_f32(0.0f); + } + } + + for (int m = 0; m < TILE_M; ++m) + for (int n = 0; n < TILE_N; ++n) { + accumulators[m][n] = + accum_f32_dot(accumulators[m][n], a_vec_low[m], a_vec_high[m], + b_vec_low[n], b_vec_high[n]); + } + } + + for (size_t k_block = K_aligned; k_block < K; k_block += VECTOR_WIDTH) { + size_t remaining = K - k_block; + float32x4_t a_vec[TILE_M], b_vec[TILE_N]; + + for (int m = 0; m < TILE_M; ++m) { + size_t row = row_block + m; + if (row < M) { + if (remaining >= VECTOR_WIDTH) { + a_vec[m] = vld1q_f32(&a[row * K + k_block]); + } else { + float tmp[VECTOR_WIDTH] = {0.0f}; + memcpy(tmp, &a[row * K + k_block], remaining * sizeof(float)); + a_vec[m] = vld1q_f32(tmp); + } + } else { + a_vec[m] = vdupq_n_f32(0.0f); + } + } + + for (int n = 0; n < TILE_N; ++n) { + size_t col = col_block + n; + if (col < N) { + if (remaining >= VECTOR_WIDTH) { + b_vec[n] = vld1q_f32(&b_transposed[col * K + k_block]); + } else { + float tmp[VECTOR_WIDTH] = {0.0f}; + memcpy(tmp, &b_transposed[col * K + k_block], + remaining * sizeof(float)); + b_vec[n] = vld1q_f32(tmp); + } + } else { + b_vec[n] = vdupq_n_f32(0.0f); + } + } + + for (int m = 0; m < TILE_M; ++m) + for (int n = 0; n < TILE_N; ++n) + accumulators[m][n] = + vfmaq_f32(accumulators[m][n], a_vec[m], b_vec[n]); + } + + for (int m = 0; m < TILE_M; ++m) { + size_t row = row_block + m; + if (row >= M) continue; + for (int n = 0; n < TILE_N; ++n) { + size_t col = col_block + n; + if (col >= N) continue; + float sum = vaddvq_f32(accumulators[m][n]); + c[row * N + col] = sum; + } + } + } + } +} + +void matmul_f32(const float* a, const float* b_transposed, float* c, size_t M, + size_t K, size_t N) { + constexpr size_t TILE_M = 4; + const size_t num_row_blocks = (M + TILE_M - 1) / TILE_M; + + aphrodite::mobile::parallel_for( + num_row_blocks, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [=](size_t start_block, size_t end_block) { + for (size_t block_idx = start_block; block_idx < end_block; + ++block_idx) { + size_t start_row = block_idx * TILE_M; + size_t end_row = std::min(start_row + TILE_M, M); + + matmul_f32_worker(a, b_transposed, c, M, K, N, start_row, end_row); + } + }); +} + +#if !defined(__ARM_FEATURE_MATMUL_INT8) +static void matmul_int8_to_int32_worker(const int8_t* a, + const int8_t* b_transposed, int32_t* c, + size_t M, size_t K, size_t N, + size_t start_row, size_t end_row) { + constexpr int TILE_M = 4; + constexpr int TILE_N = 4; + constexpr int VECTOR_WIDTH = 16; + constexpr int DOT_GRANULARITY = 4; + constexpr int VECTOR_UNROLL = 2; + const size_t K_aligned = + (K / (VECTOR_WIDTH * VECTOR_UNROLL)) * (VECTOR_WIDTH * VECTOR_UNROLL); + + for (size_t row_block = start_row; row_block < end_row; row_block += TILE_M) { + for (size_t col_block = 0; col_block < N; col_block += TILE_N) { + int32x4_t accumulators[TILE_M][TILE_N]; + for (int m = 0; m < TILE_M; ++m) + for (int n = 0; n < TILE_N; ++n) accumulators[m][n] = vdupq_n_s32(0); + + for (size_t k_block = 0; k_block < K_aligned; + k_block += VECTOR_WIDTH * VECTOR_UNROLL) { + int8x16_t a_vec[VECTOR_UNROLL][TILE_M]; + int8x16_t b_vec[VECTOR_UNROLL][TILE_N]; + + for (int m = 0; m < TILE_M; ++m) { + size_t row = row_block + m; + if (row < M) { + a_vec[0][m] = vld1q_s8(&a[row * K + k_block]); + a_vec[1][m] = vld1q_s8(&a[row * K + k_block + VECTOR_WIDTH]); + } else { + a_vec[0][m] = vdupq_n_s8(0); + a_vec[1][m] = vdupq_n_s8(0); + } + } + + for (int n = 0; n < TILE_N; ++n) { + size_t col = col_block + n; + if (col < N) { + b_vec[0][n] = vld1q_s8(&b_transposed[col * K + k_block]); + b_vec[1][n] = + vld1q_s8(&b_transposed[col * K + k_block + VECTOR_WIDTH]); + } else { + b_vec[0][n] = vdupq_n_s8(0); + b_vec[1][n] = vdupq_n_s8(0); + } + } + + for (int m = 0; m < TILE_M; ++m) { + for (int n = 0; n < TILE_N; ++n) { + accumulators[m][n] = + accum_i8mm(accumulators[m][n], a_vec[0][m], b_vec[0][n]); + accumulators[m][n] = + accum_i8mm(accumulators[m][n], a_vec[1][m], b_vec[1][n]); + } + } + } + + for (size_t k_block = K_aligned; k_block < K; + k_block += DOT_GRANULARITY) { + size_t remaining = + std::min(static_cast(DOT_GRANULARITY), K - k_block); + + for (int m = 0; m < TILE_M; ++m) { + size_t row = row_block + m; + if (row >= M) continue; + + for (int n = 0; n < TILE_N; ++n) { + size_t col = col_block + n; + if (col >= N) continue; + + int32_t dot_product = 0; + for (size_t k = 0; k < remaining; ++k) { + dot_product += + static_cast(a[row * K + k_block + k]) * + static_cast(b_transposed[col * K + k_block + k]); + } + + int32x4_t dot_vec = vdupq_n_s32(dot_product); + accumulators[m][n] = vaddq_s32(accumulators[m][n], dot_vec); + } + } + } + + for (int m = 0; m < TILE_M; ++m) { + size_t row = row_block + m; + if (row >= M) continue; + for (int n = 0; n < TILE_N; ++n) { + size_t col = col_block + n; + if (col >= N) continue; + int32_t sum = vaddvq_s32(accumulators[m][n]); + c[row * N + col] = sum; + } + } + } + } +} +#endif + +#if !defined(__ARM_FEATURE_MATMUL_INT8) +void matmul_int8_to_int32(const int8_t* a, const int8_t* b_transposed, + int32_t* c, size_t M, size_t K, size_t N) { + if (M == 0) return; + + size_t num_threads = + aphrodite::mobile::compute_gemm_parallelism(M, K, N, sizeof(int8_t)); + + if (num_threads == 1) { + matmul_int8_to_int32_worker(a, b_transposed, c, M, K, N, 0, M); + return; + } + + size_t optimal_tile_m = + std::min(aphrodite::mobile::Thresholds::GEMM_TILE_M, (M + 1) / 2 * 2); + size_t optimal_tile_n = + std::min(aphrodite::mobile::Thresholds::GEMM_TILE_N, (N + 1) / 2 * 2); + + size_t k_cache_footprint = K * sizeof(int8_t); + if (k_cache_footprint > aphrodite::mobile::Thresholds::L2_CACHE_SIZE) { + optimal_tile_m = aphrodite::mobile::Thresholds::GEMM_TILE_M_SMALL; + optimal_tile_n = aphrodite::mobile::Thresholds::GEMM_TILE_N_SMALL; + } + + aphrodite::mobile::parallel_for_2d_tiled( + M, N, optimal_tile_m, optimal_tile_n, + [=](size_t row_start, size_t row_end, size_t col_start, size_t col_end) { + constexpr int MICRO_TILE_M = 2; + constexpr int MICRO_TILE_N = 2; + constexpr int VECTOR_WIDTH = 16; + constexpr int VECTOR_UNROLL = 4; + const size_t K_aligned = (K / (VECTOR_WIDTH * VECTOR_UNROLL)) * + (VECTOR_WIDTH * VECTOR_UNROLL); + + for (size_t row_block = row_start; row_block < row_end; + row_block += MICRO_TILE_M) { + for (size_t col_block = col_start; col_block < col_end; + col_block += MICRO_TILE_N) { + int32x4_t accumulators[MICRO_TILE_M][MICRO_TILE_N]; + for (int m = 0; m < MICRO_TILE_M; ++m) + for (int n = 0; n < MICRO_TILE_N; ++n) + accumulators[m][n] = vdupq_n_s32(0); + + for (size_t k_block = 0; k_block < K_aligned; + k_block += VECTOR_WIDTH * VECTOR_UNROLL) { + int8x16_t a_vec[VECTOR_UNROLL][MICRO_TILE_M]; + int8x16_t b_vec[VECTOR_UNROLL][MICRO_TILE_N]; + + for (int m = 0; m < MICRO_TILE_M; ++m) { + size_t row = row_block + m; + if (row < row_end) { + a_vec[0][m] = vld1q_s8(&a[row * K + k_block]); + a_vec[1][m] = vld1q_s8(&a[row * K + k_block + VECTOR_WIDTH]); + a_vec[2][m] = + vld1q_s8(&a[row * K + k_block + VECTOR_WIDTH * 2]); + a_vec[3][m] = + vld1q_s8(&a[row * K + k_block + VECTOR_WIDTH * 3]); + } else { + a_vec[0][m] = vdupq_n_s8(0); + a_vec[1][m] = vdupq_n_s8(0); + a_vec[2][m] = vdupq_n_s8(0); + a_vec[3][m] = vdupq_n_s8(0); + } + } + + for (int n = 0; n < MICRO_TILE_N; ++n) { + size_t col = col_block + n; + if (col < col_end) { + b_vec[0][n] = vld1q_s8(&b_transposed[col * K + k_block]); + b_vec[1][n] = + vld1q_s8(&b_transposed[col * K + k_block + VECTOR_WIDTH]); + b_vec[2][n] = vld1q_s8( + &b_transposed[col * K + k_block + VECTOR_WIDTH * 2]); + b_vec[3][n] = vld1q_s8( + &b_transposed[col * K + k_block + VECTOR_WIDTH * 3]); + } else { + b_vec[0][n] = vdupq_n_s8(0); + b_vec[1][n] = vdupq_n_s8(0); + b_vec[2][n] = vdupq_n_s8(0); + b_vec[3][n] = vdupq_n_s8(0); + } + } + + accumulators[0][0] = + accum_i8mm(accumulators[0][0], a_vec[0][0], b_vec[0][0]); + accumulators[0][1] = + accum_i8mm(accumulators[0][1], a_vec[0][0], b_vec[0][1]); + accumulators[1][0] = + accum_i8mm(accumulators[1][0], a_vec[0][1], b_vec[0][0]); + accumulators[1][1] = + accum_i8mm(accumulators[1][1], a_vec[0][1], b_vec[0][1]); + + accumulators[0][0] = + accum_i8mm(accumulators[0][0], a_vec[1][0], b_vec[1][0]); + accumulators[0][1] = + accum_i8mm(accumulators[0][1], a_vec[1][0], b_vec[1][1]); + accumulators[1][0] = + accum_i8mm(accumulators[1][0], a_vec[1][1], b_vec[1][0]); + accumulators[1][1] = + accum_i8mm(accumulators[1][1], a_vec[1][1], b_vec[1][1]); + + accumulators[0][0] = + accum_i8mm(accumulators[0][0], a_vec[2][0], b_vec[2][0]); + accumulators[0][1] = + accum_i8mm(accumulators[0][1], a_vec[2][0], b_vec[2][1]); + accumulators[1][0] = + accum_i8mm(accumulators[1][0], a_vec[2][1], b_vec[2][0]); + accumulators[1][1] = + accum_i8mm(accumulators[1][1], a_vec[2][1], b_vec[2][1]); + + accumulators[0][0] = + accum_i8mm(accumulators[0][0], a_vec[3][0], b_vec[3][0]); + accumulators[0][1] = + accum_i8mm(accumulators[0][1], a_vec[3][0], b_vec[3][1]); + accumulators[1][0] = + accum_i8mm(accumulators[1][0], a_vec[3][1], b_vec[3][0]); + accumulators[1][1] = + accum_i8mm(accumulators[1][1], a_vec[3][1], b_vec[3][1]); + } + + for (size_t k_block = K_aligned; k_block < K; + k_block += VECTOR_WIDTH) { + size_t remaining = K - k_block; + int8x16_t a_vec[MICRO_TILE_M], b_vec[MICRO_TILE_N]; + + for (int m = 0; m < MICRO_TILE_M; ++m) { + size_t row = row_block + m; + if (row < row_end) { + if (remaining >= VECTOR_WIDTH) { + a_vec[m] = vld1q_s8(&a[row * K + k_block]); + } else { + int8_t tmp[VECTOR_WIDTH] = {}; + memcpy(tmp, &a[row * K + k_block], remaining); + a_vec[m] = vld1q_s8(tmp); + } + } else { + a_vec[m] = vdupq_n_s8(0); + } + } + + for (int n = 0; n < MICRO_TILE_N; ++n) { + size_t col = col_block + n; + if (col < col_end) { + if (remaining >= VECTOR_WIDTH) { + b_vec[n] = vld1q_s8(&b_transposed[col * K + k_block]); + } else { + int8_t tmp[VECTOR_WIDTH] = {}; + memcpy(tmp, &b_transposed[col * K + k_block], remaining); + b_vec[n] = vld1q_s8(tmp); + } + } else { + b_vec[n] = vdupq_n_s8(0); + } + } + + for (int m = 0; m < MICRO_TILE_M; ++m) + for (int n = 0; n < MICRO_TILE_N; ++n) + accumulators[m][n] = + accum_i8mm(accumulators[m][n], a_vec[m], b_vec[n]); + } + + for (int m = 0; m < MICRO_TILE_M; ++m) { + size_t row = row_block + m; + if (row >= row_end) continue; + for (int n = 0; n < MICRO_TILE_N; ++n) { + size_t col = col_block + n; + if (col >= col_end) continue; + int32_t sum = vaddvq_s32(accumulators[m][n]); + c[row * N + col] = sum; + } + } + } + } + }); +} +#endif + +#if defined(__ARM_FEATURE_MATMUL_INT8) + +static void matmul_int8_to_int32_smmla_worker( + const int8_t* a, const int8_t* b_transposed, int32_t* output, size_t M, + size_t K, size_t N, size_t start_row, size_t end_row, size_t start_col, + size_t end_col) { + const size_t K_aligned = (K / 8) * 8; + + for (size_t row_block = start_row; row_block < end_row; row_block += 4) { + for (size_t col_block = start_col; col_block < end_col; col_block += 4) { + int32x4_t acc[2][2] = {{vdupq_n_s32(0), vdupq_n_s32(0)}, + {vdupq_n_s32(0), vdupq_n_s32(0)}}; + + for (size_t k_block = 0; k_block < K_aligned; k_block += 8) { + if (row_block + 3 < M && col_block + 3 < N && k_block + 8 <= K) { + const size_t prefetch_distance = 64; + if (k_block + prefetch_distance < K) { + __builtin_prefetch(&a[row_block * K + k_block + prefetch_distance], + 0, 1); + __builtin_prefetch( + &a[(row_block + 1) * K + k_block + prefetch_distance], 0, 1); + __builtin_prefetch( + &a[(row_block + 2) * K + k_block + prefetch_distance], 0, 1); + __builtin_prefetch( + &a[(row_block + 3) * K + k_block + prefetch_distance], 0, 1); + } + + const int8_t* a_base = &a[row_block * K + k_block]; + int8x8_t a0 = vld1_s8(a_base); + int8x8_t a1 = vld1_s8(a_base + K); + int8x8_t a2 = vld1_s8(a_base + 2 * K); + int8x8_t a3 = vld1_s8(a_base + 3 * K); + + const int8_t* b_base = &b_transposed[col_block * K + k_block]; + int8x8_t b0 = vld1_s8(b_base); + int8x8_t b1 = vld1_s8(b_base + K); + int8x8_t b2 = vld1_s8(b_base + 2 * K); + int8x8_t b3 = vld1_s8(b_base + 3 * K); + + int8x16_t a_tiles[2] = {vcombine_s8(a0, a1), vcombine_s8(a2, a3)}; + int8x16_t b_tiles[2] = {vcombine_s8(b0, b1), vcombine_s8(b2, b3)}; + + asm volatile( + "smmla %0.4s, %4.16b, %6.16b\n" // acc[0][0] += a_tiles[0] * + // b_tiles[0] + "smmla %1.4s, %4.16b, %7.16b\n" // acc[0][1] += a_tiles[0] * + // b_tiles[1] + "smmla %2.4s, %5.16b, %6.16b\n" // acc[1][0] += a_tiles[1] * + // b_tiles[0] + "smmla %3.4s, %5.16b, %7.16b" // acc[1][1] += a_tiles[1] * + // b_tiles[1] + : "+w"(acc[0][0]), "+w"(acc[0][1]), "+w"(acc[1][0]), + "+w"(acc[1][1]) + : "w"(a_tiles[0]), "w"(a_tiles[1]), "w"(b_tiles[0]), + "w"(b_tiles[1])); + } else { + for (size_t r = 0; r < 4; r += 2) { + for (size_t c = 0; c < 4; c += 2) { + if (row_block + r < M && col_block + c < N) { + int8x16_t a_tile = vdupq_n_s8(0); + const int8_t* a_row_base = &a[(row_block + r) * K + k_block]; + + if (row_block + r < M && k_block + 8 <= K) { + int8x8_t a0 = vld1_s8(a_row_base); + a_tile = vcombine_s8(a0, vdup_n_s8(0)); + + } else if (row_block + r < M) { + int8_t temp_a[8] __attribute__((aligned(8))) = {0}; + size_t valid_k = std::min(8UL, K - k_block); + memcpy(temp_a, a_row_base, valid_k); + int8x8_t a0 = vld1_s8(temp_a); + a_tile = vcombine_s8(a0, vdup_n_s8(0)); + } + + if (row_block + r + 1 < M && k_block + 8 <= K) { + int8x8_t a1 = vld1_s8(a_row_base + K); + a_tile = vcombine_s8(vget_low_s8(a_tile), a1); + + } else if (row_block + r + 1 < M) { + int8_t temp_a[8] __attribute__((aligned(8))) = {0}; + size_t valid_k = std::min(8UL, K - k_block); + memcpy(temp_a, a_row_base + K, valid_k); + int8x8_t a1 = vld1_s8(temp_a); + a_tile = vcombine_s8(vget_low_s8(a_tile), a1); + } + + int8x16_t b_tile = vdupq_n_s8(0); + const int8_t* b_col_base = + &b_transposed[(col_block + c) * K + k_block]; + + if (col_block + c < N && k_block + 8 <= K) { + int8x8_t b0 = vld1_s8(b_col_base); + b_tile = vcombine_s8(b0, vdup_n_s8(0)); + + } else if (col_block + c < N) { + int8_t temp_b[8] __attribute__((aligned(8))) = {0}; + size_t valid_k = std::min(8UL, K - k_block); + memcpy(temp_b, b_col_base, valid_k); + int8x8_t b0 = vld1_s8(temp_b); + b_tile = vcombine_s8(b0, vdup_n_s8(0)); + } + + if (col_block + c + 1 < N && k_block + 8 <= K) { + int8x8_t b1 = vld1_s8(b_col_base + K); + b_tile = vcombine_s8(vget_low_s8(b_tile), b1); + + } else if (col_block + c + 1 < N) { + int8_t temp_b[8] __attribute__((aligned(8))) = {0}; + size_t valid_k = std::min(8UL, K - k_block); + memcpy(temp_b, b_col_base + K, valid_k); + int8x8_t b1 = vld1_s8(temp_b); + b_tile = vcombine_s8(vget_low_s8(b_tile), b1); + } + + asm volatile("smmla %0.4s, %1.16b, %2.16b" + : "+w"(acc[r / 2][c / 2]) + : "w"(a_tile), "w"(b_tile)); + } + } + } + } + } + + int32_t results[4][4]; + vst1q_s32(results[0], acc[0][0]); // rows 0-1, cols 0-1 + vst1q_s32(results[1], acc[0][1]); // rows 0-1, cols 2-3 + vst1q_s32(results[2], acc[1][0]); // rows 2-3, cols 0-1 + vst1q_s32(results[3], acc[1][1]); // rows 2-3, cols 2-3 + + for (size_t k = K_aligned; k < K; k++) { + for (size_t r = 0; r < 4 && row_block + r < M; r++) { + for (size_t c = 0; c < 4 && col_block + c < N; c++) { + size_t tile_idx = (r / 2) * 2 + (c / 2); + size_t elem_idx = (r % 2) * 2 + (c % 2); + results[tile_idx][elem_idx] += + a[(row_block + r) * K + k] * + b_transposed[(col_block + c) * K + k]; + } + } + } + + for (size_t r = 0; r < 4 && row_block + r < M; r++) { + for (size_t c = 0; c < 4 && col_block + c < N; c++) { + size_t tile_idx = (r / 2) * 2 + (c / 2); + size_t elem_idx = (r % 2) * 2 + (c % 2); + output[(row_block + r) * N + (col_block + c)] = + results[tile_idx][elem_idx]; + } + } + } + } +} + +void matmul_int8_to_int32_i8mm(const int8_t* a, const int8_t* b_transposed, + int32_t* c, size_t M, size_t K, size_t N) { + if (M == 0) return; + + size_t total_ops = M * K * N; + + memset(c, 0, M * N * sizeof(int32_t)); + + if (total_ops < aphrodite::mobile::Thresholds::GEMM_SMALL) { + cactus_matmul_int8_to_int32_smmla_worker(a, b_transposed, c, M, K, N, 0, M, + 0, N); + return; + } + + size_t num_threads = + aphrodite::mobile::compute_gemm_parallelism(M, K, N, sizeof(int8_t)); + + if (num_threads == 1) { + cactus_matmul_int8_to_int32_smmla_worker(a, b_transposed, c, M, K, N, 0, M, + 0, N); + return; + } + + size_t optimal_tile_m = + std::min(aphrodite::mobile::Thresholds::GEMM_TILE_M, (M + 3) / 4 * 4); + size_t optimal_tile_n = + std::min(aphrodite::mobile::Thresholds::GEMM_TILE_N, (N + 3) / 4 * 4); + + size_t k_cache_footprint = K * sizeof(int8_t); + if (k_cache_footprint > aphrodite::mobile::Thresholds::L2_CACHE_SIZE) { + optimal_tile_m = aphrodite::mobile::Thresholds::GEMM_TILE_M_SMALL; + optimal_tile_n = aphrodite::mobile::Thresholds::GEMM_TILE_N_SMALL; + } + + memset(c, 0, M * N * sizeof(int32_t)); + aphrodite::mobile::parallel_for_2d_tiled( + M, N, optimal_tile_m, optimal_tile_n, + [=](size_t row_start, size_t row_end, size_t col_start, size_t col_end) { + cactus_matmul_int8_to_int32_smmla_worker(a, b_transposed, c, M, K, N, + row_start, row_end, col_start, + col_end); + }); +} + +#endif // __ARM_FEATURE_MATMUL_INT8 +} // namespace aphrodite::mobile diff --git a/aphrodite_kernels/csrc/cpu/mobile/kernels.h b/aphrodite_kernels/csrc/cpu/mobile/kernels.h new file mode 100644 index 0000000000..6c357970ad --- /dev/null +++ b/aphrodite_kernels/csrc/cpu/mobile/kernels.h @@ -0,0 +1,157 @@ +#ifndef APHRODITE_MOBILE_KERNELS_H +#define APHRODITE_MOBILE_KERNELS_H + +#include +#include + +namespace aphrodite::mobile { + +// Quantization kernels +void int8_to_fp32(const int8_t* src, float* dst, size_t count, float scale); +void fp32_to_int8(const float* src, int8_t* dst, size_t count, float scale); +void dynamic_quantize_fp32_to_int8(const float* src, int8_t* dst, size_t count, + float* computed_scale); +void fp16_to_fp32(const __fp16* src, float* dst, size_t count); +void fp32_to_fp16(const float* src, __fp16* dst, size_t count); +void int8_to_fp16(const int8_t* src, __fp16* dst, size_t count, float scale); +void fp16_to_int8(const __fp16* src, int8_t* dst, size_t count, float scale); +float fp16_max_abs(const __fp16* src, size_t count); +void int32_to_fp16_scaled(const int32_t* src, __fp16* dst, size_t count, + float scale); + +// BLAS kernels - INT8 +void add_int8(const int8_t* a, const int8_t* b, int8_t* output, + size_t num_elements); +void subtract_int8(const int8_t* a, const int8_t* b, int8_t* output, + size_t num_elements); +void multiply_int8(const int8_t* a, const int8_t* b, int8_t* output, + size_t num_elements); +void divide_int8(const int8_t* a, const int8_t* b, int8_t* output, + size_t num_elements); + +// BLAS kernels - FP32 +void add_f32(const float* a, const float* b, float* output, + size_t num_elements); +void subtract_f32(const float* a, const float* b, float* output, + size_t num_elements); +void multiply_f32(const float* a, const float* b, float* output, + size_t num_elements); +void divide_f32(const float* a, const float* b, float* output, + size_t num_elements); + +// Matrix multiplication (GEMM) +void matmul_int8(const int8_t* a, const int8_t* b_transposed, int8_t* c, + size_t M, size_t K, size_t N, float a_scale, float b_scale, + float c_scale); +void matmul_f16(const __fp16* a, const __fp16* b_transposed, __fp16* c, + size_t M, size_t K, size_t N); +void matmul_f32(const float* a, const float* b_transposed, float* c, size_t M, + size_t K, size_t N); +void matmul_int8_to_int32(const int8_t* a, const int8_t* b_transposed, + int32_t* c, size_t M, size_t K, size_t N); +#if defined(__ARM_FEATURE_MATMUL_INT8) +void matmul_int8_to_int32_i8mm(const int8_t* a, const int8_t* b_transposed, + int32_t* c, size_t M, size_t K, size_t N); +#endif + +// Reduction kernels +int64_t sum_all_int8(const int8_t* data, size_t num_elements); +void sum_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, + size_t axis_size, size_t inner_size); +double mean_all_int8(const int8_t* data, size_t num_elements); +void mean_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, + size_t axis_size, size_t inner_size); +double mean_all_f16(const __fp16* data, size_t num_elements); +void mean_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, + size_t axis_size, size_t inner_size); +double variance_all_int8(const int8_t* data, size_t num_elements); +void variance_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, + size_t axis_size, size_t inner_size); +int64_t min_all_int8(const int8_t* data, size_t num_elements); +void min_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, + size_t axis_size, size_t inner_size); +int64_t max_all_int8(const int8_t* data, size_t num_elements); +void max_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, + size_t axis_size, size_t inner_size); +double sum_all_f32(const float* data, size_t num_elements); +void sum_axis_f32(const float* input, float* output, size_t outer_size, + size_t axis_size, size_t inner_size); + +// Scalar operation kernels +enum class ScalarOpType { + ADD, + SUBTRACT, + MULTIPLY, + DIVIDE, + EXP, + SQRT, + COS, + SIN +}; +void scalar_op_int8(const int8_t* input, int8_t* output, size_t num_elements, + float scalar_value, ScalarOpType op_type); +void scalar_op_f16(const __fp16* input, __fp16* output, size_t num_elements, + float scalar_value, ScalarOpType op_type); +void scalar_op_f32(const float* input, float* output, size_t num_elements, + float scalar_value, ScalarOpType op_type); + +// Neural network kernels +void silu_f32(const float* input, float* output, size_t num_elements); +void silu_f16(const __fp16* input, __fp16* output, size_t num_elements); +void silu_int8(const int8_t* input, int8_t* output, size_t num_elements, + float input_scale, float output_scale); +void gelu_f32(const float* input, float* output, size_t num_elements); +void gelu_f16(const __fp16* input, __fp16* output, size_t num_elements); +void gelu_int8(const int8_t* input, int8_t* output, size_t num_elements, + float input_scale, float output_scale); +void softmax_f32(const float* input, float* output, size_t batch_size, + size_t seq_len, size_t vocab_size); +void softmax_f16(const __fp16* input, __fp16* output, size_t batch_size, + size_t seq_len, size_t vocab_size); + +// Attention kernels +void attention_int8(const int8_t* queries, const int8_t* keys, + const int8_t* values, int8_t* output, size_t batch_size, + size_t seq_len, size_t kv_seq_len, size_t num_q_heads, + size_t num_kv_heads, size_t head_dim, float scale, + const int8_t* mask, float q_scale, float k_scale, + float v_scale, float output_scale, + size_t position_offset = 0, size_t window_size = 0, + bool is_causal = true); +void attention_f16(const __fp16* queries, const __fp16* keys, + const __fp16* values, __fp16* output, size_t batch_size, + size_t seq_len, size_t kv_seq_len, size_t num_q_heads, + size_t num_kv_heads, size_t head_dim, float scale, + const __fp16* mask, size_t position_offset = 0, + size_t window_size = 0, bool is_causal = true); +void attention_f32(const float* queries, const float* keys, const float* values, + float* output, size_t batch_size, size_t seq_len, + size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads, + size_t head_dim, float scale, const float* mask, + size_t position_offset = 0, size_t window_size = 0, + bool is_causal = true); + +// Normalization kernels +void rms_norm_f32(const float* input, const float* weight, float* output, + size_t batch_size, size_t dims, float eps); +void rms_norm_f16(const __fp16* input, const __fp16* weight, __fp16* output, + size_t batch_size, size_t dims, float eps); +void rms_norm_i8_f32(const int8_t* input, const float* weight, float* output, + size_t batch_size, size_t dims, float eps, + float input_scale); + +// RoPE kernels +void rope_f32(const float* input, float* output, size_t batch_size, + size_t seq_len, size_t num_heads, size_t head_dim, + size_t start_pos, float theta); +void rope_f16(const __fp16* input, __fp16* output, size_t batch_size, + size_t seq_len, size_t num_heads, size_t head_dim, + size_t start_pos, float theta); +void rope_i8_f32_i8(const int8_t* input, int8_t* output, size_t batch_size, + size_t seq_len, size_t num_heads, size_t head_dim, + size_t start_pos, float theta, float input_scale, + float output_scale); + +} // namespace aphrodite::mobile + +#endif // APHRODITE_MOBILE_KERNELS_H diff --git a/aphrodite_kernels/csrc/cpu/mobile/nn.cpp b/aphrodite_kernels/csrc/cpu/mobile/nn.cpp new file mode 100644 index 0000000000..7feeb3241b --- /dev/null +++ b/aphrodite_kernels/csrc/cpu/mobile/nn.cpp @@ -0,0 +1,857 @@ +#include "threading.hpp" +#include +#include +#include +#include +#include +#include +#include + +namespace aphrodite::mobile { + +void silu_f32(const float* input, float* output, size_t num_elements) { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + const float32x4_t one = vdupq_n_f32(1.0f); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float32x4_t x = vld1q_f32(&input[i]); + + float32x4_t neg_x = vnegq_f32(x); + + float32x4_t exp_vals[4]; + float neg_vals[4], result_vals[4]; + vst1q_f32(neg_vals, neg_x); + + for (int j = 0; j < 4; j++) { + result_vals[j] = expf(neg_vals[j]); + } + exp_vals[0] = vld1q_f32(result_vals); + + float32x4_t one_plus_exp = vaddq_f32(one, exp_vals[0]); + + float32x4_t sigmoid = vdivq_f32(one, one_plus_exp); + + float32x4_t silu = vmulq_f32(x, sigmoid); + + vst1q_f32(&output[i], silu); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + float sigmoid = 1.0f / (1.0f + expf(-input[i])); + output[i] = input[i] * sigmoid; + } + }); +} + +void silu_f16(const __fp16* input, __fp16* output, size_t num_elements) { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 8; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float16x8_t x = vld1q_f16(&input[i]); + + float32x4_t x_low = vcvt_f32_f16(vget_low_f16(x)); + float32x4_t x_high = vcvt_f32_f16(vget_high_f16(x)); + + float32x4_t neg_x_low = vnegq_f32(x_low); + float32x4_t neg_x_high = vnegq_f32(x_high); + + float exp_vals[8]; + vst1q_f32(&exp_vals[0], neg_x_low); + vst1q_f32(&exp_vals[4], neg_x_high); + + for (int j = 0; j < 8; j++) { + exp_vals[j] = expf(exp_vals[j]); + } + + float32x4_t exp_low = vld1q_f32(&exp_vals[0]); + float32x4_t exp_high = vld1q_f32(&exp_vals[4]); + + float32x4_t one_f32 = vdupq_n_f32(1.0f); + float32x4_t one_plus_exp_low = vaddq_f32(one_f32, exp_low); + float32x4_t one_plus_exp_high = vaddq_f32(one_f32, exp_high); + + float32x4_t sigmoid_low = vdivq_f32(one_f32, one_plus_exp_low); + float32x4_t sigmoid_high = vdivq_f32(one_f32, one_plus_exp_high); + + float16x4_t sigmoid_low_f16 = vcvt_f16_f32(sigmoid_low); + float16x4_t sigmoid_high_f16 = vcvt_f16_f32(sigmoid_high); + float16x8_t sigmoid = vcombine_f16(sigmoid_low_f16, sigmoid_high_f16); + + float16x8_t silu = vmulq_f16(x, sigmoid); + + vst1q_f16(&output[i], silu); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + float x_f32 = static_cast(input[i]); + float sigmoid = 1.0f / (1.0f + expf(-x_f32)); + output[i] = static_cast<__fp16>(x_f32 * sigmoid); + } + }); +} + +void silu_int8(const int8_t* input, int8_t* output, size_t num_elements, + float input_scale, float output_scale) { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [&](size_t start_idx, size_t end_idx) { + for (size_t i = start_idx; i < end_idx; ++i) { + float x = input[i] * input_scale; + float sigmoid = 1.0f / (1.0f + expf(-x)); + float silu = x * sigmoid; + float scaled = silu / output_scale; + output[i] = static_cast( + std::max(-128.0f, std::min(127.0f, roundf(scaled)))); + } + }); +} + +void gelu_f32(const float* input, float* output, size_t num_elements) { + const float sqrt_2_over_pi = 0.7978845608028654f; + const float coeff = 0.044715f; + + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + const float32x4_t half = vdupq_n_f32(0.5f); + const float32x4_t one = vdupq_n_f32(1.0f); + const float32x4_t sqrt_2_pi_vec = vdupq_n_f32(sqrt_2_over_pi); + const float32x4_t coeff_vec = vdupq_n_f32(coeff); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float32x4_t x = vld1q_f32(&input[i]); + + float32x4_t x_cubed = vmulq_f32(vmulq_f32(x, x), x); + float32x4_t inner = vmlaq_f32(x, coeff_vec, x_cubed); + float32x4_t tanh_arg = vmulq_f32(sqrt_2_pi_vec, inner); + + float tanh_vals[4], arg_vals[4]; + vst1q_f32(arg_vals, tanh_arg); + for (int j = 0; j < 4; j++) { + tanh_vals[j] = tanhf(arg_vals[j]); + } + float32x4_t tanh_result = vld1q_f32(tanh_vals); + + float32x4_t gelu = + vmulq_f32(vmulq_f32(half, x), vaddq_f32(one, tanh_result)); + + vst1q_f32(&output[i], gelu); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + float x = input[i]; + float inner = sqrt_2_over_pi * (x + coeff * x * x * x); + output[i] = 0.5f * x * (1.0f + tanhf(inner)); + } + }); +} + +void gelu_f16(const __fp16* input, __fp16* output, size_t num_elements) { + const float sqrt_2_over_pi = 0.7978845608028654f; + const float coeff = 0.044715f; + + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 8; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + const float32x4_t half = vdupq_n_f32(0.5f); + const float32x4_t one = vdupq_n_f32(1.0f); + const float32x4_t sqrt_2_pi_vec = vdupq_n_f32(sqrt_2_over_pi); + const float32x4_t coeff_vec = vdupq_n_f32(coeff); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float16x8_t x_f16 = vld1q_f16(&input[i]); + + float32x4_t x_low = vcvt_f32_f16(vget_low_f16(x_f16)); + float32x4_t x_high = vcvt_f32_f16(vget_high_f16(x_f16)); + + float32x4_t x_cubed_low = vmulq_f32(vmulq_f32(x_low, x_low), x_low); + float32x4_t x_cubed_high = + vmulq_f32(vmulq_f32(x_high, x_high), x_high); + + float32x4_t inner_low = vmlaq_f32(x_low, coeff_vec, x_cubed_low); + float32x4_t inner_high = vmlaq_f32(x_high, coeff_vec, x_cubed_high); + inner_low = vmulq_f32(sqrt_2_pi_vec, inner_low); + inner_high = vmulq_f32(sqrt_2_pi_vec, inner_high); + + float tanh_vals[8]; + vst1q_f32(&tanh_vals[0], inner_low); + vst1q_f32(&tanh_vals[4], inner_high); + for (int j = 0; j < 8; j++) { + tanh_vals[j] = tanhf(tanh_vals[j]); + } + float32x4_t tanh_low = vld1q_f32(&tanh_vals[0]); + float32x4_t tanh_high = vld1q_f32(&tanh_vals[4]); + + float32x4_t one_plus_tanh_low = vaddq_f32(one, tanh_low); + float32x4_t one_plus_tanh_high = vaddq_f32(one, tanh_high); + float32x4_t gelu_low = + vmulq_f32(vmulq_f32(half, x_low), one_plus_tanh_low); + float32x4_t gelu_high = + vmulq_f32(vmulq_f32(half, x_high), one_plus_tanh_high); + + float16x4_t gelu_low_f16 = vcvt_f16_f32(gelu_low); + float16x4_t gelu_high_f16 = vcvt_f16_f32(gelu_high); + float16x8_t gelu_f16 = vcombine_f16(gelu_low_f16, gelu_high_f16); + + vst1q_f16(&output[i], gelu_f16); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + float x = static_cast(input[i]); + float inner = sqrt_2_over_pi * (x + coeff * x * x * x); + float gelu = 0.5f * x * (1.0f + tanhf(inner)); + output[i] = static_cast<__fp16>(gelu); + } + }); +} + +void gelu_int8(const int8_t* input, int8_t* output, size_t num_elements, + float input_scale, float output_scale) { + const float sqrt_2_over_pi = 0.7978845608028654f; + const float coeff = 0.044715f; + + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [&](size_t start_idx, size_t end_idx) { + for (size_t i = start_idx; i < end_idx; ++i) { + float x = input[i] * input_scale; + float inner = sqrt_2_over_pi * (x + coeff * x * x * x); + float gelu = 0.5f * x * (1.0f + tanhf(inner)); + float scaled = gelu / output_scale; + output[i] = static_cast( + std::max(-128.0f, std::min(127.0f, roundf(scaled)))); + } + }); +} + +namespace CactusSoftmax { + +inline float32x4_t fast_exp_neon(float32x4_t x) { + const float32x4_t log2e = vdupq_n_f32(1.4426950408889634f); + const float32x4_t c1 = vdupq_n_f32(0.6931471805599453f); + const float32x4_t c2 = vdupq_n_f32(0.2402265069591007f); + const float32x4_t c3 = vdupq_n_f32(0.05550410866482158f); + const float32x4_t c4 = vdupq_n_f32(0.009618129842071803f); + const float32x4_t c5 = vdupq_n_f32(0.001333355814670656f); + + const float32x4_t clamp_min = vdupq_n_f32(-87.0f); + const float32x4_t clamp_max = vdupq_n_f32(87.0f); + + x = vmaxq_f32(x, clamp_min); + x = vminq_f32(x, clamp_max); + + x = vmulq_f32(x, log2e); + + int32x4_t xi = vcvtq_s32_f32(x); + float32x4_t xf = vsubq_f32(x, vcvtq_f32_s32(xi)); + + float32x4_t p = vdupq_n_f32(1.0f); + p = vfmaq_f32(p, c1, xf); + + float32x4_t xf2 = vmulq_f32(xf, xf); + p = vfmaq_f32(p, c2, xf2); + + float32x4_t xf3 = vmulq_f32(xf2, xf); + p = vfmaq_f32(p, c3, xf3); + + float32x4_t xf4 = vmulq_f32(xf3, xf); + p = vfmaq_f32(p, c4, xf4); + + float32x4_t xf5 = vmulq_f32(xf4, xf); + p = vfmaq_f32(p, c5, xf5); + + int32x4_t exponent = vaddq_s32(xi, vdupq_n_s32(127)); + exponent = vshlq_n_s32(exponent, 23); + float32x4_t scale = vreinterpretq_f32_s32(exponent); + + return vmulq_f32(p, scale); +} + +inline float32x4_t fast_reciprocal_neon(float32x4_t x) { + float32x4_t recip = vrecpeq_f32(x); + recip = vmulq_f32(recip, vrecpsq_f32(x, recip)); + recip = vmulq_f32(recip, vrecpsq_f32(x, recip)); + return recip; +} + +void kernel_softmax_neon_optimized_single(const float* input, float* output, + size_t vocab_size) { + constexpr size_t SIMD_WIDTH = 4; + constexpr size_t UNROLL_FACTOR = 8; + constexpr size_t VECTORIZED_WIDTH = SIMD_WIDTH * UNROLL_FACTOR; + const size_t vocab_vectorized = + (vocab_size / VECTORIZED_WIDTH) * VECTORIZED_WIDTH; + + float32x4_t max_vec[UNROLL_FACTOR]; + for (size_t u = 0; u < UNROLL_FACTOR; u++) { + max_vec[u] = vdupq_n_f32(-std::numeric_limits::infinity()); + } + + for (size_t i = 0; i < vocab_vectorized; i += VECTORIZED_WIDTH) { + for (size_t u = 0; u < UNROLL_FACTOR; u++) { + float32x4_t x_vec = vld1q_f32(&input[i + u * SIMD_WIDTH]); + max_vec[u] = vmaxq_f32(max_vec[u], x_vec); + } + } + + float32x4_t final_max = max_vec[0]; + for (size_t u = 1; u < UNROLL_FACTOR; u++) { + final_max = vmaxq_f32(final_max, max_vec[u]); + } + + float max_val = vmaxvq_f32(final_max); + for (size_t i = vocab_vectorized; i < vocab_size; ++i) { + max_val = std::max(max_val, input[i]); + } + + const float32x4_t max_broadcast = vdupq_n_f32(max_val); + + float32x4_t sum_vec[UNROLL_FACTOR]; + for (size_t u = 0; u < UNROLL_FACTOR; u++) { + sum_vec[u] = vdupq_n_f32(0.0f); + } + + for (size_t i = 0; i < vocab_vectorized; i += VECTORIZED_WIDTH) { + float32x4_t x_vec[UNROLL_FACTOR]; + float32x4_t exp_vec[UNROLL_FACTOR]; + + for (size_t u = 0; u < UNROLL_FACTOR; u++) { + x_vec[u] = vld1q_f32(&input[i + u * SIMD_WIDTH]); + } + + for (size_t u = 0; u < UNROLL_FACTOR; u++) { + exp_vec[u] = fast_exp_neon(vsubq_f32(x_vec[u], max_broadcast)); + sum_vec[u] = vaddq_f32(sum_vec[u], exp_vec[u]); + } + + for (size_t u = 0; u < UNROLL_FACTOR; u++) { + vst1q_f32(&output[i + u * SIMD_WIDTH], exp_vec[u]); + } + } + + float32x4_t final_sum = sum_vec[0]; + for (size_t u = 1; u < UNROLL_FACTOR; u++) { + final_sum = vaddq_f32(final_sum, sum_vec[u]); + } + + float sum = vaddvq_f32(final_sum); + for (size_t i = vocab_vectorized; i < vocab_size; ++i) { + float exp_val = expf(input[i] - max_val); + output[i] = exp_val; + sum += exp_val; + } + + const float32x4_t inv_sum_vec = fast_reciprocal_neon(vdupq_n_f32(sum)); + + for (size_t i = 0; i < vocab_vectorized; i += VECTORIZED_WIDTH) { + float32x4_t exp_vec[UNROLL_FACTOR]; + + for (size_t u = 0; u < UNROLL_FACTOR; u++) { + exp_vec[u] = vld1q_f32(&output[i + u * SIMD_WIDTH]); + } + + for (size_t u = 0; u < UNROLL_FACTOR; u++) { + exp_vec[u] = vmulq_f32(exp_vec[u], inv_sum_vec); + } + + for (size_t u = 0; u < UNROLL_FACTOR; u++) { + vst1q_f32(&output[i + u * SIMD_WIDTH], exp_vec[u]); + } + } + + const float inv_sum = vgetq_lane_f32(inv_sum_vec, 0); + for (size_t i = vocab_vectorized; i < vocab_size; ++i) { + output[i] *= inv_sum; + } +} + +} // namespace CactusSoftmax + +void softmax_f32(const float* input, float* output, size_t batch_size, + size_t seq_len, size_t vocab_size) { + aphrodite::mobile::parallel_for( + batch_size * seq_len, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [&](size_t start_idx, size_t end_idx) { + for (size_t idx = start_idx; idx < end_idx; ++idx) { + const size_t offset = idx * vocab_size; + CactusSoftmax::kernel_softmax_neon_optimized_single( + input + offset, output + offset, vocab_size); + } + }); +} + +void kernel_softmax_f16_single(const __fp16* input, __fp16* output, + size_t vocab_size) { + constexpr size_t SIMD_WIDTH = 8; + constexpr size_t UNROLL_FACTOR = 4; + constexpr size_t VECTORIZED_WIDTH = SIMD_WIDTH * UNROLL_FACTOR; + const size_t vocab_vectorized = + (vocab_size / VECTORIZED_WIDTH) * VECTORIZED_WIDTH; + + float32x4_t max_vec[UNROLL_FACTOR * 2]; + for (size_t u = 0; u < UNROLL_FACTOR * 2; u++) { + max_vec[u] = vdupq_n_f32(-std::numeric_limits::infinity()); + } + + for (size_t i = 0; i < vocab_vectorized; i += VECTORIZED_WIDTH) { + for (size_t u = 0; u < UNROLL_FACTOR; u++) { + float16x8_t x_vec_f16 = vld1q_f16(&input[i + u * SIMD_WIDTH]); + float32x4_t x_low = vcvt_f32_f16(vget_low_f16(x_vec_f16)); + float32x4_t x_high = vcvt_f32_f16(vget_high_f16(x_vec_f16)); + max_vec[u * 2] = vmaxq_f32(max_vec[u * 2], x_low); + max_vec[u * 2 + 1] = vmaxq_f32(max_vec[u * 2 + 1], x_high); + } + } + + float32x4_t final_max = max_vec[0]; + for (size_t u = 1; u < UNROLL_FACTOR * 2; u++) { + final_max = vmaxq_f32(final_max, max_vec[u]); + } + + float max_val = vmaxvq_f32(final_max); + for (size_t i = vocab_vectorized; i < vocab_size; ++i) { + max_val = std::max(max_val, static_cast(input[i])); + } + + const float32x4_t max_broadcast = vdupq_n_f32(max_val); + const float16x8_t max_broadcast_f16 = + vcombine_f16(vcvt_f16_f32(max_broadcast), vcvt_f16_f32(max_broadcast)); + + float32x4_t sum_vec[UNROLL_FACTOR * 2]; + for (size_t u = 0; u < UNROLL_FACTOR * 2; u++) { + sum_vec[u] = vdupq_n_f32(0.0f); + } + + for (size_t i = 0; i < vocab_vectorized; i += VECTORIZED_WIDTH) { + for (size_t u = 0; u < UNROLL_FACTOR; u++) { + float16x8_t x_vec_f16 = vld1q_f16(&input[i + u * SIMD_WIDTH]); + + float16x8_t centered_f16 = vsubq_f16(x_vec_f16, max_broadcast_f16); + + float32x4_t centered_low = vcvt_f32_f16(vget_low_f16(centered_f16)); + float32x4_t centered_high = vcvt_f32_f16(vget_high_f16(centered_f16)); + + float32x4_t exp_low = CactusSoftmax::fast_exp_neon(centered_low); + float32x4_t exp_high = CactusSoftmax::fast_exp_neon(centered_high); + + float16x8_t exp_f16 = + vcombine_f16(vcvt_f16_f32(exp_low), vcvt_f16_f32(exp_high)); + vst1q_f16(&output[i + u * SIMD_WIDTH], exp_f16); + + sum_vec[u * 2] = vaddq_f32(sum_vec[u * 2], exp_low); + sum_vec[u * 2 + 1] = vaddq_f32(sum_vec[u * 2 + 1], exp_high); + } + } + + float32x4_t final_sum = sum_vec[0]; + for (size_t u = 1; u < UNROLL_FACTOR * 2; u++) { + final_sum = vaddq_f32(final_sum, sum_vec[u]); + } + + float sum = vaddvq_f32(final_sum); + for (size_t i = vocab_vectorized; i < vocab_size; ++i) { + float exp_val = expf(static_cast(input[i]) - max_val); + output[i] = static_cast<__fp16>(exp_val); + sum += exp_val; + } + + const float inv_sum = 1.0f / sum; + const float16x8_t inv_sum_vec_f16 = vdupq_n_f16(static_cast<__fp16>(inv_sum)); + + for (size_t i = 0; i < vocab_vectorized; i += VECTORIZED_WIDTH) { + for (size_t u = 0; u < UNROLL_FACTOR; u++) { + float16x8_t exp_vec = vld1q_f16(&output[i + u * SIMD_WIDTH]); + float16x8_t result = vmulq_f16(exp_vec, inv_sum_vec_f16); + vst1q_f16(&output[i + u * SIMD_WIDTH], result); + } + } + + for (size_t i = vocab_vectorized; i < vocab_size; ++i) { + output[i] = static_cast<__fp16>(static_cast(output[i]) * inv_sum); + } +} + +void softmax_f16(const __fp16* input, __fp16* output, size_t batch_size, + size_t seq_len, size_t vocab_size) { + aphrodite::mobile::parallel_for( + batch_size * seq_len, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [&](size_t start_idx, size_t end_idx) { + for (size_t idx = start_idx; idx < end_idx; ++idx) { + const size_t offset = idx * vocab_size; + kernel_softmax_f16_single(input + offset, output + offset, + vocab_size); + } + }); +} + +void sample_f32(const float* logits, uint32_t* output, size_t vocab_size, + float temperature, float top_p, size_t top_k, + size_t random_seed) { + std::vector filtered_logits(vocab_size); + + for (size_t i = 0; i < vocab_size; ++i) { + filtered_logits[i] = logits[i]; + } + + if (temperature > 0) { + for (size_t i = 0; i < vocab_size; ++i) { + filtered_logits[i] /= temperature; + } + } + + if (top_k > 0) { + std::vector> logit_pairs; + logit_pairs.reserve(vocab_size); + for (size_t i = 0; i < vocab_size; ++i) { + logit_pairs.emplace_back(filtered_logits[i], i); + } + std::sort(logit_pairs.begin(), logit_pairs.end(), + [](const auto& a, const auto& b) { return a.first > b.first; }); + + if (top_k < vocab_size) { + float kth_value = logit_pairs[top_k - 1].first; + for (size_t i = 0; i < vocab_size; ++i) { + if (filtered_logits[i] < kth_value) { + filtered_logits[i] = -std::numeric_limits::infinity(); + } + } + } + } + + constexpr float min_p = 0.15f; + if (min_p > 0.0f) { + float max_logit = + *std::max_element(filtered_logits.begin(), filtered_logits.end()); + if (!std::isinf(max_logit)) { + std::vector temp_probs(vocab_size); + float sum = 0.0f; + for (size_t i = 0; i < vocab_size; ++i) { + if (!std::isinf(filtered_logits[i])) { + temp_probs[i] = std::exp(filtered_logits[i] - max_logit); + sum += temp_probs[i]; + } else { + temp_probs[i] = 0.0f; + } + } + + if (sum > 0.0f) { + for (size_t i = 0; i < vocab_size; ++i) { + temp_probs[i] /= sum; + } + + float max_prob = + *std::max_element(temp_probs.begin(), temp_probs.end()); + float threshold = max_prob * min_p; + + for (size_t i = 0; i < vocab_size; ++i) { + if (temp_probs[i] < threshold) { + filtered_logits[i] = -std::numeric_limits::infinity(); + } + } + } + } + } + + if (top_p > 0.0f && top_p < 1.0f) { + std::vector> sorted_logits; + sorted_logits.reserve(vocab_size); + for (size_t i = 0; i < vocab_size; ++i) { + if (!std::isinf(filtered_logits[i])) { + sorted_logits.emplace_back(filtered_logits[i], i); + } + } + std::sort(sorted_logits.begin(), sorted_logits.end(), + [](const auto& a, const auto& b) { return a.first > b.first; }); + + float max_logit = sorted_logits.empty() ? 0.0f : sorted_logits[0].first; + std::vector temp_probs; + temp_probs.reserve(sorted_logits.size()); + float sum = 0.0f; + for (const auto& pair : sorted_logits) { + float prob = std::exp(pair.first - max_logit); + temp_probs.push_back(prob); + sum += prob; + } + + for (float& prob : temp_probs) { + prob /= sum; + } + + float cumulative_prob = 0.0f; + std::vector indices_to_remove(sorted_logits.size(), false); + for (size_t i = 0; i < sorted_logits.size(); ++i) { + cumulative_prob += temp_probs[i]; + if (cumulative_prob > top_p) { + indices_to_remove[i] = true; + } + } + + if (!indices_to_remove.empty()) { + for (size_t i = 1; i < indices_to_remove.size(); ++i) { + indices_to_remove[i] = indices_to_remove[i - 1] || indices_to_remove[i]; + } + indices_to_remove[0] = false; + } + + for (size_t i = 0; i < sorted_logits.size(); ++i) { + if (indices_to_remove[i]) { + filtered_logits[sorted_logits[i].second] = + -std::numeric_limits::infinity(); + } + } + } + + float max_logit = + *std::max_element(filtered_logits.begin(), filtered_logits.end()); + if (std::isinf(max_logit)) { + output[0] = 0; + return; + } + + std::vector probs(vocab_size); + float sum = 0.0f; + for (size_t i = 0; i < vocab_size; ++i) { + if (std::isinf(filtered_logits[i])) { + probs[i] = 0.0f; + } else { + probs[i] = std::exp(filtered_logits[i] - max_logit); + sum += probs[i]; + } + } + + if (sum == 0.0f) { + output[0] = 0; + return; + } + + for (size_t i = 0; i < vocab_size; ++i) { + probs[i] /= sum; + } + + uint32_t actual_seed = + (random_seed == 0) ? std::random_device{}() : random_seed; + std::mt19937 gen(actual_seed); + std::uniform_real_distribution dist(0.0f, 1.0f); + float sample = dist(gen); + + float cumulative = 0.0f; + for (size_t i = 0; i < vocab_size; ++i) { + cumulative += probs[i]; + if (cumulative >= sample) { + output[0] = static_cast(i); + return; + } + } + + for (size_t i = vocab_size; i > 0; --i) { + if (probs[i - 1] > 0.0f) { + output[0] = static_cast(i - 1); + return; + } + } + + output[0] = 0; +} + +void sample_f16(const __fp16* logits, uint32_t* output, size_t vocab_size, + float temperature, float top_p, size_t top_k, + size_t random_seed) { + std::vector<__fp16> filtered_logits(vocab_size); + + if (temperature > 0) { + __fp16 inv_temp = static_cast<__fp16>(1.0f / temperature); + float16x8_t inv_temp_vec = vdupq_n_f16(inv_temp); + size_t i = 0; + for (; i + 8 <= vocab_size; i += 8) { + float16x8_t logits_vec = vld1q_f16(&logits[i]); + float16x8_t scaled = vmulq_f16(logits_vec, inv_temp_vec); + vst1q_f16(&filtered_logits[i], scaled); + } + for (; i < vocab_size; ++i) { + filtered_logits[i] = logits[i] * inv_temp; + } + } else { + std::memcpy(filtered_logits.data(), logits, vocab_size * sizeof(__fp16)); + } + + static std::vector token_history; + static const size_t MAX_HISTORY = 128; + static const float REPETITION_PENALTY = 1.1f; + + if (!token_history.empty() && REPETITION_PENALTY != 1.0f) { + const __fp16 penalty_inv = static_cast<__fp16>(1.0f / REPETITION_PENALTY); + const __fp16 penalty = static_cast<__fp16>(REPETITION_PENALTY); + + for (uint32_t prev_token : token_history) { + if (prev_token < vocab_size) { + filtered_logits[prev_token] = + (filtered_logits[prev_token] > static_cast<__fp16>(0)) + ? static_cast<__fp16>(filtered_logits[prev_token] * penalty_inv) + : static_cast<__fp16>(filtered_logits[prev_token] * penalty); + } + } + } + + if (top_k > 0) { + std::vector> logit_pairs; + logit_pairs.reserve(vocab_size); + for (size_t i = 0; i < vocab_size; ++i) { + logit_pairs.emplace_back(filtered_logits[i], i); + } + std::partial_sort( + logit_pairs.begin(), logit_pairs.begin() + std::min(top_k, vocab_size), + logit_pairs.end(), + [](const auto& a, const auto& b) { return a.first > b.first; }); + + if (top_k < vocab_size) { + __fp16 kth_value = logit_pairs[top_k - 1].first; + __fp16 neg_inf = + static_cast<__fp16>(-std::numeric_limits::infinity()); + for (size_t i = 0; i < vocab_size; ++i) { + if (filtered_logits[i] < kth_value) { + filtered_logits[i] = neg_inf; + } + } + } + } + + if (top_p > 0.0f && top_p < 1.0f) { + std::vector> sorted_logits; + sorted_logits.reserve(vocab_size); + __fp16 neg_inf = + static_cast<__fp16>(-std::numeric_limits::infinity()); + for (size_t i = 0; i < vocab_size; ++i) { + if (filtered_logits[i] != neg_inf) { + sorted_logits.emplace_back(filtered_logits[i], i); + } + } + std::sort(sorted_logits.begin(), sorted_logits.end(), + [](const auto& a, const auto& b) { return a.first > b.first; }); + + __fp16 max_logit = sorted_logits.empty() ? static_cast<__fp16>(0.0f) + : sorted_logits[0].first; + std::vector temp_probs; + temp_probs.reserve(sorted_logits.size()); + float sum = 0.0f; + for (const auto& pair : sorted_logits) { + float prob = std::exp(static_cast(pair.first - max_logit)); + temp_probs.push_back(prob); + sum += prob; + } + + for (float& prob : temp_probs) { + prob /= sum; + } + + float cumulative_prob = 0.0f; + std::vector indices_to_remove(sorted_logits.size(), false); + bool threshold_reached = false; + for (size_t i = 0; i < sorted_logits.size(); ++i) { + cumulative_prob += temp_probs[i]; + if (cumulative_prob > top_p && i > 0) { + threshold_reached = true; + } + if (threshold_reached) { + indices_to_remove[i] = true; + } + } + + if (!indices_to_remove.empty()) { + for (size_t i = 1; i < indices_to_remove.size(); ++i) { + indices_to_remove[i] = indices_to_remove[i - 1] || indices_to_remove[i]; + } + indices_to_remove[0] = false; + } + + for (size_t i = 0; i < sorted_logits.size(); ++i) { + if (indices_to_remove[i]) { + filtered_logits[sorted_logits[i].second] = neg_inf; + } + } + } + + __fp16 max_logit = + *std::max_element(filtered_logits.begin(), filtered_logits.end()); + __fp16 neg_inf = static_cast<__fp16>(-std::numeric_limits::infinity()); + if (max_logit == neg_inf) { + output[0] = 0; + return; + } + + std::vector probs(vocab_size); + float sum = 0.0f; + for (size_t i = 0; i < vocab_size; ++i) { + if (filtered_logits[i] == neg_inf) { + probs[i] = 0.0f; + } else { + probs[i] = std::exp(static_cast(filtered_logits[i] - max_logit)); + sum += probs[i]; + } + } + + if (sum == 0.0f) { + output[0] = 0; + return; + } + + for (size_t i = 0; i < vocab_size; ++i) { + probs[i] /= sum; + } + + uint32_t actual_seed = + (random_seed == 0) ? std::random_device{}() : random_seed; + std::mt19937 gen(actual_seed); + std::uniform_real_distribution dist(0.0f, 1.0f); + float sample = dist(gen); + + float cumulative = 0.0f; + for (size_t i = 0; i < vocab_size; ++i) { + cumulative += probs[i]; + if (cumulative >= sample) { + output[0] = static_cast(i); + token_history.push_back(output[0]); + if (token_history.size() > MAX_HISTORY) { + token_history.erase(token_history.begin()); + } + return; + } + } + + for (size_t i = vocab_size; i > 0; --i) { + if (probs[i - 1] > 0.0f) { + output[0] = static_cast(i - 1); + token_history.push_back(output[0]); + if (token_history.size() > MAX_HISTORY) { + token_history.erase(token_history.begin()); + } + return; + } + } + + output[0] = 0; + token_history.push_back(output[0]); + if (token_history.size() > MAX_HISTORY) { + token_history.erase(token_history.begin()); + } +} + +} // namespace aphrodite::mobile diff --git a/aphrodite_kernels/csrc/cpu/mobile/quant.cpp b/aphrodite_kernels/csrc/cpu/mobile/quant.cpp new file mode 100644 index 0000000000..62162cf70a --- /dev/null +++ b/aphrodite_kernels/csrc/cpu/mobile/quant.cpp @@ -0,0 +1,290 @@ +#include "threading.hpp" +#include +#include +#include + +namespace aphrodite::mobile { + +void int8_to_fp32(const int8_t* src, float* dst, size_t count, float scale) { + aphrodite::mobile::parallel_for( + count, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [src, dst, scale](size_t start, size_t end) { + const size_t simd_end = start + ((end - start) / 16) * 16; + float32x4_t scale_vec = vdupq_n_f32(scale); + + for (size_t i = start; i < simd_end; i += 16) { + int8x16_t input = vld1q_s8(&src[i]); + + int16x8_t low = vmovl_s8(vget_low_s8(input)); + int16x8_t high = vmovl_s8(vget_high_s8(input)); + + int32x4_t low_low = vmovl_s16(vget_low_s16(low)); + int32x4_t low_high = vmovl_s16(vget_high_s16(low)); + int32x4_t high_low = vmovl_s16(vget_low_s16(high)); + int32x4_t high_high = vmovl_s16(vget_high_s16(high)); + + float32x4_t f_low_low = vcvtq_f32_s32(low_low); + float32x4_t f_low_high = vcvtq_f32_s32(low_high); + float32x4_t f_high_low = vcvtq_f32_s32(high_low); + float32x4_t f_high_high = vcvtq_f32_s32(high_high); + + vst1q_f32(&dst[i], vmulq_f32(f_low_low, scale_vec)); + vst1q_f32(&dst[i + 4], vmulq_f32(f_low_high, scale_vec)); + vst1q_f32(&dst[i + 8], vmulq_f32(f_high_low, scale_vec)); + vst1q_f32(&dst[i + 12], vmulq_f32(f_high_high, scale_vec)); + } + + for (size_t i = simd_end; i < end; ++i) { + dst[i] = static_cast(src[i]) * scale; + } + }); +} + +void fp32_to_int8(const float* src, int8_t* dst, size_t count, float scale) { + const float inv_scale = 1.0f / scale; + + aphrodite::mobile::parallel_for( + count, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [src, dst, inv_scale](size_t start, size_t end) { + const size_t simd_end = start + ((end - start) / 16) * 16; + float32x4_t inv_scale_vec = vdupq_n_f32(inv_scale); + float32x4_t min_vec = vdupq_n_f32(-128.0f); + float32x4_t max_vec = vdupq_n_f32(127.0f); + + for (size_t i = start; i < simd_end; i += 16) { + float32x4_t input_0 = vld1q_f32(&src[i]); + float32x4_t input_1 = vld1q_f32(&src[i + 4]); + float32x4_t input_2 = vld1q_f32(&src[i + 8]); + float32x4_t input_3 = vld1q_f32(&src[i + 12]); + + float32x4_t scaled_0 = vmulq_f32(input_0, inv_scale_vec); + float32x4_t scaled_1 = vmulq_f32(input_1, inv_scale_vec); + float32x4_t scaled_2 = vmulq_f32(input_2, inv_scale_vec); + float32x4_t scaled_3 = vmulq_f32(input_3, inv_scale_vec); + + scaled_0 = vmaxq_f32(vminq_f32(scaled_0, max_vec), min_vec); + scaled_1 = vmaxq_f32(vminq_f32(scaled_1, max_vec), min_vec); + scaled_2 = vmaxq_f32(vminq_f32(scaled_2, max_vec), min_vec); + scaled_3 = vmaxq_f32(vminq_f32(scaled_3, max_vec), min_vec); + + int32x4_t int_0 = vcvtnq_s32_f32(scaled_0); + int32x4_t int_1 = vcvtnq_s32_f32(scaled_1); + int32x4_t int_2 = vcvtnq_s32_f32(scaled_2); + int32x4_t int_3 = vcvtnq_s32_f32(scaled_3); + + int16x8_t int16_low = + vcombine_s16(vqmovn_s32(int_0), vqmovn_s32(int_1)); + int16x8_t int16_high = + vcombine_s16(vqmovn_s32(int_2), vqmovn_s32(int_3)); + + int8x16_t result = + vcombine_s8(vqmovn_s16(int16_low), vqmovn_s16(int16_high)); + vst1q_s8(&dst[i], result); + } + + for (size_t i = simd_end; i < end; ++i) { + float quantized = src[i] * inv_scale; + dst[i] = static_cast( + std::round(std::max(-128.0f, std::min(127.0f, quantized)))); + } + }); +} + +void dynamic_quantize_fp32_to_int8(const float* src, int8_t* dst, size_t count, + float* computed_scale) { + if (count == 0) return; + + float32x4_t abs_max_vec = vdupq_n_f32(0.0f); + const size_t simd_end = (count / 4) * 4; + + for (size_t i = 0; i < simd_end; i += 4) { + float32x4_t input = vld1q_f32(&src[i]); + float32x4_t abs_input = vabsq_f32(input); + abs_max_vec = vmaxq_f32(abs_max_vec, abs_input); + } + + float abs_max = vmaxvq_f32(abs_max_vec); + + for (size_t i = simd_end; i < count; ++i) { + abs_max = std::max(abs_max, std::abs(src[i])); + } + + float scale = abs_max / 127.0f; + if (scale == 0.0f) scale = 1.0f; + + fp32_to_int8(src, dst, count, scale); + + if (computed_scale) *computed_scale = scale; +} + +void fp16_to_fp32(const __fp16* src, float* dst, size_t count) { + aphrodite::mobile::parallel_for( + count, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [src, dst](size_t start, size_t end) { + const size_t simd_end = start + ((end - start) / 8) * 8; + + for (size_t i = start; i < simd_end; i += 8) { + float16x8_t input = vld1q_f16(&src[i]); + + float32x4_t output_low = vcvt_f32_f16(vget_low_f16(input)); + float32x4_t output_high = vcvt_f32_f16(vget_high_f16(input)); + + vst1q_f32(&dst[i], output_low); + vst1q_f32(&dst[i + 4], output_high); + } + + for (size_t i = simd_end; i < end; ++i) { + dst[i] = static_cast(src[i]); + } + }); +} + +void fp32_to_fp16(const float* src, __fp16* dst, size_t count) { + aphrodite::mobile::parallel_for( + count, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [src, dst](size_t start, size_t end) { + const size_t simd_end = start + ((end - start) / 8) * 8; + + for (size_t i = start; i < simd_end; i += 8) { + float32x4_t input_low = vld1q_f32(&src[i]); + float32x4_t input_high = vld1q_f32(&src[i + 4]); + + float16x4_t output_low = vcvt_f16_f32(input_low); + float16x4_t output_high = vcvt_f16_f32(input_high); + + float16x8_t output = vcombine_f16(output_low, output_high); + vst1q_f16(&dst[i], output); + } + + for (size_t i = simd_end; i < end; ++i) { + dst[i] = static_cast<__fp16>(src[i]); + } + }); +} + +void int8_to_fp16(const int8_t* src, __fp16* dst, size_t count, float scale) { + aphrodite::mobile::parallel_for( + count, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [src, dst, scale](size_t start, size_t end) { + const size_t simd_end = start + ((end - start) / 8) * 8; + float32x4_t scale_vec = vdupq_n_f32(scale); + + for (size_t i = start; i < simd_end; i += 8) { + int8x8_t input = vld1_s8(&src[i]); + + int16x8_t int16 = vmovl_s8(input); + int32x4_t int32_low = vmovl_s16(vget_low_s16(int16)); + int32x4_t int32_high = vmovl_s16(vget_high_s16(int16)); + + float32x4_t float_low = vcvtq_f32_s32(int32_low); + float32x4_t float_high = vcvtq_f32_s32(int32_high); + + float_low = vmulq_f32(float_low, scale_vec); + float_high = vmulq_f32(float_high, scale_vec); + + float16x8_t output = + vcombine_f16(vcvt_f16_f32(float_low), vcvt_f16_f32(float_high)); + vst1q_f16(&dst[i], output); + } + + for (size_t i = simd_end; i < end; ++i) { + dst[i] = static_cast<__fp16>(static_cast(src[i]) * scale); + } + }); +} + +void fp16_to_int8(const __fp16* src, int8_t* dst, size_t count, float scale) { + const float inv_scale = 1.0f / scale; + + aphrodite::mobile::parallel_for( + count, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [src, dst, inv_scale](size_t start, size_t end) { + const size_t simd_end = start + ((end - start) / 8) * 8; + float32x4_t inv_scale_vec = vdupq_n_f32(inv_scale); + float32x4_t min_vec = vdupq_n_f32(-128.0f); + float32x4_t max_vec = vdupq_n_f32(127.0f); + + for (size_t i = start; i < simd_end; i += 8) { + float16x8_t input = vld1q_f16(&src[i]); + + float32x4_t input_low = vcvt_f32_f16(vget_low_f16(input)); + float32x4_t input_high = vcvt_f32_f16(vget_high_f16(input)); + + float32x4_t scaled_low = vmulq_f32(input_low, inv_scale_vec); + float32x4_t scaled_high = vmulq_f32(input_high, inv_scale_vec); + + scaled_low = vmaxq_f32(vminq_f32(scaled_low, max_vec), min_vec); + scaled_high = vmaxq_f32(vminq_f32(scaled_high, max_vec), min_vec); + + int32x4_t int_low = vcvtnq_s32_f32(scaled_low); + int32x4_t int_high = vcvtnq_s32_f32(scaled_high); + + int16x8_t int16_combined = + vcombine_s16(vqmovn_s32(int_low), vqmovn_s32(int_high)); + int8x8_t result = vqmovn_s16(int16_combined); + + vst1_s8(&dst[i], result); + } + + for (size_t i = simd_end; i < end; ++i) { + float quantized = static_cast(src[i]) * inv_scale; + dst[i] = static_cast( + std::round(std::max(-128.0f, std::min(127.0f, quantized)))); + } + }); +} + +float fp16_max_abs(const __fp16* src, size_t count) { + float32x4_t abs_max_vec = vdupq_n_f32(0.0f); + const size_t simd_end = (count / 8) * 8; + + for (size_t i = 0; i < simd_end; i += 8) { + float16x8_t input = vld1q_f16(&src[i]); + + float32x4_t input_low = vcvt_f32_f16(vget_low_f16(input)); + float32x4_t input_high = vcvt_f32_f16(vget_high_f16(input)); + + float32x4_t abs_low = vabsq_f32(input_low); + float32x4_t abs_high = vabsq_f32(input_high); + + abs_max_vec = vmaxq_f32(abs_max_vec, abs_low); + abs_max_vec = vmaxq_f32(abs_max_vec, abs_high); + } + + float max_abs = vmaxvq_f32(abs_max_vec); + + for (size_t i = simd_end; i < count; ++i) { + float abs_val = std::abs(static_cast(src[i])); + max_abs = std::max(max_abs, abs_val); + } + + return max_abs; +} + +void int32_to_fp16_scaled(const int32_t* src, __fp16* dst, size_t count, + float scale) { + float32x4_t scale_vec = vdupq_n_f32(scale); + const size_t simd_end = (count / 8) * 8; + + for (size_t i = 0; i < simd_end; i += 8) { + int32x4_t int_low = vld1q_s32(&src[i]); + int32x4_t int_high = vld1q_s32(&src[i + 4]); + + float32x4_t fp32_low = vcvtq_f32_s32(int_low); + float32x4_t fp32_high = vcvtq_f32_s32(int_high); + + float32x4_t scaled_low = vmulq_f32(fp32_low, scale_vec); + float32x4_t scaled_high = vmulq_f32(fp32_high, scale_vec); + + float16x8_t result = + vcombine_f16(vcvt_f16_f32(scaled_low), vcvt_f16_f32(scaled_high)); + vst1q_f16(&dst[i], result); + } + + for (size_t i = simd_end; i < count; ++i) { + float fp32_val = static_cast(src[i]) * scale; + dst[i] = static_cast<__fp16>(fp32_val); + } +} + +} // namespace aphrodite::mobile \ No newline at end of file diff --git a/aphrodite_kernels/csrc/cpu/mobile/reduce.cpp b/aphrodite_kernels/csrc/cpu/mobile/reduce.cpp new file mode 100644 index 0000000000..251275ebe1 --- /dev/null +++ b/aphrodite_kernels/csrc/cpu/mobile/reduce.cpp @@ -0,0 +1,846 @@ +#include "threading.hpp" +#include +#include +#include +#include +#include + +namespace aphrodite::mobile { + +int64_t sum_all_int8(const int8_t* data, size_t num_elements) { + return aphrodite::mobile::parallel_reduce( + num_elements, aphrodite::mobile::Thresholds::ALL_REDUCE, + [&](size_t start_idx, size_t end_idx) -> int64_t { + constexpr size_t VECTOR_WIDTH = 16; + constexpr size_t TILE_SIZE = VECTOR_WIDTH * 4; + const size_t tile_aligned = + ((end_idx - start_idx) / TILE_SIZE) * TILE_SIZE + start_idx; + + int32x4_t sum_vec[4] = {vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0), + vdupq_n_s32(0)}; + + for (size_t i = start_idx; i < tile_aligned; i += TILE_SIZE) { + int8x16_t input_vec[4]; + input_vec[0] = vld1q_s8(&data[i]); + input_vec[1] = vld1q_s8(&data[i + VECTOR_WIDTH]); + input_vec[2] = vld1q_s8(&data[i + VECTOR_WIDTH * 2]); + input_vec[3] = vld1q_s8(&data[i + VECTOR_WIDTH * 3]); + + for (int j = 0; j < 4; ++j) { + int16x8_t low = vmovl_s8(vget_low_s8(input_vec[j])); + int16x8_t high = vmovl_s8(vget_high_s8(input_vec[j])); + + sum_vec[j] = vaddq_s32(sum_vec[j], vmovl_s16(vget_low_s16(low))); + sum_vec[j] = vaddq_s32(sum_vec[j], vmovl_s16(vget_high_s16(low))); + sum_vec[j] = vaddq_s32(sum_vec[j], vmovl_s16(vget_low_s16(high))); + sum_vec[j] = vaddq_s32(sum_vec[j], vmovl_s16(vget_high_s16(high))); + } + } + + const size_t vectorized_end = + ((end_idx - start_idx) / VECTOR_WIDTH) * VECTOR_WIDTH + start_idx; + for (size_t i = tile_aligned; i < vectorized_end; i += VECTOR_WIDTH) { + int8x16_t input_vec = vld1q_s8(&data[i]); + + int16x8_t low = vmovl_s8(vget_low_s8(input_vec)); + int16x8_t high = vmovl_s8(vget_high_s8(input_vec)); + + sum_vec[0] = vaddq_s32(sum_vec[0], vmovl_s16(vget_low_s16(low))); + sum_vec[0] = vaddq_s32(sum_vec[0], vmovl_s16(vget_high_s16(low))); + sum_vec[0] = vaddq_s32(sum_vec[0], vmovl_s16(vget_low_s16(high))); + sum_vec[0] = vaddq_s32(sum_vec[0], vmovl_s16(vget_high_s16(high))); + } + + int64_t thread_sum = 0; + for (int j = 0; j < 4; ++j) { + thread_sum += vaddvq_s32(sum_vec[j]); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + thread_sum += static_cast(data[i]); + } + + return thread_sum; + }, + 0LL, [](int64_t a, int64_t b) { return a + b; }); +} + +void sum_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, + size_t axis_size, size_t inner_size) { + if (inner_size > 1) { + aphrodite::mobile::parallel_for_2d( + outer_size, inner_size, aphrodite::mobile::Thresholds::AXIS_REDUCE, + [&](size_t outer, size_t inner) { + int32x4_t sum_vec = vdupq_n_s32(0); + int32_t scalar_sum = 0; + + constexpr size_t CACHE_BLOCK_SIZE = 128; + for (size_t block_start = 0; block_start < axis_size; + block_start += CACHE_BLOCK_SIZE) { + size_t block_end = + std::min(block_start + CACHE_BLOCK_SIZE, axis_size); + const size_t vectorized_axis = + ((block_end - block_start) / 4) * 4 + block_start; + + for (size_t a = block_start; a < vectorized_axis; a += 4) { + int32_t values[4]; + for (int j = 0; j < 4; j++) { + size_t idx = outer * axis_size * inner_size + + (a + j) * inner_size + inner; + values[j] = static_cast(input[idx]); + } + int32x4_t input_vec = vld1q_s32(values); + sum_vec = vaddq_s32(sum_vec, input_vec); + } + + for (size_t a = vectorized_axis; a < block_end; a++) { + size_t idx = + outer * axis_size * inner_size + a * inner_size + inner; + scalar_sum += static_cast(input[idx]); + } + } + + int32_t total_sum = vaddvq_s32(sum_vec) + scalar_sum; + total_sum = std::min(127, std::max(-128, total_sum)); + size_t output_idx = outer * inner_size + inner; + output[output_idx] = static_cast(total_sum); + }); + } else { + aphrodite::mobile::parallel_for( + outer_size, aphrodite::mobile::Thresholds::AXIS_REDUCE, + [&](size_t start_outer, size_t end_outer) { + for (size_t outer = start_outer; outer < end_outer; outer++) { + int32x4_t sum_vec = vdupq_n_s32(0); + size_t processed = 0; + + const size_t vectorized_axis = + (axis_size / NEON_VECTOR_SIZE) * NEON_VECTOR_SIZE; + for (size_t a = 0; a < vectorized_axis; a += NEON_VECTOR_SIZE) { + size_t base_idx = outer * axis_size + a; + int8x16_t input_vec = vld1q_s8(&input[base_idx]); + + int16x8_t low = vmovl_s8(vget_low_s8(input_vec)); + int16x8_t high = vmovl_s8(vget_high_s8(input_vec)); + + int32x4_t low_low = vmovl_s16(vget_low_s16(low)); + int32x4_t low_high = vmovl_s16(vget_high_s16(low)); + int32x4_t high_low = vmovl_s16(vget_low_s16(high)); + int32x4_t high_high = vmovl_s16(vget_high_s16(high)); + + sum_vec = vaddq_s32(sum_vec, low_low); + sum_vec = vaddq_s32(sum_vec, low_high); + sum_vec = vaddq_s32(sum_vec, high_low); + sum_vec = vaddq_s32(sum_vec, high_high); + + processed = a + NEON_VECTOR_SIZE; + } + + int32_t total_sum = vaddvq_s32(sum_vec); + + for (size_t a = processed; a < axis_size; a++) { + size_t idx = outer * axis_size + a; + total_sum += static_cast(input[idx]); + } + + total_sum = std::min(127, std::max(-128, total_sum)); + output[outer] = static_cast(total_sum); + } + }); + } +} + +double mean_all_int8(const int8_t* data, size_t num_elements) { + int64_t sum = sum_all_int8(data, num_elements); + return static_cast(sum) / static_cast(num_elements); +} + +void mean_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, + size_t axis_size, size_t inner_size) { + cactus_sum_axis_int8(input, output, outer_size, axis_size, inner_size); + + size_t result_size = outer_size * inner_size; + for (size_t i = 0; i < result_size; i++) { + double mean_val = + static_cast(output[i]) / static_cast(axis_size); + int8_t clamped_val = static_cast(std::round(mean_val)); + if (mean_val > 127) clamped_val = 127; + if (mean_val < -128) clamped_val = -128; + output[i] = clamped_val; + } +} + +double variance_all_int8(const int8_t* data, size_t num_elements) { + double mean = cactus_mean_all_int8(data, num_elements); + const size_t vectorized_elements = + (num_elements / NEON_VECTOR_SIZE) * NEON_VECTOR_SIZE; + + float32x4_t sum_squared_diff_vec = vdupq_n_f32(0.0f); + float32x4_t mean_vec = vdupq_n_f32(static_cast(mean)); + + for (size_t i = 0; i < vectorized_elements; i += NEON_VECTOR_SIZE) { + int8x16_t input_vec = vld1q_s8(&data[i]); + + int16x8_t low = vmovl_s8(vget_low_s8(input_vec)); + int16x8_t high = vmovl_s8(vget_high_s8(input_vec)); + + float32x4_t low_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(low))); + float32x4_t low_high = vcvtq_f32_s32(vmovl_s16(vget_high_s16(low))); + float32x4_t high_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(high))); + float32x4_t high_high = vcvtq_f32_s32(vmovl_s16(vget_high_s16(high))); + + float32x4_t diff_low_low = vsubq_f32(low_low, mean_vec); + float32x4_t diff_low_high = vsubq_f32(low_high, mean_vec); + float32x4_t diff_high_low = vsubq_f32(high_low, mean_vec); + float32x4_t diff_high_high = vsubq_f32(high_high, mean_vec); + + sum_squared_diff_vec = + vfmaq_f32(sum_squared_diff_vec, diff_low_low, diff_low_low); + sum_squared_diff_vec = + vfmaq_f32(sum_squared_diff_vec, diff_low_high, diff_low_high); + sum_squared_diff_vec = + vfmaq_f32(sum_squared_diff_vec, diff_high_low, diff_high_low); + sum_squared_diff_vec = + vfmaq_f32(sum_squared_diff_vec, diff_high_high, diff_high_high); + } + + double sum_squared_diff = vaddvq_f32(sum_squared_diff_vec); + + for (size_t i = vectorized_elements; i < num_elements; ++i) { + double diff = static_cast(data[i]) - mean; + sum_squared_diff += diff * diff; + } + + return sum_squared_diff / static_cast(num_elements); +} + +void variance_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, + size_t axis_size, size_t inner_size) { + std::vector mean_output(outer_size * inner_size); + cactus_sum_axis_int8(input, mean_output.data(), outer_size, axis_size, + inner_size); + + for (size_t i = 0; i < outer_size * inner_size; i++) { + double mean_val = + static_cast(mean_output[i]) / static_cast(axis_size); + mean_output[i] = static_cast(std::round(mean_val)); + } + + std::vector sum_squared_diff(outer_size * inner_size, 0.0); + + for (size_t outer = 0; outer < outer_size; outer++) { + for (size_t inner = 0; inner < inner_size; inner++) { + size_t mean_idx = outer * inner_size + inner; + double mean_val = static_cast(mean_output[mean_idx]); + + float32x4_t sum_squared_diff_vec = vdupq_n_f32(0.0f); + float32x4_t mean_vec = vdupq_n_f32(static_cast(mean_val)); + + if (inner_size == 1) { + const size_t vectorized_axis = + (axis_size / NEON_VECTOR_SIZE) * NEON_VECTOR_SIZE; + for (size_t a = 0; a < vectorized_axis; a += NEON_VECTOR_SIZE) { + size_t base_idx = outer * axis_size + a; + int8x16_t input_vec = vld1q_s8(&input[base_idx]); + + int16x8_t low = vmovl_s8(vget_low_s8(input_vec)); + int16x8_t high = vmovl_s8(vget_high_s8(input_vec)); + + float32x4_t low_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(low))); + float32x4_t low_high = vcvtq_f32_s32(vmovl_s16(vget_high_s16(low))); + float32x4_t high_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(high))); + float32x4_t high_high = vcvtq_f32_s32(vmovl_s16(vget_high_s16(high))); + + float32x4_t diff_low_low = vsubq_f32(low_low, mean_vec); + float32x4_t diff_low_high = vsubq_f32(low_high, mean_vec); + float32x4_t diff_high_low = vsubq_f32(high_low, mean_vec); + float32x4_t diff_high_high = vsubq_f32(high_high, mean_vec); + + sum_squared_diff_vec = + vfmaq_f32(sum_squared_diff_vec, diff_low_low, diff_low_low); + sum_squared_diff_vec = + vfmaq_f32(sum_squared_diff_vec, diff_low_high, diff_low_high); + sum_squared_diff_vec = + vfmaq_f32(sum_squared_diff_vec, diff_high_low, diff_high_low); + sum_squared_diff_vec = + vfmaq_f32(sum_squared_diff_vec, diff_high_high, diff_high_high); + } + } else { + const size_t vectorized_axis = (axis_size / 4) * 4; + for (size_t a = 0; a < vectorized_axis; a += 4) { + float values[4]; + for (int j = 0; j < 4; j++) { + size_t idx = + outer * axis_size * inner_size + (a + j) * inner_size + inner; + values[j] = static_cast(input[idx]); + } + float32x4_t input_vec = vld1q_f32(values); + float32x4_t diff_vec = vsubq_f32(input_vec, mean_vec); + sum_squared_diff_vec = + vfmaq_f32(sum_squared_diff_vec, diff_vec, diff_vec); + } + } + + double sum_sq_diff = vaddvq_f32(sum_squared_diff_vec); + + size_t vectorized_axis = + (inner_size == 1) ? (axis_size / NEON_VECTOR_SIZE) * NEON_VECTOR_SIZE + : (axis_size / 4) * 4; + for (size_t a = vectorized_axis; a < axis_size; a++) { + size_t idx = outer * axis_size * inner_size + a * inner_size + inner; + double diff = static_cast(input[idx]) - mean_val; + sum_sq_diff += diff * diff; + } + + sum_squared_diff[mean_idx] = sum_sq_diff; + } + } + + for (size_t i = 0; i < sum_squared_diff.size(); i++) { + double variance_val = sum_squared_diff[i] / static_cast(axis_size); + int8_t clamped_val = static_cast(std::round(variance_val)); + if (variance_val > 127) clamped_val = 127; + if (variance_val < -128) clamped_val = -128; + output[i] = clamped_val; + } +} + +int64_t min_all_int8(const int8_t* data, size_t num_elements) { + if (num_elements == 0) return 0; + + constexpr size_t VECTOR_WIDTH = 16; + constexpr size_t TILE_SIZE = VECTOR_WIDTH * 4; + const size_t tile_aligned = (num_elements / TILE_SIZE) * TILE_SIZE; + + int8x16_t min_vec[4] = {vdupq_n_s8(std::numeric_limits::max()), + vdupq_n_s8(std::numeric_limits::max()), + vdupq_n_s8(std::numeric_limits::max()), + vdupq_n_s8(std::numeric_limits::max())}; + + for (size_t i = 0; i < tile_aligned; i += TILE_SIZE) { + int8x16_t input_vec[4]; + input_vec[0] = vld1q_s8(&data[i]); + input_vec[1] = vld1q_s8(&data[i + VECTOR_WIDTH]); + input_vec[2] = vld1q_s8(&data[i + VECTOR_WIDTH * 2]); + input_vec[3] = vld1q_s8(&data[i + VECTOR_WIDTH * 3]); + + min_vec[0] = vminq_s8(min_vec[0], input_vec[0]); + min_vec[1] = vminq_s8(min_vec[1], input_vec[1]); + min_vec[2] = vminq_s8(min_vec[2], input_vec[2]); + min_vec[3] = vminq_s8(min_vec[3], input_vec[3]); + } + + const size_t vectorized_elements = + (num_elements / VECTOR_WIDTH) * VECTOR_WIDTH; + for (size_t i = tile_aligned; i < vectorized_elements; i += VECTOR_WIDTH) { + int8x16_t input_vec = vld1q_s8(&data[i]); + min_vec[0] = vminq_s8(min_vec[0], input_vec); + } + + int8x16_t final_min = vminq_s8(vminq_s8(min_vec[0], min_vec[1]), + vminq_s8(min_vec[2], min_vec[3])); + int8_t min_val = vminvq_s8(final_min); + + for (size_t i = vectorized_elements; i < num_elements; ++i) { + if (data[i] < min_val) { + min_val = data[i]; + } + } + + return static_cast(min_val); +} + +void min_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, + size_t axis_size, size_t inner_size) { + aphrodite::mobile::parallel_for_2d( + outer_size, inner_size, aphrodite::mobile::Thresholds::AXIS_REDUCE, + [&](size_t outer, size_t inner) { + int8x16_t min_vec = vdupq_n_s8(std::numeric_limits::max()); + size_t vectorized_axis; + + if (inner_size == 1) { + vectorized_axis = (axis_size / NEON_VECTOR_SIZE) * NEON_VECTOR_SIZE; + for (size_t a = 0; a < vectorized_axis; a += NEON_VECTOR_SIZE) { + size_t base_idx = outer * axis_size + a; + int8x16_t input_vec = vld1q_s8(&input[base_idx]); + min_vec = vminq_s8(min_vec, input_vec); + } + } else { + vectorized_axis = (axis_size / NEON_VECTOR_SIZE) * NEON_VECTOR_SIZE; + for (size_t a = 0; a < vectorized_axis; a += NEON_VECTOR_SIZE) { + int8_t values[NEON_VECTOR_SIZE]; + for (size_t j = 0; j < NEON_VECTOR_SIZE; j++) { + size_t idx = + outer * axis_size * inner_size + (a + j) * inner_size + inner; + values[j] = input[idx]; + } + int8x16_t input_vec = vld1q_s8(values); + min_vec = vminq_s8(min_vec, input_vec); + } + } + + int8_t min_val = vminvq_s8(min_vec); + + for (size_t a = vectorized_axis; a < axis_size; a++) { + size_t idx; + if (inner_size == 1) { + idx = outer * axis_size + a; + } else { + idx = outer * axis_size * inner_size + a * inner_size + inner; + } + if (input[idx] < min_val) { + min_val = input[idx]; + } + } + + size_t output_idx = outer * inner_size + inner; + output[output_idx] = min_val; + }); +} + +int64_t max_all_int8(const int8_t* data, size_t num_elements) { + if (num_elements == 0) return 0; + + constexpr size_t VECTOR_WIDTH = 16; + constexpr size_t TILE_SIZE = VECTOR_WIDTH * 4; + const size_t tile_aligned = (num_elements / TILE_SIZE) * TILE_SIZE; + + int8x16_t max_vec[4] = {vdupq_n_s8(std::numeric_limits::min()), + vdupq_n_s8(std::numeric_limits::min()), + vdupq_n_s8(std::numeric_limits::min()), + vdupq_n_s8(std::numeric_limits::min())}; + + for (size_t i = 0; i < tile_aligned; i += TILE_SIZE) { + int8x16_t input_vec[4]; + input_vec[0] = vld1q_s8(&data[i]); + input_vec[1] = vld1q_s8(&data[i + VECTOR_WIDTH]); + input_vec[2] = vld1q_s8(&data[i + VECTOR_WIDTH * 2]); + input_vec[3] = vld1q_s8(&data[i + VECTOR_WIDTH * 3]); + + max_vec[0] = vmaxq_s8(max_vec[0], input_vec[0]); + max_vec[1] = vmaxq_s8(max_vec[1], input_vec[1]); + max_vec[2] = vmaxq_s8(max_vec[2], input_vec[2]); + max_vec[3] = vmaxq_s8(max_vec[3], input_vec[3]); + } + + const size_t vectorized_elements = + (num_elements / VECTOR_WIDTH) * VECTOR_WIDTH; + for (size_t i = tile_aligned; i < vectorized_elements; i += VECTOR_WIDTH) { + int8x16_t input_vec = vld1q_s8(&data[i]); + max_vec[0] = vmaxq_s8(max_vec[0], input_vec); + } + + int8x16_t final_max = vmaxq_s8(vmaxq_s8(max_vec[0], max_vec[1]), + vmaxq_s8(max_vec[2], max_vec[3])); + int8_t max_val = vmaxvq_s8(final_max); + + for (size_t i = vectorized_elements; i < num_elements; ++i) { + if (data[i] > max_val) { + max_val = data[i]; + } + } + + return static_cast(max_val); +} + +void max_axis_int8(const int8_t* input, int8_t* output, size_t outer_size, + size_t axis_size, size_t inner_size) { + aphrodite::mobile::parallel_for_2d( + outer_size, inner_size, aphrodite::mobile::Thresholds::AXIS_REDUCE, + [&](size_t outer, size_t inner) { + int8x16_t max_vec = vdupq_n_s8(std::numeric_limits::min()); + size_t vectorized_axis; + + if (inner_size == 1) { + vectorized_axis = (axis_size / NEON_VECTOR_SIZE) * NEON_VECTOR_SIZE; + for (size_t a = 0; a < vectorized_axis; a += NEON_VECTOR_SIZE) { + size_t base_idx = outer * axis_size + a; + int8x16_t input_vec = vld1q_s8(&input[base_idx]); + max_vec = vmaxq_s8(max_vec, input_vec); + } + } else { + vectorized_axis = (axis_size / NEON_VECTOR_SIZE) * NEON_VECTOR_SIZE; + for (size_t a = 0; a < vectorized_axis; a += NEON_VECTOR_SIZE) { + int8_t values[NEON_VECTOR_SIZE]; + for (size_t j = 0; j < NEON_VECTOR_SIZE; j++) { + size_t idx = + outer * axis_size * inner_size + (a + j) * inner_size + inner; + values[j] = input[idx]; + } + int8x16_t input_vec = vld1q_s8(values); + max_vec = vmaxq_s8(max_vec, input_vec); + } + } + + int8_t max_val = vmaxvq_s8(max_vec); + + for (size_t a = vectorized_axis; a < axis_size; a++) { + size_t idx; + if (inner_size == 1) { + idx = outer * axis_size + a; + } else { + idx = outer * axis_size * inner_size + a * inner_size + inner; + } + if (input[idx] > max_val) { + max_val = input[idx]; + } + } + + size_t output_idx = outer * inner_size + inner; + output[output_idx] = max_val; + }); +} + +double sum_all_f32(const float* data, size_t num_elements) { + return aphrodite::mobile::parallel_reduce( + num_elements, aphrodite::mobile::Thresholds::ALL_REDUCE, + [&](size_t start_idx, size_t end_idx) -> double { + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + float32x4_t sum_vec = vdupq_n_f32(0.0f); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float32x4_t input_vec = vld1q_f32(&data[i]); + sum_vec = vaddq_f32(sum_vec, input_vec); + } + + double thread_sum = static_cast(vaddvq_f32(sum_vec)); + + for (size_t i = vectorized_end; i < end_idx; ++i) { + thread_sum += static_cast(data[i]); + } + + return thread_sum; + }, + 0.0, [](double a, double b) { return a + b; }); +} + +void sum_axis_f32(const float* input, float* output, size_t outer_size, + size_t axis_size, size_t inner_size) { + aphrodite::mobile::parallel_for_2d( + outer_size, inner_size, aphrodite::mobile::Thresholds::AXIS_REDUCE, + [&](size_t outer, size_t inner) { + float32x4_t sum_vec = vdupq_n_f32(0.0f); + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_axis = (axis_size / SIMD_WIDTH) * SIMD_WIDTH; + + for (size_t a = 0; a < vectorized_axis; a += SIMD_WIDTH) { + float values[SIMD_WIDTH]; + for (size_t j = 0; j < SIMD_WIDTH; j++) { + size_t idx = + outer * axis_size * inner_size + (a + j) * inner_size + inner; + values[j] = input[idx]; + } + float32x4_t input_vec = vld1q_f32(values); + sum_vec = vaddq_f32(sum_vec, input_vec); + } + + float total_sum = vaddvq_f32(sum_vec); + + for (size_t a = vectorized_axis; a < axis_size; a++) { + size_t idx = outer * axis_size * inner_size + a * inner_size + inner; + total_sum += input[idx]; + } + + size_t output_idx = outer * inner_size + inner; + output[output_idx] = total_sum; + }); +} + +double mean_all_f16(const __fp16* data, size_t num_elements) { + return aphrodite::mobile::parallel_reduce( + num_elements, aphrodite::mobile::Thresholds::ALL_REDUCE, + [&](size_t start_idx, size_t end_idx) -> double { + constexpr size_t SIMD_WIDTH = 8; + const size_t vectorized_end = + start_idx + + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + float16x8_t sum_vec = vdupq_n_f16(0.0f); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float16x8_t input_vec = vld1q_f16(&data[i]); + sum_vec = vaddq_f16(sum_vec, input_vec); + } + + double thread_sum = 0.0; + __fp16 sum_array[8]; + vst1q_f16(sum_array, sum_vec); + for (int j = 0; j < 8; j++) { + thread_sum += static_cast(sum_array[j]); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + thread_sum += static_cast(data[i]); + } + + return thread_sum; + }, + 0.0, [](double a, double b) { return a + b; }) / + static_cast(num_elements); +} + +void mean_axis_f16(const __fp16* input, __fp16* output, size_t outer_size, + size_t axis_size, size_t inner_size) { + aphrodite::mobile::parallel_for_2d( + outer_size, inner_size, aphrodite::mobile::Thresholds::AXIS_REDUCE, + [&](size_t outer, size_t inner) { + float16x8_t sum_vec = vdupq_n_f16(0.0f); + + constexpr size_t SIMD_WIDTH = 8; + const size_t vectorized_axis = (axis_size / SIMD_WIDTH) * SIMD_WIDTH; + + for (size_t a = 0; a < vectorized_axis; a += SIMD_WIDTH) { + __fp16 values[SIMD_WIDTH]; + for (size_t j = 0; j < SIMD_WIDTH; j++) { + size_t idx = + outer * axis_size * inner_size + (a + j) * inner_size + inner; + values[j] = input[idx]; + } + float16x8_t input_vec = vld1q_f16(values); + sum_vec = vaddq_f16(sum_vec, input_vec); + } + + __fp16 total_sum = 0.0f; + __fp16 sum_array[8]; + vst1q_f16(sum_array, sum_vec); + for (int j = 0; j < 8; j++) { + total_sum += sum_array[j]; + } + + for (size_t a = vectorized_axis; a < axis_size; a++) { + size_t idx = outer * axis_size * inner_size + a * inner_size + inner; + total_sum += input[idx]; + } + + size_t output_idx = outer * inner_size + inner; + output[output_idx] = total_sum / static_cast<__fp16>(axis_size); + }); +} + +double mean_all_f32(const float* data, size_t num_elements) { + double sum = cactus_sum_all_f32(data, num_elements); + return sum / static_cast(num_elements); +} + +void mean_axis_f32(const float* input, float* output, size_t outer_size, + size_t axis_size, size_t inner_size) { + cactus_sum_axis_f32(input, output, outer_size, axis_size, inner_size); + const float divisor = static_cast(axis_size); + + aphrodite::mobile::parallel_for( + outer_size * inner_size, aphrodite::mobile::Thresholds::ELEMENT_WISE, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + float32x4_t divisor_vec = vdupq_n_f32(divisor); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float32x4_t output_vec = vld1q_f32(&output[i]); + output_vec = vdivq_f32(output_vec, divisor_vec); + vst1q_f32(&output[i], output_vec); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + output[i] /= divisor; + } + }); +} + +double variance_all_f32(const float* data, size_t num_elements) { + double mean = cactus_mean_all_f32(data, num_elements); + + return aphrodite::mobile::parallel_reduce( + num_elements, aphrodite::mobile::Thresholds::ALL_REDUCE, + [&](size_t start_idx, size_t end_idx) -> double { + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_end = + start_idx + + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + float32x4_t mean_vec = vdupq_n_f32(static_cast(mean)); + float32x4_t var_vec = vdupq_n_f32(0.0f); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float32x4_t input_vec = vld1q_f32(&data[i]); + float32x4_t diff = vsubq_f32(input_vec, mean_vec); + var_vec = vmlaq_f32(var_vec, diff, diff); + } + + double thread_var = static_cast(vaddvq_f32(var_vec)); + + for (size_t i = vectorized_end; i < end_idx; ++i) { + double diff = static_cast(data[i]) - mean; + thread_var += diff * diff; + } + + return thread_var; + }, + 0.0, [](double a, double b) { return a + b; }) / + static_cast(num_elements); +} + +void variance_axis_f32(const float* input, float* output, size_t outer_size, + size_t axis_size, size_t inner_size) { + std::vector means(outer_size * inner_size); + cactus_mean_axis_f32(input, means.data(), outer_size, axis_size, inner_size); + + aphrodite::mobile::parallel_for_2d( + outer_size, inner_size, aphrodite::mobile::Thresholds::AXIS_REDUCE, + [&](size_t outer, size_t inner) { + size_t output_idx = outer * inner_size + inner; + float mean_val = means[output_idx]; + + float32x4_t mean_vec = vdupq_n_f32(mean_val); + float32x4_t var_vec = vdupq_n_f32(0.0f); + + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_axis = (axis_size / SIMD_WIDTH) * SIMD_WIDTH; + + for (size_t a = 0; a < vectorized_axis; a += SIMD_WIDTH) { + float values[SIMD_WIDTH]; + for (size_t j = 0; j < SIMD_WIDTH; j++) { + size_t idx = + outer * axis_size * inner_size + (a + j) * inner_size + inner; + values[j] = input[idx]; + } + float32x4_t input_vec = vld1q_f32(values); + float32x4_t diff = vsubq_f32(input_vec, mean_vec); + var_vec = vmlaq_f32(var_vec, diff, diff); + } + + float total_var = vaddvq_f32(var_vec); + + for (size_t a = vectorized_axis; a < axis_size; a++) { + size_t idx = outer * axis_size * inner_size + a * inner_size + inner; + float diff = input[idx] - mean_val; + total_var += diff * diff; + } + + output[output_idx] = total_var / static_cast(axis_size); + }); +} + +float min_all_f32(const float* data, size_t num_elements) { + return aphrodite::mobile::parallel_reduce( + num_elements, aphrodite::mobile::Thresholds::ALL_REDUCE, + [&](size_t start_idx, size_t end_idx) -> float { + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + float32x4_t min_vec = vdupq_n_f32(std::numeric_limits::max()); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float32x4_t input_vec = vld1q_f32(&data[i]); + min_vec = vminq_f32(min_vec, input_vec); + } + + float thread_min = vminvq_f32(min_vec); + + for (size_t i = vectorized_end; i < end_idx; ++i) { + thread_min = std::min(thread_min, data[i]); + } + + return thread_min; + }, + std::numeric_limits::max(), + [](float a, float b) { return std::min(a, b); }); +} + +void min_axis_f32(const float* input, float* output, size_t outer_size, + size_t axis_size, size_t inner_size) { + aphrodite::mobile::parallel_for_2d( + outer_size, inner_size, aphrodite::mobile::Thresholds::AXIS_REDUCE, + [&](size_t outer, size_t inner) { + float32x4_t min_vec = vdupq_n_f32(std::numeric_limits::max()); + + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_axis = (axis_size / SIMD_WIDTH) * SIMD_WIDTH; + + for (size_t a = 0; a < vectorized_axis; a += SIMD_WIDTH) { + float values[SIMD_WIDTH]; + for (size_t j = 0; j < SIMD_WIDTH; j++) { + size_t idx = + outer * axis_size * inner_size + (a + j) * inner_size + inner; + values[j] = input[idx]; + } + float32x4_t input_vec = vld1q_f32(values); + min_vec = vminq_f32(min_vec, input_vec); + } + + float min_val = vminvq_f32(min_vec); + + for (size_t a = vectorized_axis; a < axis_size; a++) { + size_t idx = outer * axis_size * inner_size + a * inner_size + inner; + min_val = std::min(min_val, input[idx]); + } + + size_t output_idx = outer * inner_size + inner; + output[output_idx] = min_val; + }); +} + +float max_all_f32(const float* data, size_t num_elements) { + return aphrodite::mobile::parallel_reduce( + num_elements, aphrodite::mobile::Thresholds::ALL_REDUCE, + [&](size_t start_idx, size_t end_idx) -> float { + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + float32x4_t max_vec = vdupq_n_f32(std::numeric_limits::lowest()); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float32x4_t input_vec = vld1q_f32(&data[i]); + max_vec = vmaxq_f32(max_vec, input_vec); + } + + float thread_max = vmaxvq_f32(max_vec); + + for (size_t i = vectorized_end; i < end_idx; ++i) { + thread_max = std::max(thread_max, data[i]); + } + + return thread_max; + }, + std::numeric_limits::lowest(), + [](float a, float b) { return std::max(a, b); }); +} + +void max_axis_f32(const float* input, float* output, size_t outer_size, + size_t axis_size, size_t inner_size) { + aphrodite::mobile::parallel_for_2d( + outer_size, inner_size, aphrodite::mobile::Thresholds::AXIS_REDUCE, + [&](size_t outer, size_t inner) { + float32x4_t max_vec = vdupq_n_f32(std::numeric_limits::lowest()); + + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_axis = (axis_size / SIMD_WIDTH) * SIMD_WIDTH; + + for (size_t a = 0; a < vectorized_axis; a += SIMD_WIDTH) { + float values[SIMD_WIDTH]; + for (size_t j = 0; j < SIMD_WIDTH; j++) { + size_t idx = + outer * axis_size * inner_size + (a + j) * inner_size + inner; + values[j] = input[idx]; + } + float32x4_t input_vec = vld1q_f32(values); + max_vec = vmaxq_f32(max_vec, input_vec); + } + + float max_val = vmaxvq_f32(max_vec); + + for (size_t a = vectorized_axis; a < axis_size; a++) { + size_t idx = outer * axis_size * inner_size + a * inner_size + inner; + max_val = std::max(max_val, input[idx]); + } + + size_t output_idx = outer * inner_size + inner; + output[output_idx] = max_val; + }); +} +} // namespace aphrodite::mobile diff --git a/aphrodite_kernels/csrc/cpu/mobile/scalar.cpp b/aphrodite_kernels/csrc/cpu/mobile/scalar.cpp new file mode 100644 index 0000000000..ed12714e1d --- /dev/null +++ b/aphrodite_kernels/csrc/cpu/mobile/scalar.cpp @@ -0,0 +1,767 @@ +#include "threading.hpp" +#include "kernels.h" +#include +#include +#include + +namespace aphrodite::mobile { + +void scalar_op_int8(const int8_t* input, int8_t* output, size_t num_elements, + float scalar_value, ScalarOpType op_type) { + switch (op_type) { + case ScalarOpType::ADD: { + const int8_t scalar_int8 = clamp_to_int8(scalar_value); + const size_t vectorized_elements = + (num_elements / NEON_VECTOR_SIZE) * NEON_VECTOR_SIZE; + + int8x16_t scalar_vec = vdupq_n_s8(scalar_int8); + + for (size_t i = 0; i < vectorized_elements; i += NEON_VECTOR_SIZE) { + int8x16_t input_vec = vld1q_s8(&input[i]); + + int16x8_t input_low = vmovl_s8(vget_low_s8(input_vec)); + int16x8_t input_high = vmovl_s8(vget_high_s8(input_vec)); + int16x8_t scalar_low = vmovl_s8(vget_low_s8(scalar_vec)); + int16x8_t scalar_high = vmovl_s8(vget_high_s8(scalar_vec)); + + int16x8_t result_low = vaddq_s16(input_low, scalar_low); + int16x8_t result_high = vaddq_s16(input_high, scalar_high); + + int8x16_t result_vec = + vcombine_s8(vqmovn_s16(result_low), vqmovn_s16(result_high)); + vst1q_s8(&output[i], result_vec); + } + + for (size_t i = vectorized_elements; i < num_elements; ++i) { + int32_t sum = + static_cast(input[i]) + static_cast(scalar_int8); + output[i] = clamp_to_int8(sum); + } + break; + } + + case ScalarOpType::SUBTRACT: { + const int8_t scalar_int8 = clamp_to_int8(scalar_value); + const size_t vectorized_elements = + (num_elements / NEON_VECTOR_SIZE) * NEON_VECTOR_SIZE; + + int8x16_t scalar_vec = vdupq_n_s8(scalar_int8); + + for (size_t i = 0; i < vectorized_elements; i += NEON_VECTOR_SIZE) { + int8x16_t input_vec = vld1q_s8(&input[i]); + + int16x8_t input_low = vmovl_s8(vget_low_s8(input_vec)); + int16x8_t input_high = vmovl_s8(vget_high_s8(input_vec)); + int16x8_t scalar_low = vmovl_s8(vget_low_s8(scalar_vec)); + int16x8_t scalar_high = vmovl_s8(vget_high_s8(scalar_vec)); + + int16x8_t result_low = vsubq_s16(input_low, scalar_low); + int16x8_t result_high = vsubq_s16(input_high, scalar_high); + + int8x16_t result_vec = + vcombine_s8(vqmovn_s16(result_low), vqmovn_s16(result_high)); + vst1q_s8(&output[i], result_vec); + } + + for (size_t i = vectorized_elements; i < num_elements; ++i) { + int32_t diff = + static_cast(input[i]) - static_cast(scalar_int8); + output[i] = clamp_to_int8(diff); + } + break; + } + + case ScalarOpType::MULTIPLY: { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_BASIC, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 8; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + const float32x4_t scalar_f32 = vdupq_n_f32(scalar_value); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + int8x8_t input_s8 = vld1_s8(&input[i]); + int16x8_t input_s16 = vmovl_s8(input_s8); + + float32x4_t input_low_f32 = + vcvtq_f32_s32(vmovl_s16(vget_low_s16(input_s16))); + float32x4_t input_high_f32 = + vcvtq_f32_s32(vmovl_s16(vget_high_s16(input_s16))); + + float32x4_t result_low_f32 = vmulq_f32(input_low_f32, scalar_f32); + float32x4_t result_high_f32 = + vmulq_f32(input_high_f32, scalar_f32); + + int32x4_t result_low_s32 = vcvtq_s32_f32(result_low_f32); + int32x4_t result_high_s32 = vcvtq_s32_f32(result_high_f32); + + int16x8_t result_s16 = vcombine_s16(vqmovn_s32(result_low_s32), + vqmovn_s32(result_high_s32)); + vst1_s8(&output[i], vqmovn_s16(result_s16)); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + float input_float = static_cast(input[i]); + float result_float = input_float * scalar_value; + output[i] = clamp_to_int8(result_float); + } + }); + break; + } + + case ScalarOpType::DIVIDE: { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_BASIC, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 8; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + const float32x4_t scalar_f32 = vdupq_n_f32(scalar_value); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + int8x8_t input_s8 = vld1_s8(&input[i]); + int16x8_t input_s16 = vmovl_s8(input_s8); + + float32x4_t input_low_f32 = + vcvtq_f32_s32(vmovl_s16(vget_low_s16(input_s16))); + float32x4_t input_high_f32 = + vcvtq_f32_s32(vmovl_s16(vget_high_s16(input_s16))); + + float32x4_t result_low_f32 = vdivq_f32(input_low_f32, scalar_f32); + float32x4_t result_high_f32 = + vdivq_f32(input_high_f32, scalar_f32); + + int32x4_t result_low_s32 = vcvtq_s32_f32(result_low_f32); + int32x4_t result_high_s32 = vcvtq_s32_f32(result_high_f32); + + int16x8_t result_s16 = vcombine_s16(vqmovn_s32(result_low_s32), + vqmovn_s32(result_high_s32)); + vst1_s8(&output[i], vqmovn_s16(result_s16)); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + float input_float = static_cast(input[i]); + float result_float = input_float / scalar_value; + output[i] = clamp_to_int8(result_float); + } + }); + break; + } + + case ScalarOpType::EXP: { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 8; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + const float32x4_t log2e = vdupq_n_f32(1.4426950408889634f); + const float32x4_t c1 = vdupq_n_f32(0.6931471805599453f); + const float32x4_t c2 = vdupq_n_f32(0.2402265069591007f); + const float32x4_t c3 = vdupq_n_f32(0.05550410866482158f); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + int16x8_t input_s16 = vmovl_s8(vld1_s8(&input[i])); + float32x4_t in_low = + vcvtq_f32_s32(vmovl_s16(vget_low_s16(input_s16))); + float32x4_t in_high = + vcvtq_f32_s32(vmovl_s16(vget_high_s16(input_s16))); + + float32x4_t x_low = vmulq_f32(in_low, log2e); + int32x4_t xi_low = vcvtq_s32_f32(x_low); + float32x4_t xf_low = vsubq_f32(x_low, vcvtq_f32_s32(xi_low)); + float32x4_t p_low = vmlaq_f32(c2, c3, xf_low); + p_low = vmlaq_f32(c1, p_low, xf_low); + p_low = vmlaq_f32(vdupq_n_f32(1.0f), p_low, xf_low); + int32x4_t exponent_low = + vshlq_n_s32(vaddq_s32(xi_low, vdupq_n_s32(127)), 23); + float32x4_t scale_low = vreinterpretq_f32_s32(exponent_low); + float32x4_t result_low_f32 = vmulq_f32(p_low, scale_low); + + float32x4_t x_high = vmulq_f32(in_high, log2e); + int32x4_t xi_high = vcvtq_s32_f32(x_high); + float32x4_t xf_high = vsubq_f32(x_high, vcvtq_f32_s32(xi_high)); + float32x4_t p_high = vmlaq_f32(c2, c3, xf_high); + p_high = vmlaq_f32(c1, p_high, xf_high); + p_high = vmlaq_f32(vdupq_n_f32(1.0f), p_high, xf_high); + int32x4_t exponent_high = + vshlq_n_s32(vaddq_s32(xi_high, vdupq_n_s32(127)), 23); + float32x4_t scale_high = vreinterpretq_f32_s32(exponent_high); + float32x4_t result_high_f32 = vmulq_f32(p_high, scale_high); + + int16x8_t result_s16 = + vcombine_s16(vqmovn_s32(vcvtq_s32_f32(result_low_f32)), + vqmovn_s32(vcvtq_s32_f32(result_high_f32))); + vst1_s8(&output[i], vqmovn_s16(result_s16)); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + float input_float = static_cast(input[i]); + float result_float = expf(input_float); + output[i] = clamp_to_int8(result_float); + } + }); + break; + } + + case ScalarOpType::SQRT: { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 8; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + int8x8_t input_s8 = vld1_s8(&input[i]); + int16x8_t input_s16 = vmovl_s8(input_s8); + + float32x4_t input_low_f32 = + vcvtq_f32_s32(vmovl_s16(vget_low_s16(input_s16))); + float32x4_t input_high_f32 = + vcvtq_f32_s32(vmovl_s16(vget_high_s16(input_s16))); + + float32x4_t rsqrt_low = vrsqrteq_f32(input_low_f32); + float32x4_t rsqrt_high = vrsqrteq_f32(input_high_f32); + + rsqrt_low = vmulq_f32( + rsqrt_low, + vrsqrtsq_f32(vmulq_f32(input_low_f32, rsqrt_low), rsqrt_low)); + rsqrt_high = + vmulq_f32(rsqrt_high, + vrsqrtsq_f32(vmulq_f32(input_high_f32, rsqrt_high), + rsqrt_high)); + + float32x4_t result_low_f32 = vmulq_f32(input_low_f32, rsqrt_low); + float32x4_t result_high_f32 = + vmulq_f32(input_high_f32, rsqrt_high); + + int32x4_t result_low_s32 = vcvtq_s32_f32(result_low_f32); + int32x4_t result_high_s32 = vcvtq_s32_f32(result_high_f32); + + int16x8_t result_s16 = vcombine_s16(vqmovn_s32(result_low_s32), + vqmovn_s32(result_high_s32)); + vst1_s8(&output[i], vqmovn_s16(result_s16)); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + float input_float = static_cast(input[i]); + float result_float = sqrtf(input_float); + output[i] = clamp_to_int8(result_float); + } + }); + break; + } + + case ScalarOpType::COS: { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 8; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + const float32x4_t two_pi = + vdupq_n_f32(2.0f * 3.14159265358979323846f); + const float32x4_t inv_two_pi = + vdupq_n_f32(1.0f / (2.0f * 3.14159265358979323846f)); + const float32x4_t c0 = vdupq_n_f32(1.0f); + const float32x4_t c2 = vdupq_n_f32(-0.5f); + const float32x4_t c4 = vdupq_n_f32(0.04166666666f); + const float32x4_t c6 = vdupq_n_f32(-0.00138888888f); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + int16x8_t in_s16 = vmovl_s8(vld1_s8(&input[i])); + float32x4_t x_low = + vcvtq_f32_s32(vmovl_s16(vget_low_s16(in_s16))); + float32x4_t x_high = + vcvtq_f32_s32(vmovl_s16(vget_high_s16(in_s16))); + + x_low = vsubq_f32( + x_low, + vmulq_f32(vrndnq_f32(vmulq_f32(x_low, inv_two_pi)), two_pi)); + x_high = vsubq_f32( + x_high, + vmulq_f32(vrndnq_f32(vmulq_f32(x_high, inv_two_pi)), two_pi)); + + auto poly = [&](float32x4_t x) { + float32x4_t x2 = vmulq_f32(x, x); + float32x4_t x4 = vmulq_f32(x2, x2); + float32x4_t x6 = vmulq_f32(x4, x2); + float32x4_t res = c0; + res = vmlaq_f32(res, x2, c2); + res = vmlaq_f32(res, x4, c4); + res = vmlaq_f32(res, x6, c6); + return res; + }; + + float32x4_t y_low = poly(x_low); + float32x4_t y_high = poly(x_high); + + int16x8_t result_s16 = vcombine_s16( + vqmovn_s32( + vcvtq_s32_f32(vmulq_f32(y_low, vdupq_n_f32(127.0f)))), + vqmovn_s32( + vcvtq_s32_f32(vmulq_f32(y_high, vdupq_n_f32(127.0f))))); + vst1_s8(&output[i], vqmovn_s16(result_s16)); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + float input_float = static_cast(input[i]); + float result_float = cosf(input_float); + output[i] = clamp_to_int8(result_float); + } + }); + break; + } + + case ScalarOpType::SIN: { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 8; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + const float32x4_t two_pi = + vdupq_n_f32(2.0f * 3.14159265358979323846f); + const float32x4_t inv_two_pi = + vdupq_n_f32(1.0f / (2.0f * 3.14159265358979323846f)); + const float32x4_t c1 = vdupq_n_f32(1.0f); + const float32x4_t c3 = vdupq_n_f32(-0.16666666666f); + const float32x4_t c5 = vdupq_n_f32(0.00833333333f); + const float32x4_t c7 = vdupq_n_f32(-0.00019841269f); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + int16x8_t in_s16 = vmovl_s8(vld1_s8(&input[i])); + float32x4_t x_low = + vcvtq_f32_s32(vmovl_s16(vget_low_s16(in_s16))); + float32x4_t x_high = + vcvtq_f32_s32(vmovl_s16(vget_high_s16(in_s16))); + + x_low = vsubq_f32( + x_low, + vmulq_f32(vrndnq_f32(vmulq_f32(x_low, inv_two_pi)), two_pi)); + x_high = vsubq_f32( + x_high, + vmulq_f32(vrndnq_f32(vmulq_f32(x_high, inv_two_pi)), two_pi)); + + auto poly = [&](float32x4_t x) { + float32x4_t x2 = vmulq_f32(x, x); + float32x4_t x3 = vmulq_f32(x2, x); + float32x4_t x5 = vmulq_f32(x3, x2); + float32x4_t x7 = vmulq_f32(x5, x2); + float32x4_t y = vmulq_f32(x, c1); + y = vmlaq_f32(y, x3, c3); + y = vmlaq_f32(y, x5, c5); + y = vmlaq_f32(y, x7, c7); + return y; + }; + + float32x4_t y_low = poly(x_low); + float32x4_t y_high = poly(x_high); + + int16x8_t result_s16 = vcombine_s16( + vqmovn_s32( + vcvtq_s32_f32(vmulq_f32(y_low, vdupq_n_f32(127.0f)))), + vqmovn_s32( + vcvtq_s32_f32(vmulq_f32(y_high, vdupq_n_f32(127.0f))))); + vst1_s8(&output[i], vqmovn_s16(result_s16)); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + float input_float = static_cast(input[i]); + float result_float = sinf(input_float); + output[i] = clamp_to_int8(result_float); + } + }); + break; + } + } +} + +void scalar_op_f32(const float* input, float* output, size_t num_elements, + float scalar_value, ScalarOpType op_type) { + switch (op_type) { + case ScalarOpType::ADD: { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_BASIC, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + float32x4_t scalar_vec = vdupq_n_f32(scalar_value); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float32x4_t input_vec = vld1q_f32(&input[i]); + float32x4_t result_vec = vaddq_f32(input_vec, scalar_vec); + vst1q_f32(&output[i], result_vec); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + output[i] = input[i] + scalar_value; + } + }); + break; + } + + case ScalarOpType::SUBTRACT: { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_BASIC, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + float32x4_t scalar_vec = vdupq_n_f32(scalar_value); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float32x4_t input_vec = vld1q_f32(&input[i]); + float32x4_t result_vec = vsubq_f32(input_vec, scalar_vec); + vst1q_f32(&output[i], result_vec); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + output[i] = input[i] - scalar_value; + } + }); + break; + } + + case ScalarOpType::MULTIPLY: { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_BASIC, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + float32x4_t scalar_vec = vdupq_n_f32(scalar_value); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float32x4_t input_vec = vld1q_f32(&input[i]); + float32x4_t result_vec = vmulq_f32(input_vec, scalar_vec); + vst1q_f32(&output[i], result_vec); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + output[i] = input[i] * scalar_value; + } + }); + break; + } + + case ScalarOpType::DIVIDE: { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_BASIC, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + float32x4_t scalar_vec = vdupq_n_f32(scalar_value); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float32x4_t input_vec = vld1q_f32(&input[i]); + float32x4_t result_vec = vdivq_f32(input_vec, scalar_vec); + vst1q_f32(&output[i], result_vec); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + output[i] = input[i] / scalar_value; + } + }); + break; + } + + case ScalarOpType::EXP: { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float32x4_t input_f32 = vld1q_f32(&input[i]); + + const float32x4_t log2e = vdupq_n_f32(1.4426950408889634f); + const float32x4_t c1 = vdupq_n_f32(0.6931471805599453f); + const float32x4_t c2 = vdupq_n_f32(0.2402265069591007f); + const float32x4_t c3 = vdupq_n_f32(0.05550410866482158f); + + float32x4_t x = vmulq_f32(input_f32, log2e); + int32x4_t xi = vcvtq_s32_f32(x); + float32x4_t xf = vsubq_f32(x, vcvtq_f32_s32(xi)); + + float32x4_t p = vmlaq_f32(c2, c3, xf); + p = vmlaq_f32(c1, p, xf); + p = vmlaq_f32(vdupq_n_f32(1.0f), p, xf); + + int32x4_t exponent = vaddq_s32(xi, vdupq_n_s32(127)); + exponent = vshlq_n_s32(exponent, 23); + float32x4_t scale = vreinterpretq_f32_s32(exponent); + + float32x4_t result_f32 = vmulq_f32(p, scale); + + vst1q_f32(&output[i], result_f32); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + output[i] = std::exp(input[i]); + } + }); + break; + } + + case ScalarOpType::SQRT: { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float32x4_t input_f32 = vld1q_f32(&input[i]); + + input_f32 = vmaxq_f32(input_f32, vdupq_n_f32(0.0f)); + + uint32x4_t zero_mask = vceqq_f32(input_f32, vdupq_n_f32(0.0f)); + + float32x4_t rsqrt = vrsqrteq_f32(input_f32); + rsqrt = vmulq_f32( + rsqrt, vrsqrtsq_f32(vmulq_f32(input_f32, rsqrt), rsqrt)); + float32x4_t result_f32 = vmulq_f32(input_f32, rsqrt); + + result_f32 = vbslq_f32(zero_mask, vdupq_n_f32(0.0f), result_f32); + + vst1q_f32(&output[i], result_f32); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + output[i] = std::sqrt(std::max(0.0f, input[i])); + } + }); + break; + } + + case ScalarOpType::COS: { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + const float32x4_t two_pi = + vdupq_n_f32(2.0f * 3.14159265358979323846f); + const float32x4_t inv_two_pi = + vdupq_n_f32(1.0f / (2.0f * 3.14159265358979323846f)); + const float32x4_t c0 = vdupq_n_f32(1.0f); + const float32x4_t c2 = vdupq_n_f32(-0.5f); + const float32x4_t c4 = vdupq_n_f32(0.04166666666f); + const float32x4_t c6 = vdupq_n_f32(-0.00138888888f); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float32x4_t x = vld1q_f32(&input[i]); + + x = vsubq_f32( + x, vmulq_f32(vrndnq_f32(vmulq_f32(x, inv_two_pi)), two_pi)); + + float32x4_t x2 = vmulq_f32(x, x); + float32x4_t x4 = vmulq_f32(x2, x2); + float32x4_t x6 = vmulq_f32(x4, x2); + + float32x4_t result = c0; + result = vmlaq_f32(result, x2, c2); + result = vmlaq_f32(result, x4, c4); + result = vmlaq_f32(result, x6, c6); + + vst1q_f32(&output[i], result); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + output[i] = cosf(input[i]); + } + }); + break; + } + + case ScalarOpType::SIN: { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 4; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + + const float32x4_t two_pi = + vdupq_n_f32(2.0f * 3.14159265358979323846f); + const float32x4_t inv_two_pi = + vdupq_n_f32(1.0f / (2.0f * 3.14159265358979323846f)); + const float32x4_t c1 = vdupq_n_f32(1.0f); + const float32x4_t c3 = vdupq_n_f32(-0.16666666666f); + const float32x4_t c5 = vdupq_n_f32(0.00833333333f); + const float32x4_t c7 = vdupq_n_f32(-0.00019841269f); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float32x4_t x = vld1q_f32(&input[i]); + + x = vsubq_f32( + x, vmulq_f32(vrndnq_f32(vmulq_f32(x, inv_two_pi)), two_pi)); + + float32x4_t x2 = vmulq_f32(x, x); + float32x4_t x3 = vmulq_f32(x2, x); + float32x4_t x5 = vmulq_f32(x3, x2); + float32x4_t x7 = vmulq_f32(x5, x2); + + float32x4_t result = vmulq_f32(x, c1); + result = vmlaq_f32(result, x3, c3); + result = vmlaq_f32(result, x5, c5); + result = vmlaq_f32(result, x7, c7); + + vst1q_f32(&output[i], result); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + output[i] = sinf(input[i]); + } + }); + break; + } + } +} + +void scalar_op_f16(const __fp16* input, __fp16* output, size_t num_elements, + float scalar_value, ScalarOpType op_type) { + const __fp16 scalar_f16 = static_cast<__fp16>(scalar_value); + + switch (op_type) { + case ScalarOpType::ADD: { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_BASIC, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 8; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + const float16x8_t scalar_vec = vdupq_n_f16(scalar_f16); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float16x8_t in_vec = vld1q_f16(&input[i]); + float16x8_t result = vaddq_f16(in_vec, scalar_vec); + vst1q_f16(&output[i], result); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + output[i] = input[i] + scalar_f16; + } + }); + break; + } + + case ScalarOpType::SUBTRACT: { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_BASIC, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 8; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + const float16x8_t scalar_vec = vdupq_n_f16(scalar_f16); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float16x8_t in_vec = vld1q_f16(&input[i]); + float16x8_t result = vsubq_f16(in_vec, scalar_vec); + vst1q_f16(&output[i], result); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + output[i] = input[i] - scalar_f16; + } + }); + break; + } + + case ScalarOpType::MULTIPLY: { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_BASIC, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 8; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + const float16x8_t scalar_vec = vdupq_n_f16(scalar_f16); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float16x8_t in_vec = vld1q_f16(&input[i]); + float16x8_t result = vmulq_f16(in_vec, scalar_vec); + vst1q_f16(&output[i], result); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + output[i] = input[i] * scalar_f16; + } + }); + break; + } + + case ScalarOpType::DIVIDE: { + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_BASIC, + [&](size_t start_idx, size_t end_idx) { + constexpr size_t SIMD_WIDTH = 8; + const size_t vectorized_end = + start_idx + ((end_idx - start_idx) / SIMD_WIDTH) * SIMD_WIDTH; + const float16x8_t scalar_vec = vdupq_n_f16(scalar_f16); + + for (size_t i = start_idx; i < vectorized_end; i += SIMD_WIDTH) { + float16x8_t in_vec = vld1q_f16(&input[i]); + float16x8_t result = vdivq_f16(in_vec, scalar_vec); + vst1q_f16(&output[i], result); + } + + for (size_t i = vectorized_end; i < end_idx; ++i) { + output[i] = input[i] / scalar_f16; + } + }); + break; + } + + case ScalarOpType::EXP: + case ScalarOpType::SQRT: + case ScalarOpType::COS: + case ScalarOpType::SIN: { + // For complex operations, convert to float32, compute, then convert back + aphrodite::mobile::parallel_for( + num_elements, aphrodite::mobile::Thresholds::SCALAR_EXPENSIVE, + [&](size_t start_idx, size_t end_idx) { + for (size_t i = start_idx; i < end_idx; ++i) { + float val = static_cast(input[i]); + float result; + switch (op_type) { + case ScalarOpType::EXP: + result = std::exp(val); + break; + case ScalarOpType::SQRT: + result = std::sqrt(val); + break; + case ScalarOpType::COS: + result = std::cos(val); + break; + case ScalarOpType::SIN: + result = std::sin(val); + break; + default: + result = val; + break; + } + output[i] = static_cast<__fp16>(result); + } + }); + break; + } + } +} +} // namespace aphrodite::mobile diff --git a/aphrodite_kernels/csrc/cpu/mobile/threading.hpp b/aphrodite_kernels/csrc/cpu/mobile/threading.hpp new file mode 100644 index 0000000000..9bfa42b7ca --- /dev/null +++ b/aphrodite_kernels/csrc/cpu/mobile/threading.hpp @@ -0,0 +1,355 @@ +#ifndef APHRODITE_MOBILE_THREADING_HPP +#define APHRODITE_MOBILE_THREADING_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +constexpr size_t NEON_VECTOR_SIZE = 16; + +inline int8_t clamp_to_int8(float value) { + int32_t clamped = static_cast(roundf(value)); + return static_cast(std::max(-128, std::min(127, clamped))); +} + +inline int8_t clamp_to_int8(int32_t value) { + return static_cast(std::max(-128, std::min(127, value))); +} + +#if defined(__ARM_FEATURE_DOTPROD) +inline int32x4_t accum_i8mm(int32x4_t acc, int8x16_t a, int8x16_t b) { + return vdotq_s32(acc, a, b); +} +#else +inline int32x4_t accum_i8mm(int32x4_t acc, int8x16_t a, int8x16_t b) { + int16x8_t prod_low = vmull_s8(vget_low_s8(a), vget_low_s8(b)); + int32x4_t acc_high = vpaddlq_s16(vmull_s8(vget_high_s8(a), vget_high_s8(b))); + return vaddq_s32(vaddq_s32(acc, vpaddlq_s16(prod_low)), acc_high); +} +#endif + +inline float16x8_t accum_f16_dot(float16x8_t acc, float16x8_t a_low, + float16x8_t a_high, float16x8_t b_low, + float16x8_t b_high) { + acc = vfmaq_f16(acc, a_low, b_low); + return vfmaq_f16(acc, a_high, b_high); +} + +inline float32x4_t accum_f32_dot(float32x4_t acc, float32x4_t a_low, + float32x4_t a_high, float32x4_t b_low, + float32x4_t b_high) { + acc = vfmaq_f32(acc, a_low, b_low); + return vfmaq_f32(acc, a_high, b_high); +} + +namespace aphrodite::mobile { + +class ThreadPool { + private: + std::vector workers; + std::queue> tasks; + std::mutex queue_mutex; + std::condition_variable condition; + std::atomic stop{false}; + std::atomic active_workers{0}; + std::condition_variable finish_condition; + + void worker_thread() { + while (true) { + std::function task; + { + std::unique_lock lock(queue_mutex); + condition.wait(lock, [this] { return stop || !tasks.empty(); }); + + if (stop && tasks.empty()) return; + + task = std::move(tasks.front()); + tasks.pop(); + active_workers++; + } + + task(); + + active_workers--; + finish_condition.notify_all(); + } + } + + public: + explicit ThreadPool( + size_t num_threads = std::thread::hardware_concurrency()) { + workers.reserve(num_threads); + for (size_t i = 0; i < num_threads; ++i) { + workers.emplace_back(&ThreadPool::worker_thread, this); + } + } + + ~ThreadPool() { + { + std::unique_lock lock(queue_mutex); + stop = true; + } + condition.notify_all(); + for (auto& worker : workers) { + worker.join(); + } + } + + template + auto enqueue(F&& f) -> std::future { + using return_type = decltype(f()); + + auto task = + std::make_shared>(std::forward(f)); + + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); + if (stop) throw std::runtime_error("enqueue on stopped ThreadPool"); + + tasks.emplace([task]() { (*task)(); }); + } + condition.notify_one(); + return res; + } + + void wait_all() { + std::unique_lock lock(queue_mutex); + finish_condition.wait( + lock, [this] { return tasks.empty() && active_workers == 0; }); + } + + size_t num_workers() const { return workers.size(); } +}; + +inline ThreadPool& get_thread_pool() { + static ThreadPool pool; + return pool; +} + +inline size_t get_optimal_thread_count(size_t total_work, + size_t min_work_per_thread) { + if (total_work < min_work_per_thread) return 1; + size_t pool_size = get_thread_pool().num_workers(); + return std::min(pool_size, std::max(static_cast(1), + total_work / min_work_per_thread)); +} + +struct Thresholds { +#if defined(__ANDROID__) + static constexpr size_t ELEMENT_WISE = 5000; + static constexpr size_t AXIS_REDUCE = 1000; + static constexpr size_t ALL_REDUCE = 10000; + static constexpr size_t SCALAR_BASIC = 30000; + static constexpr size_t SCALAR_EXPENSIVE = 10000; + static constexpr size_t ATTENTION = 512; + static constexpr size_t GEMM_TILED = 20000; + static constexpr size_t GEMM_SMALL = 64 * 64 * 64; + static constexpr size_t GEMM_MEDIUM = 256 * 256 * 256; + static constexpr size_t GEMM_TILE_M = 64; + static constexpr size_t GEMM_TILE_N = 64; + static constexpr size_t GEMM_TILE_M_SMALL = 32; + static constexpr size_t GEMM_TILE_N_SMALL = 32; +#else // iOS + static constexpr size_t ELEMENT_WISE = 5000; + static constexpr size_t AXIS_REDUCE = 1000; + static constexpr size_t ALL_REDUCE = 10000; + static constexpr size_t SCALAR_BASIC = 5000; + static constexpr size_t SCALAR_EXPENSIVE = 2500; + static constexpr size_t ATTENTION = 4; + static constexpr size_t GEMM_TILED = 4; + static constexpr size_t GEMM_SMALL = 64 * 64 * 64; + static constexpr size_t GEMM_MEDIUM = 256 * 256 * 256; + static constexpr size_t GEMM_TILE_M = 64; + static constexpr size_t GEMM_TILE_N = 64; + static constexpr size_t GEMM_TILE_M_SMALL = 32; + static constexpr size_t GEMM_TILE_N_SMALL = 32; +#endif + static constexpr size_t L2_CACHE_SIZE = 256 * 1024; +}; + +class TaskHandle { + private: + std::vector> futures_; + bool auto_wait_; + + public: + TaskHandle(bool auto_wait = true) : auto_wait_(auto_wait) {} + + ~TaskHandle() { + if (auto_wait_) { + wait(); + } + } + + TaskHandle(TaskHandle&&) = default; + TaskHandle& operator=(TaskHandle&&) = default; + TaskHandle(const TaskHandle&) = delete; + TaskHandle& operator=(const TaskHandle&) = delete; + + void add_future(std::future&& f) { futures_.push_back(std::move(f)); } + + void wait() { + for (auto& f : futures_) { + if (f.valid()) { + f.wait(); + } + } + futures_.clear(); + } + + bool is_ready() const { + for (const auto& f : futures_) { + if (f.valid() && + f.wait_for(std::chrono::seconds(0)) != std::future_status::ready) { + return false; + } + } + return true; + } + + size_t task_count() const { return futures_.size(); } +}; + +template +TaskHandle parallel_for(size_t total_work, size_t threshold, WorkFunc work_func, + bool wait = true) { + const size_t num_threads = get_optimal_thread_count(total_work, threshold); + TaskHandle handle(!wait); + + if (num_threads == 1) { + if (wait) { + work_func(0, total_work); + return handle; + } + auto& pool = get_thread_pool(); + handle.add_future( + pool.enqueue([work_func, total_work]() { work_func(0, total_work); })); + return handle; + } + + auto& pool = get_thread_pool(); + const size_t work_per_thread = total_work / num_threads; + + for (size_t t = 0; t < num_threads; ++t) { + handle.add_future(pool.enqueue( + [work_func, t, num_threads, work_per_thread, total_work]() { + const size_t start_idx = t * work_per_thread; + const size_t end_idx = + (t == num_threads - 1) ? total_work : (t + 1) * work_per_thread; + work_func(start_idx, end_idx); + })); + } + + if (wait) { + handle.wait(); + } + return handle; +} + +template +void parallel_for_2d(size_t outer_size, size_t inner_size, size_t threshold, + WorkFunc work_func) { + const size_t total_work = outer_size * inner_size; + parallel_for(total_work, threshold, [&](size_t start_idx, size_t end_idx) { + for (size_t work_idx = start_idx; work_idx < end_idx; ++work_idx) { + const size_t outer = work_idx / inner_size; + const size_t inner = work_idx % inner_size; + work_func(outer, inner); + } + }); +} + +template +ResultType parallel_reduce(size_t total_work, size_t threshold, + WorkFunc work_func, ResultType init_value, + CombineFunc combine_func) { + const size_t num_threads = get_optimal_thread_count(total_work, threshold); + + if (num_threads == 1) { + return work_func(0, total_work); + } + + auto& pool = get_thread_pool(); + std::vector> futures; + std::vector partial_results(num_threads, init_value); + const size_t work_per_thread = total_work / num_threads; + + for (size_t t = 0; t < num_threads; ++t) { + futures.push_back(pool.enqueue([work_func, t, num_threads, work_per_thread, + total_work]() -> ResultType { + const size_t start_idx = t * work_per_thread; + const size_t end_idx = + (t == num_threads - 1) ? total_work : (t + 1) * work_per_thread; + return work_func(start_idx, end_idx); + })); + } + + ResultType result = init_value; + for (auto& future : futures) { + result = combine_func(result, future.get()); + } + return result; +} + +inline size_t compute_gemm_parallelism(size_t M, size_t K, size_t N, + size_t element_size) { + size_t total_ops = M * K * N; + + if (total_ops < Thresholds::GEMM_SMALL) return 1; + + if (total_ops < Thresholds::GEMM_MEDIUM) { + return std::min(static_cast(2), get_thread_pool().num_workers()); + } + + size_t bytes_accessed = (M * K + K * N + M * N) * element_size; + size_t cache_tiles = (bytes_accessed + Thresholds::L2_CACHE_SIZE - 1) / + Thresholds::L2_CACHE_SIZE; + + size_t compute_threads = + std::sqrt(static_cast(total_ops) / Thresholds::GEMM_SMALL); + size_t memory_threads = cache_tiles; + + size_t optimal = std::min(compute_threads, memory_threads); + return std::min(optimal, get_thread_pool().num_workers()); +} + +template +void parallel_for_2d_tiled(size_t rows, size_t cols, size_t tile_rows, + size_t tile_cols, WorkFunc work_func) { + size_t num_row_tiles = (rows + tile_rows - 1) / tile_rows; + size_t num_col_tiles = (cols + tile_cols - 1) / tile_cols; + size_t total_tiles = num_row_tiles * num_col_tiles; + + parallel_for(total_tiles, Thresholds::GEMM_TILED, + [=](size_t start_tile, size_t end_tile) { + for (size_t tile_idx = start_tile; tile_idx < end_tile; + ++tile_idx) { + size_t tile_row = tile_idx / num_col_tiles; + size_t tile_col = tile_idx % num_col_tiles; + + size_t row_start = tile_row * tile_rows; + size_t row_end = std::min(row_start + tile_rows, rows); + size_t col_start = tile_col * tile_cols; + size_t col_end = std::min(col_start + tile_cols, cols); + + work_func(row_start, row_end, col_start, col_end); + } + }); +} +} // namespace aphrodite::mobile + +#endif // APHRODITE_MOBILE_THREADING_HPP \ No newline at end of file diff --git a/aphrodite_kernels/csrc/cpu/mobile/torch_bindings.cpp b/aphrodite_kernels/csrc/cpu/mobile/torch_bindings.cpp new file mode 100644 index 0000000000..1cd784d8c4 --- /dev/null +++ b/aphrodite_kernels/csrc/cpu/mobile/torch_bindings.cpp @@ -0,0 +1,214 @@ +#include "kernels.h" +#include + +#if defined(__aarch64__) || defined(__ARM_NEON) + +namespace { + +template +T* get_data_ptr(torch::Tensor& tensor) { + return tensor.data_ptr(); +} + +template +const T* get_data_ptr(const torch::Tensor& tensor) { + return tensor.data_ptr(); +} + +} // namespace + +// Quantization wrappers +void mobile_int8_to_fp32(const torch::Tensor& src, torch::Tensor& dst, + double scale) { + TORCH_CHECK(src.scalar_type() == torch::kInt8, "Input must be int8"); + TORCH_CHECK(dst.scalar_type() == torch::kFloat32, "Output must be float32"); + TORCH_CHECK(src.numel() == dst.numel(), + "Input and output must have same size"); + + aphrodite::mobile::int8_to_fp32(get_data_ptr(src), + get_data_ptr(dst), src.numel(), + static_cast(scale)); +} + +void mobile_fp32_to_int8(const torch::Tensor& src, torch::Tensor& dst, + double scale) { + TORCH_CHECK(src.scalar_type() == torch::kFloat32, "Input must be float32"); + TORCH_CHECK(dst.scalar_type() == torch::kInt8, "Output must be int8"); + TORCH_CHECK(src.numel() == dst.numel(), + "Input and output must have same size"); + + aphrodite::mobile::fp32_to_int8(get_data_ptr(src), + get_data_ptr(dst), src.numel(), + static_cast(scale)); +} + +void mobile_dynamic_quantize_fp32_to_int8(const torch::Tensor& src, + torch::Tensor& dst, + torch::Tensor& computed_scale) { + TORCH_CHECK(src.scalar_type() == torch::kFloat32, "Input must be float32"); + TORCH_CHECK(dst.scalar_type() == torch::kInt8, "Output must be int8"); + TORCH_CHECK(computed_scale.scalar_type() == torch::kFloat32, + "Scale must be float32"); + TORCH_CHECK(src.numel() == dst.numel(), + "Input and output must have same size"); + + float scale; + aphrodite::mobile::dynamic_quantize_fp32_to_int8( + get_data_ptr(src), get_data_ptr(dst), src.numel(), &scale); + *get_data_ptr(computed_scale) = scale; +} + +// Matrix multiplication wrappers +void mobile_matmul_int8(const torch::Tensor& a, + const torch::Tensor& b_transposed, torch::Tensor& c, + double a_scale, double b_scale, double c_scale) { + TORCH_CHECK(a.scalar_type() == torch::kInt8, "Matrix A must be int8"); + TORCH_CHECK(b_transposed.scalar_type() == torch::kInt8, + "Matrix B must be int8"); + TORCH_CHECK(c.scalar_type() == torch::kInt8, "Matrix C must be int8"); + + int64_t M = a.size(0); + int64_t K = a.size(1); + int64_t N = b_transposed.size(0); + + TORCH_CHECK(b_transposed.size(1) == K, "Matrix dimensions must match"); + TORCH_CHECK(c.size(0) == M && c.size(1) == N, + "Output matrix dimensions must match"); + + aphrodite::mobile::matmul_int8( + get_data_ptr(a), get_data_ptr(b_transposed), + get_data_ptr(c), M, K, N, static_cast(a_scale), + static_cast(b_scale), static_cast(c_scale)); +} + +void mobile_matmul_f16(const torch::Tensor& a, + const torch::Tensor& b_transposed, torch::Tensor& c) { + TORCH_CHECK(a.scalar_type() == torch::kFloat16, "Matrix A must be float16"); + TORCH_CHECK(b_transposed.scalar_type() == torch::kFloat16, + "Matrix B must be float16"); + TORCH_CHECK(c.scalar_type() == torch::kFloat16, "Matrix C must be float16"); + + int64_t M = a.size(0); + int64_t K = a.size(1); + int64_t N = b_transposed.size(0); + + TORCH_CHECK(b_transposed.size(1) == K, "Matrix dimensions must match"); + TORCH_CHECK(c.size(0) == M && c.size(1) == N, + "Output matrix dimensions must match"); + + aphrodite::mobile::matmul_f16(get_data_ptr<__fp16>(a), + get_data_ptr<__fp16>(b_transposed), + get_data_ptr<__fp16>(c), M, K, N); +} + +void mobile_matmul_f32(const torch::Tensor& a, + const torch::Tensor& b_transposed, torch::Tensor& c) { + TORCH_CHECK(a.scalar_type() == torch::kFloat32, "Matrix A must be float32"); + TORCH_CHECK(b_transposed.scalar_type() == torch::kFloat32, + "Matrix B must be float32"); + TORCH_CHECK(c.scalar_type() == torch::kFloat32, "Matrix C must be float32"); + + int64_t M = a.size(0); + int64_t K = a.size(1); + int64_t N = b_transposed.size(0); + + TORCH_CHECK(b_transposed.size(1) == K, "Matrix dimensions must match"); + TORCH_CHECK(c.size(0) == M && c.size(1) == N, + "Output matrix dimensions must match"); + + aphrodite::mobile::matmul_f32(get_data_ptr(a), + get_data_ptr(b_transposed), + get_data_ptr(c), M, K, N); +} + +// Activation wrappers +void mobile_silu_f32(const torch::Tensor& input, torch::Tensor& output) { + TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input must be float32"); + TORCH_CHECK(output.scalar_type() == torch::kFloat32, + "Output must be float32"); + TORCH_CHECK(input.numel() == output.numel(), + "Input and output must have same size"); + + aphrodite::mobile::silu_f32(get_data_ptr(input), + get_data_ptr(output), input.numel()); +} + +void mobile_gelu_f32(const torch::Tensor& input, torch::Tensor& output) { + TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input must be float32"); + TORCH_CHECK(output.scalar_type() == torch::kFloat32, + "Output must be float32"); + TORCH_CHECK(input.numel() == output.numel(), + "Input and output must have same size"); + + aphrodite::mobile::gelu_f32(get_data_ptr(input), + get_data_ptr(output), input.numel()); +} + +// RMS Norm wrappers +void mobile_rms_norm_f32(const torch::Tensor& input, + const torch::Tensor& weight, torch::Tensor& output, + double eps) { + TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input must be float32"); + TORCH_CHECK(weight.scalar_type() == torch::kFloat32, + "Weight must be float32"); + TORCH_CHECK(output.scalar_type() == torch::kFloat32, + "Output must be float32"); + + int64_t batch_size = input.size(0); + int64_t dims = input.size(1); + + TORCH_CHECK(weight.size(0) == dims, "Weight size must match input dimension"); + TORCH_CHECK(output.size(0) == batch_size && output.size(1) == dims, + "Output shape must match input"); + + aphrodite::mobile::rms_norm_f32( + get_data_ptr(input), get_data_ptr(weight), + get_data_ptr(output), batch_size, dims, static_cast(eps)); +} + +// Reduction wrappers +int64_t mobile_sum_all_int8(const torch::Tensor& data) { + TORCH_CHECK(data.scalar_type() == torch::kInt8, "Input must be int8"); + return aphrodite::mobile::sum_all_int8(get_data_ptr(data), + data.numel()); +} + +double mobile_mean_all_int8(const torch::Tensor& data) { + TORCH_CHECK(data.scalar_type() == torch::kInt8, "Input must be int8"); + return aphrodite::mobile::mean_all_int8(get_data_ptr(data), + data.numel()); +} + +// Scalar operation wrappers +void mobile_scalar_op_int8(const torch::Tensor& input, torch::Tensor& output, + double scalar_value, int64_t op_type) { + TORCH_CHECK(input.scalar_type() == torch::kInt8, "Input must be int8"); + TORCH_CHECK(output.scalar_type() == torch::kInt8, "Output must be int8"); + TORCH_CHECK(input.numel() == output.numel(), + "Input and output must have same size"); + TORCH_CHECK(op_type >= 0 && op_type <= 7, "Invalid operation type"); + + aphrodite::mobile::ScalarOpType op = + static_cast(op_type); + aphrodite::mobile::scalar_op_int8(get_data_ptr(input), + get_data_ptr(output), input.numel(), + static_cast(scalar_value), op); +} + +void mobile_scalar_op_f32(const torch::Tensor& input, torch::Tensor& output, + double scalar_value, int64_t op_type) { + TORCH_CHECK(input.scalar_type() == torch::kFloat32, "Input must be float32"); + TORCH_CHECK(output.scalar_type() == torch::kFloat32, + "Output must be float32"); + TORCH_CHECK(input.numel() == output.numel(), + "Input and output must have same size"); + TORCH_CHECK(op_type >= 0 && op_type <= 7, "Invalid operation type"); + + aphrodite::mobile::ScalarOpType op = + static_cast(op_type); + aphrodite::mobile::scalar_op_f32(get_data_ptr(input), + get_data_ptr(output), input.numel(), + static_cast(scalar_value), op); +} + +#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/aphrodite_kernels/csrc/cpu/torch_bindings.cpp b/aphrodite_kernels/csrc/cpu/torch_bindings.cpp index 27726fa419..3eeddaf21f 100644 --- a/aphrodite_kernels/csrc/cpu/torch_bindings.cpp +++ b/aphrodite_kernels/csrc/cpu/torch_bindings.cpp @@ -4,6 +4,10 @@ #include +#if defined(__aarch64__) || defined(__ARM_NEON) + #include "mobile/torch_bindings.cpp" +#endif + std::string init_cpu_threads_env(const std::string& cpu_ids); void release_dnnl_matmul_handler(int64_t handler); @@ -259,6 +263,65 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("int8_scaled_mm_with_quant", torch::kCPU, &int8_scaled_mm_with_quant); #endif + + // Mobile-optimized kernels (ARM NEON) +#if defined(__aarch64__) || defined(__ARM_NEON) + // Quantization ops + ops.def("mobile_int8_to_fp32(Tensor src, Tensor! dst, float scale) -> ()"); + ops.impl("mobile_int8_to_fp32", torch::kCPU, &mobile_int8_to_fp32); + + ops.def("mobile_fp32_to_int8(Tensor src, Tensor! dst, float scale) -> ()"); + ops.impl("mobile_fp32_to_int8", torch::kCPU, &mobile_fp32_to_int8); + + ops.def( + "mobile_dynamic_quantize_fp32_to_int8(Tensor src, Tensor! dst, Tensor! " + "computed_scale) -> ()"); + ops.impl("mobile_dynamic_quantize_fp32_to_int8", torch::kCPU, + &mobile_dynamic_quantize_fp32_to_int8); + + // Matrix multiplication ops + ops.def( + "mobile_matmul_int8(Tensor a, Tensor b_transposed, Tensor! c, float " + "a_scale, float b_scale, float c_scale) -> ()"); + ops.impl("mobile_matmul_int8", torch::kCPU, &mobile_matmul_int8); + + ops.def("mobile_matmul_f16(Tensor a, Tensor b_transposed, Tensor! c) -> ()"); + ops.impl("mobile_matmul_f16", torch::kCPU, &mobile_matmul_f16); + + ops.def("mobile_matmul_f32(Tensor a, Tensor b_transposed, Tensor! c) -> ()"); + ops.impl("mobile_matmul_f32", torch::kCPU, &mobile_matmul_f32); + + // Activation ops + ops.def("mobile_silu_f32(Tensor input, Tensor! output) -> ()"); + ops.impl("mobile_silu_f32", torch::kCPU, &mobile_silu_f32); + + ops.def("mobile_gelu_f32(Tensor input, Tensor! output) -> ()"); + ops.impl("mobile_gelu_f32", torch::kCPU, &mobile_gelu_f32); + + // Normalization ops + ops.def( + "mobile_rms_norm_f32(Tensor input, Tensor weight, Tensor! output, float " + "eps) -> ()"); + ops.impl("mobile_rms_norm_f32", torch::kCPU, &mobile_rms_norm_f32); + + // Reduction ops + ops.def("mobile_sum_all_int8(Tensor data) -> int"); + ops.impl("mobile_sum_all_int8", torch::kCPU, &mobile_sum_all_int8); + + ops.def("mobile_mean_all_int8(Tensor data) -> float"); + ops.impl("mobile_mean_all_int8", torch::kCPU, &mobile_mean_all_int8); + + // Scalar operation ops + ops.def( + "mobile_scalar_op_int8(Tensor input, Tensor! output, float scalar_value, " + "int op_type) -> ()"); + ops.impl("mobile_scalar_op_int8", torch::kCPU, &mobile_scalar_op_int8); + + ops.def( + "mobile_scalar_op_f32(Tensor input, Tensor! output, float scalar_value, " + "int op_type) -> ()"); + ops.impl("mobile_scalar_op_f32", torch::kCPU, &mobile_scalar_op_f32); +#endif } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { diff --git a/aphrodite_kernels/csrc/ops.h b/aphrodite_kernels/csrc/ops.h index e62cbd34e5..56014b20d5 100644 --- a/aphrodite_kernels/csrc/ops.h +++ b/aphrodite_kernels/csrc/ops.h @@ -446,4 +446,33 @@ void qr_open_handles(fptr_t _fa, const std::vector& handles); void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half = false); int64_t qr_max_size(); +#endif + +// Mobile-optimized kernels (ARM NEON) +#if defined(__aarch64__) || defined(__ARM_NEON) +void mobile_int8_to_fp32(const torch::Tensor& src, torch::Tensor& dst, + double scale); +void mobile_fp32_to_int8(const torch::Tensor& src, torch::Tensor& dst, + double scale); +void mobile_dynamic_quantize_fp32_to_int8(const torch::Tensor& src, + torch::Tensor& dst, + torch::Tensor& computed_scale); +void mobile_matmul_int8(const torch::Tensor& a, + const torch::Tensor& b_transposed, torch::Tensor& c, + double a_scale, double b_scale, double c_scale); +void mobile_matmul_f16(const torch::Tensor& a, + const torch::Tensor& b_transposed, torch::Tensor& c); +void mobile_matmul_f32(const torch::Tensor& a, + const torch::Tensor& b_transposed, torch::Tensor& c); +void mobile_silu_f32(const torch::Tensor& input, torch::Tensor& output); +void mobile_gelu_f32(const torch::Tensor& input, torch::Tensor& output); +void mobile_rms_norm_f32(const torch::Tensor& input, + const torch::Tensor& weight, torch::Tensor& output, + double eps); +int64_t mobile_sum_all_int8(const torch::Tensor& data); +double mobile_mean_all_int8(const torch::Tensor& data); +void mobile_scalar_op_int8(const torch::Tensor& input, torch::Tensor& output, + double scalar_value, int64_t op_type); +void mobile_scalar_op_f32(const torch::Tensor& input, torch::Tensor& output, + double scalar_value, int64_t op_type); #endif \ No newline at end of file