diff --git a/CMakeLists.txt b/CMakeLists.txt index fcd8b2eef13..db0ead08223 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -166,6 +166,7 @@ option(ENABLE_AUTEST_UDS "Setup autest with curl using UDS (default OFF)") option(ENABLE_BENCHMARKS "Build benchmarks (default OFF)") option(EXTERNAL_YAML_CPP "Use external yaml-cpp (default OFF)") option(EXTERNAL_LIBSWOC "Use external libswoc (default OFF)") +option(EXTERNAL_HWY "Use external highway (default OFF)") option(LINK_PLUGINS "Link core libraries to plugins (default OFF)") option(ENABLE_PROBES "Enable ATS SystemTap probes (default OFF)") option(ENABLE_VERIFY_PLUGINS "Enable plugin verification tests (default ON)" ON) @@ -415,6 +416,11 @@ if(EXTERNAL_LIBSWOC) find_package(libswoc REQUIRED) endif() +if(EXTERNAL_HWY) + message(STATUS "Looking for external highway") + find_package(HWY "1.4.0" CONFIG REQUIRED) +endif() + include(Check128BitCas) include(ConfigureTransparentProxy) diff --git a/NOTICE b/NOTICE index 31787fecbf0..f36f46979a5 100644 --- a/NOTICE +++ b/NOTICE @@ -118,3 +118,8 @@ LS-HPACK provides functionality to encode and decode HTTP headers using HPACK compression mechanism specified in RFC 7541. Copyright (c) 2018 - 2023 LiteSpeed Technologies Inc, (MIT License) https://github.com/litespeedtech/ls-hpack.git + +~~ + +Highway is a C++ library that provides portable SIMD/vector intrinsics. +https://github.com/google/highway diff --git a/README.md b/README.md index 46d5eaf7533..b14b0245273 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,7 @@ trafficserver ............. Top src dir ├── lib ................... Third-party libraries │ ├── Catch2 ............ Unit testing framework │ ├── fastlz ............ Fast compression library +│ ├── highway ........... Portable SIMD/vector intrinsics │ ├── ls-hpack .......... HPACK compression for HTTP/2 │ ├── swoc .............. Solid Wall of Code utility library │ ├── systemtap ......... SystemTap integration diff --git a/ci/rat-exclude.txt b/ci/rat-exclude.txt index 39ad1b11ca8..8ebfa76679f 100644 --- a/ci/rat-exclude.txt +++ b/ci/rat-exclude.txt @@ -73,6 +73,7 @@ blib/** **/yamlcpp/** **/systemtap/** **/swoc/** +**/highway/** tests/gold_tests/autest-site/min_cfg tests/gold_tests/h2/rules/huge_resp_hdrs.conf tools/http_load/** diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 6e28a50eb15..17771e4ee5f 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -78,3 +78,40 @@ add_library(systemtap::systemtap INTERFACE IMPORTED GLOBAL) target_include_directories(systemtap::systemtap INTERFACE "${CMAKE_CURRENT_SOURCE_DIR}/systemtap") add_subdirectory(ls-hpack) + +# Currently v1.4.0 +# +# Upstream: https://github.com/google/highway +# +# Source-only, preserves the readme, license, cmake +# infrastructure and main source; does not include +# examples, tests or documentation. +# +# Note that hwy/tests/list_targets.cc is retained +# for compiled-in target information +# +# CMakeLists.txt slightly modified to fully disable +# building tests if `HWY_ENABLE_TESTS=OFF`, we intend +# to upstream this change +if(NOT EXTERNAL_HWY) + message(STATUS "Using internal highway") + set(HWY_FORCE_STATIC_LIBS + ON + CACHE BOOL "Ignore BUILD_SHARED_LIBS" FORCE + ) + set(HWY_ENABLE_INSTALL + OFF + CACHE BOOL "Install library" FORCE + ) + set(HWY_ENABLE_EXAMPLES + OFF + CACHE BOOL "Build examples" FORCE + ) + set(HWY_ENABLE_TESTS + OFF + CACHE BOOL "Enable HWY tests" FORCE + ) + add_subdirectory(highway) + add_library(hwy::hwy ALIAS hwy) + add_library(hwy::hwy_contrib ALIAS hwy_contrib) +endif() diff --git a/lib/highway/CMakeLists.txt b/lib/highway/CMakeLists.txt new file mode 100644 index 00000000000..82acaddc30d --- /dev/null +++ b/lib/highway/CMakeLists.txt @@ -0,0 +1,994 @@ +# Copyright 2019 Google LLC +# Copyright 2024 Arm Limited and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 +# SPDX-License-Identifier: BSD-3-Clause +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.10) + +# Set PIE flags for POSITION_INDEPENDENT_CODE targets, added in 3.14. +if(POLICY CMP0083) + cmake_policy(SET CMP0083 NEW) +endif() + +# Workaround for 3.19 raising error 'IMPORTED_LOCATION not set for imported +# target "GTest::gtest_main"'. +if(POLICY CMP0111) + cmake_policy(SET CMP0111 OLD) +endif() + +# Starting with GCC-13, we want to make sure to remove gnu extension (ie. +# explicit -std=c++17 instead of implicit `gnu++17`) +# Without this cmake property, CMAKE_CXX_EXTENSIONS=OFF was not properly +# considered +if(POLICY CMP0128) + cmake_policy(SET CMP0128 NEW) +endif() + +project(hwy VERSION 1.4.0) # Keep in sync with base.h version +# `hwy` is lowercase to handle find_package() in Config mode: +set(namespace "${PROJECT_NAME}::") + +# Directly define the ABI version from the cmake project() version values: +set(LIBRARY_VERSION "${hwy_VERSION}") +set(LIBRARY_SOVERSION ${hwy_VERSION_MAJOR}) + +set(CMAKE_CXX_EXTENSIONS OFF) + +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") +# Search for Atomics implementation: +find_package(Atomics REQUIRED) + +# Enabled PIE binaries by default if supported. +include(CheckPIESupported OPTIONAL RESULT_VARIABLE CHECK_PIE_SUPPORTED) +if(CHECK_PIE_SUPPORTED) + check_pie_supported(LANGUAGES CXX) + if(CMAKE_CXX_LINK_PIE_SUPPORTED) + set(CMAKE_POSITION_INDEPENDENT_CODE TRUE) + endif() +endif() + +include(GNUInstallDirs) + +if (NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE RelWithDebInfo) +endif() + +# The following is only required with GCC < 6.1.0 or CLANG < 16.0 +set(HWY_CMAKE_ARM7 OFF CACHE BOOL "Set copts for Armv7 with NEON (requires vfpv4)?") + +# This must be set on 32-bit x86 with GCC < 13.1, otherwise math_test will be +# skipped. For GCC 13.1+, you can also build with -fexcess-precision=standard. +set(HWY_CMAKE_SSE2 OFF CACHE BOOL "Set SSE2 as baseline for 32-bit x86?") + +# Currently this will compile the entire codebase with `-march=rvgcv1p0`: +set(HWY_CMAKE_RVV ON CACHE BOOL "Set copts for RISCV with RVV?") + +# Unconditionally adding -Werror risks breaking the build when new warnings +# arise due to compiler/platform changes. Enable this in CI/tests. +set(HWY_WARNINGS_ARE_ERRORS OFF CACHE BOOL "Add -Werror flag?") + +# Experimental support for header-only builds +set(HWY_CMAKE_HEADER_ONLY OFF CACHE BOOL "Change to header-only?") + +set(HWY_ENABLE_CONTRIB ON CACHE BOOL "Include contrib/") +set(HWY_ENABLE_EXAMPLES ON CACHE BOOL "Build examples") +set(HWY_ENABLE_INSTALL ON CACHE BOOL "Install library") +set(HWY_ENABLE_TESTS ON CACHE BOOL "Enable HWY tests") + +if (MSVC) +set(HWY_TEST_STANDALONE ON CACHE BOOL "Disable use of googletest") +else() +set(HWY_TEST_STANDALONE OFF CACHE BOOL "Disable use of googletest") +endif() + +if (NOT DEFINED CMAKE_CXX_STANDARD) + if ("cxx_std_17" IN_LIST CMAKE_CXX_COMPILE_FEATURES) + set(HWY_CXX_STD_TGT_COMPILE_FEATURE cxx_std_17) + else() + message(WARNING "cxx_std_17 not found in CMAKE_CXX_COMPILE_FEATURES but vqsort requires C++17") + set(HWY_CXX_STD_TGT_COMPILE_FEATURE cxx_std_14) + endif() +else() + if (CMAKE_CXX_STANDARD GREATER_EQUAL 17 AND CMAKE_CXX_STANDARD LESS 98) + set(HWY_CXX_STD_TGT_COMPILE_FEATURE cxx_std_17) + else() + message(WARNING "CMAKE_CXX_STANDARD < 17 but vqsort requires C++17") + set(HWY_CXX_STD_TGT_COMPILE_FEATURE cxx_std_14) + endif() +endif() + +include(CheckCXXSourceCompiles) +check_cxx_source_compiles( + "int main() { + #if !defined(__EMSCRIPTEN__) + static_assert(false, \"__EMSCRIPTEN__ is not defined\"); + #endif + return 0; + }" + HWY_EMSCRIPTEN +) + +check_cxx_source_compiles( + "int main() { + #if !defined(__riscv) + static_assert(false, \"__riscv is not defined\"); + #endif + return 0; + }" + HWY_RISCV +) + +if (WIN32) + set (ORIG_CMAKE_REQUIRED_LIBRARIES ${CMAKE_REQUIRED_LIBRARIES}) + set (CMAKE_REQUIRED_LIBRARIES synchronization) + check_cxx_source_compiles( + "#ifndef NOMINMAX + #define NOMINMAX + #endif + + #include + + int main() { + unsigned val1 = 0u; + unsigned val2 = 1u; + WaitOnAddress(&val1, &val2, sizeof(unsigned), 1); + WakeByAddressAll(&val1); + WakeByAddressSingle(&val1); + return 0; + }" + HWY_HAVE_WIN32_SYNCHRONIZATION_LIB) + set (CMAKE_REQUIRED_LIBRARIES ${ORIG_CMAKE_REQUIRED_LIBRARIES}) +else() + set (HWY_HAVE_WIN32_SYNCHRONIZATION_LIB OFF) +endif () + +if (HWY_HAVE_WIN32_SYNCHRONIZATION_LIB OR NOT WIN32) + set (HWY_DISABLE_FUTEX OFF CACHE BOOL "Disable futex for thread_pool") +else() + # Force HWY_DISABLE_FUTEX to ON if compiling for Win32 and + # libsynchronization.a or synchronization.lib is not available + set (HWY_DISABLE_FUTEX ON CACHE BOOL "Disable futex for thread_pool" FORCE) +endif() + +find_package(Threads) + +if (NOT Threads_FOUND AND HWY_ENABLE_CONTRIB) + message(FATAL_ERROR "Threads must be available if HWY_ENABLE_CONTRIB is ON") +endif() + +if(Threads_FOUND) + if(THREADS_HAVE_PTHREAD_ARG) + set(HWY_THREAD_FLAGS "-pthread") + else() + set(HWY_THREAD_FLAGS "") + endif() + + if(CMAKE_THREAD_LIBS_INIT) + set(HWY_THREAD_LIBS ${CMAKE_THREAD_LIBS_INIT}) + else() + set(HWY_THREAD_LIBS "") + endif() +else() + set(HWY_THREAD_FLAGS "") + set(HWY_THREAD_LIBS "") +endif() + +if (HWY_RISCV OR CMAKE_CXX_COMPILER_ARCHITECTURE_ID MATCHES "RISCV32|RISCV64|RISCV128" OR CMAKE_SYSTEM_PROCESSOR MATCHES "riscv32|riscv64|riscv128") + include(CheckCSourceCompiles) + check_c_source_compiles(" + #if __riscv_xlen == 64 + int main() { return 0; } + #else + #error Not RISCV-64 + #endif + " IS_RISCV_XLEN_64) + + check_c_source_compiles(" + #if __riscv_xlen == 32 + int main() { return 0; } + #else + #error Not RISCV-32 + #endif + " IS_RISCV_XLEN_32) + + if(IS_RISCV_XLEN_32) + set(RISCV_XLEN 32) + elseif(IS_RISCV_XLEN_64) + set(RISCV_XLEN 64) + else() + message(WARNING "Unable to determine RISC-V XLEN") + endif() +endif() + +if (HWY_ENABLE_CONTRIB) +# Glob all the traits so we don't need to modify this file when adding +# additional special cases. +file(GLOB HWY_CONTRIB_SOURCES "hwy/contrib/sort/vqsort_*.cc") +list(APPEND HWY_CONTRIB_SOURCES + hwy/contrib/bit_pack/bit_pack-inl.h + hwy/contrib/dot/dot-inl.h + hwy/contrib/image/image.cc + hwy/contrib/image/image.h + hwy/contrib/math/fast_math-inl.h + hwy/contrib/math/math-inl.h + hwy/contrib/matvec/matvec-inl.h + hwy/contrib/random/random-inl.h + hwy/contrib/sort/order.h + hwy/contrib/sort/shared-inl.h + hwy/contrib/sort/sorting_networks-inl.h + hwy/contrib/sort/traits-inl.h + hwy/contrib/sort/traits128-inl.h + hwy/contrib/sort/vqsort-inl.h + hwy/contrib/sort/vqsort.cc + hwy/contrib/sort/vqsort.h + hwy/contrib/thread_pool/futex.h + hwy/contrib/thread_pool/spin.h + hwy/contrib/thread_pool/thread_pool.cc + hwy/contrib/thread_pool/thread_pool.h + hwy/contrib/thread_pool/topology.cc + hwy/contrib/thread_pool/topology.h + hwy/contrib/algo/copy-inl.h + hwy/contrib/algo/count-inl.h + hwy/contrib/algo/find-inl.h + hwy/contrib/algo/minmax-inl.h + hwy/contrib/algo/transform-inl.h + hwy/contrib/unroller/unroller-inl.h +) +endif() # HWY_ENABLE_CONTRIB + +set(HWY_SOURCES + hwy/abort.h + hwy/aligned_allocator.h + hwy/auto_tune.h + hwy/base.h + hwy/bit_set.h + hwy/cache_control.h + hwy/detect_compiler_arch.h # private + hwy/detect_targets.h # private + hwy/foreach_target.h + hwy/highway_export.h + hwy/highway.h + hwy/nanobenchmark.h + hwy/ops/arm_neon-inl.h + hwy/ops/arm_sve-inl.h + hwy/ops/emu128-inl.h + hwy/ops/generic_ops-inl.h + hwy/ops/inside-inl.h + hwy/ops/loongarch_lsx-inl.h + hwy/ops/loongarch_lasx-inl.h + hwy/ops/ppc_vsx-inl.h + hwy/ops/rvv-inl.h + hwy/ops/scalar-inl.h + hwy/ops/set_macros-inl.h + hwy/ops/shared-inl.h + hwy/ops/wasm_128-inl.h + hwy/ops/x86_128-inl.h + hwy/ops/x86_256-inl.h + hwy/ops/x86_512-inl.h + hwy/ops/x86_avx3-inl.h + hwy/per_target.h + hwy/print-inl.h + hwy/print.h + hwy/profiler.h + hwy/robust_statistics.h + hwy/targets.h + hwy/timer-inl.h + hwy/timer.h + hwy/x86_cpuid.h +) + +if (NOT HWY_CMAKE_HEADER_ONLY) + list(APPEND HWY_SOURCES + hwy/abort.cc + hwy/aligned_allocator.cc + hwy/nanobenchmark.cc + hwy/per_target.cc + hwy/perf_counters.cc + hwy/print.cc + hwy/profiler.cc + hwy/targets.cc + hwy/timer.cc + ) +endif() + +set(HWY_TEST_SOURCES + hwy/tests/hwy_gtest.h + hwy/tests/test_util-inl.h + hwy/tests/test_util.cc + hwy/tests/test_util.h +) + +if (MSVC) + set(HWY_FLAGS + # fix build error C1128 in blockwise*_test & arithmetic_test + /bigobj + + # Warnings + /W4 + # Disable some W4 warnings. Enable them individually after they are cleaned up. + /wd4100 + /wd4127 + /wd4324 + /wd4456 + /wd4701 + /wd4702 + /wd4723 + + # CMake automatically adds exception handling flags. Remove them. + /GR- + /EHs-c- + # Disable exceptions in STL code. + -D_HAS_EXCEPTIONS=0 + ) + + # This adds extra warnings for the clang-cl compiler on Windows. + # This is the same as the sections in the else part. + # These could be refactored. + if (${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + list(APPEND HWY_FLAGS + # These are not included in Wall nor Wextra: + -Wconversion + -Wsign-conversion + -Wvla + -Wnon-virtual-dtor + + -Wfloat-overflow-conversion + -Wfloat-zero-conversion + -Wfor-loop-analysis + -Wgnu-redeclared-enum + -Winfinite-recursion + -Wself-assign + -Wstring-conversion + -Wtautological-overlap-compare + -Wthread-safety-analysis + -Wundefined-func-template + ) + endif() + + if (HWY_WARNINGS_ARE_ERRORS) + list(APPEND HWY_FLAGS /WX) + endif() +else() + set(HWY_FLAGS + # Avoid changing binaries based on the current time and date. + -Wno-builtin-macro-redefined + -D__DATE__="redacted" + -D__TIMESTAMP__="redacted" + -D__TIME__="redacted" + + # Optimizations + -fmerge-all-constants + + # Warnings + -Wall + -Wextra + # These are not included in Wall nor Wextra: + -Wconversion + -Wsign-conversion + -Wvla + -Wnon-virtual-dtor + -Wcast-align # see -Wcast-align=strict on x86 + ) + + if(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + list(APPEND HWY_FLAGS + -Wfloat-overflow-conversion + -Wfloat-zero-conversion + -Wfor-loop-analysis + -Wgnu-redeclared-enum + -Winfinite-recursion + -Wself-assign + -Wstring-conversion + -Wtautological-overlap-compare + -Wthread-safety-analysis + -Wundefined-func-template + + -fno-cxx-exceptions + -fno-slp-vectorize + -fno-vectorize + + # Use color in messages + -fdiagnostics-show-option -fcolor-diagnostics + ) + if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 6.0) + list(APPEND HWY_FLAGS -Wc++2a-extensions) + endif() + endif() + + if (WIN32) + if(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + list(APPEND HWY_FLAGS + -Wno-global-constructors + -Wno-language-extension-token + -Wno-used-but-marked-unused + -Wno-shadow-field-in-constructor + -Wno-unused-member-function + -Wno-unused-template + -Wno-c++98-compat-pedantic + -Wno-used-but-marked-unused + -Wno-zero-as-null-pointer-constant + ) + endif() + + list(APPEND HWY_FLAGS + -Wno-cast-align + -Wno-double-promotion + -Wno-float-equal + -Wno-format-nonliteral + -Wno-shadow + -Wno-sign-conversion + ) + else() + list(APPEND HWY_FLAGS + -fmath-errno + -fno-exceptions + ) + endif() # WIN32 + + # Workaround for excess precision, see #1488. + if (HWY_CMAKE_SSE2) + list(APPEND HWY_FLAGS -msse2 -mfpmath=sse) + endif() + + # Suppress STL iterator warnings. Supported by GCC 4.4.7 and newer, which + # predates the C++14 (and C++17 for VQSort) we require. + if (${CMAKE_CXX_COMPILER_ID} MATCHES "GNU") + list(APPEND HWY_FLAGS -Wno-psabi) + endif() + # Clang supports this flag from 11.0. + if (${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 11.0) + list(APPEND HWY_FLAGS -Wno-psabi) + endif() + endif() + + if (HWY_CMAKE_ARM7) + list(APPEND HWY_FLAGS + -march=armv7-a + -mfpu=neon-vfpv4 + -mfloat-abi=hard # must match the toolchain specified as CXX= + -DHWY_HAVE_SCALAR_F16_TYPE=0 # See #2625 + -DHWY_NEON_HAVE_F16C=0 + ) + if(${CMAKE_CXX_COMPILER_ID} MATCHES "GNU") + # using GCC + list(APPEND HWY_FLAGS + -mfp16-format=ieee # required for vcvt_f32_f16 + ) + endif() + endif() # HWY_CMAKE_ARM7 + + if(HWY_RISCV) + # Add the gcv compiler flag so that RVV is available as a baseline target. + # Without this, Clang 19+ can still use RVV via runtime dispatch. + if(HWY_CMAKE_RVV) + if(RISCV_XLEN EQUAL 64) + list(APPEND HWY_FLAGS -march=rv64gcv1p0) + add_link_options(-march=rv64gcv1p0) + elseif(RISCV_XLEN EQUAL 32) + list(APPEND HWY_FLAGS -march=rv32gcv1p0) + add_link_options(-march=rv32gcv1p0) + endif() + if(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + list(APPEND HWY_FLAGS -menable-experimental-extensions) + endif() + endif() + endif() + + if (HWY_WARNINGS_ARE_ERRORS) + list(APPEND HWY_FLAGS -Werror) + endif() + + # Prevent "wasm-ld: error: --shared-memory is disallowed by targets.cc.o + # because it was not compiled with 'atomics' or 'bulk-memory' features." + if (HWY_EMSCRIPTEN) + list(APPEND HWY_FLAGS -matomics) + endif() + +endif() # !MSVC + +if (HWY_DISABLE_FUTEX) + list(APPEND HWY_FLAGS -DHWY_DISABLE_FUTEX) +endif() + +if (HWY_CMAKE_HEADER_ONLY) + list(APPEND HWY_FLAGS -DHWY_HEADER_ONLY) +endif() + +include(CheckIncludeFile) +check_include_file(sys/auxv.h HAVE_SYS_AUXV_H) +check_include_file(asm/hwcap.h HAVE_ASM_HWCAP_H) + +# By default prefer STATIC build (legacy behavior) +option(BUILD_SHARED_LIBS "Build shared libraries" OFF) +option(HWY_FORCE_STATIC_LIBS "Ignore BUILD_SHARED_LIBS" OFF) +# only expose shared/static options to advanced users: +mark_as_advanced(BUILD_SHARED_LIBS) +mark_as_advanced(HWY_FORCE_STATIC_LIBS) +# Define visibility settings globally: +set(CMAKE_CXX_VISIBILITY_PRESET hidden) +set(CMAKE_VISIBILITY_INLINES_HIDDEN 1) + +# Copy-cat "add_library" logic + add override. +set(HWY_LIBRARY_TYPE "SHARED") +if (NOT BUILD_SHARED_LIBS OR HWY_FORCE_STATIC_LIBS) + set(HWY_LIBRARY_TYPE "STATIC") +endif() + +# This preprocessor define will drive the build, also used in the *.pc files: +if("${HWY_LIBRARY_TYPE}" STREQUAL "SHARED") + set(DLLEXPORT_TO_DEFINE "HWY_SHARED_DEFINE") +else() + set(DLLEXPORT_TO_DEFINE "HWY_STATIC_DEFINE") +endif() + +add_library(hwy ${HWY_LIBRARY_TYPE} ${HWY_SOURCES}) +if(NOT HAVE_SYS_AUXV_H) + target_compile_definitions(hwy PUBLIC TOOLCHAIN_MISS_SYS_AUXV_H) +endif() +if(NOT HAVE_ASM_HWCAP_H) + target_compile_definitions(hwy PUBLIC TOOLCHAIN_MISS_ASM_HWCAP_H) +endif() +target_compile_definitions(hwy PUBLIC "${DLLEXPORT_TO_DEFINE}") +target_compile_options(hwy PRIVATE ${HWY_FLAGS}) +set_property(TARGET hwy PROPERTY POSITION_INDEPENDENT_CODE ON) +set_target_properties(hwy PROPERTIES VERSION ${LIBRARY_VERSION} SOVERSION ${LIBRARY_SOVERSION}) +target_include_directories(hwy PUBLIC + $ + $) +target_compile_features(hwy PUBLIC cxx_std_11) +if (NOT HWY_CXX_STD_TGT_COMPILE_FEATURE STREQUAL "cxx_std_11") + target_compile_features(hwy PRIVATE ${HWY_CXX_STD_TGT_COMPILE_FEATURE}) +endif() +set_target_properties(hwy PROPERTIES + LINK_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/hwy/hwy.version) +# For GCC __atomic_store_8, see #887 +target_link_libraries(hwy PRIVATE ${ATOMICS_LIBRARIES}) +# not supported by MSVC/Clang, safe to skip (we use DLLEXPORT annotations) +if(UNIX AND NOT APPLE) + set_property(TARGET hwy APPEND_STRING PROPERTY + LINK_FLAGS " -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/hwy/hwy.version") +endif() + +if (HWY_ENABLE_CONTRIB) +add_library(hwy_contrib ${HWY_LIBRARY_TYPE} ${HWY_CONTRIB_SOURCES}) +target_link_libraries(hwy_contrib PUBLIC hwy) +target_compile_options(hwy_contrib PRIVATE ${HWY_FLAGS} ${HWY_THREAD_FLAGS}) +set_property(TARGET hwy_contrib PROPERTY POSITION_INDEPENDENT_CODE ON) +set_target_properties(hwy_contrib PROPERTIES VERSION ${LIBRARY_VERSION} SOVERSION ${LIBRARY_SOVERSION}) +target_include_directories(hwy_contrib PUBLIC + $ + $) +target_compile_features(hwy_contrib PUBLIC cxx_std_11) +if (NOT HWY_CXX_STD_TGT_COMPILE_FEATURE STREQUAL "cxx_std_11") + target_compile_features(hwy_contrib PRIVATE ${HWY_CXX_STD_TGT_COMPILE_FEATURE}) +endif() +set_target_properties(hwy_contrib PROPERTIES + LINK_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/hwy/hwy.version) +# For GCC __atomic_store_8, see #887 +target_link_libraries(hwy_contrib PRIVATE ${ATOMICS_LIBRARIES}) + +# Avoid linker errors if libpthread needs to be linked +target_link_libraries(hwy_contrib PRIVATE ${HWY_THREAD_LIBS}) + +if (WIN32 AND NOT MSVC AND NOT HWY_DISABLE_FUTEX) +target_link_libraries(hwy_contrib PUBLIC synchronization) +endif() + +# not supported by MSVC/Clang, safe to skip (we use DLLEXPORT annotations) +if(UNIX AND NOT APPLE) + set_property(TARGET hwy_contrib APPEND_STRING PROPERTY + LINK_FLAGS " -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/hwy/hwy.version") +endif() +endif() # HWY_ENABLE_CONTRIB + +if (HWY_ENABLE_TESTS) +add_library(hwy_test ${HWY_LIBRARY_TYPE} ${HWY_TEST_SOURCES}) +target_link_libraries(hwy_test PUBLIC hwy) +target_compile_options(hwy_test PRIVATE ${HWY_FLAGS}) +set_property(TARGET hwy_test PROPERTY POSITION_INDEPENDENT_CODE ON) +set_target_properties(hwy_test PROPERTIES VERSION ${LIBRARY_VERSION} SOVERSION ${LIBRARY_SOVERSION}) +target_include_directories(hwy_test PUBLIC + $ + $) +target_compile_features(hwy_test PUBLIC cxx_std_11) +if (NOT HWY_CXX_STD_TGT_COMPILE_FEATURE STREQUAL "cxx_std_11") + target_compile_features(hwy_test PRIVATE ${HWY_CXX_STD_TGT_COMPILE_FEATURE}) +endif() +set_target_properties(hwy_test PROPERTIES + LINK_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/hwy/hwy.version) +# not supported by MSVC/Clang, safe to skip (we use DLLEXPORT annotations) +if(UNIX AND NOT APPLE) + set_property(TARGET hwy_test APPEND_STRING PROPERTY + LINK_FLAGS " -Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/hwy/hwy.version") +endif() +endif() # HWY_ENABLE_TESTS + +if (CMAKE_SOURCE_DIR STREQUAL PROJECT_SOURCE_DIR) +# -------------------------------------------------------- hwy_list_targets +# Generate a tool to print the compiled-in targets as defined by the current +# flags. This tool will print to stderr at build time, after building hwy. +add_executable(hwy_list_targets hwy/tests/list_targets.cc) +target_compile_options(hwy_list_targets PRIVATE ${HWY_FLAGS}) +target_compile_features(hwy_list_targets PRIVATE ${HWY_CXX_STD_TGT_COMPILE_FEATURE}) +target_link_libraries(hwy_list_targets PRIVATE hwy) +target_include_directories(hwy_list_targets PRIVATE + $) +# TARGET_FILE always returns the path to executable +# Naked target also not always could be run (due to the lack of '.\' prefix) +# Thus effective command to run should contain the full path +# and emulator prefix (if any). +if (NOT CMAKE_CROSSCOMPILING OR CMAKE_CROSSCOMPILING_EMULATOR) +add_custom_command(TARGET hwy_list_targets POST_BUILD + COMMAND ${CMAKE_CROSSCOMPILING_EMULATOR} $ || (exit 0)) +endif() +endif() # CMAKE_SOURCE_DIR STREQUAL PROJECT_SOURCE_DIR + +# -------------------------------------------------------- +# Allow skipping the following sections for projects that do not need them: +# tests, examples, benchmarks and installation. + +# -------------------------------------------------------- install library +if (HWY_ENABLE_INSTALL) + +install(TARGETS hwy EXPORT hwy_targets + LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}" + ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" + RUNTIME DESTINATION "${CMAKE_INSTALL_BINDIR}") +# Install all the headers keeping the relative path to the current directory +# when installing them. +foreach (source ${HWY_SOURCES}) + if ("${source}" MATCHES "\.h$") + get_filename_component(dirname "${source}" DIRECTORY) + install(FILES "${source}" + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/${dirname}") + endif() +endforeach() + +if (HWY_ENABLE_CONTRIB) +install(TARGETS hwy_contrib EXPORT hwy_targets + LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}" + ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" + RUNTIME DESTINATION "${CMAKE_INSTALL_BINDIR}") +# Install all the headers keeping the relative path to the current directory +# when installing them. +foreach (source ${HWY_CONTRIB_SOURCES}) + if ("${source}" MATCHES "\.h$") + get_filename_component(dirname "${source}" DIRECTORY) + install(FILES "${source}" + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/${dirname}") + endif() +endforeach() +endif() # HWY_ENABLE_CONTRIB + +if (HWY_ENABLE_TESTS) +install(TARGETS hwy_test EXPORT hwy_targets + LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}" + ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" + RUNTIME DESTINATION "${CMAKE_INSTALL_BINDIR}") +# Install all the headers keeping the relative path to the current directory +# when installing them. +foreach (source ${HWY_TEST_SOURCES}) + if ("${source}" MATCHES "\.h$") + get_filename_component(dirname "${source}" DIRECTORY) + install(FILES "${source}" + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/${dirname}") + endif() +endforeach() +endif() # HWY_ENABLE_TESTS + +# Add a pkg-config file for libhwy and the contrib/test libraries. +set(HWY_LIBRARY_VERSION "${CMAKE_PROJECT_VERSION}") +set(HWY_PC_FILES libhwy.pc) + +if (HWY_DISABLE_FUTEX) + set(HWY_PC_DISABLE_FUTEX_CFLAGS "-DHWY_DISABLE_FUTEX") +else() + set(HWY_PC_DISABLE_FUTEX_CFLAGS "") +endif() + +if (WIN32 AND NOT MSVC AND NOT HWY_DISABLE_FUTEX) + set(HWY_PC_WIN32_SYNCHRONIZATION_LIBS "-lsynchronization") +else() + set(HWY_PC_WIN32_SYNCHRONIZATION_LIBS "") +endif() + +if (HWY_ENABLE_CONTRIB) +list(APPEND HWY_PC_FILES libhwy-contrib.pc) +endif() # HWY_ENABLE_CONTRIB +if (HWY_ENABLE_TESTS) + +if (HWY_TEST_STANDALONE) + set(HWY_PC_HWY_TEST_REQUIRES "") + set(HWY_PC_HWY_TEST_CFLAGS "-DHWY_TEST_STANDALONE=1") +else() + set(HWY_PC_HWY_TEST_REQUIRES "gtest") + set(HWY_PC_HWY_TEST_CFLAGS "") +endif() + +list(APPEND HWY_PC_FILES libhwy-test.pc) +endif() # HWY_ENABLE_TESTS +foreach (pc ${HWY_PC_FILES}) + configure_file("${CMAKE_CURRENT_SOURCE_DIR}/${pc}.in" "${pc}" @ONLY) + install(FILES "${CMAKE_CURRENT_BINARY_DIR}/${pc}" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig") +endforeach() + +endif() # HWY_ENABLE_INSTALL +# -------------------------------------------------------- Examples +if (HWY_ENABLE_EXAMPLES) + +# Avoids mismatch between GTest's static CRT and our dynamic. +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + +# Programming exercise with integrated benchmark +add_executable(hwy_benchmark hwy/examples/benchmark.cc) +target_sources(hwy_benchmark PRIVATE + hwy/nanobenchmark.h) +# Try adding one of -DHWY_COMPILE_ONLY_SCALAR, -DHWY_COMPILE_ONLY_EMU128 or +# -DHWY_COMPILE_ONLY_STATIC to observe the difference in targets printed. +target_compile_options(hwy_benchmark PRIVATE ${HWY_FLAGS}) +target_compile_features(hwy_benchmark PRIVATE ${HWY_CXX_STD_TGT_COMPILE_FEATURE}) +target_link_libraries(hwy_benchmark PRIVATE hwy) +target_link_libraries(hwy_benchmark PRIVATE ${ATOMICS_LIBRARIES}) +set_target_properties(hwy_benchmark + PROPERTIES RUNTIME_OUTPUT_DIRECTORY "examples/") + +# Profiler demo +if (HWY_ENABLE_CONTRIB) +add_executable(hwy_profiler_example hwy/examples/profiler_example.cc) +target_sources(hwy_profiler_example PRIVATE + hwy/profiler.h) +target_compile_options(hwy_profiler_example PRIVATE ${HWY_FLAGS} ${HWY_THREAD_FLAGS}) +target_compile_features(hwy_profiler_example PRIVATE ${HWY_CXX_STD_TGT_COMPILE_FEATURE}) +target_link_libraries(hwy_profiler_example PRIVATE hwy hwy_contrib) +target_link_libraries(hwy_profiler_example PRIVATE ${ATOMICS_LIBRARIES}) +target_link_libraries(hwy_profiler_example PRIVATE ${HWY_THREAD_LIBS}) +set_target_properties(hwy_profiler_example + PROPERTIES RUNTIME_OUTPUT_DIRECTORY "examples/") +endif() # HWY_ENABLE_CONTRIB + +# Simple array sum example +add_executable(sum_array_simple hwy/examples/sum_array_simple.cc) +target_compile_options(sum_array_simple PRIVATE ${HWY_FLAGS}) +target_compile_features(sum_array_simple PRIVATE ${HWY_CXX_STD_TGT_COMPILE_FEATURE}) +target_link_libraries(sum_array_simple PRIVATE hwy) +target_link_libraries(sum_array_simple PRIVATE ${ATOMICS_LIBRARIES}) +set_target_properties(sum_array_simple + PROPERTIES RUNTIME_OUTPUT_DIRECTORY "examples/") + +# Advanced array sum example +add_executable(sum_array_advanced hwy/examples/sum_array_advanced.cc) +target_compile_options(sum_array_advanced PRIVATE ${HWY_FLAGS}) +target_compile_features(sum_array_advanced PRIVATE ${HWY_CXX_STD_TGT_COMPILE_FEATURE}) +target_link_libraries(sum_array_advanced PRIVATE hwy) +target_link_libraries(sum_array_advanced PRIVATE ${ATOMICS_LIBRARIES}) +set_target_properties(sum_array_advanced + PROPERTIES RUNTIME_OUTPUT_DIRECTORY "examples/") + +endif() # HWY_ENABLE_EXAMPLES +# -------------------------------------------------------- Tests + +if(HWY_ENABLE_TESTS) +include(CTest) + +if(BUILD_TESTING) +enable_testing() +include(GoogleTest) + +set(HWY_SYSTEM_GTEST OFF CACHE BOOL "Use pre-installed googletest?") + +if(NOT HWY_TEST_STANDALONE) +if(HWY_SYSTEM_GTEST) +find_package(GTest REQUIRED) +else() +# Download and unpack googletest at configure time +configure_file(CMakeLists.txt.in googletest-download/CMakeLists.txt) +execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/googletest-download ) +if(result) + message(FATAL_ERROR "CMake step for googletest failed: ${result}") +endif() +execute_process(COMMAND ${CMAKE_COMMAND} --build . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/googletest-download ) +if(result) + message(FATAL_ERROR "Build step for googletest failed: ${result}") +endif() + +# Prevent overriding the parent project's compiler/linker +# settings on Windows +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + +# Add googletest directly to our build. This defines +# the gtest and gtest_main targets. +add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/googletest-src + ${CMAKE_CURRENT_BINARY_DIR}/googletest-build + EXCLUDE_FROM_ALL) +endif() # HWY_SYSTEM_GTEST +endif() # HWY_TEST_STANDALONE + +set(HWY_TEST_FILES + hwy/abort_test.cc + hwy/aligned_allocator_test.cc + hwy/base_test.cc + hwy/bit_set_test.cc + hwy/highway_test.cc + hwy/nanobenchmark_test.cc + hwy/perf_counters_test.cc + hwy/targets_test.cc + hwy/examples/skeleton_test.cc + hwy/tests/arithmetic_test.cc + hwy/tests/bit_permute_test.cc + hwy/tests/blockwise_combine_test.cc + hwy/tests/blockwise_shift_test.cc + hwy/tests/blockwise_test.cc + hwy/tests/cast_test.cc + hwy/tests/combine_test.cc + hwy/tests/compare_128_test.cc + hwy/tests/compare_test.cc + hwy/tests/complex_arithmetic_test.cc + hwy/tests/compress_test.cc + hwy/tests/concat_test.cc + hwy/tests/convert_test.cc + hwy/tests/count_test.cc + hwy/tests/crypto_test.cc + hwy/tests/demote_test.cc + hwy/tests/div_test.cc + hwy/tests/dup128_vec_test.cc + hwy/tests/expand_test.cc + hwy/tests/float_test.cc + hwy/tests/fma_test.cc + hwy/tests/foreach_vec_test.cc + hwy/tests/if_test.cc + hwy/tests/in_range_float_to_int_conv_test.cc + hwy/tests/interleaved_test.cc + hwy/tests/logical_test.cc + hwy/tests/mask_combine_test.cc + hwy/tests/mask_convert_test.cc + hwy/tests/mask_mem_test.cc + hwy/tests/mask_set_test.cc + hwy/tests/mask_slide_test.cc + hwy/tests/mask_test.cc + hwy/tests/masked_arithmetic_test.cc + hwy/tests/masked_compare_test.cc + hwy/tests/masked_minmax_test.cc + hwy/tests/memory_test.cc + hwy/tests/minmax_magnitude_test.cc + hwy/tests/minmax_number_test.cc + hwy/tests/minmax_test.cc + hwy/tests/minmax128_test.cc + hwy/tests/mul_by_pow2_test.cc + hwy/tests/mul_pairwise_test.cc + hwy/tests/mul_test.cc + hwy/tests/neg_test.cc + hwy/tests/reduction_test.cc + hwy/tests/resize_test.cc + hwy/tests/reverse_test.cc + hwy/tests/rotate_test.cc + hwy/tests/saturated_test.cc + hwy/tests/shift_test.cc + hwy/tests/shuffle4_test.cc + hwy/tests/sign_test.cc + hwy/tests/slide_up_down_test.cc + hwy/tests/sums_abs_diff_test.cc + hwy/tests/swizzle_block_test.cc + hwy/tests/swizzle_test.cc + hwy/tests/table_test.cc + hwy/tests/test_util_test.cc + hwy/tests/truncate_test.cc + hwy/tests/tuple_test.cc + hwy/tests/widen_mul_test.cc +) + +set(HWY_TEST_LIBS hwy hwy_test) + +if (HWY_ENABLE_CONTRIB) +list(APPEND HWY_TEST_LIBS hwy_contrib) + +list(APPEND HWY_TEST_FILES + hwy/auto_tune_test.cc + hwy/contrib/algo/copy_test.cc + hwy/contrib/algo/count_value_test.cc + hwy/contrib/algo/find_test.cc + hwy/contrib/algo/minmax_value_test.cc + hwy/contrib/algo/transform_test.cc + hwy/contrib/bit_pack/bit_pack_test.cc + hwy/contrib/dot/dot_test.cc + hwy/contrib/matvec/matvec_test.cc + hwy/contrib/image/image_test.cc + # Disabled due to SIGILL in clang7 debug build during gtest discovery phase, + # not reproducible locally. Still tested via bazel build. + hwy/contrib/math/math_test.cc + hwy/contrib/math/math_hyper_test.cc + hwy/contrib/math/math_tan_test.cc + hwy/contrib/math/math_trig_test.cc + hwy/contrib/random/random_test.cc + hwy/contrib/sort/bench_sort.cc + hwy/contrib/sort/sort_test.cc + hwy/contrib/sort/sort_unit_test.cc + hwy/contrib/thread_pool/spin_test.cc + hwy/contrib/thread_pool/thread_pool_test.cc + hwy/contrib/thread_pool/topology_test.cc + hwy/contrib/unroller/unroller_test.cc +) +endif() # HWY_ENABLE_CONTRIB + +if(HWY_TEST_STANDALONE) + set(HWY_GTEST_LIBS "") +else() + if(HWY_SYSTEM_GTEST) + if (CMAKE_VERSION VERSION_LESS 3.20) + set(HWY_GTEST_LIBS GTest::GTest GTest::Main) + else() + set(HWY_GTEST_LIBS GTest::gtest GTest::gtest_main) + endif() + else() + set(HWY_GTEST_LIBS gtest gtest_main) + endif() +endif() # HWY_TEST_STANDALONE + +file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tests) +foreach (TESTFILE IN LISTS HWY_TEST_FILES) + # The TESTNAME is the name without the extension or directory. + get_filename_component(TESTNAME ${TESTFILE} NAME_WE) + add_executable(${TESTNAME} ${TESTFILE}) + target_compile_options(${TESTNAME} PRIVATE ${HWY_FLAGS} ${HWY_THREAD_FLAGS}) + # Test all targets, not just the best/baseline. This changes the default + # policy to all-attainable; note that setting -DHWY_COMPILE_* directly can + # cause compile errors because only one may be set, and other CMakeLists.txt + # that include us may set them. + target_compile_options(${TESTNAME} PRIVATE -DHWY_IS_TEST=1) + if(HWY_TEST_STANDALONE) + target_compile_options(${TESTNAME} PRIVATE -DHWY_TEST_STANDALONE=1) + endif() + target_compile_features(${TESTNAME} PRIVATE ${HWY_CXX_STD_TGT_COMPILE_FEATURE}) + + target_link_libraries(${TESTNAME} PRIVATE ${HWY_TEST_LIBS} ${HWY_GTEST_LIBS}) + # For GCC __atomic_store_8, see #887 + target_link_libraries(${TESTNAME} PRIVATE ${ATOMICS_LIBRARIES}) + + # Avoid linker errors if libpthread needs to be linked + target_link_libraries(${TESTNAME} PRIVATE ${HWY_THREAD_LIBS}) + + # Output test targets in the test directory. + set_target_properties(${TESTNAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "tests") + + if (HWY_EMSCRIPTEN) + set_target_properties(${TESTNAME} PROPERTIES LINK_FLAGS "-s SINGLE_FILE=1") + endif() + + if(${CMAKE_VERSION} VERSION_LESS "3.10.3") + gtest_discover_tests(${TESTNAME} TIMEOUT 60) + else () + gtest_discover_tests(${TESTNAME} DISCOVERY_TIMEOUT 60) + endif () +endforeach () + +# The skeleton test uses the skeleton library code. +target_sources(skeleton_test PRIVATE hwy/examples/skeleton.cc) +target_compile_definitions(skeleton_test PRIVATE hwy_EXPORTS) + +endif() # BUILD_TESTING +endif() # HWY_ENABLE_TESTS + +if (HWY_ENABLE_INSTALL) + # write hwy-config file to handle `Config` mode + include(CMakePackageConfigHelpers) + write_basic_package_version_file("${CMAKE_CURRENT_BINARY_DIR}/hwy-config-version.cmake" COMPATIBILITY SameMajorVersion) + install(FILES "${CMAKE_CURRENT_BINARY_DIR}/hwy-config-version.cmake" DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/hwy") + install(EXPORT hwy_targets NAMESPACE "${namespace}" FILE hwy-config.cmake DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/hwy") +endif() diff --git a/lib/highway/LICENSE b/lib/highway/LICENSE new file mode 100644 index 00000000000..1af4f15ca70 --- /dev/null +++ b/lib/highway/LICENSE @@ -0,0 +1,371 @@ +This project is primarily dual-licensed under your choice of either the Apache +License 2.0 or the BSD 3-Clause License. + +The following files are licensed under different terms: +* hwy/contrib/random/random-inl.h: CC0 1.0 Universal + +The full texts of all applicable licenses are included below, separated by +'---'. + +-------------------------------------------------------------------------------- +Apache License 2.0 +-------------------------------------------------------------------------------- + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +-------------------------------------------------------------------------------- +BSD 3-Clause License +-------------------------------------------------------------------------------- + +Copyright (c) The Highway Project Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +-------------------------------------------------------------------------------- +CC0 1.0 Universal +-------------------------------------------------------------------------------- + +Creative Commons Legal Code + +CC0 1.0 Universal + + CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE + LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN + ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS + INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES + REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS + PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM + THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED + HEREUNDER. + +Statement of Purpose + +The laws of most jurisdictions throughout the world automatically confer +exclusive Copyright and Related Rights (defined below) upon the creator +and subsequent owner(s) (each and all, an "owner") of an original work of +authorship and/or a database (each, a "Work"). + +Certain owners wish to permanently relinquish those rights to a Work for +the purpose of contributing to a commons of creative, cultural and +scientific works ("Commons") that the public can reliably and without fear +of later claims of infringement build upon, modify, incorporate in other +works, reuse and redistribute as freely as possible in any form whatsoever +and for any purposes, including without limitation commercial purposes. +These owners may contribute to the Commons to promote the ideal of a free +culture and the further production of creative, cultural and scientific +works, or to gain reputation or greater distribution for their Work in +part through the use and efforts of others. + +For these and/or other purposes and motivations, and without any +expectation of additional consideration or compensation, the person +associating CC0 with a Work (the "Affirmer"), to the extent that he or she +is an owner of Copyright and Related Rights in the Work, voluntarily +elects to apply CC0 to the Work and publicly distribute the Work under its +terms, with knowledge of his or her Copyright and Related Rights in the +Work and the meaning and intended legal effect of CC0 on those rights. + +1. Copyright and Related Rights. A Work made available under CC0 may be +protected by copyright and related or neighboring rights ("Copyright and +Related Rights"). Copyright and Related Rights include, but are not +limited to, the following: + + i. the right to reproduce, adapt, distribute, perform, display, + communicate, and translate a Work; + ii. moral rights retained by the original author(s) and/or performer(s); +iii. publicity and privacy rights pertaining to a person's image or + likeness depicted in a Work; + iv. rights protecting against unfair competition in regards to a Work, + subject to the limitations in paragraph 4(a), below; + v. rights protecting the extraction, dissemination, use and reuse of data + in a Work; + vi. database rights (such as those arising under Directive 96/9/EC of the + European Parliament and of the Council of 11 March 1996 on the legal + protection of databases, and under any national implementation + thereof, including any amended or successor version of such + directive); and +vii. other similar, equivalent or corresponding rights throughout the + world based on applicable law or treaty, and any national + implementations thereof. + +2. Waiver. To the greatest extent permitted by, but not in contravention +of, applicable law, Affirmer hereby overtly, fully, permanently, +irrevocably and unconditionally waives, abandons, and surrenders all of +Affirmer's Copyright and Related Rights and associated claims and causes +of action, whether now known or unknown (including existing as well as +future claims and causes of action), in the Work (i) in all territories +worldwide, (ii) for the maximum duration provided by applicable law or +treaty (including future time extensions), (iii) in any current or future +medium and for any number of copies, and (iv) for any purpose whatsoever, +including without limitation commercial, advertising or promotional +purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each +member of the public at large and to the detriment of Affirmer's heirs and +successors, fully intending that such Waiver shall not be subject to +revocation, rescission, cancellation, termination, or any other legal or +equitable action to disrupt the quiet enjoyment of the Work by the public +as contemplated by Affirmer's express Statement of Purpose. + +3. Public License Fallback. Should any part of the Waiver for any reason +be judged legally invalid or ineffective under applicable law, then the +Waiver shall be preserved to the maximum extent permitted taking into +account Affirmer's express Statement of Purpose. In addition, to the +extent the Waiver is so judged Affirmer hereby grants to each affected +person a royalty-free, non transferable, non sublicensable, non exclusive, +irrevocable and unconditional license to exercise Affirmer's Copyright and +Related Rights in the Work (i) in all territories worldwide, (ii) for the +maximum duration provided by applicable law or treaty (including future +time extensions), (iii) in any current or future medium and for any number +of copies, and (iv) for any purpose whatsoever, including without +limitation commercial, advertising or promotional purposes (the +"License"). The License shall be deemed effective as of the date CC0 was +applied by Affirmer to the Work. Should any part of the License for any +reason be judged legally invalid or ineffective under applicable law, such +partial invalidity or ineffectiveness shall not invalidate the remainder +of the License, and in such case Affirmer hereby affirms that he or she +will not (i) exercise any of his or her remaining Copyright and Related +Rights in the Work or (ii) assert any associated claims and causes of +action with respect to the Work, in either case contrary to Affirmer's +express Statement of Purpose. + +4. Limitations and Disclaimers. + + a. No trademark or patent rights held by Affirmer are waived, abandoned, + surrendered, licensed or otherwise affected by this document. + b. Affirmer offers the Work as-is and makes no representations or + warranties of any kind concerning the Work, express, implied, + statutory or otherwise, including without limitation warranties of + title, merchantability, fitness for a particular purpose, non + infringement, or the absence of latent or other defects, accuracy, or + the present or absence of errors, whether or not discoverable, all to + the greatest extent permissible under applicable law. + c. Affirmer disclaims responsibility for clearing rights of other persons + that may apply to the Work or any use thereof, including without + limitation any person's Copyright and Related Rights in the Work. + Further, Affirmer disclaims responsibility for obtaining any necessary + consents, permissions or other rights required for any use of the + Work. + d. Affirmer understands and acknowledges that Creative Commons is not a + party to this document and has no duty or obligation with respect to + this CC0 or use of the Work. \ No newline at end of file diff --git a/lib/highway/README.md b/lib/highway/README.md new file mode 100644 index 00000000000..ba78a9de854 --- /dev/null +++ b/lib/highway/README.md @@ -0,0 +1,521 @@ +# Efficient and performance-portable vector software + +[//]: # (placeholder, do not remove) + +Highway is a C++ library that provides portable SIMD/vector intrinsics. + +[Documentation](https://google.github.io/highway/en/master/) + +Previously licensed under Apache 2, now dual-licensed as Apache 2 / BSD-3. + +## Why + +We are passionate about high-performance software. We see major untapped +potential in CPUs (servers, mobile, desktops). Highway is for engineers who want +to reliably and economically push the boundaries of what is possible in +software. + +## How + +CPUs provide SIMD/vector instructions that apply the same operation to multiple +data items. This can reduce energy usage e.g. *fivefold* because fewer +instructions are executed. We also often see *5-10x* speedups. + +Highway makes SIMD/vector programming practical and workable according to these +guiding principles: + +**Does what you expect**: Highway is a C++ library with carefully-chosen +functions that map well to CPU instructions without extensive compiler +transformations. The resulting code is more predictable and robust to code +changes/compiler updates than autovectorization. + +**Works on widely-used platforms**: Highway supports seven architectures; the +same application code can target various instruction sets, including those with +'scalable' vectors (size unknown at compile time). Highway only requires C++17 +(language features, not necessarily the library) and supports four families of +compilers. If you want to use Highway on other platforms, please raise an issue. + +**Flexible to deploy**: Applications using Highway can run on heterogeneous +clouds or client devices, choosing the best available instruction set at +runtime. Alternatively, developers may choose to target a single instruction set +without any runtime overhead. In both cases, the application code is the same +except for swapping `HWY_STATIC_DISPATCH` with `HWY_DYNAMIC_DISPATCH` plus one +line of code. See also @kfjahnke's +[introduction to dispatching](https://github.com/kfjahnke/zimt/blob/main/examples/multi_isa_example/multi_simd_isa.md). + +**Suitable for a variety of domains**: Highway provides an extensive set of +operations, used for image processing (floating-point), compression, video +analysis, linear algebra, cryptography, sorting and random generation. We +recognise that new use-cases may require additional ops and are happy to add +them where it makes sense (e.g. no performance cliffs on some architectures). If +you would like to discuss, please file an issue. + +**Rewards data-parallel design**: Highway provides tools such as Gather, +MaskedLoad, and FixedTag to enable speedups for legacy data structures. However, +the biggest gains are unlocked by designing algorithms and data structures for +scalable vectors. Helpful techniques include batching, structure-of-array +layouts, and aligned/padded allocations. + +We recommend these resources for getting started: + +- [SIMD programming with Highway talk](https://www.youtube.com/watch?v=R57biOOhnJM) +- [SIMD for C++ Developers](http://const.me/articles/simd/simd.pdf) +- [Algorithms for Modern Hardware](https://en.algorithmica.org/hpc/) +- [Optimizing software in C++](https://agner.org/optimize/optimizing_cpp.pdf) +- [Improving performance with SIMD intrinsics in three use cases](https://stackoverflow.blog/2020/07/08/improving-performance-with-simd-intrinsics-in-three-use-cases/) + +## Examples + +Online demos using Compiler Explorer: + +- [multiple targets with dynamic dispatch](https://gcc.godbolt.org/z/KM3ben7ET) + (more complicated, but flexible and uses best available SIMD) +- [single target using -m flags](https://gcc.godbolt.org/z/rGnjMevKG) + (simpler, but requires/only uses the instruction set enabled by compiler + flags) + +We observe that Highway is referenced in the following open source projects, +found via sourcegraph.com. Most are GitHub repositories. If you would like to +add your project or link to it directly, feel free to raise an issue or contact +us via the below email. + +* Audio: [Zimtohrli perceptual metric](https://github.com/google/zimtohrli) +* Browsers: Chromium (+Vivaldi), Firefox (+floorp / foxhound / librewolf / + Waterfox) +* Computational biology: [RNA analysis](https://github.com/bnprks/BPCells), + [long-sequence preprocessing](https://github.com/OpenGene/fastplong) +* Computer graphics: ghostty-org/ghostty, + [Sparse voxel renderer](https://github.com/rools/voxl), + [tgfx 2D Graphics library](https://github.com/Tencent/tgfx) +* Cryptography: google/distributed_point_functions, google/shell-encryption +* Data structures: bkille/BitLib +* Image codecs: eustas/2im, + [Grok JPEG 2000](https://github.com/GrokImageCompression/grok), + [JPEG XL](https://github.com/libjxl/libjxl), + [JPEGenc](https://github.com/osamu620/JPEGenc), + [Jpegli](https://github.com/google/jpegli), + [libaom](https://aomedia.googlesource.com/aom/), + [OpenHTJ2K](https://github.com/osamu620/OpenHTJ2K) +* Image processing: awxkee/aire, cloudinary/ssimulacra2, + [libvips](https://github.com/libvips/libvips), m-ab-s/media-autobuild_suite, +* Image viewers: AlienCowEatCake/ImageViewer, diffractor/diffractor, + [Lux panorama/image viewer](https://bitbucket.org/kfj/pv/), + mirillis/jpegxl-wic +* Information retrieval: + [iresearch database index](https://github.com/iresearch-toolkit/iresearch), + michaeljclark/zvec, + [nebula interactive analytics / OLAP](https://github.com/varchar-io/nebula), + [`ScaNN` Scalable Nearest Neighbors](https://github.com/google-research/google-research/tree/7a269cb2ce0ae1db591fe11b62cbc0be7d72532a/scann), +* Machine learning: array2d/deepx, + [gemma.cpp](https://github.com/google/gemma.cpp), Tensorflow, Numpy, + zpye/SimpleInfer +* Programming languages: + [AOT-compiled python](https://github.com/exaloop/codon), oven-sh/bun, V8/V8, + yinqiwen/rapidudf +* Robotics: + [MIT Model-Based Design and Verification](https://github.com/RobotLocomotion/drake) +* Vector search: 1yefuwang1/vectorlite, vespa-engine/vespa + +Other + +* [Evaluation of C++ SIMD Libraries](https://www.mnm-team.org/pub/Fopras/rock23/): + "Highway excelled with a strong performance across multiple SIMD extensions + [..]. Thus, Highway may currently be the most suitable SIMD library for many + software projects." +* [zimt](https://github.com/kfjahnke/zimt): C++11 template library to process n-dimensional arrays with multi-threaded SIMD code +* [vectorized Quicksort](https://github.com/google/highway/tree/master/hwy/contrib/sort) ([paper](https://arxiv.org/abs/2205.05982)) + +If you'd like to get Highway, in addition to cloning from this GitHub repository +or using it as a Git submodule, you can also find it in the following package +managers or repositories: + +* alpinelinux +* conan-io +* conda-forge +* DragonFlyBSD, +* fd00/yacp +* freebsd +* getsolus/packages +* ghostbsd +* microsoft/vcpkg +* MidnightBSD +* MSYS2 +* NetBSD +* openSUSE +* opnsense +* Xilinx/Vitis_Libraries +* xmake-io/xmake-repo + +See also the list at https://repology.org/project/highway-simd-library/versions +. + +## Current status + +### Targets + +Highway supports 27 targets, listed in alphabetical order of platform: + +- Any: `EMU128`, `SCALAR`; +- Armv7+: `NEON_WITHOUT_AES`, `NEON`, `NEON_BF16`, `SVE`, `SVE2`, `SVE_256`, + `SVE2_128`; +- IBM Z: `Z14`, `Z15`; +- LoongArch: `LSX`, `LASX`; +- POWER: `PPC8` (v2.07), `PPC9` (v3.0), `PPC10` (v3.1B, not yet supported due + to compiler bugs, see #1207; also requires QEMU 7.2); +- RISC-V: `RVV` (1.0); +- WebAssembly: `WASM`, `WASM_EMU256` (a 2x unrolled version of wasm128, + enabled if `HWY_WANT_WASM2` is defined. This will remain supported until it + is potentially superseded by a future version of WASM.); +- x86: + - `SSE2` + - `SSSE3` (~Intel Core) + - `SSE4` (~Nehalem, also includes AES + CLMUL). + - `AVX2` (~Haswell, also includes BMI2 + F16 + FMA) + - `AVX3` (~Skylake, AVX-512F/BW/CD/DQ/VL) + - `AVX3_DL` (~Icelake, includes `BitAlg` + `CLMUL` + `GFNI` + `VAES` + + `VBMI` + `VBMI2` + `VNNI` + `VPOPCNT`), + - `AVX3_ZEN4` (AVX3_DL plus BF16, optimized for AMD Zen4; requires opt-in + by defining `HWY_WANT_AVX3_ZEN4` if compiling for static dispatch, but + enabled by default for runtime dispatch), + - `AVX3_SPR` (~Sapphire Rapids, includes AVX-512FP16) + - `AVX10_2` (~Diamond Rapids) + +Our policy is that unless otherwise specified, targets will remain supported as +long as they can be (cross-)compiled with currently supported Clang or GCC, and +tested using QEMU. If the target can be compiled with LLVM trunk and tested +using our version of QEMU without extra flags, then it is eligible for inclusion +in our continuous testing infrastructure. Otherwise, the target will be manually +tested before releases with selected versions/configurations of Clang and GCC. + +SVE was initially tested using farm_sve (see acknowledgments). + +### Versioning + +Highway releases aim to follow the semver.org system (MAJOR.MINOR.PATCH), +incrementing MINOR after backward-compatible additions and PATCH after +backward-compatible fixes. We recommend using releases (rather than the Git tip) +because they are tested more extensively, see below. + +The current version 1.0 signals an increased focus on backwards compatibility. +Applications using documented functionality will remain compatible with future +updates that have the same major version number. + +### Testing + +Continuous integration tests build with a recent version of Clang (running on +native x86, or QEMU for RISC-V and Arm) and MSVC 2019 (v19.28, running on native +x86). + +Before releases, we also test on x86 with Clang and GCC, and Armv7/8 via GCC +cross-compile. See the [testing process](g3doc/release_testing_process.md) for +details. + +### Related modules + +The `contrib` directory contains SIMD-related utilities: an image class with +aligned rows, a math library (16 functions already implemented, mostly +trigonometry), and functions for computing dot products and sorting. + +### Other libraries + +If you only require x86 support, you may also use Agner Fog's +[VCL vector class library](https://github.com/vectorclass). It includes many +functions including a complete math library. + +If you have existing code using x86/NEON intrinsics, you may be interested in +[SIMDe](https://github.com/simd-everywhere/simde), which emulates those +intrinsics using other platforms' intrinsics or autovectorization. + +[xSIMD](https://github.com/xtensor-stack/xsimd) is a header only C++ library. +It supports Arm, Power, RISC-V, WebAssembly and x86 targets. Has a high level +interface, but fewer supported operations. + +[NumKong](https://github.com/ashvardanian/NumKong) a SIMD accelerated math C +library focused on operations such as dot products and mixed precision matrix +multiplications. It can be used from C++, Go, Python, Rust, Swift and +WebAssembly. Accelerated operations are availble on ARM, LoongArch, Power, +RISC-V and x86. + +## Installation + +This project uses CMake to generate and build. In a Debian-based system you can +install it via: + +```bash +sudo apt install cmake +``` + +Highway's unit tests use [googletest](https://github.com/google/googletest). +By default, Highway's CMake downloads this dependency at configuration time. +You can avoid this by setting the `HWY_SYSTEM_GTEST` CMake variable to ON and +installing gtest separately: + +```bash +sudo apt install libgtest-dev +``` + +Alternatively, you can define `HWY_TEST_STANDALONE=1` and remove all occurrences +of `gtest_main` in each BUILD file, then tests avoid the dependency on GUnit. + +Running cross-compiled tests requires support from the OS, which on Debian is +provided by the `qemu-user-binfmt` package. + +To build Highway as a shared or static library (depending on BUILD_SHARED_LIBS), +the standard CMake workflow can be used: + +```bash +mkdir -p build && cd build +cmake .. +make -j && make test +``` + +Or you can run `run_tests.sh` (`run_tests.bat` on Windows). + +Bazel is also supported for building, but it is not as widely used/tested. + +When building for Armv7, a limitation of current compilers requires you to add +`-DHWY_CMAKE_ARM7:BOOL=ON` to the CMake command line; see #834 and #1032. We +understand that work is underway to remove this limitation. + +To benefit from Armv8/v9 vusdot and vusdotq instructions, you can add "+i8mm" to +the -march compiler flag, assuming the target CPU(s) support that. + +Building on 32-bit x86 is not officially supported, and AVX2/3 are disabled by +default there. Note that johnplatts has successfully built and run the Highway +tests on 32-bit x86, including AVX2/3, on GCC 7/8 and Clang 8/11/12. On Ubuntu +22.04, Clang 11 and 12, but not later versions, require extra compiler flags +`-m32 -isystem /usr/i686-linux-gnu/include`. Clang 10 and earlier require the +above plus `-isystem /usr/i686-linux-gnu/include/c++/12/i686-linux-gnu`. See +[#1279](https://github.com/google/highway/issues/1279). + +## Building highway - Using vcpkg + +highway is now available in [vcpkg](https://github.com/Microsoft/vcpkg) + +```bash +vcpkg install highway +``` + +The highway port in vcpkg is kept up to date by Microsoft team members and community contributors. If the version is out of date, please [create an issue or pull request](https://github.com/Microsoft/vcpkg) on the vcpkg repository. + +## Quick start + +You can use the `benchmark` inside examples/ as a starting point. + +A [quick-reference page](g3doc/quick_reference.md) briefly lists all operations +and their parameters, and the [instruction_matrix](g3doc/instruction_matrix.pdf) +indicates the number of instructions per operation. + +The [FAQ](g3doc/faq.md) answers questions about portability, API design and +where to find more information. + +We recommend using full SIMD vectors whenever possible for maximum performance +portability. To obtain them, pass a `ScalableTag` (or equivalently +`HWY_FULL(float)`) tag to functions such as `Zero/Set/Load`. There are two +alternatives for use-cases requiring an upper bound on the lanes: + +- For up to `N` lanes, specify `CappedTag` or the equivalent + `HWY_CAPPED(T, N)`. The actual number of lanes will be `N` rounded down to + the nearest power of two, such as 4 if `N` is 5, or 8 if `N` is 8. This is + useful for data structures such as a narrow matrix. A loop is still required + because vectors may actually have fewer than `N` lanes. + +- For exactly a power of two `N` lanes, specify `FixedTag`. The largest + supported `N` depends on the target, but is guaranteed to be at least + `16/sizeof(T)`. + +Due to ADL restrictions, user code calling Highway ops must either: + +* Reside inside `namespace hwy { namespace HWY_NAMESPACE {`; or +* prefix each op with an alias such as `namespace hn = hwy::HWY_NAMESPACE; + hn::Add()`; or +* add using-declarations for each op used: `using hwy::HWY_NAMESPACE::Add;`. + +Additionally, each function that calls Highway ops (such as `Load`) must either +be prefixed with `HWY_ATTR`, OR reside between `HWY_BEFORE_NAMESPACE()` and +`HWY_AFTER_NAMESPACE()`. Lambda functions currently require `HWY_ATTR` before +their opening brace. + +Do not use namespace-scope nor `static` initializers for SIMD vectors because +this can cause SIGILL when using runtime dispatch and the compiler chooses an +initializer compiled for a target not supported by the current CPU. Instead, +constants initialized via `Set` should generally be local (const) variables. + +The entry points into code using Highway differ slightly depending on whether +they use static or dynamic dispatch. In both cases, we recommend that the +top-level function receives one or more pointers to arrays, rather than +target-specific vector types. + +* For static dispatch, `HWY_TARGET` will be the best available target among + `HWY_BASELINE_TARGETS`, i.e. those allowed for use by the compiler (see + [quick-reference](g3doc/quick_reference.md)). Functions inside + `HWY_NAMESPACE` can be called using `HWY_STATIC_DISPATCH(func)(args)` within + the same module they are defined in. You can call the function from other + modules by wrapping it in a regular function and declaring the regular + function in a header. + +* For dynamic dispatch, a table of function pointers is generated via the + `HWY_EXPORT` macro that is used by `HWY_DYNAMIC_DISPATCH(func)(args)` to + call the best function pointer for the current CPU's supported targets. A + module is automatically compiled for each target in `HWY_TARGETS` (see + [quick-reference](g3doc/quick_reference.md)) if `HWY_TARGET_INCLUDE` is + defined and `foreach_target.h` is included. Note that the first invocation + of `HWY_DYNAMIC_DISPATCH`, or each call to the pointer returned by the first + invocation of `HWY_DYNAMIC_POINTER`, involves some CPU detection overhead. + You can prevent this by calling the following before any invocation of + `HWY_DYNAMIC_*`: `hwy::GetChosenTarget().Update(hwy::SupportedTargets());`. + +See also a separate +[introduction to dynamic dispatch](https://github.com/kfjahnke/zimt/blob/multi_isa/examples/multi_isa_example/multi_simd_isa.md) +by @kfjahnke. + +When using dynamic dispatch, `foreach_target.h` is included from translation +units (.cc files), not headers. Headers containing vector code shared between +several translation units require a special include guard, for example the +following taken from `examples/skeleton-inl.h`: + +``` +#if defined(HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_ +#undef HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_ +#else +#define HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_ +#endif + +#include "hwy/highway.h" +// Your vector code +#endif +``` + +By convention, we name such headers `-inl.h` because their contents (often +function templates) are usually inlined. + +## Compiler flags + +Applications should be compiled with optimizations enabled. Without inlining +SIMD code may slow down by factors of 10 to 100. For clang and GCC, `-O2` is +generally sufficient. + +For MSVC, we recommend compiling with `/Gv` to allow non-inlined functions to +pass vector arguments in registers. If intending to use the AVX2 target together +with half-width vectors (e.g. for `PromoteTo`), it is also important to compile +with `/arch:AVX2`. This seems to be the only way to reliably generate +VEX-encoded SSE instructions on MSVC. Sometimes MSVC generates VEX-encoded SSE +instructions, if they are mixed with AVX, but not always, see +[DevCom-10618264](https://developercommunity.visualstudio.com/t/10618264). +Otherwise, mixing VEX-encoded AVX2 instructions and non-VEX SSE may cause severe +performance degradation. Unfortunately, with `/arch:AVX2` option, the resulting +binary will then require AVX2. Note that no such flag is needed for clang and +GCC because they support target-specific attributes, which we use to ensure +proper VEX code generation for AVX2 targets. + +## Strip-mining loops + +When vectorizing a loop, an important question is whether and how to deal with +a number of iterations ('trip count', denoted `count`) that does not evenly +divide the vector size `N = Lanes(d)`. For example, it may be necessary to avoid +writing past the end of an array. + +In this section, let `T` denote the element type and `d = ScalableTag`. +Assume the loop body is given as a function `template +void LoopBody(D d, size_t index, size_t max_n)`. + +"Strip-mining" is a technique for vectorizing a loop by transforming it into an +outer loop and inner loop, such that the number of iterations in the inner loop +matches the vector width. Then, the inner loop is replaced with vector +operations. + +Highway offers several strategies for loop vectorization: + +* Ensure all inputs/outputs are padded. Then the (outer) loop is simply + + ``` + for (size_t i = 0; i < count; i += N) LoopBody(d, i, 0); + ``` + Here, the template parameter and second function argument are not needed. + + This is the preferred option, unless `N` is in the thousands and vector + operations are pipelined with long latencies. This was the case for + supercomputers in the 90s, but nowadays ALUs are cheap and we see most + implementations split vectors into 1, 2 or 4 parts, so there is little cost + to processing entire vectors even if we do not need all their lanes. Indeed + this avoids the (potentially large) cost of predication or partial + loads/stores on older targets, and does not duplicate code. + +* Process whole vectors and include previously processed elements + in the last vector: + ``` + for (size_t i = 0; i < count; i += N) LoopBody(d, HWY_MIN(i, count - N), 0); + ``` + + This is the second preferred option provided that `count >= N` + and `LoopBody` is idempotent. Some elements might be processed twice, but + a single code path and full vectorization is usually worth it. Even if + `count < N`, it usually makes sense to pad inputs/outputs up to `N`. + +* Use the `Transform*` functions in hwy/contrib/algo/transform-inl.h. This + takes care of the loop and remainder handling and you simply define a + generic lambda function (C++14) or functor which receives the current vector + from the input/output array, plus optionally vectors from up to two extra + input arrays, and returns the value to write to the input/output array. + + Here is an example implementing the BLAS function SAXPY (`alpha * x + y`): + + ``` + Transform1(d, x, n, y, [](auto d, const auto v, const auto v1) HWY_ATTR { + return MulAdd(Set(d, alpha), v, v1); + }); + ``` + +* Process whole vectors as above, followed by a scalar loop: + + ``` + size_t i = 0; + for (; i + N <= count; i += N) LoopBody(d, i, 0); + for (; i < count; ++i) LoopBody(CappedTag(), i, 0); + ``` + The template parameter and second function arguments are again not needed. + + This avoids duplicating code, and is reasonable if `count` is large. + If `count` is small, the second loop may be slower than the next option. + +* Process whole vectors as above, followed by a single call to a modified + `LoopBody` with masking: + + ``` + size_t i = 0; + for (; i + N <= count; i += N) { + LoopBody(d, i, 0); + } + if (i < count) { + LoopBody(d, i, count - i); + } + ``` + Now the template parameter and third function argument can be used inside + `LoopBody` to non-atomically 'blend' the first `num_remaining` lanes of `v` + with the previous contents of memory at subsequent locations: + `BlendedStore(v, FirstN(d, num_remaining), d, pointer);`. Similarly, + `MaskedLoad(FirstN(d, num_remaining), d, pointer)` loads the first + `num_remaining` elements and returns zero in other lanes. + + This is a good default when it is infeasible to ensure vectors are padded, + but is only safe `#if !HWY_MEM_OPS_MIGHT_FAULT`! + In contrast to the scalar loop, only a single final iteration is needed. + The increased code size from two loop bodies is expected to be worthwhile + because it avoids the cost of masking in all but the final iteration. + +## Additional resources + +* [Highway introduction (slides)](g3doc/highway_intro.pdf) +* [Overview of instructions per operation on different architectures](g3doc/instruction_matrix.pdf) +* [Design philosophy and comparison](g3doc/design_philosophy.md) +* [Implementation details](g3doc/impl_details.md) + +## Acknowledgments + +We have used [farm-sve](https://gitlab.inria.fr/bramas/farm-sve) by Berenger +Bramas; it has proved useful for checking the SVE port on an x86 development +machine. + +This is not an officially supported Google product. +Contact: janwas@google.com diff --git a/lib/highway/cmake/FindAtomics.cmake b/lib/highway/cmake/FindAtomics.cmake new file mode 100644 index 00000000000..e866b73fac3 --- /dev/null +++ b/lib/highway/cmake/FindAtomics.cmake @@ -0,0 +1,56 @@ +# Original issue: +# * https://gitlab.kitware.com/cmake/cmake/-/issues/23021#note_1098733 +# +# For reference: +# * https://gcc.gnu.org/wiki/Atomic/GCCMM +# +# riscv64 specific: +# * https://lists.debian.org/debian-riscv/2022/01/msg00009.html +# +# ATOMICS_FOUND - system has c++ atomics +# ATOMICS_LIBRARIES - libraries needed to use c++ atomics + +include(CheckCXXSourceCompiles) + +# RISC-V only has 32-bit and 64-bit atomic instructions. GCC is supposed +# to convert smaller atomics to those larger ones via masking and +# shifting like LLVM, but it’s a known bug that it does not. This means +# anything that wants to use atomics on 1-byte or 2-byte types needs +# -latomic, but not 4-byte or 8-byte (though it does no harm). +set(atomic_code + " + #include + #include + std::atomic n8 (0); // riscv64 + std::atomic n64 (0); // armel, mipsel, powerpc + int main() { + ++n8; + ++n64; + return 0; + }") + +# https://gitlab.kitware.com/cmake/cmake/-/issues/24063 +set(CMAKE_CXX_STANDARD 11) +check_cxx_source_compiles("${atomic_code}" ATOMICS_LOCK_FREE_INSTRUCTIONS) + +if(ATOMICS_LOCK_FREE_INSTRUCTIONS) + set(ATOMICS_FOUND TRUE) + set(ATOMICS_LIBRARIES) +else() + set(CMAKE_REQUIRED_LIBRARIES "-latomic") + check_cxx_source_compiles("${atomic_code}" ATOMICS_IN_LIBRARY) + set(CMAKE_REQUIRED_LIBRARIES) + if(ATOMICS_IN_LIBRARY) + set(ATOMICS_LIBRARY atomic) + include(FindPackageHandleStandardArgs) + find_package_handle_standard_args(Atomics DEFAULT_MSG ATOMICS_LIBRARY) + set(ATOMICS_LIBRARIES ${ATOMICS_LIBRARY}) + unset(ATOMICS_LIBRARY) + else() + if(Atomics_FIND_REQUIRED) + message(FATAL_ERROR "Neither lock free instructions nor -latomic found.") + endif() + endif() +endif() +unset(atomic_code) +unset(CMAKE_CXX_STANDARD) diff --git a/lib/highway/hwy/abort.cc b/lib/highway/hwy/abort.cc new file mode 100644 index 00000000000..a67819bbd35 --- /dev/null +++ b/lib/highway/hwy/abort.cc @@ -0,0 +1,117 @@ +// Copyright 2019 Google LLC +// Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: Apache-2.0 +// SPDX-License-Identifier: BSD-3-Clause + +#include "hwy/abort.h" + +#include +#include +#include + +#include +#include + +#include "hwy/base.h" + +#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN +#include "sanitizer/common_interface_defs.h" // __sanitizer_print_stack_trace +#endif + +namespace hwy { + +namespace { + +std::atomic& AtomicWarnFunc() { + static std::atomic func; + return func; +} + +std::atomic& AtomicAbortFunc() { + static std::atomic func; + return func; +} + +std::string GetBaseName(std::string const& file_name) { + auto last_slash = file_name.find_last_of("/\\"); + return file_name.substr(last_slash + 1); +} + +} // namespace + +// Returning a reference is unfortunately incompatible with `std::atomic`, which +// is required to safely implement `SetWarnFunc`. As a workaround, we store a +// copy here, update it when called, and return a reference to the copy. This +// has the added benefit of protecting the actual pointer from modification. +HWY_DLLEXPORT WarnFunc& GetWarnFunc() { + static WarnFunc func; + func = AtomicWarnFunc().load(); + return func; +} + +HWY_DLLEXPORT AbortFunc& GetAbortFunc() { + static AbortFunc func; + func = AtomicAbortFunc().load(); + return func; +} + +HWY_DLLEXPORT WarnFunc SetWarnFunc(WarnFunc func) { + return AtomicWarnFunc().exchange(func); +} + +HWY_DLLEXPORT AbortFunc SetAbortFunc(AbortFunc func) { + return AtomicAbortFunc().exchange(func); +} + +HWY_DLLEXPORT void HWY_FORMAT(3, 4) + Warn(const char* file, int line, const char* format, ...) { + char buf[800]; + va_list args; + va_start(args, format); + vsnprintf(buf, sizeof(buf), format, args); + va_end(args); + + WarnFunc handler = AtomicWarnFunc().load(); + if (handler != nullptr) { + handler(file, line, buf); + } else { + fprintf(stderr, "Warn at %s:%d: %s\n", GetBaseName(file).data(), line, buf); + } +} + +HWY_DLLEXPORT HWY_NORETURN void HWY_FORMAT(3, 4) + Abort(const char* file, int line, const char* format, ...) { + char buf[800]; + va_list args; + va_start(args, format); + vsnprintf(buf, sizeof(buf), format, args); + va_end(args); + + AbortFunc handler = AtomicAbortFunc().load(); + if (handler != nullptr) { + handler(file, line, buf); + } else { + fprintf(stderr, "Abort at %s:%d: %s\n", GetBaseName(file).data(), line, + buf); + } + +// If compiled with any sanitizer, they can also print a stack trace. +#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN + __sanitizer_print_stack_trace(); +#endif // HWY_IS_* + fflush(stderr); + +// Now terminate the program: +#if HWY_ARCH_RISCV + exit(1); // trap/abort just freeze Spike. +#elif HWY_IS_DEBUG_BUILD && !HWY_COMPILER_MSVC && !HWY_ARCH_ARM + // Facilitates breaking into a debugger, but don't use this in non-debug + // builds because it looks like "illegal instruction", which is misleading. + // Also does not work on Arm. + __builtin_trap(); +#else + abort(); // Compile error without this due to HWY_NORETURN. +#endif +} + +} // namespace hwy diff --git a/lib/highway/hwy/abort.h b/lib/highway/hwy/abort.h new file mode 100644 index 00000000000..931e9780504 --- /dev/null +++ b/lib/highway/hwy/abort.h @@ -0,0 +1,11 @@ +// Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: Apache-2.0 +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef HIGHWAY_HWY_ABORT_H_ +#define HIGHWAY_HWY_ABORT_H_ + +// Empty header for compatibility. +// All Abort/Warn functionalities are in base.h. + +#endif // HIGHWAY_HWY_ABORT_H_ diff --git a/lib/highway/hwy/aligned_allocator.cc b/lib/highway/hwy/aligned_allocator.cc new file mode 100644 index 00000000000..e857b2288fb --- /dev/null +++ b/lib/highway/hwy/aligned_allocator.cc @@ -0,0 +1,156 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/aligned_allocator.h" + +#include +#include +#include // malloc + +#include +#include + +#include "hwy/base.h" + +namespace hwy { +namespace { + +#if HWY_ARCH_RISCV && defined(__riscv_v_intrinsic) && \ + __riscv_v_intrinsic >= 11000 +// Not actually an upper bound on the size, but this value prevents crossing a +// 4K boundary (relevant on Andes). +constexpr size_t kAlignment = HWY_MAX(HWY_ALIGNMENT, 4096); +#else +constexpr size_t kAlignment = HWY_ALIGNMENT; +#endif + +#if HWY_ARCH_X86 +// On x86, aliasing can only occur at multiples of 2K. To reduce the chance of +// allocations being equal mod 2K, we round up to kAlias and add a cyclic +// offset which is a multiple of kAlignment. Rounding up to only 1K decreases +// the number of alias-free allocations, but also wastes less memory. +constexpr size_t kAlias = HWY_MAX(kAlignment, 1024); +#else +constexpr size_t kAlias = kAlignment; +#endif + +#pragma pack(push, 1) +struct AllocationHeader { + void* allocated; + size_t payload_size; +}; +#pragma pack(pop) + +// Returns a 'random' (cyclical) offset for AllocateAlignedBytes. +size_t NextAlignedOffset() { + static std::atomic next{0}; + static_assert(kAlias % kAlignment == 0, "kAlias must be a multiple"); + constexpr size_t kGroups = kAlias / kAlignment; + const size_t group = next.fetch_add(1, std::memory_order_relaxed) % kGroups; + const size_t offset = kAlignment * group; + HWY_DASSERT((offset % kAlignment == 0) && offset <= kAlias); + return offset; +} + +} // namespace + +HWY_DLLEXPORT void* AllocateAlignedBytes(const size_t payload_size, + AllocPtr alloc_ptr, void* opaque_ptr) { + HWY_ASSERT(payload_size != 0); // likely a bug in caller + if (payload_size >= std::numeric_limits::max() / 2) { + HWY_DASSERT(false && "payload_size too large"); + return nullptr; + } + + size_t offset = NextAlignedOffset(); + + // What: | misalign | unused | AllocationHeader |payload + // Size: |<= kAlias | offset |payload_size + // ^allocated.^aligned.^header............^payload + // The header must immediately precede payload, which must remain aligned. + // To avoid wasting space, the header resides at the end of `unused`, + // which therefore cannot be empty (offset == 0). + if (offset == 0) { + offset = RoundUpTo(sizeof(AllocationHeader), kAlignment); + } + + const size_t allocated_size = kAlias + offset + payload_size; + void* allocated; + if (alloc_ptr == nullptr) { + allocated = malloc(allocated_size); + } else { + allocated = (*alloc_ptr)(opaque_ptr, allocated_size); + } + if (allocated == nullptr) return nullptr; + // Always round up even if already aligned - we already asked for kAlias + // extra bytes and there's no way to give them back. + uintptr_t aligned = reinterpret_cast(allocated) + kAlias; + static_assert((kAlias & (kAlias - 1)) == 0, "kAlias must be a power of 2"); + static_assert(kAlias >= kAlignment, "Cannot align to more than kAlias"); + aligned &= ~(kAlias - 1); + + const uintptr_t payload = aligned + offset; // still aligned + HWY_DASSERT(payload % kAlignment == 0); + + // Stash `allocated` and payload_size inside header for FreeAlignedBytes(). + // The allocated_size can be reconstructed from the payload_size. + AllocationHeader* header = reinterpret_cast(payload) - 1; + HWY_DASSERT(reinterpret_cast(header) >= aligned); + header->allocated = allocated; + header->payload_size = payload_size; + + return HWY_ASSUME_ALIGNED(reinterpret_cast(payload), kAlignment); +} + +HWY_DLLEXPORT void FreeAlignedBytes(const void* aligned_pointer, + FreePtr free_ptr, void* opaque_ptr) { + if (aligned_pointer == nullptr) return; + + const uintptr_t payload = reinterpret_cast(aligned_pointer); + HWY_DASSERT(payload % kAlignment == 0); + const AllocationHeader* header = + reinterpret_cast(payload) - 1; + + if (free_ptr == nullptr) { + free(header->allocated); + } else { + (*free_ptr)(opaque_ptr, header->allocated); + } +} + +// static +HWY_DLLEXPORT void AlignedDeleter::DeleteAlignedArray(void* aligned_pointer, + FreePtr free_ptr, + void* opaque_ptr, + ArrayDeleter deleter) { + if (aligned_pointer == nullptr) return; + + const uintptr_t payload = reinterpret_cast(aligned_pointer); + HWY_DASSERT(payload % kAlignment == 0); + const AllocationHeader* header = + reinterpret_cast(payload) - 1; + + if (deleter) { + (*deleter)(aligned_pointer, header->payload_size); + } + + if (free_ptr == nullptr) { + free(header->allocated); + } else { + (*free_ptr)(opaque_ptr, header->allocated); + } +} + +} // namespace hwy diff --git a/lib/highway/hwy/aligned_allocator.h b/lib/highway/hwy/aligned_allocator.h new file mode 100644 index 00000000000..a682bdbac5e --- /dev/null +++ b/lib/highway/hwy/aligned_allocator.h @@ -0,0 +1,469 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_ALIGNED_ALLOCATOR_H_ +#define HIGHWAY_HWY_ALIGNED_ALLOCATOR_H_ + +// Memory allocator with support for alignment and offsets. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "hwy/base.h" +#include "hwy/per_target.h" + +namespace hwy { + +// Minimum alignment of allocated memory for use in HWY_ASSUME_ALIGNED, which +// requires a literal. To prevent false sharing, this should be at least the +// L1 cache line size, usually 64 bytes. However, Intel's L2 prefetchers may +// access pairs of lines, and M1 L2 and POWER8 lines are also 128 bytes. +#define HWY_ALIGNMENT 128 + +// `align` is in bytes. +template +HWY_API constexpr bool IsAligned(T* ptr, size_t align = HWY_ALIGNMENT) { + return reinterpret_cast(ptr) % align == 0; +} + +// Pointers to functions equivalent to malloc/free with an opaque void* passed +// to them. +using AllocPtr = void* (*)(void* opaque, size_t bytes); +using FreePtr = void (*)(void* opaque, void* memory); + +// Returns null or a pointer to at least `payload_size` (which can be zero) +// bytes of newly allocated memory, aligned to the larger of HWY_ALIGNMENT and +// the vector size. Calls `alloc` with the passed `opaque` pointer to obtain +// memory or malloc() if it is null. +HWY_DLLEXPORT void* AllocateAlignedBytes(size_t payload_size, + AllocPtr alloc_ptr = nullptr, + void* opaque_ptr = nullptr); + +// Frees all memory. No effect if `aligned_pointer` == nullptr, otherwise it +// must have been returned from a previous call to `AllocateAlignedBytes`. +// Calls `free_ptr` with the passed `opaque_ptr` pointer to free the memory; if +// `free_ptr` function is null, uses the default free(). +HWY_DLLEXPORT void FreeAlignedBytes(const void* aligned_pointer, + FreePtr free_ptr, void* opaque_ptr); + +// Class that deletes the aligned pointer passed to operator() calling the +// destructor before freeing the pointer. This is equivalent to the +// std::default_delete but for aligned objects. For a similar deleter equivalent +// to free() for aligned memory see AlignedFreer(). +class AlignedDeleter { + public: + AlignedDeleter() : free_(nullptr), opaque_ptr_(nullptr) {} + AlignedDeleter(FreePtr free_ptr, void* opaque_ptr) + : free_(free_ptr), opaque_ptr_(opaque_ptr) {} + + template + void operator()(T* aligned_pointer) const { + return DeleteAlignedArray(aligned_pointer, free_, opaque_ptr_, + TypedArrayDeleter); + } + + private: + template + static void TypedArrayDeleter(void* ptr, size_t size_in_bytes) { + size_t elems = size_in_bytes / sizeof(T); + for (size_t i = 0; i < elems; i++) { + // Explicitly call the destructor on each element. + (static_cast(ptr) + i)->~T(); + } + } + + // Function prototype that calls the destructor for each element in a typed + // array. TypeArrayDeleter would match this prototype. + using ArrayDeleter = void (*)(void* t_ptr, size_t t_size); + + HWY_DLLEXPORT static void DeleteAlignedArray(void* aligned_pointer, + FreePtr free_ptr, + void* opaque_ptr, + ArrayDeleter deleter); + + FreePtr free_; + void* opaque_ptr_; +}; + +// Unique pointer to T with custom aligned deleter. This can be a single +// element U or an array of element if T is a U[]. The custom aligned deleter +// will call the destructor on U or each element of a U[] in the array case. +template +using AlignedUniquePtr = std::unique_ptr; + +// Aligned memory equivalent of make_unique using the custom allocators +// alloc/free with the passed `opaque` pointer. This function calls the +// constructor with the passed Args... and calls the destructor of the object +// when the AlignedUniquePtr is destroyed. +template +AlignedUniquePtr MakeUniqueAlignedWithAlloc(AllocPtr alloc, FreePtr free, + void* opaque, Args&&... args) { + T* ptr = static_cast(AllocateAlignedBytes(sizeof(T), alloc, opaque)); + if (HWY_UNLIKELY(ptr == nullptr)) { + return AlignedUniquePtr(nullptr, AlignedDeleter(free, opaque)); + } +#ifdef HWY_EXCEPTIONS_ENABLED + try { + return AlignedUniquePtr(new (ptr) T(std::forward(args)...), + AlignedDeleter(free, opaque)); + } catch (...) { + FreeAlignedBytes(ptr, free, opaque); + throw; + } +#else + return AlignedUniquePtr(new (ptr) T(std::forward(args)...), + AlignedDeleter(free, opaque)); +#endif +} + +// Similar to MakeUniqueAlignedWithAlloc but using the default alloc/free +// functions. +template +AlignedUniquePtr MakeUniqueAligned(Args&&... args) { + T* ptr = static_cast(AllocateAlignedBytes(sizeof(T))); + if (HWY_UNLIKELY(ptr == nullptr)) { + return AlignedUniquePtr(nullptr, AlignedDeleter()); + } +#ifdef HWY_EXCEPTIONS_ENABLED + try { + return AlignedUniquePtr(new (ptr) T(std::forward(args)...), + AlignedDeleter()); + } catch (...) { + FreeAlignedBytes(ptr, nullptr, nullptr); + throw; + } +#else + return AlignedUniquePtr(new (ptr) T(std::forward(args)...), + AlignedDeleter()); +#endif +} + +template +struct AlignedAllocator { + using value_type = T; + + AlignedAllocator() = default; + + template + explicit AlignedAllocator(const AlignedAllocator&) noexcept {} + + template + value_type* allocate(V n) { + static_assert(std::is_integral::value, + "AlignedAllocator only supports integer types"); + static_assert(sizeof(V) <= sizeof(std::size_t), + "V n must be smaller or equal size_t to avoid overflow"); + const size_t count = static_cast(n); + if (HWY_LIKELY(count != 0) && sizeof(value_type) > SIZE_MAX / count) { + HWY_ABORT("AlignedAllocator: allocation size overflow " + "(%zu * %zu exceeds size_t)", count, sizeof(value_type)); + } + return static_cast( + AllocateAlignedBytes(count * sizeof(value_type))); + } + + template + void deallocate(value_type* p, HWY_MAYBE_UNUSED V n) { + return FreeAlignedBytes(p, nullptr, nullptr); + } +}; + +template +constexpr bool operator==(const AlignedAllocator&, + const AlignedAllocator&) noexcept { + return true; +} + +template +constexpr bool operator!=(const AlignedAllocator&, + const AlignedAllocator&) noexcept { + return false; +} + +template +using AlignedVector = std::vector>; + +// Helpers for array allocators (avoids overflow) +namespace detail { + +// Returns x such that 1u << x == n (if n is a power of two). +static inline constexpr size_t ShiftCount(size_t n) { + return (n <= 1) ? 0 : 1 + ShiftCount(n / 2); +} + +template +T* AllocateAlignedItems(size_t items, AllocPtr alloc_ptr, void* opaque_ptr) { + constexpr size_t kSize = sizeof(T); + + constexpr bool kIsPow2 = (kSize & (kSize - 1)) == 0; + constexpr size_t kBits = ShiftCount(kSize); + static_assert(!kIsPow2 || (1ull << kBits) == kSize, "ShiftCount has a bug"); + + const size_t bytes = kIsPow2 ? items << kBits : items * kSize; + const size_t check = kIsPow2 ? bytes >> kBits : bytes / kSize; + if (check != items) { + return nullptr; // overflowed + } + return static_cast(AllocateAlignedBytes(bytes, alloc_ptr, opaque_ptr)); +} + +} // namespace detail + +// Aligned memory equivalent of make_unique for array types using the +// custom allocators alloc/free. This function calls the constructor with the +// passed Args... on every created item. The destructor of each element will be +// called when the AlignedUniquePtr is destroyed. +template +AlignedUniquePtr MakeUniqueAlignedArrayWithAlloc( + size_t items, AllocPtr alloc, FreePtr free, void* opaque, Args&&... args) { + T* ptr = detail::AllocateAlignedItems(items, alloc, opaque); + if (ptr != nullptr) { + for (size_t i = 0; i < items; i++) { + new (ptr + i) T(args...); + } + } + return AlignedUniquePtr(ptr, AlignedDeleter(free, opaque)); +} + +template +AlignedUniquePtr MakeUniqueAlignedArray(size_t items, Args&&... args) { + return MakeUniqueAlignedArrayWithAlloc( + items, nullptr, nullptr, nullptr, std::forward(args)...); +} + +// Custom deleter for std::unique_ptr equivalent to using free() as a deleter +// but for aligned memory. +class AlignedFreer { + public: + // Pass address of this to ctor to skip deleting externally-owned memory. + static void DoNothing(void* /*opaque*/, void* /*aligned_pointer*/) {} + + AlignedFreer() : free_(nullptr), opaque_ptr_(nullptr) {} + AlignedFreer(FreePtr free_ptr, void* opaque_ptr) + : free_(free_ptr), opaque_ptr_(opaque_ptr) {} + + template + void operator()(T* aligned_pointer) const { + FreeAlignedBytes(aligned_pointer, free_, opaque_ptr_); + } + + private: + FreePtr free_; + void* opaque_ptr_; +}; + +// Unique pointer to single POD, or (if T is U[]) an array of POD. For non POD +// data use AlignedUniquePtr. +template +using AlignedFreeUniquePtr = std::unique_ptr; + +// Allocate an aligned and uninitialized array of POD values as a unique_ptr. +// Upon destruction of the unique_ptr the aligned array will be freed. +template +AlignedFreeUniquePtr AllocateAligned(const size_t items, AllocPtr alloc, + FreePtr free, void* opaque) { + static_assert(std::is_trivially_copyable::value, + "AllocateAligned: requires trivially copyable T"); + static_assert(std::is_trivially_destructible::value, + "AllocateAligned: requires trivially destructible T"); + return AlignedFreeUniquePtr( + detail::AllocateAlignedItems(items, alloc, opaque), + AlignedFreer(free, opaque)); +} + +// Same as previous AllocateAligned(), using default allocate/free functions. +template +AlignedFreeUniquePtr AllocateAligned(const size_t items) { + return AllocateAligned(items, nullptr, nullptr, nullptr); +} + +// A simple span containing data and size of data. +template +class Span { + public: + Span() = default; + Span(T* data, size_t size) : size_(size), data_(data) {} + template + Span(U u) : Span(u.data(), u.size()) {} + Span(std::initializer_list v) : Span(v.begin(), v.size()) {} + + // Copies the contents of the initializer list to the span. + Span& operator=(std::initializer_list v) { + HWY_DASSERT(size_ == v.size()); + CopyBytes(v.begin(), data_, sizeof(T) * std::min(size_, v.size())); + return *this; + } + + // Returns the size of the contained data. + size_t size() const { return size_; } + + // Returns a pointer to the contained data. + T* data() { return data_; } + T* data() const { return data_; } + + // Returns the element at index. + T& operator[](size_t index) const { return data_[index]; } + + // Returns an iterator pointing to the first element of this span. + T* begin() { return data_; } + + // Returns a const iterator pointing to the first element of this span. + constexpr const T* cbegin() const { return data_; } + + // Returns an iterator pointing just beyond the last element at the + // end of this span. + T* end() { return data_ + size_; } + + // Returns a const iterator pointing just beyond the last element at the + // end of this span. + constexpr const T* cend() const { return data_ + size_; } + + private: + size_t size_ = 0; + T* data_ = nullptr; +}; + +// A multi dimensional array containing an aligned buffer. +// +// To maintain alignment, the innermost dimension will be padded to ensure all +// innermost arrays are aligned. +template +class AlignedNDArray { + static_assert(std::is_trivial::value, + "AlignedNDArray can only contain trivial types"); + + public: + AlignedNDArray(AlignedNDArray&& other) = default; + AlignedNDArray& operator=(AlignedNDArray&& other) = default; + + // Constructs an array of the provided shape and fills it with zeros. + explicit AlignedNDArray(std::array shape) : shape_(shape) { + sizes_ = ComputeSizes(shape_); + memory_shape_ = shape_; + // Round the innermost dimension up to the number of bytes available for + // SIMD operations on this architecture to make sure that each innermost + // array is aligned from the first element. + memory_shape_[axes - 1] = RoundUpTo(memory_shape_[axes - 1], VectorBytes()); + memory_sizes_ = ComputeSizes(memory_shape_); + buffer_ = hwy::AllocateAligned(memory_size()); + hwy::ZeroBytes(buffer_.get(), memory_size() * sizeof(T)); + } + + // Returns a span containing the innermost array at the provided indices. + Span operator[](std::array indices) { + return Span(buffer_.get() + Offset(indices), sizes_[indices.size()]); + } + + // Returns a const span containing the innermost array at the provided + // indices. + Span operator[](std::array indices) const { + return Span(buffer_.get() + Offset(indices), + sizes_[indices.size()]); + } + + // Returns the shape of the array, which might be smaller than the allocated + // buffer after padding the last axis to alignment. + const std::array& shape() const { return shape_; } + + // Returns the shape of the allocated buffer, which might be larger than the + // used size of the array after padding to alignment. + const std::array& memory_shape() const { return memory_shape_; } + + // Returns the size of the array, which might be smaller than the allocated + // buffer after padding the last axis to alignment. + size_t size() const { return sizes_[0]; } + + // Returns the size of the allocated buffer, which might be larger than the + // used size of the array after padding to alignment. + size_t memory_size() const { return memory_sizes_[0]; } + + // Returns a pointer to the allocated buffer. + T* data() { return buffer_.get(); } + + // Returns a const pointer to the buffer. + const T* data() const { return buffer_.get(); } + + // Truncates the array by updating its shape. + // + // The new shape must be equal to or less than the old shape in all axes. + // + // Doesn't modify underlying memory. + void truncate(const std::array& new_shape) { +#if HWY_IS_DEBUG_BUILD + for (size_t axis_index = 0; axis_index < axes; ++axis_index) { + HWY_ASSERT(new_shape[axis_index] <= shape_[axis_index]); + } +#endif + shape_ = new_shape; + sizes_ = ComputeSizes(shape_); + } + + private: + std::array shape_; + std::array memory_shape_; + std::array sizes_; + std::array memory_sizes_; + hwy::AlignedFreeUniquePtr buffer_; + + // Computes offset in the buffer based on the provided indices. + size_t Offset(std::array indices) const { + size_t offset = 0; + size_t shape_index = 0; + for (const size_t axis_index : indices) { + if (HWY_UNLIKELY(axis_index >= shape_[shape_index])) { + HWY_ABORT("AlignedNDArray index %zu out of bounds (axis %zu, size %zu)", + axis_index, shape_index, shape_[shape_index]); + } + offset += memory_sizes_[shape_index + 1] * axis_index; + shape_index++; + } + return offset; + } + + // Computes the sizes of all sub arrays based on the sizes of each axis. + // + // Does this by multiplying the size of each axis with the previous one in + // reverse order, starting with the conceptual axis of size 1 containing the + // actual elements in the array. + static std::array ComputeSizes( + std::array shape) { + std::array sizes; + size_t axis = shape.size(); + sizes[axis] = 1; + while (axis > 0) { + --axis; + // Check for integer overflow in dimension multiplication. + if (HWY_LIKELY(shape[axis] != 0) && + sizes[axis + 1] > SIZE_MAX / shape[axis]) { + HWY_ABORT("AlignedNDArray: dimension overflow at axis %zu " + "(%zu * %zu exceeds size_t)", + axis, sizes[axis + 1], shape[axis]); + } + sizes[axis] = sizes[axis + 1] * shape[axis]; + } + return sizes; + } +}; + +} // namespace hwy +#endif // HIGHWAY_HWY_ALIGNED_ALLOCATOR_H_ diff --git a/lib/highway/hwy/auto_tune.h b/lib/highway/hwy/auto_tune.h new file mode 100644 index 00000000000..39fb51cb352 --- /dev/null +++ b/lib/highway/hwy/auto_tune.h @@ -0,0 +1,520 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_AUTO_TUNE_H_ +#define HIGHWAY_HWY_AUTO_TUNE_H_ + +#include +#include +#include // memmove + +#include +#include + +#include "hwy/aligned_allocator.h" // Span +#include "hwy/base.h" // HWY_MIN + +// configuration to allow auto_tune to use std::sort instead of VQSort +// (also enabled in header only mode). +#if defined(HWY_HEADER_ONLY) +#define HWY_AUTOTUNE_STDSORT +#endif + +#ifdef HWY_AUTOTUNE_STDSORT +#include // std::sort +#else +#include "hwy/contrib/sort/vqsort.h" // VQSort +#endif + +// Infrastructure for auto-tuning (choosing optimal parameters at runtime). + +namespace hwy { + +// O(1) storage to estimate the central tendency of hundreds of independent +// distributions (one per configuration). The number of samples per distribution +// (`kMinSamples`) varies from few to dozens. We support both by first storing +// values in a buffer, and when full, switching to online variance estimation. +// Modified from `hwy/stats.h`. +class CostDistribution { + public: + static constexpr size_t kMaxValues = 14; // for total size of 128 bytes + + void Notify(const double x) { + if (HWY_UNLIKELY(x < 0.0)) { + HWY_WARN("Ignoring negative cost %f.", x); + return; + } + + // Online phase after filling and warm-up. + if (HWY_LIKELY(IsOnline())) return OnlineNotify(x); + + // Fill phase: store up to `kMaxValues` values. + values_[num_values_++] = x; + HWY_DASSERT(num_values_ <= kMaxValues); + if (HWY_UNLIKELY(num_values_ == kMaxValues)) { + WarmUpOnline(); + HWY_DASSERT(IsOnline()); + } + } + + // Returns an estimate of the true cost, mitigating the impact of noise. + // + // Background and observations from time measurements in `thread_pool.h`: + // - We aim for O(1) storage because there may be hundreds of instances. + // - The mean is biased upwards by mostly additive noise: particularly + // interruptions such as context switches, but also contention. + // - The minimum is not a robust estimator because there are also "lucky + // shots" (1.2-1.6x lower values) where interruptions or contention happen + // to be low. + // - We want to preserve information about contention and a configuration's + // sensitivity to it. Otherwise, we are optimizing for the best-case, not + // the common case. + // - It is still important to minimize the influence of outliers, such as page + // faults, which can cause multiple times larger measurements. + // - Detecting outliers based only on the initial variance is too brittle. If + // the sample is narrow, measurements will fluctuate across runs because + // too many measurements are considered outliers. This would cause the + // 'best' configuration to vary. + // + // Approach: + // - Use Winsorization to reduce the impact of outliers, while preserving + // information on the central tendency. + // - Continually update the thresholds based on the online variance, with + // exponential smoothing for stability. + // - Trim the initial sample via MAD or skewness for a robust estimate of the + // variance. + double EstimateCost() { + if (!IsOnline()) { + WarmUpOnline(); + HWY_DASSERT(IsOnline()); + } + return Mean(); + } + + // Multiplex online state into values_ to allow higher `kMaxValues`. + // Public for inspection in tests. Do not use directly. + double& M1() { return values_[0]; } // Moments for variance. + double& M2() { return values_[1]; } + double& Mean() { return values_[2]; } // Exponential smoothing. + double& Stddev() { return values_[3]; } + double& Lower() { return values_[4]; } + double& Upper() { return values_[5]; } + + private: + static double Median(double* to_sort, size_t n) { + HWY_DASSERT(n >= 2); + +#ifdef HWY_AUTOTUNE_STDSORT + std::sort(to_sort, to_sort + n); +#else +// F64 is supported everywhere except Armv7. +#if !HWY_ARCH_ARM_V7 + VQSort(to_sort, n, SortAscending()); +#else + // Values are known to be finite and non-negative, hence sorting as U64 is + // equivalent. + VQSort(reinterpret_cast(to_sort), n, SortAscending()); +#endif +#endif + + if (n & 1) return to_sort[n / 2]; + // Even length: average of two middle elements. + return (to_sort[n / 2] + to_sort[n / 2 - 1]) * 0.5; + } + + static double MAD(const double* values, size_t n, const double median) { + double abs_dev[kMaxValues]; + for (size_t i = 0; i < n; ++i) { + abs_dev[i] = ScalarAbs(values[i] - median); + } + return Median(abs_dev, n); + } + + // If `num_values_` is large enough, sorts and discards outliers: either via + // MAD, or if too many values are equal, by trimming according to skewness. + void RemoveOutliers() { + if (num_values_ < 3) return; // Not enough to discard two. + HWY_DASSERT(num_values_ <= kMaxValues); + + // Given the noise level in `auto_tune_test`, it can happen that 1/4 of the + // sample is an outlier *in either direction*. Use median absolute + // deviation, which is robust to almost half of the sample being outliers. + const double median = Median(values_, num_values_); // sorts in-place. + const double mad = MAD(values_, num_values_, median); + // At least half the sample is equal. + if (mad == 0.0) { + // Estimate skewness to decide which side to trim more. + const double skewness = + (values_[num_values_ - 1] - median) - (median - values_[0]); + + const size_t trim = HWY_MAX(num_values_ / 2, size_t{2}); + const size_t left = + HWY_MAX(skewness < 0.0 ? trim * 3 / 4 : trim / 4, size_t{1}); + num_values_ -= trim; + HWY_DASSERT(num_values_ >= 1); + memmove(values_, values_ + left, num_values_ * sizeof(values_[0])); + return; + } + + const double upper = median + 5.0 * mad; + const double lower = median - 5.0 * mad; + size_t right = num_values_ - 1; + while (values_[right] > upper) --right; + // Nonzero MAD implies no more than half are equal, so we did not advance + // beyond the median. + HWY_DASSERT(right >= num_values_ / 2); + + size_t left = 0; + while (left < right && values_[left] < lower) ++left; + HWY_DASSERT(left <= num_values_ / 2); + num_values_ = right - left + 1; + memmove(values_, values_ + left, num_values_ * sizeof(values_[0])); + } + + double SampleMean() const { + // Only called in non-online phase, but buffer might not be full. + HWY_DASSERT(!IsOnline() && 0 != num_values_ && num_values_ <= kMaxValues); + double sum = 0.0; + for (size_t i = 0; i < num_values_; ++i) { + sum += values_[i]; + } + return sum / static_cast(num_values_); + } + + // Unbiased estimator for population variance even for small `num_values_`. + double SampleVariance(double sample_mean) const { + HWY_DASSERT(sample_mean >= 0.0); // we checked costs are non-negative. + // Only called in non-online phase, but buffer might not be full. + HWY_DASSERT(!IsOnline() && 0 != num_values_ && num_values_ <= kMaxValues); + if (HWY_UNLIKELY(num_values_ == 1)) return 0.0; // prevent divide-by-zero. + double sum2 = 0.0; + for (size_t i = 0; i < num_values_; ++i) { + const double d = values_[i] - sample_mean; + sum2 += d * d; + } + return sum2 / static_cast(num_values_ - 1); + } + + bool IsOnline() const { return online_n_ > 0.0; } + + void OnlineNotify(double x) { + // Winsorize. + x = HWY_MIN(HWY_MAX(Lower(), x), Upper()); + + // Welford's online variance estimator. + // https://media.thinkbrg.com/wp-content/uploads/2020/06/19094655/720_720_McCrary_ImplementingAlgorithms_Whitepaper_20151119_WEB.pdf#page=7.09 + const double n_minus_1 = online_n_; + online_n_ += 1.0; + const double d = x - M1(); + const double d_div_n = d / online_n_; + M1() += d_div_n; + HWY_DASSERT(M1() >= Lower()); + M2() += d * n_minus_1 * d_div_n; // d^2 * (N-1)/N + // HWY_MAX avoids divide-by-zero. + const double stddev = std::sqrt(M2() / HWY_MAX(1.0, n_minus_1)); + + // Exponential smoothing. + constexpr double kNew = 0.2; // relatively fast update + constexpr double kOld = 1.0 - kNew; + Mean() = M1() * kNew + Mean() * kOld; + Stddev() = stddev * kNew + Stddev() * kOld; + + // Update thresholds from smoothed mean and stddev to enable recovering from + // a too narrow initial range due to excessive trimming. + Lower() = Mean() - 3.5 * Stddev(); + Upper() = Mean() + 3.5 * Stddev(); + } + + void WarmUpOnline() { + RemoveOutliers(); + + // Compute and copy before writing to `M1`, which overwrites `values_`! + const double sample_mean = SampleMean(); + const double sample_variance = SampleVariance(sample_mean); + double copy[kMaxValues]; + hwy::CopyBytes(values_, copy, num_values_ * sizeof(values_[0])); + + M1() = M2() = 0.0; + Mean() = sample_mean; + Stddev() = std::sqrt(sample_variance); + // For single-value or all-equal sample, widen the range, else we will only + // accept the same value. + if (Stddev() == 0.0) Stddev() = Mean() / 2; + + // High tolerance because the distribution is not actually Gaussian, and + // we trimmed up to *half*, and do not want to reject too many values in + // the online phase. + Lower() = Mean() - 4.0 * Stddev(); + Upper() = Mean() + 4.0 * Stddev(); + // Feed copied values into online estimator. + for (size_t i = 0; i < num_values_; ++i) { + OnlineNotify(copy[i]); + } + HWY_DASSERT(IsOnline()); + +#if SIZE_MAX == 0xFFFFFFFFu + (void)padding_; +#endif + } + + size_t num_values_ = 0; // size of `values_` <= `kMaxValues` +#if SIZE_MAX == 0xFFFFFFFFu + uint32_t padding_ = 0; +#endif + + double online_n_ = 0.0; // number of calls to `OnlineNotify`. + + double values_[kMaxValues]; +}; +static_assert(sizeof(CostDistribution) == 128, ""); + +// Implements a counter with wrap-around, plus the ability to skip values. +// O(1) time, O(N) space via doubly-linked list of indices. +class NextWithSkip { + public: + NextWithSkip() {} + explicit NextWithSkip(size_t num) { + links_.reserve(num); + for (size_t i = 0; i < num; ++i) { + links_.emplace_back(i, num); + } + } + + size_t Next(size_t pos) { + HWY_DASSERT(pos < links_.size()); + HWY_DASSERT(!links_[pos].IsRemoved()); + return links_[pos].Next(); + } + + // Must not be called for an already skipped position. Ignores an attempt to + // skip the last remaining position. + void Skip(size_t pos) { + HWY_DASSERT(!links_[pos].IsRemoved()); // not already skipped. + const size_t prev = links_[pos].Prev(); + const size_t next = links_[pos].Next(); + if (prev == pos || next == pos) return; // last remaining position. + links_[next].SetPrev(prev); + links_[prev].SetNext(next); + links_[pos].Remove(); + } + + private: + // Combine prev/next into one array to improve locality/reduce allocations. + class Link { + // Bit-shifts avoid potentially expensive 16-bit loads. Store `next` at the + // top and `prev` at the bottom for extraction with a single shift/AND. + // There may be hundreds of configurations, so 8 bits are not enough. + static constexpr size_t kBits = 14; + static constexpr size_t kShift = 32 - kBits; + static constexpr uint32_t kMaxNum = 1u << kBits; + + public: + Link(size_t pos, size_t num) { + HWY_DASSERT(num < kMaxNum); + const size_t prev = pos == 0 ? num - 1 : pos - 1; + const size_t next = pos == num - 1 ? 0 : pos + 1; + bits_ = + (static_cast(next) << kShift) | static_cast(prev); + HWY_DASSERT(Next() == next && Prev() == prev); + HWY_DASSERT(!IsRemoved()); + } + + bool IsRemoved() const { return (bits_ & kMaxNum) != 0; } + void Remove() { bits_ |= kMaxNum; } + + size_t Next() const { return bits_ >> kShift; } + size_t Prev() const { return bits_ & (kMaxNum - 1); } + + void SetNext(size_t next) { + HWY_DASSERT(next < kMaxNum); + bits_ &= (~0u >> kBits); // clear old next + bits_ |= static_cast(next) << kShift; + HWY_DASSERT(Next() == next); + HWY_DASSERT(!IsRemoved()); + } + void SetPrev(size_t prev) { + HWY_DASSERT(prev < kMaxNum); + bits_ &= ~(kMaxNum - 1); // clear old prev + bits_ |= static_cast(prev); + HWY_DASSERT(Prev() == prev); + HWY_DASSERT(!IsRemoved()); + } + + private: + uint32_t bits_; + }; + std::vector links_; +}; + +// State machine for choosing at runtime the lowest-cost `Config`, which is +// typically a struct containing multiple parameters. For an introduction, see +// "Auto-Tuning and Performance Portability on Heterogeneous Hardware". +// +// **Which parameters** +// Note that simple parameters such as the L2 cache size can be directly queried +// via `hwy/contrib/thread_pool/topology.h`. Difficult to predict parameters +// such as task granularity are more appropriate for auto-tuning. We also +// suggest that at least some parameters should also be 'algorithm variants' +// such as parallel vs. serial, or 2D tiling vs. 1D striping. +// +// **Search strategy** +// To guarantee the optimal result, we use exhaustive search, which is suitable +// for around 10 parameters and a few hundred combinations of 'candidate' +// configurations. +// +// **How to generate candidates** +// To keep this framework simple and generic, applications enumerate the search +// space and pass the list of all feasible candidates to `SetCandidates` before +// the first call to `NextConfig`. Applications should prune the space as much +// as possible, e.g. by upper-bounding parameters based on the known cache +// sizes, and applying constraints such as one being a multiple of another. +// +// **Usage** +// Applications typically conditionally branch to the code implementing the +// configuration returned by `NextConfig`. They measure the cost of running it +// and pass that to `NotifyCost`. Branching avoids the complexity and +// opaqueness of a JIT. The number of branches can be reduced (at the cost of +// code size) by inlining low-level decisions into larger code regions, e.g. by +// hoisting them outside hot loops. +// +// **What is cost** +// Cost is an arbitrary `uint64_t`, with lower values being better. Most +// applications will use the elapsed time. If the tasks being tuned are short, +// it is important to use a high-resolution timer such as `hwy/timer.h`. Energy +// may also be useful [https://www.osti.gov/servlets/purl/1361296]. +// +// **Online vs. offline** +// Although applications can auto-tune once, offline, it may be difficult to +// ensure the stored configuration still applies to the current circumstances. +// Thus we recommend online auto-tuning, re-discovering the configuration on +// each run. We assume the overhead of bookkeeping and measuring cost is +// negligible relative to the actual work. The cost of auto-tuning is then that +// of running sub-optimal configurations. Assuming the best configuration is +// better than baseline, and the work is performed many thousands of times, the +// cost is outweighed by the benefits. +// +// **kMinSamples** +// To further reduce overhead, after `kMinSamples` rounds (= measurements of +// each configuration) we start excluding configurations from further +// measurements if they are sufficiently worse than the current best. +// `kMinSamples` can be several dozen when the tasks being tuned take a few +// microseconds. Even for longer tasks, it should be at least 2 for some noise +// tolerance. After this, there are another `kMinSamples / 2 + 1` rounds before +// declaring the winner. +template +class AutoTune { + public: + // Returns non-null best configuration if auto-tuning has already finished. + // Otherwise, callers continue calling `NextConfig` and `NotifyCost`. + // Points into `Candidates()`. + const Config* Best() const { return best_; } + + // If false, caller must call `SetCandidates` before `NextConfig`. + // NOTE: also called after Best() is non-null. + bool HasCandidates() const { return !candidates_.empty(); } + + // WARNING: invalidates `Best()`, do not call if that is non-null. + void SetCandidates(std::vector candidates) { + HWY_DASSERT(!Best() && !HasCandidates()); + candidates_.swap(candidates); + HWY_DASSERT(HasCandidates()); + costs_.resize(candidates_.size()); + list_ = NextWithSkip(candidates_.size()); + } + + // Typically called after Best() is non-null to compare all candidates' costs. + Span Candidates() const { + HWY_DASSERT(HasCandidates()); + return Span(candidates_.data(), candidates_.size()); + } + Span Costs() { + return Span(costs_.data(), costs_.size()); + } + + // Returns the current `Config` to measure. + const Config& NextConfig() const { + HWY_DASSERT(HasCandidates()); + return candidates_[config_idx_]; + } + + // O(1) except at the end of each round, which is O(N). + void NotifyCost(uint64_t cost) { + HWY_DASSERT(!Best() && HasCandidates()); + + costs_[config_idx_].Notify(static_cast(cost)); + // Save now before we update `config_idx_`. + const size_t my_idx = config_idx_; + // Only retrieve once we have enough samples, otherwise, we switch to + // online variance before the buffer is populated. + const double my_cost = rounds_complete_ >= kMinSamples + ? costs_[config_idx_].EstimateCost() + : 0.0; + + // Advance to next non-skipped config with wrap-around. This decorrelates + // measurements by not immediately re-measuring the same config. + config_idx_ = list_.Next(config_idx_); + // Might still equal `my_idx` if this is the only non-skipped config. + + // Disqualify from future `NextConfig` if cost was too far beyond the + // current best. This reduces the number of measurements, while tolerating + // noise in the first few measurements. Must happen after advancing. + if (my_cost > skip_if_above_) { + list_.Skip(my_idx); + } + + // Wrap-around indicates the round is complete. + if (HWY_UNLIKELY(config_idx_ <= my_idx)) { + ++rounds_complete_; + + // Enough samples for stable estimates: update the thresholds. + if (rounds_complete_ >= kMinSamples) { + double best_cost = HighestValue(); + size_t idx_min = 0; + for (size_t i = 0; i < candidates_.size(); ++i) { + const double estimate = costs_[i].EstimateCost(); + if (estimate < best_cost) { + best_cost = estimate; + idx_min = i; + } + } + skip_if_above_ = best_cost * 1.25; + + // After sufficient rounds, declare the winner. + if (HWY_UNLIKELY(rounds_complete_ == 3 * kMinSamples / 2 + 1)) { + best_ = &candidates_[idx_min]; + HWY_DASSERT(Best()); + } + } + } + } + + // Avoid printing during the first few rounds, because those might be noisy + // and not yet skipped. + bool ShouldPrint() { return rounds_complete_ > kMinSamples; } + + private: + const Config* best_ = nullptr; + std::vector candidates_; + std::vector costs_; // one per candidate + size_t config_idx_ = 0; // [0, candidates_.size()) + NextWithSkip list_; + size_t rounds_complete_ = 0; + + double skip_if_above_ = 0.0; +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_AUTO_TUNE_H_ diff --git a/lib/highway/hwy/base.h b/lib/highway/hwy/base.h new file mode 100644 index 00000000000..342bfed214e --- /dev/null +++ b/lib/highway/hwy/base.h @@ -0,0 +1,3231 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_BASE_H_ +#define HIGHWAY_HWY_BASE_H_ + +// Target-independent definitions. + +// IWYU pragma: begin_exports +#include +#include +#if defined(HWY_HEADER_ONLY) +#include +#include +#endif + +#if !defined(HWY_NO_LIBCXX) +#include +#endif + +#include "hwy/detect_compiler_arch.h" +#include "hwy/highway_export.h" + +// API version (https://semver.org/); keep in sync with CMakeLists.txt and +// meson.build. +#define HWY_MAJOR 1 +#define HWY_MINOR 4 +#define HWY_PATCH 0 + +// True if the Highway version >= major.minor.0. Added in 1.2.0. +#define HWY_VERSION_GE(major, minor) \ + (HWY_MAJOR > (major) || (HWY_MAJOR == (major) && HWY_MINOR >= (minor))) +// True if the Highway version < major.minor.0. Added in 1.2.0. +#define HWY_VERSION_LT(major, minor) \ + (HWY_MAJOR < (major) || (HWY_MAJOR == (major) && HWY_MINOR < (minor))) + +// "IWYU pragma: keep" does not work for these includes, so hide from the IDE. +#if !HWY_IDE + +#if !defined(HWY_NO_LIBCXX) +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS // before inttypes.h +#endif +#include +#endif + +#endif // !HWY_IDE + +#if !defined(HWY_NO_LIBCXX) || HWY_COMPILER_MSVC +#include +#endif + +#ifndef HWY_HAVE_COMPARE_HEADER // allow override +#define HWY_HAVE_COMPARE_HEADER 0 +#if defined(__has_include) // note: wrapper macro fails on Clang ~17 +#if __has_include() +#undef HWY_HAVE_COMPARE_HEADER +#define HWY_HAVE_COMPARE_HEADER 1 +#endif // __has_include +#endif // defined(__has_include) +#endif // HWY_HAVE_COMPARE_HEADER + +#ifndef HWY_HAVE_CXX20_THREE_WAY_COMPARE // allow override +#if !defined(HWY_NO_LIBCXX) && defined(__cpp_impl_three_way_comparison) && \ + __cpp_impl_three_way_comparison >= 201907L && HWY_HAVE_COMPARE_HEADER +#include +#define HWY_HAVE_CXX20_THREE_WAY_COMPARE 1 +#else +#define HWY_HAVE_CXX20_THREE_WAY_COMPARE 0 +#endif +#endif // HWY_HAVE_CXX20_THREE_WAY_COMPARE + +// IWYU pragma: end_exports + +#if HWY_COMPILER_MSVC +#include // memcpy +#endif + +//------------------------------------------------------------------------------ +// Compiler-specific definitions + +#define HWY_STR_IMPL(macro) #macro +#define HWY_STR(macro) HWY_STR_IMPL(macro) + +#if HWY_COMPILER_MSVC + +#include + +#define HWY_FUNCTION __FUNCSIG__ // function name + template args +#define HWY_RESTRICT __restrict +#define HWY_INLINE __forceinline +#define HWY_NOINLINE __declspec(noinline) +#define HWY_FLATTEN +#define HWY_NORETURN __declspec(noreturn) +#define HWY_LIKELY(expr) (expr) +#define HWY_UNLIKELY(expr) (expr) +#define HWY_UNREACHABLE __assume(false) +#define HWY_PRAGMA(tokens) __pragma(tokens) +#define HWY_DIAGNOSTICS(tokens) HWY_PRAGMA(warning(tokens)) +#define HWY_DIAGNOSTICS_OFF(msc, gcc) HWY_DIAGNOSTICS(msc) +#define HWY_MAYBE_UNUSED +#define HWY_HAS_ASSUME_ALIGNED 0 +#if (_MSC_VER >= 1700) +#define HWY_MUST_USE_RESULT _Check_return_ +#else +#define HWY_MUST_USE_RESULT +#endif + +#else + +#define HWY_FUNCTION __PRETTY_FUNCTION__ // function name + template args +#define HWY_RESTRICT __restrict__ +// force inlining without optimization enabled creates very inefficient code +// that can cause compiler timeout +#ifdef __OPTIMIZE__ +#define HWY_INLINE inline __attribute__((always_inline)) +#else +#define HWY_INLINE inline +#endif +#define HWY_NOINLINE __attribute__((noinline)) +#define HWY_FLATTEN __attribute__((flatten)) +#define HWY_NORETURN __attribute__((noreturn)) +#define HWY_LIKELY(expr) __builtin_expect(!!(expr), 1) +#define HWY_UNLIKELY(expr) __builtin_expect(!!(expr), 0) +#if HWY_COMPILER_GCC || HWY_HAS_BUILTIN(__builtin_unreachable) +#define HWY_UNREACHABLE __builtin_unreachable() +#else +#define HWY_UNREACHABLE +#endif +#define HWY_PRAGMA(tokens) _Pragma(#tokens) +#define HWY_DIAGNOSTICS(tokens) HWY_PRAGMA(GCC diagnostic tokens) +#define HWY_DIAGNOSTICS_OFF(msc, gcc) HWY_DIAGNOSTICS(gcc) +// Encountered "attribute list cannot appear here" when using the C++17 +// [[maybe_unused]], so only use the old style attribute for now. +#define HWY_MAYBE_UNUSED __attribute__((unused)) +#define HWY_MUST_USE_RESULT __attribute__((warn_unused_result)) + +#endif // !HWY_COMPILER_MSVC + +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1200) || \ + (HWY_COMPILER_ICC && !HWY_COMPILER_ICX) +// The use of __attribute__((unused)) in private class member variables triggers +// a compiler warning with GCC 11 and earlier and ICC + +// GCC 11 and earlier and ICC also do not emit -Wunused-private-field warnings +// for unused private class member variables +#define HWY_MEMBER_VAR_MAYBE_UNUSED +#else +// Clang and ICX need __attribute__((unused)) in unused private class member +// variables to suppress -Wunused-private-field warnings unless this warning is +// ignored by using HWY_DIAGNOSTICS_OFF +#define HWY_MEMBER_VAR_MAYBE_UNUSED HWY_MAYBE_UNUSED +#endif + +//------------------------------------------------------------------------------ +// Builtin/attributes (no more #include after this point due to namespace!) + +namespace hwy { + +// Enables error-checking of format strings. +#if HWY_HAS_ATTRIBUTE(__format__) +#define HWY_FORMAT(idx_fmt, idx_arg) \ + __attribute__((__format__(__printf__, idx_fmt, idx_arg))) +#else +#define HWY_FORMAT(idx_fmt, idx_arg) +#endif + +// Returns a void* pointer which the compiler then assumes is N-byte aligned. +// Example: float* HWY_RESTRICT aligned = (float*)HWY_ASSUME_ALIGNED(in, 32); +// +// The assignment semantics are required by GCC/Clang. ICC provides an in-place +// __assume_aligned, whereas MSVC's __assume appears unsuitable. +#if HWY_HAS_BUILTIN(__builtin_assume_aligned) +#define HWY_ASSUME_ALIGNED(ptr, align) __builtin_assume_aligned((ptr), (align)) +#else +#define HWY_ASSUME_ALIGNED(ptr, align) (ptr) /* not supported */ +#endif + +// Returns a pointer whose type is `type` (T*), while allowing the compiler to +// assume that the untyped pointer `ptr` is aligned to a multiple of sizeof(T). +#define HWY_RCAST_ALIGNED(type, ptr) \ + reinterpret_cast( \ + HWY_ASSUME_ALIGNED((ptr), alignof(hwy::RemovePtr))) + +// Clang and GCC require attributes on each function into which SIMD intrinsics +// are inlined. Support both per-function annotation (HWY_ATTR) for lambdas and +// automatic annotation via pragmas. +#if HWY_COMPILER_ICC +// As of ICC 2021.{1-9} the pragma is neither implemented nor required. +#define HWY_PUSH_ATTRIBUTES(targets_str) +#define HWY_POP_ATTRIBUTES +#elif HWY_COMPILER_CLANG +#define HWY_PUSH_ATTRIBUTES(targets_str) \ + HWY_PRAGMA(clang attribute push(__attribute__((target(targets_str))), \ + apply_to = function)) +#define HWY_POP_ATTRIBUTES HWY_PRAGMA(clang attribute pop) +#elif HWY_COMPILER_GCC_ACTUAL +#define HWY_PUSH_ATTRIBUTES(targets_str) \ + HWY_PRAGMA(GCC push_options) HWY_PRAGMA(GCC target targets_str) +#define HWY_POP_ATTRIBUTES HWY_PRAGMA(GCC pop_options) +#else +#define HWY_PUSH_ATTRIBUTES(targets_str) +#define HWY_POP_ATTRIBUTES +#endif + +//------------------------------------------------------------------------------ +// Macros + +// Note: it is safe to remove `static` for users who want to use modules, but +// that might be a breaking change for some users, hence we do not by default. +#define HWY_API static HWY_INLINE HWY_FLATTEN + +#define HWY_CONCAT_IMPL(a, b) a##b +#define HWY_CONCAT(a, b) HWY_CONCAT_IMPL(a, b) + +#define HWY_MIN(a, b) ((a) < (b) ? (a) : (b)) +#define HWY_MAX(a, b) ((a) > (b) ? (a) : (b)) + +#if HWY_COMPILER_GCC_ACTUAL +// nielskm: GCC does not support '#pragma GCC unroll' without the factor. +#define HWY_UNROLL(factor) HWY_PRAGMA(GCC unroll factor) +#define HWY_DEFAULT_UNROLL HWY_UNROLL(4) +#elif HWY_COMPILER_CLANG || HWY_COMPILER_ICC || HWY_COMPILER_ICX +#define HWY_UNROLL(factor) HWY_PRAGMA(unroll factor) +#define HWY_DEFAULT_UNROLL HWY_UNROLL() +#else +#define HWY_UNROLL(factor) +#define HWY_DEFAULT_UNROLL +#endif + +// Tell a compiler that the expression always evaluates to true. +// The expression should be free from any side effects. +// Some older compilers may have trouble with complex expressions, therefore +// it is advisable to split multiple conditions into separate assume statements, +// and manually check the generated code. +// OK but could fail: +// HWY_ASSUME(x == 2 && y == 3); +// Better: +// HWY_ASSUME(x == 2); +// HWY_ASSUME(y == 3); +#if (HWY_CXX_LANG >= 202302L) && HWY_HAS_CPP_ATTRIBUTE(assume) +#define HWY_ASSUME(expr) [[assume(expr)]] +#elif HWY_COMPILER_MSVC || HWY_COMPILER_ICC +#define HWY_ASSUME(expr) __assume(expr) +// __builtin_assume() was added in clang 3.6. +#elif HWY_COMPILER_CLANG && HWY_HAS_BUILTIN(__builtin_assume) +#define HWY_ASSUME(expr) __builtin_assume(expr) +// __builtin_unreachable() was added in GCC 4.5, but __has_builtin() was added +// later, so check for the compiler version directly. +#elif HWY_COMPILER_GCC_ACTUAL >= 405 +#define HWY_ASSUME(expr) \ + ((expr) ? static_cast(0) : __builtin_unreachable()) +#else +#define HWY_ASSUME(expr) static_cast(0) +#endif + +// Compile-time fence to prevent undesirable code reordering. On Clang, the +// typical `asm volatile("" : : : "memory")` seems to be ignored. Note that +// `std::atomic_thread_fence` affects other threads, hence might generate a +// barrier instruction, but this does not. +#if !defined(HWY_NO_LIBCXX) +#define HWY_FENCE std::atomic_signal_fence(std::memory_order_seq_cst) +#elif HWY_COMPILER_GCC +#define HWY_FENCE asm volatile("" : : : "memory") +#else +#define HWY_FENCE +#endif + +// 4 instances of a given literal value, useful as input to LoadDup128. +#define HWY_REP4(literal) literal, literal, literal, literal + +//------------------------------------------------------------------------------ +// Abort / Warn + +#if defined(HWY_HEADER_ONLY) +HWY_DLLEXPORT inline void HWY_FORMAT(3, 4) + Warn(const char* file, int line, const char* format, ...) { + char buf[800]; + va_list args; + va_start(args, format); + vsnprintf(buf, sizeof(buf), format, args); + va_end(args); + + fprintf(stderr, "Warn at %s:%d: %s\n", file, line, buf); +} + +HWY_DLLEXPORT HWY_NORETURN inline void HWY_FORMAT(3, 4) + Abort(const char* file, int line, const char* format, ...) { + char buf[800]; + va_list args; + va_start(args, format); + vsnprintf(buf, sizeof(buf), format, args); + va_end(args); + + fprintf(stderr, "Abort at %s:%d: %s\n", file, line, buf); + + fflush(stderr); + +// Now terminate the program: +#if HWY_ARCH_RISCV + exit(1); // trap/abort just freeze Spike. +#else + abort(); // Compile error without this due to HWY_NORETURN. +#endif +} +#else // !HWY_HEADER_ONLY +// Interfaces for custom Warn/Abort handlers. +typedef void (*WarnFunc)(const char* file, int line, const char* message); + +typedef void (*AbortFunc)(const char* file, int line, const char* message); + +// Returns current Warn() handler, or nullptr if no handler was yet registered, +// indicating Highway should print to stderr. +// DEPRECATED because this is thread-hostile and prone to misuse (modifying the +// underlying pointer through the reference). +HWY_DLLEXPORT WarnFunc& GetWarnFunc(); + +// Returns current Abort() handler, or nullptr if no handler was yet registered, +// indicating Highway should print to stderr and abort. +// DEPRECATED because this is thread-hostile and prone to misuse (modifying the +// underlying pointer through the reference). +HWY_DLLEXPORT AbortFunc& GetAbortFunc(); + +// Sets a new Warn() handler and returns the previous handler, which is nullptr +// if no previous handler was registered, and should otherwise be called from +// the new handler. Thread-safe. +HWY_DLLEXPORT WarnFunc SetWarnFunc(WarnFunc func); + +// Sets a new Abort() handler and returns the previous handler, which is nullptr +// if no previous handler was registered, and should otherwise be called from +// the new handler. If all handlers return, then Highway will terminate the app. +// Thread-safe. +HWY_DLLEXPORT AbortFunc SetAbortFunc(AbortFunc func); + +HWY_DLLEXPORT void HWY_FORMAT(3, 4) + Warn(const char* file, int line, const char* format, ...); + +HWY_DLLEXPORT HWY_NORETURN void HWY_FORMAT(3, 4) + Abort(const char* file, int line, const char* format, ...); + +#endif // HWY_HEADER_ONLY + +#define HWY_WARN(format, ...) \ + ::hwy::Warn(__FILE__, __LINE__, format, ##__VA_ARGS__) + +#define HWY_ABORT(format, ...) \ + ::hwy::Abort(__FILE__, __LINE__, format, ##__VA_ARGS__) + +// Always enabled. +#define HWY_ASSERT_M(condition, msg) \ + do { \ + if (!(condition)) { \ + HWY_ABORT("Assert %s: %s", #condition, msg); \ + } \ + } while (0) +#define HWY_ASSERT(condition) HWY_ASSERT_M(condition, "") + +// For enabling HWY_DASSERT and shortening tests in slower debug builds +// +// Note: `HWY_IS_UBSAN` is specifically excluded from engaging debug +// builds. This is in service of Chromium's `-fsanitize=array-bounds` by +// default, where we don't want Highway to unconditionally build in +// debug mode. +// +// See also: +// https://docs.google.com/document/d/1eCtY4AZF-SiFHxhIYWzEytdIx3C24de7ccD6Y5Gn2H8/edit?tab=t.9zkn85hr82ms#heading=h.efcshvfql42c +#if !defined(HWY_IS_DEBUG_BUILD) +// Clang does not define NDEBUG, but it and GCC define __OPTIMIZE__, and recent +// MSVC defines NDEBUG (if not, could instead check _DEBUG). +#if (!defined(__OPTIMIZE__) && !defined(NDEBUG)) || \ + (HWY_IS_ASAN || (HWY_IS_SANITIZER && !HWY_IS_UBSAN)) || \ + defined(__clang_analyzer__) +#define HWY_IS_DEBUG_BUILD 1 +#else +#define HWY_IS_DEBUG_BUILD 0 +#endif +#endif // HWY_IS_DEBUG_BUILD + +#if HWY_IS_DEBUG_BUILD +#define HWY_DASSERT_M(condition, msg) HWY_ASSERT_M(condition, msg) +#define HWY_DASSERT(condition) HWY_ASSERT_M(condition, "") +#else +#define HWY_DASSERT_M(condition, msg) \ + do { \ + } while (0) +#define HWY_DASSERT(condition) \ + do { \ + } while (0) +#endif + +//------------------------------------------------------------------------------ +// CopyBytes / ZeroBytes + +#if HWY_COMPILER_MSVC +#pragma intrinsic(memcpy) +#pragma intrinsic(memset) +#endif + +template +HWY_API void CopyBytes(const From* HWY_RESTRICT from, To* HWY_RESTRICT to) { +#if HWY_COMPILER_MSVC + memcpy(to, from, kBytes); +#else + __builtin_memcpy(to, from, kBytes); +#endif +} + +HWY_API void CopyBytes(const void* HWY_RESTRICT from, void* HWY_RESTRICT to, + size_t num_of_bytes_to_copy) { +#if HWY_COMPILER_MSVC + memcpy(to, from, num_of_bytes_to_copy); +#else + __builtin_memcpy(to, from, num_of_bytes_to_copy); +#endif +} + +// Same as CopyBytes, but for same-sized objects; avoids a size argument. +template +HWY_API void CopySameSize(const From* HWY_RESTRICT from, To* HWY_RESTRICT to) { + static_assert(sizeof(From) == sizeof(To), ""); + CopyBytes(from, to); +} + +// Sets each byte to `byte_value`, which must be between 0 and 255. +HWY_API void FillBytes(void* to, int byte_value, size_t num_bytes) { +#if HWY_COMPILER_MSVC + memset(to, byte_value, num_bytes); +#else + __builtin_memset(to, byte_value, num_bytes); +#endif +} + +HWY_API void ZeroBytes(void* to, size_t num_bytes) { + FillBytes(to, 0, num_bytes); +} + +template +HWY_API void ZeroBytes(To* to) { + ZeroBytes(to, kBytes); +} + +//------------------------------------------------------------------------------ +// kMaxVectorSize (undocumented, pending removal) + +#if HWY_ARCH_X86 +static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 64; // AVX-512 +#elif HWY_ARCH_RISCV && defined(__riscv_v_intrinsic) && \ + __riscv_v_intrinsic >= 11000 +// Not actually an upper bound on the size. +static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 4096; +#else +static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 16; +#endif + +//------------------------------------------------------------------------------ +// Alignment + +// Potentially useful for LoadDup128 and capped vectors. In other cases, arrays +// should be allocated dynamically via aligned_allocator.h because Lanes() may +// exceed the stack size. +#if HWY_ARCH_X86 +#define HWY_ALIGN_MAX alignas(64) +#elif HWY_ARCH_RISCV && defined(__riscv_v_intrinsic) && \ + __riscv_v_intrinsic >= 11000 +#define HWY_ALIGN_MAX alignas(8) // only elements need be aligned +#else +#define HWY_ALIGN_MAX alignas(16) +#endif + +//------------------------------------------------------------------------------ +// Lane types + +// hwy::float16_t and hwy::bfloat16_t are forward declared here to allow +// BitCastScalar to be implemented before the implementations of the +// hwy::float16_t and hwy::bfloat16_t types +struct float16_t; +struct bfloat16_t; + +using float32_t = float; +using float64_t = double; + +#pragma pack(push, 1) + +// Aligned 128-bit type. Cannot use __int128 because clang doesn't yet align it: +// https://reviews.llvm.org/D86310 +struct alignas(16) uint128_t { + uint64_t lo; // little-endian layout + uint64_t hi; +}; + +// 64 bit key plus 64 bit value. Faster than using uint128_t when only the key +// field is to be compared (Lt128Upper instead of Lt128). +struct alignas(16) K64V64 { + uint64_t value; // little-endian layout + uint64_t key; +}; + +// 32 bit key plus 32 bit value. Allows vqsort recursions to terminate earlier +// than when considering both to be a 64-bit key. +struct alignas(8) K32V32 { + uint32_t value; // little-endian layout + uint32_t key; +}; + +#pragma pack(pop) + +static inline HWY_MAYBE_UNUSED bool operator<(const uint128_t& a, + const uint128_t& b) { + return (a.hi == b.hi) ? a.lo < b.lo : a.hi < b.hi; +} +// Required for std::greater. +static inline HWY_MAYBE_UNUSED bool operator>(const uint128_t& a, + const uint128_t& b) { + return b < a; +} +static inline HWY_MAYBE_UNUSED bool operator==(const uint128_t& a, + const uint128_t& b) { + return a.lo == b.lo && a.hi == b.hi; +} + +#if !defined(HWY_NO_LIBCXX) +static inline HWY_MAYBE_UNUSED std::ostream& operator<<(std::ostream& os, + const uint128_t& n) { + return os << "[hi=" << n.hi << ",lo=" << n.lo << "]"; +} +#endif + +static inline HWY_MAYBE_UNUSED bool operator<(const K64V64& a, + const K64V64& b) { + return a.key < b.key; +} +// Required for std::greater. +static inline HWY_MAYBE_UNUSED bool operator>(const K64V64& a, + const K64V64& b) { + return b < a; +} +static inline HWY_MAYBE_UNUSED bool operator==(const K64V64& a, + const K64V64& b) { + return a.key == b.key; +} + +#if !defined(HWY_NO_LIBCXX) +static inline HWY_MAYBE_UNUSED std::ostream& operator<<(std::ostream& os, + const K64V64& n) { + return os << "[k=" << n.key << ",v=" << n.value << "]"; +} +#endif + +static inline HWY_MAYBE_UNUSED bool operator<(const K32V32& a, + const K32V32& b) { + return a.key < b.key; +} +// Required for std::greater. +static inline HWY_MAYBE_UNUSED bool operator>(const K32V32& a, + const K32V32& b) { + return b < a; +} +static inline HWY_MAYBE_UNUSED bool operator==(const K32V32& a, + const K32V32& b) { + return a.key == b.key; +} + +#if !defined(HWY_NO_LIBCXX) +static inline HWY_MAYBE_UNUSED std::ostream& operator<<(std::ostream& os, + const K32V32& n) { + return os << "[k=" << n.key << ",v=" << n.value << "]"; +} +#endif + +//------------------------------------------------------------------------------ +// Controlling overload resolution (SFINAE) + +template +struct EnableIfT {}; +template <> +struct EnableIfT { + using type = void; +}; + +template +using EnableIf = typename EnableIfT::type; + +template +struct IsSameT { + enum { value = 0 }; +}; + +template +struct IsSameT { + enum { value = 1 }; +}; + +template +HWY_API constexpr bool IsSame() { + return IsSameT::value; +} + +// Returns whether T matches either of U1 or U2 +template +HWY_API constexpr bool IsSameEither() { + return IsSameT::value || IsSameT::value; +} + +template +struct IfT { + using type = Then; +}; + +template +struct IfT { + using type = Else; +}; + +template +using If = typename IfT::type; + +template +struct IsConstT { + enum { value = 0 }; +}; + +template +struct IsConstT { + enum { value = 1 }; +}; + +template +HWY_API constexpr bool IsConst() { + return IsConstT::value; +} + +template +struct RemoveConstT { + using type = T; +}; +template +struct RemoveConstT { + using type = T; +}; + +template +using RemoveConst = typename RemoveConstT::type; + +template +struct RemoveVolatileT { + using type = T; +}; +template +struct RemoveVolatileT { + using type = T; +}; + +template +using RemoveVolatile = typename RemoveVolatileT::type; + +template +struct RemoveRefT { + using type = T; +}; +template +struct RemoveRefT { + using type = T; +}; +template +struct RemoveRefT { + using type = T; +}; + +template +using RemoveRef = typename RemoveRefT::type; + +template +using RemoveCvRef = RemoveConst>>; + +template +struct RemovePtrT { + using type = T; +}; +template +struct RemovePtrT { + using type = T; +}; +template +struct RemovePtrT { + using type = T; +}; +template +struct RemovePtrT { + using type = T; +}; +template +struct RemovePtrT { + using type = T; +}; + +template +using RemovePtr = typename RemovePtrT::type; + +// Insert into template/function arguments to enable this overload only for +// vectors of exactly, at most (LE), or more than (GT) this many bytes. +// +// As an example, checking for a total size of 16 bytes will match both +// Simd and Simd. +#define HWY_IF_V_SIZE(T, kN, bytes) \ + hwy::EnableIf* = nullptr +#define HWY_IF_V_SIZE_LE(T, kN, bytes) \ + hwy::EnableIf* = nullptr +#define HWY_IF_V_SIZE_GT(T, kN, bytes) \ + hwy::EnableIf<(kN * sizeof(T) > bytes)>* = nullptr + +#define HWY_IF_LANES(kN, lanes) hwy::EnableIf<(kN == lanes)>* = nullptr +#define HWY_IF_LANES_LE(kN, lanes) hwy::EnableIf<(kN <= lanes)>* = nullptr +#define HWY_IF_LANES_GT(kN, lanes) hwy::EnableIf<(kN > lanes)>* = nullptr + +#define HWY_IF_UNSIGNED(T) hwy::EnableIf()>* = nullptr +#define HWY_IF_NOT_UNSIGNED(T) hwy::EnableIf()>* = nullptr +#define HWY_IF_SIGNED(T) \ + hwy::EnableIf() && !hwy::IsFloat() && \ + !hwy::IsSpecialFloat()>* = nullptr +#define HWY_IF_FLOAT(T) hwy::EnableIf()>* = nullptr +#define HWY_IF_NOT_FLOAT(T) hwy::EnableIf()>* = nullptr +#define HWY_IF_FLOAT3264(T) hwy::EnableIf()>* = nullptr +#define HWY_IF_NOT_FLOAT3264(T) hwy::EnableIf()>* = nullptr +#define HWY_IF_SPECIAL_FLOAT(T) \ + hwy::EnableIf()>* = nullptr +#define HWY_IF_NOT_SPECIAL_FLOAT(T) \ + hwy::EnableIf()>* = nullptr +#define HWY_IF_FLOAT_OR_SPECIAL(T) \ + hwy::EnableIf() || hwy::IsSpecialFloat()>* = nullptr +#define HWY_IF_NOT_FLOAT_NOR_SPECIAL(T) \ + hwy::EnableIf() && !hwy::IsSpecialFloat()>* = nullptr +#define HWY_IF_INTEGER(T) hwy::EnableIf()>* = nullptr + +#define HWY_IF_T_SIZE(T, bytes) hwy::EnableIf* = nullptr +#define HWY_IF_NOT_T_SIZE(T, bytes) \ + hwy::EnableIf* = nullptr +// bit_array = 0x102 means 1 or 8 bytes. There is no NONE_OF because it sounds +// too similar. If you want the opposite of this (2 or 4 bytes), ask for those +// bits explicitly (0x14) instead of attempting to 'negate' 0x102. +#define HWY_IF_T_SIZE_ONE_OF(T, bit_array) \ + hwy::EnableIf<((size_t{1} << sizeof(T)) & (bit_array)) != 0>* = nullptr +#define HWY_IF_T_SIZE_LE(T, bytes) \ + hwy::EnableIf<(sizeof(T) <= (bytes))>* = nullptr +#define HWY_IF_T_SIZE_GT(T, bytes) \ + hwy::EnableIf<(sizeof(T) > (bytes))>* = nullptr + +#define HWY_IF_SAME(T, expected) \ + hwy::EnableIf, expected>()>* = nullptr +#define HWY_IF_NOT_SAME(T, expected) \ + hwy::EnableIf, expected>()>* = nullptr + +// One of two expected types +#define HWY_IF_SAME2(T, expected1, expected2) \ + hwy::EnableIf< \ + hwy::IsSameEither, expected1, expected2>()>* = \ + nullptr + +#define HWY_IF_U8(T) HWY_IF_SAME(T, uint8_t) +#define HWY_IF_U16(T) HWY_IF_SAME(T, uint16_t) +#define HWY_IF_U32(T) HWY_IF_SAME(T, uint32_t) +#define HWY_IF_U64(T) HWY_IF_SAME(T, uint64_t) + +#define HWY_IF_I8(T) HWY_IF_SAME(T, int8_t) +#define HWY_IF_I16(T) HWY_IF_SAME(T, int16_t) +#define HWY_IF_I32(T) HWY_IF_SAME(T, int32_t) +#define HWY_IF_I64(T) HWY_IF_SAME(T, int64_t) + +#define HWY_IF_BF16(T) HWY_IF_SAME(T, hwy::bfloat16_t) +#define HWY_IF_NOT_BF16(T) HWY_IF_NOT_SAME(T, hwy::bfloat16_t) + +#define HWY_IF_F16(T) HWY_IF_SAME(T, hwy::float16_t) +#define HWY_IF_NOT_F16(T) HWY_IF_NOT_SAME(T, hwy::float16_t) + +#define HWY_IF_F32(T) HWY_IF_SAME(T, float) +#define HWY_IF_F64(T) HWY_IF_SAME(T, double) + +// Use instead of HWY_IF_T_SIZE to avoid ambiguity with float16_t/float/double +// overloads. +#define HWY_IF_UI8(T) HWY_IF_SAME2(T, uint8_t, int8_t) +#define HWY_IF_UI16(T) HWY_IF_SAME2(T, uint16_t, int16_t) +#define HWY_IF_UI32(T) HWY_IF_SAME2(T, uint32_t, int32_t) +#define HWY_IF_UI64(T) HWY_IF_SAME2(T, uint64_t, int64_t) + +#define HWY_IF_LANES_PER_BLOCK(T, N, LANES) \ + hwy::EnableIf* = nullptr + +// Empty struct used as a size tag type. +template +struct SizeTag {}; + +template +class DeclValT { + private: + template + static URef TryAddRValRef(int); + template + static U TryAddRValRef(Arg); + + public: + using type = decltype(TryAddRValRef(0)); + enum { kDisableDeclValEvaluation = 1 }; +}; + +// hwy::DeclVal() can only be used in unevaluated contexts such as within an +// expression of a decltype specifier. + +// hwy::DeclVal() does not require that T have a public default constructor +template +HWY_API typename DeclValT::type DeclVal() noexcept { + static_assert(!DeclValT::kDisableDeclValEvaluation, + "DeclVal() cannot be used in an evaluated context"); +} + +template +struct IsArrayT { + enum { value = 0 }; +}; + +template +struct IsArrayT { + enum { value = 1 }; +}; + +template +struct IsArrayT { + enum { value = 1 }; +}; + +template +static constexpr bool IsArray() { + return IsArrayT::value; +} + +#if HWY_COMPILER_MSVC +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4180, ignored "-Wignored-qualifiers") +#endif + +template +class IsConvertibleT { + private: + template + static hwy::SizeTag<1> TestFuncWithToArg(T); + + template + static decltype(IsConvertibleT::template TestFuncWithToArg( + DeclVal())) + TryConvTest(int); + + template + static hwy::SizeTag<0> TryConvTest(Arg); + + public: + enum { + value = (IsSame>, void>() && + IsSame>, void>()) || + (!IsArray() && + (IsSame())>() || + !IsSame, RemoveConst>()) && + IsSame(0)), hwy::SizeTag<1>>()) + }; +}; + +#if HWY_COMPILER_MSVC +HWY_DIAGNOSTICS(pop) +#endif + +template +HWY_API constexpr bool IsConvertible() { + return IsConvertibleT::value; +} + +template +class IsStaticCastableT { + private: + template (DeclVal()))> + static hwy::SizeTag<1> TryStaticCastTest(int); + + template + static hwy::SizeTag<0> TryStaticCastTest(Arg); + + public: + enum { + value = IsSame(0)), hwy::SizeTag<1>>() + }; +}; + +template +static constexpr bool IsStaticCastable() { + return IsStaticCastableT::value; +} + +#define HWY_IF_CASTABLE(From, To) \ + hwy::EnableIf()>* = nullptr + +#define HWY_IF_OP_CASTABLE(op, T, Native) \ + HWY_IF_CASTABLE(decltype(DeclVal() op DeclVal()), Native) + +template +class IsAssignableT { + private: + template () = DeclVal())> + static hwy::SizeTag<1> TryAssignTest(int); + + template + static hwy::SizeTag<0> TryAssignTest(Arg); + + public: + enum { + value = IsSame(0)), hwy::SizeTag<1>>() + }; +}; + +template +static constexpr bool IsAssignable() { + return IsAssignableT::value; +} + +#define HWY_IF_ASSIGNABLE(T, From) \ + hwy::EnableIf()>* = nullptr + +// ---------------------------------------------------------------------------- +// IsSpecialFloat + +// These types are often special-cased and not supported in all ops. +template +HWY_API constexpr bool IsSpecialFloat() { + return IsSameEither, hwy::float16_t, hwy::bfloat16_t>(); +} + +// ----------------------------------------------------------------------------- +// IsIntegerLaneType and IsInteger + +template +HWY_API constexpr bool IsIntegerLaneType() { + return false; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} +template <> +HWY_INLINE constexpr bool IsIntegerLaneType() { + return true; +} + +namespace detail { + +template +static HWY_INLINE constexpr bool IsNonCvInteger() { + // NOTE: Do not add a IsNonCvInteger() specialization below as it is + // possible for IsSame() to be true when compiled with MSVC + // with the /Zc:wchar_t- option. + return IsIntegerLaneType() || IsSame() || + IsSameEither() || + IsSameEither(); +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { // NOLINT + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { // NOLINT + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { // NOLINT + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { // NOLINT + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { // NOLINT + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { // NOLINT + return true; +} +#if defined(__cpp_char8_t) && __cpp_char8_t >= 201811L +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +#endif +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} +template <> +HWY_INLINE constexpr bool IsNonCvInteger() { + return true; +} + +} // namespace detail + +template +HWY_API constexpr bool IsInteger() { + return detail::IsNonCvInteger>(); +} + +// ----------------------------------------------------------------------------- +// BitCastScalar + +#if HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 +#define HWY_BITCASTSCALAR_CONSTEXPR constexpr +#else +#define HWY_BITCASTSCALAR_CONSTEXPR +#endif + +#if __cpp_constexpr >= 201304L +#define HWY_BITCASTSCALAR_CXX14_CONSTEXPR HWY_BITCASTSCALAR_CONSTEXPR +#else +#define HWY_BITCASTSCALAR_CXX14_CONSTEXPR +#endif + +#if HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 +namespace detail { + +template +struct BitCastScalarSrcCastHelper { + static HWY_INLINE constexpr const From& CastSrcValRef(const From& val) { + return val; + } +}; + +#if HWY_COMPILER_CLANG >= 900 && HWY_COMPILER_CLANG < 1000 +// Workaround for Clang 9 constexpr __builtin_bit_cast bug +template >() && + hwy::IsInteger>()>* = nullptr> +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR To +BuiltinBitCastScalar(const From& val) { + static_assert(sizeof(To) == sizeof(From), + "sizeof(To) == sizeof(From) must be true"); + return static_cast(val); +} + +template >() && + hwy::IsInteger>())>* = nullptr> +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR To +BuiltinBitCastScalar(const From& val) { + return __builtin_bit_cast(To, val); +} +#endif // HWY_COMPILER_CLANG >= 900 && HWY_COMPILER_CLANG < 1000 + +} // namespace detail + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR To BitCastScalar(const From& val) { + // If From is hwy::float16_t or hwy::bfloat16_t, first cast val to either + // const typename From::Native& or const uint16_t& using + // detail::BitCastScalarSrcCastHelper>::CastSrcValRef to + // allow BitCastScalar from hwy::float16_t or hwy::bfloat16_t to be constexpr + // if To is not a pointer type, union type, or a struct/class containing a + // pointer, union, or reference subobject +#if HWY_COMPILER_CLANG >= 900 && HWY_COMPILER_CLANG < 1000 + return detail::BuiltinBitCastScalar( + detail::BitCastScalarSrcCastHelper>::CastSrcValRef( + val)); +#else + return __builtin_bit_cast( + To, detail::BitCastScalarSrcCastHelper>::CastSrcValRef( + val)); +#endif +} +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR To BitCastScalar(const From& val) { + // If To is hwy::float16_t or hwy::bfloat16_t, first do a BitCastScalar of val + // to uint16_t, and then bit cast the uint16_t value to To using To::FromBits + // as hwy::float16_t::FromBits and hwy::bfloat16_t::FromBits are guaranteed to + // be constexpr if the __builtin_bit_cast intrinsic is available. + return To::FromBits(BitCastScalar(val)); +} +#else +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR To BitCastScalar(const From& val) { + To result; + CopySameSize(&val, &result); + return result; +} +#endif + +//------------------------------------------------------------------------------ +// F16 lane type + +#pragma pack(push, 1) + +#ifndef HWY_NEON_HAVE_F16C // allow override +// Compiler supports __fp16 and load/store/conversion NEON intrinsics, which are +// included in Armv8 and VFPv4 (except with MSVC). On Armv7 Clang requires +// __ARM_FP & 2 whereas Armv7 GCC requires -mfp16-format=ieee. +#if (HWY_ARCH_ARM_A64 && !HWY_COMPILER_MSVC) || \ + (HWY_COMPILER_CLANG && defined(__ARM_FP) && (__ARM_FP & 2)) || \ + (HWY_COMPILER_GCC_ACTUAL && defined(__ARM_FP16_FORMAT_IEEE)) +#define HWY_NEON_HAVE_F16C 1 +#else +#define HWY_NEON_HAVE_F16C 0 +#endif +#endif // HWY_NEON_HAVE_F16C + +// RVV with f16 extension supports _Float16 and f16 vector ops. If set, implies +// HWY_HAVE_FLOAT16. +#if HWY_ARCH_RISCV && defined(__riscv_zvfh) && HWY_COMPILER_CLANG >= 1600 +#define HWY_RVV_HAVE_F16_VEC 1 +#else +#define HWY_RVV_HAVE_F16_VEC 0 +#endif + +// x86 compiler supports _Float16, not necessarily with operators. +// Avoid clang-cl because it lacks __extendhfsf2. +#if HWY_ARCH_X86 && defined(__SSE2__) && defined(__FLT16_MAX__) && \ + ((HWY_COMPILER_CLANG >= 1500 && !HWY_COMPILER_CLANGCL) || \ + HWY_COMPILER_GCC_ACTUAL >= 1200) +#define HWY_SSE2_HAVE_F16_TYPE 1 +#else +#define HWY_SSE2_HAVE_F16_TYPE 0 +#endif + +#ifndef HWY_HAVE_SCALAR_F16_TYPE // allow override +// Compiler supports _Float16, not necessarily with operators. +#if HWY_NEON_HAVE_F16C || HWY_RVV_HAVE_F16_VEC || HWY_SSE2_HAVE_F16_TYPE || \ + __SPIRV_DEVICE__ +#define HWY_HAVE_SCALAR_F16_TYPE 1 +#else +#define HWY_HAVE_SCALAR_F16_TYPE 0 +#endif +#endif // HWY_HAVE_SCALAR_F16_TYPE + +#ifndef HWY_HAVE_SCALAR_F16_OPERATORS +// Recent enough compiler also has operators. +#if HWY_HAVE_SCALAR_F16_TYPE && \ + (HWY_COMPILER_CLANG >= 1800 || HWY_COMPILER_GCC_ACTUAL >= 1200 || \ + (HWY_COMPILER_CLANG >= 1500 && !HWY_COMPILER_CLANGCL && \ + !defined(_WIN32)) || \ + (HWY_ARCH_ARM && \ + (HWY_COMPILER_CLANG >= 900 || HWY_COMPILER_GCC_ACTUAL >= 800))) +#define HWY_HAVE_SCALAR_F16_OPERATORS 1 +#else +#define HWY_HAVE_SCALAR_F16_OPERATORS 0 +#endif +#endif // HWY_HAVE_SCALAR_F16_OPERATORS + +namespace detail { + +template , bool = IsSpecialFloat()> +struct SpecialFloatUnwrapArithOpOperandT {}; + +template +struct SpecialFloatUnwrapArithOpOperandT { + using type = T; +}; + +template +using SpecialFloatUnwrapArithOpOperand = + typename SpecialFloatUnwrapArithOpOperandT::type; + +template > +struct NativeSpecialFloatToWrapperT { + using type = T; +}; + +template +using NativeSpecialFloatToWrapper = + typename NativeSpecialFloatToWrapperT::type; + +} // namespace detail + +// Match [u]int##_t naming scheme so rvv-inl.h macros can obtain the type name +// by concatenating base type and bits. We use a wrapper class instead of a +// typedef to the native type to ensure that the same symbols, e.g. for VQSort, +// are generated regardless of F16 support; see #1684. +struct alignas(2) float16_t { +#if HWY_HAVE_SCALAR_F16_TYPE +#if HWY_RVV_HAVE_F16_VEC || HWY_SSE2_HAVE_F16_TYPE || __SPIRV_DEVICE__ + using Native = _Float16; +#elif HWY_NEON_HAVE_F16C + using Native = __fp16; +#else +#error "Logic error: condition should be 'all but NEON_HAVE_F16C'" +#endif +#elif HWY_IDE + using Native = uint16_t; +#endif // HWY_HAVE_SCALAR_F16_TYPE + + union { +#if HWY_HAVE_SCALAR_F16_TYPE || HWY_IDE + // Accessed via NativeLaneType, and used directly if + // HWY_HAVE_SCALAR_F16_OPERATORS. + Native native; +#endif + // Only accessed via NativeLaneType or U16LaneType. + uint16_t bits; + }; + + // Default init and copying. + float16_t() noexcept = default; + constexpr float16_t(const float16_t&) noexcept = default; + constexpr float16_t(float16_t&&) noexcept = default; + float16_t& operator=(const float16_t&) noexcept = default; + float16_t& operator=(float16_t&&) noexcept = default; + +#if HWY_HAVE_SCALAR_F16_TYPE + // NEON vget/set_lane intrinsics and SVE `svaddv` could use explicit + // float16_t(intrinsic()), but user code expects implicit conversions. + constexpr float16_t(Native arg) noexcept : native(arg) {} + constexpr operator Native() const noexcept { return native; } +#endif + +#if HWY_HAVE_SCALAR_F16_TYPE + static HWY_BITCASTSCALAR_CONSTEXPR float16_t FromBits(uint16_t bits) { + return float16_t(BitCastScalar(bits)); + } +#else + + private: + struct F16FromU16BitsTag {}; + constexpr float16_t(F16FromU16BitsTag /*tag*/, uint16_t u16_bits) + : bits(u16_bits) {} + + public: + static constexpr float16_t FromBits(uint16_t bits) { + return float16_t(F16FromU16BitsTag(), bits); + } +#endif + + // When backed by a native type, ensure the wrapper behaves like the native + // type by forwarding all operators. Unfortunately it seems difficult to reuse + // this code in a base class, so we repeat it in float16_t. +#if HWY_HAVE_SCALAR_F16_OPERATORS || HWY_IDE + template , float16_t>() && + IsConvertible()>* = nullptr> + constexpr float16_t(T&& arg) noexcept + : native(static_cast(static_cast(arg))) {} + + template , float16_t>() && + !IsConvertible() && + IsStaticCastable()>* = nullptr> + explicit constexpr float16_t(T&& arg) noexcept + : native(static_cast(static_cast(arg))) {} + + // pre-decrement operator (--x) + HWY_CXX14_CONSTEXPR float16_t& operator--() noexcept { + native = static_cast(native - Native{1}); + return *this; + } + + // post-decrement operator (x--) + HWY_CXX14_CONSTEXPR float16_t operator--(int) noexcept { + float16_t result = *this; + native = static_cast(native - Native{1}); + return result; + } + + // pre-increment operator (++x) + HWY_CXX14_CONSTEXPR float16_t& operator++() noexcept { + native = static_cast(native + Native{1}); + return *this; + } + + // post-increment operator (x++) + HWY_CXX14_CONSTEXPR float16_t operator++(int) noexcept { + float16_t result = *this; + native = static_cast(native + Native{1}); + return result; + } + + constexpr float16_t operator-() const noexcept { + return float16_t(static_cast(-native)); + } + constexpr float16_t operator+() const noexcept { return *this; } + + // Reduce clutter by generating `operator+` and `operator+=` etc. Note that + // we cannot token-paste `operator` and `+`, so pass it in as `op_func`. +#define HWY_FLOAT16_BINARY_OP(op, op_func, assign_func) \ + constexpr float16_t op_func(const float16_t& rhs) const noexcept { \ + return float16_t(static_cast(native op rhs.native)); \ + } \ + template , \ + typename RawResultT = \ + decltype(DeclVal() op DeclVal()), \ + typename ResultT = \ + detail::NativeSpecialFloatToWrapper, \ + HWY_IF_CASTABLE(RawResultT, ResultT)> \ + constexpr ResultT op_func(const T& rhs) const noexcept(noexcept( \ + static_cast(DeclVal() op DeclVal()))) { \ + return static_cast(native op static_cast(rhs)); \ + } \ + HWY_CXX14_CONSTEXPR hwy::float16_t& assign_func( \ + const hwy::float16_t& rhs) noexcept { \ + native = static_cast(native op rhs.native); \ + return *this; \ + } \ + template () op DeclVal()))> \ + HWY_CXX14_CONSTEXPR hwy::float16_t& assign_func(const T& rhs) noexcept( \ + noexcept( \ + static_cast(DeclVal() op DeclVal()))) { \ + native = static_cast(native op rhs); \ + return *this; \ + } + + HWY_FLOAT16_BINARY_OP(+, operator+, operator+=) + HWY_FLOAT16_BINARY_OP(-, operator-, operator-=) + HWY_FLOAT16_BINARY_OP(*, operator*, operator*=) + HWY_FLOAT16_BINARY_OP(/, operator/, operator/=) +#undef HWY_FLOAT16_BINARY_OP + +#endif // HWY_HAVE_SCALAR_F16_OPERATORS +}; +static_assert(sizeof(hwy::float16_t) == 2, "Wrong size of float16_t"); + +#if HWY_HAVE_SCALAR_F16_TYPE +namespace detail { + +#if HWY_HAVE_SCALAR_F16_OPERATORS +template +struct SpecialFloatUnwrapArithOpOperandT { + using type = hwy::float16_t::Native; +}; +#endif + +template +struct NativeSpecialFloatToWrapperT { + using type = hwy::float16_t; +}; + +} // namespace detail +#endif // HWY_HAVE_SCALAR_F16_TYPE + +#if HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 +namespace detail { + +template <> +struct BitCastScalarSrcCastHelper { +#if HWY_HAVE_SCALAR_F16_TYPE + static HWY_INLINE constexpr const hwy::float16_t::Native& CastSrcValRef( + const hwy::float16_t& val) { + return val.native; + } +#else + static HWY_INLINE constexpr const uint16_t& CastSrcValRef( + const hwy::float16_t& val) { + return val.bits; + } +#endif +}; + +} // namespace detail +#endif // HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 + +#if HWY_HAVE_SCALAR_F16_OPERATORS +#define HWY_F16_CONSTEXPR constexpr +#else +#define HWY_F16_CONSTEXPR HWY_BITCASTSCALAR_CXX14_CONSTEXPR +#endif // HWY_HAVE_SCALAR_F16_OPERATORS + +HWY_API HWY_F16_CONSTEXPR float F32FromF16(float16_t f16) { +#if HWY_HAVE_SCALAR_F16_OPERATORS && !HWY_IDE + return static_cast(f16); +#endif +#if !HWY_HAVE_SCALAR_F16_OPERATORS || HWY_IDE + const uint16_t bits16 = BitCastScalar(f16); + const uint32_t sign = static_cast(bits16 >> 15); + const uint32_t biased_exp = (bits16 >> 10) & 0x1F; + const uint32_t mantissa = bits16 & 0x3FF; + + // Subnormal or zero + if (biased_exp == 0) { + const float subnormal = + (1.0f / 16384) * (static_cast(mantissa) * (1.0f / 1024)); + return sign ? -subnormal : subnormal; + } + + // Normalized, infinity or NaN: convert the representation directly + // (faster than ldexp/tables). + const uint32_t biased_exp32 = + biased_exp == 31 ? 0xFF : biased_exp + (127 - 15); + const uint32_t mantissa32 = mantissa << (23 - 10); + const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; + + return BitCastScalar(bits32); +#endif // !HWY_HAVE_SCALAR_F16_OPERATORS +} + +#if HWY_IS_DEBUG_BUILD && \ + (HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926) +#if defined(__cpp_if_consteval) && __cpp_if_consteval >= 202106L +// If C++23 if !consteval support is available, only execute +// HWY_DASSERT(condition) if F16FromF32 is not called from a constant-evaluated +// context to avoid compilation errors. +#define HWY_F16_FROM_F32_DASSERT(condition) \ + do { \ + if !consteval { \ + HWY_DASSERT(condition); \ + } \ + } while (0) +#elif HWY_HAS_BUILTIN(__builtin_is_constant_evaluated) || \ + HWY_COMPILER_MSVC >= 1926 +// If the __builtin_is_constant_evaluated() intrinsic is available, +// only do HWY_DASSERT(condition) if __builtin_is_constant_evaluated() returns +// false to avoid compilation errors if F16FromF32 is called from a +// constant-evaluated context. +#define HWY_F16_FROM_F32_DASSERT(condition) \ + do { \ + if (!__builtin_is_constant_evaluated()) { \ + HWY_DASSERT(condition); \ + } \ + } while (0) +#else +// If C++23 if !consteval support is not available, +// the __builtin_is_constant_evaluated() intrinsic is not available, +// HWY_IS_DEBUG_BUILD is 1, and the __builtin_bit_cast intrinsic is available, +// do not do a HWY_DASSERT to avoid compilation errors if F16FromF32 is +// called from a constant-evaluated context. +#define HWY_F16_FROM_F32_DASSERT(condition) \ + do { \ + } while (0) +#endif // defined(__cpp_if_consteval) && __cpp_if_consteval >= 202106L +#else +// If HWY_IS_DEBUG_BUILD is 0 or the __builtin_bit_cast intrinsic is not +// available, define HWY_F16_FROM_F32_DASSERT(condition) as +// HWY_DASSERT(condition) +#define HWY_F16_FROM_F32_DASSERT(condition) HWY_DASSERT(condition) +#endif // HWY_IS_DEBUG_BUILD && (HWY_HAS_BUILTIN(__builtin_bit_cast) || + // HWY_COMPILER_MSVC >= 1926) + +HWY_API HWY_F16_CONSTEXPR float16_t F16FromF32(float f32) { +#if HWY_HAVE_SCALAR_F16_OPERATORS && !HWY_IDE + return float16_t(static_cast(f32)); +#endif +#if !HWY_HAVE_SCALAR_F16_OPERATORS || HWY_IDE + const uint32_t bits32 = BitCastScalar(f32); + const uint32_t sign = bits32 >> 31; + const uint32_t biased_exp32 = (bits32 >> 23) & 0xFF; + constexpr uint32_t kMantissaMask = 0x7FFFFF; + const uint32_t mantissa32 = bits32 & kMantissaMask; + + // Before shifting (truncation), round to nearest even to reduce bias. If + // the lowest remaining mantissa bit is odd, increase the offset. Example + // with the lowest remaining bit (left) and next lower two bits; the + // latter, plus two more, will be truncated. + // 0[00] + 1 = 0[01] + // 0[01] + 1 = 0[10] + // 0[10] + 1 = 0[11] (round down toward even) + // 0[11] + 1 = 1[00] (round up) + // 1[00] + 10 = 1[10] + // 1[01] + 10 = 1[11] + // 1[10] + 10 = C0[00] (round up toward even with C=1 carry out) + // 1[11] + 10 = C0[01] (round up toward even with C=1 carry out) + + // If |f32| >= 2^-24, f16_ulp_bit_idx is the index of the F32 mantissa bit + // that will be shifted down into the ULP bit of the rounded down F16 result + + // The biased F32 exponent of 2^-14 (the smallest positive normal F16 value) + // is 113, and bit 13 of the F32 mantissa will be shifted down to into the ULP + // bit of the rounded down F16 result if |f32| >= 2^14 + + // If |f32| < 2^-24, f16_ulp_bit_idx is equal to 24 as there are 24 mantissa + // bits (including the implied 1 bit) in the mantissa of a normal F32 value + // and as we want to round up the mantissa if |f32| > 2^-25 && |f32| < 2^-24 + const int32_t f16_ulp_bit_idx = + HWY_MIN(HWY_MAX(126 - static_cast(biased_exp32), 13), 24); + const uint32_t odd_bit = ((mantissa32 | 0x800000u) >> f16_ulp_bit_idx) & 1; + const uint32_t rounded = + mantissa32 + odd_bit + (uint32_t{1} << (f16_ulp_bit_idx - 1)) - 1u; + const bool carry = rounded >= (1u << 23); + + const int32_t exp = static_cast(biased_exp32) - 127 + carry; + + // Tiny or zero => zero. + if (exp < -24) { + // restore original sign + return float16_t::FromBits(static_cast(sign << 15)); + } + + // If biased_exp16 would be >= 31, first check whether the input was NaN so we + // can set the mantissa to nonzero. + const bool is_nan = (biased_exp32 == 255) && mantissa32 != 0; + const bool overflowed = exp >= 16; + const uint32_t biased_exp16 = + static_cast(HWY_MIN(HWY_MAX(0, exp + 15), 31)); + // exp = [-24, -15] => subnormal, shift the mantissa. + const uint32_t sub_exp = static_cast(HWY_MAX(-14 - exp, 0)); + HWY_F16_FROM_F32_DASSERT(sub_exp < 11); + const uint32_t shifted_mantissa = + (rounded & kMantissaMask) >> (23 - 10 + sub_exp); + const uint32_t leading = sub_exp == 0u ? 0u : (1024u >> sub_exp); + const uint32_t mantissa16 = is_nan ? 0x3FF + : overflowed ? 0u + : (leading + shifted_mantissa); + +#if HWY_IS_DEBUG_BUILD + if (exp < -14) { + HWY_F16_FROM_F32_DASSERT(biased_exp16 == 0); + HWY_F16_FROM_F32_DASSERT(sub_exp >= 1); + } else if (exp <= 15) { + HWY_F16_FROM_F32_DASSERT(1 <= biased_exp16 && biased_exp16 < 31); + HWY_F16_FROM_F32_DASSERT(sub_exp == 0); + } +#endif + + HWY_F16_FROM_F32_DASSERT(mantissa16 < 1024); + const uint32_t bits16 = (sign << 15) | (biased_exp16 << 10) | mantissa16; + HWY_F16_FROM_F32_DASSERT(bits16 < 0x10000); + const uint16_t narrowed = static_cast(bits16); // big-endian safe + return float16_t::FromBits(narrowed); +#endif // !HWY_HAVE_SCALAR_F16_OPERATORS +} + +HWY_API HWY_F16_CONSTEXPR float16_t F16FromF64(double f64) { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return float16_t(static_cast(f64)); +#else + // The mantissa bits of f64 are first rounded using round-to-odd rounding + // to the nearest f64 value that has the lower 29 bits zeroed out to + // ensure that the result is correctly rounded to a F16. + + // The F64 round-to-odd operation below will round a normal F64 value + // (using round-to-odd rounding) to a F64 value that has 24 bits of precision. + + // It is okay if the magnitude of a denormal F64 value is rounded up in the + // F64 round-to-odd step below as the magnitude of a denormal F64 value is + // much smaller than 2^(-24) (the smallest positive denormal F16 value). + + // It is also okay if bit 29 of a NaN F64 value is changed by the F64 + // round-to-odd step below as the lower 13 bits of a F32 NaN value are usually + // discarded or ignored by the conversion of a F32 NaN value to a F16. + + // If f64 is a NaN value, the result of the F64 round-to-odd step will be a + // NaN value as the result of the F64 round-to-odd step will have at least one + // mantissa bit if f64 is a NaN value. + + // The F64 round-to-odd step will ensure that the F64 to F32 conversion is + // exact if the magnitude of the rounded F64 value (using round-to-odd + // rounding) is between 2^(-126) (the smallest normal F32 value) and + // HighestValue() (the largest finite F32 value) + + // It is okay if the F64 to F32 conversion is inexact for F64 values that have + // a magnitude that is less than 2^(-126) as the magnitude of a denormal F32 + // value is much smaller than 2^(-24) (the smallest positive denormal F16 + // value). + + return F16FromF32( + static_cast(BitCastScalar(static_cast( + (BitCastScalar(f64) & 0xFFFFFFFFE0000000ULL) | + ((BitCastScalar(f64) + 0x000000001FFFFFFFULL) & + 0x0000000020000000ULL))))); +#endif +} + +// More convenient to define outside float16_t because these may use +// F32FromF16, which is defined after the struct. +HWY_F16_CONSTEXPR inline bool operator==(float16_t lhs, + float16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return lhs.native == rhs.native; +#else + return F32FromF16(lhs) == F32FromF16(rhs); +#endif +} +HWY_F16_CONSTEXPR inline bool operator!=(float16_t lhs, + float16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return lhs.native != rhs.native; +#else + return F32FromF16(lhs) != F32FromF16(rhs); +#endif +} +HWY_F16_CONSTEXPR inline bool operator<(float16_t lhs, float16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return lhs.native < rhs.native; +#else + return F32FromF16(lhs) < F32FromF16(rhs); +#endif +} +HWY_F16_CONSTEXPR inline bool operator<=(float16_t lhs, + float16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return lhs.native <= rhs.native; +#else + return F32FromF16(lhs) <= F32FromF16(rhs); +#endif +} +HWY_F16_CONSTEXPR inline bool operator>(float16_t lhs, float16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return lhs.native > rhs.native; +#else + return F32FromF16(lhs) > F32FromF16(rhs); +#endif +} +HWY_F16_CONSTEXPR inline bool operator>=(float16_t lhs, + float16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return lhs.native >= rhs.native; +#else + return F32FromF16(lhs) >= F32FromF16(rhs); +#endif +} +#if HWY_HAVE_CXX20_THREE_WAY_COMPARE +HWY_F16_CONSTEXPR inline std::partial_ordering operator<=>( + float16_t lhs, float16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_F16_OPERATORS + return lhs.native <=> rhs.native; +#else + return F32FromF16(lhs) <=> F32FromF16(rhs); +#endif +} +#endif // HWY_HAVE_CXX20_THREE_WAY_COMPARE + +//------------------------------------------------------------------------------ +// BF16 lane type + +// x86 compiler supports __bf16, not necessarily with operators. +// Disable in debug builds due to clang miscompiles as of 2025-07-22: casting +// bf16 <-> f32 in convert_test results in 0x2525 for 1.0 instead of 0x3f80. +// Reported at https://github.com/llvm/llvm-project/issues/151692. +#ifndef HWY_SSE2_HAVE_SCALAR_BF16_TYPE +#if HWY_ARCH_X86 && defined(__SSE2__) && \ + ((HWY_COMPILER_CLANG >= 1700 && !HWY_COMPILER_CLANGCL && \ + (!HWY_IS_DEBUG_BUILD || HWY_COMPILER3_CLANG >= 220101)) || \ + HWY_COMPILER_GCC_ACTUAL >= 1300) +#define HWY_SSE2_HAVE_SCALAR_BF16_TYPE 1 +#else +#define HWY_SSE2_HAVE_SCALAR_BF16_TYPE 0 +#endif +#endif // HWY_SSE2_HAVE_SCALAR_BF16_TYPE + +// Compiler supports __bf16, not necessarily with operators. +#if HWY_ARM_HAVE_SCALAR_BF16_TYPE || HWY_SSE2_HAVE_SCALAR_BF16_TYPE +#define HWY_HAVE_SCALAR_BF16_TYPE 1 +#else +#define HWY_HAVE_SCALAR_BF16_TYPE 0 +#endif + +#ifndef HWY_HAVE_SCALAR_BF16_OPERATORS +// Recent enough compiler also has operators. aarch64 clang 18 hits internal +// compiler errors on bf16 ToString, hence only enable on GCC for now. +// GCC >= 13 will insert a function call to the __extendbfsf2 helper function +// for scalar conversions from __bf16 to float. This is prohibitively expensive, +// so refrain from using scalar BF16 operators on these compiler versions. +// See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=121853 +#if HWY_HAVE_SCALAR_BF16_TYPE && (HWY_COMPILER_GCC_ACTUAL >= 1700) +#define HWY_HAVE_SCALAR_BF16_OPERATORS 1 +#else +#define HWY_HAVE_SCALAR_BF16_OPERATORS 0 +#endif +#endif // HWY_HAVE_SCALAR_BF16_OPERATORS + +#if HWY_HAVE_SCALAR_BF16_OPERATORS +#define HWY_BF16_CONSTEXPR constexpr +#else +#define HWY_BF16_CONSTEXPR HWY_BITCASTSCALAR_CONSTEXPR +#endif + +struct alignas(2) bfloat16_t { +#if HWY_HAVE_SCALAR_BF16_TYPE + using Native = __bf16; +#elif HWY_IDE + using Native = uint16_t; +#endif + + union { +#if HWY_HAVE_SCALAR_BF16_TYPE || HWY_IDE + // Accessed via NativeLaneType, and used directly if + // HWY_HAVE_SCALAR_BF16_OPERATORS. + Native native; +#endif + // Only accessed via NativeLaneType or U16LaneType. + uint16_t bits; + }; + + // Default init and copying + bfloat16_t() noexcept = default; + constexpr bfloat16_t(bfloat16_t&&) noexcept = default; + constexpr bfloat16_t(const bfloat16_t&) noexcept = default; + bfloat16_t& operator=(bfloat16_t&& arg) noexcept = default; + bfloat16_t& operator=(const bfloat16_t& arg) noexcept = default; + +// Only enable implicit conversions if we have a native type. +#if HWY_HAVE_SCALAR_BF16_TYPE || HWY_IDE + constexpr bfloat16_t(Native arg) noexcept : native(arg) {} + constexpr operator Native() const noexcept { return native; } +#endif + +#if HWY_HAVE_SCALAR_BF16_TYPE + static HWY_BITCASTSCALAR_CONSTEXPR bfloat16_t FromBits(uint16_t bits) { + return bfloat16_t(BitCastScalar(bits)); + } +#else + + private: + struct BF16FromU16BitsTag {}; + constexpr bfloat16_t(BF16FromU16BitsTag /*tag*/, uint16_t u16_bits) + : bits(u16_bits) {} + + public: + static constexpr bfloat16_t FromBits(uint16_t bits) { + return bfloat16_t(BF16FromU16BitsTag(), bits); + } +#endif + + // When backed by a native type, ensure the wrapper behaves like the native + // type by forwarding all operators. Unfortunately it seems difficult to reuse + // this code in a base class, so we repeat it in float16_t. +#if HWY_HAVE_SCALAR_BF16_OPERATORS || HWY_IDE + template , Native>() && + !IsSame, bfloat16_t>() && + IsConvertible()>* = nullptr> + constexpr bfloat16_t(T&& arg) noexcept( + noexcept(static_cast(DeclVal()))) + : native(static_cast(static_cast(arg))) {} + + template , Native>() && + !IsSame, bfloat16_t>() && + !IsConvertible() && + IsStaticCastable()>* = nullptr> + explicit constexpr bfloat16_t(T&& arg) noexcept( + noexcept(static_cast(DeclVal()))) + : native(static_cast(static_cast(arg))) {} + + HWY_CXX14_CONSTEXPR bfloat16_t& operator=(Native arg) noexcept { + native = arg; + return *this; + } + + // pre-decrement operator (--x) + HWY_CXX14_CONSTEXPR bfloat16_t& operator--() noexcept { + native = static_cast(native - Native{1}); + return *this; + } + + // post-decrement operator (x--) + HWY_CXX14_CONSTEXPR bfloat16_t operator--(int) noexcept { + bfloat16_t result = *this; + native = static_cast(native - Native{1}); + return result; + } + + // pre-increment operator (++x) + HWY_CXX14_CONSTEXPR bfloat16_t& operator++() noexcept { + native = static_cast(native + Native{1}); + return *this; + } + + // post-increment operator (x++) + HWY_CXX14_CONSTEXPR bfloat16_t operator++(int) noexcept { + bfloat16_t result = *this; + native = static_cast(native + Native{1}); + return result; + } + + constexpr bfloat16_t operator-() const noexcept { + return bfloat16_t(static_cast(-native)); + } + constexpr bfloat16_t operator+() const noexcept { return *this; } + + // Reduce clutter by generating `operator+` and `operator+=` etc. Note that + // we cannot token-paste `operator` and `+`, so pass it in as `op_func`. +#define HWY_BFLOAT16_BINARY_OP(op, op_func, assign_func) \ + constexpr bfloat16_t op_func(const bfloat16_t& rhs) const noexcept { \ + return bfloat16_t(static_cast(native op rhs.native)); \ + } \ + template , \ + typename RawResultT = \ + decltype(DeclVal() op DeclVal()), \ + typename ResultT = \ + detail::NativeSpecialFloatToWrapper, \ + HWY_IF_CASTABLE(RawResultT, ResultT)> \ + constexpr ResultT op_func(const T& rhs) const noexcept(noexcept( \ + static_cast(DeclVal() op DeclVal()))) { \ + return static_cast(native op static_cast(rhs)); \ + } \ + HWY_CXX14_CONSTEXPR hwy::bfloat16_t& assign_func( \ + const hwy::bfloat16_t& rhs) noexcept { \ + native = static_cast(native op rhs.native); \ + return *this; \ + } \ + template () op DeclVal()))> \ + HWY_CXX14_CONSTEXPR hwy::bfloat16_t& assign_func(const T& rhs) noexcept( \ + noexcept( \ + static_cast(DeclVal() op DeclVal()))) { \ + native = static_cast(native op rhs); \ + return *this; \ + } + HWY_BFLOAT16_BINARY_OP(+, operator+, operator+=) + HWY_BFLOAT16_BINARY_OP(-, operator-, operator-=) + HWY_BFLOAT16_BINARY_OP(*, operator*, operator*=) + HWY_BFLOAT16_BINARY_OP(/, operator/, operator/=) +#undef HWY_BFLOAT16_BINARY_OP + +#endif // HWY_HAVE_SCALAR_BF16_OPERATORS +}; +static_assert(sizeof(hwy::bfloat16_t) == 2, "Wrong size of bfloat16_t"); + +#pragma pack(pop) + +#if HWY_HAVE_SCALAR_BF16_TYPE +namespace detail { + +#if HWY_HAVE_SCALAR_BF16_OPERATORS +template +struct SpecialFloatUnwrapArithOpOperandT { + using type = hwy::bfloat16_t::Native; +}; +#endif + +template +struct NativeSpecialFloatToWrapperT { + using type = hwy::bfloat16_t; +}; + +} // namespace detail +#endif // HWY_HAVE_SCALAR_BF16_TYPE + +#if HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 +namespace detail { + +template <> +struct BitCastScalarSrcCastHelper { +#if HWY_HAVE_SCALAR_BF16_TYPE + static HWY_INLINE constexpr const hwy::bfloat16_t::Native& CastSrcValRef( + const hwy::bfloat16_t& val) { + return val.native; + } +#else + static HWY_INLINE constexpr const uint16_t& CastSrcValRef( + const hwy::bfloat16_t& val) { + return val.bits; + } +#endif +}; + +} // namespace detail +#endif // HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 + +HWY_API HWY_BF16_CONSTEXPR float F32FromBF16(bfloat16_t bf) { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return static_cast(bf); +#else + return BitCastScalar(static_cast( + static_cast(BitCastScalar(bf)) << 16)); +#endif +} + +namespace detail { + +// Returns the increment to add to the bits of a finite F32 value to round a +// finite F32 to the nearest BF16 value +static HWY_INLINE HWY_MAYBE_UNUSED constexpr uint32_t F32BitsToBF16RoundIncr( + const uint32_t f32_bits) { + return static_cast(((f32_bits & 0x7FFFFFFFu) < 0x7F800000u) + ? (0x7FFFu + ((f32_bits >> 16) & 1u)) + : 0u); +} + +// If f32_bits is the bit representation of a NaN F32 value, make sure that +// bit 6 of the BF16 result is set to convert SNaN F32 values to QNaN BF16 +// values and to prevent NaN F32 values from being converted to an infinite +// BF16 value +static HWY_INLINE constexpr uint32_t BF16BitsIfSNAN(uint32_t f32_bits) { + return ((f32_bits & 0x7FFFFFFFu) > 0x7F800000u) ? (uint32_t{1} << 6) : 0; +} + +// Converts f32_bits (which is the bits of a F32 value) to BF16 bits, +// rounded to the nearest F16 value +static HWY_INLINE HWY_MAYBE_UNUSED constexpr uint16_t F32BitsToBF16Bits( + const uint32_t f32_bits) { + return static_cast( + BF16BitsIfSNAN(f32_bits) | + ((f32_bits + F32BitsToBF16RoundIncr(f32_bits)) >> 16)); +} + +} // namespace detail + +HWY_API HWY_BF16_CONSTEXPR bfloat16_t BF16FromF32(float f) { + // The rounding mode is not specified in the C++ standard, so ignore + // `HWY_HAVE_SCALAR_BF16_OPERATORS` and only use our round to nearest. + return bfloat16_t::FromBits( + detail::F32BitsToBF16Bits(BitCastScalar(f))); +} + +HWY_API HWY_BF16_CONSTEXPR bfloat16_t BF16FromF64(double f64) { + // The mantissa bits of f64 are first rounded using round-to-odd rounding + // to the nearest f64 value that has the lower 38 bits zeroed out to + // ensure that the result is correctly rounded to a BF16. + + // The F64 round-to-odd operation below will round a normal F64 value + // (using round-to-odd rounding) to a F64 value that has 15 bits of precision. + + // It is okay if the magnitude of a denormal F64 value is rounded up in the + // F64 round-to-odd step below as the magnitude of a denormal F64 value is + // much smaller than 2^(-133) (the smallest positive denormal BF16 value). + + // It is also okay if bit 38 of a NaN F64 value is changed by the F64 + // round-to-odd step below as the lower 16 bits of a F32 NaN value are usually + // discarded or ignored by the conversion of a F32 NaN value to a BF16. + + // If f64 is a NaN value, the result of the F64 round-to-odd step will be a + // NaN value as the result of the F64 round-to-odd step will have at least one + // mantissa bit if f64 is a NaN value. + + // The F64 round-to-odd step below will ensure that the F64 to F32 conversion + // is exact if the magnitude of the rounded F64 value (using round-to-odd + // rounding) is between 2^(-135) (one-fourth of the smallest positive denormal + // BF16 value) and HighestValue() (the largest finite F32 value). + + // If |f64| is less than 2^(-135), the magnitude of the result of the F64 to + // F32 conversion is guaranteed to be less than or equal to 2^(-135), which + // ensures that the F32 to BF16 conversion is correctly rounded, even if the + // conversion of a rounded F64 value whose magnitude is less than 2^(-135) + // to a F32 is inexact. + + return BF16FromF32( + static_cast(BitCastScalar(static_cast( + (BitCastScalar(f64) & 0xFFFFFFC000000000ULL) | + ((BitCastScalar(f64) + 0x0000003FFFFFFFFFULL) & + 0x0000004000000000ULL))))); +} + +// More convenient to define outside bfloat16_t because these may use +// F32FromBF16, which is defined after the struct. + +HWY_BF16_CONSTEXPR inline bool operator==(bfloat16_t lhs, + bfloat16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return lhs.native == rhs.native; +#else + return F32FromBF16(lhs) == F32FromBF16(rhs); +#endif +} + +HWY_BF16_CONSTEXPR inline bool operator!=(bfloat16_t lhs, + bfloat16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return lhs.native != rhs.native; +#else + return F32FromBF16(lhs) != F32FromBF16(rhs); +#endif +} +HWY_BF16_CONSTEXPR inline bool operator<(bfloat16_t lhs, + bfloat16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return lhs.native < rhs.native; +#else + return F32FromBF16(lhs) < F32FromBF16(rhs); +#endif +} +HWY_BF16_CONSTEXPR inline bool operator<=(bfloat16_t lhs, + bfloat16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return lhs.native <= rhs.native; +#else + return F32FromBF16(lhs) <= F32FromBF16(rhs); +#endif +} +HWY_BF16_CONSTEXPR inline bool operator>(bfloat16_t lhs, + bfloat16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return lhs.native > rhs.native; +#else + return F32FromBF16(lhs) > F32FromBF16(rhs); +#endif +} +HWY_BF16_CONSTEXPR inline bool operator>=(bfloat16_t lhs, + bfloat16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return lhs.native >= rhs.native; +#else + return F32FromBF16(lhs) >= F32FromBF16(rhs); +#endif +} +#if HWY_HAVE_CXX20_THREE_WAY_COMPARE +HWY_BF16_CONSTEXPR inline std::partial_ordering operator<=>( + bfloat16_t lhs, bfloat16_t rhs) noexcept { +#if HWY_HAVE_SCALAR_BF16_OPERATORS + return lhs.native <=> rhs.native; +#else + return F32FromBF16(lhs) <=> F32FromBF16(rhs); +#endif +} +#endif // HWY_HAVE_CXX20_THREE_WAY_COMPARE + +//------------------------------------------------------------------------------ +// Type relations + +namespace detail { + +template +struct Relations; +template <> +struct Relations { + using Unsigned = uint8_t; + using Signed = int8_t; + using Wide = uint16_t; + enum { is_signed = 0, is_float = 0, is_bf16 = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint8_t; + using Signed = int8_t; + using Wide = int16_t; + enum { is_signed = 1, is_float = 0, is_bf16 = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint16_t; + using Signed = int16_t; + using Float = float16_t; + using Wide = uint32_t; + using Narrow = uint8_t; + enum { is_signed = 0, is_float = 0, is_bf16 = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint16_t; + using Signed = int16_t; + using Float = float16_t; + using Wide = int32_t; + using Narrow = int8_t; + enum { is_signed = 1, is_float = 0, is_bf16 = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint32_t; + using Signed = int32_t; + using Float = float; + using Wide = uint64_t; + using Narrow = uint16_t; + enum { is_signed = 0, is_float = 0, is_bf16 = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint32_t; + using Signed = int32_t; + using Float = float; + using Wide = int64_t; + using Narrow = int16_t; + enum { is_signed = 1, is_float = 0, is_bf16 = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint64_t; + using Signed = int64_t; + using Float = double; + using Wide = uint128_t; + using Narrow = uint32_t; + enum { is_signed = 0, is_float = 0, is_bf16 = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint64_t; + using Signed = int64_t; + using Float = double; + using Narrow = int32_t; + enum { is_signed = 1, is_float = 0, is_bf16 = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint128_t; + using Narrow = uint64_t; + enum { is_signed = 0, is_float = 0, is_bf16 = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint16_t; + using Signed = int16_t; + using Float = float16_t; + using Wide = float; + enum { is_signed = 1, is_float = 1, is_bf16 = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint16_t; + using Signed = int16_t; + using Wide = float; + enum { is_signed = 1, is_float = 1, is_bf16 = 1 }; +}; +template <> +struct Relations { + using Unsigned = uint32_t; + using Signed = int32_t; + using Float = float; + using Wide = double; + using Narrow = float16_t; + enum { is_signed = 1, is_float = 1, is_bf16 = 0 }; +}; +template <> +struct Relations { + using Unsigned = uint64_t; + using Signed = int64_t; + using Float = double; + using Narrow = float; + enum { is_signed = 1, is_float = 1, is_bf16 = 0 }; +}; + +template +struct TypeFromSize; +template <> +struct TypeFromSize<1> { + using Unsigned = uint8_t; + using Signed = int8_t; +}; +template <> +struct TypeFromSize<2> { + using Unsigned = uint16_t; + using Signed = int16_t; + using Float = float16_t; +}; +template <> +struct TypeFromSize<4> { + using Unsigned = uint32_t; + using Signed = int32_t; + using Float = float; +}; +template <> +struct TypeFromSize<8> { + using Unsigned = uint64_t; + using Signed = int64_t; + using Float = double; +}; +template <> +struct TypeFromSize<16> { + using Unsigned = uint128_t; +}; + +} // namespace detail + +// Aliases for types of a different category, but the same size. +template +using MakeUnsigned = typename detail::Relations::Unsigned; +template +using MakeSigned = typename detail::Relations::Signed; +template +using MakeFloat = typename detail::Relations::Float; + +// Aliases for types of the same category, but different size. +template +using MakeWide = typename detail::Relations::Wide; +template +using MakeNarrow = typename detail::Relations::Narrow; + +// Obtain type from its size [bytes]. +template +using UnsignedFromSize = typename detail::TypeFromSize::Unsigned; +template +using SignedFromSize = typename detail::TypeFromSize::Signed; +template +using FloatFromSize = typename detail::TypeFromSize::Float; + +// Avoid confusion with SizeTag where the parameter is a lane size. +using UnsignedTag = SizeTag<0>; +using SignedTag = SizeTag<0x100>; // integer +using FloatTag = SizeTag<0x200>; +using SpecialTag = SizeTag<0x300>; + +template > +constexpr auto TypeTag() + -> hwy::SizeTag<((R::is_signed + R::is_float + R::is_bf16) << 8)> { + return hwy::SizeTag<((R::is_signed + R::is_float + R::is_bf16) << 8)>(); +} + +// For when we only want to distinguish FloatTag from everything else. +using NonFloatTag = SizeTag<0x400>; + +template > +constexpr auto IsFloatTag() -> hwy::SizeTag<(R::is_float ? 0x200 : 0x400)> { + return hwy::SizeTag<(R::is_float ? 0x200 : 0x400)>(); +} + +//------------------------------------------------------------------------------ +// Type traits + +template +HWY_API constexpr bool IsFloat3264() { + return IsSameEither, float, double>(); +} + +template +HWY_API constexpr bool IsFloat() { + // Cannot use T(1.25) != T(1) for float16_t, which can only be converted to or + // from a float, not compared. Include float16_t in case HWY_HAVE_FLOAT16=1. + return IsSame, float16_t>() || IsFloat3264(); +} + +template +HWY_API constexpr bool IsSigned() { + return static_cast(0) > static_cast(-1); +} +template <> +constexpr bool IsSigned() { + return true; +} +template <> +constexpr bool IsSigned() { + return true; +} +template <> +constexpr bool IsSigned() { + return false; +} +template <> +constexpr bool IsSigned() { + return false; +} +template <> +constexpr bool IsSigned() { + return false; +} + +template +HWY_API constexpr bool IsUnsigned() { + return IsInteger() && !IsSigned(); +} + +template () && !IsIntegerLaneType()> +struct MakeLaneTypeIfIntegerT { + using type = T; +}; + +template +struct MakeLaneTypeIfIntegerT { + using type = hwy::If(), SignedFromSize, + UnsignedFromSize>; +}; + +template +using MakeLaneTypeIfInteger = typename MakeLaneTypeIfIntegerT::type; + +// Largest/smallest representable integer values. +template +HWY_API constexpr T LimitsMax() { + static_assert(IsInteger(), "Only for integer types"); + using TU = UnsignedFromSize; + return static_cast(IsSigned() ? (static_cast(~TU(0)) >> 1) + : static_cast(~TU(0))); +} +template +HWY_API constexpr T LimitsMin() { + static_assert(IsInteger(), "Only for integer types"); + return IsSigned() ? static_cast(-1) - LimitsMax() + : static_cast(0); +} + +// Largest/smallest representable value (integer or float). This naming avoids +// confusion with numeric_limits::min() (the smallest positive value). +// Cannot be constexpr because we use CopySameSize for [b]float16_t. +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR T LowestValue() { + return LimitsMin(); +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bfloat16_t LowestValue() { + return bfloat16_t::FromBits(uint16_t{0xFF7Fu}); // -1.1111111 x 2^127 +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float16_t LowestValue() { + return float16_t::FromBits(uint16_t{0xFBFFu}); // -1.1111111111 x 2^15 +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float LowestValue() { + return -3.402823466e+38F; +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR double LowestValue() { + return -1.7976931348623158e+308; +} + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR T HighestValue() { + return LimitsMax(); +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bfloat16_t HighestValue() { + return bfloat16_t::FromBits(uint16_t{0x7F7Fu}); // 1.1111111 x 2^127 +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float16_t HighestValue() { + return float16_t::FromBits(uint16_t{0x7BFFu}); // 1.1111111111 x 2^15 +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float HighestValue() { + return 3.402823466e+38F; +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR double HighestValue() { + return 1.7976931348623158e+308; +} + +// Difference between 1.0 and the next representable value. Equal to +// 1 / (1ULL << MantissaBits()), but hard-coding ensures precision. +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR T Epsilon() { + return 1; +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bfloat16_t Epsilon() { + return bfloat16_t::FromBits(uint16_t{0x3C00u}); // 0.0078125 +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float16_t Epsilon() { + return float16_t::FromBits(uint16_t{0x1400u}); // 0.0009765625 +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float Epsilon() { + return 1.192092896e-7f; +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR double Epsilon() { + return 2.2204460492503131e-16; +} + +// Returns width in bits of the mantissa field in IEEE binary16/32/64. +template +constexpr int MantissaBits() { + static_assert(sizeof(T) == 0, "Only instantiate the specializations"); + return 0; +} +template <> +constexpr int MantissaBits() { + return 7; +} +template <> +constexpr int MantissaBits() { + return 10; +} +template <> +constexpr int MantissaBits() { + return 23; +} +template <> +constexpr int MantissaBits() { + return 52; +} + +// Returns the (left-shifted by one bit) IEEE binary16/32/64 representation with +// the largest possible (biased) exponent field. Used by IsInf. +template +constexpr MakeSigned MaxExponentTimes2() { + return -(MakeSigned{1} << (MantissaBits() + 1)); +} + +// Returns bitmask of the sign bit in IEEE binary16/32/64. +template +constexpr MakeUnsigned SignMask() { + return MakeUnsigned{1} << (sizeof(T) * 8 - 1); +} + +// Returns bitmask of the exponent field in IEEE binary16/32/64. +template +constexpr MakeUnsigned ExponentMask() { + return (~(MakeUnsigned{1} << MantissaBits()) + 1) & + static_cast>(~SignMask()); +} + +// Returns bitmask of the mantissa field in IEEE binary16/32/64. +template +constexpr MakeUnsigned MantissaMask() { + return (MakeUnsigned{1} << MantissaBits()) - 1; +} + +// Returns 1 << mantissa_bits as a floating-point number. All integers whose +// absolute value are less than this can be represented exactly. +template +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR T MantissaEnd() { + static_assert(sizeof(T) == 0, "Only instantiate the specializations"); + return 0; +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bfloat16_t MantissaEnd() { + return bfloat16_t::FromBits(uint16_t{0x4300u}); // 1.0 x 2^7 +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float16_t MantissaEnd() { + return float16_t::FromBits(uint16_t{0x6400u}); // 1.0 x 2^10 +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float MantissaEnd() { + return 8388608.0f; // 1 << 23 +} +template <> +HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR double MantissaEnd() { + // floating point literal with p52 requires C++17. + return 4503599627370496.0; // 1 << 52 +} + +// Returns width in bits of the exponent field in IEEE binary16/32/64. +template +constexpr int ExponentBits() { + // Exponent := remaining bits after deducting sign and mantissa. + return 8 * sizeof(T) - 1 - MantissaBits(); +} + +// Returns largest value of the biased exponent field in IEEE binary16/32/64, +// right-shifted so that the LSB is bit zero. Example: 0xFF for float. +// This is expressed as a signed integer for more efficient comparison. +template +constexpr MakeSigned MaxExponentField() { + return (MakeSigned{1} << ExponentBits()) - 1; +} + +namespace detail { + +template +static HWY_INLINE HWY_MAYBE_UNUSED HWY_BITCASTSCALAR_CONSTEXPR T +NegativeInfOrLowestValue(hwy::FloatTag /* tag */) { + return BitCastScalar( + static_cast>(SignMask() | ExponentMask())); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED HWY_BITCASTSCALAR_CONSTEXPR T +NegativeInfOrLowestValue(hwy::NonFloatTag /* tag */) { + return LowestValue(); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED HWY_BITCASTSCALAR_CONSTEXPR T +PositiveInfOrHighestValue(hwy::FloatTag /* tag */) { + return BitCastScalar(ExponentMask()); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED HWY_BITCASTSCALAR_CONSTEXPR T +PositiveInfOrHighestValue(hwy::NonFloatTag /* tag */) { + return HighestValue(); +} + +} // namespace detail + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR T NegativeInfOrLowestValue() { + return detail::NegativeInfOrLowestValue(IsFloatTag()); +} + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR T PositiveInfOrHighestValue() { + return detail::PositiveInfOrHighestValue(IsFloatTag()); +} + +//------------------------------------------------------------------------------ +// Additional F16/BF16 operators + +#if HWY_HAVE_SCALAR_F16_OPERATORS || HWY_HAVE_SCALAR_BF16_OPERATORS + +#define HWY_RHS_SPECIAL_FLOAT_ARITH_OP(op, op_func, T2) \ + template < \ + typename T1, \ + hwy::EnableIf>() || \ + hwy::IsFloat3264>()>* = nullptr, \ + typename RawResultT = decltype(DeclVal() op DeclVal()), \ + typename ResultT = detail::NativeSpecialFloatToWrapper, \ + HWY_IF_CASTABLE(RawResultT, ResultT)> \ + static HWY_INLINE constexpr ResultT op_func(T1 a, T2 b) noexcept { \ + return static_cast(a op b.native); \ + } + +#define HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(op, assign_op, T2) \ + template >() || \ + hwy::IsFloat3264>()>* = nullptr, \ + typename ResultT = \ + decltype(DeclVal() assign_op DeclVal())> \ + static HWY_INLINE constexpr ResultT operator assign_op(T1& a, \ + T2 b) noexcept { \ + return (a assign_op b.native); \ + } + +#define HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(op, op_func, T1) \ + HWY_RHS_SPECIAL_FLOAT_ARITH_OP(op, op_func, T1) \ + template < \ + typename T2, \ + hwy::EnableIf>() || \ + hwy::IsFloat3264>()>* = nullptr, \ + typename RawResultT = decltype(DeclVal() op DeclVal()), \ + typename ResultT = detail::NativeSpecialFloatToWrapper, \ + HWY_IF_CASTABLE(RawResultT, ResultT)> \ + static HWY_INLINE constexpr ResultT op_func(T1 a, T2 b) noexcept { \ + return static_cast(a.native op b); \ + } + +#if HWY_HAVE_SCALAR_F16_OPERATORS +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(+, operator+, float16_t) +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(-, operator-, float16_t) +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(*, operator*, float16_t) +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(/, operator/, float16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(+, +=, float16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(-, -=, float16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(*, *=, float16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(/, /=, float16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(==, operator==, float16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(!=, operator!=, float16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<, operator<, float16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<=, operator<=, float16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(>, operator>, float16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(>=, operator>=, float16_t) +#if HWY_HAVE_CXX20_THREE_WAY_COMPARE +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<=>, operator<=>, float16_t) +#endif +#endif // HWY_HAVE_SCALAR_F16_OPERATORS + +#if HWY_HAVE_SCALAR_BF16_OPERATORS +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(+, operator+, bfloat16_t) +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(-, operator-, bfloat16_t) +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(*, operator*, bfloat16_t) +HWY_RHS_SPECIAL_FLOAT_ARITH_OP(/, operator/, bfloat16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(+, +=, bfloat16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(-, -=, bfloat16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(*, *=, bfloat16_t) +HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP(/, /=, bfloat16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(==, operator==, bfloat16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(!=, operator!=, bfloat16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<, operator<, bfloat16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<=, operator<=, bfloat16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(>, operator>, bfloat16_t) +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(>=, operator>=, bfloat16_t) +#if HWY_HAVE_CXX20_THREE_WAY_COMPARE +HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<=>, operator<=>, bfloat16_t) +#endif +#endif // HWY_HAVE_SCALAR_BF16_OPERATORS + +#undef HWY_RHS_SPECIAL_FLOAT_ARITH_OP +#undef HWY_RHS_SPECIAL_FLOAT_ASSIGN_OP +#undef HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP + +#endif // HWY_HAVE_SCALAR_F16_OPERATORS || HWY_HAVE_SCALAR_BF16_OPERATORS + +//------------------------------------------------------------------------------ +// Type conversions (after IsSpecialFloat) + +HWY_API float F32FromF16Mem(const void* ptr) { + float16_t f16; + CopyBytes<2>(HWY_ASSUME_ALIGNED(ptr, 2), &f16); + return F32FromF16(f16); +} + +HWY_API float F32FromBF16Mem(const void* ptr) { + bfloat16_t bf; + CopyBytes<2>(HWY_ASSUME_ALIGNED(ptr, 2), &bf); + return F32FromBF16(bf); +} + +#if HWY_HAVE_SCALAR_F16_OPERATORS +#define HWY_BF16_TO_F16_CONSTEXPR HWY_BF16_CONSTEXPR +#else +#define HWY_BF16_TO_F16_CONSTEXPR HWY_F16_CONSTEXPR +#endif + +namespace detail { + +template +static HWY_INLINE HWY_MAYBE_UNUSED constexpr TTo ConvertScalarToResult( + hwy::SizeTag<0> /*conv_to_tag*/, TFrom in) { + return static_cast(static_cast(in)); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED HWY_F16_CONSTEXPR TTo +ConvertScalarToResult(hwy::FloatTag /*conv_to_tag*/, float in) { + return F16FromF32(in); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED HWY_F16_CONSTEXPR TTo +ConvertScalarToResult(hwy::FloatTag /*conv_to_tag*/, double in) { + return F16FromF64(in); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED HWY_BF16_CONSTEXPR TTo +ConvertScalarToResult(hwy::SpecialTag /*conv_to_tag*/, float in) { + return BF16FromF32(in); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED HWY_BF16_CONSTEXPR TTo +ConvertScalarToResult(hwy::SpecialTag /*conv_to_tag*/, double in) { + return BF16FromF64(in); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED HWY_BF16_CONSTEXPR float +ConvertScalarSpecialFloatToF32(hwy::SpecialTag /*conv_from_tag*/, TFrom in) { + return F32FromBF16(in); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED HWY_F16_CONSTEXPR float +ConvertScalarSpecialFloatToF32(hwy::SpecialTag /*conv_from_tag*/, TFrom in) { + return F32FromF16(in); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED constexpr auto +ConvertScalarSpecialFloatToF32(hwy::FloatTag /*conv_from_tag*/, TFrom in) + -> hwy::If, double>(), double, float> { + return static_cast< + hwy::If, double>(), double, float>>( + in); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED constexpr TFrom +ConvertScalarSpecialFloatToF32(hwy::SizeTag<0> /*conv_from_tag*/, TFrom in) { + return static_cast(in); +} + +} // namespace detail + +template +HWY_API constexpr TTo ConvertScalarTo(TFrom in) { + return detail::ConvertScalarToResult( + hwy::SizeTag< + (!hwy::IsSame, hwy::RemoveCvRef>() && + hwy::IsSpecialFloat()) + ? (hwy::IsSame, hwy::bfloat16_t>() ? 0x300 + : 0x200) + : 0>(), + detail::ConvertScalarSpecialFloatToF32( + hwy::SizeTag< + (!hwy::IsSame, hwy::RemoveCvRef>() && + (hwy::IsSpecialFloat() || hwy::IsSpecialFloat())) + ? (hwy::IsSpecialFloat() ? 0x300 : 0x200) + : 0>(), + static_cast(in))); +} + +//------------------------------------------------------------------------------ +// Helper functions + +template +constexpr inline T1 DivCeil(T1 a, T2 b) { +#if HWY_CXX_LANG >= 201703L + HWY_DASSERT(b != T2{0}); +#endif + return (a + b - 1) / b; +} + +// Works for any non-zero `align`; if a power of two, compiler emits ADD+AND. +constexpr inline size_t RoundUpTo(size_t what, size_t align) { + return DivCeil(what, align) * align; +} + +// Works for any `align`; if a power of two, compiler emits AND. +constexpr inline size_t RoundDownTo(size_t what, size_t align) { + return what - (what % align); +} + +namespace detail { + +// T is unsigned or T is signed and (val >> shift_amt) is an arithmetic right +// shift +template +static HWY_INLINE constexpr T ScalarShr(hwy::UnsignedTag /*type_tag*/, T val, + int shift_amt) { + return static_cast(val >> shift_amt); +} + +// T is signed and (val >> shift_amt) is a non-arithmetic right shift +template +static HWY_INLINE constexpr T ScalarShr(hwy::SignedTag /*type_tag*/, T val, + int shift_amt) { + using TU = MakeUnsigned>; + return static_cast( + (val < 0) ? static_cast( + ~(static_cast(~static_cast(val)) >> shift_amt)) + : static_cast(static_cast(val) >> shift_amt)); +} + +} // namespace detail + +// If T is an signed integer type, ScalarShr is guaranteed to perform an +// arithmetic right shift + +// Otherwise, if T is an unsigned integer type, ScalarShr is guaranteed to +// perform a logical right shift +template )> +HWY_API constexpr RemoveCvRef ScalarShr(T val, int shift_amt) { + using NonCvRefT = RemoveCvRef; + return detail::ScalarShr( + hwy::SizeTag<((IsSigned() && + (LimitsMin() >> (sizeof(T) * 8 - 1)) != + static_cast(-1)) + ? 0x100 + : 0)>(), + static_cast(val), shift_amt); +} + +// Undefined results for x == 0. +HWY_API size_t Num0BitsBelowLS1Bit_Nonzero32(const uint32_t x) { + HWY_DASSERT(x != 0); +#if HWY_COMPILER_MSVC + unsigned long index; // NOLINT + _BitScanForward(&index, x); + return index; +#else // HWY_COMPILER_MSVC + return static_cast(__builtin_ctz(x)); +#endif // HWY_COMPILER_MSVC +} + +HWY_API size_t Num0BitsBelowLS1Bit_Nonzero64(const uint64_t x) { + HWY_DASSERT(x != 0); +#if HWY_COMPILER_MSVC +#if HWY_ARCH_X86_64 + unsigned long index; // NOLINT + _BitScanForward64(&index, x); + return index; +#else // HWY_ARCH_X86_64 + // _BitScanForward64 not available + uint32_t lsb = static_cast(x & 0xFFFFFFFF); + unsigned long index; // NOLINT + if (lsb == 0) { + uint32_t msb = static_cast(x >> 32u); + _BitScanForward(&index, msb); + return 32 + index; + } else { + _BitScanForward(&index, lsb); + return index; + } +#endif // HWY_ARCH_X86_64 +#else // HWY_COMPILER_MSVC + return static_cast(__builtin_ctzll(x)); +#endif // HWY_COMPILER_MSVC +} + +// Undefined results for x == 0. +HWY_API size_t Num0BitsAboveMS1Bit_Nonzero32(const uint32_t x) { + HWY_DASSERT(x != 0); +#if HWY_COMPILER_MSVC + unsigned long index; // NOLINT + _BitScanReverse(&index, x); + return 31 - index; +#else // HWY_COMPILER_MSVC + return static_cast(__builtin_clz(x)); +#endif // HWY_COMPILER_MSVC +} + +HWY_API size_t Num0BitsAboveMS1Bit_Nonzero64(const uint64_t x) { + HWY_DASSERT(x != 0); +#if HWY_COMPILER_MSVC +#if HWY_ARCH_X86_64 + unsigned long index; // NOLINT + _BitScanReverse64(&index, x); + return 63 - index; +#else // HWY_ARCH_X86_64 + // _BitScanReverse64 not available + const uint32_t msb = static_cast(x >> 32u); + unsigned long index; // NOLINT + if (msb == 0) { + const uint32_t lsb = static_cast(x & 0xFFFFFFFF); + _BitScanReverse(&index, lsb); + return 63 - index; + } else { + _BitScanReverse(&index, msb); + return 31 - index; + } +#endif // HWY_ARCH_X86_64 +#else // HWY_COMPILER_MSVC + return static_cast(__builtin_clzll(x)); +#endif // HWY_COMPILER_MSVC +} + +template ), + HWY_IF_T_SIZE_ONE_OF(RemoveCvRef, (1 << 1) | (1 << 2) | (1 << 4))> +HWY_API size_t PopCount(T x) { + uint32_t u32_x = static_cast( + static_cast)>>(x)); + +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANG + return static_cast(__builtin_popcountl(u32_x)); +#elif HWY_COMPILER_MSVC && HWY_ARCH_X86_32 && defined(__AVX__) + return static_cast(_mm_popcnt_u32(u32_x)); +#else + u32_x -= ((u32_x >> 1) & 0x55555555u); + u32_x = (((u32_x >> 2) & 0x33333333u) + (u32_x & 0x33333333u)); + u32_x = (((u32_x >> 4) + u32_x) & 0x0F0F0F0Fu); + u32_x += (u32_x >> 8); + u32_x += (u32_x >> 16); + return static_cast(u32_x & 0x3Fu); +#endif +} + +template ), + HWY_IF_T_SIZE(RemoveCvRef, 8)> +HWY_API size_t PopCount(T x) { + uint64_t u64_x = static_cast( + static_cast)>>(x)); + +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANG + return static_cast(__builtin_popcountll(u64_x)); +#elif HWY_COMPILER_MSVC && HWY_ARCH_X86_64 && defined(__AVX__) + return _mm_popcnt_u64(u64_x); +#elif HWY_COMPILER_MSVC && HWY_ARCH_X86_32 && defined(__AVX__) + return _mm_popcnt_u32(static_cast(u64_x & 0xFFFFFFFFu)) + + _mm_popcnt_u32(static_cast(u64_x >> 32)); +#else + u64_x -= ((u64_x >> 1) & 0x5555555555555555ULL); + u64_x = (((u64_x >> 2) & 0x3333333333333333ULL) + + (u64_x & 0x3333333333333333ULL)); + u64_x = (((u64_x >> 4) + u64_x) & 0x0F0F0F0F0F0F0F0FULL); + u64_x += (u64_x >> 8); + u64_x += (u64_x >> 16); + u64_x += (u64_x >> 32); + return static_cast(u64_x & 0x7Fu); +#endif +} + +// Skip HWY_API due to GCC "function not considered for inlining". Previously +// such errors were caused by underlying type mismatches, but it's not clear +// what is still mismatched despite all the casts. +template +/*HWY_API*/ constexpr size_t FloorLog2(TI x) { + return x == TI{1} + ? 0 + : static_cast(FloorLog2(static_cast(x >> 1)) + 1); +} + +template +/*HWY_API*/ constexpr size_t CeilLog2(TI x) { + return x == TI{1} + ? 0 + : static_cast(FloorLog2(static_cast(x - 1)) + 1); +} + +template +HWY_INLINE constexpr T AddWithWraparound(T t, T2 increment) { + return t + static_cast(increment); +} + +template +HWY_INLINE constexpr T AddWithWraparound(T t, T2 increment) { + return ConvertScalarTo(ConvertScalarTo(t) + + ConvertScalarTo(increment)); +} + +template +HWY_INLINE constexpr T AddWithWraparound(T t, T2 n) { + using TU = MakeUnsigned; + // Sub-int types would promote to int, not unsigned, which would trigger + // warnings, so first promote to the largest unsigned type. Due to + // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=87519, which affected GCC 8 + // until fixed in 9.3, we use built-in types rather than uint64_t. + return static_cast(static_cast( + static_cast(static_cast(t) + + static_cast(n)) & + uint64_t{hwy::LimitsMax()})); +} + +#if HWY_COMPILER_MSVC && HWY_ARCH_X86_64 +#pragma intrinsic(_mul128) +#pragma intrinsic(_umul128) +#endif + +// 64 x 64 = 128 bit multiplication +HWY_API uint64_t Mul128(uint64_t a, uint64_t b, uint64_t* HWY_RESTRICT upper) { +#if defined(__SIZEOF_INT128__) + __uint128_t product = + static_cast<__uint128_t>(a) * static_cast<__uint128_t>(b); + *upper = static_cast(product >> 64); + return static_cast(product & 0xFFFFFFFFFFFFFFFFULL); +#elif HWY_COMPILER_MSVC && HWY_ARCH_X86_64 + return _umul128(a, b, upper); +#else + constexpr uint64_t kLo32 = 0xFFFFFFFFU; + const uint64_t lo_lo = (a & kLo32) * (b & kLo32); + const uint64_t hi_lo = (a >> 32) * (b & kLo32); + const uint64_t lo_hi = (a & kLo32) * (b >> 32); + const uint64_t hi_hi = (a >> 32) * (b >> 32); + const uint64_t t = (lo_lo >> 32) + (hi_lo & kLo32) + lo_hi; + *upper = (hi_lo >> 32) + (t >> 32) + hi_hi; + return (t << 32) | (lo_lo & kLo32); +#endif +} + +HWY_API int64_t Mul128(int64_t a, int64_t b, int64_t* HWY_RESTRICT upper) { +#if defined(__SIZEOF_INT128__) + __int128_t product = static_cast<__int128_t>(a) * static_cast<__int128_t>(b); + *upper = static_cast(product >> 64); + return static_cast(product & 0xFFFFFFFFFFFFFFFFULL); +#elif HWY_COMPILER_MSVC && HWY_ARCH_X86_64 + return _mul128(a, b, upper); +#else + uint64_t unsigned_upper; + const int64_t lower = static_cast(Mul128( + static_cast(a), static_cast(b), &unsigned_upper)); + *upper = static_cast( + unsigned_upper - + (static_cast(ScalarShr(a, 63)) & static_cast(b)) - + (static_cast(ScalarShr(b, 63)) & static_cast(a))); + return lower; +#endif +} + +// Precomputation for fast n / divisor and n % divisor, where n is a variable +// and divisor is unchanging but unknown at compile-time. +class Divisor { + public: + explicit Divisor(uint32_t divisor) : divisor_(divisor) { + if (divisor <= 1) return; + + const uint32_t len = + static_cast(31 - Num0BitsAboveMS1Bit_Nonzero32(divisor - 1)); + const uint64_t u_hi = (2ULL << len) - divisor; + const uint32_t q = Truncate((u_hi << 32) / divisor); + + mul_ = q + 1; + shift1_ = 1; + shift2_ = len; + } + + uint32_t GetDivisor() const { return divisor_; } + + // Returns n / divisor_. + uint32_t Divide(uint32_t n) const { + const uint64_t mul = mul_; + const uint32_t t = Truncate((mul * n) >> 32); + return (t + ((n - t) >> shift1_)) >> shift2_; + } + + // Returns n % divisor_. + uint32_t Remainder(uint32_t n) const { return n - (Divide(n) * divisor_); } + + private: + static uint32_t Truncate(uint64_t x) { + return static_cast(x & 0xFFFFFFFFu); + } + + uint32_t divisor_; + uint32_t mul_ = 1; + uint32_t shift1_ = 0; + uint32_t shift2_ = 0; +}; + +#ifndef HWY_HAVE_DIV128 // allow override +// Exclude clang-cl because it calls __divti3 from clang_rt.builtins-x86_64, +// which is not linked in. +#if (HWY_COMPILER_MSVC >= 1920 && HWY_ARCH_X86_64) || \ + (defined(__SIZEOF_INT128__) && !HWY_COMPILER_CLANGCL) +#define HWY_HAVE_DIV128 1 +#else +#define HWY_HAVE_DIV128 0 +#endif +#endif // HWY_HAVE_DIV128 + +// Divisor64 can precompute the multiplicative inverse. +#if HWY_HAVE_DIV128 + +#if HWY_COMPILER_MSVC >= 1920 && HWY_ARCH_X86_64 +#pragma intrinsic(_udiv128) +#pragma intrinsic(__umulh) +#endif + +// As above, but for 64-bit divisors: more expensive to compute and initialize. +class Divisor64 { + public: + explicit Divisor64(uint64_t divisor) : divisor_(divisor) { + if (divisor <= 1) return; + + const uint64_t len = + static_cast(63 - Num0BitsAboveMS1Bit_Nonzero64(divisor - 1)); + const uint64_t u_hi = (2ULL << len) - divisor; + const uint64_t q = Div128(u_hi, divisor); + + mul_ = q + 1; + shift1_ = 1; + shift2_ = len; + } + + uint64_t GetDivisor() const { return divisor_; } + + // Returns n / divisor_. + uint64_t Divide(uint64_t n) const { + const uint64_t t = MulHigh(mul_, n); + return (t + ((n - t) >> shift1_)) >> shift2_; + } + + // Returns n % divisor_. + uint64_t Remainder(uint64_t n) const { return n - (Divide(n) * divisor_); } + + private: + uint64_t divisor_; + + static uint64_t Div128(uint64_t hi, uint64_t div) { +#if HWY_COMPILER_MSVC >= 1920 && HWY_ARCH_X86_64 + unsigned __int64 remainder; // unused + return _udiv128(hi, uint64_t{0}, div, &remainder); +#else + using u128 = unsigned __int128; + const u128 hi128 = static_cast(hi) << 64; + return static_cast(hi128 / static_cast(div)); +#endif + } + + static uint64_t MulHigh(uint64_t a, uint64_t b) { +#if HWY_COMPILER_MSVC >= 1920 && HWY_ARCH_X86_64 + return __umulh(a, b); +#else + using u128 = unsigned __int128; + const u128 a128 = static_cast(a); + const u128 b128 = static_cast(b); + return static_cast((a128 * b128) >> 64); +#endif + } + + uint64_t mul_ = 1; + uint64_t shift1_ = 0; + uint64_t shift2_ = 0; +}; +#else +// No Div128 available, use built-in 64-bit division on each call. +class Divisor64 { + public: + explicit Divisor64(uint64_t divisor) : divisor_(divisor) {} + + uint64_t GetDivisor() const { return divisor_; } + + uint64_t Divide(uint64_t n) const { return n / divisor_; } + uint64_t Remainder(uint64_t n) const { return n % divisor_; } + + private: + uint64_t divisor_; +}; +#endif // HWY_HAVE_DIV128 + +namespace detail { + +template +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR T ScalarAbs(hwy::FloatTag /*tag*/, + T val) { + using TU = MakeUnsigned; + return BitCastScalar( + static_cast(BitCastScalar(val) & (~SignMask()))); +} + +template +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR T +ScalarAbs(hwy::SpecialTag /*tag*/, T val) { + return ScalarAbs(hwy::FloatTag(), val); +} + +template +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR T +ScalarAbs(hwy::SignedTag /*tag*/, T val) { + using TU = MakeUnsigned; + return (val < T{0}) ? static_cast(TU{0} - static_cast(val)) : val; +} + +template +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR T +ScalarAbs(hwy::UnsignedTag /*tag*/, T val) { + return val; +} + +} // namespace detail + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR RemoveCvRef ScalarAbs(T val) { + using TVal = MakeLaneTypeIfInteger< + detail::NativeSpecialFloatToWrapper>>; + return detail::ScalarAbs(hwy::TypeTag(), static_cast(val)); +} + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR bool ScalarIsNaN(T val) { + using TF = detail::NativeSpecialFloatToWrapper>; + using TU = MakeUnsigned; + return (BitCastScalar(ScalarAbs(val)) > ExponentMask()); +} + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR bool ScalarIsInf(T val) { + using TF = detail::NativeSpecialFloatToWrapper>; + using TU = MakeUnsigned; + return static_cast(BitCastScalar(static_cast(val)) << 1) == + static_cast(MaxExponentTimes2()); +} + +namespace detail { + +template +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bool ScalarIsFinite( + hwy::FloatTag /*tag*/, T val) { + using TU = MakeUnsigned; + return (BitCastScalar(hwy::ScalarAbs(val)) < ExponentMask()); +} + +template +static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bool ScalarIsFinite( + hwy::NonFloatTag /*tag*/, T /*val*/) { + // Integer values are always finite + return true; +} + +} // namespace detail + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR bool ScalarIsFinite(T val) { + using TVal = MakeLaneTypeIfInteger< + detail::NativeSpecialFloatToWrapper>>; + return detail::ScalarIsFinite(hwy::IsFloatTag(), + static_cast(val)); +} + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR RemoveCvRef ScalarCopySign(T magn, + T sign) { + using TF = RemoveCvRef>>; + using TU = MakeUnsigned; + return BitCastScalar(static_cast( + (BitCastScalar(static_cast(magn)) & (~SignMask())) | + (BitCastScalar(static_cast(sign)) & SignMask()))); +} + +template +HWY_API HWY_BITCASTSCALAR_CONSTEXPR bool ScalarSignBit(T val) { + using TVal = MakeLaneTypeIfInteger< + detail::NativeSpecialFloatToWrapper>>; + using TU = MakeUnsigned; + return ((BitCastScalar(static_cast(val)) & SignMask()) != 0); +} + +// Prevents the compiler from eliding the computations that led to "output". +#if HWY_ARCH_PPC && (HWY_COMPILER_GCC || HWY_COMPILER_CLANG) && \ + !defined(_SOFT_FLOAT) +// Workaround to avoid test failures on PPC if compiled with Clang +template +HWY_API void PreventElision(T&& output) { + asm volatile("" : "+f"(output)::"memory"); +} +template +HWY_API void PreventElision(T&& output) { + asm volatile("" : "+d"(output)::"memory"); +} +template +HWY_API void PreventElision(T&& output) { + asm volatile("" : "+r"(output)::"memory"); +} +#else +template +HWY_API void PreventElision(T&& output) { +#if HWY_COMPILER_MSVC + // MSVC does not support inline assembly anymore (and never supported GCC's + // RTL constraints). Self-assignment with #pragma optimize("off") might be + // expected to prevent elision, but it does not with MSVC 2015. Type-punning + // with volatile pointers generates inefficient code on MSVC 2017. + static std::atomic> sink; + sink.store(output, std::memory_order_relaxed); +#else + // Works by indicating to the compiler that "output" is being read and + // modified. The +r constraint avoids unnecessary writes to memory, but only + // works for built-in types (typically FuncOutput). + asm volatile("" : "+r"(output) : : "memory"); +#endif +} +#endif + +} // namespace hwy + +#endif // HIGHWAY_HWY_BASE_H_ diff --git a/lib/highway/hwy/bit_set.h b/lib/highway/hwy/bit_set.h new file mode 100644 index 00000000000..94251d133c8 --- /dev/null +++ b/lib/highway/hwy/bit_set.h @@ -0,0 +1,410 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_BIT_SET_H_ +#define HIGHWAY_HWY_BIT_SET_H_ + +// Various BitSet for 64, up to 4096, or any number of bits. + +#include + +#include + +#include "hwy/base.h" + +namespace hwy { + +// 64-bit specialization of `std::bitset`, which lacks `Foreach`. +class BitSet64 { + public: + constexpr size_t MaxSize() const { return 64; } + + // No harm if `i` is already set. + void Set(size_t i) { + HWY_DASSERT(i < 64); + bits_ |= (1ULL << i); + HWY_DASSERT(Get(i)); + } + + // Equivalent to Set(i) for i in [0, 64) where (bits >> i) & 1. This does + // not clear any existing bits. + void SetNonzeroBitsFrom64(uint64_t bits) { bits_ |= bits; } + + void Clear(size_t i) { + HWY_DASSERT(i < 64); + bits_ &= ~(1ULL << i); + } + + bool Get(size_t i) const { + HWY_DASSERT(i < 64); + return (bits_ & (1ULL << i)) != 0; + } + + // Returns true if Get(i) would return true for any i in [0, 64). + bool Any() const { return bits_ != 0; } + + // Returns true if Get(i) would return true for all i in [0, 64). + bool All() const { return bits_ == ~uint64_t{0}; } + + // Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`! + size_t First() const { + HWY_DASSERT(Any()); + return Num0BitsBelowLS1Bit_Nonzero64(bits_); + } + + // Returns lowest i such that `!Get(i)`. Caller must first ensure `!All()`! + size_t First0() const { + HWY_DASSERT(!All()); + return Num0BitsBelowLS1Bit_Nonzero64(~bits_); + } + + // Returns uint64_t(Get(i)) << i for i in [0, 64). + uint64_t Get64() const { return bits_; } + + // Calls `func(i)` for each `i` in the set. It is safe for `func` to modify + // the set, but the current Foreach call is unaffected. + template + void Foreach(const Func& func) const { + uint64_t remaining_bits = bits_; + while (remaining_bits != 0) { + const size_t i = Num0BitsBelowLS1Bit_Nonzero64(remaining_bits); + remaining_bits &= remaining_bits - 1; // clear LSB + func(i); + } + } + + size_t Count() const { return PopCount(bits_); } + + private: + uint64_t bits_ = 0; +}; + +// Any number of bits, flat array. +template +class BitSet { + static_assert(kMaxSize != 0, "BitSet requires non-zero size"); + + public: + constexpr size_t MaxSize() const { return kMaxSize; } + + // No harm if `i` is already set. + void Set(size_t i) { + HWY_DASSERT(i < kMaxSize); + const size_t idx = i / 64; + const size_t mod = i % 64; + bits_[idx].Set(mod); + } + + void Clear(size_t i) { + HWY_DASSERT(i < kMaxSize); + const size_t idx = i / 64; + const size_t mod = i % 64; + bits_[idx].Clear(mod); + HWY_DASSERT(!Get(i)); + } + + bool Get(size_t i) const { + HWY_DASSERT(i < kMaxSize); + const size_t idx = i / 64; + const size_t mod = i % 64; + return bits_[idx].Get(mod); + } + + // Returns true if Get(i) would return true for any i in [0, kMaxSize). + bool Any() const { + for (const BitSet64& bits : bits_) { + if (bits.Any()) return true; + } + return false; + } + + // Returns true if Get(i) would return true for all i in [0, kMaxSize). + bool All() const { + for (size_t idx = 0; idx < kNum64 - 1; ++idx) { + if (!bits_[idx].All()) return false; + } + + constexpr size_t kRemainder = kMaxSize % 64; + if (kRemainder == 0) { + return bits_[kNum64 - 1].All(); + } + return bits_[kNum64 - 1].Count() == kRemainder; + } + + // Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`! + size_t First() const { + HWY_DASSERT(Any()); + for (size_t idx = 0;; ++idx) { + HWY_DASSERT(idx < kNum64); + if (bits_[idx].Any()) return idx * 64 + bits_[idx].First(); + } + } + + // Returns lowest i such that `!Get(i)`. Caller must first ensure `All()`! + size_t First0() const { + HWY_DASSERT(!All()); + for (size_t idx = 0;; ++idx) { + HWY_DASSERT(idx < kNum64); + if (!bits_[idx].All()) { + const size_t first0 = idx * 64 + bits_[idx].First0(); + HWY_DASSERT(first0 < kMaxSize); + return first0; + } + } + } + + // Calls `func(i)` for each `i` in the set. It is safe for `func` to modify + // the set, but the current Foreach call is only affected if changing one of + // the not yet visited BitSet64. + template + void Foreach(const Func& func) const { + for (size_t idx = 0; idx < kNum64; ++idx) { + bits_[idx].Foreach([idx, &func](size_t mod) { func(idx * 64 + mod); }); + } + } + + size_t Count() const { + size_t total = 0; + for (const BitSet64& bits : bits_) { + total += bits.Count(); + } + return total; + } + + private: + static constexpr size_t kNum64 = DivCeil(kMaxSize, size_t{64}); + BitSet64 bits_[kNum64]; +}; + +// Any number of bits, flat array, atomic updates to the u64. +template +class AtomicBitSet { + static_assert(kMaxSize != 0, "AtomicBitSet requires non-zero size"); + + // Bits may signal something to other threads, hence relaxed is insufficient. + // Acq/Rel ensures a happens-before relationship. + static constexpr auto kAcq = std::memory_order_acquire; + static constexpr auto kRel = std::memory_order_release; + + public: + constexpr size_t MaxSize() const { return kMaxSize; } + + // No harm if `i` is already set. + void Set(size_t i) { + HWY_DASSERT(i < kMaxSize); + const size_t idx = i / 64; + const size_t mod = i % 64; + bits_[idx].fetch_or(1ULL << mod, kRel); + } + + void Clear(size_t i) { + HWY_DASSERT(i < kMaxSize); + const size_t idx = i / 64; + const size_t mod = i % 64; + bits_[idx].fetch_and(~(1ULL << mod), kRel); + HWY_DASSERT(!Get(i)); + } + + bool Get(size_t i) const { + HWY_DASSERT(i < kMaxSize); + const size_t idx = i / 64; + const size_t mod = i % 64; + return ((bits_[idx].load(kAcq) & (1ULL << mod))) != 0; + } + + // Returns true if Get(i) would return true for any i in [0, kMaxSize). + bool Any() const { + for (const std::atomic& bits : bits_) { + if (bits.load(kAcq)) return true; + } + return false; + } + + // Returns true if Get(i) would return true for all i in [0, kMaxSize). + bool All() const { + for (size_t idx = 0; idx < kNum64 - 1; ++idx) { + if (bits_[idx].load(kAcq) != ~uint64_t{0}) return false; + } + + constexpr size_t kRemainder = kMaxSize % 64; + const uint64_t last_bits = bits_[kNum64 - 1].load(kAcq); + if (kRemainder == 0) { + return last_bits == ~uint64_t{0}; + } + return PopCount(last_bits) == kRemainder; + } + + // Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`! + size_t First() const { + HWY_DASSERT(Any()); + for (size_t idx = 0;; ++idx) { + HWY_DASSERT(idx < kNum64); + const uint64_t bits = bits_[idx].load(kAcq); + if (bits != 0) { + return idx * 64 + Num0BitsBelowLS1Bit_Nonzero64(bits); + } + } + } + + // Returns lowest i such that `!Get(i)`. Caller must first ensure `!All()`! + size_t First0() const { + HWY_DASSERT(!All()); + for (size_t idx = 0;; ++idx) { + HWY_DASSERT(idx < kNum64); + const uint64_t inv_bits = ~bits_[idx].load(kAcq); + if (inv_bits != 0) { + const size_t first0 = + idx * 64 + Num0BitsBelowLS1Bit_Nonzero64(inv_bits); + HWY_DASSERT(first0 < kMaxSize); + return first0; + } + } + } + + // Calls `func(i)` for each `i` in the set. It is safe for `func` to modify + // the set, but the current Foreach call is only affected if changing one of + // the not yet visited uint64_t. + template + void Foreach(const Func& func) const { + for (size_t idx = 0; idx < kNum64; ++idx) { + uint64_t remaining_bits = bits_[idx].load(kAcq); + while (remaining_bits != 0) { + const size_t i = Num0BitsBelowLS1Bit_Nonzero64(remaining_bits); + remaining_bits &= remaining_bits - 1; // clear LSB + func(idx * 64 + i); + } + } + } + + size_t Count() const { + size_t total = 0; + for (const std::atomic& bits : bits_) { + total += PopCount(bits.load(kAcq)); + } + return total; + } + + private: + static constexpr size_t kNum64 = DivCeil(kMaxSize, size_t{64}); + std::atomic bits_[kNum64] = {}; +}; + +// Two-level bitset for up to `kMaxSize` <= 4096 values. The iterators +// (`Any/First/Foreach/Count`) are more efficient than `BitSet` for sparse sets. +// This comes at the cost of slightly slower mutators (`Set/Clear`). +template +class BitSet4096 { + static_assert(kMaxSize != 0, "BitSet4096 requires non-zero size"); + + public: + constexpr size_t MaxSize() const { return kMaxSize; } + + // No harm if `i` is already set. + void Set(size_t i) { + HWY_DASSERT(i < kMaxSize); + const size_t idx = i / 64; + const size_t mod = i % 64; + bits_[idx].Set(mod); + nonzero_.Set(idx); + HWY_DASSERT(Get(i)); + } + + // Equivalent to Set(i) for i in [0, 64) where (bits >> i) & 1. This does + // not clear any existing bits. + void SetNonzeroBitsFrom64(uint64_t bits) { + bits_[0].SetNonzeroBitsFrom64(bits); + if (bits) nonzero_.Set(0); + } + + void Clear(size_t i) { + HWY_DASSERT(i < kMaxSize); + const size_t idx = i / 64; + const size_t mod = i % 64; + bits_[idx].Clear(mod); + if (!bits_[idx].Any()) { + nonzero_.Clear(idx); + } + HWY_DASSERT(!Get(i)); + } + + bool Get(size_t i) const { + HWY_DASSERT(i < kMaxSize); + const size_t idx = i / 64; + const size_t mod = i % 64; + return bits_[idx].Get(mod); + } + + // Returns true if `Get(i)` would return true for any i in [0, kMaxSize). + bool Any() const { return nonzero_.Any(); } + + // Returns true if `Get(i)` would return true for all i in [0, kMaxSize). + bool All() const { + // Do not check `nonzero_.All()` - that only works if `kMaxSize` is 4096. + if (nonzero_.Count() != kNum64) return false; + return Count() == kMaxSize; + } + + // Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`! + size_t First() const { + HWY_DASSERT(Any()); + const size_t idx = nonzero_.First(); + return idx * 64 + bits_[idx].First(); + } + + // Returns lowest i such that `!Get(i)`. Caller must first ensure `!All()`! + size_t First0() const { + HWY_DASSERT(!All()); + // It is likely not worthwhile to have a separate `BitSet64` for `not_all_`, + // hence iterate over all u64. + for (size_t idx = 0;; ++idx) { + HWY_DASSERT(idx < kNum64); + if (!bits_[idx].All()) { + const size_t first0 = idx * 64 + bits_[idx].First0(); + HWY_DASSERT(first0 < kMaxSize); + return first0; + } + } + } + + // Returns uint64_t(Get(i)) << i for i in [0, 64). + uint64_t Get64() const { return bits_[0].Get64(); } + + // Calls `func(i)` for each `i` in the set. It is safe for `func` to modify + // the set, but the current Foreach call is only affected if changing one of + // the not yet visited BitSet64 for which Any() is true. + template + void Foreach(const Func& func) const { + nonzero_.Foreach([&func, this](size_t idx) { + bits_[idx].Foreach([idx, &func](size_t mod) { func(idx * 64 + mod); }); + }); + } + + size_t Count() const { + size_t total = 0; + nonzero_.Foreach( + [&total, this](size_t idx) { total += bits_[idx].Count(); }); + return total; + } + + private: + static_assert(kMaxSize <= 64 * 64, "One BitSet64 insufficient"); + static constexpr size_t kNum64 = DivCeil(kMaxSize, size_t{64}); + BitSet64 nonzero_; + BitSet64 bits_[kNum64]; +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_BIT_SET_H_ diff --git a/lib/highway/hwy/cache_control.h b/lib/highway/hwy/cache_control.h new file mode 100644 index 00000000000..90743cd3f28 --- /dev/null +++ b/lib/highway/hwy/cache_control.h @@ -0,0 +1,143 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_CACHE_CONTROL_H_ +#define HIGHWAY_HWY_CACHE_CONTROL_H_ + +#include "hwy/aligned_allocator.h" // HWY_ALIGNMENT +#include "hwy/base.h" + +// Requires SSE2; fails to compile on 32-bit Clang 7 (see +// https://github.com/gperftools/gperftools/issues/946). +#if !defined(__SSE2__) || (HWY_COMPILER_CLANG && HWY_ARCH_X86_32) +#undef HWY_DISABLE_CACHE_CONTROL +#define HWY_DISABLE_CACHE_CONTROL +#endif + +#ifndef HWY_DISABLE_CACHE_CONTROL +// intrin.h is sufficient on MSVC and already included by base.h. +#if HWY_ARCH_X86 && !HWY_COMPILER_MSVC +#include // SSE2 +#include // _mm_prefetch +#elif HWY_ARCH_ARM_A64 +#include +#endif +#endif // HWY_DISABLE_CACHE_CONTROL + +namespace hwy { + +// Even if N*sizeof(T) is smaller, Stream may write a multiple of this size. +#define HWY_STREAM_MULTIPLE 16 + +// The following functions may also require an attribute. +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) && !HWY_COMPILER_MSVC +#define HWY_ATTR_CACHE __attribute__((target("sse2"))) +#else +#define HWY_ATTR_CACHE +#endif + +// Windows.h #defines this, which causes infinite recursion. Temporarily +// undefine to avoid conflict with our function. +// TODO(janwas): remove when this function is removed. +#pragma push_macro("LoadFence") +#undef LoadFence + +// Delays subsequent loads until prior loads are visible. Beware of potentially +// differing behavior across architectures and vendors: on Intel but not +// AMD CPUs, also serves as a full fence (waits for all prior instructions to +// complete). +HWY_INLINE HWY_ATTR_CACHE void LoadFence() { +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) + _mm_lfence(); +#endif +} + +// TODO(janwas): remove when this function is removed. (See above.) +#pragma pop_macro("LoadFence") + +// Overwrites "to" while attempting to bypass the cache (read-for-ownership). +// Both pointers must be aligned. +static HWY_INLINE void StreamCacheLine(const uint64_t* HWY_RESTRICT from, + uint64_t* HWY_RESTRICT to) { + HWY_DASSERT(IsAligned(from)); + HWY_DASSERT(IsAligned(to)); +#if HWY_COMPILER_CLANG && !defined(HWY_DISABLE_CACHE_CONTROL) + for (size_t i = 0; i < HWY_ALIGNMENT / sizeof(uint64_t); ++i) { + __builtin_nontemporal_store(from[i], to + i); + } +#else + hwy::CopyBytes(from, to, HWY_ALIGNMENT); +#endif +} + +// Ensures values written by previous `Stream` calls are visible on the current +// core. This is NOT sufficient for synchronizing across cores; when `Stream` +// outputs are to be consumed by other core(s), the producer must publish +// availability (e.g. via mutex or atomic_flag) after `FlushStream`. +HWY_INLINE HWY_ATTR_CACHE void FlushStream() { +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) + _mm_sfence(); +#endif +} + +// Optionally begins loading the cache line containing "p" to reduce latency of +// subsequent actual loads. +template +HWY_INLINE HWY_ATTR_CACHE void Prefetch(const T* p) { + (void)p; +#ifndef HWY_DISABLE_CACHE_CONTROL +// Use _mm_prefetch on x86/x64, except when clang-cl is compiled with -mno-mmx. +#if HWY_ARCH_X86 && !(HWY_COMPILER_CLANGCL && !defined(__MMX__)) + _mm_prefetch(reinterpret_cast(p), _MM_HINT_T0); +#elif HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL // includes clang + // Hint=0 (NTA) behavior differs, but skipping outer caches is probably not + // desirable, so use the default 3 (keep in caches). + __builtin_prefetch(p, /*write=*/0, /*hint=*/3); +#endif +#endif // HWY_DISABLE_CACHE_CONTROL +} + +// Invalidates and flushes the cache line containing "p", if possible. +HWY_INLINE HWY_ATTR_CACHE void FlushCacheline(const void* p) { +#if HWY_ARCH_X86 && !defined(HWY_DISABLE_CACHE_CONTROL) + _mm_clflush(p); +#else + (void)p; +#endif +} + +// Hints that we are inside a spin loop and potentially reduces power +// consumption and coherency traffic. For example, x86 avoids multiple +// outstanding load requests, which reduces the memory order violation penalty +// when exiting the loop. +HWY_INLINE HWY_ATTR_CACHE void Pause() { +#ifndef HWY_DISABLE_CACHE_CONTROL +#if HWY_ARCH_X86 + _mm_pause(); +#elif HWY_ARCH_ARM_A64 && HWY_COMPILER_CLANG + // This is documented in ACLE and the YIELD instruction is also available in + // Armv7, but the intrinsic is broken for Armv7 clang, hence A64 only. + __yield(); +#elif HWY_ARCH_ARM && HWY_COMPILER_GCC // includes clang + __asm__ volatile("yield" ::: "memory"); +#elif HWY_ARCH_PPC && HWY_COMPILER_GCC // includes clang + __asm__ volatile("or 27,27,27" ::: "memory"); +#endif +#endif // HWY_DISABLE_CACHE_CONTROL +} + +} // namespace hwy + +#endif // HIGHWAY_HWY_CACHE_CONTROL_H_ diff --git a/lib/highway/hwy/contrib/algo/copy-inl.h b/lib/highway/hwy/contrib/algo/copy-inl.h new file mode 100644 index 00000000000..9945132f729 --- /dev/null +++ b/lib/highway/hwy/contrib/algo/copy-inl.h @@ -0,0 +1,145 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target include guard +#if defined(HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_ +#endif + +#include +#include + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// These functions avoid having to write a loop plus remainder handling in the +// (unfortunately still common) case where arrays are not aligned/padded. If the +// inputs are known to be aligned/padded, it is more efficient to write a single +// loop using Load(). We do not provide a CopyAlignedPadded because it +// would be more verbose than such a loop. + +// Fills `to`[0, `count`) with `value`. +template > +void Fill(D d, T value, size_t count, T* HWY_RESTRICT to) { + const size_t N = Lanes(d); + const Vec v = Set(d, value); + + size_t idx = 0; + if (count >= N) { + for (; idx <= count - N; idx += N) { + StoreU(v, d, to + idx); + } + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + SafeFillN(remaining, value, d, to + idx); +} + +// Copies `from`[0, `count`) to `to`, which must not overlap `from`. +template > +void Copy(D d, const T* HWY_RESTRICT from, size_t count, T* HWY_RESTRICT to) { + const size_t N = Lanes(d); + + size_t idx = 0; + if (count >= N) { + for (; idx <= count - N; idx += N) { + const Vec v = LoadU(d, from + idx); + StoreU(v, d, to + idx); + } + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + SafeCopyN(remaining, d, from + idx, to + idx); +} + +// For idx in [0, count) in ascending order, appends `from[idx]` to `to` if the +// corresponding mask element of `func(d, v)` is true. Returns the STL-style end +// of the newly written elements in `to`. +// +// `func` is either a functor with a templated operator()(d, v) returning a +// mask, or a generic lambda if using C++14. Due to apparent limitations of +// Clang on Windows, it is currently necessary to add HWY_ATTR before the +// opening { of the lambda to avoid errors about "function .. requires target". +// +// NOTE: this is only supported for 16-, 32- or 64-bit types. +// NOTE: Func may be called a second time for elements it has already seen, but +// these elements will not be written to `to` again. +template > +T* CopyIf(D d, const T* HWY_RESTRICT from, size_t count, T* HWY_RESTRICT to, + const Func& func) { + const size_t N = Lanes(d); + + size_t idx = 0; + if (count >= N) { + for (; idx <= count - N; idx += N) { + const Vec v = LoadU(d, from + idx); + to += CompressBlendedStore(v, func(d, v), d, to); + } + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return to; + +#if HWY_MEM_OPS_MIGHT_FAULT + // Proceed one by one. + const CappedTag d1; + for (; idx < count; ++idx) { + using V1 = Vec; + // Workaround for -Waggressive-loop-optimizations on GCC 8 + // (iteration 2305843009213693951 invokes undefined behavior for T=i64) + const uintptr_t addr = reinterpret_cast(from); + const T* HWY_RESTRICT from_idx = + reinterpret_cast(addr + (idx * sizeof(T))); + const V1 v = LoadU(d1, from_idx); + // Avoid storing to `to` unless we know it should be kept - otherwise, we + // might overrun the end if it was allocated for the exact count. + if (CountTrue(d1, func(d1, v)) == 0) continue; + StoreU(v, d1, to); + to += 1; + } +#else + // Start index of the last unaligned whole vector, ending at the array end. + const size_t last = count - N; + // Number of elements before `from` or already written. + const size_t invalid = idx - last; + HWY_DASSERT(0 != invalid && invalid < N); + const Mask mask = Not(FirstN(d, invalid)); + const Vec v = MaskedLoad(mask, d, from + last); + to += CompressBlendedStore(v, And(mask, func(d, v)), d, to); +#endif + return to; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_ALGO_COPY_INL_H_ diff --git a/lib/highway/hwy/contrib/algo/count-inl.h b/lib/highway/hwy/contrib/algo/count-inl.h new file mode 100644 index 00000000000..c7d5a895bb2 --- /dev/null +++ b/lib/highway/hwy/contrib/algo/count-inl.h @@ -0,0 +1,393 @@ +// Copyright 2026 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +// Per-target include guard +#if defined(HIGHWAY_HWY_CONTRIB_ALGO_COUNT_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef HIGHWAY_HWY_CONTRIB_ALGO_COUNT_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_ALGO_COUNT_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_ALGO_COUNT_INL_H_ +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Returns the number of elements in `in[0, count)` equal to `value`. +template > +size_t Count(D d, T value, const T* HWY_RESTRICT in, size_t count) { + const size_t N = Lanes(d); + using V = Vec; + const V broadcasted = Set(d, value); + const RebindToSigned di; + using TI = TFromD; + using VI = Vec; + + size_t total = 0; + size_t i = 0; + + // Min 4 lanes needed for two pairwise widenings, 8->16->32 + if constexpr (sizeof(T) == 1 && HWY_MAX_LANES_D(D) >= 4) { + const VI k1 = Set(di, TI{1}); + const RebindToUnsigned du; + const RepartitionToWide di16; + const Repartition di32; + auto wide_sum = Zero(di32); + + if (count >= 4 * N && N >= 4) { + while (i <= count - 4 * N) { + VI acc0 = Zero(di); + VI acc1 = Zero(di); + VI acc2 = Zero(di); + VI acc3 = Zero(di); + const size_t cap = HWY_MIN(i + 128 * 4 * N, count); + + if constexpr (HWY_NATIVE_MASK) { + for (; i <= cap - 4 * N; i += 4 * N) { + const auto m0 = RebindMask(di, Eq(broadcasted, LoadU(d, in + i))); + const auto m1 = + RebindMask(di, Eq(broadcasted, LoadU(d, in + i + N))); + const auto m2 = + RebindMask(di, Eq(broadcasted, LoadU(d, in + i + 2 * N))); + const auto m3 = + RebindMask(di, Eq(broadcasted, LoadU(d, in + i + 3 * N))); + acc0 = MaskedAddOr(acc0, m0, acc0, k1); + acc1 = MaskedAddOr(acc1, m1, acc1, k1); + acc2 = MaskedAddOr(acc2, m2, acc2, k1); + acc3 = MaskedAddOr(acc3, m3, acc3, k1); + } + } else { + for (; i <= cap - 4 * N; i += 4 * N) { + const auto v0 = VecFromMask(d, Eq(broadcasted, LoadU(d, in + i))); + const auto v1 = + VecFromMask(d, Eq(broadcasted, LoadU(d, in + i + N))); + const auto v2 = + VecFromMask(d, Eq(broadcasted, LoadU(d, in + i + 2 * N))); + const auto v3 = + VecFromMask(d, Eq(broadcasted, LoadU(d, in + i + 3 * N))); + acc0 = Add(acc0, BitCast(di, v0)); + acc1 = Add(acc1, BitCast(di, v1)); + acc2 = Add(acc2, BitCast(di, v2)); + acc3 = Add(acc3, BitCast(di, v3)); + } + + acc0 = Neg(acc0); + acc1 = Neg(acc1); + acc2 = Neg(acc2); + acc3 = Neg(acc3); + } + + const auto w0 = SatWidenMulPairwiseAdd(di16, BitCast(du, acc0), k1); + const auto w1 = SatWidenMulPairwiseAdd(di16, BitCast(du, acc1), k1); + const auto w2 = SatWidenMulPairwiseAdd(di16, BitCast(du, acc2), k1); + const auto w3 = SatWidenMulPairwiseAdd(di16, BitCast(du, acc3), k1); + const auto sum16 = Add(Add(w0, w1), Add(w2, w3)); + wide_sum = SatWidenMulPairwiseAccumulate( + di32, sum16, Set(di16, int16_t{1}), wide_sum); + } + } + total += static_cast(ReduceSum(di32, wide_sum)); + } else if constexpr (sizeof(T) == 2 && HWY_MAX_LANES_D(D) >= 2) { + // Min 2 lanes needed for pairwise widening, 16->32 + const Repartition di32; + auto wide_sum = Zero(di32); + + if (count >= 4 * N && N >= 2) { + while (i <= count - 4 * N) { + VI acc0 = Zero(di); + VI acc1 = Zero(di); + VI acc2 = Zero(di); + VI acc3 = Zero(di); + const size_t cap = HWY_MIN(i + 32768 * 4 * N, count); + + if constexpr (HWY_NATIVE_MASK) { + const auto k1 = Set(di, TI{1}); + for (; i <= cap - 4 * N; i += 4 * N) { + const auto m0 = RebindMask(di, Eq(broadcasted, LoadU(d, in + i))); + const auto m1 = + RebindMask(di, Eq(broadcasted, LoadU(d, in + i + N))); + const auto m2 = + RebindMask(di, Eq(broadcasted, LoadU(d, in + i + 2 * N))); + const auto m3 = + RebindMask(di, Eq(broadcasted, LoadU(d, in + i + 3 * N))); + acc0 = MaskedAddOr(acc0, m0, acc0, k1); + acc1 = MaskedAddOr(acc1, m1, acc1, k1); + acc2 = MaskedAddOr(acc2, m2, acc2, k1); + acc3 = MaskedAddOr(acc3, m3, acc3, k1); + } + } else { + for (; i <= cap - 4 * N; i += 4 * N) { + const auto v0 = VecFromMask(d, Eq(broadcasted, LoadU(d, in + i))); + const auto v1 = + VecFromMask(d, Eq(broadcasted, LoadU(d, in + i + N))); + const auto v2 = + VecFromMask(d, Eq(broadcasted, LoadU(d, in + i + 2 * N))); + const auto v3 = + VecFromMask(d, Eq(broadcasted, LoadU(d, in + i + 3 * N))); + acc0 = Add(acc0, BitCast(di, v0)); + acc1 = Add(acc1, BitCast(di, v1)); + acc2 = Add(acc2, BitCast(di, v2)); + acc3 = Add(acc3, BitCast(di, v3)); + } + } + const auto mul = Set(di, HWY_NATIVE_MASK ? TI{1} : TI{-1}); + wide_sum = SatWidenMulPairwiseAccumulate(di32, acc0, mul, wide_sum); + wide_sum = SatWidenMulPairwiseAccumulate(di32, acc1, mul, wide_sum); + wide_sum = SatWidenMulPairwiseAccumulate(di32, acc2, mul, wide_sum); + wide_sum = SatWidenMulPairwiseAccumulate(di32, acc3, mul, wide_sum); + } + } + total += static_cast(ReduceSum(di32, wide_sum)); + } else { + // Lane type wide enough to accumulate directly + if (count >= 4 * N) { + VI acc0 = Zero(di); + VI acc1 = Zero(di); + VI acc2 = Zero(di); + VI acc3 = Zero(di); + + if constexpr (HWY_NATIVE_MASK) { + const auto k1 = Set(di, TI{1}); + for (; i <= count - 4 * N; i += 4 * N) { + const auto m0 = RebindMask(di, Eq(broadcasted, LoadU(d, in + i))); + const auto m1 = RebindMask(di, Eq(broadcasted, LoadU(d, in + i + N))); + const auto m2 = + RebindMask(di, Eq(broadcasted, LoadU(d, in + i + 2 * N))); + const auto m3 = + RebindMask(di, Eq(broadcasted, LoadU(d, in + i + 3 * N))); + acc0 = MaskedAddOr(acc0, m0, acc0, k1); + acc1 = MaskedAddOr(acc1, m1, acc1, k1); + acc2 = MaskedAddOr(acc2, m2, acc2, k1); + acc3 = MaskedAddOr(acc3, m3, acc3, k1); + } + acc0 = Add(Add(acc0, acc1), Add(acc2, acc3)); + } else { + for (; i <= count - 4 * N; i += 4 * N) { + const auto v0 = VecFromMask(d, Eq(broadcasted, LoadU(d, in + i))); + const auto v1 = VecFromMask(d, Eq(broadcasted, LoadU(d, in + i + N))); + const auto v2 = + VecFromMask(d, Eq(broadcasted, LoadU(d, in + i + 2 * N))); + const auto v3 = + VecFromMask(d, Eq(broadcasted, LoadU(d, in + i + 3 * N))); + acc0 = Add(acc0, BitCast(di, v0)); + acc1 = Add(acc1, BitCast(di, v1)); + acc2 = Add(acc2, BitCast(di, v2)); + acc3 = Add(acc3, BitCast(di, v3)); + } + acc0 = Neg(Add(Add(acc0, acc1), Add(acc2, acc3))); + } + total += static_cast(ReduceSum(di, acc0)); + } + } + + if (count >= N) { + for (; i <= count - N; i += N) { + total += CountTrue(d, Eq(broadcasted, LoadU(d, in + i))); + } + } + + if (i != count) { + const size_t remaining = count - i; + HWY_DASSERT(0 != remaining && remaining < N); + const V v = LoadN(d, in + i, remaining); + total += CountTrue(d, And(Eq(broadcasted, v), FirstN(d, remaining))); + } + + return total; +} + +// Returns the number of elements in `in[0, count)` for which `func(d, vec)` +// returns true. +template > +size_t CountIf(D d, const T* HWY_RESTRICT in, size_t count, const Func& func) { + const size_t N = Lanes(d); + using V = Vec; + const RebindToSigned di; + using TI = TFromD; + using VI = Vec; + using MI = Mask; + + size_t total = 0; + size_t i = 0; + + // Min 4 lanes needed for two pairwise widenings, 8->16->32 + if constexpr (sizeof(T) == 1 && HWY_MAX_LANES_D(D) >= 4) { + const VI k1 = Set(di, TI{1}); + const RebindToUnsigned du; + const RepartitionToWide di16; + const Repartition di32; + using VI16 = Vec; + using VI32 = Vec; + VI32 wide_sum = Zero(di32); + + if (count >= 4 * N && N >= 4) { + while (i <= count - 4 * N) { + VI acc0 = Zero(di); + VI acc1 = Zero(di); + VI acc2 = Zero(di); + VI acc3 = Zero(di); + const size_t cap = HWY_MIN(i + 128 * 4 * N, count); + + if constexpr (HWY_NATIVE_MASK) { + for (; i <= cap - 4 * N; i += 4 * N) { + const MI m0 = RebindMask(di, func(d, LoadU(d, in + i))); + const MI m1 = RebindMask(di, func(d, LoadU(d, in + i + N))); + const MI m2 = RebindMask(di, func(d, LoadU(d, in + i + 2 * N))); + const MI m3 = RebindMask(di, func(d, LoadU(d, in + i + 3 * N))); + acc0 = MaskedAddOr(acc0, m0, acc0, k1); + acc1 = MaskedAddOr(acc1, m1, acc1, k1); + acc2 = MaskedAddOr(acc2, m2, acc2, k1); + acc3 = MaskedAddOr(acc3, m3, acc3, k1); + } + } else { + for (; i <= cap - 4 * N; i += 4 * N) { + const V v0 = VecFromMask(d, func(d, LoadU(d, in + i))); + const V v1 = VecFromMask(d, func(d, LoadU(d, in + i + N))); + const V v2 = VecFromMask(d, func(d, LoadU(d, in + i + 2 * N))); + const V v3 = VecFromMask(d, func(d, LoadU(d, in + i + 3 * N))); + acc0 = Add(acc0, BitCast(di, v0)); + acc1 = Add(acc1, BitCast(di, v1)); + acc2 = Add(acc2, BitCast(di, v2)); + acc3 = Add(acc3, BitCast(di, v3)); + } + + acc0 = Neg(acc0); + acc1 = Neg(acc1); + acc2 = Neg(acc2); + acc3 = Neg(acc3); + } + + const VI16 w0 = SatWidenMulPairwiseAdd(di16, BitCast(du, acc0), k1); + const VI16 w1 = SatWidenMulPairwiseAdd(di16, BitCast(du, acc1), k1); + const VI16 w2 = SatWidenMulPairwiseAdd(di16, BitCast(du, acc2), k1); + const VI16 w3 = SatWidenMulPairwiseAdd(di16, BitCast(du, acc3), k1); + const VI16 sum16 = Add(Add(w0, w1), Add(w2, w3)); + wide_sum = SatWidenMulPairwiseAccumulate( + di32, sum16, Set(di16, int16_t{1}), wide_sum); + } + } + total += static_cast(ReduceSum(di32, wide_sum)); + } else if constexpr (sizeof(T) == 2 && HWY_MAX_LANES_D(D) >= 2) { + // Min 2 lanes needed for pairwise widening, 16->32 + const Repartition di32; + using VI32 = Vec; + VI32 wide_sum = Zero(di32); + + if (count >= 4 * N && N >= 2) { + while (i <= count - 4 * N) { + VI acc0 = Zero(di); + VI acc1 = Zero(di); + VI acc2 = Zero(di); + VI acc3 = Zero(di); + const size_t cap = HWY_MIN(i + 32768 * 4 * N, count); + + if constexpr (HWY_NATIVE_MASK) { + const VI k1 = Set(di, TI{1}); + for (; i <= cap - 4 * N; i += 4 * N) { + const MI m0 = RebindMask(di, func(d, LoadU(d, in + i))); + const MI m1 = RebindMask(di, func(d, LoadU(d, in + i + N))); + const MI m2 = RebindMask(di, func(d, LoadU(d, in + i + 2 * N))); + const MI m3 = RebindMask(di, func(d, LoadU(d, in + i + 3 * N))); + acc0 = MaskedAddOr(acc0, m0, acc0, k1); + acc1 = MaskedAddOr(acc1, m1, acc1, k1); + acc2 = MaskedAddOr(acc2, m2, acc2, k1); + acc3 = MaskedAddOr(acc3, m3, acc3, k1); + } + } else { + for (; i <= cap - 4 * N; i += 4 * N) { + const V v0 = VecFromMask(d, func(d, LoadU(d, in + i))); + const V v1 = VecFromMask(d, func(d, LoadU(d, in + i + N))); + const V v2 = VecFromMask(d, func(d, LoadU(d, in + i + 2 * N))); + const V v3 = VecFromMask(d, func(d, LoadU(d, in + i + 3 * N))); + acc0 = Add(acc0, BitCast(di, v0)); + acc1 = Add(acc1, BitCast(di, v1)); + acc2 = Add(acc2, BitCast(di, v2)); + acc3 = Add(acc3, BitCast(di, v3)); + } + } + const VI mul = Set(di, HWY_NATIVE_MASK ? TI{1} : TI{-1}); + wide_sum = SatWidenMulPairwiseAccumulate(di32, acc0, mul, wide_sum); + wide_sum = SatWidenMulPairwiseAccumulate(di32, acc1, mul, wide_sum); + wide_sum = SatWidenMulPairwiseAccumulate(di32, acc2, mul, wide_sum); + wide_sum = SatWidenMulPairwiseAccumulate(di32, acc3, mul, wide_sum); + } + } + total += static_cast(ReduceSum(di32, wide_sum)); + } else { + // Lane type wide enough to accumulate directly + if (count >= 4 * N) { + VI acc0 = Zero(di); + VI acc1 = Zero(di); + VI acc2 = Zero(di); + VI acc3 = Zero(di); + + if constexpr (HWY_NATIVE_MASK) { + const VI k1 = Set(di, TI{1}); + for (; i <= count - 4 * N; i += 4 * N) { + const MI m0 = RebindMask(di, func(d, LoadU(d, in + i))); + const MI m1 = RebindMask(di, func(d, LoadU(d, in + i + N))); + const MI m2 = RebindMask(di, func(d, LoadU(d, in + i + 2 * N))); + const MI m3 = RebindMask(di, func(d, LoadU(d, in + i + 3 * N))); + acc0 = MaskedAddOr(acc0, m0, acc0, k1); + acc1 = MaskedAddOr(acc1, m1, acc1, k1); + acc2 = MaskedAddOr(acc2, m2, acc2, k1); + acc3 = MaskedAddOr(acc3, m3, acc3, k1); + } + acc0 = Add(Add(acc0, acc1), Add(acc2, acc3)); + } else { + for (; i <= count - 4 * N; i += 4 * N) { + const V v0 = VecFromMask(d, func(d, LoadU(d, in + i))); + const V v1 = VecFromMask(d, func(d, LoadU(d, in + i + N))); + const V v2 = VecFromMask(d, func(d, LoadU(d, in + i + 2 * N))); + const V v3 = VecFromMask(d, func(d, LoadU(d, in + i + 3 * N))); + acc0 = Add(acc0, BitCast(di, v0)); + acc1 = Add(acc1, BitCast(di, v1)); + acc2 = Add(acc2, BitCast(di, v2)); + acc3 = Add(acc3, BitCast(di, v3)); + } + acc0 = Neg(Add(Add(acc0, acc1), Add(acc2, acc3))); + } + total += static_cast(ReduceSum(di, acc0)); + } + } + + if (count >= N) { + for (; i <= count - N; i += N) { + total += CountTrue(d, func(d, LoadU(d, in + i))); + } + } + + if (i != count) { + const size_t remaining = count - i; + HWY_DASSERT(0 != remaining && remaining < N); + const V v = LoadN(d, in + i, remaining); + total += CountTrue(d, And(func(d, v), FirstN(d, remaining))); + } + + return total; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_ALGO_COUNT_INL_H_ diff --git a/lib/highway/hwy/contrib/algo/find-inl.h b/lib/highway/hwy/contrib/algo/find-inl.h new file mode 100644 index 00000000000..dc0a8cac084 --- /dev/null +++ b/lib/highway/hwy/contrib/algo/find-inl.h @@ -0,0 +1,113 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target include guard +#if defined(HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Returns index of the first element equal to `value` in `in[0, count)`, or +// `count` if not found. +template > +size_t Find(D d, T value, const T* HWY_RESTRICT in, size_t count) { + const size_t N = Lanes(d); + const Vec broadcasted = Set(d, value); + + size_t i = 0; + if (count >= N) { + for (; i <= count - N; i += N) { + const intptr_t pos = FindFirstTrue(d, Eq(broadcasted, LoadU(d, in + i))); + if (pos >= 0) return i + static_cast(pos); + } + } + + if (i != count) { +#if HWY_MEM_OPS_MIGHT_FAULT + // Scan single elements. + const CappedTag d1; + using V1 = Vec; + const V1 broadcasted1 = Set(d1, GetLane(broadcasted)); + for (; i < count; ++i) { + if (AllTrue(d1, Eq(broadcasted1, LoadU(d1, in + i)))) { + return i; + } + } +#else + const size_t remaining = count - i; + HWY_DASSERT(0 != remaining && remaining < N); + const Mask mask = FirstN(d, remaining); + const Vec v = MaskedLoad(mask, d, in + i); + // Apply mask so that we don't 'find' the zero-padding from MaskedLoad. + const intptr_t pos = FindFirstTrue(d, And(Eq(broadcasted, v), mask)); + if (pos >= 0) return i + static_cast(pos); +#endif // HWY_MEM_OPS_MIGHT_FAULT + } + + return count; // not found +} + +// Returns index of the first element in `in[0, count)` for which `func(d, vec)` +// returns true, otherwise `count`. +template > +size_t FindIf(D d, const T* HWY_RESTRICT in, size_t count, const Func& func) { + const size_t N = Lanes(d); + + size_t i = 0; + if (count >= N) { + for (; i <= count - N; i += N) { + const intptr_t pos = FindFirstTrue(d, func(d, LoadU(d, in + i))); + if (pos >= 0) return i + static_cast(pos); + } + } + + if (i != count) { +#if HWY_MEM_OPS_MIGHT_FAULT + // Scan single elements. + const CappedTag d1; + for (; i < count; ++i) { + if (AllTrue(d1, func(d1, LoadU(d1, in + i)))) { + return i; + } + } +#else + const size_t remaining = count - i; + HWY_DASSERT(0 != remaining && remaining < N); + const Mask mask = FirstN(d, remaining); + const Vec v = MaskedLoad(mask, d, in + i); + // Apply mask so that we don't 'find' the zero-padding from MaskedLoad. + const intptr_t pos = FindFirstTrue(d, And(func(d, v), mask)); + if (pos >= 0) return i + static_cast(pos); +#endif // HWY_MEM_OPS_MIGHT_FAULT + } + + return count; // not found +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_ALGO_FIND_INL_H_ diff --git a/lib/highway/hwy/contrib/algo/minmax-inl.h b/lib/highway/hwy/contrib/algo/minmax-inl.h new file mode 100644 index 00000000000..6020d16d2b2 --- /dev/null +++ b/lib/highway/hwy/contrib/algo/minmax-inl.h @@ -0,0 +1,102 @@ +// Copyright 2026 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target include guard +#if defined(HIGHWAY_HWY_CONTRIB_ALGO_MINMAX_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef HIGHWAY_HWY_CONTRIB_ALGO_MINMAX_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_ALGO_MINMAX_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_ALGO_MINMAX_INL_H_ +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Returns the minimum value in `in[0, count)` or PositiveInfOrHighestValue() if count == 0. +template > +T MinValue(D d, const T* HWY_RESTRICT in, size_t count) { + const size_t N = Lanes(d); + const T identity = hwy::PositiveInfOrHighestValue(); + const Vec identity_vec = Set(d, identity); + + Vec acc0 = identity_vec; + Vec acc1 = identity_vec; + Vec acc2 = identity_vec; + Vec acc3 = identity_vec; + + size_t i = 0; + if (count >= 4 * N) { + for (; i <= count - 4 * N; i += 4 * N) { + acc0 = Min(acc0, LoadU(d, in + i)); + acc1 = Min(acc1, LoadU(d, in + i + N)); + acc2 = Min(acc2, LoadU(d, in + i + 2 * N)); + acc3 = Min(acc3, LoadU(d, in + i + 3 * N)); + } + } + + acc0 = Min(Min(acc0, acc1), Min(acc2, acc3)); + + for (; i < count; i += N) { + const size_t remaining = count - i; + const size_t n = HWY_MIN(remaining, N); + acc0 = Min(acc0, LoadNOr(identity_vec, d, in + i, n)); + } + + return ReduceMin(d, acc0); +} + +// Returns the maximum value in `in[0, count)` or NegativeInfOrLowestValue() if count == 0. +template > +T MaxValue(D d, const T* HWY_RESTRICT in, size_t count) { + const size_t N = Lanes(d); + const T identity = hwy::NegativeInfOrLowestValue(); + const Vec identity_vec = Set(d, identity); + + Vec acc0 = identity_vec; + Vec acc1 = identity_vec; + Vec acc2 = identity_vec; + Vec acc3 = identity_vec; + + size_t i = 0; + if (count >= 4 * N) { + for (; i <= count - 4 * N; i += 4 * N) { + acc0 = Max(acc0, LoadU(d, in + i)); + acc1 = Max(acc1, LoadU(d, in + i + N)); + acc2 = Max(acc2, LoadU(d, in + i + 2 * N)); + acc3 = Max(acc3, LoadU(d, in + i + 3 * N)); + } + } + + acc0 = Max(Max(acc0, acc1), Max(acc2, acc3)); + + for (; i < count; i += N) { + const size_t remaining = count - i; + const size_t n = HWY_MIN(remaining, N); + acc0 = Max(acc0, LoadNOr(identity_vec, d, in + i, n)); + } + + return ReduceMax(d, acc0); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_ALGO_MINMAX_INL_H_ diff --git a/lib/highway/hwy/contrib/algo/transform-inl.h b/lib/highway/hwy/contrib/algo/transform-inl.h new file mode 100644 index 00000000000..f48476d15b4 --- /dev/null +++ b/lib/highway/hwy/contrib/algo/transform-inl.h @@ -0,0 +1,228 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target include guard +#if defined(HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_ +#endif + +#include + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// These functions avoid having to write a loop plus remainder handling in the +// (unfortunately still common) case where arrays are not aligned/padded. If the +// inputs are known to be aligned/padded, it is more efficient to write a single +// loop using Load(). We do not provide a TransformAlignedPadded because it +// would be more verbose than such a loop. +// +// Func is either a functor with a templated operator()(d, v[, v1[, v2]]), or a +// generic lambda if using C++14. The d argument is the same as was passed to +// the Generate etc. functions. Due to apparent limitations of Clang, it is +// currently necessary to add HWY_ATTR before the opening { of the lambda to +// avoid errors about "always_inline function .. requires target". +// +// We do not check HWY_MEM_OPS_MIGHT_FAULT because LoadN/StoreN do not fault. + +// Fills `out[0, count)` with the vectors returned by `func(d, index_vec)`, +// where `index_vec` is `Vec>`. On the first call to `func`, +// the value of its lane i is i, and increases by `Lanes(d)` after every call. +// Note that some of these indices may be `>= count`, but the elements that +// `func` returns in those lanes will not be written to `out`. +template > +void Generate(D d, T* HWY_RESTRICT out, size_t count, const Func& func) { + const RebindToUnsigned du; + using TU = TFromD; + const size_t N = Lanes(d); + + size_t idx = 0; + Vec vidx = Iota(du, 0); + if (count >= N) { + for (; idx <= count - N; idx += N) { + StoreU(func(d, vidx), d, out + idx); + vidx = Add(vidx, Set(du, static_cast(N))); + } + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + StoreN(func(d, vidx), d, out + idx, remaining); +} + +// Calls `func(d, v)` for each input vector; out of bound lanes with index i >= +// `count` are instead taken from `no[i % Lanes(d)]`. +template > +void Foreach(D d, const T* HWY_RESTRICT in, const size_t count, const Vec no, + const Func& func) { + const size_t N = Lanes(d); + + size_t idx = 0; + if (count >= N) { + for (; idx <= count - N; idx += N) { + const Vec v = LoadU(d, in + idx); + func(d, v); + } + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Vec v = LoadNOr(no, d, in + idx, remaining); + func(d, v); +} + +// Replaces `inout[idx]` with `func(d, inout[idx])`. Example usage: multiplying +// array elements by a constant. +template > +void Transform(D d, T* HWY_RESTRICT inout, size_t count, const Func& func) { + const size_t N = Lanes(d); + + size_t idx = 0; + if (count >= N) { + for (; idx <= count - N; idx += N) { + const Vec v = LoadU(d, inout + idx); + StoreU(func(d, v), d, inout + idx); + } + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Vec v = LoadN(d, inout + idx, remaining); + StoreN(func(d, v), d, inout + idx, remaining); +} + +// Replaces `inout[idx]` with `func(d, inout[idx], in1[idx])`. Example usage: +// multiplying array elements by those of another array. +template > +void Transform1(D d, T* HWY_RESTRICT inout, size_t count, + const T* HWY_RESTRICT in1, const Func& func) { + const size_t N = Lanes(d); + + size_t idx = 0; + if (count >= N) { + for (; idx <= count - N; idx += N) { + const Vec v = LoadU(d, inout + idx); + const Vec v1 = LoadU(d, in1 + idx); + StoreU(func(d, v, v1), d, inout + idx); + } + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Vec v = LoadN(d, inout + idx, remaining); + const Vec v1 = LoadN(d, in1 + idx, remaining); + StoreN(func(d, v, v1), d, inout + idx, remaining); +} + +// Replaces `inout[idx]` with `func(d, inout[idx], in1[idx], in2[idx])`. Example +// usage: FMA of elements from three arrays, stored into the first array. +template > +void Transform2(D d, T* HWY_RESTRICT inout, size_t count, + const T* HWY_RESTRICT in1, const T* HWY_RESTRICT in2, + const Func& func) { + const size_t N = Lanes(d); + + size_t idx = 0; + if (count >= N) { + for (; idx <= count - N; idx += N) { + const Vec v = LoadU(d, inout + idx); + const Vec v1 = LoadU(d, in1 + idx); + const Vec v2 = LoadU(d, in2 + idx); + StoreU(func(d, v, v1, v2), d, inout + idx); + } + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Vec v = LoadN(d, inout + idx, remaining); + const Vec v1 = LoadN(d, in1 + idx, remaining); + const Vec v2 = LoadN(d, in2 + idx, remaining); + StoreN(func(d, v, v1, v2), d, inout + idx, remaining); +} + +template > +void Replace(D d, T* HWY_RESTRICT inout, size_t count, T new_t, T old_t) { + const size_t N = Lanes(d); + const Vec old_v = Set(d, old_t); + const Vec new_v = Set(d, new_t); + + size_t idx = 0; + if (count >= N) { + for (; idx <= count - N; idx += N) { + Vec v = LoadU(d, inout + idx); + StoreU(IfThenElse(Eq(v, old_v), new_v, v), d, inout + idx); + } + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Vec v = LoadN(d, inout + idx, remaining); + StoreN(IfThenElse(Eq(v, old_v), new_v, v), d, inout + idx, remaining); +} + +template > +void ReplaceIf(D d, T* HWY_RESTRICT inout, size_t count, T new_t, + const Func& func) { + const size_t N = Lanes(d); + const Vec new_v = Set(d, new_t); + + size_t idx = 0; + if (count >= N) { + for (; idx <= count - N; idx += N) { + Vec v = LoadU(d, inout + idx); + StoreU(IfThenElse(func(d, v), new_v, v), d, inout + idx); + } + } + + // `count` was a multiple of the vector length `N`: already done. + if (HWY_UNLIKELY(idx == count)) return; + + const size_t remaining = count - idx; + HWY_DASSERT(0 != remaining && remaining < N); + const Vec v = LoadN(d, inout + idx, remaining); + StoreN(IfThenElse(func(d, v), new_v, v), d, inout + idx, remaining); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_ALGO_TRANSFORM_INL_H_ diff --git a/lib/highway/hwy/contrib/bit_pack/bit_pack-inl.h b/lib/highway/hwy/contrib/bit_pack/bit_pack-inl.h new file mode 100644 index 00000000000..c041c13f02d --- /dev/null +++ b/lib/highway/hwy/contrib/bit_pack/bit_pack-inl.h @@ -0,0 +1,2851 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "hwy/base.h" + +// Per-target include guard +// clang-format off +#if defined(HIGHWAY_HWY_CONTRIB_BIT_PACK_INL_H_) == defined(HWY_TARGET_TOGGLE) // NOLINT +// clang-format on +#ifdef HIGHWAY_HWY_CONTRIB_BIT_PACK_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_BIT_PACK_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_BIT_PACK_INL_H_ +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// The entry points are class templates specialized below for each number of +// bits. Each provides Pack and Unpack member functions which load (Pack) or +// store (Unpack) B raw vectors, and store (Pack) or load (Unpack) a number of +// packed vectors equal to kBits. B denotes the bits per lane: 8 for Pack8, 16 +// for Pack16, 32 for Pack32 which is also the upper bound for kBits. +template // <= 8 +struct Pack8 {}; +template // <= 16 +struct Pack16 {}; + +template <> +struct Pack8<1> { + template + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + // 16-bit shifts avoid masking (bits will not cross 8-bit lanes). + const VU16 raw0 = BitCast(d16, LoadU(d8, raw + 0 * N8)); + const VU16 raw1 = BitCast(d16, LoadU(d8, raw + 1 * N8)); + const VU16 raw2 = BitCast(d16, LoadU(d8, raw + 2 * N8)); + const VU16 raw3 = BitCast(d16, LoadU(d8, raw + 3 * N8)); + const VU16 raw4 = BitCast(d16, LoadU(d8, raw + 4 * N8)); + const VU16 raw5 = BitCast(d16, LoadU(d8, raw + 5 * N8)); + const VU16 raw6 = BitCast(d16, LoadU(d8, raw + 6 * N8)); + const VU16 raw7 = BitCast(d16, LoadU(d8, raw + 7 * N8)); + + const VU16 packed = + Xor3(Or(ShiftLeft<7>(raw7), ShiftLeft<6>(raw6)), + Xor3(ShiftLeft<5>(raw5), ShiftLeft<4>(raw4), ShiftLeft<3>(raw3)), + Xor3(ShiftLeft<2>(raw2), ShiftLeft<1>(raw1), raw0)); + StoreU(BitCast(d8, packed), d8, packed_out); + } + + template + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + const VU16 mask = Set(d16, 0x0101u); // LSB in each byte + + const VU16 packed = BitCast(d16, LoadU(d8, packed_in)); + + const VU16 raw0 = And(packed, mask); + StoreU(BitCast(d8, raw0), d8, raw + 0 * N8); + + const VU16 raw1 = And(ShiftRight<1>(packed), mask); + StoreU(BitCast(d8, raw1), d8, raw + 1 * N8); + + const VU16 raw2 = And(ShiftRight<2>(packed), mask); + StoreU(BitCast(d8, raw2), d8, raw + 2 * N8); + + const VU16 raw3 = And(ShiftRight<3>(packed), mask); + StoreU(BitCast(d8, raw3), d8, raw + 3 * N8); + + const VU16 raw4 = And(ShiftRight<4>(packed), mask); + StoreU(BitCast(d8, raw4), d8, raw + 4 * N8); + + const VU16 raw5 = And(ShiftRight<5>(packed), mask); + StoreU(BitCast(d8, raw5), d8, raw + 5 * N8); + + const VU16 raw6 = And(ShiftRight<6>(packed), mask); + StoreU(BitCast(d8, raw6), d8, raw + 6 * N8); + + const VU16 raw7 = And(ShiftRight<7>(packed), mask); + StoreU(BitCast(d8, raw7), d8, raw + 7 * N8); + } +}; // Pack8<1> + +template <> +struct Pack8<2> { + template + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + // 16-bit shifts avoid masking (bits will not cross 8-bit lanes). + const VU16 raw0 = BitCast(d16, LoadU(d8, raw + 0 * N8)); + const VU16 raw1 = BitCast(d16, LoadU(d8, raw + 1 * N8)); + const VU16 raw2 = BitCast(d16, LoadU(d8, raw + 2 * N8)); + const VU16 raw3 = BitCast(d16, LoadU(d8, raw + 3 * N8)); + const VU16 raw4 = BitCast(d16, LoadU(d8, raw + 4 * N8)); + const VU16 raw5 = BitCast(d16, LoadU(d8, raw + 5 * N8)); + const VU16 raw6 = BitCast(d16, LoadU(d8, raw + 6 * N8)); + const VU16 raw7 = BitCast(d16, LoadU(d8, raw + 7 * N8)); + + const VU16 packed0 = Xor3(ShiftLeft<6>(raw6), ShiftLeft<4>(raw4), + Or(ShiftLeft<2>(raw2), raw0)); + const VU16 packed1 = Xor3(ShiftLeft<6>(raw7), ShiftLeft<4>(raw5), + Or(ShiftLeft<2>(raw3), raw1)); + StoreU(BitCast(d8, packed0), d8, packed_out + 0 * N8); + StoreU(BitCast(d8, packed1), d8, packed_out + 1 * N8); + } + + template + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + const VU16 mask = Set(d16, 0x0303u); // Lowest 2 bits per byte + + const VU16 packed0 = BitCast(d16, LoadU(d8, packed_in + 0 * N8)); + const VU16 packed1 = BitCast(d16, LoadU(d8, packed_in + 1 * N8)); + + const VU16 raw0 = And(packed0, mask); + StoreU(BitCast(d8, raw0), d8, raw + 0 * N8); + + const VU16 raw1 = And(packed1, mask); + StoreU(BitCast(d8, raw1), d8, raw + 1 * N8); + + const VU16 raw2 = And(ShiftRight<2>(packed0), mask); + StoreU(BitCast(d8, raw2), d8, raw + 2 * N8); + + const VU16 raw3 = And(ShiftRight<2>(packed1), mask); + StoreU(BitCast(d8, raw3), d8, raw + 3 * N8); + + const VU16 raw4 = And(ShiftRight<4>(packed0), mask); + StoreU(BitCast(d8, raw4), d8, raw + 4 * N8); + + const VU16 raw5 = And(ShiftRight<4>(packed1), mask); + StoreU(BitCast(d8, raw5), d8, raw + 5 * N8); + + const VU16 raw6 = And(ShiftRight<6>(packed0), mask); + StoreU(BitCast(d8, raw6), d8, raw + 6 * N8); + + const VU16 raw7 = And(ShiftRight<6>(packed1), mask); + StoreU(BitCast(d8, raw7), d8, raw + 7 * N8); + } +}; // Pack8<2> + +template <> +struct Pack8<3> { + template + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + const VU16 raw0 = BitCast(d16, LoadU(d8, raw + 0 * N8)); + const VU16 raw1 = BitCast(d16, LoadU(d8, raw + 1 * N8)); + const VU16 raw2 = BitCast(d16, LoadU(d8, raw + 2 * N8)); + const VU16 raw3 = BitCast(d16, LoadU(d8, raw + 3 * N8)); + const VU16 raw4 = BitCast(d16, LoadU(d8, raw + 4 * N8)); + const VU16 raw5 = BitCast(d16, LoadU(d8, raw + 5 * N8)); + const VU16 raw6 = BitCast(d16, LoadU(d8, raw + 6 * N8)); + const VU16 raw7 = BitCast(d16, LoadU(d8, raw + 7 * N8)); + + // The upper two bits of these three will be filled with packed3 (6 bits). + VU16 packed0 = Or(ShiftLeft<3>(raw4), raw0); + VU16 packed1 = Or(ShiftLeft<3>(raw5), raw1); + VU16 packed2 = Or(ShiftLeft<3>(raw6), raw2); + const VU16 packed3 = Or(ShiftLeft<3>(raw7), raw3); + + const VU16 hi2 = Set(d16, 0xC0C0u); + packed0 = OrAnd(packed0, ShiftLeft<2>(packed3), hi2); + packed1 = OrAnd(packed1, ShiftLeft<4>(packed3), hi2); + packed2 = OrAnd(packed2, ShiftLeft<6>(packed3), hi2); + StoreU(BitCast(d8, packed0), d8, packed_out + 0 * N8); + StoreU(BitCast(d8, packed1), d8, packed_out + 1 * N8); + StoreU(BitCast(d8, packed2), d8, packed_out + 2 * N8); + } + + template + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + const VU16 mask = Set(d16, 0x0707u); // Lowest 3 bits per byte + + const VU16 packed0 = BitCast(d16, LoadU(d8, packed_in + 0 * N8)); + const VU16 packed1 = BitCast(d16, LoadU(d8, packed_in + 1 * N8)); + const VU16 packed2 = BitCast(d16, LoadU(d8, packed_in + 2 * N8)); + + const VU16 raw0 = And(packed0, mask); + StoreU(BitCast(d8, raw0), d8, raw + 0 * N8); + + const VU16 raw1 = And(packed1, mask); + StoreU(BitCast(d8, raw1), d8, raw + 1 * N8); + + const VU16 raw2 = And(packed2, mask); + StoreU(BitCast(d8, raw2), d8, raw + 2 * N8); + + const VU16 raw4 = And(ShiftRight<3>(packed0), mask); + StoreU(BitCast(d8, raw4), d8, raw + 4 * N8); + + const VU16 raw5 = And(ShiftRight<3>(packed1), mask); + StoreU(BitCast(d8, raw5), d8, raw + 5 * N8); + + const VU16 raw6 = And(ShiftRight<3>(packed2), mask); + StoreU(BitCast(d8, raw6), d8, raw + 6 * N8); + + // raw73 is the concatenation of the upper two bits in packed0..2. + const VU16 hi2 = Set(d16, 0xC0C0u); + const VU16 raw73 = Xor3(ShiftRight<6>(And(packed2, hi2)), // + ShiftRight<4>(And(packed1, hi2)), + ShiftRight<2>(And(packed0, hi2))); + + const VU16 raw3 = And(mask, raw73); + StoreU(BitCast(d8, raw3), d8, raw + 3 * N8); + + const VU16 raw7 = And(mask, ShiftRight<3>(raw73)); + StoreU(BitCast(d8, raw7), d8, raw + 7 * N8); + } +}; // Pack8<3> + +template <> +struct Pack8<4> { + template + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + // 16-bit shifts avoid masking (bits will not cross 8-bit lanes). + const VU16 raw0 = BitCast(d16, LoadU(d8, raw + 0 * N8)); + const VU16 raw1 = BitCast(d16, LoadU(d8, raw + 1 * N8)); + const VU16 raw2 = BitCast(d16, LoadU(d8, raw + 2 * N8)); + const VU16 raw3 = BitCast(d16, LoadU(d8, raw + 3 * N8)); + const VU16 raw4 = BitCast(d16, LoadU(d8, raw + 4 * N8)); + const VU16 raw5 = BitCast(d16, LoadU(d8, raw + 5 * N8)); + const VU16 raw6 = BitCast(d16, LoadU(d8, raw + 6 * N8)); + const VU16 raw7 = BitCast(d16, LoadU(d8, raw + 7 * N8)); + + const VU16 packed0 = Or(ShiftLeft<4>(raw2), raw0); + const VU16 packed1 = Or(ShiftLeft<4>(raw3), raw1); + const VU16 packed2 = Or(ShiftLeft<4>(raw6), raw4); + const VU16 packed3 = Or(ShiftLeft<4>(raw7), raw5); + + StoreU(BitCast(d8, packed0), d8, packed_out + 0 * N8); + StoreU(BitCast(d8, packed1), d8, packed_out + 1 * N8); + StoreU(BitCast(d8, packed2), d8, packed_out + 2 * N8); + StoreU(BitCast(d8, packed3), d8, packed_out + 3 * N8); + } + + template + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + const VU16 mask = Set(d16, 0x0F0Fu); // Lowest 4 bits per byte + + const VU16 packed0 = BitCast(d16, LoadU(d8, packed_in + 0 * N8)); + const VU16 packed1 = BitCast(d16, LoadU(d8, packed_in + 1 * N8)); + const VU16 packed2 = BitCast(d16, LoadU(d8, packed_in + 2 * N8)); + const VU16 packed3 = BitCast(d16, LoadU(d8, packed_in + 3 * N8)); + + const VU16 raw0 = And(packed0, mask); + StoreU(BitCast(d8, raw0), d8, raw + 0 * N8); + + const VU16 raw1 = And(packed1, mask); + StoreU(BitCast(d8, raw1), d8, raw + 1 * N8); + + const VU16 raw2 = And(ShiftRight<4>(packed0), mask); + StoreU(BitCast(d8, raw2), d8, raw + 2 * N8); + + const VU16 raw3 = And(ShiftRight<4>(packed1), mask); + StoreU(BitCast(d8, raw3), d8, raw + 3 * N8); + + const VU16 raw4 = And(packed2, mask); + StoreU(BitCast(d8, raw4), d8, raw + 4 * N8); + + const VU16 raw5 = And(packed3, mask); + StoreU(BitCast(d8, raw5), d8, raw + 5 * N8); + + const VU16 raw6 = And(ShiftRight<4>(packed2), mask); + StoreU(BitCast(d8, raw6), d8, raw + 6 * N8); + + const VU16 raw7 = And(ShiftRight<4>(packed3), mask); + StoreU(BitCast(d8, raw7), d8, raw + 7 * N8); + } +}; // Pack8<4> + +template <> +struct Pack8<5> { + template + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + const VU16 raw0 = BitCast(d16, LoadU(d8, raw + 0 * N8)); + const VU16 raw1 = BitCast(d16, LoadU(d8, raw + 1 * N8)); + const VU16 raw2 = BitCast(d16, LoadU(d8, raw + 2 * N8)); + const VU16 raw3 = BitCast(d16, LoadU(d8, raw + 3 * N8)); + const VU16 raw4 = BitCast(d16, LoadU(d8, raw + 4 * N8)); + const VU16 raw5 = BitCast(d16, LoadU(d8, raw + 5 * N8)); + const VU16 raw6 = BitCast(d16, LoadU(d8, raw + 6 * N8)); + const VU16 raw7 = BitCast(d16, LoadU(d8, raw + 7 * N8)); + + // Fill upper three bits with upper bits from raw4..7. + const VU16 hi3 = Set(d16, 0xE0E0u); + const VU16 packed0 = OrAnd(raw0, ShiftLeft<3>(raw4), hi3); + const VU16 packed1 = OrAnd(raw1, ShiftLeft<3>(raw5), hi3); + const VU16 packed2 = OrAnd(raw2, ShiftLeft<3>(raw6), hi3); + const VU16 packed3 = OrAnd(raw3, ShiftLeft<3>(raw7), hi3); + + StoreU(BitCast(d8, packed0), d8, packed_out + 0 * N8); + StoreU(BitCast(d8, packed1), d8, packed_out + 1 * N8); + StoreU(BitCast(d8, packed2), d8, packed_out + 2 * N8); + StoreU(BitCast(d8, packed3), d8, packed_out + 3 * N8); + + // Combine lower two bits of raw4..7 into packed4. + const VU16 lo2 = Set(d16, 0x0303u); + const VU16 packed4 = Or(And(raw4, lo2), Xor3(ShiftLeft<2>(And(raw5, lo2)), + ShiftLeft<4>(And(raw6, lo2)), + ShiftLeft<6>(And(raw7, lo2)))); + StoreU(BitCast(d8, packed4), d8, packed_out + 4 * N8); + } + + template + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + + const VU16 packed0 = BitCast(d16, LoadU(d8, packed_in + 0 * N8)); + const VU16 packed1 = BitCast(d16, LoadU(d8, packed_in + 1 * N8)); + const VU16 packed2 = BitCast(d16, LoadU(d8, packed_in + 2 * N8)); + const VU16 packed3 = BitCast(d16, LoadU(d8, packed_in + 3 * N8)); + const VU16 packed4 = BitCast(d16, LoadU(d8, packed_in + 4 * N8)); + + const VU16 mask = Set(d16, 0x1F1Fu); // Lowest 5 bits per byte + + const VU16 raw0 = And(packed0, mask); + StoreU(BitCast(d8, raw0), d8, raw + 0 * N8); + + const VU16 raw1 = And(packed1, mask); + StoreU(BitCast(d8, raw1), d8, raw + 1 * N8); + + const VU16 raw2 = And(packed2, mask); + StoreU(BitCast(d8, raw2), d8, raw + 2 * N8); + + const VU16 raw3 = And(packed3, mask); + StoreU(BitCast(d8, raw3), d8, raw + 3 * N8); + + // The upper bits are the top 3 bits shifted right by three. + const VU16 top4 = ShiftRight<3>(AndNot(mask, packed0)); + const VU16 top5 = ShiftRight<3>(AndNot(mask, packed1)); + const VU16 top6 = ShiftRight<3>(AndNot(mask, packed2)); + const VU16 top7 = ShiftRight<3>(AndNot(mask, packed3)); + + // Insert the lower 2 bits, which were concatenated into a byte. + const VU16 lo2 = Set(d16, 0x0303u); + const VU16 raw4 = OrAnd(top4, lo2, packed4); + const VU16 raw5 = OrAnd(top5, lo2, ShiftRight<2>(packed4)); + const VU16 raw6 = OrAnd(top6, lo2, ShiftRight<4>(packed4)); + const VU16 raw7 = OrAnd(top7, lo2, ShiftRight<6>(packed4)); + + StoreU(BitCast(d8, raw4), d8, raw + 4 * N8); + StoreU(BitCast(d8, raw5), d8, raw + 5 * N8); + StoreU(BitCast(d8, raw6), d8, raw + 6 * N8); + StoreU(BitCast(d8, raw7), d8, raw + 7 * N8); + } +}; // Pack8<5> + +template <> +struct Pack8<6> { + template + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + const VU16 raw0 = BitCast(d16, LoadU(d8, raw + 0 * N8)); + const VU16 raw1 = BitCast(d16, LoadU(d8, raw + 1 * N8)); + const VU16 raw2 = BitCast(d16, LoadU(d8, raw + 2 * N8)); + const VU16 raw3 = BitCast(d16, LoadU(d8, raw + 3 * N8)); + const VU16 raw4 = BitCast(d16, LoadU(d8, raw + 4 * N8)); + const VU16 raw5 = BitCast(d16, LoadU(d8, raw + 5 * N8)); + const VU16 raw6 = BitCast(d16, LoadU(d8, raw + 6 * N8)); + const VU16 raw7 = BitCast(d16, LoadU(d8, raw + 7 * N8)); + + const VU16 hi2 = Set(d16, 0xC0C0u); + // Each triplet of these stores raw3/raw7 (6 bits) in the upper 2 bits. + const VU16 packed0 = OrAnd(raw0, ShiftLeft<2>(raw3), hi2); + const VU16 packed1 = OrAnd(raw1, ShiftLeft<4>(raw3), hi2); + const VU16 packed2 = OrAnd(raw2, ShiftLeft<6>(raw3), hi2); + const VU16 packed3 = OrAnd(raw4, ShiftLeft<2>(raw7), hi2); + const VU16 packed4 = OrAnd(raw5, ShiftLeft<4>(raw7), hi2); + const VU16 packed5 = OrAnd(raw6, ShiftLeft<6>(raw7), hi2); + + StoreU(BitCast(d8, packed0), d8, packed_out + 0 * N8); + StoreU(BitCast(d8, packed1), d8, packed_out + 1 * N8); + StoreU(BitCast(d8, packed2), d8, packed_out + 2 * N8); + StoreU(BitCast(d8, packed3), d8, packed_out + 3 * N8); + StoreU(BitCast(d8, packed4), d8, packed_out + 4 * N8); + StoreU(BitCast(d8, packed5), d8, packed_out + 5 * N8); + } + + template + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + const VU16 mask = Set(d16, 0x3F3Fu); // Lowest 6 bits per byte + + const VU16 packed0 = BitCast(d16, LoadU(d8, packed_in + 0 * N8)); + const VU16 packed1 = BitCast(d16, LoadU(d8, packed_in + 1 * N8)); + const VU16 packed2 = BitCast(d16, LoadU(d8, packed_in + 2 * N8)); + const VU16 packed3 = BitCast(d16, LoadU(d8, packed_in + 3 * N8)); + const VU16 packed4 = BitCast(d16, LoadU(d8, packed_in + 4 * N8)); + const VU16 packed5 = BitCast(d16, LoadU(d8, packed_in + 5 * N8)); + + const VU16 raw0 = And(packed0, mask); + StoreU(BitCast(d8, raw0), d8, raw + 0 * N8); + + const VU16 raw1 = And(packed1, mask); + StoreU(BitCast(d8, raw1), d8, raw + 1 * N8); + + const VU16 raw2 = And(packed2, mask); + StoreU(BitCast(d8, raw2), d8, raw + 2 * N8); + + const VU16 raw4 = And(packed3, mask); + StoreU(BitCast(d8, raw4), d8, raw + 4 * N8); + + const VU16 raw5 = And(packed4, mask); + StoreU(BitCast(d8, raw5), d8, raw + 5 * N8); + + const VU16 raw6 = And(packed5, mask); + StoreU(BitCast(d8, raw6), d8, raw + 6 * N8); + + // raw3/7 are the concatenation of the upper two bits in packed0..2. + const VU16 raw3 = Xor3(ShiftRight<6>(AndNot(mask, packed2)), + ShiftRight<4>(AndNot(mask, packed1)), + ShiftRight<2>(AndNot(mask, packed0))); + const VU16 raw7 = Xor3(ShiftRight<6>(AndNot(mask, packed5)), + ShiftRight<4>(AndNot(mask, packed4)), + ShiftRight<2>(AndNot(mask, packed3))); + StoreU(BitCast(d8, raw3), d8, raw + 3 * N8); + StoreU(BitCast(d8, raw7), d8, raw + 7 * N8); + } +}; // Pack8<6> + +template <> +struct Pack8<7> { + template + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + const VU16 raw0 = BitCast(d16, LoadU(d8, raw + 0 * N8)); + const VU16 raw1 = BitCast(d16, LoadU(d8, raw + 1 * N8)); + const VU16 raw2 = BitCast(d16, LoadU(d8, raw + 2 * N8)); + const VU16 raw3 = BitCast(d16, LoadU(d8, raw + 3 * N8)); + const VU16 raw4 = BitCast(d16, LoadU(d8, raw + 4 * N8)); + const VU16 raw5 = BitCast(d16, LoadU(d8, raw + 5 * N8)); + const VU16 raw6 = BitCast(d16, LoadU(d8, raw + 6 * N8)); + // Inserted into top bit of packed0..6. + const VU16 raw7 = BitCast(d16, LoadU(d8, raw + 7 * N8)); + + const VU16 hi1 = Set(d16, 0x8080u); + const VU16 packed0 = OrAnd(raw0, Add(raw7, raw7), hi1); + const VU16 packed1 = OrAnd(raw1, ShiftLeft<2>(raw7), hi1); + const VU16 packed2 = OrAnd(raw2, ShiftLeft<3>(raw7), hi1); + const VU16 packed3 = OrAnd(raw3, ShiftLeft<4>(raw7), hi1); + const VU16 packed4 = OrAnd(raw4, ShiftLeft<5>(raw7), hi1); + const VU16 packed5 = OrAnd(raw5, ShiftLeft<6>(raw7), hi1); + const VU16 packed6 = OrAnd(raw6, ShiftLeft<7>(raw7), hi1); + + StoreU(BitCast(d8, packed0), d8, packed_out + 0 * N8); + StoreU(BitCast(d8, packed1), d8, packed_out + 1 * N8); + StoreU(BitCast(d8, packed2), d8, packed_out + 2 * N8); + StoreU(BitCast(d8, packed3), d8, packed_out + 3 * N8); + StoreU(BitCast(d8, packed4), d8, packed_out + 4 * N8); + StoreU(BitCast(d8, packed5), d8, packed_out + 5 * N8); + StoreU(BitCast(d8, packed6), d8, packed_out + 6 * N8); + } + + template + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + const RepartitionToWide d16; + using VU16 = Vec; + const size_t N8 = Lanes(d8); + + const VU16 packed0 = BitCast(d16, LoadU(d8, packed_in + 0 * N8)); + const VU16 packed1 = BitCast(d16, LoadU(d8, packed_in + 1 * N8)); + const VU16 packed2 = BitCast(d16, LoadU(d8, packed_in + 2 * N8)); + const VU16 packed3 = BitCast(d16, LoadU(d8, packed_in + 3 * N8)); + const VU16 packed4 = BitCast(d16, LoadU(d8, packed_in + 4 * N8)); + const VU16 packed5 = BitCast(d16, LoadU(d8, packed_in + 5 * N8)); + const VU16 packed6 = BitCast(d16, LoadU(d8, packed_in + 6 * N8)); + + const VU16 mask = Set(d16, 0x7F7Fu); // Lowest 7 bits per byte + + const VU16 raw0 = And(packed0, mask); + StoreU(BitCast(d8, raw0), d8, raw + 0 * N8); + + const VU16 raw1 = And(packed1, mask); + StoreU(BitCast(d8, raw1), d8, raw + 1 * N8); + + const VU16 raw2 = And(packed2, mask); + StoreU(BitCast(d8, raw2), d8, raw + 2 * N8); + + const VU16 raw3 = And(packed3, mask); + StoreU(BitCast(d8, raw3), d8, raw + 3 * N8); + + const VU16 raw4 = And(packed4, mask); + StoreU(BitCast(d8, raw4), d8, raw + 4 * N8); + + const VU16 raw5 = And(packed5, mask); + StoreU(BitCast(d8, raw5), d8, raw + 5 * N8); + + const VU16 raw6 = And(packed6, mask); + StoreU(BitCast(d8, raw6), d8, raw + 6 * N8); + + const VU16 p0 = Xor3(ShiftRight<7>(AndNot(mask, packed6)), + ShiftRight<6>(AndNot(mask, packed5)), + ShiftRight<5>(AndNot(mask, packed4))); + const VU16 p1 = Xor3(ShiftRight<4>(AndNot(mask, packed3)), + ShiftRight<3>(AndNot(mask, packed2)), + ShiftRight<2>(AndNot(mask, packed1))); + const VU16 raw7 = Xor3(ShiftRight<1>(AndNot(mask, packed0)), p0, p1); + StoreU(BitCast(d8, raw7), d8, raw + 7 * N8); + } +}; // Pack8<7> + +template <> +struct Pack8<8> { + template + HWY_INLINE void Pack(D8 d8, const uint8_t* HWY_RESTRICT raw, + uint8_t* HWY_RESTRICT packed_out) const { + using VU8 = Vec; + const size_t N8 = Lanes(d8); + const VU8 raw0 = LoadU(d8, raw + 0 * N8); + const VU8 raw1 = LoadU(d8, raw + 1 * N8); + const VU8 raw2 = LoadU(d8, raw + 2 * N8); + const VU8 raw3 = LoadU(d8, raw + 3 * N8); + const VU8 raw4 = LoadU(d8, raw + 4 * N8); + const VU8 raw5 = LoadU(d8, raw + 5 * N8); + const VU8 raw6 = LoadU(d8, raw + 6 * N8); + const VU8 raw7 = LoadU(d8, raw + 7 * N8); + + StoreU(raw0, d8, packed_out + 0 * N8); + StoreU(raw1, d8, packed_out + 1 * N8); + StoreU(raw2, d8, packed_out + 2 * N8); + StoreU(raw3, d8, packed_out + 3 * N8); + StoreU(raw4, d8, packed_out + 4 * N8); + StoreU(raw5, d8, packed_out + 5 * N8); + StoreU(raw6, d8, packed_out + 6 * N8); + StoreU(raw7, d8, packed_out + 7 * N8); + } + + template + HWY_INLINE void Unpack(D8 d8, const uint8_t* HWY_RESTRICT packed_in, + uint8_t* HWY_RESTRICT raw) const { + using VU8 = Vec; + const size_t N8 = Lanes(d8); + const VU8 raw0 = LoadU(d8, packed_in + 0 * N8); + const VU8 raw1 = LoadU(d8, packed_in + 1 * N8); + const VU8 raw2 = LoadU(d8, packed_in + 2 * N8); + const VU8 raw3 = LoadU(d8, packed_in + 3 * N8); + const VU8 raw4 = LoadU(d8, packed_in + 4 * N8); + const VU8 raw5 = LoadU(d8, packed_in + 5 * N8); + const VU8 raw6 = LoadU(d8, packed_in + 6 * N8); + const VU8 raw7 = LoadU(d8, packed_in + 7 * N8); + + StoreU(raw0, d8, raw + 0 * N8); + StoreU(raw1, d8, raw + 1 * N8); + StoreU(raw2, d8, raw + 2 * N8); + StoreU(raw3, d8, raw + 3 * N8); + StoreU(raw4, d8, raw + 4 * N8); + StoreU(raw5, d8, raw + 5 * N8); + StoreU(raw6, d8, raw + 6 * N8); + StoreU(raw7, d8, raw + 7 * N8); + } +}; // Pack8<8> + +template <> +struct Pack16<1> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + const VU16 p0 = Xor3(ShiftLeft<2>(raw2), Add(raw1, raw1), raw0); + const VU16 p1 = + Xor3(ShiftLeft<5>(raw5), ShiftLeft<4>(raw4), ShiftLeft<3>(raw3)); + const VU16 p2 = + Xor3(ShiftLeft<8>(raw8), ShiftLeft<7>(raw7), ShiftLeft<6>(raw6)); + const VU16 p3 = + Xor3(ShiftLeft<0xB>(rawB), ShiftLeft<0xA>(rawA), ShiftLeft<9>(raw9)); + const VU16 p4 = + Xor3(ShiftLeft<0xE>(rawE), ShiftLeft<0xD>(rawD), ShiftLeft<0xC>(rawC)); + const VU16 packed = + Or(Xor3(ShiftLeft<0xF>(rawF), p0, p1), Xor3(p2, p3, p4)); + StoreU(packed, d, packed_out); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 mask = Set(d, 1u); // Lowest bit + + const VU16 packed = LoadU(d, packed_in); + + const VU16 raw0 = And(packed, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(ShiftRight<1>(packed), mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(ShiftRight<2>(packed), mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(ShiftRight<3>(packed), mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(ShiftRight<4>(packed), mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(ShiftRight<5>(packed), mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(ShiftRight<6>(packed), mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(ShiftRight<7>(packed), mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(ShiftRight<8>(packed), mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(ShiftRight<9>(packed), mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(ShiftRight<0xA>(packed), mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(ShiftRight<0xB>(packed), mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(ShiftRight<0xC>(packed), mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(ShiftRight<0xD>(packed), mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = And(ShiftRight<0xE>(packed), mask); + StoreU(rawE, d, raw + 0xE * N); + + const VU16 rawF = ShiftRight<0xF>(packed); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<1> + +template <> +struct Pack16<2> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + VU16 packed0 = Xor3(ShiftLeft<4>(raw4), ShiftLeft<2>(raw2), raw0); + VU16 packed1 = Xor3(ShiftLeft<4>(raw5), ShiftLeft<2>(raw3), raw1); + packed0 = Xor3(packed0, ShiftLeft<8>(raw8), ShiftLeft<6>(raw6)); + packed1 = Xor3(packed1, ShiftLeft<8>(raw9), ShiftLeft<6>(raw7)); + + packed0 = Xor3(packed0, ShiftLeft<12>(rawC), ShiftLeft<10>(rawA)); + packed1 = Xor3(packed1, ShiftLeft<12>(rawD), ShiftLeft<10>(rawB)); + + packed0 = Or(packed0, ShiftLeft<14>(rawE)); + packed1 = Or(packed1, ShiftLeft<14>(rawF)); + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 mask = Set(d, 0x3u); // Lowest 2 bits + + const VU16 packed0 = LoadU(d, packed_in + 0 * N); + const VU16 packed1 = LoadU(d, packed_in + 1 * N); + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(ShiftRight<2>(packed0), mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(ShiftRight<2>(packed1), mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(ShiftRight<4>(packed0), mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(ShiftRight<4>(packed1), mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(ShiftRight<6>(packed0), mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(ShiftRight<6>(packed1), mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(ShiftRight<8>(packed0), mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(ShiftRight<8>(packed1), mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(ShiftRight<0xA>(packed0), mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(ShiftRight<0xA>(packed1), mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(ShiftRight<0xC>(packed0), mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(ShiftRight<0xC>(packed1), mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = ShiftRight<0xE>(packed0); + StoreU(rawE, d, raw + 0xE * N); + + const VU16 rawF = ShiftRight<0xE>(packed1); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<2> + +template <> +struct Pack16<3> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // We can fit 15 raw vectors in three packed vectors (five each). + VU16 packed0 = Xor3(ShiftLeft<6>(raw6), ShiftLeft<3>(raw3), raw0); + VU16 packed1 = Xor3(ShiftLeft<6>(raw7), ShiftLeft<3>(raw4), raw1); + VU16 packed2 = Xor3(ShiftLeft<6>(raw8), ShiftLeft<3>(raw5), raw2); + + // rawF will be scattered into the upper bit of these three. + packed0 = Xor3(packed0, ShiftLeft<12>(rawC), ShiftLeft<9>(raw9)); + packed1 = Xor3(packed1, ShiftLeft<12>(rawD), ShiftLeft<9>(rawA)); + packed2 = Xor3(packed2, ShiftLeft<12>(rawE), ShiftLeft<9>(rawB)); + + const VU16 hi1 = Set(d, 0x8000u); + packed0 = Or(packed0, ShiftLeft<15>(rawF)); // MSB only, no mask + packed1 = OrAnd(packed1, ShiftLeft<14>(rawF), hi1); + packed2 = OrAnd(packed2, ShiftLeft<13>(rawF), hi1); + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 mask = Set(d, 0x7u); // Lowest 3 bits + + const VU16 packed0 = LoadU(d, packed_in + 0 * N); + const VU16 packed1 = LoadU(d, packed_in + 1 * N); + const VU16 packed2 = LoadU(d, packed_in + 2 * N); + + const VU16 raw0 = And(mask, packed0); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(mask, packed1); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(mask, packed2); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(mask, ShiftRight<3>(packed0)); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(mask, ShiftRight<3>(packed1)); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(mask, ShiftRight<3>(packed2)); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(mask, ShiftRight<6>(packed0)); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(mask, ShiftRight<6>(packed1)); + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(mask, ShiftRight<6>(packed2)); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(mask, ShiftRight<9>(packed0)); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(mask, ShiftRight<9>(packed1)); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(mask, ShiftRight<9>(packed2)); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(mask, ShiftRight<12>(packed0)); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(mask, ShiftRight<12>(packed1)); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = And(mask, ShiftRight<12>(packed2)); + StoreU(rawE, d, raw + 0xE * N); + + // rawF is the concatenation of the upper bit of packed0..2. + const VU16 down0 = ShiftRight<15>(packed0); + const VU16 down1 = ShiftRight<15>(packed1); + const VU16 down2 = ShiftRight<15>(packed2); + const VU16 rawF = Xor3(ShiftLeft<2>(down2), Add(down1, down1), down0); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<3> + +template <> +struct Pack16<4> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + VU16 packed0 = Xor3(ShiftLeft<8>(raw4), ShiftLeft<4>(raw2), raw0); + VU16 packed1 = Xor3(ShiftLeft<8>(raw5), ShiftLeft<4>(raw3), raw1); + packed0 = Or(packed0, ShiftLeft<12>(raw6)); + packed1 = Or(packed1, ShiftLeft<12>(raw7)); + VU16 packed2 = Xor3(ShiftLeft<8>(rawC), ShiftLeft<4>(rawA), raw8); + VU16 packed3 = Xor3(ShiftLeft<8>(rawD), ShiftLeft<4>(rawB), raw9); + packed2 = Or(packed2, ShiftLeft<12>(rawE)); + packed3 = Or(packed3, ShiftLeft<12>(rawF)); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 mask = Set(d, 0xFu); // Lowest 4 bits + + const VU16 packed0 = LoadU(d, packed_in + 0 * N); + const VU16 packed1 = LoadU(d, packed_in + 1 * N); + const VU16 packed2 = LoadU(d, packed_in + 2 * N); + const VU16 packed3 = LoadU(d, packed_in + 3 * N); + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(ShiftRight<4>(packed0), mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(ShiftRight<4>(packed1), mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(ShiftRight<8>(packed0), mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(ShiftRight<8>(packed1), mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = ShiftRight<12>(packed0); // no mask required + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = ShiftRight<12>(packed1); // no mask required + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(packed2, mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(packed3, mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(ShiftRight<4>(packed2), mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(ShiftRight<4>(packed3), mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(ShiftRight<8>(packed2), mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(ShiftRight<8>(packed3), mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = ShiftRight<12>(packed2); // no mask required + StoreU(rawE, d, raw + 0xE * N); + + const VU16 rawF = ShiftRight<12>(packed3); // no mask required + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<4> + +template <> +struct Pack16<5> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // We can fit 15 raw vectors in five packed vectors (three each). + VU16 packed0 = Xor3(ShiftLeft<10>(rawA), ShiftLeft<5>(raw5), raw0); + VU16 packed1 = Xor3(ShiftLeft<10>(rawB), ShiftLeft<5>(raw6), raw1); + VU16 packed2 = Xor3(ShiftLeft<10>(rawC), ShiftLeft<5>(raw7), raw2); + VU16 packed3 = Xor3(ShiftLeft<10>(rawD), ShiftLeft<5>(raw8), raw3); + VU16 packed4 = Xor3(ShiftLeft<10>(rawE), ShiftLeft<5>(raw9), raw4); + + // rawF will be scattered into the upper bits of these five. + const VU16 hi1 = Set(d, 0x8000u); + packed0 = Or(packed0, ShiftLeft<15>(rawF)); // MSB only, no mask + packed1 = OrAnd(packed1, ShiftLeft<14>(rawF), hi1); + packed2 = OrAnd(packed2, ShiftLeft<13>(rawF), hi1); + packed3 = OrAnd(packed3, ShiftLeft<12>(rawF), hi1); + packed4 = OrAnd(packed4, ShiftLeft<11>(rawF), hi1); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + + const VU16 packed0 = LoadU(d, packed_in + 0 * N); + const VU16 packed1 = LoadU(d, packed_in + 1 * N); + const VU16 packed2 = LoadU(d, packed_in + 2 * N); + const VU16 packed3 = LoadU(d, packed_in + 3 * N); + const VU16 packed4 = LoadU(d, packed_in + 4 * N); + + const VU16 mask = Set(d, 0x1Fu); // Lowest 5 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed4, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(ShiftRight<5>(packed0), mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(ShiftRight<5>(packed1), mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(ShiftRight<5>(packed2), mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(ShiftRight<5>(packed3), mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(ShiftRight<5>(packed4), mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(ShiftRight<10>(packed0), mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(ShiftRight<10>(packed1), mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(ShiftRight<10>(packed2), mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(ShiftRight<10>(packed3), mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = And(ShiftRight<10>(packed4), mask); + StoreU(rawE, d, raw + 0xE * N); + + // rawF is the concatenation of the lower bit of packed0..4. + const VU16 down0 = ShiftRight<15>(packed0); + const VU16 down1 = ShiftRight<15>(packed1); + const VU16 hi1 = Set(d, 0x8000u); + const VU16 p0 = + Xor3(ShiftRight<13>(And(packed2, hi1)), Add(down1, down1), down0); + const VU16 rawF = Xor3(ShiftRight<11>(And(packed4, hi1)), + ShiftRight<12>(And(packed3, hi1)), p0); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<5> + +template <> +struct Pack16<6> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + const VU16 packed3 = Or(ShiftLeft<6>(raw7), raw3); + const VU16 packed7 = Or(ShiftLeft<6>(rawF), rawB); + // Three vectors, two 6-bit raw each; packed3 (12 bits) is spread over the + // four remainder bits at the top of each vector. + const VU16 packed0 = Xor3(ShiftLeft<12>(packed3), ShiftLeft<6>(raw4), raw0); + VU16 packed1 = Or(ShiftLeft<6>(raw5), raw1); + VU16 packed2 = Or(ShiftLeft<6>(raw6), raw2); + const VU16 packed4 = Xor3(ShiftLeft<12>(packed7), ShiftLeft<6>(rawC), raw8); + VU16 packed5 = Or(ShiftLeft<6>(rawD), raw9); + VU16 packed6 = Or(ShiftLeft<6>(rawE), rawA); + + const VU16 hi4 = Set(d, 0xF000u); + packed1 = OrAnd(packed1, ShiftLeft<8>(packed3), hi4); + packed2 = OrAnd(packed2, ShiftLeft<4>(packed3), hi4); + packed5 = OrAnd(packed5, ShiftLeft<8>(packed7), hi4); + packed6 = OrAnd(packed6, ShiftLeft<4>(packed7), hi4); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed4, d, packed_out + 3 * N); + StoreU(packed5, d, packed_out + 4 * N); + StoreU(packed6, d, packed_out + 5 * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 mask = Set(d, 0x3Fu); // Lowest 6 bits + + const VU16 packed0 = LoadU(d, packed_in + 0 * N); + const VU16 packed1 = LoadU(d, packed_in + 1 * N); + const VU16 packed2 = LoadU(d, packed_in + 2 * N); + const VU16 packed4 = LoadU(d, packed_in + 3 * N); + const VU16 packed5 = LoadU(d, packed_in + 4 * N); + const VU16 packed6 = LoadU(d, packed_in + 5 * N); + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw4 = And(ShiftRight<6>(packed0), mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(ShiftRight<6>(packed1), mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(ShiftRight<6>(packed2), mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw8 = And(packed4, mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(packed5, mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(packed6, mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawC = And(ShiftRight<6>(packed4), mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(ShiftRight<6>(packed5), mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = And(ShiftRight<6>(packed6), mask); + StoreU(rawE, d, raw + 0xE * N); + + // packed3 is the concatenation of the four upper bits in packed0..2. + const VU16 down0 = ShiftRight<12>(packed0); + const VU16 down4 = ShiftRight<12>(packed4); + const VU16 hi4 = Set(d, 0xF000u); + const VU16 packed3 = Xor3(ShiftRight<4>(And(packed2, hi4)), + ShiftRight<8>(And(packed1, hi4)), down0); + const VU16 packed7 = Xor3(ShiftRight<4>(And(packed6, hi4)), + ShiftRight<8>(And(packed5, hi4)), down4); + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 rawB = And(packed7, mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 raw7 = ShiftRight<6>(packed3); // upper bits already zero + StoreU(raw7, d, raw + 7 * N); + + const VU16 rawF = ShiftRight<6>(packed7); // upper bits already zero + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<6> + +template <> +struct Pack16<7> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + const VU16 packed7 = Or(ShiftLeft<7>(rawF), raw7); + // Seven vectors, two 7-bit raw each; packed7 (14 bits) is spread over the + // two remainder bits at the top of each vector. + const VU16 packed0 = Xor3(ShiftLeft<14>(packed7), ShiftLeft<7>(raw8), raw0); + VU16 packed1 = Or(ShiftLeft<7>(raw9), raw1); + VU16 packed2 = Or(ShiftLeft<7>(rawA), raw2); + VU16 packed3 = Or(ShiftLeft<7>(rawB), raw3); + VU16 packed4 = Or(ShiftLeft<7>(rawC), raw4); + VU16 packed5 = Or(ShiftLeft<7>(rawD), raw5); + VU16 packed6 = Or(ShiftLeft<7>(rawE), raw6); + + const VU16 hi2 = Set(d, 0xC000u); + packed1 = OrAnd(packed1, ShiftLeft<12>(packed7), hi2); + packed2 = OrAnd(packed2, ShiftLeft<10>(packed7), hi2); + packed3 = OrAnd(packed3, ShiftLeft<8>(packed7), hi2); + packed4 = OrAnd(packed4, ShiftLeft<6>(packed7), hi2); + packed5 = OrAnd(packed5, ShiftLeft<4>(packed7), hi2); + packed6 = OrAnd(packed6, ShiftLeft<2>(packed7), hi2); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + + const VU16 mask = Set(d, 0x7Fu); // Lowest 7 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed4, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(packed5, mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(packed6, mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw8 = And(ShiftRight<7>(packed0), mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(ShiftRight<7>(packed1), mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(ShiftRight<7>(packed2), mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(ShiftRight<7>(packed3), mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(ShiftRight<7>(packed4), mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(ShiftRight<7>(packed5), mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = And(ShiftRight<7>(packed6), mask); + StoreU(rawE, d, raw + 0xE * N); + + // packed7 is the concatenation of the two upper bits in packed0..6. + const VU16 down0 = ShiftRight<14>(packed0); + const VU16 hi2 = Set(d, 0xC000u); + const VU16 p0 = Xor3(ShiftRight<12>(And(packed1, hi2)), + ShiftRight<10>(And(packed2, hi2)), down0); + const VU16 p1 = Xor3(ShiftRight<8>(And(packed3, hi2)), // + ShiftRight<6>(And(packed4, hi2)), + ShiftRight<4>(And(packed5, hi2))); + const VU16 packed7 = Xor3(ShiftRight<2>(And(packed6, hi2)), p1, p0); + + const VU16 raw7 = And(packed7, mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 rawF = ShiftRight<7>(packed7); // upper bits already zero + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<7> + +template <> +struct Pack16<8> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // This is equivalent to ConcatEven with 8-bit lanes, but much more + // efficient on RVV and slightly less efficient on SVE2. + const VU16 packed0 = Or(ShiftLeft<8>(raw2), raw0); + const VU16 packed1 = Or(ShiftLeft<8>(raw3), raw1); + const VU16 packed2 = Or(ShiftLeft<8>(raw6), raw4); + const VU16 packed3 = Or(ShiftLeft<8>(raw7), raw5); + const VU16 packed4 = Or(ShiftLeft<8>(rawA), raw8); + const VU16 packed5 = Or(ShiftLeft<8>(rawB), raw9); + const VU16 packed6 = Or(ShiftLeft<8>(rawE), rawC); + const VU16 packed7 = Or(ShiftLeft<8>(rawF), rawD); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 mask = Set(d, 0xFFu); // Lowest 8 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = ShiftRight<8>(packed0); // upper bits already zero + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = ShiftRight<8>(packed1); // upper bits already zero + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed2, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(packed3, mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = ShiftRight<8>(packed2); // upper bits already zero + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = ShiftRight<8>(packed3); // upper bits already zero + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(packed4, mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(packed5, mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = ShiftRight<8>(packed4); // upper bits already zero + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = ShiftRight<8>(packed5); // upper bits already zero + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(packed6, mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(packed7, mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = ShiftRight<8>(packed6); // upper bits already zero + StoreU(rawE, d, raw + 0xE * N); + + const VU16 rawF = ShiftRight<8>(packed7); // upper bits already zero + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<8> + +template <> +struct Pack16<9> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + // 8 vectors, each with 9+7 bits; top 2 bits are concatenated into packed8. + const VU16 packed0 = Or(ShiftLeft<9>(raw8), raw0); + const VU16 packed1 = Or(ShiftLeft<9>(raw9), raw1); + const VU16 packed2 = Or(ShiftLeft<9>(rawA), raw2); + const VU16 packed3 = Or(ShiftLeft<9>(rawB), raw3); + const VU16 packed4 = Or(ShiftLeft<9>(rawC), raw4); + const VU16 packed5 = Or(ShiftLeft<9>(rawD), raw5); + const VU16 packed6 = Or(ShiftLeft<9>(rawE), raw6); + const VU16 packed7 = Or(ShiftLeft<9>(rawF), raw7); + + // We could shift down, OR and shift up, but two shifts are typically more + // expensive than AND, shift into position, and OR (which can be further + // reduced via Xor3). + const VU16 mid2 = Set(d, 0x180u); // top 2 in lower 9 + const VU16 part8 = ShiftRight<7>(And(raw8, mid2)); + const VU16 part9 = ShiftRight<5>(And(raw9, mid2)); + const VU16 partA = ShiftRight<3>(And(rawA, mid2)); + const VU16 partB = ShiftRight<1>(And(rawB, mid2)); + const VU16 partC = ShiftLeft<1>(And(rawC, mid2)); + const VU16 partD = ShiftLeft<3>(And(rawD, mid2)); + const VU16 partE = ShiftLeft<5>(And(rawE, mid2)); + const VU16 partF = ShiftLeft<7>(And(rawF, mid2)); + const VU16 packed8 = Xor3(Xor3(part8, part9, partA), + Xor3(partB, partC, partD), Or(partE, partF)); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + StoreU(packed8, d, packed_out + 8 * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 packed8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + + const VU16 mask = Set(d, 0x1FFu); // Lowest 9 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed4, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(packed5, mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(packed6, mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(packed7, mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 mid2 = Set(d, 0x180u); // top 2 in lower 9 + const VU16 raw8 = + OrAnd(ShiftRight<9>(packed0), ShiftLeft<7>(packed8), mid2); + const VU16 raw9 = + OrAnd(ShiftRight<9>(packed1), ShiftLeft<5>(packed8), mid2); + const VU16 rawA = + OrAnd(ShiftRight<9>(packed2), ShiftLeft<3>(packed8), mid2); + const VU16 rawB = + OrAnd(ShiftRight<9>(packed3), ShiftLeft<1>(packed8), mid2); + const VU16 rawC = + OrAnd(ShiftRight<9>(packed4), ShiftRight<1>(packed8), mid2); + const VU16 rawD = + OrAnd(ShiftRight<9>(packed5), ShiftRight<3>(packed8), mid2); + const VU16 rawE = + OrAnd(ShiftRight<9>(packed6), ShiftRight<5>(packed8), mid2); + const VU16 rawF = + OrAnd(ShiftRight<9>(packed7), ShiftRight<7>(packed8), mid2); + + StoreU(raw8, d, raw + 8 * N); + StoreU(raw9, d, raw + 9 * N); + StoreU(rawA, d, raw + 0xA * N); + StoreU(rawB, d, raw + 0xB * N); + StoreU(rawC, d, raw + 0xC * N); + StoreU(rawD, d, raw + 0xD * N); + StoreU(rawE, d, raw + 0xE * N); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<9> + +template <> +struct Pack16<10> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // 8 vectors, each with 10+6 bits; top 4 bits are concatenated into + // packed8 and packed9. + const VU16 packed0 = Or(ShiftLeft<10>(raw8), raw0); + const VU16 packed1 = Or(ShiftLeft<10>(raw9), raw1); + const VU16 packed2 = Or(ShiftLeft<10>(rawA), raw2); + const VU16 packed3 = Or(ShiftLeft<10>(rawB), raw3); + const VU16 packed4 = Or(ShiftLeft<10>(rawC), raw4); + const VU16 packed5 = Or(ShiftLeft<10>(rawD), raw5); + const VU16 packed6 = Or(ShiftLeft<10>(rawE), raw6); + const VU16 packed7 = Or(ShiftLeft<10>(rawF), raw7); + + // We could shift down, OR and shift up, but two shifts are typically more + // expensive than AND, shift into position, and OR (which can be further + // reduced via Xor3). + const VU16 mid4 = Set(d, 0x3C0u); // top 4 in lower 10 + const VU16 part8 = ShiftRight<6>(And(raw8, mid4)); + const VU16 part9 = ShiftRight<2>(And(raw9, mid4)); + const VU16 partA = ShiftLeft<2>(And(rawA, mid4)); + const VU16 partB = ShiftLeft<6>(And(rawB, mid4)); + const VU16 partC = ShiftRight<6>(And(rawC, mid4)); + const VU16 partD = ShiftRight<2>(And(rawD, mid4)); + const VU16 partE = ShiftLeft<2>(And(rawE, mid4)); + const VU16 partF = ShiftLeft<6>(And(rawF, mid4)); + const VU16 packed8 = Or(Xor3(part8, part9, partA), partB); + const VU16 packed9 = Or(Xor3(partC, partD, partE), partF); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + StoreU(packed8, d, packed_out + 8 * N); + StoreU(packed9, d, packed_out + 9 * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 packed8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + const VU16 packed9 = BitCast(d, LoadU(d, packed_in + 9 * N)); + + const VU16 mask = Set(d, 0x3FFu); // Lowest 10 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed4, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(packed5, mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(packed6, mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(packed7, mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 mid4 = Set(d, 0x3C0u); // top 4 in lower 10 + const VU16 raw8 = + OrAnd(ShiftRight<10>(packed0), ShiftLeft<6>(packed8), mid4); + const VU16 raw9 = + OrAnd(ShiftRight<10>(packed1), ShiftLeft<2>(packed8), mid4); + const VU16 rawA = + OrAnd(ShiftRight<10>(packed2), ShiftRight<2>(packed8), mid4); + const VU16 rawB = + OrAnd(ShiftRight<10>(packed3), ShiftRight<6>(packed8), mid4); + const VU16 rawC = + OrAnd(ShiftRight<10>(packed4), ShiftLeft<6>(packed9), mid4); + const VU16 rawD = + OrAnd(ShiftRight<10>(packed5), ShiftLeft<2>(packed9), mid4); + const VU16 rawE = + OrAnd(ShiftRight<10>(packed6), ShiftRight<2>(packed9), mid4); + const VU16 rawF = + OrAnd(ShiftRight<10>(packed7), ShiftRight<6>(packed9), mid4); + + StoreU(raw8, d, raw + 8 * N); + StoreU(raw9, d, raw + 9 * N); + StoreU(rawA, d, raw + 0xA * N); + StoreU(rawB, d, raw + 0xB * N); + StoreU(rawC, d, raw + 0xC * N); + StoreU(rawD, d, raw + 0xD * N); + StoreU(rawE, d, raw + 0xE * N); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<10> + +template <> +struct Pack16<11> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // It is not obvious what the optimal partitioning looks like. To reduce the + // number of constants, we want to minimize the number of distinct bit + // lengths. 11+5 also requires 6-bit remnants with 4-bit leftovers. + // 8+3 seems better: it is easier to scatter 3 bits into the MSBs. + const VU16 lo8 = Set(d, 0xFFu); + + // Lower 8 bits of all raw + const VU16 packed0 = OrAnd(ShiftLeft<8>(raw1), raw0, lo8); + const VU16 packed1 = OrAnd(ShiftLeft<8>(raw3), raw2, lo8); + const VU16 packed2 = OrAnd(ShiftLeft<8>(raw5), raw4, lo8); + const VU16 packed3 = OrAnd(ShiftLeft<8>(raw7), raw6, lo8); + const VU16 packed4 = OrAnd(ShiftLeft<8>(raw9), raw8, lo8); + const VU16 packed5 = OrAnd(ShiftLeft<8>(rawB), rawA, lo8); + const VU16 packed6 = OrAnd(ShiftLeft<8>(rawD), rawC, lo8); + const VU16 packed7 = OrAnd(ShiftLeft<8>(rawF), rawE, lo8); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + + // Three vectors, five 3bit remnants each, plus one 3bit in their MSB. + const VU16 top0 = ShiftRight<8>(raw0); + const VU16 top1 = ShiftRight<8>(raw1); + const VU16 top2 = ShiftRight<8>(raw2); + // Insert top raw bits into 3-bit groups within packed8..A. Moving the + // mask along avoids masking each of raw0..E and enables OrAnd. + VU16 next = Set(d, 0x38u); // 0x7 << 3 + VU16 packed8 = OrAnd(top0, ShiftRight<5>(raw3), next); + VU16 packed9 = OrAnd(top1, ShiftRight<5>(raw4), next); + VU16 packedA = OrAnd(top2, ShiftRight<5>(raw5), next); + next = ShiftLeft<3>(next); + packed8 = OrAnd(packed8, ShiftRight<2>(raw6), next); + packed9 = OrAnd(packed9, ShiftRight<2>(raw7), next); + packedA = OrAnd(packedA, ShiftRight<2>(raw8), next); + next = ShiftLeft<3>(next); + packed8 = OrAnd(packed8, Add(raw9, raw9), next); + packed9 = OrAnd(packed9, Add(rawA, rawA), next); + packedA = OrAnd(packedA, Add(rawB, rawB), next); + next = ShiftLeft<3>(next); + packed8 = OrAnd(packed8, ShiftLeft<4>(rawC), next); + packed9 = OrAnd(packed9, ShiftLeft<4>(rawD), next); + packedA = OrAnd(packedA, ShiftLeft<4>(rawE), next); + + // Scatter upper 3 bits of rawF into the upper bits. + next = ShiftLeft<3>(next); // = 0x8000u + packed8 = OrAnd(packed8, ShiftLeft<7>(rawF), next); + packed9 = OrAnd(packed9, ShiftLeft<6>(rawF), next); + packedA = OrAnd(packedA, ShiftLeft<5>(rawF), next); + + StoreU(packed8, d, packed_out + 8 * N); + StoreU(packed9, d, packed_out + 9 * N); + StoreU(packedA, d, packed_out + 0xA * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 packed8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + const VU16 packed9 = BitCast(d, LoadU(d, packed_in + 9 * N)); + const VU16 packedA = BitCast(d, LoadU(d, packed_in + 0xA * N)); + + const VU16 mask = Set(d, 0xFFu); // Lowest 8 bits + + const VU16 down0 = And(packed0, mask); + const VU16 down1 = ShiftRight<8>(packed0); + const VU16 down2 = And(packed1, mask); + const VU16 down3 = ShiftRight<8>(packed1); + const VU16 down4 = And(packed2, mask); + const VU16 down5 = ShiftRight<8>(packed2); + const VU16 down6 = And(packed3, mask); + const VU16 down7 = ShiftRight<8>(packed3); + const VU16 down8 = And(packed4, mask); + const VU16 down9 = ShiftRight<8>(packed4); + const VU16 downA = And(packed5, mask); + const VU16 downB = ShiftRight<8>(packed5); + const VU16 downC = And(packed6, mask); + const VU16 downD = ShiftRight<8>(packed6); + const VU16 downE = And(packed7, mask); + const VU16 downF = ShiftRight<8>(packed7); + + // Three bits from packed8..A, eight bits from down0..F. + const VU16 hi3 = Set(d, 0x700u); + const VU16 raw0 = OrAnd(down0, ShiftLeft<8>(packed8), hi3); + const VU16 raw1 = OrAnd(down1, ShiftLeft<8>(packed9), hi3); + const VU16 raw2 = OrAnd(down2, ShiftLeft<8>(packedA), hi3); + + const VU16 raw3 = OrAnd(down3, ShiftLeft<5>(packed8), hi3); + const VU16 raw4 = OrAnd(down4, ShiftLeft<5>(packed9), hi3); + const VU16 raw5 = OrAnd(down5, ShiftLeft<5>(packedA), hi3); + + const VU16 raw6 = OrAnd(down6, ShiftLeft<2>(packed8), hi3); + const VU16 raw7 = OrAnd(down7, ShiftLeft<2>(packed9), hi3); + const VU16 raw8 = OrAnd(down8, ShiftLeft<2>(packedA), hi3); + + const VU16 raw9 = OrAnd(down9, ShiftRight<1>(packed8), hi3); + const VU16 rawA = OrAnd(downA, ShiftRight<1>(packed9), hi3); + const VU16 rawB = OrAnd(downB, ShiftRight<1>(packedA), hi3); + + const VU16 rawC = OrAnd(downC, ShiftRight<4>(packed8), hi3); + const VU16 rawD = OrAnd(downD, ShiftRight<4>(packed9), hi3); + const VU16 rawE = OrAnd(downE, ShiftRight<4>(packedA), hi3); + + // Shift MSB into the top 3-of-11 and mask. + const VU16 rawF = Or(downF, Xor3(And(ShiftRight<7>(packed8), hi3), + And(ShiftRight<6>(packed9), hi3), + And(ShiftRight<5>(packedA), hi3))); + + StoreU(raw0, d, raw + 0 * N); + StoreU(raw1, d, raw + 1 * N); + StoreU(raw2, d, raw + 2 * N); + StoreU(raw3, d, raw + 3 * N); + StoreU(raw4, d, raw + 4 * N); + StoreU(raw5, d, raw + 5 * N); + StoreU(raw6, d, raw + 6 * N); + StoreU(raw7, d, raw + 7 * N); + StoreU(raw8, d, raw + 8 * N); + StoreU(raw9, d, raw + 9 * N); + StoreU(rawA, d, raw + 0xA * N); + StoreU(rawB, d, raw + 0xB * N); + StoreU(rawC, d, raw + 0xC * N); + StoreU(rawD, d, raw + 0xD * N); + StoreU(rawE, d, raw + 0xE * N); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<11> + +template <> +struct Pack16<12> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // 8 vectors, each with 12+4 bits; top 8 bits are concatenated into + // packed8 to packedB. + const VU16 packed0 = Or(ShiftLeft<12>(raw8), raw0); + const VU16 packed1 = Or(ShiftLeft<12>(raw9), raw1); + const VU16 packed2 = Or(ShiftLeft<12>(rawA), raw2); + const VU16 packed3 = Or(ShiftLeft<12>(rawB), raw3); + const VU16 packed4 = Or(ShiftLeft<12>(rawC), raw4); + const VU16 packed5 = Or(ShiftLeft<12>(rawD), raw5); + const VU16 packed6 = Or(ShiftLeft<12>(rawE), raw6); + const VU16 packed7 = Or(ShiftLeft<12>(rawF), raw7); + + // Masking after shifting left enables OrAnd. + const VU16 hi8 = Set(d, 0xFF00u); + const VU16 packed8 = OrAnd(ShiftRight<4>(raw8), ShiftLeft<4>(raw9), hi8); + const VU16 packed9 = OrAnd(ShiftRight<4>(rawA), ShiftLeft<4>(rawB), hi8); + const VU16 packedA = OrAnd(ShiftRight<4>(rawC), ShiftLeft<4>(rawD), hi8); + const VU16 packedB = OrAnd(ShiftRight<4>(rawE), ShiftLeft<4>(rawF), hi8); + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + StoreU(packed8, d, packed_out + 8 * N); + StoreU(packed9, d, packed_out + 9 * N); + StoreU(packedA, d, packed_out + 0xA * N); + StoreU(packedB, d, packed_out + 0xB * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 packed8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + const VU16 packed9 = BitCast(d, LoadU(d, packed_in + 9 * N)); + const VU16 packedA = BitCast(d, LoadU(d, packed_in + 0xA * N)); + const VU16 packedB = BitCast(d, LoadU(d, packed_in + 0xB * N)); + + const VU16 mask = Set(d, 0xFFFu); // Lowest 12 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed4, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(packed5, mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(packed6, mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(packed7, mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 mid8 = Set(d, 0xFF0u); // upper 8 in lower 12 + const VU16 raw8 = + OrAnd(ShiftRight<12>(packed0), ShiftLeft<4>(packed8), mid8); + const VU16 raw9 = + OrAnd(ShiftRight<12>(packed1), ShiftRight<4>(packed8), mid8); + const VU16 rawA = + OrAnd(ShiftRight<12>(packed2), ShiftLeft<4>(packed9), mid8); + const VU16 rawB = + OrAnd(ShiftRight<12>(packed3), ShiftRight<4>(packed9), mid8); + const VU16 rawC = + OrAnd(ShiftRight<12>(packed4), ShiftLeft<4>(packedA), mid8); + const VU16 rawD = + OrAnd(ShiftRight<12>(packed5), ShiftRight<4>(packedA), mid8); + const VU16 rawE = + OrAnd(ShiftRight<12>(packed6), ShiftLeft<4>(packedB), mid8); + const VU16 rawF = + OrAnd(ShiftRight<12>(packed7), ShiftRight<4>(packedB), mid8); + StoreU(raw8, d, raw + 8 * N); + StoreU(raw9, d, raw + 9 * N); + StoreU(rawA, d, raw + 0xA * N); + StoreU(rawB, d, raw + 0xB * N); + StoreU(rawC, d, raw + 0xC * N); + StoreU(rawD, d, raw + 0xD * N); + StoreU(rawE, d, raw + 0xE * N); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<12> + +template <> +struct Pack16<13> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // As with 11 bits, it is not obvious what the optimal partitioning looks + // like. We similarly go with an 8+5 split. + const VU16 lo8 = Set(d, 0xFFu); + + // Lower 8 bits of all raw + const VU16 packed0 = OrAnd(ShiftLeft<8>(raw1), raw0, lo8); + const VU16 packed1 = OrAnd(ShiftLeft<8>(raw3), raw2, lo8); + const VU16 packed2 = OrAnd(ShiftLeft<8>(raw5), raw4, lo8); + const VU16 packed3 = OrAnd(ShiftLeft<8>(raw7), raw6, lo8); + const VU16 packed4 = OrAnd(ShiftLeft<8>(raw9), raw8, lo8); + const VU16 packed5 = OrAnd(ShiftLeft<8>(rawB), rawA, lo8); + const VU16 packed6 = OrAnd(ShiftLeft<8>(rawD), rawC, lo8); + const VU16 packed7 = OrAnd(ShiftLeft<8>(rawF), rawE, lo8); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + + // Five vectors, three 5bit remnants each, plus one 5bit in their MSB. + const VU16 top0 = ShiftRight<8>(raw0); + const VU16 top1 = ShiftRight<8>(raw1); + const VU16 top2 = ShiftRight<8>(raw2); + const VU16 top3 = ShiftRight<8>(raw3); + const VU16 top4 = ShiftRight<8>(raw4); + + // Insert top raw bits into 5-bit groups within packed8..C. Moving the + // mask along avoids masking each of raw0..E and enables OrAnd. + VU16 next = Set(d, 0x3E0u); // 0x1F << 5 + VU16 packed8 = OrAnd(top0, ShiftRight<3>(raw5), next); + VU16 packed9 = OrAnd(top1, ShiftRight<3>(raw6), next); + VU16 packedA = OrAnd(top2, ShiftRight<3>(raw7), next); + VU16 packedB = OrAnd(top3, ShiftRight<3>(raw8), next); + VU16 packedC = OrAnd(top4, ShiftRight<3>(raw9), next); + next = ShiftLeft<5>(next); + packed8 = OrAnd(packed8, ShiftLeft<2>(rawA), next); + packed9 = OrAnd(packed9, ShiftLeft<2>(rawB), next); + packedA = OrAnd(packedA, ShiftLeft<2>(rawC), next); + packedB = OrAnd(packedB, ShiftLeft<2>(rawD), next); + packedC = OrAnd(packedC, ShiftLeft<2>(rawE), next); + + // Scatter upper 5 bits of rawF into the upper bits. + next = ShiftLeft<3>(next); // = 0x8000u + packed8 = OrAnd(packed8, ShiftLeft<7>(rawF), next); + packed9 = OrAnd(packed9, ShiftLeft<6>(rawF), next); + packedA = OrAnd(packedA, ShiftLeft<5>(rawF), next); + packedB = OrAnd(packedB, ShiftLeft<4>(rawF), next); + packedC = OrAnd(packedC, ShiftLeft<3>(rawF), next); + + StoreU(packed8, d, packed_out + 8 * N); + StoreU(packed9, d, packed_out + 9 * N); + StoreU(packedA, d, packed_out + 0xA * N); + StoreU(packedB, d, packed_out + 0xB * N); + StoreU(packedC, d, packed_out + 0xC * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 packed8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + const VU16 packed9 = BitCast(d, LoadU(d, packed_in + 9 * N)); + const VU16 packedA = BitCast(d, LoadU(d, packed_in + 0xA * N)); + const VU16 packedB = BitCast(d, LoadU(d, packed_in + 0xB * N)); + const VU16 packedC = BitCast(d, LoadU(d, packed_in + 0xC * N)); + + const VU16 mask = Set(d, 0xFFu); // Lowest 8 bits + + const VU16 down0 = And(packed0, mask); + const VU16 down1 = ShiftRight<8>(packed0); + const VU16 down2 = And(packed1, mask); + const VU16 down3 = ShiftRight<8>(packed1); + const VU16 down4 = And(packed2, mask); + const VU16 down5 = ShiftRight<8>(packed2); + const VU16 down6 = And(packed3, mask); + const VU16 down7 = ShiftRight<8>(packed3); + const VU16 down8 = And(packed4, mask); + const VU16 down9 = ShiftRight<8>(packed4); + const VU16 downA = And(packed5, mask); + const VU16 downB = ShiftRight<8>(packed5); + const VU16 downC = And(packed6, mask); + const VU16 downD = ShiftRight<8>(packed6); + const VU16 downE = And(packed7, mask); + const VU16 downF = ShiftRight<8>(packed7); + + // Upper five bits from packed8..C, eight bits from down0..F. + const VU16 hi5 = Set(d, 0x1F00u); + const VU16 raw0 = OrAnd(down0, ShiftLeft<8>(packed8), hi5); + const VU16 raw1 = OrAnd(down1, ShiftLeft<8>(packed9), hi5); + const VU16 raw2 = OrAnd(down2, ShiftLeft<8>(packedA), hi5); + const VU16 raw3 = OrAnd(down3, ShiftLeft<8>(packedB), hi5); + const VU16 raw4 = OrAnd(down4, ShiftLeft<8>(packedC), hi5); + + const VU16 raw5 = OrAnd(down5, ShiftLeft<3>(packed8), hi5); + const VU16 raw6 = OrAnd(down6, ShiftLeft<3>(packed9), hi5); + const VU16 raw7 = OrAnd(down7, ShiftLeft<3>(packedA), hi5); + const VU16 raw8 = OrAnd(down8, ShiftLeft<3>(packed9), hi5); + const VU16 raw9 = OrAnd(down9, ShiftLeft<3>(packedA), hi5); + + const VU16 rawA = OrAnd(downA, ShiftRight<2>(packed8), hi5); + const VU16 rawB = OrAnd(downB, ShiftRight<2>(packed9), hi5); + const VU16 rawC = OrAnd(downC, ShiftRight<2>(packedA), hi5); + const VU16 rawD = OrAnd(downD, ShiftRight<2>(packed9), hi5); + const VU16 rawE = OrAnd(downE, ShiftRight<2>(packedA), hi5); + + // Shift MSB into the top 5-of-11 and mask. + const VU16 p0 = Xor3(And(ShiftRight<7>(packed8), hi5), // + And(ShiftRight<6>(packed9), hi5), + And(ShiftRight<5>(packedA), hi5)); + const VU16 p1 = Xor3(And(ShiftRight<4>(packedB), hi5), + And(ShiftRight<3>(packedC), hi5), downF); + const VU16 rawF = Or(p0, p1); + + StoreU(raw0, d, raw + 0 * N); + StoreU(raw1, d, raw + 1 * N); + StoreU(raw2, d, raw + 2 * N); + StoreU(raw3, d, raw + 3 * N); + StoreU(raw4, d, raw + 4 * N); + StoreU(raw5, d, raw + 5 * N); + StoreU(raw6, d, raw + 6 * N); + StoreU(raw7, d, raw + 7 * N); + StoreU(raw8, d, raw + 8 * N); + StoreU(raw9, d, raw + 9 * N); + StoreU(rawA, d, raw + 0xA * N); + StoreU(rawB, d, raw + 0xB * N); + StoreU(rawC, d, raw + 0xC * N); + StoreU(rawD, d, raw + 0xD * N); + StoreU(rawE, d, raw + 0xE * N); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<13> + +template <> +struct Pack16<14> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // 14 vectors, each with 14+2 bits; two raw vectors are scattered + // across the upper 2 bits. + const VU16 hi2 = Set(d, 0xC000u); + const VU16 packed0 = Or(raw0, ShiftLeft<14>(rawE)); + const VU16 packed1 = OrAnd(raw1, ShiftLeft<12>(rawE), hi2); + const VU16 packed2 = OrAnd(raw2, ShiftLeft<10>(rawE), hi2); + const VU16 packed3 = OrAnd(raw3, ShiftLeft<8>(rawE), hi2); + const VU16 packed4 = OrAnd(raw4, ShiftLeft<6>(rawE), hi2); + const VU16 packed5 = OrAnd(raw5, ShiftLeft<4>(rawE), hi2); + const VU16 packed6 = OrAnd(raw6, ShiftLeft<2>(rawE), hi2); + const VU16 packed7 = Or(raw7, ShiftLeft<14>(rawF)); + const VU16 packed8 = OrAnd(raw8, ShiftLeft<12>(rawF), hi2); + const VU16 packed9 = OrAnd(raw9, ShiftLeft<10>(rawF), hi2); + const VU16 packedA = OrAnd(rawA, ShiftLeft<8>(rawF), hi2); + const VU16 packedB = OrAnd(rawB, ShiftLeft<6>(rawF), hi2); + const VU16 packedC = OrAnd(rawC, ShiftLeft<4>(rawF), hi2); + const VU16 packedD = OrAnd(rawD, ShiftLeft<2>(rawF), hi2); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + StoreU(packed8, d, packed_out + 8 * N); + StoreU(packed9, d, packed_out + 9 * N); + StoreU(packedA, d, packed_out + 0xA * N); + StoreU(packedB, d, packed_out + 0xB * N); + StoreU(packedC, d, packed_out + 0xC * N); + StoreU(packedD, d, packed_out + 0xD * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 packed8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + const VU16 packed9 = BitCast(d, LoadU(d, packed_in + 9 * N)); + const VU16 packedA = BitCast(d, LoadU(d, packed_in + 0xA * N)); + const VU16 packedB = BitCast(d, LoadU(d, packed_in + 0xB * N)); + const VU16 packedC = BitCast(d, LoadU(d, packed_in + 0xC * N)); + const VU16 packedD = BitCast(d, LoadU(d, packed_in + 0xD * N)); + + const VU16 mask = Set(d, 0x3FFFu); // Lowest 14 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed4, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(packed5, mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(packed6, mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(packed7, mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(packed8, mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(packed9, mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(packedA, mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(packedB, mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(packedC, mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(packedD, mask); + StoreU(rawD, d, raw + 0xD * N); + + // rawE is the concatenation of the top two bits in packed0..6. + const VU16 E0 = Xor3(ShiftRight<14>(packed0), // + ShiftRight<12>(AndNot(mask, packed1)), + ShiftRight<10>(AndNot(mask, packed2))); + const VU16 E1 = Xor3(ShiftRight<8>(AndNot(mask, packed3)), + ShiftRight<6>(AndNot(mask, packed4)), + ShiftRight<4>(AndNot(mask, packed5))); + const VU16 rawE = Xor3(ShiftRight<2>(AndNot(mask, packed6)), E0, E1); + const VU16 F0 = Xor3(ShiftRight<14>(AndNot(mask, packed7)), + ShiftRight<12>(AndNot(mask, packed8)), + ShiftRight<10>(AndNot(mask, packed9))); + const VU16 F1 = Xor3(ShiftRight<8>(AndNot(mask, packedA)), + ShiftRight<6>(AndNot(mask, packedB)), + ShiftRight<4>(AndNot(mask, packedC))); + const VU16 rawF = Xor3(ShiftRight<2>(AndNot(mask, packedD)), F0, F1); + StoreU(rawE, d, raw + 0xE * N); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<14> + +template <> +struct Pack16<15> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + // 15 vectors, each with 15+1 bits; one packed vector is scattered + // across the upper bit. + const VU16 hi1 = Set(d, 0x8000u); + const VU16 packed0 = Or(raw0, ShiftLeft<15>(rawF)); + const VU16 packed1 = OrAnd(raw1, ShiftLeft<14>(rawF), hi1); + const VU16 packed2 = OrAnd(raw2, ShiftLeft<13>(rawF), hi1); + const VU16 packed3 = OrAnd(raw3, ShiftLeft<12>(rawF), hi1); + const VU16 packed4 = OrAnd(raw4, ShiftLeft<11>(rawF), hi1); + const VU16 packed5 = OrAnd(raw5, ShiftLeft<10>(rawF), hi1); + const VU16 packed6 = OrAnd(raw6, ShiftLeft<9>(rawF), hi1); + const VU16 packed7 = OrAnd(raw7, ShiftLeft<8>(rawF), hi1); + const VU16 packed8 = OrAnd(raw8, ShiftLeft<7>(rawF), hi1); + const VU16 packed9 = OrAnd(raw9, ShiftLeft<6>(rawF), hi1); + const VU16 packedA = OrAnd(rawA, ShiftLeft<5>(rawF), hi1); + const VU16 packedB = OrAnd(rawB, ShiftLeft<4>(rawF), hi1); + const VU16 packedC = OrAnd(rawC, ShiftLeft<3>(rawF), hi1); + const VU16 packedD = OrAnd(rawD, ShiftLeft<2>(rawF), hi1); + const VU16 packedE = OrAnd(rawE, ShiftLeft<1>(rawF), hi1); + + StoreU(packed0, d, packed_out + 0 * N); + StoreU(packed1, d, packed_out + 1 * N); + StoreU(packed2, d, packed_out + 2 * N); + StoreU(packed3, d, packed_out + 3 * N); + StoreU(packed4, d, packed_out + 4 * N); + StoreU(packed5, d, packed_out + 5 * N); + StoreU(packed6, d, packed_out + 6 * N); + StoreU(packed7, d, packed_out + 7 * N); + StoreU(packed8, d, packed_out + 8 * N); + StoreU(packed9, d, packed_out + 9 * N); + StoreU(packedA, d, packed_out + 0xA * N); + StoreU(packedB, d, packed_out + 0xB * N); + StoreU(packedC, d, packed_out + 0xC * N); + StoreU(packedD, d, packed_out + 0xD * N); + StoreU(packedE, d, packed_out + 0xE * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + + const VU16 packed0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 packed1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 packed2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 packed3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 packed4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 packed5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 packed6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 packed7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 packed8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + const VU16 packed9 = BitCast(d, LoadU(d, packed_in + 9 * N)); + const VU16 packedA = BitCast(d, LoadU(d, packed_in + 0xA * N)); + const VU16 packedB = BitCast(d, LoadU(d, packed_in + 0xB * N)); + const VU16 packedC = BitCast(d, LoadU(d, packed_in + 0xC * N)); + const VU16 packedD = BitCast(d, LoadU(d, packed_in + 0xD * N)); + const VU16 packedE = BitCast(d, LoadU(d, packed_in + 0xE * N)); + + const VU16 mask = Set(d, 0x7FFFu); // Lowest 15 bits + + const VU16 raw0 = And(packed0, mask); + StoreU(raw0, d, raw + 0 * N); + + const VU16 raw1 = And(packed1, mask); + StoreU(raw1, d, raw + 1 * N); + + const VU16 raw2 = And(packed2, mask); + StoreU(raw2, d, raw + 2 * N); + + const VU16 raw3 = And(packed3, mask); + StoreU(raw3, d, raw + 3 * N); + + const VU16 raw4 = And(packed4, mask); + StoreU(raw4, d, raw + 4 * N); + + const VU16 raw5 = And(packed5, mask); + StoreU(raw5, d, raw + 5 * N); + + const VU16 raw6 = And(packed6, mask); + StoreU(raw6, d, raw + 6 * N); + + const VU16 raw7 = And(packed7, mask); + StoreU(raw7, d, raw + 7 * N); + + const VU16 raw8 = And(packed8, mask); + StoreU(raw8, d, raw + 8 * N); + + const VU16 raw9 = And(packed9, mask); + StoreU(raw9, d, raw + 9 * N); + + const VU16 rawA = And(packedA, mask); + StoreU(rawA, d, raw + 0xA * N); + + const VU16 rawB = And(packedB, mask); + StoreU(rawB, d, raw + 0xB * N); + + const VU16 rawC = And(packedC, mask); + StoreU(rawC, d, raw + 0xC * N); + + const VU16 rawD = And(packedD, mask); + StoreU(rawD, d, raw + 0xD * N); + + const VU16 rawE = And(packedE, mask); + StoreU(rawE, d, raw + 0xE * N); + + // rawF is the concatenation of the top bit in packed0..E. + const VU16 F0 = Xor3(ShiftRight<15>(packed0), // + ShiftRight<14>(AndNot(mask, packed1)), + ShiftRight<13>(AndNot(mask, packed2))); + const VU16 F1 = Xor3(ShiftRight<12>(AndNot(mask, packed3)), + ShiftRight<11>(AndNot(mask, packed4)), + ShiftRight<10>(AndNot(mask, packed5))); + const VU16 F2 = Xor3(ShiftRight<9>(AndNot(mask, packed6)), + ShiftRight<8>(AndNot(mask, packed7)), + ShiftRight<7>(AndNot(mask, packed8))); + const VU16 F3 = Xor3(ShiftRight<6>(AndNot(mask, packed9)), + ShiftRight<5>(AndNot(mask, packedA)), + ShiftRight<4>(AndNot(mask, packedB))); + const VU16 F4 = Xor3(ShiftRight<3>(AndNot(mask, packedC)), + ShiftRight<2>(AndNot(mask, packedD)), + ShiftRight<1>(AndNot(mask, packedE))); + const VU16 rawF = Xor3(F0, F1, Xor3(F2, F3, F4)); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<15> + +template <> +struct Pack16<16> { + template + HWY_INLINE void Pack(D d, const uint16_t* HWY_RESTRICT raw, + uint16_t* HWY_RESTRICT packed_out) const { + using VU16 = Vec; + const size_t N = Lanes(d); + const VU16 raw0 = LoadU(d, raw + 0 * N); + const VU16 raw1 = LoadU(d, raw + 1 * N); + const VU16 raw2 = LoadU(d, raw + 2 * N); + const VU16 raw3 = LoadU(d, raw + 3 * N); + const VU16 raw4 = LoadU(d, raw + 4 * N); + const VU16 raw5 = LoadU(d, raw + 5 * N); + const VU16 raw6 = LoadU(d, raw + 6 * N); + const VU16 raw7 = LoadU(d, raw + 7 * N); + const VU16 raw8 = LoadU(d, raw + 8 * N); + const VU16 raw9 = LoadU(d, raw + 9 * N); + const VU16 rawA = LoadU(d, raw + 0xA * N); + const VU16 rawB = LoadU(d, raw + 0xB * N); + const VU16 rawC = LoadU(d, raw + 0xC * N); + const VU16 rawD = LoadU(d, raw + 0xD * N); + const VU16 rawE = LoadU(d, raw + 0xE * N); + const VU16 rawF = LoadU(d, raw + 0xF * N); + + StoreU(raw0, d, packed_out + 0 * N); + StoreU(raw1, d, packed_out + 1 * N); + StoreU(raw2, d, packed_out + 2 * N); + StoreU(raw3, d, packed_out + 3 * N); + StoreU(raw4, d, packed_out + 4 * N); + StoreU(raw5, d, packed_out + 5 * N); + StoreU(raw6, d, packed_out + 6 * N); + StoreU(raw7, d, packed_out + 7 * N); + StoreU(raw8, d, packed_out + 8 * N); + StoreU(raw9, d, packed_out + 9 * N); + StoreU(rawA, d, packed_out + 0xA * N); + StoreU(rawB, d, packed_out + 0xB * N); + StoreU(rawC, d, packed_out + 0xC * N); + StoreU(rawD, d, packed_out + 0xD * N); + StoreU(rawE, d, packed_out + 0xE * N); + StoreU(rawF, d, packed_out + 0xF * N); + } + + template + HWY_INLINE void Unpack(D d, const uint16_t* HWY_RESTRICT packed_in, + uint16_t* HWY_RESTRICT raw) const { + using VU16 = Vec; + const size_t N = Lanes(d); + + const VU16 raw0 = BitCast(d, LoadU(d, packed_in + 0 * N)); + const VU16 raw1 = BitCast(d, LoadU(d, packed_in + 1 * N)); + const VU16 raw2 = BitCast(d, LoadU(d, packed_in + 2 * N)); + const VU16 raw3 = BitCast(d, LoadU(d, packed_in + 3 * N)); + const VU16 raw4 = BitCast(d, LoadU(d, packed_in + 4 * N)); + const VU16 raw5 = BitCast(d, LoadU(d, packed_in + 5 * N)); + const VU16 raw6 = BitCast(d, LoadU(d, packed_in + 6 * N)); + const VU16 raw7 = BitCast(d, LoadU(d, packed_in + 7 * N)); + const VU16 raw8 = BitCast(d, LoadU(d, packed_in + 8 * N)); + const VU16 raw9 = BitCast(d, LoadU(d, packed_in + 9 * N)); + const VU16 rawA = BitCast(d, LoadU(d, packed_in + 0xA * N)); + const VU16 rawB = BitCast(d, LoadU(d, packed_in + 0xB * N)); + const VU16 rawC = BitCast(d, LoadU(d, packed_in + 0xC * N)); + const VU16 rawD = BitCast(d, LoadU(d, packed_in + 0xD * N)); + const VU16 rawE = BitCast(d, LoadU(d, packed_in + 0xE * N)); + const VU16 rawF = BitCast(d, LoadU(d, packed_in + 0xF * N)); + + StoreU(raw0, d, raw + 0 * N); + StoreU(raw1, d, raw + 1 * N); + StoreU(raw2, d, raw + 2 * N); + StoreU(raw3, d, raw + 3 * N); + StoreU(raw4, d, raw + 4 * N); + StoreU(raw5, d, raw + 5 * N); + StoreU(raw6, d, raw + 6 * N); + StoreU(raw7, d, raw + 7 * N); + StoreU(raw8, d, raw + 8 * N); + StoreU(raw9, d, raw + 9 * N); + StoreU(rawA, d, raw + 0xA * N); + StoreU(rawB, d, raw + 0xB * N); + StoreU(rawC, d, raw + 0xC * N); + StoreU(rawD, d, raw + 0xD * N); + StoreU(rawE, d, raw + 0xE * N); + StoreU(rawF, d, raw + 0xF * N); + } +}; // Pack16<16> + +// The supported packing types for 32/64 bits. +enum BlockPackingType { + // Simple fixed bit-packing. + kBitPacked, + // Bit packing after subtracting a `frame of reference` value from input. + kFoRBitPacked, +}; + +namespace detail { + +// Generates the implementation for bit-packing/un-packing `T` type numbers +// where each number takes `kBits` bits. +// `S` is the remainder bits left from the previous bit-packed block. +// `kLoadPos` is the offset from which the next vector block should be loaded. +// `kStorePos` is the offset into which the next vector block should be stored. +// `BlockPackingType` is the type of packing/unpacking for this block. +template +struct BitPackUnroller { + static constexpr size_t B = sizeof(T) * 8; + + template + static inline void Pack(D d, const T* HWY_RESTRICT raw, + T* HWY_RESTRICT packed_out, const V& mask, + const V& frame_of_reference, V& in, V& out) { + // Avoid compilation errors and unnecessary template instantiation if + // compiling in C++14 mode + using NextUnroller = BitPackUnroller< + T, kBits, ((S <= B) ? (S + ((S < B) ? kBits : 0)) : (S % B)), + kLoadPos + static_cast(S < B), + kStorePos + static_cast(S > B), block_packing_type>; + + (void)raw; + (void)mask; + (void)in; + + const size_t N = Lanes(d); + HWY_IF_CONSTEXPR(S >= B) { + StoreU(out, d, packed_out + kStorePos * N); + HWY_IF_CONSTEXPR(S == B) { return; } + HWY_IF_CONSTEXPR(S != B) { + constexpr size_t shr_amount = (kBits - S % B) % B; + out = ShiftRight(in); + // NextUnroller is a typedef for + // Unroller if S > B is true + return NextUnroller::Pack(d, raw, packed_out, mask, frame_of_reference, + in, out); + } + } + HWY_IF_CONSTEXPR(S < B) { + HWY_IF_CONSTEXPR(block_packing_type == BlockPackingType::kBitPacked) { + in = LoadU(d, raw + kLoadPos * N); + } + HWY_IF_CONSTEXPR(block_packing_type == BlockPackingType::kFoRBitPacked) { + in = Sub(LoadU(d, raw + kLoadPos * N), frame_of_reference); + } + // Optimize for the case when `S` is zero. + // We can skip `Or` + ShiftLeft` to align `in`. + HWY_IF_CONSTEXPR(S == 0) { out = in; } + HWY_IF_CONSTEXPR(S != 0) { out = Or(out, ShiftLeft(in)); } + // NextUnroller is a typedef for + // Unroller if S < B is true + return NextUnroller::Pack(d, raw, packed_out, mask, frame_of_reference, + in, out); + } + } + + template + static inline void Unpack(D d, const T* HWY_RESTRICT packed_in, + T* HWY_RESTRICT raw, const V& mask, + const V& frame_of_reference, V& in, V& out) { + // Avoid compilation errors and unnecessary template instantiation if + // compiling in C++14 mode + using NextUnroller = BitPackUnroller< + T, kBits, ((S <= B) ? (S + ((S < B) ? kBits : 0)) : (S % B)), + kLoadPos + static_cast(S > B), + kStorePos + static_cast(S < B), block_packing_type>; + + (void)packed_in; + (void)mask; + (void)in; + + const size_t N = Lanes(d); + HWY_IF_CONSTEXPR(S >= B) { + HWY_IF_CONSTEXPR(S == B) { + V bitpacked_output = out; + HWY_IF_CONSTEXPR(block_packing_type == + BlockPackingType::kFoRBitPacked) { + bitpacked_output = Add(bitpacked_output, frame_of_reference); + } + StoreU(bitpacked_output, d, raw + kStorePos * N); + return; + } + HWY_IF_CONSTEXPR(S != B) { + in = LoadU(d, packed_in + kLoadPos * N); + constexpr size_t shl_amount = (kBits - S % B) % B; + out = And(Or(out, ShiftLeft(in)), mask); + // NextUnroller is a typedef for + // Unroller if S > B is true + return NextUnroller::Unpack(d, packed_in, raw, mask, frame_of_reference, + in, out); + } + } + HWY_IF_CONSTEXPR(S < B) { + V bitpacked_output = out; + HWY_IF_CONSTEXPR(block_packing_type == BlockPackingType::kFoRBitPacked) { + bitpacked_output = Add(bitpacked_output, frame_of_reference); + } + StoreU(bitpacked_output, d, raw + kStorePos * N); + HWY_IF_CONSTEXPR(S + kBits < B) { + // Optimize for the case when `S` is zero. + // We can skip the `ShiftRight` to align `in`. + HWY_IF_CONSTEXPR(S == 0) { out = And(in, mask); } + HWY_IF_CONSTEXPR(S != 0) { out = And(ShiftRight(in), mask); } + } + HWY_IF_CONSTEXPR(S + kBits >= B) { out = ShiftRight(in); } + // NextUnroller is a typedef for + // Unroller if S < B is true + return NextUnroller::Unpack(d, packed_in, raw, mask, frame_of_reference, + in, out); + } + } +}; + +// Computes the highest power of two that divides `kBits`. +template +constexpr size_t NumLoops() { + return (kBits & ~(kBits - 1)); +} + +template +constexpr size_t PackedIncr() { + return kBits / NumLoops(); +} + +template +constexpr size_t UnpackedIncr() { + return (sizeof(T) * 8) / NumLoops(); +} + +template +constexpr uint32_t MaskBits32() { + return static_cast((1ull << kBits) - 1); +} + +template +constexpr uint64_t MaskBits64() { + return (uint64_t{1} << kBits) - 1; +} +template <> +constexpr uint64_t MaskBits64<64>() { + return ~uint64_t{0}; +} + +} // namespace detail + +template // <= 32 +struct Pack32 { + template + HWY_INLINE void Pack(D d, const uint32_t* HWY_RESTRICT raw, + uint32_t* HWY_RESTRICT packed_out, + const uint32_t frame_of_reference_value = 0) const { + using V = VFromD; + const V mask = Set(d, detail::MaskBits32()); + const V frame_of_reference = Set(d, frame_of_reference_value); + for (size_t i = 0; i < detail::NumLoops(); ++i) { + V in = Zero(d); + V out = Zero(d); + detail::BitPackUnroller::Pack(d, raw, packed_out, + mask, + frame_of_reference, in, + out); + raw += detail::UnpackedIncr() * Lanes(d); + packed_out += detail::PackedIncr() * Lanes(d); + } + } + + template + HWY_INLINE void Unpack(D d, const uint32_t* HWY_RESTRICT packed_in, + uint32_t* HWY_RESTRICT raw, + const uint32_t frame_of_reference_value = 0) const { + using V = VFromD; + const V mask = Set(d, detail::MaskBits32()); + const V frame_of_reference = Set(d, frame_of_reference_value); + for (size_t i = 0; i < detail::NumLoops(); ++i) { + V in = LoadU(d, packed_in + 0 * Lanes(d)); + V out = And(in, mask); + detail::BitPackUnroller::Unpack(d, packed_in, raw, + mask, + frame_of_reference, + in, out); + raw += detail::UnpackedIncr() * Lanes(d); + packed_in += detail::PackedIncr() * Lanes(d); + } + } +}; + +template // <= 64 +struct Pack64 { + template + HWY_INLINE void Pack(D d, const uint64_t* HWY_RESTRICT raw, + uint64_t* HWY_RESTRICT packed_out, + const uint64_t frame_of_reference_value = 0) const { + using V = VFromD; + const V mask = Set(d, detail::MaskBits64()); + const V frame_of_reference = Set(d, frame_of_reference_value); + for (size_t i = 0; i < detail::NumLoops(); ++i) { + V in = Zero(d); + V out = Zero(d); + detail::BitPackUnroller::Pack(d, raw, packed_out, + mask, + frame_of_reference, in, + out); + raw += detail::UnpackedIncr() * Lanes(d); + packed_out += detail::PackedIncr() * Lanes(d); + } + } + + template + HWY_INLINE void Unpack(D d, const uint64_t* HWY_RESTRICT packed_in, + uint64_t* HWY_RESTRICT raw, + const uint64_t frame_of_reference_value = 0) const { + using V = VFromD; + const V mask = Set(d, detail::MaskBits64()); + const V frame_of_reference = Set(d, frame_of_reference_value); + for (size_t i = 0; i < detail::NumLoops(); ++i) { + V in = LoadU(d, packed_in + 0 * Lanes(d)); + V out = And(in, mask); + detail::BitPackUnroller::Unpack(d, packed_in, raw, + mask, + frame_of_reference, + in, out); + raw += detail::UnpackedIncr() * Lanes(d); + packed_in += detail::PackedIncr() * Lanes(d); + } + } +}; + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_BIT_PACK_INL_H_ diff --git a/lib/highway/hwy/contrib/dot/dot-inl.h b/lib/highway/hwy/contrib/dot/dot-inl.h new file mode 100644 index 00000000000..e48796aaa76 --- /dev/null +++ b/lib/highway/hwy/contrib/dot/dot-inl.h @@ -0,0 +1,460 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// clang-format off +#if defined(HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_) == defined(HWY_TARGET_TOGGLE) // NOLINT +// clang-format on +#ifdef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ +#endif + +#include +#include + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// NOTE: the D argument describes the inputs, not the output, because both +// f32/f32, bf16/bf16, and f32/bf16 inputs accumulate to f32. +struct Dot { + // Specify zero or more of these, ORed together, as the kAssumptions template + // argument to Compute. Each one may improve performance or reduce code size, + // at the cost of additional requirements on the arguments. + enum Assumptions { + // num_elements is at least N, which may be up to HWY_MAX_BYTES / sizeof(T). + kAtLeastOneVector = 1, + // num_elements is divisible by N (a power of two, so this can be used if + // the problem size is known to be a power of two >= HWY_MAX_BYTES / + // sizeof(T)). + kMultipleOfVector = 2, + // RoundUpTo(num_elements, N) elements are accessible; their value does not + // matter (will be treated as if they were zero). + kPaddedToVector = 4, + }; + + // Returns sum{pa[i] * pb[i]} for floating-point inputs, including float16_t + // and double if HWY_HAVE_FLOAT16/64. Aligning the + // pointers to a multiple of N elements is helpful but not required. + template > + static HWY_INLINE T Compute(const D d, const T* const HWY_RESTRICT pa, + const T* const HWY_RESTRICT pb, + const size_t num_elements) { + static_assert(IsFloat(), "MulAdd requires float type"); + using V = decltype(Zero(d)); + + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + size_t i = 0; + + constexpr bool kIsAtLeastOneVector = + (kAssumptions & kAtLeastOneVector) != 0; + constexpr bool kIsMultipleOfVector = + (kAssumptions & kMultipleOfVector) != 0; + constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0; + + // Won't be able to do a full vector load without padding => scalar loop. + if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector && + HWY_UNLIKELY(num_elements < N)) { + // Only 2x unroll to avoid excessive code size. + T sum0 = ConvertScalarTo(0); + T sum1 = ConvertScalarTo(0); + for (; i + 2 <= num_elements; i += 2) { + // For reasons unknown, fp16 += does not compile on clang (Arm). + sum0 = ConvertScalarTo(sum0 + pa[i + 0] * pb[i + 0]); + sum1 = ConvertScalarTo(sum1 + pa[i + 1] * pb[i + 1]); + } + if (i < num_elements) { + sum1 = ConvertScalarTo(sum1 + pa[i] * pb[i]); + } + return ConvertScalarTo(sum0 + sum1); + } + + // Compiler doesn't make independent sum* accumulators, so unroll manually. + // 2 FMA ports * 4 cycle latency = up to 8 in-flight, but that is excessive + // for unaligned inputs (each unaligned pointer halves the throughput + // because it occupies both L1 load ports for a cycle). We cannot have + // arrays of vectors on RVV/SVE, so always unroll 4x. + V sum0 = Zero(d); + V sum1 = Zero(d); + V sum2 = Zero(d); + V sum3 = Zero(d); + + // Main loop: unrolled + for (; i + 4 * N <= num_elements; /* i += 4 * N */) { // incr in loop + const auto a0 = LoadU(d, pa + i); + const auto b0 = LoadU(d, pb + i); + i += N; + sum0 = MulAdd(a0, b0, sum0); + const auto a1 = LoadU(d, pa + i); + const auto b1 = LoadU(d, pb + i); + i += N; + sum1 = MulAdd(a1, b1, sum1); + const auto a2 = LoadU(d, pa + i); + const auto b2 = LoadU(d, pb + i); + i += N; + sum2 = MulAdd(a2, b2, sum2); + const auto a3 = LoadU(d, pa + i); + const auto b3 = LoadU(d, pb + i); + i += N; + sum3 = MulAdd(a3, b3, sum3); + } + + // Up to 3 iterations of whole vectors + for (; i + N <= num_elements; i += N) { + const auto a = LoadU(d, pa + i); + const auto b = LoadU(d, pb + i); + sum0 = MulAdd(a, b, sum0); + } + + if (!kIsMultipleOfVector) { + const size_t remaining = num_elements - i; + if (remaining != 0) { + if (kIsPaddedToVector) { + const auto mask = FirstN(d, remaining); + const auto a = LoadU(d, pa + i); + const auto b = LoadU(d, pb + i); + sum1 = MulAdd(IfThenElseZero(mask, a), IfThenElseZero(mask, b), sum1); + } else { + // Unaligned load such that the last element is in the highest lane - + // ensures we do not touch any elements outside the valid range. + // If we get here, then num_elements >= N. + HWY_DASSERT(i >= N); + i += remaining - N; + const auto skip = FirstN(d, N - remaining); + const auto a = LoadU(d, pa + i); // always unaligned + const auto b = LoadU(d, pb + i); + sum1 = MulAdd(IfThenZeroElse(skip, a), IfThenZeroElse(skip, b), sum1); + } + } + } // kMultipleOfVector + + // Reduction tree: sum of all accumulators by pairs, then across lanes. + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + return ReduceSum(d, sum0); + } + + // f32 * bf16 + template + static HWY_INLINE float Compute(const DF df, + const float* const HWY_RESTRICT pa, + const hwy::bfloat16_t* const HWY_RESTRICT pb, + const size_t num_elements) { +#if HWY_TARGET == HWY_SCALAR + const Rebind dbf; +#else + const Repartition dbf; + using VBF = decltype(Zero(dbf)); +#endif + const Half dbfh; + using VF = decltype(Zero(df)); + + HWY_LANES_CONSTEXPR size_t NF = Lanes(df); + + constexpr bool kIsAtLeastOneVector = + (kAssumptions & kAtLeastOneVector) != 0; + constexpr bool kIsMultipleOfVector = + (kAssumptions & kMultipleOfVector) != 0; + constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0; + + // Won't be able to do a full vector load without padding => scalar loop. + if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector && + HWY_UNLIKELY(num_elements < NF)) { + // Only 2x unroll to avoid excessive code size. + float sum0 = 0.0f; + float sum1 = 0.0f; + size_t i = 0; + for (; i + 2 <= num_elements; i += 2) { + sum0 += pa[i + 0] * ConvertScalarTo(pb[i + 0]); + sum1 += pa[i + 1] * ConvertScalarTo(pb[i + 1]); + } + for (; i < num_elements; ++i) { + sum1 += pa[i] * ConvertScalarTo(pb[i]); + } + return sum0 + sum1; + } + + // Compiler doesn't make independent sum* accumulators, so unroll manually. + // 2 FMA ports * 4 cycle latency = up to 8 in-flight, but that is excessive + // for unaligned inputs (each unaligned pointer halves the throughput + // because it occupies both L1 load ports for a cycle). We cannot have + // arrays of vectors on RVV/SVE, so always unroll 4x. + VF sum0 = Zero(df); + VF sum1 = Zero(df); + VF sum2 = Zero(df); + VF sum3 = Zero(df); + + size_t i = 0; + +#if HWY_TARGET != HWY_SCALAR // PromoteUpperTo supported + // Main loop: unrolled + for (; i + 4 * NF <= num_elements; /* i += 4 * N */) { // incr in loop + const VF a0 = LoadU(df, pa + i); + const VBF b0 = LoadU(dbf, pb + i); + i += NF; + sum0 = MulAdd(a0, PromoteLowerTo(df, b0), sum0); + const VF a1 = LoadU(df, pa + i); + i += NF; + sum1 = MulAdd(a1, PromoteUpperTo(df, b0), sum1); + const VF a2 = LoadU(df, pa + i); + const VBF b2 = LoadU(dbf, pb + i); + i += NF; + sum2 = MulAdd(a2, PromoteLowerTo(df, b2), sum2); + const VF a3 = LoadU(df, pa + i); + i += NF; + sum3 = MulAdd(a3, PromoteUpperTo(df, b2), sum3); + } +#endif // HWY_TARGET == HWY_SCALAR + + // Up to 3 iterations of whole vectors + for (; i + NF <= num_elements; i += NF) { + const VF a = LoadU(df, pa + i); + const VF b = PromoteTo(df, LoadU(dbfh, pb + i)); + sum0 = MulAdd(a, b, sum0); + } + + if (!kIsMultipleOfVector) { + const size_t remaining = num_elements - i; + if (remaining != 0) { + if (kIsPaddedToVector) { + const auto mask = FirstN(df, remaining); + const VF a = LoadU(df, pa + i); + const VF b = PromoteTo(df, LoadU(dbfh, pb + i)); + sum1 = MulAdd(IfThenElseZero(mask, a), IfThenElseZero(mask, b), sum1); + } else { + // Unaligned load such that the last element is in the highest lane - + // ensures we do not touch any elements outside the valid range. + // If we get here, then num_elements >= N. + HWY_DASSERT(i >= NF); + i += remaining - NF; + const auto skip = FirstN(df, NF - remaining); + const VF a = LoadU(df, pa + i); // always unaligned + const VF b = PromoteTo(df, LoadU(dbfh, pb + i)); + sum1 = MulAdd(IfThenZeroElse(skip, a), IfThenZeroElse(skip, b), sum1); + } + } + } // kMultipleOfVector + + // Reduction tree: sum of all accumulators by pairs, then across lanes. + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + return ReduceSum(df, sum0); + } + + // Returns sum{pa[i] * pb[i]} for bfloat16 inputs. Aligning the pointers to a + // multiple of N elements is helpful but not required. + template + static HWY_INLINE float Compute(const D d, + const bfloat16_t* const HWY_RESTRICT pa, + const bfloat16_t* const HWY_RESTRICT pb, + const size_t num_elements) { + const RebindToUnsigned du16; + const Repartition df32; + + using V = decltype(Zero(df32)); + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + size_t i = 0; + + constexpr bool kIsAtLeastOneVector = + (kAssumptions & kAtLeastOneVector) != 0; + constexpr bool kIsMultipleOfVector = + (kAssumptions & kMultipleOfVector) != 0; + constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0; + + // Won't be able to do a full vector load without padding => scalar loop. + if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector && + HWY_UNLIKELY(num_elements < N)) { + float sum0 = 0.0f; // Only 2x unroll to avoid excessive code size for.. + float sum1 = 0.0f; // this unlikely(?) case. + for (; i + 2 <= num_elements; i += 2) { + sum0 += F32FromBF16(pa[i + 0]) * F32FromBF16(pb[i + 0]); + sum1 += F32FromBF16(pa[i + 1]) * F32FromBF16(pb[i + 1]); + } + if (i < num_elements) { + sum1 += F32FromBF16(pa[i]) * F32FromBF16(pb[i]); + } + return sum0 + sum1; + } + + // See comment in the other Compute() overload. Unroll 2x, but we need + // twice as many sums for ReorderWidenMulAccumulate. + V sum0 = Zero(df32); + V sum1 = Zero(df32); + V sum2 = Zero(df32); + V sum3 = Zero(df32); + + // Main loop: unrolled + for (; i + 2 * N <= num_elements; /* i += 2 * N */) { // incr in loop + const auto a0 = LoadU(d, pa + i); + const auto b0 = LoadU(d, pb + i); + i += N; + sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1); + const auto a1 = LoadU(d, pa + i); + const auto b1 = LoadU(d, pb + i); + i += N; + sum2 = ReorderWidenMulAccumulate(df32, a1, b1, sum2, sum3); + } + + // Possibly one more iteration of whole vectors + if (i + N <= num_elements) { + const auto a0 = LoadU(d, pa + i); + const auto b0 = LoadU(d, pb + i); + i += N; + sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1); + } + + if (!kIsMultipleOfVector) { + const size_t remaining = num_elements - i; + if (remaining != 0) { + if (kIsPaddedToVector) { + const auto mask = FirstN(du16, remaining); + const auto va = LoadU(d, pa + i); + const auto vb = LoadU(d, pb + i); + const auto a16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, va))); + const auto b16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, vb))); + sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3); + + } else { + // Unaligned load such that the last element is in the highest lane - + // ensures we do not touch any elements outside the valid range. + // If we get here, then num_elements >= N. + HWY_DASSERT(i >= N); + i += remaining - N; + const auto skip = FirstN(du16, N - remaining); + const auto va = LoadU(d, pa + i); // always unaligned + const auto vb = LoadU(d, pb + i); + const auto a16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, va))); + const auto b16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, vb))); + sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3); + } + } + } // kMultipleOfVector + + // Reduction tree: sum of all accumulators by pairs, then across lanes. + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + return ReduceSum(df32, sum0); + } + + // Returns sum{i32(pa[i]) * i32(pb[i])} for i16 inputs. Aligning the pointers + // to a multiple of N elements is helpful but not required. + template + static HWY_INLINE int32_t Compute(const D d, + const int16_t* const HWY_RESTRICT pa, + const int16_t* const HWY_RESTRICT pb, + const size_t num_elements) { + const RebindToUnsigned du16; + const RepartitionToWide di32; + + using VI32 = Vec; + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + size_t i = 0; + + constexpr bool kIsAtLeastOneVector = + (kAssumptions & kAtLeastOneVector) != 0; + constexpr bool kIsMultipleOfVector = + (kAssumptions & kMultipleOfVector) != 0; + constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0; + + // Won't be able to do a full vector load without padding => scalar loop. + if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector && + HWY_UNLIKELY(num_elements < N)) { + int32_t sum0 = 0; // Only 2x unroll to avoid excessive code size for.. + int32_t sum1 = 0; // this unlikely(?) case. + for (; i + 2 <= num_elements; i += 2) { + sum0 += int32_t{pa[i + 0]} * int32_t{pb[i + 0]}; + sum1 += int32_t{pa[i + 1]} * int32_t{pb[i + 1]}; + } + if (i < num_elements) { + sum1 += int32_t{pa[i]} * int32_t{pb[i]}; + } + return sum0 + sum1; + } + + // See comment in the other Compute() overload. Unroll 2x, but we need + // twice as many sums for ReorderWidenMulAccumulate. + VI32 sum0 = Zero(di32); + VI32 sum1 = Zero(di32); + VI32 sum2 = Zero(di32); + VI32 sum3 = Zero(di32); + + // Main loop: unrolled + for (; i + 2 * N <= num_elements; /* i += 2 * N */) { // incr in loop + const auto a0 = LoadU(d, pa + i); + const auto b0 = LoadU(d, pb + i); + i += N; + sum0 = ReorderWidenMulAccumulate(di32, a0, b0, sum0, sum1); + const auto a1 = LoadU(d, pa + i); + const auto b1 = LoadU(d, pb + i); + i += N; + sum2 = ReorderWidenMulAccumulate(di32, a1, b1, sum2, sum3); + } + + // Possibly one more iteration of whole vectors + if (i + N <= num_elements) { + const auto a0 = LoadU(d, pa + i); + const auto b0 = LoadU(d, pb + i); + i += N; + sum0 = ReorderWidenMulAccumulate(di32, a0, b0, sum0, sum1); + } + + if (!kIsMultipleOfVector) { + const size_t remaining = num_elements - i; + if (remaining != 0) { + if (kIsPaddedToVector) { + const auto mask = FirstN(du16, remaining); + const auto va = LoadU(d, pa + i); + const auto vb = LoadU(d, pb + i); + const auto a16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, va))); + const auto b16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, vb))); + sum2 = ReorderWidenMulAccumulate(di32, a16, b16, sum2, sum3); + + } else { + // Unaligned load such that the last element is in the highest lane - + // ensures we do not touch any elements outside the valid range. + // If we get here, then num_elements >= N. + HWY_DASSERT(i >= N); + i += remaining - N; + const auto skip = FirstN(du16, N - remaining); + const auto va = LoadU(d, pa + i); // always unaligned + const auto vb = LoadU(d, pb + i); + const auto a16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, va))); + const auto b16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, vb))); + sum2 = ReorderWidenMulAccumulate(di32, a16, b16, sum2, sum3); + } + } + } // kMultipleOfVector + + // Reduction tree: sum of all accumulators by pairs, then across lanes. + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + return ReduceSum(di32, sum0); + } +}; + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_ diff --git a/lib/highway/hwy/contrib/image/image.cc b/lib/highway/hwy/contrib/image/image.cc new file mode 100644 index 00000000000..8557baf963e --- /dev/null +++ b/lib/highway/hwy/contrib/image/image.cc @@ -0,0 +1,144 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/image/image.h" + +#include +#include + +#include // std::swap + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#include "hwy/per_target.h" + +namespace hwy { + +size_t ImageBase::VectorSize() { + // Do not cache result - must return the current value, which may be greater + // than the first call if it was subject to DisableTargets! + return VectorBytes(); +} + +size_t ImageBase::BytesPerRow(const size_t xsize, const size_t sizeof_t) { + const size_t vec_size = VectorSize(); + // Check for integer overflow in xsize * sizeof_t. + if (HWY_LIKELY(xsize != 0) && sizeof_t > SIZE_MAX / xsize) { + HWY_ABORT("ImageBase::BytesPerRow overflow: xsize=%zu, sizeof_t=%zu", + xsize, sizeof_t); + } + size_t valid_bytes = xsize * sizeof_t; + + // Allow unaligned accesses starting at the last valid value - this may raise + // msan errors unless the user calls InitializePaddingForUnalignedAccesses. + // Skip for the scalar case because no extra lanes will be loaded. + if (vec_size != 1) { + HWY_DASSERT(vec_size >= sizeof_t); + valid_bytes += vec_size - sizeof_t; + } + + // Round up to vector and cache line size. + const size_t align = HWY_MAX(vec_size, HWY_ALIGNMENT); + size_t bytes_per_row = RoundUpTo(valid_bytes, align); + + // During the lengthy window before writes are committed to memory, CPUs + // guard against read after write hazards by checking the address, but + // only the lower 11 bits. We avoid a false dependency between writes to + // consecutive rows by ensuring their sizes are not multiples of 2 KiB. + // Avoid2K prevents the same problem for the planes of an Image3. + if (bytes_per_row % HWY_ALIGNMENT == 0) { + bytes_per_row += align; + } + + HWY_DASSERT(bytes_per_row % align == 0); + return bytes_per_row; +} + +ImageBase::ImageBase(const size_t xsize, const size_t ysize, + const size_t sizeof_t) + : xsize_(static_cast(xsize)), + ysize_(static_cast(ysize)), + bytes_(nullptr, AlignedFreer(&AlignedFreer::DoNothing, nullptr)) { + HWY_ASSERT(sizeof_t == 1 || sizeof_t == 2 || sizeof_t == 4 || sizeof_t == 8); + // Validate dimensions fit in uint32_t to prevent silent truncation. + HWY_ASSERT(xsize <= UINT32_MAX); + HWY_ASSERT(ysize <= UINT32_MAX); + + bytes_per_row_ = 0; + // Dimensions can be zero, e.g. for lazily-allocated images. Only allocate + // if nonzero, because "zero" bytes still have padding/bookkeeping overhead. + if (xsize != 0 && ysize != 0) { + bytes_per_row_ = BytesPerRow(xsize, sizeof_t); + // Check for integer overflow in allocation size. + if (bytes_per_row_ > SIZE_MAX / ysize) { + HWY_ABORT("ImageBase: allocation overflow (%zu * %zu exceeds size_t)", + bytes_per_row_, ysize); + } + bytes_ = AllocateAligned(bytes_per_row_ * ysize); + HWY_ASSERT(bytes_.get() != nullptr); + InitializePadding(sizeof_t, Padding::kRoundUp); + } +} + +ImageBase::ImageBase(const size_t xsize, const size_t ysize, + const size_t bytes_per_row, void* const aligned) + : xsize_(static_cast(xsize)), + ysize_(static_cast(ysize)), + bytes_per_row_(bytes_per_row), + bytes_(static_cast(aligned), + AlignedFreer(&AlignedFreer::DoNothing, nullptr)) { + const size_t vec_size = VectorSize(); + HWY_ASSERT(bytes_per_row % vec_size == 0); + HWY_ASSERT(reinterpret_cast(aligned) % vec_size == 0); +} + +void ImageBase::InitializePadding(const size_t sizeof_t, Padding padding) { +#if HWY_IS_MSAN || HWY_IDE + if (xsize_ == 0 || ysize_ == 0) return; + + const size_t vec_size = VectorSize(); // Bytes, independent of sizeof_t! + if (vec_size == 1) return; // Scalar mode: no padding needed + + const size_t valid_size = xsize_ * sizeof_t; + const size_t initialize_size = padding == Padding::kRoundUp + ? RoundUpTo(valid_size, vec_size) + : valid_size + vec_size - sizeof_t; + if (valid_size == initialize_size) return; + + for (size_t y = 0; y < ysize_; ++y) { + uint8_t* HWY_RESTRICT row = static_cast(VoidRow(y)); +#if defined(__clang__) && (__clang_major__ <= 6) + // There's a bug in msan in clang-6 when handling AVX2 operations. This + // workaround allows tests to pass on msan, although it is slower and + // prevents msan warnings from uninitialized images. + hwy::ZeroBytes(row, initialize_size); +#else + hwy::ZeroBytes(row + valid_size, initialize_size - valid_size); +#endif // clang6 + } +#else + (void)sizeof_t; + (void)padding; +#endif // HWY_IS_MSAN +} + +void ImageBase::Swap(ImageBase& other) { + std::swap(xsize_, other.xsize_); + std::swap(ysize_, other.ysize_); + std::swap(bytes_per_row_, other.bytes_per_row_); + std::swap(bytes_, other.bytes_); +} + +} // namespace hwy diff --git a/lib/highway/hwy/contrib/image/image.h b/lib/highway/hwy/contrib/image/image.h new file mode 100644 index 00000000000..4f578aaec42 --- /dev/null +++ b/lib/highway/hwy/contrib/image/image.h @@ -0,0 +1,467 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_CONTRIB_IMAGE_IMAGE_H_ +#define HIGHWAY_HWY_CONTRIB_IMAGE_IMAGE_H_ + +// SIMD/multicore-friendly planar image representation with row accessors. + +#include + +#include // std::move + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" + +namespace hwy { + +// Type-independent parts of Image<> - reduces code duplication and facilitates +// moving member function implementations to cc file. +struct HWY_CONTRIB_DLLEXPORT ImageBase { + // Returns required alignment in bytes for externally allocated memory. + static size_t VectorSize(); + + // Returns distance [bytes] between the start of two consecutive rows, a + // multiple of VectorSize but NOT kAlias (see implementation). + static size_t BytesPerRow(size_t xsize, size_t sizeof_t); + + // No allocation (for output params or unused images) + ImageBase() + : xsize_(0), + ysize_(0), + bytes_per_row_(0), + bytes_(nullptr, AlignedFreer(&AlignedFreer::DoNothing, nullptr)) {} + + // Allocates memory (this is the common case) + ImageBase(size_t xsize, size_t ysize, size_t sizeof_t); + + // References but does not take ownership of external memory. Useful for + // interoperability with other libraries. `aligned` must be aligned to a + // multiple of VectorSize() and `bytes_per_row` must also be a multiple of + // VectorSize() or preferably equal to BytesPerRow(). + ImageBase(size_t xsize, size_t ysize, size_t bytes_per_row, void* aligned); + + // Copy construction/assignment is forbidden to avoid inadvertent copies, + // which can be very expensive. Use CopyImageTo() instead. + ImageBase(const ImageBase& other) = delete; + ImageBase& operator=(const ImageBase& other) = delete; + + // Move constructor (required for returning Image from function) + ImageBase(ImageBase&& other) noexcept = default; + + // Move assignment (required for std::vector) + ImageBase& operator=(ImageBase&& other) noexcept = default; + + void Swap(ImageBase& other); + + // Useful for pre-allocating image with some padding for alignment purposes + // and later reporting the actual valid dimensions. Caller is responsible + // for ensuring xsize/ysize are <= the original dimensions. + void ShrinkTo(const size_t xsize, const size_t ysize) { + xsize_ = static_cast(xsize); + ysize_ = static_cast(ysize); + // NOTE: we can't recompute bytes_per_row for more compact storage and + // better locality because that would invalidate the image contents. + } + + // How many pixels. + HWY_INLINE size_t xsize() const { return xsize_; } + HWY_INLINE size_t ysize() const { return ysize_; } + + // NOTE: do not use this for copying rows - the valid xsize may be much less. + HWY_INLINE size_t bytes_per_row() const { return bytes_per_row_; } + + // Raw access to byte contents, for interfacing with other libraries. + // Unsigned char instead of char to avoid surprises (sign extension). + HWY_INLINE uint8_t* bytes() { + void* p = bytes_.get(); + return static_cast(HWY_ASSUME_ALIGNED(p, 64)); + } + HWY_INLINE const uint8_t* bytes() const { + const void* p = bytes_.get(); + return static_cast(HWY_ASSUME_ALIGNED(p, 64)); + } + + protected: + // Returns pointer to the start of a row. + HWY_INLINE void* VoidRow(const size_t y) const { +#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN + if (y >= ysize_) { + HWY_ABORT("Row(%d) >= %u\n", static_cast(y), ysize_); + } +#endif + + void* row = bytes_.get() + y * bytes_per_row_; + return HWY_ASSUME_ALIGNED(row, 64); + } + + enum class Padding { + // Allow Load(d, row + x) for x = 0; x < xsize(); x += Lanes(d). Default. + kRoundUp, + // Allow LoadU(d, row + x) for x <= xsize() - 1. This requires an extra + // vector to be initialized. If done by default, this would suppress + // legitimate msan warnings. We therefore require users to explicitly call + // InitializePadding before using unaligned loads (e.g. convolution). + kUnaligned + }; + + // Initializes the minimum bytes required to suppress msan warnings from + // legitimate (according to Padding mode) vector loads/stores on the right + // border, where some lanes are uninitialized and assumed to be unused. + void InitializePadding(size_t sizeof_t, Padding padding); + + // (Members are non-const to enable assignment during move-assignment.) + uint32_t xsize_; // In valid pixels, not including any padding. + uint32_t ysize_; + size_t bytes_per_row_; // Includes padding. + AlignedFreeUniquePtr bytes_; +}; + +// Single channel, aligned rows separated by padding. T must be POD. +// +// 'Single channel' (one 2D array per channel) simplifies vectorization +// (repeating the same operation on multiple adjacent components) without the +// complexity of a hybrid layout (8 R, 8 G, 8 B, ...). In particular, clients +// can easily iterate over all components in a row and Image requires no +// knowledge of the pixel format beyond the component type "T". +// +// 'Aligned' means each row is aligned to the L1 cache line size. This prevents +// false sharing between two threads operating on adjacent rows. +// +// 'Padding' is still relevant because vectors could potentially be larger than +// a cache line. By rounding up row sizes to the vector size, we allow +// reading/writing ALIGNED vectors whose first lane is a valid sample. This +// avoids needing a separate loop to handle remaining unaligned lanes. +// +// This image layout could also be achieved with a vector and a row accessor +// function, but a class wrapper with support for "deleter" allows wrapping +// existing memory allocated by clients without copying the pixels. It also +// provides convenient accessors for xsize/ysize, which shortens function +// argument lists. Supports move-construction so it can be stored in containers. +template +class Image : public ImageBase { + public: + using T = ComponentType; + + Image() = default; + Image(const size_t xsize, const size_t ysize) + : ImageBase(xsize, ysize, sizeof(T)) {} + Image(const size_t xsize, const size_t ysize, size_t bytes_per_row, + void* aligned) + : ImageBase(xsize, ysize, bytes_per_row, aligned) {} + + void InitializePaddingForUnalignedAccesses() { + InitializePadding(sizeof(T), Padding::kUnaligned); + } + + HWY_INLINE const T* ConstRow(const size_t y) const { + return static_cast(VoidRow(y)); + } + HWY_INLINE const T* ConstRow(const size_t y) { + return static_cast(VoidRow(y)); + } + + // Returns pointer to non-const. This allows passing const Image* parameters + // when the callee is only supposed to fill the pixels, as opposed to + // allocating or resizing the image. + HWY_INLINE T* MutableRow(const size_t y) const { + return static_cast(VoidRow(y)); + } + HWY_INLINE T* MutableRow(const size_t y) { + return static_cast(VoidRow(y)); + } + + // Returns number of pixels (some of which are padding) per row. Useful for + // computing other rows via pointer arithmetic. WARNING: this must + // NOT be used to determine xsize. + HWY_INLINE intptr_t PixelsPerRow() const { + return static_cast(bytes_per_row_ / sizeof(T)); + } +}; + +using ImageF = Image; + +// A bundle of 3 same-sized images. To fill an existing Image3 using +// single-channel producers, we also need access to each const Image*. Const +// prevents breaking the same-size invariant, while still allowing pixels to be +// changed via MutableRow. +template +class Image3 { + public: + using T = ComponentType; + using ImageT = Image; + static constexpr size_t kNumPlanes = 3; + + Image3() : planes_{ImageT(), ImageT(), ImageT()} {} + + Image3(const size_t xsize, const size_t ysize) + : planes_{ImageT(xsize, ysize), ImageT(xsize, ysize), + ImageT(xsize, ysize)} {} + + Image3(Image3&& other) noexcept { + for (size_t i = 0; i < kNumPlanes; i++) { + planes_[i] = std::move(other.planes_[i]); + } + } + + Image3(ImageT&& plane0, ImageT&& plane1, ImageT&& plane2) { + if (!SameSize(plane0, plane1) || !SameSize(plane0, plane2)) { + HWY_ABORT( + "Not same size: %d x %d, %d x %d, %d x %d\n", + static_cast(plane0.xsize()), static_cast(plane0.ysize()), + static_cast(plane1.xsize()), static_cast(plane1.ysize()), + static_cast(plane2.xsize()), static_cast(plane2.ysize())); + } + planes_[0] = std::move(plane0); + planes_[1] = std::move(plane1); + planes_[2] = std::move(plane2); + } + + // Copy construction/assignment is forbidden to avoid inadvertent copies, + // which can be very expensive. Use CopyImageTo instead. + Image3(const Image3& other) = delete; + Image3& operator=(const Image3& other) = delete; + + Image3& operator=(Image3&& other) noexcept { + for (size_t i = 0; i < kNumPlanes; i++) { + planes_[i] = std::move(other.planes_[i]); + } + return *this; + } + + HWY_INLINE const T* ConstPlaneRow(const size_t c, const size_t y) const { + return static_cast(VoidPlaneRow(c, y)); + } + HWY_INLINE const T* ConstPlaneRow(const size_t c, const size_t y) { + return static_cast(VoidPlaneRow(c, y)); + } + + HWY_INLINE T* MutablePlaneRow(const size_t c, const size_t y) const { + return static_cast(VoidPlaneRow(c, y)); + } + HWY_INLINE T* MutablePlaneRow(const size_t c, const size_t y) { + return static_cast(VoidPlaneRow(c, y)); + } + + HWY_INLINE const ImageT& Plane(size_t idx) const { return planes_[idx]; } + + void Swap(Image3& other) { + for (size_t c = 0; c < 3; ++c) { + other.planes_[c].Swap(planes_[c]); + } + } + + void ShrinkTo(const size_t xsize, const size_t ysize) { + for (ImageT& plane : planes_) { + plane.ShrinkTo(xsize, ysize); + } + } + + // Sizes of all three images are guaranteed to be equal. + HWY_INLINE size_t xsize() const { return planes_[0].xsize(); } + HWY_INLINE size_t ysize() const { return planes_[0].ysize(); } + // Returns offset [bytes] from one row to the next row of the same plane. + // WARNING: this must NOT be used to determine xsize, nor for copying rows - + // the valid xsize may be much less. + HWY_INLINE size_t bytes_per_row() const { return planes_[0].bytes_per_row(); } + // Returns number of pixels (some of which are padding) per row. Useful for + // computing other rows via pointer arithmetic. WARNING: this must NOT be used + // to determine xsize. + HWY_INLINE intptr_t PixelsPerRow() const { return planes_[0].PixelsPerRow(); } + + private: + // Returns pointer to the start of a row. + HWY_INLINE void* VoidPlaneRow(const size_t c, const size_t y) const { +#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN + if (c >= kNumPlanes || y >= ysize()) { + HWY_ABORT("PlaneRow(%d, %d) >= %d\n", static_cast(c), + static_cast(y), static_cast(ysize())); + } +#endif + // Use the first plane's stride because the compiler might not realize they + // are all equal. Thus we only need a single multiplication for all planes. + const size_t row_offset = y * planes_[0].bytes_per_row(); + const void* row = planes_[c].bytes() + row_offset; + return static_cast( + HWY_ASSUME_ALIGNED(row, HWY_ALIGNMENT)); + } + + private: + ImageT planes_[kNumPlanes]; +}; + +using Image3F = Image3; + +// Rectangular region in image(s). Factoring this out of Image instead of +// shifting the pointer by x0/y0 allows this to apply to multiple images with +// different resolutions. Can compare size via SameSize(rect1, rect2). +class Rect { + public: + // Most windows are xsize_max * ysize_max, except those on the borders where + // begin + size_max > end. + constexpr Rect(size_t xbegin, size_t ybegin, size_t xsize_max, + size_t ysize_max, size_t xend, size_t yend) + : x0_(xbegin), + y0_(ybegin), + xsize_(ClampedSize(xbegin, xsize_max, xend)), + ysize_(ClampedSize(ybegin, ysize_max, yend)) {} + + // Construct with origin and known size (typically from another Rect). + constexpr Rect(size_t xbegin, size_t ybegin, size_t xsize, size_t ysize) + : x0_(xbegin), y0_(ybegin), xsize_(xsize), ysize_(ysize) {} + + // Construct a rect that covers a whole image. + template + explicit Rect(const Image& image) + : Rect(0, 0, image.xsize(), image.ysize()) {} + + Rect() : Rect(0, 0, 0, 0) {} + + Rect(const Rect&) = default; + Rect& operator=(const Rect&) = default; + + Rect Subrect(size_t xbegin, size_t ybegin, size_t xsize_max, + size_t ysize_max) { + return Rect(x0_ + xbegin, y0_ + ybegin, xsize_max, ysize_max, x0_ + xsize_, + y0_ + ysize_); + } + + template + const T* ConstRow(const Image* image, size_t y) const { + return image->ConstRow(y + y0_) + x0_; + } + + template + T* MutableRow(const Image* image, size_t y) const { + return image->MutableRow(y + y0_) + x0_; + } + + template + const T* ConstPlaneRow(const Image3& image, size_t c, size_t y) const { + return image.ConstPlaneRow(c, y + y0_) + x0_; + } + + template + T* MutablePlaneRow(Image3* image, const size_t c, size_t y) const { + return image->MutablePlaneRow(c, y + y0_) + x0_; + } + + // Returns true if this Rect fully resides in the given image. ImageT could be + // Image or Image3; however if ImageT is Rect, results are nonsensical. + template + bool IsInside(const ImageT& image) const { + return (x0_ + xsize_ <= image.xsize()) && (y0_ + ysize_ <= image.ysize()); + } + + size_t x0() const { return x0_; } + size_t y0() const { return y0_; } + size_t xsize() const { return xsize_; } + size_t ysize() const { return ysize_; } + + private: + // Returns size_max, or whatever is left in [begin, end). + static constexpr size_t ClampedSize(size_t begin, size_t size_max, + size_t end) { + return (begin + size_max <= end) ? size_max + : (end > begin ? end - begin : 0); + } + + size_t x0_; + size_t y0_; + + size_t xsize_; + size_t ysize_; +}; + +// Works for any image-like input type(s). +template +HWY_MAYBE_UNUSED bool SameSize(const Image1& image1, const Image2& image2) { + return image1.xsize() == image2.xsize() && image1.ysize() == image2.ysize(); +} + +// Mirrors out of bounds coordinates and returns valid coordinates unchanged. +// We assume the radius (distance outside the image) is small compared to the +// image size, otherwise this might not terminate. +// The mirror is outside the last column (border pixel is also replicated). +static HWY_INLINE HWY_MAYBE_UNUSED size_t Mirror(int64_t x, + const int64_t xsize) { + HWY_DASSERT(xsize != 0); + + // TODO(janwas): replace with branchless version + while (x < 0 || x >= xsize) { + if (x < 0) { + x = -x - 1; + } else { + x = 2 * xsize - 1 - x; + } + } + return static_cast(x); +} + +// Wrap modes for ensuring X/Y coordinates are in the valid range [0, size): + +// Mirrors (repeating the edge pixel once). Useful for convolutions. +struct WrapMirror { + HWY_INLINE size_t operator()(const int64_t coord, const size_t size) const { + return Mirror(coord, static_cast(size)); + } +}; + +// Returns the same coordinate, for when we know "coord" is already valid (e.g. +// interior of an image). +struct WrapUnchanged { + HWY_INLINE size_t operator()(const int64_t coord, size_t /*size*/) const { + return static_cast(coord); + } +}; + +// Similar to Wrap* but for row pointers (reduces Row() multiplications). + +class WrapRowMirror { + public: + template + WrapRowMirror(const View& image, size_t ysize) + : first_row_(image.ConstRow(0)), last_row_(image.ConstRow(ysize - 1)) {} + + const float* operator()(const float* const HWY_RESTRICT row, + const int64_t stride) const { + if (row < first_row_) { + const int64_t num_before = first_row_ - row; + // Mirrored; one row before => row 0, two before = row 1, ... + return first_row_ + num_before - stride; + } + if (row > last_row_) { + const int64_t num_after = row - last_row_; + // Mirrored; one row after => last row, two after = last - 1, ... + return last_row_ - num_after + stride; + } + return row; + } + + private: + const float* const HWY_RESTRICT first_row_; + const float* const HWY_RESTRICT last_row_; +}; + +struct WrapRowUnchanged { + HWY_INLINE const float* operator()(const float* const HWY_RESTRICT row, + int64_t /*stride*/) const { + return row; + } +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_IMAGE_IMAGE_H_ diff --git a/lib/highway/hwy/contrib/math/fast_math-inl.h b/lib/highway/hwy/contrib/math/fast_math-inl.h new file mode 100644 index 00000000000..3dd1c5f532c --- /dev/null +++ b/lib/highway/hwy/contrib/math/fast_math-inl.h @@ -0,0 +1,1736 @@ +// Copyright 2026 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Include guard (still compiled once per target) +#if defined(HIGHWAY_HWY_CONTRIB_MATH_FAST_MATH_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef HIGHWAY_HWY_CONTRIB_MATH_FAST_MATH_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_MATH_FAST_MATH_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_MATH_FAST_MATH_INL_H_ +#endif + +#include +#include + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +namespace impl { + +// Port of reduce_angle_tan_SIMD +template > +HWY_INLINE void ReduceAngleTan(D d, V ang, V& x_red, V& sign) { + using T = TFromD; + const auto pi = Set(d, static_cast(3.14159265358979323846)); + const auto zero = Set(d, static_cast(0.0)); + const auto one = Set(d, static_cast(1.0)); + const auto minus_one = Set(d, static_cast(-1.0)); + + const auto inv_pi = Set(d, static_cast(0.31830988618379067153777)); + + // Modulo pi + auto quotient = Mul(ang, inv_pi); + quotient = Round(quotient); + auto ang_mod = NegMulAdd(quotient, pi, ang); + + // Determine sign + auto mask_neg = Lt(ang_mod, zero); + sign = IfThenElse(mask_neg, minus_one, one); + + // Absolute value + x_red = Abs(ang_mod); +} + +// Range reduction and exponent extraction for logarithm functions. +// Normalizes x to y in [0.707, 1.414] and extracts the exponent as a float in +// 'exp'. If kHandleSubnormals is true, scales subnormal inputs to prevent +// underflow. +template +HWY_INLINE void FastLogRangeReduction(D d, V x, V& y, V& exp) { + using T = TFromD; + const RebindToSigned di; + const RebindToUnsigned du; + using TI = TFromD; + using VI = decltype(Zero(di)); + + constexpr bool kIsF32 = (sizeof(T) == 4); + + const VI kExpMagicDiff = Set( + di, kIsF32 + ? static_cast(0x3F800000L - 0x3F3504F3L) + : static_cast(0x3FF0000000000000LL - 0x3FE6A09E00000000LL)); + + MFromD is_denormal; + if constexpr (kHandleSubnormals) { + const V kMinNormal = + Set(d, kIsF32 ? static_cast(1.175494351e-38f) + : static_cast(2.2250738585072014e-308)); + const V kScale = Set(d, kIsF32 ? static_cast(3.355443200e+7f) + : static_cast(1.8014398509481984e+16)); + is_denormal = Lt(x, kMinNormal); + x = MaskedMulOr(x, is_denormal, x, kScale); + } else { + (void)is_denormal; + } + + auto exp_bits = Add(BitCast(di, x), kExpMagicDiff); + + constexpr int kMantissaShift = kIsF32 ? 23 : 52; + const auto kBias = Set(di, kIsF32 ? 0x7F : 0x3FF); + const auto exp_int = Sub( + BitCast(di, ShiftRight(BitCast(du, exp_bits))), kBias); + exp = ConvertTo(d, exp_int); + + if constexpr (kHandleSubnormals) { + const V kExpScaleFloat = + Set(d, kIsF32 ? static_cast(-25.0) : static_cast(-54.0)); + exp = MaskedAddOr(exp, is_denormal, exp, kExpScaleFloat); + } + + const VI exp_int_shifted = ShiftLeft(exp_int); + const VI y_bits = Sub(BitCast(di, x), exp_int_shifted); + y = BitCast(d, y_bits); +} + +} // namespace impl + +namespace impl { + +template +struct FastExpImpl {}; + +template <> +struct FastExpImpl { + // Rounds float toward zero and returns as int32_t. + template , HWY_IF_F32_D(D)> + HWY_INLINE Vec> ToInt32(D /*unused*/, V x) { + return ConvertInRangeTo(Rebind(), x); + } + + // Computes 2^x, where x is an integer. + template >, HWY_IF_F32_D(D)> + HWY_INLINE Vec Pow2I(D d, VI32 x) { + const Rebind di32; + const VI32 kOffset = Set(di32, 0x7F); + return BitCast(d, ShiftLeft<23>(Add(x, kOffset))); + } + + // Sets the exponent of 'x' to 2^e. + template , class VI32 = Vec>, + HWY_IF_F32_D(D)> + HWY_INLINE V LoadExpShortRange(D d, V x, VI32 e) { + const VI32 y = ShiftRight<1>(e); + return Mul(Mul(x, Pow2I(d, y)), Pow2I(d, Sub(e, y))); + } + + template , class VI32 = Vec>, + HWY_IF_F32_D(D)> + HWY_INLINE V ExpReduce(D d, V x, VI32 q) { + // kMinusLn2 ~= -ln(2) + const V kMinusLn2 = Set(d, -0.69314718056f); + + // Extended precision modular arithmetic. + const V qf = ConvertTo(d, q); + return MulAdd(qf, kMinusLn2, x); + } + + template , class VI32 = Vec>, + HWY_IF_F32_D(D)> + HWY_INLINE V Exp2Reduce(D d, V x, VI32 q) { + const V qf = ConvertTo(d, q); + return Sub(x, qf); + } +}; + +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 +template <> +struct FastExpImpl { + // Rounds double toward zero and returns as int32_t. + template , HWY_IF_F64_D(D)> + HWY_INLINE Vec> ToInt32(D /*unused*/, V x) { + return DemoteInRangeTo(Rebind(), x); + } + + // Computes 2^x, where x is an integer. + template >, HWY_IF_F64_D(D)> + HWY_INLINE Vec Pow2I(D d, VI32 x) { + const Rebind di32; + const Rebind di64; + const VI32 kOffset = Set(di32, 0x3FF); + return BitCast(d, ShiftLeft<52>(PromoteTo(di64, Add(x, kOffset)))); + } + + // Sets the exponent of 'x' to 2^e. + template , class VI32 = Vec>, + HWY_IF_F64_D(D)> + HWY_INLINE V LoadExpShortRange(D d, V x, VI32 e) { + const VI32 y = ShiftRight<1>(e); + return Mul(Mul(x, Pow2I(d, y)), Pow2I(d, Sub(e, y))); + } + + template , class VI32 = Vec>, + HWY_IF_F64_D(D)> + HWY_INLINE V ExpReduce(D d, V x, VI32 q) { + // kMinusLn2 ~= -ln(2) + const V kMinusLn2 = Set(d, -0.6931471805599453); + + // Extended precision modular arithmetic. + const V qf = PromoteTo(d, q); + return MulAdd(qf, kMinusLn2, x); + } + + template , class VI32 = Vec>, + HWY_IF_F64_D(D)> + HWY_INLINE V Exp2Reduce(D d, V x, VI32 q) { + const V qf = PromoteTo(d, q); + return Sub(x, qf); + } +}; +#endif + +} // namespace impl + +/** + * Fast approximation of tan(x). + * + * Valid Lane Types: float32, float64 + * Max Relative Error: < 0.35% for angles equivalent to falling between [-89.99, + * +89.99] degrees (float32) and + * [-89.9999999, +89.9999999] degrees (float64). + * Valid Range: float32 : [-20, +20]rads + * float64 : [-39000, +39000]rads + * + * Note: Inputs extremely close to asymptotes may result in + * a sign flip due to precision limits. + * + * @return tangent of 'x' + */ +template +HWY_INLINE V FastTan(D d, V x) { + using T = TFromD; + + // Reduction + V x_red, sign; + impl::ReduceAngleTan(d, x, x_red, sign); + + constexpr size_t kLanes = HWY_MAX_LANES_D(D); + V b, c, d_val; + + if constexpr ((kLanes >= 4 && !HWY_HAVE_SCALABLE) || + (HWY_HAVE_SCALABLE && sizeof(T) == 4 && detail::IsFull(d))) { + // --- Table Lookup --- + const auto scale = Set(d, static_cast(3.8197186342)); + auto idx_float = Floor(Mul(x_red, scale)); + + // Convert to Integer Vector (Signed) + auto idx_int = ConvertTo(RebindToSigned(), idx_float); + + HWY_ALIGN static constexpr T arr_b[8] = { + static_cast(0), + static_cast(0.0174532925199432955), + static_cast(0.133808575986231942), + static_cast(0.378736447682769484), + static_cast(1.29590696960578966), + static_cast(9.45968454580926554), + static_cast(9.45968454580926554), + static_cast(9.45968454580926554)}; + + HWY_ALIGN static constexpr T arr_c[8] = { + static_cast(-0.0909090909092633431), + static_cast(-0.400000000000000022), + static_cast(-0.83333333333333337), + static_cast(-1.29999999999999982), + static_cast(-2.5), + static_cast(-10.9999999999791349), + static_cast(-10.9999999999791349), + static_cast(-10.9999999999791349)}; + + HWY_ALIGN static constexpr T arr_d[8] = { + static_cast(1.00277098842046231), + static_cast(1.14668131856027444), + static_cast(1.57370520888155374), + static_cast(2.18515222349690053), + static_cast(3.97062404828709958), + static_cast(17.2787595947438639), + static_cast(17.2787595947438639), + static_cast(17.2787595947438639)}; + + // Since Lookup8 is available for HWY_MIN_BYTES / sizeof(T) >= 4, this + // condition covers all cases we encounter inside the top level if block + // inside FastTan + b = Lookup8(d, arr_b, idx_int); + c = Lookup8(d, arr_c, idx_int); + d_val = Lookup8(d, arr_d, idx_int); + } else { + // --- FALLBACK PATH: Blend Chain --- + if constexpr (HWY_REGISTERS >= 32) { + // Split into two parallel chains to reduce dependency latency. + const auto t0 = Set(d, static_cast(0.2617993877995256)); + const auto t1 = Set(d, static_cast(0.5235987755990512)); + const auto t2 = Set(d, static_cast(0.7853981633985767)); + const auto t3 = Set(d, static_cast(1.0471975511981024)); + const auto t4 = Set(d, static_cast(1.3089969389976279)); + + // -- Chain 1: Indices 0 to 2 (Evaluated starting from t1 down to t0) + auto b_low = Set(d, static_cast(0.133808575986231942)); // idx 2 + auto c_low = Set(d, static_cast(-0.83333333333333337)); + auto d_low = Set(d, static_cast(1.57370520888155374)); + + auto mask = Lt(x_red, t1); + b_low = IfThenElse(mask, Set(d, static_cast(0.0174532925199432955)), + b_low); + c_low = IfThenElse(mask, Set(d, static_cast(-0.400000000000000022)), + c_low); + d_low = + IfThenElse(mask, Set(d, static_cast(1.14668131856027444)), d_low); + + mask = Lt(x_red, t0); + b_low = IfThenZeroElse(mask, b_low); + c_low = IfThenElse(mask, Set(d, static_cast(-0.0909090909092633431)), + c_low); + d_low = + IfThenElse(mask, Set(d, static_cast(1.00277098842046231)), d_low); + + // -- Chain 2: Indices 3 to 5 (Evaluated starting from t4 down to t3) + auto b_high = Set(d, static_cast(9.45968454580926554)); // idx 5 + auto c_high = Set(d, static_cast(-10.9999999999791349)); + auto d_high = Set(d, static_cast(17.2787595947438639)); + + mask = Lt(x_red, t4); + b_high = + IfThenElse(mask, Set(d, static_cast(1.29590696960578966)), b_high); + c_high = IfThenElse(mask, Set(d, static_cast(-2.5)), c_high); + d_high = + IfThenElse(mask, Set(d, static_cast(3.97062404828709958)), d_high); + + mask = Lt(x_red, t3); + b_high = IfThenElse(mask, Set(d, static_cast(0.378736447682769484)), + b_high); + c_high = IfThenElse(mask, Set(d, static_cast(-1.29999999999999982)), + c_high); + d_high = + IfThenElse(mask, Set(d, static_cast(2.18515222349690053)), d_high); + + // -- Merge the two chains + auto merge_mask = Lt(x_red, t2); + b = IfThenElse(merge_mask, b_low, b_high); + c = IfThenElse(merge_mask, c_low, c_high); + d_val = IfThenElse(merge_mask, d_low, d_high); + } else { + b = Set(d, static_cast(9.45968454580926554)); + c = Set(d, static_cast(-10.9999999999791349)); + d_val = Set(d, static_cast(17.2787595947438639)); + + auto mask = Lt(x_red, Set(d, static_cast(1.3089969389976279))); + b = IfThenElse(mask, Set(d, static_cast(1.29590696960578966)), b); + c = IfThenElse(mask, Set(d, static_cast(-2.5)), c); + d_val = + IfThenElse(mask, Set(d, static_cast(3.97062404828709958)), d_val); + + mask = Lt(x_red, Set(d, static_cast(1.0471975511981024))); + b = IfThenElse(mask, Set(d, static_cast(0.378736447682769484)), b); + c = IfThenElse(mask, Set(d, static_cast(-1.29999999999999982)), c); + d_val = + IfThenElse(mask, Set(d, static_cast(2.18515222349690053)), d_val); + + mask = Lt(x_red, Set(d, static_cast(0.7853981633985767))); + b = IfThenElse(mask, Set(d, static_cast(0.133808575986231942)), b); + c = IfThenElse(mask, Set(d, static_cast(-0.83333333333333337)), c); + d_val = + IfThenElse(mask, Set(d, static_cast(1.57370520888155374)), d_val); + + mask = Lt(x_red, Set(d, static_cast(0.5235987755990512))); + b = IfThenElse(mask, Set(d, static_cast(0.0174532925199432955)), b); + c = IfThenElse(mask, Set(d, static_cast(-0.400000000000000022)), c); + d_val = + IfThenElse(mask, Set(d, static_cast(1.14668131856027444)), d_val); + + mask = Lt(x_red, Set(d, static_cast(0.2617993877995256))); + b = IfThenZeroElse(mask, b); + c = IfThenElse(mask, Set(d, static_cast(-0.0909090909092633431)), c); + d_val = + IfThenElse(mask, Set(d, static_cast(1.00277098842046231)), d_val); + } + } + + // Math: y=(x + b)/(cx + d) + auto num = Add(x_red, b); + auto den = MulAdd(c, x_red, d_val); + + // Guard against denominator underflow/sign-flip near singularities + T epsilon_val; + if constexpr (sizeof(T) == 8) { + epsilon_val = static_cast(1e-15); + } else { + epsilon_val = static_cast(1e-6); + } + const auto kMinDenom = Set(d, epsilon_val); + // We use Abs() because on the reduced interval [0, pi/2], the tangent + // magnitude must be positive. If the polynomial approximation calculates a + // negative denominator (overshoot), it is an error, and we force it to be + // positive. + den = Max(Abs(den), kMinDenom); + + auto result = Div(num, den); + + // Apply Sign + return CopySign(result, sign); +} + +/** + * Fast approximation of atan(x). + * + * Valid Lane Types: float32, float64 + * Max Relative Error: 0.0034% + * Average Relative Error : 0.0002% for float32 + * 0.0002% for float64 + * Valid Range: float32: [-1e35, +1e35] + * float64: [-1e305, +1e305] + * + * @return arctangent of 'x' + */ +// if kAssumePositive is true, we assume inputs are non-negative. +template +HWY_INLINE V FastAtan(D d, V val) { + using T = TFromD; + + // Abs(val) and preserve sign for later (if needed) + V y; + if constexpr (kAssumePositive) { + y = val; + } else { + y = Abs(val); + } + + const V kOne = Set(d, static_cast(1.0)); + const auto gt1_mask = Gt(y, kOne); + // Domain reduction: map [1, inf) to [0, 1] + const V mapped_y = MaskedDivOr(y, gt1_mask, kOne, y); + + // Degree 4 polynomial for atan(x) / x over [0, 1] + const V c0 = Set(d, static_cast(0.9999653683169244)); + const V c1 = Set(d, static_cast(-0.3315525587266785)); + const V c2 = Set(d, static_cast(0.1844770291758270)); + const V c3 = Set(d, static_cast(-0.0907475543745560)); + const V c4 = Set(d, static_cast(0.0232748721030191)); + + const V z = Mul(mapped_y, mapped_y); + const V z2 = Mul(z, z); + + const V p01 = MulAdd(c1, z, c0); + const V p23 = MulAdd(c3, z, c2); + const V p234 = MulAdd(z2, c4, p23); + const V p = MulAdd(z2, p234, p01); + + const V poly = Mul(mapped_y, p); + + const V kPiOverTwo = Set(d, static_cast(1.57079632679489661923)); + auto result = MaskedSubOr(poly, gt1_mask, kPiOverTwo, poly); + + if constexpr (kAssumePositive) { + return result; + } else { + return CopySign(result, val); + } +} + +/** + * Fast approximation of atan2(y, x). + * + * Valid Lane Types: float32, float64 + * Valid Range: As long as y/x is in Valid Range for FastAtan() + * Correctly handles negative zero, infinities, and NaN. + * @return atan2 of 'y', 'x' + */ +template +HWY_INLINE V FastAtan2(const D d, V y, V x) { + using T = TFromD; + using M = MFromD; + + const V kPi = Set(d, static_cast(3.14159265358979323846264)); + const V kPiOverTwo = Set(d, static_cast(1.57079632679489661923)); + const V kOne = Set(d, static_cast(1.0)); + const V k0 = Zero(d); + + const V ax = Abs(x); + const V ay = Abs(y); + + const V num = Min(ax, ay); + const V den = Max(ax, ay); + + const M is_inf = IsInf(num); + V mapped_y = MaskedDivOr(k0, Ne(den, k0), num, den); + mapped_y = IfThenElse(is_inf, kOne, mapped_y); + + // Degree 4 polynomial for atan(x) / x over [0, 1] + const V c0 = Set(d, static_cast(0.9999653683169244)); + const V c1 = Set(d, static_cast(-0.3315525587266785)); + const V c2 = Set(d, static_cast(0.1844770291758270)); + const V c3 = Set(d, static_cast(-0.0907475543745560)); + const V c4 = Set(d, static_cast(0.0232748721030191)); + + const V z = Mul(mapped_y, mapped_y); + const V z2 = Mul(z, z); + + const V p01 = MulAdd(c1, z, c0); + const V p23 = MulAdd(c3, z, c2); + const V p234 = MulAdd(z2, c4, p23); + const V p = MulAdd(z2, p234, p01); + + const V poly = Mul(mapped_y, p); + + const M ay_gt_ax = Gt(ay, ax); + V angle = MaskedSubOr(poly, ay_gt_ax, kPiOverTwo, poly); + + const M x_neg = Lt(x, k0); + angle = MaskedSubOr(angle, x_neg, kPi, angle); + + const M is_nan = IsEitherNaN(y, x); + return IfThenElse(is_nan, NaN(d), CopySign(angle, y)); +} + +/** + * Fast approximation of tanh(x). + * + * Valid Lane Types: float32, float64 + * Max Relative Error : 0.02% for float32, 0.02% for float64 + * Average Relative Error : 0.0001% for float32, 0.000001% for float64 + * Max Relative Error for [-0.01, 0.01] : 0.003% + * Average Relative Error for [-0.01, 0.01] : 0.000002% + * Valid Range: float32: [-1e35, +1e35] + * float64: [-1e305, +1e305] + * + * @return hyperbolic tangent of 'x' + */ +template +HWY_INLINE V FastTanh(D d, V val) { + using T = TFromD; + + // Abs(val) and preserve sign for later + auto y = Abs(val); + + constexpr size_t kLanes = HWY_MAX_LANES_D(D); + V a, b, c, d_val; + + if constexpr ((kLanes >= 4 && !HWY_HAVE_SCALABLE) || + (HWY_HAVE_SCALABLE && sizeof(T) == 4 && detail::IsFull(d))) { + // Coefficients for P(y) ~ index using CF algo + const auto k0 = Set(d, static_cast(-0.1145426548151546)); + const auto k1 = Set(d, static_cast(6.911330994691481)); + const auto k2 = Set(d, static_cast(-2.511392313950185)); + const auto k3 = Set(d, static_cast(0.3465107224049977)); + + // Index calculation: idx = P(y) + // Estrin's scheme + // k0 + y * k1 + y^2 * (k2 + y * k3) + const auto y2 = Mul(y, y); + const auto p01 = MulAdd(k1, y, k0); + const auto p23 = MulAdd(k3, y, k2); + auto idx_poly = MulAdd(y2, p23, p01); + + // Convert to integer index + using DI = RebindToSigned; + auto idx_i = ConvertTo(DI(), idx_poly); + + // Clamp index to 7 + idx_i = Min(idx_i, Set(DI(), 7)); + + HWY_ALIGN static constexpr T arr_a[8] = { + static_cast(-0.32124064137467889), + static_cast(-0.25063809221086503), + static_cast(-0.12743143099276930), + static_cast(0.00879257493380024), + static_cast(0.09774602019349406), + static_cast(0.09746926160817335), + static_cast(0.03461152231207073), + static_cast(0.00190152088495461)}; + HWY_ALIGN static constexpr T arr_b[8] = { + static_cast(-0.00191824037528361), + static_cast(-0.04124816646249752), + static_cast(-0.17298439557734449), + static_cast(-0.40057378897277868), + static_cast(-0.61249080708786208), + static_cast(-0.60213708038791991), + static_cast(-0.27823168655690367), + static_cast(-0.02476274996694735)}; + HWY_ALIGN static constexpr T arr_c[8] = { + static_cast(1.00009167744117367), + static_cast(1.00761071731437957), + static_cast(1.05510014553372966), + static_cast(1.18269984996630884), + static_cast(1.35178582776199496), + static_cast(1.32900602936237289), + static_cast(0.76904955465038061), + static_cast(0.10755808660951535)}; + HWY_ALIGN static constexpr T arr_d[8] = { + static_cast(-0.00000105381180317), + static_cast(-0.00049073110164177), + static_cast(-0.00625688495915542), + static_cast(-0.03025960590948565), + static_cast(-0.07544878909500170), + static_cast(-0.06251270311746010), + static_cast(0.26202315539906595), + static_cast(0.84371089018138146)}; + + // Since Lookup8 is available for HWY_MIN_BYTES / sizeof(T) >= 4, this + // condition covers all cases we encounter inside the top level if block + // inside FastTanh + a = Lookup8(d, arr_a, idx_i); + b = Lookup8(d, arr_b, idx_i); + c = Lookup8(d, arr_c, idx_i); + d_val = Lookup8(d, arr_d, idx_i); + } else { + // --- FALLBACK PATH: Blend Chain --- + // Thresholds for intervals + const auto t0 = Set(d, static_cast(0.1717248723716211)); + const auto t1 = Set(d, static_cast(0.3477988003593246)); + const auto t2 = Set(d, static_cast(0.5534457063834466)); + const auto t3 = Set(d, static_cast(0.8043240819114703)); + const auto t4 = Set(d, static_cast(1.1345195608232461)); + const auto t5 = Set(d, static_cast(1.6442012736785494)); + const auto t6 = Set(d, static_cast(2.6358900094985716)); + + if constexpr (HWY_REGISTERS >= 32) { + // Split into two parallel chains to reduce dependency latency. + + // -- Chain 1: Indices 0 to 3 (Evaluated starting from t3 down to t0) + auto a_low = Set(d, static_cast(0.00879257493380024)); // idx 3 + auto b_low = Set(d, static_cast(-0.40057378897277868)); + auto c_low = Set(d, static_cast(1.18269984996630884)); + auto d_low = Set(d, static_cast(-0.03025960590948565)); + + auto mask = Lt(y, t2); + a_low = + IfThenElse(mask, Set(d, static_cast(-0.12743143099276930)), a_low); + b_low = + IfThenElse(mask, Set(d, static_cast(-0.17298439557734449)), b_low); + c_low = + IfThenElse(mask, Set(d, static_cast(1.05510014553372966)), c_low); + d_low = + IfThenElse(mask, Set(d, static_cast(-0.00625688495915542)), d_low); + + mask = Lt(y, t1); + a_low = + IfThenElse(mask, Set(d, static_cast(-0.25063809221086503)), a_low); + b_low = + IfThenElse(mask, Set(d, static_cast(-0.04124816646249752)), b_low); + c_low = + IfThenElse(mask, Set(d, static_cast(1.00761071731437957)), c_low); + d_low = + IfThenElse(mask, Set(d, static_cast(-0.00049073110164177)), d_low); + + mask = Lt(y, t0); + a_low = + IfThenElse(mask, Set(d, static_cast(-0.32124064137467889)), a_low); + b_low = + IfThenElse(mask, Set(d, static_cast(-0.00191824037528361)), b_low); + c_low = + IfThenElse(mask, Set(d, static_cast(1.00009167744117367)), c_low); + d_low = + IfThenElse(mask, Set(d, static_cast(-0.00000105381180317)), d_low); + + // -- Chain 2: Indices 4 to 7 (Evaluated starting from t6 down to t4) + auto a_high = Set(d, static_cast(0.00190152088495461)); // idx 7 + auto b_high = Set(d, static_cast(-0.02476274996694735)); + auto c_high = Set(d, static_cast(0.10755808660951535)); + auto d_high = Set(d, static_cast(0.84371089018138146)); + + mask = Lt(y, t6); + a_high = + IfThenElse(mask, Set(d, static_cast(0.03461152231207073)), a_high); + b_high = IfThenElse(mask, Set(d, static_cast(-0.27823168655690367)), + b_high); + c_high = + IfThenElse(mask, Set(d, static_cast(0.76904955465038061)), c_high); + d_high = + IfThenElse(mask, Set(d, static_cast(0.26202315539906595)), d_high); + + mask = Lt(y, t5); + a_high = + IfThenElse(mask, Set(d, static_cast(0.09746926160817335)), a_high); + b_high = IfThenElse(mask, Set(d, static_cast(-0.60213708038791991)), + b_high); + c_high = + IfThenElse(mask, Set(d, static_cast(1.32900602936237289)), c_high); + d_high = IfThenElse(mask, Set(d, static_cast(-0.06251270311746010)), + d_high); + + mask = Lt(y, t4); + a_high = + IfThenElse(mask, Set(d, static_cast(0.09774602019349406)), a_high); + b_high = IfThenElse(mask, Set(d, static_cast(-0.61249080708786208)), + b_high); + c_high = + IfThenElse(mask, Set(d, static_cast(1.35178582776199496)), c_high); + d_high = IfThenElse(mask, Set(d, static_cast(-0.07544878909500170)), + d_high); + + // -- Merge the two chains + auto merge_mask = Lt(y, t3); + a = IfThenElse(merge_mask, a_low, a_high); + b = IfThenElse(merge_mask, b_low, b_high); + c = IfThenElse(merge_mask, c_low, c_high); + d_val = IfThenElse(merge_mask, d_low, d_high); + } else { + // Start with highest index (7) + a = Set(d, static_cast(0.00190152088495461)); + b = Set(d, static_cast(-0.02476274996694735)); + c = Set(d, static_cast(0.10755808660951535)); + d_val = Set(d, static_cast(0.84371089018138146)); + + // If y < t6 (idx 6) + auto mask = Lt(y, t6); + a = IfThenElse(mask, Set(d, static_cast(0.03461152231207073)), a); + b = IfThenElse(mask, Set(d, static_cast(-0.27823168655690367)), b); + c = IfThenElse(mask, Set(d, static_cast(0.76904955465038061)), c); + d_val = + IfThenElse(mask, Set(d, static_cast(0.26202315539906595)), d_val); + + // If y < t5 (idx 5) + mask = Lt(y, t5); + a = IfThenElse(mask, Set(d, static_cast(0.09746926160817335)), a); + b = IfThenElse(mask, Set(d, static_cast(-0.60213708038791991)), b); + c = IfThenElse(mask, Set(d, static_cast(1.32900602936237289)), c); + d_val = + IfThenElse(mask, Set(d, static_cast(-0.06251270311746010)), d_val); + + // If y < t4 (idx 4) + mask = Lt(y, t4); + a = IfThenElse(mask, Set(d, static_cast(0.09774602019349406)), a); + b = IfThenElse(mask, Set(d, static_cast(-0.61249080708786208)), b); + c = IfThenElse(mask, Set(d, static_cast(1.35178582776199496)), c); + d_val = + IfThenElse(mask, Set(d, static_cast(-0.07544878909500170)), d_val); + + // If y < t3 (idx 3) + mask = Lt(y, t3); + a = IfThenElse(mask, Set(d, static_cast(0.00879257493380024)), a); + b = IfThenElse(mask, Set(d, static_cast(-0.40057378897277868)), b); + c = IfThenElse(mask, Set(d, static_cast(1.18269984996630884)), c); + d_val = + IfThenElse(mask, Set(d, static_cast(-0.03025960590948565)), d_val); + + // If y < t2 (idx 2) + mask = Lt(y, t2); + a = IfThenElse(mask, Set(d, static_cast(-0.12743143099276930)), a); + b = IfThenElse(mask, Set(d, static_cast(-0.17298439557734449)), b); + c = IfThenElse(mask, Set(d, static_cast(1.05510014553372966)), c); + d_val = + IfThenElse(mask, Set(d, static_cast(-0.00625688495915542)), d_val); + + // If y < t1 (idx 1) + mask = Lt(y, t1); + a = IfThenElse(mask, Set(d, static_cast(-0.25063809221086503)), a); + b = IfThenElse(mask, Set(d, static_cast(-0.04124816646249752)), b); + c = IfThenElse(mask, Set(d, static_cast(1.00761071731437957)), c); + d_val = + IfThenElse(mask, Set(d, static_cast(-0.00049073110164177)), d_val); + + // If y < t0 (idx 0) + mask = Lt(y, t0); + a = IfThenElse(mask, Set(d, static_cast(-0.32124064137467889)), a); + b = IfThenElse(mask, Set(d, static_cast(-0.00191824037528361)), b); + c = IfThenElse(mask, Set(d, static_cast(1.00009167744117367)), c); + d_val = + IfThenElse(mask, Set(d, static_cast(-0.00000105381180317)), d_val); + } + } + + // Math: f(y) = ay^3 + by^2 + cy + d + const auto y2 = Mul(y, y); + const auto pcd = MulAdd(c, y, d_val); + const auto pab = MulAdd(a, y, b); + auto result = MulAdd(y2, pab, pcd); + + const auto kSmall = Set(d, static_cast(0.01)); + result = IfThenElse(Lt(y, kSmall), y, result); + + const auto k1 = Set(d, static_cast(1.0)); + // We can take Min since cubic approximation for index 7 is monotonically + // increasing, so for inputs >5 the polynomial approximation will output >1.0 + // allowing us to use Min() directly instead of IfThenElse() + result = Min(result, k1); + + return CopySign(result, val); // Restore sign +} + +/** + * Fast approximation of log(x). + * + * Valid Lane Types: float32, float64 + * Max Relative Error: 0.0095% + * Average Relative Error : 0.000014% + * Valid Range: float32: (0, +FLT_MAX] + * float64: (0, +DBL_MAX] + * + * @return natural logarithm of 'x' + */ +// If false, subnormals are treated as zero. +template +HWY_INLINE V FastLog(D d, V x) { + using T = TFromD; + const V kLn2 = Set(d, static_cast(0.6931471805599453)); + V y, exp; + impl::FastLogRangeReduction(d, x, y, exp); + + constexpr size_t kLanes = HWY_MAX_LANES_D(D); + V b, c, d_coef; + + if constexpr ((kLanes >= 4 && !HWY_HAVE_SCALABLE) || + (HWY_HAVE_SCALABLE && sizeof(T) == 4 && detail::IsFull(d))) { + // --- Table Lookup --- + const auto scale = Set(d, static_cast(11.3137085)); + // Input is always non-negative, so Floor() + ConvertTo() + // can be replaced by direct ConvertTo() (truncation), which is faster. + // We use MulAdd(y, scale, -8.0) instead of Mul(Sub(y, lower_bound), scale) + // to save instructions. 0.70710678 * 11.3137085 ~= 8.0. + auto idx_i = ConvertInRangeTo( + RebindToSigned(), MulAdd(y, scale, Set(d, static_cast(-8.0)))); + + // Clamp index to 7 to handle overshoots + idx_i = Min(idx_i, Set(RebindToSigned(), 7)); + + HWY_ALIGN static constexpr T arr_b[8] = { + static_cast(-1.00194730895928918), + static_cast(-1.00042661239958708), + static_cast(-1.0000255203465902), + static_cast(-1), + static_cast(-0.999929163668789478), + static_cast(-0.999558969823431065), + static_cast(-0.998743736501089163), + static_cast(-0.997397894886509873)}; + + HWY_ALIGN static constexpr T arr_c[8] = { + static_cast(0.58385589069067223), + static_cast(0.548174514768112076), + static_cast(0.519613079391819999), + static_cast(0.497367242550162236), + static_cast(0.476391677761481835), + static_cast(0.459525070958496262), + static_cast(0.44490172854808846), + static_cast(0.432070989622927948)}; + + HWY_ALIGN static constexpr T arr_d[8] = { + static_cast(0.437891917978712797), + static_cast(0.459658304416673158), + static_cast(0.481694216614368509), + static_cast(0.502574248959839265), + static_cast(0.525922172040079627), + static_cast(0.547948723977362273), + static_cast(0.569860763464220654), + static_cast(0.591637568597068619)}; + + // Since Lookup8 is available for HWY_MIN_BYTES / sizeof(T) >= 4, this + // condition covers all cases we encounter inside the top level if block + // inside FastLog + b = Lookup8(d, arr_b, idx_i); + c = Lookup8(d, arr_c, idx_i); + d_coef = Lookup8(d, arr_d, idx_i); + } else { + // --- FALLBACK PATH: Blend Chain --- + // Polynomial Approximation + const auto t0 = Set(d, static_cast(0.7954951287634819)); + const auto t1 = Set(d, static_cast(0.8838834764038688)); + const auto t2 = Set(d, static_cast(0.9722718240442556)); + const auto t3 = Set(d, static_cast(1.0606601716846424)); + const auto t4 = Set(d, static_cast(1.1490485193250295)); + const auto t5 = Set(d, static_cast(1.2374368669654163)); + const auto t6 = Set(d, static_cast(1.3258252146058032)); + + if constexpr (HWY_REGISTERS >= 32) { + // Split into two parallel chains to reduce dependency latency. + + // -- Chain 1: Indices 0 to 3 (Evaluated starting from t3 down to t0) + auto b_low = Set(d, static_cast(-1)); // idx 3 + auto c_low = Set(d, static_cast(0.497367242550162236)); + auto d_low = Set(d, static_cast(0.502574248959839265)); + + auto mask = Lt(y, t2); + b_low = + IfThenElse(mask, Set(d, static_cast(-1.0000255203465902)), b_low); + c_low = + IfThenElse(mask, Set(d, static_cast(0.519613079391819999)), c_low); + d_low = + IfThenElse(mask, Set(d, static_cast(0.481694216614368509)), d_low); + + mask = Lt(y, t1); + b_low = + IfThenElse(mask, Set(d, static_cast(-1.00042661239958708)), b_low); + c_low = + IfThenElse(mask, Set(d, static_cast(0.548174514768112076)), c_low); + d_low = + IfThenElse(mask, Set(d, static_cast(0.459658304416673158)), d_low); + + mask = Lt(y, t0); + b_low = + IfThenElse(mask, Set(d, static_cast(-1.00194730895928918)), b_low); + c_low = + IfThenElse(mask, Set(d, static_cast(0.58385589069067223)), c_low); + d_low = + IfThenElse(mask, Set(d, static_cast(0.437891917978712797)), d_low); + + // -- Chain 2: Indices 4 to 7 (Evaluated starting from t6 down to t4) + auto b_high = Set(d, static_cast(-0.997397894886509873)); // idx 7 + auto c_high = Set(d, static_cast(0.432070989622927948)); + auto d_high = Set(d, static_cast(0.591637568597068619)); + + mask = Lt(y, t6); + b_high = IfThenElse(mask, Set(d, static_cast(-0.998743736501089163)), + b_high); + c_high = + IfThenElse(mask, Set(d, static_cast(0.44490172854808846)), c_high); + d_high = IfThenElse(mask, Set(d, static_cast(0.569860763464220654)), + d_high); + + mask = Lt(y, t5); + b_high = IfThenElse(mask, Set(d, static_cast(-0.999558969823431065)), + b_high); + c_high = IfThenElse(mask, Set(d, static_cast(0.459525070958496262)), + c_high); + d_high = IfThenElse(mask, Set(d, static_cast(0.547948723977362273)), + d_high); + + mask = Lt(y, t4); + b_high = IfThenElse(mask, Set(d, static_cast(-0.999929163668789478)), + b_high); + c_high = IfThenElse(mask, Set(d, static_cast(0.476391677761481835)), + c_high); + d_high = IfThenElse(mask, Set(d, static_cast(0.525922172040079627)), + d_high); + + // -- Merge the two chains + auto merge_mask = Lt(y, t3); + b = IfThenElse(merge_mask, b_low, b_high); + c = IfThenElse(merge_mask, c_low, c_high); + d_coef = IfThenElse(merge_mask, d_low, d_high); + } else { + // Start with highest index (7) + b = Set(d, static_cast(-0.997397894886509873)); + c = Set(d, static_cast(0.432070989622927948)); + d_coef = Set(d, static_cast(0.591637568597068619)); + + // If y < t6 (idx 6) + auto mask = Lt(y, t6); + b = IfThenElse(mask, Set(d, static_cast(-0.998743736501089163)), b); + c = IfThenElse(mask, Set(d, static_cast(0.44490172854808846)), c); + d_coef = IfThenElse(mask, Set(d, static_cast(0.569860763464220654)), + d_coef); + + // If y < t5 (idx 5) + mask = Lt(y, t5); + b = IfThenElse(mask, Set(d, static_cast(-0.999558969823431065)), b); + c = IfThenElse(mask, Set(d, static_cast(0.459525070958496262)), c); + d_coef = IfThenElse(mask, Set(d, static_cast(0.547948723977362273)), + d_coef); + + // If y < t4 (idx 4) + mask = Lt(y, t4); + b = IfThenElse(mask, Set(d, static_cast(-0.999929163668789478)), b); + c = IfThenElse(mask, Set(d, static_cast(0.476391677761481835)), c); + d_coef = IfThenElse(mask, Set(d, static_cast(0.525922172040079627)), + d_coef); + + // If y < t3 (idx 3) + mask = Lt(y, t3); + b = IfThenElse(mask, Set(d, static_cast(-1)), b); + c = IfThenElse(mask, Set(d, static_cast(0.497367242550162236)), c); + d_coef = IfThenElse(mask, Set(d, static_cast(0.502574248959839265)), + d_coef); + + // If y < t2 (idx 2) + mask = Lt(y, t2); + b = IfThenElse(mask, Set(d, static_cast(-1.0000255203465902)), b); + c = IfThenElse(mask, Set(d, static_cast(0.519613079391819999)), c); + d_coef = IfThenElse(mask, Set(d, static_cast(0.481694216614368509)), + d_coef); + + // If y < t1 (idx 1) + mask = Lt(y, t1); + b = IfThenElse(mask, Set(d, static_cast(-1.00042661239958708)), b); + c = IfThenElse(mask, Set(d, static_cast(0.548174514768112076)), c); + d_coef = IfThenElse(mask, Set(d, static_cast(0.459658304416673158)), + d_coef); + + // If y < t0 (idx 0) + mask = Lt(y, t0); + b = IfThenElse(mask, Set(d, static_cast(-1.00194730895928918)), b); + c = IfThenElse(mask, Set(d, static_cast(0.58385589069067223)), c); + d_coef = IfThenElse(mask, Set(d, static_cast(0.437891917978712797)), + d_coef); + } + } + + // Math: y = (x + b)/(cx + d_coef) + auto num = Add(y, b); + auto den = MulAdd(c, y, d_coef); + + auto approx = Div(num, den); + + return MulAdd(exp, kLn2, approx); +} + +/** + * Fast approximation of exp(x). + * + * Valid Lane Types: float32, float64 + * Max Relative Error: 0.0007% for float32 [-87, 88] + * Max Relative Error: 0.0007% for float64 [-708, 706] + * Average Relative Error: 0.00002% for float32 [-87, 88] + * Average Relative Error: 0.00001% for float64 [-708, 706] + * Max Relative Error for Subnormals: 2.4% for float32 [-FLT_MAX, -87] + * Max Relative Error for Subnormals: 0.006% for float64 [-DBL_MAX, -708] + * Valid Range: float32[-FLT_MAX, +88], float64[-DBL_MAX, +706] + * + * @return e^x + */ +template +HWY_INLINE V FastExp(D d, V x) { + using T = TFromD; + impl::FastExpImpl impl; + + T lower_bound_val; + if constexpr (kHandleSubnormals) { + lower_bound_val = sizeof(T) == 4 ? -104.0 : -1000.0; + } else { + lower_bound_val = sizeof(T) == 4 ? -88.0 : -709.0; + } + const V kLowerBound = Set(d, static_cast(lower_bound_val)); + + const V kHalf = Set(d, static_cast(+0.5)); + const V kNegZero = Set(d, static_cast(-0.0)); + + const V kOneOverLog2 = Set(d, static_cast(+1.442695040888963407359924681)); + + using TI = MakeSigned; + const Rebind di; + + V x_clamped = x; + if constexpr (!kHandleSubnormals) { + x_clamped = Max(x, kLowerBound); + } + + const auto rounded_offs = BitCast( + d, + OrAnd(BitCast(di, kHalf), BitCast(di, x_clamped), BitCast(di, kNegZero))); + + const auto q = impl.ToInt32(d, MulAdd(x_clamped, kOneOverLog2, rounded_offs)); + + const auto x_red = impl.ExpReduce(d, x_clamped, q); + + // Degree 4 polynomial approximation of e^x on [-ln2/2, ln2/2] + // Generated via Caratheodory-Fejer approximation. + const auto c0 = Set(d, static_cast(1.0000001510806224569)); + const auto c1 = Set(d, static_cast(0.99996228117046825901)); + const auto c2 = Set(d, static_cast(0.49998365704575670199)); + const auto c3 = Set(d, static_cast(0.16792157982876812494)); + const auto c4 = Set(d, static_cast(0.041959439862987071845)); + + // Estrin's scheme + const auto x2 = Mul(x_red, x_red); + // term0 = c1*x + c0 + const auto term0 = MulAdd(c1, x_red, c0); + // term1 = c3*x + c2 + const auto term1 = MulAdd(c3, x_red, c2); + // term2 = c4*x^2 + term1 + const auto term2 = MulAdd(c4, x2, term1); + // approx = term2 * x^2 + term0 + const auto approx = MulAdd(term2, x2, term0); + + if constexpr (kHandleSubnormals) { + const V res = impl.LoadExpShortRange(d, approx, q); + // Handle underflow + return IfThenElseZero(Ge(x, kLowerBound), res); + } else { + // Optimization: avoid splitting the exponent since 'q' is guaranteed + // to fall within the normal floating-point ranges. + return Mul(approx, impl.Pow2I(d, q)); + } +} + +/** + * Fast approximation of exp2(x). + * + * Valid Lane Types: float32, float64 + * Max Relative Error: 0.0007% for float32 [-150, 128] + * Max Relative Error: 0.0007% for float64 [-1075, 1024] + * Average Relative Error: 0.00002% for float32 [-150, 128] + * Average Relative Error: 0.00001% for float64 [-1075, 1024] + * Max Relative Error for Subnormals: 0.08% for float32 [-FLT_MAX, -150] + * Max Relative Error for Subnormals: 0.03% for float64 [-DBL_MAX, -1075] + * Valid Range: float32[-FLT_MAX, +128], float64[-DBL_MAX, +1024] + * + * @return 2^x + */ +template +HWY_INLINE V FastExp2(D d, V x) { + using T = TFromD; + impl::FastExpImpl impl; + + T lower_bound_val; + if constexpr (kHandleSubnormals) { + // FastExp uses kLowerBound = -104.0 / -1000.0 since it operates on e^x. For + // FastExp2, we use lower limits correspondingly to -150.0 and -1075.0. + lower_bound_val = sizeof(T) == 4 ? -150.0 : -1075.0; + } else { + lower_bound_val = sizeof(T) == 4 ? -127.0 : -1023.0; + } + const V kLowerBound = Set(d, static_cast(lower_bound_val)); + + const V kHalf = Set(d, static_cast(+0.5)); + const V kNegZero = Set(d, static_cast(-0.0)); + + using TI = MakeSigned; + const Rebind di; + + V x_clamped = x; + if constexpr (!kHandleSubnormals) { + x_clamped = Max(x, kLowerBound); + } + + const auto rounded_offs = BitCast( + d, + OrAnd(BitCast(di, kHalf), BitCast(di, x_clamped), BitCast(di, kNegZero))); + + // FastExp calculates q = ToInt32(x * (1/ln(2)) + rounded_offs) + // FastExp2 does not need the (1/ln(2)) scaling factor since the input is + // already in base 2. + const auto q = impl.ToInt32(d, Add(x_clamped, rounded_offs)); + + const auto x_red = impl.Exp2Reduce(d, x_clamped, q); + + // Degree 4 polynomial approximation of 2^x on [-1/2, 1/2] + // Derived from FastExp coefficients by pre-absorbing ln2: + // c_fast_exp2[i] = c_fast_exp[i] * (ln2)^i. + const auto c0 = Set(d, static_cast(1.0000001510806224569)); + const auto c1 = Set(d, static_cast(0.69312104523363065471)); + const auto c2 = Set(d, static_cast(0.24021865239713606622)); + const auto c3 = Set(d, static_cast(0.05592203117565365516)); + const auto c4 = Set(d, static_cast(0.00968574163456345638)); + + // Estrin's scheme + const auto x2 = Mul(x_red, x_red); + // term0 = c1*x + c0 + const auto term0 = MulAdd(c1, x_red, c0); + // term1 = c3*x + c2 + const auto term1 = MulAdd(c3, x_red, c2); + // term2 = c4*x^2 + term1 + const auto term2 = MulAdd(c4, x2, term1); + // approx = term2 * x^2 + term0 + const auto approx = MulAdd(term2, x2, term0); + + if constexpr (kHandleSubnormals) { + const V res = impl.LoadExpShortRange(d, approx, q); + // Handle underflow + return IfThenElseZero(Ge(x, kLowerBound), res); + } else { + // Optimization: avoid splitting the exponent since 'q' is guaranteed + // to fall within the normal floating-point ranges. + return Mul(approx, impl.Pow2I(d, q)); + } +} + +/** + * Fast approximation of exp(x) for x <= 0. Subnormals are flushed to zero. + * + * Valid Lane Types: float32, float64 + * Max Relative Error: 0.0007% for float32 [-87, 0] + * Max Relative Error: 0.0007% for float64 [-708, 0] + * Average Relative Error: 0.00002% for float32 [-87, 0] + * Average Relative Error: 0.00001% for float64 [-708, 0] + * Valid Range: float32[-FLT_MAX, +0.0], float64[-DBL_MAX, +0.0] + * + * @return e^x + */ +template +HWY_INLINE V FastExpMinusOrZero(D d, V x) { + using T = TFromD; + impl::FastExpImpl impl; + + const V kHalfMinus = Set(d, static_cast(-0.5)); + const V kLowerBound = + Set(d, static_cast((sizeof(T) == 4 ? -88.0 : -709.0))); + + const V kOneOverLog2 = Set(d, static_cast(+1.442695040888963407359924681)); + + // Optimization for x <= 0: + // FastExp computes `rounded_offs = sign(x) ? -0.5 : 0.5` to round the + // multiplied argument towards zero. Since x <= 0, we avoid the dynamic + // calculation and simply use a constant -0.5 (kHalfMinus). + // + // We clamp x to be >= kLowerBound. For x < kLowerBound, the remapped + // exponent q becomes -127 (f32) or -1023 (f64), which Pow2I converts to + // exactly 0.0. This avoids subnormals and the need for a final mask. + const auto x_clamped = Max(x, kLowerBound); + const auto q = impl.ToInt32(d, MulAdd(x_clamped, kOneOverLog2, kHalfMinus)); + + const auto x_red = impl.ExpReduce(d, x_clamped, q); + + // Degree 4 polynomial approximation of e^x on [-ln2/2, ln2/2] + // Generated via Caratheodory-Fejer approximation. + const auto c0 = Set(d, static_cast(1.0000001510806224569)); + const auto c1 = Set(d, static_cast(0.99996228117046825901)); + const auto c2 = Set(d, static_cast(0.49998365704575670199)); + const auto c3 = Set(d, static_cast(0.16792157982876812494)); + const auto c4 = Set(d, static_cast(0.041959439862987071845)); + + // Estrin's scheme + const auto x2 = Mul(x_red, x_red); + // term0 = c1*x + c0 + const auto term0 = MulAdd(c1, x_red, c0); + // term1 = c3*x + c2 + const auto term1 = MulAdd(c3, x_red, c2); + // term2 = c4*x^2 + term1 + const auto term2 = MulAdd(c4, x2, term1); + // approx = term2 * x^2 + term0 + const auto approx = MulAdd(term2, x2, term0); + + // Since inputs < -88.0 (f32) and < -709.0 (f64) are flushed to zero, + // we do not generate subnormals. Therefore, q is guaranteed to be >= -127 + // and we can use Pow2I directly without splitting the exponent computation. + return Mul(approx, impl.Pow2I(d, q)); +} + +/** + * Fast approximation of log2(x). + * + * Valid Lane Types: float32, float64 + * Max Relative Error: 0.008% + * Valid Range: float32: (0, +FLT_MAX] + * float64: (0, +DBL_MAX] + * + * @return base 2 logarithm of 'x' + */ +// If false, subnormals are treated as zero. +template +HWY_INLINE V FastLog2(D d, V x) { + using T = TFromD; + V y, exp; + impl::FastLogRangeReduction(d, x, y, exp); + + constexpr size_t kLanes = HWY_MAX_LANES_D(D); + V b, c, d_coef; + + if constexpr ((kLanes >= 4 && !HWY_HAVE_SCALABLE) || + (HWY_HAVE_SCALABLE && sizeof(T) == 4 && detail::IsFull(d))) { + const auto scale = Set(d, static_cast(11.3137085)); + auto idx_i = ConvertInRangeTo( + RebindToSigned(), MulAdd(y, scale, Set(d, static_cast(-8.0)))); + + idx_i = Min(idx_i, Set(RebindToSigned(), 7)); + + HWY_ALIGN static constexpr T arr_b[8] = { + static_cast(-1.00194730895928918), + static_cast(-1.00042661239958708), + static_cast(-1.0000255203465902), + static_cast(-1), + static_cast(-0.999929163668789478), + static_cast(-0.999558969823431065), + static_cast(-0.998743736501089163), + static_cast(-0.997397894886509873)}; + + HWY_ALIGN static constexpr T arr_c[8] = { + static_cast(0.58385589069067223), + static_cast(0.548174514768112076), + static_cast(0.519613079391819999), + static_cast(0.497367242550162236), + static_cast(0.476391677761481835), + static_cast(0.459525070958496262), + static_cast(0.44490172854808846), + static_cast(0.432070989622927948)}; + + HWY_ALIGN static constexpr T arr_d[8] = { + static_cast(0.437891917978712797), + static_cast(0.459658304416673158), + static_cast(0.481694216614368509), + static_cast(0.502574248959839265), + static_cast(0.525922172040079627), + static_cast(0.547948723977362273), + static_cast(0.569860763464220654), + static_cast(0.591637568597068619)}; + + b = Lookup8(d, arr_b, idx_i); + c = Lookup8(d, arr_c, idx_i); + d_coef = Lookup8(d, arr_d, idx_i); + } else { + const auto t0 = Set(d, static_cast(0.7954951287634819)); + const auto t1 = Set(d, static_cast(0.8838834764038688)); + const auto t2 = Set(d, static_cast(0.9722718240442556)); + const auto t3 = Set(d, static_cast(1.0606601716846424)); + const auto t4 = Set(d, static_cast(1.1490485193250295)); + const auto t5 = Set(d, static_cast(1.2374368669654163)); + const auto t6 = Set(d, static_cast(1.3258252146058032)); + + if constexpr (HWY_REGISTERS >= 32) { + auto b_low = Set(d, static_cast(-1)); + auto c_low = Set(d, static_cast(0.497367242550162236)); + auto d_low = Set(d, static_cast(0.502574248959839265)); + + auto mask = Lt(y, t2); + b_low = + IfThenElse(mask, Set(d, static_cast(-1.0000255203465902)), b_low); + c_low = + IfThenElse(mask, Set(d, static_cast(0.519613079391819999)), c_low); + d_low = + IfThenElse(mask, Set(d, static_cast(0.481694216614368509)), d_low); + + mask = Lt(y, t1); + b_low = + IfThenElse(mask, Set(d, static_cast(-1.00042661239958708)), b_low); + c_low = + IfThenElse(mask, Set(d, static_cast(0.548174514768112076)), c_low); + d_low = + IfThenElse(mask, Set(d, static_cast(0.459658304416673158)), d_low); + + mask = Lt(y, t0); + b_low = + IfThenElse(mask, Set(d, static_cast(-1.00194730895928918)), b_low); + c_low = + IfThenElse(mask, Set(d, static_cast(0.58385589069067223)), c_low); + d_low = + IfThenElse(mask, Set(d, static_cast(0.437891917978712797)), d_low); + + auto b_high = Set(d, static_cast(-0.997397894886509873)); + auto c_high = Set(d, static_cast(0.432070989622927948)); + auto d_high = Set(d, static_cast(0.591637568597068619)); + + mask = Lt(y, t6); + b_high = IfThenElse(mask, Set(d, static_cast(-0.998743736501089163)), + b_high); + c_high = + IfThenElse(mask, Set(d, static_cast(0.44490172854808846)), c_high); + d_high = IfThenElse(mask, Set(d, static_cast(0.569860763464220654)), + d_high); + + mask = Lt(y, t5); + b_high = IfThenElse(mask, Set(d, static_cast(-0.999558969823431065)), + b_high); + c_high = IfThenElse(mask, Set(d, static_cast(0.459525070958496262)), + c_high); + d_high = IfThenElse(mask, Set(d, static_cast(0.547948723977362273)), + d_high); + + mask = Lt(y, t4); + b_high = IfThenElse(mask, Set(d, static_cast(-0.999929163668789478)), + b_high); + c_high = IfThenElse(mask, Set(d, static_cast(0.476391677761481835)), + c_high); + d_high = IfThenElse(mask, Set(d, static_cast(0.525922172040079627)), + d_high); + + auto merge_mask = Lt(y, t3); + b = IfThenElse(merge_mask, b_low, b_high); + c = IfThenElse(merge_mask, c_low, c_high); + d_coef = IfThenElse(merge_mask, d_low, d_high); + } else { + b = Set(d, static_cast(-0.997397894886509873)); + c = Set(d, static_cast(0.432070989622927948)); + d_coef = Set(d, static_cast(0.591637568597068619)); + + auto mask = Lt(y, t6); + b = IfThenElse(mask, Set(d, static_cast(-0.998743736501089163)), b); + c = IfThenElse(mask, Set(d, static_cast(0.44490172854808846)), c); + d_coef = IfThenElse(mask, Set(d, static_cast(0.569860763464220654)), + d_coef); + + mask = Lt(y, t5); + b = IfThenElse(mask, Set(d, static_cast(-0.999558969823431065)), b); + c = IfThenElse(mask, Set(d, static_cast(0.459525070958496262)), c); + d_coef = IfThenElse(mask, Set(d, static_cast(0.547948723977362273)), + d_coef); + + mask = Lt(y, t4); + b = IfThenElse(mask, Set(d, static_cast(-0.999929163668789478)), b); + c = IfThenElse(mask, Set(d, static_cast(0.476391677761481835)), c); + d_coef = IfThenElse(mask, Set(d, static_cast(0.525922172040079627)), + d_coef); + + mask = Lt(y, t3); + b = IfThenElse(mask, Set(d, static_cast(-1)), b); + c = IfThenElse(mask, Set(d, static_cast(0.497367242550162236)), c); + d_coef = IfThenElse(mask, Set(d, static_cast(0.502574248959839265)), + d_coef); + + mask = Lt(y, t2); + b = IfThenElse(mask, Set(d, static_cast(-1.0000255203465902)), b); + c = IfThenElse(mask, Set(d, static_cast(0.519613079391819999)), c); + d_coef = IfThenElse(mask, Set(d, static_cast(0.481694216614368509)), + d_coef); + + mask = Lt(y, t1); + b = IfThenElse(mask, Set(d, static_cast(-1.00042661239958708)), b); + c = IfThenElse(mask, Set(d, static_cast(0.548174514768112076)), c); + d_coef = IfThenElse(mask, Set(d, static_cast(0.459658304416673158)), + d_coef); + + mask = Lt(y, t0); + b = IfThenElse(mask, Set(d, static_cast(-1.00194730895928918)), b); + c = IfThenElse(mask, Set(d, static_cast(0.58385589069067223)), c); + d_coef = IfThenElse(mask, Set(d, static_cast(0.437891917978712797)), + d_coef); + } + } + + auto num = Add(y, b); + auto den = MulAdd(c, y, d_coef); + + auto approx = Div(num, den); + + const auto kInvLn2 = Set(d, static_cast(1.4426950408889634)); + return MulAdd(approx, kInvLn2, exp); +} + +/** + * Fast approximation of log10(x). + * + * Valid Lane Types: float32, float64 + * Max Relative Error: 0.008% + * Valid Range: float32: (0, +FLT_MAX] + * float64: (0, +DBL_MAX] + * + * @return base 10 logarithm of 'x' + */ +// If false, subnormals are treated as zero. +template +HWY_INLINE V FastLog10(D d, V x) { + using T = TFromD; + V y, exp; + impl::FastLogRangeReduction(d, x, y, exp); + + constexpr size_t kLanes = HWY_MAX_LANES_D(D); + V b, c, d_coef; + + if constexpr ((kLanes >= 4 && !HWY_HAVE_SCALABLE) || + (HWY_HAVE_SCALABLE && sizeof(T) == 4 && detail::IsFull(d))) { + const auto scale = Set(d, static_cast(11.3137085)); + auto idx_i = ConvertInRangeTo( + RebindToSigned(), MulAdd(y, scale, Set(d, static_cast(-8.0)))); + + idx_i = Min(idx_i, Set(RebindToSigned(), 7)); + + HWY_ALIGN static constexpr T arr_b[8] = { + static_cast(-1.00194730895928918), + static_cast(-1.00042661239958708), + static_cast(-1.0000255203465902), + static_cast(-1), + static_cast(-0.999929163668789478), + static_cast(-0.999558969823431065), + static_cast(-0.998743736501089163), + static_cast(-0.997397894886509873)}; + + // Denominator coefficients are pre-multiplied by Ln(10) compared to FastLog + // to save a multiply instruction in the final step (preabsorption). + HWY_ALIGN static constexpr T arr_c[8] = { + static_cast(1.344377870361103122), + static_cast(1.262218466064299438), + static_cast(1.196453330732336395), + static_cast(1.145230398439557540), + static_cast(1.096932375640011115), + static_cast(1.058095578246064594), + static_cast(1.024424088002112043), + static_cast(0.994880219820938994)}; + + HWY_ALIGN static constexpr T arr_d[8] = { + static_cast(1.008283402680355545), + static_cast(1.058402359620750799), + static_cast(1.109141922557689730), + static_cast(1.157219973777604327), + static_cast(1.210980553414537253), + static_cast(1.261698563555383457), + static_cast(1.312152899034920495), + static_cast(1.362295845906852376)}; + + b = Lookup8(d, arr_b, idx_i); + c = Lookup8(d, arr_c, idx_i); + d_coef = Lookup8(d, arr_d, idx_i); + } else { + const auto t0 = Set(d, static_cast(0.7954951287634819)); + const auto t1 = Set(d, static_cast(0.8838834764038688)); + const auto t2 = Set(d, static_cast(0.9722718240442556)); + const auto t3 = Set(d, static_cast(1.0606601716846424)); + const auto t4 = Set(d, static_cast(1.1490485193250295)); + const auto t5 = Set(d, static_cast(1.2374368669654163)); + const auto t6 = Set(d, static_cast(1.3258252146058032)); + + if constexpr (HWY_REGISTERS >= 32) { + auto b_low = Set(d, static_cast(-1)); // idx 3 + // Denominator coefficients pre-multiplied by Ln(10). + auto c_low = Set(d, static_cast(1.145230398439557540)); + auto d_low = Set(d, static_cast(1.157219973777604327)); + + auto mask = Lt(y, t2); + b_low = + IfThenElse(mask, Set(d, static_cast(-1.0000255203465902)), b_low); + c_low = + IfThenElse(mask, Set(d, static_cast(1.196453330732336395)), c_low); + d_low = + IfThenElse(mask, Set(d, static_cast(1.109141922557689730)), d_low); + + mask = Lt(y, t1); + b_low = + IfThenElse(mask, Set(d, static_cast(-1.00042661239958708)), b_low); + c_low = + IfThenElse(mask, Set(d, static_cast(1.262218466064299438)), c_low); + d_low = + IfThenElse(mask, Set(d, static_cast(1.058402359620750799)), d_low); + + mask = Lt(y, t0); + b_low = + IfThenElse(mask, Set(d, static_cast(-1.00194730895928918)), b_low); + c_low = + IfThenElse(mask, Set(d, static_cast(1.344377870361103122)), c_low); + d_low = + IfThenElse(mask, Set(d, static_cast(1.008283402680355545)), d_low); + + auto b_high = Set(d, static_cast(-0.997397894886509873)); // idx 7 + // Denominator coefficients pre-multiplied by Ln(10). + auto c_high = Set(d, static_cast(0.994880219820938994)); + auto d_high = Set(d, static_cast(1.362295845906852376)); + + mask = Lt(y, t6); + b_high = IfThenElse(mask, Set(d, static_cast(-0.998743736501089163)), + b_high); + c_high = IfThenElse(mask, Set(d, static_cast(1.024424088002112043)), + c_high); + d_high = IfThenElse(mask, Set(d, static_cast(1.312152899034920495)), + d_high); + + mask = Lt(y, t5); + b_high = IfThenElse(mask, Set(d, static_cast(-0.999558969823431065)), + b_high); + c_high = IfThenElse(mask, Set(d, static_cast(1.058095578246064594)), + c_high); + d_high = IfThenElse(mask, Set(d, static_cast(1.261698563555383457)), + d_high); + + mask = Lt(y, t4); + b_high = IfThenElse(mask, Set(d, static_cast(-0.999929163668789478)), + b_high); + c_high = IfThenElse(mask, Set(d, static_cast(1.096932375640011115)), + c_high); + d_high = IfThenElse(mask, Set(d, static_cast(1.210980553414537253)), + d_high); + + auto merge_mask = Lt(y, t3); + b = IfThenElse(merge_mask, b_low, b_high); + c = IfThenElse(merge_mask, c_low, c_high); + d_coef = IfThenElse(merge_mask, d_low, d_high); + } else { + b = Set(d, static_cast(-0.997397894886509873)); + // Denominator coefficients pre-multiplied by Ln(10). + c = Set(d, static_cast(0.994880219820938994)); + d_coef = Set(d, static_cast(1.362295845906852376)); + + auto mask = Lt(y, t6); + b = IfThenElse(mask, Set(d, static_cast(-0.998743736501089163)), b); + c = IfThenElse(mask, Set(d, static_cast(1.024424088002112043)), c); + d_coef = IfThenElse(mask, Set(d, static_cast(1.312152899034920495)), + d_coef); + + mask = Lt(y, t5); + b = IfThenElse(mask, Set(d, static_cast(-0.999558969823431065)), b); + c = IfThenElse(mask, Set(d, static_cast(1.058095578246064594)), c); + d_coef = IfThenElse(mask, Set(d, static_cast(1.261698563555383457)), + d_coef); + + mask = Lt(y, t4); + b = IfThenElse(mask, Set(d, static_cast(-0.999929163668789478)), b); + c = IfThenElse(mask, Set(d, static_cast(1.096932375640011115)), c); + d_coef = IfThenElse(mask, Set(d, static_cast(1.210980553414537253)), + d_coef); + + mask = Lt(y, t3); + b = IfThenElse(mask, Set(d, static_cast(-1)), b); + c = IfThenElse(mask, Set(d, static_cast(1.145230398439557540)), c); + d_coef = IfThenElse(mask, Set(d, static_cast(1.157219973777604327)), + d_coef); + + mask = Lt(y, t2); + b = IfThenElse(mask, Set(d, static_cast(-1.0000255203465902)), b); + c = IfThenElse(mask, Set(d, static_cast(1.196453330732336395)), c); + d_coef = IfThenElse(mask, Set(d, static_cast(1.109141922557689730)), + d_coef); + + mask = Lt(y, t1); + b = IfThenElse(mask, Set(d, static_cast(-1.00042661239958708)), b); + c = IfThenElse(mask, Set(d, static_cast(1.262218466064299438)), c); + d_coef = IfThenElse(mask, Set(d, static_cast(1.058402359620750799)), + d_coef); + + mask = Lt(y, t0); + b = IfThenElse(mask, Set(d, static_cast(-1.00194730895928918)), b); + c = IfThenElse(mask, Set(d, static_cast(1.344377870361103122)), c); + d_coef = IfThenElse(mask, Set(d, static_cast(1.008283402680355545)), + d_coef); + } + } + + auto num = Add(y, b); + auto den = MulAdd(c, y, d_coef); + auto approx = Div(num, den); + + const auto kLog10_2 = Set(d, static_cast(0.3010299956639812)); // log10(2) + // Computes exp * log10(2) + approx. Since approx was scaled by 1/Ln(10) + // via the pre-scaled denominator, this yields the correct log10 result + // using a single MulAdd instruction. + return MulAdd(exp, kLog10_2, approx); +} + +/** + * Fast approximation of log(1 + x). + * + * Valid Lane Types: float32, float64 + * Max Relative Error: 0.0081% + * Valid Range: float32: [-1 + epsilon, +FLT_MAX] + * float64: [-1 + epsilon, +DBL_MAX] + * + * @return natural logarithm of '1 + x' + */ +// If false, subnormals are treated as zero. +template +HWY_INLINE V FastLog1p(const D d, V x) { + using T = TFromD; + const V kOne = Set(d, static_cast(+1.0)); + + const V y = Add(x, kOne); + const Mask not_pole = Ne(y, kOne); + // If y == 1, divisor becomes 1 (dummy), avoiding division by zero. + const V divisor = MaskedSubOr(y, not_pole, y, kOne); + // Ensure exactly 1.0 when x == divisor. This is necessary because some + // platforms (like Armv7) use Newton-Raphson for division, which can return + // 0.0, instead of 1.0 when the reciprocal calculation underflows + // for very large x. + const V div_res = MaskedDivOr(kOne, Ne(x, divisor), x, divisor); + const V non_pole = Mul(FastLog(d, y), div_res); + return IfThenElse(not_pole, non_pole, x); +} + +/** + * Fast approximation of base^exp. + * + * Valid Lane Types: float32, float64 + * Valid Range: float32: base in (0, +FLT_MAX], exp * log(base) in [-25.0, + * +25.0] float64: base in (0, +DBL_MAX], exp * log(base) in [-25.0, +25.0] Max + * Relative Error for Valid Range: float32 : 0.27%, float64 : 0.22% + * @return base^exp + */ +// If false, subnormals are treated as zero. +template +HWY_INLINE V FastPow(D d, V base, V exp) { + return FastExp( + d, Mul(exp, FastLog(d, base))); +} + +template +HWY_NOINLINE V CallFastAtan(const D d, VecArg x) { + return FastAtan(d, x); +} + +template +HWY_NOINLINE V CallFastTan(const D d, VecArg x) { + return FastTan(d, x); +} + +template +HWY_NOINLINE V CallFastAtan2(const D d, VecArg y, VecArg x) { + return FastAtan2(d, y, x); +} + +template +HWY_NOINLINE V CallFastTanh(const D d, VecArg x) { + return FastTanh(d, x); +} + +template +HWY_NOINLINE V CallFastLog(const D d, VecArg x) { + return FastLog<>(d, x); +} + +template +HWY_NOINLINE V CallFastExp(const D d, VecArg x) { + return FastExp(d, x); +} + +template +HWY_NOINLINE V CallFastExp2(const D d, VecArg x) { + return FastExp2(d, x); +} + +template +HWY_NOINLINE V CallFastExpMinusOrZero(const D d, VecArg x) { + return FastExpMinusOrZero(d, x); +} +template +HWY_NOINLINE V CallFastLog2(const D d, VecArg x) { + return FastLog2<>(d, x); +} + +template +HWY_NOINLINE V CallFastLog10(const D d, VecArg x) { + return FastLog10<>(d, x); +} + +template +HWY_NOINLINE V CallFastLog1p(const D d, VecArg x) { + return FastLog1p<>(d, x); +} + +template +HWY_NOINLINE V CallFastPow(const D d, VecArg base, VecArg exp) { + return FastPow<>(d, base, exp); +} + +template +HWY_NOINLINE V CallFastExpNormal(const D d, VecArg x) { + return FastExp(d, x); +} + +template +HWY_NOINLINE V CallFastExp2Normal(const D d, VecArg x) { + return FastExp2(d, x); +} + +template +HWY_NOINLINE V CallFastLogPositiveNormal(const D d, VecArg x) { + return FastLog(d, x); +} + +template +HWY_NOINLINE V CallFastLog2PositiveNormal(const D d, VecArg x) { + return FastLog2(d, x); +} + +template +HWY_NOINLINE V CallFastLog10PositiveNormal(const D d, VecArg x) { + return FastLog10(d, x); +} + +template +HWY_NOINLINE V CallFastLog1pPositiveNormal(const D d, VecArg x) { + return FastLog1p(d, x); +} + +template +HWY_NOINLINE V CallFastAtanPositive(const D d, VecArg x) { + return FastAtan(d, x); +} +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_MATH_FAST_MATH_INL_H_ diff --git a/lib/highway/hwy/contrib/math/math-inl.h b/lib/highway/hwy/contrib/math/math-inl.h new file mode 100644 index 00000000000..e131d2777b8 --- /dev/null +++ b/lib/highway/hwy/contrib/math/math-inl.h @@ -0,0 +1,1775 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Include guard (still compiled once per target) +#if defined(HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_ +#endif + +#include +#include + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +/** + * Highway SIMD version of std::acos(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: [-1, +1] + * @return arc cosine of 'x' + */ +template +HWY_INLINE V Acos(D d, V x); +template +HWY_NOINLINE V CallAcos(const D d, VecArg x) { + return Acos(d, x); +} + +/** + * Highway SIMD version of std::acosh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: float32[1, +FLT_MAX], float64[1, +DBL_MAX] + * @return hyperbolic arc cosine of 'x' + */ +template +HWY_INLINE V Acosh(D d, V x); +template +HWY_NOINLINE V CallAcosh(const D d, VecArg x) { + return Acosh(d, x); +} + +/** + * Highway SIMD version of std::asin(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: [-1, +1] + * @return arc sine of 'x' + */ +template +HWY_INLINE V Asin(D d, V x); +template +HWY_NOINLINE V CallAsin(const D d, VecArg x) { + return Asin(d, x); +} + +/** + * Highway SIMD version of std::asinh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: float32[-FLT_MAX, +FLT_MAX], float64[-DBL_MAX, +DBL_MAX] + * @return hyperbolic arc sine of 'x' + */ +template +HWY_INLINE V Asinh(D d, V x); +template +HWY_NOINLINE V CallAsinh(const D d, VecArg x) { + return Asinh(d, x); +} + +/** + * Highway SIMD version of std::atan(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: float32[-FLT_MAX, +FLT_MAX], float64[-DBL_MAX, +DBL_MAX] + * @return arc tangent of 'x' + */ +template +HWY_INLINE V Atan(D d, V x); +template +HWY_NOINLINE V CallAtan(const D d, VecArg x) { + return Atan(d, x); +} + +/** + * Highway SIMD version of std::atanh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: (-1, +1) + * @return hyperbolic arc tangent of 'x' + */ +template +HWY_INLINE V Atanh(D d, V x); +template +HWY_NOINLINE V CallAtanh(const D d, VecArg x) { + return Atanh(d, x); +} + +// Atan2 was added later and some users may be implementing it themselves, so +// notify them that this version of Highway defines it already. +#ifndef HWY_HAVE_ATAN2 +#define HWY_HAVE_ATAN2 1 +#endif + +/** + * Highway SIMD version of std::atan2(x). + * + * Valid Lane Types: float32, float64 + * Correctly handles negative zero, infinities, and NaN. + * @return atan2 of 'y', 'x' + */ +template +HWY_INLINE V Atan2(D d, V y, V x); +template +HWY_NOINLINE V CallAtan2(const D d, VecArg y, VecArg x) { + return Atan2(d, y, x); +} + +/** + * Highway SIMD version of std::cos(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: [-39000, +39000] + * @return cosine of 'x' + */ +template +HWY_INLINE V Cos(D d, V x); +template +HWY_NOINLINE V CallCos(const D d, VecArg x) { + return Cos(d, x); +} + +/** + * Highway SIMD version of std::exp(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 1 + * Valid Range: float32[-FLT_MAX, +104], float64[-DBL_MAX, +706] + * @return e^x + */ +template +HWY_INLINE V Exp(D d, V x); +template +HWY_NOINLINE V CallExp(const D d, VecArg x) { + return Exp(d, x); +} + +/** + * Highway SIMD version of std::exp2(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: float32[-FLT_MAX, +128], float64[-DBL_MAX, +1024] + * @return 2^x + */ +template +HWY_INLINE V Exp2(D d, V x); +template +HWY_NOINLINE V CallExp2(const D d, VecArg x) { + return Exp2(d, x); +} + +/** + * Highway SIMD version of std::expm1(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32[-FLT_MAX, +104], float64[-DBL_MAX, +706] + * @return e^x - 1 + */ +template +HWY_INLINE V Expm1(D d, V x); +template +HWY_NOINLINE V CallExpm1(const D d, VecArg x) { + return Expm1(d, x); +} + +/** + * Highway SIMD version of std::log(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32(0, +FLT_MAX], float64(0, +DBL_MAX] + * @return natural logarithm of 'x' + */ +template +HWY_INLINE V Log(D d, V x); +template +HWY_NOINLINE V CallLog(const D d, VecArg x) { + return Log(d, x); +} + +/** + * Highway SIMD version of std::log10(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: float32(0, +FLT_MAX], float64(0, +DBL_MAX] + * @return base 10 logarithm of 'x' + */ +template +HWY_INLINE V Log10(D d, V x); +template +HWY_NOINLINE V CallLog10(const D d, VecArg x) { + return Log10(d, x); +} + +/** + * Highway SIMD version of std::log1p(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: float32[0, +FLT_MAX], float64[0, +DBL_MAX] + * @return log(1 + x) + */ +template +HWY_INLINE V Log1p(D d, V x); +template +HWY_NOINLINE V CallLog1p(const D d, VecArg x) { + return Log1p(d, x); +} + +/** + * Highway SIMD version of std::log2(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 2 + * Valid Range: float32(0, +FLT_MAX], float64(0, +DBL_MAX] + * @return base 2 logarithm of 'x' + */ +template +HWY_INLINE V Log2(D d, V x); +template +HWY_NOINLINE V CallLog2(const D d, VecArg x) { + return Log2(d, x); +} + +/** + * Highway SIMD version of std::sin(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 3 + * Valid Range: [-39000, +39000] + * @return sine of 'x' + */ +template +HWY_INLINE V Sin(D d, V x); +template +HWY_NOINLINE V CallSin(const D d, VecArg x) { + return Sin(d, x); +} + +/** + * Highway SIMD version of std::sinh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32[-88.7228, +88.7228], float64[-709, +709] + * @return hyperbolic sine of 'x' + */ +template +HWY_INLINE V Sinh(D d, V x); +template +HWY_NOINLINE V CallSinh(const D d, VecArg x) { + return Sinh(d, x); +} + +/** + * Highway SIMD version of std::tanh(x). + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32[-FLT_MAX, +FLT_MAX], float64[-DBL_MAX, +DBL_MAX] + * @return hyperbolic tangent of 'x' + */ +template +HWY_INLINE V Tanh(D d, V x); +template +HWY_NOINLINE V CallTanh(const D d, VecArg x) { + return Tanh(d, x); +} + +/** + * Highway SIMD version of SinCos. + * Compute the sine and cosine at the same time + * The performance should be around the same as calling Sin. + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 1 + * Valid Range: [-39000, +39000] + */ +template +HWY_INLINE void SinCos(D d, V x, V& s, V& c); +template +HWY_NOINLINE void CallSinCos(const D d, VecArg x, V& s, V& c) { + SinCos(d, x, s, c); +} + +/** + * Highway SIMD version of Hypot + * + * Valid Lane Types: float32, float64 + * Max Error: ULP = 4 + * Valid Range: float32[-FLT_MAX, +FLT_MAX], float64[-DBL_MAX, +DBL_MAX] + * @return hypotenuse of a and b + */ +template +HWY_INLINE V Hypot(D d, V a, V b); +template +HWY_NOINLINE V CallHypot(const D d, VecArg a, VecArg b) { + return Hypot(d, a, b); +} + +//////////////////////////////////////////////////////////////////////////////// +// Implementation +//////////////////////////////////////////////////////////////////////////////// +namespace impl { + +// Estrin's Scheme is a faster method for evaluating large polynomials on +// super scalar architectures. It works by factoring the Horner's Method +// polynomial into power of two sub-trees that can be evaluated in parallel. +// Wikipedia Link: https://en.wikipedia.org/wiki/Estrin%27s_scheme +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1) { + return MulAdd(c1, x, c0); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2) { + T x2 = Mul(x, x); + return MulAdd(x2, c2, MulAdd(c1, x, c0)); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3) { + T x2 = Mul(x, x); + return MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + return MulAdd(x4, c4, MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + return MulAdd(x4, MulAdd(c5, x, c4), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + return MulAdd(x4, MulAdd(x2, c6, MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + return MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, c8, + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, MulAdd(c9, x, c8), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, MulAdd(x2, c10, MulAdd(c9, x, c8)), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8)), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd( + x8, MulAdd(x4, c12, MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, + MulAdd(x4, MulAdd(c13, x, c12), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, + MulAdd(x4, MulAdd(x2, c14, MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14, T c15) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + return MulAdd(x8, + MulAdd(x4, MulAdd(x2, MulAdd(c15, x, c14), MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0)))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14, T c15, T c16) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + T x16 = Mul(x8, x8); + return MulAdd( + x16, c16, + MulAdd(x8, + MulAdd(x4, MulAdd(x2, MulAdd(c15, x, c14), MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14, T c15, T c16, T c17) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + T x16 = Mul(x8, x8); + return MulAdd( + x16, MulAdd(c17, x, c16), + MulAdd(x8, + MulAdd(x4, MulAdd(x2, MulAdd(c15, x, c14), MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))))); +} +template +HWY_INLINE HWY_MAYBE_UNUSED T Estrin(T x, T c0, T c1, T c2, T c3, T c4, T c5, + T c6, T c7, T c8, T c9, T c10, T c11, + T c12, T c13, T c14, T c15, T c16, T c17, + T c18) { + T x2 = Mul(x, x); + T x4 = Mul(x2, x2); + T x8 = Mul(x4, x4); + T x16 = Mul(x8, x8); + return MulAdd( + x16, MulAdd(x2, c18, MulAdd(c17, x, c16)), + MulAdd(x8, + MulAdd(x4, MulAdd(x2, MulAdd(c15, x, c14), MulAdd(c13, x, c12)), + MulAdd(x2, MulAdd(c11, x, c10), MulAdd(c9, x, c8))), + MulAdd(x4, MulAdd(x2, MulAdd(c7, x, c6), MulAdd(c5, x, c4)), + MulAdd(x2, MulAdd(c3, x, c2), MulAdd(c1, x, c0))))); +} + +template +struct AsinImpl {}; +template +struct AtanImpl {}; +template +struct CosSinImpl {}; +template +struct ExpImpl {}; +template +struct LogImpl {}; +template +struct SinCosImpl {}; + +template <> +struct AsinImpl { + // Polynomial approximation for asin(x) over the range [0, 0.5). + template , HWY_IF_F32_D(D)> + HWY_INLINE V AsinPoly(D d, V x2, V /*x*/) { + const auto k0 = Set(d, +0.1666677296f); + const auto k1 = Set(d, +0.07495029271f); + const auto k2 = Set(d, +0.04547423869f); + const auto k3 = Set(d, +0.02424046025f); + const auto k4 = Set(d, +0.04197454825f); + + return Estrin(x2, k0, k1, k2, k3, k4); + } +}; + +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 + +template <> +struct AsinImpl { + // Polynomial approximation for asin(x) over the range [0, 0.5). + template , HWY_IF_F64_D(D)> + HWY_INLINE V AsinPoly(D d, V x2, V /*x*/) { + const auto k0 = Set(d, +0.1666666666666497543); + const auto k1 = Set(d, +0.07500000000378581611); + const auto k2 = Set(d, +0.04464285681377102438); + const auto k3 = Set(d, +0.03038195928038132237); + const auto k4 = Set(d, +0.02237176181932048341); + const auto k5 = Set(d, +0.01735956991223614604); + const auto k6 = Set(d, +0.01388715184501609218); + const auto k7 = Set(d, +0.01215360525577377331); + const auto k8 = Set(d, +0.006606077476277170610); + const auto k9 = Set(d, +0.01929045477267910674); + const auto k10 = Set(d, -0.01581918243329996643); + const auto k11 = Set(d, +0.03161587650653934628); + + return Estrin(x2, k0, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11); + } +}; + +#endif + +template <> +struct AtanImpl { + // Polynomial approximation for atan(x) over the range [0, 1.0). + template , HWY_IF_F32_D(D)> + HWY_INLINE V AtanPoly(D d, V x) { + const auto k0 = Set(d, -0.333331018686294555664062f); + const auto k1 = Set(d, +0.199926957488059997558594f); + const auto k2 = Set(d, -0.142027363181114196777344f); + const auto k3 = Set(d, +0.106347933411598205566406f); + const auto k4 = Set(d, -0.0748900920152664184570312f); + const auto k5 = Set(d, +0.0425049886107444763183594f); + const auto k6 = Set(d, -0.0159569028764963150024414f); + const auto k7 = Set(d, +0.00282363896258175373077393f); + + const auto y = Mul(x, x); + return MulAdd(Estrin(y, k0, k1, k2, k3, k4, k5, k6, k7), Mul(y, x), x); + } +}; + +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 + +template <> +struct AtanImpl { + // Polynomial approximation for atan(x) over the range [0, 1.0). + template , HWY_IF_F64_D(D)> + HWY_INLINE V AtanPoly(D d, V x) { + const auto k0 = Set(d, -0.333333333333311110369124); + const auto k1 = Set(d, +0.199999999996591265594148); + const auto k2 = Set(d, -0.14285714266771329383765); + const auto k3 = Set(d, +0.111111105648261418443745); + const auto k4 = Set(d, -0.090908995008245008229153); + const auto k5 = Set(d, +0.0769219538311769618355029); + const auto k6 = Set(d, -0.0666573579361080525984562); + const auto k7 = Set(d, +0.0587666392926673580854313); + const auto k8 = Set(d, -0.0523674852303482457616113); + const auto k9 = Set(d, +0.0466667150077840625632675); + const auto k10 = Set(d, -0.0407629191276836500001934); + const auto k11 = Set(d, +0.0337852580001353069993897); + const auto k12 = Set(d, -0.0254517624932312641616861); + const auto k13 = Set(d, +0.016599329773529201970117); + const auto k14 = Set(d, -0.00889896195887655491740809); + const auto k15 = Set(d, +0.00370026744188713119232403); + const auto k16 = Set(d, -0.00110611831486672482563471); + const auto k17 = Set(d, +0.000209850076645816976906797); + const auto k18 = Set(d, -1.88796008463073496563746e-5); + + const auto y = Mul(x, x); + return MulAdd(Estrin(y, k0, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, + k12, k13, k14, k15, k16, k17, k18), + Mul(y, x), x); + } +}; + +#endif + +template <> +struct CosSinImpl { + // Rounds float toward zero and returns as int32_t. + template , HWY_IF_F32_D(D)> + HWY_INLINE Vec> ToInt32(D /*unused*/, V x) { + return ConvertTo(Rebind(), x); + } + + template , HWY_IF_F32_D(D)> + HWY_INLINE V Poly(D d, V x) { + const auto k0 = Set(d, -1.66666597127914428710938e-1f); + const auto k1 = Set(d, +8.33307858556509017944336e-3f); + const auto k2 = Set(d, -1.981069071916863322258e-4f); + const auto k3 = Set(d, +2.6083159809786593541503e-6f); + + const auto y = Mul(x, x); + return MulAdd(Estrin(y, k0, k1, k2, k3), Mul(y, x), x); + } + + template , class VI32 = Vec>, + HWY_IF_F32_D(D)> + HWY_INLINE V CosReduce(D d, V x, VI32 q) { + // kHalfPiPart0f + kHalfPiPart1f + kHalfPiPart2f + kHalfPiPart3f ~= -pi/2 + const V kHalfPiPart0f = Set(d, -0.5f * 3.140625f); + const V kHalfPiPart1f = Set(d, -0.5f * 0.0009670257568359375f); + const V kHalfPiPart2f = Set(d, -0.5f * 6.2771141529083251953e-7f); + const V kHalfPiPart3f = Set(d, -0.5f * 1.2154201256553420762e-10f); + + // Extended precision modular arithmetic. + const V qf = ConvertTo(d, q); + x = MulAdd(qf, kHalfPiPart0f, x); + x = MulAdd(qf, kHalfPiPart1f, x); + x = MulAdd(qf, kHalfPiPart2f, x); + x = MulAdd(qf, kHalfPiPart3f, x); + return x; + } + + template , class VI32 = Vec>, + HWY_IF_F32_D(D)> + HWY_INLINE V SinReduce(D d, V x, VI32 q) { + // kPiPart0f + kPiPart1f + kPiPart2f + kPiPart3f ~= -pi + const V kPiPart0f = Set(d, -3.140625f); + const V kPiPart1f = Set(d, -0.0009670257568359375f); + const V kPiPart2f = Set(d, -6.2771141529083251953e-7f); + const V kPiPart3f = Set(d, -1.2154201256553420762e-10f); + + // Extended precision modular arithmetic. + const V qf = ConvertTo(d, q); + x = MulAdd(qf, kPiPart0f, x); + x = MulAdd(qf, kPiPart1f, x); + x = MulAdd(qf, kPiPart2f, x); + x = MulAdd(qf, kPiPart3f, x); + return x; + } + + // (q & 2) == 0 ? -0.0 : +0.0 + template >, HWY_IF_F32_D(D)> + HWY_INLINE Vec> CosSignFromQuadrant(D d, VI32 q) { + const VI32 kTwo = Set(Rebind(), 2); + return BitCast(d, ShiftLeft<30>(AndNot(q, kTwo))); + } + + // ((q & 1) ? -0.0 : +0.0) + template >, HWY_IF_F32_D(D)> + HWY_INLINE Vec> SinSignFromQuadrant(D d, VI32 q) { + const VI32 kOne = Set(Rebind(), 1); + return BitCast(d, ShiftLeft<31>(And(q, kOne))); + } +}; + +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 + +template <> +struct CosSinImpl { + // Rounds double toward zero and returns as int32_t. + template , HWY_IF_F64_D(D)> + HWY_INLINE Vec> ToInt32(D /*unused*/, V x) { + return DemoteTo(Rebind(), x); + } + + template , HWY_IF_F64_D(D)> + HWY_INLINE V Poly(D d, V x) { + const auto k0 = Set(d, -0.166666666666666657414808); + const auto k1 = Set(d, +0.00833333333333332974823815); + const auto k2 = Set(d, -0.000198412698412696162806809); + const auto k3 = Set(d, +2.75573192239198747630416e-6); + const auto k4 = Set(d, -2.50521083763502045810755e-8); + const auto k5 = Set(d, +1.60590430605664501629054e-10); + const auto k6 = Set(d, -7.64712219118158833288484e-13); + const auto k7 = Set(d, +2.81009972710863200091251e-15); + const auto k8 = Set(d, -7.97255955009037868891952e-18); + + const auto y = Mul(x, x); + return MulAdd(Estrin(y, k0, k1, k2, k3, k4, k5, k6, k7, k8), Mul(y, x), x); + } + + template , class VI32 = Vec>, + HWY_IF_F64_D(D)> + HWY_INLINE V CosReduce(D d, V x, VI32 q) { + // kHalfPiPart0d + kHalfPiPart1d + kHalfPiPart2d + kHalfPiPart3d ~= -pi/2 + const V kHalfPiPart0d = Set(d, -0.5 * 3.1415926218032836914); + const V kHalfPiPart1d = Set(d, -0.5 * 3.1786509424591713469e-8); + const V kHalfPiPart2d = Set(d, -0.5 * 1.2246467864107188502e-16); + const V kHalfPiPart3d = Set(d, -0.5 * 1.2736634327021899816e-24); + + // Extended precision modular arithmetic. + const V qf = PromoteTo(d, q); + x = MulAdd(qf, kHalfPiPart0d, x); + x = MulAdd(qf, kHalfPiPart1d, x); + x = MulAdd(qf, kHalfPiPart2d, x); + x = MulAdd(qf, kHalfPiPart3d, x); + return x; + } + + template , class VI32 = Vec>, + HWY_IF_F64_D(D)> + HWY_INLINE V SinReduce(D d, V x, VI32 q) { + // kPiPart0d + kPiPart1d + kPiPart2d + kPiPart3d ~= -pi + const V kPiPart0d = Set(d, -3.1415926218032836914); + const V kPiPart1d = Set(d, -3.1786509424591713469e-8); + const V kPiPart2d = Set(d, -1.2246467864107188502e-16); + const V kPiPart3d = Set(d, -1.2736634327021899816e-24); + + // Extended precision modular arithmetic. + const V qf = PromoteTo(d, q); + x = MulAdd(qf, kPiPart0d, x); + x = MulAdd(qf, kPiPart1d, x); + x = MulAdd(qf, kPiPart2d, x); + x = MulAdd(qf, kPiPart3d, x); + return x; + } + + // (q & 2) == 0 ? -0.0 : +0.0 + template >, HWY_IF_F64_D(D)> + HWY_INLINE Vec> CosSignFromQuadrant(D d, VI32 q) { + const VI32 kTwo = Set(Rebind(), 2); + return BitCast( + d, ShiftLeft<62>(PromoteTo(Rebind(), AndNot(q, kTwo)))); + } + + // ((q & 1) ? -0.0 : +0.0) + template >, HWY_IF_F64_D(D)> + HWY_INLINE Vec> SinSignFromQuadrant(D d, VI32 q) { + const VI32 kOne = Set(Rebind(), 1); + return BitCast( + d, ShiftLeft<63>(PromoteTo(Rebind(), And(q, kOne)))); + } +}; + +#endif + +template <> +struct ExpImpl { + // Rounds float toward zero and returns as int32_t. + template , HWY_IF_F32_D(D)> + HWY_INLINE Vec> ToInt32(D /*unused*/, V x) { + return ConvertTo(Rebind(), x); + } + + // Rounds float to nearest int32_t + template , HWY_IF_F32_D(D)> + HWY_INLINE Vec> ToNearestInt32(D /*unused*/, V x) { + return NearestInt(x); + } + + template , HWY_IF_F32_D(D)> + HWY_INLINE V ExpPoly(D d, V x) { + const auto k0 = Set(d, +0.5f); + const auto k1 = Set(d, +0.166666671633720397949219f); + const auto k2 = Set(d, +0.0416664853692054748535156f); + const auto k3 = Set(d, +0.00833336077630519866943359f); + const auto k4 = Set(d, +0.00139304355252534151077271f); + const auto k5 = Set(d, +0.000198527617612853646278381f); + + return MulAdd(Estrin(x, k0, k1, k2, k3, k4, k5), Mul(x, x), x); + } + + // Computes 2^x, where x is an integer. + template >, HWY_IF_F32_D(D)> + HWY_INLINE Vec Pow2I(D d, VI32 x) { + const Rebind di32; + const VI32 kOffset = Set(di32, 0x7F); + return BitCast(d, ShiftLeft<23>(Add(x, kOffset))); + } + + // Sets the exponent of 'x' to 2^e. + template , class VI32 = Vec>, + HWY_IF_F32_D(D)> + HWY_INLINE V LoadExpShortRange(D d, V x, VI32 e) { + const VI32 y = ShiftRight<1>(e); + return Mul(Mul(x, Pow2I(d, y)), Pow2I(d, Sub(e, y))); + } + + template , class VI32 = Vec>, + HWY_IF_F32_D(D)> + HWY_INLINE V ExpReduce(D d, V x, VI32 q) { + // kLn2Part0f + kLn2Part1f ~= -ln(2) + const V kLn2Part0f = Set(d, -0.693145751953125f); + const V kLn2Part1f = Set(d, -1.428606765330187045e-6f); + + // Extended precision modular arithmetic. + const V qf = ConvertTo(d, q); + x = MulAdd(qf, kLn2Part0f, x); + x = MulAdd(qf, kLn2Part1f, x); + return x; + } + + template , class VI32 = Vec>, + HWY_IF_F32_D(D)> + HWY_INLINE V Exp2Reduce(D d, V x, VI32 q) { + const V x_frac = Sub(x, ConvertTo(d, q)); + return MulAdd(x_frac, Set(d, 0.193147182464599609375f), + Mul(x_frac, Set(d, 0.5f))); + } +}; + +template <> +struct LogImpl { + template , HWY_IF_F32_D(D)> + HWY_INLINE Vec> Log2p1NoSubnormal( + D /*d*/, Vec> x) { + const Rebind di32; + const Rebind du32; + const auto kBias = Set(di32, 0x7F); + return Sub(BitCast(di32, ShiftRight<23>(BitCast(du32, x))), kBias); + } + + // Approximates Log(x) over the range [sqrt(2) / 2, sqrt(2)]. + template , HWY_IF_F32_D(D)> + HWY_INLINE V LogPoly(D d, V x) { + const V k0 = Set(d, 0.66666662693f); + const V k1 = Set(d, 0.40000972152f); + const V k2 = Set(d, 0.28498786688f); + const V k3 = Set(d, 0.24279078841f); + + const V x2 = Mul(x, x); + const V x4 = Mul(x2, x2); + return MulAdd(MulAdd(k2, x4, k0), x2, Mul(MulAdd(k3, x4, k1), x4)); + } +}; + +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 +template <> +struct ExpImpl { + // Rounds double toward zero and returns as int32_t. + template , HWY_IF_F64_D(D)> + HWY_INLINE Vec> ToInt32(D /*unused*/, V x) { + return DemoteTo(Rebind(), x); + } + + // Rounds double to nearest int32_t + template , HWY_IF_F64_D(D)> + HWY_INLINE Vec> ToNearestInt32(D /*unused*/, V x) { + return DemoteToNearestInt(Rebind(), x); + } + + template , HWY_IF_F64_D(D)> + HWY_INLINE V ExpPoly(D d, V x) { + const auto k0 = Set(d, +0.5); + const auto k1 = Set(d, +0.166666666666666851703837); + const auto k2 = Set(d, +0.0416666666666665047591422); + const auto k3 = Set(d, +0.00833333333331652721664984); + const auto k4 = Set(d, +0.00138888888889774492207962); + const auto k5 = Set(d, +0.000198412698960509205564975); + const auto k6 = Set(d, +2.4801587159235472998791e-5); + const auto k7 = Set(d, +2.75572362911928827629423e-6); + const auto k8 = Set(d, +2.75573911234900471893338e-7); + const auto k9 = Set(d, +2.51112930892876518610661e-8); + const auto k10 = Set(d, +2.08860621107283687536341e-9); + + return MulAdd(Estrin(x, k0, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10), + Mul(x, x), x); + } + + // Computes 2^x, where x is an integer. + template >, HWY_IF_F64_D(D)> + HWY_INLINE Vec Pow2I(D d, VI32 x) { + const Rebind di32; + const Rebind di64; + const VI32 kOffset = Set(di32, 0x3FF); + return BitCast(d, ShiftLeft<52>(PromoteTo(di64, Add(x, kOffset)))); + } + + // Sets the exponent of 'x' to 2^e. + template , class VI32 = Vec>, + HWY_IF_F64_D(D)> + HWY_INLINE V LoadExpShortRange(D d, V x, VI32 e) { + const VI32 y = ShiftRight<1>(e); + return Mul(Mul(x, Pow2I(d, y)), Pow2I(d, Sub(e, y))); + } + + template , class VI32 = Vec>, + HWY_IF_F64_D(D)> + HWY_INLINE V ExpReduce(D d, V x, VI32 q) { + // kLn2Part0d + kLn2Part1d ~= -ln(2) + const V kLn2Part0d = Set(d, -0.6931471805596629565116018); + const V kLn2Part1d = Set(d, -0.28235290563031577122588448175e-12); + + // Extended precision modular arithmetic. + const V qf = PromoteTo(d, q); + x = MulAdd(qf, kLn2Part0d, x); + x = MulAdd(qf, kLn2Part1d, x); + return x; + } + + template , class VI32 = Vec>, + HWY_IF_F64_D(D)> + HWY_INLINE V Exp2Reduce(D d, V x, VI32 q) { + const V x_frac = Sub(x, PromoteTo(d, q)); + return MulAdd(x_frac, Set(d, 0.1931471805599453139823396), + Mul(x_frac, Set(d, 0.5))); + } +}; + +template <> +struct LogImpl { + template , HWY_IF_F64_D(D)> + HWY_INLINE Vec> Log2p1NoSubnormal( + D /*d*/, Vec> x) { + const Rebind di64; + const Rebind du64; + return Sub(BitCast(di64, ShiftRight<52>(BitCast(du64, x))), + Set(di64, 0x3FF)); + } + + // Approximates Log(x) over the range [sqrt(2) / 2, sqrt(2)]. + template , HWY_IF_F64_D(D)> + HWY_INLINE V LogPoly(D d, V x) { + const V k0 = Set(d, 0.6666666666666735130); + const V k1 = Set(d, 0.3999999999940941908); + const V k2 = Set(d, 0.2857142874366239149); + const V k3 = Set(d, 0.2222219843214978396); + const V k4 = Set(d, 0.1818357216161805012); + const V k5 = Set(d, 0.1531383769920937332); + const V k6 = Set(d, 0.1479819860511658591); + + const V x2 = Mul(x, x); + const V x4 = Mul(x2, x2); + return MulAdd(MulAdd(MulAdd(MulAdd(k6, x4, k4), x4, k2), x4, k0), x2, + (Mul(MulAdd(MulAdd(k5, x4, k3), x4, k1), x4))); + } +}; + +#endif + +template +HWY_INLINE V Log(const D d, V x) { + // http://git.musl-libc.org/cgit/musl/tree/src/math/log.c for more info. + using T = TFromD; + impl::LogImpl impl; + + constexpr bool kIsF32 = (sizeof(T) == 4); + + // Float Constants + const V kLn2Hi = Set(d, kIsF32 ? static_cast(0.69313812256f) + : static_cast(0.693147180369123816490)); + const V kLn2Lo = Set(d, kIsF32 ? static_cast(9.0580006145e-6f) + : static_cast(1.90821492927058770002e-10)); + const V kOne = Set(d, static_cast(+1.0)); + const V kMinNormal = Set(d, kIsF32 ? static_cast(1.175494351e-38f) + : static_cast(2.2250738585072014e-308)); + const V kScale = Set(d, kIsF32 ? static_cast(3.355443200e+7f) + : static_cast(1.8014398509481984e+16)); + + // Integer Constants + using TI = MakeSigned; + const Rebind di; + using VI = decltype(Zero(di)); + const VI kLowerBits = Set(di, kIsF32 ? static_cast(0x00000000L) + : static_cast(0xFFFFFFFFLL)); + const VI kMagic = Set(di, kIsF32 ? static_cast(0x3F3504F3L) + : static_cast(0x3FE6A09E00000000LL)); + const VI kExpMagicDiff = Set( + di, kIsF32 + ? static_cast(0x3F800000L - 0x3F3504F3L) + : static_cast(0x3FF0000000000000LL - 0x3FE6A09E00000000LL)); + const VI kExpScale = + Set(di, kIsF32 ? static_cast(-25) : static_cast(-54)); + const VI kManMask = Set(di, kIsF32 ? static_cast(0x7FFFFFL) + : static_cast(0xFFFFF00000000LL)); + + // Scale up 'x' so that it is no longer denormalized. + VI exp_bits; + V exp; + if (kAllowSubnormals == true) { + const auto is_denormal = Lt(x, kMinNormal); + x = IfThenElse(is_denormal, Mul(x, kScale), x); + + // Compute the new exponent. + exp_bits = Add(BitCast(di, x), kExpMagicDiff); + const VI exp_scale = + BitCast(di, IfThenElseZero(is_denormal, BitCast(d, kExpScale))); + exp = ConvertTo(d, Add(exp_scale, impl.Log2p1NoSubnormal(d, exp_bits))); + } else { + // Compute the new exponent. + exp_bits = Add(BitCast(di, x), kExpMagicDiff); + exp = ConvertTo(d, impl.Log2p1NoSubnormal(d, exp_bits)); + } + + // Renormalize. + const V y = Or(And(x, BitCast(d, kLowerBits)), + BitCast(d, Add(And(exp_bits, kManMask), kMagic))); + + // Approximate and reconstruct. + const V ym1 = Sub(y, kOne); + const V z = Div(ym1, Add(y, kOne)); + + return MulSub( + exp, kLn2Hi, + Sub(MulSub(z, Sub(ym1, impl.LogPoly(d, z)), Mul(exp, kLn2Lo)), ym1)); +} + +// SinCos +// Based on "sse_mathfun.h", by Julien Pommier +// http://gruntthepeon.free.fr/ssemath/ + +// Third degree poly +template > +HWY_INLINE void SinCos3(D d, TFromD dp1, TFromD dp2, TFromD dp3, V x, + V& s, V& c) { + using T = TFromD; + using TI = MakeSigned; + using DI = Rebind; + const DI di; + using VI = decltype(Zero(di)); + using M = Mask; + + static constexpr size_t bits = sizeof(TI) * 8; + const VI sign_mask = SignBit(di); + const VI ci_0 = Zero(di); + const VI ci_1 = Set(di, 1); + const VI ci_2 = Set(di, 2); + const VI ci_4 = Set(di, 4); + const V cos_p0 = Set(d, ConvertScalarTo(2.443315711809948E-005)); + const V cos_p1 = Set(d, ConvertScalarTo(-1.388731625493765E-003)); + const V cos_p2 = Set(d, ConvertScalarTo(4.166664568298827E-002)); + const V sin_p0 = Set(d, ConvertScalarTo(-1.9515295891E-4)); + const V sin_p1 = Set(d, ConvertScalarTo(8.3321608736E-3)); + const V sin_p2 = Set(d, ConvertScalarTo(-1.6666654611E-1)); + const V FOPI = Set(d, ConvertScalarTo(1.27323954473516)); // 4 / M_PI + const V DP1 = Set(d, dp1); + const V DP2 = Set(d, dp2); + const V DP3 = Set(d, dp3); + + V xmm1, xmm2, sign_bit_sin, y; + VI imm0, imm2, imm4; + + sign_bit_sin = x; + x = Abs(x); + + /* extract the sign bit (upper one) */ + sign_bit_sin = And(sign_bit_sin, BitCast(d, sign_mask)); + + /* scale by 4/Pi */ + y = Mul(x, FOPI); + + /* store the integer part of y in imm2 */ + imm2 = ConvertTo(di, y); + + /* j=(j+1) & (~1) (see the cephes sources) */ + imm2 = Add(imm2, ci_1); + imm2 = AndNot(ci_1, imm2); + + y = ConvertTo(d, imm2); + imm4 = imm2; + + /* get the swap sign flag for the sine */ + imm0 = And(imm2, ci_4); + imm0 = ShiftLeft(imm0); + + V swap_sign_bit_sin = BitCast(d, imm0); + + /* get the polynomial selection mask for the sine*/ + imm2 = And(imm2, ci_2); + M poly_mask = RebindMask(d, Eq(imm2, ci_0)); + + /* The magic pass: "Extended precision modular arithmetic" + x = ((x - y * DP1) - y * DP2) - y * DP3; */ + x = MulAdd(y, DP1, x); + x = MulAdd(y, DP2, x); + x = MulAdd(y, DP3, x); + + imm4 = Sub(imm4, ci_2); + imm4 = AndNot(imm4, ci_4); + imm4 = ShiftLeft(imm4); + + V sign_bit_cos = BitCast(d, imm4); + + sign_bit_sin = Xor(sign_bit_sin, swap_sign_bit_sin); + + /* Evaluate the first polynomial (0 <= x <= Pi/4) */ + V z = Mul(x, x); + + y = MulAdd(cos_p0, z, cos_p1); + y = MulAdd(y, z, cos_p2); + y = Mul(y, z); + y = Mul(y, z); + y = NegMulAdd(z, Set(d, 0.5f), y); + y = Add(y, Set(d, 1)); + + /* Evaluate the second polynomial (Pi/4 <= x <= 0) */ + V y2 = MulAdd(sin_p0, z, sin_p1); + y2 = MulAdd(y2, z, sin_p2); + y2 = Mul(y2, z); + y2 = MulAdd(y2, x, x); + + /* select the correct result from the two polynomials */ + xmm1 = IfThenElse(poly_mask, y2, y); + xmm2 = IfThenElse(poly_mask, y, y2); + + /* update the sign */ + s = Xor(xmm1, sign_bit_sin); + c = Xor(xmm2, sign_bit_cos); +} + +// Sixth degree poly +template > +HWY_INLINE void SinCos6(D d, TFromD dp1, TFromD dp2, TFromD dp3, V x, + V& s, V& c) { + using T = TFromD; + using TI = MakeSigned; + using DI = Rebind; + const DI di; + using VI = decltype(Zero(di)); + using M = Mask; + + static constexpr size_t bits = sizeof(TI) * 8; + const VI sign_mask = SignBit(di); + const VI ci_0 = Zero(di); + const VI ci_1 = Set(di, 1); + const VI ci_2 = Set(di, 2); + const VI ci_4 = Set(di, 4); + const V cos_p0 = Set(d, ConvertScalarTo(-1.13585365213876817300E-11)); + const V cos_p1 = Set(d, ConvertScalarTo(2.08757008419747316778E-9)); + const V cos_p2 = Set(d, ConvertScalarTo(-2.75573141792967388112E-7)); + const V cos_p3 = Set(d, ConvertScalarTo(2.48015872888517045348E-5)); + const V cos_p4 = Set(d, ConvertScalarTo(-1.38888888888730564116E-3)); + const V cos_p5 = Set(d, ConvertScalarTo(4.16666666666665929218E-2)); + const V sin_p0 = Set(d, ConvertScalarTo(1.58962301576546568060E-10)); + const V sin_p1 = Set(d, ConvertScalarTo(-2.50507477628578072866E-8)); + const V sin_p2 = Set(d, ConvertScalarTo(2.75573136213857245213E-6)); + const V sin_p3 = Set(d, ConvertScalarTo(-1.98412698295895385996E-4)); + const V sin_p4 = Set(d, ConvertScalarTo(8.33333333332211858878E-3)); + const V sin_p5 = Set(d, ConvertScalarTo(-1.66666666666666307295E-1)); + const V FOPI = // 4 / M_PI + Set(d, ConvertScalarTo(1.2732395447351626861510701069801148)); + const V DP1 = Set(d, dp1); + const V DP2 = Set(d, dp2); + const V DP3 = Set(d, dp3); + + V xmm1, xmm2, sign_bit_sin, y; + VI imm0, imm2, imm4; + + sign_bit_sin = x; + x = Abs(x); + + /* extract the sign bit (upper one) */ + sign_bit_sin = And(sign_bit_sin, BitCast(d, sign_mask)); + + /* scale by 4/Pi */ + y = Mul(x, FOPI); + + /* store the integer part of y in imm2 */ + imm2 = ConvertTo(di, y); + + /* j=(j+1) & (~1) (see the cephes sources) */ + imm2 = Add(imm2, ci_1); + imm2 = AndNot(ci_1, imm2); + + y = ConvertTo(d, imm2); + imm4 = imm2; + + /* get the swap sign flag for the sine */ + imm0 = And(imm2, ci_4); + imm0 = ShiftLeft(imm0); + + V swap_sign_bit_sin = BitCast(d, imm0); + + /* get the polynomial selection mask for the sine*/ + imm2 = And(imm2, ci_2); + M poly_mask = RebindMask(d, Eq(imm2, ci_0)); + + /* The magic pass: "Extended precision modular arithmetic" + x = ((x - y * DP1) - y * DP2) - y * DP3; */ + x = MulAdd(y, DP1, x); + x = MulAdd(y, DP2, x); + x = MulAdd(y, DP3, x); + + imm4 = Sub(imm4, ci_2); + imm4 = AndNot(imm4, ci_4); + imm4 = ShiftLeft(imm4); + + V sign_bit_cos = BitCast(d, imm4); + sign_bit_sin = Xor(sign_bit_sin, swap_sign_bit_sin); + + /* Evaluate the first polynomial (0 <= x <= Pi/4) */ + V z = Mul(x, x); + + y = MulAdd(cos_p0, z, cos_p1); + y = MulAdd(y, z, cos_p2); + y = MulAdd(y, z, cos_p3); + y = MulAdd(y, z, cos_p4); + y = MulAdd(y, z, cos_p5); + y = Mul(y, z); + y = Mul(y, z); + y = NegMulAdd(z, Set(d, 0.5f), y); + y = Add(y, Set(d, 1.0f)); + + /* Evaluate the second polynomial (Pi/4 <= x <= 0) */ + V y2 = MulAdd(sin_p0, z, sin_p1); + y2 = MulAdd(y2, z, sin_p2); + y2 = MulAdd(y2, z, sin_p3); + y2 = MulAdd(y2, z, sin_p4); + y2 = MulAdd(y2, z, sin_p5); + y2 = Mul(y2, z); + y2 = MulAdd(y2, x, x); + + /* select the correct result from the two polynomials */ + xmm1 = IfThenElse(poly_mask, y2, y); + xmm2 = IfThenElse(poly_mask, y, y2); + + /* update the sign */ + s = Xor(xmm1, sign_bit_sin); + c = Xor(xmm2, sign_bit_cos); +} + +template <> +struct SinCosImpl { + template , HWY_IF_F32_D(D)> + HWY_INLINE void SinCos(D d, V x, V& s, V& c) { + SinCos3(d, -0.78515625f, -2.4187564849853515625e-4f, + -3.77489497744594108e-8f, x, s, c); + } +}; + +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 +template <> +struct SinCosImpl { + template , HWY_IF_F64_D(D)> + HWY_INLINE void SinCos(D d, V x, V& s, V& c) { + SinCos6(d, -7.85398125648498535156E-1, -3.77489470793079817668E-8, + -2.69515142907905952645E-15, x, s, c); + } +}; +#endif + +} // namespace impl + +template +HWY_INLINE V Acos(const D d, V x) { + using T = TFromD; + + const V kZero = Zero(d); + const V kHalf = Set(d, static_cast(+0.5)); + const V kPi = Set(d, static_cast(+3.14159265358979323846264)); + const V kPiOverTwo = Set(d, static_cast(+1.57079632679489661923132169)); + + const V sign_x = And(SignBit(d), x); + const V abs_x = Xor(x, sign_x); + const auto mask = Lt(abs_x, kHalf); + const V yy = + IfThenElse(mask, Mul(abs_x, abs_x), NegMulAdd(abs_x, kHalf, kHalf)); + const V y = IfThenElse(mask, abs_x, Sqrt(yy)); + + impl::AsinImpl impl; + const V t = Mul(impl.AsinPoly(d, yy, y), Mul(y, yy)); + + const V t_plus_y = Add(t, y); + const V z = + IfThenElse(mask, Sub(kPiOverTwo, Add(Xor(y, sign_x), Xor(t, sign_x))), + Add(t_plus_y, t_plus_y)); + return IfThenElse(Or(mask, Ge(x, kZero)), z, Sub(kPi, z)); +} + +template +HWY_INLINE V Acosh(const D d, V x) { + using T = TFromD; + + const V kLarge = Set(d, static_cast(268435456.0)); + const V kLog2 = Set(d, static_cast(0.693147180559945286227)); + const V kOne = Set(d, static_cast(+1.0)); + const V kTwo = Set(d, static_cast(+2.0)); + + const auto is_x_large = Gt(x, kLarge); + const auto is_x_gt_2 = Gt(x, kTwo); + + const V x_minus_1 = Sub(x, kOne); + const V y0 = MulSub(kTwo, x, Div(kOne, Add(Sqrt(MulSub(x, x, kOne)), x))); + const V y1 = + Add(Sqrt(MulAdd(x_minus_1, kTwo, Mul(x_minus_1, x_minus_1))), x_minus_1); + const V y2 = + IfThenElse(is_x_gt_2, IfThenElse(is_x_large, x, y0), Add(y1, kOne)); + const V z = impl::Log(d, y2); + + const auto is_pole = Eq(y2, kOne); + const auto divisor = Sub(IfThenZeroElse(is_pole, y2), kOne); + return Add(IfThenElse(is_x_gt_2, z, + IfThenElse(is_pole, y1, Div(Mul(z, y1), divisor))), + IfThenElseZero(is_x_large, kLog2)); +} + +template +HWY_INLINE V Asin(const D d, V x) { + using T = TFromD; + + const V kHalf = Set(d, static_cast(+0.5)); + const V kTwo = Set(d, static_cast(+2.0)); + const V kPiOverTwo = Set(d, static_cast(+1.57079632679489661923132169)); + + const V sign_x = And(SignBit(d), x); + const V abs_x = Xor(x, sign_x); + const auto mask = Lt(abs_x, kHalf); + const V yy = + IfThenElse(mask, Mul(abs_x, abs_x), NegMulAdd(abs_x, kHalf, kHalf)); + const V y = IfThenElse(mask, abs_x, Sqrt(yy)); + + impl::AsinImpl impl; + const V z0 = MulAdd(impl.AsinPoly(d, yy, y), Mul(yy, y), y); + const V z1 = NegMulAdd(z0, kTwo, kPiOverTwo); + return Or(IfThenElse(mask, z0, z1), sign_x); +} + +template +HWY_INLINE V Asinh(const D d, V x) { + using T = TFromD; + + const V kSmall = Set(d, static_cast(1.0 / 268435456.0)); + const V kLarge = Set(d, static_cast(268435456.0)); + const V kLog2 = Set(d, static_cast(0.693147180559945286227)); + const V kOne = Set(d, static_cast(+1.0)); + const V kTwo = Set(d, static_cast(+2.0)); + + const V sign_x = And(SignBit(d), x); // Extract the sign bit + const V abs_x = Xor(x, sign_x); + + const auto is_x_large = Gt(abs_x, kLarge); + const auto is_x_lt_2 = Lt(abs_x, kTwo); + + const V x2 = Mul(x, x); + const V sqrt_x2_plus_1 = Sqrt(Add(x2, kOne)); + + const V y0 = MulAdd(abs_x, kTwo, Div(kOne, Add(sqrt_x2_plus_1, abs_x))); + const V y1 = Add(Div(x2, Add(sqrt_x2_plus_1, kOne)), abs_x); + const V y2 = + IfThenElse(is_x_lt_2, Add(y1, kOne), IfThenElse(is_x_large, abs_x, y0)); + const V z = impl::Log(d, y2); + + const auto is_pole = Eq(y2, kOne); + const auto divisor = Sub(IfThenZeroElse(is_pole, y2), kOne); + const auto large = IfThenElse(is_pole, y1, Div(Mul(z, y1), divisor)); + const V y = IfThenElse(Lt(abs_x, kSmall), x, large); + return Or(Add(IfThenElse(is_x_lt_2, y, z), IfThenElseZero(is_x_large, kLog2)), + sign_x); +} + +template +HWY_INLINE V Atan(const D d, V x) { + using T = TFromD; + + const V kOne = Set(d, static_cast(+1.0)); + const V kPiOverTwo = Set(d, static_cast(+1.57079632679489661923132169)); + + const V sign = And(SignBit(d), x); + const V abs_x = Xor(x, sign); + const auto mask = Gt(abs_x, kOne); + + impl::AtanImpl impl; + const auto divisor = IfThenElse(mask, abs_x, kOne); + const V y = impl.AtanPoly(d, IfThenElse(mask, Div(kOne, divisor), abs_x)); + return Or(IfThenElse(mask, Sub(kPiOverTwo, y), y), sign); +} + +template +HWY_INLINE V Atan2(const D d, V y, V x) { + using T = TFromD; + using M = MFromD; + + const V kPi = Set(d, static_cast(3.14159265358979323846264)); + const V kPiOverTwo = Set(d, static_cast(1.57079632679489661923132169)); + const V kOne = Set(d, static_cast(1.0)); + const V k0 = Zero(d); + + const V ax = Abs(x); + const V ay = Abs(y); + + const V num = Min(ax, ay); + const V den = Max(ax, ay); + + const M is_inf = IsInf(num); + V mapped_y = MaskedDivOr(k0, Ne(den, k0), num, den); + mapped_y = IfThenElse(is_inf, kOne, mapped_y); + + impl::AtanImpl impl; + const V poly = impl.AtanPoly(d, mapped_y); + + const M ay_gt_ax = Gt(ay, ax); + V angle = MaskedSubOr(poly, ay_gt_ax, kPiOverTwo, poly); + + const M x_neg = Lt(x, k0); + angle = MaskedSubOr(angle, x_neg, kPi, angle); + + const M is_nan = IsEitherNaN(y, x); + return IfThenElse(is_nan, NaN(d), CopySign(angle, y)); +} + +template +HWY_INLINE V Atanh(const D d, V x) { + using T = TFromD; + + const V kHalf = Set(d, static_cast(+0.5)); + const V kOne = Set(d, static_cast(+1.0)); + + const V sign = And(SignBit(d), x); // Extract the sign bit + const V abs_x = Xor(x, sign); + return Mul(Log1p(d, Div(Add(abs_x, abs_x), Sub(kOne, abs_x))), + Xor(kHalf, sign)); +} + +template +HWY_INLINE V Cos(const D d, V x) { + using T = TFromD; + impl::CosSinImpl impl; + + // Float Constants + const V kOneOverPi = Set(d, static_cast(0.31830988618379067153)); + + // Integer Constants + const Rebind di32; + using VI32 = decltype(Zero(di32)); + const VI32 kOne = Set(di32, 1); + + const V y = Abs(x); // cos(x) == cos(|x|) + + // Compute the quadrant, q = int(|x| / pi) * 2 + 1 + const VI32 q = Add(ShiftLeft<1>(impl.ToInt32(d, Mul(y, kOneOverPi))), kOne); + + // Reduce range, apply sign, and approximate. + return impl.Poly( + d, Xor(impl.CosReduce(d, y, q), impl.CosSignFromQuadrant(d, q))); +} + +template +HWY_INLINE V Exp(const D d, V x) { + using T = TFromD; + + const V kHalf = Set(d, static_cast(+0.5)); + const V kLowerBound = + Set(d, static_cast((sizeof(T) == 4 ? -104.0 : -1000.0))); + const V kNegZero = Set(d, static_cast(-0.0)); + const V kOne = Set(d, static_cast(+1.0)); + const V kOneOverLog2 = Set(d, static_cast(+1.442695040888963407359924681)); + + impl::ExpImpl impl; + + // q = static_cast((x / log(2)) + ((x < 0) ? -0.5 : +0.5)) + const auto q = + impl.ToInt32(d, MulAdd(x, kOneOverLog2, Or(kHalf, And(x, kNegZero)))); + + // Reduce, approximate, and then reconstruct. + const V y = impl.LoadExpShortRange( + d, Add(impl.ExpPoly(d, impl.ExpReduce(d, x, q)), kOne), q); + return IfThenElseZero(Ge(x, kLowerBound), y); +} + +template +HWY_INLINE V Exp2(const D d, V x) { + using T = TFromD; + + const V kLowerBound = + Set(d, static_cast((sizeof(T) == 4 ? -150.0 : -1075.0))); + const V kOne = Set(d, static_cast(+1.0)); + + impl::ExpImpl impl; + + // q = static_cast(std::lrint(x)) + const auto q = impl.ToNearestInt32(d, x); + + // Reduce, approximate, and then reconstruct. + const V y = impl.LoadExpShortRange( + d, Add(impl.ExpPoly(d, impl.Exp2Reduce(d, x, q)), kOne), q); + return IfThenElseZero(Ge(x, kLowerBound), y); +} + +template +HWY_INLINE V Expm1(const D d, V x) { + using T = TFromD; + + const V kHalf = Set(d, static_cast(+0.5)); + const V kLowerBound = + Set(d, static_cast((sizeof(T) == 4 ? -104.0 : -1000.0))); + const V kLn2Over2 = Set(d, static_cast(+0.346573590279972654708616)); + const V kNegOne = Set(d, static_cast(-1.0)); + const V kNegZero = Set(d, static_cast(-0.0)); + const V kOne = Set(d, static_cast(+1.0)); + const V kOneOverLog2 = Set(d, static_cast(+1.442695040888963407359924681)); + + impl::ExpImpl impl; + + // q = static_cast((x / log(2)) + ((x < 0) ? -0.5 : +0.5)) + const auto q = + impl.ToInt32(d, MulAdd(x, kOneOverLog2, Or(kHalf, And(x, kNegZero)))); + + // Reduce, approximate, and then reconstruct. + const V y = impl.ExpPoly(d, impl.ExpReduce(d, x, q)); + const V z = IfThenElse(Lt(Abs(x), kLn2Over2), y, + Sub(impl.LoadExpShortRange(d, Add(y, kOne), q), kOne)); + return IfThenElse(Lt(x, kLowerBound), kNegOne, z); +} + +template +HWY_INLINE V Log(const D d, V x) { + return impl::Log(d, x); +} + +template +HWY_INLINE V Log10(const D d, V x) { + using T = TFromD; + return Mul(Log(d, x), Set(d, static_cast(0.4342944819032518276511))); +} + +template +HWY_INLINE V Log1p(const D d, V x) { + using T = TFromD; + const V kOne = Set(d, static_cast(+1.0)); + + const V y = Add(x, kOne); + const Mask not_pole = Ne(y, kOne); + // If y == 1, divisor becomes 1 (dummy), avoiding division by zero. + const V divisor = MaskedSubOr(y, not_pole, y, kOne); + // Ensure exactly 1.0 when x == divisor. This is necessary because some + // platforms (like Armv7) use Newton-Raphson for division, which can return + // 0.0, instead of 1.0 when the reciprocal calculation underflows + // for very large x. + const V div_res = MaskedDivOr(kOne, Ne(x, divisor), x, divisor); + const auto non_pole = + Mul(impl::Log(d, y), div_res); + return IfThenElse(not_pole, non_pole, x); +} + +template +HWY_INLINE V Log2(const D d, V x) { + using T = TFromD; + return Mul(Log(d, x), Set(d, static_cast(1.44269504088896340735992))); +} + +template +HWY_INLINE V Sin(const D d, V x) { + using T = TFromD; + impl::CosSinImpl impl; + + // Float Constants + const V kOneOverPi = Set(d, static_cast(0.31830988618379067153)); + const V kHalf = Set(d, static_cast(0.5)); + + // Integer Constants + const Rebind di32; + using VI32 = decltype(Zero(di32)); + + const V abs_x = Abs(x); + const V sign_x = Xor(abs_x, x); + + // Compute the quadrant, q = int((|x| / pi) + 0.5) + const VI32 q = impl.ToInt32(d, MulAdd(abs_x, kOneOverPi, kHalf)); + + // Reduce range, apply sign, and approximate. + return impl.Poly(d, Xor(impl.SinReduce(d, abs_x, q), + Xor(impl.SinSignFromQuadrant(d, q), sign_x))); +} + +template +HWY_INLINE V Sinh(const D d, V x) { + using T = TFromD; + const V kHalf = Set(d, static_cast(+0.5)); + const V kOne = Set(d, static_cast(+1.0)); + const V kTwo = Set(d, static_cast(+2.0)); + + const V sign = And(SignBit(d), x); // Extract the sign bit + const V abs_x = Xor(x, sign); + const V y = Expm1(d, abs_x); + const V z = Mul(Div(Add(y, kTwo), Add(y, kOne)), Mul(y, kHalf)); + return Xor(z, sign); // Reapply the sign bit +} + +template +HWY_INLINE V Tanh(const D d, V x) { + using T = TFromD; + const V kLimit = Set(d, static_cast(18.714973875)); + const V kOne = Set(d, static_cast(+1.0)); + const V kTwo = Set(d, static_cast(+2.0)); + + const V sign = And(SignBit(d), x); // Extract the sign bit + const V abs_x = Xor(x, sign); + const V y = Expm1(d, Mul(abs_x, kTwo)); + const V z = IfThenElse(Gt(abs_x, kLimit), kOne, Div(y, Add(y, kTwo))); + return Xor(z, sign); // Reapply the sign bit +} + +template +HWY_INLINE void SinCos(const D d, V x, V& s, V& c) { + using T = TFromD; + impl::SinCosImpl impl; + impl.SinCos(d, x, s, c); +} + +template +HWY_INLINE V Hypot(const D d, V a, V b) { + using T = TFromD; + using TI = MakeSigned; + const RebindToUnsigned du; + const RebindToSigned di; + using VI = VFromD; + + constexpr int kMaxBiasedExp = static_cast(MaxExponentField()); + static_assert(kMaxBiasedExp > 0, "kMaxBiasedExp > 0 must be true"); + + constexpr int kNumOfMantBits = MantissaBits(); + static_assert(kNumOfMantBits > 0, "kNumOfMantBits > 0 must be true"); + + constexpr int kExpBias = kMaxBiasedExp / 2; + + static_assert( + static_cast(kExpBias) + static_cast(kNumOfMantBits) < + static_cast(kMaxBiasedExp), + "kExpBias + kNumOfMantBits < kMaxBiasedExp must be true"); + + // kMinValToSquareBiasedExp is the smallest biased exponent such that + // pow(pow(2, kMinValToSquareBiasedExp - kExpBias) * x, 2) is either a normal + // floating-point value or infinity if x is a non-zero, non-NaN value + constexpr int kMinValToSquareBiasedExp = (kExpBias / 2) + kNumOfMantBits; + static_assert(kMinValToSquareBiasedExp < kExpBias, + "kMinValToSquareBiasedExp < kExpBias must be true"); + + // kMaxValToSquareBiasedExp is the largest biased exponent such that + // pow(pow(2, kMaxValToSquareBiasedExp - kExpBias) * x, 2) * 2 is guaranteed + // to be a finite value if x is a finite value + constexpr int kMaxValToSquareBiasedExp = kExpBias + ((kExpBias / 2) - 1); + static_assert(kMaxValToSquareBiasedExp > kExpBias, + "kMaxValToSquareBiasedExp > kExpBias must be true"); + static_assert(kMaxValToSquareBiasedExp < kMaxBiasedExp, + "kMaxValToSquareBiasedExp < kMaxBiasedExp must be true"); + +#if HWY_TARGET == HWY_SCALAR || HWY_TARGET == HWY_EMU128 || \ + HWY_TARGET == HWY_Z14 || HWY_TARGET == HWY_Z15 + using TExpSatSub = MakeUnsigned; + using TExpMinMax = TI; +#else + using TExpSatSub = uint16_t; + using TExpMinMax = int16_t; +#endif + + const Repartition d_exp_sat_sub; + const Repartition d_exp_min_max; + + const V abs_a = Abs(a); + const V abs_b = Abs(b); + + const MFromD either_inf = Or(IsInf(a), IsInf(b)); + + const VI zero = Zero(di); + + // exp_a[i] is the biased exponent of abs_a[i] + const VI exp_a = BitCast(di, ShiftRight(BitCast(du, abs_a))); + + // exp_b[i] is the biased exponent of abs_b[i] + const VI exp_b = BitCast(di, ShiftRight(BitCast(du, abs_b))); + + // max_exp[i] is equal to HWY_MAX(exp_a[i], exp_b[i]) + + // If abs_a[i] and abs_b[i] are both NaN values, max_exp[i] will be equal to + // the biased exponent of the larger value. Otherwise, if either abs_a[i] or + // abs_b[i] is NaN, max_exp[i] will be equal to kMaxBiasedExp. + const VI max_exp = BitCast( + di, Max(BitCast(d_exp_min_max, exp_a), BitCast(d_exp_min_max, exp_b))); + + // If either abs_a[i] or abs_b[i] is zero, min_exp[i] is equal to max_exp[i]. + // Otherwise, if abs_a[i] and abs_b[i] are both nonzero, min_exp[i] is equal + // to HWY_MIN(exp_a[i], exp_b[i]). + const VI min_exp = IfThenElse( + Or(Eq(BitCast(di, abs_a), zero), Eq(BitCast(di, abs_b), zero)), max_exp, + BitCast(di, Min(BitCast(d_exp_min_max, exp_a), + BitCast(d_exp_min_max, exp_b)))); + + // scl_pow2[i] is the power of 2 to scale abs_a[i] and abs_b[i] by + + // abs_a[i] and abs_b[i] should be scaled by a factor that is greater than + // zero but less than or equal to + // pow(2, kMaxValToSquareBiasedExp - max_exp[i]) to ensure that that the + // multiplications or addition operations do not overflow if + // std::hypot(abs_a[i], abs_b[i]) is finite + + // If either abs_a[i] or abs_b[i] is a a positive value that is less than + // pow(2, kMinValToSquareBiasedExp - kExpBias), then scaling up abs_a[i] and + // abs_b[i] by pow(2, kMinValToSquareBiasedExp - min_exp[i]) will ensure that + // the multiplications and additions result in normal floating point values, + // infinities, or NaNs. + + // If HWY_MAX(kMinValToSquareBiasedExp - min_exp[i], 0) is greater than + // kMaxValToSquareBiasedExp - max_exp[i], scale abs_a[i] and abs_b[i] up by + // pow(2, kMaxValToSquareBiasedExp - max_exp[i]) to ensure that the + // multiplication and addition operations result in a finite result if + // std::hypot(abs_a[i], abs_b[i]) is finite. + + const VI scl_pow2 = BitCast( + di, + Min(BitCast(d_exp_min_max, + SaturatedSub(BitCast(d_exp_sat_sub, + Set(di, static_cast( + kMinValToSquareBiasedExp))), + BitCast(d_exp_sat_sub, min_exp))), + BitCast(d_exp_min_max, + Sub(Set(di, static_cast(kMaxValToSquareBiasedExp)), + max_exp)))); + + const VI exp_bias = Set(di, static_cast(kExpBias)); + + const V ab_scl_factor = + BitCast(d, ShiftLeft(Add(exp_bias, scl_pow2))); + const V hypot_scl_factor = + BitCast(d, ShiftLeft(Sub(exp_bias, scl_pow2))); + + const V scl_a = Mul(abs_a, ab_scl_factor); + const V scl_b = Mul(abs_b, ab_scl_factor); + + const V scl_hypot = Sqrt(MulAdd(scl_a, scl_a, Mul(scl_b, scl_b))); + // std::hypot returns inf if one input is +/- inf, even if the other is NaN. + return IfThenElse(either_inf, Inf(d), Mul(scl_hypot, hypot_scl_factor)); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_ diff --git a/lib/highway/hwy/contrib/math/math_benchmark.cc b/lib/highway/hwy/contrib/math/math_benchmark.cc new file mode 100644 index 00000000000..23299f46497 --- /dev/null +++ b/lib/highway/hwy/contrib/math/math_benchmark.cc @@ -0,0 +1,326 @@ +// Copyright 2026 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/math/math_benchmark.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/contrib/math/math-inl.h" +#include "hwy/contrib/math/fast_math-inl.h" +#include "hwy/nanobenchmark.h" +#include "hwy/tests/test_util-inl.h" +// clang-format on + +#include +#include + +// ============================================================================ +// You can dynamically select two highway math functions from the list to +// benchmark by passing their names via command line flags `--fxn1` and +// `--fxn2`. Defaults: +// If no flags are passed, the benchmark automatically defaults to +// benchmarking `FastExp` and `Exp`. +// +// Note: +// - Either provide 2 function flags or none. +// - Function names must be from the list of valid functions. +// ============================================================================ + +namespace hwy { +extern const char* g_fxn1; +extern const char* g_fxn2; +} // namespace hwy + +#define HWY_MATH_BENCHMARKS(V) \ + V(Exp) \ + V(FastExp) \ + V(FastExpNormal) \ + V(Exp2) \ + V(FastExp2) \ + V(FastExp2Normal) \ + V(FastExpMinusOrZero) \ + V(Log) \ + V(FastLog) \ + V(FastLogPositiveNormal) \ + V(Log2) \ + V(FastLog2) \ + V(FastLog2PositiveNormal) \ + V(Log10) \ + V(FastLog10) \ + V(FastLog10PositiveNormal) \ + V(Log1p) \ + V(FastLog1p) \ + V(FastLog1pPositiveNormal) \ + V(Atan) \ + V(FastAtan) \ + V(FastAtanPositive) \ + V(Tanh) \ + V(FastTanh) \ + V(Atan2) \ + V(FastAtan2) + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +namespace hn = hwy::HWY_NAMESPACE; + +// Helper to safely convert float/double results to benchmark output +// without triggering SIGILL on Infinity or compile errors on size mismatch. +template +hwy::FuncOutput SafeCast(Val v) { + using BitsT = hwy::MakeUnsigned; + auto bits = hwy::BitCastScalar(v); + return static_cast(bits); +} + +// Macro to define benchmark body +#define DEFINE_MATH_BENCH(NAME, FUNC, MAP_EXPR) \ + template \ + void Bench##NAME(D d) { \ + using T = hn::TFromD; \ + printf("Benchmarking " #NAME " for %s:\n", \ + hwy::TypeName(T(), hn::Lanes(d)).c_str()); \ + auto func = [d](const hwy::FuncInput in) -> hwy::FuncOutput { \ + const double val = MAP_EXPR; \ + const auto v = hn::Set(d, static_cast(val)); \ + const auto res = FUNC; \ + return SafeCast(hn::GetLane(res)); \ + }; \ + const size_t kNumInputs = 16; \ + hwy::FuncInput inputs[kNumInputs]; \ + for (size_t i = 0; i < kNumInputs; ++i) inputs[i] = i; \ + hwy::Result results[kNumInputs]; \ + hwy::Params params = hwy::DefaultBenchmarkParams(); \ + const size_t num_results = \ + hwy::MeasureClosure(func, inputs, kNumInputs, results, params); \ + double sum_ticks = 0; \ + for (size_t i = 0; i < num_results; ++i) sum_ticks += results[i].ticks; \ + if (num_results > 0) { \ + printf(" Avg ticks: %f\n", sum_ticks / num_results); \ + } \ + } + +// Macro to define benchmark body for 2-element variants +#define DEFINE_MATH_BENCH_2ARG(NAME, FUNC, MAP_EXPR_Y, MAP_EXPR_X) \ + template \ + void Bench##NAME(D d) { \ + using T = hn::TFromD; \ + printf("Benchmarking " #NAME " for %s:\n", \ + hwy::TypeName(T(), hn::Lanes(d)).c_str()); \ + auto func = [d](const hwy::FuncInput in) -> hwy::FuncOutput { \ + const double val_y = MAP_EXPR_Y; \ + const double val_x = MAP_EXPR_X; \ + const auto y = hn::Set(d, static_cast(val_y)); \ + const auto x = hn::Set(d, static_cast(val_x)); \ + const auto res = FUNC; \ + return SafeCast(hn::GetLane(res)); \ + }; \ + const size_t kNumInputs = 16; \ + hwy::FuncInput inputs[kNumInputs]; \ + for (size_t i = 0; i < kNumInputs; ++i) inputs[i] = i; \ + hwy::Result results[kNumInputs]; \ + hwy::Params params = hwy::DefaultBenchmarkParams(); \ + const size_t num_results = \ + hwy::MeasureClosure(func, inputs, kNumInputs, results, params); \ + double sum_ticks = 0; \ + for (size_t i = 0; i < num_results; ++i) sum_ticks += results[i].ticks; \ + if (num_results > 0) { \ + printf(" Avg ticks: %f\n", sum_ticks / num_results); \ + } \ + } + +// Exp / FastExp +DEFINE_MATH_BENCH(CallExp, hn::CallExp(d, v), + -10.0 + static_cast(in) * (20.0 / 15.0)) +DEFINE_MATH_BENCH(CallFastExp, hn::CallFastExp(d, v), + -10.0 + static_cast(in) * (20.0 / 15.0)) +DEFINE_MATH_BENCH(CallFastExpNormal, hn::CallFastExpNormal(d, v), + -10.0 + static_cast(in) * (20.0 / 15.0)) + +// Exp2 / FastExp2 +DEFINE_MATH_BENCH(CallExp2, hn::CallExp2(d, v), + -10.0 + static_cast(in) * (20.0 / 15.0)) +DEFINE_MATH_BENCH(CallFastExp2, hn::CallFastExp2(d, v), + -10.0 + static_cast(in) * (20.0 / 15.0)) +DEFINE_MATH_BENCH(CallFastExp2Normal, hn::CallFastExp2Normal(d, v), + -10.0 + static_cast(in) * (20.0 / 15.0)) + +// FastExpMinusOrZero +DEFINE_MATH_BENCH(CallFastExpMinusOrZero, hn::CallFastExpMinusOrZero(d, v), + -10.0 + static_cast(in) * (10.0 / 15.0)) + +// Log / FastLog +DEFINE_MATH_BENCH(CallLog, hn::CallLog(d, v), + 0.1 + static_cast(in) * 1.0) +DEFINE_MATH_BENCH(CallFastLog, hn::CallFastLog(d, v), + 0.1 + static_cast(in) * 1.0) +DEFINE_MATH_BENCH(CallFastLogPositiveNormal, + hn::CallFastLogPositiveNormal(d, v), + 0.1 + static_cast(in) * 1.0) + +// Log2 / FastLog2 +DEFINE_MATH_BENCH(CallLog2, hn::CallLog2(d, v), + 0.1 + static_cast(in) * 1.0) +DEFINE_MATH_BENCH(CallFastLog2, hn::CallFastLog2(d, v), + 0.1 + static_cast(in) * 1.0) +DEFINE_MATH_BENCH(CallFastLog2PositiveNormal, + hn::CallFastLog2PositiveNormal(d, v), + 0.1 + static_cast(in) * 1.0) + +// Log10 / FastLog10 +DEFINE_MATH_BENCH(CallLog10, hn::CallLog10(d, v), + 0.1 + static_cast(in) * 1.0) +DEFINE_MATH_BENCH(CallFastLog10, hn::CallFastLog10(d, v), + 0.1 + static_cast(in) * 1.0) +DEFINE_MATH_BENCH(CallFastLog10PositiveNormal, + hn::CallFastLog10PositiveNormal(d, v), + 0.1 + static_cast(in) * 1.0) + +// Log1p / FastLog1p +DEFINE_MATH_BENCH(CallLog1p, hn::CallLog1p(d, v), + 0.1 + static_cast(in) * 1.0) +DEFINE_MATH_BENCH(CallFastLog1p, hn::CallFastLog1p(d, v), + 0.1 + static_cast(in) * 1.0) +DEFINE_MATH_BENCH(CallFastLog1pPositiveNormal, + hn::CallFastLog1pPositiveNormal(d, v), + 0.1 + static_cast(in) * 1.0) + +// Atan / FastAtan +DEFINE_MATH_BENCH(CallAtan, hn::CallAtan(d, v), + -10.0 + static_cast(in) * (20.0 / 15.0)) +DEFINE_MATH_BENCH(CallFastAtan, hn::CallFastAtan(d, v), + -10.0 + static_cast(in) * (20.0 / 15.0)) +DEFINE_MATH_BENCH(CallFastAtanPositive, hn::CallFastAtanPositive(d, v), + 0.1 + static_cast(in) * (20.0 / 15.0)) + +// Tanh / FastTanh +DEFINE_MATH_BENCH(CallTanh, hn::CallTanh(d, v), + -10.0 + static_cast(in) * (20.0 / 15.0)) +DEFINE_MATH_BENCH(CallFastTanh, hn::CallFastTanh(d, v), + -10.0 + static_cast(in) * (20.0 / 15.0)) + +// Atan2 / FastAtan2 +DEFINE_MATH_BENCH_2ARG(CallAtan2, hn::CallAtan2(d, y, x), + -10.0 + static_cast(in / 4) * (20.0 / 3.0), + -10.0 + static_cast(in % 4) * (20.0 / 3.0)) +DEFINE_MATH_BENCH_2ARG(CallFastAtan2, hn::CallFastAtan2(d, y, x), + -10.0 + static_cast(in / 4) * (20.0 / 3.0), + -10.0 + static_cast(in % 4) * (20.0 / 3.0)) + +struct RunBenchmarks { + template + HWY_NOINLINE void operator()(T, D d) { + auto run_fxn = [&](const char* fxn) { + struct BenchmarkTable { + const char* name; + void (*func)(D); + }; + + static const BenchmarkTable table[] = { +#define V(NAME) {#NAME, &BenchCall##NAME}, + HWY_MATH_BENCHMARKS(V) +#undef V + }; + + for (const auto& entry : table) { + if (strcmp(fxn, entry.name) == 0) { + entry.func(d); + return; + } + } + }; + + if (hwy::g_fxn1) run_fxn(hwy::g_fxn1); + if (hwy::g_fxn2) run_fxn(hwy::g_fxn2); + } +}; + +HWY_NOINLINE void RunAllBenchmarks() { + hn::ForFloat3264Types(hn::ForPartialVectors()); +} + +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +const char* g_fxn1 = nullptr; +const char* g_fxn2 = nullptr; + +namespace { +HWY_EXPORT(RunAllBenchmarks); +} // namespace +} // namespace hwy + +int main(int argc, char** argv) { + const char* fxn1 = nullptr; + const char* fxn2 = nullptr; + + for (int i = 1; i < argc; ++i) { + if (strncmp(argv[i], "--fxn1=", 7) == 0) { + fxn1 = argv[i] + 7; + } else if (strncmp(argv[i], "--fxn2=", 7) == 0) { + fxn2 = argv[i] + 7; + } + } + + if ((fxn1 == nullptr) != (fxn2 == nullptr)) { + fprintf(stderr, + "Error: Either pass both --fxn1 and --fxn2, or pass neither to use " + "defaults.\n"); + return 1; + } + + if (fxn1 == nullptr && fxn2 == nullptr) { + fxn1 = "FastExp"; + fxn2 = "Exp"; + } + + const char* valid_fxns[] = { +#define V(NAME) #NAME, + HWY_MATH_BENCHMARKS(V) +#undef V + }; + + bool valid1 = false; + bool valid2 = false; + for (const char* v : valid_fxns) { + if (strcmp(fxn1, v) == 0) valid1 = true; + if (strcmp(fxn2, v) == 0) valid2 = true; + } + + if (!valid1 || !valid2) { + fprintf(stderr, + "Error: One or both function names do not match any defined " + "benchmark functions.\n"); + fprintf(stderr, "Valid functions are: "); + const size_t num_valid = sizeof(valid_fxns) / sizeof(valid_fxns[0]); + for (size_t i = 0; i < num_valid; ++i) { + fprintf(stderr, "%s%s", valid_fxns[i], + (i == num_valid - 1) ? ".\n" : ", "); + } + return 1; + } + + hwy::g_fxn1 = fxn1; + hwy::g_fxn2 = fxn2; + + HWY_DYNAMIC_DISPATCH(hwy::RunAllBenchmarks)(); + return 0; +} +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/matvec/matvec-inl.h b/lib/highway/hwy/contrib/matvec/matvec-inl.h new file mode 100644 index 00000000000..594c1afb87d --- /dev/null +++ b/lib/highway/hwy/contrib/matvec/matvec-inl.h @@ -0,0 +1,452 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Include guard (still compiled once per target) +#if defined(HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_ +#endif + +#include +#include + +#include "hwy/cache_control.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +TA AddScalar(TA a, TB b) { + return ConvertScalarTo(ConvertScalarTo(a) + + ConvertScalarTo(b)); +} + +template +HWY_NOINLINE void MatVecAddImpl(const T* HWY_RESTRICT mat, + const T* HWY_RESTRICT vec, + const T* HWY_RESTRICT add, T* HWY_RESTRICT out, + hwy::ThreadPool& pool) { + (void)add; + + // Process multiple rows at a time so that we write multiples of a cache line + // to avoid false sharing (>= 64). 128 is better than 256. 512 has too little + // parallelization potential. + constexpr size_t kChunkSize2 = 64 / sizeof(T); + const uint64_t num_chunks = static_cast(kOuter / kChunkSize2); + + const ScalableTag d; + const size_t N = Lanes(d); + // Required for Stream loop, otherwise we might have partial vectors. + HWY_DASSERT(kChunkSize2 >= N); + pool.Run(0, num_chunks, + [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR { + // MSVC workaround: duplicate to ensure constexpr. + constexpr size_t kChunkSize = 64 / sizeof(T); + // Software write-combining to avoid cache pollution from out. + // Although `out` may be used later, keeping it out of the cache + // now and avoiding RFOs is a consistent 5% overall win. + HWY_ALIGN T buf[kChunkSize]; + + // Only handle entire chunks here because the Stream is not masked. + // Remaining rows are handled after the pool.Run. + const size_t begin = static_cast(chunk * kChunkSize); + for (size_t idx_row = 0; idx_row < kChunkSize; ++idx_row) { + auto sum0 = Zero(d); + auto sum1 = Zero(d); + // 4x unrolling barely helps SKX but likely helps Arm V2. + auto sum2 = Zero(d); + auto sum3 = Zero(d); + + const T* HWY_RESTRICT row = &mat[(begin + idx_row) * kInner]; + size_t i = 0; + // No clear win from prefetching from the next 1..3 rows. + // clflush &row[i] is slow, clflushopt less so but not helping. + HWY_UNROLL(1) + for (; i + 4 * N <= kInner; i += 4 * N) { + const auto a0 = LoadU(d, row + i + 0 * N); + const auto v0 = LoadU(d, vec + i + 0 * N); + sum0 = MulAdd(a0, v0, sum0); + + const auto a1 = LoadU(d, row + i + 1 * N); + const auto v1 = LoadU(d, vec + i + 1 * N); + sum1 = MulAdd(a1, v1, sum1); + + const auto a2 = LoadU(d, row + i + 2 * N); + const auto v2 = LoadU(d, vec + i + 2 * N); + sum2 = MulAdd(a2, v2, sum2); + + const auto a3 = LoadU(d, row + i + 3 * N); + const auto v3 = LoadU(d, vec + i + 3 * N); + sum3 = MulAdd(a3, v3, sum3); + } + // Last entire vectors + for (; i + N <= kInner; i += N) { + const auto a0 = LoadU(d, row + i); + const auto v0 = LoadU(d, vec + i); + sum0 = MulAdd(a0, v0, sum0); + } + const size_t remainder = kInner - i; + if (remainder != 0) { + const auto a0 = LoadN(d, row + i, remainder); + const auto v0 = LoadN(d, vec + i, remainder); + sum1 = MulAdd(a0, v0, sum1); + } + // Reduction tree: sum of all accumulators, then their lanes + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum1); + sum0 = Add(sum0, sum2); + buf[idx_row] = ReduceSum(d, sum0); + HWY_IF_CONSTEXPR(kAdd) { + buf[idx_row] = AddScalar(buf[idx_row], add[begin + idx_row]); + } + } // idx_row + HWY_UNROLL(4) // 1..4 iterations + for (size_t i = 0; i != kChunkSize; i += N) { + Stream(Load(d, buf + i), d, out + begin + i); + } + }); + hwy::FlushStream(); + + // Handle remainder rows which are not a multiple of the chunk size. + for (size_t r = num_chunks * kChunkSize2; r < kOuter; ++r) { + auto sum0 = Zero(d); + + const T* HWY_RESTRICT row = &mat[r * kInner]; + size_t i = 0; + HWY_UNROLL(1) + for (; i + N <= kInner; i += N) { + const auto a0 = LoadU(d, row + i); + const auto v0 = LoadU(d, vec + i); + sum0 = MulAdd(a0, v0, sum0); + } + const size_t remainder = kInner - i; + if (remainder != 0) { + const auto a0 = LoadN(d, row + i, remainder); + const auto v0 = LoadN(d, vec + i, remainder); + sum0 = MulAdd(a0, v0, sum0); + } + out[r] = ReduceSum(d, sum0); + HWY_IF_CONSTEXPR(kAdd) { out[r] = AddScalar(out[r], add[r]); } + } // r +} + +// Multiplies mat with vec, adds add and puts the result in out. +// +// mat is a (kOuter, kInner)-shaped array, where element [i,j] is located at +// index i * kInner + j. +// +// vec is a (kInner,)-shaped array. +// +// add is a (kOuter,)-shaped array. +// +// out is a (kOuter,)-shaped array that will set to mat @ vec + add. +template +HWY_NOINLINE void MatVecAdd(const T* HWY_RESTRICT mat, + const T* HWY_RESTRICT vec, + const T* HWY_RESTRICT add, T* HWY_RESTRICT out, + hwy::ThreadPool& pool) { + MatVecAddImpl(mat, vec, add, out, pool); +} + +// Multiplies mat with vec and puts the result in out. +// +// mat is a (kOuter, kInner)-shaped array, where element [i,j] is located at +// index i * kInner + j. +// +// vec is a (kInner,)-shaped array. +// +// out is a (kOuter,)-shaped array that will set to mat @ vec. +template +HWY_NOINLINE void MatVec(const T* HWY_RESTRICT mat, const T* HWY_RESTRICT vec, + T* HWY_RESTRICT out, hwy::ThreadPool& pool) { + MatVecAddImpl(mat, vec, /*add=*/nullptr, out, pool); +} + +// This target lacks too many ops required in our implementation, use +// HWY_EMU128 instead. +#if HWY_TARGET != HWY_SCALAR + +// Specialization for bf16 matrix, which halves memory bandwidth requirements. +template +HWY_NOINLINE void MatVecAddImpl(const hwy::bfloat16_t* HWY_RESTRICT mat, + const float* HWY_RESTRICT vec, + const float* HWY_RESTRICT add, + float* HWY_RESTRICT out, + hwy::ThreadPool& pool) { + // Process multiple rows at a time so that we write multiples of a cache line + // to avoid false sharing (>= 64). 128 is better than 256. 512 has too little + // parallelization potential. + constexpr size_t kChunkSize2 = 64 / sizeof(float); + const uint64_t num_chunks = static_cast(kOuter / kChunkSize2); + + const ScalableTag d; + const Repartition d16; + // In the remainder loop, we only process a single f32 vector, so load half + // vectors of bf16 to avoid overrun. + const Half d16h; + using V = Vec; + using V16 = Vec; + using V16H = Vec; + const size_t N = Lanes(d); + // Required for Stream loop, otherwise we might have partial vectors. + HWY_DASSERT(kChunkSize2 >= N); + pool.Run(0, num_chunks, + [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR { + // MSVC workaround: duplicate to ensure constexpr. + constexpr size_t kChunkSize = 64 / sizeof(float); + // Software write-combining to avoid cache pollution from out. + // Although `out` may be used later, keeping it out of the cache + // now and avoiding RFOs is a consistent 5% overall win. + HWY_ALIGN float buf[kChunkSize]; + + // Only handle entire chunks here because the Stream is not masked. + // Remaining rows are handled after the pool.Run. + const size_t begin = static_cast(chunk * kChunkSize); + for (size_t idx_row = 0; idx_row < kChunkSize; ++idx_row) { + auto sum0 = Zero(d); + auto sum1 = Zero(d); + // 4x unrolling barely helps SKX but likely helps Arm V2. + auto sum2 = Zero(d); + auto sum3 = Zero(d); + + const hwy::bfloat16_t* HWY_RESTRICT row = + &mat[(begin + idx_row) * kInner]; + size_t i = 0; + // No clear win from prefetching from the next 1..3 rows. + // clflush &row[i] is slow, clflushopt less so but not helping. + HWY_UNROLL(1) + for (; i + 4 * N <= kInner; i += 4 * N) { + const V16 b0 = LoadU(d16, row + i + 0 * N); + const V a0 = PromoteLowerTo(d, b0); + const V a1 = PromoteUpperTo(d, b0); + + const V16 b1 = LoadU(d16, row + i + 2 * N); + const V a2 = PromoteLowerTo(d, b1); + const V a3 = PromoteUpperTo(d, b1); + + const V v0 = LoadU(d, vec + i + 0 * N); + sum0 = MulAdd(a0, v0, sum0); + + const V v1 = LoadU(d, vec + i + 1 * N); + sum1 = MulAdd(a1, v1, sum1); + + const V v2 = LoadU(d, vec + i + 2 * N); + sum2 = MulAdd(a2, v2, sum2); + + const V v3 = LoadU(d, vec + i + 3 * N); + sum3 = MulAdd(a3, v3, sum3); + } + // Last entire vectors + for (; i + N <= kInner; i += N) { + const V16H b0 = LoadU(d16h, row + i); + const V a0 = PromoteTo(d, b0); + const V v0 = LoadU(d, vec + i); + sum0 = MulAdd(a0, v0, sum0); + } + const size_t remainder = kInner - i; + if (remainder != 0) { + const V16H b0 = LoadN(d16h, row + i, remainder); + const V a0 = PromoteTo(d, b0); + const V v0 = LoadN(d, vec + i, remainder); + sum1 = MulAdd(a0, v0, sum1); + } + // Reduction tree: sum of all accumulators, then their lanes + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum1); + sum0 = Add(sum0, sum2); + buf[idx_row] = ReduceSum(d, sum0); + HWY_IF_CONSTEXPR(kAdd) { + buf[idx_row] = AddScalar(buf[idx_row], add[begin + idx_row]); + } + } // idx_row + HWY_UNROLL(4) // 1..4 iterations + for (size_t i = 0; i != kChunkSize; i += N) { + Stream(Load(d, buf + i), d, out + begin + i); + } + }); + hwy::FlushStream(); + + // Handle remainder rows which are not a multiple of the chunk size. + for (size_t r = num_chunks * kChunkSize2; r < kOuter; ++r) { + auto sum0 = Zero(d); + + const hwy::bfloat16_t* HWY_RESTRICT row = &mat[r * kInner]; + size_t i = 0; + HWY_UNROLL(1) + for (; i + N <= kInner; i += N) { + const V16H b0 = LoadU(d16h, row + i); + const V a0 = PromoteTo(d, b0); + const V v0 = LoadU(d, vec + i); + sum0 = MulAdd(a0, v0, sum0); + } + const size_t remainder = kInner - i; + if (remainder != 0) { + const V16H b0 = LoadN(d16h, row + i, remainder); + const V a0 = PromoteTo(d, b0); + const V v0 = LoadN(d, vec + i, remainder); + sum0 = MulAdd(a0, v0, sum0); + } + out[r] = ReduceSum(d, sum0); + HWY_IF_CONSTEXPR(kAdd) { out[r] = AddScalar(out[r], add[r]); } + } // r +} + +template +HWY_NOINLINE void MatVecAdd(const hwy::bfloat16_t* HWY_RESTRICT mat, + const float* HWY_RESTRICT vec, + const float* HWY_RESTRICT add, + float* HWY_RESTRICT out, hwy::ThreadPool& pool) { + MatVecAddImpl(mat, vec, add, out, pool); +} + +template +HWY_NOINLINE void MatVec(const hwy::bfloat16_t* HWY_RESTRICT mat, + const float* HWY_RESTRICT vec, float* HWY_RESTRICT out, + hwy::ThreadPool& pool) { + MatVecAddImpl(mat, vec, /*add=*/nullptr, out, pool); +} + +// Both mat and vec are bf16. +template +HWY_NOINLINE void MatVecAddImpl(const hwy::bfloat16_t* HWY_RESTRICT mat, + const hwy::bfloat16_t* HWY_RESTRICT vec, + const hwy::bfloat16_t* HWY_RESTRICT add, + float* HWY_RESTRICT out, + hwy::ThreadPool& pool) { + // Process multiple rows at a time so that we write multiples of a cache line + // to avoid false sharing (>= 64). 128 is better than 256. 512 has too little + // parallelization potential. + constexpr size_t kChunkSize2 = 64 / sizeof(bfloat16_t); + const uint64_t num_chunks = static_cast(kOuter / kChunkSize2); + + const ScalableTag df; + const Repartition d16; + using V16 = Vec; + const size_t N = Lanes(d16); + // Required for Stream loop, otherwise we might have partial vectors. + HWY_DASSERT(kChunkSize2 >= N); + pool.Run(0, num_chunks, + [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR { + // MSVC workaround: duplicate to ensure constexpr. + constexpr size_t kChunkSize = 64 / sizeof(bfloat16_t); + // Software write-combining to avoid cache pollution from out. + // Although `out` may be used later, keeping it out of the cache + // now and avoiding RFOs is a consistent 5% overall win. + HWY_ALIGN float buf[kChunkSize]; + + // Only handle entire chunks here because the Stream is not masked. + // Remaining rows are handled after the pool.Run. + const size_t begin = static_cast(chunk * kChunkSize); + for (size_t idx_row = 0; idx_row < kChunkSize; ++idx_row) { + auto sum0 = Zero(df); + auto sum1 = Zero(df); + auto sum2 = Zero(df); + auto sum3 = Zero(df); + + const hwy::bfloat16_t* HWY_RESTRICT row = + &mat[(begin + idx_row) * kInner]; + size_t i = 0; + // No clear win from prefetching from the next 1..3 rows. + // clflush &row[i] is slow, clflushopt less so but not helping. + HWY_UNROLL(1) + for (; i + 2 * N <= kInner; i += 2 * N) { + const V16 b0 = LoadU(d16, row + i + 0 * N); + const V16 b1 = LoadU(d16, row + i + 1 * N); + const V16 v0 = LoadU(d16, vec + i + 0 * N); + const V16 v1 = LoadU(d16, vec + i + 1 * N); + sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1); + sum2 = ReorderWidenMulAccumulate(df, b1, v1, sum2, sum3); + } + // Last entire vector + for (; i + N <= kInner; i += N) { + const V16 b0 = LoadU(d16, row + i); + const V16 v0 = LoadU(d16, vec + i); + sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1); + } + const size_t remainder = kInner - i; + if (remainder != 0) { + const V16 b0 = LoadN(d16, row + i, remainder); + const V16 v0 = LoadN(d16, vec + i, remainder); + sum2 = ReorderWidenMulAccumulate(df, b0, v0, sum2, sum3); + } + // Reduction tree: sum of all accumulators, then their lanes + sum0 = Add(sum0, sum1); + sum2 = Add(sum2, sum3); + sum0 = Add(sum0, sum2); + buf[idx_row] = ReduceSum(df, sum0); + HWY_IF_CONSTEXPR(kAdd) { + buf[idx_row] = AddScalar(buf[idx_row], add[begin + idx_row]); + } + } // idx_row + HWY_UNROLL(4) // 1..4 iterations + for (size_t i = 0; i != kChunkSize; i += N / 2) { + Stream(Load(df, buf + i), df, out + begin + i); + } + }); + hwy::FlushStream(); + + // Handle remainder rows which are not a multiple of the chunk size. + for (size_t r = num_chunks * kChunkSize2; r < kOuter; ++r) { + auto sum0 = Zero(df); + auto sum1 = Zero(df); + + const hwy::bfloat16_t* HWY_RESTRICT row = &mat[r * kInner]; + size_t i = 0; + HWY_UNROLL(1) + for (; i + N <= kInner; i += N) { + const V16 b0 = LoadU(d16, row + i); + const V16 v0 = LoadU(d16, vec + i); + sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1); + } + const size_t remainder = kInner - i; + if (remainder != 0) { + const V16 b0 = LoadN(d16, row + i, remainder); + const V16 v0 = LoadN(d16, vec + i, remainder); + sum0 = ReorderWidenMulAccumulate(df, b0, v0, sum0, sum1); + } + out[r] = ReduceSum(df, Add(sum0, sum1)); + HWY_IF_CONSTEXPR(kAdd) { out[r] = AddScalar(out[r], add[r]); } + } // r +} + +template +HWY_NOINLINE void MatVecAdd(const hwy::bfloat16_t* HWY_RESTRICT mat, + const hwy::bfloat16_t* HWY_RESTRICT vec, + const hwy::bfloat16_t* HWY_RESTRICT add, + float* HWY_RESTRICT out, hwy::ThreadPool& pool) { + MatVecAddImpl(mat, vec, add, out, pool); +} + +template +HWY_NOINLINE void MatVec(const hwy::bfloat16_t* HWY_RESTRICT mat, + const hwy::bfloat16_t* HWY_RESTRICT vec, + float* HWY_RESTRICT out, hwy::ThreadPool& pool) { + MatVecAddImpl(mat, vec, /*add=*/nullptr, out, pool); +} + +#endif // HWY_TARGET != HWY_SCALAR + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_MATVEC_MATVEC_INL_H_ diff --git a/lib/highway/hwy/contrib/random/random-inl.h b/lib/highway/hwy/contrib/random/random-inl.h new file mode 100644 index 00000000000..d14598387e1 --- /dev/null +++ b/lib/highway/hwy/contrib/random/random-inl.h @@ -0,0 +1,383 @@ +/* + * Original implementation written in 2019 + * by David Blackman and Sebastiano Vigna (vigna@acm.org) + * Available at https://prng.di.unimi.it/ with creative commons license: + * To the extent possible under law, the author has dedicated all copyright + * and related and neighboring rights to this software to the public domain + * worldwide. This software is distributed without any warranty. + * See . + * + * This implementation is a Vector port of the original implementation + * written by Marco Barbone (m.barbone19@imperial.ac.uk). + * I take no credit for the original implementation. + * The code is provided as is and the original license applies. + */ + +#if defined(HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_) == \ + defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_ +#undef HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_ +#else +#define HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_ +#endif + +#include +#include +#include + +#include "hwy/aligned_allocator.h" +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); // required if not using HWY_ATTR + +namespace hwy { + +namespace HWY_NAMESPACE { // required: unique per target +namespace internal { + +#if HWY_HAVE_FLOAT64 +// C++ < 17 does not support hexfloat +#if __cpp_hex_float > 201603L +constexpr double kMulConst = 0x1.0p-53; +#else +constexpr double kMulConst = + 0.00000000000000011102230246251565404236316680908203125; +#endif // __cpp_hex_float + +#endif // HWY_HAVE_FLOAT64 + +constexpr std::uint64_t kJump[] = {0x180ec6d33cfd0aba, 0xd5a61266f0c9392c, + 0xa9582618e03fc9aa, 0x39abdc4529b1661c}; + +constexpr std::uint64_t kLongJump[] = {0x76e15d3efefdcbbf, 0xc5004e441c522fb3, + 0x77710069854ee241, 0x39109bb02acbe635}; + +class SplitMix64 { + public: + constexpr explicit SplitMix64(const std::uint64_t state) noexcept + : state_(state) {} + + HWY_CXX14_CONSTEXPR std::uint64_t operator()() { + std::uint64_t z = (state_ += 0x9e3779b97f4a7c15); + z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9; + z = (z ^ (z >> 27)) * 0x94d049bb133111eb; + return z ^ (z >> 31); + } + + private: + std::uint64_t state_; +}; + +class Xoshiro { + public: + HWY_CXX14_CONSTEXPR explicit Xoshiro(const std::uint64_t seed) noexcept + : state_{} { + SplitMix64 splitMix64{seed}; + for (auto &element : state_) { + element = splitMix64(); + } + } + + HWY_CXX14_CONSTEXPR explicit Xoshiro(const std::uint64_t seed, + const std::uint64_t thread_id) noexcept + : Xoshiro(seed) { + for (auto i = UINT64_C(0); i < thread_id; ++i) { + Jump(); + } + } + + HWY_CXX14_CONSTEXPR std::uint64_t operator()() noexcept { return Next(); } + +#if HWY_HAVE_FLOAT64 + HWY_CXX14_CONSTEXPR double Uniform() noexcept { + return static_cast(Next() >> 11) * kMulConst; + } +#endif + + HWY_CXX14_CONSTEXPR std::array GetState() const { + return {state_[0], state_[1], state_[2], state_[3]}; + } + + HWY_CXX17_CONSTEXPR void SetState( + std::array state) noexcept { + state_[0] = state[0]; + state_[1] = state[1]; + state_[2] = state[2]; + state_[3] = state[3]; + } + + static constexpr std::uint64_t StateSize() noexcept { return 4; } + + /* This is the jump function for the generator. It is equivalent to 2^128 + * calls to next(); it can be used to generate 2^128 non-overlapping + * subsequences for parallel computations. */ + HWY_CXX14_CONSTEXPR void Jump() noexcept { Jump(kJump); } + + /* This is the long-jump function for the generator. It is equivalent to 2^192 + * calls to next(); it can be used to generate 2^64 starting points, from each + * of which jump() will generate 2^64 non-overlapping subsequences for + * parallel distributed computations. */ + HWY_CXX14_CONSTEXPR void LongJump() noexcept { Jump(kLongJump); } + + private: + std::uint64_t state_[4]; + + static constexpr std::uint64_t Rotl(const std::uint64_t x, int k) noexcept { + return (x << k) | (x >> (64 - k)); + } + + HWY_CXX14_CONSTEXPR std::uint64_t Next() noexcept { + const std::uint64_t result = Rotl(state_[0] + state_[3], 23) + state_[0]; + const std::uint64_t t = state_[1] << 17; + + state_[2] ^= state_[0]; + state_[3] ^= state_[1]; + state_[1] ^= state_[2]; + state_[0] ^= state_[3]; + + state_[2] ^= t; + + state_[3] = Rotl(state_[3], 45); + + return result; + } + + HWY_CXX14_CONSTEXPR void Jump(const std::uint64_t (&jumpArray)[4]) noexcept { + std::uint64_t s0 = 0; + std::uint64_t s1 = 0; + std::uint64_t s2 = 0; + std::uint64_t s3 = 0; + + for (const std::uint64_t i : jumpArray) + for (std::uint_fast8_t b = 0; b < 64; b++) { + if (i & std::uint64_t{1UL} << b) { + s0 ^= state_[0]; + s1 ^= state_[1]; + s2 ^= state_[2]; + s3 ^= state_[3]; + } + Next(); + } + + state_[0] = s0; + state_[1] = s1; + state_[2] = s2; + state_[3] = s3; + } +}; + +} // namespace internal + +class VectorXoshiro { + private: + using VU64 = Vec>; + using StateType = AlignedNDArray; +#if HWY_HAVE_FLOAT64 + using VF64 = Vec>; +#endif + + public: + explicit VectorXoshiro(const std::uint64_t seed, + const std::uint64_t threadNumber = 0) + : state_{{internal::Xoshiro::StateSize(), + Lanes(ScalableTag{})}}, + streams{state_.shape().back()} { + internal::Xoshiro xoshiro{seed}; + + for (std::uint64_t i = 0; i < threadNumber; ++i) { + xoshiro.LongJump(); + } + + for (size_t i = 0UL; i < streams; ++i) { + const auto state = xoshiro.GetState(); + for (size_t j = 0UL; j < internal::Xoshiro::StateSize(); ++j) { + state_[{j}][i] = state[j]; + } + xoshiro.Jump(); + } + } + + HWY_INLINE VU64 operator()() noexcept { return Next(); } + + AlignedVector operator()(const std::size_t n) { + AlignedVector result(n); + const ScalableTag tag{}; + auto s0 = Load(tag, state_[{0}].data()); + auto s1 = Load(tag, state_[{1}].data()); + auto s2 = Load(tag, state_[{2}].data()); + auto s3 = Load(tag, state_[{3}].data()); + for (std::uint64_t i = 0; i < n; i += Lanes(tag)) { + const auto next = Update(s0, s1, s2, s3); + Store(next, tag, result.data() + i); + } + Store(s0, tag, state_[{0}].data()); + Store(s1, tag, state_[{1}].data()); + Store(s2, tag, state_[{2}].data()); + Store(s3, tag, state_[{3}].data()); + return result; + } + + template + std::array operator()() noexcept { + alignas(HWY_ALIGNMENT) std::array result; + const ScalableTag tag{}; + auto s0 = Load(tag, state_[{0}].data()); + auto s1 = Load(tag, state_[{1}].data()); + auto s2 = Load(tag, state_[{2}].data()); + auto s3 = Load(tag, state_[{3}].data()); + for (std::uint64_t i = 0; i < N; i += Lanes(tag)) { + const auto next = Update(s0, s1, s2, s3); + Store(next, tag, result.data() + i); + } + Store(s0, tag, state_[{0}].data()); + Store(s1, tag, state_[{1}].data()); + Store(s2, tag, state_[{2}].data()); + Store(s3, tag, state_[{3}].data()); + return result; + } + + std::uint64_t StateSize() const noexcept { + return streams * internal::Xoshiro::StateSize(); + } + + const StateType &GetState() const { return state_; } + +#if HWY_HAVE_FLOAT64 + + HWY_INLINE VF64 Uniform() noexcept { + const ScalableTag real_tag{}; + const auto MUL_VALUE = Set(real_tag, internal::kMulConst); + const auto bits = ShiftRight<11>(Next()); + const auto real = ConvertTo(real_tag, bits); + return Mul(real, MUL_VALUE); + } + + AlignedVector Uniform(const std::size_t n) { + AlignedVector result(n); + const ScalableTag tag{}; + const ScalableTag real_tag{}; + const auto MUL_VALUE = Set(real_tag, internal::kMulConst); + + auto s0 = Load(tag, state_[{0}].data()); + auto s1 = Load(tag, state_[{1}].data()); + auto s2 = Load(tag, state_[{2}].data()); + auto s3 = Load(tag, state_[{3}].data()); + + for (std::uint64_t i = 0; i < n; i += Lanes(real_tag)) { + const auto next = Update(s0, s1, s2, s3); + const auto bits = ShiftRight<11>(next); + const auto real = ConvertTo(real_tag, bits); + const auto uniform = Mul(real, MUL_VALUE); + Store(uniform, real_tag, result.data() + i); + } + + Store(s0, tag, state_[{0}].data()); + Store(s1, tag, state_[{1}].data()); + Store(s2, tag, state_[{2}].data()); + Store(s3, tag, state_[{3}].data()); + return result; + } + + template + std::array Uniform() noexcept { + alignas(HWY_ALIGNMENT) std::array result; + const ScalableTag tag{}; + const ScalableTag real_tag{}; + const auto MUL_VALUE = Set(real_tag, internal::kMulConst); + + auto s0 = Load(tag, state_[{0}].data()); + auto s1 = Load(tag, state_[{1}].data()); + auto s2 = Load(tag, state_[{2}].data()); + auto s3 = Load(tag, state_[{3}].data()); + + for (std::uint64_t i = 0; i < N; i += Lanes(real_tag)) { + const auto next = Update(s0, s1, s2, s3); + const auto bits = ShiftRight<11>(next); + const auto real = ConvertTo(real_tag, bits); + const auto uniform = Mul(real, MUL_VALUE); + Store(uniform, real_tag, result.data() + i); + } + + Store(s0, tag, state_[{0}].data()); + Store(s1, tag, state_[{1}].data()); + Store(s2, tag, state_[{2}].data()); + Store(s3, tag, state_[{3}].data()); + return result; + } + +#endif + + private: + StateType state_; + const std::uint64_t streams; + + HWY_INLINE static VU64 Update(VU64 &s0, VU64 &s1, VU64 &s2, + VU64 &s3) noexcept { + const auto result = Add(RotateRight<41>(Add(s0, s3)), s0); + const auto t = ShiftLeft<17>(s1); + s2 = Xor(s2, s0); + s3 = Xor(s3, s1); + s1 = Xor(s1, s2); + s0 = Xor(s0, s3); + s2 = Xor(s2, t); + s3 = RotateRight<19>(s3); + return result; + } + + HWY_INLINE VU64 Next() noexcept { + const ScalableTag tag{}; + auto s0 = Load(tag, state_[{0}].data()); + auto s1 = Load(tag, state_[{1}].data()); + auto s2 = Load(tag, state_[{2}].data()); + auto s3 = Load(tag, state_[{3}].data()); + auto result = Update(s0, s1, s2, s3); + Store(s0, tag, state_[{0}].data()); + Store(s1, tag, state_[{1}].data()); + Store(s2, tag, state_[{2}].data()); + Store(s3, tag, state_[{3}].data()); + return result; + } +}; + +template +class CachedXoshiro { + public: + using result_type = std::uint64_t; + + static constexpr result_type(min)() { + return (std::numeric_limits::min)(); + } + + static constexpr result_type(max)() { + return (std::numeric_limits::max)(); + } + + explicit CachedXoshiro(const result_type seed, + const result_type threadNumber = 0) + : generator_{seed, threadNumber}, + cache_{generator_.operator()()}, + index_{0} {} + + result_type operator()() noexcept { + if (HWY_UNLIKELY(index_ == size)) { + cache_ = std::move(generator_.operator()()); + index_ = 0; + } + return cache_[index_++]; + } + + private: + VectorXoshiro generator_; + alignas(HWY_ALIGNMENT) std::array cache_; + std::size_t index_; + + static_assert((size & (size - 1)) == 0 && size != 0, + "only power of 2 are supported"); +}; + +} // namespace HWY_NAMESPACE +} // namespace hwy + +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_ diff --git a/lib/highway/hwy/contrib/sort/BUILD b/lib/highway/hwy/contrib/sort/BUILD new file mode 100644 index 00000000000..a146efbe3a2 --- /dev/null +++ b/lib/highway/hwy/contrib/sort/BUILD @@ -0,0 +1,310 @@ +load("@rules_cc//cc:cc_binary.bzl", "cc_binary") +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +# Unused on Bazel builds, where this is not defined/known; Copybara replaces +# usages with an empty list. +COMPAT = [ + "//buildenv/target:non_prod", # includes mobile/vendor. +] + +cc_library( + name = "intel", + # hdrs = select({ + # "//third_party/bazel_platforms/cpu:x86_64": [ + # "avx512-16bit-common.h", + # "avx512-16bit-qsort.hpp", + # "avx512-32bit-qsort.hpp", + # "avx512-64bit-common.h", + # "avx512-64bit-qsort.hpp", + # "avx512-common-qsort.h", + # ], + # "//conditions:default": [], + # }), + compatible_with = [], +) + +cc_library( + name = "vxsort", + srcs = [ + # "vxsort/isa_detection.cpp", + # "vxsort/isa_detection_msvc.cpp", + # "vxsort/isa_detection_sane.cpp", + # "vxsort/machine_traits.avx2.cpp", + # "vxsort/smallsort/avx2_load_mask_tables.cpp", + # "vxsort/smallsort/bitonic_sort.AVX2.double.generated.cpp", + # "vxsort/smallsort/bitonic_sort.AVX2.float.generated.cpp", + # "vxsort/smallsort/bitonic_sort.AVX2.int32_t.generated.cpp", + # "vxsort/smallsort/bitonic_sort.AVX2.int64_t.generated.cpp", + # "vxsort/smallsort/bitonic_sort.AVX2.uint32_t.generated.cpp", + # "vxsort/smallsort/bitonic_sort.AVX2.uint64_t.generated.cpp", + # "vxsort/smallsort/bitonic_sort.AVX512.double.generated.cpp", + # "vxsort/smallsort/bitonic_sort.AVX512.float.generated.cpp", + # "vxsort/smallsort/bitonic_sort.AVX512.int32_t.generated.cpp", + # "vxsort/smallsort/bitonic_sort.AVX512.int64_t.generated.cpp", + # "vxsort/smallsort/bitonic_sort.AVX512.uint32_t.generated.cpp", + # "vxsort/smallsort/bitonic_sort.AVX512.uint64_t.generated.cpp", + # "vxsort/vxsort_stats.cpp", + ], + hdrs = [ + # "vxsort/alignment.h", + # "vxsort/defs.h", + # "vxsort/isa_detection.h", + # "vxsort/machine_traits.avx2.h", + # "vxsort/machine_traits.avx512.h", + # "vxsort/machine_traits.h", + # "vxsort/packer.h", + # "vxsort/smallsort/bitonic_sort.AVX2.double.generated.h", + # "vxsort/smallsort/bitonic_sort.AVX2.float.generated.h", + # "vxsort/smallsort/bitonic_sort.AVX2.int32_t.generated.h", + # "vxsort/smallsort/bitonic_sort.AVX2.int64_t.generated.h", + # "vxsort/smallsort/bitonic_sort.AVX2.uint32_t.generated.h", + # "vxsort/smallsort/bitonic_sort.AVX2.uint64_t.generated.h", + # "vxsort/smallsort/bitonic_sort.AVX512.double.generated.h", + # "vxsort/smallsort/bitonic_sort.AVX512.float.generated.h", + # "vxsort/smallsort/bitonic_sort.AVX512.int32_t.generated.h", + # "vxsort/smallsort/bitonic_sort.AVX512.int64_t.generated.h", + # "vxsort/smallsort/bitonic_sort.AVX512.uint32_t.generated.h", + # "vxsort/smallsort/bitonic_sort.AVX512.uint64_t.generated.h", + # "vxsort/smallsort/bitonic_sort.h", + # "vxsort/vxsort.h", + # "vxsort/vxsort_stats.h", + ], + compatible_with = [], + textual_hdrs = [ + # "vxsort/vxsort_targets_disable.h", + # "vxsort/vxsort_targets_enable_avx2.h", + # "vxsort/vxsort_targets_enable_avx512.h", + ], +) + +VQSORT_SRCS = [ + # Split into separate files to reduce MSVC build time. + "vqsort_128a.cc", + "vqsort_128d.cc", + "vqsort_f16a.cc", + "vqsort_f16d.cc", + "vqsort_f32a.cc", + "vqsort_f32d.cc", + "vqsort_f64a.cc", + "vqsort_f64d.cc", + "vqsort_i16a.cc", + "vqsort_i16d.cc", + "vqsort_i32a.cc", + "vqsort_i32d.cc", + "vqsort_i64a.cc", + "vqsort_i64d.cc", + # vqsort_kv64a.cc is in :vqsort_k32v32 and vqsort.cc is in :vqsort_shared. + "vqsort_kv128a.cc", + "vqsort_kv128d.cc", + "vqsort_u16a.cc", + "vqsort_u16d.cc", + "vqsort_u32a.cc", + "vqsort_u32d.cc", + "vqsort_u64a.cc", + "vqsort_u64d.cc", +] + +VQSORT_TEXTUAL_HDRS = [ + "shared-inl.h", + "sorting_networks-inl.h", + "traits-inl.h", + "traits128-inl.h", + "vqsort-inl.h", + # Placeholder for internal instrumentation. Do not remove. +] + +# both :vqsort_k32v32 and :vqsort depend on this. +cc_library( + name = "vqsort_shared", + srcs = [ + "vqsort.cc", + ], + hdrs = [ + "order.h", # part of public interface, included by vqsort.h + "vqsort.h", # public interface + ], + compatible_with = [], + local_defines = ["hwy_contrib_EXPORTS"], + textual_hdrs = VQSORT_TEXTUAL_HDRS, + deps = [ + "//:algo", + "//:hwy", + ], +) + +cc_library( + name = "vqsort_k32v32", + srcs = [ + "vqsort_kv64a.cc", + "vqsort_kv64d.cc", + ], + hdrs = [ + "order.h", # part of public interface, included by vqsort.h + "vqsort.h", # public interface + ], + compatible_with = [], + local_defines = ["hwy_contrib_EXPORTS"], + textual_hdrs = VQSORT_TEXTUAL_HDRS, + deps = [ + ":vqsort_shared", + "//:algo", + "//:hwy", + ], +) + +cc_library( + name = "vqsort", + srcs = VQSORT_SRCS, + hdrs = [ + "order.h", # part of public interface, included by vqsort.h + "vqsort.h", # public interface + ], + compatible_with = [], + local_defines = ["hwy_contrib_EXPORTS"], + textual_hdrs = VQSORT_TEXTUAL_HDRS, + deps = [ + ":intel", # required if HAVE_INTEL + ":vqsort_k32v32", + ":vqsort_shared", + ":vxsort", # required if HAVE_VXSORT + "//:algo", + "//:hwy", + "//:nanobenchmark", + ], +) + +# ----------------------------------------------------------------------------- +# Internal-only targets + +# Same as vqsort, but add HWY_COMPILE_ALL_ATTAINABLE to ensure we cover all +# targets. Do not enable this in the main vqsort because it increases +# compile times. +cc_library( + name = "vqsort_for_test", + srcs = VQSORT_SRCS, + hdrs = [ + "order.h", # part of public interface, included by vqsort.h + "vqsort.h", # public interface + ], + compatible_with = [], + local_defines = [ + "hwy_contrib_EXPORTS", + # Build for all targets because sort_test will dynamic-dispatch to all. + "HWY_COMPILE_ALL_ATTAINABLE", + ], + textual_hdrs = VQSORT_TEXTUAL_HDRS, + deps = [ + "//:algo", + "//:hwy", + "//:nanobenchmark", + ], +) + +cc_library( + name = "helpers", + testonly = 1, + textual_hdrs = [ + "algo-inl.h", + "result-inl.h", + ], + deps = [ + ":vqsort", + "//:nanobenchmark", + # Required for HAVE_PDQSORT, but that is unused and this is + # unavailable to Bazel builds, hence commented out. + # "//third_party/boost/allowed", + # Avoid ips4o and thus TBB to work around hwloc build failure. + ], +) + +cc_binary( + name = "print_network", + testonly = 1, + srcs = ["print_network.cc"], + deps = [ + ":helpers", + ":vqsort", + "//:hwy", + ], +) + +TEST_MAIN = select({ + "//:compiler_msvc": [], + "//conditions:default": ["@com_google_googletest//:gtest_main"], +}) + +cc_test( + name = "sort_unit_test", + size = "small", + srcs = ["sort_unit_test.cc"], + # Do not enable fully_static_link (pthread crash on bazel) + local_defines = ["HWY_IS_TEST"], + # for test_suite. + tags = ["hwy_ops_test"], + deps = [ + ":helpers", + ":vqsort_for_test", + "//:hwy", + "//:hwy_test_util", + ] + TEST_MAIN, +) + +cc_test( + name = "sort_test", + size = "medium", + timeout = "long", + srcs = ["sort_test.cc"], + # Do not enable fully_static_link (pthread crash on bazel) + local_defines = ["HWY_IS_TEST"], + # for test_suite. + tags = ["hwy_ops_test"], + deps = [ + ":helpers", + ":vqsort_for_test", + "//:hwy", + "//:hwy_test_util", + "//:thread_pool", + "//:topology", + ] + TEST_MAIN, +) + +cc_test( + name = "bench_sort", + size = "medium", + srcs = ["bench_sort.cc"], + # Do not enable fully_static_link (pthread crash on bazel) + local_defines = ["HWY_IS_TEST"], + # for test_suite. + tags = ["hwy_ops_test"], + deps = [ + ":helpers", + ":vqsort", + "//:hwy", + "//:hwy_test_util", + "//:nanobenchmark", + "//:thread_pool", + ] + TEST_MAIN, +) + +cc_binary( + name = "bench_parallel", + testonly = 1, + srcs = ["bench_parallel.cc"], + # Do not enable fully_static_link (pthread crash on bazel) + local_defines = ["HWY_IS_TEST"], + deps = [ + ":helpers", + ":vqsort", + "//:hwy", + "//:hwy_test_util", + "//:nanobenchmark", + ] + TEST_MAIN, +) diff --git a/lib/highway/hwy/contrib/sort/README.md b/lib/highway/hwy/contrib/sort/README.md new file mode 100644 index 00000000000..800d2cfcdab --- /dev/null +++ b/lib/highway/hwy/contrib/sort/README.md @@ -0,0 +1,361 @@ +# Vectorized and performance-portable Quicksort + +## Introduction + +As of 2022-06-07 this sorts large arrays of built-in types about ten times as +fast as LLVM's `std::sort`. Note that other algorithms such as pdqsort can be +about twice as fast as LLVM's std::sort as of 2023-06. + +See also our +[blog post](https://opensource.googleblog.com/2022/06/Vectorized%20and%20performance%20portable%20Quicksort.html) +and [paper](https://arxiv.org/abs/2205.05982). + +## Instructions + +Here are instructions for reproducing our results with cross-platform CMake, +Linux, or AWS (SVE, NEON). + +### CMake, any platform + +Please first ensure that Clang (tested with 13.0.1 and 15.0.6) is installed, and +if it is not the default compiler, point the CC and CXX environment variables to +it, e.g. + +``` +export CC=clang-15 +export CXX=clang++-15 +``` + +Then run the usual CMake workflow, also documented in the Highway README, e.g.: + +``` +mkdir -p build && cd build && cmake .. && make -j +taskset -c 2 tests/bench_sort +``` + +The optional `taskset -c 2` part reduces the variability of measurements by +preventing the OS from migrating the benchmark between cores. + +### Linux + +Please first ensure golang, and Clang (tested with 13.0.1) are installed via +your system's package manager. + +``` +go install github.com/bazelbuild/bazelisk@latest +git clone https://github.com/google/highway +cd highway +CC=clang CXX=clang++ ~/go/bin/bazelisk build -c opt hwy/contrib/sort:all +bazel-bin/hwy/contrib/sort/sort_test +bazel-bin/hwy/contrib/sort/bench_sort +``` + +### AWS Graviton3 + +Instance config: amazon linux 5.10 arm64, c7g.8xlarge (largest allowed config is +32 vCPU). Initial launch will fail. Wait a few minutes for an email saying the +config is verified, then re-launch. See IPv4 hostname in list of instances. + +`ssh -i /path/key.pem ec2-user@hostname` + +Note that the AWS CMake package is too old for llvm, so we build it first: +``` +wget https://cmake.org/files/v3.23/cmake-3.23.2.tar.gz +tar -xvzf cmake-3.23.2.tar.gz && cd cmake-3.23.2/ +./bootstrap -- -DCMAKE_USE_OPENSSL=OFF +make -j8 && sudo make install +cd .. +``` + +AWS clang is at version 11.1, which generates unnecessary `AND` instructions +which slow down the sort by 1.15x. We tested with clang trunk as of June 13 +(which reports Git hash 8f6512fea000c3a0d394864bb94e524bee375069). To build: + +``` +git clone --depth 1 https://github.com/llvm/llvm-project.git +cd llvm-project +mkdir -p build && cd build +/usr/local/bin/cmake ../llvm -DLLVM_ENABLE_PROJECTS="clang" -DLLVM_ENABLE_RUNTIMES="libcxx;libcxxabi" -DCMAKE_BUILD_TYPE=Release +make -j32 && sudo make install +``` + +``` +sudo yum install go +go install github.com/bazelbuild/bazelisk@latest +git clone https://github.com/google/highway +cd highway +CC=/usr/local/bin/clang CXX=/usr/local/bin/clang++ ~/go/bin/bazelisk build -c opt --copt=-march=armv8.2-a+sve hwy/contrib/sort:all +bazel-bin/hwy/contrib/sort/sort_test +bazel-bin/hwy/contrib/sort/bench_sort +``` + +The above command line enables SVE, which is currently only available on +Graviton 3. You can also test NEON on the same processor, or other Arm CPUs, by +changing the `-march=` option to `--copt=-march=armv8.2-a+crypto`. Note that +such flags will be unnecessary once Clang supports `#pragma target` for NEON and +SVE intrinsics, as it does for x86. + +## Results + +`bench_sort` outputs the instruction set (AVX3 refers to AVX-512), the sort +algorithm (std for `std::sort`, vq for our vqsort), the type of keys being +sorted (f32 is float), the distribution of keys (uniform32 for uniform random +with range 0-2^32), the number of keys, then the throughput of sorted keys (i.e. +number of key bytes output per second). + +Example excerpt from Xeon 6154 (Skylake-X) CPU clocked at 3 GHz: + +``` +[ RUN ] BenchSortGroup/BenchSort.BenchAllSort/AVX3 + AVX3: std: f32: uniform32: 1.00E+06 54 MB/s ( 1 threads) + AVX3: vq: f32: uniform32: 1.00E+06 1143 MB/s ( 1 threads) +``` + +## Additional results + +Thanks to Lukas Bergdoll, who did a thorough [performance analysis](https://github.com/Voultapher/sort-research-rs/blob/main/writeup/intel_avx512/text.md) +on various sort implementations. This helped us identify a performance bug, +caused by obtaining entropy from the OS on each call. This was fixed in #1334 +and we look forward to the updated results. + +### Optimizations for small arrays + +Our initial focus was on large arrays. Since the VQSort paper was published, +we have improved its performance for small arrays: + +- Previously, each call to VQSort obtained entropy from the OS. Unpredictable + seeding does help avoid worst-cases, and the cost is negligible when the + input size is at least 100K elements. However, the overhead is very costly + for arrays of just 100 or 1000, so we now obtain entropy only once per + thread and cache the seeds in TLS. This significantly improves the + performance on subsequent calls. Users can also explicitly initialize the + random generator. + +- We also improved the efficiency of our sorting network for inputs shorter + than half its size. Our approach avoids costly transposes by interpreting + inputs as a 2D matrix. Previously, we always used 16 rows, which means only + a single vector lane is active for up to 16 elements. We have added 8x2 and + 8x4 networks which use more lanes when available, and also 4x1 and 8x1 + networks for very small inputs. + +- Previously we also loaded (overlapping) full vectors, with the offsets + determined by the number of columns. Now we use the minimum vector size + sufficient for the number of columns, which enables higher IPC on Skylake + and reduces the cost of unaligned loads. + + Unfortunately this decreases code reuse; VQSort now consists of about 1500 + instructions (https://gcc.godbolt.org/z/ojYKfjPe6). The size of sorting + networks has nearly doubled to 10.8 KiB, 70% of the total. Although large, + this still fits comfortably within 32 KiB instruction caches, and possibly + even in micro-op caches (DSB, 1500-2300 micro-ops), especially given that + not all instructions are guaranteed to execute. + +### Study of AVX-512 downclocking +We study whether AVX-512 downclocking affects performance. Using the GHz +reported by perf, we find an upper bound on the effects of downclocking, and +observe that its effect is negligible when compared to scalar code. + +This issue has somehow attracted far more attention than seems warranted. An +attempt by Daniel Lemire to measure the +[worst-case](https://lemire.me/blog/2018/08/15/the-dangers-of-avx-512-throttling-a-3-impact/) +only saw a **3% decrease**, and Intel CPUs since Icelake, as well as AMD Zen4, +are much less impacted by throttling, if at all. By contrast, "Silver" and +"Bronze" Intel Xeons have more severe throttling and would require a large(r) +speedup from AVX-512 to outweigh the downclocking. However, these CPUs are +marketed towards "entry compute, network and storage" and "small business and +storage server solutions", and are thus less suitable for the high-performance +workloads we consider. + +Our test workstation runs Linux (6.1.20-2rodete1-amd64) and has the same Xeon +Gold 6154 CPU used in our paper because its Skylake microarchitecture is the +most (potentially) affected. The compiler is a Clang similar to the LLVM trunk. + +We added a new 'cold' benchmark that initializes random seeds, fills an array +with a constant except at one random index, calls VQSort, and then prints a +random element to ensure the computations are not elided. To run it, we build +bench_sort with `-DSORT_ONLY_COLD=1` and then invoke +`taskset -c 6 setarch -R x86_64 perf stat -r 15 -d bench_sort`. The taskset and +setarch serve to reduce variability by avoiding thread migration, and disabling +address space randomization. `-r 15` requests 15 runs so that perf can display +the variability of the measurements: < 1% for cycles, instructions, L1 dcache +loads; LLC miss variability is much higher (> 10%) presumably due to the +remaining background activity on this machine. + +For our measurements, we use the GHz value reported by `perf`. This does not +include time spent in the kernel, and is thus noisy for short runtimes. Note +that running `perf` under `sudo` is not an option because it results in +"Workload failed: Cannot allocate memory". We see results between 2.6 - 2.9 GHz +when running AVX-512 code. This is relative to 3.0 GHz nominal; we disabled +Turbo Boost via MSR and ran `sudo cpupower frequency-set --governor performance` +to prevent unnecessary frequency reductions. To the best of our knowledge, the +remaining gap is explained by time spent in the kernel (in particular handling +page faults) and downclocking. Thus an *upper-bound* for the latter is +(3 - 2.9)/3 to (3 - 2.6)/3, or **1.03 - 1.13x**. Such a frequency reduction +would already be negligible compared to the 2-4x increase in work per cycle from +512-bit SIMD relative to 256 or 128-bit SIMD, which is typically less or not at +all affected by downclocking. + +To further tighten this bound, we compare AVX-512 code vs. non-AVX-512 code, in +the form of `std::sort`. Ensuring the remainder of the binary does not use +AVX-512 is nontrivial. Library functions such as `memset` are known to use +AVX-512, and they would not show up in a disassembly of our binary. Neither +would they raise exceptions if run on a CPU lacking AVX-512 support, because +software typically verifies CPU support before running AVX-512. As a first step, +we take care to avoid calls to such library functions in our test, which is more +feasible with a self-contained small binary. In particular, array +zero-initialization typically compiles to `memset` (verified with clang-16), so +we manually initialize the array to the return value of an `Unpredictable1` +function whose implementation is not visible to the compiler. This indeed +compiles to a scalar loop. To further increase confidence that the binary lacks +AVX-512 instructions before VQSort, we replace the initialization loop with +AVX-512 stores. This indeed raises the measured throughput from a fairly +consistent 9 GB/s to 9-15 GB/s, likely because some of the AVX-512 startup now +occurs outside of our timings. We examine this effect in the next section, but +for now we can conclude that because adding AVX-512 makes a difference, the +binary was otherwise not using it. Now we can revert to scalar initialization +and compare the GHz reported for VQSort vs. `std::sort`. Across three runs, the +ranges are 2.8-2.9 and 2.8-2.8 GHz. Thus we conclude: if there is any +downclocking for a single core running AVX-512 on this Skylake-X CPU, the effect +is **under the noise floor of our measurement**, and certainly far below any +speedup one can reasonably predict from 512-bit SIMD. We expect this result to +generalize to AMD Zen4 and any Gold/Platinum Intel Xeon. + +### Study of AVX-512 startup overhead + +In the previous section, we saw that downclocking is negligible on our system, +but there is a noticeable benefit to warming up AVX-512 before the sort. To +understand why, we refer to Travis Downs' excellent +[measurements](https://travisdowns.github.io/blog/2020/01/17/avxfreq1.html#summary) +of how Skylake reacts to an AVX-512 instruction: 8-20 us of reduced instruction +throughput, an additional potential halt of 10 us, and then downclocking. +Note that downclocking is negligible on a single core per the previous section. + +We choose the array length of 10K unsigned 64-bit keys such that VQSort +completes in 7-10 us. Thus in this benchmark, VQSort (almost) finishes before +AVX-512 is fully warmed up, and the speedup is reduced because the startup costs +are amortized over relatively little data. Across five series of 15 runs, the +average of average throughputs is 9.3 GB/s, implying a runtime of 8.6 us +including startup costs. + +Note that the two-valued, almost all-equal input distribution is quite skewed. +The above throughput does not reflect the performance attainable on other +distributions, especially uniform random. However, this choice is deliberate +because Quicksort can terminate early if all values in a partition are equal. +When measuring such a 'best-case' input, we are more likely to observe the cost +of startup overhead in surrounding code. Otherwise, this overhead might be +hidden by the increase in sorting time. + +Now let us compare this throughput to the previously mentioned measurement with +AVX-512 warmed up (via slow scatter instructions so that initialization takes +about 100 us, well in excess of the warmup period): 15.2 GB/s, or 5.3 us without +startup cost. It appears the 10 us halt is not happening, possibly because we do +not use SIMD floating-point nor multiplication instructions. Thus we only +experience reduced instruction throughput and/or increased latency. The ratio +between cold and warmed-up time is only 1.6, which is plausible if the Skylake +throttling is actually rounding latencies up to a multiple of four cycles, as +Downs speculates. Indeed a large fraction of the SIMD instructions especially in +the VQSort base case are cross-lane or 64-bit min/max operations with latencies +of 3 cycles on Skylake, so their slowdown might only be 1.3x. The measured 1.6x +could plausibly derive from 7/8 of 1.3x and 1/8 of 4x for single-cycle latency +instructions. + +Assuming this understanding of AVX-512 startup cost is valid, how long does it +remain active before the CPU reverts to the previous settings? The CPU cannot +know what future instructions are coming, and to prevent unnecessary +transitions, it has a hysteresis (delay after the last AVX-512 instruction +before shutting down) which Downs measures as 680 us. Thus our benchmark +subsequently sleeps for 100 ms to ensure the next run of the binary sees the +original CPU state. Indeed we find for the five series that the slopes of the +lines of best fit are negative in one case, positive in two, and flat in two, +indicating there is no consistent pattern of benefit for earlier or later runs. + +What are the implications for users of VQSort? If the surrounding code executes +an AVX-512 instruction at least every 500 us, then AVX-512 remains active and +**any call to VQSort will benefit from it, no matter how small the input**. +This is a reasonable expectation for modern systems whose designers were aware +of data-oriented programming principles, because many (though not all) domains +and operations can benefit from SIMD. By contrast, consider the case of dropping +VQSort into an existing legacy system that does not yet use SIMD. In the case of +10K input sizes, we still observe a 2.3x speedup vs. `std::sort`. However, the +following code may have to deal with throttling for the remainder of the 20 us +startup period. With VQSort we have 8.6 us runtime plus up to 11.4 us throttled +code (potentially running at quarter speed) plus the remaining 3/4 of 11.4 for a +total of 28.6. With `std::sort` we have 19.5 us runtime plus 20 us of normal +subsequent code, or 39.5 us. Thus the overall speedup for the 20 us region plus +VQSort **shrinks to 1.4x**, and it is possible to imagine an actual slowdown for +sufficiently small inputs, when factoring in the throttling of subsequent code. +This unfortunate 'beggar thy neighbor' effect cannot be solved at the level of +individual building blocks such as a sort, and must instead be addressed at the +system level. For example: + +- vectorizing more and more parts of the code to amortize startup cost; +- relying on newer CPUs than Skylake (launched 2015!) which have little or no + AVX-512 startup overhead, such as Intel Icelake (2021) or AMD Zen4 (2022); +- ensuring sorts (or anything else using AVX-512) process at least 100 KiB + of data, such that the expected speedup outweighs any startup cost. + +Any of these solutions are sufficient to render AVX-512 startup overhead a +non-issue. + +### Comparison with Intel's x86-simd-sort and vxsort + +Our May 2022 paper compared performance with `ips4o` and `std::sort`. We now add +results for Intel's [x86-simd-sort](https://github.com/intel/x86-simd-sort), +released as open source around October 2022, and +[vxsort](https://github.com/damageboy/vxsort-cpp/tree/master). We find that +VQSort is generally about 1.4 times as fast as either, and in a few cases equal +or up to 2% slower. + +Note that vxsort was open-sourced around May 2020; we were unaware of it at the +time of writing because it had been published in the form of a blog series. We +imported both from Github on 2023-06-06 at about 10:15 UTC. Both are integrated +into our bench_sort, running on the same Linux OS and Xeon 6154 CPU mentioned +above. We use uniform random inputs, because vxsort and x86-simd-sort appear to +have much less robust handling of skewed input distributions. They choose the +pivot as the median of three keys, or of 64 bytes, respectively. By contrast, +VQSort draws a 384 byte sample and analyzes their distribution, which improves +load balance and prevents recursing into all-equal partitions. Lacking this, the +other algorithms are more vulnerable to worst-cases. Choosing uniform random +thus prevents disadvantaging the other algorithms. + +We sample performance across a range of input sizes and types: + +- To isolate the performance of the sorting networks used by all three + algorithms, we start with powers of two up to 128. VQSort is generally the + fastest for 64-bit keys with the following exceptions: tie with vxsort at + N=2 (537 MB/s), slower than vxsort at N=16 (2114 vs. 2147), tie with + x86-simd-sort at N=32 (2643 MB/s). Note that VQSort is about 1.6 times as + fast as both others for N=128; possibly because its 2D structure enables + larger networks. + +- The `kPow10` mode in bench_sort measures power of ten input sizes between + 10 and 100K. Note that this covers non-power of two sizes, as well as the + crossover point between sorting networks and Quicksort recursion. The + speedups of VQSort relative to x86-simd-sort range from 1.33 to 1.81 + (32-bit keys), and 1.25 to 1.68 (64-bit keys), with geomeans of 1.48 and + 1.44. The speedups of VQSort relative to vxsort range from 1.08 to 2.10 + (32-bit keys), and 1.00 to 1.47 (64-bit keys), with geomeans of 1.41 and + 1.20. Note that vxsort matches VQSort at 10 64-bit elements; in all other + cases, VQSort is strictly faster. + +- Finally, we study the effect of key type at a fixed input size of 10K + elements. x86-simd-sort requires AVX512-VBMI2 for int16, which our CPU does + not support. Also, both other algorithms do not support 128-bit keys, thus + we only consider 32/64-bit integer and float types. The results in MB/s are: + + |Type|VQSort|x86-simd-sort|vxsort| + |---|---|---|---| + |f32|**1551**| 798| 823| + |f64|**1773**|1147| 745| + |i32|**1509**|1042| 968| + |i64|**1365**|1043|1145| + + VQSort is the fastest for each type, in some cases even about twice as fast. + Interestingly, vxsort performs at its best on i64, whereas the others are at + their best for f64. A potential explanation is that this CPU can execute two + f64 min/max per cycle, but only one i64. + +In conclusion, VQSort is generally more efficient than vxsort and x86-simd-sort +across a range of input sizes and types. Occasionally, it is up to 2% slower, +but the geomean of its speedup (32-bit keys and power-of-ten sizes) vs. vxsort +is **1.41**, and **1.48** vs. x86-simd-sort. diff --git a/lib/highway/hwy/contrib/sort/algo-inl.h b/lib/highway/hwy/contrib/sort/algo-inl.h new file mode 100644 index 00000000000..654d97d53e5 --- /dev/null +++ b/lib/highway/hwy/contrib/sort/algo-inl.h @@ -0,0 +1,600 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Normal include guard for target-independent parts +#ifndef HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_ + +#include +#include + +#include // std::sort +#include // std::less, std::greater +#include + +#include "hwy/base.h" +#include "hwy/contrib/sort/vqsort.h" +#include "hwy/highway.h" +#include "hwy/print.h" + +// Third-party algorithms +#define HAVE_AVX2SORT 0 +#define HAVE_IPS4O 0 +// When enabling, consider changing max_threads (required for Table 1a) +#define HAVE_PARALLEL_IPS4O (HAVE_IPS4O && 1) +#define HAVE_PDQSORT 0 +#define HAVE_SORT512 0 +#define HAVE_VXSORT 0 +#if HWY_ARCH_X86 +#define HAVE_INTEL 0 +#else +#define HAVE_INTEL 0 +#endif + +#if HAVE_PARALLEL_IPS4O +#include // NOLINT +#endif + +#if HAVE_AVX2SORT +HWY_PUSH_ATTRIBUTES("avx2,avx") +#include "avx2sort.h" //NOLINT +HWY_POP_ATTRIBUTES +#endif +#if HAVE_IPS4O || HAVE_PARALLEL_IPS4O +#include "third_party/ips4o/include/ips4o.hpp" +#include "third_party/ips4o/include/ips4o/thread_pool.hpp" +#endif +#if HAVE_PDQSORT +#include "third_party/boost/allowed/sort/sort.hpp" +#endif +#if HAVE_SORT512 +#include "sort512.h" //NOLINT +#endif + +// vxsort is difficult to compile for multiple targets because it also uses +// .cpp files, and we'd also have to #undef its include guards. Instead, compile +// only for AVX2 or AVX3 depending on this macro. +#define VXSORT_AVX3 1 +#if HAVE_VXSORT +// inlined from vxsort_targets_enable_avx512 (must close before end of header) +#ifdef __GNUC__ +#ifdef __clang__ +#if VXSORT_AVX3 +#pragma clang attribute push(__attribute__((target("avx512f,avx512dq"))), \ + apply_to = any(function)) +#else +#pragma clang attribute push(__attribute__((target("avx2"))), \ + apply_to = any(function)) +#endif // VXSORT_AVX3 + +#else +#pragma GCC push_options +#if VXSORT_AVX3 +#pragma GCC target("avx512f,avx512dq") +#else +#pragma GCC target("avx2") +#endif // VXSORT_AVX3 +#endif +#endif + +#if VXSORT_AVX3 +#include "vxsort/machine_traits.avx512.h" +#else +#include "vxsort/machine_traits.avx2.h" +#endif // VXSORT_AVX3 +#include "vxsort/vxsort.h" +#ifdef __GNUC__ +#ifdef __clang__ +#pragma clang attribute pop +#else +#pragma GCC pop_options +#endif +#endif +#endif // HAVE_VXSORT + +namespace hwy { + +enum class Dist { kUniform8, kUniform16, kUniform32 }; + +static inline std::vector AllDist() { + // Also include lower-entropy distributions to test MaybePartitionTwoValue. + return {Dist::kUniform8, /*Dist::kUniform16,*/ Dist::kUniform32}; +} + +static inline const char* DistName(Dist dist) { + switch (dist) { + case Dist::kUniform8: + return "uniform8"; + case Dist::kUniform16: + return "uniform16"; + case Dist::kUniform32: + return "uniform32"; + } + return "unreachable"; +} + +template +class InputStats { + public: + void Notify(T value) { + min_ = HWY_MIN(min_, value); + max_ = HWY_MAX(max_, value); + // Converting to integer would truncate floats, multiplying to save digits + // risks overflow especially when casting, so instead take the sum of the + // bit representations as the checksum. + uint64_t bits = 0; + static_assert(sizeof(T) <= 8, "Expected a built-in type"); + CopyBytes(&value, &bits); // not same size + sum_ += bits; + count_ += 1; + } + + bool operator==(const InputStats& other) const { + char type_name[100]; + detail::TypeName(hwy::detail::MakeTypeInfo(), 1, type_name); + + if (count_ != other.count_) { + HWY_ABORT("Sort %s: count %d vs %d\n", type_name, + static_cast(count_), static_cast(other.count_)); + } + + if (min_ != other.min_ || max_ != other.max_) { + HWY_ABORT("Sort %s: minmax %f/%f vs %f/%f\n", type_name, + static_cast(min_), static_cast(max_), + static_cast(other.min_), + static_cast(other.max_)); + } + + // Sum helps detect duplicated/lost values + if (sum_ != other.sum_) { + HWY_ABORT("Sort %s: Sum mismatch %g %g; min %g max %g\n", type_name, + static_cast(sum_), static_cast(other.sum_), + static_cast(min_), static_cast(max_)); + } + + return true; + } + + private: + T min_ = hwy::HighestValue(); + T max_ = hwy::LowestValue(); + uint64_t sum_ = 0; + size_t count_ = 0; +}; + +enum class Algo { +#if HAVE_INTEL + kIntel, +#endif +#if HAVE_AVX2SORT + kSEA, +#endif +#if HAVE_IPS4O + kIPS4O, +#endif +#if HAVE_PARALLEL_IPS4O + kParallelIPS4O, +#endif +#if HAVE_PDQSORT + kPDQ, +#endif +#if HAVE_SORT512 + kSort512, +#endif +#if HAVE_VXSORT + kVXSort, +#endif + kStdSort, + kStdSelect, + kStdPartialSort, + kVQSort, + kVQPartialSort, + kVQSelect, + kHeapSort, + kHeapPartialSort, + kHeapSelect, +}; + +static inline bool IsVQ(Algo algo) { + return algo == Algo::kVQSort || algo == Algo::kVQPartialSort || + algo == Algo::kVQSelect; +} + +static inline bool IsSelect(Algo algo) { + return algo == Algo::kStdSelect || algo == Algo::kVQSelect || + algo == Algo::kHeapSelect; +} + +static inline bool IsPartialSort(Algo algo) { + return algo == Algo::kStdPartialSort || algo == Algo::kVQPartialSort || + algo == Algo::kHeapPartialSort; +} + +static inline Algo ReferenceAlgoFor(Algo algo) { + if (IsPartialSort(algo)) return Algo::kStdPartialSort; +#if HAVE_PDQSORT + return Algo::kPDQ; +#else + return Algo::kStdSort; +#endif +} + +static inline const char* AlgoName(Algo algo) { + switch (algo) { +#if HAVE_INTEL + case Algo::kIntel: + return "intel"; +#endif +#if HAVE_AVX2SORT + case Algo::kSEA: + return "sea"; +#endif +#if HAVE_IPS4O + case Algo::kIPS4O: + return "ips4o"; +#endif +#if HAVE_PARALLEL_IPS4O + case Algo::kParallelIPS4O: + return "par_ips4o"; +#endif +#if HAVE_PDQSORT + case Algo::kPDQ: + return "pdq"; +#endif +#if HAVE_SORT512 + case Algo::kSort512: + return "sort512"; +#endif +#if HAVE_VXSORT + case Algo::kVXSort: + return "vxsort"; +#endif + case Algo::kStdSort: + return "std"; + case Algo::kStdPartialSort: + return "std_partial"; + case Algo::kStdSelect: + return "std_select"; + case Algo::kVQSort: + return "vq"; + case Algo::kVQPartialSort: + return "vq_partial"; + case Algo::kVQSelect: + return "vq_select"; + case Algo::kHeapSort: + return "heap"; + case Algo::kHeapPartialSort: + return "heap_partial"; + case Algo::kHeapSelect: + return "heap_select"; + } + return "unreachable"; +} + +} // namespace hwy +#endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_ + +// Per-target +// clang-format off +#if defined(HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE) == defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE +#endif +// clang-format on + +#include "hwy/aligned_allocator.h" +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/traits128-inl.h" +#include "hwy/contrib/sort/vqsort-inl.h" // HeapSort + +HWY_BEFORE_NAMESPACE(); + +// Requires target pragma set by HWY_BEFORE_NAMESPACE +#if HAVE_INTEL && HWY_TARGET <= HWY_AVX3 +// #include "avx512-16bit-qsort.hpp" // requires AVX512-VBMI2 +#include "avx512-32bit-qsort.hpp" +#include "avx512-64bit-qsort.hpp" +#endif + +namespace hwy { +namespace HWY_NAMESPACE { + +#if HAVE_INTEL || HAVE_VXSORT // only supports ascending order +template +using OtherOrder = detail::OrderAscending; +#else +template +using OtherOrder = detail::OrderDescending; +#endif + +class Xorshift128Plus { + static HWY_INLINE uint64_t SplitMix64(uint64_t z) { + z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ull; + z = (z ^ (z >> 27)) * 0x94D049BB133111EBull; + return z ^ (z >> 31); + } + + public: + // Generates two vectors of 64-bit seeds via SplitMix64 and stores into + // `seeds`. Generating these afresh in each ChoosePivot is too expensive. + template + static void GenerateSeeds(DU64 du64, TFromD* HWY_RESTRICT seeds) { + seeds[0] = SplitMix64(0x9E3779B97F4A7C15ull); + for (size_t i = 1; i < 2 * Lanes(du64); ++i) { + seeds[i] = SplitMix64(seeds[i - 1]); + } + } + + // Need to pass in the state because vector cannot be class members. + template + static VU64 RandomBits(VU64& state0, VU64& state1) { + VU64 s1 = state0; + VU64 s0 = state1; + const VU64 bits = Add(s1, s0); + state0 = s0; + s1 = Xor(s1, ShiftLeft<23>(s1)); + state1 = Xor(s1, Xor(s0, Xor(ShiftRight<18>(s1), ShiftRight<5>(s0)))); + return bits; + } +}; + +template +Vec RandomValues(D d, VU64& s0, VU64& s1, const VU64 mask) { + const VU64 bits = Xorshift128Plus::RandomBits(s0, s1); + return BitCast(d, And(bits, mask)); +} + +// It is important to avoid denormals, which are flushed to zero by SIMD but not +// scalar sorts, and NaN, which may be ordered differently in scalar vs. SIMD. +template +Vec RandomValues(DF df, VU64& s0, VU64& s1, const VU64 mask) { + using TF = TFromD; + const RebindToUnsigned du; + using VU = Vec; + + const VU64 bits64 = And(Xorshift128Plus::RandomBits(s0, s1), mask); + +#if HWY_TARGET == HWY_SCALAR // Cannot repartition u64 to smaller types + using TU = MakeUnsigned; + const VU bits = Set(du, static_cast(GetLane(bits64) & LimitsMax())); +#else + const VU bits = BitCast(du, bits64); +#endif + // Avoid NaN/denormal by only generating values in [1, 2), i.e. random + // mantissas with the exponent taken from the representation of 1.0. + const VU k1 = BitCast(du, Set(df, TF{1.0})); + const VU mantissa_mask = Set(du, MantissaMask()); + const VU representation = OrAnd(k1, bits, mantissa_mask); + return BitCast(df, representation); +} + +template +Vec MaskForDist(DU64 du64, const Dist dist, size_t sizeof_t) { + switch (sizeof_t) { + case 2: + return Set(du64, (dist == Dist::kUniform8) ? 0x00FF00FF00FF00FFull + : 0xFFFFFFFFFFFFFFFFull); + case 4: + return Set(du64, (dist == Dist::kUniform8) ? 0x000000FF000000FFull + : (dist == Dist::kUniform16) ? 0x0000FFFF0000FFFFull + : 0xFFFFFFFFFFFFFFFFull); + case 8: + return Set(du64, (dist == Dist::kUniform8) ? 0x00000000000000FFull + : (dist == Dist::kUniform16) ? 0x000000000000FFFFull + : 0x00000000FFFFFFFFull); + default: + HWY_ABORT("Logic error"); + return Zero(du64); + } +} + +template +InputStats GenerateInput(const Dist dist, T* v, size_t num_lanes) { + SortTag du64; + using VU64 = Vec; + const size_t N64 = Lanes(du64); + auto seeds = hwy::AllocateAligned(2 * N64); + Xorshift128Plus::GenerateSeeds(du64, seeds.get()); + VU64 s0 = Load(du64, seeds.get()); + VU64 s1 = Load(du64, seeds.get() + N64); + +#if HWY_TARGET == HWY_SCALAR + const Sisd d; +#else + const Repartition d; +#endif + using V = Vec; + const size_t N = Lanes(d); + const VU64 mask = MaskForDist(du64, dist, sizeof(T)); + auto buf = hwy::AllocateAligned(N); + + size_t i = 0; + for (; i + N <= num_lanes; i += N) { + const V values = RandomValues(d, s0, s1, mask); + StoreU(values, d, v + i); + } + if (i < num_lanes) { + const V values = RandomValues(d, s0, s1, mask); + StoreU(values, d, buf.get()); + CopyBytes(buf.get(), v + i, (num_lanes - i) * sizeof(T)); + } + + InputStats input_stats; + for (size_t j = 0; j < num_lanes; ++j) { + input_stats.Notify(v[j]); + } + return input_stats; +} + +struct SharedState { +#if HAVE_PARALLEL_IPS4O + const unsigned max_threads = hwy::LimitsMax(); // 16 for Table 1a + ips4o::StdThreadPool pool{static_cast( + HWY_MIN(max_threads, std::thread::hardware_concurrency() / 2))}; +#endif +}; + +// Adapters from Run's num_keys to vqsort-inl.h num_lanes. +template +void CallHeapSort(KeyType* keys, const size_t num_keys, Order) { + const detail::MakeTraits st; + using LaneType = typename decltype(st)::LaneType; + return detail::HeapSort(st, reinterpret_cast(keys), + num_keys * st.LanesPerKey()); +} +template +void CallHeapPartialSort(KeyType* keys, const size_t num_keys, + const size_t k_keys, Order) { + const detail::MakeTraits st; + using LaneType = typename decltype(st)::LaneType; + detail::HeapPartialSort(st, reinterpret_cast(keys), + num_keys * st.LanesPerKey(), + k_keys * st.LanesPerKey()); +} +template +void CallHeapSelect(KeyType* keys, const size_t num_keys, const size_t k_keys, + Order) { + const detail::MakeTraits st; + using LaneType = typename decltype(st)::LaneType; + detail::HeapSelect(st, reinterpret_cast(keys), + num_keys * st.LanesPerKey(), k_keys * st.LanesPerKey()); +} + +template +void Run(Algo algo, KeyType* inout, size_t num_keys, SharedState& shared, + size_t /*thread*/, size_t k_keys, Order) { + const std::less less; + const std::greater greater; + + constexpr bool kAscending = Order::IsAscending(); + +#if !HAVE_PARALLEL_IPS4O + (void)shared; +#endif + + switch (algo) { +#if HAVE_INTEL && HWY_TARGET <= HWY_AVX3 + case Algo::kIntel: + return avx512_qsort(inout, static_cast(num_keys)); +#endif + +#if HAVE_AVX2SORT + case Algo::kSEA: + return avx2::quicksort(inout, static_cast(num_keys)); +#endif + +#if HAVE_IPS4O + case Algo::kIPS4O: + if (kAscending) { + return ips4o::sort(inout, inout + num_keys, less); + } else { + return ips4o::sort(inout, inout + num_keys, greater); + } +#endif + +#if HAVE_PARALLEL_IPS4O + case Algo::kParallelIPS4O: + if (kAscending) { + return ips4o::parallel::sort(inout, inout + num_keys, less, + shared.pool); + } else { + return ips4o::parallel::sort(inout, inout + num_keys, greater, + shared.pool); + } +#endif + +#if HAVE_SORT512 + case Algo::kSort512: + HWY_ABORT("not supported"); + // return Sort512::Sort(inout, num_keys); +#endif + +#if HAVE_PDQSORT + case Algo::kPDQ: + if (kAscending) { + return boost::sort::pdqsort_branchless(inout, inout + num_keys, less); + } else { + return boost::sort::pdqsort_branchless(inout, inout + num_keys, + greater); + } +#endif + +#if HAVE_VXSORT + case Algo::kVXSort: { +#if (VXSORT_AVX3 && HWY_TARGET != HWY_AVX3) || \ + (!VXSORT_AVX3 && HWY_TARGET != HWY_AVX2) + HWY_WARN("Do not call for target %s\n", hwy::TargetName(HWY_TARGET)); + return; +#else +#if VXSORT_AVX3 + vxsort::vxsort vx; +#else + vxsort::vxsort vx; +#endif + if (kAscending) { + return vx.sort(inout, inout + num_keys - 1); + } else { + HWY_WARN("Skipping VX - does not support descending order\n"); + return; + } +#endif // enabled for this target + } +#endif // HAVE_VXSORT + + case Algo::kStdSort: + if (kAscending) { + return std::sort(inout, inout + num_keys, less); + } else { + return std::sort(inout, inout + num_keys, greater); + } + case Algo::kStdPartialSort: + if (kAscending) { + return std::partial_sort(inout, inout + k_keys, inout + num_keys, less); + } else { + return std::partial_sort(inout, inout + k_keys, inout + num_keys, + greater); + } + case Algo::kStdSelect: + if (kAscending) { + return std::nth_element(inout, inout + k_keys, inout + num_keys, less); + } else { + return std::nth_element(inout, inout + k_keys, inout + num_keys, + greater); + } + + case Algo::kVQSort: + return VQSort(inout, num_keys, Order()); + case Algo::kVQPartialSort: + return VQPartialSort(inout, num_keys, k_keys, Order()); + case Algo::kVQSelect: + return VQSelect(inout, num_keys, k_keys, Order()); + + case Algo::kHeapSort: + return CallHeapSort(inout, num_keys, Order()); + case Algo::kHeapPartialSort: + return CallHeapPartialSort(inout, num_keys, k_keys, Order()); + case Algo::kHeapSelect: + return CallHeapSelect(inout, num_keys, k_keys, Order()); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE diff --git a/lib/highway/hwy/contrib/sort/bench_parallel.cc b/lib/highway/hwy/contrib/sort/bench_parallel.cc new file mode 100644 index 00000000000..66cc3ed7c1f --- /dev/null +++ b/lib/highway/hwy/contrib/sort/bench_parallel.cc @@ -0,0 +1,242 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Concurrent, independent sorts for generating more memory traffic and testing +// scalability when bandwidth-limited. If you want to use multiple threads for +// a single sort, you can use ips4o and integrate vqsort by calling it from +// `baseCaseSort` and increasing `IPS4OML_BASE_CASE_SIZE` to say 8192. + +#include +#include + +#include //NOLINT +#include +#include //NOLINT +#include //NOLINT +#include + +#include "hwy/timer.h" + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/bench_parallel.cc" //NOLINT +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/algo-inl.h" +#include "hwy/contrib/sort/result-inl.h" +#include "hwy/aligned_allocator.h" +// Last +#include "hwy/tests/test_util-inl.h" +// clang-format on + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +class ThreadPool { + public: + // Starts the given number of worker threads and blocks until they are ready. + explicit ThreadPool( + const size_t num_threads = std::thread::hardware_concurrency()) + : num_threads_(num_threads) { + HWY_ASSERT(num_threads_ > 0); + threads_.reserve(num_threads_); + for (size_t i = 0; i < num_threads_; ++i) { + threads_.emplace_back(ThreadFunc, this, i); + } + + WorkersReadyBarrier(); + } + + ThreadPool(const ThreadPool&) = delete; + ThreadPool& operator&(const ThreadPool&) = delete; + + // Waits for all threads to exit. + ~ThreadPool() { + StartWorkers(kWorkerExit); + + for (std::thread& thread : threads_) { + thread.join(); + } + } + + size_t NumThreads() const { return threads_.size(); } + + template + void RunOnThreads(size_t max_threads, const Func& func) { + task_ = &CallClosure; + data_ = &func; + StartWorkers(max_threads); + WorkersReadyBarrier(); + } + + private: + // After construction and between calls to Run, workers are "ready", i.e. + // waiting on worker_start_cv_. They are "started" by sending a "command" + // and notifying all worker_start_cv_ waiters. (That is why all workers + // must be ready/waiting - otherwise, the notification will not reach all of + // them and the main thread waits in vain for them to report readiness.) + using WorkerCommand = uint64_t; + + static constexpr WorkerCommand kWorkerWait = ~1ULL; + static constexpr WorkerCommand kWorkerExit = ~2ULL; + + // Calls a closure (lambda with captures). + template + static void CallClosure(const void* f, size_t thread) { + (*reinterpret_cast(f))(thread); + } + + void WorkersReadyBarrier() { + std::unique_lock lock(mutex_); + // Typically only a single iteration. + while (workers_ready_ != threads_.size()) { + workers_ready_cv_.wait(lock); + } + workers_ready_ = 0; + + // Safely handle spurious worker wakeups. + worker_start_command_ = kWorkerWait; + } + + // Precondition: all workers are ready. + void StartWorkers(const WorkerCommand worker_command) { + std::unique_lock lock(mutex_); + worker_start_command_ = worker_command; + // Workers will need this lock, so release it before they wake up. + lock.unlock(); + worker_start_cv_.notify_all(); + } + + static void ThreadFunc(ThreadPool* self, size_t thread) { + // Until kWorkerExit command received: + for (;;) { + std::unique_lock lock(self->mutex_); + // Notify main thread that this thread is ready. + if (++self->workers_ready_ == self->num_threads_) { + self->workers_ready_cv_.notify_one(); + } + RESUME_WAIT: + // Wait for a command. + self->worker_start_cv_.wait(lock); + const WorkerCommand command = self->worker_start_command_; + switch (command) { + case kWorkerWait: // spurious wakeup: + goto RESUME_WAIT; // lock still held, avoid incrementing ready. + case kWorkerExit: + return; // exits thread + default: + break; + } + + lock.unlock(); + // Command is the maximum number of threads that should run the task. + HWY_ASSERT(command < self->NumThreads()); + if (thread < command) { + self->task_(self->data_, thread); + } + } + } + + const size_t num_threads_; + + // Unmodified after ctor, but cannot be const because we call thread::join(). + std::vector threads_; + + std::mutex mutex_; // guards both cv and their variables. + std::condition_variable workers_ready_cv_; + size_t workers_ready_ = 0; + std::condition_variable worker_start_cv_; + WorkerCommand worker_start_command_; + + // Written by main thread, read by workers (after mutex lock/unlock). + std::function task_; // points to CallClosure + const void* data_; // points to caller's Func +}; + +template +void RunWithoutVerify(Traits st, const Dist dist, const size_t num_keys, + const Algo algo, SharedState& shared, size_t thread) { + using LaneType = typename Traits::LaneType; + using KeyType = typename Traits::KeyType; + using Order = typename Traits::Order; + const size_t num_lanes = num_keys * st.LanesPerKey(); + auto aligned = hwy::AllocateAligned(num_lanes); + + (void)GenerateInput(dist, aligned.get(), num_lanes); + + const Timestamp t0; + Run(algo, reinterpret_cast(aligned.get()), num_keys, shared, thread, + /*k_keys=*/0, Order()); + HWY_ASSERT(aligned[0] < aligned[num_lanes - 1]); +} + +void BenchParallel() { + // Not interested in benchmark results for other targets on x86 + if (HWY_ARCH_X86 && + (HWY_TARGET != HWY_AVX2 && HWY_TARGET != HWY_AVX3 && + HWY_TARGET != HWY_AVX3_ZEN4 && HWY_TARGET != HWY_AVX3_SPR)) { + return; + } + + ThreadPool pool; + const size_t NT = pool.NumThreads(); + + detail::SharedTraits>> st; + using KeyType = typename decltype(st)::KeyType; + const size_t num_keys = size_t{100} * 1000 * 1000; + +#if HAVE_IPS4O + const Algo algo = Algo::kIPS4O; +#else + const Algo algo = Algo::kVQSort; +#endif + const Dist dist = Dist::kUniform32; + + SharedState shared; + + std::vector results; + for (size_t nt = 1; nt < NT; nt += HWY_MAX(1, NT / 16)) { + Timestamp t0; + // Default capture because MSVC wants algo/dist but clang does not. + pool.RunOnThreads(nt, [=, &shared](size_t thread) { + RunWithoutVerify(st, dist, num_keys, algo, shared, thread); + }); + const double sec = SecondsSince(t0); + results.emplace_back(algo, dist, num_keys, nt, sec, sizeof(KeyType), + st.KeyString()); + results.back().Print(); + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +namespace { +HWY_BEFORE_TEST(BenchParallel); +HWY_EXPORT_AND_TEST_P(BenchParallel, BenchParallel); +HWY_AFTER_TEST(); +} // namespace +} // namespace hwy + +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/bench_sort.cc b/lib/highway/hwy/contrib/sort/bench_sort.cc new file mode 100644 index 00000000000..87ef7e210b0 --- /dev/null +++ b/lib/highway/hwy/contrib/sort/bench_sort.cc @@ -0,0 +1,473 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/bench_sort.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/algo-inl.h" +#include "hwy/contrib/sort/vqsort.h" +#include "hwy/contrib/sort/result-inl.h" +#include "hwy/contrib/sort/sorting_networks-inl.h" // SharedTraits +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/traits128-inl.h" +#include "hwy/tests/test_util-inl.h" +#include "hwy/nanobenchmark.h" +#include "hwy/timer.h" +#include "hwy/contrib/thread_pool/futex.h" // NanoSleep +// clang-format on + +// Mode for larger sorts because M1 is able to access more than the per-core +// share of L2, so 1M elements might still be in cache. +#define SORT_100M 0 + +#ifndef SORT_ONLY_COLD +#define SORT_ONLY_COLD 0 +#endif +#ifndef SORT_BENCH_BASE_AND_PARTITION +#define SORT_BENCH_BASE_AND_PARTITION (!SORT_ONLY_COLD && 0) +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +// Defined within HWY_ONCE, used by BenchAllSort. +extern int64_t first_sort_target; +extern int64_t first_cold_target; // for BenchAllColdSort + +namespace HWY_NAMESPACE { +namespace { +using detail::OrderAscending; +using detail::OrderDescending; +using detail::SharedTraits; +using detail::TraitsLane; + +#if HWY_TARGET != HWY_SCALAR +using detail::OrderAscending128; +using detail::OrderAscendingKV128; +using detail::Traits128; +#endif // HWY_TARGET != HWY_SCALAR + +HWY_NOINLINE void BenchAllColdSort() { + // Only run the best(first) enabled target + if (first_cold_target == 0) first_cold_target = HWY_TARGET; + if (HWY_TARGET != first_cold_target) { + return; + } + + char cpu100[100]; + if (!platform::HaveTimerStop(cpu100)) { + HWY_WARN("CPU '%s' does not support RDTSCP, skipping benchmark.\n", cpu100); + return; + } + + // Initialize random seeds +#if VQSORT_ENABLED + HWY_ASSERT(GetGeneratorState() != nullptr); // vqsort +#endif + RandomState rng(static_cast(Unpredictable1() * 129)); // this test + + using T = uint64_t; + constexpr size_t kSize = 10 * 1000; + HWY_ALIGN T items[kSize]; + + // Initialize array +#if 0 // optional: deliberate AVX-512 to verify VQSort performance improves + const ScalableTag d; + const RebindToSigned di; + const size_t N = Lanes(d); + size_t i = 0; + for (; i + N <= kSize; i += N) { + // Super-slow scatter so that we spend enough time to warm up SKX. + const Vec val = Set(d, static_cast(Unpredictable1())); + const Vec idx = + Iota(di, static_cast(Unpredictable1() - 1)); + ScatterIndex(val, d, items + i, idx); + } + for (; i < kSize; ++i) { + items[i] = static_cast(Unpredictable1()); + } +#else // scalar-only, verified with clang-16 + for (size_t i = 0; i < kSize; ++i) { + items[i] = static_cast(Unpredictable1()); + } +#endif + items[Random32(&rng) % kSize] = static_cast(Unpredictable1() + 1); + + const timer::Ticks t0 = timer::Start(); + const SortAscending order; +#if VQSORT_ENABLED && 1 // change to && 0 to switch to std::sort. + VQSort(items, kSize, order); +#else + SharedState shared; + Run(Algo::kStdSort, items, kSize, shared, /*thread=*/0, /*k_keys=*/0, order); +#endif + const timer::Ticks t1 = timer::Stop(); + + const double ticks = static_cast(t1 - t0); + const double elapsed = ticks / platform::InvariantTicksPerSecond(); + const double GBps = kSize * sizeof(T) * 1E-9 / elapsed; + + fprintf(stderr, "N=%zu GB/s=%.2f ns=%.1f random output: %g\n", kSize, GBps, + elapsed * 1E9, static_cast(items[Random32(&rng) % kSize])); + +#if SORT_ONLY_COLD + // Long enough for the CPU to switch off AVX-512 mode before the next run. + NanoSleep(100 * 1000 * 1000); +#endif +} + +#if (VQSORT_ENABLED && SORT_BENCH_BASE_AND_PARTITION) || HWY_IDE + +template +HWY_NOINLINE void BenchPartition() { + using LaneType = typename Traits::LaneType; + using KeyType = typename Traits::KeyType; + const SortTag d; + detail::SharedTraits st; + const Dist dist = Dist::kUniform8; + double sum = 0.0; + + constexpr size_t kLPK = st.LanesPerKey(); + HWY_ALIGN LaneType + buf[SortConstants::BufBytes(HWY_MAX_BYTES) / + sizeof(LaneType)]; + uint64_t* HWY_RESTRICT state = GetGeneratorState(); + + const size_t max_log2 = AdjustedLog2Reps(20); + for (size_t log2 = max_log2; log2 < max_log2 + 1; ++log2) { + const size_t num_lanes = 1ull << log2; + const size_t num_keys = num_lanes / kLPK; + auto aligned = hwy::AllocateAligned(num_lanes); + + std::vector seconds; + const size_t num_reps = (1ull << (14 - log2 / 2)) * 30; + for (size_t rep = 0; rep < num_reps; ++rep) { + (void)GenerateInput(dist, aligned.get(), num_lanes); + + // The pivot value can influence performance. Do exactly what vqsort will + // do so that the performance (influenced by prefetching and branch + // prediction) is likely to predict the actual performance inside vqsort. + detail::DrawSamples(d, st, aligned.get(), num_lanes, buf, state); + detail::SortSamples(d, st, buf); + auto pivot = detail::ChoosePivotByRank(d, st, buf); + + const Timestamp t0; + detail::Partition(d, st, aligned.get(), num_lanes - 1, pivot, buf); + seconds.push_back(SecondsSince(t0)); + // 'Use' the result to prevent optimizing out the partition. + sum += static_cast(aligned.get()[num_lanes / 2]); + } + + SortResult(Algo::kVQSort, dist, num_keys, 1, SummarizeMeasurements(seconds), + sizeof(KeyType), st.KeyString()) + .Print(); + } + HWY_ASSERT(sum != 999999); // Prevent optimizing out +} + +HWY_NOINLINE void BenchAllPartition() { + // Not interested in benchmark results for these targets + if (HWY_TARGET == HWY_SSSE3) { + return; + } + + BenchPartition>>(); + BenchPartition>>(); + BenchPartition>>(); + BenchPartition>(); + // BenchPartition>(); + BenchPartition>(); +} + +template +HWY_NOINLINE void BenchBase(std::vector& results) { + // Not interested in benchmark results for these targets + if (HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4) { + return; + } + + using LaneType = typename Traits::LaneType; + using KeyType = typename Traits::KeyType; + const SortTag d; + detail::SharedTraits st; + const Dist dist = Dist::kUniform32; + const Algo algo = Algo::kVQSort; + + const size_t N = Lanes(d); + constexpr size_t kLPK = st.LanesPerKey(); + const size_t num_lanes = SortConstants::BaseCaseNumLanes(N); + const size_t num_keys = num_lanes / kLPK; + auto keys = hwy::AllocateAligned(num_lanes); + auto buf = hwy::AllocateAligned(num_lanes + N); + + std::vector seconds; + double sum = 0; // prevents elision + constexpr size_t kMul = AdjustedReps(600); // ensures long enough to measure + + for (size_t rep = 0; rep < 30; ++rep) { + InputStats input_stats = + GenerateInput(dist, keys.get(), num_lanes); + + const Timestamp t0; + for (size_t i = 0; i < kMul; ++i) { + detail::BaseCase(d, st, keys.get(), num_lanes, buf.get()); + sum += static_cast(keys[0]); + } + seconds.push_back(SecondsSince(t0)); + // printf("%f\n", seconds.back()); + + SortOrderVerifier()(algo, input_stats, keys.get(), num_keys, + num_keys); + } + HWY_ASSERT(sum < 1E99); + results.emplace_back(algo, dist, num_keys * kMul, 1, + SummarizeMeasurements(seconds), sizeof(KeyType), + st.KeyString()); +} + +HWY_NOINLINE void BenchAllBase() { + // Not interested in benchmark results for these targets + if (HWY_TARGET == HWY_SSSE3) { + return; + } + + std::vector results; + BenchBase>>(results); + BenchBase>>(results); + BenchBase>(results); + for (const SortResult& r : results) { + r.Print(); + } +} + +#endif // VQSORT_ENABLED && SORT_BENCH_BASE_AND_PARTITION + +std::vector AlgoForBench() { + return { +#if HAVE_AVX2SORT + Algo::kSEA, +#endif +#if HAVE_PARALLEL_IPS4O + Algo::kParallelIPS4O, +#elif HAVE_IPS4O + Algo::kIPS4O, +#endif +#if HAVE_PDQSORT + Algo::kPDQ, +#endif +#if HAVE_SORT512 + Algo::kSort512, +#endif +// Only include if we're compiling for the target it supports. +#if HAVE_VXSORT && ((VXSORT_AVX3 && HWY_TARGET == HWY_AVX3) || \ + (!VXSORT_AVX3 && HWY_TARGET == HWY_AVX2)) + Algo::kVXSort, +#endif +// Only include if we're compiling for the target it supports. +#if HAVE_INTEL && HWY_TARGET <= HWY_AVX3 + Algo::kIntel, +#endif + +#if !HAVE_PARALLEL_IPS4O +#if !SORT_100M + // 10-20x slower, but that's OK for the default size when we are not + // testing the parallel nor 100M modes. + // Algo::kStdSort, +#endif + +#if VQSORT_ENABLED + Algo::kVQSort, +#endif +#endif // !HAVE_PARALLEL_IPS4O + }; +} + +template +HWY_NOINLINE void BenchSort(size_t num_keys) { + if (first_sort_target == 0) first_sort_target = HWY_TARGET; + + SharedState shared; + detail::SharedTraits st; + using Order = typename Traits::Order; + using LaneType = typename Traits::LaneType; + using KeyType = typename Traits::KeyType; + const size_t num_lanes = num_keys * st.LanesPerKey(); + auto aligned = hwy::AllocateAligned(num_lanes); + + const size_t reps = num_keys > 1000 * 1000 ? 10 : 30; + + for (Algo algo : AlgoForBench()) { + // Other algorithms don't depend on the vector instructions, so only run + // them for the first target. +#if !HAVE_VXSORT + if (algo != Algo::kVQSort && HWY_TARGET != first_sort_target) { + continue; + } +#endif + + for (Dist dist : AllDist()) { + std::vector seconds; + for (size_t rep = 0; rep < reps; ++rep) { + InputStats input_stats = + GenerateInput(dist, aligned.get(), num_lanes); + + const Timestamp t0; + Run(algo, HWY_RCAST_ALIGNED(KeyType*, aligned.get()), num_keys, shared, + /*thread=*/0, /*k_keys=*/0, Order()); + seconds.push_back(SecondsSince(t0)); + // printf("%f\n", seconds.back()); + + SortOrderVerifier()(algo, input_stats, aligned.get(), num_keys, + num_keys); + } + SortResult(algo, dist, num_keys, 1, SummarizeMeasurements(seconds), + sizeof(KeyType), st.KeyString()) + .Print(); + } // dist + } // algo +} + +enum class BenchmarkModes { + kDefault, + k1M, + k10K, + kAllSmall, + kSmallPow2, + kSmallPow2Between, // includes padding + kPow4, + kPow10 +}; + +std::vector SizesToBenchmark(BenchmarkModes mode) { + std::vector sizes; + switch (mode) { + case BenchmarkModes::kDefault: +#if HAVE_PARALLEL_IPS4O || SORT_100M + sizes.push_back(100 * 1000 * size_t{1000}); +#else + sizes.push_back(100); + sizes.push_back(100 * 1000); +#endif + break; + case BenchmarkModes::k1M: + sizes.push_back(1000 * 1000); + break; + case BenchmarkModes::k10K: + sizes.push_back(10 * 1000); + break; + + case BenchmarkModes::kAllSmall: + sizes.reserve(128); + for (size_t i = 1; i <= 128; ++i) { + sizes.push_back(i); + } + break; + case BenchmarkModes::kSmallPow2: + for (size_t size = 2; size <= 128; size *= 2) { + sizes.push_back(size); + } + break; + case BenchmarkModes::kSmallPow2Between: + for (size_t size = 2; size <= 128; size *= 2) { + sizes.push_back(3 * size / 2); + } + break; + + case BenchmarkModes::kPow4: + for (size_t size = 4; size <= 256 * 1024; size *= 4) { + sizes.push_back(size); + } + break; + case BenchmarkModes::kPow10: + for (size_t size = 10; size <= 100 * 1000; size *= 10) { + sizes.push_back(size); + } + break; + } + return sizes; +} + +HWY_NOINLINE void BenchAllSort() { + // Not interested in benchmark results for these targets. Note that SSE4 is + // numerically less than SSE2, hence it is the lower bound. + if (HWY_SSE4 <= HWY_TARGET && HWY_TARGET <= HWY_SSE2 && Unpredictable1()) { + return; + } +#if HAVE_INTEL + if (HWY_TARGET > HWY_AVX3) return; +#endif + + for (size_t num_keys : SizesToBenchmark(BenchmarkModes::kSmallPow2)) { +#if !HAVE_INTEL +#if HWY_HAVE_FLOAT16 + if (hwy::HaveFloat16()) { + BenchSort>>(num_keys); + } +#endif + BenchSort>>(num_keys); +#if HWY_HAVE_FLOAT64 + if (hwy::HaveFloat64()) { + // BenchSort>>(num_keys); + } +#endif +#endif // !HAVE_INTEL + // BenchSort>>(num_keys); + BenchSort>>(num_keys); + BenchSort>>(num_keys); + // BenchSort>>(num_keys); + // BenchSort>>(num_keys); + // BenchSort>>(num_keys); + +#if !HAVE_VXSORT && !HAVE_INTEL && HWY_TARGET != HWY_SCALAR + BenchSort>(num_keys); + BenchSort>(num_keys); +#endif + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace hwy { +int64_t first_sort_target = 0; // none run yet +int64_t first_cold_target = 0; // none run yet +HWY_BEFORE_TEST(BenchSort); +HWY_EXPORT_AND_TEST_P(BenchSort, BenchAllColdSort); +#if SORT_BENCH_BASE_AND_PARTITION +HWY_EXPORT_AND_TEST_P(BenchSort, BenchAllPartition); +HWY_EXPORT_AND_TEST_P(BenchSort, BenchAllBase); +#endif + +#if !SORT_ONLY_COLD // skip (warms up vector unit for next run) +HWY_EXPORT_AND_TEST_P(BenchSort, BenchAllSort); +#endif +HWY_AFTER_TEST(); +} // namespace hwy + +HWY_TEST_MAIN(); + +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/order.h b/lib/highway/hwy/contrib/sort/order.h new file mode 100644 index 00000000000..5aa60b5e77f --- /dev/null +++ b/lib/highway/hwy/contrib/sort/order.h @@ -0,0 +1,34 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tag arguments that determine the sort order. Used by both vqsort.h and the +// VQSortStatic in vqsort-inl.h. Moved to a separate header so that the latter +// can be used without pulling in the dllimport statements in vqsort.h. + +#ifndef HIGHWAY_HWY_CONTRIB_SORT_ORDER_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_ORDER_H_ + +namespace hwy { + +struct SortAscending { + static constexpr bool IsAscending() { return true; } +}; +struct SortDescending { + static constexpr bool IsAscending() { return false; } +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_SORT_ORDER_H_ diff --git a/lib/highway/hwy/contrib/sort/print_network.cc b/lib/highway/hwy/contrib/sort/print_network.cc new file mode 100644 index 00000000000..c72b545022e --- /dev/null +++ b/lib/highway/hwy/contrib/sort/print_network.cc @@ -0,0 +1,90 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "hwy/base.h" + +// Based on A.7 in "Entwurf und Implementierung vektorisierter +// Sortieralgorithmen" and code by Mark Blacher. +static void PrintMergeNetwork(int rows, int cols) { + printf("\n%d x %d:\n", rows, cols); + // Powers of two + HWY_ASSERT(rows != 0 && (rows & (rows - 1)) == 0); + HWY_ASSERT(cols != 0 && (cols & (cols - 1)) == 0); + HWY_ASSERT(rows >= 4); + HWY_ASSERT(cols >= 2); // otherwise no cross-column merging required + HWY_ASSERT(cols <= 16); // SortTraits lacks Reverse32 + + // Log(rows) times: sort half of the vectors with reversed groups of the + // other half. Group size halves until we are sorting adjacent vectors. + int group_size = rows; + int num_groups = 1; + for (; group_size >= 2; group_size /= 2, num_groups *= 2) { + // All vectors except those being reversed. Allows us to group the + // ReverseKeys and Sort2 operations, which is easier to read and may help + // in-order machines with high-latency ReverseKeys. + std::vector all_vi; + for (int group = 0; group < num_groups; ++group) { + for (int i = 0; i < group_size / 2; ++i) { + all_vi.push_back(group * group_size + i); + } + } + for (int vi : all_vi) { + const int vr = vi ^ (group_size - 1); + printf("v%x = st.ReverseKeys%d(d, v%x);\n", vr, cols, vr); + } + for (int vi : all_vi) { + const int vr = vi ^ (group_size - 1); + printf("st.Sort2(d, v%x, v%x);\n", vi, vr); + } + printf("\n"); + } + + // Now merge across columns in all vectors. + if (cols > 2) { + for (int i = 0; i < rows; ++i) { + printf("v%x = st.SortPairsReverse%d(d, v%x);\n", i, cols, i); + } + printf("\n"); + } + if (cols >= 16) { + for (int i = 0; i < rows; ++i) { + printf("v%x = st.SortPairsDistance4(d, v%x);\n", i, i); + } + printf("\n"); + } + if (cols >= 8) { + for (int i = 0; i < rows; ++i) { + printf("v%x = st.SortPairsDistance2(d, v%x);\n", i, i); + } + printf("\n"); + } + for (int i = 0; i < rows; ++i) { + printf("v%x = st.SortPairsDistance1(d, v%x);\n", i, i); + } + printf("\n"); +} + +int main(int /*argc*/, char** /*argv*/) { + PrintMergeNetwork(8, 2); + PrintMergeNetwork(8, 4); + PrintMergeNetwork(16, 4); + PrintMergeNetwork(16, 8); + PrintMergeNetwork(16, 16); + return 0; +} diff --git a/lib/highway/hwy/contrib/sort/result-inl.h b/lib/highway/hwy/contrib/sort/result-inl.h new file mode 100644 index 00000000000..a11b759b4e7 --- /dev/null +++ b/lib/highway/hwy/contrib/sort/result-inl.h @@ -0,0 +1,291 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/algo-inl.h" + +// Normal include guard for non-SIMD parts +#ifndef HIGHWAY_HWY_CONTRIB_SORT_RESULT_INL_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_RESULT_INL_H_ + +#include +#include +#include + +#include // std::sort +#include +#include + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#include "hwy/per_target.h" // DispatchedTarget +#include "hwy/targets.h" // TargetName + +namespace hwy { + +// Returns trimmed mean (we don't want to run an out-of-L3-cache sort often +// enough for the mode to be reliable). +static inline double SummarizeMeasurements(std::vector& seconds) { + std::sort(seconds.begin(), seconds.end()); + double sum = 0; + int count = 0; + const size_t num = seconds.size(); + for (size_t i = num / 4; i < num / 2; ++i) { + sum += seconds[i]; + count += 1; + } + return sum / count; +} + +struct SortResult { + SortResult() {} + SortResult(Algo algo_in, Dist dist_in, size_t num_keys_in, + size_t num_threads_in, double sec_in, size_t sizeof_key_in, + const char* key_name_in) + : target(DispatchedTarget()), + algo(algo_in), + dist(dist_in), + num_keys(num_keys_in), + num_threads(num_threads_in), + sec(sec_in), + sizeof_key(sizeof_key_in), + key_name(key_name_in) {} + + void Print() const { + const double bytes = static_cast(num_keys) * + static_cast(num_threads) * + static_cast(sizeof_key); + printf("%10s: %12s: %7s: %9s: %05g %4.0f MB/s (%2zu threads)\n", + hwy::TargetName(target), AlgoName(algo), key_name.c_str(), + DistName(dist), static_cast(num_keys), bytes * 1E-6 / sec, + num_threads); + } + + int64_t target; + Algo algo; + Dist dist; + size_t num_keys = 0; + size_t num_threads = 0; + double sec = 0.0; + size_t sizeof_key = 0; + std::string key_name; +}; + +} // namespace hwy +#endif // HIGHWAY_HWY_CONTRIB_SORT_RESULT_INL_H_ + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Copies the input, and compares results to that of a reference algorithm. +template +class ReferenceSortVerifier { + using LaneType = typename Traits::LaneType; + using KeyType = typename Traits::KeyType; + using Order = typename Traits::Order; + static constexpr bool kAscending = Order::IsAscending(); + static constexpr size_t kLPK = Traits().LanesPerKey(); + + public: + ReferenceSortVerifier(const LaneType* in_lanes, size_t num_lanes) { + num_lanes_ = num_lanes; + num_keys_ = num_lanes / kLPK; + in_lanes_ = hwy::AllocateAligned(num_lanes); + HWY_ASSERT(in_lanes_); + CopyBytes(in_lanes, in_lanes_.get(), num_lanes * sizeof(LaneType)); + } + + // For full sorts, k_keys == num_keys. + void operator()(Algo algo, const LaneType* out_lanes, size_t k_keys) { + SharedState shared; + const Traits st; + const CappedTag d; + + HWY_ASSERT(hwy::IsAligned(in_lanes_.get(), sizeof(KeyType))); + KeyType* in_keys = HWY_RCAST_ALIGNED(KeyType*, in_lanes_.get()); + + char caption[10]; + const char* algo_type = IsPartialSort(algo) ? "PartialSort" : "Sort"; + + HWY_ASSERT(k_keys <= num_keys_); + Run(ReferenceAlgoFor(algo), in_keys, num_keys_, shared, /*thread=*/0, + k_keys, Order()); + + if (IsSelect(algo)) { + // Print lanes centered around k_keys. + if (VQSORT_PRINT >= 3) { + const size_t begin_lane = k_keys < 3 ? 0 : (k_keys - 3) * kLPK; + const size_t end_lane = HWY_MIN(num_lanes_, (k_keys + 3) * kLPK); + fprintf(stderr, "\nExpected:\n"); + for (size_t i = begin_lane; i < end_lane; i += kLPK) { + snprintf(caption, sizeof(caption), "%4zu ", i / kLPK); + Print(d, caption, st.SetKey(d, &in_lanes_[i])); + } + fprintf(stderr, "\n\nActual:\n"); + for (size_t i = begin_lane; i < end_lane; i += kLPK) { + snprintf(caption, sizeof(caption), "%4zu ", i / kLPK); + Print(d, caption, st.SetKey(d, &out_lanes[i])); + } + fprintf(stderr, "\n\n"); + } + + // At k_keys: should be equivalent, i.e. neither a < b nor b < a. + // SortOrderVerifier will also check the ordering of the rest of the keys. + const size_t k = k_keys * kLPK; + if (st.Compare1(&in_lanes_[k], &out_lanes[k]) || + st.Compare1(&out_lanes[k], &in_lanes_[k])) { + Print(d, "Expected", st.SetKey(d, &in_lanes_[k])); + Print(d, " Actual", st.SetKey(d, &out_lanes[k])); + HWY_ABORT("Select %s asc=%d: mismatch at k_keys=%zu, num_keys=%zu\n", + st.KeyString(), kAscending, k_keys, num_keys_); + } + } else { + if (VQSORT_PRINT >= 3) { + const size_t lanes_to_print = HWY_MIN(40, k_keys * kLPK); + fprintf(stderr, "\nExpected:\n"); + for (size_t i = 0; i < lanes_to_print; i += kLPK) { + snprintf(caption, sizeof(caption), "%4zu ", i / kLPK); + Print(d, caption, st.SetKey(d, &in_lanes_[i])); + } + fprintf(stderr, "\n\nActual:\n"); + for (size_t i = 0; i < lanes_to_print; i += kLPK) { + snprintf(caption, sizeof(caption), "%4zu ", i / kLPK); + Print(d, caption, st.SetKey(d, &out_lanes[i])); + } + fprintf(stderr, "\n\n"); + } + + // Full or partial sort: all elements up to k_keys are equivalent to the + // reference sort. SortOrderVerifier also checks the output's ordering. + for (size_t i = 0; i < k_keys * kLPK; i += kLPK) { + // All up to k_keys should be equivalent, i.e. neither a < b nor b < a. + if (st.Compare1(&in_lanes_[i], &out_lanes[i]) || + st.Compare1(&out_lanes[i], &in_lanes_[i])) { + Print(d, "Expected", st.SetKey(d, &in_lanes_[i])); + Print(d, " Actual", st.SetKey(d, &out_lanes[i])); + HWY_ABORT("%s %s asc=%d: mismatch at %zu, k_keys=%zu, num_keys=%zu\n", + algo_type, st.KeyString(), kAscending, i / kLPK, k_keys, + num_keys_); + } + } + } + } + + private: + hwy::AlignedFreeUniquePtr in_lanes_; + size_t num_lanes_; + size_t num_keys_; +}; + +// Faster than ReferenceSortVerifier, for use in bench_sort. Only verifies +// order, without running a slow reference sorter. This means it can't verify +// Select places the correct key at `k_keys`, nor that input and output keys are +// the same. +template +class SortOrderVerifier { + using LaneType = typename Traits::LaneType; + using Order = typename Traits::Order; + static constexpr bool kAscending = Order::IsAscending(); + static constexpr size_t kLPK = Traits().LanesPerKey(); + + public: + void operator()(Algo algo, const InputStats& input_stats, + const LaneType* output, size_t num_keys, size_t k_keys) { + if (IsSelect(algo)) { + CheckSelectOrder(input_stats, output, num_keys, k_keys); + } else { + CheckSortedOrder(algo, input_stats, output, num_keys, k_keys); + } + } + + private: + // For full or partial sorts: ensures keys are in sorted order. + void CheckSortedOrder(const Algo algo, + const InputStats& input_stats, + const LaneType* output, const size_t num_keys, + const size_t k_keys) { + const Traits st; + const CappedTag d; + const size_t num_lanes = num_keys * kLPK; + const size_t k = k_keys * kLPK; + const char* algo_type = IsPartialSort(algo) ? "PartialSort" : "Sort"; + + InputStats output_stats; + // Even for partial sorts, loop over all keys to verify none disappeared. + for (size_t i = 0; i < num_lanes - kLPK; i += kLPK) { + output_stats.Notify(output[i]); + if (kLPK == 2) output_stats.Notify(output[i + 1]); + + // Only check the first k_keys (== num_keys for a full sort). + // Reverse order instead of checking !Compare1 so we accept equal keys. + if (i < k - kLPK && st.Compare1(output + i + kLPK, output + i)) { + Print(d, " cur", st.SetKey(d, &output[i])); + Print(d, "next", st.SetKey(d, &output[i + kLPK])); + HWY_ABORT( + "%s %s asc=%d: wrong order at %zu, k_keys=%zu, num_keys=%zu\n", + algo_type, st.KeyString(), kAscending, i / kLPK, k_keys, num_keys); + } + } + output_stats.Notify(output[num_lanes - kLPK]); + if (kLPK == 2) output_stats.Notify(output[num_lanes - kLPK + 1]); + + HWY_ASSERT(input_stats == output_stats); + } + + // Ensures keys below index k_keys are less, and all above are greater. + void CheckSelectOrder(const InputStats& input_stats, + const LaneType* output, const size_t num_keys, + const size_t k_keys) { + const Traits st; + const CappedTag d; + const size_t num_lanes = num_keys * kLPK; + const size_t k = k_keys * kLPK; + + InputStats output_stats; + for (size_t i = 0; i < num_lanes - kLPK; i += kLPK) { + output_stats.Notify(output[i]); + if (kLPK == 2) output_stats.Notify(output[i + 1]); + // Reverse order instead of checking !Compare1 so we accept equal keys. + if (i < k ? st.Compare1(output + k, output + i) + : st.Compare1(output + i, output + k)) { + Print(d, "cur", st.SetKey(d, &output[i])); + Print(d, "kth", st.SetKey(d, &output[k])); + HWY_ABORT( + "Select %s asc=%d: wrong order at %zu, k_keys=%zu, num_keys=%zu\n", + st.KeyString(), kAscending, i / kLPK, k_keys, num_keys); + } + } + output_stats.Notify(output[num_lanes - kLPK]); + if (kLPK == 2) output_stats.Notify(output[num_lanes - kLPK + 1]); + + HWY_ASSERT(input_stats == output_stats); + } +}; + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_RESULT_TOGGLE diff --git a/lib/highway/hwy/contrib/sort/shared-inl.h b/lib/highway/hwy/contrib/sort/shared-inl.h new file mode 100644 index 00000000000..90886aa0ba2 --- /dev/null +++ b/lib/highway/hwy/contrib/sort/shared-inl.h @@ -0,0 +1,184 @@ +// Copyright 2021 Google LLC +// Copyright 2025 Arm Limited and/or its affiliates +// SPDX-License-Identifier: Apache-2.0 +// SPDX-License-Identifier: BSD-3-Clause +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Definitions shared between vqsort-inl and sorting_networks-inl. + +// Normal include guard for target-independent parts +#ifndef HIGHWAY_HWY_CONTRIB_SORT_SHARED_INL_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_SHARED_INL_H_ + +#include "hwy/base.h" + +namespace hwy { + +// Based on https://github.com/numpy/numpy/issues/16313#issuecomment-641897028 +static HWY_INLINE uint64_t RandomBits(uint64_t* HWY_RESTRICT state) { + const uint64_t a = state[0]; + const uint64_t b = state[1]; + const uint64_t w = state[2] + 1; + const uint64_t next = a ^ w; + state[0] = (b + (b << 3)) ^ (b >> 11); + const uint64_t rot = (b << 24) | (b >> 40); + state[1] = rot + next; + state[2] = w; + return next; +} + +// Internal constants - these are to avoid magic numbers/literals and cannot be +// changed without also changing the associated code. +struct SortConstants { + // SortingNetwork reshapes its input into a matrix. This is the maximum number + // of *lanes* per vector. Must be at least 8 because SortSamples assumes the + // sorting network can handle 128 bytes with 8 rows, so 16 bytes per vector, + // which means 8 lanes for 16-bit types. +#if HWY_COMPILER_MSVC || HWY_IS_DEBUG_BUILD + static constexpr size_t kMaxCols = 8; // avoid build timeout/stack overflow +#else + static constexpr size_t kMaxCols = 16; // enough for u32 in 512-bit vector +#endif + + // 16 rows is a compromise between using the 32 AVX-512/SVE/RVV registers, + // fitting within 16 AVX2 registers with only a few spills, keeping BaseCase + // code size reasonable, and minimizing the extra logN factor for larger + // networks (for which only loose upper bounds on size are known). + static constexpr size_t kMaxRows = 16; + + // Template argument ensures there is no actual division instruction. + template + static constexpr HWY_INLINE size_t BaseCaseNumLanes(size_t N) { + // We use 8, 8x2, 8x4, and 16x{4..} networks, in units of keys. For N/kLPK + // < 4, we cannot use the 16-row networks. + return (((N / kLPK) >= 4) ? kMaxRows : 8) * HWY_MIN(N, kMaxCols); + } + + // Unrolling is important (pipelining and amortizing branch mispredictions); + // 2x is sufficient to reach full memory bandwidth on SKX in Partition, but + // somewhat slower for sorting than 4x. + // + // To change, must also update left + 3 * N etc. in the loop. + static constexpr size_t kPartitionUnroll = 4; + + // Chunk := group of keys loaded for sampling a pivot. Matches the typical + // cache line size of 64 bytes to get maximum benefit per L2 miss. Sort() + // ensures vectors are no larger than that, so this can be independent of the + // vector size and thus constexpr. + static constexpr HWY_INLINE size_t LanesPerChunk(size_t sizeof_t) { + return 64 / sizeof_t; + } + + template + static constexpr HWY_INLINE size_t SampleLanes() { + return 2 * LanesPerChunk(sizeof(T)); // Stored samples + } + + static constexpr HWY_INLINE size_t PartitionBufNum(size_t N) { + // The main loop reads kPartitionUnroll vectors, and first loads from + // both left and right beforehand, so it requires 2 * kPartitionUnroll + // vectors. To handle amounts between that and BaseCaseNumLanes(), we + // partition up 3 * kPartitionUnroll + 1 vectors into a two-part buffer. + return 2 * (3 * kPartitionUnroll + 1) * N; + } + + // Max across the three buffer usages. + template + static constexpr HWY_INLINE size_t BufNum(size_t N) { + // BaseCase may write one padding vector, and SortSamples uses the space + // after samples as the buffer. + return HWY_MAX(SampleLanes() + BaseCaseNumLanes(N) + N, + PartitionBufNum(N)); + } + + // Translates vector_size to lanes and returns size in bytes. + template + static constexpr HWY_INLINE size_t BufBytes(size_t vector_size) { + return BufNum(vector_size / sizeof(T)) * sizeof(T); + } + + // Returns max for any type. + template + static constexpr HWY_INLINE size_t MaxBufBytes(size_t vector_size) { + // If 2 lanes per key, it's a 128-bit key with u64 lanes. + return kLPK == 2 ? BufBytes(vector_size) + : HWY_MAX((BufBytes(vector_size)), + HWY_MAX((BufBytes(vector_size)), + (BufBytes(vector_size)))); + } +}; + +static_assert(SortConstants::MaxBufBytes<1>(64) <= 1664, "Unexpectedly high"); +static_assert(SortConstants::MaxBufBytes<2>(64) <= 1664, "Unexpectedly high"); + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_SORT_SHARED_INL_H_ + +// Per-target +// clang-format off +#if defined(HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE) == defined(HWY_TARGET_TOGGLE) // NOLINT +// clang-format on +#ifdef HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE +#endif + +#include "hwy/highway.h" + +// vqsort isn't available on HWY_SCALAR, and builds time out on MSVC opt and +// Armv7 debug, and Armv8 GCC 11 asan hits an internal compiler error likely +// due to https://gcc.gnu.org/bugzilla/show_bug.cgi?id=97696. Armv8 Clang +// hwasan/msan/tsan/asan also fail to build SVE (b/335157772), and SVE2 +// hits a compiler crash, though SVE2_128 is fine. +#undef VQSORT_ENABLED +#undef VQSORT_COMPILER_COMPATIBLE + +#if (HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD) || \ + (HWY_ARCH_ARM_V7 && HWY_IS_DEBUG_BUILD) || \ + (HWY_ARCH_ARM_A64 && HWY_IS_ASAN) || \ + (HWY_ARCH_RISCV && HWY_COMPILER_GCC_ACTUAL < 1400) +#define VQSORT_COMPILER_COMPATIBLE 0 +#else +#define VQSORT_COMPILER_COMPATIBLE 1 +#endif + +#if (HWY_TARGET == HWY_SCALAR) || !VQSORT_COMPILER_COMPATIBLE || \ + ((HWY_TARGET & HWY_ALL_SVE) && defined(SAFESTACK_SANITIZER)) || \ + ((HWY_TARGET & HWY_ALL_SVE) && HWY_HAVE_SCALABLE) || \ + defined(HWY_DISABLE_VQSORT) +#define VQSORT_ENABLED 0 +#else +#define VQSORT_ENABLED 1 +#endif + +namespace hwy { +namespace HWY_NAMESPACE { + +// Default tag / vector width selector. +#if HWY_TARGET == HWY_RVV +// Use LMUL = 1/2; for SEW=64 this ends up emulated via VSETVLI. +template +using SortTag = ScalableTag; +#else +template +using SortTag = ScalableTag; +#endif + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_SORT_SHARED_TOGGLE diff --git a/lib/highway/hwy/contrib/sort/sorting_networks-inl.h b/lib/highway/hwy/contrib/sort/sorting_networks-inl.h new file mode 100644 index 00000000000..7eabc377033 --- /dev/null +++ b/lib/highway/hwy/contrib/sort/sorting_networks-inl.h @@ -0,0 +1,913 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE +#endif + +#include "hwy/contrib/sort/shared-inl.h" // SortConstants +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +#if VQSORT_ENABLED + +using Constants = hwy::SortConstants; + +// ------------------------------ SharedTraits + +// Code shared between all traits. It's unclear whether these can profitably be +// specialized for Lane vs Block, or optimized like SortPairsDistance1 using +// Compare/DupOdd. +template +struct SharedTraits : public Base { + using SharedTraitsForSortingNetwork = + SharedTraits; + + // Conditionally swaps lane 0 with 2, 1 with 3 etc. + template + HWY_INLINE Vec SortPairsDistance2(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->SwapAdjacentPairs(d, v); + base->Sort2(d, v, swapped); + return base->OddEvenPairs(d, swapped, v); + } + + // Swaps with the vector formed by reversing contiguous groups of 8 keys. + template + HWY_INLINE Vec SortPairsReverse8(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys8(d, v); + base->Sort2(d, v, swapped); + return base->OddEvenQuads(d, swapped, v); + } + + // Swaps with the vector formed by reversing contiguous groups of 8 keys. + template + HWY_INLINE Vec SortPairsReverse16(D d, Vec v) const { + const Base* base = static_cast(this); + static_assert(Constants::kMaxCols <= 16, "Need actual Reverse16"); + Vec swapped = base->ReverseKeys(d, v); + base->Sort2(d, v, swapped); + return ConcatUpperLower(d, swapped, v); // 8 = half of the vector + } +}; + +// ------------------------------ Sorting network + +// Sorting networks for independent columns in 2, 4 and 8 vectors from +// https://bertdobbelaere.github.io/sorting_networks.html. + +template > +HWY_INLINE void Sort2(D d, Traits st, V& v0, V& v1) { + st.Sort2(d, v0, v1); +} + +template > +HWY_INLINE void Sort4(D d, Traits st, V& v0, V& v1, V& v2, V& v3) { + st.Sort2(d, v0, v2); + st.Sort2(d, v1, v3); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v1, v2); +} + +template > +HWY_INLINE void Sort8(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, V& v5, + V& v6, V& v7) { + st.Sort2(d, v0, v2); + st.Sort2(d, v1, v3); + st.Sort2(d, v4, v6); + st.Sort2(d, v5, v7); + + st.Sort2(d, v0, v4); + st.Sort2(d, v1, v5); + st.Sort2(d, v2, v6); + st.Sort2(d, v3, v7); + + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + + st.Sort2(d, v2, v4); + st.Sort2(d, v3, v5); + + st.Sort2(d, v1, v4); + st.Sort2(d, v3, v6); + + st.Sort2(d, v1, v2); + st.Sort2(d, v3, v4); + st.Sort2(d, v5, v6); +} + +// (Green's irregular) sorting network for independent columns in 16 vectors. +template > +HWY_INLINE void Sort16(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, V& v5, + V& v6, V& v7, V& v8, V& v9, V& va, V& vb, V& vc, V& vd, + V& ve, V& vf) { + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + st.Sort2(d, v0, v2); + st.Sort2(d, v1, v3); + st.Sort2(d, v4, v6); + st.Sort2(d, v5, v7); + st.Sort2(d, v8, va); + st.Sort2(d, v9, vb); + st.Sort2(d, vc, ve); + st.Sort2(d, vd, vf); + st.Sort2(d, v0, v4); + st.Sort2(d, v1, v5); + st.Sort2(d, v2, v6); + st.Sort2(d, v3, v7); + st.Sort2(d, v8, vc); + st.Sort2(d, v9, vd); + st.Sort2(d, va, ve); + st.Sort2(d, vb, vf); + st.Sort2(d, v0, v8); + st.Sort2(d, v1, v9); + st.Sort2(d, v2, va); + st.Sort2(d, v3, vb); + st.Sort2(d, v4, vc); + st.Sort2(d, v5, vd); + st.Sort2(d, v6, ve); + st.Sort2(d, v7, vf); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v3, vc); + st.Sort2(d, v7, vb); + st.Sort2(d, vd, ve); + st.Sort2(d, v4, v8); + st.Sort2(d, v1, v2); + st.Sort2(d, v1, v4); + st.Sort2(d, v7, vd); + st.Sort2(d, v2, v8); + st.Sort2(d, vb, ve); + st.Sort2(d, v2, v4); + st.Sort2(d, v5, v6); + st.Sort2(d, v9, va); + st.Sort2(d, vb, vd); + st.Sort2(d, v3, v8); + st.Sort2(d, v7, vc); + st.Sort2(d, v3, v5); + st.Sort2(d, v6, v8); + st.Sort2(d, v7, v9); + st.Sort2(d, va, vc); + st.Sort2(d, v3, v4); + st.Sort2(d, v5, v6); + st.Sort2(d, v7, v8); + st.Sort2(d, v9, va); + st.Sort2(d, vb, vc); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); +} + +// ------------------------------ Merging networks + +// Blacher's hybrid bitonic/odd-even networks, generated by print_network.cc. +// For acceptable performance, these must be inlined, otherwise vectors are +// loaded from the stack. The kKeysPerVector allows calling from generic code +// but skipping the functions when vectors have too few lanes for +// st.SortPairsDistance1 to compile. `if constexpr` in the caller would also +// work, but is not available in C++11. We write out the (unused) argument types +// rather than `...` because GCC 9 (but not 10) fails to compile with `...`. +// TODO: use C++17. + +template +HWY_INLINE void Merge8x2(D, Traits, V, V, V, V, V, V, V, V) {} +template +HWY_INLINE void Merge8x4(D, Traits, V, V, V, V, V, V, V, V) {} + +template +HWY_INLINE void Merge16x2(D, Traits, V, V, V, V, V, V, V, V, V, V, V, V, V, V, + V, V) {} +template +HWY_INLINE void Merge16x4(D, Traits, V, V, V, V, V, V, V, V, V, V, V, V, V, V, + V, V) {} +template +HWY_INLINE void Merge16x8(D, Traits, V, V, V, V, V, V, V, V, V, V, V, V, V, V, + V, V) {} +template +HWY_INLINE void Merge16x16(D, Traits, V, V, V, V, V, V, V, V, V, V, V, V, V, V, + V, V) {} + +template , + HWY_IF_LANES_GT(kKeysPerVector, 1)> +HWY_INLINE void Merge8x2(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, + V& v5, V& v6, V& v7) { + v7 = st.ReverseKeys2(d, v7); + v6 = st.ReverseKeys2(d, v6); + v5 = st.ReverseKeys2(d, v5); + v4 = st.ReverseKeys2(d, v4); + st.Sort2(d, v0, v7); + st.Sort2(d, v1, v6); + st.Sort2(d, v2, v5); + st.Sort2(d, v3, v4); + + v3 = st.ReverseKeys2(d, v3); + v2 = st.ReverseKeys2(d, v2); + v7 = st.ReverseKeys2(d, v7); + v6 = st.ReverseKeys2(d, v6); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + + v1 = st.ReverseKeys2(d, v1); + v3 = st.ReverseKeys2(d, v3); + v5 = st.ReverseKeys2(d, v5); + v7 = st.ReverseKeys2(d, v7); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); +} + +template , + HWY_IF_LANES_GT(kKeysPerVector, 2)> +HWY_INLINE void Merge8x4(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, + V& v5, V& v6, V& v7) { + v7 = st.ReverseKeys4(d, v7); + v6 = st.ReverseKeys4(d, v6); + v5 = st.ReverseKeys4(d, v5); + v4 = st.ReverseKeys4(d, v4); + st.Sort2(d, v0, v7); + st.Sort2(d, v1, v6); + st.Sort2(d, v2, v5); + st.Sort2(d, v3, v4); + + v3 = st.ReverseKeys4(d, v3); + v2 = st.ReverseKeys4(d, v2); + v7 = st.ReverseKeys4(d, v7); + v6 = st.ReverseKeys4(d, v6); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + + v1 = st.ReverseKeys4(d, v1); + v3 = st.ReverseKeys4(d, v3); + v5 = st.ReverseKeys4(d, v5); + v7 = st.ReverseKeys4(d, v7); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + + v0 = st.SortPairsReverse4(d, v0); + v1 = st.SortPairsReverse4(d, v1); + v2 = st.SortPairsReverse4(d, v2); + v3 = st.SortPairsReverse4(d, v3); + v4 = st.SortPairsReverse4(d, v4); + v5 = st.SortPairsReverse4(d, v5); + v6 = st.SortPairsReverse4(d, v6); + v7 = st.SortPairsReverse4(d, v7); + + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); +} + +// Only used by the now-deprecated SortingNetwork(). +template , + HWY_IF_LANES_GT(kKeysPerVector, 1)> +HWY_INLINE void Merge16x2(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, + V& v5, V& v6, V& v7, V& v8, V& v9, V& va, V& vb, + V& vc, V& vd, V& ve, V& vf) { + vf = st.ReverseKeys2(d, vf); + ve = st.ReverseKeys2(d, ve); + vd = st.ReverseKeys2(d, vd); + vc = st.ReverseKeys2(d, vc); + vb = st.ReverseKeys2(d, vb); + va = st.ReverseKeys2(d, va); + v9 = st.ReverseKeys2(d, v9); + v8 = st.ReverseKeys2(d, v8); + st.Sort2(d, v0, vf); + st.Sort2(d, v1, ve); + st.Sort2(d, v2, vd); + st.Sort2(d, v3, vc); + st.Sort2(d, v4, vb); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v7, v8); + + v7 = st.ReverseKeys2(d, v7); + v6 = st.ReverseKeys2(d, v6); + v5 = st.ReverseKeys2(d, v5); + v4 = st.ReverseKeys2(d, v4); + vf = st.ReverseKeys2(d, vf); + ve = st.ReverseKeys2(d, ve); + vd = st.ReverseKeys2(d, vd); + vc = st.ReverseKeys2(d, vc); + st.Sort2(d, v0, v7); + st.Sort2(d, v1, v6); + st.Sort2(d, v2, v5); + st.Sort2(d, v3, v4); + st.Sort2(d, v8, vf); + st.Sort2(d, v9, ve); + st.Sort2(d, va, vd); + st.Sort2(d, vb, vc); + + v3 = st.ReverseKeys2(d, v3); + v2 = st.ReverseKeys2(d, v2); + v7 = st.ReverseKeys2(d, v7); + v6 = st.ReverseKeys2(d, v6); + vb = st.ReverseKeys2(d, vb); + va = st.ReverseKeys2(d, va); + vf = st.ReverseKeys2(d, vf); + ve = st.ReverseKeys2(d, ve); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + st.Sort2(d, v8, vb); + st.Sort2(d, v9, va); + st.Sort2(d, vc, vf); + st.Sort2(d, vd, ve); + + v1 = st.ReverseKeys2(d, v1); + v3 = st.ReverseKeys2(d, v3); + v5 = st.ReverseKeys2(d, v5); + v7 = st.ReverseKeys2(d, v7); + v9 = st.ReverseKeys2(d, v9); + vb = st.ReverseKeys2(d, vb); + vd = st.ReverseKeys2(d, vd); + vf = st.ReverseKeys2(d, vf); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); + v8 = st.SortPairsDistance1(d, v8); + v9 = st.SortPairsDistance1(d, v9); + va = st.SortPairsDistance1(d, va); + vb = st.SortPairsDistance1(d, vb); + vc = st.SortPairsDistance1(d, vc); + vd = st.SortPairsDistance1(d, vd); + ve = st.SortPairsDistance1(d, ve); + vf = st.SortPairsDistance1(d, vf); +} + +template , + HWY_IF_LANES_GT(kKeysPerVector, 2)> +HWY_INLINE void Merge16x4(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, + V& v5, V& v6, V& v7, V& v8, V& v9, V& va, V& vb, + V& vc, V& vd, V& ve, V& vf) { + vf = st.ReverseKeys4(d, vf); + ve = st.ReverseKeys4(d, ve); + vd = st.ReverseKeys4(d, vd); + vc = st.ReverseKeys4(d, vc); + vb = st.ReverseKeys4(d, vb); + va = st.ReverseKeys4(d, va); + v9 = st.ReverseKeys4(d, v9); + v8 = st.ReverseKeys4(d, v8); + st.Sort2(d, v0, vf); + st.Sort2(d, v1, ve); + st.Sort2(d, v2, vd); + st.Sort2(d, v3, vc); + st.Sort2(d, v4, vb); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v7, v8); + + v7 = st.ReverseKeys4(d, v7); + v6 = st.ReverseKeys4(d, v6); + v5 = st.ReverseKeys4(d, v5); + v4 = st.ReverseKeys4(d, v4); + vf = st.ReverseKeys4(d, vf); + ve = st.ReverseKeys4(d, ve); + vd = st.ReverseKeys4(d, vd); + vc = st.ReverseKeys4(d, vc); + st.Sort2(d, v0, v7); + st.Sort2(d, v1, v6); + st.Sort2(d, v2, v5); + st.Sort2(d, v3, v4); + st.Sort2(d, v8, vf); + st.Sort2(d, v9, ve); + st.Sort2(d, va, vd); + st.Sort2(d, vb, vc); + + v3 = st.ReverseKeys4(d, v3); + v2 = st.ReverseKeys4(d, v2); + v7 = st.ReverseKeys4(d, v7); + v6 = st.ReverseKeys4(d, v6); + vb = st.ReverseKeys4(d, vb); + va = st.ReverseKeys4(d, va); + vf = st.ReverseKeys4(d, vf); + ve = st.ReverseKeys4(d, ve); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + st.Sort2(d, v8, vb); + st.Sort2(d, v9, va); + st.Sort2(d, vc, vf); + st.Sort2(d, vd, ve); + + v1 = st.ReverseKeys4(d, v1); + v3 = st.ReverseKeys4(d, v3); + v5 = st.ReverseKeys4(d, v5); + v7 = st.ReverseKeys4(d, v7); + v9 = st.ReverseKeys4(d, v9); + vb = st.ReverseKeys4(d, vb); + vd = st.ReverseKeys4(d, vd); + vf = st.ReverseKeys4(d, vf); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + + v0 = st.SortPairsReverse4(d, v0); + v1 = st.SortPairsReverse4(d, v1); + v2 = st.SortPairsReverse4(d, v2); + v3 = st.SortPairsReverse4(d, v3); + v4 = st.SortPairsReverse4(d, v4); + v5 = st.SortPairsReverse4(d, v5); + v6 = st.SortPairsReverse4(d, v6); + v7 = st.SortPairsReverse4(d, v7); + v8 = st.SortPairsReverse4(d, v8); + v9 = st.SortPairsReverse4(d, v9); + va = st.SortPairsReverse4(d, va); + vb = st.SortPairsReverse4(d, vb); + vc = st.SortPairsReverse4(d, vc); + vd = st.SortPairsReverse4(d, vd); + ve = st.SortPairsReverse4(d, ve); + vf = st.SortPairsReverse4(d, vf); + + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); + v8 = st.SortPairsDistance1(d, v8); + v9 = st.SortPairsDistance1(d, v9); + va = st.SortPairsDistance1(d, va); + vb = st.SortPairsDistance1(d, vb); + vc = st.SortPairsDistance1(d, vc); + vd = st.SortPairsDistance1(d, vd); + ve = st.SortPairsDistance1(d, ve); + vf = st.SortPairsDistance1(d, vf); +} + +template , + HWY_IF_LANES_GT(kKeysPerVector, 4)> +HWY_INLINE void Merge16x8(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, + V& v5, V& v6, V& v7, V& v8, V& v9, V& va, V& vb, + V& vc, V& vd, V& ve, V& vf) { + vf = st.ReverseKeys8(d, vf); + ve = st.ReverseKeys8(d, ve); + vd = st.ReverseKeys8(d, vd); + vc = st.ReverseKeys8(d, vc); + vb = st.ReverseKeys8(d, vb); + va = st.ReverseKeys8(d, va); + v9 = st.ReverseKeys8(d, v9); + v8 = st.ReverseKeys8(d, v8); + st.Sort2(d, v0, vf); + st.Sort2(d, v1, ve); + st.Sort2(d, v2, vd); + st.Sort2(d, v3, vc); + st.Sort2(d, v4, vb); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v7, v8); + + v7 = st.ReverseKeys8(d, v7); + v6 = st.ReverseKeys8(d, v6); + v5 = st.ReverseKeys8(d, v5); + v4 = st.ReverseKeys8(d, v4); + vf = st.ReverseKeys8(d, vf); + ve = st.ReverseKeys8(d, ve); + vd = st.ReverseKeys8(d, vd); + vc = st.ReverseKeys8(d, vc); + st.Sort2(d, v0, v7); + st.Sort2(d, v1, v6); + st.Sort2(d, v2, v5); + st.Sort2(d, v3, v4); + st.Sort2(d, v8, vf); + st.Sort2(d, v9, ve); + st.Sort2(d, va, vd); + st.Sort2(d, vb, vc); + + v3 = st.ReverseKeys8(d, v3); + v2 = st.ReverseKeys8(d, v2); + v7 = st.ReverseKeys8(d, v7); + v6 = st.ReverseKeys8(d, v6); + vb = st.ReverseKeys8(d, vb); + va = st.ReverseKeys8(d, va); + vf = st.ReverseKeys8(d, vf); + ve = st.ReverseKeys8(d, ve); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + st.Sort2(d, v8, vb); + st.Sort2(d, v9, va); + st.Sort2(d, vc, vf); + st.Sort2(d, vd, ve); + + v1 = st.ReverseKeys8(d, v1); + v3 = st.ReverseKeys8(d, v3); + v5 = st.ReverseKeys8(d, v5); + v7 = st.ReverseKeys8(d, v7); + v9 = st.ReverseKeys8(d, v9); + vb = st.ReverseKeys8(d, vb); + vd = st.ReverseKeys8(d, vd); + vf = st.ReverseKeys8(d, vf); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + + v0 = st.SortPairsReverse8(d, v0); + v1 = st.SortPairsReverse8(d, v1); + v2 = st.SortPairsReverse8(d, v2); + v3 = st.SortPairsReverse8(d, v3); + v4 = st.SortPairsReverse8(d, v4); + v5 = st.SortPairsReverse8(d, v5); + v6 = st.SortPairsReverse8(d, v6); + v7 = st.SortPairsReverse8(d, v7); + v8 = st.SortPairsReverse8(d, v8); + v9 = st.SortPairsReverse8(d, v9); + va = st.SortPairsReverse8(d, va); + vb = st.SortPairsReverse8(d, vb); + vc = st.SortPairsReverse8(d, vc); + vd = st.SortPairsReverse8(d, vd); + ve = st.SortPairsReverse8(d, ve); + vf = st.SortPairsReverse8(d, vf); + + v0 = st.SortPairsDistance2(d, v0); + v1 = st.SortPairsDistance2(d, v1); + v2 = st.SortPairsDistance2(d, v2); + v3 = st.SortPairsDistance2(d, v3); + v4 = st.SortPairsDistance2(d, v4); + v5 = st.SortPairsDistance2(d, v5); + v6 = st.SortPairsDistance2(d, v6); + v7 = st.SortPairsDistance2(d, v7); + v8 = st.SortPairsDistance2(d, v8); + v9 = st.SortPairsDistance2(d, v9); + va = st.SortPairsDistance2(d, va); + vb = st.SortPairsDistance2(d, vb); + vc = st.SortPairsDistance2(d, vc); + vd = st.SortPairsDistance2(d, vd); + ve = st.SortPairsDistance2(d, ve); + vf = st.SortPairsDistance2(d, vf); + + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); + v8 = st.SortPairsDistance1(d, v8); + v9 = st.SortPairsDistance1(d, v9); + va = st.SortPairsDistance1(d, va); + vb = st.SortPairsDistance1(d, vb); + vc = st.SortPairsDistance1(d, vc); + vd = st.SortPairsDistance1(d, vd); + ve = st.SortPairsDistance1(d, ve); + vf = st.SortPairsDistance1(d, vf); +} + +// Unused on MSVC, see below +#if !HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD + +template , + HWY_IF_LANES_GT(kKeysPerVector, 8)> +HWY_INLINE void Merge16x16(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, + V& v5, V& v6, V& v7, V& v8, V& v9, V& va, V& vb, + V& vc, V& vd, V& ve, V& vf) { + vf = st.ReverseKeys16(d, vf); + ve = st.ReverseKeys16(d, ve); + vd = st.ReverseKeys16(d, vd); + vc = st.ReverseKeys16(d, vc); + vb = st.ReverseKeys16(d, vb); + va = st.ReverseKeys16(d, va); + v9 = st.ReverseKeys16(d, v9); + v8 = st.ReverseKeys16(d, v8); + st.Sort2(d, v0, vf); + st.Sort2(d, v1, ve); + st.Sort2(d, v2, vd); + st.Sort2(d, v3, vc); + st.Sort2(d, v4, vb); + st.Sort2(d, v5, va); + st.Sort2(d, v6, v9); + st.Sort2(d, v7, v8); + + v7 = st.ReverseKeys16(d, v7); + v6 = st.ReverseKeys16(d, v6); + v5 = st.ReverseKeys16(d, v5); + v4 = st.ReverseKeys16(d, v4); + vf = st.ReverseKeys16(d, vf); + ve = st.ReverseKeys16(d, ve); + vd = st.ReverseKeys16(d, vd); + vc = st.ReverseKeys16(d, vc); + st.Sort2(d, v0, v7); + st.Sort2(d, v1, v6); + st.Sort2(d, v2, v5); + st.Sort2(d, v3, v4); + st.Sort2(d, v8, vf); + st.Sort2(d, v9, ve); + st.Sort2(d, va, vd); + st.Sort2(d, vb, vc); + + v3 = st.ReverseKeys16(d, v3); + v2 = st.ReverseKeys16(d, v2); + v7 = st.ReverseKeys16(d, v7); + v6 = st.ReverseKeys16(d, v6); + vb = st.ReverseKeys16(d, vb); + va = st.ReverseKeys16(d, va); + vf = st.ReverseKeys16(d, vf); + ve = st.ReverseKeys16(d, ve); + st.Sort2(d, v0, v3); + st.Sort2(d, v1, v2); + st.Sort2(d, v4, v7); + st.Sort2(d, v5, v6); + st.Sort2(d, v8, vb); + st.Sort2(d, v9, va); + st.Sort2(d, vc, vf); + st.Sort2(d, vd, ve); + + v1 = st.ReverseKeys16(d, v1); + v3 = st.ReverseKeys16(d, v3); + v5 = st.ReverseKeys16(d, v5); + v7 = st.ReverseKeys16(d, v7); + v9 = st.ReverseKeys16(d, v9); + vb = st.ReverseKeys16(d, vb); + vd = st.ReverseKeys16(d, vd); + vf = st.ReverseKeys16(d, vf); + st.Sort2(d, v0, v1); + st.Sort2(d, v2, v3); + st.Sort2(d, v4, v5); + st.Sort2(d, v6, v7); + st.Sort2(d, v8, v9); + st.Sort2(d, va, vb); + st.Sort2(d, vc, vd); + st.Sort2(d, ve, vf); + + v0 = st.SortPairsReverse16(d, v0); + v1 = st.SortPairsReverse16(d, v1); + v2 = st.SortPairsReverse16(d, v2); + v3 = st.SortPairsReverse16(d, v3); + v4 = st.SortPairsReverse16(d, v4); + v5 = st.SortPairsReverse16(d, v5); + v6 = st.SortPairsReverse16(d, v6); + v7 = st.SortPairsReverse16(d, v7); + v8 = st.SortPairsReverse16(d, v8); + v9 = st.SortPairsReverse16(d, v9); + va = st.SortPairsReverse16(d, va); + vb = st.SortPairsReverse16(d, vb); + vc = st.SortPairsReverse16(d, vc); + vd = st.SortPairsReverse16(d, vd); + ve = st.SortPairsReverse16(d, ve); + vf = st.SortPairsReverse16(d, vf); + + v0 = st.SortPairsDistance4(d, v0); + v1 = st.SortPairsDistance4(d, v1); + v2 = st.SortPairsDistance4(d, v2); + v3 = st.SortPairsDistance4(d, v3); + v4 = st.SortPairsDistance4(d, v4); + v5 = st.SortPairsDistance4(d, v5); + v6 = st.SortPairsDistance4(d, v6); + v7 = st.SortPairsDistance4(d, v7); + v8 = st.SortPairsDistance4(d, v8); + v9 = st.SortPairsDistance4(d, v9); + va = st.SortPairsDistance4(d, va); + vb = st.SortPairsDistance4(d, vb); + vc = st.SortPairsDistance4(d, vc); + vd = st.SortPairsDistance4(d, vd); + ve = st.SortPairsDistance4(d, ve); + vf = st.SortPairsDistance4(d, vf); + + v0 = st.SortPairsDistance2(d, v0); + v1 = st.SortPairsDistance2(d, v1); + v2 = st.SortPairsDistance2(d, v2); + v3 = st.SortPairsDistance2(d, v3); + v4 = st.SortPairsDistance2(d, v4); + v5 = st.SortPairsDistance2(d, v5); + v6 = st.SortPairsDistance2(d, v6); + v7 = st.SortPairsDistance2(d, v7); + v8 = st.SortPairsDistance2(d, v8); + v9 = st.SortPairsDistance2(d, v9); + va = st.SortPairsDistance2(d, va); + vb = st.SortPairsDistance2(d, vb); + vc = st.SortPairsDistance2(d, vc); + vd = st.SortPairsDistance2(d, vd); + ve = st.SortPairsDistance2(d, ve); + vf = st.SortPairsDistance2(d, vf); + + v0 = st.SortPairsDistance1(d, v0); + v1 = st.SortPairsDistance1(d, v1); + v2 = st.SortPairsDistance1(d, v2); + v3 = st.SortPairsDistance1(d, v3); + v4 = st.SortPairsDistance1(d, v4); + v5 = st.SortPairsDistance1(d, v5); + v6 = st.SortPairsDistance1(d, v6); + v7 = st.SortPairsDistance1(d, v7); + v8 = st.SortPairsDistance1(d, v8); + v9 = st.SortPairsDistance1(d, v9); + va = st.SortPairsDistance1(d, va); + vb = st.SortPairsDistance1(d, vb); + vc = st.SortPairsDistance1(d, vc); + vd = st.SortPairsDistance1(d, vd); + ve = st.SortPairsDistance1(d, ve); + vf = st.SortPairsDistance1(d, vf); +} + +#endif // !HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD + +// Reshapes `buf` into a matrix, sorts columns independently, and then merges +// into a sorted 1D array without transposing. +// +// DEPRECATED, use BaseCase() instead. +template +HWY_INLINE void SortingNetwork(Traits st, size_t cols, V& v0, V& v1, V& v2, + V& v3, V& v4, V& v5, V& v6, V& v7, V& v8, V& v9, + V& va, V& vb, V& vc, V& vd, V& ve, V& vf) { + // traits*-inl assume 'full' vectors (but still capped to kMaxCols). + const CappedTag d; + + HWY_DASSERT(cols <= Constants::kMaxCols); + + // The network width depends on the number of keys, not lanes. + constexpr size_t kLanesPerKey = st.LanesPerKey(); + const size_t keys = cols / kLanesPerKey; + constexpr size_t kMaxKeys = MaxLanes(d) / kLanesPerKey; + + Sort16(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, vf); + + // Checking MaxLanes avoids generating HWY_ASSERT code for the unreachable + // code paths: if MaxLanes < 2, then keys <= cols < 2. + if (HWY_LIKELY(keys >= 2 && kMaxKeys >= 2)) { + Merge16x2(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, + vc, vd, ve, vf); + + if (HWY_LIKELY(keys >= 4 && kMaxKeys >= 4)) { + Merge16x4(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, + vc, vd, ve, vf); + + if (HWY_LIKELY(keys >= 8 && kMaxKeys >= 8)) { + Merge16x8(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, + vb, vc, vd, ve, vf); + + // Avoids build timeout. Must match #if condition in kMaxCols. +#if !HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD + if (HWY_LIKELY(keys >= 16 && kMaxKeys >= 16)) { + Merge16x16(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, + va, vb, vc, vd, ve, vf); + + static_assert(Constants::kMaxCols <= 16, "Add more branches"); + } +#endif + } + } + } +} + +// As above, but loads from/stores to `buf`. This ensures full vectors are +// aligned, and enables loads/stores without bounds checks. +// +// DEPRECATED, use BaseCase() instead. +template +HWY_NOINLINE void SortingNetwork(Traits st, T* HWY_RESTRICT buf, size_t cols) { + // traits*-inl assume 'full' vectors (but still capped to kMaxCols). + // However, for smaller arrays and sub-maximal `cols` we have overlapping + // loads where only the lowest `cols` are valid, and we skip Merge16 etc. + const CappedTag d; + using V = decltype(Zero(d)); + + HWY_DASSERT(cols <= Constants::kMaxCols); + + // These are aligned iff cols == Lanes(d). We prefer unaligned/non-constexpr + // offsets to duplicating this code for every value of cols. + static_assert(Constants::kMaxRows == 16, "Update loads/stores/args"); + V v0 = LoadU(d, buf + 0x0 * cols); + V v1 = LoadU(d, buf + 0x1 * cols); + V v2 = LoadU(d, buf + 0x2 * cols); + V v3 = LoadU(d, buf + 0x3 * cols); + V v4 = LoadU(d, buf + 0x4 * cols); + V v5 = LoadU(d, buf + 0x5 * cols); + V v6 = LoadU(d, buf + 0x6 * cols); + V v7 = LoadU(d, buf + 0x7 * cols); + V v8 = LoadU(d, buf + 0x8 * cols); + V v9 = LoadU(d, buf + 0x9 * cols); + V va = LoadU(d, buf + 0xa * cols); + V vb = LoadU(d, buf + 0xb * cols); + V vc = LoadU(d, buf + 0xc * cols); + V vd = LoadU(d, buf + 0xd * cols); + V ve = LoadU(d, buf + 0xe * cols); + V vf = LoadU(d, buf + 0xf * cols); + + SortingNetwork(st, cols, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, + vd, ve, vf); + + StoreU(v0, d, buf + 0x0 * cols); + StoreU(v1, d, buf + 0x1 * cols); + StoreU(v2, d, buf + 0x2 * cols); + StoreU(v3, d, buf + 0x3 * cols); + StoreU(v4, d, buf + 0x4 * cols); + StoreU(v5, d, buf + 0x5 * cols); + StoreU(v6, d, buf + 0x6 * cols); + StoreU(v7, d, buf + 0x7 * cols); + StoreU(v8, d, buf + 0x8 * cols); + StoreU(v9, d, buf + 0x9 * cols); + StoreU(va, d, buf + 0xa * cols); + StoreU(vb, d, buf + 0xb * cols); + StoreU(vc, d, buf + 0xc * cols); + StoreU(vd, d, buf + 0xd * cols); + StoreU(ve, d, buf + 0xe * cols); + StoreU(vf, d, buf + 0xf * cols); +} + +#else +template +struct SharedTraits : public Base {}; + +namespace detail { + +// Empty function to avoid a possible -Wpragma-clang-attribute warning if +// compiling with Clang +static HWY_INLINE HWY_MAYBE_UNUSED void HWY_CONCAT(UnusedSortingNetworksFunc, + __LINE__)() {} + +} // namespace detail + +#endif // VQSORT_ENABLED + +} // namespace detail +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_SORTING_NETWORKS_TOGGLE diff --git a/lib/highway/hwy/contrib/sort/traits-inl.h b/lib/highway/hwy/contrib/sort/traits-inl.h new file mode 100644 index 00000000000..b02ed8b9ab0 --- /dev/null +++ b/lib/highway/hwy/contrib/sort/traits-inl.h @@ -0,0 +1,624 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE +#endif + +#include +#include + +#include "hwy/contrib/sort/order.h" // SortDescending +#include "hwy/contrib/sort/shared-inl.h" // SortConstants +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +// Base class of both KeyLane variants +template +struct KeyLaneBase { + static constexpr bool Is128() { return false; } + constexpr size_t LanesPerKey() const { return 1; } + + // What type bench_sort should allocate for generating inputs. + using LaneType = LaneTypeArg; + // What type to pass to VQSort. + using KeyType = KeyTypeArg; + + const char* KeyString() const { + return IsSame() ? "f16" + : IsSame() ? "f32" + : IsSame() ? "f64" + : IsSame() ? "i16" + : IsSame() ? "i32" + : IsSame() ? "i64" + : IsSame() ? "u32" + : IsSame() ? "u32" + : IsSame() ? "u64" + : IsSame() ? "k+v=64" + : "?"; + } +}; + +// Wrapper functions so we can specialize for floats - infinity trumps +// HighestValue (the normal value with the largest magnitude). Must be outside +// Order* classes to enable SFINAE. + +template +Vec LargestSortValue(D d) { + return Inf(d); +} +template +Vec LargestSortValue(D d) { + return Set(d, hwy::HighestValue>()); +} + +template +Vec SmallestSortValue(D d) { + return Neg(Inf(d)); +} +template +Vec SmallestSortValue(D d) { + return Set(d, hwy::LowestValue>()); +} + +// Returns the next distinct larger value unless already +inf. +template +Vec LargerSortValue(D d, Vec v) { + HWY_DASSERT(AllFalse(d, IsNaN(v))); // we replaced all NaN with LastValue. + using T = TFromD; + const RebindToUnsigned du; + using VU = Vec; + + const VU vu = BitCast(du, Abs(v)); + + // The direction depends on the original sign. Integer comparison is cheaper + // than float comparison and treats -0 as 0 (so we return +epsilon). + const Mask was_pos = Le(BitCast(du, v), SignBit(du)); + // If positive, add 1, else -1. +#if HWY_ARCH_ARM_V7 + // Workaround for incorrect codegen. ~x - x is equivalent to x? 1 : -1. + const VU was_pos_u = VecFromMask(du, was_pos); + const VU add = Not(was_pos_u) - was_pos_u; +#else + using TU = TFromD; + const VU add = IfThenElse(was_pos, Set(du, 1u), Set(du, LimitsMax())); +#endif + // Prev/next integer is the prev/next value, even if mantissa under/overflows. + v = BitCast(d, Add(vu, add)); + // But we may have overflowed into inf or NaN; replace with inf if positive, + // but the largest (later negated!) value if the input was -inf. + const Mask was_pos_f = RebindMask(d, was_pos); + v = IfThenElse(IsFinite(v), v, + IfThenElse(was_pos_f, Inf(d), Set(d, HighestValue()))); + // Restore the original sign - not via CopySignToAbs because we used a mask. + return IfThenElse(was_pos_f, v, Neg(v)); +} + +// Returns the next distinct smaller value unless already -inf. +template +Vec SmallerSortValue(D d, Vec v) { + HWY_DASSERT(AllFalse(d, IsNaN(v))); // we replaced all NaN with LastValue. + using T = TFromD; + const RebindToUnsigned du; + using VU = Vec; + using TU = TFromD; + + const VU vu = BitCast(du, Abs(v)); + + // The direction depends on the original sign. Float comparison because we + // want to treat 0 as -0 so we return -epsilon. + const Mask was_pos = Gt(v, Zero(d)); + // If positive, add -1, else 1. + const VU add = + IfThenElse(RebindMask(du, was_pos), Set(du, LimitsMax()), Set(du, 1)); + // Prev/next integer is the prev/next value, even if mantissa under/overflows. + v = BitCast(d, Add(vu, add)); + // But we may have overflowed into inf or NaN; replace with +inf (which will + // later be negated) if negative, but the largest value if the input was +inf. + v = IfThenElse(IsFinite(v), v, + IfThenElse(was_pos, Set(d, HighestValue()), Inf(d))); + // Restore the original sign - not via CopySignToAbs because we used a mask. + return IfThenElse(was_pos, v, Neg(v)); +} + +template +Vec LargerSortValue(D d, Vec v) { + return Add(v, Set(d, TFromD{1})); +} + +template +Vec SmallerSortValue(D d, Vec v) { + return Sub(v, Set(d, TFromD{1})); +} + +// Highway does not provide a lane type for 128-bit keys, so we use uint64_t +// along with an abstraction layer for single-lane vs. lane-pair, which is +// independent of the order. +template +struct KeyLane : public KeyLaneBase { + // For HeapSort + HWY_INLINE void Swap(LaneType* a, LaneType* b) const { + const LaneType temp = *a; + *a = *b; + *b = temp; + } + + template + HWY_INLINE V CompressKeys(V keys, M mask) const { + return CompressNot(keys, mask); + } + + // Broadcasts one key into a vector + template + HWY_INLINE Vec SetKey(D d, const LaneType* key) const { + return Set(d, *key); + } + + template + HWY_INLINE Mask EqualKeys(D /*tag*/, Vec a, Vec b) const { + return Eq(a, b); + } + + template + HWY_INLINE Mask NotEqualKeys(D /*tag*/, Vec a, Vec b) const { + return Ne(a, b); + } + + // For keys=lanes, any difference counts. + template + HWY_INLINE bool NoKeyDifference(D /*tag*/, Vec diff) const { + // Must avoid floating-point comparisons (for -0) + const RebindToUnsigned du; + return AllTrue(du, Eq(BitCast(du, diff), Zero(du))); + } + + HWY_INLINE bool Equal1(const LaneType* a, const LaneType* b) const { + return *a == *b; + } + + template + HWY_INLINE Vec ReverseKeys(D d, Vec v) const { + return Reverse(d, v); + } + + template + HWY_INLINE Vec ReverseKeys2(D d, Vec v) const { + return Reverse2(d, v); + } + + template + HWY_INLINE Vec ReverseKeys4(D d, Vec v) const { + return Reverse4(d, v); + } + + template + HWY_INLINE Vec ReverseKeys8(D d, Vec v) const { + return Reverse8(d, v); + } + + template + HWY_INLINE Vec ReverseKeys16(D d, Vec v) const { + static_assert(SortConstants::kMaxCols <= 16, "Assumes u32x16 = 512 bit"); + return ReverseKeys(d, v); + } + + template + HWY_INLINE V OddEvenKeys(const V odd, const V even) const { + return OddEven(odd, even); + } + + template + HWY_INLINE Vec SwapAdjacentPairs(D d, const Vec v) const { + const Repartition du32; + return BitCast(d, Shuffle2301(BitCast(du32, v))); + } + template + HWY_INLINE Vec SwapAdjacentPairs(D /* tag */, const Vec v) const { + return Shuffle1032(v); + } + template + HWY_INLINE Vec SwapAdjacentPairs(D /* tag */, const Vec v) const { + return SwapAdjacentBlocks(v); + } + + template + HWY_INLINE Vec SwapAdjacentQuads(D d, const Vec v) const { +#if HWY_HAVE_FLOAT64 // in case D is float32 + const RepartitionToWide dw; +#else + const RepartitionToWide> dw; +#endif + return BitCast(d, SwapAdjacentPairs(dw, BitCast(dw, v))); + } + template + HWY_INLINE Vec SwapAdjacentQuads(D d, const Vec v) const { + // Assumes max vector size = 512 + return ConcatLowerUpper(d, v, v); + } + + template + HWY_INLINE Vec OddEvenPairs(D d, const Vec odd, + const Vec even) const { +#if HWY_HAVE_FLOAT64 // in case D is float32 + const RepartitionToWide dw; +#else + const RepartitionToWide> dw; +#endif + return BitCast(d, OddEven(BitCast(dw, odd), BitCast(dw, even))); + } + template + HWY_INLINE Vec OddEvenPairs(D /* tag */, Vec odd, Vec even) const { + return OddEvenBlocks(odd, even); + } + + template + HWY_INLINE Vec OddEvenQuads(D d, Vec odd, Vec even) const { +#if HWY_HAVE_FLOAT64 // in case D is float32 + const RepartitionToWide dw; +#else + const RepartitionToWide> dw; +#endif + return BitCast(d, OddEvenPairs(dw, BitCast(dw, odd), BitCast(dw, even))); + } + template + HWY_INLINE Vec OddEvenQuads(D d, Vec odd, Vec even) const { + return ConcatUpperLower(d, odd, even); + } +}; + +// Anything order-related depends on the key traits *and* the order (see +// FirstOfLanes). We cannot implement just one Compare function because Lt128 +// only compiles if the lane type is u64. Thus we need either overloaded +// functions with a tag type, class specializations, or separate classes. +// We avoid overloaded functions because we want all functions to be callable +// from a SortTraits without per-function wrappers. Specializing would work, but +// we are anyway going to specialize at a higher level. +template +struct OrderAscending : public KeyLane { + // False indicates the entire key (i.e. lane) should be compared. KV stands + // for key-value. + static constexpr bool IsKV() { return false; } + + using Order = SortAscending; + using OrderForSortingNetwork = OrderAscending; + + HWY_INLINE bool Compare1(const T* a, const T* b) const { return *a < *b; } + + template + HWY_INLINE Mask Compare(D /* tag */, Vec a, Vec b) const { + return Lt(a, b); + } + + // Two halves of Sort2, used in ScanMinMax. + template + HWY_INLINE Vec First(D /* tag */, const Vec a, const Vec b) const { + return Min(a, b); + } + + template + HWY_INLINE Vec Last(D /* tag */, const Vec a, const Vec b) const { + return Max(a, b); + } + + template + HWY_INLINE Vec FirstOfLanes(D d, Vec v, + T* HWY_RESTRICT /* buf */) const { + return MinOfLanes(d, v); + } + + template + HWY_INLINE Vec LastOfLanes(D d, Vec v, + T* HWY_RESTRICT /* buf */) const { + return MaxOfLanes(d, v); + } + + template + HWY_INLINE Vec FirstValue(D d) const { + return SmallestSortValue(d); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return LargestSortValue(d); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + return SmallerSortValue(d, v); + } +}; + +template +struct OrderDescending : public KeyLane { + // False indicates the entire key (i.e. lane) should be compared. KV stands + // for key-value. + static constexpr bool IsKV() { return false; } + + using Order = SortDescending; + using OrderForSortingNetwork = OrderDescending; + + HWY_INLINE bool Compare1(const T* a, const T* b) const { return *b < *a; } + + template + HWY_INLINE Mask Compare(D /* tag */, Vec a, Vec b) const { + return Lt(b, a); + } + + template + HWY_INLINE Vec First(D /* tag */, const Vec a, const Vec b) const { + return Max(a, b); + } + + template + HWY_INLINE Vec Last(D /* tag */, const Vec a, const Vec b) const { + return Min(a, b); + } + + template + HWY_INLINE Vec FirstOfLanes(D d, Vec v, + T* HWY_RESTRICT /* buf */) const { + return MaxOfLanes(d, v); + } + + template + HWY_INLINE Vec LastOfLanes(D d, Vec v, + T* HWY_RESTRICT /* buf */) const { + return MinOfLanes(d, v); + } + + template + HWY_INLINE Vec FirstValue(D d) const { + return LargestSortValue(d); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return SmallestSortValue(d); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + return LargerSortValue(d, v); + } +}; + +struct KeyValue64 : public KeyLane { + // True indicates only part of the key (i.e. lane) should be compared. KV + // stands for key-value. + static constexpr bool IsKV() { return true; } + + template + HWY_INLINE Mask EqualKeys(D /*tag*/, Vec a, Vec b) const { + return Eq(ShiftRight<32>(a), ShiftRight<32>(b)); + } + + template + HWY_INLINE Mask NotEqualKeys(D /*tag*/, Vec a, Vec b) const { + return Ne(ShiftRight<32>(a), ShiftRight<32>(b)); + } + + HWY_INLINE bool Equal1(const uint64_t* a, const uint64_t* b) const { + return (*a >> 32) == (*b >> 32); + } + + // Only count differences in the actual key, not the value. + template + HWY_INLINE bool NoKeyDifference(D /*tag*/, Vec diff) const { + // Must avoid floating-point comparisons (for -0) + const RebindToUnsigned du; + const Vec zero = Zero(du); + const Vec keys = ShiftRight<32>(diff); // clear values + return AllTrue(du, Eq(BitCast(du, keys), zero)); + } +}; + +struct OrderAscendingKV64 : public KeyValue64 { + using Order = SortAscending; + using OrderForSortingNetwork = OrderAscending; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) const { + return (*a >> 32) < (*b >> 32); + } + + template + HWY_INLINE Mask Compare(D /* tag */, Vec a, Vec b) const { + return Lt(ShiftRight<32>(a), ShiftRight<32>(b)); + } + + // Not required to be stable (preserving the order of equivalent keys), so + // we can include the value in the comparison. + template + HWY_INLINE Vec First(D /* tag */, const Vec a, const Vec b) const { + return Min(a, b); + } + + template + HWY_INLINE Vec Last(D /* tag */, const Vec a, const Vec b) const { + return Max(a, b); + } + + template + HWY_INLINE Vec FirstOfLanes(D d, Vec v, + uint64_t* HWY_RESTRICT /* buf */) const { + return MinOfLanes(d, v); + } + + template + HWY_INLINE Vec LastOfLanes(D d, Vec v, + uint64_t* HWY_RESTRICT /* buf */) const { + return MaxOfLanes(d, v); + } + + // Same as for regular lanes. + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::LowestValue>()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::HighestValue>()); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + return Sub(v, Set(d, uint64_t{1} << 32)); + } +}; + +struct OrderDescendingKV64 : public KeyValue64 { + using Order = SortDescending; + using OrderForSortingNetwork = OrderDescending; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) const { + return (*b >> 32) < (*a >> 32); + } + + template + HWY_INLINE Mask Compare(D /* tag */, Vec a, Vec b) const { + return Lt(ShiftRight<32>(b), ShiftRight<32>(a)); + } + + // Not required to be stable (preserving the order of equivalent keys), so + // we can include the value in the comparison. + template + HWY_INLINE Vec First(D /* tag */, const Vec a, const Vec b) const { + return Max(a, b); + } + + template + HWY_INLINE Vec Last(D /* tag */, const Vec a, const Vec b) const { + return Min(a, b); + } + + template + HWY_INLINE Vec FirstOfLanes(D d, Vec v, + uint64_t* HWY_RESTRICT /* buf */) const { + return MaxOfLanes(d, v); + } + + template + HWY_INLINE Vec LastOfLanes(D d, Vec v, + uint64_t* HWY_RESTRICT /* buf */) const { + return MinOfLanes(d, v); + } + + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::HighestValue>()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::LowestValue>()); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + return Add(v, Set(d, uint64_t{1} << 32)); + } +}; + +// Shared code that depends on Order. +template +struct TraitsLane : public Base { + using TraitsForSortingNetwork = + TraitsLane; + + // For each lane i: replaces a[i] with the first and b[i] with the second + // according to Base. + // Corresponds to a conditional swap, which is one "node" of a sorting + // network. Min/Max are cheaper than compare + blend at least for integers. + template + HWY_INLINE void Sort2(D d, Vec& a, Vec& b) const { + const Base* base = static_cast(this); + + const Vec a_copy = a; + // Prior to AVX3, there is no native 64-bit Min/Max, so they compile to 4 + // instructions. We can reduce it to a compare + 2 IfThenElse. +#if HWY_AVX3 < HWY_TARGET && HWY_TARGET <= HWY_SSSE3 + if (sizeof(TFromD) == 8) { + const Mask cmp = base->Compare(d, a, b); + a = IfThenElse(cmp, a, b); + b = IfThenElse(cmp, b, a_copy); + return; + } +#endif + a = base->First(d, a, b); + b = base->Last(d, a_copy, b); + } + + // Conditionally swaps even-numbered lanes with their odd-numbered neighbor. + template + HWY_INLINE Vec SortPairsDistance1(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys2(d, v); + // Further to the above optimization, Sort2+OddEvenKeys compile to four + // instructions; we can save one by combining two blends. +#if HWY_AVX3 < HWY_TARGET && HWY_TARGET <= HWY_SSSE3 + const Vec cmp = VecFromMask(d, base->Compare(d, v, swapped)); + return IfVecThenElse(DupOdd(cmp), swapped, v); +#else + Sort2(d, v, swapped); + return base->OddEvenKeys(swapped, v); +#endif + } + + // (See above - we use Sort2 for non-64-bit types.) + template + HWY_INLINE Vec SortPairsDistance1(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys2(d, v); + Sort2(d, v, swapped); + return base->OddEvenKeys(swapped, v); + } + + // Swaps with the vector formed by reversing contiguous groups of 4 keys. + template + HWY_INLINE Vec SortPairsReverse4(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys4(d, v); + Sort2(d, v, swapped); + return base->OddEvenPairs(d, swapped, v); + } + + // Conditionally swaps lane 0 with 4, 1 with 5 etc. + template + HWY_INLINE Vec SortPairsDistance4(D d, Vec v) const { + const Base* base = static_cast(this); + Vec swapped = base->SwapAdjacentQuads(d, v); + // Only used in Merge16, so this will not be used on AVX2 (which only has 4 + // u64 lanes), so skip the above optimization for 64-bit AVX2. + Sort2(d, v, swapped); + return base->OddEvenQuads(d, swapped, v); + } +}; + +} // namespace detail +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE diff --git a/lib/highway/hwy/contrib/sort/traits128-inl.h b/lib/highway/hwy/contrib/sort/traits128-inl.h new file mode 100644 index 00000000000..f6fa9379521 --- /dev/null +++ b/lib/highway/hwy/contrib/sort/traits128-inl.h @@ -0,0 +1,549 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE +#endif + +#include +#include + +#include "hwy/contrib/sort/order.h" // SortDescending +#include "hwy/contrib/sort/shared-inl.h" +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +// Also used by HeapSort, so do not require VQSORT_ENABLED. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + +// Highway does not provide a lane type for 128-bit keys, so we use uint64_t +// along with an abstraction layer for single-lane vs. lane-pair, which is +// independent of the order. +struct KeyAny128 { + static constexpr bool Is128() { return true; } + constexpr size_t LanesPerKey() const { return 2; } + + // What type bench_sort should allocate for generating inputs. + using LaneType = uint64_t; + // KeyType and KeyString are defined by derived classes. + + HWY_INLINE void Swap(LaneType* a, LaneType* b) const { + const FixedTag d; + const auto temp = LoadU(d, a); + StoreU(LoadU(d, b), d, a); + StoreU(temp, d, b); + } + + template + HWY_INLINE V CompressKeys(V keys, M mask) const { + return CompressBlocksNot(keys, mask); + } + + template + HWY_INLINE Vec SetKey(D d, const TFromD* key) const { + return LoadDup128(d, key); + } + + template + HWY_INLINE Vec ReverseKeys(D d, Vec v) const { + return ReverseBlocks(d, v); + } + + template + HWY_INLINE Vec ReverseKeys2(D /* tag */, const Vec v) const { + HWY_DASSERT(Lanes(D()) >= 4); // at least 2 keys + return SwapAdjacentBlocks(v); + } + + // Only called for 4 keys because we do not support >512-bit vectors. + template + HWY_INLINE Vec ReverseKeys4(D d, const Vec v) const { + HWY_DASSERT(Lanes(D()) == 8); // exactly 4 keys: the 512-bit limit + return ReverseKeys(d, v); + } + + // Only called for 4 keys because we do not support >512-bit vectors. + template + HWY_INLINE Vec OddEvenPairs(D d, const Vec odd, + const Vec even) const { + HWY_DASSERT(Lanes(D()) == 8); // exactly 4 keys: the 512-bit limit + return ConcatUpperLower(d, odd, even); + } + + template + HWY_INLINE V OddEvenKeys(const V odd, const V even) const { + return OddEvenBlocks(odd, even); + } + + template + HWY_INLINE Vec ReverseKeys8(D, Vec) const { + HWY_ASSERT(0); // not supported: would require 1024-bit vectors + } + + template + HWY_INLINE Vec ReverseKeys16(D, Vec) const { + HWY_ASSERT(0); // not supported: would require 2048-bit vectors + } + + // This is only called for 8/16 col networks (not supported). + template + HWY_INLINE Vec SwapAdjacentPairs(D, Vec) const { + HWY_ASSERT(0); + } + + // This is only called for 16 col networks (not supported). + template + HWY_INLINE Vec SwapAdjacentQuads(D, Vec) const { + HWY_ASSERT(0); + } + + // This is only called for 8 col networks (not supported). + template + HWY_INLINE Vec OddEvenQuads(D, Vec, Vec) const { + HWY_ASSERT(0); + } +}; + +// Base class shared between OrderAscending128, OrderDescending128. +struct Key128 : public KeyAny128 { + // False indicates the entire key should be compared. KV means key-value. + static constexpr bool IsKV() { return false; } + + // What type to pass to VQSort. + using KeyType = hwy::uint128_t; + + const char* KeyString() const { return "U128"; } + + template + HWY_INLINE Mask EqualKeys(D d, Vec a, Vec b) const { + return Eq128(d, a, b); + } + + template + HWY_INLINE Mask NotEqualKeys(D d, Vec a, Vec b) const { + return Ne128(d, a, b); + } + + // For keys=entire 128 bits, any difference counts. + template + HWY_INLINE bool NoKeyDifference(D /*tag*/, Vec diff) const { + // Must avoid floating-point comparisons (for -0) + const RebindToUnsigned du; + return AllTrue(du, Eq(BitCast(du, diff), Zero(du))); + } + + HWY_INLINE bool Equal1(const LaneType* a, const LaneType* b) const { + return a[0] == b[0] && a[1] == b[1]; + } + + // Returns vector with only the top half of each block valid. This allows + // fusing the "replicate upper to lower half" step with a subsequent permute. + template + HWY_INLINE HWY_MAYBE_UNUSED Vec CompareTop(D d, Vec a, Vec b) const { + const Mask eqHL = Eq(a, b); + const Vec ltHL = VecFromMask(d, Order().CompareLanes(a, b)); +#if HWY_TARGET <= HWY_AVX2 // slightly faster + const Vec ltLX = ShiftLeftLanes<1>(ltHL); + return OrAnd(ltHL, VecFromMask(d, eqHL), ltLX); +#else + return IfThenElse(eqHL, DupEven(ltHL), ltHL); +#endif + } +}; + +// Anything order-related depends on the key traits *and* the order (see +// FirstOfLanes). We cannot implement just one Compare function because Lt128 +// only compiles if the lane type is u64. Thus we need either overloaded +// functions with a tag type, class specializations, or separate classes. +// We avoid overloaded functions because we want all functions to be callable +// from a SortTraits without per-function wrappers. Specializing would work, but +// we are anyway going to specialize at a higher level. +struct OrderAscending128 : public Key128 { + using Order = SortAscending; + using OrderForSortingNetwork = OrderAscending128; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) const { + return (a[1] == b[1]) ? a[0] < b[0] : a[1] < b[1]; + } + + template + HWY_INLINE Mask Compare(D d, Vec a, Vec b) const { + return Lt128(d, a, b); + } + + template + HWY_INLINE Vec First(D d, const Vec a, const Vec b) const { + return Min128(d, a, b); + } + + template + HWY_INLINE Vec Last(D d, const Vec a, const Vec b) const { + return Max128(d, a, b); + } + + // FirstOfLanes/LastOfLanes are implemented in Traits128. + + // Same as for regular lanes because 128-bit keys are u64. + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::LowestValue >()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::HighestValue >()); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + const Vec k0 = Zero(d); + const Vec k1 = OddEven(k0, Set(d, uint64_t{1})); + const Mask borrow = Eq(v, k0); // don't-care, lo == 0 + // lo == 0? 1 : 0, 0 + const Vec adjust = ShiftLeftLanes<1>(IfThenElseZero(borrow, k1)); + return Sub(Sub(v, k1), adjust); + } + + // 'Private', used by base class Key128::CompareTop. + template + HWY_INLINE Mask > CompareLanes(V a, V b) const { + return Lt(a, b); + } +}; + +struct OrderDescending128 : public Key128 { + using Order = SortDescending; + using OrderForSortingNetwork = OrderDescending128; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) const { + return (a[1] == b[1]) ? b[0] < a[0] : b[1] < a[1]; + } + + template + HWY_INLINE Mask Compare(D d, Vec a, Vec b) const { + return Lt128(d, b, a); + } + + template + HWY_INLINE Vec First(D d, const Vec a, const Vec b) const { + return Max128(d, a, b); + } + + template + HWY_INLINE Vec Last(D d, const Vec a, const Vec b) const { + return Min128(d, a, b); + } + + // FirstOfLanes/LastOfLanes are implemented in Traits128. + + // Same as for regular lanes because 128-bit keys are u64. + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::HighestValue >()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::LowestValue >()); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + const Vec k1 = OddEven(Zero(d), Set(d, uint64_t{1})); + const Vec added = Add(v, k1); + const Mask overflowed = Lt(added, v); // false, overflowed + // overflowed? 1 : 0, 0 + const Vec adjust = ShiftLeftLanes<1>(IfThenElseZero(overflowed, k1)); + return Add(added, adjust); + } + + // 'Private', used by base class Key128::CompareTop. + template + HWY_INLINE Mask > CompareLanes(V a, V b) const { + return Lt(b, a); + } +}; + +// Base class shared between OrderAscendingKV128, OrderDescendingKV128. +struct KeyValue128 : public KeyAny128 { + // True indicates only part of the key (the more significant lane) should be + // compared. KV stands for key-value. + static constexpr bool IsKV() { return true; } + + // What type to pass to VQSort. + using KeyType = K64V64; + + const char* KeyString() const { return "k+v=128"; } + + template + HWY_INLINE Mask EqualKeys(D d, Vec a, Vec b) const { + return Eq128Upper(d, a, b); + } + + template + HWY_INLINE Mask NotEqualKeys(D d, Vec a, Vec b) const { + return Ne128Upper(d, a, b); + } + + HWY_INLINE bool Equal1(const LaneType* a, const LaneType* b) const { + return a[1] == b[1]; + } + + // Only count differences in the actual key, not the value. + template + HWY_INLINE bool NoKeyDifference(D /*tag*/, Vec diff) const { + // Must avoid floating-point comparisons (for -0) + const RebindToUnsigned du; + const Vec zero = Zero(du); + const Vec keys = OddEven(diff, zero); // clear values + return AllTrue(du, Eq(BitCast(du, keys), zero)); + } + + // Returns vector with only the top half of each block valid. This allows + // fusing the "replicate upper to lower half" step with a subsequent permute. + template + HWY_INLINE HWY_MAYBE_UNUSED Vec CompareTop(D d, Vec a, Vec b) const { + // Only the upper lane of each block is a key, and only that lane is + // required to be valid, so comparing all lanes is sufficient. + return VecFromMask(d, Order().CompareLanes(a, b)); + } +}; + +struct OrderAscendingKV128 : public KeyValue128 { + using Order = SortAscending; + using OrderForSortingNetwork = OrderAscending128; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) const { + return a[1] < b[1]; + } + + template + HWY_INLINE Mask Compare(D d, Vec a, Vec b) const { + return Lt128Upper(d, a, b); + } + + template + HWY_INLINE Vec First(D d, const Vec a, const Vec b) const { + return Min128Upper(d, a, b); + } + + template + HWY_INLINE Vec Last(D d, const Vec a, const Vec b) const { + return Max128Upper(d, a, b); + } + + // FirstOfLanes/LastOfLanes are implemented in Traits128. + + // Same as for regular lanes because 128-bit keys are u64. + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::LowestValue >()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::HighestValue >()); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + const Vec k1 = OddEven(Set(d, uint64_t{1}), Zero(d)); + return Sub(v, k1); + } + + // 'Private', used by base class KeyValue128::CompareTop. + template + HWY_INLINE Mask > CompareLanes(V a, V b) const { + return Lt(a, b); + } +}; + +struct OrderDescendingKV128 : public KeyValue128 { + using Order = SortDescending; + using OrderForSortingNetwork = OrderDescending128; + + HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) const { + return b[1] < a[1]; + } + + template + HWY_INLINE Mask Compare(D d, Vec a, Vec b) const { + return Lt128Upper(d, b, a); + } + + template + HWY_INLINE Vec First(D d, const Vec a, const Vec b) const { + return Max128Upper(d, a, b); + } + + template + HWY_INLINE Vec Last(D d, const Vec a, const Vec b) const { + return Min128Upper(d, a, b); + } + + // FirstOfLanes/LastOfLanes are implemented in Traits128. + + // Same as for regular lanes because 128-bit keys are u64. + template + HWY_INLINE Vec FirstValue(D d) const { + return Set(d, hwy::HighestValue >()); + } + + template + HWY_INLINE Vec LastValue(D d) const { + return Set(d, hwy::LowestValue >()); + } + + template + HWY_INLINE Vec PrevValue(D d, Vec v) const { + const Vec k1 = OddEven(Set(d, uint64_t{1}), Zero(d)); + return Add(v, k1); + } + + // 'Private', used by base class KeyValue128::CompareTop. + template + HWY_INLINE Mask > CompareLanes(V a, V b) const { + return Lt(b, a); + } +}; + +// We want to swap 2 u128, i.e. 4 u64 lanes, based on the 0 or FF..FF mask in +// the most-significant of those lanes (the result of CompareTop), so +// replicate it 4x. Only called for >= 256-bit vectors. + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_INLINE V ReplicateTop4x(V v) { + return V{_mm512_permutex_epi64(v.raw, _MM_SHUFFLE(3, 3, 3, 3))}; +} +#endif // HWY_TARGET <= HWY_AVX3 + +#if HWY_TARGET <= HWY_AVX2 + +template +HWY_INLINE V ReplicateTop4x(V v) { + return V{_mm256_permute4x64_epi64(v.raw, _MM_SHUFFLE(3, 3, 3, 3))}; +} + +#else // HWY_TARGET > HWY_AVX2 + +template +HWY_INLINE V ReplicateTop4x(V v) { +#if HWY_TARGET == HWY_SVE_256 + return svdup_lane_u64(v, 3); +#else + const ScalableTag d; + HWY_DASSERT(Lanes(d) == 4 || Lanes(d) == 8); // for table below + HWY_ALIGN static constexpr uint64_t kIndices[8] = {3, 3, 3, 3, 7, 7, 7, 7}; + return TableLookupLanes(v, SetTableIndices(d, kIndices)); +#endif +} + +#endif // HWY_TARGET <= HWY_AVX2 + +// Shared code that depends on Order. +template +struct Traits128 : public Base { + using TraitsForSortingNetwork = + Traits128; + + template + HWY_INLINE Vec FirstOfLanes(D d, Vec v, + TFromD* HWY_RESTRICT buf) const { + const Base* base = static_cast(this); + const size_t N = Lanes(d); + Store(v, d, buf); + v = base->SetKey(d, buf + 0); // result must be broadcasted + for (size_t i = base->LanesPerKey(); i < N; i += base->LanesPerKey()) { + v = base->First(d, v, base->SetKey(d, buf + i)); + } + return v; + } + + template + HWY_INLINE Vec LastOfLanes(D d, Vec v, + TFromD* HWY_RESTRICT buf) const { + const Base* base = static_cast(this); + const size_t N = Lanes(d); + Store(v, d, buf); + v = base->SetKey(d, buf + 0); // result must be broadcasted + for (size_t i = base->LanesPerKey(); i < N; i += base->LanesPerKey()) { + v = base->Last(d, v, base->SetKey(d, buf + i)); + } + return v; + } + + template + HWY_INLINE void Sort2(D d, Vec& a, Vec& b) const { + const Base* base = static_cast(this); + + const Vec a_copy = a; + const auto lt = base->Compare(d, a, b); + a = IfThenElse(lt, a, b); + b = IfThenElse(lt, b, a_copy); + } + + // Conditionally swaps even-numbered keys with their odd-numbered neighbor. + template + HWY_INLINE Vec SortPairsDistance1(D d, Vec v) const { + HWY_DASSERT(Lanes(d) >= 4); // required by ReplicateTop4x + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys2(d, v); + const Vec cmpHx = base->template CompareTop(d, v, swapped); + return IfVecThenElse(ReplicateTop4x(cmpHx), swapped, v); + } + + // Swaps with the vector formed by reversing contiguous groups of four 128-bit + // keys, which implies 512-bit vectors (we do not support more than that). + template + HWY_INLINE Vec SortPairsReverse4(D d, Vec v) const { + HWY_DASSERT(Lanes(d) == 8); // For TableLookupLanes below + const Base* base = static_cast(this); + Vec swapped = base->ReverseKeys4(d, v); + + const Vec cmpHx = base->template CompareTop(d, v, swapped); + // Similar to ReplicateTop4x, we want to gang together 2 comparison results + // (4 lanes). They are not contiguous, so use permute to replicate 4x. + HWY_ALIGN uint64_t kIndices[8] = {7, 7, 5, 5, 5, 5, 7, 7}; + const Vec select = TableLookupLanes(cmpHx, SetTableIndices(d, kIndices)); + return IfVecThenElse(select, swapped, v); + } + + // Conditionally swaps lane 0 with 4, 1 with 5 etc. + template + HWY_INLINE Vec SortPairsDistance4(D, Vec) const { + // Only used by Merge16, which would require 2048 bit vectors (unsupported). + HWY_ASSERT(0); + } +}; + +#endif // HWY_TARGET != HWY_SCALAR + +} // namespace detail +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE diff --git a/lib/highway/hwy/contrib/sort/vqsort-inl.h b/lib/highway/hwy/contrib/sort/vqsort-inl.h new file mode 100644 index 00000000000..1f98bcccf27 --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort-inl.h @@ -0,0 +1,2212 @@ +// Copyright 2021 Google LLC +// Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: Apache-2.0 +// SPDX-License-Identifier: BSD-3-Clause +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Normal include guard for target-independent parts +#ifndef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_ + +// unconditional #include so we can use if(VQSORT_PRINT), which unlike #if does +// not interfere with code-folding. +#include +#include // clock + +// IWYU pragma: begin_exports +#include "hwy/base.h" +#include "hwy/contrib/sort/order.h" // SortAscending +// IWYU pragma: end_exports + +#include "hwy/cache_control.h" // Prefetch +#include "hwy/print.h" // unconditional, see above. + +// If 1, VQSortStatic can be called without including vqsort.h, and we avoid +// any DLLEXPORT. This simplifies integration into other build systems, but +// decreases the security of random seeds. +#ifndef VQSORT_ONLY_STATIC +#define VQSORT_ONLY_STATIC 0 +#endif + +// Verbosity: 0 for none, 1 for brief per-sort, 2+ for more details. +#ifndef VQSORT_PRINT +#define VQSORT_PRINT 0 +#endif + +#if !VQSORT_ONLY_STATIC +#include "hwy/contrib/sort/vqsort.h" // Fill16BytesSecure +#endif + +namespace hwy { +namespace detail { + +HWY_INLINE void Fill16BytesStatic(void* bytes) { +#if !VQSORT_ONLY_STATIC + if (Fill16BytesSecure(bytes)) return; +#endif + + uint64_t* words = reinterpret_cast(bytes); + + // Static-only, or Fill16BytesSecure failed. Get some entropy from the + // stack/code location, and the clock() timer. + uint64_t** seed_stack = &words; + void (*seed_code)(void*) = &Fill16BytesStatic; + const uintptr_t bits_stack = reinterpret_cast(seed_stack); + const uintptr_t bits_code = reinterpret_cast(seed_code); + const uint64_t bits_time = static_cast(clock()); + words[0] = bits_stack ^ bits_time ^ 0xFEDCBA98; // "Nothing up my sleeve" + words[1] = bits_code ^ bits_time ^ 0x01234567; // constants. +} + +HWY_INLINE uint64_t* GetGeneratorStateStatic() { + thread_local uint64_t state[3] = {0}; + // This is a counter; zero indicates not yet initialized. + if (HWY_UNLIKELY(state[2] == 0)) { + Fill16BytesStatic(state); + state[2] = 1; + } + return state; +} + +} // namespace detail +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_ + +// Per-target +#if defined(HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE +#undef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE +#else +#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE +#endif + +#if VQSORT_PRINT +#include "hwy/print-inl.h" +#endif + +#include "hwy/contrib/algo/copy-inl.h" +#include "hwy/contrib/sort/shared-inl.h" +#include "hwy/contrib/sort/sorting_networks-inl.h" +#include "hwy/contrib/sort/traits-inl.h" +#include "hwy/contrib/sort/traits128-inl.h" +// Placeholder for internal instrumentation. Do not remove. +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +using Constants = hwy::SortConstants; + +// Wrapper avoids #if in user code (interferes with code folding) +template +HWY_INLINE void MaybePrintVector(D d, const char* label, Vec v, + size_t start = 0, size_t max_lanes = 16) { +#if VQSORT_PRINT >= 2 // Print is only defined #if + Print(d, label, v, start, max_lanes); +#else + (void)d; + (void)label; + (void)v; + (void)start; + (void)max_lanes; +#endif +} + +// ------------------------------ HeapSort + +template +void SiftDown(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes, + size_t start) { + constexpr size_t N1 = st.LanesPerKey(); + const FixedTag d; + + while (start < num_lanes) { + const size_t left = 2 * start + N1; + const size_t right = 2 * start + 2 * N1; + if (left >= num_lanes) break; + size_t idx_larger = start; + const auto key_j = st.SetKey(d, lanes + start); + if (AllTrue(d, st.Compare(d, key_j, st.SetKey(d, lanes + left)))) { + idx_larger = left; + } + if (right < num_lanes && + AllTrue(d, st.Compare(d, st.SetKey(d, lanes + idx_larger), + st.SetKey(d, lanes + right)))) { + idx_larger = right; + } + if (idx_larger == start) break; + st.Swap(lanes + start, lanes + idx_larger); + start = idx_larger; + } +} + +// Heapsort: O(1) space, O(N*logN) worst-case comparisons. +// Based on LLVM sanitizer_common.h, licensed under Apache-2.0. +template +void HeapSort(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes) { + constexpr size_t N1 = st.LanesPerKey(); + HWY_DASSERT(num_lanes % N1 == 0); + if (num_lanes == N1) return; + + // Build heap. + for (size_t i = ((num_lanes - N1) / N1 / 2) * N1; i != (~N1 + 1); i -= N1) { + SiftDown(st, lanes, num_lanes, i); + } + + for (size_t i = num_lanes - N1; i != 0; i -= N1) { +// Workaround for -Waggressive-loop-optimizations warning that might be emitted +// by GCC +#if HWY_COMPILER_GCC_ACTUAL + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4756, + ignored "-Waggressive-loop-optimizations") +#endif + // Swap root with last + st.Swap(lanes + 0, lanes + i); + +#if HWY_COMPILER_GCC_ACTUAL + HWY_DIAGNOSTICS(pop) +#endif + + // Sift down the new root. + SiftDown(st, lanes, i, 0); + } +} + +template +void HeapSelect(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes, + const size_t k_lanes) { + constexpr size_t N1 = st.LanesPerKey(); + const size_t k = k_lanes + N1; + HWY_DASSERT(num_lanes % N1 == 0); + if (num_lanes == N1) return; + + const FixedTag d; + + // Build heap. + for (size_t i = ((k - N1) / N1 / 2) * N1; i != (~N1 + 1); i -= N1) { + SiftDown(st, lanes, k, i); + } + + for (size_t i = k; i <= num_lanes - N1; i += N1) { + if (AllTrue(d, st.Compare(d, st.SetKey(d, lanes + i), + st.SetKey(d, lanes + 0)))) { +// Workaround for -Waggressive-loop-optimizations warning that might be emitted +// by GCC +#if HWY_COMPILER_GCC_ACTUAL + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4756, + ignored "-Waggressive-loop-optimizations") +#endif + + // Swap root with last + st.Swap(lanes + 0, lanes + i); + +#if HWY_COMPILER_GCC_ACTUAL + HWY_DIAGNOSTICS(pop) +#endif + + // Sift down the new root. + SiftDown(st, lanes, k, 0); + } + } + + st.Swap(lanes + 0, lanes + k - N1); +} + +template +void HeapPartialSort(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes, + const size_t k_lanes) { + HeapSelect(st, lanes, num_lanes, k_lanes); + HeapSort(st, lanes, k_lanes); +} + +#if VQSORT_ENABLED || HWY_IDE + +// ------------------------------ BaseCase + +// Special cases where `num_lanes` is in the specified range (inclusive). +template +HWY_INLINE void Sort2To2(Traits st, T* HWY_RESTRICT keys, size_t num_lanes, + T* HWY_RESTRICT /* buf */) { + constexpr size_t kLPK = st.LanesPerKey(); + const size_t num_keys = num_lanes / kLPK; + HWY_DASSERT(num_keys == 2); + HWY_ASSUME(num_keys == 2); + + // One key per vector, required to avoid reading past the end of `keys`. + const CappedTag d; + using V = Vec; + + V v0 = LoadU(d, keys + 0x0 * kLPK); + V v1 = LoadU(d, keys + 0x1 * kLPK); + + Sort2(d, st, v0, v1); + + StoreU(v0, d, keys + 0x0 * kLPK); + StoreU(v1, d, keys + 0x1 * kLPK); +} + +template +HWY_INLINE void Sort3To4(Traits st, T* HWY_RESTRICT keys, size_t num_lanes, + T* HWY_RESTRICT buf) { + constexpr size_t kLPK = st.LanesPerKey(); + const size_t num_keys = num_lanes / kLPK; + HWY_DASSERT(3 <= num_keys && num_keys <= 4); + HWY_ASSUME(num_keys >= 3); + HWY_ASSUME(num_keys <= 4); // reduces branches + + // One key per vector, required to avoid reading past the end of `keys`. + const CappedTag d; + using V = Vec; + + // If num_keys == 3, initialize padding for the last sorting network element + // so that it does not influence the other elements. + Store(st.LastValue(d), d, buf); + + // Points to a valid key, or padding. This avoids special-casing + // HWY_MEM_OPS_MIGHT_FAULT because there is only a single key per vector. + T* in_out3 = num_keys == 3 ? buf : keys + 0x3 * kLPK; + + V v0 = LoadU(d, keys + 0x0 * kLPK); + V v1 = LoadU(d, keys + 0x1 * kLPK); + V v2 = LoadU(d, keys + 0x2 * kLPK); + V v3 = LoadU(d, in_out3); + + Sort4(d, st, v0, v1, v2, v3); + + StoreU(v0, d, keys + 0x0 * kLPK); + StoreU(v1, d, keys + 0x1 * kLPK); + StoreU(v2, d, keys + 0x2 * kLPK); + StoreU(v3, d, in_out3); +} + +#if HWY_MEM_OPS_MIGHT_FAULT + +template > +HWY_INLINE void CopyHalfToPaddedBuf(D d, Traits st, T* HWY_RESTRICT keys, + size_t num_lanes, T* HWY_RESTRICT buf) { + constexpr size_t kMinLanes = kRows / 2 * kLanesPerRow; + // Must cap for correctness: we will load up to the last valid lane, so + // Lanes(dmax) must not exceed `num_lanes` (known to be at least kMinLanes). + const CappedTag dmax; + const size_t Nmax = Lanes(dmax); + HWY_DASSERT(Nmax < num_lanes); + HWY_ASSUME(Nmax <= kMinLanes); + + // Fill with padding - last in sort order, not copied to keys. + const Vec kPadding = st.LastValue(dmax); + + // Rounding down allows aligned stores, which are typically faster. + size_t i = num_lanes & ~(Nmax - 1); + HWY_ASSUME(i != 0); // because Nmax <= num_lanes; avoids branch + do { + Store(kPadding, dmax, buf + i); + i += Nmax; + // Initialize enough for the last vector even if Nmax > kLanesPerRow. + } while (i < (kRows - 1) * kLanesPerRow + Lanes(d)); + + // Ensure buf contains all we will read, and perhaps more before. + ptrdiff_t end = static_cast(num_lanes); + do { + end -= static_cast(Nmax); + StoreU(LoadU(dmax, keys + end), dmax, buf + end); + } while (end > static_cast(kRows / 2 * kLanesPerRow)); +} + +#endif // HWY_MEM_OPS_MIGHT_FAULT + +template +HWY_NOINLINE void Sort8Rows(Traits st, T* HWY_RESTRICT keys, size_t num_lanes, + T* HWY_RESTRICT buf) { + // kKeysPerRow <= 4 because 8 64-bit keys implies 512-bit vectors, which + // are likely slower than 16x4, so 8x4 is the largest we handle here. + static_assert(kKeysPerRow <= 4, ""); + + constexpr size_t kLPK = st.LanesPerKey(); + + // We reshape the 1D keys into kRows x kKeysPerRow. + constexpr size_t kRows = 8; + constexpr size_t kLanesPerRow = kKeysPerRow * kLPK; + constexpr size_t kMinLanes = kRows / 2 * kLanesPerRow; + HWY_DASSERT(kMinLanes < num_lanes && num_lanes <= kRows * kLanesPerRow); + + const CappedTag d; + using V = Vec; + V v4, v5, v6, v7; + + // At least half the kRows are valid, otherwise a different function would + // have been called to handle this num_lanes. + V v0 = LoadU(d, keys + 0x0 * kLanesPerRow); + V v1 = LoadU(d, keys + 0x1 * kLanesPerRow); + V v2 = LoadU(d, keys + 0x2 * kLanesPerRow); + V v3 = LoadU(d, keys + 0x3 * kLanesPerRow); +#if HWY_MEM_OPS_MIGHT_FAULT + CopyHalfToPaddedBuf(d, st, keys, num_lanes, buf); + v4 = LoadU(d, buf + 0x4 * kLanesPerRow); + v5 = LoadU(d, buf + 0x5 * kLanesPerRow); + v6 = LoadU(d, buf + 0x6 * kLanesPerRow); + v7 = LoadU(d, buf + 0x7 * kLanesPerRow); +#endif // HWY_MEM_OPS_MIGHT_FAULT +#if !HWY_MEM_OPS_MIGHT_FAULT || HWY_IDE + (void)buf; + const V vnum_lanes = Set(d, ConvertScalarTo(num_lanes)); + // First offset where not all vector are guaranteed valid. + const V kIota = Iota(d, static_cast(kMinLanes)); + const V k1 = Set(d, static_cast(kLanesPerRow)); + const V k2 = Add(k1, k1); + + using M = Mask; + const M m4 = Gt(vnum_lanes, kIota); + const M m5 = Gt(vnum_lanes, Add(kIota, k1)); + const M m6 = Gt(vnum_lanes, Add(kIota, k2)); + const M m7 = Gt(vnum_lanes, Add(kIota, Add(k2, k1))); + + const V kPadding = st.LastValue(d); // Not copied to keys. + v4 = MaskedLoadOr(kPadding, m4, d, keys + 0x4 * kLanesPerRow); + v5 = MaskedLoadOr(kPadding, m5, d, keys + 0x5 * kLanesPerRow); + v6 = MaskedLoadOr(kPadding, m6, d, keys + 0x6 * kLanesPerRow); + v7 = MaskedLoadOr(kPadding, m7, d, keys + 0x7 * kLanesPerRow); +#endif // !HWY_MEM_OPS_MIGHT_FAULT + + Sort8(d, st, v0, v1, v2, v3, v4, v5, v6, v7); + + // Merge8x2 is a no-op if kKeysPerRow < 2 etc. + Merge8x2(d, st, v0, v1, v2, v3, v4, v5, v6, v7); + Merge8x4(d, st, v0, v1, v2, v3, v4, v5, v6, v7); + + StoreU(v0, d, keys + 0x0 * kLanesPerRow); + StoreU(v1, d, keys + 0x1 * kLanesPerRow); + StoreU(v2, d, keys + 0x2 * kLanesPerRow); + StoreU(v3, d, keys + 0x3 * kLanesPerRow); + +#if HWY_MEM_OPS_MIGHT_FAULT + // Store remaining vectors into buf and safely copy them into keys. + StoreU(v4, d, buf + 0x4 * kLanesPerRow); + StoreU(v5, d, buf + 0x5 * kLanesPerRow); + StoreU(v6, d, buf + 0x6 * kLanesPerRow); + StoreU(v7, d, buf + 0x7 * kLanesPerRow); + + const ScalableTag dmax; + const size_t Nmax = Lanes(dmax); + + // The first half of vectors have already been stored unconditionally into + // `keys`, so we do not copy them. + size_t i = kMinLanes; + HWY_UNROLL(1) + for (; i + Nmax <= num_lanes; i += Nmax) { + StoreU(LoadU(dmax, buf + i), dmax, keys + i); + } + + // Last iteration: copy partial vector + const size_t remaining = num_lanes - i; + HWY_ASSUME(remaining < 256); // helps FirstN + SafeCopyN(remaining, dmax, buf + i, keys + i); +#endif // HWY_MEM_OPS_MIGHT_FAULT +#if !HWY_MEM_OPS_MIGHT_FAULT || HWY_IDE + BlendedStore(v4, m4, d, keys + 0x4 * kLanesPerRow); + BlendedStore(v5, m5, d, keys + 0x5 * kLanesPerRow); + BlendedStore(v6, m6, d, keys + 0x6 * kLanesPerRow); + BlendedStore(v7, m7, d, keys + 0x7 * kLanesPerRow); +#endif // !HWY_MEM_OPS_MIGHT_FAULT +} + +template +HWY_NOINLINE void Sort16Rows(Traits st, T* HWY_RESTRICT keys, size_t num_lanes, + T* HWY_RESTRICT buf) { + static_assert(kKeysPerRow <= SortConstants::kMaxCols, ""); + + constexpr size_t kLPK = st.LanesPerKey(); + + // We reshape the 1D keys into kRows x kKeysPerRow. + constexpr size_t kRows = 16; + constexpr size_t kLanesPerRow = kKeysPerRow * kLPK; + constexpr size_t kMinLanes = kRows / 2 * kLanesPerRow; + HWY_DASSERT(kMinLanes < num_lanes && num_lanes <= kRows * kLanesPerRow); + + const CappedTag d; + using V = Vec; + V v8, v9, va, vb, vc, vd, ve, vf; + + // At least half the kRows are valid, otherwise a different function would + // have been called to handle this num_lanes. + V v0 = LoadU(d, keys + 0x0 * kLanesPerRow); + V v1 = LoadU(d, keys + 0x1 * kLanesPerRow); + V v2 = LoadU(d, keys + 0x2 * kLanesPerRow); + V v3 = LoadU(d, keys + 0x3 * kLanesPerRow); + V v4 = LoadU(d, keys + 0x4 * kLanesPerRow); + V v5 = LoadU(d, keys + 0x5 * kLanesPerRow); + V v6 = LoadU(d, keys + 0x6 * kLanesPerRow); + V v7 = LoadU(d, keys + 0x7 * kLanesPerRow); +#if HWY_MEM_OPS_MIGHT_FAULT + CopyHalfToPaddedBuf(d, st, keys, num_lanes, buf); + v8 = LoadU(d, buf + 0x8 * kLanesPerRow); + v9 = LoadU(d, buf + 0x9 * kLanesPerRow); + va = LoadU(d, buf + 0xa * kLanesPerRow); + vb = LoadU(d, buf + 0xb * kLanesPerRow); + vc = LoadU(d, buf + 0xc * kLanesPerRow); + vd = LoadU(d, buf + 0xd * kLanesPerRow); + ve = LoadU(d, buf + 0xe * kLanesPerRow); + vf = LoadU(d, buf + 0xf * kLanesPerRow); +#endif // HWY_MEM_OPS_MIGHT_FAULT +#if !HWY_MEM_OPS_MIGHT_FAULT || HWY_IDE + (void)buf; + const V vnum_lanes = Set(d, ConvertScalarTo(num_lanes)); + // First offset where not all vector are guaranteed valid. + const V kIota = Iota(d, static_cast(kMinLanes)); + const V k1 = Set(d, static_cast(kLanesPerRow)); + const V k2 = Add(k1, k1); + const V k4 = Add(k2, k2); + const V k8 = Add(k4, k4); + + using M = Mask; + const M m8 = Gt(vnum_lanes, kIota); + const M m9 = Gt(vnum_lanes, Add(kIota, k1)); + const M ma = Gt(vnum_lanes, Add(kIota, k2)); + const M mb = Gt(vnum_lanes, Add(kIota, Sub(k4, k1))); + const M mc = Gt(vnum_lanes, Add(kIota, k4)); + const M md = Gt(vnum_lanes, Add(kIota, Add(k4, k1))); + const M me = Gt(vnum_lanes, Add(kIota, Add(k4, k2))); + const M mf = Gt(vnum_lanes, Add(kIota, Sub(k8, k1))); + + const V kPadding = st.LastValue(d); // Not copied to keys. + v8 = MaskedLoadOr(kPadding, m8, d, keys + 0x8 * kLanesPerRow); + v9 = MaskedLoadOr(kPadding, m9, d, keys + 0x9 * kLanesPerRow); + va = MaskedLoadOr(kPadding, ma, d, keys + 0xa * kLanesPerRow); + vb = MaskedLoadOr(kPadding, mb, d, keys + 0xb * kLanesPerRow); + vc = MaskedLoadOr(kPadding, mc, d, keys + 0xc * kLanesPerRow); + vd = MaskedLoadOr(kPadding, md, d, keys + 0xd * kLanesPerRow); + ve = MaskedLoadOr(kPadding, me, d, keys + 0xe * kLanesPerRow); + vf = MaskedLoadOr(kPadding, mf, d, keys + 0xf * kLanesPerRow); +#endif // !HWY_MEM_OPS_MIGHT_FAULT + + Sort16(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, vf); + + // Merge16x4 is a no-op if kKeysPerRow < 4 etc. + Merge16x2(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, + vc, vd, ve, vf); + Merge16x4(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, + vc, vd, ve, vf); + Merge16x8(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, + vc, vd, ve, vf); +#if !HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD + Merge16x16(d, st, v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, + vc, vd, ve, vf); +#endif + + StoreU(v0, d, keys + 0x0 * kLanesPerRow); + StoreU(v1, d, keys + 0x1 * kLanesPerRow); + StoreU(v2, d, keys + 0x2 * kLanesPerRow); + StoreU(v3, d, keys + 0x3 * kLanesPerRow); + StoreU(v4, d, keys + 0x4 * kLanesPerRow); + StoreU(v5, d, keys + 0x5 * kLanesPerRow); + StoreU(v6, d, keys + 0x6 * kLanesPerRow); + StoreU(v7, d, keys + 0x7 * kLanesPerRow); + +#if HWY_MEM_OPS_MIGHT_FAULT + // Store remaining vectors into buf and safely copy them into keys. + StoreU(v8, d, buf + 0x8 * kLanesPerRow); + StoreU(v9, d, buf + 0x9 * kLanesPerRow); + StoreU(va, d, buf + 0xa * kLanesPerRow); + StoreU(vb, d, buf + 0xb * kLanesPerRow); + StoreU(vc, d, buf + 0xc * kLanesPerRow); + StoreU(vd, d, buf + 0xd * kLanesPerRow); + StoreU(ve, d, buf + 0xe * kLanesPerRow); + StoreU(vf, d, buf + 0xf * kLanesPerRow); + + const ScalableTag dmax; + const size_t Nmax = Lanes(dmax); + + // The first half of vectors have already been stored unconditionally into + // `keys`, so we do not copy them. + size_t i = kMinLanes; + HWY_UNROLL(1) + for (; i + Nmax <= num_lanes; i += Nmax) { + StoreU(LoadU(dmax, buf + i), dmax, keys + i); + } + + // Last iteration: copy partial vector + const size_t remaining = num_lanes - i; + HWY_ASSUME(remaining < 256); // helps FirstN + SafeCopyN(remaining, dmax, buf + i, keys + i); +#endif // HWY_MEM_OPS_MIGHT_FAULT +#if !HWY_MEM_OPS_MIGHT_FAULT || HWY_IDE + BlendedStore(v8, m8, d, keys + 0x8 * kLanesPerRow); + BlendedStore(v9, m9, d, keys + 0x9 * kLanesPerRow); + BlendedStore(va, ma, d, keys + 0xa * kLanesPerRow); + BlendedStore(vb, mb, d, keys + 0xb * kLanesPerRow); + BlendedStore(vc, mc, d, keys + 0xc * kLanesPerRow); + BlendedStore(vd, md, d, keys + 0xd * kLanesPerRow); + BlendedStore(ve, me, d, keys + 0xe * kLanesPerRow); + BlendedStore(vf, mf, d, keys + 0xf * kLanesPerRow); +#endif // !HWY_MEM_OPS_MIGHT_FAULT +} + +// Sorts `keys` within the range [0, num_lanes) via sorting network. +// Reshapes into a matrix, sorts columns independently, and then merges +// into a sorted 1D array without transposing. +// +// `TraitsKV` is SharedTraits>. This abstraction layer bridges +// differences in sort order and single-lane vs 128-bit keys. For key-value +// types, items with the same key are not equivalent. Our sorting network +// does not preserve order, thus we prevent mixing padding into the items by +// comparing all the item bits, including the value (see *ForSortingNetwork). +// +// See M. Blacher's thesis: https://github.com/mark-blacher/masterthesis +template +HWY_NOINLINE void BaseCase(D d, TraitsKV, T* HWY_RESTRICT keys, + size_t num_lanes, T* buf) { + using Traits = typename TraitsKV::SharedTraitsForSortingNetwork; + Traits st; + constexpr size_t kLPK = st.LanesPerKey(); + HWY_DASSERT(num_lanes <= Constants::BaseCaseNumLanes(Lanes(d))); + const size_t num_keys = num_lanes / kLPK; + + // Can be zero when called through HandleSpecialCases, but also 1 (in which + // case the array is already sorted). Also ensures num_lanes - 1 != 0. + if (HWY_UNLIKELY(num_keys <= 1)) return; + + const size_t ceil_log2 = + 32 - Num0BitsAboveMS1Bit_Nonzero32(static_cast(num_keys - 1)); + + // Checking kMaxKeysPerVector avoids generating unreachable codepaths. + constexpr size_t kMaxKeysPerVector = MaxLanes(d) / kLPK; + + using FuncPtr = decltype(&Sort2To2); + const FuncPtr funcs[9] = { + /* <= 1 */ nullptr, // We ensured num_keys > 1. + /* <= 2 */ &Sort2To2, + /* <= 4 */ &Sort3To4, + /* <= 8 */ &Sort8Rows<1, Traits, T>, // 1 key per row + /* <= 16 */ kMaxKeysPerVector >= 2 ? &Sort8Rows<2, Traits, T> : nullptr, + /* <= 32 */ kMaxKeysPerVector >= 4 ? &Sort8Rows<4, Traits, T> : nullptr, + /* <= 64 */ kMaxKeysPerVector >= 4 ? &Sort16Rows<4, Traits, T> : nullptr, + /* <= 128 */ kMaxKeysPerVector >= 8 ? &Sort16Rows<8, Traits, T> : nullptr, +#if !HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD + /* <= 256 */ kMaxKeysPerVector >= 16 ? &Sort16Rows<16, Traits, T> + : nullptr, +#endif + }; + funcs[ceil_log2](st, keys, num_lanes, buf); +} + +// ------------------------------ Partition + +// Partitions O(1) of the *rightmost* keys, at least `N`, until a multiple of +// kUnroll*N remains, or all keys if there are too few for that. +// +// Returns how many remain to partition at the *start* of `keys`, sets `bufL` to +// the number of keys for the left partition written to `buf`, and `writeR` to +// the start of the finished right partition at the end of `keys`. +template +HWY_INLINE size_t PartitionRightmost(D d, Traits st, T* const keys, + const size_t num, const Vec pivot, + size_t& bufL, size_t& writeR, + T* HWY_RESTRICT buf) { + const size_t N = Lanes(d); + HWY_DASSERT(num > 2 * N); // BaseCase handles smaller arrays + + constexpr size_t kUnroll = Constants::kPartitionUnroll; + size_t num_here; // how many to process here + size_t num_main; // how many for main Partition loop (return value) + { + // The main Partition loop increments by kUnroll * N, so at least handle + // the remainders here. + const size_t remainder = num & (kUnroll * N - 1); + // Ensure we handle at least one vector to prevent overruns (see below), but + // still leave a multiple of kUnroll * N. + const size_t min = remainder + (remainder < N ? kUnroll * N : 0); + // Do not exceed the input size. + num_here = HWY_MIN(min, num); + num_main = num - num_here; + // Before the main Partition loop we load two blocks; if not enough left for + // that, handle everything here. + if (num_main < 2 * kUnroll * N) { + num_here = num; + num_main = 0; + } + } + + // Note that `StoreLeftRight` uses `CompressBlendedStore`, which may load and + // store a whole vector starting at `writeR`, and thus overrun `keys`. To + // prevent this, we partition at least `N` of the rightmost `keys` so that + // `StoreLeftRight` will be able to safely blend into them. + HWY_DASSERT(num_here >= N); + + // We cannot use `CompressBlendedStore` for the same reason, so we instead + // write the right-of-partition keys into a buffer in ascending order. + // `min` may be up to (kUnroll + 1) * N, hence `num_here` could be as much as + // (3 * kUnroll + 1) * N, and they might all fall on one side of the pivot. + const size_t max_buf = (3 * kUnroll + 1) * N; + HWY_DASSERT(num_here <= max_buf); + + const T* pReadR = keys + num; // pre-decremented by N + + bufL = 0; + size_t bufR = max_buf; // starting position, not the actual count. + + size_t i = 0; + // For whole vectors, we can LoadU. + for (; i <= num_here - N; i += N) { + pReadR -= N; + HWY_DASSERT(pReadR >= keys); + const Vec v = LoadU(d, pReadR); + + const Mask comp = st.Compare(d, pivot, v); + const size_t numL = CompressStore(v, Not(comp), d, buf + bufL); + bufL += numL; + (void)CompressStore(v, comp, d, buf + bufR); + bufR += (N - numL); + } + + // Last iteration: avoid reading past the end. + const size_t remaining = num_here - i; + if (HWY_LIKELY(remaining != 0)) { + const Mask mask = FirstN(d, remaining); + pReadR -= remaining; + HWY_DASSERT(pReadR >= keys); + const Vec v = LoadN(d, pReadR, remaining); + + const Mask comp = st.Compare(d, pivot, v); + const size_t numL = CompressStore(v, AndNot(comp, mask), d, buf + bufL); + bufL += numL; + (void)CompressStore(v, comp, d, buf + bufR); + bufR += (remaining - numL); + } + + const size_t numWrittenR = bufR - max_buf; +// Prior to 2022-10, Clang MSAN did not understand AVX-512 CompressStore. +#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1600 + detail::MaybeUnpoison(buf, bufL); + detail::MaybeUnpoison(buf + max_buf, numWrittenR); +#endif + + // Overwrite already-read end of keys with bufR. + writeR = num - numWrittenR; + hwy::CopyBytes(buf + max_buf, keys + writeR, numWrittenR * sizeof(T)); + // Ensure we finished reading/writing all we wanted + HWY_DASSERT(pReadR == keys + num_main); + HWY_DASSERT(bufL + numWrittenR == num_here); + return num_main; +} + +// Note: we could track the OrXor of v and pivot to see if the entire left +// partition is equal, but that happens rarely and thus is a net loss. +template +HWY_INLINE void StoreLeftRight(D d, Traits st, const Vec v, + const Vec pivot, T* HWY_RESTRICT keys, + size_t& writeL, size_t& remaining) { + const size_t N = Lanes(d); + + const Mask comp = st.Compare(d, pivot, v); + + // Otherwise StoreU/CompressStore overwrites right keys. + HWY_DASSERT(remaining >= 2 * N); + + remaining -= N; + if (hwy::HWY_NAMESPACE::CompressIsPartition::value || + (HWY_MAX_BYTES == 16 && st.Is128())) { + // Non-native Compress (e.g. AVX2): we are able to partition a vector using + // a single Compress+two StoreU instead of two Compress[Blended]Store. The + // latter are more expensive. Because we store entire vectors, the contents + // between the updated writeL and writeR are ignored and will be overwritten + // by subsequent calls. This works because writeL and writeR are at least + // two vectors apart. + const Vec lr = st.CompressKeys(v, comp); + const size_t num_left = N - CountTrue(d, comp); + StoreU(lr, d, keys + writeL); + // Now write the right-side elements (if any), such that the previous writeR + // is one past the end of the newly written right elements, then advance. + StoreU(lr, d, keys + remaining + writeL); + writeL += num_left; + } else { + // Native Compress[Store] (e.g. AVX3), which only keep the left or right + // side, not both, hence we require two calls. + const size_t num_left = CompressStore(v, Not(comp), d, keys + writeL); + writeL += num_left; + + (void)CompressBlendedStore(v, comp, d, keys + remaining + writeL); + } +} + +template +HWY_INLINE void StoreLeftRight4(D d, Traits st, const Vec v0, + const Vec v1, const Vec v2, + const Vec v3, const Vec pivot, + T* HWY_RESTRICT keys, size_t& writeL, + size_t& remaining) { + StoreLeftRight(d, st, v0, pivot, keys, writeL, remaining); + StoreLeftRight(d, st, v1, pivot, keys, writeL, remaining); + StoreLeftRight(d, st, v2, pivot, keys, writeL, remaining); + StoreLeftRight(d, st, v3, pivot, keys, writeL, remaining); +} + +// For the last two vectors, we cannot use StoreLeftRight because it might +// overwrite prior right-side keys. Instead write R and append L into `buf`. +template +HWY_INLINE void StoreRightAndBuf(D d, Traits st, const Vec v, + const Vec pivot, T* HWY_RESTRICT keys, + size_t& writeR, T* HWY_RESTRICT buf, + size_t& bufL) { + const size_t N = Lanes(d); + const Mask comp = st.Compare(d, pivot, v); + const size_t numL = CompressStore(v, Not(comp), d, buf + bufL); + const size_t numR = N - numL; + bufL += numL; + writeR -= numR; + StoreN(Compress(v, comp), d, keys + writeR, numR); +} + +// Moves "<= pivot" keys to the front, and others to the back. pivot is +// broadcasted. Returns the index of the first key in the right partition. +// +// Time-critical, but aligned loads do not seem to be worthwhile because we +// are not bottlenecked by load ports. +template +HWY_INLINE size_t Partition(D d, Traits st, T* const keys, const size_t num, + const Vec pivot, T* HWY_RESTRICT buf) { + using V = decltype(Zero(d)); + const size_t N = Lanes(d); + + size_t bufL, writeR; + const size_t num_main = + PartitionRightmost(d, st, keys, num, pivot, bufL, writeR, buf); + HWY_DASSERT(num_main <= num && writeR <= num); + HWY_DASSERT(bufL <= Constants::PartitionBufNum(N)); + HWY_DASSERT(num_main + bufL == writeR); + + if (VQSORT_PRINT >= 3) { + fprintf(stderr, " num_main %zu bufL %zu writeR %zu\n", num_main, bufL, + writeR); + } + + constexpr size_t kUnroll = Constants::kPartitionUnroll; + + // Partition splits the vector into 3 sections, left to right: Elements + // smaller or equal to the pivot, unpartitioned elements and elements larger + // than the pivot. To write elements unconditionally on the loop body without + // overwriting existing data, we maintain two regions of the loop where all + // elements have been copied elsewhere (e.g. vector registers.). I call these + // bufferL and bufferR, for left and right respectively. + // + // These regions are tracked by the indices (writeL, writeR, left, right) as + // presented in the diagram below. + // + // writeL writeR + // \/ \/ + // | <= pivot | bufferL | unpartitioned | bufferR | > pivot | + // \/ \/ \/ + // readL readR num + // + // In the main loop body below we choose a side, load some elements out of the + // vector and move either `readL` or `readR`. Next we call into StoreLeftRight + // to partition the data, and the partitioned elements will be written either + // to writeR or writeL and the corresponding index will be moved accordingly. + // + // Note that writeR is not explicitly tracked as an optimization for platforms + // with conditional operations. Instead we track writeL and the number of + // not yet written elements (`remaining`). From the diagram above we can see + // that: + // writeR - writeL = remaining => writeR = remaining + writeL + // + // Tracking `remaining` is advantageous because each iteration reduces the + // number of unpartitioned elements by a fixed amount, so we can compute + // `remaining` without data dependencies. + size_t writeL = 0; + size_t remaining = writeR - writeL; + + const T* readL = keys; + const T* readR = keys + num_main; + // Cannot load if there were fewer than 2 * kUnroll * N. + if (HWY_LIKELY(num_main != 0)) { + HWY_DASSERT(num_main >= 2 * kUnroll * N); + HWY_DASSERT((num_main & (kUnroll * N - 1)) == 0); + + // Make space for writing in-place by reading from readL/readR. + const V vL0 = LoadU(d, readL + 0 * N); + const V vL1 = LoadU(d, readL + 1 * N); + const V vL2 = LoadU(d, readL + 2 * N); + const V vL3 = LoadU(d, readL + 3 * N); + readL += kUnroll * N; + readR -= kUnroll * N; + const V vR0 = LoadU(d, readR + 0 * N); + const V vR1 = LoadU(d, readR + 1 * N); + const V vR2 = LoadU(d, readR + 2 * N); + const V vR3 = LoadU(d, readR + 3 * N); + + // readL/readR changed above, so check again before the loop. + while (readL != readR) { + V v0, v1, v2, v3; + + // Data-dependent but branching is faster than forcing branch-free. + const size_t capacityL = + static_cast((readL - keys) - static_cast(writeL)); + HWY_DASSERT(capacityL <= num_main); // >= 0 + // Load data from the end of the vector with less data (front or back). + // The next paragraphs explain how this works. + // + // let block_size = (kUnroll * N) + // On the loop prelude we load block_size elements from the front of the + // vector and an additional block_size elements from the back. On each + // iteration k elements are written to the front of the vector and + // (block_size - k) to the back. + // + // This creates a loop invariant where the capacity on the front + // (capacityL) and on the back (capacityR) always add to 2 * block_size. + // In other words: + // capacityL + capacityR = 2 * block_size + // capacityR = 2 * block_size - capacityL + // + // This means that: + // capacityL > capacityR <=> + // capacityL > 2 * block_size - capacityL <=> + // 2 * capacityL > 2 * block_size <=> + // capacityL > block_size + if (capacityL > kUnroll * N) { // equivalent to capacityL > capacityR. + readR -= kUnroll * N; + v0 = LoadU(d, readR + 0 * N); + v1 = LoadU(d, readR + 1 * N); + v2 = LoadU(d, readR + 2 * N); + v3 = LoadU(d, readR + 3 * N); + hwy::Prefetch(readR - 3 * kUnroll * N); + } else { + v0 = LoadU(d, readL + 0 * N); + v1 = LoadU(d, readL + 1 * N); + v2 = LoadU(d, readL + 2 * N); + v3 = LoadU(d, readL + 3 * N); + readL += kUnroll * N; + hwy::Prefetch(readL + 3 * kUnroll * N); + } + + StoreLeftRight4(d, st, v0, v1, v2, v3, pivot, keys, writeL, remaining); + } + + // Now finish writing the saved vectors to the middle. + StoreLeftRight4(d, st, vL0, vL1, vL2, vL3, pivot, keys, writeL, remaining); + + StoreLeftRight(d, st, vR0, pivot, keys, writeL, remaining); + StoreLeftRight(d, st, vR1, pivot, keys, writeL, remaining); + + // Switch back to updating writeR for clarity. The middle is missing vR2/3 + // and what is in the buffer. + HWY_DASSERT(remaining == bufL + 2 * N); + writeR = writeL + remaining; + // Switch to StoreRightAndBuf for the last two vectors because + // StoreLeftRight may overwrite prior keys. + StoreRightAndBuf(d, st, vR2, pivot, keys, writeR, buf, bufL); + StoreRightAndBuf(d, st, vR3, pivot, keys, writeR, buf, bufL); + HWY_DASSERT(writeR <= num); // >= 0 + HWY_DASSERT(bufL <= Constants::PartitionBufNum(N)); + } + + // We have partitioned [0, num) into [0, writeL) and [writeR, num). + // Now insert left keys from `buf` to empty space starting at writeL. + HWY_DASSERT(writeL + bufL == writeR); + CopyBytes(buf, keys + writeL, bufL * sizeof(T)); + + return writeL + bufL; +} + +// Returns true and partitions if [keys, keys + num) contains only {valueL, +// valueR}. Otherwise, sets third to the first differing value; keys may have +// been reordered and a regular Partition is still necessary. +// Called from two locations, hence NOINLINE. +template +HWY_NOINLINE bool MaybePartitionTwoValue(D d, Traits st, T* HWY_RESTRICT keys, + size_t num, const Vec valueL, + const Vec valueR, Vec& third, + T* HWY_RESTRICT buf) { + const size_t N = Lanes(d); + // No guarantee that num >= N because this is called for subarrays! + + size_t i = 0; + size_t writeL = 0; + + // As long as all lanes are equal to L or R, we can overwrite with valueL. + // This is faster than first counting, then backtracking to fill L and R. + if (num >= N) { + for (; i <= num - N; i += N) { + const Vec v = LoadU(d, keys + i); + // It is not clear how to apply OrXor here - that can check if *both* + // comparisons are true, but here we want *either*. Comparing the unsigned + // min of differences to zero works, but is expensive for u64 prior to + // AVX3. + const Mask eqL = st.EqualKeys(d, v, valueL); + const Mask eqR = st.EqualKeys(d, v, valueR); + // At least one other value present; will require a regular partition. + // On AVX-512, Or + AllTrue are folded into a single kortest if we are + // careful with the FindKnownFirstTrue argument, see below. + if (HWY_UNLIKELY(!AllTrue(d, Or(eqL, eqR)))) { + // If we repeat Or(eqL, eqR) here, the compiler will hoist it into the + // loop, which is a pessimization because this if-true branch is cold. + // We can defeat this via Not(Xor), which is equivalent because eqL and + // eqR cannot be true at the same time. Can we elide the additional Not? + // FindFirstFalse instructions are generally unavailable, but we can + // fuse Not and Xor/Or into one ExclusiveNeither. + const size_t lane = FindKnownFirstTrue(d, ExclusiveNeither(eqL, eqR)); + third = st.SetKey(d, keys + i + lane); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "found 3rd value at vec %zu; writeL %zu\n", i, + writeL); + } + // 'Undo' what we did by filling the remainder of what we read with R. + if (i >= N) { + for (; writeL <= i - N; writeL += N) { + StoreU(valueR, d, keys + writeL); + } + } + StoreN(valueR, d, keys + writeL, i - writeL); + return false; + } + StoreU(valueL, d, keys + writeL); + writeL += CountTrue(d, eqL); + } + } + + // Final vector, masked comparison (no effect if i == num) + const size_t remaining = num - i; + SafeCopyN(remaining, d, keys + i, buf); + const Vec v = Load(d, buf); + const Mask valid = FirstN(d, remaining); + const Mask eqL = And(st.EqualKeys(d, v, valueL), valid); + const Mask eqR = st.EqualKeys(d, v, valueR); + // Invalid lanes are considered equal. + const Mask eq = Or(Or(eqL, eqR), Not(valid)); + // At least one other value present; will require a regular partition. + if (HWY_UNLIKELY(!AllTrue(d, eq))) { + const size_t lane = FindKnownFirstTrue(d, Not(eq)); + third = st.SetKey(d, keys + i + lane); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "found 3rd value at partial vec %zu; writeL %zu\n", i, + writeL); + } + // 'Undo' what we did by filling the remainder of what we read with R. + if (i >= N) { + for (; writeL <= i - N; writeL += N) { + StoreU(valueR, d, keys + writeL); + } + } + StoreN(valueR, d, keys + writeL, i - writeL); + return false; + } + StoreN(valueL, d, keys + writeL, remaining); + writeL += CountTrue(d, eqL); + + // Fill right side + i = writeL; + if (num >= N) { + for (; i <= num - N; i += N) { + StoreU(valueR, d, keys + i); + } + } + StoreN(valueR, d, keys + i, num - i); + + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Successful MaybePartitionTwoValue\n"); + } + return true; +} + +// Same as above, except that the pivot equals valueR, so scan right to left. +template +HWY_INLINE bool MaybePartitionTwoValueR(D d, Traits st, T* HWY_RESTRICT keys, + size_t num, const Vec valueL, + const Vec valueR, Vec& third, + T* HWY_RESTRICT buf) { + const size_t N = Lanes(d); + + HWY_DASSERT(num >= N); + size_t pos = num - N; // current read/write position + size_t countR = 0; // number of valueR found + + // For whole vectors, in descending address order: as long as all lanes are + // equal to L or R, overwrite with valueR. This is faster than counting, then + // filling both L and R. Loop terminates after unsigned wraparound. + for (; pos < num; pos -= N) { + const Vec v = LoadU(d, keys + pos); + // It is not clear how to apply OrXor here - that can check if *both* + // comparisons are true, but here we want *either*. Comparing the unsigned + // min of differences to zero works, but is expensive for u64 prior to AVX3. + const Mask eqL = st.EqualKeys(d, v, valueL); + const Mask eqR = st.EqualKeys(d, v, valueR); + // If there is a third value, stop and undo what we've done. On AVX-512, + // Or + AllTrue are folded into a single kortest, but only if we are + // careful with the FindKnownFirstTrue argument - see prior comment on that. + if (HWY_UNLIKELY(!AllTrue(d, Or(eqL, eqR)))) { + const size_t lane = FindKnownFirstTrue(d, ExclusiveNeither(eqL, eqR)); + third = st.SetKey(d, keys + pos + lane); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "found 3rd value at vec %zu; countR %zu\n", pos, + countR); + MaybePrintVector(d, "third", third, 0, st.LanesPerKey()); + } + pos += N; // rewind: we haven't yet committed changes in this iteration. + // We have filled [pos, num) with R, but only countR of them should have + // been written. Rewrite [pos, num - countR) to L. + HWY_DASSERT(countR <= num - pos); + const size_t endL = num - countR; + if (endL >= N) { + for (; pos <= endL - N; pos += N) { + StoreU(valueL, d, keys + pos); + } + } + StoreN(valueL, d, keys + pos, endL - pos); + return false; + } + StoreU(valueR, d, keys + pos); + countR += CountTrue(d, eqR); + } + + // Final partial (or empty) vector, masked comparison. + const size_t remaining = pos + N; + HWY_DASSERT(remaining <= N); + const Vec v = LoadU(d, keys); // Safe because num >= N. + const Mask valid = FirstN(d, remaining); + const Mask eqL = st.EqualKeys(d, v, valueL); + const Mask eqR = And(st.EqualKeys(d, v, valueR), valid); + // Invalid lanes are considered equal. + const Mask eq = Or(Or(eqL, eqR), Not(valid)); + // At least one other value present; will require a regular partition. + if (HWY_UNLIKELY(!AllTrue(d, eq))) { + const size_t lane = FindKnownFirstTrue(d, Not(eq)); + third = st.SetKey(d, keys + lane); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "found 3rd value at partial vec %zu; writeR %zu\n", pos, + countR); + MaybePrintVector(d, "third", third, 0, st.LanesPerKey()); + } + pos += N; // rewind: we haven't yet committed changes in this iteration. + // We have filled [pos, num) with R, but only countR of them should have + // been written. Rewrite [pos, num - countR) to L. + HWY_DASSERT(countR <= num - pos); + const size_t endL = num - countR; + if (endL >= N) { + for (; pos <= endL - N; pos += N) { + StoreU(valueL, d, keys + pos); + } + } + StoreN(valueL, d, keys + pos, endL - pos); + return false; + } + const size_t lastR = CountTrue(d, eqR); + countR += lastR; + + // First finish writing valueR - [0, N) lanes were not yet written. + StoreU(valueR, d, keys); // Safe because num >= N. + + // Fill left side (ascending order for clarity) + const size_t endL = num - countR; + size_t i = 0; + if (endL >= N) { + for (; i <= endL - N; i += N) { + StoreU(valueL, d, keys + i); + } + } + Store(valueL, d, buf); + SafeCopyN(endL - i, d, buf, keys + i); // avoids ASan overrun + + if (VQSORT_PRINT >= 2) { + fprintf(stderr, + "MaybePartitionTwoValueR countR %zu pos %zu i %zu endL %zu\n", + countR, pos, i, endL); + } + + return true; +} + +// `idx_second` is `first_mismatch` from `AllEqual` and thus the index of the +// second key. This is the first path into `MaybePartitionTwoValue`, called +// when all samples are equal. Returns false if there are at least a third +// value and sets `third`. Otherwise, partitions the array and returns true. +template +HWY_INLINE bool PartitionIfTwoKeys(D d, Traits st, const Vec pivot, + T* HWY_RESTRICT keys, size_t num, + const size_t idx_second, const Vec second, + Vec& third, T* HWY_RESTRICT buf) { + // True if second comes before pivot. + const bool is_pivotR = AllFalse(d, st.Compare(d, pivot, second)); + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "Samples all equal, diff at %zu, isPivotR %d\n", idx_second, + is_pivotR); + } + HWY_DASSERT(AllFalse(d, st.EqualKeys(d, second, pivot))); + + // If pivot is R, we scan backwards over the entire array. Otherwise, + // we already scanned up to idx_second and can leave those in place. + return is_pivotR ? MaybePartitionTwoValueR(d, st, keys, num, second, pivot, + third, buf) + : MaybePartitionTwoValue(d, st, keys + idx_second, + num - idx_second, pivot, second, + third, buf); +} + +// Second path into `MaybePartitionTwoValue`, called when not all samples are +// equal. `samples` is sorted. +template +HWY_INLINE bool PartitionIfTwoSamples(D d, Traits st, T* HWY_RESTRICT keys, + size_t num, T* HWY_RESTRICT samples) { + constexpr size_t kSampleLanes = Constants::SampleLanes(); + constexpr size_t N1 = st.LanesPerKey(); + const Vec valueL = st.SetKey(d, samples); + const Vec valueR = st.SetKey(d, samples + kSampleLanes - N1); + HWY_DASSERT(AllTrue(d, st.Compare(d, valueL, valueR))); + HWY_DASSERT(AllFalse(d, st.EqualKeys(d, valueL, valueR))); + const Vec prev = st.PrevValue(d, valueR); + // If the sample has more than two values, then the keys have at least that + // many, and thus this special case is inapplicable. + if (HWY_UNLIKELY(!AllTrue(d, st.EqualKeys(d, valueL, prev)))) { + return false; + } + + // Must not overwrite samples because if this returns false, caller wants to + // read the original samples again. + T* HWY_RESTRICT buf = samples + kSampleLanes; + Vec third; // unused + return MaybePartitionTwoValue(d, st, keys, num, valueL, valueR, third, buf); +} + +// ------------------------------ Pivot sampling + +template +HWY_INLINE V MedianOf3(Traits st, V v0, V v1, V v2) { + const DFromV d; + // Slightly faster for 128-bit, apparently because not serially dependent. + if (st.Is128()) { + // Median = XOR-sum 'minus' the first and last. Calling First twice is + // slightly faster than Compare + 2 IfThenElse or even IfThenElse + XOR. + const V sum = Xor(Xor(v0, v1), v2); + const V first = st.First(d, st.First(d, v0, v1), v2); + const V last = st.Last(d, st.Last(d, v0, v1), v2); + return Xor(Xor(sum, first), last); + } + st.Sort2(d, v0, v2); + v1 = st.Last(d, v0, v1); + v1 = st.First(d, v1, v2); + return v1; +} + +// Returns slightly biased random index of a chunk in [0, num_chunks). +// See https://www.pcg-random.org/posts/bounded-rands.html. +HWY_INLINE size_t RandomChunkIndex(const uint32_t num_chunks, uint32_t bits) { + const uint64_t chunk_index = (static_cast(bits) * num_chunks) >> 32; + HWY_DASSERT(chunk_index < num_chunks); + return static_cast(chunk_index); +} + +// Writes samples from `keys[0, num)` into `buf`. +template +HWY_INLINE void DrawSamples(D d, Traits st, T* HWY_RESTRICT keys, size_t num, + T* HWY_RESTRICT buf, uint64_t* HWY_RESTRICT state) { + using V = decltype(Zero(d)); + const size_t N = Lanes(d); + + // Power of two + constexpr size_t kLanesPerChunk = Constants::LanesPerChunk(sizeof(T)); + + // Align start of keys to chunks. We have at least 2 chunks (x 64 bytes) + // because the base case handles anything up to 8 vectors (x 16 bytes). + HWY_DASSERT(num >= Constants::SampleLanes()); + const size_t misalign = + (reinterpret_cast(keys) / sizeof(T)) & (kLanesPerChunk - 1); + if (misalign != 0) { + const size_t consume = kLanesPerChunk - misalign; + keys += consume; + num -= consume; + } + + // Generate enough random bits for 6 uint32 + uint32_t bits[6]; + for (size_t i = 0; i < 6; i += 2) { + const uint64_t bits64 = RandomBits(state); + CopyBytes<8>(&bits64, bits + i); + } + + const size_t num_chunks64 = num / kLanesPerChunk; + // Clamp to uint32 for RandomChunkIndex + const uint32_t num_chunks = + static_cast(HWY_MIN(num_chunks64, 0xFFFFFFFFull)); + + const size_t offset0 = RandomChunkIndex(num_chunks, bits[0]) * kLanesPerChunk; + const size_t offset1 = RandomChunkIndex(num_chunks, bits[1]) * kLanesPerChunk; + const size_t offset2 = RandomChunkIndex(num_chunks, bits[2]) * kLanesPerChunk; + const size_t offset3 = RandomChunkIndex(num_chunks, bits[3]) * kLanesPerChunk; + const size_t offset4 = RandomChunkIndex(num_chunks, bits[4]) * kLanesPerChunk; + const size_t offset5 = RandomChunkIndex(num_chunks, bits[5]) * kLanesPerChunk; + for (size_t i = 0; i < kLanesPerChunk; i += N) { + const V v0 = Load(d, keys + offset0 + i); + const V v1 = Load(d, keys + offset1 + i); + const V v2 = Load(d, keys + offset2 + i); + const V medians0 = MedianOf3(st, v0, v1, v2); + Store(medians0, d, buf + i); + + const V v3 = Load(d, keys + offset3 + i); + const V v4 = Load(d, keys + offset4 + i); + const V v5 = Load(d, keys + offset5 + i); + const V medians1 = MedianOf3(st, v3, v4, v5); + Store(medians1, d, buf + i + kLanesPerChunk); + } +} + +template +V OrXor(const V o, const V x1, const V x2) { + return Or(o, Xor(x1, x2)); // TERNLOG on AVX3 +} + +// For detecting inputs where (almost) all keys are equal. +template +HWY_INLINE bool UnsortedSampleEqual(D d, Traits st, + const TFromD* HWY_RESTRICT samples) { + constexpr size_t kSampleLanes = Constants::SampleLanes>(); + const size_t N = Lanes(d); + // Both are powers of two, so there will be no remainders. + HWY_DASSERT(N < kSampleLanes); + using V = Vec; + + const V first = st.SetKey(d, samples); + + if (!hwy::IsFloat>()) { + // OR of XOR-difference may be faster than comparison. + V diff = Zero(d); + for (size_t i = 0; i < kSampleLanes; i += N) { + const V v = Load(d, samples + i); + diff = OrXor(diff, first, v); + } + return st.NoKeyDifference(d, diff); + } else { + // Disable the OrXor optimization for floats because OrXor will not treat + // subnormals the same as actual comparisons, leading to logic errors for + // 2-value cases. + for (size_t i = 0; i < kSampleLanes; i += N) { + const V v = Load(d, samples + i); + if (!AllTrue(d, st.EqualKeys(d, v, first))) { + return false; + } + } + return true; + } +} + +template +HWY_INLINE void SortSamples(D d, Traits st, T* HWY_RESTRICT buf) { + const size_t N = Lanes(d); + constexpr size_t kSampleLanes = Constants::SampleLanes(); + // Network must be large enough to sort two chunks. + HWY_DASSERT(Constants::BaseCaseNumLanes(N) >= kSampleLanes); + + BaseCase(d, st, buf, kSampleLanes, buf + kSampleLanes); + + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Samples:\n"); + for (size_t i = 0; i < kSampleLanes; i += N) { + MaybePrintVector(d, "", Load(d, buf + i), 0, N); + } + } +} + +// ------------------------------ Pivot selection + +enum class PivotResult { + kDone, // stop without partitioning (all equal, or two-value partition) + kNormal, // partition and recurse left and right + kIsFirst, // partition but skip left recursion + kWasLast, // partition but skip right recursion +}; + +HWY_INLINE const char* PivotResultString(PivotResult result) { + switch (result) { + case PivotResult::kDone: + return "done"; + case PivotResult::kNormal: + return "normal"; + case PivotResult::kIsFirst: + return "first"; + case PivotResult::kWasLast: + return "last"; + } + return "unknown"; +} + +// (Could vectorize, but only 0.2% of total time) +template +HWY_INLINE size_t PivotRank(Traits st, const T* HWY_RESTRICT samples) { + constexpr size_t kSampleLanes = Constants::SampleLanes(); + constexpr size_t N1 = st.LanesPerKey(); + + constexpr size_t kRankMid = kSampleLanes / 2; + static_assert(kRankMid % N1 == 0, "Mid is not an aligned key"); + + // Find the previous value not equal to the median. + size_t rank_prev = kRankMid - N1; + for (; st.Equal1(samples + rank_prev, samples + kRankMid); rank_prev -= N1) { + // All previous samples are equal to the median. + if (rank_prev == 0) return 0; + } + + size_t rank_next = rank_prev + N1; + for (; st.Equal1(samples + rank_next, samples + kRankMid); rank_next += N1) { + // The median is also the largest sample. If it is also the largest key, + // we'd end up with an empty right partition, so choose the previous key. + if (rank_next == kSampleLanes - N1) return rank_prev; + } + + // If we choose the median as pivot, the ratio of keys ending in the left + // partition will likely be rank_next/kSampleLanes (if the sample is + // representative). This is because equal-to-pivot values also land in the + // left - it's infeasible to do an in-place vectorized 3-way partition. + // Check whether prev would lead to a more balanced partition. + const size_t excess_if_median = rank_next - kRankMid; + const size_t excess_if_prev = kRankMid - rank_prev; + return excess_if_median < excess_if_prev ? kRankMid : rank_prev; +} + +// Returns pivot chosen from `samples`. It will never be the largest key +// (thus the right partition will never be empty). +template +HWY_INLINE Vec ChoosePivotByRank(D d, Traits st, + const T* HWY_RESTRICT samples) { + const size_t pivot_rank = PivotRank(st, samples); + const Vec pivot = st.SetKey(d, samples + pivot_rank); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, " Pivot rank %3zu\n", pivot_rank); + HWY_ALIGN T pivot_lanes[MaxLanes(d)]; + Store(pivot, d, pivot_lanes); + using Key = typename Traits::KeyType; + Key key; + CopyBytes(pivot_lanes, &key); + PrintValue(key); + } + // Verify pivot is not equal to the last sample. + constexpr size_t kSampleLanes = Constants::SampleLanes(); + constexpr size_t N1 = st.LanesPerKey(); + const Vec last = st.SetKey(d, samples + kSampleLanes - N1); + const bool all_neq = AllTrue(d, st.NotEqualKeys(d, pivot, last)); + (void)all_neq; + HWY_DASSERT(all_neq); + return pivot; +} + +// Returns true if all keys equal `pivot`, otherwise returns false and sets +// `*first_mismatch' to the index of the first differing key. +template +HWY_INLINE bool AllEqual(D d, Traits st, const Vec pivot, + const T* HWY_RESTRICT keys, size_t num, + size_t* HWY_RESTRICT first_mismatch) { + const size_t N = Lanes(d); + // Ensures we can use overlapping loads for the tail; see HandleSpecialCases. + HWY_DASSERT(num >= N); + const Vec zero = Zero(d); + + // Vector-align keys + i. + const size_t misalign = + (reinterpret_cast(keys) / sizeof(T)) & (N - 1); + HWY_DASSERT(misalign % st.LanesPerKey() == 0); + const size_t consume = N - misalign; + { + const Vec v = LoadU(d, keys); + // Only check masked lanes; consider others to be equal. + const Mask diff = And(FirstN(d, consume), st.NotEqualKeys(d, v, pivot)); + if (HWY_UNLIKELY(!AllFalse(d, diff))) { + const size_t lane = FindKnownFirstTrue(d, diff); + *first_mismatch = lane; + return false; + } + } + size_t i = consume; + HWY_DASSERT(((reinterpret_cast(keys + i) / sizeof(T)) & (N - 1)) == + 0); + + // Disable the OrXor optimization for floats because OrXor will not treat + // subnormals the same as actual comparisons, leading to logic errors for + // 2-value cases. + if (!hwy::IsFloat()) { + // Sticky bits registering any difference between `keys` and the first key. + // We use vector XOR because it may be cheaper than comparisons, especially + // for 128-bit. 2x unrolled for more ILP. + Vec diff0 = zero; + Vec diff1 = zero; + + // We want to stop once a difference has been found, but without slowing + // down the loop by comparing during each iteration. The compromise is to + // compare after a 'group', which consists of kLoops times two vectors. + constexpr size_t kLoops = 8; + const size_t lanes_per_group = kLoops * 2 * N; + + if (num >= lanes_per_group) { + for (; i <= num - lanes_per_group; i += lanes_per_group) { + HWY_DEFAULT_UNROLL + for (size_t loop = 0; loop < kLoops; ++loop) { + const Vec v0 = Load(d, keys + i + loop * 2 * N); + const Vec v1 = Load(d, keys + i + loop * 2 * N + N); + diff0 = OrXor(diff0, v0, pivot); + diff1 = OrXor(diff1, v1, pivot); + } + + // If there was a difference in the entire group: + if (HWY_UNLIKELY(!st.NoKeyDifference(d, Or(diff0, diff1)))) { + // .. then loop until the first one, with termination guarantee. + for (;; i += N) { + const Vec v = Load(d, keys + i); + const Mask diff = st.NotEqualKeys(d, v, pivot); + if (HWY_UNLIKELY(!AllFalse(d, diff))) { + const size_t lane = FindKnownFirstTrue(d, diff); + *first_mismatch = i + lane; + return false; + } + } + } + } + } + } // !hwy::IsFloat() + + // Whole vectors, no unrolling, compare directly + for (; i <= num - N; i += N) { + const Vec v = Load(d, keys + i); + const Mask diff = st.NotEqualKeys(d, v, pivot); + if (HWY_UNLIKELY(!AllFalse(d, diff))) { + const size_t lane = FindKnownFirstTrue(d, diff); + *first_mismatch = i + lane; + return false; + } + } + // Always re-check the last (unaligned) vector to reduce branching. + i = num - N; + const Vec v = LoadU(d, keys + i); + const Mask diff = st.NotEqualKeys(d, v, pivot); + if (HWY_UNLIKELY(!AllFalse(d, diff))) { + const size_t lane = FindKnownFirstTrue(d, diff); + *first_mismatch = i + lane; + return false; + } + + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "All keys equal\n"); + } + return true; // all equal +} + +// Called from 'two locations', but only one is active (IsKV is constexpr). +template +HWY_INLINE bool ExistsAnyBefore(D d, Traits st, const T* HWY_RESTRICT keys, + size_t num, const Vec pivot) { + const size_t N = Lanes(d); + HWY_DASSERT(num >= N); // See HandleSpecialCases + + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Scanning for before\n"); + } + + size_t i = 0; + + constexpr size_t kLoops = 16; + const size_t lanes_per_group = kLoops * N; + + Vec first = pivot; + + // Whole group, unrolled + if (num >= lanes_per_group) { + for (; i <= num - lanes_per_group; i += lanes_per_group) { + HWY_DEFAULT_UNROLL + for (size_t loop = 0; loop < kLoops; ++loop) { + const Vec curr = LoadU(d, keys + i + loop * N); + first = st.First(d, first, curr); + } + + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, first, pivot)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at end of group %zu\n", + i + lanes_per_group); + } + return true; + } + } + } + // Whole vectors, no unrolling + for (; i <= num - N; i += N) { + const Vec curr = LoadU(d, keys + i); + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, curr, pivot)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at %zu\n", i); + } + return true; + } + } + // If there are remainders, re-check the last whole vector. + if (HWY_LIKELY(i != num)) { + const Vec curr = LoadU(d, keys + num - N); + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, curr, pivot)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at last %zu\n", num - N); + } + return true; + } + } + + return false; // pivot is the first +} + +// Called from 'two locations', but only one is active (IsKV is constexpr). +template +HWY_INLINE bool ExistsAnyAfter(D d, Traits st, const T* HWY_RESTRICT keys, + size_t num, const Vec pivot) { + const size_t N = Lanes(d); + HWY_DASSERT(num >= N); // See HandleSpecialCases + + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Scanning for after\n"); + } + + size_t i = 0; + + constexpr size_t kLoops = 16; + const size_t lanes_per_group = kLoops * N; + + Vec last = pivot; + + // Whole group, unrolled + if (num >= lanes_per_group) { + for (; i + lanes_per_group <= num; i += lanes_per_group) { + HWY_DEFAULT_UNROLL + for (size_t loop = 0; loop < kLoops; ++loop) { + const Vec curr = LoadU(d, keys + i + loop * N); + last = st.Last(d, last, curr); + } + + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, last)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at end of group %zu\n", + i + lanes_per_group); + } + return true; + } + } + } + // Whole vectors, no unrolling + for (; i <= num - N; i += N) { + const Vec curr = LoadU(d, keys + i); + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, curr)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at %zu\n", i); + } + return true; + } + } + // If there are remainders, re-check the last whole vector. + if (HWY_LIKELY(i != num)) { + const Vec curr = LoadU(d, keys + num - N); + if (HWY_UNLIKELY(!AllFalse(d, st.Compare(d, pivot, curr)))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Stopped scanning at last %zu\n", num - N); + } + return true; + } + } + + return false; // pivot is the last +} + +// Returns pivot chosen from `keys[0, num)`. It will never be the largest key +// (thus the right partition will never be empty). +template +HWY_INLINE Vec ChoosePivotForEqualSamples(D d, Traits st, + T* HWY_RESTRICT keys, size_t num, + T* HWY_RESTRICT samples, + Vec second, Vec third, + PivotResult& result) { + const Vec pivot = st.SetKey(d, samples); // the single unique sample + + // Early out for mostly-0 arrays, where pivot is often FirstValue. + if (HWY_UNLIKELY(AllTrue(d, st.EqualKeys(d, pivot, st.FirstValue(d))))) { + result = PivotResult::kIsFirst; + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Pivot equals first possible value\n"); + } + return pivot; + } + if (HWY_UNLIKELY(AllTrue(d, st.EqualKeys(d, pivot, st.LastValue(d))))) { + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "Pivot equals last possible value\n"); + } + result = PivotResult::kWasLast; + return st.PrevValue(d, pivot); + } + + // If key-value, we didn't run PartitionIfTwo* and thus `third` is unknown and + // cannot be used. + if (st.IsKV()) { + // If true, pivot is either middle or last. + const bool before = !AllFalse(d, st.Compare(d, second, pivot)); + if (HWY_UNLIKELY(before)) { + // Not last, so middle. + if (HWY_UNLIKELY(ExistsAnyAfter(d, st, keys, num, pivot))) { + result = PivotResult::kNormal; + return pivot; + } + + // We didn't find anything after pivot, so it is the last. Because keys + // equal to the pivot go to the left partition, the right partition would + // be empty and Partition will not have changed anything. Instead use the + // previous value in sort order, which is not necessarily an actual key. + result = PivotResult::kWasLast; + return st.PrevValue(d, pivot); + } + + // Otherwise, pivot is first or middle. Rule out it being first: + if (HWY_UNLIKELY(ExistsAnyBefore(d, st, keys, num, pivot))) { + result = PivotResult::kNormal; + return pivot; + } + // It is first: fall through to shared code below. + } else { + // Check if pivot is between two known values. If so, it is not the first + // nor the last and we can avoid scanning. + st.Sort2(d, second, third); + HWY_DASSERT(AllTrue(d, st.Compare(d, second, third))); + const bool before = !AllFalse(d, st.Compare(d, second, pivot)); + const bool after = !AllFalse(d, st.Compare(d, pivot, third)); + // Only reached if there are three keys, which means pivot is either first, + // last, or in between. Thus there is another key that comes before or + // after. + HWY_DASSERT(before || after); + if (HWY_UNLIKELY(before)) { + // Neither first nor last. + if (HWY_UNLIKELY(after || ExistsAnyAfter(d, st, keys, num, pivot))) { + result = PivotResult::kNormal; + return pivot; + } + + // We didn't find anything after pivot, so it is the last. Because keys + // equal to the pivot go to the left partition, the right partition would + // be empty and Partition will not have changed anything. Instead use the + // previous value in sort order, which is not necessarily an actual key. + result = PivotResult::kWasLast; + return st.PrevValue(d, pivot); + } + + // Has after, and we found one before: in the middle. + if (HWY_UNLIKELY(ExistsAnyBefore(d, st, keys, num, pivot))) { + result = PivotResult::kNormal; + return pivot; + } + } + + // Pivot is first. We could consider a special partition mode that only + // reads from and writes to the right side, and later fills in the left + // side, which we know is equal to the pivot. However, that leads to more + // cache misses if the array is large, and doesn't save much, hence is a + // net loss. + result = PivotResult::kIsFirst; + return pivot; +} + +// ------------------------------ Quicksort recursion + +enum class RecurseMode { + kSort, // Sort mode. + kSelect, // Select mode. + // The element pointed at by nth is changed to whatever element + // would occur in that position if [first, last) were sorted. All of + // the elements before this new nth element are less than or equal + // to the elements after the new nth element. + kLooseSelect, // Loose select mode. + // The first n elements will contain the n smallest elements in + // unspecified order +}; + +template +HWY_NOINLINE void PrintMinMax(D d, Traits st, const T* HWY_RESTRICT keys, + size_t num, T* HWY_RESTRICT buf) { + if (VQSORT_PRINT >= 2) { + const size_t N = Lanes(d); + if (num < N) return; + + Vec first = st.LastValue(d); + Vec last = st.FirstValue(d); + + size_t i = 0; + for (; i <= num - N; i += N) { + const Vec v = LoadU(d, keys + i); + first = st.First(d, v, first); + last = st.Last(d, v, last); + } + if (HWY_LIKELY(i != num)) { + HWY_DASSERT(num >= N); // See HandleSpecialCases + const Vec v = LoadU(d, keys + num - N); + first = st.First(d, v, first); + last = st.Last(d, v, last); + } + + first = st.FirstOfLanes(d, first, buf); + last = st.LastOfLanes(d, last, buf); + MaybePrintVector(d, "first", first, 0, st.LanesPerKey()); + MaybePrintVector(d, "last", last, 0, st.LanesPerKey()); + } +} + +template +HWY_NOINLINE void Recurse(D d, Traits st, T* HWY_RESTRICT keys, + const size_t num, T* HWY_RESTRICT buf, + uint64_t* HWY_RESTRICT state, + const size_t remaining_levels, const size_t k = 0) { + HWY_DASSERT(num != 0); + + const size_t N = Lanes(d); + constexpr size_t kLPK = st.LanesPerKey(); + if (HWY_UNLIKELY(num <= Constants::BaseCaseNumLanes(N))) { + BaseCase(d, st, keys, num, buf); + return; + } + + // Move after BaseCase so we skip printing for small subarrays. + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "\n\n=== Recurse depth=%zu len=%zu k=%zu\n", + remaining_levels, num, k); + PrintMinMax(d, st, keys, num, buf); + } + + DrawSamples(d, st, keys, num, buf, state); + + Vec pivot; + PivotResult result = PivotResult::kNormal; + if (HWY_UNLIKELY(UnsortedSampleEqual(d, st, buf))) { + pivot = st.SetKey(d, buf); + size_t idx_second = 0; + if (HWY_UNLIKELY(AllEqual(d, st, pivot, keys, num, &idx_second))) { + return; + } + HWY_DASSERT(idx_second % st.LanesPerKey() == 0); + // Must capture the value before PartitionIfTwoKeys may overwrite it. + const Vec second = st.SetKey(d, keys + idx_second); + MaybePrintVector(d, "pivot", pivot, 0, st.LanesPerKey()); + MaybePrintVector(d, "second", second, 0, st.LanesPerKey()); + + Vec third = Zero(d); + // Not supported for key-value types because two 'keys' may be equivalent + // but not interchangeable (their values may differ). + if (HWY_UNLIKELY(!st.IsKV() && + PartitionIfTwoKeys(d, st, pivot, keys, num, idx_second, + second, third, buf))) { + return; // Done, skip recursion because each side has all-equal keys. + } + + // We can no longer start scanning from idx_second because + // PartitionIfTwoKeys may have reordered keys. + pivot = ChoosePivotForEqualSamples(d, st, keys, num, buf, second, third, + result); + // If kNormal, `pivot` is very common but not the first/last. It is + // tempting to do a 3-way partition (to avoid moving the =pivot keys a + // second time), but that is a net loss due to the extra comparisons. + } else { + SortSamples(d, st, buf); + + // Not supported for key-value types because two 'keys' may be equivalent + // but not interchangeable (their values may differ). + if (HWY_UNLIKELY(!st.IsKV() && + PartitionIfTwoSamples(d, st, keys, num, buf))) { + return; + } + + pivot = ChoosePivotByRank(d, st, buf); + } + + // Too many recursions. This is unlikely to happen because we select pivots + // from large (though still O(1)) samples. + if (HWY_UNLIKELY(remaining_levels == 0)) { + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "HeapSort reached, size=%zu\n", num); + } + HeapSort(st, keys, num); // Slow but N*logN. + return; + } + + const size_t bound = Partition(d, st, keys, num, pivot, buf); + if (VQSORT_PRINT >= 2) { + fprintf(stderr, "bound %zu num %zu result %s\n", bound, num, + PivotResultString(result)); + } + // The left partition is not empty because the pivot is usually one of the + // keys. Exception: if kWasLast, we set pivot to PrevValue(pivot), but we + // still have at least one value <= pivot because AllEqual ruled out the case + // of only one unique value. Note that for floating-point, PrevValue can + // return the same value (for -inf inputs), but that would just mean the + // pivot is again one of the keys. + using Order = typename Traits::Order; + (void)Order::IsAscending(); + HWY_DASSERT_M(bound != 0, + (Order::IsAscending() ? "Ascending" : "Descending")); + // ChoosePivot* ensure pivot != last, so the right partition is never empty + // except in the rare case of the pivot matching the last-in-sort-order value, + // which implies we anyway skip the right partition due to kWasLast. + HWY_DASSERT(bound != num || result == PivotResult::kWasLast); + + HWY_IF_CONSTEXPR(mode == RecurseMode::kSelect) { + if (HWY_LIKELY(result != PivotResult::kIsFirst) && k < bound) { + Recurse(d, st, keys, bound, buf, state, + remaining_levels - 1, k); + } else if (HWY_LIKELY(result != PivotResult::kWasLast) && k >= bound) { + Recurse(d, st, keys + bound, num - bound, buf, + state, remaining_levels - 1, k - bound); + } + } + HWY_IF_CONSTEXPR(mode == RecurseMode::kSort) { + if (HWY_LIKELY(result != PivotResult::kIsFirst)) { + Recurse(d, st, keys, bound, buf, state, + remaining_levels - 1); + } + if (HWY_LIKELY(result != PivotResult::kWasLast)) { + Recurse(d, st, keys + bound, num - bound, buf, state, + remaining_levels - 1); + } + } +} + +// Returns true if sorting is finished. +template +HWY_INLINE bool HandleSpecialCases(D d, Traits st, T* HWY_RESTRICT keys, + size_t num, T* HWY_RESTRICT buf) { + const size_t N = Lanes(d); + constexpr size_t kLPK = st.LanesPerKey(); + const size_t base_case_num = Constants::BaseCaseNumLanes(N); + + // Recurse will also check this, but doing so here first avoids setting up + // the random generator state. + if (HWY_UNLIKELY(num <= base_case_num)) { + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "Special-casing small, %zu lanes\n", num); + } + BaseCase(d, st, keys, num, buf); + return true; + } + + // 128-bit keys require vectors with at least two u64 lanes, which is always + // the case unless `d` requests partial vectors (e.g. fraction = 1/2) AND the + // hardware vector width is less than 128bit / fraction. + const bool partial_128 = !IsFull(d) && N < 2 && st.Is128(); + // Partition assumes its input is at least two vectors. If vectors are huge, + // base_case_num may actually be smaller. If so, which is only possible on + // RVV, pass a capped or partial d (LMUL < 1). Use HWY_MAX_BYTES instead of + // HWY_LANES to account for the largest possible LMUL. + constexpr bool kPotentiallyHuge = + HWY_MAX_BYTES / sizeof(T) > Constants::kMaxRows * Constants::kMaxCols; + const bool huge_vec = kPotentiallyHuge && (2 * N > base_case_num); + if (partial_128 || huge_vec) { + if (VQSORT_PRINT >= 1) { + HWY_WARN("using slow HeapSort: partial %d huge %d\n", partial_128, + huge_vec); + } + HeapSort(st, keys, num); + return true; + } + + // We could also check for already sorted/reverse/equal, but that's probably + // counterproductive if vqsort is used as a base case. + + return false; // not finished sorting +} + +#endif // VQSORT_ENABLED + +template +HWY_INLINE size_t CountAndReplaceNaN(D d, Traits st, T* HWY_RESTRICT keys, + size_t num) { + const size_t N = Lanes(d); + // Will be sorted to the back of the array. + const Vec sentinel = st.LastValue(d); + size_t num_nan = 0; + size_t i = 0; + if (num >= N) { + for (; i <= num - N; i += N) { + const Mask is_nan = IsNaN(LoadU(d, keys + i)); + BlendedStore(sentinel, is_nan, d, keys + i); + num_nan += CountTrue(d, is_nan); + } + } + + const size_t remaining = num - i; + HWY_DASSERT(remaining < N); + const Vec v = LoadN(d, keys + i, remaining); + const Mask is_nan = IsNaN(v); + StoreN(IfThenElse(is_nan, sentinel, v), d, keys + i, remaining); + num_nan += CountTrue(d, is_nan); + return num_nan; +} + +// IsNaN is not implemented for non-float, so skip it. +template +HWY_INLINE size_t CountAndReplaceNaN(D, Traits, T* HWY_RESTRICT, size_t) { + return 0; +} + +} // namespace detail + +// Old interface with user-specified buffer, retained for compatibility. Called +// by the newer overload below. `buf` must be vector-aligned and hold at least +// SortConstants::BufBytes(HWY_MAX_BYTES, st.LanesPerKey()). +template +void Sort(D d, Traits st, T* HWY_RESTRICT keys, const size_t num, + T* HWY_RESTRICT buf) { + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "=============== Sort %s num=%zu, vec bytes=%zu\n", + st.KeyString(), num, sizeof(T) * Lanes(d)); + } + +#if HWY_MAX_BYTES > 64 + // sorting_networks-inl and traits assume no more than 512 bit vectors. + if (HWY_UNLIKELY(Lanes(d) > 64 / sizeof(T))) { + return Sort(CappedTag(), st, keys, num, buf); + } +#endif // HWY_MAX_BYTES > 64 + + const size_t num_nan = detail::CountAndReplaceNaN(d, st, keys, num); + +#if VQSORT_ENABLED || HWY_IDE + if (!detail::HandleSpecialCases(d, st, keys, num, buf)) { + uint64_t* HWY_RESTRICT state = hwy::detail::GetGeneratorStateStatic(); + // Introspection: switch to worst-case N*logN heapsort after this many. + // Should never be reached, so computing log2 exactly does not help. + const size_t max_levels = 50; + detail::Recurse(d, st, keys, num, buf, state, + max_levels); + } +#else // !VQSORT_ENABLED + (void)d; + (void)buf; + if (VQSORT_PRINT >= 1) { + HWY_WARN("using slow HeapSort because vqsort disabled\n"); + } + detail::HeapSort(st, keys, num); +#endif // VQSORT_ENABLED + + if (num_nan != 0) { + Fill(d, GetLane(NaN(d)), num_nan, keys + num - num_nan); + } +} + +template +void PartialSort(D d, Traits st, T* HWY_RESTRICT keys, size_t num, size_t k, + T* HWY_RESTRICT buf) { + if (VQSORT_PRINT >= 1) { + fprintf(stderr, + "=============== PartialSort %s num=%zu, k=%zu vec bytes=%zu\n", + st.KeyString(), num, k, sizeof(T) * Lanes(d)); + } + HWY_DASSERT(k <= num); + +#if HWY_MAX_BYTES > 64 + // sorting_networks-inl and traits assume no more than 512 bit vectors. + if (HWY_UNLIKELY(Lanes(d) > 64 / sizeof(T))) { + return PartialSort(CappedTag(), st, keys, num, k, buf); + } +#endif // HWY_MAX_BYTES > 64 + + const size_t num_nan = detail::CountAndReplaceNaN(d, st, keys, num); + +#if VQSORT_ENABLED || HWY_IDE + if (!detail::HandleSpecialCases(d, st, keys, num, buf)) { // TODO + uint64_t* HWY_RESTRICT state = hwy::detail::GetGeneratorStateStatic(); + // Introspection: switch to worst-case N*logN heapsort after this many. + // Should never be reached, so computing log2 exactly does not help. + const size_t max_levels = 50; + // TODO: optimize to use kLooseSelect + detail::Recurse(d, st, keys, num, buf, state, + max_levels, k); + detail::Recurse(d, st, keys, k, buf, state, + max_levels); + } +#else // !VQSORT_ENABLED + (void)d; + (void)buf; + if (VQSORT_PRINT >= 1) { + HWY_WARN("using slow HeapSort because vqsort disabled\n"); + } + detail::HeapPartialSort(st, keys, num, k); +#endif // VQSORT_ENABLED + + if (num_nan != 0) { + Fill(d, GetLane(NaN(d)), num_nan, keys + num - num_nan); + } +} + +template +void Select(D d, Traits st, T* HWY_RESTRICT keys, const size_t num, + const size_t k, T* HWY_RESTRICT buf) { + if (VQSORT_PRINT >= 1) { + fprintf(stderr, "=============== Select %s num=%zu, k=%zu vec bytes=%zu\n", + st.KeyString(), num, k, sizeof(T) * Lanes(d)); + } + HWY_DASSERT(k < num); + +#if HWY_MAX_BYTES > 64 + // sorting_networks-inl and traits assume no more than 512 bit vectors. + if (HWY_UNLIKELY(Lanes(d) > 64 / sizeof(T))) { + return Select(CappedTag(), st, keys, num, k, buf); + } +#endif // HWY_MAX_BYTES > 64 + + const size_t num_nan = detail::CountAndReplaceNaN(d, st, keys, num); + +#if VQSORT_ENABLED || HWY_IDE + if (!detail::HandleSpecialCases(d, st, keys, num, buf)) { // TODO + uint64_t* HWY_RESTRICT state = hwy::detail::GetGeneratorStateStatic(); + // Introspection: switch to worst-case N*logN heapsort after this many. + // Should never be reached, so computing log2 exactly does not help. + const size_t max_levels = 50; + detail::Recurse(d, st, keys, num, buf, state, + max_levels, k); + } +#else // !VQSORT_ENABLED + (void)d; + (void)buf; + if (VQSORT_PRINT >= 1) { + HWY_WARN("using slow HeapSort because vqsort disabled\n"); + } + detail::HeapSelect(st, keys, num, k); +#endif // VQSORT_ENABLED + + if (num_nan != 0) { + Fill(d, GetLane(NaN(d)), num_nan, keys + num - num_nan); + } +} + +// Sorts `keys[0..num-1]` according to the order defined by `st.Compare`. +// In-place i.e. O(1) additional storage. Worst-case N*logN comparisons. +// Non-stable (order of equal keys may change), except for the common case where +// the upper bits of T are the key, and the lower bits are a sequential or at +// least unique ID. Any NaN will be moved to the back of the array and replaced +// with the canonical NaN(d). +// There is no upper limit on `num`, but note that pivots may be chosen by +// sampling only from the first 256 GiB. +// +// `d` is typically SortTag (chooses between full and partial vectors). +// `st` is SharedTraits>. This abstraction layer bridges +// differences in sort order and single-lane vs 128-bit keys. +// `num` is in units of `T`, not keys! +template +HWY_API void Sort(D d, Traits st, T* HWY_RESTRICT keys, const size_t num) { + constexpr size_t kLPK = st.LanesPerKey(); + HWY_ALIGN T buf[SortConstants::BufBytes(HWY_MAX_BYTES) / sizeof(T)]; + Sort(d, st, keys, num, buf); +} + +// Rearranges elements such that the range [0, k) contains the sorted first `k` +// elements in the range [0, n) ordered by `st.Compare`. See also the comment +// for `Sort()`; note that `num` and `k` are in units of `T`, not keys! +template +HWY_API void PartialSort(D d, Traits st, T* HWY_RESTRICT keys, const size_t num, + const size_t k) { + constexpr size_t kLPK = st.LanesPerKey(); + HWY_ALIGN T buf[SortConstants::BufBytes(HWY_MAX_BYTES) / sizeof(T)]; + PartialSort(d, st, keys, num, k, buf); +} + +// Reorders `keys[0..num-1]` such that `keys+k` is the k-th element if keys was +// sorted by `st.Compare`, and all of the elements before it are ordered +// by `st.Compare` relative to `keys[k]`. See also the comment for `Sort()`; +// note that `num` and `k` are in units of `T`, not keys! +template +HWY_API void Select(D d, Traits st, T* HWY_RESTRICT keys, const size_t num, + const size_t k) { + constexpr size_t kLPK = st.LanesPerKey(); + HWY_ALIGN T buf[SortConstants::BufBytes(HWY_MAX_BYTES) / sizeof(T)]; + Select(d, st, keys, num, k, buf); +} + +// Translates Key and Order (SortAscending or SortDescending) to SharedTraits. +namespace detail { + +// Primary template for built-in key types = lane type. +template +struct KeyAdapter { + template + using Traits = TraitsLane< + hwy::If, OrderDescending>>; +}; + +template <> +struct KeyAdapter { + template + using Traits = TraitsLane< + hwy::If>; +}; + +// 128-bit keys require 128-bit SIMD. +#if HWY_TARGET != HWY_SCALAR + +template <> +struct KeyAdapter { + template + using Traits = Traits128< + hwy::If>; +}; + +template <> +struct KeyAdapter { + template + using Traits = Traits128< + hwy::If>; +}; + +#endif // HWY_TARGET != HWY_SCALAR + +template +using MakeTraits = + SharedTraits::template Traits>; + +} // namespace detail + +// Simpler interface matching VQSort(), but without dynamic dispatch. Uses the +// instructions available in the current target (HWY_NAMESPACE). Supported key +// types: 16-64 bit unsigned/signed/floating-point (but float16/64 only #if +// HWY_HAVE_FLOAT16/64), uint128_t, K64V64, K32V32. Note that `num`, and for +// VQPartialSortStatic/VQSelectStatic also `k`, are in units of *keys*, whereas +// for all functions above this point, they are in units of `T`. Order is either +// SortAscending or SortDescending. +template +void VQSortStatic(Key* HWY_RESTRICT keys, const size_t num_keys, Order) { + const detail::MakeTraits st; + using LaneType = typename decltype(st)::LaneType; + const SortTag d; + Sort(d, st, reinterpret_cast(keys), num_keys * st.LanesPerKey()); +} + +template +void VQPartialSortStatic(Key* HWY_RESTRICT keys, const size_t num_keys, + const size_t k_keys, Order) { + const detail::MakeTraits st; + using LaneType = typename decltype(st)::LaneType; + const SortTag d; + PartialSort(d, st, reinterpret_cast(keys), + num_keys * st.LanesPerKey(), k_keys * st.LanesPerKey()); +} + +template +void VQSelectStatic(Key* HWY_RESTRICT keys, const size_t num_keys, + const size_t k_keys, Order) { + const detail::MakeTraits st; + using LaneType = typename decltype(st)::LaneType; + const SortTag d; + Select(d, st, reinterpret_cast(keys), num_keys * st.LanesPerKey(), + k_keys * st.LanesPerKey()); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE diff --git a/lib/highway/hwy/contrib/sort/vqsort.cc b/lib/highway/hwy/contrib/sort/vqsort.cc new file mode 100644 index 00000000000..68036b6513f --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort.cc @@ -0,0 +1,121 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" + +#include "hwy/base.h" +#include "hwy/contrib/sort/vqsort-inl.h" +#include "hwy/per_target.h" + +// Check if we have getrandom from . Because is +// unavailable on Android and non-Linux RVV, we assume that those systems lack +// getrandom. Note that the only supported sources of entropy are getrandom or +// Windows, thus VQSORT_SECURE_SEED=0 when this is 0 and we are not on Windows. +#if defined(ANDROID) || defined(__ANDROID__) || \ + (HWY_ARCH_RISCV && !HWY_OS_LINUX) +#define VQSORT_GETRANDOM 0 +#endif + +#if !defined(VQSORT_GETRANDOM) && HWY_OS_LINUX +#include + +// ---- which libc +#if defined(__UCLIBC__) +#define VQSORT_GETRANDOM 1 // added Mar 2015, before uclibc-ng 1.0 + +#elif defined(__GLIBC__) && defined(__GLIBC_PREREQ) +#if __GLIBC_PREREQ(2, 25) +#define VQSORT_GETRANDOM 1 +#else +#define VQSORT_GETRANDOM 0 +#endif + +#else +// Assume MUSL, which has getrandom since 2018. There is no macro to test, see +// https://www.openwall.com/lists/musl/2013/03/29/13. +#define VQSORT_GETRANDOM 1 + +#endif // ---- which libc +#endif // linux + +#if !defined(VQSORT_GETRANDOM) +#define VQSORT_GETRANDOM 0 +#endif + +// Choose a seed source for SFC generator: 1=getrandom, 2=CryptGenRandom. +// Allow user override - not all Android support the getrandom wrapper. +#ifndef VQSORT_SECURE_SEED + +#if VQSORT_GETRANDOM +#define VQSORT_SECURE_SEED 1 +#elif defined(_WIN32) || defined(_WIN64) +#define VQSORT_SECURE_SEED 2 +#else +#define VQSORT_SECURE_SEED 0 +#endif + +#endif // VQSORT_SECURE_SEED + +// Pull in dependencies of the chosen seed source. +#if VQSORT_SECURE_SEED == 1 +#include +#elif VQSORT_SECURE_SEED == 2 +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif // WIN32_LEAN_AND_MEAN +#include +#if HWY_COMPILER_MSVC || HWY_COMPILER_CLANGCL +#pragma comment(lib, "advapi32.lib") +#endif // HWY_COMPILER_MSVC || HWY_COMPILER_CLANGCL +// Must come after windows.h. +#include +#endif // VQSORT_SECURE_SEED + +namespace hwy { + +// Returns false or performs the equivalent of `memcpy(bytes, r, 16)`, where r +// is high-quality (unpredictable, uniformly distributed) random bits. +bool Fill16BytesSecure(void* bytes) { +#if VQSORT_SECURE_SEED == 1 + // May block if urandom is not yet initialized. + const ssize_t ret = getrandom(bytes, 16, /*flags=*/0); + if (ret == 16) return true; +#elif VQSORT_SECURE_SEED == 2 + HCRYPTPROV hProvider{}; + if (CryptAcquireContextA(&hProvider, nullptr, nullptr, PROV_RSA_FULL, + CRYPT_VERIFYCONTEXT)) { + const BOOL ok = + CryptGenRandom(hProvider, 16, reinterpret_cast(bytes)); + CryptReleaseContext(hProvider, 0); + if (ok) return true; + } +#else + (void)bytes; +#endif + + return false; +} + +// Unused, only for ABI compatibility +void Sorter::Fill24Bytes(const void*, size_t, void*) {} +bool Sorter::HaveFloat64() { return hwy::HaveFloat64(); } +Sorter::Sorter() {} +void Sorter::Delete() {} +uint64_t* GetGeneratorState() { return hwy::detail::GetGeneratorStateStatic(); } + +} // namespace hwy diff --git a/lib/highway/hwy/contrib/sort/vqsort.h b/lib/highway/hwy/contrib/sort/vqsort.h new file mode 100644 index 00000000000..8f05415800f --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort.h @@ -0,0 +1,303 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Interface to vectorized quicksort with dynamic dispatch. For static dispatch +// without any DLLEXPORT, avoid including this header and instead define +// VQSORT_ONLY_STATIC, then call VQSortStatic* in vqsort-inl.h. +// +// Blog post: https://tinyurl.com/vqsort-blog +// Paper with measurements: https://arxiv.org/abs/2205.05982 +// +// To ensure the overhead of using wide vectors (e.g. AVX2 or AVX-512) is +// worthwhile, we recommend using this code for sorting arrays whose size is at +// least 100 KiB. See the README for details. + +#ifndef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_H_ +#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_H_ + +// IWYU pragma: begin_exports +#include + +#include "hwy/base.h" +#include "hwy/contrib/sort/order.h" // SortAscending +// IWYU pragma: end_exports + +namespace hwy { + +// Vectorized Quicksort: sorts keys[0, n). Does not preserve the ordering of +// equivalent keys (defined as: neither greater nor less than another). +// Dispatches to the best available instruction set. Does not allocate memory. +// Uses about 1.2 KiB stack plus an internal 3-word TLS cache for random state. +HWY_CONTRIB_DLLEXPORT void VQSort(uint16_t* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(uint16_t* HWY_RESTRICT keys, size_t n, + SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSort(uint32_t* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(uint32_t* HWY_RESTRICT keys, size_t n, + SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSort(uint64_t* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(uint64_t* HWY_RESTRICT keys, size_t n, + SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSort(int16_t* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(int16_t* HWY_RESTRICT keys, size_t n, + SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSort(int32_t* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(int32_t* HWY_RESTRICT keys, size_t n, + SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSort(int64_t* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(int64_t* HWY_RESTRICT keys, size_t n, + SortDescending); + +// These two must only be called if hwy::HaveFloat16() is true. +HWY_CONTRIB_DLLEXPORT void VQSort(float16_t* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(float16_t* HWY_RESTRICT keys, size_t n, + SortDescending); + +HWY_CONTRIB_DLLEXPORT void VQSort(float* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(float* HWY_RESTRICT keys, size_t n, + SortDescending); + +// These two must only be called if hwy::HaveFloat64() is true. +HWY_CONTRIB_DLLEXPORT void VQSort(double* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(double* HWY_RESTRICT keys, size_t n, + SortDescending); + +HWY_CONTRIB_DLLEXPORT void VQSort(K32V32* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(K32V32* HWY_RESTRICT keys, size_t n, + SortDescending); + +// 128-bit types: `n` is still in units of the 128-bit keys. +HWY_CONTRIB_DLLEXPORT void VQSort(uint128_t* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(uint128_t* HWY_RESTRICT keys, size_t n, + SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSort(K64V64* HWY_RESTRICT keys, size_t n, + SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSort(K64V64* HWY_RESTRICT keys, size_t n, + SortDescending); + +// Vectorized partial Quicksort: +// Rearranges elements such that the range [0, k) contains the sorted first k +// elements in the range [0, n). Does not preserve the ordering of equivalent +// keys (defined as: neither greater nor less than another). +// Dispatches to the best available instruction set. Does not allocate memory. +// Uses about 1.2 KiB stack plus an internal 3-word TLS cache for random state. +HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint32_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint32_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint64_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint64_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(int16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(int16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(int32_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(int32_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(int64_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(int64_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +// These two must only be called if hwy::HaveFloat16() is true. +HWY_CONTRIB_DLLEXPORT void VQPartialSort(float16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(float16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +HWY_CONTRIB_DLLEXPORT void VQPartialSort(float* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(float* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +// These two must only be called if hwy::HaveFloat64() is true. +HWY_CONTRIB_DLLEXPORT void VQPartialSort(double* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(double* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +HWY_CONTRIB_DLLEXPORT void VQPartialSort(K32V32* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(K32V32* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +// 128-bit types: `n` and `k` are still in units of the 128-bit keys. +HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint128_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(uint128_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(K64V64* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQPartialSort(K64V64* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +// Vectorized Quickselect: +// rearranges elements in [0, n) such that: +// The element pointed at by kth is changed to whatever element would occur in +// that position if [0, n) were sorted. All of the elements before this new kth +// element are less than or equal to the elements after the new kth element. +HWY_CONTRIB_DLLEXPORT void VQSelect(uint16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(uint16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSelect(uint32_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(uint32_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSelect(uint64_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(uint64_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSelect(int16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(int16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSelect(int32_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(int32_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSelect(int64_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(int64_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +// These two must only be called if hwy::HaveFloat16() is true. +HWY_CONTRIB_DLLEXPORT void VQSelect(float16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(float16_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +HWY_CONTRIB_DLLEXPORT void VQSelect(float* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(float* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +// These two must only be called if hwy::HaveFloat64() is true. +HWY_CONTRIB_DLLEXPORT void VQSelect(double* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(double* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +HWY_CONTRIB_DLLEXPORT void VQSelect(K32V32* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(K32V32* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +// 128-bit types: `n` and `k` are still in units of the 128-bit keys. +HWY_CONTRIB_DLLEXPORT void VQSelect(uint128_t* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(uint128_t* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); +HWY_CONTRIB_DLLEXPORT void VQSelect(K64V64* HWY_RESTRICT keys, size_t n, + size_t k, SortAscending); +HWY_CONTRIB_DLLEXPORT void VQSelect(K64V64* HWY_RESTRICT keys, size_t n, + size_t k, SortDescending); + +// User-level caching is no longer required, so this class is no longer +// beneficial. We recommend using the simpler VQSort() interface instead, and +// retain this class only for compatibility. It now just calls VQSort. +class HWY_CONTRIB_DLLEXPORT Sorter { + public: + Sorter(); + ~Sorter() { Delete(); } + + // Move-only + Sorter(const Sorter&) = delete; + Sorter& operator=(const Sorter&) = delete; + Sorter(Sorter&& /*other*/) {} + Sorter& operator=(Sorter&& /*other*/) { return *this; } + + void operator()(uint16_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(uint16_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(uint32_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(uint32_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(uint64_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(uint64_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(int16_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(int16_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(int32_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(int32_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + void operator()(int64_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(int64_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + + // These two must only be called if hwy::HaveFloat16() is true. + void operator()(float16_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(float16_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(float* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(float* HWY_RESTRICT keys, size_t n, SortDescending) const; + + // These two must only be called if hwy::HaveFloat64() is true. + void operator()(double* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(double* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(uint128_t* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(uint128_t* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(K64V64* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(K64V64* HWY_RESTRICT keys, size_t n, SortDescending) const; + + void operator()(K32V32* HWY_RESTRICT keys, size_t n, SortAscending) const; + void operator()(K32V32* HWY_RESTRICT keys, size_t n, SortDescending) const; + + // Unused + static void Fill24Bytes(const void*, size_t, void*); + static bool HaveFloat64(); // Can also use hwy::HaveFloat64 directly. + + private: + void Delete(); + + template + T* Get() const { + return unused_; + } + +#if HWY_COMPILER_CLANG + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wunused-private-field") +#endif + void* unused_ = nullptr; +#if HWY_COMPILER_CLANG + HWY_DIAGNOSTICS(pop) +#endif +}; + +// Used by vqsort-inl.h unless VQSORT_ONLY_STATIC. +HWY_CONTRIB_DLLEXPORT bool Fill16BytesSecure(void* bytes); + +// Unused, only provided for binary compatibility. +HWY_CONTRIB_DLLEXPORT uint64_t* GetGeneratorState(); + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_H_ diff --git a/lib/highway/hwy/contrib/sort/vqsort_128a.cc b/lib/highway/hwy/contrib/sort/vqsort_128a.cc new file mode 100644 index 00000000000..3c4bcbf1e50 --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_128a.cc @@ -0,0 +1,98 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_128a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void Sort128Asc(uint128_t* HWY_RESTRICT keys, const size_t num) { +// 128-bit keys require 128-bit SIMD. +#if HWY_TARGET != HWY_SCALAR + return VQSortStatic(keys, num, SortAscending()); +#else + (void)keys; + (void)num; +#endif +} + +void PartialSort128Asc(uint128_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { +// 128-bit keys require 128-bit SIMD. +#if HWY_TARGET != HWY_SCALAR + return VQPartialSortStatic(keys, num, k, SortAscending()); +#else + (void)keys; + (void)num; + (void)k; +#endif +} + +void Select128Asc(uint128_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { +// 128-bit keys require 128-bit SIMD. +#if HWY_TARGET != HWY_SCALAR + return VQSelectStatic(keys, num, k, SortAscending()); +#else + (void)keys; + (void)num; + (void)k; +#endif +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(Sort128Asc); +HWY_EXPORT(PartialSort128Asc); +HWY_EXPORT(Select128Asc); +} // namespace + +void VQSort(uint128_t* HWY_RESTRICT keys, const size_t n, SortAscending) { + HWY_DYNAMIC_DISPATCH(Sort128Asc)(keys, n); +} + +void VQPartialSort(uint128_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(PartialSort128Asc)(keys, n, k); +} + +void VQSelect(uint128_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(Select128Asc)(keys, n, k); +} + +void Sorter::operator()(uint128_t* HWY_RESTRICT keys, size_t n, + SortAscending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_128d.cc b/lib/highway/hwy/contrib/sort/vqsort_128d.cc new file mode 100644 index 00000000000..36580abc24f --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_128d.cc @@ -0,0 +1,98 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_128d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void Sort128Desc(uint128_t* HWY_RESTRICT keys, const size_t num) { + // 128-bit keys require 128-bit SIMD. +#if HWY_TARGET != HWY_SCALAR + return VQSortStatic(keys, num, SortDescending()); +#else + (void)keys; + (void)num; +#endif +} + +void PartialSort128Desc(uint128_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { +// 128-bit keys require 128-bit SIMD. +#if HWY_TARGET != HWY_SCALAR + return VQPartialSortStatic(keys, num, k, SortDescending()); +#else + (void)keys; + (void)num; + (void)k; +#endif +} + +void Select128Desc(uint128_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { +// 128-bit keys require 128-bit SIMD. +#if HWY_TARGET != HWY_SCALAR + return VQSelectStatic(keys, num, k, SortDescending()); +#else + (void)keys; + (void)num; + (void)k; +#endif +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(Sort128Desc); +HWY_EXPORT(PartialSort128Desc); +HWY_EXPORT(Select128Desc); +} // namespace + +void VQSort(uint128_t* HWY_RESTRICT keys, const size_t n, SortDescending) { + HWY_DYNAMIC_DISPATCH(Sort128Desc)(keys, n); +} + +void VQPartialSort(uint128_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(PartialSort128Desc)(keys, n, k); +} + +void VQSelect(uint128_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(Select128Desc)(keys, n, k); +} + +void Sorter::operator()(uint128_t* HWY_RESTRICT keys, size_t n, + SortDescending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_f16a.cc b/lib/highway/hwy/contrib/sort/vqsort_f16a.cc new file mode 100644 index 00000000000..a5db4f76eee --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_f16a.cc @@ -0,0 +1,99 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort +#include "hwy/nanobenchmark.h" // Unpredictable1 + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f16a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void SortF16Asc(float16_t* HWY_RESTRICT keys, const size_t num) { +#if HWY_HAVE_FLOAT16 + return VQSortStatic(keys, num, SortAscending()); +#else + (void)keys; + (void)num; + if (Unpredictable1()) HWY_ASSERT(0); +#endif +} + +void PartialSortF16Asc(float16_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { +#if HWY_HAVE_FLOAT16 + return VQPartialSortStatic(keys, num, k, SortAscending()); +#else + (void)keys; + (void)num; + (void)k; + if (Unpredictable1()) HWY_ASSERT(0); +#endif +} + +void SelectF16Asc(float16_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { +#if HWY_HAVE_FLOAT16 + return VQSelectStatic(keys, num, k, SortAscending()); +#else + (void)keys; + (void)num; + (void)k; + if (Unpredictable1()) HWY_ASSERT(0); +#endif +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortF16Asc); +HWY_EXPORT(PartialSortF16Asc); +HWY_EXPORT(SelectF16Asc); +} // namespace + +void VQSort(float16_t* HWY_RESTRICT keys, const size_t n, SortAscending) { + HWY_DYNAMIC_DISPATCH(SortF16Asc)(keys, n); +} + +void VQPartialSort(float16_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(PartialSortF16Asc)(keys, n, k); +} + +void VQSelect(float16_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(SelectF16Asc)(keys, n, k); +} + +void Sorter::operator()(float16_t* HWY_RESTRICT keys, size_t n, + SortAscending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_f16d.cc b/lib/highway/hwy/contrib/sort/vqsort_f16d.cc new file mode 100644 index 00000000000..ccedd49fc1a --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_f16d.cc @@ -0,0 +1,99 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort +#include "hwy/nanobenchmark.h" // + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f16d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void SortF16Desc(float16_t* HWY_RESTRICT keys, const size_t num) { +#if HWY_HAVE_FLOAT16 + return VQSortStatic(keys, num, SortDescending()); +#else + (void)keys; + (void)num; + if (Unpredictable1()) HWY_ASSERT(0); +#endif +} + +void PartialSortF16Desc(float16_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { +#if HWY_HAVE_FLOAT16 + return VQPartialSortStatic(keys, num, k, SortDescending()); +#else + (void)keys; + (void)num; + (void)k; + if (Unpredictable1()) HWY_ASSERT(0); +#endif +} + +void SelectF16Desc(float16_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { +#if HWY_HAVE_FLOAT16 + return VQSelectStatic(keys, num, k, SortDescending()); +#else + (void)keys; + (void)num; + (void)k; + if (Unpredictable1()) HWY_ASSERT(0); +#endif +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortF16Desc); +HWY_EXPORT(PartialSortF16Desc); +HWY_EXPORT(SelectF16Desc); +} // namespace + +void VQSort(float16_t* HWY_RESTRICT keys, const size_t n, SortDescending) { + HWY_DYNAMIC_DISPATCH(SortF16Desc)(keys, n); +} + +void VQPartialSort(float16_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(PartialSortF16Desc)(keys, n, k); +} + +void VQSelect(float16_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(SelectF16Desc)(keys, n, k); +} + +void Sorter::operator()(float16_t* HWY_RESTRICT keys, size_t n, + SortDescending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_f32a.cc b/lib/highway/hwy/contrib/sort/vqsort_f32a.cc new file mode 100644 index 00000000000..fc243c1aaca --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_f32a.cc @@ -0,0 +1,77 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f32a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void SortF32Asc(float* HWY_RESTRICT keys, const size_t num) { + return VQSortStatic(keys, num, SortAscending()); +} + +void PartialSortF32Asc(float* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQPartialSortStatic(keys, num, k, SortAscending()); +} + +void SelectF32Asc(float* HWY_RESTRICT keys, const size_t num, const size_t k) { + return VQSelectStatic(keys, num, k, SortAscending()); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortF32Asc); +HWY_EXPORT(PartialSortF32Asc); +HWY_EXPORT(SelectF32Asc); +} // namespace + +void VQSort(float* HWY_RESTRICT keys, const size_t n, SortAscending) { + HWY_DYNAMIC_DISPATCH(SortF32Asc)(keys, n); +} + +void VQPartialSort(float* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(PartialSortF32Asc)(keys, n, k); +} + +void VQSelect(float* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(SelectF32Asc)(keys, n, k); +} + +void Sorter::operator()(float* HWY_RESTRICT keys, size_t n, + SortAscending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_f32d.cc b/lib/highway/hwy/contrib/sort/vqsort_f32d.cc new file mode 100644 index 00000000000..1017bbb8c6a --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_f32d.cc @@ -0,0 +1,77 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f32d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void SortF32Desc(float* HWY_RESTRICT keys, const size_t num) { + return VQSortStatic(keys, num, SortDescending()); +} + +void PartialSortF32Desc(float* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQPartialSortStatic(keys, num, k, SortDescending()); +} + +void SelectF32Desc(float* HWY_RESTRICT keys, const size_t num, const size_t k) { + return VQSelectStatic(keys, num, k, SortDescending()); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortF32Desc); +HWY_EXPORT(PartialSortF32Desc); +HWY_EXPORT(SelectF32Desc); +} // namespace + +void VQSort(float* HWY_RESTRICT keys, const size_t n, SortDescending) { + HWY_DYNAMIC_DISPATCH(SortF32Desc)(keys, n); +} + +void VQPartialSort(float* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(PartialSortF32Desc)(keys, n, k); +} + +void VQSelect(float* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(SelectF32Desc)(keys, n, k); +} + +void Sorter::operator()(float* HWY_RESTRICT keys, size_t n, + SortDescending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_f64a.cc b/lib/highway/hwy/contrib/sort/vqsort_f64a.cc new file mode 100644 index 00000000000..39e1c1642a7 --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_f64a.cc @@ -0,0 +1,97 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f64a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void SortF64Asc(double* HWY_RESTRICT keys, const size_t num) { +#if HWY_HAVE_FLOAT64 + return VQSortStatic(keys, num, SortAscending()); +#else + (void)keys; + (void)num; + HWY_ASSERT(0); +#endif +} + +void PartialSortF64Asc(double* HWY_RESTRICT keys, const size_t num, + const size_t k) { +#if HWY_HAVE_FLOAT64 + return VQPartialSortStatic(keys, num, k, SortAscending()); +#else + (void)keys; + (void)num; + (void)k; + HWY_ASSERT(0); +#endif +} + +void SelectF64Asc(double* HWY_RESTRICT keys, const size_t num, const size_t k) { +#if HWY_HAVE_FLOAT64 + return VQSelectStatic(keys, num, k, SortAscending()); +#else + (void)keys; + (void)num; + (void)k; + HWY_ASSERT(0); +#endif +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortF64Asc); +HWY_EXPORT(PartialSortF64Asc); +HWY_EXPORT(SelectF64Asc); +} // namespace + +void VQSort(double* HWY_RESTRICT keys, const size_t n, SortAscending) { + HWY_DYNAMIC_DISPATCH(SortF64Asc)(keys, n); +} + +void VQPartialSort(double* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(PartialSortF64Asc)(keys, n, k); +} + +void VQSelect(double* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(SelectF64Asc)(keys, n, k); +} + +void Sorter::operator()(double* HWY_RESTRICT keys, size_t n, + SortAscending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_f64d.cc b/lib/highway/hwy/contrib/sort/vqsort_f64d.cc new file mode 100644 index 00000000000..585fd2b1de1 --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_f64d.cc @@ -0,0 +1,98 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f64d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void SortF64Desc(double* HWY_RESTRICT keys, const size_t num) { +#if HWY_HAVE_FLOAT64 + return VQSortStatic(keys, num, SortDescending()); +#else + (void)keys; + (void)num; + HWY_ASSERT(0); +#endif +} + +void PartialSortF64Desc(double* HWY_RESTRICT keys, const size_t num, + const size_t k) { +#if HWY_HAVE_FLOAT64 + return VQPartialSortStatic(keys, num, k, SortDescending()); +#else + (void)keys; + (void)num; + (void)k; + HWY_ASSERT(0); +#endif +} + +void SelectF64Desc(double* HWY_RESTRICT keys, const size_t num, + const size_t k) { +#if HWY_HAVE_FLOAT64 + return VQSelectStatic(keys, num, k, SortDescending()); +#else + (void)keys; + (void)num; + (void)k; + HWY_ASSERT(0); +#endif +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortF64Desc); +HWY_EXPORT(PartialSortF64Desc); +HWY_EXPORT(SelectF64Desc); +} // namespace + +void VQSort(double* HWY_RESTRICT keys, const size_t n, SortDescending) { + HWY_DYNAMIC_DISPATCH(SortF64Desc)(keys, n); +} + +void VQPartialSort(double* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(PartialSortF64Desc)(keys, n, k); +} + +void VQSelect(double* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(SelectF64Desc)(keys, n, k); +} + +void Sorter::operator()(double* HWY_RESTRICT keys, size_t n, + SortDescending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_i16a.cc b/lib/highway/hwy/contrib/sort/vqsort_i16a.cc new file mode 100644 index 00000000000..0be242ceb89 --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_i16a.cc @@ -0,0 +1,78 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i16a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void SortI16Asc(int16_t* HWY_RESTRICT keys, const size_t num) { + return VQSortStatic(keys, num, SortAscending()); +} + +void PartialSortI16Asc(int16_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQPartialSortStatic(keys, num, k, SortAscending()); +} + +void SelectI16Asc(int16_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQSelectStatic(keys, num, k, SortAscending()); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI16Asc); +HWY_EXPORT(PartialSortI16Asc); +HWY_EXPORT(SelectI16Asc); +} // namespace + +void VQSort(int16_t* HWY_RESTRICT keys, const size_t n, SortAscending) { + HWY_DYNAMIC_DISPATCH(SortI16Asc)(keys, n); +} + +void VQPartialSort(int16_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(PartialSortI16Asc)(keys, n, k); +} + +void VQSelect(int16_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(SelectI16Asc)(keys, n, k); +} + +void Sorter::operator()(int16_t* HWY_RESTRICT keys, size_t n, + SortAscending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_i16d.cc b/lib/highway/hwy/contrib/sort/vqsort_i16d.cc new file mode 100644 index 00000000000..c5008353ded --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_i16d.cc @@ -0,0 +1,78 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i16d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void SortI16Desc(int16_t* HWY_RESTRICT keys, const size_t num) { + return VQSortStatic(keys, num, SortDescending()); +} + +void PartialSortI16Desc(int16_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQPartialSortStatic(keys, num, k, SortDescending()); +} + +void SelectI16Desc(int16_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQSelectStatic(keys, num, k, SortDescending()); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI16Desc); +HWY_EXPORT(PartialSortI16Desc); +HWY_EXPORT(SelectI16Desc); +} // namespace + +void VQSort(int16_t* HWY_RESTRICT keys, const size_t n, SortDescending) { + HWY_DYNAMIC_DISPATCH(SortI16Desc)(keys, n); +} + +void VQPartialSort(int16_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(PartialSortI16Desc)(keys, n, k); +} + +void VQSelect(int16_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(SelectI16Desc)(keys, n, k); +} + +void Sorter::operator()(int16_t* HWY_RESTRICT keys, size_t n, + SortDescending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_i32a.cc b/lib/highway/hwy/contrib/sort/vqsort_i32a.cc new file mode 100644 index 00000000000..d40a3bfa314 --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_i32a.cc @@ -0,0 +1,78 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i32a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void SortI32Asc(int32_t* HWY_RESTRICT keys, const size_t num) { + return VQSortStatic(keys, num, SortAscending()); +} + +void PartialSortI32Asc(int32_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQPartialSortStatic(keys, num, k, SortAscending()); +} + +void SelectI32Asc(int32_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQSelectStatic(keys, num, k, SortAscending()); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI32Asc); +HWY_EXPORT(PartialSortI32Asc); +HWY_EXPORT(SelectI32Asc); +} // namespace + +void VQSort(int32_t* HWY_RESTRICT keys, const size_t n, SortAscending) { + HWY_DYNAMIC_DISPATCH(SortI32Asc)(keys, n); +} + +void VQPartialSort(int32_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(PartialSortI32Asc)(keys, n, k); +} + +void VQSelect(int32_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(SelectI32Asc)(keys, n, k); +} + +void Sorter::operator()(int32_t* HWY_RESTRICT keys, size_t n, + SortAscending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_i32d.cc b/lib/highway/hwy/contrib/sort/vqsort_i32d.cc new file mode 100644 index 00000000000..5f11f4fbec4 --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_i32d.cc @@ -0,0 +1,78 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i32d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void SortI32Desc(int32_t* HWY_RESTRICT keys, const size_t num) { + return VQSortStatic(keys, num, SortDescending()); +} + +void PartialSortI32Desc(int32_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQPartialSortStatic(keys, num, k, SortDescending()); +} + +void SelectI32Desc(int32_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQSelectStatic(keys, num, k, SortDescending()); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI32Desc); +HWY_EXPORT(PartialSortI32Desc); +HWY_EXPORT(SelectI32Desc); +} // namespace + +void VQSort(int32_t* HWY_RESTRICT keys, const size_t n, SortDescending) { + HWY_DYNAMIC_DISPATCH(SortI32Desc)(keys, n); +} + +void VQPartialSort(int32_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(PartialSortI32Desc)(keys, n, k); +} + +void VQSelect(int32_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(SelectI32Desc)(keys, n, k); +} + +void Sorter::operator()(int32_t* HWY_RESTRICT keys, size_t n, + SortDescending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_i64a.cc b/lib/highway/hwy/contrib/sort/vqsort_i64a.cc new file mode 100644 index 00000000000..d0b808bc1b7 --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_i64a.cc @@ -0,0 +1,78 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i64a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void SortI64Asc(int64_t* HWY_RESTRICT keys, const size_t num) { + return VQSortStatic(keys, num, SortAscending()); +} + +void PartialSortI64Asc(int64_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQPartialSortStatic(keys, num, k, SortAscending()); +} + +void SelectI64Asc(int64_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQSelectStatic(keys, num, k, SortAscending()); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI64Asc); +HWY_EXPORT(PartialSortI64Asc); +HWY_EXPORT(SelectI64Asc); +} // namespace + +void VQSort(int64_t* HWY_RESTRICT keys, const size_t n, SortAscending) { + HWY_DYNAMIC_DISPATCH(SortI64Asc)(keys, n); +} + +void VQPartialSort(int64_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(PartialSortI64Asc)(keys, n, k); +} + +void VQSelect(int64_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(SelectI64Asc)(keys, n, k); +} + +void Sorter::operator()(int64_t* HWY_RESTRICT keys, size_t n, + SortAscending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_i64d.cc b/lib/highway/hwy/contrib/sort/vqsort_i64d.cc new file mode 100644 index 00000000000..0231f59618d --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_i64d.cc @@ -0,0 +1,78 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i64d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void SortI64Desc(int64_t* HWY_RESTRICT keys, const size_t num) { + return VQSortStatic(keys, num, SortDescending()); +} + +void PartialSortI64Desc(int64_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQPartialSortStatic(keys, num, k, SortDescending()); +} + +void SelectI64Desc(int64_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQSelectStatic(keys, num, k, SortDescending()); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortI64Desc); +HWY_EXPORT(PartialSortI64Desc); +HWY_EXPORT(SelectI64Desc); +} // namespace + +void VQSort(int64_t* HWY_RESTRICT keys, const size_t n, SortDescending) { + HWY_DYNAMIC_DISPATCH(SortI64Desc)(keys, n); +} + +void VQPartialSort(int64_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(PartialSortI64Desc)(keys, n, k); +} + +void VQSelect(int64_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(SelectI64Desc)(keys, n, k); +} + +void Sorter::operator()(int64_t* HWY_RESTRICT keys, size_t n, + SortDescending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_kv128a.cc b/lib/highway/hwy/contrib/sort/vqsort_kv128a.cc new file mode 100644 index 00000000000..87ffb89059f --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_kv128a.cc @@ -0,0 +1,101 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort + +#undef HWY_TARGET_INCLUDE +// clang-format off +// (avoid line break, which would prevent Copybara rules from matching) +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_kv128a.cc" //NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void SortKV128Asc(K64V64* HWY_RESTRICT keys, const size_t num) { + // 128-bit keys require 128-bit SIMD. +#if HWY_TARGET != HWY_SCALAR + return VQSortStatic(keys, num, SortAscending()); +#else + (void)keys; + (void)num; +#endif +} + +void PartialSortKV128Asc(K64V64* HWY_RESTRICT keys, const size_t num, + const size_t k) { + // 128-bit keys require 128-bit SIMD. +#if HWY_TARGET != HWY_SCALAR + return VQPartialSortStatic(keys, num, k, SortAscending()); +#else + (void)keys; + (void)num; + (void)k; +#endif +} + +void SelectKV128Asc(K64V64* HWY_RESTRICT keys, const size_t num, + const size_t k) { + // 128-bit keys require 128-bit SIMD. +#if HWY_TARGET != HWY_SCALAR + return VQSelectStatic(keys, num, k, SortAscending()); +#else + (void)keys; + (void)num; + (void)k; +#endif +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortKV128Asc); +HWY_EXPORT(PartialSortKV128Asc); +HWY_EXPORT(SelectKV128Asc); +} // namespace + +void VQSort(K64V64* HWY_RESTRICT keys, const size_t n, SortAscending) { + HWY_DYNAMIC_DISPATCH(SortKV128Asc)(keys, n); +} + +void VQPartialSort(K64V64* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(PartialSortKV128Asc)(keys, n, k); +} + +void VQSelect(K64V64* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(SelectKV128Asc)(keys, n, k); +} + +void Sorter::operator()(K64V64* HWY_RESTRICT keys, size_t n, + SortAscending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_kv128d.cc b/lib/highway/hwy/contrib/sort/vqsort_kv128d.cc new file mode 100644 index 00000000000..7e652c70058 --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_kv128d.cc @@ -0,0 +1,101 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort + +#undef HWY_TARGET_INCLUDE +// clang-format off +// (avoid line break, which would prevent Copybara rules from matching) +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_kv128d.cc" //NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void SortKV128Desc(K64V64* HWY_RESTRICT keys, const size_t num) { + // 128-bit keys require 128-bit SIMD. +#if HWY_TARGET != HWY_SCALAR + return VQSortStatic(keys, num, SortDescending()); +#else + (void)keys; + (void)num; +#endif +} + +void PartialSortKV128Desc(K64V64* HWY_RESTRICT keys, const size_t num, + const size_t k) { + // 128-bit keys require 128-bit SIMD. +#if HWY_TARGET != HWY_SCALAR + return VQPartialSortStatic(keys, num, k, SortDescending()); +#else + (void)keys; + (void)num; + (void)k; +#endif +} + +void SelectKV128Desc(K64V64* HWY_RESTRICT keys, const size_t num, + const size_t k) { + // 128-bit keys require 128-bit SIMD. +#if HWY_TARGET != HWY_SCALAR + return VQSelectStatic(keys, num, k, SortDescending()); +#else + (void)keys; + (void)num; + (void)k; +#endif +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortKV128Desc); +HWY_EXPORT(PartialSortKV128Desc); +HWY_EXPORT(SelectKV128Desc); +} // namespace + +void VQSort(K64V64* HWY_RESTRICT keys, const size_t n, SortDescending) { + HWY_DYNAMIC_DISPATCH(SortKV128Desc)(keys, n); +} + +void VQPartialSort(K64V64* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(PartialSortKV128Desc)(keys, n, k); +} + +void VQSelect(K64V64* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(SelectKV128Desc)(keys, n, k); +} + +void Sorter::operator()(K64V64* HWY_RESTRICT keys, size_t n, + SortDescending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_kv64a.cc b/lib/highway/hwy/contrib/sort/vqsort_kv64a.cc new file mode 100644 index 00000000000..2b3468a797e --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_kv64a.cc @@ -0,0 +1,81 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort + +#undef HWY_TARGET_INCLUDE +// clang-format off +// (avoid line break, which would prevent Copybara rules from matching) +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_kv64a.cc" //NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void SortKV64Asc(K32V32* HWY_RESTRICT keys, const size_t num) { + return VQSortStatic(keys, num, SortAscending()); +} + +void PartialSortKV64Asc(K32V32* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQPartialSortStatic(keys, num, k, SortAscending()); +} + +void SelectKV64Asc(K32V32* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQSelectStatic(keys, num, k, SortAscending()); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortKV64Asc); +HWY_EXPORT(PartialSortKV64Asc); +HWY_EXPORT(SelectKV64Asc); +} // namespace + +void VQSort(K32V32* HWY_RESTRICT keys, const size_t n, SortAscending) { + HWY_DYNAMIC_DISPATCH(SortKV64Asc)(keys, n); +} + +void VQPartialSort(K32V32* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(PartialSortKV64Asc)(keys, n, k); +} + +void VQSelect(K32V32* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(SelectKV64Asc)(keys, n, k); +} + +void Sorter::operator()(K32V32* HWY_RESTRICT keys, size_t n, + SortAscending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_kv64d.cc b/lib/highway/hwy/contrib/sort/vqsort_kv64d.cc new file mode 100644 index 00000000000..4d968bdd407 --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_kv64d.cc @@ -0,0 +1,81 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort + +#undef HWY_TARGET_INCLUDE +// clang-format off +// (avoid line break, which would prevent Copybara rules from matching) +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_kv64d.cc" //NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void SortKV64Desc(K32V32* HWY_RESTRICT keys, const size_t num) { + return VQSortStatic(keys, num, SortDescending()); +} + +void PartialSortKV64Desc(K32V32* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQPartialSortStatic(keys, num, k, SortDescending()); +} + +void SelectKV64Desc(K32V32* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQSelectStatic(keys, num, k, SortDescending()); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortKV64Desc); +HWY_EXPORT(PartialSortKV64Desc); +HWY_EXPORT(SelectKV64Desc); +} // namespace + +void VQSort(K32V32* HWY_RESTRICT keys, const size_t n, SortDescending) { + HWY_DYNAMIC_DISPATCH(SortKV64Desc)(keys, n); +} + +void VQPartialSort(K32V32* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(PartialSortKV64Desc)(keys, n, k); +} + +void VQSelect(K32V32* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(SelectKV64Desc)(keys, n, k); +} + +void Sorter::operator()(K32V32* HWY_RESTRICT keys, size_t n, + SortDescending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_u16a.cc b/lib/highway/hwy/contrib/sort/vqsort_u16a.cc new file mode 100644 index 00000000000..e2d177611e2 --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_u16a.cc @@ -0,0 +1,78 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u16a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void SortU16Asc(uint16_t* HWY_RESTRICT keys, const size_t num) { + return VQSortStatic(keys, num, SortAscending()); +} + +void PartialSortU16Asc(uint16_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQPartialSortStatic(keys, num, k, SortAscending()); +} + +void SelectU16Asc(uint16_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQSelectStatic(keys, num, k, SortAscending()); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU16Asc); +HWY_EXPORT(PartialSortU16Asc); +HWY_EXPORT(SelectU16Asc); +} // namespace + +void VQSort(uint16_t* HWY_RESTRICT keys, const size_t n, SortAscending) { + HWY_DYNAMIC_DISPATCH(SortU16Asc)(keys, n); +} + +void VQPartialSort(uint16_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(PartialSortU16Asc)(keys, n, k); +} + +void VQSelect(uint16_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(SelectU16Asc)(keys, n, k); +} + +void Sorter::operator()(uint16_t* HWY_RESTRICT keys, size_t n, + SortAscending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_u16d.cc b/lib/highway/hwy/contrib/sort/vqsort_u16d.cc new file mode 100644 index 00000000000..bdb022b73c9 --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_u16d.cc @@ -0,0 +1,78 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u16d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void SortU16Desc(uint16_t* HWY_RESTRICT keys, const size_t num) { + return VQSortStatic(keys, num, SortDescending()); +} + +void PartialSortU16Desc(uint16_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQPartialSortStatic(keys, num, k, SortDescending()); +} + +void SelectU16Desc(uint16_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQSelectStatic(keys, num, k, SortDescending()); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU16Desc); +HWY_EXPORT(PartialSortU16Desc); +HWY_EXPORT(SelectU16Desc); +} // namespace + +void VQSort(uint16_t* HWY_RESTRICT keys, const size_t n, SortDescending) { + HWY_DYNAMIC_DISPATCH(SortU16Desc)(keys, n); +} + +void VQPartialSort(uint16_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(PartialSortU16Desc)(keys, n, k); +} + +void VQSelect(uint16_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(SelectU16Desc)(keys, n, k); +} + +void Sorter::operator()(uint16_t* HWY_RESTRICT keys, size_t n, + SortDescending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_u32a.cc b/lib/highway/hwy/contrib/sort/vqsort_u32a.cc new file mode 100644 index 00000000000..6f82ce0e03b --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_u32a.cc @@ -0,0 +1,78 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u32a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void SortU32Asc(uint32_t* HWY_RESTRICT keys, const size_t num) { + return VQSortStatic(keys, num, SortAscending()); +} + +void PartialSortU32Asc(uint32_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQPartialSortStatic(keys, num, k, SortAscending()); +} + +void SelectU32Asc(uint32_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQSelectStatic(keys, num, k, SortAscending()); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU32Asc); +HWY_EXPORT(PartialSortU32Asc); +HWY_EXPORT(SelectU32Asc); +} // namespace + +void VQSort(uint32_t* HWY_RESTRICT keys, const size_t n, SortAscending) { + HWY_DYNAMIC_DISPATCH(SortU32Asc)(keys, n); +} + +void VQPartialSort(uint32_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(PartialSortU32Asc)(keys, n, k); +} + +void VQSelect(uint32_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(SelectU32Asc)(keys, n, k); +} + +void Sorter::operator()(uint32_t* HWY_RESTRICT keys, size_t n, + SortAscending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_u32d.cc b/lib/highway/hwy/contrib/sort/vqsort_u32d.cc new file mode 100644 index 00000000000..a58e5f26003 --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_u32d.cc @@ -0,0 +1,78 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u32d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void SortU32Desc(uint32_t* HWY_RESTRICT keys, const size_t num) { + return VQSortStatic(keys, num, SortDescending()); +} + +void PartialSortU32Desc(uint32_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQPartialSortStatic(keys, num, k, SortDescending()); +} + +void SelectU32Desc(uint32_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQSelectStatic(keys, num, k, SortDescending()); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU32Desc); +HWY_EXPORT(PartialSortU32Desc); +HWY_EXPORT(SelectU32Desc); +} // namespace + +void VQSort(uint32_t* HWY_RESTRICT keys, const size_t n, SortDescending) { + HWY_DYNAMIC_DISPATCH(SortU32Desc)(keys, n); +} + +void VQPartialSort(uint32_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(PartialSortU32Desc)(keys, n, k); +} + +void VQSelect(uint32_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(SelectU32Desc)(keys, n, k); +} + +void Sorter::operator()(uint32_t* HWY_RESTRICT keys, size_t n, + SortDescending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_u64a.cc b/lib/highway/hwy/contrib/sort/vqsort_u64a.cc new file mode 100644 index 00000000000..f6b3bb89375 --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_u64a.cc @@ -0,0 +1,78 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u64a.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void SortU64Asc(uint64_t* HWY_RESTRICT keys, const size_t num) { + return VQSortStatic(keys, num, SortAscending()); +} + +void PartialSortU64Asc(uint64_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQPartialSortStatic(keys, num, k, SortAscending()); +} + +void SelectU64Asc(uint64_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQSelectStatic(keys, num, k, SortAscending()); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU64Asc); +HWY_EXPORT(PartialSortU64Asc); +HWY_EXPORT(SelectU64Asc); +} // namespace + +void VQSort(uint64_t* HWY_RESTRICT keys, const size_t n, SortAscending) { + HWY_DYNAMIC_DISPATCH(SortU64Asc)(keys, n); +} + +void VQPartialSort(uint64_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(PartialSortU64Asc)(keys, n, k); +} + +void VQSelect(uint64_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortAscending) { + HWY_DYNAMIC_DISPATCH(SelectU64Asc)(keys, n, k); +} + +void Sorter::operator()(uint64_t* HWY_RESTRICT keys, size_t n, + SortAscending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/sort/vqsort_u64d.cc b/lib/highway/hwy/contrib/sort/vqsort_u64d.cc new file mode 100644 index 00000000000..901dd183e6a --- /dev/null +++ b/lib/highway/hwy/contrib/sort/vqsort_u64d.cc @@ -0,0 +1,78 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/sort/vqsort.h" // VQSort + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u64d.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep + +// After foreach_target +#include "hwy/contrib/sort/vqsort-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { + +void SortU64Desc(uint64_t* HWY_RESTRICT keys, const size_t num) { + return VQSortStatic(keys, num, SortDescending()); +} + +void PartialSortU64Desc(uint64_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQPartialSortStatic(keys, num, k, SortDescending()); +} + +void SelectU64Desc(uint64_t* HWY_RESTRICT keys, const size_t num, + const size_t k) { + return VQSelectStatic(keys, num, k, SortDescending()); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(SortU64Desc); +HWY_EXPORT(PartialSortU64Desc); +HWY_EXPORT(SelectU64Desc); +} // namespace + +void VQSort(uint64_t* HWY_RESTRICT keys, const size_t n, SortDescending) { + HWY_DYNAMIC_DISPATCH(SortU64Desc)(keys, n); +} + +void VQPartialSort(uint64_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(PartialSortU64Desc)(keys, n, k); +} + +void VQSelect(uint64_t* HWY_RESTRICT keys, const size_t n, const size_t k, + SortDescending) { + HWY_DYNAMIC_DISPATCH(SelectU64Desc)(keys, n, k); +} + +void Sorter::operator()(uint64_t* HWY_RESTRICT keys, size_t n, + SortDescending tag) const { + VQSort(keys, n, tag); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/contrib/thread_pool/futex.h b/lib/highway/hwy/contrib/thread_pool/futex.h new file mode 100644 index 00000000000..9050fdb8a3c --- /dev/null +++ b/lib/highway/hwy/contrib/thread_pool/futex.h @@ -0,0 +1,290 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_CONTRIB_THREAD_POOL_FUTEX_H_ +#define HIGHWAY_HWY_CONTRIB_THREAD_POOL_FUTEX_H_ + +// Keyed event (futex): kernel queue of blocked threads, identified by the +// address of an atomic u32 called `current` within the same process (do NOT +// use with shared-memory mappings). +// +// Futex equivalents: https://outerproduct.net/futex-dictionary.html; we +// support Linux/Emscripten/FreeBSD/Apple/Windows and C++20 std::atomic::wait, +// plus a NanoSleep fallback. + +#include + +#include +#include // INT_MAX + +#include "hwy/base.h" + +#if HWY_OS_APPLE +#include +// __ulock* were added in OS X 10.12 (Sierra, 2016). +#if MAC_OS_X_VERSION_MAX_ALLOWED < 101200 && !defined(HWY_DISABLE_FUTEX) +#define HWY_DISABLE_FUTEX +#endif +#endif // HWY_OS_APPLE + +#if HWY_OS_WIN +// Need to include on Windows, even if HWY_DISABLE_FUTEX is defined, +// since hwy::NanoSleep uses Windows API's that are defined in windows.h. +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif // WIN32_LEAN_AND_MEAN +#include +#endif + +#if HWY_ARCH_WASM +#include +#include // INFINITY + +#elif HWY_OS_LINUX +#include // IWYU pragma: keep +#include // FUTEX_* +#include +#include // SYS_* +#include +// Android may not declare these: +#ifndef SYS_futex +#ifdef SYS_futex_time64 // 32-bit with 64-bit time_t +#define SYS_futex SYS_futex_time64 +#else +#define SYS_futex __NR_futex +#endif // SYS_futex_time64 +#endif // SYS_futex +#ifndef FUTEX_WAIT_PRIVATE +#define FUTEX_WAIT_PRIVATE (FUTEX_WAIT | 128) +#endif +#ifndef FUTEX_WAKE_PRIVATE +#define FUTEX_WAKE_PRIVATE (FUTEX_WAKE | 128) +#endif + +#elif HWY_OS_FREEBSD && !defined(HWY_DISABLE_FUTEX) +#include // __FreeBSD_version +#if __FreeBSD_version >= 600000 +#include +#include +#include +#else +#define HWY_DISABLE_FUTEX +#endif + +#elif HWY_OS_APPLE && !defined(HWY_DISABLE_FUTEX) +// These are private APIs, so add an opt-out. +extern "C" { +int __ulock_wait(uint32_t op, void* address, uint64_t val, uint32_t max_us); +int __ulock_wake(uint32_t op, void* address, uint64_t zero); +} // extern "C" +#define UL_COMPARE_AND_WAIT 1 +#define ULF_WAKE_ALL 0x00000100 + +#elif HWY_OS_WIN && !defined(HWY_DISABLE_FUTEX) +// WakeByAddressAll requires Windows 8, so add an opt-out. +#if HWY_COMPILER_MSVC || HWY_COMPILER_CLANGCL +#pragma comment(lib, "synchronization.lib") +#endif + +#elif HWY_CXX_LANG < 202002L // NOT C++20, which has native support +#define HWY_FUTEX_SLEEP +#endif + +namespace hwy { + +// Attempts to pause for the specified nanoseconds, though the resolution is +// closer to 0.1 microseconds. Returns false if no wait happened. Thread-safe. +static inline bool NanoSleep(uint64_t ns) { +#if HWY_OS_WIN + static thread_local HANDLE hTimer = nullptr; + if (HWY_UNLIKELY(hTimer == nullptr)) { + // Must be manual reset: auto-reset would immediately signal after the next + // SetWaitableTimer. + hTimer = CreateWaitableTimer(nullptr, TRUE, nullptr); + if (hTimer == nullptr) return false; + } + + // Negative means relative, in units of 100 ns. + LARGE_INTEGER time; + time.QuadPart = -static_cast(ns / 100); + const LONG period = 0; // signal once + if (!SetWaitableTimer(hTimer, &time, period, nullptr, nullptr, FALSE)) { + return false; + } + + (void)WaitForSingleObject(hTimer, INFINITE); + return true; +#else + timespec duration; + duration.tv_sec = static_cast(ns / 1000000000); + duration.tv_nsec = static_cast(ns % 1000000000); + timespec remainder; + // Repeat if interrupted by a signal. Note that the remainder may be rounded + // up, which could cause an infinite loop if continually interrupted. Using + // clock_nanosleep would work, but we'd have to get the current time. We + // assume durations are short, and instead just cap the number of retries. + for (int rep = 0; rep < 3; ++rep) { + if (nanosleep(&duration, &remainder) == 0 || errno != EINTR) break; + duration = remainder; + } + return true; +#endif +} + +// Waits until `current != prev` and returns the new value. May return +// immediately if `current` already changed, or after blocking and waking. +static inline uint32_t BlockUntilDifferent( + const uint32_t prev, const std::atomic& current) { + const auto acq = std::memory_order_acquire; + +#if HWY_ARCH_WASM + // It is always safe to cast to void. + volatile void* address = + const_cast(static_cast(¤t)); + const double max_ms = INFINITY; + for (;;) { + const uint32_t next = current.load(acq); + if (next != prev) return next; + const int ret = emscripten_futex_wait(address, prev, max_ms); + HWY_DASSERT(ret >= 0); + (void)ret; + } + +#elif HWY_OS_LINUX + // Safe to cast because std::atomic is a standard layout type. + const uint32_t* address = reinterpret_cast(¤t); + // _PRIVATE requires this only be used in the same process, and avoids + // virtual->physical lookups and atomic reference counting. + const int op = FUTEX_WAIT_PRIVATE; + for (;;) { + const uint32_t next = current.load(acq); + if (next != prev) return next; + // timeout=null may prevent interrupts via signal. No lvalue because + // the timespec type is only standardized since C++17 or C11. + const auto ret = syscall(SYS_futex, address, op, prev, nullptr, nullptr, 0); + if (ret == -1) { + HWY_DASSERT(errno == EAGAIN); // otherwise an actual error + } + } + +#elif HWY_OS_FREEBSD && !defined(HWY_DISABLE_FUTEX) // >= 6.0 + // _umtx_op with UMTX_OP_WAIT_UINT_PRIVATE: process-private futex on FreeBSD. + void* address = const_cast(static_cast(¤t)); + for (;;) { + const uint32_t next = current.load(acq); + if (next != prev) return next; + const int ret = _umtx_op(address, UMTX_OP_WAIT_UINT_PRIVATE, + static_cast(prev), nullptr, nullptr); + if (ret == -1) { + HWY_DASSERT(errno == EAGAIN || errno == EINTR); + } + } + +#elif HWY_OS_WIN && !defined(HWY_DISABLE_FUTEX) + // It is always safe to cast to void. + volatile void* address = + const_cast(static_cast(¤t)); + // API is not const-correct, but only loads from the pointer. + PVOID pprev = const_cast(static_cast(&prev)); + const DWORD max_ms = INFINITE; + for (;;) { + const uint32_t next = current.load(acq); + if (next != prev) return next; + const BOOL ok = WaitOnAddress(address, pprev, sizeof(prev), max_ms); + HWY_DASSERT(ok); + (void)ok; + } + +#elif HWY_OS_APPLE && !defined(HWY_DISABLE_FUTEX) + // It is always safe to cast to void. + void* address = const_cast(static_cast(¤t)); + for (;;) { + const uint32_t next = current.load(acq); + if (next != prev) return next; + __ulock_wait(UL_COMPARE_AND_WAIT, address, prev, 0); + } + +#elif defined(HWY_FUTEX_SLEEP) + for (;;) { + const uint32_t next = current.load(acq); + if (next != prev) return next; + NanoSleep(2000); + } + +#elif HWY_CXX_LANG >= 202002L + current.wait(prev, acq); // No spurious wakeup. + const uint32_t next = current.load(acq); + HWY_DASSERT(next != prev); + return next; + +#else +#error "Logic error, should have reached HWY_FUTEX_SLEEP" +#endif // HWY_OS_* +} // BlockUntilDifferent + +// Wakes all threads, if any, that are waiting because they called +// `BlockUntilDifferent` with the same `current`. +static inline void WakeAll(std::atomic& current) { +#if HWY_ARCH_WASM + // It is always safe to cast to void. + volatile void* address = static_cast(¤t); + const int max_to_wake = INT_MAX; // actually signed + const int ret = emscripten_futex_wake(address, max_to_wake); + HWY_DASSERT(ret >= 0); + (void)ret; + +#elif HWY_OS_LINUX + // Safe to cast because std::atomic is a standard layout type. + uint32_t* address = reinterpret_cast(¤t); + const int max_to_wake = INT_MAX; // actually signed + const auto ret = syscall(SYS_futex, address, FUTEX_WAKE_PRIVATE, max_to_wake, + nullptr, nullptr, 0); + HWY_DASSERT(ret >= 0); // number woken + (void)ret; + +#elif HWY_OS_FREEBSD && !defined(HWY_DISABLE_FUTEX) // >= 6.0 + void* address = static_cast(¤t); + const int ret = _umtx_op(address, UMTX_OP_WAKE_PRIVATE, INT_MAX, nullptr, + nullptr); + HWY_DASSERT(ret >= 0); + (void)ret; + +#elif HWY_OS_WIN && !defined(HWY_DISABLE_FUTEX) + // It is always safe to cast to void. + void* address = static_cast(¤t); + WakeByAddressAll(address); + +#elif HWY_OS_APPLE && !defined(HWY_DISABLE_FUTEX) + // It is always safe to cast to void. + void* address = static_cast(¤t); + __ulock_wake(UL_COMPARE_AND_WAIT | ULF_WAKE_ALL, address, 0); + +#elif defined(HWY_FUTEX_SLEEP) + // NanoSleep loop does not require wakeup. + (void)current; +#elif HWY_CXX_LANG >= 202002L + current.notify_all(); + +#else +#error "Logic error, should have reached HWY_FUTEX_SLEEP" +#endif +} // WakeAll + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_THREAD_POOL_FUTEX_H_ diff --git a/lib/highway/hwy/contrib/thread_pool/spin.h b/lib/highway/hwy/contrib/thread_pool/spin.h new file mode 100644 index 00000000000..c70d62d0a65 --- /dev/null +++ b/lib/highway/hwy/contrib/thread_pool/spin.h @@ -0,0 +1,336 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_CONTRIB_THREAD_POOL_SPIN_H_ +#define HIGHWAY_HWY_CONTRIB_THREAD_POOL_SPIN_H_ + +// Relatively power-efficient spin lock for low-latency synchronization. + +#include + +#include + +#include "hwy/base.h" +#include "hwy/cache_control.h" // Pause + +#ifndef HWY_ENABLE_MONITORX // allow override +// Clang 3.9 suffices for mwaitx, but the target pragma requires 9.0. +#if HWY_ARCH_X86 && ((HWY_COMPILER_CLANG >= 900) || \ + (HWY_COMPILER_GCC_ACTUAL >= 502) || defined(__MWAITX__)) +#define HWY_ENABLE_MONITORX 1 +#else +#define HWY_ENABLE_MONITORX 0 +#endif +#endif // HWY_ENABLE_MONITORX + +#ifndef HWY_ENABLE_UMONITOR // allow override +#if HWY_ARCH_X86 && ((HWY_COMPILER_CLANG >= 900) || \ + (HWY_COMPILER_GCC_ACTUAL >= 901) || defined(__WAITPKG__)) +#define HWY_ENABLE_UMONITOR 1 +#else +#define HWY_ENABLE_UMONITOR 0 +#endif +#endif // HWY_ENABLE_UMONITOR + +// Inline assembly is preferred because it allows inlining of `UntilDifferent` +// etc, but we also support intrinsics for MSVC. +#ifndef HWY_ENABLE_SPIN_ASM // allow override +#if (HWY_COMPILER_CLANG || HWY_COMPILER_GCC) && HWY_ARCH_X86_64 +#define HWY_ENABLE_SPIN_ASM 1 +#else +#define HWY_ENABLE_SPIN_ASM 0 +#endif +#endif // HWY_ENABLE_SPIN_ASM + +#if HWY_ENABLE_MONITORX || HWY_ENABLE_UMONITOR +#if HWY_ENABLE_SPIN_ASM +#define HWY_INLINE_SPIN HWY_INLINE // can inline functions with inline assembly +#else +// Intrinsics require attributes, which prevent inlining. +#define HWY_INLINE_SPIN +#include +#endif // HWY_ENABLE_SPIN_ASM + +#include "hwy/x86_cpuid.h" +#endif // HWY_ENABLE_MONITORX || HWY_ENABLE_UMONITOR + +namespace hwy { + +// Returned by `UntilDifferent` in a single register. +struct SpinResult { + // We also use u32 because that is all that futex.h supports. + uint32_t current; + // Number of retries before returning, useful for checking that the + // monitor/wait did not just return immediately. + uint32_t reps; +}; + +// User-space monitor/wait are supported on Zen2+ AMD and SPR+ Intel. Spin waits +// are rarely called from SIMD code, hence we do not integrate this into +// `HWY_TARGET` and its runtime dispatch mechanism. Returned by `Type()`, also +// used by callers to set the `disabled` argument for `DetectSpin`. +enum class SpinType : uint8_t { +#if HWY_ENABLE_MONITORX + kMonitorX = 1, // AMD +#endif +#if HWY_ENABLE_UMONITOR + kUMonitor = 2, // Intel +#endif + kPause = 3, + kSentinel // for iterating over all enumerators. Must be last. +}; + +// For printing which is in use. +static inline const char* ToString(SpinType type) { + switch (type) { +#if HWY_ENABLE_MONITORX + case SpinType::kMonitorX: + return "MonitorX_C1"; +#endif +#if HWY_ENABLE_UMONITOR + case SpinType::kUMonitor: + return "UMonitor_C0.2"; +#endif + case SpinType::kPause: + return "Pause"; + case SpinType::kSentinel: + default: + return nullptr; + } +} + +// Indirect function calls turn out to be too expensive because this is called +// multiple times per ThreadPool barrier. We will instead inline the spin and +// barrier using policy classes. This one is always available; use it as a +// reference for the interface. Note that Pause varies across CPUs: it can be +// a no-op, or wait 140 cycles. +struct SpinPause { + SpinType Type() const { return SpinType::kPause; } + + // Spins until `watched != prev` and returns the new value, similar to + // `BlockUntilDifferent` in `futex.h`. + HWY_INLINE SpinResult UntilDifferent( + const uint32_t prev, const std::atomic& watched) const { + for (uint32_t reps = 0;; ++reps) { + const uint32_t current = watched.load(std::memory_order_acquire); + if (current != prev) return SpinResult{current, reps}; + hwy::Pause(); + } + } + + // Returns number of retries until `watched == expected`. + HWY_INLINE size_t UntilEqual(const uint32_t expected, + const std::atomic& watched) const { + for (size_t reps = 0;; ++reps) { + const uint32_t current = watched.load(std::memory_order_acquire); + if (current == expected) return reps; + hwy::Pause(); + } + } +}; + +#if HWY_ENABLE_MONITORX || HWY_IDE +#if !HWY_ENABLE_SPIN_ASM +HWY_PUSH_ATTRIBUTES("mwaitx") +#endif + +// AMD's user-mode monitor/wait (Zen2+). +class SpinMonitorX { + public: + SpinType Type() const { return SpinType::kMonitorX; } + + HWY_INLINE_SPIN SpinResult UntilDifferent( + const uint32_t prev, const std::atomic& watched) const { + for (uint32_t reps = 0;; ++reps) { + uint32_t current = watched.load(std::memory_order_acquire); + if (current != prev) return SpinResult{current, reps}; + Monitor(&watched); + // Double-checked 'lock' to avoid missed events: + current = watched.load(std::memory_order_acquire); + if (current != prev) return SpinResult{current, reps}; + Wait(); + } + } + + HWY_INLINE_SPIN size_t UntilEqual( + const uint32_t expected, const std::atomic& watched) const { + for (size_t reps = 0;; ++reps) { + uint32_t current = watched.load(std::memory_order_acquire); + if (current == expected) return reps; + Monitor(&watched); + // Double-checked 'lock' to avoid missed events: + current = watched.load(std::memory_order_acquire); + if (current == expected) return reps; + Wait(); + } + } + + private: + static HWY_INLINE void Monitor(const void* addr) { + // No extensions/hints currently defined. +#if HWY_ENABLE_SPIN_ASM + asm volatile("monitorx" ::"a"(addr), "c"(0), "d"(0)); +#else + _mm_monitorx(const_cast(addr), 0, 0); +#endif + } + + static HWY_INLINE void Wait() { +#if HWY_ENABLE_SPIN_ASM + // EBX=0 cycles means no timeout/infinite. + asm volatile("mwaitx" ::"a"(kHints), "b"(0), "c"(kExtensions)); +#else + _mm_mwaitx(kExtensions, kHints, /*cycles=*/0); +#endif + } + + // 0xF would be C0. Its wakeup latency is less than 0.1 us shorter, and + // package power is sometimes actually higher than with Pause. The + // difference in spurious wakeups is minor. + static constexpr unsigned kHints = 0x0; // C1: a bit deeper than C0 + // No timeout required, we assume the mwaitx does not miss stores, see + // https://www.usenix.org/system/files/usenixsecurity23-zhang-ruiyi.pdf.] + static constexpr unsigned kExtensions = 0; +}; + +#if !HWY_ENABLE_SPIN_ASM +HWY_POP_ATTRIBUTES +#endif +#endif // HWY_ENABLE_MONITORX + +#if HWY_ENABLE_UMONITOR || HWY_IDE +#if !HWY_ENABLE_SPIN_ASM +HWY_PUSH_ATTRIBUTES("waitpkg") +#endif + +// Intel's user-mode monitor/wait (SPR+). +class SpinUMonitor { + public: + SpinType Type() const { return SpinType::kUMonitor; } + + HWY_INLINE_SPIN SpinResult UntilDifferent( + const uint32_t prev, const std::atomic& watched) const { + for (uint32_t reps = 0;; ++reps) { + uint32_t current = watched.load(std::memory_order_acquire); + if (current != prev) return SpinResult{current, reps}; + Monitor(&watched); + // Double-checked 'lock' to avoid missed events: + current = watched.load(std::memory_order_acquire); + if (current != prev) return SpinResult{current, reps}; + Wait(); + } + } + + HWY_INLINE_SPIN size_t UntilEqual( + const uint32_t expected, const std::atomic& watched) const { + for (size_t reps = 0;; ++reps) { + uint32_t current = watched.load(std::memory_order_acquire); + if (current == expected) return reps; + Monitor(&watched); + // Double-checked 'lock' to avoid missed events: + current = watched.load(std::memory_order_acquire); + if (current == expected) return reps; + Wait(); + } + } + + private: + static HWY_INLINE void Monitor(const void* addr) { +#if HWY_ENABLE_SPIN_ASM + asm volatile("umonitor %%rcx" ::"c"(addr)); +#else + _umonitor(const_cast(addr)); +#endif + } + + static HWY_INLINE void Wait() { +#if HWY_ENABLE_SPIN_ASM + asm volatile("umwait %%ecx" ::"c"(kControl), "d"(kDeadline >> 32), + "a"(kDeadline & 0xFFFFFFFFu)); +#else + _umwait(kControl, kDeadline); +#endif + } + + // 1 would be C0.1. C0.2 has 20x fewer spurious wakeups and additional 4% + // package power savings vs Pause on SPR. It comes at the cost of + // 0.4-0.6us higher wake latency, but the total is comparable to Zen4. + static constexpr unsigned kControl = 0; // C0.2 for deeper sleep + static constexpr uint64_t kDeadline = ~uint64_t{0}; // no timeout, see above +}; + +#if !HWY_ENABLE_SPIN_ASM +HWY_POP_ATTRIBUTES +#endif +#endif // HWY_ENABLE_UMONITOR + +// TODO(janwas): add WFE on Arm. May wake at 10 kHz, but still worthwhile. + +// Returns the best-available type whose bit in `disabled` is not set. Example: +// to disable kUMonitor, pass `1 << static_cast(SpinType::kUMonitor)`. +// Ignores `disabled` for `kPause` if it is the only supported and enabled type. +// Somewhat expensive, typically called during initialization. +static inline SpinType DetectSpin(int disabled = 0) { + const auto enabled = [disabled](SpinType type) { + return (disabled & (1 << static_cast(type))) == 0; + }; + (void)enabled; + +#if HWY_ENABLE_MONITORX + if (enabled(SpinType::kMonitorX) && x86::IsAMD()) { + uint32_t abcd[4]; + x86::Cpuid(0x80000001U, 0, abcd); + if (x86::IsBitSet(abcd[2], 29)) return SpinType::kMonitorX; + } +#endif // HWY_ENABLE_MONITORX + +#if HWY_ENABLE_UMONITOR + if (enabled(SpinType::kUMonitor) && x86::MaxLevel() >= 7) { + uint32_t abcd[4]; + x86::Cpuid(7, 0, abcd); + if (x86::IsBitSet(abcd[2], 5)) return SpinType::kUMonitor; + } +#endif // HWY_ENABLE_UMONITOR + + if (!enabled(SpinType::kPause)) { + HWY_WARN("Ignoring attempt to disable Pause, it is the only option left."); + } + return SpinType::kPause; +} + +// Calls `func(spin, args)` for the given `spin_type`. +template +HWY_INLINE void CallWithSpin(SpinType spin_type, Func&& func, Args&&... args) { + switch (spin_type) { +#if HWY_ENABLE_MONITORX + case SpinType::kMonitorX: + func(SpinMonitorX(), std::forward(args)...); + break; +#endif +#if HWY_ENABLE_UMONITOR + case SpinType::kUMonitor: + func(SpinUMonitor(), std::forward(args)...); + break; +#endif + case SpinType::kPause: + default: + func(SpinPause(), std::forward(args)...); + break; + } +} + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_THREAD_POOL_SPIN_H_ diff --git a/lib/highway/hwy/contrib/thread_pool/thread_pool.cc b/lib/highway/hwy/contrib/thread_pool/thread_pool.cc new file mode 100644 index 00000000000..e3f81087e61 --- /dev/null +++ b/lib/highway/hwy/contrib/thread_pool/thread_pool.cc @@ -0,0 +1,31 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/thread_pool/thread_pool.h" + +#include "hwy/highway_export.h" + +namespace hwy { +namespace pool { + +// TODO: move implementation here. + +HWY_CONTRIB_DLLEXPORT Shared& Shared::Get() { + static Shared* shared = new Shared(); + return *shared; +} + +} // namespace pool +} // namespace hwy diff --git a/lib/highway/hwy/contrib/thread_pool/thread_pool.h b/lib/highway/hwy/contrib/thread_pool/thread_pool.h new file mode 100644 index 00000000000..dc7bd0e77f6 --- /dev/null +++ b/lib/highway/hwy/contrib/thread_pool/thread_pool.h @@ -0,0 +1,1719 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Modified from BSD-licensed code +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// See https://github.com/libjxl/libjxl/blob/main/LICENSE. + +#ifndef HIGHWAY_HWY_CONTRIB_THREAD_POOL_THREAD_POOL_H_ +#define HIGHWAY_HWY_CONTRIB_THREAD_POOL_THREAD_POOL_H_ + +#include +#include +#include // snprintf +#include + +#include +#include +#include +#include // NOLINT +#include + +#include "hwy/aligned_allocator.h" // HWY_ALIGNMENT +#include "hwy/auto_tune.h" +#include "hwy/base.h" +#include "hwy/cache_control.h" // Pause +#include "hwy/contrib/thread_pool/futex.h" +#include "hwy/contrib/thread_pool/spin.h" +#include "hwy/contrib/thread_pool/topology.h" +#include "hwy/profiler.h" +#include "hwy/stats.h" +#include "hwy/timer.h" + +#if HWY_OS_APPLE +#include +#endif + +#if PROFILER_ENABLED +#include // std::sort + +#include "hwy/bit_set.h" +#endif + +namespace hwy { + +// Sets the name of the current thread to the format string `format`, which must +// include %d for `thread`. Currently only implemented for pthreads (*nix and +// OSX); Windows involves throwing an exception. +static inline void SetThreadName(const char* format, int thread) { + char buf[16] = {}; // Linux limit, including \0 + const int chars_written = snprintf(buf, sizeof(buf), format, thread); + HWY_ASSERT(0 < chars_written && + chars_written <= static_cast(sizeof(buf) - 1)); + +#if (HWY_OS_LINUX && (!defined(__ANDROID__) || __ANDROID_API__ >= 19)) || \ + HWY_OS_FREEBSD + // Note that FreeBSD pthread_set_name_np does not return a value (#2669). + HWY_ASSERT(0 == pthread_setname_np(pthread_self(), buf)); +#elif HWY_OS_APPLE && (MAC_OS_X_VERSION_MIN_REQUIRED >= 1060) + // Different interface: single argument, current thread only. + HWY_ASSERT(0 == pthread_setname_np(buf)); +#elif defined(__EMSCRIPTEN__) + emscripten_set_thread_name(pthread_self(), buf); +#else + (void)format; + (void)thread; +#endif +} + +// Whether workers should block or spin. +enum class PoolWaitMode : uint8_t { kBlock = 1, kSpin }; + +enum class Exit : uint32_t { kNone, kLoop, kThread }; + +// Upper bound on non-empty `ThreadPool` (single-worker pools do not count). +// Turin has 16 clusters. Add one for the across-cluster pool. +HWY_INLINE_VAR constexpr size_t kMaxClusters = 32 + 1; + +// Use the last slot so that `PoolWorkerMapping` does not have to know the +// total number of clusters. +HWY_INLINE_VAR constexpr size_t kAllClusters = kMaxClusters - 1; + +// Argument to `ThreadPool`: how to map local worker_idx to global. +class PoolWorkerMapping { + public: + // Backward-compatible mode: returns local worker index. + PoolWorkerMapping() : cluster_idx_(0), max_cluster_workers_(0) {} + PoolWorkerMapping(size_t cluster_idx, size_t max_cluster_workers) + : cluster_idx_(cluster_idx), max_cluster_workers_(max_cluster_workers) { + HWY_DASSERT(cluster_idx <= kAllClusters); + // Only use this ctor for the new global worker index mode. If this were + // zero, we would still return local indices. + HWY_DASSERT(max_cluster_workers != 0); + } + + size_t ClusterIdx() const { return cluster_idx_; } + size_t MaxClusterWorkers() const { return max_cluster_workers_; } + + // Returns global_idx, or unchanged local worker_idx if default-constructed. + size_t operator()(size_t worker_idx) const { + if (cluster_idx_ == kAllClusters) { + const size_t cluster_idx = worker_idx; + HWY_DASSERT(cluster_idx < kAllClusters); + // First index within the N-th cluster. The main thread is the first. + return cluster_idx * max_cluster_workers_; + } + HWY_DASSERT(max_cluster_workers_ == 0 || worker_idx < max_cluster_workers_); + return cluster_idx_ * max_cluster_workers_ + worker_idx; + } + + private: + size_t cluster_idx_; + size_t max_cluster_workers_; +}; + +namespace pool { + +#ifndef HWY_POOL_VERBOSITY +#define HWY_POOL_VERBOSITY 0 +#endif + +static constexpr int kVerbosity = HWY_POOL_VERBOSITY; + +// Some CPUs already have more than this many threads, but rather than one +// large pool, we assume applications create multiple pools, ideally per +// cluster (cores sharing a cache), because this improves locality and barrier +// latency. In that case, this is a generous upper bound. +static constexpr size_t kMaxThreads = 127; + +// Generates a random permutation of [0, size). O(1) storage. +class ShuffledIota { + public: + ShuffledIota() : coprime_(1) {} // for Worker + explicit ShuffledIota(uint32_t coprime) : coprime_(coprime) {} + + // Returns the next after `current`, using an LCG-like generator. + uint32_t Next(uint32_t current, const Divisor64& divisor) const { + HWY_DASSERT(current < divisor.GetDivisor()); + // (coprime * i + current) % size, see https://lemire.me/blog/2017/09/18/. + return static_cast(divisor.Remainder(current + coprime_)); + } + + // Returns true if a and b have no common denominator except 1. Based on + // binary GCD. Assumes a and b are nonzero. Also used in tests. + static bool CoprimeNonzero(uint32_t a, uint32_t b) { + const size_t trailing_a = Num0BitsBelowLS1Bit_Nonzero32(a); + const size_t trailing_b = Num0BitsBelowLS1Bit_Nonzero32(b); + // If both have at least one trailing zero, they are both divisible by 2. + if (HWY_MIN(trailing_a, trailing_b) != 0) return false; + + // If one of them has a trailing zero, shift it out. + a >>= trailing_a; + b >>= trailing_b; + + for (;;) { + // Swap such that a >= b. + const uint32_t tmp_a = a; + a = HWY_MAX(tmp_a, b); + b = HWY_MIN(tmp_a, b); + + // When the smaller number is 1, they were coprime. + if (b == 1) return true; + + a -= b; + // a == b means there was a common factor, so not coprime. + if (a == 0) return false; + a >>= Num0BitsBelowLS1Bit_Nonzero32(a); + } + } + + // Returns another coprime >= `start`, or 1 for small `size`. + // Used to seed independent ShuffledIota instances. + static uint32_t FindAnotherCoprime(uint32_t size, uint32_t start) { + if (size <= 2) { + return 1; + } + + // Avoids even x for even sizes, which are sure to be rejected. + const uint32_t inc = (size & 1) ? 1 : 2; + + for (uint32_t x = start | 1; x < start + size * 16; x += inc) { + if (CoprimeNonzero(x, static_cast(size))) { + return x; + } + } + + HWY_UNREACHABLE; + } + + uint32_t coprime_; +}; + +// 'Policies' suitable for various worker counts and locality. To define a +// new class, add an enum and update `ToString` plus `CallWithConfig`. The +// enumerators must be contiguous so we can iterate over them. +enum class WaitType : uint8_t { + kBlock, + kSpin1, + kSpinSeparate, + kSentinel // Must be last. +}; + +// For printing which is in use. +static inline const char* ToString(WaitType type) { + switch (type) { + case WaitType::kBlock: + return "Block"; + case WaitType::kSpin1: + return "Single"; + case WaitType::kSpinSeparate: + return "Separate"; + case WaitType::kSentinel: + return nullptr; + } +} + +// Parameters governing the main and worker thread behavior. Can be updated at +// runtime via `SetWaitMode`, which calls `SendConfig`. Both have copies which +// are carefully synchronized. 32 bits leave room for two future fields. +// 64 bits would also be fine because this does not go through futex. +struct Config { // 4 bytes + static std::vector AllCandidates(PoolWaitMode wait_mode) { + std::vector candidates; + + if (wait_mode == PoolWaitMode::kSpin) { + std::vector spin_types; + spin_types.reserve(2); + spin_types.push_back(DetectSpin()); + // Monitor-based spin may be slower, so also try Pause. + if (spin_types[0] != SpinType::kPause) { + spin_types.push_back(SpinType::kPause); + } + + // All except `kBlock`. + std::vector wait_types; + for (size_t wait = 0;; ++wait) { + const WaitType wait_type = static_cast(wait); + if (wait_type == WaitType::kSentinel) break; + if (wait_type != WaitType::kBlock) wait_types.push_back(wait_type); + } + + candidates.reserve(spin_types.size() * wait_types.size()); + for (const SpinType spin_type : spin_types) { + for (const WaitType wait_type : wait_types) { + candidates.emplace_back(spin_type, wait_type); + } + } + } else { + // kBlock does not use spin, so there is only one candidate. + candidates.emplace_back(SpinType::kPause, WaitType::kBlock); + } + + return candidates; + } + + std::string ToString() const { + char buf[128]; + snprintf(buf, sizeof(buf), "%-14s %-9s", hwy::ToString(spin_type), + pool::ToString(wait_type)); + return buf; + } + + Config(SpinType spin_type_in, WaitType wait_type_in) + : spin_type(spin_type_in), wait_type(wait_type_in) {} + // Workers initially spin until ThreadPool sends them their actual config. + Config() : Config(SpinType::kPause, WaitType::kSpinSeparate) {} + + SpinType spin_type; + WaitType wait_type; + HWY_MEMBER_VAR_MAYBE_UNUSED uint8_t reserved[2]; +}; +static_assert(sizeof(Config) == 4, ""); + +#if PROFILER_ENABLED + +// Accumulates timings and stats from main thread and workers. +class Stats { + // Up to `HWY_ALIGNMENT / 8` slots/offsets, passed to `PerThread`. + static constexpr size_t kDWait = 0; + static constexpr size_t kWaitReps = 1; + static constexpr size_t kTBeforeRun = 2; + static constexpr size_t kDRun = 3; + static constexpr size_t kTasksStatic = 4; + static constexpr size_t kTasksDynamic = 5; + static constexpr size_t kTasksStolen = 6; + static constexpr size_t kDFuncStatic = 7; + static constexpr size_t kDFuncDynamic = 8; + static constexpr size_t kSentinel = 9; + + public: + Stats() { + for (size_t thread_idx = 0; thread_idx < kMaxThreads; ++thread_idx) { + for (size_t offset = 0; offset < kSentinel; ++offset) { + PerThread(thread_idx, offset) = 0; + } + } + Reset(); + } + + // Called by the N lowest-indexed workers that got one of the N tasks, which + // includes the main thread because its index is 0. + // `d_*` denotes "difference" (of timestamps) and thus also duration. + void NotifyRunStatic(size_t worker_idx, timer::Ticks d_func) { + if (worker_idx == 0) { // main thread + num_run_static_++; + sum_tasks_static_++; + sum_d_func_static_ += d_func; + } else { + const size_t thread_idx = worker_idx - 1; + // Defer the sums until `NotifyMainRun` to avoid atomic RMW. + PerThread(thread_idx, kTasksStatic)++; + PerThread(thread_idx, kDFuncStatic) += d_func; + } + } + + // Called by all workers, including the main thread, regardless of whether + // they actually stole or even ran a task. + void NotifyRunDynamic(size_t worker_idx, size_t tasks, size_t stolen, + timer::Ticks d_func) { + if (worker_idx == 0) { // main thread + num_run_dynamic_++; + sum_tasks_dynamic_ += tasks; + sum_tasks_stolen_ += stolen; + sum_d_func_dynamic_ += d_func; + } else { + const size_t thread_idx = worker_idx - 1; + // Defer the sums until `NotifyMainRun` to avoid atomic RMW. + PerThread(thread_idx, kTasksDynamic) += tasks; + PerThread(thread_idx, kTasksStolen) += stolen; + PerThread(thread_idx, kDFuncDynamic) += d_func; + } + } + + // Called concurrently by non-main worker threads after their `WorkerRun` and + // before the barrier. + void NotifyThreadRun(size_t worker_idx, timer::Ticks d_wait, size_t wait_reps, + timer::Ticks t_before_run, timer::Ticks d_run) { + HWY_DASSERT(worker_idx != 0); // Not called by main thread. + const size_t thread_idx = worker_idx - 1; + HWY_DASSERT(PerThread(thread_idx, kDWait) == 0); + HWY_DASSERT(PerThread(thread_idx, kWaitReps) == 0); + HWY_DASSERT(PerThread(thread_idx, kTBeforeRun) == 0); + HWY_DASSERT(PerThread(thread_idx, kDRun) == 0); + PerThread(thread_idx, kDWait) = d_wait; + PerThread(thread_idx, kWaitReps) = wait_reps; + PerThread(thread_idx, kTBeforeRun) = t_before_run; // For wake latency. + PerThread(thread_idx, kDRun) = d_run; + } + + // Called by the main thread after the barrier, whose store-release and + // load-acquire publishes all prior writes. Note: only the main thread can + // store `after_barrier`. If workers did, which by definition happens after + // the barrier, then they would race with this function's reads. + void NotifyMainRun(size_t num_threads, timer::Ticks t_before_wake, + timer::Ticks d_wake, timer::Ticks d_main_run, + timer::Ticks d_barrier) { + HWY_DASSERT(num_threads <= kMaxThreads); + + timer::Ticks min_d_run = ~timer::Ticks{0}; + timer::Ticks max_d_run = 0; + timer::Ticks sum_d_run = 0; + for (size_t thread_idx = 0; thread_idx < num_threads; ++thread_idx) { + sum_tasks_static_ += PerThread(thread_idx, kTasksStatic); + sum_tasks_dynamic_ += PerThread(thread_idx, kTasksDynamic); + sum_tasks_stolen_ += PerThread(thread_idx, kTasksStolen); + sum_d_func_static_ += PerThread(thread_idx, kDFuncStatic); + sum_d_func_dynamic_ += PerThread(thread_idx, kDFuncDynamic); + sum_d_wait_ += PerThread(thread_idx, kDWait); + sum_wait_reps_ += PerThread(thread_idx, kWaitReps); + const timer::Ticks d_thread_run = PerThread(thread_idx, kDRun); + min_d_run = HWY_MIN(min_d_run, d_thread_run); + max_d_run = HWY_MAX(max_d_run, d_thread_run); + sum_d_run += d_thread_run; + const timer::Ticks t_before_run = PerThread(thread_idx, kTBeforeRun); + + for (size_t offset = 0; offset < kSentinel; ++offset) { + PerThread(thread_idx, offset) = 0; + } + + HWY_DASSERT(t_before_run != 0); + const timer::Ticks d_latency = t_before_run - t_before_wake; + sum_wake_latency_ += d_latency; + max_wake_latency_ = HWY_MAX(max_wake_latency_, d_latency); + } + const double inv_avg_d_run = + static_cast(num_threads) / static_cast(sum_d_run); + // Ratios of min and max run times to the average, for this pool.Run. + const double r_min = static_cast(min_d_run) * inv_avg_d_run; + const double r_max = static_cast(max_d_run) * inv_avg_d_run; + + num_run_++; // `num_run_*` are incremented by `NotifyRun*`. + sum_d_run_ += sum_d_run; + sum_r_min_ += r_min; // For average across all pool.Run. + sum_r_max_ += r_max; + + sum_d_wake_ += d_wake; // `*wake_latency_` are updated above. + sum_d_barrier_ += d_barrier; + + sum_d_run_ += d_main_run; + sum_d_run_main_ += d_main_run; + } + + void PrintAndReset(size_t num_threads, timer::Ticks d_thread_lifetime_ticks) { + // This is unconditionally called via `ProfilerFunc`. If the pool was unused + // in this invocation, skip it. + if (num_run_ == 0) return; + HWY_ASSERT(num_run_ == num_run_static_ + num_run_dynamic_); + + const double d_func_static = Seconds(sum_d_func_static_); + const double d_func_dynamic = Seconds(sum_d_func_dynamic_); + const double sum_d_run = Seconds(sum_d_run_); + const double func_div_run = (d_func_static + d_func_dynamic) / sum_d_run; + if (!(0.95 <= func_div_run && func_div_run <= 1.0)) { + HWY_WARN("Func time %f should be similar to total run %f.", + d_func_static + d_func_dynamic, sum_d_run); + } + const double sum_d_run_main = Seconds(sum_d_run_main_); + const double max_wake_latency = Seconds(max_wake_latency_); + const double sum_d_wait = Seconds(sum_d_wait_); + const double d_thread_lifetime = Seconds(d_thread_lifetime_ticks); + + const double inv_run = 1.0 / static_cast(num_run_); + const auto per_run = [inv_run](double sum) { return sum * inv_run; }; + const double avg_d_wake = per_run(Seconds(sum_d_wake_)); + const double avg_wake_latency = per_run(Seconds(sum_wake_latency_)); + const double avg_d_wait = per_run(sum_d_wait); + const double avg_wait_reps = per_run(static_cast(sum_wait_reps_)); + const double avg_d_barrier = per_run(Seconds(sum_d_barrier_)); + const double avg_r_min = per_run(sum_r_min_); + const double avg_r_max = per_run(sum_r_max_); + + const size_t num_workers = 1 + num_threads; + const double avg_tasks_static = + Avg(sum_tasks_static_, num_run_static_ * num_workers); + const double avg_tasks_dynamic = + Avg(sum_tasks_dynamic_, num_run_dynamic_ * num_workers); + const double avg_steals = + Avg(sum_tasks_stolen_, num_run_dynamic_ * num_workers); + const double avg_d_run = sum_d_run / num_workers; + + const double pc_wait = sum_d_wait / d_thread_lifetime * 100.0; + const double pc_run = sum_d_run / d_thread_lifetime * 100.0; + const double pc_main = sum_d_run_main / avg_d_run * 100.0; + + const auto us = [](double sec) { return sec * 1E6; }; + const auto ns = [](double sec) { return sec * 1E9; }; + printf( + "%3zu: %5d x %.2f/%5d x %4.1f tasks, %.2f steals; " + "wake %7.3f ns, latency %6.3f < %7.3f us, barrier %7.3f us; " + "wait %.1f us (%6.0f reps, %4.1f%%), balance %4.1f%%-%5.1f%%, " + "func: %6.3f + %7.3f, " + "%.1f%% of thread time %7.3f s; main:worker %5.1f%%\n", + num_threads, num_run_static_, avg_tasks_static, num_run_dynamic_, + avg_tasks_dynamic, avg_steals, ns(avg_d_wake), us(avg_wake_latency), + us(max_wake_latency), us(avg_d_barrier), us(avg_d_wait), avg_wait_reps, + pc_wait, avg_r_min * 100.0, avg_r_max * 100.0, d_func_static, + d_func_dynamic, pc_run, d_thread_lifetime, pc_main); + + Reset(num_threads); + } + + void Reset(size_t num_threads = kMaxThreads) { + num_run_ = 0; + num_run_static_ = 0; + num_run_dynamic_ = 0; + + sum_tasks_stolen_ = 0; + sum_tasks_static_ = 0; + sum_tasks_dynamic_ = 0; + + sum_d_wake_ = 0; + sum_wake_latency_ = 0; + max_wake_latency_ = 0; + sum_d_wait_ = 0; + sum_wait_reps_ = 0; + sum_d_barrier_ = 0; + + sum_d_func_static_ = 0; + sum_d_func_dynamic_ = 0; + sum_r_min_ = 0.0; + sum_r_max_ = 0.0; + sum_d_run_ = 0; + sum_d_run_main_ = 0; + // ctor and `NotifyMainRun` already reset `PerThread`. + } + + private: + template + static double Avg(T sum, size_t div) { + return div == 0 ? 0.0 : static_cast(sum) / static_cast(div); + } + + static constexpr size_t kU64PerLine = HWY_ALIGNMENT / sizeof(uint64_t); + + uint64_t& PerThread(size_t thread_idx, size_t offset) { + HWY_DASSERT(thread_idx < kMaxThreads); + HWY_DASSERT(offset < kSentinel); + return per_thread_[thread_idx * kU64PerLine + offset]; + } + + int32_t num_run_; + int32_t num_run_static_; + int32_t num_run_dynamic_; + + int32_t sum_tasks_stolen_; + int64_t sum_tasks_static_; + int64_t sum_tasks_dynamic_; + + timer::Ticks sum_d_wake_; + timer::Ticks sum_wake_latency_; + timer::Ticks max_wake_latency_; + timer::Ticks sum_d_wait_; + uint64_t sum_wait_reps_; + timer::Ticks sum_d_barrier_; + + timer::Ticks sum_d_func_static_; + timer::Ticks sum_d_func_dynamic_; + double sum_r_min_; + double sum_r_max_; + timer::Ticks sum_d_run_; + timer::Ticks sum_d_run_main_; + + // One cache line per pool thread to avoid false sharing. + uint64_t per_thread_[kMaxThreads * kU64PerLine]; +}; +// Enables shift rather than multiplication. +static_assert(sizeof(Stats) == (kMaxThreads + 1) * HWY_ALIGNMENT, "Wrong size"); + +// Non-power of two to avoid 2K aliasing. +HWY_INLINE_VAR constexpr size_t kMaxCallers = 60; + +// Per-caller stats, stored in `PerCluster`. +class CallerAccumulator { + public: + bool Any() const { return calls_ != 0; } + + void Add(size_t tasks, size_t workers, bool is_root, timer::Ticks wait_before, + timer::Ticks elapsed) { + calls_++; + root_ += is_root; + workers_ += workers; + min_tasks_ = HWY_MIN(min_tasks_, tasks); + max_tasks_ = HWY_MAX(max_tasks_, tasks); + tasks_ += tasks; + wait_before_ += wait_before; + elapsed_ += elapsed; + } + + void AddFrom(const CallerAccumulator& other) { + calls_ += other.calls_; + root_ += other.root_; + workers_ += other.workers_; + min_tasks_ = HWY_MIN(min_tasks_, other.min_tasks_); + max_tasks_ = HWY_MAX(max_tasks_, other.max_tasks_); + tasks_ += other.tasks_; + wait_before_ += other.wait_before_; + elapsed_ += other.elapsed_; + } + + bool operator>(const CallerAccumulator& other) const { + return elapsed_ > other.elapsed_; + } + + void PrintAndReset(const char* caller, size_t active_clusters) { + if (!Any()) return; + HWY_ASSERT(root_ <= calls_); + const double inv_calls = 1.0 / static_cast(calls_); + const double pc_root = static_cast(root_) * inv_calls * 100.0; + const double avg_workers = static_cast(workers_) * inv_calls; + const double avg_tasks = static_cast(tasks_) * inv_calls; + const double avg_tasks_per_worker = avg_tasks / avg_workers; + const double inv_freq = 1.0 / platform::InvariantTicksPerSecond(); + const double sum_wait_before = static_cast(wait_before_) * inv_freq; + const double avg_wait_before = + root_ ? sum_wait_before / static_cast(root_) : 0.0; + const double elapsed = static_cast(elapsed_) * inv_freq; + const double avg_elapsed = elapsed * inv_calls; + const double task_len = avg_elapsed / avg_tasks_per_worker; + printf( + "%40s: %7.0f x (%3.0f%%) %2zu clusters, %4.1f workers @ " + "%5.1f tasks (%5u-%5u), " + "%5.0f us wait, %6.1E us run (task len %6.1E us), total %6.2f s\n", + caller, static_cast(calls_), pc_root, active_clusters, + avg_workers, avg_tasks_per_worker, static_cast(min_tasks_), + static_cast(max_tasks_), avg_wait_before * 1E6, + avg_elapsed * 1E6, task_len * 1E6, elapsed); + *this = CallerAccumulator(); + } + + // For the grand total, only print calls and elapsed because averaging the + // the other stats is not very useful. No need to reset because this is called + // on a temporary. + void PrintTotal() { + if (!Any()) return; + HWY_ASSERT(root_ <= calls_); + const double elapsed = + static_cast(elapsed_) / platform::InvariantTicksPerSecond(); + printf("TOTAL: %7.0f x run %6.2f s\n", static_cast(calls_), + elapsed); + } + + private: + int64_t calls_ = 0; + int64_t root_ = 0; + uint64_t workers_ = 0; + uint64_t min_tasks_ = ~uint64_t{0}; + uint64_t max_tasks_ = 0; + uint64_t tasks_ = 0; + // both are wall time for root Run, otherwise CPU time. + timer::Ticks wait_before_ = 0; + timer::Ticks elapsed_ = 0; +}; +static_assert(sizeof(CallerAccumulator) == 64, ""); + +class PerCluster { + public: + CallerAccumulator& Get(size_t caller_idx) { + HWY_DASSERT(caller_idx < kMaxCallers); + callers_.Set(caller_idx); + return accumulators_[caller_idx]; + } + + template + void ForeachCaller(Func&& func) { + callers_.Foreach([&](size_t caller_idx) { + func(caller_idx, accumulators_[caller_idx]); + }); + } + + // Returns indices (required for `StringTable::Name`) in descending order of + // elapsed time. + std::vector Sorted() { + std::vector vec; + vec.reserve(kMaxCallers); + ForeachCaller([&](size_t caller_idx, CallerAccumulator&) { + vec.push_back(caller_idx); + }); + std::sort(vec.begin(), vec.end(), [&](size_t a, size_t b) { + return accumulators_[a] > accumulators_[b]; + }); + return vec; + } + + // Caller takes care of resetting `accumulators_`. + void ResetBits() { callers_ = hwy::BitSet(); } + + private: + CallerAccumulator accumulators_[kMaxCallers]; + hwy::BitSet callers_; +}; + +// Type-safe wrapper. +class Caller { + public: + Caller() : idx_(0) {} // `AddCaller` never returns 0. + explicit Caller(size_t idx) : idx_(idx) { HWY_DASSERT(idx < kMaxCallers); } + size_t Idx() const { return idx_; } + + private: + size_t idx_; +}; + +// Singleton, shared by all ThreadPool. +class Shared { + public: + static HWY_CONTRIB_DLLEXPORT Shared& Get(); // Thread-safe. + + Stopwatch MakeStopwatch() const { return Stopwatch(timer_); } + Stopwatch& LastRootEnd() { return last_root_end_; } + + // Thread-safe. Calls with the same `name` return the same `Caller`. + Caller AddCaller(const char* name) { return Caller(callers_.Add(name)); } + + PerCluster& Cluster(size_t cluster_idx) { + HWY_DASSERT(cluster_idx < kMaxClusters); + return per_cluster_[cluster_idx]; + } + + // Called from the main thread via `Profiler::PrintResults`. + void PrintAndReset() { + // Start counting pools (= one per cluster) invoked by each caller. + size_t active_clusters[kMaxCallers] = {}; + per_cluster_[0].ForeachCaller( + [&](size_t caller_idx, CallerAccumulator& acc) { + active_clusters[caller_idx] = acc.Any(); + }); + // Reduce per-cluster accumulators into the first cluster. + for (size_t cluster_idx = 1; cluster_idx < kMaxClusters; ++cluster_idx) { + per_cluster_[cluster_idx].ForeachCaller( + [&](size_t caller_idx, CallerAccumulator& acc) { + active_clusters[caller_idx] += acc.Any(); + per_cluster_[0].Get(caller_idx).AddFrom(acc); + acc = CallerAccumulator(); + }); + per_cluster_[cluster_idx].ResetBits(); + } + + CallerAccumulator total; + for (size_t caller_idx : per_cluster_[0].Sorted()) { + CallerAccumulator& acc = per_cluster_[0].Get(caller_idx); + total.AddFrom(acc); // must be before PrintAndReset. + acc.PrintAndReset(callers_.Name(caller_idx), active_clusters[caller_idx]); + } + total.PrintTotal(); + per_cluster_[0].ResetBits(); + } + + private: + Shared() // called via Get(). + : last_root_end_(timer_), + send_config(callers_.Add("SendConfig")), + dtor(callers_.Add("PoolDtor")), + print_stats(callers_.Add("PrintStats")) { + Profiler::Get().AddFunc(this, [this]() { PrintAndReset(); }); + // Can skip `RemoveFunc` because the singleton never dies. + } + + const Timer timer_; + Stopwatch last_root_end_; + + PerCluster per_cluster_[kMaxClusters]; + StringTable callers_; + + public: + // Returned from `callers_.Add`: + Caller send_config; + Caller dtor; + Caller print_stats; +}; + +#else + +struct Stats { + void NotifyRunStatic(size_t, timer::Ticks) {} + void NotifyRunDynamic(size_t, size_t, size_t, timer::Ticks) {} + void NotifyThreadRun(size_t, timer::Ticks, size_t, timer::Ticks, + timer::Ticks) {} + void NotifyMainRun(size_t, timer::Ticks, timer::Ticks, timer::Ticks, + timer::Ticks) {} + void PrintAndReset(size_t, timer::Ticks) {} + void Reset(size_t = kMaxThreads) {} +}; + +struct Caller {}; + +class Shared { + public: + static HWY_CONTRIB_DLLEXPORT Shared& Get(); // Thread-safe. + + Stopwatch MakeStopwatch() const { return Stopwatch(timer_); } + + Caller AddCaller(const char*) { return Caller(); } + + private: + Shared() {} + + const Timer timer_; + + public: + Caller send_config; + Caller dtor; + Caller print_stats; +}; + +#endif // PROFILER_ENABLED + +// Per-worker state used by both main and worker threads. `ThreadFunc` +// (threads) and `ThreadPool` (main) have a few additional members of their own. +class alignas(HWY_ALIGNMENT) Worker { // HWY_ALIGNMENT bytes + static constexpr size_t kMaxVictims = 4; + + static constexpr auto kAcq = std::memory_order_acquire; + static constexpr auto kRel = std::memory_order_release; + + bool OwnsGlobalIdx() const { +#if PROFILER_ENABLED + if (global_idx_ >= profiler::kMaxWorkers) { + HWY_WARN("Windows-only bug? global_idx %zu >= %zu.", global_idx_, + profiler::kMaxWorkers); + } +#endif // PROFILER_ENABLED + // Across-cluster pool owns all except the main thread, which is reserved by + // profiler.cc. + if (cluster_idx_ == kAllClusters) return global_idx_ != 0; + // Within-cluster pool owns all except *its* main thread, because that is + // owned by the across-cluster pool. + return worker_ != 0; + } + + public: + Worker(const size_t worker, const size_t num_threads, + const PoolWorkerMapping mapping, const Divisor64& div_workers, + const Stopwatch& stopwatch) + : workers_(this - worker), + worker_(worker), + num_threads_(num_threads), + stopwatch_(stopwatch), + // If `num_threads == 0`, we might be in an inner pool and must use + // the `global_idx` we are currently running on. + global_idx_(num_threads == 0 ? Profiler::GlobalIdx() : mapping(worker)), + cluster_idx_(mapping.ClusterIdx()) { + HWY_DASSERT(IsAligned(this, HWY_ALIGNMENT)); + HWY_DASSERT(worker <= num_threads); + const size_t num_workers = static_cast(div_workers.GetDivisor()); + num_victims_ = static_cast(HWY_MIN(kMaxVictims, num_workers)); + + // Increase gap between coprimes to reduce collisions. + const uint32_t coprime = ShuffledIota::FindAnotherCoprime( + static_cast(num_workers), + static_cast((worker + 1) * 257 + worker * 13)); + const ShuffledIota shuffled_iota(coprime); + + // To simplify `WorkerRun`, this worker is the first to 'steal' from. + victims_[0] = static_cast(worker); + for (uint32_t i = 1; i < num_victims_; ++i) { + victims_[i] = shuffled_iota.Next(victims_[i - 1], div_workers); + HWY_DASSERT(victims_[i] != worker); + } + + HWY_IF_CONSTEXPR(PROFILER_ENABLED) { + if (HWY_LIKELY(OwnsGlobalIdx())) { + Profiler::Get().ReserveWorker(global_idx_); + } + } + } + + ~Worker() { + HWY_IF_CONSTEXPR(PROFILER_ENABLED) { + if (HWY_LIKELY(OwnsGlobalIdx())) { + Profiler::Get().FreeWorker(global_idx_); + } + } + } + + // Placement-newed by `WorkerLifecycle`, we do not expect any copying. + Worker(const Worker&) = delete; + Worker& operator=(const Worker&) = delete; + + size_t Index() const { return worker_; } + // For work stealing. + Worker* AllWorkers() { return workers_; } + const Worker* AllWorkers() const { return workers_; } + size_t NumThreads() const { return num_threads_; } + + size_t GlobalIdx() const { return global_idx_; } + size_t ClusterIdx() const { return cluster_idx_; } + + void SetStartTime() { stopwatch_.Reset(); } + timer::Ticks ElapsedTime() { return stopwatch_.Elapsed(); } + + // ------------------------ Per-worker storage for `SendConfig` + + Config NextConfig() const { return next_config_; } + // Called during `SendConfig` by workers and now also the main thread. This + // avoids a separate `ThreadPool` member which risks going out of sync. + void SetNextConfig(Config copy) { next_config_ = copy; } + + Exit GetExit() const { return exit_; } + void SetExit(Exit exit) { exit_ = exit; } + + uint32_t WorkerEpoch() const { return worker_epoch_; } + uint32_t AdvanceWorkerEpoch() { return ++worker_epoch_; } + + // ------------------------ Task assignment + + // Called from the main thread. + void SetRange(const uint64_t begin, const uint64_t end) { + my_begin_.store(begin, kRel); + my_end_.store(end, kRel); + } + + uint64_t MyEnd() const { return my_end_.load(kAcq); } + + Span Victims() const { + return hwy::Span(victims_.data(), + static_cast(num_victims_)); + } + + // Returns the next task to execute. If >= MyEnd(), it must be skipped. + uint64_t WorkerReserveTask() { + // TODO(janwas): replace with cooperative work-stealing. + return my_begin_.fetch_add(1, std::memory_order_relaxed); + } + + // ------------------------ Waiter: Threads wait for tasks + + // WARNING: some `Wait*` do not set this for all Worker instances. For + // example, `WaitType::kBlock` only uses the first worker's `Waiter` because + // one futex can wake multiple waiters. Hence we never load this directly + // without going through `Wait*` policy classes, and must ensure all threads + // use the same wait mode. + + const std::atomic& Waiter() const { return wait_epoch_; } + std::atomic& MutableWaiter() { return wait_epoch_; } // futex + void StoreWaiter(uint32_t epoch) { wait_epoch_.store(epoch, kRel); } + + // ------------------------ Barrier: Main thread waits for workers + + // For use by `HasReached` and `UntilReached`. + const std::atomic& Barrier() const { return barrier_epoch_; } + // Setting to `epoch` signals that the worker has reached the barrier. + void StoreBarrier(uint32_t epoch) { barrier_epoch_.store(epoch, kRel); } + + private: + // Set by `SetRange`: + std::atomic my_begin_; + std::atomic my_end_; + + Worker* const workers_; + const size_t worker_; + const size_t num_threads_; + + Stopwatch stopwatch_; // Reset by `SetStartTime`. + const size_t global_idx_; + const size_t cluster_idx_; + + // Use u32 to match futex.h. These must start at the initial value of + // `worker_epoch_`. + std::atomic wait_epoch_{1}; + std::atomic barrier_epoch_{1}; + + uint32_t num_victims_; // <= kPoolMaxVictims + std::array victims_; + + // Written and read by the same thread, hence not atomic. + Config next_config_; + Exit exit_ = Exit::kNone; + // thread_pool_test requires nonzero epoch. + uint32_t worker_epoch_ = 1; + + HWY_MEMBER_VAR_MAYBE_UNUSED uint8_t + padding_[HWY_ALIGNMENT - 56 - 6 * sizeof(void*) - sizeof(victims_)]; +}; +static_assert(sizeof(Worker) == HWY_ALIGNMENT, ""); + +// Creates/destroys `Worker` using preallocated storage. See comment at +// `ThreadPool::worker_bytes_` for why we do not dynamically allocate. +class WorkerLifecycle { // 0 bytes + public: + // Placement new for `Worker` into `storage` because its ctor requires + // the worker index. Returns array of all workers. + static Worker* Init(uint8_t* storage, size_t num_threads, + PoolWorkerMapping mapping, const Divisor64& div_workers, + Shared& shared) { + Worker* workers = new (storage) + Worker(0, num_threads, mapping, div_workers, shared.MakeStopwatch()); + for (size_t worker = 1; worker <= num_threads; ++worker) { + new (Addr(storage, worker)) Worker(worker, num_threads, mapping, + div_workers, shared.MakeStopwatch()); + // Ensure pointer arithmetic is the same (will be used in Destroy). + HWY_DASSERT(reinterpret_cast(workers + worker) == + reinterpret_cast(Addr(storage, worker))); + } + + // Publish non-atomic stores in `workers`. + std::atomic_thread_fence(std::memory_order_release); + + return workers; + } + + static void Destroy(Worker* workers, size_t num_threads) { + for (size_t worker = 0; worker <= num_threads; ++worker) { + workers[worker].~Worker(); + } + } + + private: + static uint8_t* Addr(uint8_t* storage, size_t worker) { + return storage + worker * sizeof(Worker); + } +}; + +// Stores arguments to `Run`: the function and range of task indices. Set by +// the main thread, read by workers including the main thread. +class Tasks { + static constexpr auto kAcq = std::memory_order_acquire; + + // Signature of the (internal) function called from workers(s) for each + // `task` in the [`begin`, `end`) passed to Run(). Closures (lambdas) do not + // receive the first argument, which points to the lambda object. + typedef void (*RunFunc)(const void* opaque, uint64_t task, size_t worker); + + public: + Tasks() { HWY_DASSERT(IsAligned(this, 8)); } + + template + void Set(uint64_t begin, uint64_t end, const Closure& closure) { + constexpr auto kRel = std::memory_order_release; + // `TestTasks` and `SetWaitMode` call this with `begin == end`. + HWY_DASSERT(begin <= end); + begin_.store(begin, kRel); + end_.store(end, kRel); + func_.store(static_cast(&CallClosure), kRel); + opaque_.store(reinterpret_cast(&closure), kRel); + } + + // Assigns workers their share of `[begin, end)`. Called from the main + // thread; workers are initializing or waiting for a command. + // Negligible CPU time. + static void DivideRangeAmongWorkers(const uint64_t begin, const uint64_t end, + const Divisor64& div_workers, + Worker* workers) { + const size_t num_workers = static_cast(div_workers.GetDivisor()); + HWY_DASSERT(num_workers > 1); // Else Run() runs on the main thread. + HWY_DASSERT(begin <= end); + const size_t num_tasks = static_cast(end - begin); + + // Assigning all remainders to the last worker causes imbalance. We instead + // give one more to each worker whose index is less. This may be zero when + // called from `TestTasks`. + const size_t min_tasks = static_cast(div_workers.Divide(num_tasks)); + const size_t remainder = + static_cast(div_workers.Remainder(num_tasks)); + + uint64_t my_begin = begin; + for (size_t worker = 0; worker < num_workers; ++worker) { + const uint64_t my_end = my_begin + min_tasks + (worker < remainder); + workers[worker].SetRange(my_begin, my_end); + my_begin = my_end; + } + HWY_DASSERT(my_begin == end); + } + + // Runs the worker's assigned range of tasks, plus work stealing if needed. + void WorkerRun(Worker* worker, const Shared& shared, Stats& stats) const { + if (NumTasks() > worker->NumThreads() + 1) { + WorkerRunDynamic(worker, shared, stats); + } else { + WorkerRunStatic(worker, shared, stats); + } + } + + private: + // Special case for <= 1 task per worker, where stealing is unnecessary. + void WorkerRunStatic(Worker* worker, const Shared& shared, + Stats& stats) const { + const uint64_t begin = begin_.load(kAcq); + const uint64_t end = end_.load(kAcq); + HWY_DASSERT(begin <= end); + const size_t index = worker->Index(); + + const uint64_t task = begin + index; + // We might still have more workers than tasks, so check first. + if (HWY_LIKELY(task < end)) { + const void* opaque = Opaque(); + const RunFunc func = Func(); + Stopwatch stopwatch = shared.MakeStopwatch(); + func(opaque, task, index); + stats.NotifyRunStatic(index, stopwatch.Elapsed()); + } + } + + // Must be called for each `worker` in [0, num_workers). + // + // A prior version of this code attempted to assign only as much work as a + // worker will actually use. As with OpenMP's 'guided' strategy, we assigned + // remaining/(k*num_threads) in each iteration. Although the worst-case + // imbalance is bounded, this required several rounds of work allocation, and + // the atomic counter did not scale to > 30 threads. + // + // We now use work stealing instead, where already-finished workers look for + // and perform work from others, as if they were that worker. This deals with + // imbalances as they arise, but care is required to reduce contention. We + // randomize the order in which threads choose victims to steal from. + void WorkerRunDynamic(Worker* worker, const Shared& shared, + Stats& stats) const { + Worker* workers = worker->AllWorkers(); + const size_t index = worker->Index(); + const RunFunc func = Func(); + const void* opaque = Opaque(); + + size_t sum_tasks = 0; + size_t sum_stolen = 0; + timer::Ticks sum_d_func = 0; + // For each worker in random order, starting with our own, attempt to do + // all their work. + for (uint32_t victim : worker->Victims()) { + Worker* other_worker = workers + victim; + + // Until all of other_worker's work is done: + const uint64_t other_end = other_worker->MyEnd(); + for (;;) { + // The worker that first sets `task` to `other_end` exits this loop. + // After that, `task` can be incremented up to `num_workers - 1` times, + // once per other worker. + const uint64_t task = other_worker->WorkerReserveTask(); + if (HWY_UNLIKELY(task >= other_end)) { + hwy::Pause(); // Reduce coherency traffic while stealing. + break; + } + Stopwatch stopwatch = shared.MakeStopwatch(); + // Pass the index we are actually running on; this is important + // because it is the TLS index for user code. + func(opaque, task, index); + sum_tasks++; + sum_stolen += worker != other_worker; + sum_d_func += stopwatch.Elapsed(); + } + } + stats.NotifyRunDynamic(index, sum_tasks, sum_stolen, sum_d_func); + } + + size_t NumTasks() const { + return static_cast(end_.load(kAcq) - begin_.load(kAcq)); + } + + const void* Opaque() const { return opaque_.load(kAcq); } + RunFunc Func() const { return func_.load(kAcq); } + + // Calls closure(task, worker). Signature must match `RunFunc`. + template + static void CallClosure(const void* opaque, uint64_t task, size_t worker) { + (*reinterpret_cast(opaque))(task, worker); + } + + std::atomic begin_; + std::atomic end_; + std::atomic func_; + std::atomic opaque_; +}; +static_assert(sizeof(Tasks) == 16 + 2 * sizeof(void*), ""); + +// ------------------------------ Threads wait, main wakes them + +// Considerations: +// - uint32_t storage per `Worker` so we can use `futex.h`. +// - avoid atomic read-modify-write. These are implemented on x86 using a LOCK +// prefix, which interferes with other cores' cache-coherency transactions +// and drains our core's store buffer. We use only store-release and +// load-acquire. Although expressed using `std::atomic`, these are normal +// loads/stores in the strong x86 memory model. +// - prefer to avoid resetting the state. "Sense-reversing" (flipping a flag) +// would work, but we we prefer an 'epoch' counter because it is more useful +// and easier to understand/debug, and as fast. + +// Both the main thread and each worker maintain their own counter, which are +// implicitly synchronized by the barrier. To wake, the main thread does a +// store-release, and each worker does a load-acquire. The policy classes differ +// in whether they block or spin (with pause/monitor to reduce power), and +// whether workers check their own counter or a shared one. +// +// All methods are const because they only use storage in `Worker`, and we +// prefer to pass const-references to empty classes to enable type deduction. + +// Futex: blocking reduces apparent CPU usage, but has higher wake latency. +struct WaitBlock { + // Wakes all workers by storing the current `epoch`. + void WakeWorkers(Worker* workers, const uint32_t epoch) const { + HWY_DASSERT(epoch != 0); + workers[1].StoreWaiter(epoch); + WakeAll(workers[1].MutableWaiter()); // futex: expensive syscall + } + + // Waits until `WakeWorkers(_, epoch)` has been called. + template + size_t UntilWoken(const Worker& worker, const Spin& /*spin*/) const { + HWY_DASSERT(worker.Index() != 0); // main is 0 + const uint32_t epoch = worker.WorkerEpoch(); + const Worker* workers = worker.AllWorkers(); + BlockUntilDifferent(epoch - 1, workers[1].Waiter()); + return 1; // iterations + } +}; + +// Single u32: single store by the main thread. All worker threads poll this +// one cache line and thus have it in a shared state, which means the store +// will invalidate each of them, leading to more transactions than SpinSeparate. +struct WaitSpin1 { + void WakeWorkers(Worker* workers, const uint32_t epoch) const { + workers[1].StoreWaiter(epoch); + } + + // Returns the number of spin-wait iterations. + template + size_t UntilWoken(const Worker& worker, const Spin& spin) const { + HWY_DASSERT(worker.Index() != 0); // main is 0 + const Worker* workers = worker.AllWorkers(); + const uint32_t epoch = worker.WorkerEpoch(); + return spin.UntilEqual(epoch, workers[1].Waiter()); + } +}; + +// Separate u32 per thread: more stores for the main thread, but each worker +// only polls its own cache line, leading to fewer cache-coherency transactions. +struct WaitSpinSeparate { + void WakeWorkers(Worker* workers, const uint32_t epoch) const { + for (size_t thread = 0; thread < workers->NumThreads(); ++thread) { + workers[1 + thread].StoreWaiter(epoch); + } + } + + template + size_t UntilWoken(const Worker& worker, const Spin& spin) const { + HWY_DASSERT(worker.Index() != 0); // main is 0 + const uint32_t epoch = worker.WorkerEpoch(); + return spin.UntilEqual(epoch, worker.Waiter()); + } +}; + +// Calls unrolled code selected by all config enums. +template +HWY_INLINE void CallWithConfig(const Config& config, Func&& func, + Args&&... args) { + switch (config.wait_type) { + case WaitType::kBlock: + return func(SpinPause(), WaitBlock(), std::forward(args)...); + case WaitType::kSpin1: + return CallWithSpin(config.spin_type, func, WaitSpin1(), + std::forward(args)...); + case WaitType::kSpinSeparate: + return CallWithSpin(config.spin_type, func, WaitSpinSeparate(), + std::forward(args)...); + case WaitType::kSentinel: + HWY_UNREACHABLE; + } +} + +// ------------------------------ Barrier: Main thread waits for workers + +// Similar to `WaitSpinSeparate`, a store-release of the same local epoch +// counter serves as a "have arrived" flag that does not require resetting. +class Barrier { + public: + void WorkerReached(Worker& worker, uint32_t epoch) const { + HWY_DASSERT(worker.Index() != 0); // main is 0 + worker.StoreBarrier(epoch); + } + + // Returns true if `worker` (can be the main thread) reached the barrier. + bool HasReached(const Worker* worker, uint32_t epoch) const { + const uint32_t barrier = worker->Barrier().load(std::memory_order_acquire); + HWY_DASSERT(barrier <= epoch); + return barrier == epoch; + } + + // Main thread loops over each worker. A "group of 2 or 4" barrier was not + // competitive on Skylake, Granite Rapids and Zen5. + template + void UntilReached(size_t num_threads, Worker* workers, const Spin& spin, + uint32_t epoch) const { + workers[0].StoreBarrier(epoch); // for main thread HasReached. + + for (size_t i = 0; i < num_threads; ++i) { + // TODO: log number of spin-wait iterations. + (void)spin.UntilEqual(epoch, workers[1 + i].Barrier()); + } + } +}; + +// In debug builds, detects when functions are re-entered. +class BusyFlag { + public: + void Set() { HWY_DASSERT(!busy_.test_and_set()); } + void Clear() { HWY_IF_CONSTEXPR(HWY_IS_DEBUG_BUILD) busy_.clear(); } + + private: + std::atomic_flag busy_ = ATOMIC_FLAG_INIT; +}; + +} // namespace pool + +// Highly efficient parallel-for, intended for workloads with thousands of +// fork-join regions which consist of calling tasks[t](i) for a few hundred i, +// using dozens of threads. +// +// To reduce scheduling overhead, we assume that tasks are statically known and +// that threads do not schedule new work themselves. This allows us to avoid +// queues and only store a counter plus the current task. The latter is a +// pointer to a lambda function, without the allocation/indirection required for +// `std::function`. +// +// To reduce fork/join latency, we choose an efficient barrier, optionally +// enable spin-waits via `SetWaitMode`, and avoid any mutex/lock. We largely +// even avoid atomic RMW operations (LOCK prefix): currently for the wait and +// barrier, in future hopefully also for work stealing. +// +// To eliminate false sharing and enable reasoning about cache line traffic, the +// class is aligned and holds all worker state. +// +// For load-balancing, we use work stealing in random order. +class alignas(HWY_ALIGNMENT) ThreadPool { + // Used to initialize `num_threads_` from the ctor argument. + static size_t ClampedNumThreads(size_t num_threads) { + // Upper bound is required for `worker_bytes_`. + if (HWY_UNLIKELY(num_threads > pool::kMaxThreads)) { + HWY_WARN("ThreadPool: clamping num_threads %zu to %zu.", num_threads, + pool::kMaxThreads); + num_threads = pool::kMaxThreads; + } + return num_threads; + } + + public: + // This typically includes hyperthreads, hence it is a loose upper bound. + // -1 because these are in addition to the main thread. + static size_t MaxThreads() { + LogicalProcessorSet lps; + // This is OS dependent, but more accurate if available because it takes + // into account restrictions set by cgroups or numactl/taskset. + if (GetThreadAffinity(lps)) { + return lps.Count() - 1; + } + return static_cast(std::thread::hardware_concurrency() - 1); + } + + // `num_threads` is the number of *additional* threads to spawn, which should + // not exceed `MaxThreads()`. Note that the main thread also performs work. + // `mapping` indicates how to map local worker_idx to global. + ThreadPool(size_t num_threads, + PoolWorkerMapping mapping = PoolWorkerMapping()) + : num_threads_(ClampedNumThreads(num_threads)), + div_workers_(1 + num_threads_), + shared_(pool::Shared::Get()), // on first call, calls ReserveWorker(0)! + workers_(pool::WorkerLifecycle::Init(worker_bytes_, num_threads_, + mapping, div_workers_, shared_)) { + // Leaves the default wait mode as `kBlock`, which means futex, because + // spinning only makes sense when threads are pinned and wake latency is + // important, so it must explicitly be requested by calling `SetWaitMode`. + for (PoolWaitMode mode : {PoolWaitMode::kSpin, PoolWaitMode::kBlock}) { + wait_mode_ = mode; // for AutoTuner + AutoTuner().SetCandidates( + pool::Config::AllCandidates(mode)); + } + + // Skip empty pools because they do not update stats anyway. + if (num_threads_ > 0) { + Profiler::Get().AddFunc(this, [this]() { PrintStats(); }); + } + + threads_.reserve(num_threads_); + for (size_t thread = 0; thread < num_threads_; ++thread) { + threads_.emplace_back( + ThreadFunc(workers_[1 + thread], tasks_, shared_, stats_)); + } + + // Threads' `Config` defaults to spinning. Change to `kBlock` (see above). + // This also ensures all threads have started before we return, so that + // startup latency is billed to the ctor, not the first `Run`. + SendConfig(AutoTuner().Candidates()[0]); + } + + // If we created threads, waits for them all to exit. + ~ThreadPool() { + // There is no portable way to request threads to exit like `ExitThread` on + // Windows, otherwise we could call that from `Run`. Instead, we must cause + // the thread to wake up and exit. We can just use `Run`. + (void)RunWithoutAutotune( + 0, NumWorkers(), shared_.dtor, + [this](HWY_MAYBE_UNUSED uint64_t task, size_t worker) { + HWY_DASSERT(task == worker); + workers_[worker].SetExit(Exit::kThread); + }); + + for (std::thread& thread : threads_) { + HWY_DASSERT(thread.joinable()); + thread.join(); + } + + if (num_threads_ > 0) { + Profiler::Get().RemoveFunc(this); + } + + pool::WorkerLifecycle::Destroy(workers_, num_threads_); + } + + ThreadPool(const ThreadPool&) = delete; + ThreadPool& operator&(const ThreadPool&) = delete; + + // Returns number of Worker, i.e., one more than the largest `worker` + // argument. Useful for callers that want to allocate thread-local storage. + size_t NumWorkers() const { + return static_cast(div_workers_.GetDivisor()); + } + + // `mode` defaults to `kBlock`, which means futex. Switching to `kSpin` + // reduces fork-join overhead especially when there are many calls to `Run`, + // but wastes power when waiting over long intervals. Must not be called + // concurrently with any `Run`, because this uses the same waiter/barrier. + void SetWaitMode(PoolWaitMode mode) { + wait_mode_ = mode; + SendConfig(AutoTuneComplete() ? *AutoTuner().Best() + : AutoTuner().NextConfig()); + } + + // For printing which is in use. + pool::Config config() const { return workers_[0].NextConfig(); } + + bool AutoTuneComplete() const { return AutoTuner().Best(); } + Span AutoTuneCosts() { return AutoTuner().Costs(); } + + static pool::Caller AddCaller(const char* name) { + return pool::Shared::Get().AddCaller(name); + } + + // parallel-for: Runs `closure(task, worker)` on workers for every `task` in + // `[begin, end)`. Note that the unit of work should be large enough to + // amortize the function call overhead, but small enough that each worker + // processes a few tasks. Thus each `task` is usually a loop. + // + // Not thread-safe - concurrent parallel-for in the same `ThreadPool` are + // forbidden unless `NumWorkers() == 1` or `end <= begin + 1`. + template + void Run(uint64_t begin, uint64_t end, pool::Caller caller, + const Closure& closure) { + AutoTuneT& auto_tuner = AutoTuner(); + // Already finished tuning: run without time measurement. + if (HWY_LIKELY(auto_tuner.Best())) { + // Don't care whether threads ran, we are done either way. + (void)RunWithoutAutotune(begin, end, caller, closure); + return; + } + + // Not yet finished: measure time and notify autotuner. + Stopwatch stopwatch(shared_.MakeStopwatch()); + // Skip update if threads didn't actually run. + if (!RunWithoutAutotune(begin, end, caller, closure)) return; + auto_tuner.NotifyCost(stopwatch.Elapsed()); + + pool::Config next = auto_tuner.NextConfig(); // may be overwritten below + if (auto_tuner.Best()) { // just finished + next = *auto_tuner.Best(); + HWY_IF_CONSTEXPR(pool::kVerbosity >= 1) { + const size_t idx_best = static_cast( + auto_tuner.Best() - auto_tuner.Candidates().data()); + HWY_DASSERT(idx_best < auto_tuner.Costs().size()); + auto& AT = auto_tuner.Costs()[idx_best]; + const double best_cost = AT.EstimateCost(); + HWY_DASSERT(best_cost > 0.0); // will divide by this below + + Stats s_ratio; + for (size_t i = 0; i < auto_tuner.Costs().size(); ++i) { + if (i == idx_best) continue; + const double cost = auto_tuner.Costs()[i].EstimateCost(); + s_ratio.Notify(static_cast(cost / best_cost)); + } + + fprintf(stderr, + "Pool %3zu: %s %8.0f +/- %6.0f. Gain %.2fx [%.2fx, %.2fx]\n", + NumWorkers(), auto_tuner.Best()->ToString().c_str(), best_cost, + AT.Stddev(), s_ratio.GeometricMean(), + static_cast(s_ratio.Min()), + static_cast(s_ratio.Max())); + } + } + SendConfig(next); + } + + // Backward-compatible version without Caller. + template + void Run(uint64_t begin, uint64_t end, const Closure& closure) { + Run(begin, end, pool::Caller(), closure); + } + + private: + // Called via `CallWithConfig`. + struct MainWakeAndBarrier { + template + void operator()(const Spin& spin, const Wait& wait, pool::Worker& main, + const pool::Tasks& tasks, const pool::Shared& shared, + pool::Stats& stats) const { + const pool::Barrier barrier; + pool::Worker* workers = main.AllWorkers(); + HWY_DASSERT(&main == main.AllWorkers()); // main is first. + const size_t num_threads = main.NumThreads(); + const uint32_t epoch = main.AdvanceWorkerEpoch(); + + HWY_IF_CONSTEXPR(HWY_IS_DEBUG_BUILD) { + for (size_t i = 0; i < 1 + num_threads; ++i) { + HWY_DASSERT(!barrier.HasReached(workers + i, epoch)); + } + } + + Stopwatch stopwatch(shared.MakeStopwatch()); + const timer::Ticks t_before_wake = stopwatch.Origin(); + wait.WakeWorkers(workers, epoch); + const timer::Ticks d_wake = stopwatch.Elapsed(); + + // Also perform work on the main thread before the barrier. + tasks.WorkerRun(&main, shared, stats); + const timer::Ticks d_run = stopwatch.Elapsed(); + + // Spin-waits until all worker *threads* (not `main`, because it already + // knows it is here) called `WorkerReached`. + barrier.UntilReached(num_threads, workers, spin, epoch); + const timer::Ticks d_barrier = stopwatch.Elapsed(); + stats.NotifyMainRun(main.NumThreads(), t_before_wake, d_wake, d_run, + d_barrier); + + HWY_IF_CONSTEXPR(HWY_IS_DEBUG_BUILD) { + for (size_t i = 0; i < 1 + num_threads; ++i) { + HWY_DASSERT(barrier.HasReached(workers + i, epoch)); + } + } + + // Threads are or will soon be waiting `UntilWoken`, which serves as the + // 'release' phase of the barrier. + } + }; + + // Called by `std::thread`. Could also be a lambda. + class ThreadFunc { + // Functor called by `CallWithConfig`. Loops until `SendConfig` changes the + // Spin or Wait policy or the pool is destroyed. + struct WorkerLoop { + template + void operator()(const Spin& spin, const Wait& wait, pool::Worker& worker, + pool::Tasks& tasks, const pool::Shared& shared, + pool::Stats& stats) const { + do { + // Main worker also calls this, so their epochs match. + const uint32_t epoch = worker.AdvanceWorkerEpoch(); + + Stopwatch stopwatch(shared.MakeStopwatch()); + + const size_t wait_reps = wait.UntilWoken(worker, spin); + const timer::Ticks d_wait = stopwatch.Elapsed(); + const timer::Ticks t_before_run = stopwatch.Origin(); + + tasks.WorkerRun(&worker, shared, stats); + const timer::Ticks d_run = stopwatch.Elapsed(); + stats.NotifyThreadRun(worker.Index(), d_wait, wait_reps, t_before_run, + d_run); + + // Notify barrier after `WorkerRun`. Note that we cannot send an + // after-barrier timestamp, see above. + pool::Barrier().WorkerReached(worker, epoch); + // Check after `WorkerReached`, otherwise the main thread deadlocks. + } while (worker.GetExit() == Exit::kNone); + } + }; + + public: + ThreadFunc(pool::Worker& worker, pool::Tasks& tasks, + const pool::Shared& shared, pool::Stats& stats) + : worker_(worker), tasks_(tasks), shared_(shared), stats_(stats) {} + + void operator()() { + // Ensure main thread's writes are visible (synchronizes with fence in + // `WorkerLifecycle::Init`). + std::atomic_thread_fence(std::memory_order_acquire); + + HWY_DASSERT(worker_.Index() != 0); // main is 0 + SetThreadName("worker%03zu", static_cast(worker_.Index() - 1)); + + worker_.SetStartTime(); + Profiler& profiler = Profiler::Get(); + profiler.SetGlobalIdx(worker_.GlobalIdx()); + // No Zone here because it would only exit after `GetExit`, which may be + // after the main thread's `PROFILER_END_ROOT_RUN`, and thus too late to + // be counted. Instead, `ProfilerFunc` records the elapsed time. + + // Loop termination via `GetExit` is triggered by `~ThreadPool`. + for (;;) { + // Uses the initial config, or the last one set during WorkerRun. + CallWithConfig(worker_.NextConfig(), WorkerLoop(), worker_, tasks_, + shared_, stats_); + + // Exit or reset the flag and return to WorkerLoop with a new config. + if (worker_.GetExit() == Exit::kThread) break; + worker_.SetExit(Exit::kNone); + } + + profiler.SetGlobalIdx(~size_t{0}); + + // Defer `FreeWorker` until workers are destroyed to ensure the profiler + // is not still using the worker. + } + + private: + pool::Worker& worker_; + pool::Tasks& tasks_; + const pool::Shared& shared_; + pool::Stats& stats_; + }; + + void PrintStats() { + // Total run time from all non-main threads. + std::atomic sum_thread_elapsed{0}; + (void)RunWithoutAutotune( + 0, NumWorkers(), shared_.print_stats, + [this, &sum_thread_elapsed](HWY_MAYBE_UNUSED uint64_t task, + size_t worker) { + HWY_DASSERT(task == worker); + // Skip any main thread(s) because they did not init the stopwatch. + if (worker != 0) { + sum_thread_elapsed.fetch_add(workers_[worker].ElapsedTime()); + } + }); + const timer::Ticks thread_total = + sum_thread_elapsed.load(std::memory_order_acquire); + stats_.PrintAndReset(num_threads_, thread_total); + } + + // Returns whether threads were used. If not, there is no need to update + // the autotuner config. + template + bool RunWithoutAutotune(uint64_t begin, uint64_t end, pool::Caller caller, + const Closure& closure) { + pool::Worker& main = workers_[0]; + + const size_t num_tasks = static_cast(end - begin); + const size_t num_workers = NumWorkers(); + + // If zero or one task, or no extra threads, run on the main thread without + // setting any member variables, because we may be re-entering Run. + if (HWY_UNLIKELY(num_tasks <= 1 || num_workers == 1)) { + for (uint64_t task = begin; task < end; ++task) { + closure(task, /*worker=*/0); + } + return false; + } + + busy_.Set(); + +#if PROFILER_ENABLED + const bool is_root = PROFILER_IS_ROOT_RUN(); + Stopwatch stopwatch(shared_.MakeStopwatch()); + const timer::Ticks wait_before = + is_root ? shared_.LastRootEnd().Elapsed() : 0; +#endif + + tasks_.Set(begin, end, closure); + + // More than one task per worker: use work stealing. + if (HWY_LIKELY(num_tasks > num_workers)) { + pool::Tasks::DivideRangeAmongWorkers(begin, end, div_workers_, workers_); + } + + // Runs `MainWakeAndBarrier` with the first worker slot. + CallWithConfig(config(), MainWakeAndBarrier(), main, tasks_, shared_, + stats_); + +#if PROFILER_ENABLED + pool::CallerAccumulator& acc = + shared_.Cluster(main.ClusterIdx()).Get(caller.Idx()); + acc.Add(num_tasks, num_workers, is_root, wait_before, stopwatch.Elapsed()); + if (is_root) { + PROFILER_END_ROOT_RUN(); + shared_.LastRootEnd().Reset(); + } +#else + (void)caller; +#endif + + busy_.Clear(); + return true; + } + + // Sends `next_config` to workers: + // - Main wakes threads using the current config. + // - Threads copy `next_config` into their `Worker` during `WorkerRun`. + // - Threads notify the (same) barrier and already wait for the next wake + // using `next_config`. + HWY_NOINLINE void SendConfig(pool::Config next_config) { + (void)RunWithoutAutotune( + 0, NumWorkers(), shared_.send_config, + [this, next_config](HWY_MAYBE_UNUSED uint64_t task, size_t worker) { + HWY_DASSERT(task == worker); // one task per worker + workers_[worker].SetNextConfig(next_config); + workers_[worker].SetExit(Exit::kLoop); + }); + + // All have woken and are, or will be, waiting per `next_config`. Now we + // can entirely switch the main thread's config for the next wake. + workers_[0].SetNextConfig(next_config); + } + + using AutoTuneT = AutoTune; + AutoTuneT& AutoTuner() { + static_assert(static_cast(PoolWaitMode::kBlock) == 1, ""); + return auto_tune_[static_cast(wait_mode_) - 1]; + } + const AutoTuneT& AutoTuner() const { + return auto_tune_[static_cast(wait_mode_) - 1]; + } + + const size_t num_threads_; // not including main thread + const Divisor64 div_workers_; + pool::Shared& shared_; + pool::Worker* const workers_; // points into `worker_bytes_` + + alignas(HWY_ALIGNMENT) pool::Stats stats_; + + // This is written by the main thread and read by workers, via reference + // passed to `ThreadFunc`. Padding ensures that the workers' cache lines are + // not unnecessarily invalidated when the main thread writes other members. + alignas(HWY_ALIGNMENT) pool::Tasks tasks_; + HWY_MEMBER_VAR_MAYBE_UNUSED char + padding_[HWY_ALIGNMENT - sizeof(pool::Tasks)]; + + pool::BusyFlag busy_; + + // Unmodified after ctor, but cannot be const because we call thread::join(). + std::vector threads_; + + PoolWaitMode wait_mode_; + AutoTuneT auto_tune_[2]; // accessed via `AutoTuner` + + // Last because it is large. Store inside `ThreadPool` so that callers can + // bind it to the NUMA node's memory. Not stored inside `WorkerLifecycle` + // because that class would be initialized after `workers_`. + alignas(HWY_ALIGNMENT) uint8_t + worker_bytes_[sizeof(pool::Worker) * (pool::kMaxThreads + 1)]; +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_THREAD_POOL_THREAD_POOL_H_ diff --git a/lib/highway/hwy/contrib/thread_pool/topology.cc b/lib/highway/hwy/contrib/thread_pool/topology.cc new file mode 100644 index 00000000000..d73170507f2 --- /dev/null +++ b/lib/highway/hwy/contrib/thread_pool/topology.cc @@ -0,0 +1,1280 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/contrib/thread_pool/topology.h" + +#include // isspace +#include +#include +#include +#include // strchr + +#include +#include +#include +#include + +#include "hwy/base.h" // HWY_OS_WIN, HWY_WARN + +#if HWY_OS_APPLE +#include + +#include "hwy/aligned_allocator.h" // HWY_ALIGNMENT +#endif + +#if HWY_OS_WIN +#ifndef NOMINMAX +#define NOMINMAX +#endif +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#ifndef _WIN32_WINNT +#define _WIN32_WINNT 0x0601 // Windows 7 / Server 2008 +#endif +#include +#endif // HWY_OS_WIN + +#if HWY_OS_LINUX || HWY_OS_FREEBSD +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif +#include +#include +#include +#include +#include +#include +#include +#include // sysconf +#endif // HWY_OS_LINUX || HWY_OS_FREEBSD + +#if HWY_OS_FREEBSD +#include +// After param.h / types.h. +#include +#endif // HWY_OS_FREEBSD + +#if HWY_ARCH_WASM +#include +#endif + +namespace hwy { + +HWY_CONTRIB_DLLEXPORT bool HaveThreadingSupport() { +#if HWY_ARCH_WASM + return emscripten_has_threading_support() != 0; +#else + return true; +#endif +} + +namespace { + +// Returns `whole / part`, with a check that `part` evenly divides `whole`, +// which implies the result is exact. +HWY_MAYBE_UNUSED size_t DivByFactor(size_t whole, size_t part) { + HWY_ASSERT(part != 0); + const size_t div = whole / part; + const size_t mul = div * part; + if (mul != whole) { + HWY_ABORT("%zu / %zu = %zu; *%zu = %zu\n", whole, part, div, part, mul); + } + return div; +} + +#if HWY_OS_WIN + +using SLPI = SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX; + +template +bool ForEachSLPI(LOGICAL_PROCESSOR_RELATIONSHIP rel, Func&& func) { + // Get required buffer size. + DWORD buf_bytes = 0; + HWY_ASSERT(!GetLogicalProcessorInformationEx(rel, nullptr, &buf_bytes)); + // Observed when `rel` is not supported: + if (HWY_UNLIKELY(buf_bytes == 0 && GetLastError() == ERROR_GEN_FAILURE)) { + if (rel != RelationNumaNodeEx && rel != RelationProcessorDie) { + HWY_WARN("Unexpected err %lx for GLPI relationship %d\n", GetLastError(), + static_cast(rel)); + } + return false; + } + HWY_ASSERT(GetLastError() == ERROR_INSUFFICIENT_BUFFER); + // Note: `buf_bytes` may be less than `sizeof(SLPI)`, which has padding. + // `calloc` zero-initializes the `Reserved` field, part of which has been + // repurposed into `GroupCount` in SDKs, 10.0.22000.0 or possibly earlier. + uint8_t* buf = static_cast(calloc(1, buf_bytes)); + HWY_ASSERT(buf); + + // Fill the buffer. + SLPI* info = reinterpret_cast(buf); + if (HWY_UNLIKELY(!GetLogicalProcessorInformationEx(rel, info, &buf_bytes))) { + free(buf); + return false; + } + + // Iterate over each SLPI. `sizeof(SLPI)` is unreliable, see above. + uint8_t* pos = buf; + while (pos < buf + buf_bytes) { + info = reinterpret_cast(pos); + HWY_ASSERT(rel == RelationAll || info->Relationship == rel); + func(*info); + pos += info->Size; + } + if (pos != buf + buf_bytes) { + HWY_WARN("unexpected pos %p, end %p, buf_bytes %lu, sizeof(SLPI) %zu\n", + pos, buf + buf_bytes, buf_bytes, sizeof(SLPI)); + } + + free(buf); + return true; +} + +size_t NumBits(size_t num_groups, const GROUP_AFFINITY* affinity) { + size_t total_bits = 0; + for (size_t i = 0; i < num_groups; ++i) { + size_t bits = 0; + hwy::CopyBytes(&affinity[i].Mask, &bits); + total_bits += hwy::PopCount(bits); + } + return total_bits; +} + +// Calls `func(lp, lps)` for each index `lp` in the set, after ensuring that +// `lp < lps.size()`. `line` is for debugging via Warn(). +template +void ForeachBit(size_t num_groups, const GROUP_AFFINITY* affinity, + std::vector& lps, int line, const Func& func) { + for (size_t group = 0; group < num_groups; ++group) { + size_t bits = 0; + hwy::CopyBytes(&affinity[group].Mask, &bits); + while (bits != 0) { + size_t lp = group * 64 + Num0BitsBelowLS1Bit_Nonzero64(bits); + bits &= bits - 1; // clear LSB + if (HWY_UNLIKELY(lp >= lps.size())) { + Warn(__FILE__, line, "Clamping lp %zu to lps.size() %zu, groups %zu\n", + lp, lps.size(), num_groups); + lp = lps.size() - 1; + } + func(lp, lps); + } + } +} + +#elif HWY_OS_APPLE + +// Returns whether sysctlbyname() succeeded; if so, writes `val / div` to +// `out`, otherwise sets `err`. +template +bool Sysctl(const char* name, size_t div, int& err, T* out) { + size_t val = 0; + size_t size = sizeof(val); + // Last two arguments are for updating the value, which we do not want. + const int ret = sysctlbyname(name, &val, &size, nullptr, 0); + if (HWY_UNLIKELY(ret != 0)) { + // Do not print warnings because some `name` are expected to fail. + err = ret; + return false; + } + *out = static_cast(DivByFactor(val, div)); + return true; +} + +#endif // HWY_OS_* + +} // namespace + +HWY_CONTRIB_DLLEXPORT size_t TotalLogicalProcessors() { + size_t total_lps = 0; +#if HWY_ARCH_WASM + const int num_cores = emscripten_num_logical_cores(); + if (num_cores > 0) total_lps = static_cast(num_cores); +#elif HWY_OS_WIN + // If there are multiple groups, this should return them all, rather than + // just the first 64, but VMs report less. + (void)ForEachSLPI(RelationProcessorCore, [&total_lps](const SLPI& info) { + const PROCESSOR_RELATIONSHIP& p = info.Processor; + total_lps += NumBits(p.GroupCount, p.GroupMask); + }); +#elif HWY_OS_LINUX + // Only check "online" because sysfs entries such as topology are missing for + // offline CPUs, which will cause `DetectPackages` to fail. + const long ret = sysconf(_SC_NPROCESSORS_ONLN); // NOLINT(runtime/int) + if (ret < 0) { + HWY_WARN("Unexpected _SC_NPROCESSORS_CONF = %d\n", static_cast(ret)); + } else { + total_lps = static_cast(ret); + } +#elif HWY_OS_APPLE + int err; + // Only report P processors. + if (!Sysctl("hw.perflevel0.logicalcpu", 1, err, &total_lps)) { + total_lps = 0; + } +#endif + + if (HWY_UNLIKELY(total_lps == 0)) { // Failed to detect. + HWY_WARN( + "Unknown TotalLogicalProcessors, assuming 1. " + "HWY_OS_: WIN=%d LINUX=%d APPLE=%d;\n" + "HWY_ARCH_: WASM=%d X86=%d PPC=%d ARM=%d RISCV=%d S390X=%d\n", + HWY_OS_WIN, HWY_OS_LINUX, HWY_OS_APPLE, HWY_ARCH_WASM, HWY_ARCH_X86, + HWY_ARCH_PPC, HWY_ARCH_ARM, HWY_ARCH_RISCV, HWY_ARCH_S390X); + return 1; + } + + // Warn that we are clamping. + if (HWY_UNLIKELY(total_lps > kMaxLogicalProcessors)) { + HWY_WARN("OS reports %zu processors but clamping to %zu\n", total_lps, + kMaxLogicalProcessors); + total_lps = kMaxLogicalProcessors; + } + + return total_lps; +} + +// ------------------------------ Affinity + +#if HWY_OS_LINUX || HWY_OS_FREEBSD +namespace { + +#if HWY_OS_LINUX +using CpuSet = cpu_set_t; +#else +using CpuSet = cpuset_t; +#endif + +// Helper functions reduce the number of #if in GetThreadAffinity. +int GetAffinity(CpuSet* set) { + // To specify the current thread, pass 0 on Linux/Android and -1 on FreeBSD. +#if defined(__ANDROID__) && __ANDROID_API__ < 12 + return syscall(__NR_sched_getaffinity, 0, sizeof(CpuSet), set); +#elif HWY_OS_FREEBSD + return cpuset_getaffinity(CPU_LEVEL_WHICH, CPU_WHICH_TID, -1, sizeof(CpuSet), + set); +#else // normal Linux + return sched_getaffinity(0, sizeof(CpuSet), set); +#endif +} + +int SetAffinity(CpuSet* set) { + // To specify the current thread, pass 0 on Linux/Android and -1 on FreeBSD. +#if defined(__ANDROID__) && __ANDROID_API__ < 12 + return syscall(__NR_sched_setaffinity, 0, sizeof(CpuSet), set); +#elif HWY_OS_FREEBSD + return cpuset_setaffinity(CPU_LEVEL_WHICH, CPU_WHICH_TID, -1, sizeof(CpuSet), + set); +#else // normal Linux + return sched_setaffinity(0, sizeof(CpuSet), set); +#endif +} + +bool IsSet(size_t lp, const CpuSet* set) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for GCC compiler warning with CPU_ISSET macro + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4305 4309, ignored "-Wsign-conversion") +#endif + const int is_set = CPU_ISSET(static_cast(lp), set); +#if HWY_COMPILER_GCC_ACTUAL + HWY_DIAGNOSTICS(pop) +#endif + return is_set != 0; +} + +void Set(size_t lp, CpuSet* set) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for GCC compiler warning with CPU_SET macro + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4305 4309, ignored "-Wsign-conversion") +#endif + CPU_SET(static_cast(lp), set); +#if HWY_COMPILER_GCC_ACTUAL + HWY_DIAGNOSTICS(pop) +#endif +} + +} // namespace +#endif // HWY_OS_LINUX || HWY_OS_FREEBSD + +HWY_CONTRIB_DLLEXPORT bool GetThreadAffinity(LogicalProcessorSet& lps) { +#if HWY_OS_WIN + // Only support the first 64 because WINE does not support processor groups. + const HANDLE hThread = GetCurrentThread(); + const DWORD_PTR prev = SetThreadAffinityMask(hThread, ~DWORD_PTR(0)); + if (!prev) return false; + (void)SetThreadAffinityMask(hThread, prev); + lps = LogicalProcessorSet(); // clear all + lps.SetNonzeroBitsFrom64(prev); + return true; +#elif HWY_OS_LINUX || HWY_OS_FREEBSD + CpuSet set; + CPU_ZERO(&set); + const int err = GetAffinity(&set); + if (err != 0) return false; + for (size_t lp = 0; lp < kMaxLogicalProcessors; ++lp) { + if (IsSet(lp, &set)) lps.Set(lp); + } + return true; +#else + // For HWY_OS_APPLE, affinity is not supported. Do not even set lp=0 to force + // callers to handle this case. + (void)lps; + return false; +#endif +} + +HWY_CONTRIB_DLLEXPORT bool SetThreadAffinity(const LogicalProcessorSet& lps) { +#if HWY_OS_WIN + const HANDLE hThread = GetCurrentThread(); + const DWORD_PTR prev = SetThreadAffinityMask(hThread, lps.Get64()); + return prev != 0; +#elif HWY_OS_LINUX || HWY_OS_FREEBSD + CpuSet set; + CPU_ZERO(&set); + lps.Foreach([&set](size_t lp) { Set(lp, &set); }); + const int err = SetAffinity(&set); + if (err != 0) return false; + return true; +#else + // Apple THREAD_AFFINITY_POLICY is only an (often ignored) hint. + (void)lps; + return false; +#endif +} + +namespace { + +struct PackageSizes { + size_t num_clusters; + size_t num_cores; +}; + +#if HWY_OS_LINUX + +class File { + public: + explicit File(const char* path) { + for (;;) { + fd_ = open(path, O_RDONLY); + if (fd_ > 0) return; // success + if (errno == EINTR) continue; // signal: retry + if (errno == ENOENT) return; // not found, give up + HWY_WARN("Unexpected error opening %s: %d\n", path, errno); + return; // unknown error, give up + } + } + + ~File() { + if (fd_ > 0) { + for (;;) { + const int ret = close(fd_); + if (ret == 0) break; // success + if (errno == EINTR) continue; // signal: retry + HWY_WARN("Unexpected error closing file: %d\n", errno); + return; // unknown error, ignore + } + } + } + + // Returns number of bytes read or 0 on failure. + size_t Read(char* buf200) const { + if (fd_ < 0) return 0; + size_t pos = 0; + for (;;) { + // read instead of `pread`, which might not work for sysfs. + const auto bytes_read = read(fd_, buf200 + pos, 200 - pos); + if (bytes_read == 0) { // EOF: done + buf200[pos++] = '\0'; + return pos; + } + if (bytes_read == -1) { + if (errno == EINTR) continue; // signal: retry + HWY_WARN("Unexpected error reading file: %d\n", errno); + return 0; + } + pos += static_cast(bytes_read); + HWY_ASSERT(pos <= 200); + } + } + + private: + int fd_; +}; + +// Returns bytes read, or 0 on failure. +size_t ReadSysfs(const char* format, size_t lp, char* buf200) { + char path[200]; + const int bytes_written = snprintf(path, sizeof(path), format, lp); + HWY_ASSERT(0 < bytes_written && + bytes_written < static_cast(sizeof(path) - 1)); + + const File file(path); + return file.Read(buf200); +} + +// Interprets [str + pos, str + end) as base-10 ASCII. Stops when any non-digit +// is found, or at end. Returns false if no digits found. +bool ParseDigits(const char* str, const size_t end, size_t& pos, size_t* out) { + HWY_ASSERT(pos <= end); + // 9 digits cannot overflow even 32-bit size_t. + const size_t stop = pos + 9; + *out = 0; + for (; pos < HWY_MIN(end, stop); ++pos) { + const int c = str[pos]; + if (c < '0' || c > '9') break; + *out *= 10; + *out += static_cast(c - '0'); + } + if (pos == 0) { // No digits found + *out = 0; + return false; + } + return true; +} + +// Number, plus optional K or M suffix, plus terminator. +bool ParseNumberWithOptionalSuffix(const char* str, size_t len, size_t* out) { + size_t pos = 0; + if (!ParseDigits(str, len, pos, out)) return false; + if (str[pos] == 'K') { + *out <<= 10; + ++pos; + } + if (str[pos] == 'M') { + *out <<= 20; + ++pos; + } + if (str[pos] != '\0' && str[pos] != '\n') { + HWY_ABORT("Expected [suffix] terminator at %zu %s\n", pos, str); + } + return true; +} + +bool ReadNumberWithOptionalSuffix(const char* format, size_t lp, size_t* out) { + char buf200[200]; + const size_t pos = ReadSysfs(format, lp, buf200); + if (pos == 0) return false; + return ParseNumberWithOptionalSuffix(buf200, pos, out); +} + +const char* kPackage = + "/sys/devices/system/cpu/cpu%zu/topology/physical_package_id"; +const char* kCluster = "/sys/devices/system/cpu/cpu%zu/cache/index3/id"; +const char* kCore = "/sys/devices/system/cpu/cpu%zu/topology/core_id"; +const char* kL2Size = "/sys/devices/system/cpu/cpu%zu/cache/index2/size"; +const char* kL3Size = "/sys/devices/system/cpu/cpu%zu/cache/index3/size"; +const char* kNode = "/sys/devices/system/node/node%zu/cpulist"; + +// sysfs values can be arbitrarily large, so store in a map and replace with +// indices in order of appearance. +class Remapper { + public: + // Returns false on error, or sets `out_index` to the index of the sysfs + // value selected by `format` and `lp`. + template + bool operator()(const char* format, size_t lp, T* HWY_RESTRICT out_index) { + size_t opaque; + if (!ReadNumberWithOptionalSuffix(format, lp, &opaque)) return false; + + const auto ib = indices_.insert({opaque, num_}); + num_ += ib.second; // increment if inserted + const size_t index = ib.first->second; // new or existing + HWY_ASSERT(index < num_); + HWY_ASSERT(index < hwy::LimitsMax()); + *out_index = static_cast(index); + return true; + } + + size_t Num() const { return num_; } + + private: + std::map indices_; + size_t num_ = 0; +}; + +// For internal use by `DetectPackages`. +struct PerPackage { + Remapper clusters; + Remapper cores; + // We rely on this zero-init and increment it below. + uint8_t smt_per_core[kMaxLogicalProcessors] = {0}; +}; + +// Initializes `lps` and returns a PackageSizes vector (empty on failure) +// indicating the number of clusters and cores per package. +std::vector DetectPackages(std::vector& lps) { + std::vector empty; + + Remapper packages; + for (size_t lp = 0; lp < lps.size(); ++lp) { + if (!packages(kPackage, lp, &lps[lp].package)) { + HWY_WARN("Failed to read sysfs package for LP %zu\n", lp); + return empty; + } + } + std::vector per_package(packages.Num()); + HWY_ASSERT(!per_package.empty()); + + for (size_t lp = 0; lp < lps.size(); ++lp) { + PerPackage& pp = per_package[lps[lp].package]; + // Not a failure: some CPUs lack a (shared) L3 cache. + if (!pp.clusters(kCluster, lp, &lps[lp].cluster)) { + lps[lp].cluster = 0; + } + + if (!pp.cores(kCore, lp, &lps[lp].core)) { + HWY_WARN("Failed to read sysfs core for LP %zu\n", lp); + return empty; + } + + // SMT ID is how many LP we have already seen assigned to the same core. + HWY_ASSERT(lps[lp].core < kMaxLogicalProcessors); + lps[lp].smt = pp.smt_per_core[lps[lp].core]++; + HWY_ASSERT(lps[lp].smt < 16); + } + + std::vector package_sizes(per_package.size()); + for (size_t p = 0; p < package_sizes.size(); ++p) { + // Was zero if the package has no shared L3, see above. + package_sizes[p].num_clusters = HWY_MAX(1, per_package[p].clusters.Num()); + package_sizes[p].num_cores = per_package[p].cores.Num(); + HWY_ASSERT(package_sizes[p].num_cores != 0); + } + return package_sizes; +} + +std::vector ExpandList(const char* list, size_t list_end, + size_t max_lp) { + std::vector expanded; + constexpr size_t kNotFound = ~size_t{0}; + size_t pos = 0; + + // Gracefully handle empty lists, happens on GH200 systems (#2668). + if (isspace(list[0]) && list_end <= 2) return expanded; + + // Returns first `found_pos >= pos` where `list[found_pos] == c`, or + // `kNotFound`. + const auto find = [list, list_end, &pos](char c) -> size_t { + const char* found_ptr = strchr(list + pos, c); + if (found_ptr == nullptr) return kNotFound; + const size_t found_pos = static_cast(found_ptr - list); + HWY_ASSERT(found_pos < list_end && list[found_pos] == c); + return found_pos; + }; + + // Reads LP number and advances `pos`. `end` is for verifying we did not + // read past a known terminator, or the end of string. + const auto parse_lp = [list, list_end, &pos, max_lp](size_t end) -> size_t { + end = HWY_MIN(end, list_end); + size_t lp; + HWY_ASSERT(ParseDigits(list, end, pos, &lp)); + HWY_IF_CONSTEXPR(HWY_ARCH_RISCV) { + // On RISC-V, both TotalLogicalProcessors and GetThreadAffinity may + // under-report the count, hence clamp. + lp = HWY_MIN(lp, max_lp); + } + HWY_ASSERT(lp <= max_lp); + HWY_ASSERT(pos <= end); + return lp; + }; + + // Parse all [first-]last separated by commas. + for (;;) { + // Single number or first of range: ends with dash, comma, or end. + const size_t lp_range_first = parse_lp(HWY_MIN(find('-'), find(','))); + + if (list[pos] == '-') { // range + ++pos; // skip dash + // Last of range ends with comma or end. + const size_t lp_range_last = parse_lp(find(',')); + + expanded.reserve(expanded.size() + lp_range_last - lp_range_first + 1); + for (size_t lp = lp_range_first; lp <= lp_range_last; ++lp) { + expanded.push_back(lp); + } + } else { // single number + expanded.push_back(lp_range_first); + } + + // Done if reached end of string. + if (pos == list_end || list[pos] == '\0' || list[pos] == '\n') { + break; + } + // Comma means at least one more term is coming. + if (list[pos] == ',') { + ++pos; + continue; + } + HWY_ABORT("Unexpected character at %zu in %s\n", pos, list); + } // for pos + + return expanded; +} + +// Sets LP.node for all `lps`. +void SetNodes(std::vector& lps) { + // For each NUMA node found via sysfs: + for (size_t node = 0;; node++) { + // Read its cpulist so we can scatter `node` to all its `lps`. + char buf200[200]; + const size_t bytes_read = ReadSysfs(kNode, node, buf200); + if (bytes_read == 0) break; + const std::vector list = + ExpandList(buf200, bytes_read, lps.size() - 1); + for (size_t lp : list) { + lps[lp].node = static_cast(node); + } + } +} + +void SetClusterCacheSizes(std::vector& packages) { + for (size_t ip = 0; ip < packages.size(); ++ip) { + Topology::Package& p = packages[ip]; + for (size_t ic = 0; ic < p.clusters.size(); ++ic) { + Topology::Cluster& c = p.clusters[ic]; + const size_t lp = c.lps.First(); + size_t bytes; + if (ReadNumberWithOptionalSuffix(kL2Size, lp, &bytes)) { + c.private_kib = bytes >> 10; + } + if (ReadNumberWithOptionalSuffix(kL3Size, lp, &bytes)) { + c.shared_kib = bytes >> 10; + } + } + } +} + +#elif HWY_OS_WIN + +// See #2734. GroupCount was added around Windows 10, but SDK docs do not +// mention the actual version required. It is known to be absent in 8.1 and +// MinGW 5.0.1, and present in the 10.0.22000.0 SDK. However, the OS must also +// know about the field. Thus we zero-initialize the reserved field, assume it +// remains zero, and return 1 if zero (old style single GroupMask), otherwise +// the number of groups. There are two such structures, but note that +// `PROCESSOR_RELATIONSHIP` already had this field. +static size_t GroupCount(const CACHE_RELATIONSHIP& cr) { + // Added as the last u16 in the reserved area before GroupMask. We only read + // one byte because 256*64 processor bits are plenty. + const uint8_t* pcount = + reinterpret_cast(&cr.GroupMask) - sizeof(uint16_t); + return HWY_MAX(pcount[HWY_IS_BIG_ENDIAN], 1); +} + +static size_t GroupCount(const NUMA_NODE_RELATIONSHIP& nn) { + const uint8_t* pcount = + reinterpret_cast(&nn.GroupMask) - sizeof(uint16_t); + return HWY_MAX(pcount[HWY_IS_BIG_ENDIAN], 1); +} + +struct PerPackage { + size_t clusters = 0; + size_t cores = 0; + size_t max_lps_per_core = 0; +}; + +// Returns per-package vector and assigns LP.package to an index within it. +std::vector AssignPackageIndices(std::vector& lps) { + size_t package_idx = 0; + (void)ForEachSLPI( + RelationProcessorPackage, [&lps, &package_idx](const SLPI& info) { + const PROCESSOR_RELATIONSHIP& p = info.Processor; + ForeachBit(p.GroupCount, p.GroupMask, lps, __LINE__, + [&package_idx](size_t lp, std::vector& lps) { + lps[lp].package = static_cast(package_idx); + }); + ++package_idx; + }); + return std::vector(package_idx); +} + +// Sets LP.core and LP.smt and updates `PerPackage.cores/max_lps_per_core`. +void AssignCoreSmtIndices(std::vector& lps, + std::vector& per_package) { + (void)ForEachSLPI( + RelationProcessorCore, [&lps, &per_package](const SLPI& info) { + const PROCESSOR_RELATIONSHIP& p = info.Processor; + PerPackage* pp = nullptr; + // Foreach LP in this core: assign its core and smt. + size_t smt = 0; + ForeachBit(p.GroupCount, p.GroupMask, lps, __LINE__, + [&](size_t lp, std::vector& lps) { + pp = &per_package[lps[lp].package]; + lps[lp].core = static_cast(pp->cores); + lps[lp].smt = static_cast(smt++); + }); + HWY_ASSERT(pp != nullptr); + ++pp->cores; + pp->max_lps_per_core = + HWY_MAX(pp->max_lps_per_core, NumBits(p.GroupCount, p.GroupMask)); + HWY_ASSERT(pp->max_lps_per_core != 0); + }); +} + +// Interprets cluster (typically a shared L3 cache) as a "processor die". Sets +// LP.cluster and updates `PerPackage.clusters`. +void AssignClusterIndices(std::vector& lps, + std::vector& per_package) { + // Shared between `foreach_die` and `foreach_l3`. Assigns all LPs to this + // cluster and increments the cluster index. + const auto foreach_cluster = [&](size_t num_groups, + const GROUP_AFFINITY* groups) { + PerPackage* pp = nullptr; + ForeachBit(num_groups, groups, lps, __LINE__, + [&per_package, &pp](size_t lp, std::vector& lps) { + pp = &per_package[lps[lp].package]; + lps[lp].cluster = static_cast(pp->clusters); + }); + if (pp != nullptr) { + ++pp->clusters; + } + }; + + // Passes group bits to `foreach_cluster`, depending on relationship type. + const auto foreach_die = [&foreach_cluster](const SLPI& info) { + const PROCESSOR_RELATIONSHIP& p = info.Processor; + foreach_cluster(p.GroupCount, p.GroupMask); + }; + const auto foreach_l3 = [&foreach_cluster](const SLPI& info) { + const CACHE_RELATIONSHIP& cr = info.Cache; + if (cr.Type != CacheUnified && cr.Type != CacheData) return; + if (cr.Level != 3) return; + foreach_cluster(GroupCount(cr), cr.GroupMasks); + }; + + if (!ForEachSLPI(RelationProcessorDie, foreach_die)) { + // Has been observed to fail; also check for shared L3 caches. + (void)ForEachSLPI(RelationCache, foreach_l3); + } + + // All packages should have the same number of clusters. + for (size_t package_idx = 1; package_idx < per_package.size(); + ++package_idx) { + if (per_package[package_idx].clusters != per_package[0].clusters) { + HWY_ABORT("pkg %zu has %zu clusters, expected %zu\n", package_idx, + per_package[package_idx].clusters, per_package[0].clusters); + } + } + + if (per_package[0].clusters == 0) { + HWY_WARN("No clusters found, assuming 1 cluster\n"); + for (PerPackage& pp : per_package) { + pp.clusters = 1; + } + for (Topology::LP& lp : lps) { + lp.cluster = 0; + } + } +} + +// Initializes `lps` and returns a `PackageSizes` vector (empty on failure) +// indicating the number of clusters and cores per package. +std::vector DetectPackages(std::vector& lps) { + std::vector per_package = AssignPackageIndices(lps); + if (per_package.empty()) return {}; + AssignCoreSmtIndices(lps, per_package); + AssignClusterIndices(lps, per_package); + + std::vector packages(per_package.size()); + for (size_t package_idx = 0; package_idx < per_package.size(); + ++package_idx) { + packages[package_idx].num_clusters = per_package[package_idx].clusters; + packages[package_idx].num_cores = per_package[package_idx].cores; + } + return packages; +} + +// Sets LP.node for all `lps`. +void SetNodes(std::vector& lps) { + // Zero-initialize all nodes in case the below fails. + for (size_t lp = 0; lp < lps.size(); ++lp) { + lps[lp].node = 0; + } + + // We want the full NUMA nodes, but Windows Server 2022 truncates the results + // of `RelationNumaNode` to a single 64-LP group. To get the old, unlimited + // behavior without using the new `RelationNumaNodeEx` symbol, use the old + // `RelationAll` and filter the SLPI we want. + (void)ForEachSLPI(RelationAll, [&](const SLPI& info) { + if (info.Relationship != RelationNumaNode) return; + const NUMA_NODE_RELATIONSHIP& nn = info.NumaNode; + // This field was previously reserved/zero. There is at least one group. + const size_t num_groups = HWY_MAX(1, GroupCount(nn)); + const uint8_t node = static_cast(nn.NodeNumber); + ForeachBit(num_groups, nn.GroupMasks, lps, __LINE__, + [node](size_t lp, std::vector& lps) { + lps[lp].node = node; + }); + }); +} + +#elif HWY_OS_APPLE + +// Initializes `lps` and returns a `PackageSizes` vector (empty on failure) +// indicating the number of clusters and cores per package. +std::vector DetectPackages(std::vector& lps) { + int err; + + size_t total_cores = 0; + if (!Sysctl("hw.perflevel0.physicalcpu", 1, err, &total_cores)) { + HWY_WARN("Error %d detecting total_cores, assuming one per LP\n", err); + total_cores = lps.size(); + } + + if (lps.size() % total_cores != 0) { + HWY_WARN("LPs %zu not a multiple of total_cores %zu\n", lps.size(), + total_cores); + } + const size_t lp_per_core = DivCeil(lps.size(), total_cores); + + size_t cores_per_cluster = 0; + if (!Sysctl("hw.perflevel0.cpusperl2", 1, err, &cores_per_cluster)) { + HWY_WARN("Error %d detecting cores_per_cluster\n", err); + cores_per_cluster = HWY_MIN(4, total_cores); + } + + if (total_cores % cores_per_cluster != 0) { + HWY_WARN("total_cores %zu not a multiple of cores_per_cluster %zu\n", + total_cores, cores_per_cluster); + } + + for (size_t lp = 0; lp < lps.size(); ++lp) { + lps[lp].package = 0; // single package + lps[lp].core = static_cast(lp / lp_per_core); + lps[lp].smt = static_cast(lp % lp_per_core); + lps[lp].cluster = static_cast(lps[lp].core / cores_per_cluster); + } + + PackageSizes ps; + ps.num_clusters = DivCeil(total_cores, cores_per_cluster); + ps.num_cores = total_cores; + return std::vector{ps}; +} + +// Sets LP.node for all `lps`. +void SetNodes(std::vector& lps) { + for (size_t lp = 0; lp < lps.size(); ++lp) { + lps[lp].node = 0; // no NUMA + } +} + +#endif // HWY_OS_* + +#if HWY_OS_WIN || HWY_OS_APPLE + +void SetClusterCacheSizes(std::vector& packages) { + // Assumes clusters are homogeneous. Otherwise, we would have to scan + // `RelationCache` again and find the corresponding package_idx. + const Cache* caches = DataCaches(); + const size_t private_kib = caches ? caches[2].size_kib : 0; + const size_t shared_kib = caches ? caches[3].size_kib : 0; + + for (size_t ip = 0; ip < packages.size(); ++ip) { + Topology::Package& p = packages[ip]; + for (size_t ic = 0; ic < p.clusters.size(); ++ic) { + Topology::Cluster& c = p.clusters[ic]; + c.private_kib = private_kib; + c.shared_kib = shared_kib; + } + } +} + +#endif // HWY_OS_WIN || HWY_OS_APPLE + +} // namespace + +HWY_CONTRIB_DLLEXPORT Topology::Topology() { +#if HWY_OS_LINUX || HWY_OS_WIN || HWY_OS_APPLE + lps.resize(TotalLogicalProcessors()); + const std::vector& package_sizes = DetectPackages(lps); + if (package_sizes.empty()) return; + SetNodes(lps); + + // Allocate per-package/cluster/core vectors. This indicates to callers that + // detection succeeded. + packages.resize(package_sizes.size()); + for (size_t p = 0; p < packages.size(); ++p) { + packages[p].clusters.resize(package_sizes[p].num_clusters); + packages[p].cores.resize(package_sizes[p].num_cores); + } + + // Populate the per-cluster/core sets of LP. + for (size_t lp = 0; lp < lps.size(); ++lp) { + Package& p = packages[lps[lp].package]; + p.clusters[lps[lp].cluster].lps.Set(lp); + p.cores[lps[lp].core].lps.Set(lp); + } + + SetClusterCacheSizes(packages); +#endif // HWY_OS_* +} + +// ------------------------------ Cache detection + +namespace { + +using Caches = std::array; + +// We assume homogeneous caches across all clusters because some OS APIs return +// a single value for a class of CPUs. + +#if HWY_OS_LINUX +std::string ReadString(const char* name, size_t index) { + // First CPU is usually a P core. + const std::string path("/sys/devices/system/cpu/cpu0/cache/index%zu/"); + char buf200[200]; + size_t end = ReadSysfs((path + name).c_str(), index, buf200); + // Remove trailing newline/null to simplify string comparison. + for (; end != 0; --end) { + if (buf200[end - 1] != '\0' && buf200[end - 1] != '\n') break; + } + return std::string(buf200, buf200 + end); +} + +template +bool WriteSysfs(const char* name, size_t index, T* out) { + const std::string str = ReadString(name, index); + // Do not call `ParseNumberWithOptionalSuffix` because it acts on the + // K suffix in "size", but we actually want KiB. + size_t pos = 0; + size_t val; + if (!ParseDigits(str.c_str(), str.length(), pos, &val)) return false; + HWY_ASSERT(pos <= str.length()); + *out = static_cast(val); + return true; +} + +// Reading from sysfs is preferred because sysconf returns L3 associativity = 0 +// on some CPUs, and does not indicate sharing across cores. +// https://www.kernel.org/doc/Documentation/ABI/testing/sysfs-devices-system-cpu +bool InitCachesSysfs(Caches& caches) { + // For computing shared cache sizes. + std::vector lps(TotalLogicalProcessors()); + const std::vector package_sizes = DetectPackages(lps); + // `package_sizes` is only used to check that `lps` were filled. + if (package_sizes.empty()) { + HWY_WARN("no packages, shared cache sizes may be incorrect\n"); + return false; + } + + for (size_t i = 0;; ++i) { + const std::string type = ReadString("type", i); + if (type.empty()) break; // done, no more entries + if (type != "Data" && type != "Unified") continue; + uint32_t level; + if (!WriteSysfs("level", i, &level)) continue; + if (level != 1 && level != 2 && level != 3) continue; + Cache& c = caches[level]; + + // Check before overwriting any fields. + if (c.size_kib != 0) { + HWY_WARN("ignoring another L%u, first size %u\n", level, c.size_kib); + continue; + } + + const bool ok = WriteSysfs("size", i, &c.size_kib) && + WriteSysfs("ways_of_associativity", i, &c.associativity) && + WriteSysfs("number_of_sets", i, &c.sets); + if (HWY_UNLIKELY(!ok)) { + HWY_WARN("skipping partially-detected L%u, error %d\n", level, errno); + c = Cache(); + continue; + } + + // Compute line size *before* adjusting the size for sharing. Note that + // `coherency_line_size` exists, but we are not sure that is the line size. + const size_t bytes = static_cast(c.size_kib) * 1024; + const size_t lines = c.associativity * c.sets; + c.bytes_per_line = static_cast(DivByFactor(bytes, lines)); + + // Divide by number of *cores* sharing the cache. + const std::string shared_str = ReadString("shared_cpu_list", i); + if (HWY_UNLIKELY(shared_str.empty())) { + HWY_WARN("no shared_cpu_list for L%u %s\n", level, type.c_str()); + c.cores_sharing = 1; + } else { + const std::vector shared_lps = + ExpandList(shared_str.c_str(), shared_str.length(), lps.size() - 1); + size_t num_cores = 0; + for (size_t lp : shared_lps) { + if (HWY_LIKELY(lp < lps.size())) { + num_cores += lps[lp].smt == 0; + } else { + HWY_WARN("out of bounds lp %zu of %zu from %s\n", lp, lps.size(), + shared_str.c_str()); + } + } + if (num_cores == 0) { + HWY_WARN("no cores sharing L%u %s, setting to 1\n", level, + type.c_str()); + num_cores = 1; + } + c.cores_sharing = static_cast(num_cores); + // There exist CPUs for which L3 is not evenly divisible by `num_cores`, + // hence do not use `DivByFactor`. It is safer to round down. + c.size_kib = static_cast(c.size_kib / num_cores); + c.sets = static_cast(c.sets / num_cores); + } + } + + // Require L1 and L2 cache. + if (HWY_UNLIKELY(caches[1].size_kib == 0 || caches[2].size_kib == 0)) { +// Don't complain on Android because this is known to happen there. We are +// unaware of good alternatives: `getauxval(AT_L1D_CACHEGEOMETRY)` and +// `sysconf(_SC_LEVEL1_DCACHE_SIZE)` are unreliable, detecting via timing seems +// difficult to do reliably, and we do not want to maintain lists of known CPUs +// and their properties. It's OK to return false; callers are responsible for +// assuming reasonable defaults. +#ifndef __ANDROID__ + HWY_WARN("sysfs detected L1=%u L2=%u, err %d\n", caches[1].size_kib, + caches[2].size_kib, errno); +#endif + return false; + } + + // L3 is optional; if not found, its size is already zero from static init. + return true; +} + +#elif HWY_OS_WIN + +bool InitCachesWin(Caches& caches) { + std::vector lps(TotalLogicalProcessors()); + std::vector per_package = AssignPackageIndices(lps); + if (per_package.empty()) return false; + AssignCoreSmtIndices(lps, per_package); + + (void)ForEachSLPI(RelationCache, [&per_package, &caches](const SLPI& info) { + const CACHE_RELATIONSHIP& cr = info.Cache; + if (cr.Type != CacheUnified && cr.Type != CacheData) return; + if (1 <= cr.Level && cr.Level <= 3) { + Cache& c = caches[cr.Level]; + // If the size is non-zero then we (probably) have already detected this + // cache and can skip the CR. + if (c.size_kib > 0) return; + c.size_kib = static_cast(DivByFactor(cr.CacheSize, 1024)); + c.bytes_per_line = static_cast(cr.LineSize); + c.associativity = (cr.Associativity == CACHE_FULLY_ASSOCIATIVE) + ? Cache::kMaxAssociativity + : cr.Associativity; + + // How many cores share this cache? + size_t shared_with = NumBits(GroupCount(cr), cr.GroupMasks); + // Divide out hyperthreads. This core may have fewer than + // `max_lps_per_core`, hence round up. + shared_with = DivCeil(shared_with, per_package[0].max_lps_per_core); + if (shared_with == 0) { + HWY_WARN("no cores sharing L%u, setting to 1\n", cr.Level); + shared_with = 1; + } + + // Update `size_kib` to *per-core* portion. + // There exist CPUs for which L3 is not evenly divisible by `shared_with`, + // hence do not use `DivByFactor`. It is safer to round down. + c.size_kib = static_cast(c.size_kib / shared_with); + c.cores_sharing = static_cast(shared_with); + } + }); + + // Require L1 and L2 cache. + if (HWY_UNLIKELY(caches[1].size_kib == 0 || caches[2].size_kib == 0)) { + HWY_WARN("Windows detected L1=%u, L2=%u, err %lx\n", caches[1].size_kib, + caches[2].size_kib, GetLastError()); + return false; + } + + // L3 is optional; if not found, its size is already zero from static init. + return true; +} + +#elif HWY_OS_APPLE + +bool InitCachesApple(Caches& caches) { + int err = 0; + Cache& L1 = caches[1]; + Cache& L2 = caches[2]; + Cache& L3 = caches[3]; + + // Total L1 and L2 size can be reliably queried, but prefer perflevel0 + // (P-cores) because hw.l1dcachesize etc. are documented to describe the + // "least performant core". + bool ok = Sysctl("hw.perflevel0.l1dcachesize", 1024, err, &L1.size_kib) || + Sysctl("hw.l1dcachesize", 1024, err, &L1.size_kib); + ok &= Sysctl("hw.perflevel0.l2cachesize", 1024, err, &L2.size_kib) || + Sysctl("hw.l2cachesize", 1024, err, &L2.size_kib); + if (HWY_UNLIKELY(!ok)) { + HWY_WARN("Apple cache detection failed, error %d\n", err); + return false; + } + L1.cores_sharing = 1; + if (Sysctl("hw.perflevel0.cpusperl2", 1, err, &L2.cores_sharing)) { + // There exist CPUs for which L2 is not evenly divisible by `cores_sharing`, + // hence do not use `DivByFactor`. It is safer to round down. + L2.size_kib /= L2.cores_sharing; + } else { + L2.cores_sharing = 1; + } + + // Other properties are not always reported. Set `associativity` and + // `bytes_per_line` based on known models. + char brand[128] = {0}; + size_t size = sizeof(brand); + if (!sysctlbyname("machdep.cpu.brand_string", brand, &size, nullptr, 0)) { + if (strncmp(brand, "Apple ", 6) != 0) { + // Unexpected, but we will continue check the string suffixes. + HWY_WARN("unexpected Apple brand %s\n", brand); + } + + if (brand[6] == 'M') { + // https://dougallj.github.io/applecpu/firestorm.html, + // https://www.7-cpu.com/cpu/Apple_M1.html: + L1.bytes_per_line = 64; + L1.associativity = 8; + L2.bytes_per_line = 128; + if (brand[7] == '1') { // M1 + L2.associativity = 12; + } else if ('2' <= brand[7] && brand[7] <= '4') { // M2/M3, maybe also M4 + L2.associativity = 16; + } else { + L2.associativity = 0; // Unknown, set below via sysctl. + } + + // Although Wikipedia lists SLC sizes per model, we do not know how it is + // partitioned/allocated, so do not treat it as a reliable L3. + } // M* + } // brand string + + // This sysctl does not distinguish between L1 and L2 line sizes, so only use + // it if we have not already set `bytes_per_line` above. + uint16_t bytes_per_line; + if (!Sysctl("hw.cachelinesize", 1, err, &bytes_per_line)) { + bytes_per_line = static_cast(HWY_ALIGNMENT); // guess + } + for (size_t level = 1; level <= 3; ++level) { + if (caches[level].bytes_per_line == 0) { + caches[level].bytes_per_line = bytes_per_line; + } + } + + // Fill in associativity if not already set. Unfortunately this is only + // reported on x86, not on M*. + if (L1.associativity == 0 && !Sysctl("machdep.cpu.cache.L1_associativity", 1, + err, &L1.associativity)) { + L1.associativity = 8; // guess + } + if (L2.associativity == 0 && !Sysctl("machdep.cpu.cache.L2_associativity", 1, + err, &L2.associativity)) { + L2.associativity = 12; // guess + } + // There is no L3_associativity. + if (L3.associativity == 0) { + L3.associativity = 12; // guess + } + + // Now attempt to query L3. Although this sysctl is documented, M3 does not + // report an L3 cache. + if (L3.size_kib == 0 && + (Sysctl("hw.perflevel0.l3cachesize", 1024, err, &L3.size_kib) || + Sysctl("hw.l3cachesize", 1024, err, &L3.size_kib))) { + // There exist CPUs for which L3 is not evenly divisible by `cores_sharing`, + // hence do not use `DivByFactor`. It is safer to round down. + if (Sysctl("hw.perflevel0.cpusperl3", 1, err, &L3.cores_sharing)) { + L3.size_kib /= L3.cores_sharing; + } else { + L3.cores_sharing = 1; + } + } + // If no L3 cache, reset all fields for consistency. + if (L3.size_kib == 0) { + L3 = Cache(); + } + + // Are there other useful sysctls? hw.cacheconfig appears to be how many + // cores share the memory and caches, though this is not documented, and + // duplicates information in hw.perflevel0.cpusperl*. + + return true; +} + +#endif // HWY_OS_* + +// Most APIs do not set the `sets` field, so compute it from the size and +// associativity, and if a value is already set, ensure it matches. +HWY_MAYBE_UNUSED void ComputeSets(Cache& c) { + // If there is no such cache, avoid division by zero. + if (HWY_UNLIKELY(c.size_kib == 0)) { + c.sets = 0; + return; + } + const size_t bytes = static_cast(c.size_kib) * 1024; + // `size_kib` may have been rounded down, hence `lines` and `sets` are not + // necessarily evenly divisible, so round down instead of `DivByFactor`. + const size_t lines = bytes / c.bytes_per_line; + const size_t sets = lines / c.associativity; + + if (c.sets == 0) { + c.sets = static_cast(sets); + } else { + const size_t diff = c.sets - sets; + if (diff > 1) { + HWY_ABORT("Inconsistent cache sets %u != %zu\n", c.sets, sets); + } + } +} + +const Cache* InitDataCaches() { + alignas(64) static Caches caches; + + // On failure, return immediately because InitCaches*() already warn. +#if HWY_OS_LINUX + if (HWY_UNLIKELY(!InitCachesSysfs(caches))) return nullptr; +#elif HWY_OS_WIN + if (HWY_UNLIKELY(!InitCachesWin(caches))) return nullptr; +#elif HWY_OS_APPLE + if (HWY_UNLIKELY(!InitCachesApple(caches))) return nullptr; +#else + HWY_WARN("Cache detection not implemented for this platform.\n"); + (void)caches; + return nullptr; +#define HWY_NO_CACHE_DETECTION +#endif + + // Prevents "code not reached" warnings on WASM. +#ifndef HWY_NO_CACHE_DETECTION + for (size_t level = 1; level <= 3; ++level) { + ComputeSets(caches[level]); + } + + // Heuristic to ignore SLCs such as on Ampere Altra, which should not be + // treated as a reliable L3 because of their cache inclusion policy. + // On Apple M*, these are not even reported as an L3. + if (caches[3].cores_sharing >= 16 && caches[3].size_kib <= 512) { + caches[3] = Cache(); + } + + return &caches[0]; +#endif // HWY_NO_CACHE_DETECTION +} + +} // namespace + +HWY_CONTRIB_DLLEXPORT const Cache* DataCaches() { + static const Cache* caches = InitDataCaches(); + return caches; +} + +} // namespace hwy diff --git a/lib/highway/hwy/contrib/thread_pool/topology.h b/lib/highway/hwy/contrib/thread_pool/topology.h new file mode 100644 index 00000000000..84780a8745f --- /dev/null +++ b/lib/highway/hwy/contrib/thread_pool/topology.h @@ -0,0 +1,141 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_CONTRIB_THREAD_POOL_TOPOLOGY_H_ +#define HIGHWAY_HWY_CONTRIB_THREAD_POOL_TOPOLOGY_H_ + +// OS-specific functions for processor topology and thread affinity. + +#include + +#include + +#include "hwy/base.h" +#include "hwy/bit_set.h" + +namespace hwy { + +// Returns false if std::thread should not be used. +HWY_CONTRIB_DLLEXPORT bool HaveThreadingSupport(); + +// Upper bound on logical processors, including hyperthreads. +static constexpr size_t kMaxLogicalProcessors = 1024; // matches glibc + +// Set used by Get/SetThreadAffinity. +using LogicalProcessorSet = BitSet4096; + +// Returns false, or sets `lps` to all logical processors which are online and +// available to the current thread. +HWY_CONTRIB_DLLEXPORT bool GetThreadAffinity(LogicalProcessorSet& lps); + +// Ensures the current thread can only run on the logical processors in `lps`. +// Returns false if not supported (in particular on Apple), or if the +// intersection between `lps` and `GetThreadAffinity` is the empty set. +HWY_CONTRIB_DLLEXPORT bool SetThreadAffinity(const LogicalProcessorSet& lps); + +// Returns false, or ensures the current thread will only run on `lp`, which +// must not exceed `TotalLogicalProcessors`. Note that this merely calls +// `SetThreadAffinity`, see the comment there. +static inline bool PinThreadToLogicalProcessor(size_t lp) { + LogicalProcessorSet lps; + lps.Set(lp); + return SetThreadAffinity(lps); +} + +// Returns 1 if unknown, otherwise the total number of logical processors +// provided by the hardware clamped to `kMaxLogicalProcessors`. +// These processors are not necessarily all usable; you can determine which are +// via GetThreadAffinity(). +HWY_CONTRIB_DLLEXPORT size_t TotalLogicalProcessors(); + +struct Topology { + // Caller must check packages.empty(); if so, do not use any fields. + HWY_CONTRIB_DLLEXPORT Topology(); + + // Clique of cores with lower latency to each other. On Apple M1 these are + // four cores sharing an L2. On Zen4 these 'CCX' are up to eight cores sharing + // an L3 and a memory controller, or for Zen4c up to 16 and half the L3 size. + struct Cluster { + LogicalProcessorSet lps; + uint64_t private_kib = 0; // 0 if unknown + uint64_t shared_kib = 0; // 0 if unknown + uint64_t reserved1 = 0; + uint64_t reserved2 = 0; + uint64_t reserved3 = 0; + }; + + struct Core { + LogicalProcessorSet lps; + uint64_t reserved = 0; + }; + + struct Package { + std::vector clusters; + std::vector cores; + }; + + std::vector packages; + + // Several hundred instances, so prefer a compact representation. +#pragma pack(push, 1) + struct LP { + uint16_t cluster = 0; // < packages[package].clusters.size() + uint16_t core = 0; // < packages[package].cores.size() + uint8_t package = 0; // < packages.size() + uint8_t smt = 0; // < packages[package].cores[core].lps.Count() + uint8_t node = 0; + + uint8_t reserved = 0; + }; +#pragma pack(pop) + std::vector lps; // size() == TotalLogicalProcessors(). +}; + +#pragma pack(push, 1) +// Cache parameters. Note the overlap with `HWY_ALIGNMENT`, which is intended +// but not guaranteed to be an upper bound for L1/L2 line sizes, and +// `Topology::Cluster::private_kib/shared_kib`, which are intended but not +// guaranteed to be the L2/L3 sizes. Getting the exact parameters, including the +// ways of associativity, can be useful for modeling cache conflicts. +// +// Uses packed fields so the array of `Cache` fits in a typical cache line. +struct Cache { + // Arbitrary upper bound for sanity checking. + static constexpr uint16_t kMaxAssociativity = 128; + + // Zero if the level does not exist; *per-core* portion for shared caches. + uint32_t size_kib = 0; + // Also per-core portion, computed as number of lines / associativity. + uint32_t sets = 0; + uint16_t bytes_per_line = 0; + uint16_t associativity = 0; // number of ways + uint16_t cores_sharing = 0; // usually 1 for L1 + uint16_t reserved = 0; +}; +static_assert(sizeof(Cache) == 16, "Unexpected size"); +#pragma pack(pop) + +// Returns null if unknown, otherwise pointer to an array of `Cache` instances, +// where entry 0 is reserved, entry 1 describes the L1 data cache, entry 2 +// describes the (possibly unified or shared) L2, and entry 3 describes the L3 +// if its `size_kib != 0`. +// +// Initializes on-demand, which has some overhead for thread safety, hence +// callers should cache the result. +HWY_CONTRIB_DLLEXPORT const Cache* DataCaches(); + +} // namespace hwy + +#endif // HIGHWAY_HWY_CONTRIB_THREAD_POOL_TOPOLOGY_H_ diff --git a/lib/highway/hwy/contrib/unroller/README.md b/lib/highway/hwy/contrib/unroller/README.md new file mode 100644 index 00000000000..7f492b24c98 --- /dev/null +++ b/lib/highway/hwy/contrib/unroller/README.md @@ -0,0 +1,31 @@ +# Unroller + +All contents of the `unroller` folder are experimental and subject to changes. + +`Unroller` is a templated function that automatically implements common optimizations that are usually handled by compilers when writing scalar code. Modern CPUs operate much more efficiently when non-dependent calculations are packed into an instruction pipeline. For scalar code, this often means a compiler will take a one-line loop, and compile it down to hundreds of lines of machine code in order to fully capture these efficiencies. + +As of today (2023-07-06), compilers are not nearly as good at implementing these optimizations for code written in SIMD intrinsics. `Unroller` is a templated function that takes in an `UnrollerUnit` of SIMD instructions, and then implements unrolling, reordering, hoisting and tail-handling (URHT optimizations) of arrays of data being processed with SIMD intrinsics. + +### `UnrollerUnit` + +`UnrollerUnit` and `UnrollerUnit2D` are a base classes of functions that `Unroller` needs implemented in order to properly handle URHT. `UnrollerUnit` has default implementations for all but the `Func` method, which defines the SIMD operation to be applied. Many examples of how to implement these functions are in the tests. + +### Doubling values of an array example + +``` +struct DoubleUnit : UnrollerUnit { + using TT = ScalableTag; + inline Vec Func(ptrdiff_t idx, Vec x, Vec y) { + TT d; + return Mul(x, Set(d, 2)); + } +}; +``` + +Leaving all other methods in their default state, the following code will double all the values in array `a` and place them in `r` + +``` +DoubleUnit dblunit; +int r[N]; +Unroller(dblunit, a, r, N); +``` \ No newline at end of file diff --git a/lib/highway/hwy/contrib/unroller/unroller-inl.h b/lib/highway/hwy/contrib/unroller/unroller-inl.h new file mode 100644 index 00000000000..e5d9661b655 --- /dev/null +++ b/lib/highway/hwy/contrib/unroller/unroller-inl.h @@ -0,0 +1,473 @@ +// Copyright 2023 Matthew Kolbe +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_ +#undef HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_ +#else +#define HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_ +#endif + +#include // std::abs + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +namespace hn = hwy::HWY_NAMESPACE; + +template +struct UnrollerUnit { + static constexpr size_t kMaxTSize = HWY_MAX(sizeof(IN_T), sizeof(OUT_T)); + using LargerT = SignedFromSize; // only the size matters. + + DERIVED* me() { return static_cast(this); } + + static constexpr size_t MaxUnitLanes() { + return HWY_MAX_LANES_D(hn::ScalableTag); + } + static size_t ActualLanes() { return Lanes(hn::ScalableTag()); } + + using LargerD = hn::CappedTag; + using IT = hn::Rebind; + using OT = hn::Rebind; + IT d_in; + OT d_out; + using Y_VEC = hn::Vec; + using X_VEC = hn::Vec; + + Y_VEC Func(const ptrdiff_t idx, const X_VEC x, const Y_VEC y) { + return me()->Func(idx, x, y); + } + + X_VEC X0Init() { return me()->X0InitImpl(); } + + X_VEC X0InitImpl() { return hn::Zero(d_in); } + + Y_VEC YInit() { return me()->YInitImpl(); } + + Y_VEC YInitImpl() { return hn::Zero(d_out); } + + X_VEC Load(const ptrdiff_t idx, const IN_T* from) { + return me()->LoadImpl(idx, from); + } + + X_VEC LoadImpl(const ptrdiff_t idx, const IN_T* from) { + return hn::LoadU(d_in, from + idx); + } + + // MaskLoad can take in either a positive or negative number for `places`. if + // the number is positive, then it loads the top `places` values, and if it's + // negative, it loads the bottom |places| values. example: places = 3 + // | o | o | o | x | x | x | x | x | + // example places = -3 + // | x | x | x | x | x | o | o | o | + X_VEC MaskLoad(const ptrdiff_t idx, const IN_T* from, + const ptrdiff_t places) { + return me()->MaskLoadImpl(idx, from, places); + } + + X_VEC MaskLoadImpl(const ptrdiff_t idx, const IN_T* from, + const ptrdiff_t places) { + auto mask = hn::FirstN(d_in, static_cast(places)); + auto maskneg = hn::Not(hn::FirstN( + d_in, + static_cast(places + static_cast(ActualLanes())))); + if (places < 0) mask = maskneg; + + return hn::MaskedLoad(mask, d_in, from + idx); + } + + bool StoreAndShortCircuit(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) { + return me()->StoreAndShortCircuitImpl(idx, to, x); + } + + bool StoreAndShortCircuitImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) { + hn::StoreU(x, d_out, to + idx); + return true; + } + + ptrdiff_t MaskStore(const ptrdiff_t idx, OUT_T* to, const Y_VEC x, + ptrdiff_t const places) { + return me()->MaskStoreImpl(idx, to, x, places); + } + + ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x, + const ptrdiff_t places) { + auto mask = hn::FirstN(d_out, static_cast(places)); + auto maskneg = hn::Not(hn::FirstN( + d_out, + static_cast(places + static_cast(ActualLanes())))); + if (places < 0) mask = maskneg; + + hn::BlendedStore(x, mask, d_out, to + idx); + return std::abs(places); + } + + ptrdiff_t Reduce(const Y_VEC x, OUT_T* to) { return me()->ReduceImpl(x, to); } + + ptrdiff_t ReduceImpl(const Y_VEC x, OUT_T* to) { + // default does nothing + (void)x; + (void)to; + return 0; + } + + void Reduce(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) { + me()->ReduceImpl(x0, x1, x2, y); + } + + void ReduceImpl(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) { + // default does nothing + (void)x0; + (void)x1; + (void)x2; + (void)y; + } +}; + +template +struct UnrollerUnit2D { + DERIVED* me() { return static_cast(this); } + + static constexpr size_t kMaxTSize = + HWY_MAX(sizeof(IN0_T), HWY_MAX(sizeof(IN1_T), sizeof(OUT_T))); + using LargerT = SignedFromSize; // only the size matters. + + static constexpr size_t MaxUnitLanes() { + return HWY_MAX_LANES_D(hn::ScalableTag); + } + static size_t ActualLanes() { return Lanes(hn::ScalableTag()); } + + using LargerD = hn::CappedTag; + + using I0T = hn::Rebind; + using I1T = hn::Rebind; + using OT = hn::Rebind; + I0T d_in0; + I1T d_in1; + OT d_out; + using Y_VEC = hn::Vec; + using X0_VEC = hn::Vec; + using X1_VEC = hn::Vec; + + hn::Vec Func(const ptrdiff_t idx, const hn::Vec x0, + const hn::Vec x1, const Y_VEC y) { + return me()->Func(idx, x0, x1, y); + } + + X0_VEC X0Init() { return me()->X0InitImpl(); } + + X0_VEC X0InitImpl() { return hn::Zero(d_in0); } + + X1_VEC X1Init() { return me()->X1InitImpl(); } + + X1_VEC X1InitImpl() { return hn::Zero(d_in1); } + + Y_VEC YInit() { return me()->YInitImpl(); } + + Y_VEC YInitImpl() { return hn::Zero(d_out); } + + X0_VEC Load0(const ptrdiff_t idx, const IN0_T* from) { + return me()->Load0Impl(idx, from); + } + + X0_VEC Load0Impl(const ptrdiff_t idx, const IN0_T* from) { + return hn::LoadU(d_in0, from + idx); + } + + X1_VEC Load1(const ptrdiff_t idx, const IN1_T* from) { + return me()->Load1Impl(idx, from); + } + + X1_VEC Load1Impl(const ptrdiff_t idx, const IN1_T* from) { + return hn::LoadU(d_in1, from + idx); + } + + // maskload can take in either a positive or negative number for `places`. if + // the number is positive, then it loads the top `places` values, and if it's + // negative, it loads the bottom |places| values. example: places = 3 + // | o | o | o | x | x | x | x | x | + // example places = -3 + // | x | x | x | x | x | o | o | o | + X0_VEC MaskLoad0(const ptrdiff_t idx, const IN0_T* from, + const ptrdiff_t places) { + return me()->MaskLoad0Impl(idx, from, places); + } + + X0_VEC MaskLoad0Impl(const ptrdiff_t idx, const IN0_T* from, + const ptrdiff_t places) { + auto mask = hn::FirstN(d_in0, static_cast(places)); + auto maskneg = hn::Not(hn::FirstN( + d_in0, + static_cast(places + static_cast(ActualLanes())))); + if (places < 0) mask = maskneg; + + return hn::MaskedLoad(mask, d_in0, from + idx); + } + + hn::Vec MaskLoad1(const ptrdiff_t idx, const IN1_T* from, + const ptrdiff_t places) { + return me()->MaskLoad1Impl(idx, from, places); + } + + hn::Vec MaskLoad1Impl(const ptrdiff_t idx, const IN1_T* from, + const ptrdiff_t places) { + auto mask = hn::FirstN(d_in1, static_cast(places)); + auto maskneg = hn::Not(hn::FirstN( + d_in1, + static_cast(places + static_cast(ActualLanes())))); + if (places < 0) mask = maskneg; + + return hn::MaskedLoad(mask, d_in1, from + idx); + } + + // store returns a bool that is `false` when + bool StoreAndShortCircuit(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) { + return me()->StoreAndShortCircuitImpl(idx, to, x); + } + + bool StoreAndShortCircuitImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) { + hn::StoreU(x, d_out, to + idx); + return true; + } + + ptrdiff_t MaskStore(const ptrdiff_t idx, OUT_T* to, const Y_VEC x, + const ptrdiff_t places) { + return me()->MaskStoreImpl(idx, to, x, places); + } + + ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x, + const ptrdiff_t places) { + auto mask = hn::FirstN(d_out, static_cast(places)); + auto maskneg = hn::Not(hn::FirstN( + d_out, + static_cast(places + static_cast(ActualLanes())))); + if (places < 0) mask = maskneg; + + hn::BlendedStore(x, mask, d_out, to + idx); + return std::abs(places); + } + + ptrdiff_t Reduce(const Y_VEC x, OUT_T* to) { return me()->ReduceImpl(x, to); } + + ptrdiff_t ReduceImpl(const Y_VEC x, OUT_T* to) { + // default does nothing + (void)x; + (void)to; + return 0; + } + + void Reduce(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) { + me()->ReduceImpl(x0, x1, x2, y); + } + + void ReduceImpl(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) { + // default does nothing + (void)x0; + (void)x1; + (void)x2; + (void)y; + } +}; + +template +inline void Unroller(FUNC& f, const IN_T* HWY_RESTRICT x, OUT_T* HWY_RESTRICT y, + const ptrdiff_t n) { + auto xx = f.X0Init(); + auto yy = f.YInit(); + ptrdiff_t i = 0; + +#if HWY_MEM_OPS_MIGHT_FAULT + constexpr auto lane_sz = + static_cast(RemoveRef::MaxUnitLanes()); + if (n < lane_sz) { + const DFromV d; + // this may not fit on the stack for HWY_RVV, but we do not reach this code + // there + HWY_ALIGN IN_T xtmp[static_cast(lane_sz)]; + HWY_ALIGN OUT_T ytmp[static_cast(lane_sz)]; + + CopyBytes(x, xtmp, static_cast(n) * sizeof(IN_T)); + xx = f.MaskLoad(0, xtmp, n); + yy = f.Func(0, xx, yy); + Store(Zero(d), d, ytmp); + i += f.MaskStore(0, ytmp, yy, n); + i += f.Reduce(yy, ytmp); + CopyBytes(ytmp, y, static_cast(i) * sizeof(OUT_T)); + return; + } +#endif + + const ptrdiff_t actual_lanes = + static_cast(RemoveRef::ActualLanes()); + if (n > 4 * actual_lanes) { + auto xx1 = f.X0Init(); + auto yy1 = f.YInit(); + auto xx2 = f.X0Init(); + auto yy2 = f.YInit(); + auto xx3 = f.X0Init(); + auto yy3 = f.YInit(); + + while (i + 4 * actual_lanes - 1 < n) { + xx = f.Load(i, x); + i += actual_lanes; + xx1 = f.Load(i, x); + i += actual_lanes; + xx2 = f.Load(i, x); + i += actual_lanes; + xx3 = f.Load(i, x); + i -= 3 * actual_lanes; + + yy = f.Func(i, xx, yy); + yy1 = f.Func(i + actual_lanes, xx1, yy1); + yy2 = f.Func(i + 2 * actual_lanes, xx2, yy2); + yy3 = f.Func(i + 3 * actual_lanes, xx3, yy3); + + if (!f.StoreAndShortCircuit(i, y, yy)) return; + i += actual_lanes; + if (!f.StoreAndShortCircuit(i, y, yy1)) return; + i += actual_lanes; + if (!f.StoreAndShortCircuit(i, y, yy2)) return; + i += actual_lanes; + if (!f.StoreAndShortCircuit(i, y, yy3)) return; + i += actual_lanes; + } + + f.Reduce(yy3, yy2, yy1, &yy); + } + + while (i + actual_lanes - 1 < n) { + xx = f.Load(i, x); + yy = f.Func(i, xx, yy); + if (!f.StoreAndShortCircuit(i, y, yy)) return; + i += actual_lanes; + } + + if (i != n) { + xx = f.MaskLoad(n - actual_lanes, x, i - n); + yy = f.Func(n - actual_lanes, xx, yy); + f.MaskStore(n - actual_lanes, y, yy, i - n); + } + + f.Reduce(yy, y); +} + +template +inline void Unroller(FUNC& HWY_RESTRICT f, IN0_T* HWY_RESTRICT x0, + IN1_T* HWY_RESTRICT x1, OUT_T* HWY_RESTRICT y, + const ptrdiff_t n) { + const ptrdiff_t lane_sz = + static_cast(RemoveRef::ActualLanes()); + + auto xx00 = f.X0Init(); + auto xx10 = f.X1Init(); + auto yy = f.YInit(); + + ptrdiff_t i = 0; + +#if HWY_MEM_OPS_MIGHT_FAULT + if (n < lane_sz) { + const DFromV d; + // this may not fit on the stack for HWY_RVV, but we do not reach this code + // there + constexpr auto max_lane_sz = + static_cast(RemoveRef::MaxUnitLanes()); + HWY_ALIGN IN0_T xtmp0[static_cast(max_lane_sz)]; + HWY_ALIGN IN1_T xtmp1[static_cast(max_lane_sz)]; + HWY_ALIGN OUT_T ytmp[static_cast(max_lane_sz)]; + + CopyBytes(x0, xtmp0, static_cast(n) * sizeof(IN0_T)); + CopyBytes(x1, xtmp1, static_cast(n) * sizeof(IN1_T)); + xx00 = f.MaskLoad0(0, xtmp0, n); + xx10 = f.MaskLoad1(0, xtmp1, n); + yy = f.Func(0, xx00, xx10, yy); + Store(Zero(d), d, ytmp); + i += f.MaskStore(0, ytmp, yy, n); + i += f.Reduce(yy, ytmp); + CopyBytes(ytmp, y, static_cast(i) * sizeof(OUT_T)); + return; + } +#endif + + if (n > 4 * lane_sz) { + auto xx01 = f.X0Init(); + auto xx11 = f.X1Init(); + auto yy1 = f.YInit(); + auto xx02 = f.X0Init(); + auto xx12 = f.X1Init(); + auto yy2 = f.YInit(); + auto xx03 = f.X0Init(); + auto xx13 = f.X1Init(); + auto yy3 = f.YInit(); + + while (i + 4 * lane_sz - 1 < n) { + xx00 = f.Load0(i, x0); + xx10 = f.Load1(i, x1); + i += lane_sz; + xx01 = f.Load0(i, x0); + xx11 = f.Load1(i, x1); + i += lane_sz; + xx02 = f.Load0(i, x0); + xx12 = f.Load1(i, x1); + i += lane_sz; + xx03 = f.Load0(i, x0); + xx13 = f.Load1(i, x1); + i -= 3 * lane_sz; + + yy = f.Func(i, xx00, xx10, yy); + yy1 = f.Func(i + lane_sz, xx01, xx11, yy1); + yy2 = f.Func(i + 2 * lane_sz, xx02, xx12, yy2); + yy3 = f.Func(i + 3 * lane_sz, xx03, xx13, yy3); + + if (!f.StoreAndShortCircuit(i, y, yy)) return; + i += lane_sz; + if (!f.StoreAndShortCircuit(i, y, yy1)) return; + i += lane_sz; + if (!f.StoreAndShortCircuit(i, y, yy2)) return; + i += lane_sz; + if (!f.StoreAndShortCircuit(i, y, yy3)) return; + i += lane_sz; + } + + f.Reduce(yy3, yy2, yy1, &yy); + } + + while (i + lane_sz - 1 < n) { + xx00 = f.Load0(i, x0); + xx10 = f.Load1(i, x1); + yy = f.Func(i, xx00, xx10, yy); + if (!f.StoreAndShortCircuit(i, y, yy)) return; + i += lane_sz; + } + + if (i != n) { + xx00 = f.MaskLoad0(n - lane_sz, x0, i - n); + xx10 = f.MaskLoad1(n - lane_sz, x1, i - n); + yy = f.Func(n - lane_sz, xx00, xx10, yy); + f.MaskStore(n - lane_sz, y, yy, i - n); + } + + f.Reduce(yy, y); +} + +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_ diff --git a/lib/highway/hwy/detect_compiler_arch.h b/lib/highway/hwy/detect_compiler_arch.h new file mode 100644 index 00000000000..792ec5c756c --- /dev/null +++ b/lib/highway/hwy/detect_compiler_arch.h @@ -0,0 +1,527 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_DETECT_COMPILER_ARCH_H_ +#define HIGHWAY_HWY_DETECT_COMPILER_ARCH_H_ + +// Detects compiler and arch from predefined macros. Zero dependencies for +// inclusion by foreach_target.h. + +// Add to #if conditions to prevent IDE from graying out code. +// Note for clangd users: There is no predefined macro in clangd, so you must +// manually add these two lines (without the preceding '// ') to your project's +// `.clangd` file: +// CompileFlags: +// Add: [-D__CLANGD__] +#if (defined __CDT_PARSER__) || (defined __INTELLISENSE__) || \ + (defined Q_CREATOR_RUN) || (defined __CLANGD__) || \ + (defined GROK_ELLIPSIS_BUILD) || (defined __JETBRAINS_IDE__) +#define HWY_IDE 1 +#else +#define HWY_IDE 0 +#endif + +//------------------------------------------------------------------------------ +// Compiler + +// Actual MSVC, not clang-cl, which defines _MSC_VER but doesn't behave like +// MSVC in other aspects (e.g. HWY_DIAGNOSTICS). +#if defined(_MSC_VER) && !defined(__clang__) +#define HWY_COMPILER_MSVC _MSC_VER +#else +#define HWY_COMPILER_MSVC 0 +#endif + +#if defined(_MSC_VER) && defined(__clang__) +#define HWY_COMPILER_CLANGCL _MSC_VER +#else +#define HWY_COMPILER_CLANGCL 0 +#endif + +#ifdef __INTEL_COMPILER +#define HWY_COMPILER_ICC __INTEL_COMPILER +#else +#define HWY_COMPILER_ICC 0 +#endif + +#ifdef __INTEL_LLVM_COMPILER +#define HWY_COMPILER_ICX __INTEL_LLVM_COMPILER +#else +#define HWY_COMPILER_ICX 0 +#endif + +// HWY_COMPILER_GCC is a generic macro for all compilers implementing the GNU +// compiler extensions (eg. Clang, Intel...) +#ifdef __GNUC__ +#define HWY_COMPILER_GCC (__GNUC__ * 100 + __GNUC_MINOR__) +#else +#define HWY_COMPILER_GCC 0 +#endif + +#ifndef HWY_COMPILER_CLANG // Allow user override. +#ifdef __clang__ // Clang or clang-cl, not GCC. +// In case of Apple LLVM (whose version number is unrelated to that of LLVM) or +// an invalid version number, deduce it from the presence of warnings. +// Originally based on +// https://github.com/simd-everywhere/simde/blob/47d6e603de9d04ee05cdfbc57cf282a02be1bf2a/simde/simde-detect-clang.h#L59. +// Please send updates below to them as well, thanks! +#if defined(__apple_build_version__) || __clang_major__ >= 999 +#if __has_builtin(__builtin_elementwise_fshl) +#define HWY_COMPILER_CLANG 2201 +#elif __has_builtin(__builtin_structured_binding_size) +#define HWY_COMPILER_CLANG 2101 +#elif __has_builtin(__builtin_common_type) +#define HWY_COMPILER_CLANG 2001 +#elif __has_warning("-Wreturn-mismatch") +#define HWY_COMPILER_CLANG 1901 +#elif __has_warning("-Woverriding-option") +#define HWY_COMPILER_CLANG 1801 +// No new warnings in 17.0, and Apple LLVM 15.3, which should be 1600, already +// has the unsafe_buffer_usage attribute, so we instead check for new builtins. +#elif __has_builtin(__builtin_nondeterministic_value) +#define HWY_COMPILER_CLANG 1700 +#elif __has_attribute(nouwtable) // no new warnings in 16.0 +#define HWY_COMPILER_CLANG 1600 +#elif __has_warning("-Warray-parameter") +#define HWY_COMPILER_CLANG 1500 +#elif __has_warning("-Wbitwise-instead-of-logical") +#define HWY_COMPILER_CLANG 1400 +#elif __has_warning("-Wreserved-identifier") +#define HWY_COMPILER_CLANG 1300 +#elif __has_warning("-Wformat-insufficient-args") +#define HWY_COMPILER_CLANG 1200 +#elif __has_warning("-Wimplicit-const-int-float-conversion") +#define HWY_COMPILER_CLANG 1100 +#elif __has_warning("-Wmisleading-indentation") +#define HWY_COMPILER_CLANG 1000 +#elif defined(__FILE_NAME__) +#define HWY_COMPILER_CLANG 900 +#elif __has_warning("-Wextra-semi-stmt") || \ + __has_builtin(__builtin_rotateleft32) +#define HWY_COMPILER_CLANG 800 +// For reasons unknown, XCode 10.3 (Apple LLVM version 10.0.1) is apparently +// based on Clang 7, but does not support the warning we test. +// See https://en.wikipedia.org/wiki/Xcode#Toolchain_versions and +// https://trac.macports.org/wiki/XcodeVersionInfo. +#elif __has_warning("-Wc++98-compat-extra-semi") || \ + (defined(__apple_build_version__) && __apple_build_version__ >= 10010000) +#define HWY_COMPILER_CLANG 700 +#else // Anything older than 7.0 is not recommended for Highway. +#define HWY_COMPILER_CLANG 600 +#endif // __has_warning chain +#else // use normal version +#define HWY_COMPILER_CLANG (__clang_major__ * 100 + __clang_minor__) +#define HWY_COMPILER3_CLANG \ + (__clang_major__ * 10000 + __clang_minor__ * 100 + __clang_patchlevel__) +#endif +#else // Not clang +#define HWY_COMPILER_CLANG 0 +#define HWY_COMPILER3_CLANG 0 +#endif // __clang__ +#endif // HWY_COMPILER_CLANG + +// User-defined or deduced HWY_COMPILER_CLANG: derive HWY_COMPILER3_CLANG. +#ifndef HWY_COMPILER3_CLANG +#define HWY_COMPILER3_CLANG (HWY_COMPILER_CLANG * 100) +#endif + +#if HWY_COMPILER_GCC && !HWY_COMPILER_CLANG && !HWY_COMPILER_ICC && \ + !HWY_COMPILER_ICX +#define HWY_COMPILER_GCC_ACTUAL HWY_COMPILER_GCC +#else +#define HWY_COMPILER_GCC_ACTUAL 0 +#endif + +// More than one may be nonzero, but we want at least one. +#if 0 == (HWY_COMPILER_MSVC + HWY_COMPILER_CLANGCL + HWY_COMPILER_ICC + \ + HWY_COMPILER_ICX + HWY_COMPILER_GCC + HWY_COMPILER_CLANG) +#error "Unsupported compiler" +#endif + +// We should only detect one of these (only clang/clangcl/icx overlap) +#if 1 < (!!HWY_COMPILER_MSVC + (!!HWY_COMPILER_ICC & !HWY_COMPILER_ICX) + \ + !!HWY_COMPILER_GCC_ACTUAL + \ + !!(HWY_COMPILER_ICX | HWY_COMPILER_CLANGCL | HWY_COMPILER_CLANG)) +#error "Detected multiple compilers" +#endif + +//------------------------------------------------------------------------------ +// Compiler features and C++ version + +#ifdef __has_builtin +#define HWY_HAS_BUILTIN(name) __has_builtin(name) +#else +#define HWY_HAS_BUILTIN(name) 0 +#endif + +#ifdef __has_attribute +#define HWY_HAS_ATTRIBUTE(name) __has_attribute(name) +#else +#define HWY_HAS_ATTRIBUTE(name) 0 +#endif + +#ifdef __has_cpp_attribute +#define HWY_HAS_CPP_ATTRIBUTE(name) __has_cpp_attribute(name) +#else +#define HWY_HAS_CPP_ATTRIBUTE(name) 0 +#endif + +#ifdef __has_feature +#define HWY_HAS_FEATURE(name) __has_feature(name) +#else +#define HWY_HAS_FEATURE(name) 0 +#endif + +// NOTE: clang ~17 does not correctly handle wrapping __has_include in a macro. + +#if HWY_COMPILER_MSVC && defined(_MSVC_LANG) && _MSVC_LANG > __cplusplus +#define HWY_CXX_LANG _MSVC_LANG +#else +#define HWY_CXX_LANG __cplusplus +#endif + +// Use instead of constexpr to avoid compiler errors for older compilers. This +// macro is for when a constexpr function involves multiple statements and +// loops, which is allowed in C++14 but not before. If the compiler does not +// support C++14 constexpr, this evaluates to nothing. +#if defined(__cpp_constexpr) && __cpp_constexpr >= 201304L +#define HWY_CXX14_CONSTEXPR constexpr +#else +#define HWY_CXX14_CONSTEXPR +#endif + +// Same as above, but for C++17 constexpr, which adds support for lambdas, the +// standard library, and capturing *this. +// Note that C++17 constexpr still disallows allocating and virtual functions, +// which are allowed in C++20, but we do not have a use case yet. +#if defined(__cpp_constexpr) && __cpp_constexpr >= 201603L +#define HWY_CXX17_CONSTEXPR constexpr +#else +#define HWY_CXX17_CONSTEXPR +#endif + +// Use instead of `if constexpr` to avoid compiler errors for older compilers. +// When compilers lack C++17 support, this evaluates to a normal if statement. +#if HWY_CXX_LANG >= 201703L || \ + (defined(__cpp_if_constexpr) && __cpp_if_constexpr >= 201606L) +#define HWY_IF_CONSTEXPR if constexpr +#else +#define HWY_IF_CONSTEXPR if +#endif + +// Use for constexpr variables at namespace scope in headers. Constexpr is +// separate to allow using `HWY_CXX14_CONSTEXPR` if required. +#ifndef HWY_INLINE_VAR +#if __cplusplus > 201402L +// C++17: mark as COMDAT to ensure linkers de-duplicate it. See +// https://quuxplusone.github.io/blog/2022/07/08/inline-constexpr/ +#define HWY_INLINE_VAR inline +#else +#define HWY_INLINE_VAR +#endif +#endif + +//------------------------------------------------------------------------------ +// Sanitizers + +#if HWY_HAS_FEATURE(memory_sanitizer) || defined(MEMORY_SANITIZER) || \ + defined(__SANITIZE_MEMORY__) +#define HWY_IS_MSAN 1 +#else +#define HWY_IS_MSAN 0 +#endif + +#if HWY_HAS_FEATURE(address_sanitizer) || defined(ADDRESS_SANITIZER) || \ + defined(__SANITIZE_ADDRESS__) +#define HWY_IS_ASAN 1 +#else +#define HWY_IS_ASAN 0 +#endif + +#if HWY_HAS_FEATURE(hwaddress_sanitizer) || defined(HWADDRESS_SANITIZER) || \ + defined(__SANITIZE_HWADDRESS__) +#define HWY_IS_HWASAN 1 +#else +#define HWY_IS_HWASAN 0 +#endif + +#if HWY_HAS_FEATURE(thread_sanitizer) || defined(THREAD_SANITIZER) || \ + defined(__SANITIZE_THREAD__) +#define HWY_IS_TSAN 1 +#else +#define HWY_IS_TSAN 0 +#endif + +#if HWY_HAS_FEATURE(undefined_behavior_sanitizer) || \ + defined(UNDEFINED_BEHAVIOR_SANITIZER) +#define HWY_IS_UBSAN 1 +#else +#define HWY_IS_UBSAN 0 +#endif + +// MSAN may cause lengthy build times or false positives e.g. in AVX3 DemoteTo. +// You can disable MSAN by adding this attribute to the function that fails. +#if HWY_IS_MSAN +#define HWY_ATTR_NO_MSAN __attribute__((no_sanitize_memory)) +#else +#define HWY_ATTR_NO_MSAN +#endif + +#if HWY_IS_ASAN || HWY_IS_HWASAN || HWY_IS_MSAN || HWY_IS_TSAN || HWY_IS_UBSAN +#define HWY_IS_SANITIZER 1 +#else +#define HWY_IS_SANITIZER 0 +#endif + +//------------------------------------------------------------------------------ +// Architecture + +#if defined(__i386__) || defined(_M_IX86) +#define HWY_ARCH_X86_32 1 +#else +#define HWY_ARCH_X86_32 0 +#endif + +#if defined(__x86_64__) || defined(_M_X64) +#define HWY_ARCH_X86_64 1 +#else +#define HWY_ARCH_X86_64 0 +#endif + +#if HWY_ARCH_X86_32 && HWY_ARCH_X86_64 +#error "Cannot have both x86-32 and x86-64" +#endif + +#if HWY_ARCH_X86_32 || HWY_ARCH_X86_64 +#define HWY_ARCH_X86 1 +#else +#define HWY_ARCH_X86 0 +#endif + +#if defined(__powerpc64__) || defined(_M_PPC) || defined(__powerpc__) +#define HWY_ARCH_PPC 1 +#else +#define HWY_ARCH_PPC 0 +#endif + +#if defined(__powerpc64__) || (HWY_ARCH_PPC && defined(__64BIT__)) +#define HWY_ARCH_PPC_64 1 +#else +#define HWY_ARCH_PPC_64 0 +#endif + +// aarch32 is currently not supported; please raise an issue if you want it. +#if defined(__ARM_ARCH_ISA_A64) || defined(__aarch64__) || defined(_M_ARM64) +#define HWY_ARCH_ARM_A64 1 +#else +#define HWY_ARCH_ARM_A64 0 +#endif + +#if (defined(__ARM_ARCH) && __ARM_ARCH == 7) || (defined(_M_ARM) && _M_ARM == 7) +#define HWY_ARCH_ARM_V7 1 +#else +#define HWY_ARCH_ARM_V7 0 +#endif + +#if HWY_ARCH_ARM_A64 && HWY_ARCH_ARM_V7 +#error "Cannot have both A64 and V7" +#endif + +// Any *supported* version of Arm, i.e. 7 or later +#if HWY_ARCH_ARM_A64 || HWY_ARCH_ARM_V7 +#define HWY_ARCH_ARM 1 +#else +#define HWY_ARCH_ARM 0 +#endif + +// Older than Armv7 (e.g. armel aka Armv5) => we do not support SIMD. +#if (defined(__arm__) || defined(_M_ARM)) && !HWY_ARCH_ARM +#define HWY_ARCH_ARM_OLD 1 +#else +#define HWY_ARCH_ARM_OLD 0 +#endif + +#if defined(__EMSCRIPTEN__) || defined(__wasm__) || defined(__WASM__) +#define HWY_ARCH_WASM 1 +#else +#define HWY_ARCH_WASM 0 +#endif + +#ifdef __riscv +#define HWY_ARCH_RISCV 1 +#else +#define HWY_ARCH_RISCV 0 +#endif +// DEPRECATED names; please use HWY_ARCH_RISCV instead. +#define HWY_ARCH_RVV HWY_ARCH_RISCV + +#if HWY_ARCH_RISCV && defined(__riscv_xlen) + +#if __riscv_xlen == 32 +#define HWY_ARCH_RISCV_32 1 +#else +#define HWY_ARCH_RISCV_32 0 +#endif + +#if __riscv_xlen == 64 +#define HWY_ARCH_RISCV_64 1 +#else +#define HWY_ARCH_RISCV_64 0 +#endif + +#else // !HWY_ARCH_RISCV || !defined(__riscv_xlen) +#define HWY_ARCH_RISCV_32 0 +#define HWY_ARCH_RISCV_64 0 +#endif // HWY_ARCH_RISCV && defined(__riscv_xlen) + +#if HWY_ARCH_RISCV_32 && HWY_ARCH_RISCV_64 +#error "Cannot have both RISCV_32 and RISCV_64" +#endif + +#if defined(__s390x__) +#define HWY_ARCH_S390X 1 +#else +#define HWY_ARCH_S390X 0 +#endif + +#if defined(__loongarch64__) || defined(__loongarch64) || \ + (defined(__loongarch_grlen) && __loongarch_grlen == 64) +#define HWY_ARCH_LOONGARCH_64 1 +#else +#define HWY_ARCH_LOONGARCH_64 0 +#endif + +#if defined(__loongarch__) && !HWY_ARCH_LOONGARCH_64 +#define HWY_ARCH_LOONGARCH_32 1 +#else +#define HWY_ARCH_LOONGARCH_32 0 +#endif + +#if HWY_ARCH_LOONGARCH_64 || HWY_ARCH_LOONGARCH_32 +#define HWY_ARCH_LOONGARCH 1 +#else +#define HWY_ARCH_LOONGARCH 0 +#endif + +#if defined(__hexagon__) || defined(__HEXAGON_ARCH__) +#define HWY_ARCH_HEXAGON 1 +#else +#define HWY_ARCH_HEXAGON 0 +#endif + +// It is an error to detect multiple architectures at the same time, but OK to +// detect none of the above. +#if (HWY_ARCH_X86 + HWY_ARCH_PPC + HWY_ARCH_ARM + HWY_ARCH_ARM_OLD + \ + HWY_ARCH_WASM + HWY_ARCH_RISCV + HWY_ARCH_S390X + HWY_ARCH_LOONGARCH + \ + HWY_ARCH_HEXAGON) > 1 +#error "Must not detect more than one architecture" +#endif + +#if HWY_ARCH_RISCV +#define HWY_ARCH_MAX_BYTES 65536 +#elif HWY_ARCH_ARM_A64 +#define HWY_ARCH_MAX_BYTES 256 +#elif HWY_ARCH_HEXAGON +#define HWY_ARCH_MAX_BYTES 128 +#elif HWY_ARCH_X86 +#define HWY_ARCH_MAX_BYTES 64 +#elif HWY_ARCH_WASM || HWY_ARCH_LOONGARCH +#define HWY_ARCH_MAX_BYTES 32 +#elif HWY_ARCH_PPC || HWY_ARCH_S390X || HWY_ARCH_ARM_V7 || HWY_ARCH_ARM_OLD +#define HWY_ARCH_MAX_BYTES 16 +#else +#error "Missing case for HWY_ARCH_*" +#endif + +//------------------------------------------------------------------------------ +// Operating system + +#if defined(_WIN32) || defined(_WIN64) +#define HWY_OS_WIN 1 +#else +#define HWY_OS_WIN 0 +#endif + +#if defined(linux) || defined(__linux__) +#define HWY_OS_LINUX 1 +#else +#define HWY_OS_LINUX 0 +#endif + +// iOS or Mac +#if defined(__APPLE__) +#define HWY_OS_APPLE 1 +#else +#define HWY_OS_APPLE 0 +#endif + +#if defined(__FreeBSD__) +#define HWY_OS_FREEBSD 1 +#else +#define HWY_OS_FREEBSD 0 +#endif + +// It is an error to detect multiple OSes at the same time, but OK to +// detect none of the above. +#if (HWY_OS_WIN + HWY_OS_LINUX + HWY_OS_APPLE + HWY_OS_FREEBSD) > 1 +#error "Must not detect more than one OS" +#endif + +//------------------------------------------------------------------------------ +// Endianness + +#if HWY_COMPILER_MSVC +#if HWY_ARCH_PPC && defined(_XBOX_VER) && _XBOX_VER >= 200 +// XBox 360 is big-endian +#define HWY_IS_LITTLE_ENDIAN 0 +#define HWY_IS_BIG_ENDIAN 1 +#else +// All other targets supported by MSVC are little-endian +#define HWY_IS_LITTLE_ENDIAN 1 +#define HWY_IS_BIG_ENDIAN 0 +#endif // HWY_ARCH_PPC && defined(_XBOX_VER) && _XBOX_VER >= 200 +#elif defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN__) && \ + __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ +#define HWY_IS_LITTLE_ENDIAN 1 +#define HWY_IS_BIG_ENDIAN 0 +#elif defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \ + __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define HWY_IS_LITTLE_ENDIAN 0 +#define HWY_IS_BIG_ENDIAN 1 +#else +#error "Unable to detect endianness or unsupported byte order" +#endif + +#if (HWY_IS_LITTLE_ENDIAN + HWY_IS_BIG_ENDIAN) != 1 +#error "Must only detect one byte order" +#endif + +//------------------------------------------------------------------------------ +// Features checked in set_macros-inl.h + +// Compiler supports ACLE __bf16, not necessarily with operators. +// +// Disable the __bf16 type on AArch64 with GCC 13 or earlier as there is a bug +// in GCC 13 and earlier that sometimes causes BF16 constant values to be +// incorrectly loaded on AArch64, and this GCC bug on AArch64 is +// described at https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111867. +#if HWY_ARCH_ARM_A64 && \ + (HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400) +#define HWY_ARM_HAVE_SCALAR_BF16_TYPE 1 +#else +#define HWY_ARM_HAVE_SCALAR_BF16_TYPE 0 +#endif + +#endif // HIGHWAY_HWY_DETECT_COMPILER_ARCH_H_ diff --git a/lib/highway/hwy/detect_targets.h b/lib/highway/hwy/detect_targets.h new file mode 100644 index 00000000000..16d59c21d24 --- /dev/null +++ b/lib/highway/hwy/detect_targets.h @@ -0,0 +1,983 @@ +// Copyright 2021 Google LLC +// Copyright 2025 Arm Limited and/or its affiliates +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_DETECT_TARGETS_H_ +#define HIGHWAY_HWY_DETECT_TARGETS_H_ + +// Defines targets and chooses which to enable. + +#include "hwy/detect_compiler_arch.h" + +//------------------------------------------------------------------------------ +// Optional configuration + +// See g3doc/quick_reference.md for documentation of these macros. + +// Uncomment to override the default baseline determined from predefined macros: +// #define HWY_BASELINE_TARGETS (HWY_SSE4 | HWY_SCALAR) + +// Uncomment to override the default blocklist: +// #define HWY_BROKEN_TARGETS HWY_AVX3 + +// Uncomment to definitely avoid generating those target(s): +// #define HWY_DISABLED_TARGETS HWY_SSE4 + +// Uncomment to avoid emitting BMI/BMI2/FMA instructions (allows generating +// AVX2 target for VMs which support AVX2 but not the other instruction sets) +// #define HWY_DISABLE_BMI2_FMA + +// Uncomment to enable these on MSVC even if the predefined macros are not set. +// #define HWY_WANT_SSE2 1 +// #define HWY_WANT_SSSE3 1 +// #define HWY_WANT_SSE4 1 + +//------------------------------------------------------------------------------ +// Targets + +// Unique bit value for each target. A lower value is "better" (e.g. more lanes) +// than a higher value within the same group/platform - see HWY_STATIC_TARGET. +// +// All values are unconditionally defined so we can test HWY_TARGETS without +// first checking the HWY_ARCH_*. +// +// The C99 preprocessor evaluates #if expressions using intmax_t types. This +// holds at least 64 bits in practice (verified 2022-07-18 via Godbolt on +// 32-bit clang/GCC/MSVC compilers for x86/Arm7/AArch32/RISC-V/WASM). We now +// avoid overflow when computing HWY_TARGETS (subtracting one instead of +// left-shifting 2^62), but still do not use bit 63 because it is the sign bit. + +// --------------------------- x86: 15 targets (+ one fallback) +// Bits 0..2 reserved (3 targets) +#define HWY_AVX10_2 (1LL << 3) // AVX10.2 with 512-bit vectors +#define HWY_AVX3_SPR (1LL << 4) +// Bit 5: reserved (1 target) +// Currently `HWY_AVX3_DL` plus `AVX512BF16` and a special case for +// `CompressStore` (10x as fast, still useful on Zen5). We may later also use +// `VPCONFLICT`. Note that `VP2INTERSECT` is available in Zen5. +#define HWY_AVX3_ZEN4 (1LL << 6) // see HWY_WANT_AVX3_ZEN4 below + +// Currently satisfiable by Ice Lake (`VNNI`, `VPCLMULQDQ`, `VPOPCNTDQ`, +// `VBMI`, `VBMI2`, `VAES`, `BITALG`, `GFNI`). +#define HWY_AVX3_DL (1LL << 7) +#define HWY_AVX3 (1LL << 8) // HWY_AVX2 plus AVX-512F/BW/CD/DQ/VL +#define HWY_AVX2 (1LL << 9) // HWY_SSE4 plus BMI2 + F16 + FMA +// Bit 10: reserved +#define HWY_SSE4 (1LL << 11) // SSE4.2 plus AES + CLMUL +#define HWY_SSSE3 (1LL << 12) // S-SSE3 +// Bit 13: reserved for SSE3 +#define HWY_SSE2 (1LL << 14) +// The highest bit in the HWY_TARGETS mask that a x86 target can have. Used for +// dynamic dispatch. All x86 target bits must be lower or equal to +// (1 << HWY_HIGHEST_TARGET_BIT_X86) and they can only use +// HWY_MAX_DYNAMIC_TARGETS in total. +#define HWY_HIGHEST_TARGET_BIT_X86 14 + +// --------------------------- Arm: 15 targets (+ one fallback) +// Bits 15..17 reserved (3 targets) +#define HWY_SVE2_128 (1LL << 18) // specialized (e.g. Neoverse V2/N2/N3) +#define HWY_SVE_256 (1LL << 19) // specialized (Neoverse V1) +// Bits 20-22 reserved for later SVE (3 targets) +#define HWY_SVE2 (1LL << 23) +#define HWY_SVE (1LL << 24) +// Bit 25 reserved for NEON +#define HWY_NEON_BF16 (1LL << 26) // fp16/dot/bf16 (e.g. Neoverse V2/N2) +// Bit 27 reserved for NEON +#define HWY_NEON (1LL << 28) // Implies support for AES +#define HWY_NEON_WITHOUT_AES (1LL << 29) +#define HWY_HIGHEST_TARGET_BIT_ARM 29 + +#define HWY_ALL_NEON (HWY_NEON_WITHOUT_AES | HWY_NEON | HWY_NEON_BF16) +#define HWY_ALL_SVE (HWY_SVE | HWY_SVE2 | HWY_SVE_256 | HWY_SVE2_128) + +// --------------------------- RISC-V: 9 targets (+ one fallback) +// Bits 30..36 reserved (7 targets) +#define HWY_RVV (1LL << 37) +// Bit 38 reserved +#define HWY_HIGHEST_TARGET_BIT_RVV 38 + +// --------------------------- LoongArch: 3 targets (+ one fallback) +// Bits 39 reserved (1 target) +#define HWY_LASX (1LL << 40) +#define HWY_LSX (1LL << 41) +#define HWY_HIGHEST_TARGET_BIT_LOONGARCH 41 + +// --------------------------- Future expansion: 1 target +// Bits 42 reserved + +// --------------------------- IBM Power/ZSeries: 9 targets (+ one fallback) +// Bits 43..46 reserved (4 targets) +#define HWY_PPC10 (1LL << 47) // v3.1 +#define HWY_PPC9 (1LL << 48) // v3.0 +#define HWY_PPC8 (1LL << 49) // v2.07 +#define HWY_Z15 (1LL << 50) // Z15 +#define HWY_Z14 (1LL << 51) // Z14 +#define HWY_HIGHEST_TARGET_BIT_PPC 51 + +#define HWY_ALL_PPC (HWY_PPC8 | HWY_PPC9 | HWY_PPC10) + +// --------------------------- WebAssembly: 9 targets (+ one fallback) +// Bits 52..57 reserved (6 targets) +#define HWY_WASM_EMU256 (1LL << 58) // Experimental +#define HWY_WASM (1LL << 59) +// Bits 60 reserved +#define HWY_HIGHEST_TARGET_BIT_WASM 60 + +// --------------------------- Emulation: 2 targets + +#define HWY_EMU128 (1LL << 61) +// We do not add/left-shift, so this will not overflow to a negative number. +#define HWY_SCALAR (1LL << 62) +#define HWY_HIGHEST_TARGET_BIT_SCALAR 62 + +// Do not use bit 63 - would be confusing to have negative numbers. + +//------------------------------------------------------------------------------ +// Set default blocklists + +// Disabled means excluded from enabled at user's request. A separate config +// macro allows disabling without deactivating the blocklist below. +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS 0 +#endif + +// Broken means excluded from enabled due to known compiler issues. We define +// separate HWY_BROKEN_* and then OR them together (more than one might apply). + +#ifndef HWY_BROKEN_CLANG6 // allow override +// x86 clang-6: we saw multiple AVX2/3 compile errors and in one case invalid +// SSE4 codegen (possibly only for msan), so disable all those targets. +#if HWY_ARCH_X86 && (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 700) +#define HWY_BROKEN_CLANG6 (HWY_SSE4 | (HWY_SSE4 - 1)) +// This entails a major speed reduction, so warn unless the user explicitly +// opts in to scalar-only. +#if !defined(HWY_COMPILE_ONLY_SCALAR) +#pragma message("x86 Clang <= 6: define HWY_COMPILE_ONLY_SCALAR or upgrade.") +#endif + +#else +#define HWY_BROKEN_CLANG6 0 +#endif +#endif // HWY_BROKEN_CLANG6 + +#ifndef HWY_BROKEN_32BIT // allow override +// 32-bit may fail to compile AVX2/3. +#if HWY_ARCH_X86_32 +// GCC-13 is ok with AVX2: +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL >= 1300) +#define HWY_BROKEN_32BIT (HWY_AVX3 | (HWY_AVX3 - 1)) +#else +#define HWY_BROKEN_32BIT (HWY_AVX2 | (HWY_AVX2 - 1)) +#endif +#else +#define HWY_BROKEN_32BIT 0 +#endif +#endif // HWY_BROKEN_32BIT + +#ifndef HWY_BROKEN_MSVC // allow override +// MSVC AVX3 support is buggy: https://github.com/Mysticial/Flops/issues/16 +#if HWY_COMPILER_MSVC != 0 +#define HWY_BROKEN_MSVC (HWY_AVX3 | (HWY_AVX3 - 1)) +#else +#define HWY_BROKEN_MSVC 0 +#endif +#endif // HWY_BROKEN_MSVC + +#ifndef HWY_BROKEN_AVX10_2 // allow override +// AVX10_2 requires clang >= 20.1 (postpone to 23 due to "avx10.2-512" remnant, +// only removed in https://github.com/llvm/llvm-project/pull/157034) or +// gcc >= 15.2 with binutils 2.44. +#if (HWY_COMPILER_CLANG < 2300) && (HWY_COMPILER_GCC_ACTUAL < 1502) +#define HWY_BROKEN_AVX10_2 HWY_AVX10_2 +#else +#define HWY_BROKEN_AVX10_2 0 +#endif +#endif // HWY_BROKEN_AVX10_2 + +#ifndef HWY_BROKEN_AVX3_DL_ZEN4 // allow override +// AVX3_DL and AVX3_ZEN4 require clang >= 7 (ensured above), gcc >= 10.1 or ICC +// 2021. +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1001) || \ + (HWY_COMPILER_ICC && HWY_COMPILER_ICC < 2021) +#define HWY_BROKEN_AVX3_DL_ZEN4 (HWY_AVX3_DL | HWY_AVX3_ZEN4) +#else +#define HWY_BROKEN_AVX3_DL_ZEN4 0 +#endif +#endif // HWY_BROKEN_AVX3_DL_ZEN4 + +#ifndef HWY_BROKEN_AVX3_SPR // allow override +// AVX3_SPR requires clang >= 14, gcc >= 12, or ICC 2021. +#if (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 1400) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1200) || \ + (HWY_COMPILER_ICC && HWY_COMPILER_ICC < 2021) +#define HWY_BROKEN_AVX3_SPR (HWY_AVX3_SPR) +#else +#define HWY_BROKEN_AVX3_SPR 0 +#endif +#endif // HWY_BROKEN_AVX3_SPR + +#ifndef HWY_BROKEN_ARM7_BIG_ENDIAN // allow override +// armv7be has not been tested and is not yet supported. +#if HWY_ARCH_ARM_V7 && HWY_IS_BIG_ENDIAN +#define HWY_BROKEN_ARM7_BIG_ENDIAN HWY_ALL_NEON +#else +#define HWY_BROKEN_ARM7_BIG_ENDIAN 0 +#endif +#endif // HWY_BROKEN_ARM7_BIG_ENDIAN + +#ifdef __ARM_NEON_FP +#define HWY_HAVE_NEON_FP __ARM_NEON_FP +#else +#define HWY_HAVE_NEON_FP 0 +#endif + +#ifndef HWY_BROKEN_ARM7_WITHOUT_VFP4 // allow override +// armv7-a without a detected vfpv4 is not supported +// (for example Cortex-A8, Cortex-A9) +// vfpv4 always have neon half-float _and_ FMA. +#if HWY_ARCH_ARM_V7 && (__ARM_ARCH_PROFILE == 'A') && \ + !defined(__ARM_VFPV4__) && \ + !((HWY_HAVE_NEON_FP & 0x2 /* half-float */) && (__ARM_FEATURE_FMA == 1)) +#define HWY_BROKEN_ARM7_WITHOUT_VFP4 HWY_ALL_NEON +#else +#define HWY_BROKEN_ARM7_WITHOUT_VFP4 0 +#endif +#endif // HWY_BROKEN_ARM7_WITHOUT_VFP4 + +#ifndef HWY_BROKEN_NEON_BF16 // allow override +// Broken on older compilers: +#if (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 1700) || \ + (HWY_COMPILER_GCC_ACTUAL != 0 && HWY_COMPILER_GCC_ACTUAL < 1302) || \ + (defined(__apple_build_version__) && __apple_build_version__ <= 17000000) +#define HWY_BROKEN_NEON_BF16 (HWY_NEON_BF16) +#else +#define HWY_BROKEN_NEON_BF16 0 +#endif +#endif // HWY_BROKEN_NEON_BF16 + +// SVE[2] require recent clang or gcc versions. + +#ifndef HWY_BROKEN_SVE // allow override +// Clang 22+, GCC 10+, except MSAN does not yet support SVE. +// No Apple CPU (at least up to and including M4 and A18) has SVE. +#if (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 2200) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1000) || \ + HWY_OS_APPLE || HWY_IS_MSAN +#define HWY_BROKEN_SVE (HWY_SVE | HWY_SVE_256) +#else +#define HWY_BROKEN_SVE 0 +#endif +#endif // HWY_BROKEN_SVE + +#ifndef HWY_BROKEN_SVE2 // allow override +// Clang 22+, GCC 10+, except MSAN does not yet support SVE2. +// No Apple CPU (at least up to and including M4 and A18) has SVE2. +#if (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 2200) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1000) || \ + HWY_OS_APPLE || HWY_IS_MSAN +#define HWY_BROKEN_SVE2 (HWY_SVE2) +#else +#define HWY_BROKEN_SVE2 0 +#endif +#endif // HWY_BROKEN_SVE2 + +#ifndef HWY_BROKEN_SVE2_128 // allow override +// GCC 10+. Clang 21 works for SVE2_128, but not for SVE2 nor MSAN. +#if (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 2100) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1000) || \ + HWY_OS_APPLE || HWY_IS_MSAN +#define HWY_BROKEN_SVE2_128 (HWY_SVE2_128) +#else +#define HWY_BROKEN_SVE2_128 0 +#endif +#endif // HWY_BROKEN_SVE2_128 + +#ifndef HWY_BROKEN_PPC10 // allow override +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1100) +// GCC 10 supports the -mcpu=power10 option but does not support the PPC10 +// vector intrinsics +#define HWY_BROKEN_PPC10 (HWY_PPC10) +#elif HWY_ARCH_PPC && HWY_IS_BIG_ENDIAN && \ + ((HWY_COMPILER3_CLANG && HWY_COMPILER3_CLANG < 160001) || \ + (HWY_COMPILER_GCC_ACTUAL >= 1200 && HWY_COMPILER_GCC_ACTUAL <= 1203) || \ + (HWY_COMPILER_GCC_ACTUAL >= 1300 && HWY_COMPILER_GCC_ACTUAL <= 1301)) +// GCC 12.0 through 12.3 and GCC 13.0 through 13.1 have a compiler bug where the +// vsldoi instruction is sometimes incorrectly optimized out (and this causes +// some of the Highway unit tests to fail on big-endian PPC10). Details about +// this compiler bug can be found at +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109069, and this bug will be +// fixed in the upcoming GCC 12.4 and 13.2 releases. + +// Clang 16.0.0 and earlier (but not Clang 16.0.1 and later) have a compiler +// bug in the LLVM DAGCombiner that causes a zero-extend followed by an +// element insert into a vector, followed by a vector shuffle to be incorrectly +// optimized on big-endian PPC (and which caused some of the Highway unit tests +// to fail on big-endian PPC10). + +// Details about this bug, which has already been fixed in Clang 16.0.1 and +// later, can be found at https://github.com/llvm/llvm-project/issues/61315. +#define HWY_BROKEN_PPC10 (HWY_PPC10) +#else +#define HWY_BROKEN_PPC10 0 +#endif +#endif // HWY_BROKEN_PPC10 + +#ifndef HWY_BROKEN_PPC_32BIT // allow override +// PPC8/PPC9/PPC10 targets may fail to compile on 32-bit PowerPC +#if HWY_ARCH_PPC && !HWY_ARCH_PPC_64 +#define HWY_BROKEN_PPC_32BIT (HWY_PPC8 | HWY_PPC9 | HWY_PPC10) +#else +#define HWY_BROKEN_PPC_32BIT 0 +#endif +#endif // HWY_BROKEN_PPC_32BIT + +#ifndef HWY_BROKEN_RVV // allow override +// HWY_RVV fails to compile with GCC < 13 or Clang < 16. +#if HWY_ARCH_RISCV && \ + ((HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1600) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1300)) +#define HWY_BROKEN_RVV (HWY_RVV) +#else +#define HWY_BROKEN_RVV 0 +#endif +#endif // HWY_BROKEN_RVV + +#ifndef HWY_BROKEN_LOONGARCH // allow override +// Using __loongarch_sx and __loongarch_asx macros to +// check whether LSX/LASX targets are available. +// GCC does not work yet, see https://gcc.gnu.org/PR121875. +#if !defined(__loongarch_sx) && \ + !(HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 1800) +#define HWY_BROKEN_LOONGARCH (HWY_LSX | HWY_LASX) +#elif !defined(__loongarch_asx) && \ + !(HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 1800) +#define HWY_BROKEN_LOONGARCH (HWY_LASX) +#else +#define HWY_BROKEN_LOONGARCH 0 +#endif +#endif // HWY_BROKEN_LOONGARCH + +#ifndef HWY_BROKEN_Z14 // allow override +#if HWY_ARCH_S390X +#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1900 +// Clang 18 and earlier have bugs with some ZVector intrinsics +#define HWY_BROKEN_Z14 (HWY_Z14 | HWY_Z15) +#elif HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 +// Z15 target requires GCC 9 or later +#define HWY_BROKEN_Z14 (HWY_Z15) +#else +#define HWY_BROKEN_Z14 0 +#endif +#else // !HWY_ARCH_S390X +#define HWY_BROKEN_Z14 0 +#endif // HWY_ARCH_S390X +#endif // HWY_BROKEN_Z14 + +// Allow the user to override this without any guarantee of success. +#ifndef HWY_BROKEN_TARGETS + +#define HWY_BROKEN_TARGETS \ + (HWY_BROKEN_CLANG6 | HWY_BROKEN_32BIT | HWY_BROKEN_MSVC | \ + HWY_BROKEN_AVX10_2 | HWY_BROKEN_AVX3_DL_ZEN4 | HWY_BROKEN_AVX3_SPR | \ + HWY_BROKEN_ARM7_BIG_ENDIAN | HWY_BROKEN_ARM7_WITHOUT_VFP4 | \ + HWY_BROKEN_NEON_BF16 | HWY_BROKEN_SVE | HWY_BROKEN_SVE2 | \ + HWY_BROKEN_SVE2_128 | HWY_BROKEN_PPC10 | HWY_BROKEN_PPC_32BIT | \ + HWY_BROKEN_RVV | HWY_BROKEN_LOONGARCH | HWY_BROKEN_Z14) + +#endif // HWY_BROKEN_TARGETS + +// Enabled means not disabled nor blocklisted. +#define HWY_ENABLED(targets) \ + ((targets) & ~((HWY_DISABLED_TARGETS) | (HWY_BROKEN_TARGETS))) + +// Opt-out for EMU128 (affected by a GCC bug on multiple arches, fixed in 12.3: +// see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=106322). An issue still +// remains with 13.2, see #1683. This is separate from HWY_BROKEN_TARGETS +// because it affects the fallback target, which must always be enabled. If 1, +// we instead choose HWY_SCALAR even without HWY_COMPILE_ONLY_SCALAR being set. +#if !defined(HWY_BROKEN_EMU128) // allow overriding +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1600) || \ + defined(HWY_NO_LIBCXX) +#define HWY_BROKEN_EMU128 1 +#else +#define HWY_BROKEN_EMU128 0 +#endif +#endif // HWY_BROKEN_EMU128 + +//------------------------------------------------------------------------------ +// Detect baseline targets using predefined macros + +// Baseline means the targets for which the compiler is allowed to generate +// instructions, implying the target CPU would have to support them. This does +// not take the blocklist into account. + +#if defined(HWY_COMPILE_ONLY_SCALAR) || HWY_BROKEN_EMU128 +#define HWY_BASELINE_SCALAR HWY_SCALAR +#else +#define HWY_BASELINE_SCALAR HWY_EMU128 +#endif + +// Also check HWY_ARCH to ensure that simulating unknown platforms ends up with +// HWY_TARGET == HWY_BASELINE_SCALAR. + +#if HWY_ARCH_WASM && defined(__wasm_simd128__) +#if defined(HWY_WANT_WASM2) +#define HWY_BASELINE_WASM HWY_WASM_EMU256 +#else +#define HWY_BASELINE_WASM HWY_WASM +#endif // HWY_WANT_WASM2 +#else +#define HWY_BASELINE_WASM 0 +#endif + +// GCC or Clang. +#if HWY_ARCH_PPC && HWY_COMPILER_GCC && defined(__ALTIVEC__) && \ + defined(__VSX__) && defined(__POWER8_VECTOR__) && \ + (defined(__CRYPTO__) || defined(HWY_DISABLE_PPC8_CRYPTO)) +#define HWY_BASELINE_PPC8 HWY_PPC8 +#else +#define HWY_BASELINE_PPC8 0 +#endif + +#if HWY_BASELINE_PPC8 != 0 && defined(__POWER9_VECTOR__) +#define HWY_BASELINE_PPC9 HWY_PPC9 +#else +#define HWY_BASELINE_PPC9 0 +#endif + +#if HWY_BASELINE_PPC9 != 0 && \ + (defined(_ARCH_PWR10) || defined(__POWER10_VECTOR__)) +#define HWY_BASELINE_PPC10 HWY_PPC10 +#else +#define HWY_BASELINE_PPC10 0 +#endif + +#if HWY_ARCH_S390X && defined(__VEC__) && defined(__ARCH__) && __ARCH__ >= 12 +#define HWY_BASELINE_Z14 HWY_Z14 +#else +#define HWY_BASELINE_Z14 0 +#endif + +#if HWY_BASELINE_Z14 && __ARCH__ >= 13 +#define HWY_BASELINE_Z15 HWY_Z15 +#else +#define HWY_BASELINE_Z15 0 +#endif + +#define HWY_BASELINE_SVE2 0 +#define HWY_BASELINE_SVE 0 +#define HWY_BASELINE_NEON 0 + +#if HWY_ARCH_ARM + +// Also check compiler version as done for HWY_ATTAINABLE_SVE2 because the +// static target (influenced here) must be one of the attainable targets. +#if defined(__ARM_FEATURE_SVE2) && \ + (HWY_COMPILER_CLANG >= 1400 || HWY_COMPILER_GCC_ACTUAL >= 1200) +#undef HWY_BASELINE_SVE2 // was 0, will be re-defined +// If user specified -msve-vector-bits=128, they assert the vector length is +// 128 bits and we should use the HWY_SVE2_128 (more efficient for some ops). +#if defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS == 128 +#define HWY_BASELINE_SVE2 HWY_SVE2_128 +// Otherwise we're not sure what the vector length will be. The baseline must be +// unconditionally valid, so we can only assume HWY_SVE2. However, when running +// on a CPU with 128-bit vectors, user code that supports dynamic dispatch will +// still benefit from HWY_SVE2_128 because we add it to HWY_ATTAINABLE_TARGETS. +#else +#define HWY_BASELINE_SVE2 HWY_SVE2 +#endif // __ARM_FEATURE_SVE_BITS +#endif // __ARM_FEATURE_SVE2 + +#if defined(__ARM_FEATURE_SVE) && \ + (HWY_COMPILER_CLANG >= 900 || HWY_COMPILER_GCC_ACTUAL >= 800) +#undef HWY_BASELINE_SVE // was 0, will be re-defined +// See above. If user-specified vector length matches our optimization, use it. +#if defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS == 256 +#define HWY_BASELINE_SVE HWY_SVE_256 +#else +#define HWY_BASELINE_SVE HWY_SVE +#endif // __ARM_FEATURE_SVE_BITS +#endif // __ARM_FEATURE_SVE + +// GCC 4.5.4 only defines __ARM_NEON__; 5.4 defines both. +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#undef HWY_BASELINE_NEON +#if defined(__ARM_FEATURE_AES) && \ + defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && \ + defined(__ARM_FEATURE_DOTPROD) && \ + defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) && \ + defined(__ARM_FEATURE_MATMUL_INT8) +#define HWY_BASELINE_NEON HWY_ALL_NEON +#elif defined(__ARM_FEATURE_AES) +#define HWY_BASELINE_NEON (HWY_NEON_WITHOUT_AES | HWY_NEON) +#else +#define HWY_BASELINE_NEON (HWY_NEON_WITHOUT_AES) +#endif // __ARM_FEATURE* +#endif // __ARM_NEON + +#endif // HWY_ARCH_ARM + +// Special handling for MSVC because it has fewer predefined macros: +#if HWY_COMPILER_MSVC + +#if HWY_ARCH_X86_32 +#if _M_IX86_FP >= 2 +#define HWY_CHECK_SSE2 1 +#else +#define HWY_CHECK_SSE2 0 +#endif +#elif HWY_ARCH_X86_64 +#define HWY_CHECK_SSE2 1 +#else +#define HWY_CHECK_SSE2 0 +#endif + +// 1) We can only be sure SSSE3/SSE4 are enabled if AVX is: +// https://stackoverflow.com/questions/18563978/. +#if defined(__AVX__) +#define HWY_CHECK_SSSE3 1 +#define HWY_CHECK_SSE4 1 +#else +#define HWY_CHECK_SSSE3 0 +#define HWY_CHECK_SSE4 0 +#endif + +// 2) Cannot check for PCLMUL/AES and BMI2/FMA/F16C individually; we assume +// PCLMUL/AES are available if SSE4 is, and BMI2/FMA/F16C if AVX2 is. +#define HWY_CHECK_PCLMUL_AES 1 +#define HWY_CHECK_BMI2_FMA 1 +#define HWY_CHECK_F16C 1 + +#else // non-MSVC + +#if defined(__SSE2__) +#define HWY_CHECK_SSE2 1 +#else +#define HWY_CHECK_SSE2 0 +#endif + +#if defined(__SSSE3__) +#define HWY_CHECK_SSSE3 1 +#else +#define HWY_CHECK_SSSE3 0 +#endif + +#if defined(__SSE4_1__) && defined(__SSE4_2__) +#define HWY_CHECK_SSE4 1 +#else +#define HWY_CHECK_SSE4 0 +#endif + +// If these are disabled, they should not gate the availability of SSE4/AVX2. +#if defined(HWY_DISABLE_PCLMUL_AES) || (defined(__PCLMUL__) && defined(__AES__)) +#define HWY_CHECK_PCLMUL_AES 1 +#else +#define HWY_CHECK_PCLMUL_AES 0 +#endif + +#if defined(HWY_DISABLE_BMI2_FMA) || (defined(__BMI2__) && defined(__FMA__)) +#define HWY_CHECK_BMI2_FMA 1 +#else +#define HWY_CHECK_BMI2_FMA 0 +#endif + +#if defined(HWY_DISABLE_F16C) || defined(__F16C__) +#define HWY_CHECK_F16C 1 +#else +#define HWY_CHECK_F16C 0 +#endif + +#endif // non-MSVC + +#if HWY_ARCH_X86 && \ + ((defined(HWY_WANT_SSE2) && HWY_WANT_SSE2) || HWY_CHECK_SSE2) +#define HWY_BASELINE_SSE2 HWY_SSE2 +#else +#define HWY_BASELINE_SSE2 0 +#endif + +#if HWY_ARCH_X86 && \ + ((defined(HWY_WANT_SSSE3) && HWY_WANT_SSSE3) || HWY_CHECK_SSSE3) +#define HWY_BASELINE_SSSE3 HWY_SSSE3 +#else +#define HWY_BASELINE_SSSE3 0 +#endif + +#if HWY_ARCH_X86 && ((defined(HWY_WANT_SSE4) && HWY_WANT_SSE4) || \ + (HWY_CHECK_SSE4 && HWY_CHECK_PCLMUL_AES)) +#define HWY_BASELINE_SSE4 HWY_SSE4 +#else +#define HWY_BASELINE_SSE4 0 +#endif + +#if HWY_BASELINE_SSE4 != 0 && HWY_CHECK_BMI2_FMA && HWY_CHECK_F16C && \ + defined(__AVX2__) +#define HWY_BASELINE_AVX2 HWY_AVX2 +#else +#define HWY_BASELINE_AVX2 0 +#endif + +// Require everything in AVX2 plus AVX-512 flags (also set by MSVC) +#if HWY_BASELINE_AVX2 != 0 && \ + ((defined(__AVX512F__) && defined(__AVX512BW__) && \ + defined(__AVX512DQ__) && defined(__AVX512VL__)) || \ + defined(__AVX10_2__)) && \ + ((!HWY_COMPILER_GCC_ACTUAL && !HWY_COMPILER_CLANG) || \ + HWY_COMPILER_GCC_ACTUAL < 1400 || HWY_COMPILER_CLANG < 1800 || \ + defined(__EVEX512__)) +#define HWY_BASELINE_AVX3 HWY_AVX3 +#else +#define HWY_BASELINE_AVX3 0 +#endif + +// TODO(janwas): not yet known whether these will be set by MSVC +#if HWY_BASELINE_AVX3 != 0 && \ + ((defined(__AVX512VNNI__) && defined(__VAES__) && \ + defined(__VPCLMULQDQ__) && defined(__AVX512VBMI__) && \ + defined(__AVX512VBMI2__) && defined(__AVX512VPOPCNTDQ__) && \ + defined(__AVX512BITALG__)) || \ + defined(__AVX10_2__)) +#define HWY_BASELINE_AVX3_DL HWY_AVX3_DL +#else +#define HWY_BASELINE_AVX3_DL 0 +#endif + +// The ZEN4-optimized AVX3 target is numerically lower than AVX3_DL and is thus +// considered better. Do not enable it unless the user explicitly requests it - +// we do not want to choose the ZEN4 path on Intel because it could be slower. +#if defined(HWY_WANT_AVX3_ZEN4) && HWY_BASELINE_AVX3_DL != 0 +#define HWY_BASELINE_AVX3_ZEN4 HWY_AVX3_ZEN4 +#else +#define HWY_BASELINE_AVX3_ZEN4 0 +#endif + +#if HWY_BASELINE_AVX3_DL != 0 && \ + ((defined(__AVX512BF16__) && defined(__AVX512FP16__)) || \ + defined(__AVX10_2__)) +#define HWY_BASELINE_AVX3_SPR HWY_AVX3_SPR +#else +#define HWY_BASELINE_AVX3_SPR 0 +#endif + +#if HWY_BASELINE_AVX3_SPR != 0 && defined(__AVX10_2__) +#define HWY_BASELINE_AVX10_2 HWY_AVX10_2 +#else +#define HWY_BASELINE_AVX10_2 0 +#endif + +// RVV requires intrinsics 0.11 or later, see #1156. + +// Also check that the __riscv_v macro is defined as GCC or Clang will define +// the __risc_v macro if the RISC-V "V" extension is enabled. + +#if HWY_ARCH_RISCV && defined(__riscv_v) && defined(__riscv_v_intrinsic) && \ + __riscv_v_intrinsic >= 11000 +#define HWY_BASELINE_RVV HWY_RVV +#else +#define HWY_BASELINE_RVV 0 +#endif + +#if HWY_ARCH_LOONGARCH && defined(__loongarch_sx) && defined(__loongarch_asx) +#define HWY_BASELINE_LOONGARCH (HWY_LSX | HWY_LASX) +#elif HWY_ARCH_LOONGARCH && defined(__loongarch_sx) +#define HWY_BASELINE_LOONGARCH (HWY_LSX) +#else +#define HWY_BASELINE_LOONGARCH 0 +#endif + +// Workaround for libaom, which unconditionally defines HWY_BASELINE_TARGETS +// even when that would be disabled/broken. If so, at least use AVX2. +#if defined(HWY_BASELINE_TARGETS) +#if HWY_BASELINE_TARGETS == HWY_AVX3_DL && \ + ((HWY_BROKEN_TARGETS | HWY_DISABLED_TARGETS) & HWY_AVX3_DL) +#undef HWY_BASELINE_TARGETS +#define HWY_BASELINE_TARGETS HWY_AVX2 +#endif +#endif // HWY_BASELINE_TARGETS + +// Allow the user to override this without any guarantee of success. If the +// compiler invocation considers that target to be broken/disabled, then +// `HWY_ENABLED_BASELINE` will be 0 and users will have to check for that and +// skip their code. +#ifndef HWY_BASELINE_TARGETS +#define HWY_BASELINE_TARGETS \ + (HWY_BASELINE_SCALAR | HWY_BASELINE_WASM | HWY_BASELINE_PPC8 | \ + HWY_BASELINE_PPC9 | HWY_BASELINE_PPC10 | HWY_BASELINE_Z14 | \ + HWY_BASELINE_Z15 | HWY_BASELINE_SVE2 | HWY_BASELINE_SVE | \ + HWY_BASELINE_NEON | HWY_BASELINE_SSE2 | HWY_BASELINE_SSSE3 | \ + HWY_BASELINE_SSE4 | HWY_BASELINE_AVX2 | HWY_BASELINE_AVX3 | \ + HWY_BASELINE_AVX3_DL | HWY_BASELINE_AVX3_ZEN4 | HWY_BASELINE_AVX3_SPR | \ + HWY_BASELINE_AVX10_2 | HWY_BASELINE_RVV | HWY_BASELINE_LOONGARCH) +#endif // HWY_BASELINE_TARGETS + +//------------------------------------------------------------------------------ +// Choose target for static dispatch + +#define HWY_ENABLED_BASELINE HWY_ENABLED(HWY_BASELINE_TARGETS) +#if HWY_ENABLED_BASELINE == 0 +#pragma message \ + "All baseline targets are disabled or considered broken." \ + "This is typically due to very restrictive HWY_BASELINE_TARGETS, or " \ + "too expansive HWY_BROKEN_TARGETS or HWY_DISABLED_TAREGTS. User code " \ + "must also check for this and skip any usage of SIMD." +#endif + +// Best baseline, used for static dispatch. This is the least-significant 1-bit +// within HWY_ENABLED_BASELINE and lower bit values imply "better". +#define HWY_STATIC_TARGET (HWY_ENABLED_BASELINE & -HWY_ENABLED_BASELINE) + +// Start by assuming static dispatch. If we later use dynamic dispatch, this +// will be defined to other targets during the multiple-inclusion, and finally +// return to the initial value. Defining this outside begin/end_target ensures +// inl headers successfully compile by themselves (required by Bazel). +#define HWY_TARGET HWY_STATIC_TARGET + +//------------------------------------------------------------------------------ +// Choose targets for dynamic dispatch according to one of four policies + +#if 1 < (defined(HWY_COMPILE_ONLY_SCALAR) + defined(HWY_COMPILE_ONLY_EMU128) + \ + defined(HWY_COMPILE_ONLY_STATIC)) +#error "Can only define one of HWY_COMPILE_ONLY_{SCALAR|EMU128|STATIC} - bug?" +#endif +// Defining one of HWY_COMPILE_ONLY_* will trump HWY_COMPILE_ALL_ATTAINABLE. + +#ifndef HWY_HAVE_ASM_HWCAP // allow override +#ifdef TOOLCHAIN_MISS_ASM_HWCAP_H +#define HWY_HAVE_ASM_HWCAP 0 // CMake failed to find the header +#elif defined(__has_include) // note: wrapper macro fails on Clang ~17 +// clang-format off +#if __has_include() +// clang-format on +#define HWY_HAVE_ASM_HWCAP 1 // header present +#else +#define HWY_HAVE_ASM_HWCAP 0 // header not present +#endif // __has_include +#else // compiler lacks __has_include +#define HWY_HAVE_ASM_HWCAP 0 +#endif +#endif // HWY_HAVE_ASM_HWCAP + +#ifndef HWY_HAVE_AUXV // allow override +#ifdef TOOLCHAIN_MISS_SYS_AUXV_H +#define HWY_HAVE_AUXV 0 // CMake failed to find the header +// glibc 2.16 added auxv, but checking for that requires features.h, and we do +// not want to include system headers here. Instead check for the header +// directly, which has been supported at least since GCC 5.4 and Clang 3. +#elif defined(__has_include) // note: wrapper macro fails on Clang ~17 +// clang-format off +#if __has_include() +// clang-format on +#define HWY_HAVE_AUXV 1 // header present +#else +#define HWY_HAVE_AUXV 0 // header not present +#endif // __has_include +#else // compiler lacks __has_include +#define HWY_HAVE_AUXV 0 +#endif +#endif // HWY_HAVE_AUXV + +#ifndef HWY_HAVE_RUNTIME_DISPATCH_RVV // allow override +// Clang 19+ supports target attributes for RVV intrinsics (resolved in +// https://github.com/llvm/llvm-project/issues/56592 and +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=115325). +#if HWY_ARCH_RISCV && HWY_COMPILER_CLANG >= 1900 +#define HWY_HAVE_RUNTIME_DISPATCH_RVV 1 +#else +#define HWY_HAVE_RUNTIME_DISPATCH_RVV 0 +#endif +#endif // HWY_HAVE_RUNTIME_DISPATCH_RVV + +#ifndef HWY_HAVE_RUNTIME_DISPATCH_APPLE // allow override +#if HWY_ARCH_ARM_A64 && HWY_OS_APPLE && \ + (HWY_COMPILER_GCC_ACTUAL || HWY_COMPILER_CLANG >= 1700) +#define HWY_HAVE_RUNTIME_DISPATCH_APPLE 1 +#else +#define HWY_HAVE_RUNTIME_DISPATCH_APPLE 0 +#endif +#endif // HWY_HAVE_RUNTIME_DISPATCH_APPLE + +#ifndef HWY_HAVE_RUNTIME_DISPATCH_LOONGARCH // allow override +#if HWY_ARCH_LOONGARCH && HWY_HAVE_AUXV && !defined(__loongarch_asx) && \ + HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 1800 +#define HWY_HAVE_RUNTIME_DISPATCH_LOONGARCH 1 +#else +#define HWY_HAVE_RUNTIME_DISPATCH_LOONGARCH 0 +#endif +#endif // HWY_HAVE_RUNTIME_DISPATCH_LOONGARCH + +#ifndef HWY_HAVE_RUNTIME_DISPATCH_LINUX // allow override +#if (HWY_ARCH_ARM || HWY_ARCH_PPC || HWY_ARCH_S390X) && HWY_OS_LINUX && \ + (HWY_COMPILER_GCC_ACTUAL || HWY_COMPILER_CLANG >= 1700) && HWY_HAVE_AUXV +#define HWY_HAVE_RUNTIME_DISPATCH_LINUX 1 +#else +#define HWY_HAVE_RUNTIME_DISPATCH_LINUX 0 +#endif +#endif // HWY_HAVE_RUNTIME_DISPATCH_LINUX + +// Allow opting out, and without a guarantee of success, opting-in. +#ifndef HWY_HAVE_RUNTIME_DISPATCH +// Clang, GCC and MSVC allow OS-independent runtime dispatch on x86. +// Wasm does not, because browsers reject a binary containing any SIMD +// instructions when the browser does not support them. Typical practice there +// is to build two binaries, one with the -msimd128 flag. +#if HWY_ARCH_X86 || HWY_HAVE_RUNTIME_DISPATCH_RVV || \ + HWY_HAVE_RUNTIME_DISPATCH_APPLE || HWY_HAVE_RUNTIME_DISPATCH_LOONGARCH || \ + HWY_HAVE_RUNTIME_DISPATCH_LINUX +#define HWY_HAVE_RUNTIME_DISPATCH 1 +#else +#define HWY_HAVE_RUNTIME_DISPATCH 0 +#endif +#endif // HWY_HAVE_RUNTIME_DISPATCH + +#if HWY_ARCH_ARM_A64 && HWY_HAVE_RUNTIME_DISPATCH +#define HWY_ATTAINABLE_NEON HWY_ALL_NEON +#elif HWY_ARCH_ARM // static dispatch, or HWY_ARCH_ARM_V7 +#define HWY_ATTAINABLE_NEON (HWY_BASELINE_NEON) +#else +#define HWY_ATTAINABLE_NEON 0 +#endif + +#if HWY_ARCH_ARM_A64 && \ + (HWY_COMPILER_CLANG >= 900 || HWY_COMPILER_GCC_ACTUAL >= 800) && \ + (HWY_HAVE_RUNTIME_DISPATCH || \ + (HWY_ENABLED_BASELINE & (HWY_SVE | HWY_SVE_256))) +#define HWY_ATTAINABLE_SVE (HWY_SVE | HWY_SVE_256) +#else +#define HWY_ATTAINABLE_SVE 0 +#endif + +#if HWY_ARCH_ARM_A64 && \ + (HWY_COMPILER_CLANG >= 1400 || HWY_COMPILER_GCC_ACTUAL >= 1200) && \ + (HWY_HAVE_RUNTIME_DISPATCH || \ + (HWY_ENABLED_BASELINE & (HWY_SVE2 | HWY_SVE2_128))) +#define HWY_ATTAINABLE_SVE2 (HWY_SVE2 | HWY_SVE2_128) +#else +#define HWY_ATTAINABLE_SVE2 0 +#endif + +#if HWY_ARCH_PPC && defined(__ALTIVEC__) && \ + (!HWY_COMPILER_CLANG || HWY_BASELINE_PPC8 != 0) + +#if (HWY_BASELINE_PPC9 | HWY_BASELINE_PPC10) && \ + !defined(HWY_SKIP_NON_BEST_BASELINE) +// On POWER with -m flags, we get compile errors (#1707) for targets older than +// the baseline specified via -m, so only generate the static target and better. +// Note that some Linux distros actually do set POWER9 as the baseline. +// This works by skipping case 3 below, so case 4 is reached. +#define HWY_SKIP_NON_BEST_BASELINE +#endif + +#define HWY_ATTAINABLE_PPC (HWY_PPC8 | HWY_PPC9 | HWY_PPC10) + +#else +#define HWY_ATTAINABLE_PPC 0 +#endif + +#if HWY_ARCH_S390X && HWY_BASELINE_Z14 != 0 +#define HWY_ATTAINABLE_S390X (HWY_Z14 | HWY_Z15) +#else +#define HWY_ATTAINABLE_S390X 0 +#endif + +#if HWY_ARCH_RISCV && HWY_HAVE_RUNTIME_DISPATCH +#define HWY_ATTAINABLE_RISCV HWY_RVV +#else +#define HWY_ATTAINABLE_RISCV HWY_BASELINE_RVV +#endif + +#if HWY_ARCH_LOONGARCH && HWY_HAVE_RUNTIME_DISPATCH +#define HWY_ATTAINABLE_LOONGARCH (HWY_LSX | HWY_LASX) +#else +#define HWY_ATTAINABLE_LOONGARCH HWY_BASELINE_LOONGARCH +#endif + +#ifndef HWY_ATTAINABLE_TARGETS_X86 // allow override +#if HWY_COMPILER_MSVC && defined(HWY_SLOW_MSVC) +// Fewer targets for faster builds. +#define HWY_ATTAINABLE_TARGETS_X86 \ + HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_STATIC_TARGET | HWY_AVX2) +#else // !HWY_COMPILER_MSVC +#define HWY_ATTAINABLE_TARGETS_X86 \ + HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_SSE2 | HWY_SSSE3 | HWY_SSE4 | \ + HWY_AVX2 | HWY_AVX3 | HWY_AVX3_DL | HWY_AVX3_ZEN4 | \ + HWY_AVX3_SPR | HWY_AVX10_2) +#endif // !HWY_COMPILER_MSVC +#endif // HWY_ATTAINABLE_TARGETS_X86 + +// Attainable means enabled and the compiler allows intrinsics (even when not +// allowed to auto-vectorize). Used in 3 and 4. +#if HWY_ARCH_X86 +#define HWY_ATTAINABLE_TARGETS HWY_ATTAINABLE_TARGETS_X86 +#elif HWY_ARCH_ARM +#define HWY_ATTAINABLE_TARGETS \ + HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_ATTAINABLE_NEON | HWY_ATTAINABLE_SVE | \ + HWY_ATTAINABLE_SVE2) +#elif HWY_ARCH_PPC +#define HWY_ATTAINABLE_TARGETS \ + HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_ATTAINABLE_PPC) +#elif HWY_ARCH_S390X +#define HWY_ATTAINABLE_TARGETS \ + HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_ATTAINABLE_S390X) +#elif HWY_ARCH_RISCV +#define HWY_ATTAINABLE_TARGETS \ + HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_ATTAINABLE_RISCV) +#elif HWY_ARCH_LOONGARCH +#define HWY_ATTAINABLE_TARGETS \ + HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_ATTAINABLE_LOONGARCH) +#else +#define HWY_ATTAINABLE_TARGETS (HWY_ENABLED_BASELINE) +#endif // HWY_ARCH_* + +// 1) For older compilers: avoid SIMD intrinsics, but still support all ops. +#if defined(HWY_COMPILE_ONLY_EMU128) && !HWY_BROKEN_EMU128 +#undef HWY_STATIC_TARGET +#define HWY_STATIC_TARGET HWY_EMU128 // override baseline +#define HWY_TARGETS HWY_EMU128 + +// 1b) HWY_SCALAR is less capable than HWY_EMU128 (which supports all ops), but +// we currently still support it for backwards compatibility. +#elif defined(HWY_COMPILE_ONLY_SCALAR) || \ + (defined(HWY_COMPILE_ONLY_EMU128) && HWY_BROKEN_EMU128) +#undef HWY_STATIC_TARGET +#define HWY_STATIC_TARGET HWY_SCALAR // override baseline +#define HWY_TARGETS HWY_SCALAR + +// 2) For forcing static dispatch without code changes (removing HWY_EXPORT) +#elif defined(HWY_COMPILE_ONLY_STATIC) +#define HWY_TARGETS HWY_STATIC_TARGET + +// 3) For tests: include all attainable targets (in particular: scalar) +#elif (defined(HWY_COMPILE_ALL_ATTAINABLE) || defined(HWY_IS_TEST)) && \ + !defined(HWY_SKIP_NON_BEST_BASELINE) +#define HWY_TARGETS HWY_ATTAINABLE_TARGETS + +// 4) Default: attainable WITHOUT non-best baseline. This reduces code size by +// excluding superseded targets, in particular scalar. Note: HWY_STATIC_TARGET +// may be 2^62 (HWY_SCALAR), so we must not left-shift/add it. Subtracting one +// sets all lower bits (better targets), then we also include the static target. +#else +#define HWY_TARGETS \ + (HWY_ATTAINABLE_TARGETS & ((HWY_STATIC_TARGET - 1LL) | HWY_STATIC_TARGET)) + +#endif // target policy + +// HWY_ONCE and the multiple-inclusion mechanism rely on HWY_STATIC_TARGET being +// one of the dynamic targets. This also implies HWY_TARGETS != 0 and +// (HWY_TARGETS & HWY_ENABLED_BASELINE) != 0. +#if (HWY_TARGETS & HWY_STATIC_TARGET) == 0 && HWY_ENABLED_BASELINE != 0 +#error "Logic error: best baseline should be included in dynamic targets" +#endif + +#endif // HIGHWAY_HWY_DETECT_TARGETS_H_ diff --git a/lib/highway/hwy/foreach_target.h b/lib/highway/hwy/foreach_target.h new file mode 100644 index 00000000000..18a6f3387bc --- /dev/null +++ b/lib/highway/hwy/foreach_target.h @@ -0,0 +1,410 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_FOREACH_TARGET_H_ +#define HIGHWAY_HWY_FOREACH_TARGET_H_ + +// Re-includes the translation unit zero or more times to compile for any +// targets except HWY_STATIC_TARGET. Defines unique HWY_TARGET each time so that +// highway.h defines the corresponding macro/namespace. + +#include "hwy/detect_targets.h" + +// *_inl.h may include other headers, which requires include guards to prevent +// repeated inclusion. The guards must be reset after compiling each target, so +// the header is again visible. This is done by flipping HWY_TARGET_TOGGLE, +// defining it if undefined and vice versa. This macro is initially undefined +// so that IDEs don't gray out the contents of each header. +#ifdef HWY_TARGET_TOGGLE +#error "This macro must not be defined outside foreach_target.h" +#endif + +#ifdef HWY_HIGHWAY_INCLUDED // highway.h include guard +// Trigger fixup at the bottom of this header. +#define HWY_ALREADY_INCLUDED + +// The next highway.h must re-include set_macros-inl.h because the first +// highway.h chose the static target instead of what we will set below. +#undef HWY_SET_MACROS_PER_TARGET +#endif + +// Disable HWY_EXPORT in user code until we have generated all targets. Note +// that a subsequent highway.h will not override this definition. +#undef HWY_ONCE +#define HWY_ONCE (0 || HWY_IDE) + +// Avoid warnings on #include HWY_TARGET_INCLUDE by hiding them from the IDE; +// also skip if only 1 target defined (no re-inclusion will be necessary). +#if !HWY_IDE && (HWY_TARGETS != HWY_STATIC_TARGET) + +#if !defined(HWY_TARGET_INCLUDE) +#error ">1 target enabled => define HWY_TARGET_INCLUDE before foreach_target.h" +#endif + +// ------------------------------ HWY_ARCH_X86 + +#if (HWY_TARGETS & HWY_SSE2) && (HWY_STATIC_TARGET != HWY_SSE2) +#undef HWY_TARGET +#define HWY_TARGET HWY_SSE2 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SSSE3) && (HWY_STATIC_TARGET != HWY_SSSE3) +#undef HWY_TARGET +#define HWY_TARGET HWY_SSSE3 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SSE4) && (HWY_STATIC_TARGET != HWY_SSE4) +#undef HWY_TARGET +#define HWY_TARGET HWY_SSE4 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_AVX2) && (HWY_STATIC_TARGET != HWY_AVX2) +#undef HWY_TARGET +#define HWY_TARGET HWY_AVX2 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_AVX3) && (HWY_STATIC_TARGET != HWY_AVX3) +#undef HWY_TARGET +#define HWY_TARGET HWY_AVX3 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_AVX3_DL) && (HWY_STATIC_TARGET != HWY_AVX3_DL) +#undef HWY_TARGET +#define HWY_TARGET HWY_AVX3_DL +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_AVX3_ZEN4) && (HWY_STATIC_TARGET != HWY_AVX3_ZEN4) +#undef HWY_TARGET +#define HWY_TARGET HWY_AVX3_ZEN4 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_AVX3_SPR) && (HWY_STATIC_TARGET != HWY_AVX3_SPR) +#undef HWY_TARGET +#define HWY_TARGET HWY_AVX3_SPR +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_AVX10_2) && (HWY_STATIC_TARGET != HWY_AVX10_2) +#undef HWY_TARGET +#define HWY_TARGET HWY_AVX10_2 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +// ------------------------------ HWY_ARCH_ARM + +#if (HWY_TARGETS & HWY_NEON_WITHOUT_AES) && \ + (HWY_STATIC_TARGET != HWY_NEON_WITHOUT_AES) +#undef HWY_TARGET +#define HWY_TARGET HWY_NEON_WITHOUT_AES +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_NEON) && (HWY_STATIC_TARGET != HWY_NEON) +#undef HWY_TARGET +#define HWY_TARGET HWY_NEON +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_NEON_BF16) && (HWY_STATIC_TARGET != HWY_NEON_BF16) +#undef HWY_TARGET +#define HWY_TARGET HWY_NEON_BF16 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SVE) && (HWY_STATIC_TARGET != HWY_SVE) +#undef HWY_TARGET +#define HWY_TARGET HWY_SVE +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SVE2) && (HWY_STATIC_TARGET != HWY_SVE2) +#undef HWY_TARGET +#define HWY_TARGET HWY_SVE2 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SVE_256) && (HWY_STATIC_TARGET != HWY_SVE_256) +#undef HWY_TARGET +#define HWY_TARGET HWY_SVE_256 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SVE2_128) && (HWY_STATIC_TARGET != HWY_SVE2_128) +#undef HWY_TARGET +#define HWY_TARGET HWY_SVE2_128 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +// ------------------------------ HWY_ARCH_WASM + +#if (HWY_TARGETS & HWY_WASM_EMU256) && (HWY_STATIC_TARGET != HWY_WASM_EMU256) +#undef HWY_TARGET +#define HWY_TARGET HWY_WASM_EMU256 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_WASM) && (HWY_STATIC_TARGET != HWY_WASM) +#undef HWY_TARGET +#define HWY_TARGET HWY_WASM +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +// ------------------------------ HWY_ARCH_PPC + +#if (HWY_TARGETS & HWY_PPC8) && (HWY_STATIC_TARGET != HWY_PPC8) +#undef HWY_TARGET +#define HWY_TARGET HWY_PPC8 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_PPC9) && (HWY_STATIC_TARGET != HWY_PPC9) +#undef HWY_TARGET +#define HWY_TARGET HWY_PPC9 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_PPC10) && (HWY_STATIC_TARGET != HWY_PPC10) +#undef HWY_TARGET +#define HWY_TARGET HWY_PPC10 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +// ------------------------------ HWY_ARCH_S390X + +#if (HWY_TARGETS & HWY_Z14) && (HWY_STATIC_TARGET != HWY_Z14) +#undef HWY_TARGET +#define HWY_TARGET HWY_Z14 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_Z15) && (HWY_STATIC_TARGET != HWY_Z15) +#undef HWY_TARGET +#define HWY_TARGET HWY_Z15 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +// ------------------------------ HWY_ARCH_RISCV + +#if (HWY_TARGETS & HWY_RVV) && (HWY_STATIC_TARGET != HWY_RVV) +#undef HWY_TARGET +#define HWY_TARGET HWY_RVV +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +// ------------------------------ HWY_ARCH_LOONGARCH + +#if (HWY_TARGETS & HWY_LSX) && (HWY_STATIC_TARGET != HWY_LSX) +#undef HWY_TARGET +#define HWY_TARGET HWY_LSX +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_LASX) && (HWY_STATIC_TARGET != HWY_LASX) +#undef HWY_TARGET +#define HWY_TARGET HWY_LASX +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +// ------------------------------ Scalar + +#if (HWY_TARGETS & HWY_EMU128) && (HWY_STATIC_TARGET != HWY_EMU128) +#undef HWY_TARGET +#define HWY_TARGET HWY_EMU128 +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#if (HWY_TARGETS & HWY_SCALAR) && (HWY_STATIC_TARGET != HWY_SCALAR) +#undef HWY_TARGET +#define HWY_TARGET HWY_SCALAR +#include HWY_TARGET_INCLUDE +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif +#endif + +#endif // !HWY_IDE && (HWY_TARGETS != HWY_STATIC_TARGET) + +// Now that all but the static target have been generated, re-enable HWY_EXPORT. +#undef HWY_ONCE +#define HWY_ONCE 1 + +// If we re-include once per enabled target, the translation unit's +// implementation would have to be skipped via #if to avoid redefining symbols. +// We instead skip the re-include for HWY_STATIC_TARGET, and generate its +// implementation when resuming compilation of the translation unit. +#undef HWY_TARGET +#define HWY_TARGET HWY_STATIC_TARGET + +#ifdef HWY_ALREADY_INCLUDED +// Revert the previous toggle to prevent redefinitions for the static target. +#ifdef HWY_TARGET_TOGGLE +#undef HWY_TARGET_TOGGLE +#else +#define HWY_TARGET_TOGGLE +#endif + +// Force re-inclusion of set_macros-inl.h now that HWY_TARGET is restored. +#ifdef HWY_SET_MACROS_PER_TARGET +#undef HWY_SET_MACROS_PER_TARGET +#else +#define HWY_SET_MACROS_PER_TARGET +#endif +#endif + +#endif // HIGHWAY_HWY_FOREACH_TARGET_H_ diff --git a/lib/highway/hwy/highway.h b/lib/highway/hwy/highway.h new file mode 100644 index 00000000000..c2744efb1e1 --- /dev/null +++ b/lib/highway/hwy/highway.h @@ -0,0 +1,731 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Main header required before using vector types. + +// IWYU pragma: begin_exports +#include "hwy/base.h" +#include "hwy/detect_compiler_arch.h" +#include "hwy/detect_targets.h" +#include "hwy/highway_export.h" +#include "hwy/targets.h" +// IWYU pragma: end_exports + +#if HWY_CXX_LANG < 201703L +#define HWY_DISPATCH_MAP 1 +#else +#define HWY_DISPATCH_MAP 0 +#endif + +// This include guard is checked by foreach_target, so avoid the usual _H_ +// suffix to prevent copybara from renaming it. NOTE: ops/*-inl.h are included +// after/outside this include guard. +#ifndef HWY_HIGHWAY_INCLUDED +#define HWY_HIGHWAY_INCLUDED + +namespace hwy { + +//------------------------------------------------------------------------------ +// Shorthand for tags (defined in shared-inl.h) used to select overloads. +// Note that ScalableTag is preferred over HWY_FULL, and CappedTag over +// HWY_CAPPED(T, N). + +// HWY_FULL(T[,LMUL=1]) is a native vector/group. LMUL is the number of +// registers in the group, and is ignored on targets that do not support groups. +#define HWY_FULL1(T) hwy::HWY_NAMESPACE::ScalableTag +#define HWY_FULL2(T, LMUL) \ + hwy::HWY_NAMESPACE::ScalableTag +#define HWY_3TH_ARG(arg1, arg2, arg3, ...) arg3 +// Workaround for MSVC grouping __VA_ARGS__ into a single argument +#define HWY_FULL_RECOMPOSER(args_with_paren) HWY_3TH_ARG args_with_paren +// Trailing comma avoids -pedantic false alarm +#define HWY_CHOOSE_FULL(...) \ + HWY_FULL_RECOMPOSER((__VA_ARGS__, HWY_FULL2, HWY_FULL1, )) +#define HWY_FULL(...) HWY_CHOOSE_FULL(__VA_ARGS__())(__VA_ARGS__) + +// Vector of up to MAX_N lanes. It's better to use full vectors where possible. +#define HWY_CAPPED(T, MAX_N) hwy::HWY_NAMESPACE::CappedTag + +//------------------------------------------------------------------------------ +// Export user functions for static/dynamic dispatch + +// The static target is the best baseline. When using foreach_target.h, this is +// the last target compiled. Otherwise, it is the only target. + +// Evaluates to 0 inside a translation unit if it is generating anything but the +// static target. Used to prevent redefinitions of HWY_EXPORT. Unless +// foreach_target.h is included, we only compile once anyway, so this is 1 +// unless it is or has been included. +#ifndef HWY_ONCE +#define HWY_ONCE 1 +#endif + +// `HWY_STATIC_NAMESPACE` expands to its namespace name, e.g. `N_AVX2`. +#if HWY_STATIC_TARGET == HWY_SCALAR +#define HWY_STATIC_NAMESPACE N_SCALAR +#elif HWY_STATIC_TARGET == HWY_EMU128 +#define HWY_STATIC_NAMESPACE N_EMU128 +#elif HWY_STATIC_TARGET == HWY_WASM +#define HWY_STATIC_NAMESPACE N_WASM +#elif HWY_STATIC_TARGET == HWY_WASM_EMU256 +#define HWY_STATIC_NAMESPACE N_WASM_EMU256 +#elif HWY_STATIC_TARGET == HWY_Z14 +#define HWY_STATIC_NAMESPACE N_Z14 +#elif HWY_STATIC_TARGET == HWY_Z15 +#define HWY_STATIC_NAMESPACE N_Z15 +#elif HWY_STATIC_TARGET == HWY_PPC8 +#define HWY_STATIC_NAMESPACE N_PPC8 +#elif HWY_STATIC_TARGET == HWY_PPC9 +#define HWY_STATIC_NAMESPACE N_PPC9 +#elif HWY_STATIC_TARGET == HWY_PPC10 +#define HWY_STATIC_NAMESPACE N_PPC10 +#elif HWY_STATIC_TARGET == HWY_LSX +#define HWY_STATIC_NAMESPACE N_LSX +#elif HWY_STATIC_TARGET == HWY_LASX +#define HWY_STATIC_NAMESPACE N_LASX +#elif HWY_STATIC_TARGET == HWY_RVV +#define HWY_STATIC_NAMESPACE N_RVV +#elif HWY_STATIC_TARGET == HWY_NEON_WITHOUT_AES +#define HWY_STATIC_NAMESPACE N_NEON_WITHOUT_AES +#elif HWY_STATIC_TARGET == HWY_NEON +#define HWY_STATIC_NAMESPACE N_NEON +#elif HWY_STATIC_TARGET == HWY_NEON_BF16 +#define HWY_STATIC_NAMESPACE N_NEON_BF16 +#elif HWY_STATIC_TARGET == HWY_SVE +#define HWY_STATIC_NAMESPACE N_SVE +#elif HWY_STATIC_TARGET == HWY_SVE2 +#define HWY_STATIC_NAMESPACE N_SVE2 +#elif HWY_STATIC_TARGET == HWY_SVE_256 +#define HWY_STATIC_NAMESPACE N_SVE_256 +#elif HWY_STATIC_TARGET == HWY_SVE2_128 +#define HWY_STATIC_NAMESPACE N_SVE2_128 +#elif HWY_STATIC_TARGET == HWY_SSE2 +#define HWY_STATIC_NAMESPACE N_SSE2 +#elif HWY_STATIC_TARGET == HWY_SSSE3 +#define HWY_STATIC_NAMESPACE N_SSSE3 +#elif HWY_STATIC_TARGET == HWY_SSE4 +#define HWY_STATIC_NAMESPACE N_SSE4 +#elif HWY_STATIC_TARGET == HWY_AVX2 +#define HWY_STATIC_NAMESPACE N_AVX2 +#elif HWY_STATIC_TARGET == HWY_AVX3 +#define HWY_STATIC_NAMESPACE N_AVX3 +#elif HWY_STATIC_TARGET == HWY_AVX3_DL +#define HWY_STATIC_NAMESPACE N_AVX3_DL +#elif HWY_STATIC_TARGET == HWY_AVX3_ZEN4 +#define HWY_STATIC_NAMESPACE N_AVX3_ZEN4 +#elif HWY_STATIC_TARGET == HWY_AVX3_SPR +#define HWY_STATIC_NAMESPACE N_AVX3_SPR +#elif HWY_STATIC_TARGET == HWY_AVX10_2 +#define HWY_STATIC_NAMESPACE N_AVX10_2 +#endif + +// `HWY_STATIC_DISPATCH(FUNC_NAME)` is the namespace-qualified FUNC_NAME for +// `HWY_STATIC_TARGET`, and can be used to deduce the return type of Choose*. +#define HWY_STATIC_DISPATCH(FUNC_NAME) HWY_STATIC_NAMESPACE::FUNC_NAME + +// `HWY_CHOOSE_*(FUNC_NAME)` expands to the function pointer for that target or +// nullptr if that target was not compiled. +// `HWY_VISIT_*(VISITOR)` expands to `VISITOR(TARGET, NAMESPACE)` or nothing if +// that target was not compiled. +#if HWY_TARGETS & HWY_EMU128 +#define HWY_CHOOSE_FALLBACK(FUNC_NAME) &N_EMU128::FUNC_NAME +#define HWY_VISIT_FALLBACK(VISITOR) VISITOR(HWY_EMU128, N_EMU128) +#elif HWY_TARGETS & HWY_SCALAR +#define HWY_CHOOSE_FALLBACK(FUNC_NAME) &N_SCALAR::FUNC_NAME +#define HWY_VISIT_FALLBACK(VISITOR) VISITOR(HWY_SCALAR, N_SCALAR) +#else +// When HWY_SCALAR/HWY_EMU128 are not present and other targets were disabled at +// runtime, fall back to the baseline with HWY_STATIC_DISPATCH(). +#define HWY_CHOOSE_FALLBACK(FUNC_NAME) &HWY_STATIC_DISPATCH(FUNC_NAME) +#define HWY_VISIT_FALLBACK(VISITOR) \ + VISITOR(HWY_STATIC_TARGET, HWY_STATIC_NAMESPACE) +#endif + +#if HWY_TARGETS & HWY_WASM +#define HWY_CHOOSE_WASM(FUNC_NAME) &N_WASM::FUNC_NAME +#define HWY_VISIT_WASM(VISITOR) VISITOR(HWY_WASM, N_WASM) +#else +#define HWY_CHOOSE_WASM(FUNC_NAME) nullptr +#define HWY_VISIT_WASM(VISITOR) +#endif + +#if HWY_TARGETS & HWY_WASM_EMU256 +#define HWY_CHOOSE_WASM_EMU256(FUNC_NAME) &N_WASM_EMU256::FUNC_NAME +#define HWY_VISIT_WASM_EMU256(VISITOR) VISITOR(HWY_WASM_EMU256, N_WASM_EMU256) +#else +#define HWY_CHOOSE_WASM_EMU256(FUNC_NAME) nullptr +#define HWY_VISIT_WASM_EMU256(VISITOR) +#endif + +#if HWY_TARGETS & HWY_Z14 +#define HWY_CHOOSE_Z14(FUNC_NAME) &N_Z14::FUNC_NAME +#define HWY_VISIT_Z14(VISITOR) VISITOR(HWY_Z14, N_Z14) +#else +#define HWY_CHOOSE_Z14(FUNC_NAME) nullptr +#define HWY_VISIT_Z14(VISITOR) +#endif + +#if HWY_TARGETS & HWY_Z15 +#define HWY_CHOOSE_Z15(FUNC_NAME) &N_Z15::FUNC_NAME +#define HWY_VISIT_Z15(VISITOR) VISITOR(HWY_Z15, N_Z15) +#else +#define HWY_CHOOSE_Z15(FUNC_NAME) nullptr +#define HWY_VISIT_Z15(VISITOR) +#endif + +#if HWY_TARGETS & HWY_PPC8 +#define HWY_CHOOSE_PPC8(FUNC_NAME) &N_PPC8::FUNC_NAME +#define HWY_VISIT_PPC8(VISITOR) VISITOR(HWY_PPC8, N_PPC8) +#else +#define HWY_CHOOSE_PPC8(FUNC_NAME) nullptr +#define HWY_VISIT_PPC8(VISITOR) +#endif + +#if HWY_TARGETS & HWY_PPC9 +#define HWY_CHOOSE_PPC9(FUNC_NAME) &N_PPC9::FUNC_NAME +#define HWY_VISIT_PPC9(VISITOR) VISITOR(HWY_PPC9, N_PPC9) +#else +#define HWY_CHOOSE_PPC9(FUNC_NAME) nullptr +#define HWY_VISIT_PPC9(VISITOR) +#endif + +#if HWY_TARGETS & HWY_LSX +#define HWY_CHOOSE_LSX(FUNC_NAME) &N_LSX::FUNC_NAME +#define HWY_VISIT_LSX(VISITOR) VISITOR(HWY_LSX, N_LSX) +#else +#define HWY_CHOOSE_LSX(FUNC_NAME) nullptr +#define HWY_VISIT_LSX(VISITOR) +#endif + +#if HWY_TARGETS & HWY_LASX +#define HWY_CHOOSE_LASX(FUNC_NAME) &N_LASX::FUNC_NAME +#define HWY_VISIT_LASX(VISITOR) VISITOR(HWY_LASX, N_LASX) +#else +#define HWY_CHOOSE_LASX(FUNC_NAME) nullptr +#define HWY_VISIT_LASX(VISITOR) +#endif + +#if HWY_TARGETS & HWY_PPC10 +#define HWY_CHOOSE_PPC10(FUNC_NAME) &N_PPC10::FUNC_NAME +#define HWY_VISIT_PPC10(VISITOR) VISITOR(HWY_PPC10, N_PPC10) +#else +#define HWY_CHOOSE_PPC10(FUNC_NAME) nullptr +#define HWY_VISIT_PPC10(VISITOR) +#endif + +#if HWY_TARGETS & HWY_RVV +#define HWY_CHOOSE_RVV(FUNC_NAME) &N_RVV::FUNC_NAME +#define HWY_VISIT_RVV(VISITOR) VISITOR(HWY_RVV, N_RVV) +#else +#define HWY_CHOOSE_RVV(FUNC_NAME) nullptr +#define HWY_VISIT_RVV(VISITOR) +#endif + +#if HWY_TARGETS & HWY_NEON_WITHOUT_AES +#define HWY_CHOOSE_NEON_WITHOUT_AES(FUNC_NAME) &N_NEON_WITHOUT_AES::FUNC_NAME +#define HWY_VISIT_NEON_WITHOUT_AES(VISITOR) \ + VISITOR(HWY_NEON_WITHOUT_AES, N_NEON_WITHOUT_AES) +#else +#define HWY_CHOOSE_NEON_WITHOUT_AES(FUNC_NAME) nullptr +#define HWY_VISIT_NEON_WITHOUT_AES(VISITOR) +#endif + +#if HWY_TARGETS & HWY_NEON +#define HWY_CHOOSE_NEON(FUNC_NAME) &N_NEON::FUNC_NAME +#define HWY_VISIT_NEON(VISITOR) VISITOR(HWY_NEON, N_NEON) +#else +#define HWY_CHOOSE_NEON(FUNC_NAME) nullptr +#define HWY_VISIT_NEON(VISITOR) +#endif + +#if HWY_TARGETS & HWY_NEON_BF16 +#define HWY_CHOOSE_NEON_BF16(FUNC_NAME) &N_NEON_BF16::FUNC_NAME +#define HWY_VISIT_NEON_BF16(VISITOR) VISITOR(HWY_NEON_BF16, N_NEON_BF16) +#else +#define HWY_CHOOSE_NEON_BF16(FUNC_NAME) nullptr +#define HWY_VISIT_NEON_BF16(VISITOR) +#endif + +#if HWY_TARGETS & HWY_SVE +#define HWY_CHOOSE_SVE(FUNC_NAME) &N_SVE::FUNC_NAME +#define HWY_VISIT_SVE(VISITOR) VISITOR(HWY_SVE, N_SVE) +#else +#define HWY_CHOOSE_SVE(FUNC_NAME) nullptr +#define HWY_VISIT_SVE(VISITOR) +#endif + +#if HWY_TARGETS & HWY_SVE2 +#define HWY_CHOOSE_SVE2(FUNC_NAME) &N_SVE2::FUNC_NAME +#define HWY_VISIT_SVE2(VISITOR) VISITOR(HWY_SVE2, N_SVE2) +#else +#define HWY_CHOOSE_SVE2(FUNC_NAME) nullptr +#define HWY_VISIT_SVE2(VISITOR) +#endif + +#if HWY_TARGETS & HWY_SVE_256 +#define HWY_CHOOSE_SVE_256(FUNC_NAME) &N_SVE_256::FUNC_NAME +#define HWY_VISIT_SVE_256(VISITOR) VISITOR(HWY_SVE_256, N_SVE_256) +#else +#define HWY_CHOOSE_SVE_256(FUNC_NAME) nullptr +#define HWY_VISIT_SVE_256(VISITOR) +#endif + +#if HWY_TARGETS & HWY_SVE2_128 +#define HWY_CHOOSE_SVE2_128(FUNC_NAME) &N_SVE2_128::FUNC_NAME +#define HWY_VISIT_SVE2_128(VISITOR) VISITOR(HWY_SVE2_128, N_SVE2_128) +#else +#define HWY_CHOOSE_SVE2_128(FUNC_NAME) nullptr +#define HWY_VISIT_SVE2_128(VISITOR) +#endif + +#if HWY_TARGETS & HWY_SSE2 +#define HWY_CHOOSE_SSE2(FUNC_NAME) &N_SSE2::FUNC_NAME +#define HWY_VISIT_SSE2(VISITOR) VISITOR(HWY_SSE2, N_SSE2) +#else +#define HWY_CHOOSE_SSE2(FUNC_NAME) nullptr +#define HWY_VISIT_SSE2(VISITOR) +#endif + +#if HWY_TARGETS & HWY_SSSE3 +#define HWY_CHOOSE_SSSE3(FUNC_NAME) &N_SSSE3::FUNC_NAME +#define HWY_VISIT_SSSE3(VISITOR) VISITOR(HWY_SSSE3, N_SSSE3) +#else +#define HWY_CHOOSE_SSSE3(FUNC_NAME) nullptr +#define HWY_VISIT_SSSE3(VISITOR) +#endif + +#if HWY_TARGETS & HWY_SSE4 +#define HWY_CHOOSE_SSE4(FUNC_NAME) &N_SSE4::FUNC_NAME +#define HWY_VISIT_SSE4(VISITOR) VISITOR(HWY_SSE4, N_SSE4) +#else +#define HWY_CHOOSE_SSE4(FUNC_NAME) nullptr +#define HWY_VISIT_SSE4(VISITOR) +#endif + +#if HWY_TARGETS & HWY_AVX2 +#define HWY_CHOOSE_AVX2(FUNC_NAME) &N_AVX2::FUNC_NAME +#define HWY_VISIT_AVX2(VISITOR) VISITOR(HWY_AVX2, N_AVX2) +#else +#define HWY_CHOOSE_AVX2(FUNC_NAME) nullptr +#define HWY_VISIT_AVX2(VISITOR) +#endif + +#if HWY_TARGETS & HWY_AVX3 +#define HWY_CHOOSE_AVX3(FUNC_NAME) &N_AVX3::FUNC_NAME +#define HWY_VISIT_AVX3(VISITOR) VISITOR(HWY_AVX3, N_AVX3) +#else +#define HWY_CHOOSE_AVX3(FUNC_NAME) nullptr +#define HWY_VISIT_AVX3(VISITOR) +#endif + +#if HWY_TARGETS & HWY_AVX3_DL +#define HWY_CHOOSE_AVX3_DL(FUNC_NAME) &N_AVX3_DL::FUNC_NAME +#define HWY_VISIT_AVX3_DL(VISITOR) VISITOR(HWY_AVX3_DL, N_AVX3_DL) +#else +#define HWY_CHOOSE_AVX3_DL(FUNC_NAME) nullptr +#define HWY_VISIT_AVX3_DL(VISITOR) +#endif + +#if HWY_TARGETS & HWY_AVX3_ZEN4 +#define HWY_CHOOSE_AVX3_ZEN4(FUNC_NAME) &N_AVX3_ZEN4::FUNC_NAME +#define HWY_VISIT_AVX3_ZEN4(VISITOR) VISITOR(HWY_AVX3_ZEN4, N_AVX3_ZEN4) +#else +#define HWY_CHOOSE_AVX3_ZEN4(FUNC_NAME) nullptr +#define HWY_VISIT_AVX3_ZEN4(VISITOR) +#endif + +#if HWY_TARGETS & HWY_AVX3_SPR +#define HWY_CHOOSE_AVX3_SPR(FUNC_NAME) &N_AVX3_SPR::FUNC_NAME +#define HWY_VISIT_AVX3_SPR(VISITOR) VISITOR(HWY_AVX3_SPR, N_AVX3_SPR) +#else +#define HWY_CHOOSE_AVX3_SPR(FUNC_NAME) nullptr +#define HWY_VISIT_AVX3_SPR(VISITOR) +#endif + +#if HWY_TARGETS & HWY_AVX10_2 +#define HWY_CHOOSE_AVX10_2(FUNC_NAME) &N_AVX10_2::FUNC_NAME +#define HWY_VISIT_AVX10_2(VISITOR) VISITOR(HWY_AVX10_2, N_AVX10_2) +#else +#define HWY_CHOOSE_AVX10_2(FUNC_NAME) nullptr +#define HWY_VISIT_AVX10_2(VISITOR) +#endif + +// MSVC 2017 workaround: the non-type template parameter to ChooseAndCall +// apparently cannot be an array. Use a function pointer instead, which has the +// disadvantage that we call the static (not best) target on the first call to +// any HWY_DYNAMIC_DISPATCH. +#if (HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1915) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700) +#define HWY_DISPATCH_WORKAROUND 1 +#else +#define HWY_DISPATCH_WORKAROUND 0 +#endif + +#if HWY_DISPATCH_MAP +struct AllExports { + template + static const FuncPtr*& GetRefToExportsPtr() { + static const FuncPtr* s_exports = nullptr; + return s_exports; + } +}; +#endif + +// Provides a static member function which is what is called during the first +// HWY_DYNAMIC_DISPATCH, where GetIndex is still zero, and instantiations of +// this function are the first entry in the tables created by HWY_EXPORT[_T]. +template +struct FunctionCache { + public: + typedef RetType(FuncType)(Args...); + using FuncPtr = FuncType*; + + // A template function that when instantiated has the same signature as the + // function being called. This function initializes the bit array of targets + // supported by the current CPU and then calls the appropriate entry within + // the HWY_EXPORT table. Subsequent calls via HWY_DYNAMIC_DISPATCH to any + // exported functions, even those defined by different translation units, + // will dispatch directly to the best available target. +#if HWY_DISPATCH_MAP + template + static RetType ChooseAndCall(Args... args) { + ChosenTarget& chosen_target = GetChosenTarget(); + chosen_target.Update(SupportedTargets()); + + const FuncPtr* table = AllExports::template GetRefToExportsPtr< + FuncPtr, RemoveCvRef, kHash>(); + HWY_ASSERT(table); + + return (table[chosen_target.GetIndex()])(args...); + } + +#if !HWY_DISPATCH_WORKAROUND + template + static RetType TableChooseAndCall(Args... args) { + ChosenTarget& chosen_target = GetChosenTarget(); + chosen_target.Update(SupportedTargets()); + return (table[chosen_target.GetIndex()])(args...); + } +#endif // !HWY_DISPATCH_WORKAROUND + +#else // !HWY_DISPATCH_MAP: zero-overhead, but requires C++17 + template + static RetType ChooseAndCall(Args... args) { + ChosenTarget& chosen_target = GetChosenTarget(); + chosen_target.Update(SupportedTargets()); + return (table[chosen_target.GetIndex()])(args...); + } +#endif // HWY_DISPATCH_MAP +}; + +// Used to deduce the template parameters RetType and Args from a function. +template +FunctionCache DeduceFunctionCache(RetType (*)(Args...)) { + return FunctionCache(); +} + +#define HWY_DISPATCH_TABLE(FUNC_NAME) \ + HWY_CONCAT(FUNC_NAME, HighwayDispatchTable) + +// HWY_EXPORT(FUNC_NAME); expands to a static array that is used by +// HWY_DYNAMIC_DISPATCH() to call the appropriate function at runtime. +// After being exported, it can be called from other parts of the same source +// file using HWY_DYNAMIC_DISPATCH(), in particular from a function wrapper +// like in the following example: +// +// #include "hwy/highway.h" +// HWY_BEFORE_NAMESPACE(); +// namespace skeleton { +// namespace HWY_NAMESPACE { +// +// void MyFunction(int a, char b, const char* c) { ... } +// +// // NOLINTNEXTLINE(google-readability-namespace-comments) +// } // namespace HWY_NAMESPACE +// } // namespace skeleton +// HWY_AFTER_NAMESPACE(); +// +// namespace skeleton { +// HWY_EXPORT(MyFunction); // Defines the dispatch table in this scope. +// +// void MyFunction(int a, char b, const char* c) { +// return HWY_DYNAMIC_DISPATCH(MyFunction)(a, b, c); +// } +// } // namespace skeleton +// +// For templated code with a single type parameter, instead use HWY_EXPORT_T and +// its HWY_DYNAMIC_DISPATCH_T counterpart: +// +// template +// void MyFunctionCaller(T ...) { +// // First argument to both HWY_EXPORT_T and HWY_DYNAMIC_DISPATCH_T is an +// // arbitrary table name; you must provide the same name for each call. +// // It is fine to have multiple HWY_EXPORT_T in a function, but a 64-bit +// // FNV hash collision among *any* table names will trigger HWY_ABORT. +// HWY_EXPORT_T(Table1, MyFunction) +// HWY_DYNAMIC_DISPATCH_T(Table1)(a, b, c); +// } +// +// Note that HWY_EXPORT_T must be invoked inside a template (in the above +// example: `MyFunctionCaller`), so that a separate table will be created for +// each template instantiation. For convenience, we also provide a macro that +// combines both steps and avoids the need to pick a table name: +// +// template +// void MyFunctionCaller(T ...) { +// // Table name is automatically chosen. Note that this variant must be +// // called in statement context; it is not a valid expression. +// HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(MyFunction)(a, b, c); +// } + +// Simplified version for IDE or the dynamic dispatch case with only one target. +#if HWY_IDE || ((HWY_TARGETS & (HWY_TARGETS - 1)) == 0) + +// We use a table to provide the same compile error conditions as with the +// non-simplified case, but the table only has a single entry. +#define HWY_EXPORT_T(TABLE_NAME, FUNC_NAME) \ + HWY_MAYBE_UNUSED static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) const \ + HWY_DISPATCH_TABLE(TABLE_NAME)[1] = {&HWY_STATIC_DISPATCH(FUNC_NAME)} + +// Use the table, not just STATIC_DISPATCH as in DYNAMIC_DISPATCH, because +// TABLE_NAME might not match the function name. +#define HWY_DYNAMIC_POINTER_T(TABLE_NAME) (HWY_DISPATCH_TABLE(TABLE_NAME)[0]) +#define HWY_DYNAMIC_DISPATCH_T(TABLE_NAME) \ + (*(HWY_DYNAMIC_POINTER_T(TABLE_NAME))) + +#define HWY_EXPORT(FUNC_NAME) HWY_EXPORT_T(FUNC_NAME, FUNC_NAME) +#define HWY_DYNAMIC_POINTER(FUNC_NAME) &HWY_STATIC_DISPATCH(FUNC_NAME) +#define HWY_DYNAMIC_DISPATCH(FUNC_NAME) HWY_STATIC_DISPATCH(FUNC_NAME) + +#else // not simplified: full table + +// Pre-C++17 workaround: non-type template arguments must have linkage, which +// means we cannot pass &table as a template argument to ChooseAndCall. +// ChooseAndCall must find a way to access the table in order to dispatch to the +// chosen target: +// 0) Skipping this by dispatching to the static target would be surprising to +// users and may have serious performance implications. +// 1) An extra function parameter would be unacceptable because it changes the +// user-visible function signature. +// 2) Declaring a table, then defining a pointer to it would work, but requires +// an additional DECLARE step outside the function so that the pointer has +// linkage, which breaks existing code. +// 3) We instead associate the function with the table using an instance of an +// unnamed struct and the hash of the table name as the key. Because +// ChooseAndCall has the type information, it can then cast to the function +// pointer type. However, we cannot simply pass the name as a template +// argument to ChooseAndCall because this requires char*, which hits the same +// linkage problem. We instead hash the table name, which assumes the +// function names do not have collisions. +#if HWY_DISPATCH_MAP + +static constexpr uint64_t FNV(const char* name) { + return *name ? static_cast(static_cast(*name)) ^ + (0x100000001b3ULL * FNV(name + 1)) + : 0xcbf29ce484222325ULL; +} + +template +struct AddExport { + template + AddExport(ExportsKey /*exports_key*/, const char* table_name, + const FuncPtr* table) { + using FuncCache = decltype(DeduceFunctionCache(hwy::DeclVal())); + static_assert( + hwy::IsSame, typename FuncCache::FuncPtr>(), + "FuncPtr should be same type as FuncCache::FuncPtr"); + + const FuncPtr*& exports_ptr = AllExports::template GetRefToExportsPtr< + RemoveCvRef, RemoveCvRef, kHash>(); + if (exports_ptr && exports_ptr != table) { + HWY_ABORT("Hash collision for %s, rename the function\n", table_name); + } else { + exports_ptr = table; + } + } +}; + +// Dynamic dispatch: defines table of function pointers. This must be invoked +// from inside the function template that calls the template we are exporting. +// TABLE_NAME must match the one passed to HWY_DYNAMIC_DISPATCH_T. This +// argument allows multiple exports within one function. +#define HWY_EXPORT_T(TABLE_NAME, FUNC_NAME) \ + static const struct { \ + } HWY_CONCAT(TABLE_NAME, HighwayDispatchExportsKey) = {}; \ + static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) const HWY_DISPATCH_TABLE( \ + TABLE_NAME)[static_cast(HWY_MAX_DYNAMIC_TARGETS + 2)] = { \ + /* The first entry in the table initializes the global cache and \ + * calls the appropriate function. */ \ + &decltype(hwy::DeduceFunctionCache(&HWY_STATIC_DISPATCH(FUNC_NAME))):: \ + template ChooseAndCall, \ + HWY_CHOOSE_TARGET_LIST(FUNC_NAME), \ + HWY_CHOOSE_FALLBACK(FUNC_NAME), \ + }; \ + HWY_MAYBE_UNUSED static hwy::AddExport HWY_CONCAT( \ + HighwayAddTable, __LINE__)( \ + HWY_CONCAT(TABLE_NAME, HighwayDispatchExportsKey), #TABLE_NAME, \ + HWY_DISPATCH_TABLE(TABLE_NAME)) + +// For non-template functions. Not necessarily invoked within a function, hence +// we derive the string and variable names from FUNC_NAME, not HWY_FUNCTION. +#if HWY_DISPATCH_WORKAROUND +#define HWY_EXPORT(FUNC_NAME) HWY_EXPORT_T(FUNC_NAME, FUNC_NAME) +#else +#define HWY_EXPORT(FUNC_NAME) \ + static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) const HWY_DISPATCH_TABLE( \ + FUNC_NAME)[static_cast(HWY_MAX_DYNAMIC_TARGETS + 2)] = { \ + /* The first entry in the table initializes the global cache and \ + * calls the appropriate function. */ \ + &decltype(hwy::DeduceFunctionCache(&HWY_STATIC_DISPATCH(FUNC_NAME))):: \ + template TableChooseAndCall, \ + HWY_CHOOSE_TARGET_LIST(FUNC_NAME), \ + HWY_CHOOSE_FALLBACK(FUNC_NAME), \ + } +#endif // HWY_DISPATCH_WORKAROUND + +#else // !HWY_DISPATCH_MAP + +// Zero-overhead, but requires C++17 for non-type template arguments without +// linkage, because HWY_EXPORT_T tables are local static variables. +#define HWY_EXPORT_T(TABLE_NAME, FUNC_NAME) \ + static decltype(&HWY_STATIC_DISPATCH(FUNC_NAME)) const HWY_DISPATCH_TABLE( \ + TABLE_NAME)[static_cast(HWY_MAX_DYNAMIC_TARGETS + 2)] = { \ + /* The first entry in the table initializes the global cache and \ + * calls the appropriate function. */ \ + &decltype(hwy::DeduceFunctionCache(&HWY_STATIC_DISPATCH(FUNC_NAME))):: \ + template ChooseAndCall, \ + HWY_CHOOSE_TARGET_LIST(FUNC_NAME), \ + HWY_CHOOSE_FALLBACK(FUNC_NAME), \ + } + +#define HWY_EXPORT(FUNC_NAME) HWY_EXPORT_T(FUNC_NAME, FUNC_NAME) + +#endif // HWY_DISPATCH_MAP + +// HWY_DISPATCH_MAP only affects how tables are created, not their usage. + +// Evaluates to the function pointer for the chosen target. +#define HWY_DYNAMIC_POINTER(FUNC_NAME) \ + (HWY_DISPATCH_TABLE(FUNC_NAME)[hwy::GetChosenTarget().GetIndex()]) + +// Calls the function pointer for the chosen target. +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANG + +// On GCC or Clang, we call hwy::PreventElision(...) to work around a compiler +// crash where the LLVM inliner crashes due to inlining incompatible intrinsics. + +#define HWY_DYNAMIC_DISPATCH(FUNC_NAME) \ + __extension__({ \ + auto HWY_CONCAT(hwy_tmp_, __LINE__) = *(HWY_DYNAMIC_POINTER(FUNC_NAME)); \ + hwy::PreventElision(HWY_CONCAT(hwy_tmp_, __LINE__)); \ + HWY_CONCAT(hwy_tmp_, __LINE__); \ + }) + +#else // !(HWY_COMPILER_GCC || HWY_COMPILER_CLANG) + +#define HWY_DYNAMIC_DISPATCH(FUNC_NAME) (*(HWY_DYNAMIC_POINTER(FUNC_NAME))) + +#endif // HWY_COMPILER_GCC || HWY_COMPILER_CLANG + +// Same as DISPATCH, but provide a different arg name to clarify usage. +#define HWY_DYNAMIC_DISPATCH_T(TABLE_NAME) HWY_DYNAMIC_DISPATCH(TABLE_NAME) +#define HWY_DYNAMIC_POINTER_T(TABLE_NAME) HWY_DYNAMIC_POINTER(TABLE_NAME) + +#endif // HWY_IDE || ((HWY_TARGETS & (HWY_TARGETS - 1)) == 0) + +// Returns the name of an anonymous dispatch table that is only shared with +// macro invocations coming from the same source line. +#define HWY_DISPATCH_TABLE_T() HWY_CONCAT(HighwayDispatchTableT, __LINE__) + +// For templated code, combines export and dispatch using an anonymous table. +#define HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC_NAME) \ + HWY_EXPORT_T(HWY_DISPATCH_TABLE_T(), FUNC_NAME); \ + HWY_DYNAMIC_DISPATCH_T(HWY_DISPATCH_TABLE_T()) + +// DEPRECATED names; please use HWY_HAVE_* instead. +#define HWY_CAP_INTEGER64 HWY_HAVE_INTEGER64 +#define HWY_CAP_FLOAT16 HWY_HAVE_FLOAT16 +#define HWY_CAP_FLOAT64 HWY_HAVE_FLOAT64 + +} // namespace hwy + +#endif // HWY_HIGHWAY_INCLUDED + +//------------------------------------------------------------------------------ + +// NOTE: the following definitions and ops/*.h depend on HWY_TARGET, so we want +// to include them once per target, which is ensured by the toggle check. +// Because ops/*.h are included under it, they do not need their own guard. +#if defined(HWY_HIGHWAY_PER_TARGET) == defined(HWY_TARGET_TOGGLE) +#ifdef HWY_HIGHWAY_PER_TARGET +#undef HWY_HIGHWAY_PER_TARGET +#else +#define HWY_HIGHWAY_PER_TARGET +#endif + +// No SIMD target enabled, skip header inclusion. +#if HWY_ENABLED_BASELINE == 0 + +// We would expect that HWY_TARGET and HWY_STATIC_TARGET are now both 0. +#if HWY_TARGET != 0 +#error "Why is HWY_TARGET not 0 when HWY_ENABLED_BASELINE == 0?" +#endif +#if HWY_STATIC_TARGET != 0 +#error "Why is HWY_STATIC_TARGET not 0 when HWY_ENABLED_BASELINE == 0?" +#endif + +#else + +// These define ops inside namespace hwy::HWY_NAMESPACE. +#if HWY_TARGET == HWY_SSE2 || HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4 +#include "hwy/ops/x86_128-inl.h" +#elif HWY_TARGET == HWY_AVX2 +#include "hwy/ops/x86_256-inl.h" +#elif HWY_TARGET == HWY_AVX3 || HWY_TARGET == HWY_AVX3_DL || \ + HWY_TARGET == HWY_AVX3_ZEN4 || HWY_TARGET == HWY_AVX3_SPR || \ + HWY_TARGET == HWY_AVX10_2 +#include "hwy/ops/x86_avx3-inl.h" +#elif HWY_TARGET == HWY_Z14 || HWY_TARGET == HWY_Z15 || \ + (HWY_TARGET & HWY_ALL_PPC) +#include "hwy/ops/ppc_vsx-inl.h" +#elif HWY_TARGET & HWY_ALL_NEON +#include "hwy/ops/arm_neon-inl.h" +#elif HWY_TARGET & HWY_ALL_SVE +#include "hwy/ops/arm_sve-inl.h" +#elif HWY_TARGET == HWY_WASM_EMU256 +#include "hwy/ops/wasm_256-inl.h" +#elif HWY_TARGET == HWY_WASM +#include "hwy/ops/wasm_128-inl.h" +#elif HWY_TARGET == HWY_RVV +#include "hwy/ops/rvv-inl.h" +#elif HWY_TARGET == HWY_LSX +#include "hwy/ops/loongarch_lsx-inl.h" +#elif HWY_TARGET == HWY_LASX +#include "hwy/ops/loongarch_lasx-inl.h" +#elif HWY_TARGET == HWY_EMU128 +#include "hwy/ops/emu128-inl.h" +#elif HWY_TARGET == HWY_SCALAR +#include "hwy/ops/scalar-inl.h" +#else +#pragma message("HWY_TARGET does not match any known target") +#endif // HWY_TARGET + +#include "hwy/ops/generic_ops-inl.h" + +#endif // HWY_ENABLED_BASELINE + +#endif // HWY_HIGHWAY_PER_TARGET diff --git a/lib/highway/hwy/highway_export.h b/lib/highway/hwy/highway_export.h new file mode 100644 index 00000000000..30edc17d013 --- /dev/null +++ b/lib/highway/hwy/highway_export.h @@ -0,0 +1,74 @@ +// Pseudo-generated file to handle both cmake & bazel build system. + +// Initial generation done using cmake code: +// include(GenerateExportHeader) +// generate_export_header(hwy EXPORT_MACRO_NAME HWY_DLLEXPORT EXPORT_FILE_NAME +// hwy/highway_export.h) +// code reformatted using clang-format --style=Google + +#ifndef HWY_DLLEXPORT_H +#define HWY_DLLEXPORT_H + +#if !defined(HWY_SHARED_DEFINE) +#define HWY_DLLEXPORT +#define HWY_CONTRIB_DLLEXPORT +#define HWY_TEST_DLLEXPORT +#else // !HWY_SHARED_DEFINE + +#ifndef HWY_DLLEXPORT +#if defined(hwy_EXPORTS) +/* We are building this library */ +#ifdef _WIN32 +#define HWY_DLLEXPORT __declspec(dllexport) +#else +#define HWY_DLLEXPORT __attribute__((visibility("default"))) +#endif +#else // defined(hwy_EXPORTS) +/* We are using this library */ +#ifdef _WIN32 +#define HWY_DLLEXPORT __declspec(dllimport) +#else +#define HWY_DLLEXPORT __attribute__((visibility("default"))) +#endif +#endif // defined(hwy_EXPORTS) +#endif // HWY_DLLEXPORT + +#ifndef HWY_CONTRIB_DLLEXPORT +#if defined(hwy_contrib_EXPORTS) +/* We are building this library */ +#ifdef _WIN32 +#define HWY_CONTRIB_DLLEXPORT __declspec(dllexport) +#else +#define HWY_CONTRIB_DLLEXPORT __attribute__((visibility("default"))) +#endif +#else // defined(hwy_contrib_EXPORTS) +/* We are using this library */ +#ifdef _WIN32 +#define HWY_CONTRIB_DLLEXPORT __declspec(dllimport) +#else +#define HWY_CONTRIB_DLLEXPORT __attribute__((visibility("default"))) +#endif +#endif // defined(hwy_contrib_EXPORTS) +#endif // HWY_CONTRIB_DLLEXPORT + +#ifndef HWY_TEST_DLLEXPORT +#if defined(hwy_test_EXPORTS) +/* We are building this library */ +#ifdef _WIN32 +#define HWY_TEST_DLLEXPORT __declspec(dllexport) +#else +#define HWY_TEST_DLLEXPORT __attribute__((visibility("default"))) +#endif +#else // defined(hwy_test_EXPORTS) +/* We are using this library */ +#ifdef _WIN32 +#define HWY_TEST_DLLEXPORT __declspec(dllimport) +#else +#define HWY_TEST_DLLEXPORT __attribute__((visibility("default"))) +#endif +#endif // defined(hwy_test_EXPORTS) +#endif // HWY_TEST_DLLEXPORT + +#endif // !HWY_SHARED_DEFINE + +#endif /* HWY_DLLEXPORT_H */ diff --git a/lib/highway/hwy/hwy.version b/lib/highway/hwy/hwy.version new file mode 100644 index 00000000000..9ff6be6a2d7 --- /dev/null +++ b/lib/highway/hwy/hwy.version @@ -0,0 +1,19 @@ +HWY_0 { + global: + extern "C++" { + *hwy::*; + }; + + local: + # Hide all the std namespace symbols. std namespace is explicitly marked + # as visibility(default) and header-only functions or methods (such as those + # from templates) should be exposed in shared libraries as weak symbols but + # this is only needed when we expose those types in the shared library API + # in any way. We don't use C++ std types in the API and we also don't + # support exceptions in the library. + # See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=36022 for a discussion + # about this. + extern "C++" { + *std::*; + }; +}; diff --git a/lib/highway/hwy/nanobenchmark.cc b/lib/highway/hwy/nanobenchmark.cc new file mode 100644 index 00000000000..0a885d2d090 --- /dev/null +++ b/lib/highway/hwy/nanobenchmark.cc @@ -0,0 +1,302 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/nanobenchmark.h" + +#include +#include +#include // clock_gettime + +#include // std::sort, std::find_if +#include // std::iota +#include +#include + +#include "hwy/base.h" +#include "hwy/robust_statistics.h" +#include "hwy/timer.h" + +namespace hwy { +namespace { +const timer::Ticks& GetTimerResolution() { + static const timer::Ticks timer_resolution = platform::TimerResolution(); + return timer_resolution; +} + +// Estimates the expected value of "lambda" values with a variable number of +// samples until the variability "rel_mad" is less than "max_rel_mad". +template +timer::Ticks SampleUntilStable(const double max_rel_mad, double* rel_mad, + const Params& p, const Lambda& lambda) { + // Choose initial samples_per_eval based on a single estimated duration. + timer::Ticks t0 = timer::Start(); + lambda(); + timer::Ticks t1 = timer::Stop(); // Caller checks HaveTimerStop + timer::Ticks est = t1 - t0; + static const double ticks_per_second = platform::InvariantTicksPerSecond(); + const size_t ticks_per_eval = + static_cast(ticks_per_second * p.seconds_per_eval); + size_t samples_per_eval = est == 0 + ? p.min_samples_per_eval + : static_cast(ticks_per_eval / est); + samples_per_eval = HWY_MAX(samples_per_eval, p.min_samples_per_eval); + + std::vector samples; + samples.reserve(1 + samples_per_eval); + samples.push_back(est); + + // Percentage is too strict for tiny differences, so also allow a small + // absolute "median absolute deviation". + const timer::Ticks max_abs_mad = (GetTimerResolution() + 99) / 100; + *rel_mad = 0.0; // ensure initialized + + for (size_t eval = 0; eval < p.max_evals; ++eval, samples_per_eval *= 2) { + samples.reserve(samples.size() + samples_per_eval); + for (size_t i = 0; i < samples_per_eval; ++i) { + t0 = timer::Start(); + lambda(); + t1 = timer::Stop(); // Caller checks HaveTimerStop + samples.push_back(t1 - t0); + } + + if (samples.size() >= p.min_mode_samples) { + est = robust_statistics::Mode(samples.data(), samples.size()); + } else { + // For "few" (depends also on the variance) samples, Median is safer. + est = robust_statistics::Median(samples.data(), samples.size()); + } + if (est == 0) { + HWY_WARN("estimated duration is 0\n"); + } + + // Median absolute deviation (mad) is a robust measure of 'variability'. + const timer::Ticks abs_mad = robust_statistics::MedianAbsoluteDeviation( + samples.data(), samples.size(), est); + *rel_mad = static_cast(abs_mad) / static_cast(est); + + if (*rel_mad <= max_rel_mad || abs_mad <= max_abs_mad) { + if (p.verbose) { + printf("%6d samples => %5d (abs_mad=%4d, rel_mad=%4.2f%%)\n", + static_cast(samples.size()), static_cast(est), + static_cast(abs_mad), *rel_mad * 100.0); + } + return est; + } + } + + if (p.verbose) { + printf("WARNING: rel_mad=%4.2f%% still exceeds %4.2f%% after %6d samples\n", + *rel_mad * 100.0, max_rel_mad * 100.0, + static_cast(samples.size())); + } + return est; +} + +using InputVec = std::vector; + +// Returns vector of unique input values. +InputVec UniqueInputs(const FuncInput* inputs, const size_t num_inputs) { + InputVec unique(inputs, inputs + num_inputs); + std::sort(unique.begin(), unique.end()); + unique.erase(std::unique(unique.begin(), unique.end()), unique.end()); + return unique; +} + +// Returns how often we need to call func for sufficient precision. +size_t NumSkip(const Func func, const uint8_t* arg, const InputVec& unique, + const Params& p) { + // Min elapsed ticks for any input. + timer::Ticks min_duration = ~timer::Ticks(0); + + for (const FuncInput input : unique) { + double rel_mad; + const timer::Ticks total = SampleUntilStable( + p.target_rel_mad, &rel_mad, p, + [func, arg, input]() { PreventElision(func(arg, input)); }); + min_duration = HWY_MIN(min_duration, total - GetTimerResolution()); + } + + // Number of repetitions required to reach the target resolution. + const size_t max_skip = p.precision_divisor; + // Number of repetitions given the estimated duration. + const size_t num_skip = + min_duration == 0 + ? 0 + : static_cast((max_skip + min_duration - 1) / min_duration); + if (p.verbose) { + printf("res=%d max_skip=%d min_dur=%d num_skip=%d\n", + static_cast(GetTimerResolution()), static_cast(max_skip), + static_cast(min_duration), static_cast(num_skip)); + } + return num_skip; +} + +// Replicates inputs until we can omit "num_skip" occurrences of an input. +InputVec ReplicateInputs(const FuncInput* inputs, const size_t num_inputs, + const size_t num_unique, const size_t num_skip, + const Params& p) { + InputVec full; + if (num_unique == 1) { + full.assign(p.subset_ratio * num_skip, inputs[0]); + return full; + } + + full.reserve(p.subset_ratio * num_skip * num_inputs); + for (size_t i = 0; i < p.subset_ratio * num_skip; ++i) { + full.insert(full.end(), inputs, inputs + num_inputs); + } + std::mt19937 rng; + std::shuffle(full.begin(), full.end(), rng); + return full; +} + +// Copies the "full" to "subset" in the same order, but with "num_skip" +// randomly selected occurrences of "input_to_skip" removed. +void FillSubset(const InputVec& full, const FuncInput input_to_skip, + const size_t num_skip, InputVec* subset) { + const size_t count = + static_cast(std::count(full.begin(), full.end(), input_to_skip)); + // Generate num_skip random indices: which occurrence to skip. + std::vector omit(count); + std::iota(omit.begin(), omit.end(), 0); + // omit[] is the same on every call, but that's OK because they identify the + // Nth instance of input_to_skip, so the position within full[] differs. + std::mt19937 rng; + std::shuffle(omit.begin(), omit.end(), rng); + omit.resize(num_skip); + std::sort(omit.begin(), omit.end()); + + uint32_t occurrence = ~0u; // 0 after preincrement + size_t idx_omit = 0; // cursor within omit[] + size_t idx_subset = 0; // cursor within *subset + for (const FuncInput next : full) { + if (next == input_to_skip) { + ++occurrence; + // Haven't removed enough already + if (idx_omit < num_skip) { + // This one is up for removal + if (occurrence == omit[idx_omit]) { + ++idx_omit; + continue; + } + } + } + if (idx_subset < subset->size()) { + (*subset)[idx_subset++] = next; + } + } + HWY_DASSERT(idx_subset == subset->size()); + HWY_DASSERT(idx_omit == omit.size()); + HWY_DASSERT(occurrence == count - 1); +} + +// Returns total ticks elapsed for all inputs. +timer::Ticks TotalDuration(const Func func, const uint8_t* arg, + const InputVec* inputs, const Params& p, + double* max_rel_mad) { + double rel_mad; + const timer::Ticks duration = + SampleUntilStable(p.target_rel_mad, &rel_mad, p, [func, arg, inputs]() { + for (const FuncInput input : *inputs) { + PreventElision(func(arg, input)); + } + }); + *max_rel_mad = HWY_MAX(*max_rel_mad, rel_mad); + return duration; +} + +// (Nearly) empty Func for measuring timer overhead/resolution. +HWY_NOINLINE FuncOutput EmptyFunc(const void* /*arg*/, const FuncInput input) { + return input; +} + +// Returns overhead of accessing inputs[] and calling a function; this will +// be deducted from future TotalDuration return values. +timer::Ticks Overhead(const uint8_t* arg, const InputVec* inputs, + const Params& p) { + double rel_mad; + // Zero tolerance because repeatability is crucial and EmptyFunc is fast. + return SampleUntilStable(0.0, &rel_mad, p, [arg, inputs]() { + for (const FuncInput input : *inputs) { + PreventElision(EmptyFunc(arg, input)); + } + }); +} + +} // namespace + +HWY_DLLEXPORT int Unpredictable1() { return timer::Start() != ~0ULL; } + +HWY_DLLEXPORT size_t Measure(const Func func, const uint8_t* arg, + const FuncInput* inputs, const size_t num_inputs, + Result* results, const Params& p) { + HWY_DASSERT(num_inputs != 0); + + char cpu100[100]; + if (!platform::HaveTimerStop(cpu100)) { + HWY_WARN("CPU '%s' does not support RDTSCP, skipping benchmark.\n", cpu100); + return 0; + } + + const InputVec& unique = UniqueInputs(inputs, num_inputs); + + const size_t num_skip = NumSkip(func, arg, unique, p); // never 0 + if (num_skip == 0) return 0; // NumSkip already printed error message + // (slightly less work on x86 to cast from signed integer) + const float mul = 1.0f / static_cast(static_cast(num_skip)); + + const InputVec& full = + ReplicateInputs(inputs, num_inputs, unique.size(), num_skip, p); + InputVec subset(full.size() - num_skip); + + const timer::Ticks overhead = Overhead(arg, &full, p); + const timer::Ticks overhead_skip = Overhead(arg, &subset, p); + if (overhead < overhead_skip) { + HWY_WARN("Measurement failed: overhead %d < %d\n", + static_cast(overhead), static_cast(overhead_skip)); + return 0; + } + + if (p.verbose) { + printf("#inputs=%5d,%5d overhead=%5d,%5d\n", static_cast(full.size()), + static_cast(subset.size()), static_cast(overhead), + static_cast(overhead_skip)); + } + + double max_rel_mad = 0.0; + const timer::Ticks total = TotalDuration(func, arg, &full, p, &max_rel_mad); + + for (size_t i = 0; i < unique.size(); ++i) { + FillSubset(full, unique[i], num_skip, &subset); + const timer::Ticks total_skip = + TotalDuration(func, arg, &subset, p, &max_rel_mad); + + if (total < total_skip) { + HWY_WARN("Measurement failed: total %f < %f\n", + static_cast(total), static_cast(total_skip)); + return 0; + } + + const timer::Ticks duration = + (total - overhead) - (total_skip - overhead_skip); + results[i].input = unique[i]; + results[i].ticks = static_cast(duration) * mul; + results[i].variability = static_cast(max_rel_mad); + } + + return unique.size(); +} + +} // namespace hwy diff --git a/lib/highway/hwy/nanobenchmark.h b/lib/highway/hwy/nanobenchmark.h new file mode 100644 index 00000000000..16f011a1a4c --- /dev/null +++ b/lib/highway/hwy/nanobenchmark.h @@ -0,0 +1,162 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_NANOBENCHMARK_H_ +#define HIGHWAY_HWY_NANOBENCHMARK_H_ + +// Benchmarks functions of a single integer argument with realistic branch +// prediction hit rates. Uses a robust estimator to summarize the measurements. +// The precision is about 0.2%. +// +// Examples: see nanobenchmark_test.cc. +// +// Background: Microbenchmarks such as http://github.com/google/benchmark +// can measure elapsed times on the order of a microsecond. Shorter functions +// are typically measured by repeating them thousands of times and dividing +// the total elapsed time by this count. Unfortunately, repetition (especially +// with the same input parameter!) influences the runtime. In time-critical +// code, it is reasonable to expect warm instruction/data caches and TLBs, +// but a perfect record of which branches will be taken is unrealistic. +// Unless the application also repeatedly invokes the measured function with +// the same parameter, the benchmark is measuring something very different - +// a best-case result, almost as if the parameter were made a compile-time +// constant. This may lead to erroneous conclusions about branch-heavy +// algorithms outperforming branch-free alternatives. +// +// Our approach differs in three ways. Adding fences to the timer functions +// reduces variability due to instruction reordering, improving the timer +// resolution to about 40 CPU cycles. However, shorter functions must still +// be invoked repeatedly. For more realistic branch prediction performance, +// we vary the input parameter according to a user-specified distribution. +// Thus, instead of VaryInputs(Measure(Repeat(func))), we change the +// loop nesting to Measure(Repeat(VaryInputs(func))). We also estimate the +// central tendency of the measurement samples with the "half sample mode", +// which is more robust to outliers and skewed data than the mean or median. + +#include +#include + +#include "hwy/highway_export.h" +#include "hwy/timer.h" // IWYU pragma: export +#include "hwy/base.h" + +namespace hwy { + +// Returns 1, but without the compiler knowing what the value is. This prevents +// optimizing out code. +HWY_DLLEXPORT int Unpredictable1(); + +// Input influencing the function being measured (e.g. number of bytes to copy). +using FuncInput = size_t; + +// "Proof of work" returned by Func to ensure the compiler does not elide it. +using FuncOutput = uint64_t; + +// Function to measure: either 1) a captureless lambda or function with two +// arguments or 2) a lambda with capture, in which case the first argument +// is reserved for use by MeasureClosure. +using Func = FuncOutput (*)(const void*, FuncInput); + +// Internal parameters that determine precision/resolution/measuring time. +struct Params { + // Best-case precision, expressed as a divisor of the timer resolution. + // Larger => more calls to Func and higher precision. + size_t precision_divisor = 1024; + + // Ratio between full and subset input distribution sizes. Cannot be less + // than 2; larger values increase measurement time but more faithfully + // model the given input distribution. + size_t subset_ratio = 2; + + // Together with the estimated Func duration, determines how many times to + // call Func before checking the sample variability. Larger values increase + // measurement time, memory/cache use and precision. + double seconds_per_eval = 4E-3; + + // The minimum number of samples before estimating the central tendency. + size_t min_samples_per_eval = 7; + + // The mode is better than median for estimating the central tendency of + // skewed/fat-tailed distributions, but it requires sufficient samples + // relative to the width of half-ranges. + size_t min_mode_samples = 64; + + // Maximum permissible variability (= median absolute deviation / center). + double target_rel_mad = 0.002; + + // Abort after this many evals without reaching target_rel_mad. This + // prevents infinite loops. + size_t max_evals = 9; + + // Whether to print additional statistics to stdout. + bool verbose = true; +}; + +// Measurement result for each unique input. +struct Result { + FuncInput input; + + // Robust estimate (mode or median) of duration. + float ticks; + + // Measure of variability (median absolute deviation relative to "ticks"). + float variability; +}; + +// Returns a Params struct with customized configuration for benchmarks. +// Specifically limits `max_evals` to prevent timeout in tests. +static inline Params DefaultBenchmarkParams() { + Params p; + p.max_evals = HWY_IS_DEBUG_BUILD ? 3 : 4; + return p; +} + +// Precisely measures the number of ticks elapsed when calling "func" with the +// given inputs, shuffled to ensure realistic branch prediction hit rates. +// +// "func" returns a 'proof of work' to ensure its computations are not elided. +// "arg" is passed to Func, or reserved for internal use by MeasureClosure. +// "inputs" is an array of "num_inputs" (not necessarily unique) arguments to +// "func". The values should be chosen to maximize coverage of "func". This +// represents a distribution, so a value's frequency should reflect its +// probability in the real application. Order does not matter; for example, a +// uniform distribution over [0, 4) could be represented as {3,0,2,1}. +// Returns how many Result were written to "results": one per unique input, or +// zero if the measurement failed (an error message goes to stderr). +HWY_DLLEXPORT size_t Measure(Func func, const uint8_t* arg, + const FuncInput* inputs, size_t num_inputs, + Result* results, const Params& p = Params()); + +// Calls operator() of the given closure (lambda function). +template +static FuncOutput CallClosure(const void* f, const FuncInput input) { + return (*reinterpret_cast(f))(input); +} + +// Same as Measure, except "closure" is typically a lambda function of +// FuncInput -> FuncOutput with a capture list. +template +static inline size_t MeasureClosure(const Closure& closure, + const FuncInput* inputs, + const size_t num_inputs, Result* results, + const Params& p = Params()) { + return Measure(static_cast(&CallClosure), + reinterpret_cast(&closure), inputs, num_inputs, + results, p); +} + +} // namespace hwy + +#endif // HIGHWAY_HWY_NANOBENCHMARK_H_ diff --git a/lib/highway/hwy/ops/arm_neon-inl.h b/lib/highway/hwy/ops/arm_neon-inl.h new file mode 100644 index 00000000000..1a610ebb4f7 --- /dev/null +++ b/lib/highway/hwy/ops/arm_neon-inl.h @@ -0,0 +1,10656 @@ +// Copyright 2019 Google LLC +// Copyright 2024-2026 Arm Limited and/or its affiliates +// SPDX-License-Identifier: Apache-2.0 +// SPDX-License-Identifier: BSD-3-Clause +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 128-bit Arm NEON vectors and operations. +// External include guard in highway.h - see comment there. + +// Arm NEON intrinsics are documented at: +// https://developer.arm.com/architectures/instruction-sets/intrinsics/#f:@navigationhierarchiessimdisa=[Neon] + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wuninitialized") +#include // NOLINT(build/include_order) +HWY_DIAGNOSTICS(pop) + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +namespace detail { // for code folding and Raw128 + +// Macros used to define single and double function calls for multiple types +// for full and half vectors. These macros are undefined at the end of the file. + +// HWY_NEON_BUILD_TPL_* is the template<...> prefix to the function. +#define HWY_NEON_BUILD_TPL_1 +#define HWY_NEON_BUILD_TPL_2 +#define HWY_NEON_BUILD_TPL_3 + +// HWY_NEON_BUILD_RET_* is return type; type arg is without _t suffix so we can +// extend it to int32x4x2_t packs. +#define HWY_NEON_BUILD_RET_1(type, size) Vec128 +#define HWY_NEON_BUILD_RET_2(type, size) Vec128 +#define HWY_NEON_BUILD_RET_3(type, size) Vec128 + +// HWY_NEON_BUILD_PARAM_* is the list of parameters the function receives. +#define HWY_NEON_BUILD_PARAM_1(type, size) const Vec128 a +#define HWY_NEON_BUILD_PARAM_2(type, size) \ + const Vec128 a, const Vec128 b +#define HWY_NEON_BUILD_PARAM_3(type, size) \ + const Vec128 a, const Vec128 b, \ + const Vec128 c + +// HWY_NEON_BUILD_ARG_* is the list of arguments passed to the underlying +// function. +#define HWY_NEON_BUILD_ARG_1 a.raw +#define HWY_NEON_BUILD_ARG_2 a.raw, b.raw +#define HWY_NEON_BUILD_ARG_3 a.raw, b.raw, c.raw + +// We use HWY_NEON_EVAL(func, ...) to delay the evaluation of func until after +// the __VA_ARGS__ have been expanded. This allows "func" to be a macro on +// itself like with some of the library "functions" such as vshlq_u8. For +// example, HWY_NEON_EVAL(vshlq_u8, MY_PARAMS) where MY_PARAMS is defined as +// "a, b" (without the quotes) will end up expanding "vshlq_u8(a, b)" if needed. +// Directly writing vshlq_u8(MY_PARAMS) would fail since vshlq_u8() macro +// expects two arguments. +#define HWY_NEON_EVAL(func, ...) func(__VA_ARGS__) + +// Main macro definition that defines a single function for the given type and +// size of vector, using the underlying (prefix##infix##suffix) function and +// the template, return type, parameters and arguments defined by the "args" +// parameters passed here (see HWY_NEON_BUILD_* macros defined before). +#define HWY_NEON_DEF_FUNCTION(type, size, name, prefix, infix, suffix, args) \ + HWY_CONCAT(HWY_NEON_BUILD_TPL_, args) \ + HWY_API HWY_CONCAT(HWY_NEON_BUILD_RET_, args)(type, size) \ + name(HWY_CONCAT(HWY_NEON_BUILD_PARAM_, args)(type, size)) { \ + return HWY_CONCAT(HWY_NEON_BUILD_RET_, args)(type, size)( \ + HWY_NEON_EVAL(prefix##infix##suffix, HWY_NEON_BUILD_ARG_##args)); \ + } + +// The HWY_NEON_DEF_FUNCTION_* macros define all the variants of a function +// called "name" using the set of neon functions starting with the given +// "prefix" for all the variants of certain types, as specified next to each +// macro. For example, the prefix "vsub" can be used to define the operator- +// using args=2. + +// uint8_t +#define HWY_NEON_DEF_FUNCTION_UINT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 16, name, prefix##q, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 8, name, prefix, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 4, name, prefix, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 2, name, prefix, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 1, name, prefix, infix, u8, args) + +// int8_t +#define HWY_NEON_DEF_FUNCTION_INT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int8, 16, name, prefix##q, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 8, name, prefix, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 4, name, prefix, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 2, name, prefix, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int8, 1, name, prefix, infix, s8, args) + +// uint16_t +#define HWY_NEON_DEF_FUNCTION_UINT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 8, name, prefix##q, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 4, name, prefix, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 2, name, prefix, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 1, name, prefix, infix, u16, args) + +// int16_t +#define HWY_NEON_DEF_FUNCTION_INT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int16, 8, name, prefix##q, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int16, 4, name, prefix, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int16, 2, name, prefix, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int16, 1, name, prefix, infix, s16, args) + +// uint32_t +#define HWY_NEON_DEF_FUNCTION_UINT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint32, 4, name, prefix##q, infix, u32, args) \ + HWY_NEON_DEF_FUNCTION(uint32, 2, name, prefix, infix, u32, args) \ + HWY_NEON_DEF_FUNCTION(uint32, 1, name, prefix, infix, u32, args) + +// int32_t +#define HWY_NEON_DEF_FUNCTION_INT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int32, 4, name, prefix##q, infix, s32, args) \ + HWY_NEON_DEF_FUNCTION(int32, 2, name, prefix, infix, s32, args) \ + HWY_NEON_DEF_FUNCTION(int32, 1, name, prefix, infix, s32, args) + +// uint64_t +#define HWY_NEON_DEF_FUNCTION_UINT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 2, name, prefix##q, infix, u64, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 1, name, prefix, infix, u64, args) + +// int64_t +#define HWY_NEON_DEF_FUNCTION_INT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int64, 2, name, prefix##q, infix, s64, args) \ + HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args) + +// bfloat16_t +#if HWY_NEON_HAVE_BFLOAT16 +#define HWY_NEON_DEF_FUNCTION_BFLOAT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(bfloat16, 8, name, prefix##q, infix, bf16, args) \ + HWY_NEON_DEF_FUNCTION(bfloat16, 4, name, prefix, infix, bf16, args) \ + HWY_NEON_DEF_FUNCTION(bfloat16, 2, name, prefix, infix, bf16, args) \ + HWY_NEON_DEF_FUNCTION(bfloat16, 1, name, prefix, infix, bf16, args) +#else +#define HWY_NEON_DEF_FUNCTION_BFLOAT_16(name, prefix, infix, args) +#endif + +// Used for conversion instructions if HWY_NEON_HAVE_F16C. +#define HWY_NEON_DEF_FUNCTION_FLOAT_16_UNCONDITIONAL(name, prefix, infix, \ + args) \ + HWY_NEON_DEF_FUNCTION(float16, 8, name, prefix##q, infix, f16, args) \ + HWY_NEON_DEF_FUNCTION(float16, 4, name, prefix, infix, f16, args) \ + HWY_NEON_DEF_FUNCTION(float16, 2, name, prefix, infix, f16, args) \ + HWY_NEON_DEF_FUNCTION(float16, 1, name, prefix, infix, f16, args) + +// float16_t +#if HWY_HAVE_FLOAT16 +#define HWY_NEON_DEF_FUNCTION_FLOAT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_16_UNCONDITIONAL(name, prefix, infix, args) +#else +#define HWY_NEON_DEF_FUNCTION_FLOAT_16(name, prefix, infix, args) +#endif + +// Enable generic functions for whichever of (f16, bf16) are not supported. +#if !HWY_HAVE_FLOAT16 && !HWY_NEON_HAVE_BFLOAT16 +#define HWY_NEON_IF_EMULATED_D(D) HWY_IF_SPECIAL_FLOAT_D(D) +#define HWY_GENERIC_IF_EMULATED_D(D) HWY_IF_SPECIAL_FLOAT_D(D) +#define HWY_NEON_IF_NOT_EMULATED_D(D) HWY_IF_NOT_SPECIAL_FLOAT_D(D) +#elif !HWY_HAVE_FLOAT16 && HWY_NEON_HAVE_BFLOAT16 +#define HWY_NEON_IF_EMULATED_D(D) HWY_IF_F16_D(D) +#define HWY_GENERIC_IF_EMULATED_D(D) HWY_IF_F16_D(D) +#define HWY_NEON_IF_NOT_EMULATED_D(D) HWY_IF_NOT_F16_D(D) +#elif HWY_HAVE_FLOAT16 && !HWY_NEON_HAVE_BFLOAT16 +#define HWY_NEON_IF_EMULATED_D(D) HWY_IF_BF16_D(D) +#define HWY_GENERIC_IF_EMULATED_D(D) HWY_IF_BF16_D(D) +#define HWY_NEON_IF_NOT_EMULATED_D(D) HWY_IF_NOT_BF16_D(D) +#elif HWY_HAVE_FLOAT16 && HWY_NEON_HAVE_BFLOAT16 +// NOTE: hwy::EnableIf()>* = nullptr is used instead of +// hwy::EnableIf* = nullptr to avoid compiler errors since +// !hwy::IsSame() is always false and as !hwy::IsSame() will cause +// SFINAE to occur instead of a hard error due to a dependency on the D template +// argument +#define HWY_NEON_IF_EMULATED_D(D) hwy::EnableIf()>* = nullptr +#define HWY_GENERIC_IF_EMULATED_D(D) \ + hwy::EnableIf()>* = nullptr +#define HWY_NEON_IF_NOT_EMULATED_D(D) hwy::EnableIf* = nullptr +#else +#error "Logic error, handled all four cases" +#endif + +// float +#define HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(float32, 4, name, prefix##q, infix, f32, args) \ + HWY_NEON_DEF_FUNCTION(float32, 2, name, prefix, infix, f32, args) \ + HWY_NEON_DEF_FUNCTION(float32, 1, name, prefix, infix, f32, args) + +// double +#if HWY_HAVE_FLOAT64 +#define HWY_NEON_DEF_FUNCTION_FLOAT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(float64, 2, name, prefix##q, infix, f64, args) \ + HWY_NEON_DEF_FUNCTION(float64, 1, name, prefix, infix, f64, args) +#else +#define HWY_NEON_DEF_FUNCTION_FLOAT_64(name, prefix, infix, args) +#endif + +// Helper macros to define for more than one type. +// uint8_t, uint16_t and uint32_t +#define HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_32(name, prefix, infix, args) + +// int8_t, int16_t and int32_t +#define HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_32(name, prefix, infix, args) + +// uint8_t, uint16_t, uint32_t and uint64_t +#define HWY_NEON_DEF_FUNCTION_UINTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_64(name, prefix, infix, args) + +// int8_t, int16_t, int32_t and int64_t +#define HWY_NEON_DEF_FUNCTION_INTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_64(name, prefix, infix, args) + +// All int*_t and uint*_t up to 64 +#define HWY_NEON_DEF_FUNCTION_INTS_UINTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINTS(name, prefix, infix, args) + +#define HWY_NEON_DEF_FUNCTION_FLOAT_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_32(name, prefix, infix, args) + +#define HWY_NEON_DEF_FUNCTION_ALL_FLOATS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_64(name, prefix, infix, args) + +// All previous types. +#define HWY_NEON_DEF_FUNCTION_ALL_TYPES(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INTS_UINTS(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_ALL_FLOATS(name, prefix, infix, args) + +#define HWY_NEON_DEF_FUNCTION_UI_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) + +#define HWY_NEON_DEF_FUNCTION_UIF_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UI_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_16_32(name, prefix, infix, args) + +#define HWY_NEON_DEF_FUNCTION_UIF_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_64(name, prefix, infix, args) + +// For vzip1/2 +#define HWY_NEON_DEF_FUNCTION_FULL_UI_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 2, name, prefix##q, infix, u64, args) \ + HWY_NEON_DEF_FUNCTION(int64, 2, name, prefix##q, infix, s64, args) +#define HWY_NEON_DEF_FUNCTION_FULL_UIF_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FULL_UI_64(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(float64, 2, name, prefix##q, infix, f64, args) + +// For eor3q, which is only defined for full vectors. +#define HWY_NEON_DEF_FUNCTION_FULL_UI(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(uint8, 16, name, prefix##q, infix, u8, args) \ + HWY_NEON_DEF_FUNCTION(uint16, 8, name, prefix##q, infix, u16, args) \ + HWY_NEON_DEF_FUNCTION(uint32, 4, name, prefix##q, infix, u32, args) \ + HWY_NEON_DEF_FUNCTION(int8, 16, name, prefix##q, infix, s8, args) \ + HWY_NEON_DEF_FUNCTION(int16, 8, name, prefix##q, infix, s16, args) \ + HWY_NEON_DEF_FUNCTION(int32, 4, name, prefix##q, infix, s32, args) \ + HWY_NEON_DEF_FUNCTION_FULL_UI_64(name, prefix, infix, args) +// Emulation of some intrinsics on armv7. +#if HWY_ARCH_ARM_V7 +#define vuzp1_s8(x, y) vuzp_s8(x, y).val[0] +#define vuzp1_u8(x, y) vuzp_u8(x, y).val[0] +#define vuzp1_s16(x, y) vuzp_s16(x, y).val[0] +#define vuzp1_u16(x, y) vuzp_u16(x, y).val[0] +#define vuzp1_s32(x, y) vuzp_s32(x, y).val[0] +#define vuzp1_u32(x, y) vuzp_u32(x, y).val[0] +#define vuzp1_f32(x, y) vuzp_f32(x, y).val[0] +#define vuzp1q_s8(x, y) vuzpq_s8(x, y).val[0] +#define vuzp1q_u8(x, y) vuzpq_u8(x, y).val[0] +#define vuzp1q_s16(x, y) vuzpq_s16(x, y).val[0] +#define vuzp1q_u16(x, y) vuzpq_u16(x, y).val[0] +#define vuzp1q_s32(x, y) vuzpq_s32(x, y).val[0] +#define vuzp1q_u32(x, y) vuzpq_u32(x, y).val[0] +#define vuzp1q_f32(x, y) vuzpq_f32(x, y).val[0] +#define vuzp2_s8(x, y) vuzp_s8(x, y).val[1] +#define vuzp2_u8(x, y) vuzp_u8(x, y).val[1] +#define vuzp2_s16(x, y) vuzp_s16(x, y).val[1] +#define vuzp2_u16(x, y) vuzp_u16(x, y).val[1] +#define vuzp2_s32(x, y) vuzp_s32(x, y).val[1] +#define vuzp2_u32(x, y) vuzp_u32(x, y).val[1] +#define vuzp2_f32(x, y) vuzp_f32(x, y).val[1] +#define vuzp2q_s8(x, y) vuzpq_s8(x, y).val[1] +#define vuzp2q_u8(x, y) vuzpq_u8(x, y).val[1] +#define vuzp2q_s16(x, y) vuzpq_s16(x, y).val[1] +#define vuzp2q_u16(x, y) vuzpq_u16(x, y).val[1] +#define vuzp2q_s32(x, y) vuzpq_s32(x, y).val[1] +#define vuzp2q_u32(x, y) vuzpq_u32(x, y).val[1] +#define vuzp2q_f32(x, y) vuzpq_f32(x, y).val[1] +#define vzip1_s8(x, y) vzip_s8(x, y).val[0] +#define vzip1_u8(x, y) vzip_u8(x, y).val[0] +#define vzip1_s16(x, y) vzip_s16(x, y).val[0] +#define vzip1_u16(x, y) vzip_u16(x, y).val[0] +#define vzip1_f32(x, y) vzip_f32(x, y).val[0] +#define vzip1_u32(x, y) vzip_u32(x, y).val[0] +#define vzip1_s32(x, y) vzip_s32(x, y).val[0] +#define vzip1q_s8(x, y) vzipq_s8(x, y).val[0] +#define vzip1q_u8(x, y) vzipq_u8(x, y).val[0] +#define vzip1q_s16(x, y) vzipq_s16(x, y).val[0] +#define vzip1q_u16(x, y) vzipq_u16(x, y).val[0] +#define vzip1q_s32(x, y) vzipq_s32(x, y).val[0] +#define vzip1q_u32(x, y) vzipq_u32(x, y).val[0] +#define vzip1q_f32(x, y) vzipq_f32(x, y).val[0] +#define vzip2_s8(x, y) vzip_s8(x, y).val[1] +#define vzip2_u8(x, y) vzip_u8(x, y).val[1] +#define vzip2_s16(x, y) vzip_s16(x, y).val[1] +#define vzip2_u16(x, y) vzip_u16(x, y).val[1] +#define vzip2_s32(x, y) vzip_s32(x, y).val[1] +#define vzip2_u32(x, y) vzip_u32(x, y).val[1] +#define vzip2_f32(x, y) vzip_f32(x, y).val[1] +#define vzip2q_s8(x, y) vzipq_s8(x, y).val[1] +#define vzip2q_u8(x, y) vzipq_u8(x, y).val[1] +#define vzip2q_s16(x, y) vzipq_s16(x, y).val[1] +#define vzip2q_u16(x, y) vzipq_u16(x, y).val[1] +#define vzip2q_s32(x, y) vzipq_s32(x, y).val[1] +#define vzip2q_u32(x, y) vzipq_u32(x, y).val[1] +#define vzip2q_f32(x, y) vzipq_f32(x, y).val[1] +#endif + +// Wrappers over uint8x16x2_t etc. so we can define StoreInterleaved2 +// overloads for all vector types, even those (bfloat16_t) where the +// underlying vector is the same as others (uint16_t). +template +struct Tuple2; +template +struct Tuple3; +template +struct Tuple4; + +template <> +struct Tuple2 { + uint8x16x2_t raw; +}; +template +struct Tuple2 { + uint8x8x2_t raw; +}; +template <> +struct Tuple2 { + int8x16x2_t raw; +}; +template +struct Tuple2 { + int8x8x2_t raw; +}; +template <> +struct Tuple2 { + uint16x8x2_t raw; +}; +template +struct Tuple2 { + uint16x4x2_t raw; +}; +template <> +struct Tuple2 { + int16x8x2_t raw; +}; +template +struct Tuple2 { + int16x4x2_t raw; +}; +template <> +struct Tuple2 { + uint32x4x2_t raw; +}; +template +struct Tuple2 { + uint32x2x2_t raw; +}; +template <> +struct Tuple2 { + int32x4x2_t raw; +}; +template +struct Tuple2 { + int32x2x2_t raw; +}; +template <> +struct Tuple2 { + uint64x2x2_t raw; +}; +template +struct Tuple2 { + uint64x1x2_t raw; +}; +template <> +struct Tuple2 { + int64x2x2_t raw; +}; +template +struct Tuple2 { + int64x1x2_t raw; +}; + +template <> +struct Tuple2 { + float32x4x2_t raw; +}; +template +struct Tuple2 { + float32x2x2_t raw; +}; +#if HWY_HAVE_FLOAT64 +template <> +struct Tuple2 { + float64x2x2_t raw; +}; +template +struct Tuple2 { + float64x1x2_t raw; +}; +#endif // HWY_HAVE_FLOAT64 + +template <> +struct Tuple3 { + uint8x16x3_t raw; +}; +template +struct Tuple3 { + uint8x8x3_t raw; +}; +template <> +struct Tuple3 { + int8x16x3_t raw; +}; +template +struct Tuple3 { + int8x8x3_t raw; +}; +template <> +struct Tuple3 { + uint16x8x3_t raw; +}; +template +struct Tuple3 { + uint16x4x3_t raw; +}; +template <> +struct Tuple3 { + int16x8x3_t raw; +}; +template +struct Tuple3 { + int16x4x3_t raw; +}; +template <> +struct Tuple3 { + uint32x4x3_t raw; +}; +template +struct Tuple3 { + uint32x2x3_t raw; +}; +template <> +struct Tuple3 { + int32x4x3_t raw; +}; +template +struct Tuple3 { + int32x2x3_t raw; +}; +template <> +struct Tuple3 { + uint64x2x3_t raw; +}; +template +struct Tuple3 { + uint64x1x3_t raw; +}; +template <> +struct Tuple3 { + int64x2x3_t raw; +}; +template +struct Tuple3 { + int64x1x3_t raw; +}; + +template <> +struct Tuple3 { + float32x4x3_t raw; +}; +template +struct Tuple3 { + float32x2x3_t raw; +}; +#if HWY_HAVE_FLOAT64 +template <> +struct Tuple3 { + float64x2x3_t raw; +}; +template +struct Tuple3 { + float64x1x3_t raw; +}; +#endif // HWY_HAVE_FLOAT64 + +template <> +struct Tuple4 { + uint8x16x4_t raw; +}; +template +struct Tuple4 { + uint8x8x4_t raw; +}; +template <> +struct Tuple4 { + int8x16x4_t raw; +}; +template +struct Tuple4 { + int8x8x4_t raw; +}; +template <> +struct Tuple4 { + uint16x8x4_t raw; +}; +template +struct Tuple4 { + uint16x4x4_t raw; +}; +template <> +struct Tuple4 { + int16x8x4_t raw; +}; +template +struct Tuple4 { + int16x4x4_t raw; +}; +template <> +struct Tuple4 { + uint32x4x4_t raw; +}; +template +struct Tuple4 { + uint32x2x4_t raw; +}; +template <> +struct Tuple4 { + int32x4x4_t raw; +}; +template +struct Tuple4 { + int32x2x4_t raw; +}; +template <> +struct Tuple4 { + uint64x2x4_t raw; +}; +template +struct Tuple4 { + uint64x1x4_t raw; +}; +template <> +struct Tuple4 { + int64x2x4_t raw; +}; +template +struct Tuple4 { + int64x1x4_t raw; +}; + +template <> +struct Tuple4 { + float32x4x4_t raw; +}; +template +struct Tuple4 { + float32x2x4_t raw; +}; +#if HWY_HAVE_FLOAT64 +template <> +struct Tuple4 { + float64x2x4_t raw; +}; +template +struct Tuple4 { + float64x1x4_t raw; +}; +#endif // HWY_HAVE_FLOAT64 + +template +struct Raw128; + +template <> +struct Raw128 { + using type = uint8x16_t; +}; +template +struct Raw128 { + using type = uint8x8_t; +}; + +template <> +struct Raw128 { + using type = uint16x8_t; +}; +template +struct Raw128 { + using type = uint16x4_t; +}; + +template <> +struct Raw128 { + using type = uint32x4_t; +}; +template +struct Raw128 { + using type = uint32x2_t; +}; + +template <> +struct Raw128 { + using type = uint64x2_t; +}; +template <> +struct Raw128 { + using type = uint64x1_t; +}; + +template <> +struct Raw128 { + using type = int8x16_t; +}; +template +struct Raw128 { + using type = int8x8_t; +}; + +template <> +struct Raw128 { + using type = int16x8_t; +}; +template +struct Raw128 { + using type = int16x4_t; +}; + +template <> +struct Raw128 { + using type = int32x4_t; +}; +template +struct Raw128 { + using type = int32x2_t; +}; + +template <> +struct Raw128 { + using type = int64x2_t; +}; +template <> +struct Raw128 { + using type = int64x1_t; +}; + +template <> +struct Raw128 { + using type = float32x4_t; +}; +template +struct Raw128 { + using type = float32x2_t; +}; + +#if HWY_HAVE_FLOAT64 +template <> +struct Raw128 { + using type = float64x2_t; +}; +template <> +struct Raw128 { + using type = float64x1_t; +}; +#endif // HWY_HAVE_FLOAT64 + +#if HWY_NEON_HAVE_F16C + +template <> +struct Tuple2 { + float16x8x2_t raw; +}; +template +struct Tuple2 { + float16x4x2_t raw; +}; + +template <> +struct Tuple3 { + float16x8x3_t raw; +}; +template +struct Tuple3 { + float16x4x3_t raw; +}; + +template <> +struct Tuple4 { + float16x8x4_t raw; +}; +template +struct Tuple4 { + float16x4x4_t raw; +}; + +template <> +struct Raw128 { + using type = float16x8_t; +}; +template +struct Raw128 { + using type = float16x4_t; +}; + +#else // !HWY_NEON_HAVE_F16C + +template +struct Tuple2 : public Tuple2 {}; +template +struct Tuple3 : public Tuple3 {}; +template +struct Tuple4 : public Tuple4 {}; +template +struct Raw128 : public Raw128 {}; + +#endif // HWY_NEON_HAVE_F16C + +#if HWY_NEON_HAVE_BFLOAT16 + +template <> +struct Tuple2 { + bfloat16x8x2_t raw; +}; +template +struct Tuple2 { + bfloat16x4x2_t raw; +}; + +template <> +struct Tuple3 { + bfloat16x8x3_t raw; +}; +template +struct Tuple3 { + bfloat16x4x3_t raw; +}; + +template <> +struct Tuple4 { + bfloat16x8x4_t raw; +}; +template +struct Tuple4 { + bfloat16x4x4_t raw; +}; + +template <> +struct Raw128 { + using type = bfloat16x8_t; +}; +template +struct Raw128 { + using type = bfloat16x4_t; +}; + +#else // !HWY_NEON_HAVE_BFLOAT16 + +template +struct Tuple2 : public Tuple2 {}; +template +struct Tuple3 : public Tuple3 {}; +template +struct Tuple4 : public Tuple4 {}; +template +struct Raw128 : public Raw128 {}; + +#endif // HWY_NEON_HAVE_BFLOAT16 + +} // namespace detail + +template +class Vec128 { + public: + using Raw = typename detail::Raw128::type; + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = N; // only for DFromV + + HWY_INLINE Vec128() {} + Vec128(const Vec128&) = default; + Vec128& operator=(const Vec128&) = default; + HWY_INLINE explicit Vec128(const Raw raw) : raw(raw) {} + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator%=(const Vec128 other) { + return *this = (*this % other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +template +using Vec64 = Vec128; + +template +using Vec32 = Vec128; + +template +using Vec16 = Vec128; + +// FF..FF or 0. +template +class Mask128 { + public: + // Arm C Language Extensions return and expect unsigned type. + using Raw = typename detail::Raw128, N>::type; + + using PrivateT = T; // only for DFromM + static constexpr size_t kPrivateN = N; // only for DFromM + + HWY_INLINE Mask128() {} + Mask128(const Mask128&) = default; + Mask128& operator=(const Mask128&) = default; + HWY_INLINE explicit Mask128(const Raw raw) : raw(raw) {} + + Raw raw; +}; + +template +using Mask64 = Mask128; + +template +using DFromV = Simd; + +template +using DFromM = Simd; + +template +using TFromV = typename V::PrivateT; + +// TODO(janwas): ForDemoteVectors, in convert_test and demote_test, appear to +// instantiate this with D = double x 4. The cause is unknown. Previously, +// defining this in terms of Set rejected that via SFINAE because only +// V_SIZE = 16 and V_SIZE <= 8 overloads were defined. As a workaround, +// truncate the lane count to 128 bits. +template +using VFromD = + Vec128, HWY_MIN(16 / sizeof(TFromD), MaxLanes(D()))>; + +// ------------------------------ BitCast + +namespace detail { + +// Converts from Vec128 to Vec128 using the +// vreinterpret*_u8_*() set of functions. +#define HWY_NEON_BUILD_TPL_HWY_CAST_TO_U8 +#define HWY_NEON_BUILD_RET_HWY_CAST_TO_U8(type, size) \ + Vec128 +#define HWY_NEON_BUILD_PARAM_HWY_CAST_TO_U8(type, size) Vec128 v +#define HWY_NEON_BUILD_ARG_HWY_CAST_TO_U8 v.raw + +// Special case of u8 to u8 since vreinterpret*_u8_u8 is obviously not defined. +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return v; +} + +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(BitCastToByte, vreinterpret, _u8_, + HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_BFLOAT_16(BitCastToByte, vreinterpret, _u8_, + HWY_CAST_TO_U8) + +HWY_NEON_DEF_FUNCTION_INTS(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_UINT_16(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_UINT_32(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) +HWY_NEON_DEF_FUNCTION_UINT_64(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8) + +#if !HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_F16C +HWY_NEON_DEF_FUNCTION_FLOAT_16_UNCONDITIONAL(BitCastToByte, vreinterpret, _u8_, + HWY_CAST_TO_U8) +#else +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return BitCastToByte(Vec128(v.raw)); +} +#endif // HWY_NEON_HAVE_F16C +#endif // !HWY_HAVE_FLOAT16 + +#if !HWY_NEON_HAVE_BFLOAT16 +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return BitCastToByte(Vec128(v.raw)); +} +#endif // !HWY_NEON_HAVE_BFLOAT16 + +#undef HWY_NEON_BUILD_TPL_HWY_CAST_TO_U8 +#undef HWY_NEON_BUILD_RET_HWY_CAST_TO_U8 +#undef HWY_NEON_BUILD_PARAM_HWY_CAST_TO_U8 +#undef HWY_NEON_BUILD_ARG_HWY_CAST_TO_U8 + +template +HWY_INLINE VFromD BitCastFromByte(D /* tag */, VFromD v) { + return v; +} + +// 64-bit or less: + +template +HWY_INLINE VFromD BitCastFromByte(D /* tag */, + VFromD> v) { + return VFromD(vreinterpret_s8_u8(v.raw)); +} +template +HWY_INLINE VFromD BitCastFromByte(D /* tag */, + VFromD> v) { + return VFromD(vreinterpret_u16_u8(v.raw)); +} +template +HWY_INLINE VFromD BitCastFromByte(D /* tag */, + VFromD> v) { + return VFromD(vreinterpret_s16_u8(v.raw)); +} +template +HWY_INLINE VFromD BitCastFromByte(D /* tag */, + VFromD> v) { + return VFromD(vreinterpret_u32_u8(v.raw)); +} +template +HWY_INLINE VFromD BitCastFromByte(D /* tag */, + VFromD> v) { + return VFromD(vreinterpret_s32_u8(v.raw)); +} + +template +HWY_INLINE Vec64 BitCastFromByte(D /* tag */, Vec64 v) { + return Vec64(vreinterpret_u64_u8(v.raw)); +} +template +HWY_INLINE Vec64 BitCastFromByte(D /* tag */, Vec64 v) { + return Vec64(vreinterpret_s64_u8(v.raw)); +} + +// Cannot use HWY_NEON_IF_EMULATED_D due to the extra HWY_NEON_HAVE_F16C. +template +HWY_INLINE VFromD BitCastFromByte(D, VFromD> v) { +#if HWY_HAVE_FLOAT16 || HWY_NEON_HAVE_F16C + return VFromD(vreinterpret_f16_u8(v.raw)); +#else + const RebindToUnsigned du; + return VFromD(BitCastFromByte(du, v).raw); +#endif +} + +template +HWY_INLINE VFromD BitCastFromByte(D, VFromD> v) { +#if HWY_NEON_HAVE_BFLOAT16 + return VFromD(vreinterpret_bf16_u8(v.raw)); +#else + const RebindToUnsigned du; + return VFromD(BitCastFromByte(du, v).raw); +#endif +} + +template +HWY_INLINE VFromD BitCastFromByte(D /* tag */, + VFromD> v) { + return VFromD(vreinterpret_f32_u8(v.raw)); +} + +#if HWY_HAVE_FLOAT64 +template +HWY_INLINE Vec64 BitCastFromByte(D /* tag */, Vec64 v) { + return Vec64(vreinterpret_f64_u8(v.raw)); +} +#endif // HWY_HAVE_FLOAT64 + +// 128-bit full: + +template +HWY_INLINE Vec128 BitCastFromByte(D /* tag */, Vec128 v) { + return Vec128(vreinterpretq_s8_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(D /* tag */, Vec128 v) { + return Vec128(vreinterpretq_u16_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(D /* tag */, Vec128 v) { + return Vec128(vreinterpretq_s16_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(D /* tag */, Vec128 v) { + return Vec128(vreinterpretq_u32_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(D /* tag */, Vec128 v) { + return Vec128(vreinterpretq_s32_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(D /* tag */, Vec128 v) { + return Vec128(vreinterpretq_u64_u8(v.raw)); +} +template +HWY_INLINE Vec128 BitCastFromByte(D /* tag */, Vec128 v) { + return Vec128(vreinterpretq_s64_u8(v.raw)); +} + +template +HWY_INLINE Vec128 BitCastFromByte(D /* tag */, Vec128 v) { + return Vec128(vreinterpretq_f32_u8(v.raw)); +} + +#if HWY_HAVE_FLOAT64 +template +HWY_INLINE Vec128 BitCastFromByte(D /* tag */, Vec128 v) { + return Vec128(vreinterpretq_f64_u8(v.raw)); +} +#endif // HWY_HAVE_FLOAT64 + +// Cannot use HWY_NEON_IF_EMULATED_D due to the extra HWY_NEON_HAVE_F16C. +template +HWY_INLINE VFromD BitCastFromByte(D, Vec128 v) { +#if HWY_HAVE_FLOAT16 || HWY_NEON_HAVE_F16C + return VFromD(vreinterpretq_f16_u8(v.raw)); +#else + return VFromD(BitCastFromByte(RebindToUnsigned(), v).raw); +#endif +} + +template +HWY_INLINE VFromD BitCastFromByte(D, Vec128 v) { +#if HWY_NEON_HAVE_BFLOAT16 + return VFromD(vreinterpretq_bf16_u8(v.raw)); +#else + return VFromD(BitCastFromByte(RebindToUnsigned(), v).raw); +#endif +} + +} // namespace detail + +template +HWY_API VFromD BitCast(D d, + Vec128().MaxLanes()> v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ ResizeBitCast + +// <= 8 byte vector to <= 8 byte vector +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + const Repartition du8; + return BitCast(d, VFromD{detail::BitCastToByte(v).raw}); +} + +// 16-byte vector to 16-byte vector: same as BitCast +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + return BitCast(d, v); +} + +// 16-byte vector to <= 8-byte vector +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + const DFromV d_from; + const Half dh_from; + return ResizeBitCast(d, LowerHalf(dh_from, v)); +} + +// <= 8-bit vector to 16-byte vector +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + const Full64> d_full64_from; + const Full128> d_full128_from; + return BitCast(d, Combine(d_full128_from, Zero(d_full64_from), + ResizeBitCast(d_full64_from, v))); +} + +// ------------------------------ Set + +namespace detail { +// We want to route any combination of N/kPow2 to the intrinsics depending on +// whether the requested size is <= 64 bits or 128. HWY_NEON_BUILD_TPL is +// unconditional and currently does not accept inputs (such as whether the +// vector is 64 or 128-bit). Thus we are not able to use HWY_IF_V_SIZE_D for +// SFINAE. We instead define a private NativeSet which receives a Simd<> whose +// kPow2 has already been folded into its N. +#define HWY_NEON_BUILD_TPL_HWY_SET +#define HWY_NEON_BUILD_RET_HWY_SET(type, size) Vec128 +#define HWY_NEON_BUILD_PARAM_HWY_SET(type, size) \ + Simd /* tag */, type##_t t +#define HWY_NEON_BUILD_ARG_HWY_SET t + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(NativeSet, vdup, _n_, HWY_SET) +#if !HWY_HAVE_FLOAT16 && HWY_NEON_HAVE_F16C && HWY_HAVE_SCALAR_F16_TYPE +HWY_NEON_DEF_FUNCTION_FLOAT_16_UNCONDITIONAL(NativeSet, vdup, _n_, HWY_SET) +#endif +HWY_NEON_DEF_FUNCTION_BFLOAT_16(NativeSet, vdup, _n_, HWY_SET) + +#if !HWY_NEON_HAVE_F16C || !HWY_HAVE_SCALAR_F16_TYPE +template +HWY_API VFromD NativeSet(D d, TFromD t) { + const uint16_t tu = BitCastScalar(t); + return BitCast(d, Set(RebindToUnsigned(), tu)); +} +#endif + +#if !HWY_NEON_HAVE_BFLOAT16 +template +HWY_API VFromD NativeSet(D d, TFromD t) { + const uint16_t tu = BitCastScalar(t); + return BitCast(d, Set(RebindToUnsigned(), tu)); +} +#endif + +#undef HWY_NEON_BUILD_TPL_HWY_SET +#undef HWY_NEON_BUILD_RET_HWY_SET +#undef HWY_NEON_BUILD_PARAM_HWY_SET +#undef HWY_NEON_BUILD_ARG_HWY_SET + +} // namespace detail + +// Full vector. +// Do not use a typename T = TFromD argument because T will be deduced from +// the actual argument type, which can differ from TFromD. +template +HWY_INLINE VFromD Set(D /* tag */, T t) { + return detail::NativeSet(Full128>(), static_cast>(t)); +} + +// Partial vector: create 64-bit and return wrapper. +template +HWY_API VFromD Set(D /* tag */, T t) { + const Full64> dfull; + return VFromD(detail::NativeSet(dfull, static_cast>(t)).raw); +} + +template +HWY_API VFromD Zero(D d) { + // Default ctor also works for bfloat16_t and float16_t. + return Set(d, TFromD{}); +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") +#if HWY_COMPILER_GCC_ACTUAL +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wmaybe-uninitialized") +#endif + +template +HWY_API VFromD Undefined(D /*tag*/) { +#if HWY_HAS_BUILTIN(__builtin_nondeterministic_value) + return VFromD{__builtin_nondeterministic_value(Zero(D()).raw)}; +#else + VFromD v; + return v; +#endif +} + +HWY_DIAGNOSTICS(pop) + +#if !HWY_COMPILER_GCC && !HWY_COMPILER_CLANGCL +namespace detail { + +#pragma pack(push, 1) + +template +struct alignas(8) Vec64ValsWrapper { + static_assert(sizeof(T) >= 1, "sizeof(T) >= 1 must be true"); + static_assert(sizeof(T) <= 8, "sizeof(T) <= 8 must be true"); + T vals[8 / sizeof(T)]; +}; + +#pragma pack(pop) + +} // namespace detail +#endif // !HWY_COMPILER_GCC && !HWY_COMPILER_CLANGCL + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD /*t8*/, TFromD /*t9*/, + TFromD /*t10*/, TFromD /*t11*/, + TFromD /*t12*/, TFromD /*t13*/, + TFromD /*t14*/, TFromD /*t15*/) { +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL + typedef int8_t GccI8RawVectType __attribute__((__vector_size__(8))); + (void)d; + const GccI8RawVectType raw = { + static_cast(t0), static_cast(t1), static_cast(t2), + static_cast(t3), static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7)}; + return VFromD(reinterpret_cast::Raw>(raw)); +#else + return ResizeBitCast( + d, Set(Full64(), + BitCastScalar(detail::Vec64ValsWrapper>{ + {t0, t1, t2, t3, t4, t5, t6, t7}}))); +#endif +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, + TFromD /*t4*/, TFromD /*t5*/, + TFromD /*t6*/, TFromD /*t7*/) { +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL + typedef int16_t GccI16RawVectType __attribute__((__vector_size__(8))); + (void)d; + const GccI16RawVectType raw = { + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3)}; + return VFromD(reinterpret_cast::Raw>(raw)); +#else + return ResizeBitCast( + d, Set(Full64(), + BitCastScalar( + detail::Vec64ValsWrapper>{{t0, t1, t2, t3}}))); +#endif +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD /*t2*/, TFromD /*t3*/) { +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL + typedef int32_t GccI32RawVectType __attribute__((__vector_size__(8))); + (void)d; + const GccI32RawVectType raw = {static_cast(t0), + static_cast(t1)}; + return VFromD(reinterpret_cast::Raw>(raw)); +#else + return ResizeBitCast(d, + Set(Full64(), + BitCastScalar( + detail::Vec64ValsWrapper>{{t0, t1}}))); +#endif +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD /*t2*/, TFromD /*t3*/) { +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL + typedef float GccF32RawVectType __attribute__((__vector_size__(8))); + (void)d; + const GccF32RawVectType raw = {t0, t1}; + return VFromD(reinterpret_cast::Raw>(raw)); +#else + return ResizeBitCast(d, + Set(Full64(), + BitCastScalar( + detail::Vec64ValsWrapper>{{t0, t1}}))); +#endif +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD /*t1*/) { + return Set(d, t0); +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL + typedef int8_t GccI8RawVectType __attribute__((__vector_size__(16))); + (void)d; + const GccI8RawVectType raw = { + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7), + static_cast(t8), static_cast(t9), + static_cast(t10), static_cast(t11), + static_cast(t12), static_cast(t13), + static_cast(t14), static_cast(t15)}; + return VFromD(reinterpret_cast::Raw>(raw)); +#else + const Half dh; + return Combine(d, + Dup128VecFromValues(dh, t8, t9, t10, t11, t12, t13, t14, t15, + t8, t9, t10, t11, t12, t13, t14, t15), + Dup128VecFromValues(dh, t0, t1, t2, t3, t4, t5, t6, t7, t0, t1, + t2, t3, t4, t5, t6, t7)); +#endif +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL + typedef int16_t GccI16RawVectType __attribute__((__vector_size__(16))); + (void)d; + const GccI16RawVectType raw = { + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7)}; + return VFromD(reinterpret_cast::Raw>(raw)); +#else + const Half dh; + return Combine(d, Dup128VecFromValues(dh, t4, t5, t6, t7, t4, t5, t6, t7), + Dup128VecFromValues(dh, t0, t1, t2, t3, t0, t1, t2, t3)); +#endif +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL + typedef int32_t GccI32RawVectType __attribute__((__vector_size__(16))); + (void)d; + const GccI32RawVectType raw = { + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3)}; + return VFromD(reinterpret_cast::Raw>(raw)); +#else + const Half dh; + return Combine(d, Dup128VecFromValues(dh, t2, t3, t2, t3), + Dup128VecFromValues(dh, t0, t1, t0, t1)); +#endif +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + (void)d; + const GccF32RawVectType raw = {t0, t1, t2, t3}; + return VFromD(reinterpret_cast::Raw>(raw)); +#else + const Half dh; + return Combine(d, Dup128VecFromValues(dh, t2, t3, t2, t3), + Dup128VecFromValues(dh, t0, t1, t0, t1)); +#endif +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1) { +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL + typedef int64_t GccI64RawVectType __attribute__((__vector_size__(16))); + (void)d; + const GccI64RawVectType raw = {static_cast(t0), + static_cast(t1)}; + return VFromD(reinterpret_cast::Raw>(raw)); +#else + const Half dh; + return Combine(d, Set(dh, t1), Set(dh, t0)); +#endif +} + +#if HWY_HAVE_FLOAT64 +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1) { +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + (void)d; + const GccF64RawVectType raw = {t0, t1}; + return VFromD(reinterpret_cast::Raw>(raw)); +#else + const Half dh; + return Combine(d, Set(dh, t1), Set(dh, t0)); +#endif +} +#endif + +// Generic for all vector lengths +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const RebindToSigned di; + return BitCast(d, + Dup128VecFromValues( + di, BitCastScalar(t0), BitCastScalar(t1), + BitCastScalar(t2), BitCastScalar(t3), + BitCastScalar(t4), BitCastScalar(t5), + BitCastScalar(t6), BitCastScalar(t7))); +} + +#if (HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL) && HWY_NEON_HAVE_F16C && \ + HWY_HAVE_SCALAR_F16_TYPE +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, + TFromD /*t4*/, TFromD /*t5*/, + TFromD /*t6*/, TFromD /*t7*/) { + typedef __fp16 GccF16RawVectType __attribute__((__vector_size__(8))); + (void)d; + const GccF16RawVectType raw = { + static_cast<__fp16>(t0), static_cast<__fp16>(t1), static_cast<__fp16>(t2), + static_cast<__fp16>(t3)}; + return VFromD(reinterpret_cast::Raw>(raw)); +} +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + typedef __fp16 GccF16RawVectType __attribute__((__vector_size__(16))); + (void)d; + const GccF16RawVectType raw = { + static_cast<__fp16>(t0), static_cast<__fp16>(t1), static_cast<__fp16>(t2), + static_cast<__fp16>(t3), static_cast<__fp16>(t4), static_cast<__fp16>(t5), + static_cast<__fp16>(t6), static_cast<__fp16>(t7)}; + return VFromD(reinterpret_cast::Raw>(raw)); +} +#else +// Generic for all vector lengths if MSVC or !HWY_NEON_HAVE_F16C +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const RebindToSigned di; + return BitCast(d, + Dup128VecFromValues( + di, BitCastScalar(t0), BitCastScalar(t1), + BitCastScalar(t2), BitCastScalar(t3), + BitCastScalar(t4), BitCastScalar(t5), + BitCastScalar(t6), BitCastScalar(t7))); +} +#endif // (HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL) && HWY_NEON_HAVE_F16C + +namespace detail { + +template +HWY_INLINE VFromD Iota0(D d) { + return Dup128VecFromValues( + d, TFromD{0}, TFromD{1}, TFromD{2}, TFromD{3}, TFromD{4}, + TFromD{5}, TFromD{6}, TFromD{7}, TFromD{8}, TFromD{9}, + TFromD{10}, TFromD{11}, TFromD{12}, TFromD{13}, TFromD{14}, + TFromD{15}); +} + +template +HWY_INLINE VFromD Iota0(D d) { + return Dup128VecFromValues(d, TFromD{0}, TFromD{1}, TFromD{2}, + TFromD{3}, TFromD{4}, TFromD{5}, + TFromD{6}, TFromD{7}); +} + +template +HWY_INLINE VFromD Iota0(D d) { + const RebindToUnsigned du; + return BitCast(d, Dup128VecFromValues(du, uint16_t{0}, uint16_t{0x3C00}, + uint16_t{0x4000}, uint16_t{0x4200}, + uint16_t{0x4400}, uint16_t{0x4500}, + uint16_t{0x4600}, uint16_t{0x4700})); +} + +template +HWY_INLINE VFromD Iota0(D d) { + return Dup128VecFromValues(d, TFromD{0}, TFromD{1}, TFromD{2}, + TFromD{3}); +} + +template +HWY_INLINE VFromD Iota0(D d) { + return Dup128VecFromValues(d, TFromD{0}, TFromD{1}); +} + +#if HWY_COMPILER_MSVC +template +static HWY_INLINE V MaskOutIota(V v) { + constexpr size_t kVecSizeInBytes = HWY_MAX_LANES_V(V) * sizeof(TFromV); + constexpr uint64_t kU64MaskOutMask = + hwy::LimitsMax>(); + + const DFromV d; + const Repartition du8; + using VU8 = VFromD; + const auto mask_out_mask = + BitCast(d, VU8(vreinterpret_u8_u64(vdup_n_u64(kU64MaskOutMask)))); + return v & mask_out_mask; +} +template +static HWY_INLINE V MaskOutIota(V v) { + return v; +} +#endif + +} // namespace detail + +template +HWY_API VFromD Iota(D d, const T2 first) { + const auto result_iota = + detail::Iota0(d) + Set(d, static_cast>(first)); +#if HWY_COMPILER_MSVC + return detail::MaskOutIota(result_iota); +#else + return result_iota; +#endif +} + +// ------------------------------ Combine + +// Full result +template +HWY_API Vec128 Combine(D /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_u8(lo.raw, hi.raw)); +} +template +HWY_API Vec128 Combine(D /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_u16(lo.raw, hi.raw)); +} +template +HWY_API Vec128 Combine(D /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_u32(lo.raw, hi.raw)); +} +template +HWY_API Vec128 Combine(D /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_u64(lo.raw, hi.raw)); +} + +template +HWY_API Vec128 Combine(D /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_s8(lo.raw, hi.raw)); +} +template +HWY_API Vec128 Combine(D /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_s16(lo.raw, hi.raw)); +} +template +HWY_API Vec128 Combine(D /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_s32(lo.raw, hi.raw)); +} +template +HWY_API Vec128 Combine(D /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_s64(lo.raw, hi.raw)); +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Combine(D, Vec64 hi, Vec64 lo) { + return Vec128(vcombine_f16(lo.raw, hi.raw)); +} +#endif // HWY_HAVE_FLOAT16 + +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API VFromD Combine(D, Vec64 hi, Vec64 lo) { + return VFromD(vcombine_bf16(lo.raw, hi.raw)); +} +#endif // HWY_NEON_HAVE_BFLOAT16 + +template , HWY_NEON_IF_EMULATED_D(D)> +HWY_API VFromD Combine(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + const Half duh; + return BitCast(d, Combine(du, BitCast(duh, hi), BitCast(duh, lo))); +} + +template +HWY_API Vec128 Combine(D /* tag */, Vec64 hi, Vec64 lo) { + return Vec128(vcombine_f32(lo.raw, hi.raw)); +} +#if HWY_HAVE_FLOAT64 +template +HWY_API Vec128 Combine(D /* tag */, Vec64 hi, + Vec64 lo) { + return Vec128(vcombine_f64(lo.raw, hi.raw)); +} +#endif // HWY_HAVE_FLOAT64 + +// ------------------------------ GetLane + +namespace detail { +#define HWY_NEON_BUILD_TPL_HWY_GET template +#define HWY_NEON_BUILD_RET_HWY_GET(type, size) type##_t +#define HWY_NEON_BUILD_PARAM_HWY_GET(type, size) Vec128 v +#define HWY_NEON_BUILD_ARG_HWY_GET v.raw, kLane + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(GetLane, vget, _lane_, HWY_GET) +HWY_NEON_DEF_FUNCTION_BFLOAT_16(GetLane, vget, _lane_, HWY_GET) + +template )> +static HWY_INLINE HWY_MAYBE_UNUSED TFromV GetLane(V v) { + const DFromV d; + const RebindToUnsigned du; + return BitCastScalar>(GetLane(BitCast(du, v))); +} + +#undef HWY_NEON_BUILD_TPL_HWY_GET +#undef HWY_NEON_BUILD_RET_HWY_GET +#undef HWY_NEON_BUILD_PARAM_HWY_GET +#undef HWY_NEON_BUILD_ARG_HWY_GET + +} // namespace detail + +template +HWY_API TFromV GetLane(const V v) { + return detail::GetLane<0>(v); +} + +// ------------------------------ ExtractLane + +// Requires one overload per vector length because GetLane<3> is a compile error +// if v is a uint32x2_t. +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return detail::GetLane<0>(v); +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::GetLane<0>(v); + case 1: + return detail::GetLane<1>(v); + } + } +#endif + alignas(16) T lanes[2]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::GetLane<0>(v); + case 1: + return detail::GetLane<1>(v); + case 2: + return detail::GetLane<2>(v); + case 3: + return detail::GetLane<3>(v); + } + } +#endif + alignas(16) T lanes[4]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::GetLane<0>(v); + case 1: + return detail::GetLane<1>(v); + case 2: + return detail::GetLane<2>(v); + case 3: + return detail::GetLane<3>(v); + case 4: + return detail::GetLane<4>(v); + case 5: + return detail::GetLane<5>(v); + case 6: + return detail::GetLane<6>(v); + case 7: + return detail::GetLane<7>(v); + } + } +#endif + alignas(16) T lanes[8]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::GetLane<0>(v); + case 1: + return detail::GetLane<1>(v); + case 2: + return detail::GetLane<2>(v); + case 3: + return detail::GetLane<3>(v); + case 4: + return detail::GetLane<4>(v); + case 5: + return detail::GetLane<5>(v); + case 6: + return detail::GetLane<6>(v); + case 7: + return detail::GetLane<7>(v); + case 8: + return detail::GetLane<8>(v); + case 9: + return detail::GetLane<9>(v); + case 10: + return detail::GetLane<10>(v); + case 11: + return detail::GetLane<11>(v); + case 12: + return detail::GetLane<12>(v); + case 13: + return detail::GetLane<13>(v); + case 14: + return detail::GetLane<14>(v); + case 15: + return detail::GetLane<15>(v); + } + } +#endif + alignas(16) T lanes[16]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane + +namespace detail { +#define HWY_NEON_BUILD_TPL_HWY_INSERT template +#define HWY_NEON_BUILD_RET_HWY_INSERT(type, size) Vec128 +#define HWY_NEON_BUILD_PARAM_HWY_INSERT(type, size) \ + Vec128 v, type##_t t +#define HWY_NEON_BUILD_ARG_HWY_INSERT t, v.raw, kLane + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(InsertLane, vset, _lane_, HWY_INSERT) +HWY_NEON_DEF_FUNCTION_BFLOAT_16(InsertLane, vset, _lane_, HWY_INSERT) + +#undef HWY_NEON_BUILD_TPL_HWY_INSERT +#undef HWY_NEON_BUILD_RET_HWY_INSERT +#undef HWY_NEON_BUILD_PARAM_HWY_INSERT +#undef HWY_NEON_BUILD_ARG_HWY_INSERT + +template , HWY_NEON_IF_EMULATED_D(D)> +HWY_API V InsertLane(const V v, TFromD t) { + const D d; + const RebindToUnsigned du; + const uint16_t tu = BitCastScalar(t); + return BitCast(d, InsertLane(BitCast(du, v), tu)); +} + +} // namespace detail + +// Requires one overload per vector length because InsertLane<3> may be a +// compile error. + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + return Set(DFromV(), t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[2]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[4]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[8]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + case 8: + return detail::InsertLane<8>(v, t); + case 9: + return detail::InsertLane<9>(v, t); + case 10: + return detail::InsertLane<10>(v, t); + case 11: + return detail::InsertLane<11>(v, t); + case 12: + return detail::InsertLane<12>(v, t); + case 13: + return detail::InsertLane<13>(v, t); + case 14: + return detail::InsertLane<14>(v, t); + case 15: + return detail::InsertLane<15>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[16]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition +HWY_NEON_DEF_FUNCTION_UINTS(operator+, vadd, _, 2) +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator+, vadd, _, 2) + +template +HWY_API Vec128 operator+(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) + BitCast(du, b)); +} + +template +HWY_API Vec128 operator+(Vec128 a, + Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) + BitCast(du, b)); +} + +template +HWY_API Vec128 operator+(Vec128 a, + Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) + BitCast(du, b)); +} + +template +HWY_API Vec128 operator+(Vec128 a, + Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) + BitCast(du, b)); +} + +// ------------------------------ Subtraction +HWY_NEON_DEF_FUNCTION_UINTS(operator-, vsub, _, 2) +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator-, vsub, _, 2) + +template +HWY_API Vec128 operator-(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) - BitCast(du, b)); +} + +template +HWY_API Vec128 operator-(Vec128 a, + Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) - BitCast(du, b)); +} + +template +HWY_API Vec128 operator-(Vec128 a, + Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) - BitCast(du, b)); +} + +template +HWY_API Vec128 operator-(Vec128 a, + Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) - BitCast(du, b)); +} + +// ------------------------------ SumsOf8 + +HWY_API Vec128 SumsOf8(const Vec128 v) { + return Vec128(vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(v.raw)))); +} +HWY_API Vec64 SumsOf8(const Vec64 v) { + return Vec64(vpaddl_u32(vpaddl_u16(vpaddl_u8(v.raw)))); +} +HWY_API Vec128 SumsOf8(const Vec128 v) { + return Vec128(vpaddlq_s32(vpaddlq_s16(vpaddlq_s8(v.raw)))); +} +HWY_API Vec64 SumsOf8(const Vec64 v) { + return Vec64(vpaddl_s32(vpaddl_s16(vpaddl_s8(v.raw)))); +} + +// ------------------------------ SumsOf2 +namespace detail { + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddl_s8(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddlq_s8(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddl_u8(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddlq_u8(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddl_s16(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddlq_s16(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddl_u16(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddlq_u16(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag, hwy::SizeTag<4> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddl_s32(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag, hwy::SizeTag<4> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddlq_s32(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag, hwy::SizeTag<4> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddl_u32(v.raw)); +} + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag, hwy::SizeTag<4> /*lane_size_tag*/, V v) { + return VFromD>>(vpaddlq_u32(v.raw)); +} + +} // namespace detail + +// ------------------------------ SaturatedAdd + +#ifdef HWY_NATIVE_I32_SATURATED_ADDSUB +#undef HWY_NATIVE_I32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I32_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_U32_SATURATED_ADDSUB +#undef HWY_NATIVE_U32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_U32_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_I64_SATURATED_ADDSUB +#undef HWY_NATIVE_I64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I64_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_U64_SATURATED_ADDSUB +#undef HWY_NATIVE_U64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_U64_SATURATED_ADDSUB +#endif + +// Returns a + b clamped to the destination range. +HWY_NEON_DEF_FUNCTION_INTS_UINTS(SaturatedAdd, vqadd, _, 2) + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. +HWY_NEON_DEF_FUNCTION_INTS_UINTS(SaturatedSub, vqsub, _, 2) + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 +#undef HWY_NATIVE_AVERAGE_ROUND_UI32 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI32 +#endif + +HWY_NEON_DEF_FUNCTION_UI_8_16_32(AverageRound, vrhadd, _, 2) + +// ------------------------------ Neg + +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Neg, vneg, _, 1) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(Neg, vneg, _, 1) // i64 implemented below + +#if !HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Neg(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; + return BitCast(d, Xor(BitCast(du, v), Set(du, SignMask()))); +} +#endif // !HWY_HAVE_FLOAT16 + +// There is no vneg for bf16, but we can cast to f16 (emulated or native). +template +HWY_API Vec128 Neg(const Vec128 v) { + const DFromV d; + const Rebind df16; + return BitCast(d, Neg(BitCast(df16, v))); +} + +HWY_API Vec64 Neg(const Vec64 v) { +#if HWY_ARCH_ARM_A64 + return Vec64(vneg_s64(v.raw)); +#else + return Zero(DFromV()) - v; +#endif +} + +HWY_API Vec128 Neg(const Vec128 v) { +#if HWY_ARCH_ARM_A64 + return Vec128(vnegq_s64(v.raw)); +#else + return Zero(DFromV()) - v; +#endif +} + +// ------------------------------ SaturatedNeg +#ifdef HWY_NATIVE_SATURATED_NEG_8_16_32 +#undef HWY_NATIVE_SATURATED_NEG_8_16_32 +#else +#define HWY_NATIVE_SATURATED_NEG_8_16_32 +#endif + +HWY_NEON_DEF_FUNCTION_INT_8_16_32(SaturatedNeg, vqneg, _, 1) + +#if HWY_ARCH_ARM_A64 +#ifdef HWY_NATIVE_SATURATED_NEG_64 +#undef HWY_NATIVE_SATURATED_NEG_64 +#else +#define HWY_NATIVE_SATURATED_NEG_64 +#endif + +HWY_API Vec64 SaturatedNeg(const Vec64 v) { + return Vec64(vqneg_s64(v.raw)); +} + +HWY_API Vec128 SaturatedNeg(const Vec128 v) { + return Vec128(vqnegq_s64(v.raw)); +} +#endif + +// ------------------------------ ShiftLeft + +#ifdef HWY_NATIVE_ROUNDING_SHR +#undef HWY_NATIVE_ROUNDING_SHR +#else +#define HWY_NATIVE_ROUNDING_SHR +#endif + +// Customize HWY_NEON_DEF_FUNCTION to special-case count=0 (not supported). +#pragma push_macro("HWY_NEON_DEF_FUNCTION") +#undef HWY_NEON_DEF_FUNCTION +#define HWY_NEON_DEF_FUNCTION(type, size, name, prefix, infix, suffix, args) \ + template \ + HWY_API Vec128 name(const Vec128 v) { \ + return kBits == 0 ? v \ + : Vec128(HWY_NEON_EVAL( \ + prefix##infix##suffix, v.raw, HWY_MAX(1, kBits))); \ + } + +HWY_NEON_DEF_FUNCTION_INTS_UINTS(ShiftLeft, vshl, _n_, ignored) + +HWY_NEON_DEF_FUNCTION_UINTS(ShiftRight, vshr, _n_, ignored) +HWY_NEON_DEF_FUNCTION_INTS(ShiftRight, vshr, _n_, ignored) +HWY_NEON_DEF_FUNCTION_UINTS(RoundingShiftRight, vrshr, _n_, ignored) +HWY_NEON_DEF_FUNCTION_INTS(RoundingShiftRight, vrshr, _n_, ignored) + +#pragma pop_macro("HWY_NEON_DEF_FUNCTION") + +// ------------------------------ RotateRight (ShiftRight, Or) +template +HWY_API Vec128 RotateRight(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + + constexpr size_t kSizeInBits = sizeof(T) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + + return Or(BitCast(d, ShiftRight(BitCast(du, v))), + ShiftLeft(v)); +} + +// NOTE: vxarq_u64 can be applied to uint64_t, but we do not yet have a +// mechanism for checking for extensions to Armv8. + +// ------------------------------ Shl + +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return Vec128(vshlq_u8(v.raw, vreinterpretq_s8_u8(bits.raw))); +} +template +HWY_API Vec128 operator<<(Vec128 v, + Vec128 bits) { + return Vec128(vshl_u8(v.raw, vreinterpret_s8_u8(bits.raw))); +} + +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return Vec128(vshlq_u16(v.raw, vreinterpretq_s16_u16(bits.raw))); +} +template +HWY_API Vec128 operator<<(Vec128 v, + Vec128 bits) { + return Vec128(vshl_u16(v.raw, vreinterpret_s16_u16(bits.raw))); +} + +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return Vec128(vshlq_u32(v.raw, vreinterpretq_s32_u32(bits.raw))); +} +template +HWY_API Vec128 operator<<(Vec128 v, + Vec128 bits) { + return Vec128(vshl_u32(v.raw, vreinterpret_s32_u32(bits.raw))); +} + +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return Vec128(vshlq_u64(v.raw, vreinterpretq_s64_u64(bits.raw))); +} +HWY_API Vec64 operator<<(Vec64 v, Vec64 bits) { + return Vec64(vshl_u64(v.raw, vreinterpret_s64_u64(bits.raw))); +} + +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return Vec128(vshlq_s8(v.raw, bits.raw)); +} +template +HWY_API Vec128 operator<<(Vec128 v, + Vec128 bits) { + return Vec128(vshl_s8(v.raw, bits.raw)); +} + +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return Vec128(vshlq_s16(v.raw, bits.raw)); +} +template +HWY_API Vec128 operator<<(Vec128 v, + Vec128 bits) { + return Vec128(vshl_s16(v.raw, bits.raw)); +} + +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return Vec128(vshlq_s32(v.raw, bits.raw)); +} +template +HWY_API Vec128 operator<<(Vec128 v, + Vec128 bits) { + return Vec128(vshl_s32(v.raw, bits.raw)); +} + +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return Vec128(vshlq_s64(v.raw, bits.raw)); +} +HWY_API Vec64 operator<<(Vec64 v, Vec64 bits) { + return Vec64(vshl_s64(v.raw, bits.raw)); +} + +// ------------------------------ Shr (Neg) + +HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { + const RebindToSigned> di; + const int8x16_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vshlq_u8(v.raw, neg_bits)); +} +template +HWY_API Vec128 operator>>(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int8x8_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vshl_u8(v.raw, neg_bits)); +} + +HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { + const RebindToSigned> di; + const int16x8_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vshlq_u16(v.raw, neg_bits)); +} +template +HWY_API Vec128 operator>>(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int16x4_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vshl_u16(v.raw, neg_bits)); +} + +HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { + const RebindToSigned> di; + const int32x4_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vshlq_u32(v.raw, neg_bits)); +} +template +HWY_API Vec128 operator>>(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int32x2_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vshl_u32(v.raw, neg_bits)); +} + +HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { + const RebindToSigned> di; + const int64x2_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vshlq_u64(v.raw, neg_bits)); +} +HWY_API Vec64 operator>>(Vec64 v, Vec64 bits) { + const RebindToSigned> di; + const int64x1_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec64(vshl_u64(v.raw, neg_bits)); +} + +HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { + return Vec128(vshlq_s8(v.raw, Neg(bits).raw)); +} +template +HWY_API Vec128 operator>>(Vec128 v, + Vec128 bits) { + return Vec128(vshl_s8(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { + return Vec128(vshlq_s16(v.raw, Neg(bits).raw)); +} +template +HWY_API Vec128 operator>>(Vec128 v, + Vec128 bits) { + return Vec128(vshl_s16(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { + return Vec128(vshlq_s32(v.raw, Neg(bits).raw)); +} +template +HWY_API Vec128 operator>>(Vec128 v, + Vec128 bits) { + return Vec128(vshl_s32(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { + return Vec128(vshlq_s64(v.raw, Neg(bits).raw)); +} +HWY_API Vec64 operator>>(Vec64 v, Vec64 bits) { + return Vec64(vshl_s64(v.raw, Neg(bits).raw)); +} + +// ------------------------------ RoundingShr (Neg) + +HWY_API Vec128 RoundingShr(Vec128 v, Vec128 bits) { + const RebindToSigned> di; + const int8x16_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vrshlq_u8(v.raw, neg_bits)); +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int8x8_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vrshl_u8(v.raw, neg_bits)); +} + +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int16x8_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vrshlq_u16(v.raw, neg_bits)); +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int16x4_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vrshl_u16(v.raw, neg_bits)); +} + +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int32x4_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vrshlq_u32(v.raw, neg_bits)); +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int32x2_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vrshl_u32(v.raw, neg_bits)); +} + +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + const RebindToSigned> di; + const int64x2_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec128(vrshlq_u64(v.raw, neg_bits)); +} +HWY_API Vec64 RoundingShr(Vec64 v, Vec64 bits) { + const RebindToSigned> di; + const int64x1_t neg_bits = Neg(BitCast(di, bits)).raw; + return Vec64(vrshl_u64(v.raw, neg_bits)); +} + +HWY_API Vec128 RoundingShr(Vec128 v, Vec128 bits) { + return Vec128(vrshlq_s8(v.raw, Neg(bits).raw)); +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + return Vec128(vrshl_s8(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128 RoundingShr(Vec128 v, Vec128 bits) { + return Vec128(vrshlq_s16(v.raw, Neg(bits).raw)); +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + return Vec128(vrshl_s16(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128 RoundingShr(Vec128 v, Vec128 bits) { + return Vec128(vrshlq_s32(v.raw, Neg(bits).raw)); +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + return Vec128(vrshl_s32(v.raw, Neg(bits).raw)); +} + +HWY_API Vec128 RoundingShr(Vec128 v, Vec128 bits) { + return Vec128(vrshlq_s64(v.raw, Neg(bits).raw)); +} +HWY_API Vec64 RoundingShr(Vec64 v, Vec64 bits) { + return Vec64(vrshl_s64(v.raw, Neg(bits).raw)); +} + +// ------------------------------ ShiftLeftSame (Shl) + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, int bits) { + return v << Set(DFromV(), static_cast(bits)); +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, int bits) { + return v >> Set(DFromV(), static_cast(bits)); +} + +// ------------------------------ RoundingShiftRightSame (RoundingShr) + +template +HWY_API Vec128 RoundingShiftRightSame(const Vec128 v, int bits) { + return RoundingShr(v, Set(DFromV(), static_cast(bits))); +} + +// ------------------------------ Int/float multiplication + +// Per-target flag to prevent generic_ops-inl.h from defining 8-bit operator*. +#ifdef HWY_NATIVE_MUL_8 +#undef HWY_NATIVE_MUL_8 +#else +#define HWY_NATIVE_MUL_8 +#endif + +// All except ui64 +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(operator*, vmul, _, 2) +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator*, vmul, _, 2) + +template +HWY_API Vec128 operator*(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) * BitCast(du, b)); +} + +template +HWY_API Vec128 operator*(Vec128 a, + Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) * BitCast(du, b)); +} + +template +HWY_API Vec128 operator*(Vec128 a, + Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) * BitCast(du, b)); +} + +// ------------------------------ Integer multiplication + +// Returns the upper sizeof(T)*8 bits of a * b in each lane. +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + int16x8_t rlo = vmull_s8(vget_low_s8(a.raw), vget_low_s8(b.raw)); +#if HWY_ARCH_ARM_A64 + int16x8_t rhi = vmull_high_s8(a.raw, b.raw); +#else + int16x8_t rhi = vmull_s8(vget_high_s8(a.raw), vget_high_s8(b.raw)); +#endif + return Vec128( + vuzp2q_s8(vreinterpretq_s8_s16(rlo), vreinterpretq_s8_s16(rhi))); +} +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + uint16x8_t rlo = vmull_u8(vget_low_u8(a.raw), vget_low_u8(b.raw)); +#if HWY_ARCH_ARM_A64 + uint16x8_t rhi = vmull_high_u8(a.raw, b.raw); +#else + uint16x8_t rhi = vmull_u8(vget_high_u8(a.raw), vget_high_u8(b.raw)); +#endif + return Vec128( + vuzp2q_u8(vreinterpretq_u8_u16(rlo), vreinterpretq_u8_u16(rhi))); +} + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + int8x16_t hi_lo = vreinterpretq_s8_s16(vmull_s8(a.raw, b.raw)); + return Vec128(vget_low_s8(vuzp2q_s8(hi_lo, hi_lo))); +} +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + uint8x16_t hi_lo = vreinterpretq_u8_u16(vmull_u8(a.raw, b.raw)); + return Vec128(vget_low_u8(vuzp2q_u8(hi_lo, hi_lo))); +} + +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + int32x4_t rlo = vmull_s16(vget_low_s16(a.raw), vget_low_s16(b.raw)); +#if HWY_ARCH_ARM_A64 + int32x4_t rhi = vmull_high_s16(a.raw, b.raw); +#else + int32x4_t rhi = vmull_s16(vget_high_s16(a.raw), vget_high_s16(b.raw)); +#endif + return Vec128( + vuzp2q_s16(vreinterpretq_s16_s32(rlo), vreinterpretq_s16_s32(rhi))); +} +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + uint32x4_t rlo = vmull_u16(vget_low_u16(a.raw), vget_low_u16(b.raw)); +#if HWY_ARCH_ARM_A64 + uint32x4_t rhi = vmull_high_u16(a.raw, b.raw); +#else + uint32x4_t rhi = vmull_u16(vget_high_u16(a.raw), vget_high_u16(b.raw)); +#endif + return Vec128( + vuzp2q_u16(vreinterpretq_u16_u32(rlo), vreinterpretq_u16_u32(rhi))); +} + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + int16x8_t hi_lo = vreinterpretq_s16_s32(vmull_s16(a.raw, b.raw)); + return Vec128(vget_low_s16(vuzp2q_s16(hi_lo, hi_lo))); +} +template +HWY_API Vec128 MulHigh(Vec128 a, + Vec128 b) { + uint16x8_t hi_lo = vreinterpretq_u16_u32(vmull_u16(a.raw, b.raw)); + return Vec128(vget_low_u16(vuzp2q_u16(hi_lo, hi_lo))); +} + +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + int64x2_t rlo = vmull_s32(vget_low_s32(a.raw), vget_low_s32(b.raw)); +#if HWY_ARCH_ARM_A64 + int64x2_t rhi = vmull_high_s32(a.raw, b.raw); +#else + int64x2_t rhi = vmull_s32(vget_high_s32(a.raw), vget_high_s32(b.raw)); +#endif + return Vec128( + vuzp2q_s32(vreinterpretq_s32_s64(rlo), vreinterpretq_s32_s64(rhi))); +} +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + uint64x2_t rlo = vmull_u32(vget_low_u32(a.raw), vget_low_u32(b.raw)); +#if HWY_ARCH_ARM_A64 + uint64x2_t rhi = vmull_high_u32(a.raw, b.raw); +#else + uint64x2_t rhi = vmull_u32(vget_high_u32(a.raw), vget_high_u32(b.raw)); +#endif + return Vec128( + vuzp2q_u32(vreinterpretq_u32_u64(rlo), vreinterpretq_u32_u64(rhi))); +} + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + int32x4_t hi_lo = vreinterpretq_s32_s64(vmull_s32(a.raw, b.raw)); + return Vec128(vget_low_s32(vuzp2q_s32(hi_lo, hi_lo))); +} +template +HWY_API Vec128 MulHigh(Vec128 a, + Vec128 b) { + uint32x4_t hi_lo = vreinterpretq_u32_u64(vmull_u32(a.raw, b.raw)); + return Vec128(vget_low_u32(vuzp2q_u32(hi_lo, hi_lo))); +} + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + T hi_0; + T hi_1; + + Mul128(GetLane(a), GetLane(b), &hi_0); + Mul128(detail::GetLane<1>(a), detail::GetLane<1>(b), &hi_1); + + return Dup128VecFromValues(Full128(), hi_0, hi_1); +} + +template +HWY_API Vec64 MulHigh(Vec64 a, Vec64 b) { + T hi; + Mul128(GetLane(a), GetLane(b), &hi); + return Set(Full64(), hi); +} + +HWY_API Vec128 MulFixedPoint15(Vec128 a, Vec128 b) { + return Vec128(vqrdmulhq_s16(a.raw, b.raw)); +} +template +HWY_API Vec128 MulFixedPoint15(Vec128 a, + Vec128 b) { + return Vec128(vqrdmulh_s16(a.raw, b.raw)); +} + +// ------------------------------ Floating-point division + +// Emulate missing intrinsic +#if HWY_HAVE_FLOAT64 && HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 +HWY_INLINE float64x1_t vrecpe_f64(float64x1_t raw) { + const CappedTag d; + const Twice dt; + using VT = VFromD; + return LowerHalf(d, VT(vrecpeq_f64(Combine(dt, v, v).raw))).raw; +} +#endif + +// Approximate reciprocal +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(ApproximateReciprocal, vrecpe, _, 1) + +#if HWY_HAVE_FLOAT64 +#ifdef HWY_NATIVE_F64_APPROX_RECIP +#undef HWY_NATIVE_F64_APPROX_RECIP +#else +#define HWY_NATIVE_F64_APPROX_RECIP +#endif + +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator/, vdiv, _, 2) +#else // !HWY_HAVE_FLOAT64 +namespace detail { +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(ReciprocalNewtonRaphsonStep, vrecps, _, 2) +} // namespace detail + +template +HWY_API Vec128 operator/(Vec128 a, Vec128 b) { + auto x = ApproximateReciprocal(b); + x *= detail::ReciprocalNewtonRaphsonStep(x, b); + x *= detail::ReciprocalNewtonRaphsonStep(x, b); + x *= detail::ReciprocalNewtonRaphsonStep(x, b); + return a * x; +} +#endif // HWY_HAVE_FLOAT64 + +// ------------------------------ Absolute value of difference. + +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(AbsDiff, vabd, _, 2) +HWY_NEON_DEF_FUNCTION_UI_8_16_32(AbsDiff, vabd, _, 2) // no UI64 + +#ifdef HWY_NATIVE_INTEGER_ABS_DIFF +#undef HWY_NATIVE_INTEGER_ABS_DIFF +#else +#define HWY_NATIVE_INTEGER_ABS_DIFF +#endif + +// ------------------------------ Integer multiply-add + +// Per-target flag to prevent generic_ops-inl.h from defining int MulAdd. +#ifdef HWY_NATIVE_INT_FMA +#undef HWY_NATIVE_INT_FMA +#else +#define HWY_NATIVE_INT_FMA +#endif + +// Wrappers for changing argument order to what intrinsics expect. +namespace detail { +// All except ui64 +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(MulAdd, vmla, _, 3) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(MulAdd, vmla, _, 3) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(NegMulAdd, vmls, _, 3) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(NegMulAdd, vmls, _, 3) +} // namespace detail + +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return detail::MulAdd(add, mul, x); +} + +template +HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return detail::NegMulAdd(add, mul, x); +} + +// 64-bit integer +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return Add(Mul(mul, x), add); +} + +template +HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return Sub(add, Mul(mul, x)); +} + +// ------------------------------ Floating-point multiply-add variants + +namespace detail { + +#if HWY_NATIVE_FMA +// Wrappers for changing argument order to what intrinsics expect. +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(MulAdd, vfma, _, 3) +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(NegMulAdd, vfms, _, 3) +#else +// Emulate. Matches intrinsics arg order. +template +HWY_API Vec128 MulAdd(Vec128 add, Vec128 mul, + Vec128 x) { + return mul * x + add; +} + +template +HWY_API Vec128 NegMulAdd(Vec128 add, Vec128 mul, + Vec128 x) { + return add - mul * x; +} + +#endif // HWY_NATIVE_FMA +} // namespace detail + +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return detail::MulAdd(add, mul, x); +} + +template +HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return detail::NegMulAdd(add, mul, x); +} + +template +HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, + Vec128 sub) { + return MulAdd(mul, x, Neg(sub)); +} + +template +HWY_API Vec128 NegMulSub(Vec128 mul, Vec128 x, + Vec128 sub) { + return Neg(MulAdd(mul, x, sub)); +} + +// ------------------------------ Floating-point square root (IfThenZeroElse) + +// Emulate missing intrinsic +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 490 +HWY_INLINE float64x1_t vrsqrte_f64(float64x1_t raw) { + const CappedTag d; + const Twice dt; + using VT = VFromD; + const VFromD v(raw); + return LowerHalf(d, VT(vrsqrteq_f64(Combine(dt, v, v).raw))).raw; +} +#endif + +// Approximate reciprocal square root +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(ApproximateReciprocalSqrt, vrsqrte, _, 1) + +#if HWY_HAVE_FLOAT64 +#ifdef HWY_NATIVE_F64_APPROX_RSQRT +#undef HWY_NATIVE_F64_APPROX_RSQRT +#else +#define HWY_NATIVE_F64_APPROX_RSQRT +#endif + +// Full precision square root +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Sqrt, vsqrt, _, 1) +#else // !HWY_HAVE_FLOAT64 +namespace detail { +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(ReciprocalSqrtStep, vrsqrts, _, 2) +} // namespace detail + +template +HWY_API Vec128 Sqrt(const Vec128 v) { + auto recip = ApproximateReciprocalSqrt(v); + + recip *= detail::ReciprocalSqrtStep(v * recip, recip); + recip *= detail::ReciprocalSqrtStep(v * recip, recip); + recip *= detail::ReciprocalSqrtStep(v * recip, recip); + + const auto root = v * recip; + return IfThenZeroElse(v == Zero(Simd()), root); +} +#endif // HWY_HAVE_FLOAT64 + +// ================================================== LOGICAL + +// ------------------------------ Not + +// There is no 64-bit vmvn, so cast instead of using HWY_NEON_DEF_FUNCTION. +template +HWY_API Vec128 Not(const Vec128 v) { + const DFromV d; + const Repartition d8; + return BitCast(d, Vec128(vmvnq_u8(BitCast(d8, v).raw))); +} +template +HWY_API Vec128 Not(const Vec128 v) { + const DFromV d; + const Repartition d8; + using V8 = decltype(Zero(d8)); + return BitCast(d, V8(vmvn_u8(BitCast(d8, v).raw))); +} + +// ------------------------------ And +HWY_NEON_DEF_FUNCTION_INTS_UINTS(And, vand, _, 2) + +// Uses the u32/64 defined above. +template +HWY_API Vec128 And(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) & BitCast(du, b)); +} + +// ------------------------------ AndNot + +namespace detail { +// AndNotSwap returns a & ~b, whereas AndNot is defined as ~a & b. +HWY_NEON_DEF_FUNCTION_INTS_UINTS(AndNotSwap, vbic, _, 2) +} // namespace detail + +// Returns ~not_mask & mask. +template +HWY_API Vec128 AndNot(const Vec128 not_mask, + const Vec128 mask) { + return detail::AndNotSwap(mask, not_mask); +} + +// Uses the u32/64 defined above. +template +HWY_API Vec128 AndNot(const Vec128 not_mask, + const Vec128 mask) { + const DFromV d; + const RebindToUnsigned du; + VFromD ret = + detail::AndNotSwap(BitCast(du, mask), BitCast(du, not_mask)); + return BitCast(d, ret); +} + +// ------------------------------ Or + +HWY_NEON_DEF_FUNCTION_INTS_UINTS(Or, vorr, _, 2) + +// Uses the u32/64 defined above. +template +HWY_API Vec128 Or(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) | BitCast(du, b)); +} + +// ------------------------------ Xor + +HWY_NEON_DEF_FUNCTION_INTS_UINTS(Xor, veor, _, 2) + +// Uses the u32/64 defined above. +template +HWY_API Vec128 Xor(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) ^ BitCast(du, b)); +} + +// ------------------------------ Xor3 +#if HWY_ARCH_ARM_A64 && defined(__ARM_FEATURE_SHA3) + +#ifdef HWY_NATIVE_XOR3 +#undef HWY_NATIVE_XOR3 +#else +#define HWY_NATIVE_XOR3 +#endif + +HWY_NEON_DEF_FUNCTION_FULL_UI(Xor3, veor3, _, 3) + +// Half vectors are not natively supported. Two Xor are likely more efficient +// than Combine to 128-bit. +template +HWY_API Vec128 Xor3(Vec128 x1, Vec128 x2, Vec128 x3) { + return Xor(x1, Xor(x2, x3)); +} + +template +HWY_API Vec128 Xor3(const Vec128 x1, const Vec128 x2, + const Vec128 x3) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, Xor3(BitCast(du, x1), BitCast(du, x2), BitCast(du, x3))); +} + +#endif + +// ------------------------------ Or3 +template +HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd +template +HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ XorAndNot +#if HWY_ARCH_ARM_A64 && defined(__ARM_FEATURE_SHA3) + +#ifdef HWY_NATIVE_BCAX +#undef HWY_NATIVE_BCAX +#else +#define HWY_NATIVE_BCAX +#endif + +namespace detail { +HWY_NEON_DEF_FUNCTION_FULL_UI(XorAndNotSwap, vbcax, _, 3) +} // namespace detail + +// As with AndNot, swap the last two arguments because our "negated first" +// convention mismatches the intrinsics, which have the negated arg last. +template +HWY_API V XorAndNot(V x, V a1, V a2) { + return detail::XorAndNotSwap(x, a2, a1); +} + +// Half vectors are not natively supported. Two ops are likely more efficient +// than Combine to 128-bit. +template +HWY_API Vec128 XorAndNot(Vec128 x, Vec128 a1, + Vec128 a2) { + return Xor(x, AndNot(a1, a2)); +} + +template +HWY_API Vec128 XorAndNot(const Vec128 x, const Vec128 a1, + const Vec128 a2) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, + XorAndNot(BitCast(du, x), BitCast(du, a1), BitCast(du, a2))); +} + +#endif + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec128 operator&(const Vec128 a, const Vec128 b) { + return And(a, b); +} + +template +HWY_API Vec128 operator|(const Vec128 a, const Vec128 b) { + return Or(a, b); +} + +template +HWY_API Vec128 operator^(const Vec128 a, const Vec128 b) { + return Xor(a, b); +} + +// ------------------------------ I64/U64 AbsDiff + +template +HWY_API Vec128 AbsDiff(const Vec128 a, + const Vec128 b) { + return Max(a, b) - Min(a, b); +} + +template +HWY_API Vec128 AbsDiff(const Vec128 a, + const Vec128 b) { + return Or(SaturatedSub(a, b), SaturatedSub(b, a)); +} + +// ------------------------------ PopulationCount + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +namespace detail { + +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<1> /* tag */, Vec128 v) { + const Full128 d8; + return Vec128(vcntq_u8(BitCast(d8, v).raw)); +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<1> /* tag */, + Vec128 v) { + const Simd d8; + return Vec128(vcnt_u8(BitCast(d8, v).raw)); +} + +// NEON lacks popcount for lane sizes > 1, so take pairwise sums of the bytes. +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<2> /* tag */, Vec128 v) { + const Full128 d8; + const uint8x16_t bytes = vcntq_u8(BitCast(d8, v).raw); + return Vec128(vpaddlq_u8(bytes)); +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<2> /* tag */, + Vec128 v) { + const Repartition> d8; + const uint8x8_t bytes = vcnt_u8(BitCast(d8, v).raw); + return Vec128(vpaddl_u8(bytes)); +} + +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<4> /* tag */, Vec128 v) { + const Full128 d8; + const uint8x16_t bytes = vcntq_u8(BitCast(d8, v).raw); + return Vec128(vpaddlq_u16(vpaddlq_u8(bytes))); +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<4> /* tag */, + Vec128 v) { + const Repartition> d8; + const uint8x8_t bytes = vcnt_u8(BitCast(d8, v).raw); + return Vec128(vpaddl_u16(vpaddl_u8(bytes))); +} + +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<8> /* tag */, Vec128 v) { + const Full128 d8; + const uint8x16_t bytes = vcntq_u8(BitCast(d8, v).raw); + return Vec128(vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(bytes)))); +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<8> /* tag */, + Vec128 v) { + const Repartition> d8; + const uint8x8_t bytes = vcnt_u8(BitCast(d8, v).raw); + return Vec128(vpaddl_u32(vpaddl_u16(vpaddl_u8(bytes)))); +} + +} // namespace detail + +template +HWY_API Vec128 PopulationCount(Vec128 v) { + return detail::PopulationCount(hwy::SizeTag(), v); +} + +// ================================================== SIGN + +// ------------------------------ Abs +// i64 is implemented after BroadcastSignBit. +HWY_NEON_DEF_FUNCTION_INT_8_16_32(Abs, vabs, _, 1) +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Abs, vabs, _, 1) + +// ------------------------------ SaturatedAbs +#ifdef HWY_NATIVE_SATURATED_ABS +#undef HWY_NATIVE_SATURATED_ABS +#else +#define HWY_NATIVE_SATURATED_ABS +#endif + +HWY_NEON_DEF_FUNCTION_INT_8_16_32(SaturatedAbs, vqabs, _, 1) + +// ------------------------------ CopySignToAbs +template +HWY_API Vec128 CopySignToAbs(Vec128 abs, Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const DFromV d; + return OrAnd(abs, SignBit(d), sign); +} + +// ------------------------------ BroadcastSignBit + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight(v); +} + +// ================================================== MASK + +// ------------------------------ To/from vector + +// Mask and Vec have the same representation (true = FF..FF). +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + const Simd, N, 0> du; + return Mask128(BitCast(du, v).raw); +} + +template +using MFromD = decltype(MaskFromVec(VFromD())); + +template +HWY_API VFromD VecFromMask(D d, const MFromD m) { + // Raw type of masks is unsigned. + const RebindToUnsigned du; + return BitCast(d, VFromD(m.raw)); +} + +// ------------------------------ RebindMask (MaskFromVec) + +template +HWY_API MFromD RebindMask(DTo /* tag */, Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TFromD), "Must have same size"); + return MFromD(m.raw); +} + +// ------------------------------ IfThenElse + +// Workaround for incorrect codegen. +#if HWY_ARCH_ARM_V7 + +template > +HWY_API V IfThenElse(MFromD mask, V yes, V no) { + const RebindToUnsigned du; + using VU = VFromD; + const VU no_u = BitCast(du, no); + const VU diff_u = BitCast(du, yes) ^ no_u; + const VU mask_u = BitCast(du, VecFromMask(D(), mask)); + return BitCast(D(), no_u ^ (diff_u & mask_u)); +} + +#else // normal VBSL instruction + +#define HWY_NEON_BUILD_TPL_HWY_IF +#define HWY_NEON_BUILD_RET_HWY_IF(type, size) Vec128 +#define HWY_NEON_BUILD_PARAM_HWY_IF(type, size) \ + const Mask128 mask, const Vec128 yes, \ + const Vec128 no +#define HWY_NEON_BUILD_ARG_HWY_IF mask.raw, yes.raw, no.raw + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(IfThenElse, vbsl, _, HWY_IF) + +#endif // HWY_ARCH_ARM_V7 + +#if HWY_HAVE_FLOAT16 +#define HWY_NEON_IF_EMULATED_IF_THEN_ELSE(V) HWY_IF_BF16(TFromV) +#else +#define HWY_NEON_IF_EMULATED_IF_THEN_ELSE(V) HWY_IF_SPECIAL_FLOAT_V(V) +#endif + +template +HWY_API V IfThenElse(MFromD> mask, V yes, V no) { + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, IfThenElse(RebindMask(du, mask), BitCast(du, yes), BitCast(du, no))); +} + +#undef HWY_NEON_IF_EMULATED_IF_THEN_ELSE +#undef HWY_NEON_BUILD_TPL_HWY_IF +#undef HWY_NEON_BUILD_RET_HWY_IF +#undef HWY_NEON_BUILD_PARAM_HWY_IF +#undef HWY_NEON_BUILD_ARG_HWY_IF + +// mask ? yes : 0 +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return yes & VecFromMask(DFromV(), mask); +} +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, IfThenElseZero(RebindMask(du, mask), BitCast(du, yes))); +} + +// mask ? 0 : no +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return AndNot(VecFromMask(DFromV(), mask), no); +} +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, IfThenZeroElse(RebindMask(du, mask), BitCast(du, no))); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const DFromV d; + const RebindToSigned di; + + Mask128 m = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); + return IfThenElse(m, yes, no); +} + +template +HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, + Vec128 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ BitwiseIfThenElse + +#ifdef HWY_NATIVE_BITWISE_IF_THEN_ELSE +#undef HWY_NATIVE_BITWISE_IF_THEN_ELSE +#else +#define HWY_NATIVE_BITWISE_IF_THEN_ELSE +#endif + +template +HWY_API V BitwiseIfThenElse(V mask, V yes, V no) { + return IfVecThenElse(mask, yes, no); +} + +// ------------------------------ CopySign (BitwiseIfThenElse) +template +HWY_API Vec128 CopySign(Vec128 magn, Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const DFromV d; + return BitwiseIfThenElse(SignBit(d), sign, magn); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(const Mask128 m) { + return MaskFromVec(Not(VecFromMask(DFromM(), m))); +} + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + const DFromM d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + const DFromM d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + const DFromM d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + const DFromM d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + const DFromM d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +// ------------------------------ Shuffle2301 (for i64 compares) + +// Swap 32-bit halves in 64-bits +HWY_API Vec64 Shuffle2301(const Vec64 v) { + return Vec64(vrev64_u32(v.raw)); +} +HWY_API Vec64 Shuffle2301(const Vec64 v) { + return Vec64(vrev64_s32(v.raw)); +} +HWY_API Vec64 Shuffle2301(const Vec64 v) { + return Vec64(vrev64_f32(v.raw)); +} +HWY_API Vec128 Shuffle2301(const Vec128 v) { + return Vec128(vrev64q_u32(v.raw)); +} +HWY_API Vec128 Shuffle2301(const Vec128 v) { + return Vec128(vrev64q_s32(v.raw)); +} +HWY_API Vec128 Shuffle2301(const Vec128 v) { + return Vec128(vrev64q_f32(v.raw)); +} + +#define HWY_NEON_BUILD_TPL_HWY_COMPARE +#define HWY_NEON_BUILD_RET_HWY_COMPARE(type, size) Mask128 +#define HWY_NEON_BUILD_PARAM_HWY_COMPARE(type, size) \ + const Vec128 a, const Vec128 b +#define HWY_NEON_BUILD_ARG_HWY_COMPARE a.raw, b.raw + +// ------------------------------ Equality +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator==, vceq, _, HWY_COMPARE) +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_INTS_UINTS(operator==, vceq, _, HWY_COMPARE) +#else +// No 64-bit comparisons on armv7: emulate them below, after Shuffle2301. +HWY_NEON_DEF_FUNCTION_INT_8_16_32(operator==, vceq, _, HWY_COMPARE) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(operator==, vceq, _, HWY_COMPARE) +#endif + +// ------------------------------ Strict inequality (signed, float) +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_INTS_UINTS(operator<, vclt, _, HWY_COMPARE) +#else +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(operator<, vclt, _, HWY_COMPARE) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(operator<, vclt, _, HWY_COMPARE) +#endif +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator<, vclt, _, HWY_COMPARE) + +// ------------------------------ Weak inequality (float) +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_INTS_UINTS(operator<=, vcle, _, HWY_COMPARE) +#else +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(operator<=, vcle, _, HWY_COMPARE) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(operator<=, vcle, _, HWY_COMPARE) +#endif +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator<=, vcle, _, HWY_COMPARE) + +#undef HWY_NEON_BUILD_TPL_HWY_COMPARE +#undef HWY_NEON_BUILD_RET_HWY_COMPARE +#undef HWY_NEON_BUILD_PARAM_HWY_COMPARE +#undef HWY_NEON_BUILD_ARG_HWY_COMPARE + +// ------------------------------ Armv7 i64 compare (Shuffle2301, Eq) + +#if HWY_ARCH_ARM_V7 + +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + const Simd d32; + const Simd d64; + const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); + const auto cmp64 = cmp32 & Shuffle2301(cmp32); + return MaskFromVec(BitCast(d64, cmp64)); +} + +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + const Simd d32; + const Simd d64; + const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); + const auto cmp64 = cmp32 & Shuffle2301(cmp32); + return MaskFromVec(BitCast(d64, cmp64)); +} + +HWY_API Mask128 operator<(const Vec128 a, + const Vec128 b) { + const int64x2_t sub = vqsubq_s64(a.raw, b.raw); + return MaskFromVec(BroadcastSignBit(Vec128(sub))); +} +HWY_API Mask128 operator<(const Vec64 a, + const Vec64 b) { + const int64x1_t sub = vqsub_s64(a.raw, b.raw); + return MaskFromVec(BroadcastSignBit(Vec64(sub))); +} + +template +HWY_API Mask128 operator<(const Vec128 a, + const Vec128 b) { + const DFromV du; + const RebindToSigned di; + const Vec128 msb = AndNot(a, b) | AndNot(a ^ b, a - b); + return MaskFromVec(BitCast(du, BroadcastSignBit(BitCast(di, msb)))); +} + +template +HWY_API Mask128 operator<=(const Vec128 a, + const Vec128 b) { + return Not(b < a); +} + +template +HWY_API Mask128 operator<=(const Vec128 a, + const Vec128 b) { + return Not(b < a); +} + +#endif + +// ------------------------------ operator!= (operator==) + +// Customize HWY_NEON_DEF_FUNCTION to call 2 functions. +#pragma push_macro("HWY_NEON_DEF_FUNCTION") +#undef HWY_NEON_DEF_FUNCTION +// This cannot have _any_ template argument (in x86_128 we can at least have N +// as an argument), otherwise it is not more specialized than rewritten +// operator== in C++20, leading to compile errors. +#define HWY_NEON_DEF_FUNCTION(type, size, name, prefix, infix, suffix, args) \ + HWY_API Mask128 name(Vec128 a, \ + Vec128 b) { \ + return Not(a == b); \ + } + +HWY_NEON_DEF_FUNCTION_ALL_TYPES(operator!=, ignored, ignored, ignored) + +#pragma pop_macro("HWY_NEON_DEF_FUNCTION") + +// ------------------------------ Reversed comparisons + +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + return operator<(b, a); +} +template +HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { + return operator<=(b, a); +} + +// ------------------------------ FirstN (Iota, Lt) + +template +HWY_API MFromD FirstN(D d, size_t num) { + const RebindToSigned di; // Signed comparisons are cheaper. + using TI = TFromD; + return RebindMask(d, detail::Iota0(di) < Set(di, static_cast(num))); +} + +// ------------------------------ TestBit (Eq) + +#define HWY_NEON_BUILD_TPL_HWY_TESTBIT +#define HWY_NEON_BUILD_RET_HWY_TESTBIT(type, size) Mask128 +#define HWY_NEON_BUILD_PARAM_HWY_TESTBIT(type, size) \ + Vec128 v, Vec128 bit +#define HWY_NEON_BUILD_ARG_HWY_TESTBIT v.raw, bit.raw + +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_INTS_UINTS(TestBit, vtst, _, HWY_TESTBIT) +#else +// No 64-bit versions on armv7 +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(TestBit, vtst, _, HWY_TESTBIT) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(TestBit, vtst, _, HWY_TESTBIT) + +template +HWY_API Mask128 TestBit(Vec128 v, + Vec128 bit) { + return (v & bit) == bit; +} +template +HWY_API Mask128 TestBit(Vec128 v, + Vec128 bit) { + return (v & bit) == bit; +} + +#endif +#undef HWY_NEON_BUILD_TPL_HWY_TESTBIT +#undef HWY_NEON_BUILD_RET_HWY_TESTBIT +#undef HWY_NEON_BUILD_PARAM_HWY_TESTBIT +#undef HWY_NEON_BUILD_ARG_HWY_TESTBIT + +// ------------------------------ Abs i64 (IfNegativeThenElse, Neg) +HWY_API Vec128 Abs(const Vec128 v) { +#if HWY_ARCH_ARM_A64 + return Vec128(vabsq_s64(v.raw)); +#else + return IfNegativeThenElse(v, Neg(v), v); +#endif +} +HWY_API Vec64 Abs(const Vec64 v) { +#if HWY_ARCH_ARM_A64 + return Vec64(vabs_s64(v.raw)); +#else + return IfNegativeThenElse(v, Neg(v), v); +#endif +} + +HWY_API Vec128 SaturatedAbs(const Vec128 v) { +#if HWY_ARCH_ARM_A64 + return Vec128(vqabsq_s64(v.raw)); +#else + const auto zero = Zero(DFromV()); + return IfNegativeThenElse(v, SaturatedSub(zero, v), v); +#endif +} +HWY_API Vec64 SaturatedAbs(const Vec64 v) { +#if HWY_ARCH_ARM_A64 + return Vec64(vqabs_s64(v.raw)); +#else + const auto zero = Zero(DFromV()); + return IfNegativeThenElse(v, SaturatedSub(zero, v), v); +#endif +} + +// ------------------------------ Min (IfThenElse, BroadcastSignBit) + +// Unsigned +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(Min, vmin, _, 2) + +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(b < a, b, a); +#else + const DFromV du; + const RebindToSigned di; + return BitCast(du, BitCast(di, a) - BitCast(di, SaturatedSub(a, b))); +#endif +} + +// Signed +HWY_NEON_DEF_FUNCTION_INT_8_16_32(Min, vmin, _, 2) + +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(b < a, b, a); +#else + const Vec128 sign = SaturatedSub(a, b); + return IfThenElse(MaskFromVec(BroadcastSignBit(sign)), a, b); +#endif +} + +// Float: IEEE minimumNumber on v8 +#if HWY_ARCH_ARM_A64 + +HWY_NEON_DEF_FUNCTION_FLOAT_16_32(Min, vminnm, _, 2) + +// GCC 6.5 and earlier are missing the 64-bit (non-q) intrinsic, so define +// in terms of the 128-bit intrinsic. +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 +namespace detail { + +template +HWY_INLINE V F64Vec64Min(V a, V b) { + const DFromV d; + const Twice dt; + return LowerHalf(d, Min(ZeroExtendVector(dt, a), ZeroExtendVector(dt, b))); +} + +} // namespace detail +#endif // HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 + +HWY_API Vec64 Min(Vec64 a, Vec64 b) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 + return detail::F64Vec64Min(a, b); +#else + return Vec64(vminnm_f64(a.raw, b.raw)); +#endif +} + +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128(vminnmq_f64(a.raw, b.raw)); +} + +#else +// Armv7: NaN if any is NaN. +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Min, vmin, _, 2) +#endif // HWY_ARCH_ARM_A64 + +// ------------------------------ Max (IfThenElse, BroadcastSignBit) + +// Unsigned (no u64) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(Max, vmax, _, 2) + +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(b < a, a, b); +#else + const DFromV du; + const RebindToSigned di; + return BitCast(du, BitCast(di, b) + BitCast(di, SaturatedSub(a, b))); +#endif +} + +// Signed (no i64) +HWY_NEON_DEF_FUNCTION_INT_8_16_32(Max, vmax, _, 2) + +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { +#if HWY_ARCH_ARM_A64 + return IfThenElse(b < a, a, b); +#else + const Vec128 sign = SaturatedSub(a, b); + return IfThenElse(MaskFromVec(BroadcastSignBit(sign)), b, a); +#endif +} + +// Float: IEEE minimumNumber on v8 +#if HWY_ARCH_ARM_A64 + +HWY_NEON_DEF_FUNCTION_FLOAT_16_32(Max, vmaxnm, _, 2) + +// GCC 6.5 and earlier are missing the 64-bit (non-q) intrinsic, so define +// in terms of the 128-bit intrinsic. +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 +namespace detail { + +template +HWY_INLINE V F64Vec64Max(V a, V b) { + const DFromV d; + const Twice dt; + return LowerHalf(d, Max(ZeroExtendVector(dt, a), ZeroExtendVector(dt, b))); +} + +} // namespace detail +#endif // HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 + +HWY_API Vec64 Max(Vec64 a, Vec64 b) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 + return detail::F64Vec64Max(a, b); +#else + return Vec64(vmaxnm_f64(a.raw, b.raw)); +#endif +} + +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128(vmaxnmq_f64(a.raw, b.raw)); +} + +#else +// Armv7: NaN if any is NaN. +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Max, vmax, _, 2) +#endif // HWY_ARCH_ARM_A64 + +// ------------------------------ MinNumber and MaxNumber + +#if !HWY_ARCH_ARM_A64 + +#ifdef HWY_NATIVE_FLOAT_MIN_MAX_NUMBER +#undef HWY_NATIVE_FLOAT_MIN_MAX_NUMBER +#else +#define HWY_NATIVE_FLOAT_MIN_MAX_NUMBER +#endif + +template +HWY_API V MinNumber(V a, V b) { + return Min(IfThenElse(IsNaN(a), b, a), IfThenElse(IsNaN(b), a, b)); +} + +template +HWY_API V MaxNumber(V a, V b) { + return Max(IfThenElse(IsNaN(a), b, a), IfThenElse(IsNaN(b), a, b)); +} + +#endif + +// ================================================== MEMORY + +// ------------------------------ Load 128 + +template +HWY_API Vec128 LoadU(D /* tag */, + const uint8_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_u8(unaligned)); +} +template +HWY_API Vec128 LoadU(D /* tag */, + const uint16_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_u16(unaligned)); +} +template +HWY_API Vec128 LoadU(D /* tag */, + const uint32_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_u32(unaligned)); +} +template +HWY_API Vec128 LoadU(D /* tag */, + const uint64_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_u64(unaligned)); +} +template +HWY_API Vec128 LoadU(D /* tag */, + const int8_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_s8(unaligned)); +} +template +HWY_API Vec128 LoadU(D /* tag */, + const int16_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_s16(unaligned)); +} +template +HWY_API Vec128 LoadU(D /* tag */, + const int32_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_s32(unaligned)); +} +template +HWY_API Vec128 LoadU(D /* tag */, + const int64_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_s64(unaligned)); +} +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 LoadU(D /* tag */, + const float16_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_f16(detail::NativeLanePointer(unaligned))); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec128 LoadU(D /* tag */, + const bfloat16_t* HWY_RESTRICT unaligned) { + return Vec128(vld1q_bf16(detail::NativeLanePointer(unaligned))); +} +#endif // HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec128 LoadU(D /* tag */, const float* HWY_RESTRICT unaligned) { + return Vec128(vld1q_f32(unaligned)); +} +#if HWY_HAVE_FLOAT64 +template +HWY_API Vec128 LoadU(D /* tag */, + const double* HWY_RESTRICT unaligned) { + return Vec128(vld1q_f64(unaligned)); +} +#endif // HWY_HAVE_FLOAT64 + +// ------------------------------ Load 64 + +template +HWY_API Vec64 LoadU(D /* tag */, const uint8_t* HWY_RESTRICT p) { + return Vec64(vld1_u8(p)); +} +template +HWY_API Vec64 LoadU(D /* tag */, const uint16_t* HWY_RESTRICT p) { + return Vec64(vld1_u16(p)); +} +template +HWY_API Vec64 LoadU(D /* tag */, const uint32_t* HWY_RESTRICT p) { + return Vec64(vld1_u32(p)); +} +template +HWY_API Vec64 LoadU(D /* tag */, const uint64_t* HWY_RESTRICT p) { + return Vec64(vld1_u64(p)); +} +template +HWY_API Vec64 LoadU(D /* tag */, const int8_t* HWY_RESTRICT p) { + return Vec64(vld1_s8(p)); +} +template +HWY_API Vec64 LoadU(D /* tag */, const int16_t* HWY_RESTRICT p) { + return Vec64(vld1_s16(p)); +} +template +HWY_API Vec64 LoadU(D /* tag */, const int32_t* HWY_RESTRICT p) { + return Vec64(vld1_s32(p)); +} +template +HWY_API Vec64 LoadU(D /* tag */, const int64_t* HWY_RESTRICT p) { + return Vec64(vld1_s64(p)); +} +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec64 LoadU(D /* tag */, const float16_t* HWY_RESTRICT p) { + return Vec64(vld1_f16(detail::NativeLanePointer(p))); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec64 LoadU(D /* tag */, const bfloat16_t* HWY_RESTRICT p) { + return Vec64(vld1_bf16(detail::NativeLanePointer(p))); +} +#endif // HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec64 LoadU(D /* tag */, const float* HWY_RESTRICT p) { + return Vec64(vld1_f32(p)); +} +#if HWY_HAVE_FLOAT64 +template +HWY_API Vec64 LoadU(D /* tag */, const double* HWY_RESTRICT p) { + return Vec64(vld1_f64(p)); +} +#endif // HWY_HAVE_FLOAT64 + +// ------------------------------ Load 32 + +// Actual 32-bit broadcast load - used to implement the other lane types +// because reinterpret_cast of the pointer leads to incorrect codegen on GCC. +template +HWY_API Vec32 LoadU(D /*tag*/, const uint32_t* HWY_RESTRICT p) { + return Vec32(vld1_dup_u32(p)); +} +template +HWY_API Vec32 LoadU(D /*tag*/, const int32_t* HWY_RESTRICT p) { + return Vec32(vld1_dup_s32(p)); +} +template +HWY_API Vec32 LoadU(D /*tag*/, const float* HWY_RESTRICT p) { + return Vec32(vld1_dup_f32(p)); +} + +// {u,i}{8,16} +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + const Repartition d32; + uint32_t buf; + CopyBytes<4>(p, &buf); + return BitCast(d, LoadU(d32, &buf)); +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + const Repartition d32; + uint32_t buf; + CopyBytes<4>(p, &buf); + return BitCast(d, LoadU(d32, &buf)); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + const Repartition d32; + uint32_t buf; + CopyBytes<4>(p, &buf); + return BitCast(d, LoadU(d32, &buf)); +} +#endif // HWY_NEON_HAVE_BFLOAT16 + +// ------------------------------ Load 16 + +// Actual 16-bit broadcast load - used to implement the other lane types +// because reinterpret_cast of the pointer leads to incorrect codegen on GCC. +template +HWY_API VFromD LoadU(D /* tag */, const uint16_t* HWY_RESTRICT p) { + return VFromD(vld1_dup_u16(p)); +} +template +HWY_API VFromD LoadU(D /* tag */, const int16_t* HWY_RESTRICT p) { + return VFromD(vld1_dup_s16(p)); +} +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD LoadU(D /* tag */, const float16_t* HWY_RESTRICT p) { + return VFromD(vld1_dup_f16(detail::NativeLanePointer(p))); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API VFromD LoadU(D /* tag */, const bfloat16_t* HWY_RESTRICT p) { + return VFromD(vld1_dup_bf16(detail::NativeLanePointer(p))); +} +#endif // HWY_NEON_HAVE_BFLOAT16 + +// 8-bit x2 +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + const Repartition d16; + uint16_t buf; + CopyBytes<2>(p, &buf); + return BitCast(d, LoadU(d16, &buf)); +} + +// ------------------------------ Load 8 +template +HWY_API VFromD LoadU(D /* tag */, const uint8_t* HWY_RESTRICT p) { + return VFromD(vld1_dup_u8(p)); +} +template +HWY_API VFromD LoadU(D /* tag */, const int8_t* HWY_RESTRICT p) { + return VFromD(vld1_dup_s8(p)); +} + +// ------------------------------ Load misc + +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, LoadU(du, detail::U16LanePointer(p))); +} + +// On Arm, Load is the same as LoadU. +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT p) { + return LoadU(d, p); +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, + const TFromD* HWY_RESTRICT aligned) { + return IfThenElse(m, Load(d, aligned), v); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template +HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { + return LoadU(d, p); +} + +// ------------------------------ Store 128 + +template +HWY_API void StoreU(Vec128 v, D /* tag */, + uint8_t* HWY_RESTRICT unaligned) { + vst1q_u8(unaligned, v.raw); +} +template +HWY_API void StoreU(Vec128 v, D /* tag */, + uint16_t* HWY_RESTRICT unaligned) { + vst1q_u16(unaligned, v.raw); +} +template +HWY_API void StoreU(Vec128 v, D /* tag */, + uint32_t* HWY_RESTRICT unaligned) { + vst1q_u32(unaligned, v.raw); +} +template +HWY_API void StoreU(Vec128 v, D /* tag */, + uint64_t* HWY_RESTRICT unaligned) { + vst1q_u64(unaligned, v.raw); +} +template +HWY_API void StoreU(Vec128 v, D /* tag */, + int8_t* HWY_RESTRICT unaligned) { + vst1q_s8(unaligned, v.raw); +} +template +HWY_API void StoreU(Vec128 v, D /* tag */, + int16_t* HWY_RESTRICT unaligned) { + vst1q_s16(unaligned, v.raw); +} +template +HWY_API void StoreU(Vec128 v, D /* tag */, + int32_t* HWY_RESTRICT unaligned) { + vst1q_s32(unaligned, v.raw); +} +template +HWY_API void StoreU(Vec128 v, D /* tag */, + int64_t* HWY_RESTRICT unaligned) { + vst1q_s64(unaligned, v.raw); +} +#if HWY_HAVE_FLOAT16 +template +HWY_API void StoreU(Vec128 v, D /* tag */, + float16_t* HWY_RESTRICT unaligned) { + vst1q_f16(detail::NativeLanePointer(unaligned), v.raw); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API void StoreU(Vec128 v, D /* tag */, + bfloat16_t* HWY_RESTRICT unaligned) { + vst1q_bf16(detail::NativeLanePointer(unaligned), v.raw); +} +#endif // HWY_NEON_HAVE_BFLOAT16 +template +HWY_API void StoreU(Vec128 v, D /* tag */, + float* HWY_RESTRICT unaligned) { + vst1q_f32(unaligned, v.raw); +} +#if HWY_HAVE_FLOAT64 +template +HWY_API void StoreU(Vec128 v, D /* tag */, + double* HWY_RESTRICT unaligned) { + vst1q_f64(unaligned, v.raw); +} +#endif // HWY_HAVE_FLOAT64 + +// ------------------------------ Store 64 + +template +HWY_API void StoreU(Vec64 v, D /* tag */, uint8_t* HWY_RESTRICT p) { + vst1_u8(p, v.raw); +} +template +HWY_API void StoreU(Vec64 v, D /* tag */, uint16_t* HWY_RESTRICT p) { + vst1_u16(p, v.raw); +} +template +HWY_API void StoreU(Vec64 v, D /* tag */, uint32_t* HWY_RESTRICT p) { + vst1_u32(p, v.raw); +} +template +HWY_API void StoreU(Vec64 v, D /* tag */, uint64_t* HWY_RESTRICT p) { + vst1_u64(p, v.raw); +} +template +HWY_API void StoreU(Vec64 v, D /* tag */, int8_t* HWY_RESTRICT p) { + vst1_s8(p, v.raw); +} +template +HWY_API void StoreU(Vec64 v, D /* tag */, int16_t* HWY_RESTRICT p) { + vst1_s16(p, v.raw); +} +template +HWY_API void StoreU(Vec64 v, D /* tag */, int32_t* HWY_RESTRICT p) { + vst1_s32(p, v.raw); +} +template +HWY_API void StoreU(Vec64 v, D /* tag */, int64_t* HWY_RESTRICT p) { + vst1_s64(p, v.raw); +} +#if HWY_HAVE_FLOAT16 +template +HWY_API void StoreU(Vec64 v, D /* tag */, + float16_t* HWY_RESTRICT p) { + vst1_f16(detail::NativeLanePointer(p), v.raw); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API void StoreU(Vec64 v, D /* tag */, + bfloat16_t* HWY_RESTRICT p) { + vst1_bf16(detail::NativeLanePointer(p), v.raw); +} +#endif // HWY_NEON_HAVE_BFLOAT16 +template +HWY_API void StoreU(Vec64 v, D /* tag */, float* HWY_RESTRICT p) { + vst1_f32(p, v.raw); +} +#if HWY_HAVE_FLOAT64 +template +HWY_API void StoreU(Vec64 v, D /* tag */, double* HWY_RESTRICT p) { + vst1_f64(p, v.raw); +} +#endif // HWY_HAVE_FLOAT64 + +// ------------------------------ Store 32 + +template +HWY_API void StoreU(Vec32 v, D, uint32_t* HWY_RESTRICT p) { + vst1_lane_u32(p, v.raw, 0); +} +template +HWY_API void StoreU(Vec32 v, D, int32_t* HWY_RESTRICT p) { + vst1_lane_s32(p, v.raw, 0); +} +template +HWY_API void StoreU(Vec32 v, D, float* HWY_RESTRICT p) { + vst1_lane_f32(p, v.raw, 0); +} + +// {u,i}{8,16} +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + Repartition d32; + uint32_t buf = GetLane(BitCast(d32, v)); + CopyBytes<4>(&buf, p); +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + Repartition d32; + uint32_t buf = GetLane(BitCast(d32, v)); + CopyBytes<4>(&buf, p); +} +#endif +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + Repartition d32; + uint32_t buf = GetLane(BitCast(d32, v)); + CopyBytes<4>(&buf, p); +} +#endif // HWY_NEON_HAVE_BFLOAT16 + +// ------------------------------ Store 16 + +template +HWY_API void StoreU(Vec16 v, D, uint16_t* HWY_RESTRICT p) { + vst1_lane_u16(p, v.raw, 0); +} +template +HWY_API void StoreU(Vec16 v, D, int16_t* HWY_RESTRICT p) { + vst1_lane_s16(p, v.raw, 0); +} +#if HWY_HAVE_FLOAT16 +template +HWY_API void StoreU(Vec16 v, D, float16_t* HWY_RESTRICT p) { + vst1_lane_f16(detail::NativeLanePointer(p), v.raw, 0); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API void StoreU(Vec16 v, D, bfloat16_t* HWY_RESTRICT p) { + vst1_lane_bf16(detail::NativeLanePointer(p), v.raw, 0); +} +#endif // HWY_NEON_HAVE_BFLOAT16 + +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + const Repartition d16; + const uint16_t buf = GetLane(BitCast(d16, v)); + CopyBytes<2>(&buf, p); +} + +// ------------------------------ Store 8 + +template +HWY_API void StoreU(Vec128 v, D, uint8_t* HWY_RESTRICT p) { + vst1_lane_u8(p, v.raw, 0); +} +template +HWY_API void StoreU(Vec128 v, D, int8_t* HWY_RESTRICT p) { + vst1_lane_s8(p, v.raw, 0); +} + +// ------------------------------ Store misc + +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return StoreU(BitCast(du, v), du, detail::U16LanePointer(p)); +} + +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL +HWY_DIAGNOSTICS_OFF(disable : 4701, ignored "-Wmaybe-uninitialized") +#endif + +// On Arm, Store is the same as StoreU. +template +HWY_API void Store(VFromD v, D d, TFromD* HWY_RESTRICT aligned) { + StoreU(v, d, aligned); +} + +HWY_DIAGNOSTICS(pop) + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + // Treat as unsigned so that we correctly support float16. + const RebindToUnsigned du; + const auto blended = + IfThenElse(RebindMask(du, m), BitCast(du, v), BitCast(du, LoadU(d, p))); + StoreU(BitCast(d, blended), d, p); +} + +// ------------------------------ Non-temporal stores + +// Same as aligned stores on non-x86. + +template +HWY_API void Stream(const VFromD v, D d, TFromD* HWY_RESTRICT aligned) { +#if HWY_ARCH_ARM_A64 +#if HWY_COMPILER_GCC + __builtin_prefetch(aligned, 1, 0); +#elif HWY_COMPILER_MSVC + __prefetch2(aligned, 0x11); +#endif +#endif + Store(v, d, aligned); +} + +// ================================================== CONVERT + +// ------------------------------ ConvertTo + +#if HWY_ARCH_ARM_A64 && HWY_HAVE_FLOAT16 + +// TODO(janwas): use macro generator instead of handwritten +template +HWY_API Vec128 ConvertTo(D /* tag */, Vec128 v) { + return Vec128(vcvtq_f16_s16(v.raw)); +} +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD(vcvt_f16_s16(v.raw)); +} + +template +HWY_API Vec128 ConvertTo(D /* tag */, Vec128 v) { + return Vec128(vcvtq_f16_u16(v.raw)); +} +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD(vcvt_f16_u16(v.raw)); +} + +#endif // HWY_ARCH_ARM_A64 && HWY_HAVE_FLOAT16 + +template +HWY_API Vec128 ConvertTo(D /* tag */, Vec128 v) { + return Vec128(vcvtq_f32_s32(v.raw)); +} +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD(vcvt_f32_s32(v.raw)); +} + +template +HWY_API Vec128 ConvertTo(D /* tag */, Vec128 v) { + return Vec128(vcvtq_f32_u32(v.raw)); +} +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD(vcvt_f32_u32(v.raw)); +} + +#if HWY_HAVE_FLOAT64 + +template +HWY_API Vec128 ConvertTo(D /* tag */, Vec128 v) { + return Vec128(vcvtq_f64_s64(v.raw)); +} +template +HWY_API Vec64 ConvertTo(D /* tag */, Vec64 v) { +// GCC 6.5 and earlier are missing the 64-bit (non-q) intrinsic. +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 + return Set(Full64(), static_cast(GetLane(v))); +#else + return Vec64(vcvt_f64_s64(v.raw)); +#endif // HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 +} + +template +HWY_API Vec128 ConvertTo(D /* tag */, Vec128 v) { + return Vec128(vcvtq_f64_u64(v.raw)); +} +template +HWY_API Vec64 ConvertTo(D /* tag */, Vec64 v) { + // GCC 6.5 and earlier are missing the 64-bit (non-q) intrinsic. +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 + return Set(Full64(), static_cast(GetLane(v))); +#else + return Vec64(vcvt_f64_u64(v.raw)); +#endif // HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 +} + +#endif // HWY_HAVE_FLOAT64 + +namespace detail { +// Truncates (rounds toward zero). +template +HWY_INLINE Vec128 ConvertFToI(D /* tag */, Vec128 v) { +#if HWY_COMPILER_CLANG && \ + ((HWY_ARCH_ARM_A64 && HWY_COMPILER_CLANG < 1200) || HWY_ARCH_ARM_V7) + // If compiling for AArch64 NEON with Clang 11 or earlier or if compiling for + // Armv7 NEON, use inline assembly to avoid undefined behavior if v[i] is + // outside of the range of an int32_t. + + int32x4_t raw_result; + __asm__( +#if HWY_ARCH_ARM_A64 + "fcvtzs %0.4s, %1.4s" +#else + "vcvt.s32.f32 %0, %1" +#endif + : "=w"(raw_result) + : "w"(v.raw)); + return Vec128(raw_result); +#else + return Vec128(vcvtq_s32_f32(v.raw)); +#endif +} +template +HWY_INLINE VFromD ConvertFToI(D /* tag */, VFromD> v) { +#if HWY_COMPILER_CLANG && \ + ((HWY_ARCH_ARM_A64 && HWY_COMPILER_CLANG < 1200) || HWY_ARCH_ARM_V7) + // If compiling for AArch64 NEON with Clang 11 or earlier or if compiling for + // Armv7 NEON, use inline assembly to avoid undefined behavior if v[i] is + // outside of the range of an int32_t. + + int32x2_t raw_result; + __asm__( +#if HWY_ARCH_ARM_A64 + "fcvtzs %0.2s, %1.2s" +#else + "vcvt.s32.f32 %0, %1" +#endif + : "=w"(raw_result) + : "w"(v.raw)); + return VFromD(raw_result); +#else + return VFromD(vcvt_s32_f32(v.raw)); +#endif +} +template +HWY_INLINE Vec128 ConvertFToU(D /* tag */, Vec128 v) { +#if HWY_COMPILER_CLANG && \ + ((HWY_ARCH_ARM_A64 && HWY_COMPILER_CLANG < 1200) || HWY_ARCH_ARM_V7) + // If compiling for AArch64 NEON with Clang 11 or earlier or if compiling for + // Armv7 NEON, use inline assembly to avoid undefined behavior if v[i] is + // outside of the range of an uint32_t. + + uint32x4_t raw_result; + __asm__( +#if HWY_ARCH_ARM_A64 + "fcvtzu %0.4s, %1.4s" +#else + "vcvt.u32.f32 %0, %1" +#endif + : "=w"(raw_result) + : "w"(v.raw)); + return Vec128(raw_result); +#else + return Vec128(vcvtq_u32_f32(v.raw)); +#endif +} +template +HWY_INLINE VFromD ConvertFToU(D /* tag */, VFromD> v) { +#if HWY_COMPILER_CLANG && \ + ((HWY_ARCH_ARM_A64 && HWY_COMPILER_CLANG < 1200) || HWY_ARCH_ARM_V7) + // If compiling for AArch64 NEON with Clang 11 or earlier or if compiling for + // Armv7 NEON, use inline assembly to avoid undefined behavior if v[i] is + // outside of the range of an uint32_t. + + uint32x2_t raw_result; + __asm__( +#if HWY_ARCH_ARM_A64 + "fcvtzu %0.2s, %1.2s" +#else + "vcvt.u32.f32 %0, %1" +#endif + : "=w"(raw_result) + : "w"(v.raw)); + return VFromD(raw_result); +#else + return VFromD(vcvt_u32_f32(v.raw)); +#endif +} + +#if HWY_HAVE_FLOAT64 + +// Truncates (rounds toward zero). +template +HWY_INLINE Vec128 ConvertFToI(D /* tag */, Vec128 v) { +#if HWY_COMPILER_CLANG && HWY_ARCH_ARM_A64 && HWY_COMPILER_CLANG < 1200 + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an int64_t. + int64x2_t raw_result; + __asm__("fcvtzs %0.2d, %1.2d" : "=w"(raw_result) : "w"(v.raw)); + return Vec128(raw_result); +#else + return Vec128(vcvtq_s64_f64(v.raw)); +#endif +} +template +HWY_INLINE Vec64 ConvertFToI(D /* tag */, Vec64 v) { +#if HWY_ARCH_ARM_A64 && \ + ((HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1200)) + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an int64_t. + // If compiling for AArch64 NEON with GCC 6 or earlier, use inline assembly to + // work around the missing vcvt_s64_f64 intrinsic. + int64x1_t raw_result; + __asm__("fcvtzs %d0, %d1" : "=w"(raw_result) : "w"(v.raw)); + return Vec64(raw_result); +#else + return Vec64(vcvt_s64_f64(v.raw)); +#endif +} +template +HWY_INLINE Vec128 ConvertFToU(D /* tag */, Vec128 v) { +#if HWY_COMPILER_CLANG && HWY_ARCH_ARM_A64 && HWY_COMPILER_CLANG < 1200 + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an uint64_t. + uint64x2_t raw_result; + __asm__("fcvtzu %0.2d, %1.2d" : "=w"(raw_result) : "w"(v.raw)); + return Vec128(raw_result); +#else + return Vec128(vcvtq_u64_f64(v.raw)); +#endif +} +template +HWY_INLINE Vec64 ConvertFToU(D /* tag */, Vec64 v) { +#if HWY_ARCH_ARM_A64 && \ + ((HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1200)) + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an uint64_t. + + // Inline assembly is also used if compiling for AArch64 NEON with GCC 6 or + // earlier to work around the issue of the missing vcvt_u64_f64 intrinsic. + uint64x1_t raw_result; + __asm__("fcvtzu %d0, %d1" : "=w"(raw_result) : "w"(v.raw)); + return Vec64(raw_result); +#else + return Vec64(vcvt_u64_f64(v.raw)); +#endif +} + +#endif // HWY_HAVE_FLOAT64 + +#if HWY_ARCH_ARM_A64 && HWY_HAVE_FLOAT16 + +// Truncates (rounds toward zero). +template +HWY_INLINE Vec128 ConvertFToI(D /* tag */, Vec128 v) { +#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1200 + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an int16_t. + int16x8_t raw_result; + __asm__("fcvtzs %0.8h, %1.8h" : "=w"(raw_result) : "w"(v.raw)); + return Vec128(raw_result); +#else + return Vec128(vcvtq_s16_f16(v.raw)); +#endif +} +template +HWY_INLINE VFromD ConvertFToI(D /* tag */, VFromD> v) { +#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1200 + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an int16_t. + int16x4_t raw_result; + __asm__("fcvtzs %0.4h, %1.4h" : "=w"(raw_result) : "w"(v.raw)); + return VFromD(raw_result); +#else + return VFromD(vcvt_s16_f16(v.raw)); +#endif +} + +template +HWY_INLINE Vec128 ConvertFToU(D /* tag */, Vec128 v) { +#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1200 + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an uint16_t. + uint16x8_t raw_result; + __asm__("fcvtzu %0.8h, %1.8h" : "=w"(raw_result) : "w"(v.raw)); + return Vec128(raw_result); +#else + return Vec128(vcvtq_u16_f16(v.raw)); +#endif +} +template +HWY_INLINE VFromD ConvertFToU(D /* tag */, VFromD> v) { +#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1200 + // If compiling for AArch64 NEON with Clang 11 or earlier, use inline assembly + // to avoid undefined behavior if v[i] is outside of the range of an uint16_t. + uint16x4_t raw_result; + __asm__("fcvtzu %0.4h, %1.4h" : "=w"(raw_result) : "w"(v.raw)); + return VFromD(raw_result); +#else + return VFromD(vcvt_u16_f16(v.raw)); +#endif +} + +#endif // HWY_ARCH_ARM_A64 && HWY_HAVE_FLOAT16 +} // namespace detail + +template +HWY_API VFromD ConvertTo(D di, VFromD> v) { + return detail::ConvertFToI(di, v); +} + +template +HWY_API VFromD ConvertTo(D du, VFromD> v) { + return detail::ConvertFToU(du, v); +} + +// ------------------------------ PromoteTo (ConvertTo) + +// Unsigned: zero-extend to full vector. +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec64 v) { + return Vec128(vmovl_u8(v.raw)); +} +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec32 v) { + uint16x8_t a = vmovl_u8(v.raw); + return Vec128(vmovl_u16(vget_low_u16(a))); +} +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec64 v) { + return Vec128(vmovl_u16(v.raw)); +} +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec64 v) { + return Vec128(vmovl_u32(v.raw)); +} +template +HWY_API Vec128 PromoteTo(D d, Vec64 v) { + return BitCast(d, Vec128(vmovl_u8(v.raw))); +} +template +HWY_API Vec128 PromoteTo(D d, Vec32 v) { + uint16x8_t a = vmovl_u8(v.raw); + return BitCast(d, Vec128(vmovl_u16(vget_low_u16(a)))); +} +template +HWY_API Vec128 PromoteTo(D d, Vec64 v) { + return BitCast(d, Vec128(vmovl_u16(v.raw))); +} +template +HWY_API Vec128 PromoteTo(D d, Vec64 v) { + return BitCast(d, Vec128(vmovl_u32(v.raw))); +} + +// Unsigned: zero-extend to half vector. +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD(vget_low_u16(vmovl_u8(v.raw))); +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD(vget_low_u32(vmovl_u16(vget_low_u16(vmovl_u8(v.raw))))); +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD(vget_low_u32(vmovl_u16(v.raw))); +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD(vget_low_u64(vmovl_u32(v.raw))); +} +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { + using VU16 = VFromD>; + return BitCast(d, VU16(vget_low_u16(vmovl_u8(v.raw)))); +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + const uint32x4_t u32 = vmovl_u16(vget_low_u16(vmovl_u8(v.raw))); + return VFromD(vget_low_s32(vreinterpretq_s32_u32(u32))); +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD(vget_low_s32(vreinterpretq_s32_u32(vmovl_u16(v.raw)))); +} +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { + using DU = RebindToUnsigned; + return BitCast(d, VFromD(vget_low_u64(vmovl_u32(v.raw)))); +} + +// U8/U16 to U64/I64: First, zero-extend to U32, and then zero-extend to +// TFromD +template +HWY_API VFromD PromoteTo(D d, V v) { + const Rebind du32; + return PromoteTo(d, PromoteTo(du32, v)); +} + +// Signed: replicate sign bit to full vector. +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec64 v) { + return Vec128(vmovl_s8(v.raw)); +} +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec32 v) { + int16x8_t a = vmovl_s8(v.raw); + return Vec128(vmovl_s16(vget_low_s16(a))); +} +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec64 v) { + return Vec128(vmovl_s16(v.raw)); +} +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec64 v) { + return Vec128(vmovl_s32(v.raw)); +} + +// Signed: replicate sign bit to half vector. +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD(vget_low_s16(vmovl_s8(v.raw))); +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD(vget_low_s32(vmovl_s16(vget_low_s16(vmovl_s8(v.raw))))); +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD(vget_low_s32(vmovl_s16(v.raw))); +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD(vget_low_s64(vmovl_s32(v.raw))); +} + +// I8/I16 to I64: First, promote to I32, and then promote to I64 +template +HWY_API VFromD PromoteTo(D d, V v) { + const Rebind di32; + return PromoteTo(d, PromoteTo(di32, v)); +} + +#if HWY_NEON_HAVE_F16C + +// Per-target flag to prevent generic_ops-inl.h from defining f16 conversions. +#ifdef HWY_NATIVE_F16C +#undef HWY_NATIVE_F16C +#else +#define HWY_NATIVE_F16C +#endif + +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec64 v) { + return Vec128(vcvt_f32_f16(v.raw)); +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD(vget_low_f32(vcvt_f32_f16(v.raw))); +} + +#endif // HWY_NEON_HAVE_F16C + +#if HWY_HAVE_FLOAT64 + +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec64 v) { + return Vec128(vcvt_f64_f32(v.raw)); +} + +template +HWY_API Vec64 PromoteTo(D /* tag */, Vec32 v) { + return Vec64(vget_low_f64(vcvt_f64_f32(v.raw))); +} + +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec64 v) { + const int64x2_t i64 = vmovl_s32(v.raw); + return Vec128(vcvtq_f64_s64(i64)); +} + +template +HWY_API Vec64 PromoteTo(D d, Vec32 v) { + return ConvertTo(d, Vec64(vget_low_s64(vmovl_s32(v.raw)))); +} + +template +HWY_API Vec128 PromoteTo(D /* tag */, Vec64 v) { + const uint64x2_t u64 = vmovl_u32(v.raw); + return Vec128(vcvtq_f64_u64(u64)); +} + +template +HWY_API Vec64 PromoteTo(D d, Vec32 v) { + return ConvertTo(d, Vec64(vget_low_u64(vmovl_u32(v.raw)))); +} + +template +HWY_API VFromD PromoteTo(D d64, VFromD> v) { + const RebindToFloat df64; + return ConvertTo(d64, PromoteTo(df64, v)); +} + +#else // !HWY_HAVE_FLOAT64 + +template +HWY_API VFromD PromoteTo(D di64, VFromD> v) { + const Rebind di32; + const RebindToFloat df32; + const RebindToUnsigned du32; + const Repartition du32_as_du8; + + const auto exponent_adj = BitCast( + du32, + Min(SaturatedSub(BitCast(du32_as_du8, ShiftRight<23>(BitCast(du32, v))), + BitCast(du32_as_du8, Set(du32, uint32_t{157}))), + BitCast(du32_as_du8, Set(du32, uint32_t{32})))); + const auto adj_v = + BitCast(df32, BitCast(du32, v) - ShiftLeft<23>(exponent_adj)); + + const auto f32_to_i32_result = ConvertTo(di32, adj_v); + const auto lo64_or_mask = PromoteTo( + di64, + BitCast(du32, VecFromMask(di32, Eq(f32_to_i32_result, + Set(di32, LimitsMax()))))); + + return Or(PromoteTo(di64, BitCast(di32, f32_to_i32_result)) + << PromoteTo(di64, exponent_adj), + lo64_or_mask); +} + +template +HWY_API VFromD PromoteTo(D du64, VFromD> v) { + const Rebind du32; + const RebindToFloat df32; + const Repartition du32_as_du8; + + const auto exponent_adj = BitCast( + du32, + Min(SaturatedSub(BitCast(du32_as_du8, ShiftRight<23>(BitCast(du32, v))), + BitCast(du32_as_du8, Set(du32, uint32_t{158}))), + BitCast(du32_as_du8, Set(du32, uint32_t{32})))); + + const auto adj_v = + BitCast(df32, BitCast(du32, v) - ShiftLeft<23>(exponent_adj)); + const auto f32_to_u32_result = ConvertTo(du32, adj_v); + const auto lo32_or_mask = PromoteTo( + du64, + VecFromMask(du32, f32_to_u32_result == Set(du32, LimitsMax()))); + + return Or(PromoteTo(du64, f32_to_u32_result) << PromoteTo(du64, exponent_adj), + lo32_or_mask); +} + +#ifdef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#endif + +template +HWY_API VFromD PromoteInRangeTo(D d64, VFromD> v) { + const Rebind>, decltype(d64)> d32; + const RebindToFloat df32; + const RebindToUnsigned du32; + const Repartition du32_as_du8; + + constexpr uint32_t kExpAdjDecr = + 0xFFFFFF9Du + static_cast(!IsSigned>()); + + const auto exponent_adj = BitCast( + du32, SaturatedSub(BitCast(du32_as_du8, ShiftRight<23>(BitCast(du32, v))), + BitCast(du32_as_du8, Set(du32, kExpAdjDecr)))); + const auto adj_v = + BitCast(df32, BitCast(du32, v) - ShiftLeft<23>(exponent_adj)); + + return PromoteTo(d64, ConvertTo(d32, adj_v)) << PromoteTo(d64, exponent_adj); +} + +#endif // HWY_HAVE_FLOAT64 + +// ------------------------------ PromoteEvenTo/PromoteOddTo +#include "hwy/ops/inside-inl.h" + +// ------------------------------ PromoteUpperTo + +#if HWY_ARCH_ARM_A64 + +// Per-target flag to prevent generic_ops-inl.h from defining PromoteUpperTo. +#ifdef HWY_NATIVE_PROMOTE_UPPER_TO +#undef HWY_NATIVE_PROMOTE_UPPER_TO +#else +#define HWY_NATIVE_PROMOTE_UPPER_TO +#endif + +// Unsigned: zero-extend to full vector. +template +HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { + return Vec128(vmovl_high_u8(v.raw)); +} +template +HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { + return Vec128(vmovl_high_u16(v.raw)); +} +template +HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { + return Vec128(vmovl_high_u32(v.raw)); +} +template +HWY_API Vec128 PromoteUpperTo(D d, Vec128 v) { + return BitCast(d, Vec128(vmovl_high_u8(v.raw))); +} +template +HWY_API Vec128 PromoteUpperTo(D d, Vec128 v) { + return BitCast(d, Vec128(vmovl_high_u16(v.raw))); +} +template +HWY_API Vec128 PromoteUpperTo(D d, Vec128 v) { + return BitCast(d, Vec128(vmovl_high_u32(v.raw))); +} + +// Signed: replicate sign bit to full vector. +template +HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { + return Vec128(vmovl_high_s8(v.raw)); +} +template +HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { + return Vec128(vmovl_high_s16(v.raw)); +} +template +HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { + return Vec128(vmovl_high_s32(v.raw)); +} + +#if HWY_NEON_HAVE_F16C + +template +HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { + return Vec128(vcvt_high_f32_f16(v.raw)); +} + +#endif // HWY_NEON_HAVE_F16C + +template +HWY_API VFromD PromoteUpperTo(D df32, VFromD> v) { + const Repartition du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteUpperTo(di32, BitCast(du16, v)))); +} + +#if HWY_HAVE_FLOAT64 + +template +HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { + return Vec128(vcvt_high_f64_f32(v.raw)); +} + +template +HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { + const int64x2_t i64 = vmovl_high_s32(v.raw); + return Vec128(vcvtq_f64_s64(i64)); +} + +template +HWY_API Vec128 PromoteUpperTo(D /* tag */, Vec128 v) { + const uint64x2_t u64 = vmovl_high_u32(v.raw); + return Vec128(vcvtq_f64_u64(u64)); +} + +#endif // HWY_HAVE_FLOAT64 + +template +HWY_API VFromD PromoteUpperTo(D d64, Vec128 v) { +#if HWY_HAVE_FLOAT64 + const RebindToFloat df64; + return ConvertTo(d64, PromoteUpperTo(df64, v)); +#else + const Rebind dh; + return PromoteTo(d, UpperHalf(dh, v)); +#endif +} + +// Generic version for <=64 bit input/output (_high is only for full vectors). +template +HWY_API VFromD PromoteUpperTo(D d, V v) { + const Rebind, decltype(d)> dh; + return PromoteTo(d, UpperHalf(dh, v)); +} + +#endif // HWY_ARCH_ARM_A64 + +// ------------------------------ DemoteTo (ConvertTo) + +// From full vector to half or quarter +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { + return Vec64(vqmovun_s32(v.raw)); +} +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { + return Vec64(vqmovn_s32(v.raw)); +} +template +HWY_API Vec32 DemoteTo(D /* tag */, Vec128 v) { + const uint16x4_t a = vqmovun_s32(v.raw); + return Vec32(vqmovn_u16(vcombine_u16(a, a))); +} +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { + return Vec64(vqmovun_s16(v.raw)); +} +template +HWY_API Vec32 DemoteTo(D /* tag */, Vec128 v) { + const int16x4_t a = vqmovn_s32(v.raw); + return Vec32(vqmovn_s16(vcombine_s16(a, a))); +} +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { + return Vec64(vqmovn_s16(v.raw)); +} +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { + return Vec64(vqmovn_u32(v.raw)); +} +template +HWY_API Vec32 DemoteTo(D /* tag */, Vec128 v) { + const uint16x4_t a = vqmovn_u32(v.raw); + return Vec32(vqmovn_u16(vcombine_u16(a, a))); +} +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { + return Vec64(vqmovn_u16(v.raw)); +} + +// From half vector to partial half +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD(vqmovun_s32(vcombine_s32(v.raw, v.raw))); +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD(vqmovn_s32(vcombine_s32(v.raw, v.raw))); +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + const uint16x4_t a = vqmovun_s32(vcombine_s32(v.raw, v.raw)); + return VFromD(vqmovn_u16(vcombine_u16(a, a))); +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD(vqmovun_s16(vcombine_s16(v.raw, v.raw))); +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + const int16x4_t a = vqmovn_s32(vcombine_s32(v.raw, v.raw)); + return VFromD(vqmovn_s16(vcombine_s16(a, a))); +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD(vqmovn_s16(vcombine_s16(v.raw, v.raw))); +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD(vqmovn_u32(vcombine_u32(v.raw, v.raw))); +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + const uint16x4_t a = vqmovn_u32(vcombine_u32(v.raw, v.raw)); + return VFromD(vqmovn_u16(vcombine_u16(a, a))); +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD(vqmovn_u16(vcombine_u16(v.raw, v.raw))); +} + +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { + return Vec64(vqmovn_s64(v.raw)); +} +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { + return Vec64(vqmovun_s64(v.raw)); +} +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { + return Vec64(vqmovn_u64(v.raw)); +} +template +HWY_API VFromD DemoteTo(D d, Vec128 v) { + const Rebind di32; + return DemoteTo(d, DemoteTo(di32, v)); +} +template +HWY_API VFromD DemoteTo(D d, Vec128 v) { + const Rebind du32; + return DemoteTo(d, DemoteTo(du32, v)); +} +template +HWY_API VFromD DemoteTo(D d, Vec128 v) { + const Rebind du32; + return DemoteTo(d, DemoteTo(du32, v)); +} + +template +HWY_API Vec32 DemoteTo(D /* tag */, Vec64 v) { + return Vec32(vqmovn_s64(vcombine_s64(v.raw, v.raw))); +} +template +HWY_API Vec32 DemoteTo(D /* tag */, Vec64 v) { + return Vec32(vqmovun_s64(vcombine_s64(v.raw, v.raw))); +} +template +HWY_API Vec32 DemoteTo(D /* tag */, Vec64 v) { + return Vec32(vqmovn_u64(vcombine_u64(v.raw, v.raw))); +} +template +HWY_API VFromD DemoteTo(D d, Vec64 v) { + const Rebind di32; + return DemoteTo(d, DemoteTo(di32, v)); +} +template +HWY_API VFromD DemoteTo(D d, Vec64 v) { + const Rebind du32; + return DemoteTo(d, DemoteTo(du32, v)); +} +template +HWY_API VFromD DemoteTo(D d, Vec64 v) { + const Rebind du32; + return DemoteTo(d, DemoteTo(du32, v)); +} + +#if HWY_NEON_HAVE_F16C + +// We already toggled HWY_NATIVE_F16C above. + +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { + return Vec64{vcvt_f16_f32(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD(vcvt_f16_f32(vcombine_f32(v.raw, v.raw))); +} + +#endif // HWY_NEON_HAVE_F16C + +#if HWY_NEON_HAVE_F32_TO_BF16C +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + +namespace detail { +#if HWY_NEON_HAVE_BFLOAT16 +// If HWY_NEON_HAVE_BFLOAT16 is true, detail::Vec128::type is +// bfloat16x4_t or bfloat16x8_t. +static HWY_INLINE bfloat16x4_t BitCastFromRawNeonBF16(bfloat16x4_t raw) { + return raw; +} +#else +// If HWY_NEON_HAVE_F32_TO_BF16C && !HWY_NEON_HAVE_BFLOAT16 is true, +// detail::Vec128::type is uint16x4_t or uint16x8_t vector to +// work around compiler bugs that are there with GCC 13 or earlier or Clang 16 +// or earlier on AArch64. + +// The bfloat16x4_t vector returned by vcvt_bf16_f32 needs to be bitcasted to +// an uint16x4_t vector if HWY_NEON_HAVE_F32_TO_BF16C && +// !HWY_NEON_HAVE_BFLOAT16 is true. +static HWY_INLINE uint16x4_t BitCastFromRawNeonBF16(bfloat16x4_t raw) { + return vreinterpret_u16_bf16(raw); +} +#endif +} // namespace detail + +template +HWY_API VFromD DemoteTo(D /*dbf16*/, VFromD> v) { + return VFromD(detail::BitCastFromRawNeonBF16(vcvt_bf16_f32(v.raw))); +} +template +HWY_API VFromD DemoteTo(D /*dbf16*/, VFromD> v) { + return VFromD(detail::BitCastFromRawNeonBF16( + vcvt_bf16_f32(vcombine_f32(v.raw, v.raw)))); +} +#endif // HWY_NEON_HAVE_F32_TO_BF16C + +#if HWY_HAVE_FLOAT64 + +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec128 v) { + return Vec64(vcvt_f32_f64(v.raw)); +} +template +HWY_API Vec32 DemoteTo(D /* tag */, Vec64 v) { + return Vec32(vcvt_f32_f64(vcombine_f64(v.raw, v.raw))); +} + +template +HWY_API VFromD DemoteTo(D d32, VFromD> v) { + const Rebind>, D> d64; + return DemoteTo(d32, ConvertTo(d64, v)); +} + +#endif // HWY_HAVE_FLOAT64 + +template +HWY_API VFromD DemoteTo(D df32, VFromD> v) { + const Rebind di64; + const RebindToUnsigned du64; + +#if HWY_ARCH_ARM_A64 + const RebindToFloat df64; + + const auto k2p64_63 = Set(df64, 27670116110564327424.0); + const auto f64_hi52 = + Xor(BitCast(df64, ShiftRight<12>(BitCast(du64, v))), k2p64_63) - k2p64_63; + const auto f64_lo12 = + ConvertTo(df64, And(BitCast(du64, v), Set(du64, uint64_t{0x00000FFF}))); + + const auto f64_sum = f64_hi52 + f64_lo12; + const auto f64_carry = (f64_hi52 - f64_sum) + f64_lo12; + + const auto f64_sum_is_inexact = + ShiftRight<63>(BitCast(du64, VecFromMask(df64, f64_carry != Zero(df64)))); + const auto f64_bits_decrement = + And(ShiftRight<63>(BitCast(du64, Xor(f64_sum, f64_carry))), + f64_sum_is_inexact); + + const auto adj_f64_val = BitCast( + df64, + Or(BitCast(du64, f64_sum) - f64_bits_decrement, f64_sum_is_inexact)); + + return DemoteTo(df32, adj_f64_val); +#else + const RebindToUnsigned du32; + const auto hi23 = TruncateTo(du32, ShiftRight<41>(BitCast(du64, v))); + const auto mid23 = And(TruncateTo(du32, ShiftRight<18>(BitCast(du64, v))), + Set(du32, uint32_t{0x007FFFFFu})); + const auto lo18 = + And(TruncateTo(du32, BitCast(du64, v)), Set(du32, uint32_t{0x0003FFFFu})); + + const auto k2p41_f32 = Set(df32, 2199023255552.0f); + const auto k2p64_63_f32 = Set(df32, 27670116110564327424.0f); + + const auto hi23_f32 = + BitCast(df32, Xor(hi23, BitCast(du32, k2p64_63_f32))) - k2p64_63_f32; + const auto mid23_f32 = + BitCast(df32, Or(mid23, BitCast(du32, k2p41_f32))) - k2p41_f32; + const auto lo18_f32 = ConvertTo(df32, lo18); + + const auto s_hi46 = hi23_f32 + mid23_f32; + const auto c_hi46 = (hi23_f32 - s_hi46) + mid23_f32; + + auto s_lo = c_hi46 + lo18_f32; + const auto c_lo = (c_hi46 - s_lo) + lo18_f32; + + const auto s_lo_inexact_mask = + VecFromMask(du32, RebindMask(du32, c_lo != Zero(df32))); + const auto s_lo_mag_adj = ShiftRight<31>( + And(s_lo_inexact_mask, Xor(BitCast(du32, s_lo), BitCast(du32, c_lo)))); + + s_lo = BitCast(df32, BitCast(du32, s_lo) - s_lo_mag_adj); + s_lo = + BitCast(df32, Or(BitCast(du32, s_lo), ShiftRight<31>(s_lo_inexact_mask))); + return s_hi46 + s_lo; +#endif +} + +template +HWY_API VFromD DemoteTo(D df32, VFromD> v) { +#if HWY_ARCH_ARM_A64 + const Rebind du64; + const RebindToFloat df64; + + const auto k2p64 = Set(df64, 18446744073709551616.0); + const auto f64_hi52 = Or(BitCast(df64, ShiftRight<12>(v)), k2p64) - k2p64; + const auto f64_lo12 = + ConvertTo(df64, And(v, Set(du64, uint64_t{0x00000FFF}))); + + const auto f64_sum = f64_hi52 + f64_lo12; + const auto f64_carry = (f64_hi52 - f64_sum) + f64_lo12; + const auto f64_sum_is_inexact = + ShiftRight<63>(BitCast(du64, VecFromMask(df64, f64_carry != Zero(df64)))); + + const auto adj_f64_val = BitCast( + df64, + Or(BitCast(du64, f64_sum) - ShiftRight<63>(BitCast(du64, f64_carry)), + f64_sum_is_inexact)); + + return DemoteTo(df32, adj_f64_val); +#else + const RebindToUnsigned du32; + + const auto hi23 = TruncateTo(du32, ShiftRight<41>(v)); + const auto mid23 = And(TruncateTo(du32, ShiftRight<18>(v)), + Set(du32, uint32_t{0x007FFFFFu})); + const auto lo18 = And(TruncateTo(du32, v), Set(du32, uint32_t{0x0003FFFFu})); + + const auto k2p41_f32 = Set(df32, 2199023255552.0f); + const auto k2p64_f32 = Set(df32, 18446744073709551616.0f); + + const auto hi23_f32 = + BitCast(df32, Or(hi23, BitCast(du32, k2p64_f32))) - k2p64_f32; + const auto mid23_f32 = + BitCast(df32, Or(mid23, BitCast(du32, k2p41_f32))) - k2p41_f32; + const auto lo18_f32 = ConvertTo(df32, lo18); + + const auto s_hi46 = hi23_f32 + mid23_f32; + const auto c_hi46 = (hi23_f32 - s_hi46) + mid23_f32; + + auto s_lo = c_hi46 + lo18_f32; + const auto c_lo = (c_hi46 - s_lo) + lo18_f32; + + const auto s_lo_inexact_mask = + VecFromMask(du32, RebindMask(du32, c_lo != Zero(df32))); + const auto s_lo_mag_adj = ShiftRight<31>( + And(s_lo_inexact_mask, Xor(BitCast(du32, s_lo), BitCast(du32, c_lo)))); + + s_lo = BitCast(df32, BitCast(du32, s_lo) - s_lo_mag_adj); + s_lo = + BitCast(df32, Or(BitCast(du32, s_lo), ShiftRight<31>(s_lo_inexact_mask))); + return s_hi46 + s_lo; +#endif +} + +HWY_API Vec32 U8FromU32(Vec128 v) { + const uint8x16_t org_v = detail::BitCastToByte(v).raw; + const uint8x16_t w = vuzp1q_u8(org_v, org_v); + return Vec32(vget_low_u8(vuzp1q_u8(w, w))); +} +template +HWY_API Vec128 U8FromU32(Vec128 v) { + const uint8x8_t org_v = detail::BitCastToByte(v).raw; + const uint8x8_t w = vuzp1_u8(org_v, org_v); + return Vec128(vuzp1_u8(w, w)); +} + +// ------------------------------ Round (IfThenElse, mask, logical) + +#if HWY_ARCH_ARM_A64 +// Toward nearest integer +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Round, vrndn, _, 1) + +// Toward zero, aka truncate +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Trunc, vrnd, _, 1) + +// Toward +infinity, aka ceiling +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Ceil, vrndp, _, 1) + +// Toward -infinity, aka floor +HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Floor, vrndm, _, 1) +#else + +// ------------------------------ Trunc + +// Armv7 only supports truncation to integer. We can either convert back to +// float (3 floating-point and 2 logic operations) or manipulate the binary32 +// representation, clearing the lowest 23-exp mantissa bits. This requires 9 +// integer operations and 3 constants, which is likely more expensive. + +namespace detail { + +// The original value is already the desired result if NaN or the magnitude is +// large (i.e. the value is already an integer). +template +HWY_INLINE Mask128 UseInt(const Vec128 v) { + return Abs(v) < Set(Simd(), MantissaEnd()); +} + +} // namespace detail + +template +HWY_API Vec128 Trunc(const Vec128 v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), int_f, v); +} + +template +HWY_API Vec128 Round(const Vec128 v) { + const DFromV df; + + // Armv7 also lacks a native NearestInt, but we can instead rely on rounding + // (we assume the current mode is nearest-even) after addition with a large + // value such that no mantissa bits remain. We may need a compiler flag for + // precise floating-point to prevent this from being "optimized" out. + const auto max = Set(df, MantissaEnd()); + const auto large = CopySignToAbs(max, v); + const auto added = large + v; + const auto rounded = added - large; + + // Keep original if NaN or the magnitude is large (already an int). + return IfThenElse(Abs(v) < max, rounded, v); +} + +template +HWY_API Vec128 Ceil(const Vec128 v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a positive non-integer ends up smaller; if so, add 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f < v))); + + return IfThenElse(detail::UseInt(v), int_f - neg1, v); +} + +template +HWY_API Vec128 Floor(const Vec128 v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a negative non-integer ends up larger; if so, subtract 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f > v))); + + return IfThenElse(detail::UseInt(v), int_f + neg1, v); +} + +#endif + +// ------------------------------ CeilInt/FloorInt +#if HWY_ARCH_ARM_A64 + +#ifdef HWY_NATIVE_CEIL_FLOOR_INT +#undef HWY_NATIVE_CEIL_FLOOR_INT +#else +#define HWY_NATIVE_CEIL_FLOOR_INT +#endif + +#if HWY_HAVE_FLOAT16 +HWY_API Vec128 CeilInt(const Vec128 v) { + return Vec128(vcvtpq_s16_f16(v.raw)); +} + +template +HWY_API Vec128 CeilInt(const Vec128 v) { + return Vec128(vcvtp_s16_f16(v.raw)); +} + +HWY_API Vec128 FloorInt(const Vec128 v) { + return Vec128(vcvtmq_s16_f16(v.raw)); +} + +template +HWY_API Vec128 FloorInt(const Vec128 v) { + return Vec128(vcvtm_s16_f16(v.raw)); +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Vec128 CeilInt(const Vec128 v) { + return Vec128(vcvtpq_s32_f32(v.raw)); +} + +template +HWY_API Vec128 CeilInt(const Vec128 v) { + return Vec128(vcvtp_s32_f32(v.raw)); +} + +HWY_API Vec128 CeilInt(const Vec128 v) { + return Vec128(vcvtpq_s64_f64(v.raw)); +} + +template +HWY_API Vec128 CeilInt(const Vec128 v) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 610 + // Workaround for missing vcvtp_s64_f64 intrinsic + const DFromV d; + const RebindToSigned di; + const Twice dt; + return LowerHalf(di, CeilInt(Combine(dt, v, v))); +#else + return Vec128(vcvtp_s64_f64(v.raw)); +#endif +} + +HWY_API Vec128 FloorInt(const Vec128 v) { + return Vec128(vcvtmq_s32_f32(v.raw)); +} + +template +HWY_API Vec128 FloorInt(const Vec128 v) { + return Vec128(vcvtm_s32_f32(v.raw)); +} + +HWY_API Vec128 FloorInt(const Vec128 v) { + return Vec128(vcvtmq_s64_f64(v.raw)); +} + +template +HWY_API Vec128 FloorInt(const Vec128 v) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 610 + // Workaround for missing vcvtm_s64_f64 intrinsic + const DFromV d; + const RebindToSigned di; + const Twice dt; + return LowerHalf(di, FloorInt(Combine(dt, v, v))); +#else + return Vec128(vcvtm_s64_f64(v.raw)); +#endif +} + +#endif // HWY_ARCH_ARM_A64 + +// ------------------------------ NearestInt (Round) + +#if HWY_HAVE_FLOAT16 +HWY_API Vec128 NearestInt(const Vec128 v) { + return Vec128(vcvtnq_s16_f16(v.raw)); +} +template +HWY_API Vec128 NearestInt(const Vec128 v) { + return Vec128(vcvtn_s16_f16(v.raw)); +} +#endif + +#if HWY_ARCH_ARM_A64 + +HWY_API Vec128 NearestInt(const Vec128 v) { + return Vec128(vcvtnq_s32_f32(v.raw)); +} +template +HWY_API Vec128 NearestInt(const Vec128 v) { + return Vec128(vcvtn_s32_f32(v.raw)); +} + +HWY_API Vec128 NearestInt(const Vec128 v) { + return Vec128(vcvtnq_s64_f64(v.raw)); +} + +template +HWY_API Vec128 NearestInt(const Vec128 v) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 610 + // Workaround for missing vcvtn_s64_f64 intrinsic + const DFromV d; + const RebindToSigned di; + const Twice dt; + return LowerHalf(di, NearestInt(Combine(dt, v, v))); +#else + return Vec128(vcvtn_s64_f64(v.raw)); +#endif +} + +template +HWY_API VFromD DemoteToNearestInt(DI32 di32, + VFromD> v) { + return DemoteTo(di32, NearestInt(v)); +} + +#else + +template +HWY_API Vec128 NearestInt(const Vec128 v) { + const RebindToSigned> di; + return ConvertTo(di, Round(v)); +} + +#endif + +// ------------------------------ Floating-point classification + +#if !HWY_COMPILER_CLANG || HWY_COMPILER_CLANG > 1801 || HWY_ARCH_ARM_V7 +template +HWY_API Mask128 IsNaN(const Vec128 v) { + return v != v; +} +#else +// Clang up to 18.1 generates less efficient code than the expected FCMEQ, see +// https://github.com/numpy/numpy/issues/27313 and +// https://github.com/numpy/numpy/pull/22954/files and +// https://github.com/llvm/llvm-project/issues/59855 + +#if HWY_HAVE_FLOAT16 +template +HWY_API Mask128 IsNaN(const Vec128 v) { + typename Mask128::Raw ret; + __asm__ volatile("fcmeq %0.8h, %1.8h, %1.8h" : "=w"(ret) : "w"(v.raw)); + return Not(Mask128(ret)); +} +template +HWY_API Mask128 IsNaN(const Vec128 v) { + typename Mask128::Raw ret; + __asm__ volatile("fcmeq %0.4h, %1.4h, %1.4h" : "=w"(ret) : "w"(v.raw)); + return Not(Mask128(ret)); +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API Mask128 IsNaN(const Vec128 v) { + typename Mask128::Raw ret; + __asm__ volatile("fcmeq %0.4s, %1.4s, %1.4s" : "=w"(ret) : "w"(v.raw)); + return Not(Mask128(ret)); +} +template +HWY_API Mask128 IsNaN(const Vec128 v) { + typename Mask128::Raw ret; + __asm__ volatile("fcmeq %0.2s, %1.2s, %1.2s" : "=w"(ret) : "w"(v.raw)); + return Not(Mask128(ret)); +} + +#if HWY_HAVE_FLOAT64 +template +HWY_API Mask128 IsNaN(const Vec128 v) { + typename Mask128::Raw ret; + __asm__ volatile("fcmeq %0.2d, %1.2d, %1.2d" : "=w"(ret) : "w"(v.raw)); + return Not(Mask128(ret)); +} +template +HWY_API Mask128 IsNaN(const Vec128 v) { + typename Mask128::Raw ret; + __asm__ volatile("fcmeq %d0, %d1, %d1" : "=w"(ret) : "w"(v.raw)); + return Not(Mask128(ret)); +} +#endif // HWY_HAVE_FLOAT64 + +#endif // HWY_COMPILER_CLANG + +// ================================================== SWIZZLE + +// ------------------------------ LowerHalf + +// <= 64 bit: just return different type +template +HWY_API Vec128 LowerHalf(Vec128 v) { + return Vec128(v.raw); +} + +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_u8(v.raw)); +} +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_u16(v.raw)); +} +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_u32(v.raw)); +} +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_u64(v.raw)); +} +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_s8(v.raw)); +} +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_s16(v.raw)); +} +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_s32(v.raw)); +} +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_s64(v.raw)); +} +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_f32(v.raw)); +} +#if HWY_HAVE_FLOAT16 +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_f16(v.raw)); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_bf16(v.raw)); +} +#endif // HWY_NEON_HAVE_BFLOAT16 +#if HWY_HAVE_FLOAT64 +HWY_API Vec64 LowerHalf(Vec128 v) { + return Vec64(vget_low_f64(v.raw)); +} +#endif // HWY_HAVE_FLOAT64 + +template ), HWY_IF_V_SIZE_V(V, 16)> +HWY_API VFromD>> LowerHalf(V v) { + const Full128 du; + const Half> dh; + return BitCast(dh, LowerHalf(BitCast(du, v))); +} + +template +HWY_API VFromD LowerHalf(DH /* tag */, VFromD> v) { + return LowerHalf(v); +} + +// ------------------------------ CombineShiftRightBytes + +// 128-bit +template > +HWY_API Vec128 CombineShiftRightBytes(D d, Vec128 hi, Vec128 lo) { + static_assert(0 < kBytes && kBytes < 16, "kBytes must be in [1, 15]"); + const Repartition d8; + uint8x16_t v8 = vextq_u8(BitCast(d8, lo).raw, BitCast(d8, hi).raw, kBytes); + return BitCast(d, Vec128(v8)); +} + +// 64-bit +template > +HWY_API Vec64 CombineShiftRightBytes(D d, Vec64 hi, Vec64 lo) { + static_assert(0 < kBytes && kBytes < 8, "kBytes must be in [1, 7]"); + const Repartition d8; + uint8x8_t v8 = vext_u8(BitCast(d8, lo).raw, BitCast(d8, hi).raw, kBytes); + return BitCast(d, VFromD(v8)); +} + +// <= 32-bit defined after ShiftLeftBytes. + +// ------------------------------ Shift vector by constant #bytes + +namespace detail { + +// Partially specialize because kBytes = 0 and >= size are compile errors; +// callers replace the latter with 0xFF for easier specialization. +template +struct ShiftLeftBytesT { + // Full + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + const Full128 d; + return CombineShiftRightBytes<16 - kBytes>(d, v, Zero(d)); + } + + // Partial + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + // Expand to 64-bit so we only use the native EXT instruction. + const Full64 d64; + const auto zero64 = Zero(d64); + const decltype(zero64) v64(v.raw); + return Vec128( + CombineShiftRightBytes<8 - kBytes>(d64, v64, zero64).raw); + } +}; +template <> +struct ShiftLeftBytesT<0> { + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + return v; + } +}; +template <> +struct ShiftLeftBytesT<0xFF> { + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + return Xor(v, v); + } +}; + +template +struct ShiftRightBytesT { + template + HWY_INLINE Vec128 operator()(Vec128 v) { + const DFromV d; + // For < 64-bit vectors, zero undefined lanes so we shift in zeros. + if (d.MaxBytes() < 8) { + constexpr size_t kReg = d.MaxBytes() == 16 ? 16 : 8; + const Simd dreg; + v = Vec128( + IfThenElseZero(FirstN(dreg, N), VFromD(v.raw)).raw); + } + return CombineShiftRightBytes(d, Zero(d), v); + } +}; +template <> +struct ShiftRightBytesT<0> { + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + return v; + } +}; +template <> +struct ShiftRightBytesT<0xFF> { + template + HWY_INLINE Vec128 operator()(const Vec128 v) { + return Xor(v, v); + } +}; + +} // namespace detail + +template +HWY_API VFromD ShiftLeftBytes(D d, VFromD v) { + return detail::ShiftLeftBytesT<(kBytes >= d.MaxBytes() ? 0xFF : kBytes)>()(v); +} + +template +HWY_API Vec128 ShiftLeftBytes(Vec128 v) { + return ShiftLeftBytes(DFromV(), v); +} + +template +HWY_API VFromD ShiftLeftLanes(D d, VFromD v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes)>(BitCast(d8, v))); +} + +template +HWY_API Vec128 ShiftLeftLanes(Vec128 v) { + return ShiftLeftLanes(DFromV(), v); +} + +// 0x01..0F, kBytes = 1 => 0x0001..0E +template +HWY_API VFromD ShiftRightBytes(D d, VFromD v) { + return detail::ShiftRightBytesT<(kBytes >= d.MaxBytes() ? 0xFF : kBytes)>()( + v); +} + +template +HWY_API VFromD ShiftRightLanes(D d, VFromD v) { + const Repartition d8; + return BitCast( + d, ShiftRightBytes)>(d8, BitCast(d8, v))); +} + +// Calls ShiftLeftBytes +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + constexpr size_t kSize = d.MaxBytes(); + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); + const Repartition d8; + const Full64 d_full8; + const Repartition, decltype(d_full8)> d_full; + using V64 = VFromD; + const V64 hi64(BitCast(d8, hi).raw); + // Move into most-significant bytes + const V64 lo64 = ShiftLeftBytes<8 - kSize>(V64(BitCast(d8, lo).raw)); + const V64 r = CombineShiftRightBytes<8 - kSize + kBytes>(d_full8, hi64, lo64); + // After casting to full 64-bit vector of correct type, shrink to 32-bit + return VFromD(BitCast(d_full, r).raw); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +// Full input +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_u8(v.raw)); +} +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_u16(v.raw)); +} +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_u32(v.raw)); +} +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_u64(v.raw)); +} +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_s8(v.raw)); +} +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_s16(v.raw)); +} +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_s32(v.raw)); +} +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_s64(v.raw)); +} +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_f16(v.raw)); +} +#endif +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_bf16(v.raw)); +} +#endif // HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_f32(v.raw)); +} +#if HWY_HAVE_FLOAT64 +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64(vget_high_f64(v.raw)); +} +#endif // HWY_HAVE_FLOAT64 + +template +HWY_API VFromD UpperHalf(D dh, VFromD> v) { + const RebindToUnsigned> du; + const Half duh; + return BitCast(dh, UpperHalf(duh, BitCast(du, v))); +} + +// Partial +template +HWY_API VFromD UpperHalf(DH dh, VFromD> v) { + const Twice d; + const RebindToUnsigned du; + const VFromD upper = + ShiftRightBytes(du, BitCast(du, v)); + return VFromD(BitCast(d, upper).raw); +} + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec128 Broadcast(Vec128 v) { + return v; +} + +#if HWY_ARCH_ARM_A64 +// Unsigned +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 16, "Invalid lane"); + return Vec128(vdupq_laneq_u8(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u8(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_laneq_u16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_laneq_u32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_laneq_u64(v.raw, kLane)); +} + +// Signed +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 16, "Invalid lane"); + return Vec128(vdupq_laneq_s8(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s8(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_laneq_s16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_laneq_s32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_laneq_s64(v.raw, kLane)); +} + +// Float +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_laneq_f16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_f16(v.raw, kLane)); +} +#endif // HWY_HAVE_FLOAT16 + +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_laneq_bf16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_bf16(v.raw, kLane)); +} +#endif // HWY_NEON_HAVE_BFLOAT16 + +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_laneq_f32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_f32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_laneq_f64(v.raw, kLane)); +} + +#else // !HWY_ARCH_ARM_A64 +// No vdupq_laneq_* on armv7: use vgetq_lane_* + vdupq_n_*. + +// Unsigned +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 16, "Invalid lane"); + return Vec128(vdupq_n_u8(vgetq_lane_u8(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u8(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_n_u16(vgetq_lane_u16(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_n_u32(vgetq_lane_u32(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_u32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_n_u64(vgetq_lane_u64(v.raw, kLane))); +} + +// Signed +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 16, "Invalid lane"); + return Vec128(vdupq_n_s8(vgetq_lane_s8(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s8(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_n_s16(vgetq_lane_s16(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s16(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_n_s32(vgetq_lane_s32(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_s32(v.raw, kLane)); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec128(vdupq_n_s64(vgetq_lane_s64(v.raw, kLane))); +} + +// Float +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_n_f16(vgetq_lane_f16(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_f16(v.raw, kLane)); +} +#endif // HWY_HAVE_FLOAT16 +#if HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + return Vec128(vdupq_n_bf16(vgetq_lane_bf16(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_bf16(v.raw, kLane)); +} +#endif // HWY_NEON_HAVE_BFLOAT16 +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec128(vdupq_n_f32(vgetq_lane_f32(v.raw, kLane))); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128(vdup_lane_f32(v.raw, kLane)); +} + +#endif // HWY_ARCH_ARM_A64 + +template ), + HWY_IF_LANES_GT_D(DFromV, 1)> +HWY_API V Broadcast(V v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, Broadcast(BitCast(du, v))); +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices128 { + typename detail::Raw128::type raw; +}; + +namespace detail { + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; + return Iota(d8, 0); +} + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14}; + return Load(d8, kBroadcastLaneBytes); +} + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12}; + return Load(d8, kBroadcastLaneBytes); +} + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 8, 8}; + return Load(d8, kBroadcastLaneBytes); +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + return Zero(d8); +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kByteOffsets[16] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}; + return Load(d8, kByteOffsets); +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kByteOffsets[16] = { + 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; + return Load(d8, kByteOffsets); +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kByteOffsets[16] = { + 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7}; + return Load(d8, kByteOffsets); +} + +} // namespace detail + +template +HWY_API Indices128, MaxLanes(D())> IndicesFromVec( + D d, Vec128 vec) { + using T = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const RebindToUnsigned du; + using TU = TFromD; + HWY_DASSERT(AllTrue( + du, Lt(BitCast(du, vec), Set(du, static_cast(MaxLanes(d) * 2))))); +#endif + + (void)d; + return Indices128, MaxLanes(D())>{BitCast(d, vec).raw}; +} + +template +HWY_API Indices128, MaxLanes(D())> IndicesFromVec( + D d, Vec128 vec) { + using T = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const RebindToUnsigned du; + using TU = TFromD; + HWY_DASSERT(AllTrue( + du, Lt(BitCast(du, vec), Set(du, static_cast(MaxLanes(d) * 2))))); +#endif + + const Repartition d8; + using V8 = VFromD; + + // Broadcast each lane index to all bytes of T and shift to bytes + const V8 lane_indices = TableLookupBytes( + BitCast(d8, vec), detail::IndicesFromVecBroadcastLaneBytes(d)); + constexpr int kIndexShiftAmt = static_cast(FloorLog2(sizeof(T))); + const V8 byte_indices = ShiftLeft(lane_indices); + const V8 sum = Add(byte_indices, detail::IndicesFromVecByteOffsets(d)); + return Indices128, MaxLanes(D())>{BitCast(d, sum).raw}; +} + +template +HWY_API Indices128, MaxLanes(D())> SetTableIndices(D d, + const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + const DFromV d; + const RebindToSigned di; + return BitCast( + d, TableLookupBytes(BitCast(di, v), BitCast(di, Vec128{idx.raw}))); +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { + const DFromV d; + const Twice dt; +// TableLookupLanes currently requires table and index vectors to be the same +// size, though a half-length index vector would be sufficient here. +#if HWY_IS_MSAN + const Vec128 idx_vec{idx.raw}; + const Indices128 idx2{Combine(dt, idx_vec, idx_vec).raw}; +#else + // We only keep LowerHalf of the result, which is valid in idx. + const Indices128 idx2{idx.raw}; +#endif + return LowerHalf(d, TableLookupLanes(Combine(dt, b, a), idx2)); +} + +template +HWY_API Vec64 TwoTablesLookupLanes(Vec64 a, Vec64 b, + Indices128 idx) { + const DFromV d; + const Repartition du8; + const auto a_u8 = BitCast(du8, a); + const auto b_u8 = BitCast(du8, b); + const auto idx_u8 = BitCast(du8, Vec64{idx.raw}); + +#if HWY_ARCH_ARM_A64 + const Twice dt_u8; + return BitCast( + d, Vec64{vqtbl1_u8(Combine(dt_u8, b_u8, a_u8).raw, idx_u8.raw)}); +#else + detail::Tuple2 tup = {{{a_u8.raw, b_u8.raw}}}; + return BitCast(d, Vec64{vtbl2_u8(tup.raw, idx_u8.raw)}); +#endif +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { + const DFromV d; + const Repartition du8; + const auto a_u8 = BitCast(du8, a); + const auto b_u8 = BitCast(du8, b); + const auto idx_u8 = BitCast(du8, Vec128{idx.raw}); + +#if HWY_ARCH_ARM_A64 + detail::Tuple2 tup = {{{a_u8.raw, b_u8.raw}}}; + return BitCast(d, Vec128{vqtbl2q_u8(tup.raw, idx_u8.raw)}); +#else + const Half dh; + const Repartition dh_u8; + const auto a_lo_u8 = LowerHalf(dh_u8, a_u8); + const auto a_hi_u8 = UpperHalf(dh_u8, a_u8); + const auto b_lo_u8 = LowerHalf(dh_u8, b_u8); + const auto b_hi_u8 = UpperHalf(dh_u8, b_u8); + const auto idx_lo_u8 = LowerHalf(dh_u8, idx_u8); + const auto idx_hi_u8 = UpperHalf(dh_u8, idx_u8); + + detail::Tuple4 tup = { + {{a_lo_u8.raw, a_hi_u8.raw, b_lo_u8.raw, b_hi_u8.raw}}}; + const auto lo_result = + BitCast(dh, Vec64{vtbl4_u8(tup.raw, idx_lo_u8.raw)}); + const auto hi_result = + BitCast(dh, Vec64{vtbl4_u8(tup.raw, idx_hi_u8.raw)}); + return Combine(d, hi_result, lo_result); +#endif +} + +// ------------------------------ Reverse2 (CombineShiftRightBytes) + +// Per-target flag to prevent generic_ops-inl.h defining 8-bit Reverse2/4/8. +#ifdef HWY_NATIVE_REVERSE2_8 +#undef HWY_NATIVE_REVERSE2_8 +#else +#define HWY_NATIVE_REVERSE2_8 +#endif + +template +HWY_API VFromD Reverse2(D d, VFromD v) { + const RebindToUnsigned du; + return BitCast(d, VFromD(vrev16_u8(BitCast(du, v).raw))); +} +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec128 Reverse2(D d, Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev16q_u8(BitCast(du, v).raw))); +} + +template +HWY_API VFromD Reverse2(D d, VFromD v) { + const RebindToUnsigned du; + return BitCast(d, VFromD(vrev32_u16(BitCast(du, v).raw))); +} +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec128 Reverse2(D d, Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev32q_u16(BitCast(du, v).raw))); +} + +template +HWY_API VFromD Reverse2(D d, VFromD v) { + const RebindToUnsigned du; + return BitCast(d, VFromD(vrev64_u32(BitCast(du, v).raw))); +} +template , HWY_IF_T_SIZE(T, 4)> +HWY_API Vec128 Reverse2(D d, Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev64q_u32(BitCast(du, v).raw))); +} + +template +HWY_API VFromD Reverse2(D d, VFromD v) { + return CombineShiftRightBytes<8>(d, v, v); +} + +// ------------------------------ Reverse4 (Reverse2) + +template +HWY_API VFromD Reverse4(D d, VFromD v) { + const RebindToUnsigned du; + return BitCast(d, VFromD(vrev32_u8(BitCast(du, v).raw))); +} +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec128 Reverse4(D d, Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev32q_u8(BitCast(du, v).raw))); +} + +template +HWY_API VFromD Reverse4(D d, VFromD v) { + const RebindToUnsigned du; + return BitCast(d, VFromD(vrev64_u16(BitCast(du, v).raw))); +} +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec128 Reverse4(D d, Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev64q_u16(BitCast(du, v).raw))); +} + +template +HWY_API VFromD Reverse4(D d, VFromD v) { + const RepartitionToWide> duw; + return BitCast(d, Reverse2(duw, BitCast(duw, Reverse2(d, v)))); +} + +template +HWY_API VFromD Reverse4(D /* tag */, VFromD) { + HWY_ASSERT(0); // don't have 8 u64 lanes +} + +// ------------------------------ Reverse8 (Reverse2, Reverse4) + +template +HWY_API VFromD Reverse8(D d, VFromD v) { + const RebindToUnsigned du; + return BitCast(d, VFromD(vrev64_u8(BitCast(du, v).raw))); +} +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec128 Reverse8(D d, Vec128 v) { + const RebindToUnsigned du; + return BitCast(d, Vec128(vrev64q_u8(BitCast(du, v).raw))); +} + +template +HWY_API VFromD Reverse8(D d, VFromD v) { + const Repartition du64; + return BitCast(d, Reverse2(du64, BitCast(du64, Reverse4(d, v)))); +} + +template +HWY_API VFromD Reverse8(D, VFromD) { + HWY_ASSERT(0); // don't have 8 lanes if larger than 16-bit +} + +// ------------------------------ Reverse (Reverse2, Reverse4, Reverse8) + +template , HWY_IF_LANES_D(D, 1)> +HWY_API Vec128 Reverse(D /* tag */, Vec128 v) { + return v; +} + +template , HWY_IF_LANES_D(D, 2)> +HWY_API Vec128 Reverse(D d, Vec128 v) { + return Reverse2(d, v); +} + +template , HWY_IF_LANES_D(D, 4)> +HWY_API Vec128 Reverse(D d, Vec128 v) { + return Reverse4(d, v); +} + +template , HWY_IF_LANES_D(D, 8)> +HWY_API Vec128 Reverse(D d, Vec128 v) { + return Reverse8(d, v); +} + +template , HWY_IF_LANES_D(D, 16)> +HWY_API Vec128 Reverse(D d, Vec128 v) { + const Repartition du64; + return BitCast(d, Reverse2(du64, BitCast(du64, Reverse8(d, v)))); +} + +// ------------------------------ ReverseBits + +#if HWY_ARCH_ARM_A64 + +#ifdef HWY_NATIVE_REVERSE_BITS_UI8 +#undef HWY_NATIVE_REVERSE_BITS_UI8 +#else +#define HWY_NATIVE_REVERSE_BITS_UI8 +#endif + +HWY_NEON_DEF_FUNCTION_INT_8(ReverseBits, vrbit, _, 1) +HWY_NEON_DEF_FUNCTION_UINT_8(ReverseBits, vrbit, _, 1) + +#endif // HWY_ARCH_ARM_A64 + +// ------------------------------ Other shuffles (TableLookupBytes) + +// Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 64-bit halves +template +HWY_API Vec128 Shuffle1032(Vec128 v) { + return CombineShiftRightBytes<8>(DFromV(), v, v); +} +template +HWY_API Vec128 Shuffle01(Vec128 v) { + return CombineShiftRightBytes<8>(DFromV(), v, v); +} + +// Rotate right 32 bits +template +HWY_API Vec128 Shuffle0321(Vec128 v) { + return CombineShiftRightBytes<4>(DFromV(), v, v); +} + +// Rotate left 32 bits +template +HWY_API Vec128 Shuffle2103(Vec128 v) { + return CombineShiftRightBytes<12>(DFromV(), v, v); +} + +// Reverse +template +HWY_API Vec128 Shuffle0123(Vec128 v) { + return Reverse4(DFromV(), v); +} + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). +HWY_NEON_DEF_FUNCTION_UIF_8_16_32(InterleaveLower, vzip1, _, 2) +#if HWY_ARCH_ARM_A64 +// N=1 makes no sense (in that case, there would be no upper/lower). +HWY_NEON_DEF_FUNCTION_FULL_UIF_64(InterleaveLower, vzip1, _, 2) +#else +// Emulated version for Armv7. +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + const DFromV d; + return CombineShiftRightBytes<8>(d, b, Shuffle01(a)); +} +#endif + +#if !HWY_HAVE_FLOAT16 +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, InterleaveLower(BitCast(du, a), BitCast(du, b))); +} +#endif // !HWY_HAVE_FLOAT16 + +// < 64 bit parts +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + return Vec128(InterleaveLower(Vec64(a.raw), Vec64(b.raw)).raw); +} + +// Additional overload for the optional Simd<> tag. +template +HWY_API VFromD InterleaveLower(D /* tag */, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveUpper (UpperHalf) + +// All functions inside detail lack the required D parameter. +namespace detail { +HWY_NEON_DEF_FUNCTION_UIF_8_16_32(InterleaveUpper, vzip2, _, 2) + +#if HWY_ARCH_ARM_A64 +// N=1 makes no sense (in that case, there would be no upper/lower). +HWY_NEON_DEF_FUNCTION_FULL_UIF_64(InterleaveUpper, vzip2, _, 2) +#else +// Emulated version for Armv7. +template +HWY_API Vec128 InterleaveUpper(Vec128 a, Vec128 b) { + const DFromV d; + return CombineShiftRightBytes<8>(d, Shuffle01(b), a); +} +#endif +} // namespace detail + +// Full register +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return detail::InterleaveUpper(a, b); +} + +// Partial +template +HWY_API VFromD InterleaveUpper(D d, VFromD a, VFromD b) { + const Half d2; + const VFromD a2(UpperHalf(d2, a).raw); + const VFromD b2(UpperHalf(d2, b).raw); + return InterleaveLower(d, a2, b2); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ------------------------------ Per4LaneBlockShuffle +namespace detail { + +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANG + +#ifdef HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#undef HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#else +#define HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#endif + +template +HWY_INLINE VFromD Per4LaneBlkShufDupSet4xU32(D d, const uint32_t /*x3*/, + const uint32_t /*x2*/, + const uint32_t x1, + const uint32_t x0) { + typedef uint32_t GccU32RawVectType __attribute__((__vector_size__(8))); + const GccU32RawVectType raw = {x0, x1}; + return ResizeBitCast(d, Vec64(reinterpret_cast(raw))); +} + +template +HWY_INLINE VFromD Per4LaneBlkShufDupSet4xU32(D d, const uint32_t x3, + const uint32_t x2, + const uint32_t x1, + const uint32_t x0) { + typedef uint32_t GccU32RawVectType __attribute__((__vector_size__(16))); + const GccU32RawVectType raw = {x0, x1, x2, x3}; + return ResizeBitCast(d, Vec128(reinterpret_cast(raw))); +} +#endif // HWY_COMPILER_GCC || HWY_COMPILER_CLANG + +template , 4)> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x88> /*idx_3210_tag*/, + hwy::SizeTag /*lane_size_tag*/, + hwy::SizeTag /*vect_size_tag*/, + V v) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWide dw; + + const auto evens = BitCast(dw, ConcatEven(d, v, v)); + return BitCast(d, InterleaveLower(dw, evens, evens)); +} + +template , 4)> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xDD> /*idx_3210_tag*/, + hwy::SizeTag /*lane_size_tag*/, + hwy::SizeTag /*vect_size_tag*/, + V v) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWide dw; + + const auto odds = BitCast(dw, ConcatOdd(d, v, v)); + return BitCast(d, InterleaveLower(dw, odds, odds)); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xFA> /*idx_3210_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, + hwy::SizeTag<8> /*vect_size_tag*/, V v) { + const DFromV d; + return InterleaveUpper(d, v, v); +} + +} // namespace detail + +// ------------------------------ SlideUpLanes + +namespace detail { + +template +HWY_INLINE V SlideUpLanes(V v, size_t amt) { + const DFromV d; + using TU = UnsignedFromSize; + const Repartition du; + return BitCast(d, BitCast(du, v) << Set( + du, static_cast(amt * sizeof(TFromV) * 8))); +} + +template +HWY_INLINE V SlideUpLanes(V v, size_t amt) { + const DFromV d; + const Repartition du8; + const auto idx = + Iota(du8, static_cast(size_t{0} - amt * sizeof(TFromV))); + return BitCast(d, TableLookupBytesOr0(BitCast(du8, v), idx)); +} + +} // namespace detail + +template +HWY_API VFromD SlideUpLanes(D /*d*/, VFromD v, size_t /*amt*/) { + return v; +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + case 2: + return ShiftLeftLanes<2>(d, v); + case 3: + return ShiftLeftLanes<3>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + case 2: + return ShiftLeftLanes<2>(d, v); + case 3: + return ShiftLeftLanes<3>(d, v); + case 4: + return ShiftLeftLanes<4>(d, v); + case 5: + return ShiftLeftLanes<5>(d, v); + case 6: + return ShiftLeftLanes<6>(d, v); + case 7: + return ShiftLeftLanes<7>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + case 2: + return ShiftLeftLanes<2>(d, v); + case 3: + return ShiftLeftLanes<3>(d, v); + case 4: + return ShiftLeftLanes<4>(d, v); + case 5: + return ShiftLeftLanes<5>(d, v); + case 6: + return ShiftLeftLanes<6>(d, v); + case 7: + return ShiftLeftLanes<7>(d, v); + case 8: + return ShiftLeftLanes<8>(d, v); + case 9: + return ShiftLeftLanes<9>(d, v); + case 10: + return ShiftLeftLanes<10>(d, v); + case 11: + return ShiftLeftLanes<11>(d, v); + case 12: + return ShiftLeftLanes<12>(d, v); + case 13: + return ShiftLeftLanes<13>(d, v); + case 14: + return ShiftLeftLanes<14>(d, v); + case 15: + return ShiftLeftLanes<15>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +// ------------------------------ SlideDownLanes + +namespace detail { + +template +HWY_INLINE V SlideDownLanes(V v, size_t amt) { + const DFromV d; + using TU = UnsignedFromSize; + const Repartition du; + return BitCast(d, + BitCast(du, v) << Set( + du, static_cast(TU{0} - amt * sizeof(TFromV) * 8))); +} + +template +HWY_INLINE V SlideDownLanes(V v, size_t amt) { + const DFromV d; + const Repartition di8; + auto idx = Iota(di8, static_cast(amt * sizeof(TFromV))); + idx = Or(idx, VecFromMask(di8, idx > Set(di8, int8_t{15}))); + return BitCast(d, TableLookupBytesOr0(BitCast(di8, v), idx)); +} + +} // namespace detail + +template +HWY_API VFromD SlideDownLanes(D /*d*/, VFromD v, size_t /*amt*/) { + return v; +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + case 4: + return ShiftRightLanes<4>(d, v); + case 5: + return ShiftRightLanes<5>(d, v); + case 6: + return ShiftRightLanes<6>(d, v); + case 7: + return ShiftRightLanes<7>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + case 4: + return ShiftRightLanes<4>(d, v); + case 5: + return ShiftRightLanes<5>(d, v); + case 6: + return ShiftRightLanes<6>(d, v); + case 7: + return ShiftRightLanes<7>(d, v); + case 8: + return ShiftRightLanes<8>(d, v); + case 9: + return ShiftRightLanes<9>(d, v); + case 10: + return ShiftRightLanes<10>(d, v); + case 11: + return ShiftRightLanes<11>(d, v); + case 12: + return ShiftRightLanes<12>(d, v); + case 13: + return ShiftRightLanes<13>(d, v); + case 14: + return ShiftRightLanes<14>(d, v); + case 15: + return ShiftRightLanes<15>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +// ------------------------------- WidenHighMulAdd + +#ifdef HWY_NATIVE_WIDEN_HIGH_MUL_ADD +#undef HWY_NATIVE_WIDEN_HIGH_MUL_ADD +#else +#define HWY_NATIVE_WIDEN_HIGH_MUL_ADD +#endif + +namespace detail { + +template, + HWY_IF_LANES_GT_D(DN, 2)> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { +#if HWY_ARCH_ARM_A64 + return Vec128(vmlal_high_u32(add.raw, mul.raw, x.raw)); +#else + const Full64 dh; + return Vec128( + vmlal_u32(add.raw, UpperHalf(dh, mul).raw, UpperHalf(dh, x).raw)); +#endif +} + +template, + HWY_IF_LANES_LE_D(DN, 2)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 mulResult = Vec128(vmull_u32(mul.raw, x.raw)); + return UpperHalf(d, mulResult) + add; +} + +template, + HWY_IF_LANES_GT_D(DN, 2)> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { +#if HWY_ARCH_ARM_A64 + return Vec128(vmlal_high_s32(add.raw, mul.raw, x.raw)); +#else + const Full64 dh; + return Vec128( + vmlal_s32(add.raw, UpperHalf(dh, mul).raw, UpperHalf(dh, x).raw)); +#endif +} + +template, + HWY_IF_LANES_LE_D(DN, 2)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 mulResult = Vec128(vmull_s32(mul.raw, x.raw)); + return UpperHalf(d, mulResult) + add; +} + +template, + HWY_IF_LANES_GT_D(DN, 4)> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { +#if HWY_ARCH_ARM_A64 + return Vec128(vmlal_high_s16(add.raw, mul.raw, x.raw)); +#else + const Full64 dh; + return Vec128( + vmlal_s16(add.raw, UpperHalf(dh, mul).raw, UpperHalf(dh, x).raw)); +#endif +} + +template, + HWY_IF_LANES_D(DN, 4)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_s16(mul.raw, x.raw)); + Vec64 hi = UpperHalf(d, widen); + return hi + add; +} + +template, + HWY_IF_LANES_D(DN, 2)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_s16(mul.raw, x.raw)); + Vec32 hi = UpperHalf(d, Vec64(vget_high_s32(widen.raw))); + return hi + add; +} + +template, + HWY_IF_LANES_GT_D(DN, 4)> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { +#if HWY_ARCH_ARM_A64 + return Vec128(vmlal_high_u16(add.raw, mul.raw, x.raw)); +#else + const Full64 dh; + return Vec128( + vmlal_u16(add.raw, UpperHalf(dh, mul).raw, UpperHalf(dh, x).raw)); +#endif +} + +template, + HWY_IF_LANES_D(DN, 4)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_u16(mul.raw, x.raw)); + VFromD hi = UpperHalf(d, widen); + return hi + add; +} + +template> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_u16(mul.raw, x.raw)); + VFromD hi = UpperHalf(d, Vec64(vget_high_u32(widen.raw))); + return hi + add; +} + +template, + HWY_IF_LANES_GT_D(DN, 8)> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { +#if HWY_ARCH_ARM_A64 + return Vec128(vmlal_high_u8(add.raw, mul.raw, x.raw)); +#else + const Full64 dh; + return Vec128( + vmlal_u8(add.raw, UpperHalf(dh, mul).raw, UpperHalf(dh, x).raw)); +#endif +} + +template, + HWY_IF_LANES_D(DN, 8)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_u8(mul.raw, x.raw)); + VFromD hi = UpperHalf(d, widen); + return hi + add; +} + +template), class DN = RepartitionToNarrow, + HWY_IF_LANES_LE_D(DN, 4)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_u8(mul.raw, x.raw)); + const Twice d16F; + VFromD hi = UpperHalf(d, VFromD(vget_high_u16(widen.raw))); + return hi + add; +} + +template, + HWY_IF_LANES_GT_D(DN, 8)> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { +#if HWY_ARCH_ARM_A64 + return Vec128(vmlal_high_s8(add.raw, mul.raw, x.raw)); +#else + const Full64 dh; + return Vec128( + vmlal_s8(add.raw, UpperHalf(dh, mul).raw, UpperHalf(dh, x).raw)); +#endif +} + +template, + HWY_IF_LANES_D(DN, 8)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_s8(mul.raw, x.raw)); + VFromD hi = UpperHalf(d, widen); + return hi + add; +} + +template, + HWY_IF_LANES_LE_D(DN, 4)> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + Vec128 widen = Vec128(vmull_s8(mul.raw, x.raw)); + const Twice d16F; + VFromD hi = UpperHalf(d, VFromD(vget_high_s16(widen.raw))); + return hi + add; +} + +#if 0 +#if HWY_HAVE_FLOAT16 +template> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return VFromD(vfmlalq_high_f16(add.raw, mul.raw, x.raw)); +} + +template> +HWY_API VFromD WidenHighMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return Vec64(vfmlal_high_f16(add.raw, mul.raw, x.raw)); +} + +template> +HWY_API VFromD WidenHighMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + return MulAdd(add, PromoteUpperTo(d, mul), PromoteUpperTo(d, x)); +} +#endif +#endif + +} // namespace detail + +// ------------------------------- WidenMulAdd + +#ifdef HWY_NATIVE_WIDEN_MUL_ADD +#undef HWY_NATIVE_WIDEN_MUL_ADD +#else +#define HWY_NATIVE_WIDEN_MUL_ADD +#endif + +namespace detail { + +template>, D>> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return Vec128(vmlal_u8(add.raw, mul.raw, x.raw)); +} + +template >, D>> +HWY_API VFromD WidenMulAdd(D d, VFromD mul, VFromD x, + VFromD add) { + return MulAdd(add, PromoteTo(d, mul), PromoteTo(d, x)); +} + +template>, D>> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return VFromD(vmlal_s8(add.raw, mul.raw, x.raw)); +} + +template >, D>> +HWY_API VFromD WidenMulAdd(D d, VFromD mul, VFromD x, + VFromD add) { + return MulAdd(add, PromoteTo(d, mul), PromoteTo(d, x)); +} + +template>, D>, + HWY_IF_LANES_GT_D(DN, 2)> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return Vec128(vmlal_s16(add.raw, mul.raw, x.raw)); +} + +template>, D>, + HWY_IF_LANES_D(DN, 2)> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + Vec128 mulRs = Vec128(vmull_s16(mul.raw, x.raw)); + const VFromD mul10 = LowerHalf(mulRs); + return add + mul10; +} + +template>, D>, + HWY_IF_LANES_D(D, 1)> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + Vec64 mulRs = LowerHalf(Vec128(vmull_s16(mul.raw, x.raw))); + const Vec32 mul10(LowerHalf(mulRs)); + return add + mul10; +} + +template>, D>> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return Vec128(vmlal_u16(add.raw, mul.raw, x.raw)); +} + +template>, D>> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + Vec128 mulRs = Vec128(vmull_u16(mul.raw, x.raw)); + const Vec64 mul10(LowerHalf(mulRs)); + return add + mul10; +} + +template>, D>> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + Vec64 mulRs = + LowerHalf(Vec128(vmull_u16(mul.raw, x.raw))); + const Vec32 mul10(LowerHalf(mulRs)); + return add + mul10; +} + +template>, D>, + HWY_IF_LANES_D(DN, 2)> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return VFromD(vmlal_s32(add.raw, mul.raw, x.raw)); +} + +template>, D>> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + Vec128 mulRs = Vec128(vmull_s32(mul.raw, x.raw)); + const VFromD mul10(LowerHalf(mulRs)); + return add + mul10; +} + +template>, D>, + HWY_IF_LANES_D(DN, 2)> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return VFromD(vmlal_u32(add.raw, mul.raw, x.raw)); +} + +template>, D>, + HWY_IF_LANES_D(DN, 1)> +HWY_API VFromD WidenMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + Vec128 mulRs = Vec128(vmull_u32(mul.raw, x.raw)); + const VFromD mul10(LowerHalf(mulRs)); + return add + mul10; +} + +#if 0 +#if HWY_HAVE_FLOAT16 +template, + HWY_IF_LANES_D(D, 4)> +HWY_API VFromD WidenLowMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return VFromD(vfmlalq_low_f16(add.raw, mul.raw, x.raw)); +} + +template, + HWY_IF_LANES_D(DN, 4)> +HWY_API VFromD WidenLowMulAdd(D /* tag */, VFromD mul, + VFromD x, VFromD add) { + return Vec64(vfmlal_low_f16(add.raw, mul.raw, x.raw)); +} + +template> +HWY_API VFromD WidenLowMulAdd(D d, VFromD mul, + VFromD x, VFromD add) { + return MulAdd(add, PromoteLowerTo(d, mul), PromoteLowerTo(d, x)); +} +#endif +#endif + +} // namespace detail + +// ------------------------------ WidenMulAccumulate + +#ifdef HWY_NATIVE_WIDEN_MUL_ACCUMULATE +#undef HWY_NATIVE_WIDEN_MUL_ACCUMULATE +#else +#define HWY_NATIVE_WIDEN_MUL_ACCUMULATE +#endif + +template), class DN = RepartitionToNarrow> +HWY_API VFromD WidenMulAccumulate(D d, VFromD mul, VFromD x, + VFromD low, VFromD& high) { + high = detail::WidenHighMulAdd(d, mul, x, high); + return detail::WidenMulAdd(d, LowerHalf(mul), LowerHalf(x), low); +} + +#if 0 +#ifdef HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16 +#undef HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16 +#else +#define HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16 +#endif + +#if HWY_HAVE_FLOAT16 + +template> +HWY_API VFromD WidenMulAccumulate(D d, VFromD mul, VFromD x, + VFromD low, VFromD& high) { + high = detail::WidenHighMulAdd(d, mul, x, high); + return detail::WidenLowMulAdd(d, mul, x, low); +} + +#endif +#endif + +// ------------------------------ SatWidenMulAccumFixedPoint + +#ifdef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#undef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#else +#define HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#endif + +template +HWY_API VFromD SatWidenMulAccumFixedPoint(DI32 /*di32*/, + VFromD> a, + VFromD> b, + VFromD sum) { + return VFromD(vqdmlal_s16(sum.raw, a.raw, b.raw)); +} + +template +HWY_API VFromD SatWidenMulAccumFixedPoint(DI32 di32, + VFromD> a, + VFromD> b, + VFromD sum) { + const Full128> di32_full; + const Rebind di16_full64; + return ResizeBitCast( + di32, SatWidenMulAccumFixedPoint(di32_full, ResizeBitCast(di16_full64, a), + ResizeBitCast(di16_full64, b), + ResizeBitCast(di32_full, sum))); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +#if HWY_NEON_HAVE_F32_TO_BF16C + +#ifdef HWY_NATIVE_MUL_EVEN_BF16 +#undef HWY_NATIVE_MUL_EVEN_BF16 +#else +#define HWY_NATIVE_MUL_EVEN_BF16 +#endif + +#ifdef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#undef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#else +#define HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#endif + +namespace detail { +#if HWY_NEON_HAVE_BFLOAT16 +// If HWY_NEON_HAVE_BFLOAT16 is true, detail::Vec128::type is +// bfloat16x4_t or bfloat16x8_t. +static HWY_INLINE bfloat16x4_t BitCastToRawNeonBF16(bfloat16x4_t raw) { + return raw; +} +static HWY_INLINE bfloat16x8_t BitCastToRawNeonBF16(bfloat16x8_t raw) { + return raw; +} +#else +// If HWY_NEON_HAVE_F32_TO_BF16C && !HWY_NEON_HAVE_BFLOAT16 is true, +// detail::Vec128::type is uint16x4_t or uint16x8_t vector to +// work around compiler bugs that are there with GCC 13 or earlier or Clang 16 +// or earlier on AArch64. + +// The uint16x4_t or uint16x8_t vector neets to be bitcasted to a bfloat16x4_t +// or a bfloat16x8_t vector for the vbfdot_f32 and vbfdotq_f32 intrinsics if +// HWY_NEON_HAVE_F32_TO_BF16C && !HWY_NEON_HAVE_BFLOAT16 is true + +// NOTE: vbfdot uses round to odd unless the additional FEAT_EBF16 feature is +// available and enabled. +static HWY_INLINE bfloat16x4_t BitCastToRawNeonBF16(uint16x4_t raw) { + return vreinterpret_bf16_u16(raw); +} +static HWY_INLINE bfloat16x8_t BitCastToRawNeonBF16(uint16x8_t raw) { + return vreinterpretq_bf16_u16(raw); +} +#endif +} // namespace detail + +template +HWY_API Vec128 MulEvenAdd(D /*d32*/, Vec128 a, + Vec128 b, const Vec128 c) { + return Vec128(vbfmlalbq_f32(c.raw, detail::BitCastToRawNeonBF16(a.raw), + detail::BitCastToRawNeonBF16(b.raw))); +} + +template +HWY_API Vec128 MulOddAdd(D /*d32*/, Vec128 a, + Vec128 b, const Vec128 c) { + return Vec128(vbfmlaltq_f32(c.raw, detail::BitCastToRawNeonBF16(a.raw), + detail::BitCastToRawNeonBF16(b.raw))); +} + +template +HWY_API Vec128 ReorderWidenMulAccumulate(D /*d32*/, Vec128 a, + Vec128 b, + const Vec128 sum0, + Vec128& /*sum1*/) { + return Vec128(vbfdotq_f32(sum0.raw, + detail::BitCastToRawNeonBF16(a.raw), + detail::BitCastToRawNeonBF16(b.raw))); +} + +// There is no non-q version of these instructions. +template +HWY_API VFromD MulEvenAdd(D d32, VFromD> a, + VFromD> b, + const VFromD c) { + const Full128 d32f; + const Full128 d16f; + return ResizeBitCast( + d32, MulEvenAdd(d32f, ResizeBitCast(d16f, a), ResizeBitCast(d16f, b), + ResizeBitCast(d32f, c))); +} + +template +HWY_API VFromD MulOddAdd(D d32, VFromD> a, + VFromD> b, + const VFromD c) { + const Full128 d32f; + const Full128 d16f; + return ResizeBitCast( + d32, MulOddAdd(d32f, ResizeBitCast(d16f, a), ResizeBitCast(d16f, b), + ResizeBitCast(d32f, c))); +} + +template +HWY_API VFromD ReorderWidenMulAccumulate( + D /*d32*/, VFromD> a, + VFromD> b, const VFromD sum0, + VFromD& /*sum1*/) { + return VFromD(vbfdot_f32(sum0.raw, detail::BitCastToRawNeonBF16(a.raw), + detail::BitCastToRawNeonBF16(b.raw))); +} + +template +HWY_API Vec128 RearrangeToOddPlusEven(Vec128 sum0, + Vec128) { + return sum0; +} + +#endif // HWY_NEON_HAVE_F32_TO_BF16C + +template +HWY_API Vec128 ReorderWidenMulAccumulate(D /*d32*/, Vec128 a, + Vec128 b, + const Vec128 sum0, + Vec128& sum1) { +#if HWY_ARCH_ARM_A64 + sum1 = Vec128(vmlal_high_s16(sum1.raw, a.raw, b.raw)); +#else + const Full64 dh; + sum1 = Vec128( + vmlal_s16(sum1.raw, UpperHalf(dh, a).raw, UpperHalf(dh, b).raw)); +#endif + return Vec128( + vmlal_s16(sum0.raw, LowerHalf(a).raw, LowerHalf(b).raw)); +} + +template +HWY_API Vec64 ReorderWidenMulAccumulate(D d32, Vec64 a, + Vec64 b, + const Vec64 sum0, + Vec64& sum1) { + // vmlal writes into the upper half, which the caller cannot use, so + // split into two halves. + const Vec128 mul_3210(vmull_s16(a.raw, b.raw)); + const Vec64 mul_32 = UpperHalf(d32, mul_3210); + sum1 += mul_32; + return sum0 + LowerHalf(mul_3210); +} + +template +HWY_API Vec32 ReorderWidenMulAccumulate(D d32, Vec32 a, + Vec32 b, + const Vec32 sum0, + Vec32& sum1) { + const Vec128 mul_xx10(vmull_s16(a.raw, b.raw)); + const Vec64 mul_10(LowerHalf(mul_xx10)); + const Vec32 mul0 = LowerHalf(d32, mul_10); + const Vec32 mul1 = UpperHalf(d32, mul_10); + sum1 += mul1; + return sum0 + mul0; +} + +template +HWY_API Vec128 ReorderWidenMulAccumulate(D /*d32*/, + Vec128 a, + Vec128 b, + const Vec128 sum0, + Vec128& sum1) { +#if HWY_ARCH_ARM_A64 + sum1 = Vec128(vmlal_high_u16(sum1.raw, a.raw, b.raw)); +#else + const Full64 dh; + sum1 = Vec128( + vmlal_u16(sum1.raw, UpperHalf(dh, a).raw, UpperHalf(dh, b).raw)); +#endif + return Vec128( + vmlal_u16(sum0.raw, LowerHalf(a).raw, LowerHalf(b).raw)); +} + +template +HWY_API Vec64 ReorderWidenMulAccumulate(D d32, Vec64 a, + Vec64 b, + const Vec64 sum0, + Vec64& sum1) { + // vmlal writes into the upper half, which the caller cannot use, so + // split into two halves. + const Vec128 mul_3210(vmull_u16(a.raw, b.raw)); + const Vec64 mul_32 = UpperHalf(d32, mul_3210); + sum1 += mul_32; + return sum0 + LowerHalf(mul_3210); +} + +template +HWY_API Vec32 ReorderWidenMulAccumulate(D du32, Vec32 a, + Vec32 b, + const Vec32 sum0, + Vec32& sum1) { + const Vec128 mul_xx10(vmull_u16(a.raw, b.raw)); + const Vec64 mul_10(LowerHalf(mul_xx10)); + const Vec32 mul0 = LowerHalf(du32, mul_10); + const Vec32 mul1 = UpperHalf(du32, mul_10); + sum1 += mul1; + return sum0 + mul0; +} + +// ------------------------------ Combine partial (InterleaveLower) +// < 64bit input, <= 64 bit result +template +HWY_API VFromD Combine(D d, VFromD> hi, VFromD> lo) { + // First double N (only lower halves will be used). + const VFromD hi2(hi.raw); + const VFromD lo2(lo.raw); + // Repartition to two unsigned lanes (each the size of the valid input). + const Simd, 2, 0> du; + return BitCast(d, InterleaveLower(BitCast(du, lo2), BitCast(du, hi2))); +} + +// ------------------------------ RearrangeToOddPlusEven (Combine) + +namespace detail { +// Armv7 only provides 64-bit (half-vector) pairwise operations. +#define HWY_NEON_DEF_PAIRWISE_OP(T, name, prefix, suffix) \ + HWY_INLINE Vec64 Pairwise##name(Vec64 a, Vec64 b) { \ + return Vec64(prefix##_##suffix(a.raw, b.raw)); \ + } + +// Note that Armv7 also lacks [u]int64 instructions, which are handled by +// generic_ops-inl.h SumOfLanes etc., hence no 64-bit overloads here. +#define HWY_NEON_DEF_PAIRWISE_OPS(name, prefix) \ + HWY_NEON_DEF_PAIRWISE_OP(uint32_t, name, prefix, u32) \ + HWY_NEON_DEF_PAIRWISE_OP(uint16_t, name, prefix, u16) \ + HWY_NEON_DEF_PAIRWISE_OP(uint8_t, name, prefix, u8) \ + HWY_NEON_DEF_PAIRWISE_OP(int32_t, name, prefix, s32) \ + HWY_NEON_DEF_PAIRWISE_OP(int16_t, name, prefix, s16) \ + HWY_NEON_DEF_PAIRWISE_OP(int8_t, name, prefix, s8) \ + HWY_NEON_DEF_PAIRWISE_OP(float32_t, name, prefix, f32) + +HWY_NEON_DEF_PAIRWISE_OPS(Sum, vpadd) +HWY_NEON_DEF_PAIRWISE_OPS(Min, vpmin) +HWY_NEON_DEF_PAIRWISE_OPS(Max, vpmax) +#undef HWY_NEON_DEF_PAIRWISE_OPS +#undef HWY_NEON_DEF_PAIRWISE_OP +} // namespace detail + +HWY_API Vec128 RearrangeToOddPlusEven(Vec128 sum0, + Vec128 sum1) { +// vmlal_s16 multiplied the lower half into sum0 and upper into sum1. +#if HWY_ARCH_ARM_A64 // pairwise sum is available and what we want + return Vec128(vpaddq_s32(sum0.raw, sum1.raw)); +#else + const Full128 d; + const Half d64; + const Vec64 hi = + detail::PairwiseSum(LowerHalf(d64, sum1), UpperHalf(d64, sum1)); + const Vec64 lo( + detail::PairwiseSum(LowerHalf(d64, sum0), UpperHalf(d64, sum0))); + return Combine(d, hi, lo); +#endif +} + +HWY_API Vec64 RearrangeToOddPlusEven(Vec64 sum0, + Vec64 sum1) { + // vmlal_s16 multiplied the lower half into sum0 and upper into sum1. + return detail::PairwiseSum(sum0, sum1); +} + +HWY_API Vec32 RearrangeToOddPlusEven(Vec32 sum0, + Vec32 sum1) { + // Only one widened sum per register, so add them for sum of odd and even. + return sum0 + sum1; +} + +HWY_API Vec128 RearrangeToOddPlusEven(Vec128 sum0, + Vec128 sum1) { +// vmlal_s16 multiplied the lower half into sum0 and upper into sum1. +#if HWY_ARCH_ARM_A64 // pairwise sum is available and what we want + return Vec128(vpaddq_u32(sum0.raw, sum1.raw)); +#else + const Full128 d; + const Half d64; + const Vec64 hi = + detail::PairwiseSum(LowerHalf(d64, sum1), UpperHalf(d64, sum1)); + const Vec64 lo = + detail::PairwiseSum(LowerHalf(d64, sum0), UpperHalf(d64, sum0)); + return Combine(d, hi, lo); +#endif +} + +HWY_API Vec64 RearrangeToOddPlusEven(Vec64 sum0, + Vec64 sum1) { + // vmlal_u16 multiplied the lower half into sum0 and upper into sum1. + return detail::PairwiseSum(sum0, sum1); +} + +HWY_API Vec32 RearrangeToOddPlusEven(Vec32 sum0, + Vec32 sum1) { + // Only one widened sum per register, so add them for sum of odd and even. + return sum0 + sum1; +} + +// ------------------------------ SumOfMulQuadAccumulate + + +#if HWY_TARGET == HWY_NEON_BF16 || defined(__ARM_FEATURE_DOTPROD) + +#ifdef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 /*di32*/, + VFromD> a, + VFromD> b, + VFromD sum) { + return VFromD(vdot_s32(sum.raw, a.raw, b.raw)); +} + +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 /*di32*/, + VFromD> a, + VFromD> b, + VFromD sum) { + return VFromD(vdotq_s32(sum.raw, a.raw, b.raw)); +} + +#ifdef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DU32 /*du32*/, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD(vdot_u32(sum.raw, a.raw, b.raw)); +} + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DU32 /*du32*/, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD(vdotq_u32(sum.raw, a.raw, b.raw)); +} + +#endif //__ARM_FEATURE_DOTPROD || HWY_TARGET == HWY_NEON_BF16 + +#ifdef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#endif + +#if defined(__ARM_FEATURE_MATMUL_INT8) || \ + (HWY_TARGET == HWY_NEON_BF16 && HWY_OS_APPLE && HWY_ARCH_ARM_A64 && \ + HWY_HAVE_RUNTIME_DISPATCH) + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DI32 /*di32*/, VFromD> a_u, + VFromD> b_i, VFromD sum) { + return VFromD(vusdot_s32(sum.raw, a_u.raw, b_i.raw)); +} + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DI32 /*di32*/, VFromD> a_u, + VFromD> b_i, VFromD sum) { + return VFromD(vusdotq_s32(sum.raw, a_u.raw, b_i.raw)); +} + +#else + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DI32 di32, VFromD> a_u, + VFromD> b_i, VFromD sum) { + const RebindToUnsigned du32; + const Repartition du8; + + const auto b_u = BitCast(du8, b_i); + const auto result_sum0 = + SumOfMulQuadAccumulate(du32, a_u, b_u, BitCast(du32, sum)); + const auto result_sum1 = ShiftLeft<8>( + SumOfMulQuadAccumulate(du32, a_u, ShiftRight<7>(b_u), Zero(du32))); + + return BitCast(di32, Sub(result_sum0, result_sum1)); +} + +#endif // __ARM_FEATURE_MATMUL_INT8 + +// ------------------------------ WidenMulPairwiseAdd + +#if HWY_NEON_HAVE_F32_TO_BF16C + +template +HWY_API Vec128 WidenMulPairwiseAdd(DF df, Vec128 a, + Vec128 b) { + return Vec128(vbfdotq_f32(Zero(df).raw, + detail::BitCastToRawNeonBF16(a.raw), + detail::BitCastToRawNeonBF16(b.raw))); +} + +template +HWY_API VFromD WidenMulPairwiseAdd(DF df, + VFromD> a, + VFromD> b) { + return VFromD(vbfdot_f32(Zero(df).raw, + detail::BitCastToRawNeonBF16(a.raw), + detail::BitCastToRawNeonBF16(b.raw))); +} + +#else +template +HWY_API VFromD WidenMulPairwiseAdd(DF df, + VFromD> a, + VFromD> b) { + return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), + Mul(PromoteOddTo(df, a), PromoteOddTo(df, b))); +} +#endif // HWY_NEON_HAVE_F32_TO_BF16C + +template +HWY_API Vec128 WidenMulPairwiseAdd(D /*d32*/, Vec128 a, + Vec128 b) { + Vec128 sum1; +#if HWY_ARCH_ARM_A64 + sum1 = Vec128(vmull_high_s16(a.raw, b.raw)); +#else + const Full64 dh; + sum1 = Vec128(vmull_s16(UpperHalf(dh, a).raw, UpperHalf(dh, b).raw)); +#endif + Vec128 sum0 = + Vec128(vmull_s16(LowerHalf(a).raw, LowerHalf(b).raw)); + return RearrangeToOddPlusEven(sum0, sum1); +} + +template +HWY_API Vec64 WidenMulPairwiseAdd(D d32, Vec64 a, + Vec64 b) { + // vmlal writes into the upper half, which the caller cannot use, so + // split into two halves. + const Vec128 mul_3210(vmull_s16(a.raw, b.raw)); + const Vec64 mul0 = LowerHalf(mul_3210); + const Vec64 mul1 = UpperHalf(d32, mul_3210); + return RearrangeToOddPlusEven(mul0, mul1); +} + +template +HWY_API Vec32 WidenMulPairwiseAdd(D d32, Vec32 a, + Vec32 b) { + const Vec128 mul_xx10(vmull_s16(a.raw, b.raw)); + const Vec64 mul_10(LowerHalf(mul_xx10)); + const Vec32 mul0 = LowerHalf(d32, mul_10); + const Vec32 mul1 = UpperHalf(d32, mul_10); + return RearrangeToOddPlusEven(mul0, mul1); +} + +template +HWY_API Vec128 WidenMulPairwiseAdd(D /*d32*/, Vec128 a, + Vec128 b) { + Vec128 sum1; +#if HWY_ARCH_ARM_A64 + sum1 = Vec128(vmull_high_u16(a.raw, b.raw)); +#else + const Full64 dh; + sum1 = + Vec128(vmull_u16(UpperHalf(dh, a).raw, UpperHalf(dh, b).raw)); +#endif + Vec128 sum0 = + Vec128(vmull_u16(LowerHalf(a).raw, LowerHalf(b).raw)); + return RearrangeToOddPlusEven(sum0, sum1); +} + +template +HWY_API Vec64 WidenMulPairwiseAdd(D d32, Vec64 a, + Vec64 b) { + // vmlal writes into the upper half, which the caller cannot use, so + // split into two halves. + const Vec128 mul_3210(vmull_u16(a.raw, b.raw)); + const Vec64 mul0 = LowerHalf(mul_3210); + const Vec64 mul1 = UpperHalf(d32, mul_3210); + return RearrangeToOddPlusEven(mul0, mul1); +} + +template +HWY_API Vec32 WidenMulPairwiseAdd(D d32, Vec32 a, + Vec32 b) { + const Vec128 mul_xx10(vmull_u16(a.raw, b.raw)); + const Vec64 mul_10(LowerHalf(mul_xx10)); + const Vec32 mul0 = LowerHalf(d32, mul_10); + const Vec32 mul1 = UpperHalf(d32, mul_10); + return RearrangeToOddPlusEven(mul0, mul1); +} + +// ------------------------------ ZeroExtendVector (Combine) + +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { + return Combine(d, Zero(Half()), lo); +} + +// ------------------------------ ConcatLowerLower + +// 64 or 128-bit input: just interleave +template +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + // Treat half-width input as a single lane and interleave them. + const Repartition, decltype(d)> du; + return BitCast(d, InterleaveLower(BitCast(du, lo), BitCast(du, hi))); +} + +namespace detail { +#if HWY_ARCH_ARM_A64 +HWY_NEON_DEF_FUNCTION_UIF_8_16_32(InterleaveEven, vtrn1, _, 2) +HWY_NEON_DEF_FUNCTION_UIF_8_16_32(InterleaveOdd, vtrn2, _, 2) +#else + +// vtrn returns a struct with even and odd result. +#define HWY_NEON_BUILD_TPL_HWY_TRN +#define HWY_NEON_BUILD_RET_HWY_TRN(type, size) type##x##size##x2_t +// Pass raw args so we can accept uint16x2 args, for which there is no +// corresponding uint16x2x2 return type. +#define HWY_NEON_BUILD_PARAM_HWY_TRN(TYPE, size) \ + Raw128::type a, Raw128::type b +#define HWY_NEON_BUILD_ARG_HWY_TRN a, b + +// Cannot use UINT8 etc. type macros because the x2_t tuples are only defined +// for full and half vectors. +HWY_NEON_DEF_FUNCTION(uint8, 16, InterleaveEvenOdd, vtrnq, _, u8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint8, 8, InterleaveEvenOdd, vtrn, _, u8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint16, 8, InterleaveEvenOdd, vtrnq, _, u16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint16, 4, InterleaveEvenOdd, vtrn, _, u16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint32, 4, InterleaveEvenOdd, vtrnq, _, u32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(uint32, 2, InterleaveEvenOdd, vtrn, _, u32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int8, 16, InterleaveEvenOdd, vtrnq, _, s8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int8, 8, InterleaveEvenOdd, vtrn, _, s8, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int16, 8, InterleaveEvenOdd, vtrnq, _, s16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int16, 4, InterleaveEvenOdd, vtrn, _, s16, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int32, 4, InterleaveEvenOdd, vtrnq, _, s32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(int32, 2, InterleaveEvenOdd, vtrn, _, s32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(float32, 4, InterleaveEvenOdd, vtrnq, _, f32, HWY_TRN) +HWY_NEON_DEF_FUNCTION(float32, 2, InterleaveEvenOdd, vtrn, _, f32, HWY_TRN) + +#undef HWY_NEON_BUILD_TPL_HWY_TRN +#undef HWY_NEON_BUILD_RET_HWY_TRN +#undef HWY_NEON_BUILD_PARAM_HWY_TRN +#undef HWY_NEON_BUILD_ARG_HWY_TRN + +#endif // HWY_ARCH_ARM_A64 +} // namespace detail + +// <= 32-bit input/output +template +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + // Treat half-width input as two lanes and take every second one. + const Repartition, decltype(d)> du; +#if HWY_ARCH_ARM_A64 + return BitCast(d, detail::InterleaveEven(BitCast(du, lo), BitCast(du, hi))); +#else + using VU = VFromD; + return BitCast( + d, VU(detail::InterleaveEvenOdd(BitCast(du, lo).raw, BitCast(du, hi).raw) + .val[0])); +#endif +} + +// ------------------------------ ConcatUpperUpper + +// 64 or 128-bit input: just interleave +template +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + // Treat half-width input as a single lane and interleave them. + const Repartition, decltype(d)> du; + return BitCast(d, InterleaveUpper(du, BitCast(du, lo), BitCast(du, hi))); +} + +// <= 32-bit input/output +template +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + // Treat half-width input as two lanes and take every second one. + const Repartition, decltype(d)> du; +#if HWY_ARCH_ARM_A64 + return BitCast(d, detail::InterleaveOdd(BitCast(du, lo), BitCast(du, hi))); +#else + using VU = VFromD; + return BitCast( + d, VU(detail::InterleaveEvenOdd(BitCast(du, lo).raw, BitCast(du, hi).raw) + .val[1])); +#endif +} + +// ------------------------------ ConcatLowerUpper (ShiftLeftBytes) + +// 64 or 128-bit input: extract from concatenated +template +HWY_API VFromD ConcatLowerUpper(D d, VFromD hi, VFromD lo) { + return CombineShiftRightBytes(d, hi, lo); +} + +// <= 32-bit input/output +template +HWY_API VFromD ConcatLowerUpper(D d, VFromD hi, VFromD lo) { + constexpr size_t kSize = d.MaxBytes(); + const Repartition d8; + const Full64 d8x8; + const Full64> d64; + using V8x8 = VFromD; + const V8x8 hi8x8(BitCast(d8, hi).raw); + // Move into most-significant bytes + const V8x8 lo8x8 = ShiftLeftBytes<8 - kSize>(V8x8(BitCast(d8, lo).raw)); + const V8x8 r = CombineShiftRightBytes<8 - kSize / 2>(d8x8, hi8x8, lo8x8); + // Back to original lane type, then shrink N. + return VFromD(BitCast(d64, r).raw); +} + +// ------------------------------ ConcatUpperLower + +// Works for all N. +template +HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { + return IfThenElse(FirstN(d, Lanes(d) / 2), lo, hi); +} + +// ------------------------------ ConcatOdd (InterleaveUpper) + +namespace detail { +// There is no vuzpq_u64. +HWY_NEON_DEF_FUNCTION_UIF_8_16_32(ConcatEven, vuzp1, _, 2) +HWY_NEON_DEF_FUNCTION_UIF_8_16_32(ConcatOdd, vuzp2, _, 2) + +#if !HWY_HAVE_FLOAT16 +template +HWY_INLINE Vec128 ConcatEven(Vec128 hi, + Vec128 lo) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, ConcatEven(BitCast(du, hi), BitCast(du, lo))); +} +template +HWY_INLINE Vec128 ConcatOdd(Vec128 hi, + Vec128 lo) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, ConcatOdd(BitCast(du, hi), BitCast(du, lo))); +} +#endif // !HWY_HAVE_FLOAT16 +} // namespace detail + +// Full/half vector +template +HWY_API VFromD ConcatOdd(D /* tag */, VFromD hi, VFromD lo) { + return detail::ConcatOdd(lo, hi); +} + +// 8-bit x4 +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec32 ConcatOdd(D d, Vec32 hi, Vec32 lo) { + const Twice d2; + const Repartition dw2; + const VFromD hi2(hi.raw); + const VFromD lo2(lo.raw); + const VFromD Hx1Lx1 = BitCast(dw2, ConcatOdd(d2, hi2, lo2)); + // Compact into two pairs of u8, skipping the invalid x lanes. Could also use + // vcopy_lane_u16, but that's A64-only. + return Vec32(BitCast(d2, ConcatEven(dw2, Hx1Lx1, Hx1Lx1)).raw); +} + +// Any type x2 +template > +HWY_API Vec128 ConcatOdd(D d, Vec128 hi, Vec128 lo) { + return InterleaveUpper(d, lo, hi); +} + +// ------------------------------ ConcatEven (InterleaveLower) + +// Full/half vector +template +HWY_API VFromD ConcatEven(D /* tag */, VFromD hi, VFromD lo) { + return detail::ConcatEven(lo, hi); +} + +// 8-bit x4 +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec32 ConcatEven(D d, Vec32 hi, Vec32 lo) { + const Twice d2; + const Repartition dw2; + const VFromD hi2(hi.raw); + const VFromD lo2(lo.raw); + const VFromD Hx0Lx0 = BitCast(dw2, ConcatEven(d2, hi2, lo2)); + // Compact into two pairs of u8, skipping the invalid x lanes. Could also use + // vcopy_lane_u16, but that's A64-only. + return Vec32(BitCast(d2, ConcatEven(dw2, Hx0Lx0, Hx0Lx0)).raw); +} + +// Any type x2 +template > +HWY_API Vec128 ConcatEven(D d, Vec128 hi, Vec128 lo) { + return InterleaveLower(d, lo, hi); +} + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec128 DupEven(Vec128 v) { +#if HWY_ARCH_ARM_A64 + return detail::InterleaveEven(v, v); +#else + return Vec128(detail::InterleaveEvenOdd(v.raw, v.raw).val[0]); +#endif +} + +template +HWY_API Vec128 DupEven(Vec128 v) { + return InterleaveLower(DFromV(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec128 DupOdd(Vec128 v) { +#if HWY_ARCH_ARM_A64 + return detail::InterleaveOdd(v, v); +#else + return Vec128(detail::InterleaveEvenOdd(v.raw, v.raw).val[1]); +#endif +} + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return InterleaveUpper(DFromV(), v, v); +} + +// ------------------------------ OddEven (IfThenElse) + +template +HWY_API Vec128 OddEven(const Vec128 a, const Vec128 b) { + const DFromV d; + const Repartition d8; + alignas(16) static constexpr uint8_t kBytes[16] = { + ((0 / sizeof(T)) & 1) ? 0 : 0xFF, ((1 / sizeof(T)) & 1) ? 0 : 0xFF, + ((2 / sizeof(T)) & 1) ? 0 : 0xFF, ((3 / sizeof(T)) & 1) ? 0 : 0xFF, + ((4 / sizeof(T)) & 1) ? 0 : 0xFF, ((5 / sizeof(T)) & 1) ? 0 : 0xFF, + ((6 / sizeof(T)) & 1) ? 0 : 0xFF, ((7 / sizeof(T)) & 1) ? 0 : 0xFF, + ((8 / sizeof(T)) & 1) ? 0 : 0xFF, ((9 / sizeof(T)) & 1) ? 0 : 0xFF, + ((10 / sizeof(T)) & 1) ? 0 : 0xFF, ((11 / sizeof(T)) & 1) ? 0 : 0xFF, + ((12 / sizeof(T)) & 1) ? 0 : 0xFF, ((13 / sizeof(T)) & 1) ? 0 : 0xFF, + ((14 / sizeof(T)) & 1) ? 0 : 0xFF, ((15 / sizeof(T)) & 1) ? 0 : 0xFF, + }; + const auto vec = BitCast(d, Load(d8, kBytes)); + return IfThenElse(MaskFromVec(vec), b, a); +} + +// ------------------------------ InterleaveEven +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { +#if HWY_ARCH_ARM_A64 + return detail::InterleaveEven(a, b); +#else + return VFromD(detail::InterleaveEvenOdd(a.raw, b.raw).val[0]); +#endif +} + +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveOdd +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { +#if HWY_ARCH_ARM_A64 + return detail::InterleaveOdd(a, b); +#else + return VFromD(detail::InterleaveEvenOdd(a.raw, b.raw).val[1]); +#endif +} + +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + return InterleaveUpper(d, a, b); +} + +// ------------------------------ OddEvenBlocks +template +HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks +template +HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { + return v; +} + +// ------------------------------ InterleaveEvenBlocks +template > +HWY_API V InterleaveEvenBlocks(D, V a, V /*b*/) { + return a; +} +// ------------------------------ InterleaveOddBlocks +template > +HWY_API V InterleaveOddBlocks(D, V a, V /*b*/) { + return a; +} + +// ------------------------------ InterleaveLowerBlocks +template > +HWY_API V InterleaveLowerBlocks(D, V a, V /*b*/) { + return a; +} +// ------------------------------ InterleaveUpperBlocks +template > +HWY_API V InterleaveUpperBlocks(D, V a, V /*b*/) { + return a; +} + +// ------------------------------ ReverseBlocks +// Single block: no change +template +HWY_API VFromD ReverseBlocks(D /* tag */, VFromD v) { + return v; +} + +// ------------------------------ ReorderDemote2To (OddEven) + +#if HWY_NEON_HAVE_F32_TO_BF16C +template +HWY_API VFromD ReorderDemote2To(D dbf16, VFromD> a, + VFromD> b) { + const Half dh_bf16; + return Combine(dbf16, DemoteTo(dh_bf16, b), DemoteTo(dh_bf16, a)); +} +#endif // HWY_NEON_HAVE_F32_TO_BF16C + +template +HWY_API Vec128 ReorderDemote2To(D d32, Vec128 a, + Vec128 b) { + const Vec64 a32(vqmovn_s64(a.raw)); +#if HWY_ARCH_ARM_A64 + (void)d32; + return Vec128(vqmovn_high_s64(a32.raw, b.raw)); +#else + const Vec64 b32(vqmovn_s64(b.raw)); + return Combine(d32, b32, a32); +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D d32, VFromD> a, + VFromD> b) { + const Rebind dt; + return DemoteTo(d32, Combine(dt, b, a)); +} + +template +HWY_API Vec128 ReorderDemote2To(D d32, Vec128 a, + Vec128 b) { + const Vec64 a32(vqmovun_s64(a.raw)); +#if HWY_ARCH_ARM_A64 + (void)d32; + return Vec128(vqmovun_high_s64(a32.raw, b.raw)); +#else + const Vec64 b32(vqmovun_s64(b.raw)); + return Combine(d32, b32, a32); +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D d32, VFromD> a, + VFromD> b) { + const Rebind dt; + return DemoteTo(d32, Combine(dt, b, a)); +} + +template +HWY_API Vec128 ReorderDemote2To(D d32, Vec128 a, + Vec128 b) { + const Vec64 a32(vqmovn_u64(a.raw)); +#if HWY_ARCH_ARM_A64 + (void)d32; + return Vec128(vqmovn_high_u64(a32.raw, b.raw)); +#else + const Vec64 b32(vqmovn_u64(b.raw)); + return Combine(d32, b32, a32); +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D d32, VFromD> a, + VFromD> b) { + const Rebind dt; + return DemoteTo(d32, Combine(dt, b, a)); +} + +template +HWY_API Vec128 ReorderDemote2To(D d16, Vec128 a, + Vec128 b) { + const Vec64 a16(vqmovn_s32(a.raw)); +#if HWY_ARCH_ARM_A64 + (void)d16; + return Vec128(vqmovn_high_s32(a16.raw, b.raw)); +#else + const Vec64 b16(vqmovn_s32(b.raw)); + return Combine(d16, b16, a16); +#endif +} + +template +HWY_API Vec64 ReorderDemote2To(D /*d16*/, Vec64 a, + Vec64 b) { + const Full128 d32; + const Vec128 ab = Combine(d32, b, a); + return Vec64(vqmovn_s32(ab.raw)); +} + +template +HWY_API Vec32 ReorderDemote2To(D /*d16*/, Vec32 a, + Vec32 b) { + const Full128 d32; + const Vec64 ab(vzip1_s32(a.raw, b.raw)); + return Vec32(vqmovn_s32(Combine(d32, ab, ab).raw)); +} + +template +HWY_API Vec128 ReorderDemote2To(D d16, Vec128 a, + Vec128 b) { + const Vec64 a16(vqmovun_s32(a.raw)); +#if HWY_ARCH_ARM_A64 + (void)d16; + return Vec128(vqmovun_high_s32(a16.raw, b.raw)); +#else + const Vec64 b16(vqmovun_s32(b.raw)); + return Combine(d16, b16, a16); +#endif +} + +template +HWY_API Vec64 ReorderDemote2To(D /*d16*/, Vec64 a, + Vec64 b) { + const Full128 d32; + const Vec128 ab = Combine(d32, b, a); + return Vec64(vqmovun_s32(ab.raw)); +} + +template +HWY_API Vec32 ReorderDemote2To(D /*d16*/, Vec32 a, + Vec32 b) { + const Full128 d32; + const Vec64 ab(vzip1_s32(a.raw, b.raw)); + return Vec32(vqmovun_s32(Combine(d32, ab, ab).raw)); +} + +template +HWY_API Vec128 ReorderDemote2To(D d16, Vec128 a, + Vec128 b) { + const Vec64 a16(vqmovn_u32(a.raw)); +#if HWY_ARCH_ARM_A64 + (void)d16; + return Vec128(vqmovn_high_u32(a16.raw, b.raw)); +#else + const Vec64 b16(vqmovn_u32(b.raw)); + return Combine(d16, b16, a16); +#endif +} + +template +HWY_API Vec64 ReorderDemote2To(D /*d16*/, Vec64 a, + Vec64 b) { + const Full128 d32; + const Vec128 ab = Combine(d32, b, a); + return Vec64(vqmovn_u32(ab.raw)); +} + +template +HWY_API Vec32 ReorderDemote2To(D /*d16*/, Vec32 a, + Vec32 b) { + const Full128 d32; + const Vec64 ab(vzip1_u32(a.raw, b.raw)); + return Vec32(vqmovn_u32(Combine(d32, ab, ab).raw)); +} + +template +HWY_API Vec128 ReorderDemote2To(D d8, Vec128 a, + Vec128 b) { + const Vec64 a8(vqmovn_s16(a.raw)); +#if HWY_ARCH_ARM_A64 + (void)d8; + return Vec128(vqmovn_high_s16(a8.raw, b.raw)); +#else + const Vec64 b8(vqmovn_s16(b.raw)); + return Combine(d8, b8, a8); +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D d8, VFromD> a, + VFromD> b) { + const Rebind dt; + return DemoteTo(d8, Combine(dt, b, a)); +} + +template +HWY_API Vec128 ReorderDemote2To(D d8, Vec128 a, + Vec128 b) { + const Vec64 a8(vqmovun_s16(a.raw)); +#if HWY_ARCH_ARM_A64 + (void)d8; + return Vec128(vqmovun_high_s16(a8.raw, b.raw)); +#else + const Vec64 b8(vqmovun_s16(b.raw)); + return Combine(d8, b8, a8); +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D d8, VFromD> a, + VFromD> b) { + const Rebind dt; + return DemoteTo(d8, Combine(dt, b, a)); +} + +template +HWY_API Vec128 ReorderDemote2To(D d8, Vec128 a, + Vec128 b) { + const Vec64 a8(vqmovn_u16(a.raw)); +#if HWY_ARCH_ARM_A64 + (void)d8; + return Vec128(vqmovn_high_u16(a8.raw, b.raw)); +#else + const Vec64 b8(vqmovn_u16(b.raw)); + return Combine(d8, b8, a8); +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D d8, VFromD> a, + VFromD> b) { + const Rebind dt; + return DemoteTo(d8, Combine(dt, b, a)); +} + +template ), + HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedDemote2To(D d, V a, V b) { + return ReorderDemote2To(d, a, b); +} + +#if HWY_NEON_HAVE_F32_TO_BF16C +template +HWY_API VFromD OrderedDemote2To(D dbf16, VFromD> a, + VFromD> b) { + return ReorderDemote2To(dbf16, a, b); +} +#endif // HWY_NEON_HAVE_F32_TO_BF16C + +// ================================================== CRYPTO + +// (aarch64 or Arm7) and (__ARM_FEATURE_AES or HWY_HAVE_RUNTIME_DISPATCH). +// Otherwise, rely on generic_ops-inl.h to emulate AESRound / CLMul*. +#if HWY_TARGET != HWY_NEON_WITHOUT_AES + +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +HWY_API Vec128 AESRound(Vec128 state, + Vec128 round_key) { + // NOTE: it is important that AESE and AESMC be consecutive instructions so + // they can be fused. AESE includes AddRoundKey, which is a different ordering + // than the AES-NI semantics we adopted, so XOR by 0 and later with the actual + // round key (the compiler will hopefully optimize this for multiple rounds). + return Vec128(vaesmcq_u8(vaeseq_u8(state.raw, vdupq_n_u8(0)))) ^ + round_key; +} + +HWY_API Vec128 AESLastRound(Vec128 state, + Vec128 round_key) { + return Vec128(vaeseq_u8(state.raw, vdupq_n_u8(0))) ^ round_key; +} + +HWY_API Vec128 AESInvMixColumns(Vec128 state) { + return Vec128{vaesimcq_u8(state.raw)}; +} + +HWY_API Vec128 AESRoundInv(Vec128 state, + Vec128 round_key) { + // NOTE: it is important that AESD and AESIMC be consecutive instructions so + // they can be fused. AESD includes AddRoundKey, which is a different ordering + // than the AES-NI semantics we adopted, so XOR by 0 and later with the actual + // round key (the compiler will hopefully optimize this for multiple rounds). + return Vec128(vaesimcq_u8(vaesdq_u8(state.raw, vdupq_n_u8(0)))) ^ + round_key; +} + +HWY_API Vec128 AESLastRoundInv(Vec128 state, + Vec128 round_key) { + return Vec128(vaesdq_u8(state.raw, vdupq_n_u8(0))) ^ round_key; +} + +HWY_API Vec128 CLMulLower(Vec128 a, Vec128 b) { + return Vec128((uint64x2_t)vmull_p64(GetLane(a), GetLane(b))); +} + +HWY_API Vec128 CLMulUpper(Vec128 a, Vec128 b) { + return Vec128( + (uint64x2_t)vmull_high_p64((poly64x2_t)a.raw, (poly64x2_t)b.raw)); +} + +#endif // HWY_TARGET != HWY_NEON_WITHOUT_AES + +// ================================================== MISC + +template +HWY_API VFromD PromoteTo(D df32, VFromD> v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +// ------------------------------ Truncations + +template , typename TFrom, + HWY_IF_UNSIGNED(TFrom), HWY_IF_UNSIGNED(TTo), + hwy::EnableIf<(sizeof(TTo) < sizeof(TFrom))>* = nullptr> +HWY_API Vec128 TruncateTo(DTo /* tag */, Vec128 v) { + const Repartition> d; + return Vec128{BitCast(d, v).raw}; +} + +template +HWY_API Vec16 TruncateTo(D /* tag */, Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + const auto v3 = detail::ConcatEven(v2, v2); + const auto v4 = detail::ConcatEven(v3, v3); + return LowerHalf(LowerHalf(LowerHalf(v4))); +} + +template +HWY_API Vec32 TruncateTo(D /* tag */, Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + const auto v3 = detail::ConcatEven(v2, v2); + return LowerHalf(LowerHalf(v3)); +} + +template +HWY_API Vec64 TruncateTo(D /* tag */, Vec128 v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + return LowerHalf(v2); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, VFromD> v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + const auto v3 = detail::ConcatEven(v2, v2); + return LowerHalf(LowerHalf(v3)); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, VFromD> v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + return LowerHalf(v2); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, VFromD> v) { + const Repartition> d; + const auto v1 = BitCast(d, v); + const auto v2 = detail::ConcatEven(v1, v1); + return LowerHalf(v2); +} + +// ------------------------------ MulEven (ConcatEven) + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + const DFromV d; + int8x16_t a_packed = ConcatEven(d, a, a).raw; + int8x16_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vmull_s8(vget_low_s8(a_packed), vget_low_s8(b_packed))); +} +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + const DFromV d; + uint8x16_t a_packed = ConcatEven(d, a, a).raw; + uint8x16_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vmull_u8(vget_low_u8(a_packed), vget_low_u8(b_packed))); +} +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + const DFromV d; + int16x8_t a_packed = ConcatEven(d, a, a).raw; + int16x8_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vmull_s16(vget_low_s16(a_packed), vget_low_s16(b_packed))); +} +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + const DFromV d; + uint16x8_t a_packed = ConcatEven(d, a, a).raw; + uint16x8_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vmull_u16(vget_low_u16(a_packed), vget_low_u16(b_packed))); +} +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + const DFromV d; + int32x4_t a_packed = ConcatEven(d, a, a).raw; + int32x4_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vmull_s32(vget_low_s32(a_packed), vget_low_s32(b_packed))); +} +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + const DFromV d; + uint32x4_t a_packed = ConcatEven(d, a, a).raw; + uint32x4_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vmull_u32(vget_low_u32(a_packed), vget_low_u32(b_packed))); +} + +template +HWY_API Vec128 MulEven(Vec128 a, + Vec128 b) { + const DFromV d; + int8x8_t a_packed = ConcatEven(d, a, a).raw; + int8x8_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vget_low_s16(vmull_s8(a_packed, b_packed))); +} +template +HWY_API Vec128 MulEven(Vec128 a, + Vec128 b) { + const DFromV d; + uint8x8_t a_packed = ConcatEven(d, a, a).raw; + uint8x8_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vget_low_u16(vmull_u8(a_packed, b_packed))); +} +template +HWY_API Vec128 MulEven(Vec128 a, + Vec128 b) { + const DFromV d; + int16x4_t a_packed = ConcatEven(d, a, a).raw; + int16x4_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vget_low_s32(vmull_s16(a_packed, b_packed))); +} +template +HWY_API Vec128 MulEven(Vec128 a, + Vec128 b) { + const DFromV d; + uint16x4_t a_packed = ConcatEven(d, a, a).raw; + uint16x4_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vget_low_u32(vmull_u16(a_packed, b_packed))); +} +template +HWY_API Vec128 MulEven(Vec128 a, + Vec128 b) { + const DFromV d; + int32x2_t a_packed = ConcatEven(d, a, a).raw; + int32x2_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vget_low_s64(vmull_s32(a_packed, b_packed))); +} +template +HWY_API Vec128 MulEven(Vec128 a, + Vec128 b) { + const DFromV d; + uint32x2_t a_packed = ConcatEven(d, a, a).raw; + uint32x2_t b_packed = ConcatEven(d, b, b).raw; + return Vec128( + vget_low_u64(vmull_u32(a_packed, b_packed))); +} + +template +HWY_INLINE Vec128 MulEven(Vec128 a, Vec128 b) { + T hi; + T lo = Mul128(GetLane(a), GetLane(b), &hi); + return Dup128VecFromValues(Full128(), lo, hi); +} + +// Multiplies odd lanes (1, 3 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + const DFromV d; + int8x16_t a_packed = ConcatOdd(d, a, a).raw; + int8x16_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vmull_s8(vget_low_s8(a_packed), vget_low_s8(b_packed))); +} +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + const DFromV d; + uint8x16_t a_packed = ConcatOdd(d, a, a).raw; + uint8x16_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vmull_u8(vget_low_u8(a_packed), vget_low_u8(b_packed))); +} +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + const DFromV d; + int16x8_t a_packed = ConcatOdd(d, a, a).raw; + int16x8_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vmull_s16(vget_low_s16(a_packed), vget_low_s16(b_packed))); +} +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + const DFromV d; + uint16x8_t a_packed = ConcatOdd(d, a, a).raw; + uint16x8_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vmull_u16(vget_low_u16(a_packed), vget_low_u16(b_packed))); +} +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + const DFromV d; + int32x4_t a_packed = ConcatOdd(d, a, a).raw; + int32x4_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vmull_s32(vget_low_s32(a_packed), vget_low_s32(b_packed))); +} +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + const DFromV d; + uint32x4_t a_packed = ConcatOdd(d, a, a).raw; + uint32x4_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vmull_u32(vget_low_u32(a_packed), vget_low_u32(b_packed))); +} + +template +HWY_API Vec128 MulOdd(Vec128 a, + Vec128 b) { + const DFromV d; + int8x8_t a_packed = ConcatOdd(d, a, a).raw; + int8x8_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vget_low_s16(vmull_s8(a_packed, b_packed))); +} +template +HWY_API Vec128 MulOdd(Vec128 a, + Vec128 b) { + const DFromV d; + uint8x8_t a_packed = ConcatOdd(d, a, a).raw; + uint8x8_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vget_low_u16(vmull_u8(a_packed, b_packed))); +} +template +HWY_API Vec128 MulOdd(Vec128 a, + Vec128 b) { + const DFromV d; + int16x4_t a_packed = ConcatOdd(d, a, a).raw; + int16x4_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vget_low_s32(vmull_s16(a_packed, b_packed))); +} +template +HWY_API Vec128 MulOdd(Vec128 a, + Vec128 b) { + const DFromV d; + uint16x4_t a_packed = ConcatOdd(d, a, a).raw; + uint16x4_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vget_low_u32(vmull_u16(a_packed, b_packed))); +} +template +HWY_API Vec128 MulOdd(Vec128 a, + Vec128 b) { + const DFromV d; + int32x2_t a_packed = ConcatOdd(d, a, a).raw; + int32x2_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vget_low_s64(vmull_s32(a_packed, b_packed))); +} +template +HWY_API Vec128 MulOdd(Vec128 a, + Vec128 b) { + const DFromV d; + uint32x2_t a_packed = ConcatOdd(d, a, a).raw; + uint32x2_t b_packed = ConcatOdd(d, b, b).raw; + return Vec128( + vget_low_u64(vmull_u32(a_packed, b_packed))); +} + +template +HWY_INLINE Vec128 MulOdd(Vec128 a, Vec128 b) { + T hi; + T lo = Mul128(detail::GetLane<1>(a), detail::GetLane<1>(b), &hi); + return Dup128VecFromValues(Full128(), lo, hi); +} + +// ------------------------------ TableLookupBytes (Combine, LowerHalf) + +// Both full +template +HWY_API Vec128 TableLookupBytes(Vec128 bytes, Vec128 from) { + const DFromV d; + const Repartition d8; +#if HWY_ARCH_ARM_A64 + return BitCast(d, Vec128(vqtbl1q_u8(BitCast(d8, bytes).raw, + BitCast(d8, from).raw))); +#else + uint8x16_t table0 = BitCast(d8, bytes).raw; + uint8x8x2_t table; + table.val[0] = vget_low_u8(table0); + table.val[1] = vget_high_u8(table0); + uint8x16_t idx = BitCast(d8, from).raw; + uint8x8_t low = vtbl2_u8(table, vget_low_u8(idx)); + uint8x8_t hi = vtbl2_u8(table, vget_high_u8(idx)); + return BitCast(d, Vec128(vcombine_u8(low, hi))); +#endif +} + +// Partial index vector +template +HWY_API Vec128 TableLookupBytes(Vec128 bytes, Vec128 from) { + const Full128 d_full; + const Vec64 from64(from.raw); + const auto idx_full = Combine(d_full, from64, from64); + const auto out_full = TableLookupBytes(bytes, idx_full); + return Vec128(LowerHalf(Half(), out_full).raw); +} + +// Partial table vector +template +HWY_API Vec128 TableLookupBytes(Vec128 bytes, Vec128 from) { + const Full128 d_full; + return TableLookupBytes(Combine(d_full, bytes, bytes), from); +} + +// Partial both +template +HWY_API Vec128 TableLookupBytes(Vec128 bytes, + Vec128 from) { + const DFromV d; + const Simd d_idx; + const Repartition d_idx8; + // uint8x8 + const auto bytes8 = BitCast(Repartition(), bytes); + const auto from8 = BitCast(d_idx8, from); + const VFromD v8(vtbl1_u8(bytes8.raw, from8.raw)); + return BitCast(d_idx, v8); +} + +// For all vector widths; Arm anyway zeroes if >= 0x10. +template +HWY_API VI TableLookupBytesOr0(V bytes, VI from) { + return TableLookupBytes(bytes, from); +} + +// ---------------------------- AESKeyGenAssist (AESLastRound, TableLookupBytes) + +#if HWY_TARGET != HWY_NEON_WITHOUT_AES +template +HWY_API Vec128 AESKeyGenAssist(Vec128 v) { + alignas(16) static constexpr uint8_t kRconXorMask[16] = { + 0, kRcon, 0, 0, 0, 0, 0, 0, 0, kRcon, 0, 0, 0, 0, 0, 0}; + alignas(16) static constexpr uint8_t kRotWordShuffle[16] = { + 0, 13, 10, 7, 1, 14, 11, 4, 8, 5, 2, 15, 9, 6, 3, 12}; + const DFromV d; + const Repartition du32; + const auto w13 = BitCast(d, DupOdd(BitCast(du32, v))); + const auto sub_word_result = AESLastRound(w13, Load(d, kRconXorMask)); + return TableLookupBytes(sub_word_result, Load(d, kRotWordShuffle)); +} +#endif // HWY_TARGET != HWY_NEON_WITHOUT_AES + +// ------------------------------ Scatter in generic_ops-inl.h +// ------------------------------ Gather in generic_ops-inl.h + +// ------------------------------ Reductions + +// On Armv8 we define ReduceSum and generic_ops defines SumOfLanes via Set. +#if HWY_ARCH_ARM_A64 + +#ifdef HWY_NATIVE_REDUCE_SCALAR +#undef HWY_NATIVE_REDUCE_SCALAR +#else +#define HWY_NATIVE_REDUCE_SCALAR +#endif + +// TODO(janwas): use normal HWY_NEON_DEF, then FULL type list. +#define HWY_NEON_DEF_REDUCTION(type, size, name, prefix, infix, suffix) \ + template \ + HWY_API type##_t name(D /* tag */, Vec128 v) { \ + return HWY_NEON_EVAL(prefix##infix##suffix, v.raw); \ + } + +// Excludes u64/s64 (missing minv/maxv) and f16 (missing addv). +#define HWY_NEON_DEF_REDUCTION_CORE_TYPES(name, prefix) \ + HWY_NEON_DEF_REDUCTION(uint8, 8, name, prefix, _, u8) \ + HWY_NEON_DEF_REDUCTION(uint8, 16, name, prefix##q, _, u8) \ + HWY_NEON_DEF_REDUCTION(uint16, 4, name, prefix, _, u16) \ + HWY_NEON_DEF_REDUCTION(uint16, 8, name, prefix##q, _, u16) \ + HWY_NEON_DEF_REDUCTION(uint32, 2, name, prefix, _, u32) \ + HWY_NEON_DEF_REDUCTION(uint32, 4, name, prefix##q, _, u32) \ + HWY_NEON_DEF_REDUCTION(int8, 8, name, prefix, _, s8) \ + HWY_NEON_DEF_REDUCTION(int8, 16, name, prefix##q, _, s8) \ + HWY_NEON_DEF_REDUCTION(int16, 4, name, prefix, _, s16) \ + HWY_NEON_DEF_REDUCTION(int16, 8, name, prefix##q, _, s16) \ + HWY_NEON_DEF_REDUCTION(int32, 2, name, prefix, _, s32) \ + HWY_NEON_DEF_REDUCTION(int32, 4, name, prefix##q, _, s32) \ + HWY_NEON_DEF_REDUCTION(float32, 2, name, prefix, _, f32) \ + HWY_NEON_DEF_REDUCTION(float32, 4, name, prefix##q, _, f32) \ + HWY_NEON_DEF_REDUCTION(float64, 2, name, prefix##q, _, f64) + +// Different interface than HWY_NEON_DEF_FUNCTION_FULL_UI_64. +#define HWY_NEON_DEF_REDUCTION_UI64(name, prefix) \ + HWY_NEON_DEF_REDUCTION(uint64, 2, name, prefix##q, _, u64) \ + HWY_NEON_DEF_REDUCTION(int64, 2, name, prefix##q, _, s64) + +#if HWY_HAVE_FLOAT16 +#define HWY_NEON_DEF_REDUCTION_F16(name, prefix) \ + HWY_NEON_DEF_REDUCTION(float16, 4, name, prefix, _, f16) \ + HWY_NEON_DEF_REDUCTION(float16, 8, name, prefix##q, _, f16) +#else +#define HWY_NEON_DEF_REDUCTION_F16(name, prefix) +#endif + +HWY_NEON_DEF_REDUCTION_CORE_TYPES(ReduceMin, vminv) +HWY_NEON_DEF_REDUCTION_CORE_TYPES(ReduceMax, vmaxv) +HWY_NEON_DEF_REDUCTION_F16(ReduceMin, vminv) +HWY_NEON_DEF_REDUCTION_F16(ReduceMax, vmaxv) + +HWY_NEON_DEF_REDUCTION_CORE_TYPES(ReduceSum, vaddv) +HWY_NEON_DEF_REDUCTION_UI64(ReduceSum, vaddv) + +// Emulate missing UI64 and partial N=2. +template +HWY_API TFromD ReduceSum(D /* tag */, VFromD v10) { + return GetLane(v10) + ExtractLane(v10, 1); +} + +template +HWY_API TFromD ReduceMin(D /* tag */, VFromD v10) { + return HWY_MIN(GetLane(v10), ExtractLane(v10, 1)); +} + +template +HWY_API TFromD ReduceMax(D /* tag */, VFromD v10) { + return HWY_MAX(GetLane(v10), ExtractLane(v10, 1)); +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API float16_t ReduceMin(D d, VFromD v10) { + return GetLane(Min(v10, Reverse2(d, v10))); +} + +template +HWY_API float16_t ReduceMax(D d, VFromD v10) { + return GetLane(Max(v10, Reverse2(d, v10))); +} + +template +HWY_API float16_t ReduceSum(D /* tag */, VFromD v) { + const float16x4_t x2 = vpadd_f16(v.raw, v.raw); + return GetLane(VFromD(vpadd_f16(x2, x2))); +} +template +HWY_API float16_t ReduceSum(D d, VFromD v) { + const Half dh; + return ReduceSum(dh, LowerHalf(dh, VFromD(vpaddq_f16(v.raw, v.raw)))); +} +#endif // HWY_HAVE_FLOAT16 + +#undef HWY_NEON_DEF_REDUCTION_CORE_TYPES +#undef HWY_NEON_DEF_REDUCTION_F16 +#undef HWY_NEON_DEF_REDUCTION_UI64 +#undef HWY_NEON_DEF_REDUCTION + +// ------------------------------ SumOfLanes + +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + return Set(d, ReduceSum(d, v)); +} +template +HWY_API VFromD MinOfLanes(D d, VFromD v) { + return Set(d, ReduceMin(d, v)); +} +template +HWY_API VFromD MaxOfLanes(D d, VFromD v) { + return Set(d, ReduceMax(d, v)); +} + +// On Armv7 we define SumOfLanes and generic_ops defines ReduceSum via GetLane. +#else // !HWY_ARCH_ARM_A64 + +// Armv7 lacks N=2 (except 32-bit) and 8-bit x4, so enable them in generic_ops. +#undef HWY_IF_SUM_OF_LANES_D +#define HWY_IF_SUM_OF_LANES_D(D) \ + hwy::EnableIf<(sizeof(TFromD) != 4 && HWY_MAX_LANES_D(D) == 2) || \ + (sizeof(TFromD) == 1 && HWY_MAX_LANES_D(D) == 4)>* = \ + nullptr +#undef HWY_IF_MINMAX_OF_LANES_D +#define HWY_IF_MINMAX_OF_LANES_D(D) \ + hwy::EnableIf<(sizeof(TFromD) != 4 && HWY_MAX_LANES_D(D) == 2) || \ + (sizeof(TFromD) == 1 && HWY_MAX_LANES_D(D) == 4)>* = \ + nullptr + +// For arm7, we implement reductions using a series of pairwise operations. This +// produces the full vector result, so we express Reduce* in terms of *OfLanes. + +#define HWY_NEON_DEF_PAIRWISE_REDUCTION(name) \ + /* generic_ops-inl.h handles 64-bit types. */ \ + template \ + HWY_API VFromD name##OfLanes(D d, VFromD v) { \ + HWY_LANES_CONSTEXPR size_t N = Lanes(d); \ + VFromD tmp = detail::Pairwise##name(v, v); \ + if ((N / 2) > 1) tmp = detail::Pairwise##name(tmp, tmp); \ + if ((N / 4) > 1) tmp = detail::Pairwise##name(tmp, tmp); \ + return tmp; \ + } \ + /* Armv7 lacks q (full-vector) instructions, so first reduce 128-bit v */ \ + /* into a half-vector, then reduce that. */ \ + template \ + HWY_API VFromD name##OfLanes(D d, VFromD v) { \ + const Half dh; \ + VFromD upper = UpperHalf(dh, v); \ + VFromD lower = LowerHalf(dh, v); \ + VFromD half = detail::Pairwise##name(upper, lower); \ + half = name##OfLanes(dh, half); \ + return Combine(d, half, half); \ + } + +HWY_NEON_DEF_PAIRWISE_REDUCTION(Sum) +HWY_NEON_DEF_PAIRWISE_REDUCTION(Min) +HWY_NEON_DEF_PAIRWISE_REDUCTION(Max) +#undef HWY_NEON_DEF_PAIRWISE_REDUCTION + +// GetLane(SumsOf4(v)) is more efficient on ArmV7 NEON than the default +// N=4 I8/U8 ReduceSum implementation in generic_ops-inl.h +#ifdef HWY_NATIVE_REDUCE_SUM_4_UI8 +#undef HWY_NATIVE_REDUCE_SUM_4_UI8 +#else +#define HWY_NATIVE_REDUCE_SUM_4_UI8 +#endif + +template +HWY_API TFromD ReduceSum(D /*d*/, VFromD v) { + return static_cast>(GetLane(SumsOf4(v))); +} + +#endif // HWY_ARCH_ARM_A64 + +// ------------------------------ LoadMaskBits (TestBit) + +namespace detail { + +// Helper function to set 64 bits and potentially return a smaller vector. The +// overload is required to call the q vs non-q intrinsics. Note that 8-bit +// LoadMaskBits only requires 16 bits, but 64 avoids casting. +template +HWY_INLINE VFromD Set64(D /* tag */, uint64_t mask_bits) { + const auto v64 = Vec64(vdup_n_u64(mask_bits)); + return VFromD(BitCast(Full64>(), v64).raw); +} +template +HWY_INLINE Vec128 Set64(Full128 d, uint64_t mask_bits) { + return BitCast(d, Vec128(vdupq_n_u64(mask_bits))); +} + +template +HWY_INLINE MFromD LoadMaskBits(D d, uint64_t mask_bits) { + const RebindToUnsigned du; + // Easier than Set(), which would require an >8-bit type, which would not + // compile for T=uint8_t, N=1. + const auto vmask_bits = Set64(du, mask_bits); + + // Replicate bytes 8x such that each byte contains the bit that governs it. + alignas(16) static constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1}; + const auto rep8 = TableLookupBytes(vmask_bits, Load(du, kRep8)); + + alignas(16) static constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, + 1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); +} + +template +HWY_INLINE MFromD LoadMaskBits(D d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) static constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE MFromD LoadMaskBits(D d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) static constexpr uint32_t kBit[8] = {1, 2, 4, 8}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE MFromD LoadMaskBits(D d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) static constexpr uint64_t kBit[8] = {1, 2}; + return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); +} + +} // namespace detail + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + CopyBytes<(d.MaxLanes() + 7) / 8>(bits, &mask_bits); + return detail::LoadMaskBits(d, mask_bits); +} + +// ------------------------------ Dup128MaskFromMaskBits + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= (1u << kN) - 1; + return detail::LoadMaskBits(d, mask_bits); +} + +// ------------------------------ Mask + +namespace detail { + +// Returns mask[i]? 0xF : 0 in each nibble. This is more efficient than +// BitsFromMask for use in (partial) CountTrue, FindFirstTrue and AllFalse. +template +HWY_INLINE uint64_t NibblesFromMask(D d, MFromD mask) { + const Full128 du16; + const Vec128 vu16 = BitCast(du16, VecFromMask(d, mask)); + const Vec64 nib(vshrn_n_u16(vu16.raw, 4)); + return GetLane(BitCast(Full64(), nib)); +} + +template +HWY_INLINE uint64_t NibblesFromMask(D d, MFromD mask) { + // There is no vshrn_n_u16 for uint16x4, so zero-extend. + const Twice d2; + const VFromD v128 = ZeroExtendVector(d2, VecFromMask(d, mask)); + // No need to mask, upper half is zero thanks to ZeroExtendVector. + return NibblesFromMask(d2, MaskFromVec(v128)); +} + +template +HWY_INLINE uint64_t NibblesFromMask(D d, MFromD mask) { + const Mask64> mask64(mask.raw); + const uint64_t nib = NibblesFromMask(Full64>(), mask64); + // Clear nibbles from upper half of 64-bits + return nib & ((1ull << (d.MaxBytes() * 4)) - 1); +} + +// Returns the lowest N for the BitsFromMask result. +template +constexpr uint64_t OnlyActive(D d, uint64_t bits) { + return (d.MaxBytes() >= 8) ? bits : (bits & ((1ull << d.MaxLanes()) - 1)); +} + +} // namespace detail + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + alignas(16) static constexpr uint8_t kSliceLanes[16] = { + 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80, 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80, + }; + const RebindToUnsigned du; + const Vec128 values = + BitCast(du, VecFromMask(d, mask)) & Load(du, kSliceLanes); + +#if HWY_ARCH_ARM_A64 + // Can't vaddv - we need two separate bytes (16 bits). + const uint8x8_t x2 = vget_low_u8(vpaddq_u8(values.raw, values.raw)); + const uint8x8_t x4 = vpadd_u8(x2, x2); + const uint8x8_t x8 = vpadd_u8(x4, x4); + return vget_lane_u64(vreinterpret_u64_u8(x8), 0) & 0xFFFF; +#else + // Don't have vpaddq, so keep doubling lane size. + const uint16x8_t x2 = vpaddlq_u8(values.raw); + const uint32x4_t x4 = vpaddlq_u16(x2); + const uint64x2_t x8 = vpaddlq_u32(x4); + return (vgetq_lane_u64(x8, 1) << 8) | vgetq_lane_u64(x8, 0); +#endif +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + // Upper lanes of partial loads are undefined. OnlyActive will fix this if + // we load all kSliceLanes so the upper lanes do not pollute the valid bits. + alignas(8) static constexpr uint8_t kSliceLanes[8] = {1, 2, 4, 8, + 0x10, 0x20, 0x40, 0x80}; + const RebindToUnsigned du; + using VU = VFromD; + const VU slice(Load(Full64(), kSliceLanes).raw); + const VU values = BitCast(du, VecFromMask(d, mask)) & slice; + +#if HWY_ARCH_ARM_A64 + return detail::OnlyActive(d, vaddv_u8(values.raw)); +#else + const uint16x4_t x2 = vpaddl_u8(values.raw); + const uint32x2_t x4 = vpaddl_u16(x2); + const uint64x1_t x8 = vpaddl_u32(x4); + return detail::OnlyActive(d, vget_lane_u64(x8, 0)); +#endif +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + alignas(16) static constexpr uint16_t kSliceLanes[8] = { + 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80}; + const RebindToUnsigned du; + const Vec128 values = + BitCast(du, VecFromMask(d, mask)) & Load(du, kSliceLanes); +#if HWY_ARCH_ARM_A64 + return detail::OnlyActive(d, vaddvq_u16(values.raw)); +#else + const uint32x4_t x2 = vpaddlq_u16(values.raw); + const uint64x2_t x4 = vpaddlq_u32(x2); + return detail::OnlyActive(d, vgetq_lane_u64(x4, 0) + vgetq_lane_u64(x4, 1)); +#endif +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + // Upper lanes of partial loads are undefined. OnlyActive will fix this if + // we load all kSliceLanes so the upper lanes do not pollute the valid bits. + alignas(8) static constexpr uint16_t kSliceLanes[4] = {1, 2, 4, 8}; + const RebindToUnsigned du; + using VU = VFromD; + const VU slice(Load(Full64(), kSliceLanes).raw); + const VU values = BitCast(du, VecFromMask(d, mask)) & slice; +#if HWY_ARCH_ARM_A64 + return detail::OnlyActive(d, vaddv_u16(values.raw)); +#else + const uint32x2_t x2 = vpaddl_u16(values.raw); + const uint64x1_t x4 = vpaddl_u32(x2); + return detail::OnlyActive(d, vget_lane_u64(x4, 0)); +#endif +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + alignas(16) static constexpr uint32_t kSliceLanes[4] = {1, 2, 4, 8}; + const RebindToUnsigned du; + const Vec128 values = + BitCast(du, VecFromMask(d, mask)) & Load(du, kSliceLanes); +#if HWY_ARCH_ARM_A64 + return detail::OnlyActive(d, vaddvq_u32(values.raw)); +#else + const uint64x2_t x2 = vpaddlq_u32(values.raw); + return detail::OnlyActive(d, vgetq_lane_u64(x2, 0) + vgetq_lane_u64(x2, 1)); +#endif +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + // Upper lanes of partial loads are undefined. OnlyActive will fix this if + // we load all kSliceLanes so the upper lanes do not pollute the valid bits. + alignas(8) static constexpr uint32_t kSliceLanes[2] = {1, 2}; + const RebindToUnsigned du; + using VU = VFromD; + const VU slice(Load(Full64(), kSliceLanes).raw); + const VU values = BitCast(du, VecFromMask(d, mask)) & slice; +#if HWY_ARCH_ARM_A64 + return detail::OnlyActive(d, vaddv_u32(values.raw)); +#else + const uint64x1_t x2 = vpaddl_u32(values.raw); + return detail::OnlyActive(d, vget_lane_u64(x2, 0)); +#endif +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + alignas(16) static constexpr uint64_t kSliceLanes[2] = {1, 2}; + const RebindToUnsigned du; + const Vec128 values = + BitCast(du, VecFromMask(d, mask)) & Load(du, kSliceLanes); +#if HWY_ARCH_ARM_A64 + return detail::OnlyActive(d, vaddvq_u64(values.raw)); +#else + return detail::OnlyActive( + d, vgetq_lane_u64(values.raw, 0) + vgetq_lane_u64(values.raw, 1)); +#endif +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const RebindToUnsigned du; + const Vec64 values = BitCast(du, VecFromMask(d, mask)) & Set(du, 1); + return vget_lane_u64(values.raw, 0); +} + +namespace detail { + +// Returns number of lanes whose mask is set. +// +// Masks are either FF..FF or 0. Unfortunately there is no reduce-sub op +// ("vsubv"). ANDing with 1 would work but requires a constant. Negating also +// changes each lane to 1 (if mask set) or 0. +// NOTE: PopCount also operates on vectors, so we still have to do horizontal +// sums separately. We specialize CountTrue for full vectors (negating instead +// of PopCount because it avoids an extra shift), and use PopCount of +// NibblesFromMask for partial vectors. + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<1> /*tag*/, Mask128 mask) { + const Full128 di; + const int8x16_t ones = + vnegq_s8(BitCast(di, VecFromMask(Full128(), mask)).raw); + +#if HWY_ARCH_ARM_A64 + return static_cast(vaddvq_s8(ones)); +#else + const int16x8_t x2 = vpaddlq_s8(ones); + const int32x4_t x4 = vpaddlq_s16(x2); + const int64x2_t x8 = vpaddlq_s32(x4); + return static_cast(vgetq_lane_s64(x8, 0) + vgetq_lane_s64(x8, 1)); +#endif +} +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<2> /*tag*/, Mask128 mask) { + const Full128 di; + const int16x8_t ones = + vnegq_s16(BitCast(di, VecFromMask(Full128(), mask)).raw); + +#if HWY_ARCH_ARM_A64 + return static_cast(vaddvq_s16(ones)); +#else + const int32x4_t x2 = vpaddlq_s16(ones); + const int64x2_t x4 = vpaddlq_s32(x2); + return static_cast(vgetq_lane_s64(x4, 0) + vgetq_lane_s64(x4, 1)); +#endif +} + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<4> /*tag*/, Mask128 mask) { + const Full128 di; + const int32x4_t ones = + vnegq_s32(BitCast(di, VecFromMask(Full128(), mask)).raw); + +#if HWY_ARCH_ARM_A64 + return static_cast(vaddvq_s32(ones)); +#else + const int64x2_t x2 = vpaddlq_s32(ones); + return static_cast(vgetq_lane_s64(x2, 0) + vgetq_lane_s64(x2, 1)); +#endif +} + +template +HWY_INLINE size_t CountTrue(hwy::SizeTag<8> /*tag*/, Mask128 mask) { +#if HWY_ARCH_ARM_A64 + const Full128 di; + const int64x2_t ones = + vnegq_s64(BitCast(di, VecFromMask(Full128(), mask)).raw); + return static_cast(vaddvq_s64(ones)); +#else + const Full128 du; + const auto mask_u = VecFromMask(du, RebindMask(du, mask)); + const uint64x2_t ones = vshrq_n_u64(mask_u.raw, 63); + return static_cast(vgetq_lane_u64(ones, 0) + vgetq_lane_u64(ones, 1)); +#endif +} + +} // namespace detail + +// Full +template > +HWY_API size_t CountTrue(D /* tag */, Mask128 mask) { + return detail::CountTrue(hwy::SizeTag(), mask); +} + +// Partial +template +HWY_API size_t CountTrue(D d, MFromD mask) { + constexpr int kDiv = 4 * sizeof(TFromD); + return PopCount(detail::NibblesFromMask(d, mask)) / kDiv; +} + +template +HWY_API size_t FindKnownFirstTrue(D d, MFromD mask) { + const uint64_t nib = detail::NibblesFromMask(d, mask); + constexpr size_t kDiv = 4 * sizeof(TFromD); + return Num0BitsBelowLS1Bit_Nonzero64(nib) / kDiv; +} + +template +HWY_API intptr_t FindFirstTrue(D d, MFromD mask) { + const uint64_t nib = detail::NibblesFromMask(d, mask); + if (nib == 0) return -1; + constexpr size_t kDiv = 4 * sizeof(TFromD); + return static_cast(Num0BitsBelowLS1Bit_Nonzero64(nib) / kDiv); +} + +template +HWY_API size_t FindKnownLastTrue(D d, MFromD mask) { + const uint64_t nib = detail::NibblesFromMask(d, mask); + constexpr size_t kDiv = 4 * sizeof(TFromD); + return (63 - Num0BitsAboveMS1Bit_Nonzero64(nib)) / kDiv; +} + +template +HWY_API intptr_t FindLastTrue(D d, MFromD mask) { + const uint64_t nib = detail::NibblesFromMask(d, mask); + if (nib == 0) return -1; + constexpr size_t kDiv = 4 * sizeof(TFromD); + return static_cast((63 - Num0BitsAboveMS1Bit_Nonzero64(nib)) / + kDiv); +} + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(D d, MFromD mask, uint8_t* bits) { + const uint64_t mask_bits = BitsFromMask(d, mask); + const size_t kNumBytes = (d.MaxLanes() + 7) / 8; + CopyBytes(&mask_bits, bits); + return kNumBytes; +} + +template +HWY_API bool AllFalse(D d, MFromD m) { + return detail::NibblesFromMask(d, m) == 0; +} + +// Full +template > +HWY_API bool AllTrue(D d, Mask128 m) { + return detail::NibblesFromMask(d, m) == ~0ull; +} +// Partial +template +HWY_API bool AllTrue(D d, MFromD m) { + return detail::NibblesFromMask(d, m) == (1ull << (d.MaxBytes() * 4)) - 1; +} + +// ------------------------------ Compress + +template +struct CompressIsPartition { + enum { value = (sizeof(T) != 1) }; +}; + +namespace detail { + +// Load 8 bytes, replicate into upper half so ZipLower can use the lower half. +template +HWY_INLINE Vec128 Load8Bytes(D /*tag*/, const uint8_t* bytes) { + return Vec128(vreinterpretq_u8_u64( + vld1q_dup_u64(HWY_RCAST_ALIGNED(const uint64_t*, bytes)))); +} + +// Load 8 bytes and return half-reg with N <= 8 bytes. +template +HWY_INLINE VFromD Load8Bytes(D d, const uint8_t* bytes) { + return Load(d, bytes); +} + +template +HWY_INLINE Vec128 IdxFromBits(hwy::SizeTag<2> /*tag*/, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd d; + const Repartition d8; + const Simd du; + + // NEON does not provide an equivalent of AVX2 permutevar, so we need byte + // indices for VTBL (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) static constexpr uint8_t table[256 * 8] = { + // PrintCompress16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 2, 0, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 4, 0, 2, 6, 8, 10, 12, 14, /**/ 0, 4, 2, 6, 8, 10, 12, 14, // + 2, 4, 0, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 6, 0, 2, 4, 8, 10, 12, 14, /**/ 0, 6, 2, 4, 8, 10, 12, 14, // + 2, 6, 0, 4, 8, 10, 12, 14, /**/ 0, 2, 6, 4, 8, 10, 12, 14, // + 4, 6, 0, 2, 8, 10, 12, 14, /**/ 0, 4, 6, 2, 8, 10, 12, 14, // + 2, 4, 6, 0, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 8, 0, 2, 4, 6, 10, 12, 14, /**/ 0, 8, 2, 4, 6, 10, 12, 14, // + 2, 8, 0, 4, 6, 10, 12, 14, /**/ 0, 2, 8, 4, 6, 10, 12, 14, // + 4, 8, 0, 2, 6, 10, 12, 14, /**/ 0, 4, 8, 2, 6, 10, 12, 14, // + 2, 4, 8, 0, 6, 10, 12, 14, /**/ 0, 2, 4, 8, 6, 10, 12, 14, // + 6, 8, 0, 2, 4, 10, 12, 14, /**/ 0, 6, 8, 2, 4, 10, 12, 14, // + 2, 6, 8, 0, 4, 10, 12, 14, /**/ 0, 2, 6, 8, 4, 10, 12, 14, // + 4, 6, 8, 0, 2, 10, 12, 14, /**/ 0, 4, 6, 8, 2, 10, 12, 14, // + 2, 4, 6, 8, 0, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 10, 0, 2, 4, 6, 8, 12, 14, /**/ 0, 10, 2, 4, 6, 8, 12, 14, // + 2, 10, 0, 4, 6, 8, 12, 14, /**/ 0, 2, 10, 4, 6, 8, 12, 14, // + 4, 10, 0, 2, 6, 8, 12, 14, /**/ 0, 4, 10, 2, 6, 8, 12, 14, // + 2, 4, 10, 0, 6, 8, 12, 14, /**/ 0, 2, 4, 10, 6, 8, 12, 14, // + 6, 10, 0, 2, 4, 8, 12, 14, /**/ 0, 6, 10, 2, 4, 8, 12, 14, // + 2, 6, 10, 0, 4, 8, 12, 14, /**/ 0, 2, 6, 10, 4, 8, 12, 14, // + 4, 6, 10, 0, 2, 8, 12, 14, /**/ 0, 4, 6, 10, 2, 8, 12, 14, // + 2, 4, 6, 10, 0, 8, 12, 14, /**/ 0, 2, 4, 6, 10, 8, 12, 14, // + 8, 10, 0, 2, 4, 6, 12, 14, /**/ 0, 8, 10, 2, 4, 6, 12, 14, // + 2, 8, 10, 0, 4, 6, 12, 14, /**/ 0, 2, 8, 10, 4, 6, 12, 14, // + 4, 8, 10, 0, 2, 6, 12, 14, /**/ 0, 4, 8, 10, 2, 6, 12, 14, // + 2, 4, 8, 10, 0, 6, 12, 14, /**/ 0, 2, 4, 8, 10, 6, 12, 14, // + 6, 8, 10, 0, 2, 4, 12, 14, /**/ 0, 6, 8, 10, 2, 4, 12, 14, // + 2, 6, 8, 10, 0, 4, 12, 14, /**/ 0, 2, 6, 8, 10, 4, 12, 14, // + 4, 6, 8, 10, 0, 2, 12, 14, /**/ 0, 4, 6, 8, 10, 2, 12, 14, // + 2, 4, 6, 8, 10, 0, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 12, 0, 2, 4, 6, 8, 10, 14, /**/ 0, 12, 2, 4, 6, 8, 10, 14, // + 2, 12, 0, 4, 6, 8, 10, 14, /**/ 0, 2, 12, 4, 6, 8, 10, 14, // + 4, 12, 0, 2, 6, 8, 10, 14, /**/ 0, 4, 12, 2, 6, 8, 10, 14, // + 2, 4, 12, 0, 6, 8, 10, 14, /**/ 0, 2, 4, 12, 6, 8, 10, 14, // + 6, 12, 0, 2, 4, 8, 10, 14, /**/ 0, 6, 12, 2, 4, 8, 10, 14, // + 2, 6, 12, 0, 4, 8, 10, 14, /**/ 0, 2, 6, 12, 4, 8, 10, 14, // + 4, 6, 12, 0, 2, 8, 10, 14, /**/ 0, 4, 6, 12, 2, 8, 10, 14, // + 2, 4, 6, 12, 0, 8, 10, 14, /**/ 0, 2, 4, 6, 12, 8, 10, 14, // + 8, 12, 0, 2, 4, 6, 10, 14, /**/ 0, 8, 12, 2, 4, 6, 10, 14, // + 2, 8, 12, 0, 4, 6, 10, 14, /**/ 0, 2, 8, 12, 4, 6, 10, 14, // + 4, 8, 12, 0, 2, 6, 10, 14, /**/ 0, 4, 8, 12, 2, 6, 10, 14, // + 2, 4, 8, 12, 0, 6, 10, 14, /**/ 0, 2, 4, 8, 12, 6, 10, 14, // + 6, 8, 12, 0, 2, 4, 10, 14, /**/ 0, 6, 8, 12, 2, 4, 10, 14, // + 2, 6, 8, 12, 0, 4, 10, 14, /**/ 0, 2, 6, 8, 12, 4, 10, 14, // + 4, 6, 8, 12, 0, 2, 10, 14, /**/ 0, 4, 6, 8, 12, 2, 10, 14, // + 2, 4, 6, 8, 12, 0, 10, 14, /**/ 0, 2, 4, 6, 8, 12, 10, 14, // + 10, 12, 0, 2, 4, 6, 8, 14, /**/ 0, 10, 12, 2, 4, 6, 8, 14, // + 2, 10, 12, 0, 4, 6, 8, 14, /**/ 0, 2, 10, 12, 4, 6, 8, 14, // + 4, 10, 12, 0, 2, 6, 8, 14, /**/ 0, 4, 10, 12, 2, 6, 8, 14, // + 2, 4, 10, 12, 0, 6, 8, 14, /**/ 0, 2, 4, 10, 12, 6, 8, 14, // + 6, 10, 12, 0, 2, 4, 8, 14, /**/ 0, 6, 10, 12, 2, 4, 8, 14, // + 2, 6, 10, 12, 0, 4, 8, 14, /**/ 0, 2, 6, 10, 12, 4, 8, 14, // + 4, 6, 10, 12, 0, 2, 8, 14, /**/ 0, 4, 6, 10, 12, 2, 8, 14, // + 2, 4, 6, 10, 12, 0, 8, 14, /**/ 0, 2, 4, 6, 10, 12, 8, 14, // + 8, 10, 12, 0, 2, 4, 6, 14, /**/ 0, 8, 10, 12, 2, 4, 6, 14, // + 2, 8, 10, 12, 0, 4, 6, 14, /**/ 0, 2, 8, 10, 12, 4, 6, 14, // + 4, 8, 10, 12, 0, 2, 6, 14, /**/ 0, 4, 8, 10, 12, 2, 6, 14, // + 2, 4, 8, 10, 12, 0, 6, 14, /**/ 0, 2, 4, 8, 10, 12, 6, 14, // + 6, 8, 10, 12, 0, 2, 4, 14, /**/ 0, 6, 8, 10, 12, 2, 4, 14, // + 2, 6, 8, 10, 12, 0, 4, 14, /**/ 0, 2, 6, 8, 10, 12, 4, 14, // + 4, 6, 8, 10, 12, 0, 2, 14, /**/ 0, 4, 6, 8, 10, 12, 2, 14, // + 2, 4, 6, 8, 10, 12, 0, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 14, 0, 2, 4, 6, 8, 10, 12, /**/ 0, 14, 2, 4, 6, 8, 10, 12, // + 2, 14, 0, 4, 6, 8, 10, 12, /**/ 0, 2, 14, 4, 6, 8, 10, 12, // + 4, 14, 0, 2, 6, 8, 10, 12, /**/ 0, 4, 14, 2, 6, 8, 10, 12, // + 2, 4, 14, 0, 6, 8, 10, 12, /**/ 0, 2, 4, 14, 6, 8, 10, 12, // + 6, 14, 0, 2, 4, 8, 10, 12, /**/ 0, 6, 14, 2, 4, 8, 10, 12, // + 2, 6, 14, 0, 4, 8, 10, 12, /**/ 0, 2, 6, 14, 4, 8, 10, 12, // + 4, 6, 14, 0, 2, 8, 10, 12, /**/ 0, 4, 6, 14, 2, 8, 10, 12, // + 2, 4, 6, 14, 0, 8, 10, 12, /**/ 0, 2, 4, 6, 14, 8, 10, 12, // + 8, 14, 0, 2, 4, 6, 10, 12, /**/ 0, 8, 14, 2, 4, 6, 10, 12, // + 2, 8, 14, 0, 4, 6, 10, 12, /**/ 0, 2, 8, 14, 4, 6, 10, 12, // + 4, 8, 14, 0, 2, 6, 10, 12, /**/ 0, 4, 8, 14, 2, 6, 10, 12, // + 2, 4, 8, 14, 0, 6, 10, 12, /**/ 0, 2, 4, 8, 14, 6, 10, 12, // + 6, 8, 14, 0, 2, 4, 10, 12, /**/ 0, 6, 8, 14, 2, 4, 10, 12, // + 2, 6, 8, 14, 0, 4, 10, 12, /**/ 0, 2, 6, 8, 14, 4, 10, 12, // + 4, 6, 8, 14, 0, 2, 10, 12, /**/ 0, 4, 6, 8, 14, 2, 10, 12, // + 2, 4, 6, 8, 14, 0, 10, 12, /**/ 0, 2, 4, 6, 8, 14, 10, 12, // + 10, 14, 0, 2, 4, 6, 8, 12, /**/ 0, 10, 14, 2, 4, 6, 8, 12, // + 2, 10, 14, 0, 4, 6, 8, 12, /**/ 0, 2, 10, 14, 4, 6, 8, 12, // + 4, 10, 14, 0, 2, 6, 8, 12, /**/ 0, 4, 10, 14, 2, 6, 8, 12, // + 2, 4, 10, 14, 0, 6, 8, 12, /**/ 0, 2, 4, 10, 14, 6, 8, 12, // + 6, 10, 14, 0, 2, 4, 8, 12, /**/ 0, 6, 10, 14, 2, 4, 8, 12, // + 2, 6, 10, 14, 0, 4, 8, 12, /**/ 0, 2, 6, 10, 14, 4, 8, 12, // + 4, 6, 10, 14, 0, 2, 8, 12, /**/ 0, 4, 6, 10, 14, 2, 8, 12, // + 2, 4, 6, 10, 14, 0, 8, 12, /**/ 0, 2, 4, 6, 10, 14, 8, 12, // + 8, 10, 14, 0, 2, 4, 6, 12, /**/ 0, 8, 10, 14, 2, 4, 6, 12, // + 2, 8, 10, 14, 0, 4, 6, 12, /**/ 0, 2, 8, 10, 14, 4, 6, 12, // + 4, 8, 10, 14, 0, 2, 6, 12, /**/ 0, 4, 8, 10, 14, 2, 6, 12, // + 2, 4, 8, 10, 14, 0, 6, 12, /**/ 0, 2, 4, 8, 10, 14, 6, 12, // + 6, 8, 10, 14, 0, 2, 4, 12, /**/ 0, 6, 8, 10, 14, 2, 4, 12, // + 2, 6, 8, 10, 14, 0, 4, 12, /**/ 0, 2, 6, 8, 10, 14, 4, 12, // + 4, 6, 8, 10, 14, 0, 2, 12, /**/ 0, 4, 6, 8, 10, 14, 2, 12, // + 2, 4, 6, 8, 10, 14, 0, 12, /**/ 0, 2, 4, 6, 8, 10, 14, 12, // + 12, 14, 0, 2, 4, 6, 8, 10, /**/ 0, 12, 14, 2, 4, 6, 8, 10, // + 2, 12, 14, 0, 4, 6, 8, 10, /**/ 0, 2, 12, 14, 4, 6, 8, 10, // + 4, 12, 14, 0, 2, 6, 8, 10, /**/ 0, 4, 12, 14, 2, 6, 8, 10, // + 2, 4, 12, 14, 0, 6, 8, 10, /**/ 0, 2, 4, 12, 14, 6, 8, 10, // + 6, 12, 14, 0, 2, 4, 8, 10, /**/ 0, 6, 12, 14, 2, 4, 8, 10, // + 2, 6, 12, 14, 0, 4, 8, 10, /**/ 0, 2, 6, 12, 14, 4, 8, 10, // + 4, 6, 12, 14, 0, 2, 8, 10, /**/ 0, 4, 6, 12, 14, 2, 8, 10, // + 2, 4, 6, 12, 14, 0, 8, 10, /**/ 0, 2, 4, 6, 12, 14, 8, 10, // + 8, 12, 14, 0, 2, 4, 6, 10, /**/ 0, 8, 12, 14, 2, 4, 6, 10, // + 2, 8, 12, 14, 0, 4, 6, 10, /**/ 0, 2, 8, 12, 14, 4, 6, 10, // + 4, 8, 12, 14, 0, 2, 6, 10, /**/ 0, 4, 8, 12, 14, 2, 6, 10, // + 2, 4, 8, 12, 14, 0, 6, 10, /**/ 0, 2, 4, 8, 12, 14, 6, 10, // + 6, 8, 12, 14, 0, 2, 4, 10, /**/ 0, 6, 8, 12, 14, 2, 4, 10, // + 2, 6, 8, 12, 14, 0, 4, 10, /**/ 0, 2, 6, 8, 12, 14, 4, 10, // + 4, 6, 8, 12, 14, 0, 2, 10, /**/ 0, 4, 6, 8, 12, 14, 2, 10, // + 2, 4, 6, 8, 12, 14, 0, 10, /**/ 0, 2, 4, 6, 8, 12, 14, 10, // + 10, 12, 14, 0, 2, 4, 6, 8, /**/ 0, 10, 12, 14, 2, 4, 6, 8, // + 2, 10, 12, 14, 0, 4, 6, 8, /**/ 0, 2, 10, 12, 14, 4, 6, 8, // + 4, 10, 12, 14, 0, 2, 6, 8, /**/ 0, 4, 10, 12, 14, 2, 6, 8, // + 2, 4, 10, 12, 14, 0, 6, 8, /**/ 0, 2, 4, 10, 12, 14, 6, 8, // + 6, 10, 12, 14, 0, 2, 4, 8, /**/ 0, 6, 10, 12, 14, 2, 4, 8, // + 2, 6, 10, 12, 14, 0, 4, 8, /**/ 0, 2, 6, 10, 12, 14, 4, 8, // + 4, 6, 10, 12, 14, 0, 2, 8, /**/ 0, 4, 6, 10, 12, 14, 2, 8, // + 2, 4, 6, 10, 12, 14, 0, 8, /**/ 0, 2, 4, 6, 10, 12, 14, 8, // + 8, 10, 12, 14, 0, 2, 4, 6, /**/ 0, 8, 10, 12, 14, 2, 4, 6, // + 2, 8, 10, 12, 14, 0, 4, 6, /**/ 0, 2, 8, 10, 12, 14, 4, 6, // + 4, 8, 10, 12, 14, 0, 2, 6, /**/ 0, 4, 8, 10, 12, 14, 2, 6, // + 2, 4, 8, 10, 12, 14, 0, 6, /**/ 0, 2, 4, 8, 10, 12, 14, 6, // + 6, 8, 10, 12, 14, 0, 2, 4, /**/ 0, 6, 8, 10, 12, 14, 2, 4, // + 2, 6, 8, 10, 12, 14, 0, 4, /**/ 0, 2, 6, 8, 10, 12, 14, 4, // + 4, 6, 8, 10, 12, 14, 0, 2, /**/ 0, 4, 6, 8, 10, 12, 14, 2, // + 2, 4, 6, 8, 10, 12, 14, 0, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx = Load8Bytes(d8, table + mask_bits * 8); + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(hwy::SizeTag<2> /*tag*/, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd d; + const Repartition d8; + const Simd du; + + // NEON does not provide an equivalent of AVX2 permutevar, so we need byte + // indices for VTBL (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) static constexpr uint8_t table[256 * 8] = { + // PrintCompressNot16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 14, 0, // + 0, 4, 6, 8, 10, 12, 14, 2, /**/ 4, 6, 8, 10, 12, 14, 0, 2, // + 0, 2, 6, 8, 10, 12, 14, 4, /**/ 2, 6, 8, 10, 12, 14, 0, 4, // + 0, 6, 8, 10, 12, 14, 2, 4, /**/ 6, 8, 10, 12, 14, 0, 2, 4, // + 0, 2, 4, 8, 10, 12, 14, 6, /**/ 2, 4, 8, 10, 12, 14, 0, 6, // + 0, 4, 8, 10, 12, 14, 2, 6, /**/ 4, 8, 10, 12, 14, 0, 2, 6, // + 0, 2, 8, 10, 12, 14, 4, 6, /**/ 2, 8, 10, 12, 14, 0, 4, 6, // + 0, 8, 10, 12, 14, 2, 4, 6, /**/ 8, 10, 12, 14, 0, 2, 4, 6, // + 0, 2, 4, 6, 10, 12, 14, 8, /**/ 2, 4, 6, 10, 12, 14, 0, 8, // + 0, 4, 6, 10, 12, 14, 2, 8, /**/ 4, 6, 10, 12, 14, 0, 2, 8, // + 0, 2, 6, 10, 12, 14, 4, 8, /**/ 2, 6, 10, 12, 14, 0, 4, 8, // + 0, 6, 10, 12, 14, 2, 4, 8, /**/ 6, 10, 12, 14, 0, 2, 4, 8, // + 0, 2, 4, 10, 12, 14, 6, 8, /**/ 2, 4, 10, 12, 14, 0, 6, 8, // + 0, 4, 10, 12, 14, 2, 6, 8, /**/ 4, 10, 12, 14, 0, 2, 6, 8, // + 0, 2, 10, 12, 14, 4, 6, 8, /**/ 2, 10, 12, 14, 0, 4, 6, 8, // + 0, 10, 12, 14, 2, 4, 6, 8, /**/ 10, 12, 14, 0, 2, 4, 6, 8, // + 0, 2, 4, 6, 8, 12, 14, 10, /**/ 2, 4, 6, 8, 12, 14, 0, 10, // + 0, 4, 6, 8, 12, 14, 2, 10, /**/ 4, 6, 8, 12, 14, 0, 2, 10, // + 0, 2, 6, 8, 12, 14, 4, 10, /**/ 2, 6, 8, 12, 14, 0, 4, 10, // + 0, 6, 8, 12, 14, 2, 4, 10, /**/ 6, 8, 12, 14, 0, 2, 4, 10, // + 0, 2, 4, 8, 12, 14, 6, 10, /**/ 2, 4, 8, 12, 14, 0, 6, 10, // + 0, 4, 8, 12, 14, 2, 6, 10, /**/ 4, 8, 12, 14, 0, 2, 6, 10, // + 0, 2, 8, 12, 14, 4, 6, 10, /**/ 2, 8, 12, 14, 0, 4, 6, 10, // + 0, 8, 12, 14, 2, 4, 6, 10, /**/ 8, 12, 14, 0, 2, 4, 6, 10, // + 0, 2, 4, 6, 12, 14, 8, 10, /**/ 2, 4, 6, 12, 14, 0, 8, 10, // + 0, 4, 6, 12, 14, 2, 8, 10, /**/ 4, 6, 12, 14, 0, 2, 8, 10, // + 0, 2, 6, 12, 14, 4, 8, 10, /**/ 2, 6, 12, 14, 0, 4, 8, 10, // + 0, 6, 12, 14, 2, 4, 8, 10, /**/ 6, 12, 14, 0, 2, 4, 8, 10, // + 0, 2, 4, 12, 14, 6, 8, 10, /**/ 2, 4, 12, 14, 0, 6, 8, 10, // + 0, 4, 12, 14, 2, 6, 8, 10, /**/ 4, 12, 14, 0, 2, 6, 8, 10, // + 0, 2, 12, 14, 4, 6, 8, 10, /**/ 2, 12, 14, 0, 4, 6, 8, 10, // + 0, 12, 14, 2, 4, 6, 8, 10, /**/ 12, 14, 0, 2, 4, 6, 8, 10, // + 0, 2, 4, 6, 8, 10, 14, 12, /**/ 2, 4, 6, 8, 10, 14, 0, 12, // + 0, 4, 6, 8, 10, 14, 2, 12, /**/ 4, 6, 8, 10, 14, 0, 2, 12, // + 0, 2, 6, 8, 10, 14, 4, 12, /**/ 2, 6, 8, 10, 14, 0, 4, 12, // + 0, 6, 8, 10, 14, 2, 4, 12, /**/ 6, 8, 10, 14, 0, 2, 4, 12, // + 0, 2, 4, 8, 10, 14, 6, 12, /**/ 2, 4, 8, 10, 14, 0, 6, 12, // + 0, 4, 8, 10, 14, 2, 6, 12, /**/ 4, 8, 10, 14, 0, 2, 6, 12, // + 0, 2, 8, 10, 14, 4, 6, 12, /**/ 2, 8, 10, 14, 0, 4, 6, 12, // + 0, 8, 10, 14, 2, 4, 6, 12, /**/ 8, 10, 14, 0, 2, 4, 6, 12, // + 0, 2, 4, 6, 10, 14, 8, 12, /**/ 2, 4, 6, 10, 14, 0, 8, 12, // + 0, 4, 6, 10, 14, 2, 8, 12, /**/ 4, 6, 10, 14, 0, 2, 8, 12, // + 0, 2, 6, 10, 14, 4, 8, 12, /**/ 2, 6, 10, 14, 0, 4, 8, 12, // + 0, 6, 10, 14, 2, 4, 8, 12, /**/ 6, 10, 14, 0, 2, 4, 8, 12, // + 0, 2, 4, 10, 14, 6, 8, 12, /**/ 2, 4, 10, 14, 0, 6, 8, 12, // + 0, 4, 10, 14, 2, 6, 8, 12, /**/ 4, 10, 14, 0, 2, 6, 8, 12, // + 0, 2, 10, 14, 4, 6, 8, 12, /**/ 2, 10, 14, 0, 4, 6, 8, 12, // + 0, 10, 14, 2, 4, 6, 8, 12, /**/ 10, 14, 0, 2, 4, 6, 8, 12, // + 0, 2, 4, 6, 8, 14, 10, 12, /**/ 2, 4, 6, 8, 14, 0, 10, 12, // + 0, 4, 6, 8, 14, 2, 10, 12, /**/ 4, 6, 8, 14, 0, 2, 10, 12, // + 0, 2, 6, 8, 14, 4, 10, 12, /**/ 2, 6, 8, 14, 0, 4, 10, 12, // + 0, 6, 8, 14, 2, 4, 10, 12, /**/ 6, 8, 14, 0, 2, 4, 10, 12, // + 0, 2, 4, 8, 14, 6, 10, 12, /**/ 2, 4, 8, 14, 0, 6, 10, 12, // + 0, 4, 8, 14, 2, 6, 10, 12, /**/ 4, 8, 14, 0, 2, 6, 10, 12, // + 0, 2, 8, 14, 4, 6, 10, 12, /**/ 2, 8, 14, 0, 4, 6, 10, 12, // + 0, 8, 14, 2, 4, 6, 10, 12, /**/ 8, 14, 0, 2, 4, 6, 10, 12, // + 0, 2, 4, 6, 14, 8, 10, 12, /**/ 2, 4, 6, 14, 0, 8, 10, 12, // + 0, 4, 6, 14, 2, 8, 10, 12, /**/ 4, 6, 14, 0, 2, 8, 10, 12, // + 0, 2, 6, 14, 4, 8, 10, 12, /**/ 2, 6, 14, 0, 4, 8, 10, 12, // + 0, 6, 14, 2, 4, 8, 10, 12, /**/ 6, 14, 0, 2, 4, 8, 10, 12, // + 0, 2, 4, 14, 6, 8, 10, 12, /**/ 2, 4, 14, 0, 6, 8, 10, 12, // + 0, 4, 14, 2, 6, 8, 10, 12, /**/ 4, 14, 0, 2, 6, 8, 10, 12, // + 0, 2, 14, 4, 6, 8, 10, 12, /**/ 2, 14, 0, 4, 6, 8, 10, 12, // + 0, 14, 2, 4, 6, 8, 10, 12, /**/ 14, 0, 2, 4, 6, 8, 10, 12, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 0, 14, // + 0, 4, 6, 8, 10, 12, 2, 14, /**/ 4, 6, 8, 10, 12, 0, 2, 14, // + 0, 2, 6, 8, 10, 12, 4, 14, /**/ 2, 6, 8, 10, 12, 0, 4, 14, // + 0, 6, 8, 10, 12, 2, 4, 14, /**/ 6, 8, 10, 12, 0, 2, 4, 14, // + 0, 2, 4, 8, 10, 12, 6, 14, /**/ 2, 4, 8, 10, 12, 0, 6, 14, // + 0, 4, 8, 10, 12, 2, 6, 14, /**/ 4, 8, 10, 12, 0, 2, 6, 14, // + 0, 2, 8, 10, 12, 4, 6, 14, /**/ 2, 8, 10, 12, 0, 4, 6, 14, // + 0, 8, 10, 12, 2, 4, 6, 14, /**/ 8, 10, 12, 0, 2, 4, 6, 14, // + 0, 2, 4, 6, 10, 12, 8, 14, /**/ 2, 4, 6, 10, 12, 0, 8, 14, // + 0, 4, 6, 10, 12, 2, 8, 14, /**/ 4, 6, 10, 12, 0, 2, 8, 14, // + 0, 2, 6, 10, 12, 4, 8, 14, /**/ 2, 6, 10, 12, 0, 4, 8, 14, // + 0, 6, 10, 12, 2, 4, 8, 14, /**/ 6, 10, 12, 0, 2, 4, 8, 14, // + 0, 2, 4, 10, 12, 6, 8, 14, /**/ 2, 4, 10, 12, 0, 6, 8, 14, // + 0, 4, 10, 12, 2, 6, 8, 14, /**/ 4, 10, 12, 0, 2, 6, 8, 14, // + 0, 2, 10, 12, 4, 6, 8, 14, /**/ 2, 10, 12, 0, 4, 6, 8, 14, // + 0, 10, 12, 2, 4, 6, 8, 14, /**/ 10, 12, 0, 2, 4, 6, 8, 14, // + 0, 2, 4, 6, 8, 12, 10, 14, /**/ 2, 4, 6, 8, 12, 0, 10, 14, // + 0, 4, 6, 8, 12, 2, 10, 14, /**/ 4, 6, 8, 12, 0, 2, 10, 14, // + 0, 2, 6, 8, 12, 4, 10, 14, /**/ 2, 6, 8, 12, 0, 4, 10, 14, // + 0, 6, 8, 12, 2, 4, 10, 14, /**/ 6, 8, 12, 0, 2, 4, 10, 14, // + 0, 2, 4, 8, 12, 6, 10, 14, /**/ 2, 4, 8, 12, 0, 6, 10, 14, // + 0, 4, 8, 12, 2, 6, 10, 14, /**/ 4, 8, 12, 0, 2, 6, 10, 14, // + 0, 2, 8, 12, 4, 6, 10, 14, /**/ 2, 8, 12, 0, 4, 6, 10, 14, // + 0, 8, 12, 2, 4, 6, 10, 14, /**/ 8, 12, 0, 2, 4, 6, 10, 14, // + 0, 2, 4, 6, 12, 8, 10, 14, /**/ 2, 4, 6, 12, 0, 8, 10, 14, // + 0, 4, 6, 12, 2, 8, 10, 14, /**/ 4, 6, 12, 0, 2, 8, 10, 14, // + 0, 2, 6, 12, 4, 8, 10, 14, /**/ 2, 6, 12, 0, 4, 8, 10, 14, // + 0, 6, 12, 2, 4, 8, 10, 14, /**/ 6, 12, 0, 2, 4, 8, 10, 14, // + 0, 2, 4, 12, 6, 8, 10, 14, /**/ 2, 4, 12, 0, 6, 8, 10, 14, // + 0, 4, 12, 2, 6, 8, 10, 14, /**/ 4, 12, 0, 2, 6, 8, 10, 14, // + 0, 2, 12, 4, 6, 8, 10, 14, /**/ 2, 12, 0, 4, 6, 8, 10, 14, // + 0, 12, 2, 4, 6, 8, 10, 14, /**/ 12, 0, 2, 4, 6, 8, 10, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 0, 12, 14, // + 0, 4, 6, 8, 10, 2, 12, 14, /**/ 4, 6, 8, 10, 0, 2, 12, 14, // + 0, 2, 6, 8, 10, 4, 12, 14, /**/ 2, 6, 8, 10, 0, 4, 12, 14, // + 0, 6, 8, 10, 2, 4, 12, 14, /**/ 6, 8, 10, 0, 2, 4, 12, 14, // + 0, 2, 4, 8, 10, 6, 12, 14, /**/ 2, 4, 8, 10, 0, 6, 12, 14, // + 0, 4, 8, 10, 2, 6, 12, 14, /**/ 4, 8, 10, 0, 2, 6, 12, 14, // + 0, 2, 8, 10, 4, 6, 12, 14, /**/ 2, 8, 10, 0, 4, 6, 12, 14, // + 0, 8, 10, 2, 4, 6, 12, 14, /**/ 8, 10, 0, 2, 4, 6, 12, 14, // + 0, 2, 4, 6, 10, 8, 12, 14, /**/ 2, 4, 6, 10, 0, 8, 12, 14, // + 0, 4, 6, 10, 2, 8, 12, 14, /**/ 4, 6, 10, 0, 2, 8, 12, 14, // + 0, 2, 6, 10, 4, 8, 12, 14, /**/ 2, 6, 10, 0, 4, 8, 12, 14, // + 0, 6, 10, 2, 4, 8, 12, 14, /**/ 6, 10, 0, 2, 4, 8, 12, 14, // + 0, 2, 4, 10, 6, 8, 12, 14, /**/ 2, 4, 10, 0, 6, 8, 12, 14, // + 0, 4, 10, 2, 6, 8, 12, 14, /**/ 4, 10, 0, 2, 6, 8, 12, 14, // + 0, 2, 10, 4, 6, 8, 12, 14, /**/ 2, 10, 0, 4, 6, 8, 12, 14, // + 0, 10, 2, 4, 6, 8, 12, 14, /**/ 10, 0, 2, 4, 6, 8, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 0, 10, 12, 14, // + 0, 4, 6, 8, 2, 10, 12, 14, /**/ 4, 6, 8, 0, 2, 10, 12, 14, // + 0, 2, 6, 8, 4, 10, 12, 14, /**/ 2, 6, 8, 0, 4, 10, 12, 14, // + 0, 6, 8, 2, 4, 10, 12, 14, /**/ 6, 8, 0, 2, 4, 10, 12, 14, // + 0, 2, 4, 8, 6, 10, 12, 14, /**/ 2, 4, 8, 0, 6, 10, 12, 14, // + 0, 4, 8, 2, 6, 10, 12, 14, /**/ 4, 8, 0, 2, 6, 10, 12, 14, // + 0, 2, 8, 4, 6, 10, 12, 14, /**/ 2, 8, 0, 4, 6, 10, 12, 14, // + 0, 8, 2, 4, 6, 10, 12, 14, /**/ 8, 0, 2, 4, 6, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 0, 8, 10, 12, 14, // + 0, 4, 6, 2, 8, 10, 12, 14, /**/ 4, 6, 0, 2, 8, 10, 12, 14, // + 0, 2, 6, 4, 8, 10, 12, 14, /**/ 2, 6, 0, 4, 8, 10, 12, 14, // + 0, 6, 2, 4, 8, 10, 12, 14, /**/ 6, 0, 2, 4, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 0, 6, 8, 10, 12, 14, // + 0, 4, 2, 6, 8, 10, 12, 14, /**/ 4, 0, 2, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 0, 4, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx = Load8Bytes(d8, table + mask_bits * 8); + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IdxFromBits(hwy::SizeTag<4> /*tag*/, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[16 * 16] = { + // PrintCompress32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, // + 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, // + 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(hwy::SizeTag<4> /*tag*/, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[16 * 16] = { + // PrintCompressNot32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, + 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15}; + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +#if HWY_HAVE_INTEGER64 || HWY_HAVE_FLOAT64 + +template +HWY_INLINE Vec128 IdxFromBits(hwy::SizeTag<8> /*tag*/, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[64] = { + // PrintCompress64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(hwy::SizeTag<8> /*tag*/, + uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[4 * 16] = { + // PrintCompressNot64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +#endif + +// Helper function called by both Compress and CompressStore - avoids a +// redundant BitsFromMask in the latter. +template +HWY_INLINE Vec128 Compress(Vec128 v, uint64_t mask_bits) { + const auto idx = + detail::IdxFromBits(hwy::SizeTag(), mask_bits); + using D = DFromV; + const RebindToSigned di; + return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +template +HWY_INLINE Vec128 CompressNot(Vec128 v, uint64_t mask_bits) { + const auto idx = + detail::IdxFromNotBits(hwy::SizeTag(), mask_bits); + using D = DFromV; + const RebindToSigned di; + return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +} // namespace detail + +// Single lane: no-op +template +HWY_API Vec128 Compress(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + // If mask[1] = 1 and mask[0] = 0, then swap both halves, else keep. + const DFromV d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskL, maskH); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 byte lanes +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + const DFromV d; + return detail::Compress(v, BitsFromMask(d, mask)); +} + +// Single lane: no-op +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // If mask[1] = 0 and mask[0] = 1, then swap both halves, else keep. + const DFromV d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskH, maskL); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 byte lanes +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + const DFromV d; + // For partial vectors, we cannot pull the Not() into the table because + // BitsFromMask clears the upper bits. + if (N < 16 / sizeof(T)) { + return detail::Compress(v, BitsFromMask(d, Not(mask))); + } + return detail::CompressNot(v, BitsFromMask(d, mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +// ------------------------------ CompressBits + +template +HWY_INLINE Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::Compress(v, mask_bits); +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(VFromD v, MFromD mask, D d, + TFromD* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = BitsFromMask(d, mask); + StoreU(detail::Compress(v, mask_bits), d, unaligned); + return PopCount(mask_bits); +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; // so we can support fp16/bf16 + const uint64_t mask_bits = BitsFromMask(d, m); + const size_t count = PopCount(mask_bits); + const MFromD store_mask = RebindMask(d, FirstN(du, count)); + const VFromD compressed = + detail::Compress(BitCast(du, v), mask_bits); + BlendedStore(BitCast(d, compressed), store_mask, d, unaligned); + return count; +} + +// ------------------------------ CompressBitsStore + +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (d.MaxLanes() + 7) / 8; + CopyBytes(bits, &mask_bits); + if (d.MaxLanes() < 8) { + mask_bits &= (1ull << d.MaxLanes()) - 1; + } + + StoreU(detail::Compress(v, mask_bits), d, unaligned); + return PopCount(mask_bits); +} + +// ------------------------------ LoadInterleaved2 + +// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +namespace detail { + +#define HWY_NEON_BUILD_TPL_HWY_LOAD_INT +#define HWY_NEON_BUILD_ARG_HWY_LOAD_INT from + +#if HWY_ARCH_ARM_A64 +#define HWY_IF_LOAD_INT(D) \ + HWY_IF_V_SIZE_GT_D(D, 4), HWY_NEON_IF_NOT_EMULATED_D(D) +#define HWY_NEON_DEF_FUNCTION_LOAD_INT(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_ALL_TYPES(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_BFLOAT_16(name, prefix, infix, args) +#else +// Exclude 64x2 and f64x1, which are only supported on aarch64; also exclude any +// emulated types. +#define HWY_IF_LOAD_INT(D) \ + HWY_IF_V_SIZE_GT_D(D, 4), HWY_NEON_IF_NOT_EMULATED_D(D), \ + hwy::EnableIf<(HWY_MAX_LANES_D(D) == 1 || sizeof(TFromD) < 8)>* = \ + nullptr +#define HWY_NEON_DEF_FUNCTION_LOAD_INT(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_BFLOAT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 1, name, prefix, infix, u64, args) +#endif // HWY_ARCH_ARM_A64 + +// Must return raw tuple because Tuple2 lack a ctor, and we cannot use +// brace-initialization in HWY_NEON_DEF_FUNCTION because some functions return +// void. +#define HWY_NEON_BUILD_RET_HWY_LOAD_INT(type, size) \ + decltype(Tuple2().raw) +// Tuple tag arg allows overloading (cannot just overload on return type) +#define HWY_NEON_BUILD_PARAM_HWY_LOAD_INT(type, size) \ + const NativeLaneType*from, Tuple2 +HWY_NEON_DEF_FUNCTION_LOAD_INT(LoadInterleaved2, vld2, _, HWY_LOAD_INT) +#undef HWY_NEON_BUILD_RET_HWY_LOAD_INT +#undef HWY_NEON_BUILD_PARAM_HWY_LOAD_INT + +#define HWY_NEON_BUILD_RET_HWY_LOAD_INT(type, size) \ + decltype(Tuple3().raw) +#define HWY_NEON_BUILD_PARAM_HWY_LOAD_INT(type, size) \ + const NativeLaneType*from, Tuple3 +HWY_NEON_DEF_FUNCTION_LOAD_INT(LoadInterleaved3, vld3, _, HWY_LOAD_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_LOAD_INT +#undef HWY_NEON_BUILD_RET_HWY_LOAD_INT + +#define HWY_NEON_BUILD_RET_HWY_LOAD_INT(type, size) \ + decltype(Tuple4().raw) +#define HWY_NEON_BUILD_PARAM_HWY_LOAD_INT(type, size) \ + const NativeLaneType*from, Tuple4 +HWY_NEON_DEF_FUNCTION_LOAD_INT(LoadInterleaved4, vld4, _, HWY_LOAD_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_LOAD_INT +#undef HWY_NEON_BUILD_RET_HWY_LOAD_INT + +#undef HWY_NEON_DEF_FUNCTION_LOAD_INT +#undef HWY_NEON_BUILD_TPL_HWY_LOAD_INT +#undef HWY_NEON_BUILD_ARG_HWY_LOAD_INT + +} // namespace detail + +template > +HWY_API void LoadInterleaved2(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1) { + auto raw = detail::LoadInterleaved2(detail::NativeLanePointer(unaligned), + detail::Tuple2()); + v0 = VFromD(raw.val[0]); + v1 = VFromD(raw.val[1]); +} + +// <= 32 bits: avoid loading more than N bytes by copying to buffer +template > +HWY_API void LoadInterleaved2(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1) { + // The smallest vector registers are 64-bits and we want space for two. + alignas(16) T buf[2 * 8 / sizeof(T)] = {}; + CopyBytes(unaligned, buf); + auto raw = detail::LoadInterleaved2(detail::NativeLanePointer(buf), + detail::Tuple2()); + v0 = VFromD(raw.val[0]); + v1 = VFromD(raw.val[1]); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template , HWY_IF_T_SIZE(T, 8), + HWY_NEON_IF_NOT_EMULATED_D(D)> +HWY_API void LoadInterleaved2(D d, T* HWY_RESTRICT unaligned, Vec128& v0, + Vec128& v1) { + const Half dh; + VFromD v00, v10, v01, v11; + LoadInterleaved2(dh, detail::NativeLanePointer(unaligned), v00, v10); + LoadInterleaved2(dh, detail::NativeLanePointer(unaligned + 2), v01, v11); + v0 = Combine(d, v01, v00); + v1 = Combine(d, v11, v10); +} +#endif // HWY_ARCH_ARM_V7 + +// ------------------------------ LoadInterleaved3 + +template > +HWY_API void LoadInterleaved3(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + auto raw = detail::LoadInterleaved3(detail::NativeLanePointer(unaligned), + detail::Tuple3()); + v0 = VFromD(raw.val[0]); + v1 = VFromD(raw.val[1]); + v2 = VFromD(raw.val[2]); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template > +HWY_API void LoadInterleaved3(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + // The smallest vector registers are 64-bits and we want space for three. + alignas(16) T buf[3 * 8 / sizeof(T)] = {}; + CopyBytes(unaligned, buf); + auto raw = detail::LoadInterleaved3(detail::NativeLanePointer(buf), + detail::Tuple3()); + v0 = VFromD(raw.val[0]); + v1 = VFromD(raw.val[1]); + v2 = VFromD(raw.val[2]); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template , HWY_IF_T_SIZE(T, 8), + HWY_NEON_IF_NOT_EMULATED_D(D)> +HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, + Vec128& v0, Vec128& v1, Vec128& v2) { + const Half dh; + VFromD v00, v10, v20, v01, v11, v21; + LoadInterleaved3(dh, detail::NativeLanePointer(unaligned), v00, v10, v20); + LoadInterleaved3(dh, detail::NativeLanePointer(unaligned + 3), v01, v11, v21); + v0 = Combine(d, v01, v00); + v1 = Combine(d, v11, v10); + v2 = Combine(d, v21, v20); +} +#endif // HWY_ARCH_ARM_V7 + +// ------------------------------ LoadInterleaved4 + +template > +HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + auto raw = detail::LoadInterleaved4(detail::NativeLanePointer(unaligned), + detail::Tuple4()); + v0 = VFromD(raw.val[0]); + v1 = VFromD(raw.val[1]); + v2 = VFromD(raw.val[2]); + v3 = VFromD(raw.val[3]); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template > +HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + alignas(16) T buf[4 * 8 / sizeof(T)] = {}; + CopyBytes(unaligned, buf); + auto raw = detail::LoadInterleaved4(detail::NativeLanePointer(buf), + detail::Tuple4()); + v0 = VFromD(raw.val[0]); + v1 = VFromD(raw.val[1]); + v2 = VFromD(raw.val[2]); + v3 = VFromD(raw.val[3]); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template , HWY_IF_T_SIZE(T, 8), + HWY_NEON_IF_NOT_EMULATED_D(D)> +HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, + Vec128& v0, Vec128& v1, Vec128& v2, + Vec128& v3) { + const Half dh; + VFromD v00, v10, v20, v30, v01, v11, v21, v31; + LoadInterleaved4(dh, detail::NativeLanePointer(unaligned), v00, v10, v20, + v30); + LoadInterleaved4(dh, detail::NativeLanePointer(unaligned + 4), v01, v11, v21, + v31); + v0 = Combine(d, v01, v00); + v1 = Combine(d, v11, v10); + v2 = Combine(d, v21, v20); + v3 = Combine(d, v31, v30); +} +#endif // HWY_ARCH_ARM_V7 + +#undef HWY_IF_LOAD_INT + +// ------------------------------ StoreInterleaved2 + +namespace detail { +#define HWY_NEON_BUILD_TPL_HWY_STORE_INT +#define HWY_NEON_BUILD_RET_HWY_STORE_INT(type, size) void +#define HWY_NEON_BUILD_ARG_HWY_STORE_INT to, tup.raw + +#if HWY_ARCH_ARM_A64 +#define HWY_IF_STORE_INT(D) \ + HWY_IF_V_SIZE_GT_D(D, 4), HWY_NEON_IF_NOT_EMULATED_D(D) +#define HWY_NEON_DEF_FUNCTION_STORE_INT(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_ALL_TYPES(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_BFLOAT_16(name, prefix, infix, args) +#else +// Exclude 64x2 and f64x1, which are only supported on aarch64; also exclude any +// emulated types. +#define HWY_IF_STORE_INT(D) \ + HWY_IF_V_SIZE_GT_D(D, 4), HWY_NEON_IF_NOT_EMULATED_D(D), \ + hwy::EnableIf<(HWY_MAX_LANES_D(D) == 1 || sizeof(TFromD) < 8)>* = \ + nullptr +#define HWY_NEON_DEF_FUNCTION_STORE_INT(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_INT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_UINT_8_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_BFLOAT_16(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION_FLOAT_16_32(name, prefix, infix, args) \ + HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args) \ + HWY_NEON_DEF_FUNCTION(uint64, 1, name, prefix, infix, u64, args) +#endif // HWY_ARCH_ARM_A64 + +#define HWY_NEON_BUILD_PARAM_HWY_STORE_INT(type, size) \ + Tuple2 tup, NativeLaneType*to +HWY_NEON_DEF_FUNCTION_STORE_INT(StoreInterleaved2, vst2, _, HWY_STORE_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_STORE_INT + +#define HWY_NEON_BUILD_PARAM_HWY_STORE_INT(type, size) \ + Tuple3 tup, NativeLaneType*to +HWY_NEON_DEF_FUNCTION_STORE_INT(StoreInterleaved3, vst3, _, HWY_STORE_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_STORE_INT + +#define HWY_NEON_BUILD_PARAM_HWY_STORE_INT(type, size) \ + Tuple4 tup, NativeLaneType*to +HWY_NEON_DEF_FUNCTION_STORE_INT(StoreInterleaved4, vst4, _, HWY_STORE_INT) +#undef HWY_NEON_BUILD_PARAM_HWY_STORE_INT + +#undef HWY_NEON_DEF_FUNCTION_STORE_INT +#undef HWY_NEON_BUILD_TPL_HWY_STORE_INT +#undef HWY_NEON_BUILD_RET_HWY_STORE_INT +#undef HWY_NEON_BUILD_ARG_HWY_STORE_INT +} // namespace detail + +template > +HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, + T* HWY_RESTRICT unaligned) { + detail::Tuple2 tup = {{{v0.raw, v1.raw}}}; + detail::StoreInterleaved2(tup, detail::NativeLanePointer(unaligned)); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template > +HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, + T* HWY_RESTRICT unaligned) { + alignas(16) T buf[2 * 8 / sizeof(T)]; + detail::Tuple2 tup = {{{v0.raw, v1.raw}}}; + detail::StoreInterleaved2(tup, detail::NativeLanePointer(buf)); + CopyBytes(buf, unaligned); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template , HWY_IF_T_SIZE(T, 8), + HWY_NEON_IF_NOT_EMULATED_D(D)> +HWY_API void StoreInterleaved2(Vec128 v0, Vec128 v1, D d, + T* HWY_RESTRICT unaligned) { + const Half dh; + StoreInterleaved2(LowerHalf(dh, v0), LowerHalf(dh, v1), dh, + detail::NativeLanePointer(unaligned)); + StoreInterleaved2(UpperHalf(dh, v0), UpperHalf(dh, v1), dh, + detail::NativeLanePointer(unaligned + 2)); +} +#endif // HWY_ARCH_ARM_V7 + +// ------------------------------ StoreInterleaved3 + +template > +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + T* HWY_RESTRICT unaligned) { + detail::Tuple3 tup = {{{v0.raw, v1.raw, v2.raw}}}; + detail::StoreInterleaved3(tup, detail::NativeLanePointer(unaligned)); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template > +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + T* HWY_RESTRICT unaligned) { + alignas(16) T buf[3 * 8 / sizeof(T)]; + detail::Tuple3 tup = {{{v0.raw, v1.raw, v2.raw}}}; + detail::StoreInterleaved3(tup, detail::NativeLanePointer(buf)); + CopyBytes(buf, unaligned); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template , HWY_IF_T_SIZE(T, 8), + HWY_NEON_IF_NOT_EMULATED_D(D)> +HWY_API void StoreInterleaved3(Vec128 v0, Vec128 v1, Vec128 v2, D d, + T* HWY_RESTRICT unaligned) { + const Half dh; + StoreInterleaved3(LowerHalf(dh, v0), LowerHalf(dh, v1), LowerHalf(dh, v2), dh, + detail::NativeLanePointer(unaligned)); + StoreInterleaved3(UpperHalf(dh, v0), UpperHalf(dh, v1), UpperHalf(dh, v2), dh, + detail::NativeLanePointer(unaligned + 3)); +} +#endif // HWY_ARCH_ARM_V7 + +// ------------------------------ StoreInterleaved4 + +template > +HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, + VFromD v3, D d, T* HWY_RESTRICT unaligned) { + detail::Tuple4 tup = {{{v0.raw, v1.raw, v2.raw, v3.raw}}}; + detail::StoreInterleaved4(tup, detail::NativeLanePointer(unaligned)); +} + +// <= 32 bits: avoid writing more than N bytes by copying to buffer +template > +HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, + VFromD v3, D d, T* HWY_RESTRICT unaligned) { + alignas(16) T buf[4 * 8 / sizeof(T)]; + detail::Tuple4 tup = {{{v0.raw, v1.raw, v2.raw, v3.raw}}}; + detail::StoreInterleaved4(tup, detail::NativeLanePointer(buf)); + CopyBytes(buf, unaligned); +} + +#if HWY_ARCH_ARM_V7 +// 64x2: split into two 64x1 +template , HWY_IF_T_SIZE(T, 8), + HWY_NEON_IF_NOT_EMULATED_D(D)> +HWY_API void StoreInterleaved4(Vec128 v0, Vec128 v1, Vec128 v2, + Vec128 v3, D d, T* HWY_RESTRICT unaligned) { + const Half dh; + StoreInterleaved4(LowerHalf(dh, v0), LowerHalf(dh, v1), LowerHalf(dh, v2), + LowerHalf(dh, v3), dh, + detail::NativeLanePointer(unaligned)); + StoreInterleaved4(UpperHalf(dh, v0), UpperHalf(dh, v1), UpperHalf(dh, v2), + UpperHalf(dh, v3), dh, + detail::NativeLanePointer(unaligned + 4)); +} +#endif // HWY_ARCH_ARM_V7 + +#undef HWY_IF_STORE_INT + +// Fall back on generic Load/StoreInterleaved[234] for any emulated types. +// Requires HWY_GENERIC_IF_EMULATED_D mirrors HWY_NEON_IF_EMULATED_D. + +// ------------------------------ Additional mask logical operations +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + return mask; +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const FixedTag d; + const auto vmask = VecFromMask(d, mask); + return MaskFromVec(Or(vmask, InterleaveLower(vmask, vmask))); +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const Simd d; + const auto vmask = VecFromMask(d, mask); + const auto neg_vmask = + ResizeBitCast(d, Neg(ResizeBitCast(Full64(), vmask))); + return MaskFromVec(Or(vmask, neg_vmask)); +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const Full128 d; + const Repartition di64; + + auto vmask = BitCast(di64, VecFromMask(d, mask)); + vmask = Or(vmask, Neg(vmask)); + + // Copy the sign bit of the first int64_t lane to the second int64_t lane + const auto vmask2 = BroadcastSignBit(InterleaveLower(Zero(di64), vmask)); + return MaskFromVec(BitCast(d, Or(vmask, vmask2))); +} + +template +HWY_API Mask128 SetBeforeFirst(Mask128 mask) { + return Not(SetAtOrAfterFirst(mask)); +} + +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + return mask; +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const FixedTag d; + const RebindToSigned di; + + const auto vmask = BitCast(di, VecFromMask(d, mask)); + const auto zero = Zero(di); + const auto vmask2 = VecFromMask(di, InterleaveLower(zero, vmask) == zero); + return MaskFromVec(BitCast(d, And(vmask, vmask2))); +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const Simd d; + const RebindToSigned di; + + const auto vmask = ResizeBitCast(Full64(), VecFromMask(d, mask)); + const auto only_first_vmask = + BitCast(d, Neg(ResizeBitCast(di, And(vmask, Neg(vmask))))); + return MaskFromVec(only_first_vmask); +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const Full128 d; + const RebindToSigned di; + const Repartition di64; + + const auto zero = Zero(di64); + const auto vmask = BitCast(di64, VecFromMask(d, mask)); + const auto vmask2 = VecFromMask(di64, InterleaveLower(zero, vmask) == zero); + const auto only_first_vmask = Neg(BitCast(di, And(vmask, Neg(vmask)))); + return MaskFromVec(BitCast(d, And(only_first_vmask, BitCast(di, vmask2)))); +} + +template +HWY_API Mask128 SetAtOrBeforeFirst(Mask128 /*mask*/) { + const FixedTag d; + const RebindToSigned di; + using TI = MakeSigned; + + return RebindMask(d, MaskFromVec(Set(di, TI(-1)))); +} +template +HWY_API Mask128 SetAtOrBeforeFirst(Mask128 mask) { + const Simd d; + return SetBeforeFirst(MaskFromVec(ShiftLeftLanes<1>(VecFromMask(d, mask)))); +} + +// ------------------------------ Lt128 + +template +HWY_INLINE MFromD Lt128(D d, VFromD a, VFromD b) { + static_assert(IsSame, uint64_t>(), "T must be u64"); + // Truth table of Eq and Lt for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const MFromD eqHL = Eq(a, b); + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + // We need to bring cL to the upper lane/bit corresponding to cH. Comparing + // the result of InterleaveUpper/Lower requires 9 ops, whereas shifting the + // comparison result leftwards requires only 4. IfThenElse compiles to the + // same code as OrAnd(). + const VFromD ltLx = DupEven(ltHL); + const VFromD outHx = IfThenElse(eqHL, ltLx, ltHL); + return MaskFromVec(DupOdd(outHx)); +} + +template +HWY_INLINE MFromD Lt128Upper(D d, VFromD a, VFromD b) { + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + return MaskFromVec(InterleaveUpper(d, ltHL, ltHL)); +} + +// ------------------------------ Eq128 + +template +HWY_INLINE MFromD Eq128(D d, VFromD a, VFromD b) { + static_assert(IsSame, uint64_t>(), "T must be u64"); + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + return MaskFromVec(And(Reverse2(d, eqHL), eqHL)); +} + +template +HWY_INLINE MFromD Eq128Upper(D d, VFromD a, VFromD b) { + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + return MaskFromVec(InterleaveUpper(d, eqHL, eqHL)); +} + +// ------------------------------ Ne128 + +template +HWY_INLINE MFromD Ne128(D d, VFromD a, VFromD b) { + static_assert(IsSame, uint64_t>(), "T must be u64"); + const VFromD neHL = VecFromMask(d, Ne(a, b)); + return MaskFromVec(Or(Reverse2(d, neHL), neHL)); +} + +template +HWY_INLINE MFromD Ne128Upper(D d, VFromD a, VFromD b) { + const VFromD neHL = VecFromMask(d, Ne(a, b)); + return MaskFromVec(InterleaveUpper(d, neHL, neHL)); +} + +// ------------------------------ Min128, Max128 (Lt128) + +// Without a native OddEven, it seems infeasible to go faster than Lt128. +template +HWY_INLINE VFromD Min128(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128(d, a, b), a, b); +} + +template +HWY_INLINE VFromD Max128(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128(d, b, a), a, b); +} + +template +HWY_INLINE VFromD Min128Upper(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template +HWY_INLINE VFromD Max128Upper(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +// -------------------- LeadingZeroCount, TrailingZeroCount, HighestSetBitIndex + +#ifdef HWY_NATIVE_LEADING_ZERO_COUNT +#undef HWY_NATIVE_LEADING_ZERO_COUNT +#else +#define HWY_NATIVE_LEADING_ZERO_COUNT +#endif + +HWY_NEON_DEF_FUNCTION_INT_8_16_32(LeadingZeroCount, vclz, _, 1) +HWY_NEON_DEF_FUNCTION_UINT_8_16_32(LeadingZeroCount, vclz, _, 1) + +template )> +HWY_API V LeadingZeroCount(V v) { + const DFromV d; + const RebindToUnsigned du; + const Repartition du32; + + const auto v_k32 = BitCast(du32, Set(du, 32)); + const auto v_u32_lzcnt = LeadingZeroCount(BitCast(du32, v)) + v_k32; + const auto v_u32_lo_lzcnt = + And(v_u32_lzcnt, BitCast(du32, Set(du, 0xFFFFFFFFu))); + const auto v_u32_hi_lzcnt = + BitCast(du32, ShiftRight<32>(BitCast(du, v_u32_lzcnt))); + + return BitCast( + d, IfThenElse(v_u32_hi_lzcnt == v_k32, v_u32_lo_lzcnt, v_u32_hi_lzcnt)); +} + +template +HWY_API V HighestSetBitIndex(V v) { + const DFromV d; + using T = TFromD; + return BitCast(d, Set(d, T{sizeof(T) * 8 - 1}) - LeadingZeroCount(v)); +} + +template +HWY_API V TrailingZeroCount(V v) { + return LeadingZeroCount(ReverseBits(v)); +} + +template +HWY_API V TrailingZeroCount(V v) { + const DFromV d; + const Repartition du8; + return LeadingZeroCount( + ReverseLaneBytes(BitCast(d, ReverseBits(BitCast(du8, v))))); +} + +namespace detail { // for code folding +#if HWY_ARCH_ARM_V7 +#undef vuzp1_s8 +#undef vuzp1_u8 +#undef vuzp1_s16 +#undef vuzp1_u16 +#undef vuzp1_s32 +#undef vuzp1_u32 +#undef vuzp1_f32 +#undef vuzp1q_s8 +#undef vuzp1q_u8 +#undef vuzp1q_s16 +#undef vuzp1q_u16 +#undef vuzp1q_s32 +#undef vuzp1q_u32 +#undef vuzp1q_f32 +#undef vuzp2_s8 +#undef vuzp2_u8 +#undef vuzp2_s16 +#undef vuzp2_u16 +#undef vuzp2_s32 +#undef vuzp2_u32 +#undef vuzp2_f32 +#undef vuzp2q_s8 +#undef vuzp2q_u8 +#undef vuzp2q_s16 +#undef vuzp2q_u16 +#undef vuzp2q_s32 +#undef vuzp2q_u32 +#undef vuzp2q_f32 +#undef vzip1_s8 +#undef vzip1_u8 +#undef vzip1_s16 +#undef vzip1_u16 +#undef vzip1_s32 +#undef vzip1_u32 +#undef vzip1_f32 +#undef vzip1q_s8 +#undef vzip1q_u8 +#undef vzip1q_s16 +#undef vzip1q_u16 +#undef vzip1q_s32 +#undef vzip1q_u32 +#undef vzip1q_f32 +#undef vzip2_s8 +#undef vzip2_u8 +#undef vzip2_s16 +#undef vzip2_u16 +#undef vzip2_s32 +#undef vzip2_u32 +#undef vzip2_f32 +#undef vzip2q_s8 +#undef vzip2q_u8 +#undef vzip2q_s16 +#undef vzip2q_u16 +#undef vzip2q_s32 +#undef vzip2q_u32 +#undef vzip2q_f32 +#endif + +#undef HWY_NEON_BUILD_ARG_1 +#undef HWY_NEON_BUILD_ARG_2 +#undef HWY_NEON_BUILD_ARG_3 +#undef HWY_NEON_BUILD_PARAM_1 +#undef HWY_NEON_BUILD_PARAM_2 +#undef HWY_NEON_BUILD_PARAM_3 +#undef HWY_NEON_BUILD_RET_1 +#undef HWY_NEON_BUILD_RET_2 +#undef HWY_NEON_BUILD_RET_3 +#undef HWY_NEON_BUILD_TPL_1 +#undef HWY_NEON_BUILD_TPL_2 +#undef HWY_NEON_BUILD_TPL_3 +#undef HWY_NEON_DEF_FUNCTION +#undef HWY_NEON_DEF_FUNCTION_ALL_FLOATS +#undef HWY_NEON_DEF_FUNCTION_ALL_TYPES +#undef HWY_NEON_DEF_FUNCTION_BFLOAT_16 +#undef HWY_NEON_DEF_FUNCTION_FLOAT_16 +#undef HWY_NEON_DEF_FUNCTION_FLOAT_16_32 +#undef HWY_NEON_DEF_FUNCTION_FLOAT_32 +#undef HWY_NEON_DEF_FUNCTION_FLOAT_64 +#undef HWY_NEON_DEF_FUNCTION_FULL_UI +#undef HWY_NEON_DEF_FUNCTION_FULL_UI_64 +#undef HWY_NEON_DEF_FUNCTION_FULL_UIF_64 +#undef HWY_NEON_DEF_FUNCTION_INT_16 +#undef HWY_NEON_DEF_FUNCTION_INT_32 +#undef HWY_NEON_DEF_FUNCTION_INT_64 +#undef HWY_NEON_DEF_FUNCTION_INT_8 +#undef HWY_NEON_DEF_FUNCTION_INT_8_16_32 +#undef HWY_NEON_DEF_FUNCTION_INTS +#undef HWY_NEON_DEF_FUNCTION_INTS_UINTS +#undef HWY_NEON_DEF_FUNCTION_UI_8_16_32 +#undef HWY_NEON_DEF_FUNCTION_UIF_64 +#undef HWY_NEON_DEF_FUNCTION_UIF_8_16_32 +#undef HWY_NEON_DEF_FUNCTION_UINT_16 +#undef HWY_NEON_DEF_FUNCTION_UINT_32 +#undef HWY_NEON_DEF_FUNCTION_UINT_64 +#undef HWY_NEON_DEF_FUNCTION_UINT_8 +#undef HWY_NEON_DEF_FUNCTION_UINT_8_16_32 +#undef HWY_NEON_DEF_FUNCTION_UINTS +#undef HWY_NEON_EVAL +#undef HWY_NEON_IF_EMULATED_D +#undef HWY_NEON_IF_NOT_EMULATED_D +} // namespace detail + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/lib/highway/hwy/ops/arm_sve-inl.h b/lib/highway/hwy/ops/arm_sve-inl.h new file mode 100644 index 00000000000..c90bf19dde5 --- /dev/null +++ b/lib/highway/hwy/ops/arm_sve-inl.h @@ -0,0 +1,7166 @@ +// Copyright 2021 Google LLC +// Copyright 2025 Arm Limited and/or its affiliates +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Arm SVE[2] vectors (length not known at compile time). +// External include guard in highway.h - see comment there. + +#include + +#include "hwy/ops/shared-inl.h" + +// Arm C215 declares that SVE vector lengths will always be a power of two. +// We default to relying on this, which makes some operations more efficient. +// You can still opt into fixups by setting this to 0 (unsupported). +#ifndef HWY_SVE_IS_POW2 +#define HWY_SVE_IS_POW2 1 +#endif + +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 +#define HWY_SVE_HAVE_2 1 +#else +#define HWY_SVE_HAVE_2 0 +#endif + +// HWY_SVE_HAVE_BF16_VEC is defined to 1 if the SVE svbfloat16_t vector type +// is supported, even if HWY_SVE_HAVE_BF16_FEATURE (= intrinsics) is 0. +#if HWY_SVE_HAVE_BF16_FEATURE || \ + (HWY_COMPILER_CLANG >= 1200 && defined(__ARM_FEATURE_SVE_BF16)) || \ + HWY_COMPILER_GCC_ACTUAL >= 1000 +#define HWY_SVE_HAVE_BF16_VEC 1 +#else +#define HWY_SVE_HAVE_BF16_VEC 0 +#endif + +// HWY_SVE_HAVE_F32_TO_BF16C is defined to 1 if the SVE svcvt_bf16_f32_x +// and svcvtnt_bf16_f32_x intrinsics are available, even if the __bf16 type +// is disabled +#if HWY_SVE_HAVE_BF16_VEC && defined(__ARM_FEATURE_SVE_BF16) +#define HWY_SVE_HAVE_F32_TO_BF16C 1 +#else +#define HWY_SVE_HAVE_F32_TO_BF16C 0 +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +struct DFromV_t {}; // specialized in macros +template +using DFromV = typename DFromV_t>::type; + +template +using TFromV = TFromD>; + +// ================================================== MACROS + +// Generate specializations and function definitions using X macros. Although +// harder to read and debug, writing everything manually is too bulky. + +namespace detail { // for code folding + +// Args: BASE, CHAR, BITS, HALF, NAME, OP + +// Unsigned: +#define HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) X_MACRO(uint, u, 8, 8, NAME, OP) +#define HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) X_MACRO(uint, u, 16, 8, NAME, OP) +#define HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \ + X_MACRO(uint, u, 32, 16, NAME, OP) +#define HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) \ + X_MACRO(uint, u, 64, 32, NAME, OP) + +// Signed: +#define HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) X_MACRO(int, s, 8, 8, NAME, OP) +#define HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) X_MACRO(int, s, 16, 8, NAME, OP) +#define HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) X_MACRO(int, s, 32, 16, NAME, OP) +#define HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) X_MACRO(int, s, 64, 32, NAME, OP) + +// Float: +#define HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) \ + X_MACRO(float, f, 16, 16, NAME, OP) +#define HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \ + X_MACRO(float, f, 32, 16, NAME, OP) +#define HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) \ + X_MACRO(float, f, 64, 32, NAME, OP) + +#define HWY_SVE_FOREACH_BF16_UNCONDITIONAL(X_MACRO, NAME, OP) \ + X_MACRO(bfloat, bf, 16, 16, NAME, OP) + +#if HWY_SVE_HAVE_BF16_FEATURE +#define HWY_SVE_FOREACH_BF16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_BF16_UNCONDITIONAL(X_MACRO, NAME, OP) +// We have both f16 and bf16, so nothing is emulated. + +// NOTE: hwy::EnableIf()>* = nullptr is used instead of +// hwy::EnableIf* = nullptr to avoid compiler errors since +// !hwy::IsSame() is always false and as !hwy::IsSame() will cause +// SFINAE to occur instead of a hard error due to a dependency on the D template +// argument +#define HWY_SVE_IF_EMULATED_D(D) hwy::EnableIf()>* = nullptr +#define HWY_GENERIC_IF_EMULATED_D(D) \ + hwy::EnableIf()>* = nullptr +#define HWY_SVE_IF_NOT_EMULATED_D(D) hwy::EnableIf* = nullptr +#else +#define HWY_SVE_FOREACH_BF16(X_MACRO, NAME, OP) +#define HWY_SVE_IF_EMULATED_D(D) HWY_IF_BF16_D(D) +#define HWY_GENERIC_IF_EMULATED_D(D) HWY_IF_BF16_D(D) +#define HWY_SVE_IF_NOT_EMULATED_D(D) HWY_IF_NOT_BF16_D(D) +#endif // HWY_SVE_HAVE_BF16_FEATURE + +// For all element sizes: +#define HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_F3264(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) + +// HWY_SVE_FOREACH_F does not include HWY_SVE_FOREACH_BF16 because SVE lacks +// bf16 overloads for some intrinsics (especially less-common arithmetic). +// However, this does include f16 because SVE supports it unconditionally. +#define HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F3264(X_MACRO, NAME, OP) + +// Commonly used type categories for a given element size: +#define HWY_SVE_FOREACH_UI08(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_UI16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_UIF3264(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F3264(X_MACRO, NAME, OP) + +// Commonly used type categories: +#define HWY_SVE_FOREACH_UI(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH_IF(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) + +#define HWY_SVE_FOREACH(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \ + HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) + +// Assemble types for use in x-macros +#define HWY_SVE_T(BASE, BITS) BASE##BITS##_t +#define HWY_SVE_D(BASE, BITS, N, POW2) Simd +#define HWY_SVE_V(BASE, BITS) sv##BASE##BITS##_t +#define HWY_SVE_TUPLE(BASE, BITS, MUL) sv##BASE##BITS##x##MUL##_t + +} // namespace detail + +#define HWY_SPECIALIZE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template <> \ + struct DFromV_t { \ + using type = ScalableTag; \ + }; + +HWY_SVE_FOREACH(HWY_SPECIALIZE, _, _) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _) +#endif +#undef HWY_SPECIALIZE + +// Note: _x (don't-care value for inactive lanes) avoids additional MOVPRFX +// instructions, and we anyway only use it when the predicate is ptrue. + +// vector = f(vector), e.g. Not +#define HWY_SVE_RETV_ARGPV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } +#define HWY_SVE_RETV_ARGV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS(v); \ + } +#define HWY_SVE_RETV_ARGMV_M(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) no, svbool_t m, HWY_SVE_V(BASE, BITS) a) { \ + return sv##OP##_##CHAR##BITS##_m(no, m, a); \ + } +#define HWY_SVE_RETV_ARGMV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_x(m, v); \ + } +#define HWY_SVE_RETV_ARGMV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a) { \ + return sv##OP##_##CHAR##BITS##_z(m, a); \ + } + +// vector = f(vector, scalar), e.g. detail::AddN +#define HWY_SVE_RETV_ARGPVN(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \ + } +#define HWY_SVE_RETV_ARGVN(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(a, b); \ + } + +// vector = f(vector, vector), e.g. Add +#define HWY_SVE_RETV_ARGVV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(a, b); \ + } +// All-true mask +#define HWY_SVE_RETV_ARGPVV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b); \ + } +// User-specified mask. Mask=false value is undefined and must be set by caller +// because SVE instructions take it from one of the two inputs, whereas +// AVX-512, RVV and Highway allow a third argument. +#define HWY_SVE_RETV_ARGMVV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(m, a, b); \ + } +// User-specified mask. Mask=false value is zero. +#define HWY_SVE_RETV_ARGMVV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_z(m, a, b); \ + } + +#define HWY_SVE_RETV_ARGVVV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \ + HWY_SVE_V(BASE, BITS) c) { \ + return sv##OP##_##CHAR##BITS(a, b, c); \ + } +#define HWY_SVE_RETV_ARGMVVV(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \ + HWY_SVE_V(BASE, BITS) c) { \ + return sv##OP##_##CHAR##BITS##_x(m, a, b, c); \ + } +#define HWY_SVE_RETV_ARGMVVV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_V(BASE, BITS) mul, HWY_SVE_V(BASE, BITS) x, \ + HWY_SVE_V(BASE, BITS) add) { \ + return sv##OP##_##CHAR##BITS##_z(m, x, mul, add); \ + } + +// ------------------------------ Lanes + +namespace detail { + +// Returns actual lanes of a hardware vector without rounding to a power of two. +template +HWY_INLINE size_t AllHardwareLanes() { + return svcntb_pat(SV_ALL); +} +template +HWY_INLINE size_t AllHardwareLanes() { + return svcnth_pat(SV_ALL); +} +template +HWY_INLINE size_t AllHardwareLanes() { + return svcntw_pat(SV_ALL); +} +template +HWY_INLINE size_t AllHardwareLanes() { + return svcntd_pat(SV_ALL); +} + +// All-true mask from a macro + +#if HWY_SVE_IS_POW2 +#define HWY_SVE_ALL_PTRUE(BITS) svptrue_b##BITS() +#define HWY_SVE_PTRUE(BITS) svptrue_b##BITS() +#else +#define HWY_SVE_ALL_PTRUE(BITS) svptrue_pat_b##BITS(SV_ALL) +#define HWY_SVE_PTRUE(BITS) svptrue_pat_b##BITS(SV_POW2) +#endif // HWY_SVE_IS_POW2 + +} // namespace detail + +#if HWY_HAVE_SCALABLE + +// Returns actual number of lanes after capping by N and shifting. May return 0 +// (e.g. for "1/8th" of a u32x4 - would be 1 for 1/8th of u32x8). +template +HWY_API size_t Lanes(Simd d) { + const size_t actual = detail::AllHardwareLanes(); + constexpr size_t kMaxLanes = MaxLanes(d); + constexpr int kClampedPow2 = HWY_MIN(kPow2, 0); + // Common case of full vectors: avoid any extra instructions. + if (detail::IsFull(d)) return actual; + return HWY_MIN(detail::ScaleByPower(actual, kClampedPow2), kMaxLanes); +} + +#endif // HWY_HAVE_SCALABLE + +// ================================================== MASK INIT + +// One mask bit per byte; only the one belonging to the lowest byte is valid. + +// ------------------------------ FirstN +#define HWY_SVE_FIRSTN(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, size_t count) { \ + const size_t limit = detail::IsFull(d) ? count : HWY_MIN(Lanes(d), count); \ + return sv##OP##_b##BITS##_u32(uint32_t{0}, static_cast(limit)); \ + } +HWY_SVE_FOREACH(HWY_SVE_FIRSTN, FirstN, whilelt) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_FIRSTN, FirstN, whilelt) +#endif + +template +svbool_t FirstN(D /* tag */, size_t count) { + return FirstN(RebindToUnsigned(), count); +} + +#undef HWY_SVE_FIRSTN + +template +using MFromD = svbool_t; + +namespace detail { + +#define HWY_SVE_WRAP_PTRUE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ + return HWY_SVE_PTRUE(BITS); \ + } \ + template \ + HWY_API svbool_t All##NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ + return HWY_SVE_ALL_PTRUE(BITS); \ + } + +HWY_SVE_FOREACH(HWY_SVE_WRAP_PTRUE, PTrue, ptrue) // return all-true +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_WRAP_PTRUE, PTrue, ptrue) +#undef HWY_SVE_WRAP_PTRUE + +HWY_API svbool_t PFalse() { return svpfalse_b(); } + +// Returns all-true if d is HWY_FULL or FirstN(N) after capping N. +// +// This is used in functions that load/store memory; other functions (e.g. +// arithmetic) can ignore d and use PTrue instead. +// +// Always use FirstN(N) for HWY_TARGET == HWY_SVE2_128 to avoid vector length +// information loss when using PTrue(d) predicates in memory intrinsics. +// +// SVE2_256 is untested due to unavailable hardware and cannot assume +// equal minimum and maximum vector lengths as SVE2_128 can. +template +svbool_t MakeMask(D d) { +#if HWY_TARGET != HWY_SVE2_128 + HWY_IF_CONSTEXPR(IsFull(d)) { return PTrue(d); } +#endif + return FirstN(d, Lanes(d)); +} + +} // namespace detail + +#ifdef HWY_NATIVE_MASK_FALSE +#undef HWY_NATIVE_MASK_FALSE +#else +#define HWY_NATIVE_MASK_FALSE +#endif + +template +HWY_API svbool_t MaskFalse(const D /*d*/) { + return detail::PFalse(); +} + +#ifdef HWY_NATIVE_SET_MASK +#undef HWY_NATIVE_SET_MASK +#else +#define HWY_NATIVE_SET_MASK +#endif + +template +HWY_API svbool_t SetMask(D d, bool val) { + // The SVE svdup_n_b* intrinsics are equivalent to the FirstN op below if + // detail::IsFull(d) is true since svdup_n_b* is simply a wrapper around the + // SVE whilelo instruction. + return FirstN(d, size_t{0} - static_cast(val)); +} + +// ================================================== INIT + +// ------------------------------ Set +// vector = f(d, scalar), e.g. Set +#define HWY_SVE_SET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_T(BASE, BITS) arg) { \ + return sv##OP##_##CHAR##BITS(arg); \ + } + +HWY_SVE_FOREACH(HWY_SVE_SET, Set, dup_n) +#if HWY_SVE_HAVE_BF16_FEATURE // for if-elif chain +HWY_SVE_FOREACH_BF16(HWY_SVE_SET, Set, dup_n) +#elif HWY_SVE_HAVE_BF16_VEC +// Required for Zero and VFromD +template +HWY_API svbfloat16_t Set(D d, bfloat16_t arg) { + return svreinterpret_bf16_u16( + Set(RebindToUnsigned(), BitCastScalar(arg))); +} +#else // neither bf16 feature nor vector: emulate with u16 +// Required for Zero and VFromD +template +HWY_API svuint16_t Set(D d, bfloat16_t arg) { + const RebindToUnsigned du; + return Set(du, BitCastScalar(arg)); +} +#endif // HWY_SVE_HAVE_BF16_FEATURE +#undef HWY_SVE_SET + +template +using VFromD = decltype(Set(D(), TFromD())); + +using VBF16 = VFromD>; + +// ------------------------------ MaskedSetOr/MaskedSet + +#define HWY_SVE_MASKED_SET_OR(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) no, svbool_t m, HWY_SVE_T(BASE, BITS) op) { \ + return sv##OP##_##CHAR##BITS##_m(no, m, op); \ + } + +HWY_SVE_FOREACH(HWY_SVE_MASKED_SET_OR, MaskedSetOr, dup_n) +#undef HWY_SVE_MASKED_SET_OR + +#define HWY_SVE_MASKED_SET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + svbool_t m, HWY_SVE_T(BASE, BITS) op) { \ + return sv##OP##_##CHAR##BITS##_z(m, op); \ + } + +HWY_SVE_FOREACH(HWY_SVE_MASKED_SET, MaskedSet, dup_n) +#undef HWY_SVE_MASKED_SET + +// ------------------------------ Zero + +template +VFromD Zero(D d) { + // Cast to support bfloat16_t. + const RebindToUnsigned du; + return BitCast(d, Set(du, 0)); +} + +// ------------------------------ BitCast + +namespace detail { + +// u8: no change +#define HWY_SVE_CAST_NOP(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \ + return v; \ + } \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) BitCastFromByte( \ + HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \ + return v; \ + } + +// All other types +#define HWY_SVE_CAST(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_INLINE svuint8_t BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_u8_##CHAR##BITS(v); \ + } \ + template \ + HWY_INLINE HWY_SVE_V(BASE, BITS) \ + BitCastFromByte(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svuint8_t v) { \ + return sv##OP##_##CHAR##BITS##_u8(v); \ + } + +// U08 is special-cased, hence do not use FOREACH. +HWY_SVE_FOREACH_U08(HWY_SVE_CAST_NOP, _, _) +HWY_SVE_FOREACH_I08(HWY_SVE_CAST, _, reinterpret) +HWY_SVE_FOREACH_UI16(HWY_SVE_CAST, _, reinterpret) +HWY_SVE_FOREACH_UI32(HWY_SVE_CAST, _, reinterpret) +HWY_SVE_FOREACH_UI64(HWY_SVE_CAST, _, reinterpret) +HWY_SVE_FOREACH_F(HWY_SVE_CAST, _, reinterpret) + +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CAST, _, reinterpret) +#else // !(HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC) +template )> +HWY_INLINE svuint8_t BitCastToByte(V v) { + const RebindToUnsigned> du; + return BitCastToByte(BitCast(du, v)); +} + +template +HWY_INLINE VFromD BitCastFromByte(D d, svuint8_t v) { + const RebindToUnsigned du; + return BitCastFromByte(du, v); +} +#endif // HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC + +#undef HWY_SVE_CAST_NOP +#undef HWY_SVE_CAST + +} // namespace detail + +template +HWY_API VFromD BitCast(D d, FromV v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Undefined + +#define HWY_SVE_UNDEFINED(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */) { \ + return sv##OP##_##CHAR##BITS(); \ + } + +HWY_SVE_FOREACH(HWY_SVE_UNDEFINED, Undefined, undef) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_UNDEFINED, Undefined, undef) +#endif + +template +VFromD Undefined(D d) { + const RebindToUnsigned du; + return BitCast(d, Undefined(du)); +} + +// ------------------------------ Tuple + +// tuples = f(d, v..), e.g. Create2 +#define HWY_SVE_CREATE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_TUPLE(BASE, BITS, 2) \ + NAME##2(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1) { \ + return sv##OP##2_##CHAR##BITS(v0, v1); \ + } \ + template \ + HWY_API HWY_SVE_TUPLE(BASE, BITS, 3) NAME##3( \ + HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v0, \ + HWY_SVE_V(BASE, BITS) v1, HWY_SVE_V(BASE, BITS) v2) { \ + return sv##OP##3_##CHAR##BITS(v0, v1, v2); \ + } \ + template \ + HWY_API HWY_SVE_TUPLE(BASE, BITS, 4) \ + NAME##4(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ + HWY_SVE_V(BASE, BITS) v2, HWY_SVE_V(BASE, BITS) v3) { \ + return sv##OP##4_##CHAR##BITS(v0, v1, v2, v3); \ + } + +HWY_SVE_FOREACH(HWY_SVE_CREATE, Create, create) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CREATE, Create, create) +#endif +#undef HWY_SVE_CREATE + +template +using Vec2 = decltype(Create2(D(), Zero(D()), Zero(D()))); +template +using Vec3 = decltype(Create3(D(), Zero(D()), Zero(D()), Zero(D()))); +template +using Vec4 = decltype(Create4(D(), Zero(D()), Zero(D()), Zero(D()), Zero(D()))); + +#define HWY_SVE_GET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME##2(HWY_SVE_TUPLE(BASE, BITS, 2) tuple) { \ + return sv##OP##2_##CHAR##BITS(tuple, kIndex); \ + } \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME##3(HWY_SVE_TUPLE(BASE, BITS, 3) tuple) { \ + return sv##OP##3_##CHAR##BITS(tuple, kIndex); \ + } \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME##4(HWY_SVE_TUPLE(BASE, BITS, 4) tuple) { \ + return sv##OP##4_##CHAR##BITS(tuple, kIndex); \ + } + +HWY_SVE_FOREACH(HWY_SVE_GET, Get, get) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_GET, Get, get) +#endif +#undef HWY_SVE_GET + +#define HWY_SVE_SET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_TUPLE(BASE, BITS, 2) \ + NAME##2(HWY_SVE_TUPLE(BASE, BITS, 2) tuple, HWY_SVE_V(BASE, BITS) vec) { \ + return sv##OP##2_##CHAR##BITS(tuple, kIndex, vec); \ + } \ + template \ + HWY_API HWY_SVE_TUPLE(BASE, BITS, 3) \ + NAME##3(HWY_SVE_TUPLE(BASE, BITS, 3) tuple, HWY_SVE_V(BASE, BITS) vec) { \ + return sv##OP##3_##CHAR##BITS(tuple, kIndex, vec); \ + } \ + template \ + HWY_API HWY_SVE_TUPLE(BASE, BITS, 4) \ + NAME##4(HWY_SVE_TUPLE(BASE, BITS, 4) tuple, HWY_SVE_V(BASE, BITS) vec) { \ + return sv##OP##4_##CHAR##BITS(tuple, kIndex, vec); \ + } + +HWY_SVE_FOREACH(HWY_SVE_SET, Set, set) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_SET, Set, set) +#endif +#undef HWY_SVE_SET + +// ------------------------------ ResizeBitCast + +// Same as BitCast on SVE +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + return BitCast(d, v); +} + +// ------------------------------ Dup128VecFromValues + +template +HWY_API svint8_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + return svdupq_n_s8(t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, + t14, t15); +} + +template +HWY_API svuint8_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + return svdupq_n_u8(t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, + t14, t15); +} + +template +HWY_API svint16_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return svdupq_n_s16(t0, t1, t2, t3, t4, t5, t6, t7); +} + +template +HWY_API svuint16_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return svdupq_n_u16(t0, t1, t2, t3, t4, t5, t6, t7); +} + +template +HWY_API svfloat16_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, + TFromD t4, TFromD t5, + TFromD t6, TFromD t7) { + return svdupq_n_f16(t0, t1, t2, t3, t4, t5, t6, t7); +} + +template +HWY_API VBF16 Dup128VecFromValues(D d, TFromD t0, TFromD t1, TFromD t2, + TFromD t3, TFromD t4, TFromD t5, + TFromD t6, TFromD t7) { +#if HWY_SVE_HAVE_BF16_FEATURE + (void)d; + return svdupq_n_bf16(t0, t1, t2, t3, t4, t5, t6, t7); +#else + const RebindToUnsigned du; + return BitCast( + d, Dup128VecFromValues( + du, BitCastScalar(t0), BitCastScalar(t1), + BitCastScalar(t2), BitCastScalar(t3), + BitCastScalar(t4), BitCastScalar(t5), + BitCastScalar(t6), BitCastScalar(t7))); +#endif +} + +template +HWY_API svint32_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return svdupq_n_s32(t0, t1, t2, t3); +} + +template +HWY_API svuint32_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return svdupq_n_u32(t0, t1, t2, t3); +} + +template +HWY_API svfloat32_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return svdupq_n_f32(t0, t1, t2, t3); +} + +template +HWY_API svint64_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return svdupq_n_s64(t0, t1); +} + +template +HWY_API svuint64_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return svdupq_n_u64(t0, t1); +} + +template +HWY_API svfloat64_t Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return svdupq_n_f64(t0, t1); +} + +// ------------------------------ GetLane + +namespace detail { +#define HWY_SVE_GET_LANE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_INLINE HWY_SVE_T(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) v, svbool_t mask) { \ + return sv##OP##_##CHAR##BITS(mask, v); \ + } + +HWY_SVE_FOREACH(HWY_SVE_GET_LANE, GetLaneM, lasta) +HWY_SVE_FOREACH(HWY_SVE_GET_LANE, ExtractLastMatchingLaneM, lastb) +#undef HWY_SVE_GET_LANE +} // namespace detail + +template +HWY_API TFromV GetLane(V v) { + return detail::GetLaneM(v, detail::PFalse()); +} + +// ================================================== LOGICAL + +// detail::*N() functions accept a scalar argument to avoid extra Set(). + +// ------------------------------ Not +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPV, Not, not ) // NOLINT + +// ------------------------------ And + +namespace detail { +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, AndN, and_n) +} // namespace detail + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, And, and) + +template +HWY_API V And(const V a, const V b) { + const DFromV df; + const RebindToUnsigned du; + return BitCast(df, And(BitCast(du, a), BitCast(du, b))); +} + +// ------------------------------ Or + +namespace detail { +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, OrN, orr_n) +} // namespace detail + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Or, orr) + +template +HWY_API V Or(const V a, const V b) { + const DFromV df; + const RebindToUnsigned du; + return BitCast(df, Or(BitCast(du, a), BitCast(du, b))); +} + +// ------------------------------ MaskedOr +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV_Z, MaskedOr, orr) + +// ------------------------------ Xor + +namespace detail { +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, XorN, eor_n) +} // namespace detail + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Xor, eor) + +template +HWY_API V Xor(const V a, const V b) { + const DFromV df; + const RebindToUnsigned du; + return BitCast(df, Xor(BitCast(du, a), BitCast(du, b))); +} + +// ------------------------------ AndNot + +namespace detail { +#define HWY_SVE_RETV_ARGPVN_SWAP(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_T(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN_SWAP, AndNotN, bic_n) +#undef HWY_SVE_RETV_ARGPVN_SWAP +} // namespace detail + +#define HWY_SVE_RETV_ARGPVV_SWAP(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), b, a); \ + } +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV_SWAP, AndNot, bic) +#undef HWY_SVE_RETV_ARGPVV_SWAP + +template +HWY_API V AndNot(const V a, const V b) { + const DFromV df; + const RebindToUnsigned du; + return BitCast(df, AndNot(BitCast(du, a), BitCast(du, b))); +} + +// ------------------------------ Xor3 + +#if HWY_SVE_HAVE_2 + +#ifdef HWY_NATIVE_XOR3 +#undef HWY_NATIVE_XOR3 +#else +#define HWY_NATIVE_XOR3 +#endif + +#ifdef HWY_NATIVE_BCAX +#undef HWY_NATIVE_BCAX +#else +#define HWY_NATIVE_BCAX +#endif + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVVV, Xor3, eor3) + +// As with AndNot, we follow the x86 convention where the first argument is +// negated, whereas the intrinsic has it last. +#define HWY_SVE_RETV_ARGVVV_SWAP(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \ + HWY_SVE_V(BASE, BITS) c) { \ + return sv##OP##_##CHAR##BITS(a, c, b); \ + } +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVVV_SWAP, XorAndNot, bcax) +#undef HWY_SVE_RETV_ARGVVV_SWAP + +template +HWY_API V Xor3(const V x1, const V x2, const V x3) { + const DFromV df; + const RebindToUnsigned du; + return BitCast(df, Xor3(BitCast(du, x1), BitCast(du, x2), BitCast(du, x3))); +} + +template +HWY_API V XorAndNot(const V x, const V a1, const V a2) { + const DFromV df; + const RebindToUnsigned du; + return BitCast(df, + XorAndNot(BitCast(du, x), BitCast(du, a1), BitCast(du, a2))); +} + +#endif // HWY_SVE_HAVE_2 + +// ------------------------------ Or3 +template +HWY_API V Or3(V o1, V o2, V o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd +template +HWY_API V OrAnd(const V o, const V a1, const V a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ PopulationCount + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +// Need to return original type instead of unsigned. +#define HWY_SVE_POPCNT(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return BitCast(DFromV(), \ + sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v)); \ + } +HWY_SVE_FOREACH_UI(HWY_SVE_POPCNT, PopulationCount, cnt) +#undef HWY_SVE_POPCNT + +// ================================================== SIGN + +// ------------------------------ Neg +HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Neg, neg) + +HWY_API VBF16 Neg(VBF16 v) { + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; + return BitCast(d, Xor(BitCast(du, v), Set(du, SignMask()))); +} + +// ------------------------------ SaturatedNeg +#if HWY_SVE_HAVE_2 +#ifdef HWY_NATIVE_SATURATED_NEG_8_16_32 +#undef HWY_NATIVE_SATURATED_NEG_8_16_32 +#else +#define HWY_NATIVE_SATURATED_NEG_8_16_32 +#endif + +#ifdef HWY_NATIVE_SATURATED_NEG_64 +#undef HWY_NATIVE_SATURATED_NEG_64 +#else +#define HWY_NATIVE_SATURATED_NEG_64 +#endif + +HWY_SVE_FOREACH_I(HWY_SVE_RETV_ARGPV, SaturatedNeg, qneg) +#endif // HWY_SVE_HAVE_2 + +// ================================================== ARITHMETIC + +// Per-target flags to prevent generic_ops-inl.h defining Add etc. +#ifdef HWY_NATIVE_OPERATOR_REPLACEMENTS +#undef HWY_NATIVE_OPERATOR_REPLACEMENTS +#else +#define HWY_NATIVE_OPERATOR_REPLACEMENTS +#endif + +// ------------------------------ Add + +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN, AddN, add_n) +} // namespace detail + +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Add, add) + +// ------------------------------ Sub + +namespace detail { +// Can't use HWY_SVE_RETV_ARGPVN because caller wants to specify pg. +#define HWY_SVE_RETV_ARGPVN_MASK(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_z(pg, a, b); \ + } + +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN_MASK, SubN, sub_n) +#undef HWY_SVE_RETV_ARGPVN_MASK +} // namespace detail + +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Sub, sub) + +// ------------------------------ SumsOf8 +HWY_API svuint64_t SumsOf8(const svuint8_t v) { + const ScalableTag du32; + const ScalableTag du64; + const svbool_t pg = detail::PTrue(du64); + + const svuint32_t sums_of_4 = svdot_n_u32(Zero(du32), v, 1); + // Compute pairwise sum of u32 and extend to u64. + +#if HWY_SVE_HAVE_2 + return svadalp_u64_x(pg, Zero(du64), sums_of_4); +#else + const svuint64_t hi = svlsr_n_u64_x(pg, BitCast(du64, sums_of_4), 32); + // Isolate the lower 32 bits (to be added to the upper 32 and zero-extended) + const svuint64_t lo = svextw_u64_x(pg, BitCast(du64, sums_of_4)); + return Add(hi, lo); +#endif +} + +HWY_API svint64_t SumsOf8(const svint8_t v) { + const ScalableTag di32; + const ScalableTag di64; + const svbool_t pg = detail::PTrue(di64); + + const svint32_t sums_of_4 = svdot_n_s32(Zero(di32), v, 1); +#if HWY_SVE_HAVE_2 + return svadalp_s64_x(pg, Zero(di64), sums_of_4); +#else + const svint64_t hi = svasr_n_s64_x(pg, BitCast(di64, sums_of_4), 32); + // Isolate the lower 32 bits (to be added to the upper 32 and sign-extended) + const svint64_t lo = svextw_s64_x(pg, BitCast(di64, sums_of_4)); + return Add(hi, lo); +#endif +} + +// ------------------------------ SumsOf2 +#if HWY_SVE_HAVE_2 +namespace detail { + +HWY_INLINE svint16_t SumsOf2(hwy::SignedTag /*type_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, svint8_t v) { + const ScalableTag di16; + const svbool_t pg = detail::PTrue(di16); + return svadalp_s16_x(pg, Zero(di16), v); +} + +HWY_INLINE svuint16_t SumsOf2(hwy::UnsignedTag /*type_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, svuint8_t v) { + const ScalableTag du16; + const svbool_t pg = detail::PTrue(du16); + return svadalp_u16_x(pg, Zero(du16), v); +} + +HWY_INLINE svint32_t SumsOf2(hwy::SignedTag /*type_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, svint16_t v) { + const ScalableTag di32; + const svbool_t pg = detail::PTrue(di32); + return svadalp_s32_x(pg, Zero(di32), v); +} + +HWY_INLINE svuint32_t SumsOf2(hwy::UnsignedTag /*type_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, svuint16_t v) { + const ScalableTag du32; + const svbool_t pg = detail::PTrue(du32); + return svadalp_u32_x(pg, Zero(du32), v); +} + +HWY_INLINE svint64_t SumsOf2(hwy::SignedTag /*type_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, svint32_t v) { + const ScalableTag di64; + const svbool_t pg = detail::PTrue(di64); + return svadalp_s64_x(pg, Zero(di64), v); +} + +HWY_INLINE svuint64_t SumsOf2(hwy::UnsignedTag /*type_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, svuint32_t v) { + const ScalableTag du64; + const svbool_t pg = detail::PTrue(du64); + return svadalp_u64_x(pg, Zero(du64), v); +} + +} // namespace detail +#endif // HWY_SVE_HAVE_2 + +// ------------------------------ SumsOf4 +namespace detail { + +HWY_INLINE svint32_t SumsOf4(hwy::SignedTag /*type_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, svint8_t v) { + return svdot_n_s32(Zero(ScalableTag()), v, 1); +} + +HWY_INLINE svuint32_t SumsOf4(hwy::UnsignedTag /*type_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, svuint8_t v) { + return svdot_n_u32(Zero(ScalableTag()), v, 1); +} + +HWY_INLINE svint64_t SumsOf4(hwy::SignedTag /*type_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, svint16_t v) { + return svdot_n_s64(Zero(ScalableTag()), v, 1); +} + +HWY_INLINE svuint64_t SumsOf4(hwy::UnsignedTag /*type_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, svuint16_t v) { + return svdot_n_u64(Zero(ScalableTag()), v, 1); +} + +} // namespace detail + +// ------------------------------ SaturatedAdd + +#ifdef HWY_NATIVE_I32_SATURATED_ADDSUB +#undef HWY_NATIVE_I32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I32_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_U32_SATURATED_ADDSUB +#undef HWY_NATIVE_U32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_U32_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_I64_SATURATED_ADDSUB +#undef HWY_NATIVE_I64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I64_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_U64_SATURATED_ADDSUB +#undef HWY_NATIVE_U64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_U64_SATURATED_ADDSUB +#endif + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVV, SaturatedAdd, qadd) + +// ------------------------------ SaturatedSub + +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVV, SaturatedSub, qsub) + +// ------------------------------ AbsDiff +#ifdef HWY_NATIVE_INTEGER_ABS_DIFF +#undef HWY_NATIVE_INTEGER_ABS_DIFF +#else +#define HWY_NATIVE_INTEGER_ABS_DIFF +#endif + +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, AbsDiff, abd) + +// ------------------------------ ShiftLeft[Same] + +#define HWY_SVE_SHIFT_N(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, kBits); \ + } \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME##Same(HWY_SVE_V(BASE, BITS) v, int bits) { \ + return sv##OP##_##CHAR##BITS##_x( \ + HWY_SVE_PTRUE(BITS), v, static_cast(bits)); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_N, ShiftLeft, lsl_n) + +// ------------------------------ ShiftRight[Same] + +HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_N, ShiftRight, lsr_n) +HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_N, ShiftRight, asr_n) + +#undef HWY_SVE_SHIFT_N + +// ------------------------------ MaskedShift[Left/Right] + +#define HWY_SVE_SHIFT_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \ + auto shifts = static_cast(kBits); \ + return sv##OP##_##CHAR##BITS##_z(m, v, shifts); \ + } +HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_Z, MaskedShiftLeft, lsl_n) +HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_Z, MaskedShiftRight, asr_n) +HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_Z, MaskedShiftRight, lsr_n) + +#undef HWY_SVE_SHIFT_Z + +// ------------------------------ MaskedShiftRightOr + +#define HWY_SVE_SHIFT_OR(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) no, svbool_t m, HWY_SVE_V(BASE, BITS) v) { \ + auto shifts = static_cast(kBits); \ + return svsel##_##CHAR##BITS(m, sv##OP##_##CHAR##BITS##_z(m, v, shifts), \ + no); \ + } +HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_OR, MaskedShiftRightOr, asr_n) +HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_OR, MaskedShiftRightOr, lsr_n) + +#undef HWY_SVE_SHIFT_OR + +// ------------------------------ RotateRight + +#if HWY_SVE_HAVE_2 + +#define HWY_SVE_ROTATE_RIGHT_N(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + if (kBits == 0) return v; \ + return sv##OP##_##CHAR##BITS(v, Zero(DFromV()), \ + HWY_MAX(kBits, 1)); \ + } + +HWY_SVE_FOREACH_U(HWY_SVE_ROTATE_RIGHT_N, RotateRight, xar_n) +HWY_SVE_FOREACH_I(HWY_SVE_ROTATE_RIGHT_N, RotateRight, xar_n) + +#undef HWY_SVE_ROTATE_RIGHT_N + +#else // !HWY_SVE_HAVE_2 +template +HWY_API V RotateRight(const V v) { + const DFromV d; + const RebindToUnsigned du; + + constexpr size_t kSizeInBits = sizeof(TFromV) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + + return Or(BitCast(d, ShiftRight(BitCast(du, v))), + ShiftLeft(v)); +} +#endif + +// ------------------------------ Shl, Shr + +#define HWY_SVE_SHIFT(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(BASE, BITS) bits) { \ + const RebindToUnsigned> du; \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, \ + BitCast(du, bits)); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT, Shl, lsl) + +HWY_SVE_FOREACH_U(HWY_SVE_SHIFT, Shr, lsr) +HWY_SVE_FOREACH_I(HWY_SVE_SHIFT, Shr, asr) + +#undef HWY_SVE_SHIFT + +// ------------------------------ RoundingShiftLeft[Same]/RoundingShr + +#if HWY_SVE_HAVE_2 + +#ifdef HWY_NATIVE_ROUNDING_SHR +#undef HWY_NATIVE_ROUNDING_SHR +#else +#define HWY_NATIVE_ROUNDING_SHR +#endif + +#define HWY_SVE_ROUNDING_SHR_N(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + HWY_IF_CONSTEXPR(kBits == 0) { return v; } \ + \ + return sv##OP##_##CHAR##BITS##_x( \ + HWY_SVE_PTRUE(BITS), v, static_cast(HWY_MAX(kBits, 1))); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_ROUNDING_SHR_N, RoundingShiftRight, rshr_n) + +#undef HWY_SVE_ROUNDING_SHR_N + +#define HWY_SVE_ROUNDING_SHR(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(BASE, BITS) bits) { \ + const RebindToSigned> di; \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, \ + Neg(BitCast(di, bits))); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_ROUNDING_SHR, RoundingShr, rshl) + +#undef HWY_SVE_ROUNDING_SHR + +template +HWY_API V RoundingShiftRightSame(V v, int bits) { + const DFromV d; + using T = TFromD; + return RoundingShr(v, Set(d, static_cast(bits))); +} + +#endif // HWY_SVE_HAVE_2 + +// ------------------------------ BroadcastSignBit (ShiftRight) +template +HWY_API V BroadcastSignBit(const V v) { + return ShiftRight) * 8 - 1>(v); +} + +// ------------------------------ Abs (ShiftRight, Add, Xor, AndN) + +// Workaround for incorrect results with `svabs`. +#if HWY_COMPILER_CLANG +template +HWY_API V Abs(V v) { + const V sign = BroadcastSignBit(v); + return Xor(Add(v, sign), sign); +} + +template +HWY_NOINLINE V Abs(V v) { + const DFromV d; + const RebindToUnsigned du; + using TU = MakeUnsigned>; + return BitCast( + d, detail::AndN(BitCast(du, v), static_cast(~SignMask()))); +} + +#else +HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Abs, abs) +#endif + +// ------------------------------ SaturatedAbs +#if HWY_SVE_HAVE_2 +#ifdef HWY_NATIVE_SATURATED_ABS +#undef HWY_NATIVE_SATURATED_ABS +#else +#define HWY_NATIVE_SATURATED_ABS +#endif + +HWY_SVE_FOREACH_I(HWY_SVE_RETV_ARGPV, SaturatedAbs, qabs) +#endif // HWY_SVE_HAVE_2 + +// ------------------------------ MaskedAbsOr +HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGMV_M, MaskedAbsOr, abs) + +// ------------------------------ MaskedAbs +HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGMV_Z, MaskedAbs, abs) + +// ------------------------------ Mul + +// Per-target flags to prevent generic_ops-inl.h defining 8/64-bit operator*. +#ifdef HWY_NATIVE_MUL_8 +#undef HWY_NATIVE_MUL_8 +#else +#define HWY_NATIVE_MUL_8 +#endif +#ifdef HWY_NATIVE_MUL_64 +#undef HWY_NATIVE_MUL_64 +#else +#define HWY_NATIVE_MUL_64 +#endif + +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Mul, mul) + +// ------------------------------ MulHigh +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, MulHigh, mulh) + +// ------------------------------ MulFixedPoint15 +HWY_API svint16_t MulFixedPoint15(svint16_t a, svint16_t b) { +#if HWY_SVE_HAVE_2 + return svqrdmulh_s16(a, b); +#else + const DFromV d; + const RebindToUnsigned du; + + const svuint16_t lo = BitCast(du, Mul(a, b)); + const svint16_t hi = MulHigh(a, b); + // We want (lo + 0x4000) >> 15, but that can overflow, and if it does we must + // carry that into the result. Instead isolate the top two bits because only + // they can influence the result. + const svuint16_t lo_top2 = ShiftRight<14>(lo); + // Bits 11: add 2, 10: add 1, 01: add 1, 00: add 0. + const svuint16_t rounding = ShiftRight<1>(detail::AddN(lo_top2, 1)); + return Add(Add(hi, hi), BitCast(d, rounding)); +#endif +} + +// ------------------------------ Div +#ifdef HWY_NATIVE_INT_DIV +#undef HWY_NATIVE_INT_DIV +#else +#define HWY_NATIVE_INT_DIV +#endif + +HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPVV, Div, div) +HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGPVV, Div, div) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Div, div) + +// ------------------------------ ApproximateReciprocal +#ifdef HWY_NATIVE_F64_APPROX_RECIP +#undef HWY_NATIVE_F64_APPROX_RECIP +#else +#define HWY_NATIVE_F64_APPROX_RECIP +#endif + +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGV, ApproximateReciprocal, recpe) + +// ------------------------------ Sqrt +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Sqrt, sqrt) + +// ------------------------------ MaskedSqrt +#ifdef HWY_NATIVE_MASKED_SQRT +#undef HWY_NATIVE_MASKED_SQRT +#else +#define HWY_NATIVE_MASKED_SQRT +#endif + +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMV_Z, MaskedSqrt, sqrt) + +// ------------------------------ ApproximateReciprocalSqrt +#ifdef HWY_NATIVE_F64_APPROX_RSQRT +#undef HWY_NATIVE_F64_APPROX_RSQRT +#else +#define HWY_NATIVE_F64_APPROX_RSQRT +#endif + +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGV, ApproximateReciprocalSqrt, rsqrte) + +// ------------------------------ MulAdd + +// Per-target flag to prevent generic_ops-inl.h from defining int MulAdd. +#ifdef HWY_NATIVE_INT_FMA +#undef HWY_NATIVE_INT_FMA +#else +#define HWY_NATIVE_INT_FMA +#endif + +#define HWY_SVE_FMA(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) mul, HWY_SVE_V(BASE, BITS) x, \ + HWY_SVE_V(BASE, BITS) add) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), x, mul, add); \ + } + +HWY_SVE_FOREACH(HWY_SVE_FMA, MulAdd, mad) + +// ------------------------------ NegMulAdd +HWY_SVE_FOREACH(HWY_SVE_FMA, NegMulAdd, msb) + +// ------------------------------ MulSub +HWY_SVE_FOREACH_F(HWY_SVE_FMA, MulSub, nmsb) + +// ------------------------------ NegMulSub +HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulSub, nmad) + +#undef HWY_SVE_FMA + +// ------------------------------ Round etc. + +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Round, rintn) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Floor, rintm) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Ceil, rintp) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Trunc, rintz) + +// ================================================== MASK + +// ------------------------------ RebindMask +template +HWY_API svbool_t RebindMask(const D /*d*/, const MFrom mask) { + return mask; +} + +// ------------------------------ Mask logical + +HWY_API svbool_t Not(svbool_t m) { + // We don't know the lane type, so assume 8-bit. For larger types, this will + // de-canonicalize the predicate, i.e. set bits to 1 even though they do not + // correspond to the lowest byte in the lane. Arm says such bits are ignored. + return svnot_b_z(HWY_SVE_PTRUE(8), m); +} +HWY_API svbool_t And(svbool_t a, svbool_t b) { + return svand_b_z(b, b, a); // same order as AndNot for consistency +} +HWY_API svbool_t AndNot(svbool_t a, svbool_t b) { + return svbic_b_z(b, b, a); // reversed order like NEON +} +HWY_API svbool_t Or(svbool_t a, svbool_t b) { + return svsel_b(a, a, b); // a ? true : b +} +HWY_API svbool_t Xor(svbool_t a, svbool_t b) { + return svsel_b(a, svnand_b_z(a, a, b), b); // a ? !(a & b) : b. +} + +HWY_API svbool_t ExclusiveNeither(svbool_t a, svbool_t b) { + return svnor_b_z(HWY_SVE_PTRUE(8), a, b); // !a && !b, undefined if a && b. +} + +// ------------------------------ CountTrue + +#define HWY_SVE_COUNT_TRUE(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, svbool_t m) { \ + return sv##OP##_b##BITS(detail::MakeMask(d), m); \ + } + +HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE, CountTrue, cntp) +#undef HWY_SVE_COUNT_TRUE + +// For 16-bit Compress: full vector, not limited to SV_POW2. +namespace detail { + +#define HWY_SVE_COUNT_TRUE_FULL(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, svbool_t m) { \ + return sv##OP##_b##BITS(svptrue_b##BITS(), m); \ + } + +HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE_FULL, CountTrueFull, cntp) +#undef HWY_SVE_COUNT_TRUE_FULL + +} // namespace detail + +// ------------------------------ AllFalse +template +HWY_API bool AllFalse(D d, svbool_t m) { + return !svptest_any(detail::MakeMask(d), m); +} + +// ------------------------------ AllTrue +template +HWY_API bool AllTrue(D d, svbool_t m) { + return CountTrue(d, m) == Lanes(d); +} + +// ------------------------------ FindFirstTrue +template +HWY_API intptr_t FindFirstTrue(D d, svbool_t m) { + return AllFalse(d, m) ? intptr_t{-1} + : static_cast( + CountTrue(d, svbrkb_b_z(detail::MakeMask(d), m))); +} + +// ------------------------------ FindKnownFirstTrue +template +HWY_API size_t FindKnownFirstTrue(D d, svbool_t m) { + return CountTrue(d, svbrkb_b_z(detail::MakeMask(d), m)); +} + +// ------------------------------ IfThenElse +#define HWY_SVE_IF_THEN_ELSE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_V(BASE, BITS) yes, HWY_SVE_V(BASE, BITS) no) { \ + return sv##OP##_##CHAR##BITS(m, yes, no); \ + } + +HWY_SVE_FOREACH(HWY_SVE_IF_THEN_ELSE, IfThenElse, sel) +HWY_SVE_FOREACH_BF16(HWY_SVE_IF_THEN_ELSE, IfThenElse, sel) +#undef HWY_SVE_IF_THEN_ELSE + +template , HWY_SVE_IF_EMULATED_D(D)> +HWY_API V IfThenElse(const svbool_t mask, V yes, V no) { + const RebindToUnsigned du; + return BitCast( + D(), IfThenElse(RebindMask(du, mask), BitCast(du, yes), BitCast(du, no))); +} + +// ------------------------------ IfThenElseZero + +template , HWY_SVE_IF_NOT_EMULATED_D(D)> +HWY_API V IfThenElseZero(const svbool_t mask, const V yes) { + return IfThenElse(mask, yes, Zero(D())); +} + +template , HWY_SVE_IF_EMULATED_D(D)> +HWY_API V IfThenElseZero(const svbool_t mask, V yes) { + const RebindToUnsigned du; + return BitCast(D(), IfThenElseZero(RebindMask(du, mask), BitCast(du, yes))); +} + +// ------------------------------ IfThenZeroElse + +template , HWY_SVE_IF_NOT_EMULATED_D(D)> +HWY_API V IfThenZeroElse(const svbool_t mask, const V no) { + return IfThenElse(mask, Zero(D()), no); +} + +template , HWY_SVE_IF_EMULATED_D(D)> +HWY_API V IfThenZeroElse(const svbool_t mask, V no) { + const RebindToUnsigned du; + return BitCast(D(), IfThenZeroElse(RebindMask(du, mask), BitCast(du, no))); +} + +// ------------------------------ Additional mask logical operations +HWY_API svbool_t SetBeforeFirst(svbool_t m) { + // We don't know the lane type, so assume 8-bit. For larger types, this will + // de-canonicalize the predicate, i.e. set bits to 1 even though they do not + // correspond to the lowest byte in the lane. Arm says such bits are ignored. + return svbrkb_b_z(HWY_SVE_PTRUE(8), m); +} + +HWY_API svbool_t SetAtOrBeforeFirst(svbool_t m) { + // We don't know the lane type, so assume 8-bit. For larger types, this will + // de-canonicalize the predicate, i.e. set bits to 1 even though they do not + // correspond to the lowest byte in the lane. Arm says such bits are ignored. + return svbrka_b_z(HWY_SVE_PTRUE(8), m); +} + +HWY_API svbool_t SetOnlyFirst(svbool_t m) { return svbrka_b_z(m, m); } + +HWY_API svbool_t SetAtOrAfterFirst(svbool_t m) { + return Not(SetBeforeFirst(m)); +} + +// ------------------------------ PromoteMaskTo + +#ifdef HWY_NATIVE_PROMOTE_MASK_TO +#undef HWY_NATIVE_PROMOTE_MASK_TO +#else +#define HWY_NATIVE_PROMOTE_MASK_TO +#endif + +template ) * 2)> +HWY_API svbool_t PromoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, svbool_t m) { + return svunpklo_b(m); +} + +template ) * 2)> +HWY_API svbool_t PromoteMaskTo(DTo d_to, DFrom d_from, svbool_t m) { + using TFrom = TFromD; + using TWFrom = MakeWide>; + static_assert(sizeof(TWFrom) > sizeof(TFrom), + "sizeof(TWFrom) > sizeof(TFrom) must be true"); + + const Rebind dw_from; + return PromoteMaskTo(d_to, dw_from, PromoteMaskTo(dw_from, d_from, m)); +} + +// ------------------------------ DemoteMaskTo + +#ifdef HWY_NATIVE_DEMOTE_MASK_TO +#undef HWY_NATIVE_DEMOTE_MASK_TO +#else +#define HWY_NATIVE_DEMOTE_MASK_TO +#endif + +template +HWY_API svbool_t DemoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, svbool_t m) { + return svuzp1_b8(m, m); +} + +template +HWY_API svbool_t DemoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, svbool_t m) { + return svuzp1_b16(m, m); +} + +template +HWY_API svbool_t DemoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, svbool_t m) { + return svuzp1_b32(m, m); +} + +template ) / 4)> +HWY_API svbool_t DemoteMaskTo(DTo d_to, DFrom d_from, svbool_t m) { + using TFrom = TFromD; + using TNFrom = MakeNarrow>; + static_assert(sizeof(TNFrom) < sizeof(TFrom), + "sizeof(TNFrom) < sizeof(TFrom) must be true"); + + const Rebind dn_from; + return DemoteMaskTo(d_to, dn_from, DemoteMaskTo(dn_from, d_from, m)); +} + +// ------------------------------ LowerHalfOfMask +#ifdef HWY_NATIVE_LOWER_HALF_OF_MASK +#undef HWY_NATIVE_LOWER_HALF_OF_MASK +#else +#define HWY_NATIVE_LOWER_HALF_OF_MASK +#endif + +template +HWY_API svbool_t LowerHalfOfMask(D /*d*/, svbool_t m) { + return m; +} + +// ------------------------------ MaskedAddOr etc. (IfThenElse) + +#ifdef HWY_NATIVE_MASKED_ARITH +#undef HWY_NATIVE_MASKED_ARITH +#else +#define HWY_NATIVE_MASKED_ARITH +#endif + +namespace detail { +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMVV, MaskedMin, minnm) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMVV, MaskedMax, maxnm) +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedMin, min) +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedMax, max) +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedAdd, add) +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedSub, sub) +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMul, mul) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMVV, MaskedDiv, div) +HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGMVV, MaskedDiv, div) +HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGMVV, MaskedDiv, div) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMV, MaskedSqrt, sqrt) +#if HWY_SVE_HAVE_2 +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedSatAdd, qadd) +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedSatSub, qsub) +#endif +} // namespace detail + +template +HWY_API V MaskedMinOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedMin(m, a, b), no); +} + +template +HWY_API V MaskedMaxOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedMax(m, a, b), no); +} + +template +HWY_API V MaskedAddOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedAdd(m, a, b), no); +} + +template +HWY_API V MaskedSubOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedSub(m, a, b), no); +} + +template +HWY_API V MaskedMulOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedMul(m, a, b), no); +} + +template , hwy::float16_t>() ? (1 << 2) : 0) | + (1 << 4) | (1 << 8))> +HWY_API V MaskedDivOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedDiv(m, a, b), no); +} + +// I8/U8/I16/U16 MaskedDivOr is implemented after I8/U8/I16/U16 Div + +#if HWY_SVE_HAVE_2 +template +HWY_API V MaskedSatAddOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedSatAdd(m, a, b), no); +} + +template +HWY_API V MaskedSatSubOr(V no, M m, V a, V b) { + return IfThenElse(m, detail::MaskedSatSub(m, a, b), no); +} +#else +template +HWY_API V MaskedSatAddOr(V no, M m, V a, V b) { + return IfThenElse(m, SaturatedAdd(a, b), no); +} + +template +HWY_API V MaskedSatSubOr(V no, M m, V a, V b) { + return IfThenElse(m, SaturatedSub(a, b), no); +} +#endif + +// ------------------------------ MaskedMulAddOr +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVVV, MaskedMulAdd, mad) +} + +// Per-target flag to prevent generic_ops-inl.h from defining int +// MaskedMulAddOr. +#ifdef HWY_NATIVE_MASKED_INT_FMA +#undef HWY_NATIVE_MASKED_INT_FMA +#else +#define HWY_NATIVE_MASKED_INT_FMA +#endif + +template +HWY_API V MaskedMulAddOr(V no, M m, V mul, V x, V add) { + return IfThenElse(m, detail::MaskedMulAdd(m, mul, x, add), no); +} + +template +HWY_API V MaskedSqrtOr(V no, M m, V v) { + return IfThenElse(m, detail::MaskedSqrt(m, v), no); +} + +// ================================================== REDUCE + +#ifdef HWY_NATIVE_REDUCE_SCALAR +#undef HWY_NATIVE_REDUCE_SCALAR +#else +#define HWY_NATIVE_REDUCE_SCALAR +#endif + +// These return T, suitable for ReduceSum. +namespace detail { +#define HWY_SVE_REDUCE_ADD(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_T(BASE, BITS) NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) v) { \ + /* The intrinsic returns [u]int64_t; truncate to T so we can broadcast. */ \ + using T = HWY_SVE_T(BASE, BITS); \ + using TU = MakeUnsigned; \ + constexpr uint64_t kMask = LimitsMax(); \ + return static_cast(static_cast( \ + static_cast(sv##OP##_##CHAR##BITS(pg, v)) & kMask)); \ + } + +#define HWY_SVE_REDUCE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_T(BASE, BITS) NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS(pg, v); \ + } + +// TODO: Remove SumOfLanesM in favor of using MaskedReduceSum +HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE_ADD, SumOfLanesM, addv) +HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, SumOfLanesM, addv) + +HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE, MinOfLanesM, minv) +HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE, MaxOfLanesM, maxv) +// NaN if all are +HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MinOfLanesM, minnmv) +HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, MaxOfLanesM, maxnmv) + +#undef HWY_SVE_REDUCE +#undef HWY_SVE_REDUCE_ADD +} // namespace detail + +// detail::SumOfLanesM, detail::MinOfLanesM, and detail::MaxOfLanesM is more +// efficient for N=4 I8/U8 reductions on SVE than the default implementations +// of the N=4 I8/U8 ReduceSum/ReduceMin/ReduceMax operations in +// generic_ops-inl.h +#undef HWY_IF_REDUCE_D +#define HWY_IF_REDUCE_D(D) hwy::EnableIf* = nullptr + +#ifdef HWY_NATIVE_REDUCE_SUM_4_UI8 +#undef HWY_NATIVE_REDUCE_SUM_4_UI8 +#else +#define HWY_NATIVE_REDUCE_SUM_4_UI8 +#endif + +#ifdef HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#undef HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#else +#define HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#endif + +template +HWY_API TFromD ReduceSum(D d, VFromD v) { + return detail::SumOfLanesM(detail::MakeMask(d), v); +} + +template +HWY_API TFromD ReduceMin(D d, VFromD v) { + return detail::MinOfLanesM(detail::MakeMask(d), v); +} + +template +HWY_API TFromD ReduceMax(D d, VFromD v) { + return detail::MaxOfLanesM(detail::MakeMask(d), v); +} + +#ifdef HWY_NATIVE_MASKED_REDUCE_SCALAR +#undef HWY_NATIVE_MASKED_REDUCE_SCALAR +#else +#define HWY_NATIVE_MASKED_REDUCE_SCALAR +#endif + +template +HWY_API TFromD MaskedReduceSum(D /*d*/, M m, VFromD v) { + return detail::SumOfLanesM(m, v); +} +template +HWY_API TFromD MaskedReduceMin(D /*d*/, M m, VFromD v) { + return detail::MinOfLanesM(m, v); +} +template +HWY_API TFromD MaskedReduceMax(D /*d*/, M m, VFromD v) { + return detail::MaxOfLanesM(m, v); +} + +// ------------------------------ SumOfLanes + +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + return Set(d, ReduceSum(d, v)); +} +template +HWY_API VFromD MinOfLanes(D d, VFromD v) { + return Set(d, ReduceMin(d, v)); +} +template +HWY_API VFromD MaxOfLanes(D d, VFromD v) { + return Set(d, ReduceMax(d, v)); +} + +// ------------------------------ MaskedAdd etc. (IfThenElse) + +#ifdef HWY_NATIVE_ZERO_MASKED_ARITH +#undef HWY_NATIVE_ZERO_MASKED_ARITH +#else +#define HWY_NATIVE_ZERO_MASKED_ARITH +#endif + +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV_Z, MaskedMax, max) +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV_Z, MaskedAdd, add) +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV_Z, MaskedSub, sub) +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV_Z, MaskedMul, mul) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMVV_Z, MaskedDiv, div) +HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGMVV_Z, MaskedDiv, div) +HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGMVV_Z, MaskedDiv, div) +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVVV_Z, MaskedMulAdd, mad) +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVVV_Z, MaskedNegMulAdd, msb) + +// I8/U8/I16/U16 MaskedDiv is implemented after I8/U8/I16/U16 Div + +#if HWY_SVE_HAVE_2 +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV_Z, MaskedSaturatedAdd, qadd) +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV_Z, MaskedSaturatedSub, qsub) +#else +template +HWY_API V MaskedSaturatedAdd(M m, V a, V b) { + return IfThenElseZero(m, SaturatedAdd(a, b)); +} + +template +HWY_API V MaskedSaturatedSub(M m, V a, V b) { + return IfThenElseZero(m, SaturatedSub(a, b)); +} +#endif + +template , HWY_IF_I16_D(D)> +HWY_API V MaskedMulFixedPoint15(M m, V a, V b) { + return IfThenElseZero(m, MulFixedPoint15(a, b)); +} + +template >> +HWY_API VFromD MaskedWidenMulPairwiseAdd(D d32, M m, V16 a, V16 b) { + return IfThenElseZero(m, WidenMulPairwiseAdd(d32, a, b)); +} + +template +HWY_API VFromD MaskedWidenMulPairwiseAdd(DF df, M m, VBF a, VBF b) { + return IfThenElseZero(m, WidenMulPairwiseAdd(df, a, b)); +} + +// ================================================== COMPARE + +// mask = f(vector, vector) +#define HWY_SVE_COMPARE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b); \ + } +#define HWY_SVE_COMPARE_N(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b); \ + } + +// ------------------------------ Eq +HWY_SVE_FOREACH(HWY_SVE_COMPARE, Eq, cmpeq) +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, EqN, cmpeq_n) +} // namespace detail + +// ------------------------------ Ne +HWY_SVE_FOREACH(HWY_SVE_COMPARE, Ne, cmpne) +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, NeN, cmpne_n) +} // namespace detail + +// ------------------------------ Lt +HWY_SVE_FOREACH(HWY_SVE_COMPARE, Lt, cmplt) +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, LtN, cmplt_n) +} // namespace detail + +// ------------------------------ Le +HWY_SVE_FOREACH(HWY_SVE_COMPARE, Le, cmple) +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, LeN, cmple_n) +} // namespace detail + +// ------------------------------ Gt/Ge (swapped order) +template +HWY_API svbool_t Gt(const V a, const V b) { + return Lt(b, a); +} +template +HWY_API svbool_t Ge(const V a, const V b) { + return Le(b, a); +} +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, GeN, cmpge_n) +HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, GtN, cmpgt_n) +} // namespace detail + +#undef HWY_SVE_COMPARE +#undef HWY_SVE_COMPARE_N + +// ------------------------------ TestBit +template +HWY_API svbool_t TestBit(const V a, const V bit) { + return detail::NeN(And(a, bit), 0); +} + +// ------------------------------ Min/Max (Lt, IfThenElse) + +HWY_SVE_FOREACH_U(HWY_SVE_RETV_ARGPVV, Min, min) +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Max, max) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Max, maxnm) + +// Workaround for incorrect results with `svmin`. +#if HWY_COMPILER_CLANG +template +HWY_API V Min(V a, V b) { + return IfThenElse(Lt(a, b), a, b); +} +template +HWY_API V Min(V a, V b) { + return IfThenElse(Or(Lt(a, b), Ne(b, b)), a, b); +} +#else +HWY_SVE_FOREACH_I(HWY_SVE_RETV_ARGPVV, Min, min) +HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Min, minnm) +#endif + +namespace detail { +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MinN, min_n) +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MaxN, max_n) +} // namespace detail + +// ================================================== SWIZZLE + +// ------------------------------ ConcatEven/ConcatOdd + +// WARNING: the upper half of these needs fixing up (uzp1/uzp2 use the +// full vector length, not rounded down to a power of two as we require). +namespace detail { + +#define HWY_SVE_CONCAT_EVERY_SECOND(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_INLINE HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \ + return sv##OP##_##CHAR##BITS(lo, hi); \ + } +HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenFull, uzp1) +HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddFull, uzp2) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenFull, + uzp1) +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddFull, + uzp2) +#endif // HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +#if defined(__ARM_FEATURE_SVE_MATMUL_FP64) +HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatEvenBlocks, uzp1q) +HWY_SVE_FOREACH(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddBlocks, uzp2q) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CONCAT_EVERY_SECOND, + ConcatEvenBlocks, uzp1q) +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_CONCAT_EVERY_SECOND, ConcatOddBlocks, + uzp2q) +#endif // HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +#endif // defined(__ARM_FEATURE_SVE_MATMUL_FP64) +#undef HWY_SVE_CONCAT_EVERY_SECOND + +// Used to slide up / shift whole register left; mask indicates which range +// to take from lo, and the rest is filled from hi starting at its lowest. +#define HWY_SVE_SPLICE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME( \ + HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo, svbool_t mask) { \ + return sv##OP##_##CHAR##BITS(mask, lo, hi); \ + } +HWY_SVE_FOREACH(HWY_SVE_SPLICE, Splice, splice) +#if HWY_SVE_HAVE_BF16_FEATURE +HWY_SVE_FOREACH_BF16(HWY_SVE_SPLICE, Splice, splice) +#else +template )> +HWY_INLINE V Splice(V hi, V lo, svbool_t mask) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, Splice(BitCast(du, hi), BitCast(du, lo), mask)); +} +#endif // HWY_SVE_HAVE_BF16_FEATURE +#undef HWY_SVE_SPLICE + +} // namespace detail + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { +#if HWY_SVE_IS_POW2 + if (detail::IsFull(d)) return detail::ConcatOddFull(hi, lo); +#endif + const VFromD hi_odd = detail::ConcatOddFull(hi, hi); + const VFromD lo_odd = detail::ConcatOddFull(lo, lo); + return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2)); +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { +#if HWY_SVE_IS_POW2 + if (detail::IsFull(d)) return detail::ConcatEvenFull(hi, lo); +#endif + const VFromD hi_odd = detail::ConcatEvenFull(hi, hi); + const VFromD lo_odd = detail::ConcatEvenFull(lo, lo); + return detail::Splice(hi_odd, lo_odd, FirstN(d, Lanes(d) / 2)); +} + +HWY_API svuint8_t U8FromU32(const svuint32_t v) { + const DFromV du32; + const RepartitionToNarrow du16; + const RepartitionToNarrow du8; + + const svuint16_t cast16 = BitCast(du16, v); + const svuint16_t x2 = svuzp1_u16(cast16, cast16); + const svuint8_t cast8 = BitCast(du8, x2); + return svuzp1_u8(cast8, cast8); +} + +// ================================================== MASK + +// ------------------------------ MaskFromVec (Ne) +template +HWY_API svbool_t MaskFromVec(const V v) { + using T = TFromV; + return detail::NeN(v, ConvertScalarTo(0)); +} + +// ------------------------------ VecFromMask +template +HWY_API VFromD VecFromMask(const D d, svbool_t mask) { + const RebindToSigned di; + // This generates MOV imm, whereas svdup_n_s8_z generates MOV scalar, which + // requires an extra instruction plus M0 pipeline. + return BitCast(d, IfThenElseZero(mask, Set(di, -1))); +} + +// ------------------------------ BitsFromMask (AndN, Shl, ReduceSum, GetLane +// ConcatEvenFull, U8FromU32) + +namespace detail { + +// For each mask lane (governing lane type T), store 1 or 0 in BYTE lanes. +template +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { + return svdup_n_u8_z(m, 1); +} +template +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { + const ScalableTag d8; + const svuint8_t b16 = BitCast(d8, svdup_n_u16_z(m, 1)); + return detail::ConcatEvenFull(b16, b16); // lower half +} +template +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { + return U8FromU32(svdup_n_u32_z(m, 1)); +} +template +HWY_INLINE svuint8_t BoolFromMask(svbool_t m) { + const ScalableTag d32; + const svuint32_t b64 = BitCast(d32, svdup_n_u64_z(m, 1)); + return U8FromU32(detail::ConcatEvenFull(b64, b64)); // lower half +} + +// Compacts groups of 8 u8 into 8 contiguous bits in a 64-bit lane. +HWY_INLINE svuint64_t BitsFromBool(svuint8_t x) { + const ScalableTag d8; + const ScalableTag d16; + const ScalableTag d32; + const ScalableTag d64; + // TODO(janwas): could use SVE2 BDEP, but it's optional. + x = Or(x, BitCast(d8, ShiftRight<7>(BitCast(d16, x)))); + x = Or(x, BitCast(d8, ShiftRight<14>(BitCast(d32, x)))); + x = Or(x, BitCast(d8, ShiftRight<28>(BitCast(d64, x)))); + return BitCast(d64, x); +} + +} // namespace detail + +// BitsFromMask is required if `HWY_MAX_BYTES <= 64`, which is true for the +// fixed-size SVE targets. +#if HWY_TARGET == HWY_SVE2_128 || HWY_TARGET == HWY_SVE_256 +template +HWY_API uint64_t BitsFromMask(D d, svbool_t mask) { + const Repartition du64; + svuint64_t bits_in_u64 = detail::BitsFromBool(detail::BoolFromMask(mask)); + + constexpr size_t N = MaxLanes(d); + static_assert(N < 64, "SVE2_128 and SVE_256 are only 128 or 256 bits"); + const uint64_t valid = (1ull << N) - 1; + HWY_IF_CONSTEXPR(N <= 8) { + // Upper bits are undefined even if N == 8, hence mask. + return GetLane(bits_in_u64) & valid; + } + + // Up to 8 of the least-significant bits of each u64 lane are valid. + bits_in_u64 = detail::AndN(bits_in_u64, 0xFF); + + // 128-bit vector: only two u64, so avoid ReduceSum. + HWY_IF_CONSTEXPR(HWY_TARGET == HWY_SVE2_128) { + alignas(16) uint64_t lanes[2]; + Store(bits_in_u64, du64, lanes); + // lanes[0] is always valid because we know N > 8, but lanes[1] might + // not be - we may mask it out below. + const uint64_t result = lanes[0] + (lanes[1] << 8); + // 8-bit lanes, no further masking + HWY_IF_CONSTEXPR(N == 16) return result; + return result & valid; + } + + // Shift the 8-bit groups into place in each u64 lane. + alignas(32) uint64_t kShifts[4] = {0 * 8, 1 * 8, 2 * 8, 3 * 8}; + bits_in_u64 = Shl(bits_in_u64, Load(du64, kShifts)); + return ReduceSum(du64, bits_in_u64) & valid; +} + +#endif // HWY_TARGET == HWY_SVE2_128 || HWY_TARGET == HWY_SVE_256 + +// ------------------------------ IsNegative (Lt) +#ifdef HWY_NATIVE_IS_NEGATIVE +#undef HWY_NATIVE_IS_NEGATIVE +#else +#define HWY_NATIVE_IS_NEGATIVE +#endif + +template +HWY_API svbool_t IsNegative(V v) { + const DFromV d; + const RebindToSigned di; + using TI = TFromD; + + return detail::LtN(BitCast(di, v), static_cast(0)); +} + +// ------------------------------ IfVecThenElse (MaskFromVec, IfThenElse) + +#if HWY_SVE_HAVE_2 + +#define HWY_SVE_IF_VEC(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) mask, HWY_SVE_V(BASE, BITS) yes, \ + HWY_SVE_V(BASE, BITS) no) { \ + return sv##OP##_##CHAR##BITS(yes, no, mask); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_IF_VEC, IfVecThenElse, bsl) +#undef HWY_SVE_IF_VEC + +template +HWY_API V IfVecThenElse(const V mask, const V yes, const V no) { + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, IfVecThenElse(BitCast(du, mask), BitCast(du, yes), BitCast(du, no))); +} + +#else + +template +HWY_API V IfVecThenElse(const V mask, const V yes, const V no) { + return Or(And(mask, yes), AndNot(mask, no)); +} + +#endif // HWY_SVE_HAVE_2 + +// ------------------------------ BitwiseIfThenElse + +#ifdef HWY_NATIVE_BITWISE_IF_THEN_ELSE +#undef HWY_NATIVE_BITWISE_IF_THEN_ELSE +#else +#define HWY_NATIVE_BITWISE_IF_THEN_ELSE +#endif + +template +HWY_API V BitwiseIfThenElse(V mask, V yes, V no) { + return IfVecThenElse(mask, yes, no); +} + +// ------------------------------ CopySign (BitwiseIfThenElse) +template +HWY_API V CopySign(const V magn, const V sign) { + const DFromV d; + return BitwiseIfThenElse(SignBit(d), sign, magn); +} + +// ------------------------------ CopySignToAbs +template +HWY_API V CopySignToAbs(const V abs, const V sign) { +#if HWY_SVE_HAVE_2 // CopySign is more efficient than OrAnd + return CopySign(abs, sign); +#else + const DFromV d; + return OrAnd(abs, SignBit(d), sign); +#endif +} + +// ------------------------------ Floating-point classification (Ne) + +template +HWY_API svbool_t IsNaN(const V v) { + return Ne(v, v); // could also use cmpuo +} + +// Per-target flag to prevent generic_ops-inl.h from defining IsInf / IsFinite. +// We use a fused Set/comparison for IsFinite. +#ifdef HWY_NATIVE_ISINF +#undef HWY_NATIVE_ISINF +#else +#define HWY_NATIVE_ISINF +#endif + +template +HWY_API svbool_t IsInf(const V v) { + using T = TFromV; + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; + + // 'Shift left' to clear the sign bit + const VFromD vu = BitCast(du, v); + const VFromD v2 = Add(vu, vu); + // Check for exponent=max and mantissa=0. + const VFromD max2 = Set(di, hwy::MaxExponentTimes2()); + return RebindMask(d, Eq(v2, BitCast(du, max2))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API svbool_t IsFinite(const V v) { + using T = TFromV; + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD exp = + BitCast(di, ShiftRight() + 1>(Add(vu, vu))); + return RebindMask(d, detail::LtN(exp, hwy::MaxExponentField())); +} + +// ------------------------------ MulByPow2/MulByFloorPow2 + +#define HWY_SVE_MUL_BY_POW2(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(int, BITS) exp) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, exp); \ + } + +HWY_SVE_FOREACH_F(HWY_SVE_MUL_BY_POW2, MulByPow2, scale) + +#undef HWY_SVE_MUL_BY_POW2 + +// ------------------------------ MaskedEq etc. +#ifdef HWY_NATIVE_MASKED_COMP +#undef HWY_NATIVE_MASKED_COMP +#else +#define HWY_NATIVE_MASKED_COMP +#endif + +// mask = f(mask, vector, vector) +#define HWY_SVE_COMPARE_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API svbool_t NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, \ + HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(m, a, b); \ + } + +HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedEq, cmpeq) +HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedNe, cmpne) +HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLt, cmplt) +HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLe, cmple) + +#undef HWY_SVE_COMPARE_Z + +template > +HWY_API MFromD MaskedGt(M m, V a, V b) { + // Swap args to reverse comparison + return MaskedLt(m, b, a); +} + +template > +HWY_API MFromD MaskedGe(M m, V a, V b) { + // Swap args to reverse comparison + return MaskedLe(m, b, a); +} + +template > +HWY_API MFromD MaskedIsNaN(const M m, const V v) { + return MaskedNe(m, v, v); +} + +// ================================================== MEMORY + +// ------------------------------ LoadU/MaskedLoad/LoadDup128/StoreU/Stream + +#define HWY_SVE_MEM(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + LoadU(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + return svld1_##CHAR##BITS(detail::MakeMask(d), \ + detail::NativeLanePointer(p)); \ + } \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + MaskedLoad(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + return svld1_##CHAR##BITS(m, detail::NativeLanePointer(p)); \ + } \ + template \ + HWY_API void StoreU(HWY_SVE_V(BASE, BITS) v, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + svst1_##CHAR##BITS(detail::MakeMask(d), detail::NativeLanePointer(p), v); \ + } \ + template \ + HWY_API void Stream(HWY_SVE_V(BASE, BITS) v, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + svstnt1_##CHAR##BITS(detail::MakeMask(d), detail::NativeLanePointer(p), \ + v); \ + } \ + template \ + HWY_API void BlendedStore(HWY_SVE_V(BASE, BITS) v, svbool_t m, \ + HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + svst1_##CHAR##BITS(m, detail::NativeLanePointer(p), v); \ + } + +HWY_SVE_FOREACH(HWY_SVE_MEM, _, _) +HWY_SVE_FOREACH_BF16(HWY_SVE_MEM, _, _) + +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, LoadU(du, detail::U16LanePointer(p))); +} + +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + StoreU(BitCast(du, v), du, detail::U16LanePointer(p)); +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, + MaskedLoad(RebindMask(du, m), du, detail::U16LanePointer(p))); +} + +// MaskedLoadOr is generic and does not require emulation. + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + BlendedStore(BitCast(du, v), RebindMask(du, m), du, + detail::U16LanePointer(p)); +} + +#undef HWY_SVE_MEM + +#if HWY_TARGET != HWY_SVE2_128 +namespace detail { +#define HWY_SVE_LOAD_DUP128(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \ + /* All-true predicate to load all 128 bits. */ \ + return sv##OP##_##CHAR##BITS(HWY_SVE_PTRUE(8), \ + detail::NativeLanePointer(p)); \ + } + +HWY_SVE_FOREACH(HWY_SVE_LOAD_DUP128, LoadDupFull128, ld1rq) +HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD_DUP128, LoadDupFull128, ld1rq) + +template +HWY_API VFromD LoadDupFull128(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, LoadDupFull128(du, detail::U16LanePointer(p))); +} + +} // namespace detail +#endif // HWY_TARGET != HWY_SVE2_128 + +#if HWY_TARGET == HWY_SVE2_128 +// On the HWY_SVE2_128 target, LoadDup128 is the same as LoadU since vectors +// cannot exceed 16 bytes on the HWY_SVE2_128 target. +template +HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { + return LoadU(d, p); +} +#else // HWY_TARGET != HWY_SVE2_128 +// If D().MaxBytes() <= 16 is true, simply do a LoadU operation. +template +HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { + return LoadU(d, p); +} + +// If D().MaxBytes() > 16 is true, need to load the vector using ld1rq +template +HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { + return detail::LoadDupFull128(d, p); +} + +#endif // HWY_TARGET != HWY_SVE2_128 + +// Truncate to smaller size and store +#ifdef HWY_NATIVE_STORE_TRUNCATED +#undef HWY_NATIVE_STORE_TRUNCATED +#else +#define HWY_NATIVE_STORE_TRUNCATED +#endif + +#define HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, TO_BITS) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \ + const HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, TO_BITS) * HWY_RESTRICT p) { \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), detail::NativeLanePointer(p), \ + v); \ + } + +#define HWY_SVE_STORE_TRUNCATED_BYTE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 8) +#define HWY_SVE_STORE_TRUNCATED_HALF(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 16) +#define HWY_SVE_STORE_TRUNCATED_WORD(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, 32) + +HWY_SVE_FOREACH_UI16(HWY_SVE_STORE_TRUNCATED_BYTE, TruncateStore, st1b) +HWY_SVE_FOREACH_UI32(HWY_SVE_STORE_TRUNCATED_BYTE, TruncateStore, st1b) +HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_BYTE, TruncateStore, st1b) +HWY_SVE_FOREACH_UI32(HWY_SVE_STORE_TRUNCATED_HALF, TruncateStore, st1h) +HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_HALF, TruncateStore, st1h) +HWY_SVE_FOREACH_UI64(HWY_SVE_STORE_TRUNCATED_WORD, TruncateStore, st1w) + +#undef HWY_SVE_STORE_TRUNCATED + +// ------------------------------ Load/Store + +// SVE only requires lane alignment, not natural alignment of the entire +// vector, so Load/Store are the same as LoadU/StoreU. +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT p) { + return LoadU(d, p); +} + +template +HWY_API void Store(const V v, D d, TFromD* HWY_RESTRICT p) { + StoreU(v, d, p); +} + +// ------------------------------ MaskedLoadOr + +// SVE MaskedLoad hard-codes zero, so this requires an extra blend. +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + return IfThenElse(m, MaskedLoad(m, d, p), v); +} + +// ------------------------------ ScatterOffset/Index + +#ifdef HWY_NATIVE_SCATTER +#undef HWY_NATIVE_SCATTER +#else +#define HWY_NATIVE_SCATTER +#endif + +#define HWY_SVE_SCATTER_OFFSET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ + HWY_SVE_V(int, BITS) offset) { \ + sv##OP##_s##BITS##offset_##CHAR##BITS(detail::MakeMask(d), base, offset, \ + v); \ + } + +#define HWY_SVE_MASKED_SCATTER_INDEX(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, svbool_t m, \ + HWY_SVE_D(BASE, BITS, N, kPow2) /*d*/, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ + HWY_SVE_V(int, BITS) indices) { \ + sv##OP##_s##BITS##index_##CHAR##BITS(m, base, indices, v); \ + } + +HWY_SVE_FOREACH_UIF3264(HWY_SVE_SCATTER_OFFSET, ScatterOffset, st1_scatter) +HWY_SVE_FOREACH_UIF3264(HWY_SVE_MASKED_SCATTER_INDEX, MaskedScatterIndex, + st1_scatter) +#undef HWY_SVE_SCATTER_OFFSET +#undef HWY_SVE_MASKED_SCATTER_INDEX + +template +HWY_API void ScatterIndex(VFromD v, D d, TFromD* HWY_RESTRICT p, + VFromD> indices) { + MaskedScatterIndex(v, detail::MakeMask(d), d, p, indices); +} + +// ------------------------------ GatherOffset/Index + +#ifdef HWY_NATIVE_GATHER +#undef HWY_NATIVE_GATHER +#else +#define HWY_NATIVE_GATHER +#endif + +#define HWY_SVE_GATHER_OFFSET(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ + HWY_SVE_V(int, BITS) offset) { \ + return sv##OP##_s##BITS##offset_##CHAR##BITS(detail::MakeMask(d), base, \ + offset); \ + } +#define HWY_SVE_MASKED_GATHER_INDEX(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT base, \ + HWY_SVE_V(int, BITS) indices) { \ + const RebindToSigned di; \ + (void)di; /* for HWY_DASSERT */ \ + HWY_DASSERT(AllFalse(di, Lt(indices, Zero(di)))); \ + return sv##OP##_s##BITS##index_##CHAR##BITS(m, base, indices); \ + } + +HWY_SVE_FOREACH_UIF3264(HWY_SVE_GATHER_OFFSET, GatherOffset, ld1_gather) +HWY_SVE_FOREACH_UIF3264(HWY_SVE_MASKED_GATHER_INDEX, MaskedGatherIndex, + ld1_gather) +#undef HWY_SVE_GATHER_OFFSET +#undef HWY_SVE_MASKED_GATHER_INDEX + +template +HWY_API VFromD MaskedGatherIndexOr(VFromD no, svbool_t m, D d, + const TFromD* HWY_RESTRICT p, + VFromD> indices) { + return IfThenElse(m, MaskedGatherIndex(m, d, p, indices), no); +} + +template +HWY_API VFromD GatherIndex(D d, const TFromD* HWY_RESTRICT p, + VFromD> indices) { + return MaskedGatherIndex(detail::MakeMask(d), d, p, indices); +} + +// ------------------------------ LoadInterleaved2 + +// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +#define HWY_SVE_LOAD2(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ + HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1) { \ + const HWY_SVE_TUPLE(BASE, BITS, 2) tuple = sv##OP##_##CHAR##BITS( \ + detail::MakeMask(d), detail::NativeLanePointer(unaligned)); \ + v0 = svget2(tuple, 0); \ + v1 = svget2(tuple, 1); \ + } +HWY_SVE_FOREACH(HWY_SVE_LOAD2, LoadInterleaved2, ld2) +HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD2, LoadInterleaved2, ld2) + +#undef HWY_SVE_LOAD2 + +// ------------------------------ LoadInterleaved3 + +#define HWY_SVE_LOAD3(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ + HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1, \ + HWY_SVE_V(BASE, BITS) & v2) { \ + const HWY_SVE_TUPLE(BASE, BITS, 3) tuple = sv##OP##_##CHAR##BITS( \ + detail::MakeMask(d), detail::NativeLanePointer(unaligned)); \ + v0 = svget3(tuple, 0); \ + v1 = svget3(tuple, 1); \ + v2 = svget3(tuple, 2); \ + } +HWY_SVE_FOREACH(HWY_SVE_LOAD3, LoadInterleaved3, ld3) +HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD3, LoadInterleaved3, ld3) + +#undef HWY_SVE_LOAD3 + +// ------------------------------ LoadInterleaved4 + +#define HWY_SVE_LOAD4(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned, \ + HWY_SVE_V(BASE, BITS) & v0, HWY_SVE_V(BASE, BITS) & v1, \ + HWY_SVE_V(BASE, BITS) & v2, HWY_SVE_V(BASE, BITS) & v3) { \ + const HWY_SVE_TUPLE(BASE, BITS, 4) tuple = sv##OP##_##CHAR##BITS( \ + detail::MakeMask(d), detail::NativeLanePointer(unaligned)); \ + v0 = svget4(tuple, 0); \ + v1 = svget4(tuple, 1); \ + v2 = svget4(tuple, 2); \ + v3 = svget4(tuple, 3); \ + } +HWY_SVE_FOREACH(HWY_SVE_LOAD4, LoadInterleaved4, ld4) +HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD4, LoadInterleaved4, ld4) + +#undef HWY_SVE_LOAD4 + +// ------------------------------ StoreInterleaved2 + +#define HWY_SVE_STORE2(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), \ + detail::NativeLanePointer(unaligned), \ + Create2(d, v0, v1)); \ + } +HWY_SVE_FOREACH(HWY_SVE_STORE2, StoreInterleaved2, st2) +HWY_SVE_FOREACH_BF16(HWY_SVE_STORE2, StoreInterleaved2, st2) + +#undef HWY_SVE_STORE2 + +// ------------------------------ StoreInterleaved3 + +#define HWY_SVE_STORE3(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ + HWY_SVE_V(BASE, BITS) v2, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), \ + detail::NativeLanePointer(unaligned), \ + Create3(d, v0, v1, v2)); \ + } +HWY_SVE_FOREACH(HWY_SVE_STORE3, StoreInterleaved3, st3) +HWY_SVE_FOREACH_BF16(HWY_SVE_STORE3, StoreInterleaved3, st3) + +#undef HWY_SVE_STORE3 + +// ------------------------------ StoreInterleaved4 + +#define HWY_SVE_STORE4(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \ + HWY_SVE_V(BASE, BITS) v2, HWY_SVE_V(BASE, BITS) v3, \ + HWY_SVE_D(BASE, BITS, N, kPow2) d, \ + HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \ + sv##OP##_##CHAR##BITS(detail::MakeMask(d), \ + detail::NativeLanePointer(unaligned), \ + Create4(d, v0, v1, v2, v3)); \ + } +HWY_SVE_FOREACH(HWY_SVE_STORE4, StoreInterleaved4, st4) +HWY_SVE_FOREACH_BF16(HWY_SVE_STORE4, StoreInterleaved4, st4) + +#undef HWY_SVE_STORE4 + +// Fall back on generic Load/StoreInterleaved[234] for any emulated types. +// Requires HWY_GENERIC_IF_EMULATED_D mirrors HWY_SVE_IF_EMULATED_D. + +// ================================================== CONVERT + +// ------------------------------ PromoteTo + +// Same sign +#define HWY_SVE_PROMOTE_TO(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) NAME( \ + HWY_SVE_D(BASE, BITS, N, kPow2) /* tag */, HWY_SVE_V(BASE, HALF) v) { \ + return sv##OP##_##CHAR##BITS(v); \ + } + +HWY_SVE_FOREACH_UI16(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) +HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) +HWY_SVE_FOREACH_UI64(HWY_SVE_PROMOTE_TO, PromoteTo, unpklo) + +// 2x +template +HWY_API svuint32_t PromoteTo(Simd dto, svuint8_t vfrom) { + const RepartitionToWide> d2; + return PromoteTo(dto, PromoteTo(d2, vfrom)); +} +template +HWY_API svint32_t PromoteTo(Simd dto, svint8_t vfrom) { + const RepartitionToWide> d2; + return PromoteTo(dto, PromoteTo(d2, vfrom)); +} +template +HWY_API svuint64_t PromoteTo(Simd dto, svuint16_t vfrom) { + const RepartitionToWide> d2; + return PromoteTo(dto, PromoteTo(d2, vfrom)); +} +template +HWY_API svint64_t PromoteTo(Simd dto, svint16_t vfrom) { + const RepartitionToWide> d2; + return PromoteTo(dto, PromoteTo(d2, vfrom)); +} + +// 3x +template +HWY_API svuint64_t PromoteTo(Simd dto, svuint8_t vfrom) { + const RepartitionToNarrow d4; + const RepartitionToNarrow d2; + return PromoteTo(dto, PromoteTo(d4, PromoteTo(d2, vfrom))); +} +template +HWY_API svint64_t PromoteTo(Simd dto, svint8_t vfrom) { + const RepartitionToNarrow d4; + const RepartitionToNarrow d2; + return PromoteTo(dto, PromoteTo(d4, PromoteTo(d2, vfrom))); +} + +// Sign change +template ), sizeof(TFromV))> +HWY_API VFromD PromoteTo(D di, V v) { + const RebindToUnsigned du; + return BitCast(di, PromoteTo(du, v)); +} + +// ------------------------------ PromoteTo F + +// Per-target flag to prevent generic_ops-inl.h from defining f16 conversions. +#ifdef HWY_NATIVE_F16C +#undef HWY_NATIVE_F16C +#else +#define HWY_NATIVE_F16C +#endif + +// Unlike Highway's ZipLower, this returns the same type. +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, ZipLowerSame, zip1) +} // namespace detail + +template +HWY_API svfloat32_t PromoteTo(Simd /* d */, + const svfloat16_t v) { + // svcvt* expects inputs in even lanes, whereas Highway wants lower lanes, so + // first replicate each lane once. + const svfloat16_t vv = detail::ZipLowerSame(v, v); + return svcvt_f32_f16_x(detail::PTrue(Simd()), vv); +} + +#ifdef HWY_NATIVE_PROMOTE_F16_TO_F64 +#undef HWY_NATIVE_PROMOTE_F16_TO_F64 +#else +#define HWY_NATIVE_PROMOTE_F16_TO_F64 +#endif + +template +HWY_API svfloat64_t PromoteTo(Simd /* d */, + const svfloat16_t v) { + // svcvt* expects inputs in even lanes, whereas Highway wants lower lanes, so + // first replicate each lane once. + const svfloat16_t vv = detail::ZipLowerSame(v, v); + return svcvt_f64_f16_x(detail::PTrue(Simd()), + detail::ZipLowerSame(vv, vv)); +} + +template +HWY_API svfloat64_t PromoteTo(Simd /* d */, + const svfloat32_t v) { + const svfloat32_t vv = detail::ZipLowerSame(v, v); + return svcvt_f64_f32_x(detail::PTrue(Simd()), vv); +} + +template +HWY_API svfloat64_t PromoteTo(Simd /* d */, + const svint32_t v) { + const svint32_t vv = detail::ZipLowerSame(v, v); + return svcvt_f64_s32_x(detail::PTrue(Simd()), vv); +} + +template +HWY_API svfloat64_t PromoteTo(Simd /* d */, + const svuint32_t v) { + const svuint32_t vv = detail::ZipLowerSame(v, v); + return svcvt_f64_u32_x(detail::PTrue(Simd()), vv); +} + +template +HWY_API svint64_t PromoteTo(Simd /* d */, + const svfloat32_t v) { + const svfloat32_t vv = detail::ZipLowerSame(v, v); + return svcvt_s64_f32_x(detail::PTrue(Simd()), vv); +} + +template +HWY_API svuint64_t PromoteTo(Simd /* d */, + const svfloat32_t v) { + const svfloat32_t vv = detail::ZipLowerSame(v, v); + return svcvt_u64_f32_x(detail::PTrue(Simd()), vv); +} + +// ------------------------------ PromoteUpperTo + +namespace detail { +HWY_SVE_FOREACH_UI16(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi) +HWY_SVE_FOREACH_UI32(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi) +HWY_SVE_FOREACH_UI64(HWY_SVE_PROMOTE_TO, PromoteUpperTo, unpkhi) +#undef HWY_SVE_PROMOTE_TO +} // namespace detail + +#ifdef HWY_NATIVE_PROMOTE_UPPER_TO +#undef HWY_NATIVE_PROMOTE_UPPER_TO +#else +#define HWY_NATIVE_PROMOTE_UPPER_TO +#endif + +// Unsigned->Unsigned or Signed->Signed +template , typename TV = TFromV, + hwy::EnableIf() && IsInteger() && + (IsSigned() == IsSigned())>* = nullptr> +HWY_API VFromD PromoteUpperTo(D d, V v) { + if (detail::IsFull(d)) { + return detail::PromoteUpperTo(d, v); + } + const Rebind, decltype(d)> dh; + return PromoteTo(d, UpperHalf(dh, v)); +} + +// Differing signs or either is float +template , typename TV = TFromV, + hwy::EnableIf() || !IsInteger() || + (IsSigned() != IsSigned())>* = nullptr> +HWY_API VFromD PromoteUpperTo(D d, V v) { + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V + // because it cannot be deduced from D (could be either bf16 or f16). + const Rebind, decltype(d)> dh; + return PromoteTo(d, UpperHalf(dh, v)); +} + +// ------------------------------ DemoteTo U + +namespace detail { + +// Saturates unsigned vectors to half/quarter-width TN. +template +VU SaturateU(VU v) { + return detail::MinN(v, static_cast>(LimitsMax())); +} + +// Saturates unsigned vectors to half/quarter-width TN. +template +VI SaturateI(VI v) { + return detail::MinN(detail::MaxN(v, LimitsMin()), LimitsMax()); +} + +} // namespace detail + +template +HWY_API svuint8_t DemoteTo(Simd dn, const svint16_t v) { +#if HWY_SVE_HAVE_2 + const svuint8_t vn = BitCast(dn, svqxtunb_s16(v)); +#else + const DFromV di; + const RebindToUnsigned du; + using TN = TFromD; + // First clamp negative numbers to zero and cast to unsigned. + const svuint16_t clamped = BitCast(du, detail::MaxN(v, 0)); + // Saturate to unsigned-max and halve the width. + const svuint8_t vn = BitCast(dn, detail::SaturateU(clamped)); +#endif + return svuzp1_u8(vn, vn); +} + +template +HWY_API svuint16_t DemoteTo(Simd dn, const svint32_t v) { +#if HWY_SVE_HAVE_2 + const svuint16_t vn = BitCast(dn, svqxtunb_s32(v)); +#else + const DFromV di; + const RebindToUnsigned du; + using TN = TFromD; + // First clamp negative numbers to zero and cast to unsigned. + const svuint32_t clamped = BitCast(du, detail::MaxN(v, 0)); + // Saturate to unsigned-max and halve the width. + const svuint16_t vn = BitCast(dn, detail::SaturateU(clamped)); +#endif + return svuzp1_u16(vn, vn); +} + +template +HWY_API svuint8_t DemoteTo(Simd dn, const svint32_t v) { + const DFromV di; + const RebindToUnsigned du; + const RepartitionToNarrow d2; +#if HWY_SVE_HAVE_2 + const svuint16_t cast16 = BitCast(d2, svqxtnb_u16(svqxtunb_s32(v))); +#else + using TN = TFromD; + // First clamp negative numbers to zero and cast to unsigned. + const svuint32_t clamped = BitCast(du, detail::MaxN(v, 0)); + // Saturate to unsigned-max and quarter the width. + const svuint16_t cast16 = BitCast(d2, detail::SaturateU(clamped)); +#endif + const svuint8_t x2 = BitCast(dn, svuzp1_u16(cast16, cast16)); + return svuzp1_u8(x2, x2); +} + +template +HWY_API svuint8_t DemoteTo(Simd dn, const svuint16_t v) { +#if HWY_SVE_HAVE_2 + const svuint8_t vn = BitCast(dn, svqxtnb_u16(v)); +#else + using TN = TFromD; + const svuint8_t vn = BitCast(dn, detail::SaturateU(v)); +#endif + return svuzp1_u8(vn, vn); +} + +template +HWY_API svuint16_t DemoteTo(Simd dn, const svuint32_t v) { +#if HWY_SVE_HAVE_2 + const svuint16_t vn = BitCast(dn, svqxtnb_u32(v)); +#else + using TN = TFromD; + const svuint16_t vn = BitCast(dn, detail::SaturateU(v)); +#endif + return svuzp1_u16(vn, vn); +} + +template +HWY_API svuint8_t DemoteTo(Simd dn, const svuint32_t v) { + using TN = TFromD; + return U8FromU32(detail::SaturateU(v)); +} + +// ------------------------------ Truncations + +template +HWY_API svuint8_t TruncateTo(Simd /* tag */, + const svuint64_t v) { + const DFromV d; + const svuint8_t v1 = BitCast(d, v); + const svuint8_t v2 = svuzp1_u8(v1, v1); + const svuint8_t v3 = svuzp1_u8(v2, v2); + return svuzp1_u8(v3, v3); +} + +template +HWY_API svuint16_t TruncateTo(Simd /* tag */, + const svuint64_t v) { + const DFromV d; + const svuint16_t v1 = BitCast(d, v); + const svuint16_t v2 = svuzp1_u16(v1, v1); + return svuzp1_u16(v2, v2); +} + +template +HWY_API svuint32_t TruncateTo(Simd /* tag */, + const svuint64_t v) { + const DFromV d; + const svuint32_t v1 = BitCast(d, v); + return svuzp1_u32(v1, v1); +} + +template +HWY_API svuint8_t TruncateTo(Simd /* tag */, + const svuint32_t v) { + const DFromV d; + const svuint8_t v1 = BitCast(d, v); + const svuint8_t v2 = svuzp1_u8(v1, v1); + return svuzp1_u8(v2, v2); +} + +template +HWY_API svuint16_t TruncateTo(Simd /* tag */, + const svuint32_t v) { + const DFromV d; + const svuint16_t v1 = BitCast(d, v); + return svuzp1_u16(v1, v1); +} + +template +HWY_API svuint8_t TruncateTo(Simd /* tag */, + const svuint16_t v) { + const DFromV d; + const svuint8_t v1 = BitCast(d, v); + return svuzp1_u8(v1, v1); +} + +// ------------------------------ DemoteTo I + +template +HWY_API svint8_t DemoteTo(Simd dn, const svint16_t v) { +#if HWY_SVE_HAVE_2 + const svint8_t vn = BitCast(dn, svqxtnb_s16(v)); +#else + using TN = TFromD; + const svint8_t vn = BitCast(dn, detail::SaturateI(v)); +#endif + return svuzp1_s8(vn, vn); +} + +template +HWY_API svint16_t DemoteTo(Simd dn, const svint32_t v) { +#if HWY_SVE_HAVE_2 + const svint16_t vn = BitCast(dn, svqxtnb_s32(v)); +#else + using TN = TFromD; + const svint16_t vn = BitCast(dn, detail::SaturateI(v)); +#endif + return svuzp1_s16(vn, vn); +} + +template +HWY_API svint8_t DemoteTo(Simd dn, const svint32_t v) { + const RepartitionToWide d2; +#if HWY_SVE_HAVE_2 + const svint16_t cast16 = BitCast(d2, svqxtnb_s16(svqxtnb_s32(v))); +#else + using TN = TFromD; + const svint16_t cast16 = BitCast(d2, detail::SaturateI(v)); +#endif + const svint8_t v2 = BitCast(dn, svuzp1_s16(cast16, cast16)); + return BitCast(dn, svuzp1_s8(v2, v2)); +} + +// ------------------------------ I64/U64 DemoteTo + +template +HWY_API svint32_t DemoteTo(Simd dn, const svint64_t v) { + const Rebind du64; + const RebindToUnsigned dn_u; +#if HWY_SVE_HAVE_2 + const svuint64_t vn = BitCast(du64, svqxtnb_s64(v)); +#else + using TN = TFromD; + const svuint64_t vn = BitCast(du64, detail::SaturateI(v)); +#endif + return BitCast(dn, TruncateTo(dn_u, vn)); +} + +template +HWY_API svint16_t DemoteTo(Simd dn, const svint64_t v) { + const Rebind du64; + const RebindToUnsigned dn_u; +#if HWY_SVE_HAVE_2 + const svuint64_t vn = BitCast(du64, svqxtnb_s32(svqxtnb_s64(v))); +#else + using TN = TFromD; + const svuint64_t vn = BitCast(du64, detail::SaturateI(v)); +#endif + return BitCast(dn, TruncateTo(dn_u, vn)); +} + +template +HWY_API svint8_t DemoteTo(Simd dn, const svint64_t v) { + const Rebind du64; + const RebindToUnsigned dn_u; + using TN = TFromD; + const svuint64_t vn = BitCast(du64, detail::SaturateI(v)); + return BitCast(dn, TruncateTo(dn_u, vn)); +} + +template +HWY_API svuint32_t DemoteTo(Simd dn, const svint64_t v) { + const Rebind du64; +#if HWY_SVE_HAVE_2 + const svuint64_t vn = BitCast(du64, svqxtunb_s64(v)); +#else + using TN = TFromD; + // First clamp negative numbers to zero and cast to unsigned. + const svuint64_t clamped = BitCast(du64, detail::MaxN(v, 0)); + // Saturate to unsigned-max + const svuint64_t vn = detail::SaturateU(clamped); +#endif + return TruncateTo(dn, vn); +} + +template +HWY_API svuint16_t DemoteTo(Simd dn, const svint64_t v) { + const Rebind du64; +#if HWY_SVE_HAVE_2 + const svuint64_t vn = BitCast(du64, svqxtnb_u32(svqxtunb_s64(v))); +#else + using TN = TFromD; + // First clamp negative numbers to zero and cast to unsigned. + const svuint64_t clamped = BitCast(du64, detail::MaxN(v, 0)); + // Saturate to unsigned-max + const svuint64_t vn = detail::SaturateU(clamped); +#endif + return TruncateTo(dn, vn); +} + +template +HWY_API svuint8_t DemoteTo(Simd dn, const svint64_t v) { + const Rebind du64; + using TN = TFromD; + // First clamp negative numbers to zero and cast to unsigned. + const svuint64_t clamped = BitCast(du64, detail::MaxN(v, 0)); + // Saturate to unsigned-max + const svuint64_t vn = detail::SaturateU(clamped); + return TruncateTo(dn, vn); +} + +template +HWY_API svuint32_t DemoteTo(Simd dn, const svuint64_t v) { + const Rebind du64; +#if HWY_SVE_HAVE_2 + const svuint64_t vn = BitCast(du64, svqxtnb_u64(v)); +#else + using TN = TFromD; + const svuint64_t vn = BitCast(du64, detail::SaturateU(v)); +#endif + return TruncateTo(dn, vn); +} + +template +HWY_API svuint16_t DemoteTo(Simd dn, const svuint64_t v) { + const Rebind du64; +#if HWY_SVE_HAVE_2 + const svuint64_t vn = BitCast(du64, svqxtnb_u32(svqxtnb_u64(v))); +#else + using TN = TFromD; + const svuint64_t vn = BitCast(du64, detail::SaturateU(v)); +#endif + return TruncateTo(dn, vn); +} + +template +HWY_API svuint8_t DemoteTo(Simd dn, const svuint64_t v) { + const Rebind du64; + using TN = TFromD; + const svuint64_t vn = BitCast(du64, detail::SaturateU(v)); + return TruncateTo(dn, vn); +} + +// ------------------------------ Unsigned to signed demotions + +// Disable the default unsigned to signed DemoteTo/ReorderDemote2To +// implementations in generic_ops-inl.h on SVE/SVE2 as the SVE/SVE2 targets have +// target-specific implementations of the unsigned to signed DemoteTo and +// ReorderDemote2To ops + +// NOTE: hwy::EnableIf()>* = nullptr is used instead of +// hwy::EnableIf* = nullptr to avoid compiler errors since +// !hwy::IsSame() is always false and as !hwy::IsSame() will cause +// SFINAE to occur instead of a hard error due to a dependency on the V template +// argument +#undef HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V +#define HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V(V) \ + hwy::EnableIf()>* = nullptr + +template ) - 1)> +HWY_API VFromD DemoteTo(D dn, V v) { + const RebindToUnsigned dn_u; + return BitCast(dn, TruncateTo(dn_u, detail::SaturateU>(v))); +} + +// ------------------------------ PromoteEvenTo/PromoteOddTo + +// Signed to signed PromoteEvenTo: 1 instruction instead of 2 in generic-inl.h. +// Might as well also enable unsigned to unsigned, though it is just an And. +namespace detail { +HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPV, NativePromoteEvenTo, extb) +HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPV, NativePromoteEvenTo, exth) +HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGPV, NativePromoteEvenTo, extw) +} // namespace detail + +#include "hwy/ops/inside-inl.h" + +// ------------------------------ DemoteTo F + +// We already toggled HWY_NATIVE_F16C above. + +template +HWY_API svfloat16_t DemoteTo(Simd d, const svfloat32_t v) { + const svfloat16_t in_even = svcvt_f16_f32_x(detail::PTrue(d), v); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +#ifdef HWY_NATIVE_DEMOTE_F64_TO_F16 +#undef HWY_NATIVE_DEMOTE_F64_TO_F16 +#else +#define HWY_NATIVE_DEMOTE_F64_TO_F16 +#endif + +template +HWY_API svfloat16_t DemoteTo(Simd d, const svfloat64_t v) { + const svfloat16_t in_lo16 = svcvt_f16_f64_x(detail::PTrue(d), v); + const svfloat16_t in_even = detail::ConcatEvenFull(in_lo16, in_lo16); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + +#if !HWY_SVE_HAVE_F32_TO_BF16C +namespace detail { + +// Round a F32 value to the nearest BF16 value, with the result returned as the +// rounded F32 value bitcasted to an U32 + +// RoundF32ForDemoteToBF16 also converts NaN values to QNaN values to prevent +// NaN F32 values from being converted to an infinity +HWY_INLINE svuint32_t RoundF32ForDemoteToBF16(svfloat32_t v) { + const DFromV df32; + const RebindToUnsigned du32; + + const auto is_non_nan = Eq(v, v); + const auto bits32 = BitCast(du32, v); + + const auto round_incr = + detail::AddN(detail::AndN(ShiftRight<16>(bits32), 1u), 0x7FFFu); + return MaskedAddOr(detail::OrN(bits32, 0x00400000u), is_non_nan, bits32, + round_incr); +} + +} // namespace detail +#endif // !HWY_SVE_HAVE_F32_TO_BF16C + +template +HWY_API VBF16 DemoteTo(Simd dbf16, svfloat32_t v) { +#if HWY_SVE_HAVE_F32_TO_BF16C + const VBF16 in_even = svcvt_bf16_f32_x(detail::PTrue(dbf16), v); + return detail::ConcatEvenFull(in_even, in_even); +#else + const svuint16_t in_odd = + BitCast(ScalableTag(), detail::RoundF32ForDemoteToBF16(v)); + return BitCast(dbf16, detail::ConcatOddFull(in_odd, in_odd)); // lower half +#endif +} + +template +HWY_API svfloat32_t DemoteTo(Simd d, const svfloat64_t v) { + const svfloat32_t in_even = svcvt_f32_f64_x(detail::PTrue(d), v); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +template +HWY_API svint32_t DemoteTo(Simd d, const svfloat64_t v) { + const svint32_t in_even = svcvt_s32_f64_x(detail::PTrue(d), v); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +template +HWY_API svuint32_t DemoteTo(Simd d, const svfloat64_t v) { + const svuint32_t in_even = svcvt_u32_f64_x(detail::PTrue(d), v); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +template +HWY_API svfloat32_t DemoteTo(Simd d, const svint64_t v) { + const svfloat32_t in_even = svcvt_f32_s64_x(detail::PTrue(d), v); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +template +HWY_API svfloat32_t DemoteTo(Simd d, const svuint64_t v) { + const svfloat32_t in_even = svcvt_f32_u64_x(detail::PTrue(d), v); + return detail::ConcatEvenFull(in_even, + in_even); // lower half +} + +// ------------------------------ ConvertTo F + +#define HWY_SVE_CONVERT(BASE, CHAR, BITS, HALF, NAME, OP) \ + /* Float from signed */ \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(int, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_s##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } \ + /* Float from unsigned */ \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, HWY_SVE_V(uint, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_u##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } \ + /* Signed from float, rounding toward zero */ \ + template \ + HWY_API HWY_SVE_V(int, BITS) \ + NAME(HWY_SVE_D(int, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_s##BITS##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } \ + /* Unsigned from float, rounding toward zero */ \ + template \ + HWY_API HWY_SVE_V(uint, BITS) \ + NAME(HWY_SVE_D(uint, BITS, N, kPow2) /* d */, HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_u##BITS##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } + +HWY_SVE_FOREACH_F(HWY_SVE_CONVERT, ConvertTo, cvt) +#undef HWY_SVE_CONVERT + +// ------------------------------ MaskedConvertTo F + +#define HWY_SVE_MASKED_CONVERT_TO_OR_ZERO(BASE, CHAR, BITS, HALF, NAME, OP) \ + /* Float from signed */ \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_V(int, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_s##BITS##_z(m, v); \ + } \ + /* Float from unsigned */ \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, \ + HWY_SVE_V(uint, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_u##BITS##_z(m, v); \ + } \ + /* Signed from float, rounding toward zero */ \ + template \ + HWY_API HWY_SVE_V(int, BITS) \ + NAME(svbool_t m, HWY_SVE_D(int, BITS, N, kPow2) /* d */, \ + HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_s##BITS##_##CHAR##BITS##_z(m, v); \ + } \ + /* Unsigned from float, rounding toward zero */ \ + template \ + HWY_API HWY_SVE_V(uint, BITS) \ + NAME(svbool_t m, HWY_SVE_D(uint, BITS, N, kPow2) /* d */, \ + HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_u##BITS##_##CHAR##BITS##_z(m, v); \ + } + +HWY_SVE_FOREACH_F(HWY_SVE_MASKED_CONVERT_TO_OR_ZERO, MaskedConvertTo, cvt) +#undef HWY_SVE_MASKED_CONVERT_TO_OR_ZERO + +// ------------------------------ NearestInt (Round, ConvertTo) +template >> +HWY_API VFromD NearestInt(VF v) { + // No single instruction, round then truncate. + return ConvertTo(DI(), Round(v)); +} + +template +HWY_API VFromD DemoteToNearestInt(DI32 di32, + VFromD> v) { + // No single instruction, round then demote. + return DemoteTo(di32, Round(v)); +} + +// ------------------------------ Iota (AddN, ConvertTo) + +#define HWY_SVE_IOTA(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /* d */, T2 first) { \ + return sv##OP##_##CHAR##BITS( \ + ConvertScalarTo(first), 1); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_IOTA, Iota, index) +#undef HWY_SVE_IOTA + +template , typename T2, HWY_IF_FLOAT(T)> +HWY_API VFromD Iota(const D d, T2 first) { + const RebindToSigned di; + const T first_f = ConvertScalarTo(first); + const VFromD iota_f = ConvertTo(d, Iota(di, 0)); + return detail::AddN(iota_f, first_f); +} + +// ================================================== LANE ACCESS + +// ------------------------------ ExtractLane (GetLaneM, FirstN) +template +HWY_API TFromV ExtractLane(V v, size_t i) { + return detail::GetLaneM(v, FirstN(DFromV(), i)); +} + +// ------------------------------ InsertLane (IfThenElse, EqN) +template +HWY_API V InsertLane(const V v, size_t i, T t) { + static_assert(sizeof(TFromV) == sizeof(T), "Lane size mismatch"); + const DFromV d; + const RebindToSigned di; + using TI = TFromD; + const svbool_t is_i = detail::EqN(Iota(di, 0), static_cast(i)); + // The actual type may be int16_t for special floats; copy, not cast. + TFromV t_bits; + hwy::CopySameSize(&t, &t_bits); + return IfThenElse(RebindMask(d, is_i), Set(d, t_bits), v); +} + +// ------------------------------ GetExponent + +#if HWY_SVE_HAVE_2 || HWY_IDE +#ifdef HWY_NATIVE_GET_EXPONENT +#undef HWY_NATIVE_GET_EXPONENT +#else +#define HWY_NATIVE_GET_EXPONENT +#endif + +namespace detail { +#define HWY_SVE_GET_EXP(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(int, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \ + } +HWY_SVE_FOREACH_F(HWY_SVE_GET_EXP, GetExponent, logb) +#undef HWY_SVE_GET_EXP +} // namespace detail + +template +HWY_API V GetExponent(V v) { + const DFromV d; + const RebindToSigned di; + const VFromD exponent_int = detail::GetExponent(v); + // convert integer to original type + return ConvertTo(d, exponent_int); +} +#endif // HWY_SVE_HAVE_2 + +// ------------------------------ InterleaveLower + +template +HWY_API V InterleaveLower(D d, const V a, const V b) { + static_assert(IsSame, TFromV>(), "D/V mismatch"); +#if HWY_TARGET == HWY_SVE2_128 + (void)d; + return detail::ZipLowerSame(a, b); +#else + // Move lower halves of blocks to lower half of vector. + const Repartition d64; + const auto a64 = BitCast(d64, a); + const auto b64 = BitCast(d64, b); + const auto a_blocks = detail::ConcatEvenFull(a64, a64); // lower half + const auto b_blocks = detail::ConcatEvenFull(b64, b64); + return detail::ZipLowerSame(BitCast(d, a_blocks), BitCast(d, b_blocks)); +#endif +} + +template +HWY_API V InterleaveLower(const V a, const V b) { + return InterleaveLower(DFromV(), a, b); +} + +// ------------------------------ InterleaveUpper + +// Only use zip2 if vector are a powers of two, otherwise getting the actual +// "upper half" requires MaskUpperHalf. +namespace detail { +// Unlike Highway's ZipUpper, this returns the same type. +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, ZipUpperSame, zip2) +} // namespace detail + +// Full vector: guaranteed to have at least one block +template , + hwy::EnableIf* = nullptr> +HWY_API V InterleaveUpper(D d, const V a, const V b) { +#if HWY_TARGET == HWY_SVE2_128 + (void)d; + return detail::ZipUpperSame(a, b); +#else + // Move upper halves of blocks to lower half of vector. + const Repartition d64; + const auto a64 = BitCast(d64, a); + const auto b64 = BitCast(d64, b); + const auto a_blocks = detail::ConcatOddFull(a64, a64); // lower half + const auto b_blocks = detail::ConcatOddFull(b64, b64); + return detail::ZipLowerSame(BitCast(d, a_blocks), BitCast(d, b_blocks)); +#endif +} + +// Capped/fraction: need runtime check +template , + hwy::EnableIf* = nullptr> +HWY_API V InterleaveUpper(D d, const V a, const V b) { + // Less than one block: treat as capped + if (Lanes(d) * sizeof(TFromD) < 16) { + const Half d2; + return InterleaveLower(d, UpperHalf(d2, a), UpperHalf(d2, b)); + } + return InterleaveUpper(DFromV(), a, b); +} + +// ------------------------------ InterleaveWholeLower +#ifdef HWY_TOGGLE_INTERLEAVE_WHOLE +#undef HWY_TOGGLE_INTERLEAVE_WHOLE +#else +#define HWY_TOGGLE_INTERLEAVE_WHOLE +#endif + +template +HWY_API VFromD InterleaveWholeLower(D /*d*/, VFromD a, VFromD b) { + return detail::ZipLowerSame(a, b); +} + +// ------------------------------ InterleaveWholeUpper + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + if (HWY_SVE_IS_POW2 && detail::IsFull(d)) { + return detail::ZipUpperSame(a, b); + } + + const Half d2; + return InterleaveWholeLower(d, UpperHalf(d2, a), UpperHalf(d2, b)); +} + +// ------------------------------ Per4LaneBlockShuffle + +namespace detail { + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x88> /*idx_3210_tag*/, + hwy::SizeTag /*lane_size_tag*/, + hwy::SizeTag /*vect_size_tag*/, + V v) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWide dw; + + const auto evens = BitCast(dw, ConcatEvenFull(v, v)); + return BitCast(d, ZipLowerSame(evens, evens)); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xDD> /*idx_3210_tag*/, + hwy::SizeTag /*lane_size_tag*/, + hwy::SizeTag /*vect_size_tag*/, + V v) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWide dw; + + const auto odds = BitCast(dw, ConcatOddFull(v, v)); + return BitCast(d, ZipLowerSame(odds, odds)); +} + +} // namespace detail + +// ================================================== COMBINE + +namespace detail { + +#if (HWY_TARGET == HWY_SVE_256 && HWY_HAVE_CONSTEXPR_LANES) || HWY_IDE +template +svbool_t MaskLowerHalf(D d) { + switch (MaxLanes(d)) { + case 32: + return svptrue_pat_b8(SV_VL16); + case 16: + return svptrue_pat_b8(SV_VL8); + case 8: + return svptrue_pat_b8(SV_VL4); + case 4: + return svptrue_pat_b8(SV_VL2); + default: + return svptrue_pat_b8(SV_VL1); + } +} +template +svbool_t MaskLowerHalf(D d) { + switch (MaxLanes(d)) { + case 16: + return svptrue_pat_b16(SV_VL8); + case 8: + return svptrue_pat_b16(SV_VL4); + case 4: + return svptrue_pat_b16(SV_VL2); + default: + return svptrue_pat_b16(SV_VL1); + } +} +template +svbool_t MaskLowerHalf(D d) { + switch (MaxLanes(d)) { + case 8: + return svptrue_pat_b32(SV_VL4); + case 4: + return svptrue_pat_b32(SV_VL2); + default: + return svptrue_pat_b32(SV_VL1); + } +} +template +svbool_t MaskLowerHalf(D d) { + switch (MaxLanes(d)) { + case 4: + return svptrue_pat_b64(SV_VL2); + default: + return svptrue_pat_b64(SV_VL1); + } +} +#endif +#if (HWY_TARGET == HWY_SVE2_128 && HWY_HAVE_CONSTEXPR_LANES) || HWY_IDE +template +svbool_t MaskLowerHalf(D d) { + switch (MaxLanes(d)) { + case 16: + return svptrue_pat_b8(SV_VL8); + case 8: + return svptrue_pat_b8(SV_VL4); + case 4: + return svptrue_pat_b8(SV_VL2); + case 2: + case 1: + default: + return svptrue_pat_b8(SV_VL1); + } +} +template +svbool_t MaskLowerHalf(D d) { + switch (MaxLanes(d)) { + case 8: + return svptrue_pat_b16(SV_VL4); + case 4: + return svptrue_pat_b16(SV_VL2); + case 2: + case 1: + default: + return svptrue_pat_b16(SV_VL1); + } +} +template +svbool_t MaskLowerHalf(D d) { + return svptrue_pat_b32(MaxLanes(d) == 4 ? SV_VL2 : SV_VL1); +} +template +svbool_t MaskLowerHalf(D /*d*/) { + return svptrue_pat_b64(SV_VL1); +} +#endif // HWY_TARGET == HWY_SVE2_128 +#if (HWY_TARGET != HWY_SVE_256 && HWY_TARGET != HWY_SVE2_128) || \ + !HWY_HAVE_CONSTEXPR_LANES +template +svbool_t MaskLowerHalf(D d) { + return FirstN(d, Lanes(d) / 2); +} +#endif + +template +svbool_t MaskUpperHalf(D d) { + // TODO(janwas): WHILEGE on SVE2 + if (HWY_SVE_IS_POW2 && IsFull(d)) { + return Not(MaskLowerHalf(d)); + } + + // For Splice to work as intended, make sure bits above Lanes(d) are zero. + return AndNot(MaskLowerHalf(d), detail::MakeMask(d)); +} + +// Right-shift vector pair by constexpr; can be used to slide down (=N) or up +// (=Lanes()-N). +#define HWY_SVE_EXT(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) hi, HWY_SVE_V(BASE, BITS) lo) { \ + return sv##OP##_##CHAR##BITS(lo, hi, kIndex); \ + } +HWY_SVE_FOREACH(HWY_SVE_EXT, Ext, ext) +#undef HWY_SVE_EXT + +} // namespace detail + +// ------------------------------ ConcatUpperLower +template +HWY_API V ConcatUpperLower(const D d, const V hi, const V lo) { + return IfThenElse(detail::MaskLowerHalf(d), lo, hi); +} + +// ------------------------------ ConcatLowerLower +template +HWY_API V ConcatLowerLower(const D d, const V hi, const V lo) { + if (detail::IsFull(d)) { +#if defined(__ARM_FEATURE_SVE_MATMUL_FP64) && HWY_TARGET == HWY_SVE_256 + return detail::ConcatEvenBlocks(hi, lo); +#endif +#if HWY_TARGET == HWY_SVE2_128 + const Repartition du64; + const auto lo64 = BitCast(du64, lo); + return BitCast(d, InterleaveLower(du64, lo64, BitCast(du64, hi))); +#endif + } + return detail::Splice(hi, lo, detail::MaskLowerHalf(d)); +} + +// ------------------------------ ConcatLowerUpper +template +HWY_API V ConcatLowerUpper(const D d, const V hi, const V lo) { +#if HWY_HAVE_CONSTEXPR_LANES + if (detail::IsFull(d)) { + return detail::Ext(hi, lo); + } +#endif + return detail::Splice(hi, lo, detail::MaskUpperHalf(d)); +} + +// ------------------------------ ConcatUpperUpper +template +HWY_API V ConcatUpperUpper(const D d, const V hi, const V lo) { + if (detail::IsFull(d)) { +#if defined(__ARM_FEATURE_SVE_MATMUL_FP64) && HWY_TARGET == HWY_SVE_256 + return detail::ConcatOddBlocks(hi, lo); +#endif +#if HWY_TARGET == HWY_SVE2_128 + const Repartition du64; + const auto lo64 = BitCast(du64, lo); + return BitCast(d, InterleaveUpper(du64, lo64, BitCast(du64, hi))); +#endif + } + const svbool_t mask_upper = detail::MaskUpperHalf(d); + const V lo_upper = detail::Splice(lo, lo, mask_upper); + return IfThenElse(mask_upper, hi, lo_upper); +} + +// ------------------------------ Combine +template +HWY_API VFromD Combine(const D d, const V2 hi, const V2 lo) { + return ConcatLowerLower(d, hi, lo); +} + +// ------------------------------ ZeroExtendVector +template +HWY_API V ZeroExtendVector(const D d, const V lo) { + return Combine(d, Zero(Half()), lo); +} + +// ------------------------------ Lower/UpperHalf + +template +HWY_API V LowerHalf(D2 /* tag */, const V v) { + return v; +} + +template +HWY_API V LowerHalf(const V v) { + return v; +} + +template +HWY_API V UpperHalf(const DH dh, const V v) { + const Twice d; + // Cast so that we support bfloat16_t. + const RebindToUnsigned du; + const VFromD vu = BitCast(du, v); +#if HWY_HAVE_CONSTEXPR_LANES + return BitCast(d, detail::Ext(vu, vu)); +#else + const MFromD mask = detail::MaskUpperHalf(du); + return BitCast(d, detail::Splice(vu, vu, mask)); +#endif +} + +// ================================================== SWIZZLE + +// ------------------------------ DupEven + +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveEven, trn1) +} // namespace detail + +template +HWY_API V DupEven(const V v) { + return detail::InterleaveEven(v, v); +} + +// ------------------------------ DupOdd + +namespace detail { +HWY_SVE_FOREACH(HWY_SVE_RETV_ARGVV, InterleaveOdd, trn2) +} // namespace detail + +template +HWY_API V DupOdd(const V v) { + return detail::InterleaveOdd(v, v); +} + +// ------------------------------ OddEven + +#if HWY_SVE_HAVE_2 + +#define HWY_SVE_ODD_EVEN(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) odd, HWY_SVE_V(BASE, BITS) even) { \ + return sv##OP##_##CHAR##BITS(even, odd, /*xor=*/0); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_ODD_EVEN, OddEven, eortb_n) +#undef HWY_SVE_ODD_EVEN + +template +HWY_API V OddEven(const V odd, const V even) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, OddEven(BitCast(du, odd), BitCast(du, even))); +} + +#else + +template +HWY_API V OddEven(const V odd, const V even) { + const auto odd_in_even = detail::Ext<1>(odd, odd); + return detail::InterleaveEven(even, odd_in_even); +} + +#endif // HWY_TARGET + +// ------------------------------ InterleaveEven +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return detail::InterleaveEven(a, b); +} + +// ------------------------------ InterleaveOdd +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return detail::InterleaveOdd(a, b); +} + +// ------------------------------ OddEvenBlocks +template +HWY_API V OddEvenBlocks(const V odd, const V even) { + const DFromV d; +#if HWY_TARGET == HWY_SVE_256 + return ConcatUpperLower(d, odd, even); +#elif HWY_TARGET == HWY_SVE2_128 + (void)odd; + (void)d; + return even; +#else + const RebindToUnsigned du; + using TU = TFromD; + constexpr size_t kShift = CeilLog2(16 / sizeof(TU)); + const auto idx_block = ShiftRight(Iota(du, 0)); + const auto lsb = detail::AndN(idx_block, static_cast(1)); + const svbool_t is_even = detail::EqN(lsb, static_cast(0)); + return IfThenElse(is_even, even, odd); +#endif +} + +// ------------------------------ TableLookupLanes + +template +HWY_API VFromD> IndicesFromVec(D d, VI vec) { + using TI = TFromV; + static_assert(sizeof(TFromD) == sizeof(TI), "Index/lane size mismatch"); + const RebindToUnsigned du; + const auto indices = BitCast(du, vec); +#if HWY_IS_DEBUG_BUILD + using TU = MakeUnsigned; + const size_t twice_max_lanes = Lanes(d) * 2; + HWY_DASSERT(AllTrue( + du, Eq(indices, + detail::AndN(indices, static_cast(twice_max_lanes - 1))))); +#else + (void)d; +#endif + return indices; +} + +template +HWY_API VFromD> SetTableIndices(D d, const TI* idx) { + static_assert(sizeof(TFromD) == sizeof(TI), "Index size must match lane"); + return IndicesFromVec(d, LoadU(Rebind(), idx)); +} + +#define HWY_SVE_TABLE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(uint, BITS) idx) { \ + return sv##OP##_##CHAR##BITS(v, idx); \ + } + +HWY_SVE_FOREACH(HWY_SVE_TABLE, TableLookupLanes, tbl) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_TABLE, TableLookupLanes, tbl) +#endif +#undef HWY_SVE_TABLE + +#if HWY_SVE_HAVE_2 +namespace detail { +#define HWY_SVE_TABLE2(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_TUPLE(BASE, BITS, 2) tuple, HWY_SVE_V(uint, BITS) idx) { \ + return sv##OP##_##CHAR##BITS(tuple, idx); \ + } + +HWY_SVE_FOREACH(HWY_SVE_TABLE2, NativeTwoTableLookupLanes, tbl2) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_TABLE2, NativeTwoTableLookupLanes, + tbl2) +#endif +#undef HWY_SVE_TABLE +} // namespace detail +#endif // HWY_SVE_HAVE_2 + +template +HWY_API VFromD TwoTablesLookupLanes(D d, VFromD a, VFromD b, + VFromD> idx) { + // SVE2 has an instruction for this, but it only works for full 2^n vectors. +#if HWY_SVE_HAVE_2 && HWY_SVE_IS_POW2 + if (detail::IsFull(d)) { + return detail::NativeTwoTableLookupLanes(Create2(d, a, b), idx); + } +#endif + const RebindToUnsigned du; + using TU = TFromD; + + const size_t num_of_lanes = Lanes(d); + const auto idx_mod = detail::AndN(idx, static_cast(num_of_lanes - 1)); + const auto sel_a_mask = Eq(idx, idx_mod); + + const auto a_lookup_result = TableLookupLanes(a, idx_mod); + const auto b_lookup_result = TableLookupLanes(b, idx_mod); + return IfThenElse(sel_a_mask, a_lookup_result, b_lookup_result); +} + +template +HWY_API V TwoTablesLookupLanes(V a, V b, + VFromD>> idx) { + const DFromV d; + return TwoTablesLookupLanes(d, a, b, idx); +} + +// ------------------------------ SlideUpLanes (FirstN) +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { + return detail::Splice(v, Zero(d), FirstN(d, amt)); +} + +// ------------------------------ Slide1Up + +#ifdef HWY_NATIVE_SLIDE1_UP_DOWN +#undef HWY_NATIVE_SLIDE1_UP_DOWN +#else +#define HWY_NATIVE_SLIDE1_UP_DOWN +#endif + +template +HWY_API VFromD Slide1Up(D d, VFromD v) { + return SlideUpLanes(d, v, 1); +} + +// ------------------------------ SlideDownLanes (TableLookupLanes) +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { + const RebindToUnsigned du; + using TU = TFromD; + const auto idx = Iota(du, static_cast(amt)); + return IfThenElseZero(FirstN(d, Lanes(d) - amt), TableLookupLanes(v, idx)); +} + +// ------------------------------ Slide1Down +template +HWY_API VFromD Slide1Down(D d, VFromD v) { + return SlideDownLanes(d, v, 1); +} + +// ------------------------------ SwapAdjacentBlocks (TableLookupLanes) + +namespace detail { + +template +constexpr size_t LanesPerBlock(Simd d) { + // We might have a capped vector smaller than a block, so honor that. + return HWY_MIN(16 / sizeof(T), MaxLanes(d)); +} + +} // namespace detail + +template +HWY_API V SwapAdjacentBlocks(const V v) { + const DFromV d; +#if HWY_TARGET == HWY_SVE_256 + return ConcatLowerUpper(d, v, v); +#elif HWY_TARGET == HWY_SVE2_128 + (void)d; + return v; +#else + const RebindToUnsigned du; + constexpr auto kLanesPerBlock = + static_cast>(detail::LanesPerBlock(d)); + const VFromD idx = detail::XorN(Iota(du, 0), kLanesPerBlock); + return TableLookupLanes(v, idx); +#endif +} + +// ------------------------------ InterleaveEvenBlocks +// (ConcatLowerLower, SlideUpLanes, OddEvenBlocks) + +template > +HWY_API V InterleaveEvenBlocks(D d, V a, V b) { +#if HWY_TARGET == HWY_SVE_256 + return ConcatLowerLower(d, b, a); +#elif HWY_TARGET == HWY_SVE2_128 + (void)d; + (void)b; + return a; +#else + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); + return OddEvenBlocks(SlideUpLanes(d, b, kLanesPerBlock), a); +#endif +} + +// ------------------------------ InterleaveOddBlocks +// (ConcatUpperUpper, SlideDownLanes, OddEvenBlocks) + +template > +HWY_API V InterleaveOddBlocks(D d, V a, V b) { +#if HWY_TARGET == HWY_SVE_256 + return ConcatUpperUpper(d, b, a); +#elif HWY_TARGET == HWY_SVE2_128 + (void)d; + (void)b; + return a; +#else + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); + return OddEvenBlocks(b, SlideDownLanes(d, a, kLanesPerBlock)); +#endif +} + +// ------------------------------ InterleaveLowerBlocks +// (InterleaveEvenBlocks) + +template > +HWY_API V InterleaveLowerBlocks(D d, V a, V b) { +#if HWY_TARGET == HWY_SVE_256 + return InterleaveEvenBlocks(d, a, b); +#elif HWY_TARGET == HWY_SVE2_128 + (void)d; + (void)b; + return a; +#else + const Repartition du64; + const svuint64_t a64 = BitCast(du64, a); + const svuint64_t b64 = BitCast(du64, b); + svuint64_t even = detail::InterleaveEven(a64, b64); // a0 b0 a2 b2 + svuint64_t odd = detail::InterleaveOdd(a64, b64); // a1 b1 a3 b3 + return BitCast(d, detail::ZipLowerSame(even, odd)); // a10 b10 +#endif +} + +// ------------------------------ InterleaveUpperBlocks +// (ConcatUpperUpper, SlideDownLanes, OddEvenBlocks) + +template > +HWY_API V InterleaveUpperBlocks(D d, V a, V b) { +#if HWY_TARGET == HWY_SVE_256 + return InterleaveOddBlocks(d, a, b); +#elif HWY_TARGET == HWY_SVE2_128 + (void)d; + (void)b; + return a; +#else + const Repartition du64; + const svuint64_t a64 = BitCast(du64, a); + const svuint64_t b64 = BitCast(du64, b); + svuint64_t even = detail::InterleaveEven(a64, b64); // a0 b0 a2 b2 + svuint64_t odd = detail::InterleaveOdd(a64, b64); // a1 b1 a3 b3 + HWY_IF_CONSTEXPR(detail::IsFull(d)) { + return BitCast(d, detail::ZipUpperSame(even, odd)); // a32 b32 + } + // ZipUpperSame assumes full vectors; instead use UpperHalf to honor the + // capped/fractional tag. + const Half du64h; + even = ResizeBitCast(du64, UpperHalf(du64h, even)); + odd = ResizeBitCast(du64, UpperHalf(du64h, odd)); + return BitCast(d, detail::ZipLowerSame(even, odd)); +#endif // HWY_TARGET +} + +// ------------------------------ Reverse + +namespace detail { + +#define HWY_SVE_REVERSE(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS(v); \ + } + +HWY_SVE_FOREACH(HWY_SVE_REVERSE, ReverseFull, rev) +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC +HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SVE_REVERSE, ReverseFull, rev) +#endif +#undef HWY_SVE_REVERSE + +} // namespace detail + +template +HWY_API V Reverse(D d, V v) { + using T = TFromD; + const auto reversed = detail::ReverseFull(v); + if (HWY_SVE_IS_POW2 && detail::IsFull(d)) return reversed; + // Shift right to remove extra (non-pow2 and remainder) lanes. + // TODO(janwas): on SVE2, use WHILEGE. + // Avoids FirstN truncating to the return vector size. Must also avoid Not + // because that is limited to SV_POW2. + const ScalableTag dfull; + const svbool_t all_true = detail::AllPTrue(dfull); + const size_t all_lanes = detail::AllHardwareLanes(); + const size_t want_lanes = Lanes(d); + HWY_DASSERT(want_lanes <= all_lanes); + const svbool_t mask = + svnot_b_z(all_true, FirstN(dfull, all_lanes - want_lanes)); + return detail::Splice(reversed, reversed, mask); +} + +// ------------------------------ Reverse2 + +// Per-target flag to prevent generic_ops-inl.h defining 8-bit Reverse2/4/8. +#ifdef HWY_NATIVE_REVERSE2_8 +#undef HWY_NATIVE_REVERSE2_8 +#else +#define HWY_NATIVE_REVERSE2_8 +#endif + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const RebindToUnsigned du; + const RepartitionToWide dw; + return BitCast(d, svrevb_u16_x(detail::PTrue(d), BitCast(dw, v))); +} + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const RebindToUnsigned du; + const RepartitionToWide dw; + return BitCast(d, svrevh_u32_x(detail::PTrue(d), BitCast(dw, v))); +} + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const RebindToUnsigned du; + const RepartitionToWide dw; + return BitCast(d, svrevw_u64_x(detail::PTrue(d), BitCast(dw, v))); +} + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { // 3210 +#if HWY_TARGET == HWY_SVE2_128 + if (detail::IsFull(d)) { + return detail::Ext<1>(v, v); + } +#endif + (void)d; + const auto odd_in_even = detail::Ext<1>(v, v); // x321 + return detail::InterleaveEven(odd_in_even, v); // 2301 +} + +// ------------------------------ Reverse4 (TableLookupLanes) + +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + const RebindToUnsigned du; + const RepartitionToWideX2 du32; + return BitCast(d, svrevb_u32_x(detail::PTrue(d), BitCast(du32, v))); +} + +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + const RebindToUnsigned du; + const RepartitionToWideX2 du64; + return BitCast(d, svrevh_u64_x(detail::PTrue(d), BitCast(du64, v))); +} + +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + if (HWY_TARGET == HWY_SVE2_128 && detail::IsFull(d)) { + return detail::ReverseFull(v); + } + // TODO(janwas): is this approach faster than Shuffle0123? + const RebindToUnsigned du; + const auto idx = detail::XorN(Iota(du, 0), 3); + return TableLookupLanes(v, idx); +} + +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + if (HWY_TARGET == HWY_SVE_256 && detail::IsFull(d)) { + return detail::ReverseFull(v); + } + // TODO(janwas): is this approach faster than Shuffle0123? + const RebindToUnsigned du; + const auto idx = detail::XorN(Iota(du, 0), 3); + return TableLookupLanes(v, idx); +} + +// ------------------------------ Reverse8 (TableLookupLanes) + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + const Repartition du64; + return BitCast(d, svrevb_u64_x(detail::PTrue(d), BitCast(du64, v))); +} + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + const RebindToUnsigned du; + const auto idx = detail::XorN(Iota(du, 0), 7); + return TableLookupLanes(v, idx); +} + +// ------------------------------- ReverseBits + +#ifdef HWY_NATIVE_REVERSE_BITS_UI8 +#undef HWY_NATIVE_REVERSE_BITS_UI8 +#else +#define HWY_NATIVE_REVERSE_BITS_UI8 +#endif + +#ifdef HWY_NATIVE_REVERSE_BITS_UI16_32_64 +#undef HWY_NATIVE_REVERSE_BITS_UI16_32_64 +#else +#define HWY_NATIVE_REVERSE_BITS_UI16_32_64 +#endif + +#define HWY_SVE_REVERSE_BITS(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + const DFromV d; \ + return sv##OP##_##CHAR##BITS##_x(detail::PTrue(d), v); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_REVERSE_BITS, ReverseBits, rbit) +#undef HWY_SVE_REVERSE_BITS + +// ------------------------------ Block insert/extract/broadcast ops +#if HWY_TARGET != HWY_SVE2_128 + +#ifdef HWY_NATIVE_BLK_INSERT_EXTRACT +#undef HWY_NATIVE_BLK_INSERT_EXTRACT +#else +#define HWY_NATIVE_BLK_INSERT_EXTRACT +#endif + +template +HWY_API V InsertBlock(V v, V blk_to_insert) { + const DFromV d; + static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(), + "Invalid block index"); + +#if HWY_TARGET == HWY_SVE_256 + return (kBlockIdx == 0) ? ConcatUpperLower(d, v, blk_to_insert) + : ConcatLowerLower(d, blk_to_insert, v); +#else + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); + + constexpr size_t kBlockOffset = + static_cast(kBlockIdx) * kLanesPerBlock; + const auto splice_mask = FirstN(d, kBlockOffset); + const auto sel_lo_mask = FirstN(d, kBlockOffset + kLanesPerBlock); + + const auto splice_result = detail::Splice(blk_to_insert, v, splice_mask); + return IfThenElse(sel_lo_mask, splice_result, v); +#endif +} + +template +HWY_API V ExtractBlock(V v) { + const DFromV d; + static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(), + "Invalid block index"); + + if (kBlockIdx == 0) return v; + +#if HWY_TARGET == HWY_SVE_256 + return UpperHalf(Half(), v); +#else + const RebindToUnsigned du; + using TU = TFromD; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); + constexpr size_t kBlockOffset = + static_cast(kBlockIdx) * kLanesPerBlock; + const auto splice_mask = + RebindMask(d, detail::LtN(Iota(du, static_cast(0u - kBlockOffset)), + static_cast(kLanesPerBlock))); + return detail::Splice(v, v, splice_mask); +#endif +} + +template +HWY_API V BroadcastBlock(V v) { + const DFromV d; + static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(), + "Invalid block index"); + + const RebindToUnsigned du; // for bfloat16_t + using VU = VFromD; + const VU vu = BitCast(du, v); + +#if HWY_TARGET == HWY_SVE_256 + return BitCast(d, (kBlockIdx == 0) ? ConcatLowerLower(du, vu, vu) + : ConcatUpperUpper(du, vu, vu)); +#else + using TU = TFromD; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); + constexpr size_t kBlockOffset = + static_cast(kBlockIdx) * kLanesPerBlock; + + const VU idx = detail::AddN( + detail::AndN(Iota(du, TU{0}), static_cast(kLanesPerBlock - 1)), + static_cast(kBlockOffset)); + return BitCast(d, TableLookupLanes(vu, idx)); +#endif +} + +#endif // HWY_TARGET != HWY_SVE2_128 + +// ------------------------------ Compress (PromoteTo) + +template +struct CompressIsPartition { +#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 + // Optimization for 64-bit lanes (could also be applied to 32-bit, but that + // requires a larger table). + enum { value = (sizeof(T) == 8) }; +#else + enum { value = 0 }; +#endif // HWY_TARGET == HWY_SVE_256 +}; + +#define HWY_SVE_COMPRESS(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v, svbool_t mask) { \ + return sv##OP##_##CHAR##BITS(mask, v); \ + } + +#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 +HWY_SVE_FOREACH_UI32(HWY_SVE_COMPRESS, Compress, compact) +HWY_SVE_FOREACH_F32(HWY_SVE_COMPRESS, Compress, compact) +#else +HWY_SVE_FOREACH_UIF3264(HWY_SVE_COMPRESS, Compress, compact) +#endif +#undef HWY_SVE_COMPRESS + +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE +template +HWY_API V Compress(V v, svbool_t mask) { + const DFromV d; + const RebindToUnsigned du64; + + // Convert mask into bitfield via horizontal sum (faster than ORV) of masked + // bits 1, 2, 4, 8. Pre-multiply by N so we can use it as an offset for + // SetTableIndices. + const svuint64_t bits = Shl(Set(du64, 1), Iota(du64, 2)); + const size_t offset = detail::SumOfLanesM(mask, bits); + + // See CompressIsPartition. + alignas(16) static constexpr uint64_t table[4 * 16] = { + // PrintCompress64x4Tables + 0, 1, 2, 3, 0, 1, 2, 3, 1, 0, 2, 3, 0, 1, 2, 3, 2, 0, 1, 3, 0, 2, + 1, 3, 1, 2, 0, 3, 0, 1, 2, 3, 3, 0, 1, 2, 0, 3, 1, 2, 1, 3, 0, 2, + 0, 1, 3, 2, 2, 3, 0, 1, 0, 2, 3, 1, 1, 2, 3, 0, 0, 1, 2, 3}; + return TableLookupLanes(v, SetTableIndices(d, table + offset)); +} + +#endif // HWY_TARGET == HWY_SVE_256 +#if HWY_TARGET == HWY_SVE2_128 || HWY_IDE +template +HWY_API V Compress(V v, svbool_t mask) { + // If mask == 10: swap via splice. A mask of 00 or 11 leaves v unchanged, 10 + // swaps upper/lower (the lower half is set to the upper half, and the + // remaining upper half is filled from the lower half of the second v), and + // 01 is invalid because it would ConcatLowerLower. zip1 and AndNot keep 10 + // unchanged and map everything else to 00. + const svbool_t maskLL = svzip1_b64(mask, mask); // broadcast lower lane + return detail::Splice(v, v, AndNot(maskLL, mask)); +} + +#endif // HWY_TARGET == HWY_SVE2_128 + +template +HWY_API V Compress(V v, svbool_t mask16) { + static_assert(!IsSame(), "Must use overload"); + const DFromV d16; + + // Promote vector and mask to 32-bit + const RepartitionToWide dw; + const auto v32L = PromoteTo(dw, v); + const auto v32H = detail::PromoteUpperTo(dw, v); + const svbool_t mask32L = svunpklo_b(mask16); + const svbool_t mask32H = svunpkhi_b(mask16); + + const auto compressedL = Compress(v32L, mask32L); + const auto compressedH = Compress(v32H, mask32H); + + // Demote to 16-bit (already in range) - separately so we can splice + const V evenL = BitCast(d16, compressedL); + const V evenH = BitCast(d16, compressedH); + const V v16L = detail::ConcatEvenFull(evenL, evenL); // lower half + const V v16H = detail::ConcatEvenFull(evenH, evenH); + + // We need to combine two vectors of non-constexpr length, so the only option + // is Splice, which requires us to synthesize a mask. NOTE: this function uses + // full vectors (SV_ALL instead of SV_POW2), hence we need unmasked svcnt. + const size_t countL = detail::CountTrueFull(dw, mask32L); + const auto compressed_maskL = FirstN(d16, countL); + return detail::Splice(v16H, v16L, compressed_maskL); +} + +// Must treat float16_t as integers so we can ConcatEven. +HWY_API svfloat16_t Compress(svfloat16_t v, svbool_t mask16) { + const DFromV df; + const RebindToSigned di; + return BitCast(df, Compress(BitCast(di, v), mask16)); +} + +// ------------------------------ CompressNot + +// 2 or 4 bytes +template +HWY_API V CompressNot(V v, const svbool_t mask) { + return Compress(v, Not(mask)); +} + +template +HWY_API V CompressNot(V v, svbool_t mask) { +#if HWY_TARGET == HWY_SVE2_128 || HWY_IDE + // If mask == 01: swap via splice. A mask of 00 or 11 leaves v unchanged, 10 + // swaps upper/lower (the lower half is set to the upper half, and the + // remaining upper half is filled from the lower half of the second v), and + // 01 is invalid because it would ConcatLowerLower. zip1 and AndNot map + // 01 to 10, and everything else to 00. + const svbool_t maskLL = svzip1_b64(mask, mask); // broadcast lower lane + return detail::Splice(v, v, AndNot(mask, maskLL)); +#endif +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE + const DFromV d; + const RebindToUnsigned du64; + + // Convert mask into bitfield via horizontal sum (faster than ORV) of masked + // bits 1, 2, 4, 8. Pre-multiply by N so we can use it as an offset for + // SetTableIndices. + const svuint64_t bits = Shl(Set(du64, 1), Iota(du64, 2)); + const size_t offset = detail::SumOfLanesM(mask, bits); + + // See CompressIsPartition. + alignas(16) static constexpr uint64_t table[4 * 16] = { + // PrintCompressNot64x4Tables + 0, 1, 2, 3, 1, 2, 3, 0, 0, 2, 3, 1, 2, 3, 0, 1, 0, 1, 3, 2, 1, 3, + 0, 2, 0, 3, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 1, 2, 0, 3, 0, 2, 1, 3, + 2, 0, 1, 3, 0, 1, 2, 3, 1, 0, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; + return TableLookupLanes(v, SetTableIndices(d, table + offset)); +#endif // HWY_TARGET == HWY_SVE_256 + + return Compress(v, Not(mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API svuint64_t CompressBlocksNot(svuint64_t v, svbool_t mask) { +#if HWY_TARGET == HWY_SVE2_128 + (void)mask; + return v; +#endif +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE + uint64_t bits = 0; // predicate reg is 32-bit + CopyBytes<4>(&mask, &bits); // not same size - 64-bit more efficient + // Concatenate LSB for upper and lower blocks, pre-scale by 4 for table idx. + const size_t offset = ((bits & 1) ? 4u : 0u) + ((bits & 0x10000) ? 8u : 0u); + // See CompressIsPartition. Manually generated; flip halves if mask = [0, 1]. + alignas(16) static constexpr uint64_t table[4 * 4] = {0, 1, 2, 3, 2, 3, 0, 1, + 0, 1, 2, 3, 0, 1, 2, 3}; + const ScalableTag d; + return TableLookupLanes(v, SetTableIndices(d, table + offset)); +#endif + + return CompressNot(v, mask); +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(const V v, const svbool_t mask, const D d, + TFromD* HWY_RESTRICT unaligned) { + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(const V v, const svbool_t mask, const D d, + TFromD* HWY_RESTRICT unaligned) { + const size_t count = CountTrue(d, mask); + const svbool_t store_mask = FirstN(d, count); + BlendedStore(Compress(v, mask), store_mask, d, unaligned); + return count; +} + +// ================================================== MASK (2) + +// ------------------------------ FindKnownLastTrue +template +HWY_API size_t FindKnownLastTrue(D d, svbool_t m) { + const RebindToUnsigned du; + return static_cast(detail::ExtractLastMatchingLaneM( + Iota(du, 0), And(m, detail::MakeMask(d)))); +} + +// ------------------------------ FindLastTrue +template +HWY_API intptr_t FindLastTrue(D d, svbool_t m) { + return AllFalse(d, m) ? intptr_t{-1} + : static_cast(FindKnownLastTrue(d, m)); +} + +// ================================================== BLOCKWISE + +// ------------------------------ CombineShiftRightBytes + +// Prevent accidentally using these for 128-bit vectors - should not be +// necessary. +#if HWY_TARGET != HWY_SVE2_128 +namespace detail { + +// For x86-compatible behaviour mandated by Highway API: TableLookupBytes +// offsets are implicitly relative to the start of their 128-bit block. +template +HWY_INLINE V OffsetsOf128BitBlocks(const D d, const V iota0) { + using T = MakeUnsigned>; + return detail::AndNotN(static_cast(LanesPerBlock(d) - 1), iota0); +} + +template +svbool_t FirstNPerBlock(D d) { + const RebindToUnsigned du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + const svuint8_t idx_mod = + svdupq_n_u8(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock, + 3 % kLanesPerBlock, 4 % kLanesPerBlock, 5 % kLanesPerBlock, + 6 % kLanesPerBlock, 7 % kLanesPerBlock, 8 % kLanesPerBlock, + 9 % kLanesPerBlock, 10 % kLanesPerBlock, 11 % kLanesPerBlock, + 12 % kLanesPerBlock, 13 % kLanesPerBlock, 14 % kLanesPerBlock, + 15 % kLanesPerBlock); + return detail::LtN(BitCast(du, idx_mod), kLanes); +} +template +svbool_t FirstNPerBlock(D d) { + const RebindToUnsigned du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + const svuint16_t idx_mod = + svdupq_n_u16(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock, + 3 % kLanesPerBlock, 4 % kLanesPerBlock, 5 % kLanesPerBlock, + 6 % kLanesPerBlock, 7 % kLanesPerBlock); + return detail::LtN(BitCast(du, idx_mod), kLanes); +} +template +svbool_t FirstNPerBlock(D d) { + const RebindToUnsigned du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + const svuint32_t idx_mod = + svdupq_n_u32(0 % kLanesPerBlock, 1 % kLanesPerBlock, 2 % kLanesPerBlock, + 3 % kLanesPerBlock); + return detail::LtN(BitCast(du, idx_mod), kLanes); +} +template +svbool_t FirstNPerBlock(D d) { + const RebindToUnsigned du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + const svuint64_t idx_mod = + svdupq_n_u64(0 % kLanesPerBlock, 1 % kLanesPerBlock); + return detail::LtN(BitCast(du, idx_mod), kLanes); +} + +} // namespace detail +#endif // HWY_TARGET != HWY_SVE2_128 + +template > +HWY_API V CombineShiftRightBytes(const D d, const V hi, const V lo) { + const Repartition d8; + const auto hi8 = BitCast(d8, hi); + const auto lo8 = BitCast(d8, lo); +#if HWY_TARGET == HWY_SVE2_128 + return BitCast(d, detail::Ext(hi8, lo8)); +#else + const auto hi_up = detail::Splice(hi8, hi8, FirstN(d8, 16 - kBytes)); + const auto lo_down = detail::Ext(lo8, lo8); + const svbool_t is_lo = detail::FirstNPerBlock<16 - kBytes>(d8); + return BitCast(d, IfThenElse(is_lo, lo_down, hi_up)); +#endif +} + +// ------------------------------ Shuffle2301 +template +HWY_API V Shuffle2301(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + return Reverse2(d, v); +} + +// ------------------------------ Shuffle2103 +template +HWY_API V Shuffle2103(const V v) { + const DFromV d; + const Repartition d8; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + const svuint8_t v8 = BitCast(d8, v); + return BitCast(d, CombineShiftRightBytes<12>(d8, v8, v8)); +} + +// ------------------------------ Shuffle0321 +template +HWY_API V Shuffle0321(const V v) { + const DFromV d; + const Repartition d8; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + const svuint8_t v8 = BitCast(d8, v); + return BitCast(d, CombineShiftRightBytes<4>(d8, v8, v8)); +} + +// ------------------------------ Shuffle1032 +template +HWY_API V Shuffle1032(const V v) { + const DFromV d; + const Repartition d8; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + const svuint8_t v8 = BitCast(d8, v); + return BitCast(d, CombineShiftRightBytes<8>(d8, v8, v8)); +} + +// ------------------------------ Shuffle01 +template +HWY_API V Shuffle01(const V v) { + const DFromV d; + const Repartition d8; + static_assert(sizeof(TFromD) == 8, "Defined for 64-bit types"); + const svuint8_t v8 = BitCast(d8, v); + return BitCast(d, CombineShiftRightBytes<8>(d8, v8, v8)); +} + +// ------------------------------ Shuffle0123 +template +HWY_API V Shuffle0123(const V v) { + return Shuffle2301(Shuffle1032(v)); +} + +// ------------------------------ ReverseBlocks (Reverse, Shuffle01) +template > +HWY_API V ReverseBlocks(D d, V v) { +#if HWY_TARGET == HWY_SVE_256 + if (detail::IsFull(d)) { + return SwapAdjacentBlocks(v); + } else if (detail::IsFull(Twice())) { + return v; + } +#elif HWY_TARGET == HWY_SVE2_128 + (void)d; + return v; +#endif + const Repartition du64; + return BitCast(d, Shuffle01(Reverse(du64, BitCast(du64, v)))); +} + +// ------------------------------ TableLookupBytes + +template +HWY_API VI TableLookupBytes(const V v, const VI idx) { + const DFromV d; + const Repartition du8; +#if HWY_TARGET == HWY_SVE2_128 + return BitCast(d, TableLookupLanes(BitCast(du8, v), BitCast(du8, idx))); +#else + const auto offsets128 = detail::OffsetsOf128BitBlocks(du8, Iota(du8, 0)); + const auto idx8 = Add(BitCast(du8, idx), offsets128); + return BitCast(d, TableLookupLanes(BitCast(du8, v), idx8)); +#endif +} + +template +HWY_API VI TableLookupBytesOr0(const V v, const VI idx) { + const DFromV d; + // Mask size must match vector type, so cast everything to this type. + const Repartition di8; + + auto idx8 = BitCast(di8, idx); + const auto msb = detail::LtN(idx8, 0); + + const auto lookup = TableLookupBytes(BitCast(di8, v), idx8); + return BitCast(d, IfThenZeroElse(msb, lookup)); +} + +// ------------------------------ Broadcast + +#ifdef HWY_NATIVE_BROADCASTLANE +#undef HWY_NATIVE_BROADCASTLANE +#else +#define HWY_NATIVE_BROADCASTLANE +#endif + +namespace detail { +#define HWY_SVE_BROADCAST(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_INLINE HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS(v, kLane); \ + } + +HWY_SVE_FOREACH(HWY_SVE_BROADCAST, BroadcastLane, dup_lane) +#undef HWY_SVE_BROADCAST +} // namespace detail + +template +HWY_API V Broadcast(const V v) { + const DFromV d; + const RebindToUnsigned du; + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(du); + static_assert(0 <= kLane && kLane < kLanesPerBlock, "Invalid lane"); +#if HWY_TARGET == HWY_SVE2_128 + return detail::BroadcastLane(v); +#else + auto idx = detail::OffsetsOf128BitBlocks(du, Iota(du, 0)); + if (kLane != 0) { + idx = detail::AddN(idx, kLane); + } + return TableLookupLanes(v, idx); +#endif +} + +template +HWY_API V BroadcastLane(const V v) { + static_assert(0 <= kLane && kLane < HWY_MAX_LANES_V(V), "Invalid lane"); + return detail::BroadcastLane(v); +} + +// ------------------------------ ShiftLeftLanes + +template > +HWY_API V ShiftLeftLanes(D d, const V v) { + const auto zero = Zero(d); + const auto shifted = detail::Splice(v, zero, FirstN(d, kLanes)); +#if HWY_TARGET == HWY_SVE2_128 + return shifted; +#else + // Match x86 semantics by zeroing lower lanes in 128-bit blocks + return IfThenElse(detail::FirstNPerBlock(d), zero, shifted); +#endif +} + +template +HWY_API V ShiftLeftLanes(const V v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightLanes +template > +HWY_API V ShiftRightLanes(D d, V v) { + // For capped/fractional vectors, clear upper lanes so we shift in zeros. + if (!detail::IsFull(d)) { + v = IfThenElseZero(detail::MakeMask(d), v); + } + +#if HWY_TARGET == HWY_SVE2_128 + return detail::Ext(Zero(d), v); +#else + const auto shifted = detail::Ext(v, v); + // Match x86 semantics by zeroing upper lanes in 128-bit blocks + constexpr size_t kLanesPerBlock = detail::LanesPerBlock(d); + const svbool_t mask = detail::FirstNPerBlock(d); + return IfThenElseZero(mask, shifted); +#endif +} + +// ------------------------------ ShiftLeftBytes + +template > +HWY_API V ShiftLeftBytes(const D d, const V v) { + const Repartition d8; + return BitCast(d, ShiftLeftLanes(BitCast(d8, v))); +} + +template +HWY_API V ShiftLeftBytes(const V v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes +template > +HWY_API V ShiftRightBytes(const D d, const V v) { + const Repartition d8; + return BitCast(d, ShiftRightLanes(d8, BitCast(d8, v))); +} + +// ------------------------------ ZipLower + +template >> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + const RepartitionToNarrow dn; + static_assert(IsSame, TFromV>(), "D/V mismatch"); + return BitCast(dw, InterleaveLower(dn, a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(const V a, const V b) { + return BitCast(DW(), InterleaveLower(D(), a, b)); +} + +// ------------------------------ ZipUpper +template >> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + const RepartitionToNarrow dn; + static_assert(IsSame, TFromV>(), "D/V mismatch"); + return BitCast(dw, InterleaveUpper(dn, a, b)); +} + +// ================================================== Ops with dependencies + +// ------------------------------ AddSub (Reverse2) + +// NOTE: svcadd_f*_x(HWY_SVE_PTRUE(BITS), a, b, 90) computes a[i] - b[i + 1] in +// the even lanes and a[i] + b[i - 1] in the odd lanes. + +#define HWY_SVE_ADDSUB_F(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + const DFromV d; \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, Reverse2(d, b), \ + 90); \ + } + +HWY_SVE_FOREACH_F(HWY_SVE_ADDSUB_F, AddSub, cadd) + +#undef HWY_SVE_ADDSUB_F + +// NOTE: svcadd_s*(a, b, 90) and svcadd_u*(a, b, 90) compute a[i] - b[i + 1] in +// the even lanes and a[i] + b[i - 1] in the odd lanes. + +#if HWY_SVE_HAVE_2 +#define HWY_SVE_ADDSUB_UI(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \ + const DFromV d; \ + return sv##OP##_##CHAR##BITS(a, Reverse2(d, b), 90); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_ADDSUB_UI, AddSub, cadd) + +#undef HWY_SVE_ADDSUB_UI + +// Disable the default implementation of AddSub in generic_ops-inl.h on SVE2 +#undef HWY_IF_ADDSUB_V +#define HWY_IF_ADDSUB_V(V) \ + HWY_IF_LANES_GT_D(DFromV, 1), \ + hwy::EnableIf()>* = nullptr + +#else // !HWY_SVE_HAVE_2 + +// Disable the default implementation of AddSub in generic_ops-inl.h for +// floating-point vectors on SVE, but enable the default implementation of +// AddSub in generic_ops-inl.h for integer vectors on SVE that do not support +// SVE2 +#undef HWY_IF_ADDSUB_V +#define HWY_IF_ADDSUB_V(V) \ + HWY_IF_LANES_GT_D(DFromV, 1), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V) + +#endif // HWY_SVE_HAVE_2 + +// ------------------------------ MulAddSub (AddSub) + +template , 1), HWY_IF_FLOAT_V(V)> +HWY_API V MulAddSub(V mul, V x, V sub_or_add) { + using T = TFromV; + + const DFromV d; + const T neg_zero = ConvertScalarTo(-0.0f); + + return MulAdd(mul, x, AddSub(Set(d, neg_zero), sub_or_add)); +} + +#if HWY_SVE_HAVE_2 + +// Disable the default implementation of MulAddSub in generic_ops-inl.h on SVE2 +#undef HWY_IF_MULADDSUB_V +#define HWY_IF_MULADDSUB_V(V) \ + HWY_IF_LANES_GT_D(DFromV, 1), \ + hwy::EnableIf()>* = nullptr + +template , 1), + HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)> +HWY_API V MulAddSub(V mul, V x, V sub_or_add) { + const DFromV d; + return MulAdd(mul, x, AddSub(Zero(d), sub_or_add)); +} + +#else // !HWY_SVE_HAVE_2 + +// Disable the default implementation of MulAddSub in generic_ops-inl.h for +// floating-point vectors on SVE, but enable the default implementation of +// AddSub in generic_ops-inl.h for integer vectors on SVE targets that do not +// support SVE2 +#undef HWY_IF_MULADDSUB_V +#define HWY_IF_MULADDSUB_V(V) \ + HWY_IF_LANES_GT_D(DFromV, 1), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V) + +#endif // HWY_SVE_HAVE_2 + +// ------------------------------ PromoteTo bfloat16 (ZipLower) +template +HWY_API svfloat32_t PromoteTo(Simd df32, VBF16 v) { + const ScalableTag du16; + return BitCast(df32, detail::ZipLowerSame(svdup_n_u16(0), BitCast(du16, v))); +} + +// ------------------------------ PromoteEvenTo/PromoteOddTo (ConcatOddFull) + +namespace detail { + +// Signed to signed PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<2> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + svint8_t v) { + return svextb_s16_x(detail::PTrue(d_to), BitCast(d_to, v)); +} + +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<4> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + svint16_t v) { + return svexth_s32_x(detail::PTrue(d_to), BitCast(d_to, v)); +} + +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + svint32_t v) { + return svextw_s64_x(detail::PTrue(d_to), BitCast(d_to, v)); +} + +// F16->F32 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<4> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + svfloat16_t v) { + const Repartition d_from; + return svcvt_f32_f16_x(detail::PTrue(d_from), v); +} + +// F32->F64 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + svfloat32_t v) { + const Repartition d_from; + return svcvt_f64_f32_x(detail::PTrue(d_from), v); +} + +// I32->F64 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + svint32_t v) { + const Repartition d_from; + return svcvt_f64_s32_x(detail::PTrue(d_from), v); +} + +// U32->F64 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::UnsignedTag /*from_type_tag*/, D d_to, + svuint32_t v) { + const Repartition d_from; + return svcvt_f64_u32_x(detail::PTrue(d_from), v); +} + +// F32->I64 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + svfloat32_t v) { + const Repartition d_from; + return svcvt_s64_f32_x(detail::PTrue(d_from), v); +} + +// F32->U64 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::UnsignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + svfloat32_t v) { + const Repartition d_from; + return svcvt_u64_f32_x(detail::PTrue(d_from), v); +} + +// F16->F32 PromoteOddTo +template +HWY_INLINE VFromD PromoteOddTo(hwy::FloatTag to_type_tag, + hwy::SizeTag<4> to_lane_size_tag, + hwy::FloatTag from_type_tag, D d_to, + svfloat16_t v) { + return PromoteEvenTo(to_type_tag, to_lane_size_tag, from_type_tag, d_to, + DupOdd(v)); +} + +// I32/U32/F32->F64 PromoteOddTo +template +HWY_INLINE VFromD PromoteOddTo(hwy::FloatTag to_type_tag, + hwy::SizeTag<8> to_lane_size_tag, + FromTypeTag from_type_tag, D d_to, V v) { + return PromoteEvenTo(to_type_tag, to_lane_size_tag, from_type_tag, d_to, + DupOdd(v)); +} + +// F32->I64/U64 PromoteOddTo +template +HWY_INLINE VFromD PromoteOddTo(ToTypeTag to_type_tag, + hwy::SizeTag<8> to_lane_size_tag, + hwy::FloatTag from_type_tag, D d_to, + svfloat32_t v) { + return PromoteEvenTo(to_type_tag, to_lane_size_tag, from_type_tag, d_to, + DupOdd(v)); +} + +} // namespace detail + +// ------------------------------ ReorderDemote2To (OddEven) + +template +HWY_API VBF16 ReorderDemote2To(Simd dbf16, svfloat32_t a, + svfloat32_t b) { +#if HWY_SVE_HAVE_F32_TO_BF16C + const VBF16 b_in_even = svcvt_bf16_f32_x(detail::PTrue(dbf16), b); + return svcvtnt_bf16_f32_x(b_in_even, detail::PTrue(dbf16), a); +#else + (void)dbf16; + const auto a_in_odd = + BitCast(ScalableTag(), detail::RoundF32ForDemoteToBF16(a)); + const auto b_in_odd = + BitCast(ScalableTag(), detail::RoundF32ForDemoteToBF16(b)); + return BitCast(dbf16, detail::InterleaveOdd(b_in_odd, a_in_odd)); +#endif +} + +template +HWY_API svint16_t ReorderDemote2To(Simd d16, svint32_t a, + svint32_t b) { +#if HWY_SVE_HAVE_2 + (void)d16; + const svint16_t a_in_even = svqxtnb_s32(a); + return svqxtnt_s32(a_in_even, b); +#else + const svint16_t a16 = BitCast(d16, detail::SaturateI(a)); + const svint16_t b16 = BitCast(d16, detail::SaturateI(b)); + return detail::InterleaveEven(a16, b16); +#endif +} + +template +HWY_API svuint16_t ReorderDemote2To(Simd d16, svint32_t a, + svint32_t b) { +#if HWY_SVE_HAVE_2 + (void)d16; + const svuint16_t a_in_even = svqxtunb_s32(a); + return svqxtunt_s32(a_in_even, b); +#else + const Repartition du32; + const svuint32_t clamped_a = BitCast(du32, detail::MaxN(a, 0)); + const svuint32_t clamped_b = BitCast(du32, detail::MaxN(b, 0)); + const svuint16_t a16 = BitCast(d16, detail::SaturateU(clamped_a)); + const svuint16_t b16 = BitCast(d16, detail::SaturateU(clamped_b)); + return detail::InterleaveEven(a16, b16); +#endif +} + +template +HWY_API svuint16_t ReorderDemote2To(Simd d16, svuint32_t a, + svuint32_t b) { +#if HWY_SVE_HAVE_2 + (void)d16; + const svuint16_t a_in_even = svqxtnb_u32(a); + return svqxtnt_u32(a_in_even, b); +#else + const svuint16_t a16 = BitCast(d16, detail::SaturateU(a)); + const svuint16_t b16 = BitCast(d16, detail::SaturateU(b)); + return detail::InterleaveEven(a16, b16); +#endif +} + +template +HWY_API svint8_t ReorderDemote2To(Simd d8, svint16_t a, + svint16_t b) { +#if HWY_SVE_HAVE_2 + (void)d8; + const svint8_t a_in_even = svqxtnb_s16(a); + return svqxtnt_s16(a_in_even, b); +#else + const svint8_t a8 = BitCast(d8, detail::SaturateI(a)); + const svint8_t b8 = BitCast(d8, detail::SaturateI(b)); + return detail::InterleaveEven(a8, b8); +#endif +} + +template +HWY_API svuint8_t ReorderDemote2To(Simd d8, svint16_t a, + svint16_t b) { +#if HWY_SVE_HAVE_2 + (void)d8; + const svuint8_t a_in_even = svqxtunb_s16(a); + return svqxtunt_s16(a_in_even, b); +#else + const Repartition du16; + const svuint16_t clamped_a = BitCast(du16, detail::MaxN(a, 0)); + const svuint16_t clamped_b = BitCast(du16, detail::MaxN(b, 0)); + const svuint8_t a8 = BitCast(d8, detail::SaturateU(clamped_a)); + const svuint8_t b8 = BitCast(d8, detail::SaturateU(clamped_b)); + return detail::InterleaveEven(a8, b8); +#endif +} + +template +HWY_API svuint8_t ReorderDemote2To(Simd d8, svuint16_t a, + svuint16_t b) { +#if HWY_SVE_HAVE_2 + (void)d8; + const svuint8_t a_in_even = svqxtnb_u16(a); + return svqxtnt_u16(a_in_even, b); +#else + const svuint8_t a8 = BitCast(d8, detail::SaturateU(a)); + const svuint8_t b8 = BitCast(d8, detail::SaturateU(b)); + return detail::InterleaveEven(a8, b8); +#endif +} + +template +HWY_API svint32_t ReorderDemote2To(Simd d32, svint64_t a, + svint64_t b) { +#if HWY_SVE_HAVE_2 + (void)d32; + const svint32_t a_in_even = svqxtnb_s64(a); + return svqxtnt_s64(a_in_even, b); +#else + const svint32_t a32 = BitCast(d32, detail::SaturateI(a)); + const svint32_t b32 = BitCast(d32, detail::SaturateI(b)); + return detail::InterleaveEven(a32, b32); +#endif +} + +template +HWY_API svuint32_t ReorderDemote2To(Simd d32, svint64_t a, + svint64_t b) { +#if HWY_SVE_HAVE_2 + (void)d32; + const svuint32_t a_in_even = svqxtunb_s64(a); + return svqxtunt_s64(a_in_even, b); +#else + const Repartition du64; + const svuint64_t clamped_a = BitCast(du64, detail::MaxN(a, 0)); + const svuint64_t clamped_b = BitCast(du64, detail::MaxN(b, 0)); + const svuint32_t a32 = BitCast(d32, detail::SaturateU(clamped_a)); + const svuint32_t b32 = BitCast(d32, detail::SaturateU(clamped_b)); + return detail::InterleaveEven(a32, b32); +#endif +} + +template +HWY_API svuint32_t ReorderDemote2To(Simd d32, svuint64_t a, + svuint64_t b) { +#if HWY_SVE_HAVE_2 + (void)d32; + const svuint32_t a_in_even = svqxtnb_u64(a); + return svqxtnt_u64(a_in_even, b); +#else + const svuint32_t a32 = BitCast(d32, detail::SaturateU(a)); + const svuint32_t b32 = BitCast(d32, detail::SaturateU(b)); + return detail::InterleaveEven(a32, b32); +#endif +} + +template ) / 2)> +HWY_API VFromD ReorderDemote2To(D dn, V a, V b) { + const auto clamped_a = BitCast(dn, detail::SaturateU>(a)); + const auto clamped_b = BitCast(dn, detail::SaturateU>(b)); + return detail::InterleaveEven(clamped_a, clamped_b); +} + +template ), + HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2)> +HWY_API VFromD OrderedDemote2To(D dn, V a, V b) { + const Half dnh; + const auto demoted_a = DemoteTo(dnh, a); + const auto demoted_b = DemoteTo(dnh, b); + return Combine(dn, demoted_b, demoted_a); +} + +template +HWY_API VBF16 OrderedDemote2To(Simd dbf16, svfloat32_t a, + svfloat32_t b) { +#if HWY_SVE_HAVE_F32_TO_BF16C + (void)dbf16; + const VBF16 a_in_even = svcvt_bf16_f32_x(detail::PTrue(dbf16), a); + const VBF16 b_in_even = svcvt_bf16_f32_x(detail::PTrue(dbf16), b); + return ConcatEven(dbf16, b_in_even, a_in_even); +#else + const RebindToUnsigned du16; + const svuint16_t a_in_odd = BitCast(du16, detail::RoundF32ForDemoteToBF16(a)); + const svuint16_t b_in_odd = BitCast(du16, detail::RoundF32ForDemoteToBF16(b)); + return BitCast(dbf16, ConcatOdd(du16, b_in_odd, a_in_odd)); // lower half +#endif +} + +// ------------------------------ I8/U8/I16/U16 Div + +template +HWY_API V Div(V a, V b) { + const DFromV d; + const Half dh; + const RepartitionToWide dw; + + const auto q_lo = + Div(PromoteTo(dw, LowerHalf(dh, a)), PromoteTo(dw, LowerHalf(dh, b))); + const auto q_hi = Div(PromoteUpperTo(dw, a), PromoteUpperTo(dw, b)); + + return OrderedDemote2To(d, q_lo, q_hi); +} + +// ------------------------------ I8/U8/I16/U16 MaskedDivOr +template +HWY_API V MaskedDivOr(V no, M m, V a, V b) { + return IfThenElse(m, Div(a, b), no); +} + +template +HWY_API V MaskedDiv(M m, V a, V b) { + return IfThenElseZero(m, Div(a, b)); +} + +// ------------------------------ Mod (Div, NegMulAdd) +template +HWY_API V Mod(V a, V b) { + return NegMulAdd(Div(a, b), b, a); +} + +// ------------------------------ MaskedModOr (Mod) +template +HWY_API V MaskedModOr(V no, M m, V a, V b) { + return IfThenElse(m, Mod(a, b), no); +} + +// ------------------------------ IfNegativeThenElse (BroadcastSignBit) +template +HWY_API V IfNegativeThenElse(V v, V yes, V no) { + static_assert(IsSigned>(), "Only works for signed/float"); + return IfThenElse(IsNegative(v), yes, no); +} +// ------------------------------ IfNegativeThenNegOrUndefIfZero + +#ifdef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#undef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#else +#define HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#endif + +#define HWY_SVE_NEG_IF(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, BITS) mask, HWY_SVE_V(BASE, BITS) v) { \ + return sv##OP##_##CHAR##BITS##_m(v, IsNegative(mask), v); \ + } + +HWY_SVE_FOREACH_IF(HWY_SVE_NEG_IF, IfNegativeThenNegOrUndefIfZero, neg) + +#undef HWY_SVE_NEG_IF + +// ------------------------------ AverageRound (ShiftRight) + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 +#undef HWY_NATIVE_AVERAGE_ROUND_UI32 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI32 +#endif + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64 +#undef HWY_NATIVE_AVERAGE_ROUND_UI64 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI64 +#endif + +#if HWY_SVE_HAVE_2 +HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd) +#else +template +HWY_API V AverageRound(const V a, const V b) { + return Sub(Or(a, b), ShiftRight<1>(Xor(a, b))); +} +#endif // HWY_SVE_HAVE_2 + +// ------------------------------ LoadMaskBits (TestBit) + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_INLINE svbool_t LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { +#if HWY_COMPILER_CLANG >= 2200 || HWY_COMPILER_GCC_ACTUAL >= 1200 + typedef svbool_t UnalignedSveMaskT + __attribute__((__aligned__(1), __may_alias__)); + (void)d; + return *reinterpret_cast(bits); +#else + // TODO(janwas): with SVE2.1, load to vector, then PMOV + const RebindToUnsigned du; + const svuint8_t iota = Iota(du, 0); + + // Load correct number of bytes (bits/8) with 7 zeros after each. + const svuint8_t bytes = BitCast(du, svld1ub_u64(detail::PTrue(d), bits)); + // Replicate bytes 8x such that each byte contains the bit that governs it. + const svuint8_t rep8 = svtbl_u8(bytes, detail::AndNotN(7, iota)); + + const svuint8_t bit = + svdupq_n_u8(1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); + return TestBit(rep8, bit); +#endif +} + +template +HWY_INLINE svbool_t LoadMaskBits(D /* tag */, + const uint8_t* HWY_RESTRICT bits) { + const RebindToUnsigned du; + const Repartition du8; + + // There may be up to 128 bits; avoid reading past the end. + const svuint8_t bytes = svld1(FirstN(du8, (Lanes(du) + 7) / 8), bits); + + // Replicate bytes 16x such that each lane contains the bit that governs it. + const svuint8_t rep16 = svtbl_u8(bytes, ShiftRight<4>(Iota(du8, 0))); + + const svuint16_t bit = svdupq_n_u16(1, 2, 4, 8, 16, 32, 64, 128); + return TestBit(BitCast(du, rep16), bit); +} + +template +HWY_INLINE svbool_t LoadMaskBits(D /* tag */, + const uint8_t* HWY_RESTRICT bits) { + const RebindToUnsigned du; + const Repartition du8; + + // Upper bound = 2048 bits / 32 bit = 64 bits; at least 8 bytes are readable, + // so we can skip computing the actual length (Lanes(du)+7)/8. + const svuint8_t bytes = svld1(FirstN(du8, 8), bits); + + // Replicate bytes 32x such that each lane contains the bit that governs it. + const svuint8_t rep32 = svtbl_u8(bytes, ShiftRight<5>(Iota(du8, 0))); + + // 1, 2, 4, 8, 16, 32, 64, 128, 1, 2 .. + const svuint32_t bit = Shl(Set(du, 1), detail::AndN(Iota(du, 0), 7)); + + return TestBit(BitCast(du, rep32), bit); +} + +template +HWY_INLINE svbool_t LoadMaskBits(D /* tag */, + const uint8_t* HWY_RESTRICT bits) { + const RebindToUnsigned du; + + // Max 2048 bits = 32 lanes = 32 input bits; replicate those into each lane. + // The "at least 8 byte" guarantee in quick_reference ensures this is safe. + uint32_t mask_bits; + CopyBytes<4>(bits, &mask_bits); // copy from bytes + const auto vbits = Set(du, mask_bits); + + // 2 ^ {0,1, .., 31}, will not have more lanes than that. + const svuint64_t bit = Shl(Set(du, 1), Iota(du, 0)); + + return TestBit(vbits, bit); +} + +// ------------------------------ Dup128MaskFromMaskBits + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + const RebindToUnsigned du; + + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= (1u << kN) - 1; + + // Replicate the lower 8 bits of mask_bits to each u8 lane + const svuint8_t bytes = BitCast(du, Set(du, static_cast(mask_bits))); + + const svuint8_t bit = + svdupq_n_u8(1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); + return TestBit(bytes, bit); +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + const RebindToUnsigned du; + const Repartition du16; + + // Replicate the lower 16 bits of mask_bits to each u16 lane of a u16 vector, + // and then bitcast the replicated mask_bits to a u8 vector + const svuint8_t bytes = + BitCast(du, Set(du16, static_cast(mask_bits))); + // Replicate bytes 8x such that each byte contains the bit that governs it. + const svuint8_t rep8 = svtbl_u8(bytes, ShiftRight<3>(Iota(du, 0))); + + const svuint8_t bit = + svdupq_n_u8(1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); + return TestBit(rep8, bit); +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + const RebindToUnsigned du; + const Repartition du8; + + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= (1u << kN) - 1; + + // Set all of the u8 lanes of bytes to the lower 8 bits of mask_bits + const svuint8_t bytes = Set(du8, static_cast(mask_bits)); + + const svuint16_t bit = svdupq_n_u16(1, 2, 4, 8, 16, 32, 64, 128); + return TestBit(BitCast(du, bytes), bit); +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + const RebindToUnsigned du; + const Repartition du8; + + constexpr size_t kN = MaxLanes(d); + if (kN < 4) mask_bits &= (1u << kN) - 1; + + // Set all of the u8 lanes of bytes to the lower 8 bits of mask_bits + const svuint8_t bytes = Set(du8, static_cast(mask_bits)); + + const svuint32_t bit = svdupq_n_u32(1, 2, 4, 8); + return TestBit(BitCast(du, bytes), bit); +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + const RebindToUnsigned du; + const Repartition du8; + + if (MaxLanes(d) < 2) mask_bits &= 1u; + + // Set all of the u8 lanes of bytes to the lower 8 bits of mask_bits + const svuint8_t bytes = Set(du8, static_cast(mask_bits)); + + const svuint64_t bit = svdupq_n_u64(1, 2); + return TestBit(BitCast(du, bytes), bit); +} + +// ------------------------------ StoreMaskBits (BitsFromMask) + +// `p` points to at least 8 writable bytes. +// TODO(janwas): with SVE2.1, use PMOV to store to vector, then StoreU +template +HWY_API size_t StoreMaskBits(D d, svbool_t m, uint8_t* bits) { +#if HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 + constexpr size_t N = MaxLanes(d); + const uint64_t bits64 = BitsFromMask(d, m); + HWY_IF_CONSTEXPR(N < 8) { + // BitsFromMask guarantees upper bits are zero, hence no masking. + bits[0] = static_cast(bits64); + return 1; + } + static_assert(N < 8 || N % 8 == 0, "N is pow2 >= 8, hence divisible"); + static_assert(HWY_IS_LITTLE_ENDIAN, ""); + hwy::CopyBytes(&bits64, bits); + constexpr size_t num_bytes = hwy::DivCeil(N, size_t{8}); + return num_bytes; +#else + svuint64_t bits_in_u64 = detail::BitsFromBool(detail::BoolFromMask(m)); + + const size_t num_bits = Lanes(d); + const size_t num_bytes = hwy::DivCeil(num_bits, size_t{8}); + + // Truncate each u64 to 8 bits and store to u8. + svst1b_u64(FirstN(ScalableTag(), num_bytes), bits, bits_in_u64); + + // Non-full byte, need to clear the undefined upper bits. Can happen for + // capped/fractional vectors or large T and small hardware vectors. + if (num_bits < 8) { + const int mask = static_cast((1ull << num_bits) - 1); + bits[0] = static_cast(bits[0] & mask); + } + // Else: we wrote full bytes because num_bits is a power of two >= 8. + + return num_bytes; +#endif // HWY_TARGET == HWY_SVE_256 || HWY_TARGET == HWY_SVE2_128 +} + +// ------------------------------ CompressBits (LoadMaskBits) +template +HWY_INLINE V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) { + return Compress(v, LoadMaskBits(DFromV(), bits)); +} + +// ------------------------------ CompressBitsStore (LoadMaskBits) +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} + +// ------------------------------ Expand (StoreMaskBits) + +#ifdef HWY_NATIVE_EXPAND +#undef HWY_NATIVE_EXPAND +#else +#define HWY_NATIVE_EXPAND +#endif + +namespace detail { + +HWY_INLINE svuint8_t IndicesForExpandFromBits(uint64_t mask_bits) { + const CappedTag du8; + alignas(16) static constexpr uint8_t table[8 * 256] = { + // PrintExpand8x8Tables + 128, 128, 128, 128, 128, 128, 128, 128, // + 0, 128, 128, 128, 128, 128, 128, 128, // + 128, 0, 128, 128, 128, 128, 128, 128, // + 0, 1, 128, 128, 128, 128, 128, 128, // + 128, 128, 0, 128, 128, 128, 128, 128, // + 0, 128, 1, 128, 128, 128, 128, 128, // + 128, 0, 1, 128, 128, 128, 128, 128, // + 0, 1, 2, 128, 128, 128, 128, 128, // + 128, 128, 128, 0, 128, 128, 128, 128, // + 0, 128, 128, 1, 128, 128, 128, 128, // + 128, 0, 128, 1, 128, 128, 128, 128, // + 0, 1, 128, 2, 128, 128, 128, 128, // + 128, 128, 0, 1, 128, 128, 128, 128, // + 0, 128, 1, 2, 128, 128, 128, 128, // + 128, 0, 1, 2, 128, 128, 128, 128, // + 0, 1, 2, 3, 128, 128, 128, 128, // + 128, 128, 128, 128, 0, 128, 128, 128, // + 0, 128, 128, 128, 1, 128, 128, 128, // + 128, 0, 128, 128, 1, 128, 128, 128, // + 0, 1, 128, 128, 2, 128, 128, 128, // + 128, 128, 0, 128, 1, 128, 128, 128, // + 0, 128, 1, 128, 2, 128, 128, 128, // + 128, 0, 1, 128, 2, 128, 128, 128, // + 0, 1, 2, 128, 3, 128, 128, 128, // + 128, 128, 128, 0, 1, 128, 128, 128, // + 0, 128, 128, 1, 2, 128, 128, 128, // + 128, 0, 128, 1, 2, 128, 128, 128, // + 0, 1, 128, 2, 3, 128, 128, 128, // + 128, 128, 0, 1, 2, 128, 128, 128, // + 0, 128, 1, 2, 3, 128, 128, 128, // + 128, 0, 1, 2, 3, 128, 128, 128, // + 0, 1, 2, 3, 4, 128, 128, 128, // + 128, 128, 128, 128, 128, 0, 128, 128, // + 0, 128, 128, 128, 128, 1, 128, 128, // + 128, 0, 128, 128, 128, 1, 128, 128, // + 0, 1, 128, 128, 128, 2, 128, 128, // + 128, 128, 0, 128, 128, 1, 128, 128, // + 0, 128, 1, 128, 128, 2, 128, 128, // + 128, 0, 1, 128, 128, 2, 128, 128, // + 0, 1, 2, 128, 128, 3, 128, 128, // + 128, 128, 128, 0, 128, 1, 128, 128, // + 0, 128, 128, 1, 128, 2, 128, 128, // + 128, 0, 128, 1, 128, 2, 128, 128, // + 0, 1, 128, 2, 128, 3, 128, 128, // + 128, 128, 0, 1, 128, 2, 128, 128, // + 0, 128, 1, 2, 128, 3, 128, 128, // + 128, 0, 1, 2, 128, 3, 128, 128, // + 0, 1, 2, 3, 128, 4, 128, 128, // + 128, 128, 128, 128, 0, 1, 128, 128, // + 0, 128, 128, 128, 1, 2, 128, 128, // + 128, 0, 128, 128, 1, 2, 128, 128, // + 0, 1, 128, 128, 2, 3, 128, 128, // + 128, 128, 0, 128, 1, 2, 128, 128, // + 0, 128, 1, 128, 2, 3, 128, 128, // + 128, 0, 1, 128, 2, 3, 128, 128, // + 0, 1, 2, 128, 3, 4, 128, 128, // + 128, 128, 128, 0, 1, 2, 128, 128, // + 0, 128, 128, 1, 2, 3, 128, 128, // + 128, 0, 128, 1, 2, 3, 128, 128, // + 0, 1, 128, 2, 3, 4, 128, 128, // + 128, 128, 0, 1, 2, 3, 128, 128, // + 0, 128, 1, 2, 3, 4, 128, 128, // + 128, 0, 1, 2, 3, 4, 128, 128, // + 0, 1, 2, 3, 4, 5, 128, 128, // + 128, 128, 128, 128, 128, 128, 0, 128, // + 0, 128, 128, 128, 128, 128, 1, 128, // + 128, 0, 128, 128, 128, 128, 1, 128, // + 0, 1, 128, 128, 128, 128, 2, 128, // + 128, 128, 0, 128, 128, 128, 1, 128, // + 0, 128, 1, 128, 128, 128, 2, 128, // + 128, 0, 1, 128, 128, 128, 2, 128, // + 0, 1, 2, 128, 128, 128, 3, 128, // + 128, 128, 128, 0, 128, 128, 1, 128, // + 0, 128, 128, 1, 128, 128, 2, 128, // + 128, 0, 128, 1, 128, 128, 2, 128, // + 0, 1, 128, 2, 128, 128, 3, 128, // + 128, 128, 0, 1, 128, 128, 2, 128, // + 0, 128, 1, 2, 128, 128, 3, 128, // + 128, 0, 1, 2, 128, 128, 3, 128, // + 0, 1, 2, 3, 128, 128, 4, 128, // + 128, 128, 128, 128, 0, 128, 1, 128, // + 0, 128, 128, 128, 1, 128, 2, 128, // + 128, 0, 128, 128, 1, 128, 2, 128, // + 0, 1, 128, 128, 2, 128, 3, 128, // + 128, 128, 0, 128, 1, 128, 2, 128, // + 0, 128, 1, 128, 2, 128, 3, 128, // + 128, 0, 1, 128, 2, 128, 3, 128, // + 0, 1, 2, 128, 3, 128, 4, 128, // + 128, 128, 128, 0, 1, 128, 2, 128, // + 0, 128, 128, 1, 2, 128, 3, 128, // + 128, 0, 128, 1, 2, 128, 3, 128, // + 0, 1, 128, 2, 3, 128, 4, 128, // + 128, 128, 0, 1, 2, 128, 3, 128, // + 0, 128, 1, 2, 3, 128, 4, 128, // + 128, 0, 1, 2, 3, 128, 4, 128, // + 0, 1, 2, 3, 4, 128, 5, 128, // + 128, 128, 128, 128, 128, 0, 1, 128, // + 0, 128, 128, 128, 128, 1, 2, 128, // + 128, 0, 128, 128, 128, 1, 2, 128, // + 0, 1, 128, 128, 128, 2, 3, 128, // + 128, 128, 0, 128, 128, 1, 2, 128, // + 0, 128, 1, 128, 128, 2, 3, 128, // + 128, 0, 1, 128, 128, 2, 3, 128, // + 0, 1, 2, 128, 128, 3, 4, 128, // + 128, 128, 128, 0, 128, 1, 2, 128, // + 0, 128, 128, 1, 128, 2, 3, 128, // + 128, 0, 128, 1, 128, 2, 3, 128, // + 0, 1, 128, 2, 128, 3, 4, 128, // + 128, 128, 0, 1, 128, 2, 3, 128, // + 0, 128, 1, 2, 128, 3, 4, 128, // + 128, 0, 1, 2, 128, 3, 4, 128, // + 0, 1, 2, 3, 128, 4, 5, 128, // + 128, 128, 128, 128, 0, 1, 2, 128, // + 0, 128, 128, 128, 1, 2, 3, 128, // + 128, 0, 128, 128, 1, 2, 3, 128, // + 0, 1, 128, 128, 2, 3, 4, 128, // + 128, 128, 0, 128, 1, 2, 3, 128, // + 0, 128, 1, 128, 2, 3, 4, 128, // + 128, 0, 1, 128, 2, 3, 4, 128, // + 0, 1, 2, 128, 3, 4, 5, 128, // + 128, 128, 128, 0, 1, 2, 3, 128, // + 0, 128, 128, 1, 2, 3, 4, 128, // + 128, 0, 128, 1, 2, 3, 4, 128, // + 0, 1, 128, 2, 3, 4, 5, 128, // + 128, 128, 0, 1, 2, 3, 4, 128, // + 0, 128, 1, 2, 3, 4, 5, 128, // + 128, 0, 1, 2, 3, 4, 5, 128, // + 0, 1, 2, 3, 4, 5, 6, 128, // + 128, 128, 128, 128, 128, 128, 128, 0, // + 0, 128, 128, 128, 128, 128, 128, 1, // + 128, 0, 128, 128, 128, 128, 128, 1, // + 0, 1, 128, 128, 128, 128, 128, 2, // + 128, 128, 0, 128, 128, 128, 128, 1, // + 0, 128, 1, 128, 128, 128, 128, 2, // + 128, 0, 1, 128, 128, 128, 128, 2, // + 0, 1, 2, 128, 128, 128, 128, 3, // + 128, 128, 128, 0, 128, 128, 128, 1, // + 0, 128, 128, 1, 128, 128, 128, 2, // + 128, 0, 128, 1, 128, 128, 128, 2, // + 0, 1, 128, 2, 128, 128, 128, 3, // + 128, 128, 0, 1, 128, 128, 128, 2, // + 0, 128, 1, 2, 128, 128, 128, 3, // + 128, 0, 1, 2, 128, 128, 128, 3, // + 0, 1, 2, 3, 128, 128, 128, 4, // + 128, 128, 128, 128, 0, 128, 128, 1, // + 0, 128, 128, 128, 1, 128, 128, 2, // + 128, 0, 128, 128, 1, 128, 128, 2, // + 0, 1, 128, 128, 2, 128, 128, 3, // + 128, 128, 0, 128, 1, 128, 128, 2, // + 0, 128, 1, 128, 2, 128, 128, 3, // + 128, 0, 1, 128, 2, 128, 128, 3, // + 0, 1, 2, 128, 3, 128, 128, 4, // + 128, 128, 128, 0, 1, 128, 128, 2, // + 0, 128, 128, 1, 2, 128, 128, 3, // + 128, 0, 128, 1, 2, 128, 128, 3, // + 0, 1, 128, 2, 3, 128, 128, 4, // + 128, 128, 0, 1, 2, 128, 128, 3, // + 0, 128, 1, 2, 3, 128, 128, 4, // + 128, 0, 1, 2, 3, 128, 128, 4, // + 0, 1, 2, 3, 4, 128, 128, 5, // + 128, 128, 128, 128, 128, 0, 128, 1, // + 0, 128, 128, 128, 128, 1, 128, 2, // + 128, 0, 128, 128, 128, 1, 128, 2, // + 0, 1, 128, 128, 128, 2, 128, 3, // + 128, 128, 0, 128, 128, 1, 128, 2, // + 0, 128, 1, 128, 128, 2, 128, 3, // + 128, 0, 1, 128, 128, 2, 128, 3, // + 0, 1, 2, 128, 128, 3, 128, 4, // + 128, 128, 128, 0, 128, 1, 128, 2, // + 0, 128, 128, 1, 128, 2, 128, 3, // + 128, 0, 128, 1, 128, 2, 128, 3, // + 0, 1, 128, 2, 128, 3, 128, 4, // + 128, 128, 0, 1, 128, 2, 128, 3, // + 0, 128, 1, 2, 128, 3, 128, 4, // + 128, 0, 1, 2, 128, 3, 128, 4, // + 0, 1, 2, 3, 128, 4, 128, 5, // + 128, 128, 128, 128, 0, 1, 128, 2, // + 0, 128, 128, 128, 1, 2, 128, 3, // + 128, 0, 128, 128, 1, 2, 128, 3, // + 0, 1, 128, 128, 2, 3, 128, 4, // + 128, 128, 0, 128, 1, 2, 128, 3, // + 0, 128, 1, 128, 2, 3, 128, 4, // + 128, 0, 1, 128, 2, 3, 128, 4, // + 0, 1, 2, 128, 3, 4, 128, 5, // + 128, 128, 128, 0, 1, 2, 128, 3, // + 0, 128, 128, 1, 2, 3, 128, 4, // + 128, 0, 128, 1, 2, 3, 128, 4, // + 0, 1, 128, 2, 3, 4, 128, 5, // + 128, 128, 0, 1, 2, 3, 128, 4, // + 0, 128, 1, 2, 3, 4, 128, 5, // + 128, 0, 1, 2, 3, 4, 128, 5, // + 0, 1, 2, 3, 4, 5, 128, 6, // + 128, 128, 128, 128, 128, 128, 0, 1, // + 0, 128, 128, 128, 128, 128, 1, 2, // + 128, 0, 128, 128, 128, 128, 1, 2, // + 0, 1, 128, 128, 128, 128, 2, 3, // + 128, 128, 0, 128, 128, 128, 1, 2, // + 0, 128, 1, 128, 128, 128, 2, 3, // + 128, 0, 1, 128, 128, 128, 2, 3, // + 0, 1, 2, 128, 128, 128, 3, 4, // + 128, 128, 128, 0, 128, 128, 1, 2, // + 0, 128, 128, 1, 128, 128, 2, 3, // + 128, 0, 128, 1, 128, 128, 2, 3, // + 0, 1, 128, 2, 128, 128, 3, 4, // + 128, 128, 0, 1, 128, 128, 2, 3, // + 0, 128, 1, 2, 128, 128, 3, 4, // + 128, 0, 1, 2, 128, 128, 3, 4, // + 0, 1, 2, 3, 128, 128, 4, 5, // + 128, 128, 128, 128, 0, 128, 1, 2, // + 0, 128, 128, 128, 1, 128, 2, 3, // + 128, 0, 128, 128, 1, 128, 2, 3, // + 0, 1, 128, 128, 2, 128, 3, 4, // + 128, 128, 0, 128, 1, 128, 2, 3, // + 0, 128, 1, 128, 2, 128, 3, 4, // + 128, 0, 1, 128, 2, 128, 3, 4, // + 0, 1, 2, 128, 3, 128, 4, 5, // + 128, 128, 128, 0, 1, 128, 2, 3, // + 0, 128, 128, 1, 2, 128, 3, 4, // + 128, 0, 128, 1, 2, 128, 3, 4, // + 0, 1, 128, 2, 3, 128, 4, 5, // + 128, 128, 0, 1, 2, 128, 3, 4, // + 0, 128, 1, 2, 3, 128, 4, 5, // + 128, 0, 1, 2, 3, 128, 4, 5, // + 0, 1, 2, 3, 4, 128, 5, 6, // + 128, 128, 128, 128, 128, 0, 1, 2, // + 0, 128, 128, 128, 128, 1, 2, 3, // + 128, 0, 128, 128, 128, 1, 2, 3, // + 0, 1, 128, 128, 128, 2, 3, 4, // + 128, 128, 0, 128, 128, 1, 2, 3, // + 0, 128, 1, 128, 128, 2, 3, 4, // + 128, 0, 1, 128, 128, 2, 3, 4, // + 0, 1, 2, 128, 128, 3, 4, 5, // + 128, 128, 128, 0, 128, 1, 2, 3, // + 0, 128, 128, 1, 128, 2, 3, 4, // + 128, 0, 128, 1, 128, 2, 3, 4, // + 0, 1, 128, 2, 128, 3, 4, 5, // + 128, 128, 0, 1, 128, 2, 3, 4, // + 0, 128, 1, 2, 128, 3, 4, 5, // + 128, 0, 1, 2, 128, 3, 4, 5, // + 0, 1, 2, 3, 128, 4, 5, 6, // + 128, 128, 128, 128, 0, 1, 2, 3, // + 0, 128, 128, 128, 1, 2, 3, 4, // + 128, 0, 128, 128, 1, 2, 3, 4, // + 0, 1, 128, 128, 2, 3, 4, 5, // + 128, 128, 0, 128, 1, 2, 3, 4, // + 0, 128, 1, 128, 2, 3, 4, 5, // + 128, 0, 1, 128, 2, 3, 4, 5, // + 0, 1, 2, 128, 3, 4, 5, 6, // + 128, 128, 128, 0, 1, 2, 3, 4, // + 0, 128, 128, 1, 2, 3, 4, 5, // + 128, 0, 128, 1, 2, 3, 4, 5, // + 0, 1, 128, 2, 3, 4, 5, 6, // + 128, 128, 0, 1, 2, 3, 4, 5, // + 0, 128, 1, 2, 3, 4, 5, 6, // + 128, 0, 1, 2, 3, 4, 5, 6, // + 0, 1, 2, 3, 4, 5, 6, 7}; + return Load(du8, table + mask_bits * 8); +} + +template +HWY_INLINE svuint8_t LaneIndicesFromByteIndices(D, svuint8_t idx) { + return idx; +} +template , HWY_IF_NOT_T_SIZE_D(D, 1)> +HWY_INLINE VFromD LaneIndicesFromByteIndices(D, svuint8_t idx) { + return PromoteTo(DU(), idx); +} + +// General case when we don't know the vector size, 8 elements at a time. +template +HWY_INLINE V ExpandLoop(V v, svbool_t mask) { + const DFromV d; + using T = TFromV; + uint8_t mask_bytes[256 / 8]; + StoreMaskBits(d, mask, mask_bytes); + + // ShiftLeftLanes is expensive, so we're probably better off storing to memory + // and loading the final result. + alignas(16) T out[2 * MaxLanes(d)]; + + svbool_t next = svpfalse_b(); + size_t input_consumed = 0; + const V iota = Iota(d, 0); + for (size_t i = 0; i < Lanes(d); i += 8) { + uint64_t mask_bits = mask_bytes[i / 8]; + + // We want to skip past the v lanes already consumed. There is no + // instruction for variable-shift-reg, but we can splice. + const V vH = detail::Splice(v, v, next); + input_consumed += PopCount(mask_bits); + next = detail::GeN(iota, ConvertScalarTo(input_consumed)); + + const auto idx = detail::LaneIndicesFromByteIndices( + d, detail::IndicesForExpandFromBits(mask_bits)); + const V expand = TableLookupLanes(vH, idx); + StoreU(expand, d, out + i); + } + return LoadU(d, out); +} + +} // namespace detail + +template +HWY_API V Expand(V v, svbool_t mask) { +#if HWY_TARGET == HWY_SVE2_128 || HWY_IDE + const DFromV d; + uint8_t mask_bytes[256 / 8]; + StoreMaskBits(d, mask, mask_bytes); + const uint64_t maskL = mask_bytes[0]; + const uint64_t maskH = mask_bytes[1]; + + // We want to skip past the v bytes already consumed by expandL. There is no + // instruction for shift-reg by variable bytes, but we can splice. Instead of + // GeN, Not(FirstN()) would also work. + using T = TFromV; + const T countL = static_cast(PopCount(maskL)); + const V vH = detail::Splice(v, v, detail::GeN(Iota(d, 0), countL)); + + const svuint8_t idxL = detail::IndicesForExpandFromBits(maskL); + const svuint8_t idxH = detail::IndicesForExpandFromBits(maskH); + return Combine(d, TableLookupLanes(vH, idxH), TableLookupLanes(v, idxL)); +#else + return detail::ExpandLoop(v, mask); +#endif +} + +template +HWY_API V Expand(V v, svbool_t mask) { +#if HWY_TARGET == HWY_SVE2_128 || HWY_IDE // 16x8 + const DFromV d; + const RebindToUnsigned du16; + const Rebind du8; + // Convert mask into bitfield via horizontal sum (faster than ORV) of 8 bits. + // Pre-multiply by N so we can use it as an offset for Load. + const svuint16_t bits = Shl(Set(du16, 1), Iota(du16, 3)); + const size_t offset = detail::SumOfLanesM(mask, bits); + + // Storing as 8-bit reduces table size from 4 KiB to 2 KiB. We cannot apply + // the nibble trick used below because not all indices fit within one lane. + alignas(16) static constexpr uint8_t table[8 * 256] = { + // PrintExpand16x8LaneTables + 255, 255, 255, 255, 255, 255, 255, 255, // + 0, 255, 255, 255, 255, 255, 255, 255, // + 255, 0, 255, 255, 255, 255, 255, 255, // + 0, 1, 255, 255, 255, 255, 255, 255, // + 255, 255, 0, 255, 255, 255, 255, 255, // + 0, 255, 1, 255, 255, 255, 255, 255, // + 255, 0, 1, 255, 255, 255, 255, 255, // + 0, 1, 2, 255, 255, 255, 255, 255, // + 255, 255, 255, 0, 255, 255, 255, 255, // + 0, 255, 255, 1, 255, 255, 255, 255, // + 255, 0, 255, 1, 255, 255, 255, 255, // + 0, 1, 255, 2, 255, 255, 255, 255, // + 255, 255, 0, 1, 255, 255, 255, 255, // + 0, 255, 1, 2, 255, 255, 255, 255, // + 255, 0, 1, 2, 255, 255, 255, 255, // + 0, 1, 2, 3, 255, 255, 255, 255, // + 255, 255, 255, 255, 0, 255, 255, 255, // + 0, 255, 255, 255, 1, 255, 255, 255, // + 255, 0, 255, 255, 1, 255, 255, 255, // + 0, 1, 255, 255, 2, 255, 255, 255, // + 255, 255, 0, 255, 1, 255, 255, 255, // + 0, 255, 1, 255, 2, 255, 255, 255, // + 255, 0, 1, 255, 2, 255, 255, 255, // + 0, 1, 2, 255, 3, 255, 255, 255, // + 255, 255, 255, 0, 1, 255, 255, 255, // + 0, 255, 255, 1, 2, 255, 255, 255, // + 255, 0, 255, 1, 2, 255, 255, 255, // + 0, 1, 255, 2, 3, 255, 255, 255, // + 255, 255, 0, 1, 2, 255, 255, 255, // + 0, 255, 1, 2, 3, 255, 255, 255, // + 255, 0, 1, 2, 3, 255, 255, 255, // + 0, 1, 2, 3, 4, 255, 255, 255, // + 255, 255, 255, 255, 255, 0, 255, 255, // + 0, 255, 255, 255, 255, 1, 255, 255, // + 255, 0, 255, 255, 255, 1, 255, 255, // + 0, 1, 255, 255, 255, 2, 255, 255, // + 255, 255, 0, 255, 255, 1, 255, 255, // + 0, 255, 1, 255, 255, 2, 255, 255, // + 255, 0, 1, 255, 255, 2, 255, 255, // + 0, 1, 2, 255, 255, 3, 255, 255, // + 255, 255, 255, 0, 255, 1, 255, 255, // + 0, 255, 255, 1, 255, 2, 255, 255, // + 255, 0, 255, 1, 255, 2, 255, 255, // + 0, 1, 255, 2, 255, 3, 255, 255, // + 255, 255, 0, 1, 255, 2, 255, 255, // + 0, 255, 1, 2, 255, 3, 255, 255, // + 255, 0, 1, 2, 255, 3, 255, 255, // + 0, 1, 2, 3, 255, 4, 255, 255, // + 255, 255, 255, 255, 0, 1, 255, 255, // + 0, 255, 255, 255, 1, 2, 255, 255, // + 255, 0, 255, 255, 1, 2, 255, 255, // + 0, 1, 255, 255, 2, 3, 255, 255, // + 255, 255, 0, 255, 1, 2, 255, 255, // + 0, 255, 1, 255, 2, 3, 255, 255, // + 255, 0, 1, 255, 2, 3, 255, 255, // + 0, 1, 2, 255, 3, 4, 255, 255, // + 255, 255, 255, 0, 1, 2, 255, 255, // + 0, 255, 255, 1, 2, 3, 255, 255, // + 255, 0, 255, 1, 2, 3, 255, 255, // + 0, 1, 255, 2, 3, 4, 255, 255, // + 255, 255, 0, 1, 2, 3, 255, 255, // + 0, 255, 1, 2, 3, 4, 255, 255, // + 255, 0, 1, 2, 3, 4, 255, 255, // + 0, 1, 2, 3, 4, 5, 255, 255, // + 255, 255, 255, 255, 255, 255, 0, 255, // + 0, 255, 255, 255, 255, 255, 1, 255, // + 255, 0, 255, 255, 255, 255, 1, 255, // + 0, 1, 255, 255, 255, 255, 2, 255, // + 255, 255, 0, 255, 255, 255, 1, 255, // + 0, 255, 1, 255, 255, 255, 2, 255, // + 255, 0, 1, 255, 255, 255, 2, 255, // + 0, 1, 2, 255, 255, 255, 3, 255, // + 255, 255, 255, 0, 255, 255, 1, 255, // + 0, 255, 255, 1, 255, 255, 2, 255, // + 255, 0, 255, 1, 255, 255, 2, 255, // + 0, 1, 255, 2, 255, 255, 3, 255, // + 255, 255, 0, 1, 255, 255, 2, 255, // + 0, 255, 1, 2, 255, 255, 3, 255, // + 255, 0, 1, 2, 255, 255, 3, 255, // + 0, 1, 2, 3, 255, 255, 4, 255, // + 255, 255, 255, 255, 0, 255, 1, 255, // + 0, 255, 255, 255, 1, 255, 2, 255, // + 255, 0, 255, 255, 1, 255, 2, 255, // + 0, 1, 255, 255, 2, 255, 3, 255, // + 255, 255, 0, 255, 1, 255, 2, 255, // + 0, 255, 1, 255, 2, 255, 3, 255, // + 255, 0, 1, 255, 2, 255, 3, 255, // + 0, 1, 2, 255, 3, 255, 4, 255, // + 255, 255, 255, 0, 1, 255, 2, 255, // + 0, 255, 255, 1, 2, 255, 3, 255, // + 255, 0, 255, 1, 2, 255, 3, 255, // + 0, 1, 255, 2, 3, 255, 4, 255, // + 255, 255, 0, 1, 2, 255, 3, 255, // + 0, 255, 1, 2, 3, 255, 4, 255, // + 255, 0, 1, 2, 3, 255, 4, 255, // + 0, 1, 2, 3, 4, 255, 5, 255, // + 255, 255, 255, 255, 255, 0, 1, 255, // + 0, 255, 255, 255, 255, 1, 2, 255, // + 255, 0, 255, 255, 255, 1, 2, 255, // + 0, 1, 255, 255, 255, 2, 3, 255, // + 255, 255, 0, 255, 255, 1, 2, 255, // + 0, 255, 1, 255, 255, 2, 3, 255, // + 255, 0, 1, 255, 255, 2, 3, 255, // + 0, 1, 2, 255, 255, 3, 4, 255, // + 255, 255, 255, 0, 255, 1, 2, 255, // + 0, 255, 255, 1, 255, 2, 3, 255, // + 255, 0, 255, 1, 255, 2, 3, 255, // + 0, 1, 255, 2, 255, 3, 4, 255, // + 255, 255, 0, 1, 255, 2, 3, 255, // + 0, 255, 1, 2, 255, 3, 4, 255, // + 255, 0, 1, 2, 255, 3, 4, 255, // + 0, 1, 2, 3, 255, 4, 5, 255, // + 255, 255, 255, 255, 0, 1, 2, 255, // + 0, 255, 255, 255, 1, 2, 3, 255, // + 255, 0, 255, 255, 1, 2, 3, 255, // + 0, 1, 255, 255, 2, 3, 4, 255, // + 255, 255, 0, 255, 1, 2, 3, 255, // + 0, 255, 1, 255, 2, 3, 4, 255, // + 255, 0, 1, 255, 2, 3, 4, 255, // + 0, 1, 2, 255, 3, 4, 5, 255, // + 255, 255, 255, 0, 1, 2, 3, 255, // + 0, 255, 255, 1, 2, 3, 4, 255, // + 255, 0, 255, 1, 2, 3, 4, 255, // + 0, 1, 255, 2, 3, 4, 5, 255, // + 255, 255, 0, 1, 2, 3, 4, 255, // + 0, 255, 1, 2, 3, 4, 5, 255, // + 255, 0, 1, 2, 3, 4, 5, 255, // + 0, 1, 2, 3, 4, 5, 6, 255, // + 255, 255, 255, 255, 255, 255, 255, 0, // + 0, 255, 255, 255, 255, 255, 255, 1, // + 255, 0, 255, 255, 255, 255, 255, 1, // + 0, 1, 255, 255, 255, 255, 255, 2, // + 255, 255, 0, 255, 255, 255, 255, 1, // + 0, 255, 1, 255, 255, 255, 255, 2, // + 255, 0, 1, 255, 255, 255, 255, 2, // + 0, 1, 2, 255, 255, 255, 255, 3, // + 255, 255, 255, 0, 255, 255, 255, 1, // + 0, 255, 255, 1, 255, 255, 255, 2, // + 255, 0, 255, 1, 255, 255, 255, 2, // + 0, 1, 255, 2, 255, 255, 255, 3, // + 255, 255, 0, 1, 255, 255, 255, 2, // + 0, 255, 1, 2, 255, 255, 255, 3, // + 255, 0, 1, 2, 255, 255, 255, 3, // + 0, 1, 2, 3, 255, 255, 255, 4, // + 255, 255, 255, 255, 0, 255, 255, 1, // + 0, 255, 255, 255, 1, 255, 255, 2, // + 255, 0, 255, 255, 1, 255, 255, 2, // + 0, 1, 255, 255, 2, 255, 255, 3, // + 255, 255, 0, 255, 1, 255, 255, 2, // + 0, 255, 1, 255, 2, 255, 255, 3, // + 255, 0, 1, 255, 2, 255, 255, 3, // + 0, 1, 2, 255, 3, 255, 255, 4, // + 255, 255, 255, 0, 1, 255, 255, 2, // + 0, 255, 255, 1, 2, 255, 255, 3, // + 255, 0, 255, 1, 2, 255, 255, 3, // + 0, 1, 255, 2, 3, 255, 255, 4, // + 255, 255, 0, 1, 2, 255, 255, 3, // + 0, 255, 1, 2, 3, 255, 255, 4, // + 255, 0, 1, 2, 3, 255, 255, 4, // + 0, 1, 2, 3, 4, 255, 255, 5, // + 255, 255, 255, 255, 255, 0, 255, 1, // + 0, 255, 255, 255, 255, 1, 255, 2, // + 255, 0, 255, 255, 255, 1, 255, 2, // + 0, 1, 255, 255, 255, 2, 255, 3, // + 255, 255, 0, 255, 255, 1, 255, 2, // + 0, 255, 1, 255, 255, 2, 255, 3, // + 255, 0, 1, 255, 255, 2, 255, 3, // + 0, 1, 2, 255, 255, 3, 255, 4, // + 255, 255, 255, 0, 255, 1, 255, 2, // + 0, 255, 255, 1, 255, 2, 255, 3, // + 255, 0, 255, 1, 255, 2, 255, 3, // + 0, 1, 255, 2, 255, 3, 255, 4, // + 255, 255, 0, 1, 255, 2, 255, 3, // + 0, 255, 1, 2, 255, 3, 255, 4, // + 255, 0, 1, 2, 255, 3, 255, 4, // + 0, 1, 2, 3, 255, 4, 255, 5, // + 255, 255, 255, 255, 0, 1, 255, 2, // + 0, 255, 255, 255, 1, 2, 255, 3, // + 255, 0, 255, 255, 1, 2, 255, 3, // + 0, 1, 255, 255, 2, 3, 255, 4, // + 255, 255, 0, 255, 1, 2, 255, 3, // + 0, 255, 1, 255, 2, 3, 255, 4, // + 255, 0, 1, 255, 2, 3, 255, 4, // + 0, 1, 2, 255, 3, 4, 255, 5, // + 255, 255, 255, 0, 1, 2, 255, 3, // + 0, 255, 255, 1, 2, 3, 255, 4, // + 255, 0, 255, 1, 2, 3, 255, 4, // + 0, 1, 255, 2, 3, 4, 255, 5, // + 255, 255, 0, 1, 2, 3, 255, 4, // + 0, 255, 1, 2, 3, 4, 255, 5, // + 255, 0, 1, 2, 3, 4, 255, 5, // + 0, 1, 2, 3, 4, 5, 255, 6, // + 255, 255, 255, 255, 255, 255, 0, 1, // + 0, 255, 255, 255, 255, 255, 1, 2, // + 255, 0, 255, 255, 255, 255, 1, 2, // + 0, 1, 255, 255, 255, 255, 2, 3, // + 255, 255, 0, 255, 255, 255, 1, 2, // + 0, 255, 1, 255, 255, 255, 2, 3, // + 255, 0, 1, 255, 255, 255, 2, 3, // + 0, 1, 2, 255, 255, 255, 3, 4, // + 255, 255, 255, 0, 255, 255, 1, 2, // + 0, 255, 255, 1, 255, 255, 2, 3, // + 255, 0, 255, 1, 255, 255, 2, 3, // + 0, 1, 255, 2, 255, 255, 3, 4, // + 255, 255, 0, 1, 255, 255, 2, 3, // + 0, 255, 1, 2, 255, 255, 3, 4, // + 255, 0, 1, 2, 255, 255, 3, 4, // + 0, 1, 2, 3, 255, 255, 4, 5, // + 255, 255, 255, 255, 0, 255, 1, 2, // + 0, 255, 255, 255, 1, 255, 2, 3, // + 255, 0, 255, 255, 1, 255, 2, 3, // + 0, 1, 255, 255, 2, 255, 3, 4, // + 255, 255, 0, 255, 1, 255, 2, 3, // + 0, 255, 1, 255, 2, 255, 3, 4, // + 255, 0, 1, 255, 2, 255, 3, 4, // + 0, 1, 2, 255, 3, 255, 4, 5, // + 255, 255, 255, 0, 1, 255, 2, 3, // + 0, 255, 255, 1, 2, 255, 3, 4, // + 255, 0, 255, 1, 2, 255, 3, 4, // + 0, 1, 255, 2, 3, 255, 4, 5, // + 255, 255, 0, 1, 2, 255, 3, 4, // + 0, 255, 1, 2, 3, 255, 4, 5, // + 255, 0, 1, 2, 3, 255, 4, 5, // + 0, 1, 2, 3, 4, 255, 5, 6, // + 255, 255, 255, 255, 255, 0, 1, 2, // + 0, 255, 255, 255, 255, 1, 2, 3, // + 255, 0, 255, 255, 255, 1, 2, 3, // + 0, 1, 255, 255, 255, 2, 3, 4, // + 255, 255, 0, 255, 255, 1, 2, 3, // + 0, 255, 1, 255, 255, 2, 3, 4, // + 255, 0, 1, 255, 255, 2, 3, 4, // + 0, 1, 2, 255, 255, 3, 4, 5, // + 255, 255, 255, 0, 255, 1, 2, 3, // + 0, 255, 255, 1, 255, 2, 3, 4, // + 255, 0, 255, 1, 255, 2, 3, 4, // + 0, 1, 255, 2, 255, 3, 4, 5, // + 255, 255, 0, 1, 255, 2, 3, 4, // + 0, 255, 1, 2, 255, 3, 4, 5, // + 255, 0, 1, 2, 255, 3, 4, 5, // + 0, 1, 2, 3, 255, 4, 5, 6, // + 255, 255, 255, 255, 0, 1, 2, 3, // + 0, 255, 255, 255, 1, 2, 3, 4, // + 255, 0, 255, 255, 1, 2, 3, 4, // + 0, 1, 255, 255, 2, 3, 4, 5, // + 255, 255, 0, 255, 1, 2, 3, 4, // + 0, 255, 1, 255, 2, 3, 4, 5, // + 255, 0, 1, 255, 2, 3, 4, 5, // + 0, 1, 2, 255, 3, 4, 5, 6, // + 255, 255, 255, 0, 1, 2, 3, 4, // + 0, 255, 255, 1, 2, 3, 4, 5, // + 255, 0, 255, 1, 2, 3, 4, 5, // + 0, 1, 255, 2, 3, 4, 5, 6, // + 255, 255, 0, 1, 2, 3, 4, 5, // + 0, 255, 1, 2, 3, 4, 5, 6, // + 255, 0, 1, 2, 3, 4, 5, 6, // + 0, 1, 2, 3, 4, 5, 6, 7}; + const svuint16_t indices = PromoteTo(du16, Load(du8, table + offset)); + return TableLookupLanes(v, indices); // already zeros mask=false lanes +#else + return detail::ExpandLoop(v, mask); +#endif +} + +template +HWY_API V Expand(V v, svbool_t mask) { +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE // 32x8 + const DFromV d; + const RebindToUnsigned du32; + // Convert mask into bitfield via horizontal sum (faster than ORV). + const svuint32_t bits = Shl(Set(du32, 1), Iota(du32, 0)); + const size_t code = detail::SumOfLanesM(mask, bits); + + alignas(16) constexpr uint32_t packed_array[256] = { + // PrintExpand32x8. + 0xffffffff, 0xfffffff0, 0xffffff0f, 0xffffff10, 0xfffff0ff, 0xfffff1f0, + 0xfffff10f, 0xfffff210, 0xffff0fff, 0xffff1ff0, 0xffff1f0f, 0xffff2f10, + 0xffff10ff, 0xffff21f0, 0xffff210f, 0xffff3210, 0xfff0ffff, 0xfff1fff0, + 0xfff1ff0f, 0xfff2ff10, 0xfff1f0ff, 0xfff2f1f0, 0xfff2f10f, 0xfff3f210, + 0xfff10fff, 0xfff21ff0, 0xfff21f0f, 0xfff32f10, 0xfff210ff, 0xfff321f0, + 0xfff3210f, 0xfff43210, 0xff0fffff, 0xff1ffff0, 0xff1fff0f, 0xff2fff10, + 0xff1ff0ff, 0xff2ff1f0, 0xff2ff10f, 0xff3ff210, 0xff1f0fff, 0xff2f1ff0, + 0xff2f1f0f, 0xff3f2f10, 0xff2f10ff, 0xff3f21f0, 0xff3f210f, 0xff4f3210, + 0xff10ffff, 0xff21fff0, 0xff21ff0f, 0xff32ff10, 0xff21f0ff, 0xff32f1f0, + 0xff32f10f, 0xff43f210, 0xff210fff, 0xff321ff0, 0xff321f0f, 0xff432f10, + 0xff3210ff, 0xff4321f0, 0xff43210f, 0xff543210, 0xf0ffffff, 0xf1fffff0, + 0xf1ffff0f, 0xf2ffff10, 0xf1fff0ff, 0xf2fff1f0, 0xf2fff10f, 0xf3fff210, + 0xf1ff0fff, 0xf2ff1ff0, 0xf2ff1f0f, 0xf3ff2f10, 0xf2ff10ff, 0xf3ff21f0, + 0xf3ff210f, 0xf4ff3210, 0xf1f0ffff, 0xf2f1fff0, 0xf2f1ff0f, 0xf3f2ff10, + 0xf2f1f0ff, 0xf3f2f1f0, 0xf3f2f10f, 0xf4f3f210, 0xf2f10fff, 0xf3f21ff0, + 0xf3f21f0f, 0xf4f32f10, 0xf3f210ff, 0xf4f321f0, 0xf4f3210f, 0xf5f43210, + 0xf10fffff, 0xf21ffff0, 0xf21fff0f, 0xf32fff10, 0xf21ff0ff, 0xf32ff1f0, + 0xf32ff10f, 0xf43ff210, 0xf21f0fff, 0xf32f1ff0, 0xf32f1f0f, 0xf43f2f10, + 0xf32f10ff, 0xf43f21f0, 0xf43f210f, 0xf54f3210, 0xf210ffff, 0xf321fff0, + 0xf321ff0f, 0xf432ff10, 0xf321f0ff, 0xf432f1f0, 0xf432f10f, 0xf543f210, + 0xf3210fff, 0xf4321ff0, 0xf4321f0f, 0xf5432f10, 0xf43210ff, 0xf54321f0, + 0xf543210f, 0xf6543210, 0x0fffffff, 0x1ffffff0, 0x1fffff0f, 0x2fffff10, + 0x1ffff0ff, 0x2ffff1f0, 0x2ffff10f, 0x3ffff210, 0x1fff0fff, 0x2fff1ff0, + 0x2fff1f0f, 0x3fff2f10, 0x2fff10ff, 0x3fff21f0, 0x3fff210f, 0x4fff3210, + 0x1ff0ffff, 0x2ff1fff0, 0x2ff1ff0f, 0x3ff2ff10, 0x2ff1f0ff, 0x3ff2f1f0, + 0x3ff2f10f, 0x4ff3f210, 0x2ff10fff, 0x3ff21ff0, 0x3ff21f0f, 0x4ff32f10, + 0x3ff210ff, 0x4ff321f0, 0x4ff3210f, 0x5ff43210, 0x1f0fffff, 0x2f1ffff0, + 0x2f1fff0f, 0x3f2fff10, 0x2f1ff0ff, 0x3f2ff1f0, 0x3f2ff10f, 0x4f3ff210, + 0x2f1f0fff, 0x3f2f1ff0, 0x3f2f1f0f, 0x4f3f2f10, 0x3f2f10ff, 0x4f3f21f0, + 0x4f3f210f, 0x5f4f3210, 0x2f10ffff, 0x3f21fff0, 0x3f21ff0f, 0x4f32ff10, + 0x3f21f0ff, 0x4f32f1f0, 0x4f32f10f, 0x5f43f210, 0x3f210fff, 0x4f321ff0, + 0x4f321f0f, 0x5f432f10, 0x4f3210ff, 0x5f4321f0, 0x5f43210f, 0x6f543210, + 0x10ffffff, 0x21fffff0, 0x21ffff0f, 0x32ffff10, 0x21fff0ff, 0x32fff1f0, + 0x32fff10f, 0x43fff210, 0x21ff0fff, 0x32ff1ff0, 0x32ff1f0f, 0x43ff2f10, + 0x32ff10ff, 0x43ff21f0, 0x43ff210f, 0x54ff3210, 0x21f0ffff, 0x32f1fff0, + 0x32f1ff0f, 0x43f2ff10, 0x32f1f0ff, 0x43f2f1f0, 0x43f2f10f, 0x54f3f210, + 0x32f10fff, 0x43f21ff0, 0x43f21f0f, 0x54f32f10, 0x43f210ff, 0x54f321f0, + 0x54f3210f, 0x65f43210, 0x210fffff, 0x321ffff0, 0x321fff0f, 0x432fff10, + 0x321ff0ff, 0x432ff1f0, 0x432ff10f, 0x543ff210, 0x321f0fff, 0x432f1ff0, + 0x432f1f0f, 0x543f2f10, 0x432f10ff, 0x543f21f0, 0x543f210f, 0x654f3210, + 0x3210ffff, 0x4321fff0, 0x4321ff0f, 0x5432ff10, 0x4321f0ff, 0x5432f1f0, + 0x5432f10f, 0x6543f210, 0x43210fff, 0x54321ff0, 0x54321f0f, 0x65432f10, + 0x543210ff, 0x654321f0, 0x6543210f, 0x76543210}; + + // For lane i, shift the i-th 4-bit index down and mask with 0xF because + // svtbl zeros outputs if the index is out of bounds. + const svuint32_t packed = Set(du32, packed_array[code]); + const svuint32_t indices = detail::AndN(Shr(packed, svindex_u32(0, 4)), 0xF); + return TableLookupLanes(v, indices); // already zeros mask=false lanes +#elif HWY_TARGET == HWY_SVE2_128 // 32x4 + const DFromV d; + const RebindToUnsigned du32; + // Convert mask into bitfield via horizontal sum (faster than ORV). + const svuint32_t bits = Shl(Set(du32, 1), Iota(du32, 0)); + const size_t offset = detail::SumOfLanesM(mask, bits); + + alignas(16) constexpr uint32_t packed_array[16] = { + // PrintExpand64x4Nibble - same for 32x4. + 0x0000ffff, 0x0000fff0, 0x0000ff0f, 0x0000ff10, 0x0000f0ff, 0x0000f1f0, + 0x0000f10f, 0x0000f210, 0x00000fff, 0x00001ff0, 0x00001f0f, 0x00002f10, + 0x000010ff, 0x000021f0, 0x0000210f, 0x00003210}; + + // For lane i, shift the i-th 4-bit index down and mask with 0xF because + // svtbl zeros outputs if the index is out of bounds. + const svuint32_t packed = Set(du32, packed_array[offset]); + const svuint32_t indices = detail::AndN(Shr(packed, svindex_u32(0, 4)), 0xF); + return TableLookupLanes(v, indices); // already zeros mask=false lanes +#else + return detail::ExpandLoop(v, mask); +#endif +} + +template +HWY_API V Expand(V v, svbool_t mask) { +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE // 64x4 + const DFromV d; + const RebindToUnsigned du64; + + // Convert mask into bitfield via horizontal sum (faster than ORV) of masked + // bits 1, 2, 4, 8. Pre-multiply by N so we can use it as an offset for + // SetTableIndices. + const svuint64_t bits = Shl(Set(du64, 1), Iota(du64, 2)); + const size_t offset = detail::SumOfLanesM(mask, bits); + + alignas(16) static constexpr uint64_t table[4 * 16] = { + // PrintExpand64x4Tables - small enough to store uncompressed. + 255, 255, 255, 255, 0, 255, 255, 255, 255, 0, 255, 255, 0, 1, 255, 255, + 255, 255, 0, 255, 0, 255, 1, 255, 255, 0, 1, 255, 0, 1, 2, 255, + 255, 255, 255, 0, 0, 255, 255, 1, 255, 0, 255, 1, 0, 1, 255, 2, + 255, 255, 0, 1, 0, 255, 1, 2, 255, 0, 1, 2, 0, 1, 2, 3}; + // This already zeros mask=false lanes. + return TableLookupLanes(v, SetTableIndices(d, table + offset)); +#elif HWY_TARGET == HWY_SVE2_128 // 64x2 + // Same as Compress, just zero out the mask=false lanes. + return IfThenElseZero(mask, Compress(v, mask)); +#else + return detail::ExpandLoop(v, mask); +#endif +} + +// ------------------------------ LoadExpand + +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { + return Expand(LoadU(d, unaligned), mask); +} + +// ------------------------------ MulEven (InterleaveEven) + +#if HWY_SVE_HAVE_2 +namespace detail { +#define HWY_SVE_MUL_EVEN(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_V(BASE, HALF) a, HWY_SVE_V(BASE, HALF) b) { \ + return sv##OP##_##CHAR##BITS(a, b); \ + } + +HWY_SVE_FOREACH_UI16(HWY_SVE_MUL_EVEN, MulEvenNative, mullb) +HWY_SVE_FOREACH_UI32(HWY_SVE_MUL_EVEN, MulEvenNative, mullb) +HWY_SVE_FOREACH_UI64(HWY_SVE_MUL_EVEN, MulEvenNative, mullb) +HWY_SVE_FOREACH_UI16(HWY_SVE_MUL_EVEN, MulOddNative, mullt) +HWY_SVE_FOREACH_UI32(HWY_SVE_MUL_EVEN, MulOddNative, mullt) +HWY_SVE_FOREACH_UI64(HWY_SVE_MUL_EVEN, MulOddNative, mullt) +#undef HWY_SVE_MUL_EVEN +} // namespace detail +#endif + +template >, + HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2) | (1 << 4))> +HWY_API VFromD MulEven(const V a, const V b) { +#if HWY_SVE_HAVE_2 + return BitCast(DW(), detail::MulEvenNative(a, b)); +#else + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return BitCast(DW(), detail::InterleaveEven(lo, hi)); +#endif +} + +template >, + HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2) | (1 << 4))> +HWY_API VFromD MulOdd(const V a, const V b) { +#if HWY_SVE_HAVE_2 + return BitCast(DW(), detail::MulOddNative(a, b)); +#else + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return BitCast(DW(), detail::InterleaveOdd(lo, hi)); +#endif +} + +HWY_API svint64_t MulEven(const svint64_t a, const svint64_t b) { + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return detail::InterleaveEven(lo, hi); +} + +HWY_API svuint64_t MulEven(const svuint64_t a, const svuint64_t b) { + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return detail::InterleaveEven(lo, hi); +} + +HWY_API svint64_t MulOdd(const svint64_t a, const svint64_t b) { + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return detail::InterleaveOdd(lo, hi); +} + +HWY_API svuint64_t MulOdd(const svuint64_t a, const svuint64_t b) { + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return detail::InterleaveOdd(lo, hi); +} + +// ------------------------------ PairwiseAdd/PairwiseSub +#if HWY_TARGET != HWY_SCALAR +#if HWY_SVE_HAVE_2 || HWY_IDE + +#ifdef HWY_NATIVE_PAIRWISE_ADD +#undef HWY_NATIVE_PAIRWISE_ADD +#else +#define HWY_NATIVE_PAIRWISE_ADD +#endif + +namespace detail { +#define HWY_SVE_SV_PAIRWISE_ADD(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /*d*/, HWY_SVE_V(BASE, BITS) a, \ + HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS##_m(HWY_SVE_PTRUE(BITS), a, b); \ + } + +HWY_SVE_FOREACH(HWY_SVE_SV_PAIRWISE_ADD, PairwiseAdd, addp) +#undef HWY_SVE_SV_PAIRWISE_ADD +} // namespace detail + +// Pairwise add returning interleaved output of a and b +template +HWY_API V PairwiseAdd(D d, V a, V b) { + return detail::PairwiseAdd(d, a, b); +} + +#endif // HWY_SVE_HAVE_2 +#endif // HWY_TARGET != HWY_SCALAR + +// ------------------------------ MulEvenAdd/MulOddAdd (PromoteEvenTo) + +// Always implemented here because this op is used in WidenMulEven. +#ifdef HWY_NATIVE_MUL_EVEN_BF16 +#undef HWY_NATIVE_MUL_EVEN_BF16 +#else +#define HWY_NATIVE_MUL_EVEN_BF16 +#endif + +template +HWY_API svfloat32_t MulEvenAdd(Simd dw, VBF16 a, VBF16 b, + const svfloat32_t c) { +#if HWY_SVE_HAVE_BF16_FEATURE + return svbfmlalb_f32(c, a, b); +#else + return MulAdd(PromoteEvenTo(dw, a), PromoteEvenTo(dw, b), c); +#endif // HWY_SVE_HAVE_BF16_FEATURE +} + +template +HWY_API svfloat32_t MulOddAdd(Simd dw, VBF16 a, VBF16 b, + const svfloat32_t c) { +#if HWY_SVE_HAVE_BF16_FEATURE + (void)dw; + return svbfmlalt_f32(c, a, b); +#else + return MulAdd(PromoteOddTo(dw, a), PromoteOddTo(dw, b), c); +#endif // HWY_SVE_HAVE_BF16_FEATURE +} + +// Highway API only guarantees support for bf16*bf16+f32; also implement UI32 +// to allow reusing this op for integer WidenMulPairwiseAdd: + +template +HWY_API svint32_t MulEvenAdd(Simd dw, svint16_t a, + svint16_t b, const svint32_t c) { +#if HWY_SVE_HAVE_2 + (void)dw; + return svmlalb_s32(c, a, b); +#else + return MulAdd(PromoteEvenTo(dw, a), PromoteEvenTo(dw, b), c); +#endif // HWY_SVE_HAVE_2 +} + +template +HWY_API svuint32_t MulEvenAdd(Simd dw, svuint16_t a, + svuint16_t b, const svuint32_t c) { +#if HWY_SVE_HAVE_2 + (void)dw; + return svmlalb_u32(c, a, b); +#else + return MulAdd(PromoteEvenTo(dw, a), PromoteEvenTo(dw, b), c); +#endif // HWY_SVE_HAVE_2 +} + +template +HWY_API svint32_t MulOddAdd(Simd dw, svint16_t a, + svint16_t b, const svint32_t c) { +#if HWY_SVE_HAVE_2 + (void)dw; + return svmlalt_s32(c, a, b); +#else + return MulAdd(PromoteOddTo(dw, a), PromoteOddTo(dw, b), c); +#endif // HWY_SVE_HAVE_2 +} + +template +HWY_API svuint32_t MulOddAdd(Simd dw, svuint16_t a, + svuint16_t b, const svuint32_t c) { +#if HWY_SVE_HAVE_2 + (void)dw; + return svmlalt_u32(c, a, b); +#else + return MulAdd(PromoteOddTo(dw, a), PromoteOddTo(dw, b), c); +#endif // HWY_SVE_HAVE_2 +} + +// ------------------------------ WidenMulEven (MulEvenAdd, PromoteEvenTo) + +template +HWY_API svfloat32_t WidenMulEven(Simd dw, VBF16 a, VBF16 b) { +#if HWY_SVE_HAVE_BF16_FEATURE + (void)dw; + return MulEvenAdd(dw, Zero(dw), a, b); +#else + // Same as MulEvenAdd, but without generating a zero argument. + return Mul(PromoteEvenTo(dw, a), PromoteEvenTo(dw, b)); +#endif // HWY_SVE_HAVE_BF16_FEATURE +} + +template +HWY_API svint32_t WidenMulEven(Simd dw, svint16_t a, + svint16_t b) { +#if HWY_SVE_HAVE_2 + (void)dw; + return svmullb_s32(a, b); +#else + // Same as MulEvenAdd, but without generating a zero argument. + return Mul(PromoteEvenTo(dw, a), PromoteEvenTo(dw, b)); +#endif // HWY_SVE_HAVE_2 +} + +template +HWY_API svuint32_t WidenMulEven(Simd dw, svuint16_t a, + svuint16_t b) { +#if HWY_SVE_HAVE_2 + (void)dw; + return svmullb_u32(a, b); +#else + // Same as MulEvenAdd, but without generating a zero argument. + return Mul(PromoteEvenTo(dw, a), PromoteEvenTo(dw, b)); +#endif // HWY_SVE_HAVE_2 +} + +// ------------------------------ WidenMulPairwiseAdd (WidenMulEven, MulOddAdd) +// Deduce from VN: RepartitionToNarrow could be either F16 or BF16 for float. +template , class DW = RepartitionToWide, + class VW = VFromD> +HWY_API VW WidenMulPairwiseAdd(DW dw, VN a, VN b) { + return MulOddAdd(dw, a, b, WidenMulEven(dw, a, b)); +} + +// ------------------------------ SatWidenMulPairwiseAccumulate +#if HWY_SVE_HAVE_2 +#define HWY_SVE_SAT_MUL_WIDEN_PW_ACC_SVE_2(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME(HWY_SVE_D(BASE, BITS, N, kPow2) dw, HWY_SVE_V(BASE, HALF) a, \ + HWY_SVE_V(BASE, HALF) b, HWY_SVE_V(BASE, BITS) sum) { \ + auto product = svmlalt_##CHAR##BITS(svmullb_##CHAR##BITS(a, b), a, b); \ + const auto mul_overflow = IfThenElseZero( \ + Eq(product, Set(dw, LimitsMin())), Set(dw, -1)); \ + return SaturatedAdd(Sub(sum, And(BroadcastSignBit(sum), mul_overflow)), \ + Add(product, mul_overflow)); \ + } +HWY_SVE_FOREACH_UI16(HWY_SVE_SAT_MUL_WIDEN_PW_ACC_SVE_2, + SatWidenMulPairwiseAccumulate, _) +HWY_SVE_FOREACH_UI32(HWY_SVE_SAT_MUL_WIDEN_PW_ACC_SVE_2, + SatWidenMulPairwiseAccumulate, _) +HWY_SVE_FOREACH_UI64(HWY_SVE_SAT_MUL_WIDEN_PW_ACC_SVE_2, + SatWidenMulPairwiseAccumulate, _) + +#undef HWY_SVE_SAT_MUL_WIDEN_PW_ACC_SVE_2 +#endif + +// ------------------------------ SatWidenMulAccumFixedPoint + +#if HWY_SVE_HAVE_2 + +#ifdef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#undef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#else +#define HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#endif + +template +HWY_API VFromD SatWidenMulAccumFixedPoint(DI32 /*di32*/, + VFromD> a, + VFromD> b, + VFromD sum) { + return svqdmlalb_s32(sum, detail::ZipLowerSame(a, a), + detail::ZipLowerSame(b, b)); +} + +#endif // HWY_SVE_HAVE_2 + +// ------------------------------ ReorderWidenMulAccumulate (MulOddEven) + +#if HWY_SVE_HAVE_BF16_FEATURE + +#ifdef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#undef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#else +#define HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#endif + +// NOTE: svbfdot uses round to odd unless the additional FEAT_EBF16 feature is +// available and enabled. +template +HWY_API svfloat32_t ReorderWidenMulAccumulate(Simd d32, + VBF16 a, VBF16 b, + const svfloat32_t sum0, + svfloat32_t& sum1) { + (void)d32; + (void)sum1; + return svbfdot_f32(sum0, a, b); +} + +template +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW) { + // sum1 is unused and the invariant already holds. + return sum0; +} + +#endif // HWY_SVE_HAVE_BF16_FEATURE + +// F32 version is implemented above, or in generic_ops-inl.h. +template , class DW = RepartitionToWide, + class VW = VFromD, HWY_IF_NOT_FLOAT_D(DW)> +HWY_API VW ReorderWidenMulAccumulate(DW dw, VN a, VN b, const VW sum0, + VW& sum1) { + sum1 = MulOddAdd(dw, a, b, sum1); + return MulEvenAdd(dw, a, b, sum0); +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { + // sum0 is the sum of bottom/even lanes and sum1 of top/odd lanes. + return Add(sum0, sum1); +} + +// ------------------------------ SumOfMulQuadAccumulate + +#ifdef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 /*di32*/, svint8_t a, + svint8_t b, svint32_t sum) { + return svdot_s32(sum, a, b); +} + +#ifdef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate(DU32 /*du32*/, svuint8_t a, + svuint8_t b, svuint32_t sum) { + return svdot_u32(sum, a, b); +} + +#ifdef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 di32, svuint8_t a_u, + svint8_t b_i, svint32_t sum) { +#if HWY_SVE_HAVE_2 && __ARM_FEATURE_MATMUL_INT8 + (void)di32; + return svusdot_s32(sum, a_u, b_i); +#else + const RebindToUnsigned du32; + const Repartition du8; + + const auto b_u = BitCast(du8, b_i); + const auto result_sum0 = svdot_u32(BitCast(du32, sum), a_u, b_u); + const auto result_sum1 = + ShiftLeft<8>(svdot_u32(Zero(du32), a_u, ShiftRight<7>(b_u))); + + return BitCast(di32, Sub(result_sum0, result_sum1)); +#endif +} + +#ifdef HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate(DI64 /*di64*/, svint16_t a, + svint16_t b, svint64_t sum) { + return svdot_s64(sum, a, b); +} + +#ifdef HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate(DU64 /*du64*/, svuint16_t a, + svuint16_t b, svuint64_t sum) { + return svdot_u64(sum, a, b); +} + +// ------------------------------ MulComplex* / MaskedMulComplex* + +// Per-target flag to prevent generic_ops-inl.h from defining MulComplex*. +#ifdef HWY_NATIVE_CPLX +#undef HWY_NATIVE_CPLX +#else +#define HWY_NATIVE_CPLX +#endif + +template )> +HWY_API V ComplexConj(V a) { + return OddEven(Neg(a), a); +} + +namespace detail { +#define HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, ROT) \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME##ROT(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \ + HWY_SVE_V(BASE, BITS) c) { \ + return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), a, b, c, ROT); \ + } \ + HWY_API HWY_SVE_V(BASE, BITS) \ + NAME##Z##ROT(svbool_t m, HWY_SVE_V(BASE, BITS) a, \ + HWY_SVE_V(BASE, BITS) b, HWY_SVE_V(BASE, BITS) c) { \ + return sv##OP##_##CHAR##BITS##_z(m, a, b, c, ROT); \ + } + +#define HWY_SVE_CPLX_FMA(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 0) \ + HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 90) \ + HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 180) \ + HWY_SVE_CPLX_FMA_ROT(BASE, CHAR, BITS, HALF, NAME, OP, 270) + +// Only SVE2 has complex multiply add for integer types +// and these do not include masked variants +HWY_SVE_FOREACH_F(HWY_SVE_CPLX_FMA, ComplexMulAdd, cmla) +#undef HWY_SVE_CPLX_FMA +#undef HWY_SVE_CPLX_FMA_ROT +} // namespace detail + +template +HWY_API V MaskedMulComplexConjAdd(M mask, V a, V b, V c) { + const V t = detail::ComplexMulAddZ0(mask, c, b, a); + return detail::ComplexMulAddZ270(mask, t, b, a); +} + +template +HWY_API V MaskedMulComplexConj(M mask, V a, V b) { + return MaskedMulComplexConjAdd(mask, a, b, Zero(DFromV())); +} + +template +HWY_API V MulComplexAdd(V a, V b, V c) { + return detail::ComplexMulAdd90(detail::ComplexMulAdd0(c, a, b), a, b); +} + +template +HWY_API V MulComplex(V a, V b) { + return MulComplexAdd(a, b, Zero(DFromV())); +} + +template +HWY_API V MaskedMulComplexOr(V no, M mask, V a, V b) { + return IfThenElse(mask, MulComplex(a, b), no); +} + +template +HWY_API V MulComplexConjAdd(V a, V b, V c) { + return detail::ComplexMulAdd270(detail::ComplexMulAdd0(c, b, a), b, a); +} + +template +HWY_API V MulComplexConj(V a, V b) { + return MulComplexConjAdd(a, b, Zero(DFromV())); +} + +// TODO SVE2 does have intrinsics for integers but not masked variants +template +HWY_API V MulComplex(V a, V b) { + // a = u + iv, b = x + iy + const auto u = DupEven(a); + const auto v = DupOdd(a); + const auto x = DupEven(b); + const auto y = DupOdd(b); + + return OddEven(MulAdd(u, y, Mul(v, x)), Sub(Mul(u, x), Mul(v, y))); +} + +template +HWY_API V MulComplexConj(V a, V b) { + // a = u + iv, b = x + iy + const auto u = DupEven(a); + const auto v = DupOdd(a); + const auto x = DupEven(b); + const auto y = DupOdd(b); + + return OddEven(Sub(Mul(v, x), Mul(u, y)), MulAdd(u, x, Mul(v, y))); +} + +template +HWY_API V MulComplexAdd(V a, V b, V c) { + return Add(MulComplex(a, b), c); +} + +template +HWY_API V MulComplexConjAdd(V a, V b, V c) { + return Add(MulComplexConj(a, b), c); +} + +template +HWY_API V MaskedMulComplexConjAdd(M mask, V a, V b, V c) { + return IfThenElseZero(mask, MulComplexConjAdd(a, b, c)); +} + +template +HWY_API V MaskedMulComplexConj(M mask, V a, V b) { + return IfThenElseZero(mask, MulComplexConj(a, b)); +} + +template +HWY_API V MaskedMulComplexOr(V no, M mask, V a, V b) { + return IfThenElse(mask, MulComplex(a, b), no); +} + +// ------------------------------ AESRound / CLMul + +// Static dispatch with -march=armv8-a+sve2+aes, or dynamic dispatch WITHOUT a +// baseline, in which case we check for AES support at runtime. +#if defined(__ARM_FEATURE_SVE2_AES) || \ + (HWY_SVE_HAVE_2 && HWY_HAVE_RUNTIME_DISPATCH && HWY_BASELINE_SVE2 == 0) + +// Per-target flag to prevent generic_ops-inl.h from defining AESRound. +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +HWY_API svuint8_t AESRound(svuint8_t state, svuint8_t round_key) { + // It is not clear whether E and MC fuse like they did on NEON. + return Xor(svaesmc_u8(svaese_u8(state, svdup_n_u8(0))), round_key); +} + +HWY_API svuint8_t AESLastRound(svuint8_t state, svuint8_t round_key) { + return Xor(svaese_u8(state, svdup_n_u8(0)), round_key); +} + +HWY_API svuint8_t AESInvMixColumns(svuint8_t state) { + return svaesimc_u8(state); +} + +HWY_API svuint8_t AESRoundInv(svuint8_t state, svuint8_t round_key) { + return Xor(svaesimc_u8(svaesd_u8(state, svdup_n_u8(0))), round_key); +} + +HWY_API svuint8_t AESLastRoundInv(svuint8_t state, svuint8_t round_key) { + return Xor(svaesd_u8(state, svdup_n_u8(0)), round_key); +} + +template +HWY_API svuint8_t AESKeyGenAssist(svuint8_t v) { + alignas(16) static constexpr uint8_t kRconXorMask[16] = { + 0, kRcon, 0, 0, 0, 0, 0, 0, 0, kRcon, 0, 0, 0, 0, 0, 0}; + alignas(16) static constexpr uint8_t kRotWordShuffle[16] = { + 0, 13, 10, 7, 1, 14, 11, 4, 8, 5, 2, 15, 9, 6, 3, 12}; + const DFromV d; + const Repartition du32; + const auto w13 = BitCast(d, DupOdd(BitCast(du32, v))); + const auto sub_word_result = AESLastRound(w13, LoadDup128(d, kRconXorMask)); + return TableLookupBytes(sub_word_result, LoadDup128(d, kRotWordShuffle)); +} + +HWY_API svuint64_t CLMulLower(const svuint64_t a, const svuint64_t b) { + return svpmullb_pair(a, b); +} + +HWY_API svuint64_t CLMulUpper(const svuint64_t a, const svuint64_t b) { + return svpmullt_pair(a, b); +} + +#endif // __ARM_FEATURE_SVE2_AES + +// ------------------------------ Lt128 + +namespace detail { +#define HWY_SVE_DUP(BASE, CHAR, BITS, HALF, NAME, OP) \ + template \ + HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) /*d*/, svbool_t m) { \ + return sv##OP##_b##BITS(m, m); \ + } + +HWY_SVE_FOREACH_U(HWY_SVE_DUP, DupEvenB, trn1) // actually for bool +HWY_SVE_FOREACH_U(HWY_SVE_DUP, DupOddB, trn2) // actually for bool +#undef HWY_SVE_DUP + +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE +template +HWY_INLINE svuint64_t Lt128Vec(D d, const svuint64_t a, const svuint64_t b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + const svbool_t eqHx = Eq(a, b); // only odd lanes used + // Convert to vector: more pipelines can execute vector TRN* instructions + // than the predicate version. + const svuint64_t ltHL = VecFromMask(d, Lt(a, b)); + // Move into upper lane: ltL if the upper half is equal, otherwise ltH. + // Requires an extra IfThenElse because INSR, EXT, TRN2 are unpredicated. + const svuint64_t ltHx = IfThenElse(eqHx, DupEven(ltHL), ltHL); + // Duplicate upper lane into lower. + return DupOdd(ltHx); +} +#endif +} // namespace detail + +template +HWY_INLINE svbool_t Lt128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return MaskFromVec(detail::Lt128Vec(d, a, b)); +#else + static_assert(IsSame, uint64_t>(), "D must be u64"); + const svbool_t eqHx = Eq(a, b); // only odd lanes used + const svbool_t ltHL = Lt(a, b); + // Move into upper lane: ltL if the upper half is equal, otherwise ltH. + const svbool_t ltHx = svsel_b(eqHx, detail::DupEvenB(d, ltHL), ltHL); + // Duplicate upper lane into lower. + return detail::DupOddB(d, ltHx); +#endif // HWY_TARGET != HWY_SVE_256 +} + +// ------------------------------ Lt128Upper + +template +HWY_INLINE svbool_t Lt128Upper(D d, svuint64_t a, svuint64_t b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + const svbool_t ltHL = Lt(a, b); + return detail::DupOddB(d, ltHL); +} + +// ------------------------------ Eq128, Ne128 + +#if HWY_TARGET == HWY_SVE_256 || HWY_IDE +namespace detail { + +template +HWY_INLINE svuint64_t Eq128Vec(D d, const svuint64_t a, const svuint64_t b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + // Convert to vector: more pipelines can execute vector TRN* instructions + // than the predicate version. + const svuint64_t eqHL = VecFromMask(d, Eq(a, b)); + // Duplicate upper and lower. + const svuint64_t eqHH = DupOdd(eqHL); + const svuint64_t eqLL = DupEven(eqHL); + return And(eqLL, eqHH); +} + +template +HWY_INLINE svuint64_t Ne128Vec(D d, const svuint64_t a, const svuint64_t b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + // Convert to vector: more pipelines can execute vector TRN* instructions + // than the predicate version. + const svuint64_t neHL = VecFromMask(d, Ne(a, b)); + // Duplicate upper and lower. + const svuint64_t neHH = DupOdd(neHL); + const svuint64_t neLL = DupEven(neHL); + return Or(neLL, neHH); +} + +} // namespace detail +#endif + +template +HWY_INLINE svbool_t Eq128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return MaskFromVec(detail::Eq128Vec(d, a, b)); +#else + static_assert(IsSame, uint64_t>(), "D must be u64"); + const svbool_t eqHL = Eq(a, b); + const svbool_t eqHH = detail::DupOddB(d, eqHL); + const svbool_t eqLL = detail::DupEvenB(d, eqHL); + return And(eqLL, eqHH); +#endif // HWY_TARGET != HWY_SVE_256 +} + +template +HWY_INLINE svbool_t Ne128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return MaskFromVec(detail::Ne128Vec(d, a, b)); +#else + static_assert(IsSame, uint64_t>(), "D must be u64"); + const svbool_t neHL = Ne(a, b); + const svbool_t neHH = detail::DupOddB(d, neHL); + const svbool_t neLL = detail::DupEvenB(d, neHL); + return Or(neLL, neHH); +#endif // HWY_TARGET != HWY_SVE_256 +} + +// ------------------------------ Eq128Upper, Ne128Upper + +template +HWY_INLINE svbool_t Eq128Upper(D d, svuint64_t a, svuint64_t b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + const svbool_t eqHL = Eq(a, b); + return detail::DupOddB(d, eqHL); +} + +template +HWY_INLINE svbool_t Ne128Upper(D d, svuint64_t a, svuint64_t b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + const svbool_t neHL = Ne(a, b); + return detail::DupOddB(d, neHL); +} + +// ------------------------------ Min128, Max128 (Lt128) + +template +HWY_INLINE svuint64_t Min128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return IfVecThenElse(detail::Lt128Vec(d, a, b), a, b); +#else + return IfThenElse(Lt128(d, a, b), a, b); +#endif +} + +template +HWY_INLINE svuint64_t Max128(D d, const svuint64_t a, const svuint64_t b) { +#if HWY_TARGET == HWY_SVE_256 + return IfVecThenElse(detail::Lt128Vec(d, b, a), a, b); +#else + return IfThenElse(Lt128(d, b, a), a, b); +#endif +} + +template +HWY_INLINE svuint64_t Min128Upper(D d, const svuint64_t a, const svuint64_t b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template +HWY_INLINE svuint64_t Max128Upper(D d, const svuint64_t a, const svuint64_t b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +// -------------------- LeadingZeroCount, TrailingZeroCount, HighestSetBitIndex + +#ifdef HWY_NATIVE_LEADING_ZERO_COUNT +#undef HWY_NATIVE_LEADING_ZERO_COUNT +#else +#define HWY_NATIVE_LEADING_ZERO_COUNT +#endif + +#define HWY_SVE_LEADING_ZERO_COUNT(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \ + const DFromV d; \ + return BitCast(d, sv##OP##_##CHAR##BITS##_x(detail::PTrue(d), v)); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_LEADING_ZERO_COUNT, LeadingZeroCount, clz) +#undef HWY_SVE_LEADING_ZERO_COUNT + +template +HWY_API V TrailingZeroCount(V v) { + return LeadingZeroCount(ReverseBits(v)); +} + +template +HWY_API V HighestSetBitIndex(V v) { + const DFromV d; + using T = TFromD; + return BitCast(d, Sub(Set(d, T{sizeof(T) * 8 - 1}), LeadingZeroCount(v))); +} + +#ifdef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT +#undef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT +#else +#define HWY_NATIVE_MASKED_LEADING_ZERO_COUNT +#endif + +#define HWY_SVE_MASKED_LEADING_ZERO_COUNT(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \ + const DFromV d; \ + return BitCast(d, sv##OP##_##CHAR##BITS##_z(m, v)); \ + } + +HWY_SVE_FOREACH_UI(HWY_SVE_MASKED_LEADING_ZERO_COUNT, MaskedLeadingZeroCount, + clz) +#undef HWY_SVE_LEADING_ZERO_COUNT + +// ================================================== END MACROS +#undef HWY_SVE_ALL_PTRUE +#undef HWY_SVE_D +#undef HWY_SVE_FOREACH +#undef HWY_SVE_FOREACH_BF16 +#undef HWY_SVE_FOREACH_BF16_UNCONDITIONAL +#undef HWY_SVE_FOREACH_F +#undef HWY_SVE_FOREACH_F16 +#undef HWY_SVE_FOREACH_F32 +#undef HWY_SVE_FOREACH_F3264 +#undef HWY_SVE_FOREACH_F64 +#undef HWY_SVE_FOREACH_I +#undef HWY_SVE_FOREACH_I08 +#undef HWY_SVE_FOREACH_I16 +#undef HWY_SVE_FOREACH_I32 +#undef HWY_SVE_FOREACH_I64 +#undef HWY_SVE_FOREACH_IF +#undef HWY_SVE_FOREACH_U +#undef HWY_SVE_FOREACH_U08 +#undef HWY_SVE_FOREACH_U16 +#undef HWY_SVE_FOREACH_U32 +#undef HWY_SVE_FOREACH_U64 +#undef HWY_SVE_FOREACH_UI +#undef HWY_SVE_FOREACH_UI08 +#undef HWY_SVE_FOREACH_UI16 +#undef HWY_SVE_FOREACH_UI32 +#undef HWY_SVE_FOREACH_UI64 +#undef HWY_SVE_FOREACH_UIF3264 +#undef HWY_SVE_HAVE_2 +#undef HWY_SVE_IF_EMULATED_D +#undef HWY_SVE_IF_NOT_EMULATED_D +#undef HWY_SVE_PTRUE +#undef HWY_SVE_RETV_ARGMVV +#undef HWY_SVE_RETV_ARGMVV_Z +#undef HWY_SVE_RETV_ARGMV_Z +#undef HWY_SVE_RETV_ARGMV +#undef HWY_SVE_RETV_ARGMVV_Z +#undef HWY_SVE_RETV_ARGPV +#undef HWY_SVE_RETV_ARGPVN +#undef HWY_SVE_RETV_ARGPVV +#undef HWY_SVE_RETV_ARGV +#undef HWY_SVE_RETV_ARGVN +#undef HWY_SVE_RETV_ARGMV_M +#undef HWY_SVE_RETV_ARGVV +#undef HWY_SVE_RETV_ARGVVV +#undef HWY_SVE_RETV_ARGMVVV_Z +#undef HWY_SVE_RETV_ARGMVVV +#undef HWY_SVE_T +#undef HWY_SVE_UNDEFINED +#undef HWY_SVE_V + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/lib/highway/hwy/ops/emu128-inl.h b/lib/highway/hwy/ops/emu128-inl.h new file mode 100644 index 00000000000..1ef625f8d89 --- /dev/null +++ b/lib/highway/hwy/ops/emu128-inl.h @@ -0,0 +1,2990 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Single-element vectors and operations. +// External include guard in highway.h - see comment there. + +#include "hwy/base.h" + +#ifndef HWY_NO_LIBCXX +#include // sqrtf +#endif + +#include "hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +using Full128 = Simd; + +// (Wrapper class required for overloading comparison operators.) +template +struct Vec128 { + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = N; // only for DFromV + + HWY_INLINE Vec128() = default; + Vec128(const Vec128&) = default; + Vec128& operator=(const Vec128&) = default; + + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator%=(const Vec128 other) { + return *this = (*this % other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + // Behave like wasm128 (vectors can always hold 128 bits). generic_ops-inl.h + // relies on this for LoadInterleaved*. CAVEAT: this method of padding + // prevents using range for, especially in SumOfLanes, where it would be + // incorrect. Moving padding to another field would require handling the case + // where N = 16 / sizeof(T) (i.e. there is no padding), which is also awkward. + T raw[16 / sizeof(T)] = {}; +}; + +// 0 or FF..FF, same size as Vec128. +template +struct Mask128 { + using Raw = hwy::MakeUnsigned; + + using PrivateT = T; // only for DFromM + static constexpr size_t kPrivateN = N; // only for DFromM + + static HWY_INLINE Raw FromBool(bool b) { + return b ? static_cast(~Raw{0}) : 0; + } + + // Must match the size of Vec128. + Raw bits[16 / sizeof(T)] = {}; +}; + +template +using DFromV = Simd; + +template +using DFromM = Simd; + +template +using TFromV = typename V::PrivateT; + +// ------------------------------ Zero + +// Use HWY_MAX_LANES_D here because VFromD is defined in terms of Zero. +template +HWY_API Vec128, HWY_MAX_LANES_D(D)> Zero(D /* tag */) { + Vec128, HWY_MAX_LANES_D(D)> v; // zero-initialized + return v; +} + +template +using VFromD = decltype(Zero(D())); + +// ------------------------------ BitCast + +template +HWY_API VFromD BitCast(D /* tag */, VFrom v) { + VFromD to; + CopySameSize(&v.raw, &to.raw); + return to; +} + +// ------------------------------ ResizeBitCast + +template +HWY_API VFromD ResizeBitCast(D d, VFrom v) { + using DFrom = DFromV; + using TFrom = TFromD; + using TTo = TFromD; + + constexpr size_t kFromByteLen = sizeof(TFrom) * HWY_MAX_LANES_D(DFrom); + constexpr size_t kToByteLen = sizeof(TTo) * HWY_MAX_LANES_D(D); + constexpr size_t kCopyByteLen = HWY_MIN(kFromByteLen, kToByteLen); + + VFromD to = Zero(d); + CopyBytes(&v.raw, &to.raw); + return to; +} + +namespace detail { + +// ResizeBitCast on the HWY_EMU128 target has zero-extending semantics if +// VFromD is a larger vector than FromV +template +HWY_INLINE VFromD ZeroExtendResizeBitCast(FromSizeTag /* from_size_tag */, + ToSizeTag /* to_size_tag */, + DTo d_to, DFrom /* d_from */, + VFromD v) { + return ResizeBitCast(d_to, v); +} + +} // namespace detail + +// ------------------------------ Set +template +HWY_API VFromD Set(D d, const T2 t) { + VFromD v; + for (size_t i = 0; i < MaxLanes(d); ++i) { + v.raw[i] = ConvertScalarTo>(t); + } + return v; +} + +// ------------------------------ Undefined +template +HWY_API VFromD Undefined(D d) { + return Zero(d); +} + +// ------------------------------ Dup128VecFromValues + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + VFromD result; + result.raw[0] = t0; + result.raw[1] = t1; + result.raw[2] = t2; + result.raw[3] = t3; + result.raw[4] = t4; + result.raw[5] = t5; + result.raw[6] = t6; + result.raw[7] = t7; + result.raw[8] = t8; + result.raw[9] = t9; + result.raw[10] = t10; + result.raw[11] = t11; + result.raw[12] = t12; + result.raw[13] = t13; + result.raw[14] = t14; + result.raw[15] = t15; + return result; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + VFromD result; + result.raw[0] = t0; + result.raw[1] = t1; + result.raw[2] = t2; + result.raw[3] = t3; + result.raw[4] = t4; + result.raw[5] = t5; + result.raw[6] = t6; + result.raw[7] = t7; + return result; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + VFromD result; + result.raw[0] = t0; + result.raw[1] = t1; + result.raw[2] = t2; + result.raw[3] = t3; + return result; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + VFromD result; + result.raw[0] = t0; + result.raw[1] = t1; + return result; +} + +// ------------------------------ Iota + +template , typename T2> +HWY_API VFromD Iota(D d, T2 first) { + VFromD v; + for (size_t i = 0; i < MaxLanes(d); ++i) { + v.raw[i] = AddWithWraparound(static_cast(first), i); + } + return v; +} + +// ================================================== LOGICAL + +// ------------------------------ Not +template +HWY_API Vec128 Not(Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; + VFromD vu = BitCast(du, v); + for (size_t i = 0; i < N; ++i) { + vu.raw[i] = static_cast(~vu.raw[i]); + } + return BitCast(d, vu); +} + +// ------------------------------ And +template +HWY_API Vec128 And(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + auto au = BitCast(du, a); + auto bu = BitCast(du, b); + for (size_t i = 0; i < N; ++i) { + au.raw[i] &= bu.raw[i]; + } + return BitCast(d, au); +} +template +HWY_API Vec128 operator&(Vec128 a, Vec128 b) { + return And(a, b); +} + +// ------------------------------ AndNot +template +HWY_API Vec128 AndNot(Vec128 a, Vec128 b) { + return And(Not(a), b); +} + +// ------------------------------ Or +template +HWY_API Vec128 Or(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + auto au = BitCast(du, a); + auto bu = BitCast(du, b); + for (size_t i = 0; i < N; ++i) { + au.raw[i] |= bu.raw[i]; + } + return BitCast(d, au); +} +template +HWY_API Vec128 operator|(Vec128 a, Vec128 b) { + return Or(a, b); +} + +// ------------------------------ Xor +template +HWY_API Vec128 Xor(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + auto au = BitCast(du, a); + auto bu = BitCast(du, b); + for (size_t i = 0; i < N; ++i) { + au.raw[i] ^= bu.raw[i]; + } + return BitCast(d, au); +} +template +HWY_API Vec128 operator^(Vec128 a, Vec128 b) { + return Xor(a, b); +} + +// ------------------------------ Or3 +template +HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd +template +HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse +template +HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, + Vec128 no) { + return Or(And(mask, yes), AndNot(mask, no)); +} + +// ------------------------------ CopySign +template +HWY_API Vec128 CopySign(Vec128 magn, Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const DFromV d; + return BitwiseIfThenElse(SignBit(d), sign, magn); +} + +// ------------------------------ CopySignToAbs +template +HWY_API Vec128 CopySignToAbs(Vec128 abs, Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const DFromV d; + return OrAnd(abs, SignBit(d), sign); +} + +// ------------------------------ BroadcastSignBit +template +HWY_API Vec128 BroadcastSignBit(Vec128 v) { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = ScalarShr(v.raw[i], sizeof(T) * 8 - 1); + } + return v; +} + +// ------------------------------ Mask + +// v must be 0 or FF..FF. +template +HWY_API Mask128 MaskFromVec(Vec128 v) { + Mask128 mask; + CopySameSize(&v.raw, &mask.bits); + return mask; +} + +template +using MFromD = decltype(MaskFromVec(VFromD())); + +template +HWY_API MFromD RebindMask(DTo /* tag */, MFrom mask) { + MFromD to; + CopySameSize(&mask.bits, &to.bits); + return to; +} + +template +VFromD VecFromMask(D /* tag */, MFromD mask) { + VFromD v; + CopySameSize(&mask.bits, &v.raw); + return v; +} + +template +uint64_t BitsFromMask(D d, MFromD mask) { + uint64_t bits = 0; + for (size_t i = 0; i < Lanes(d); ++i) { + bits |= mask.bits[i] ? (1ull << i) : 0; + } + return bits; +} + +template +HWY_API MFromD FirstN(D d, size_t n) { + MFromD m; + for (size_t i = 0; i < MaxLanes(d); ++i) { + m.bits[i] = MFromD::FromBool(i < n); + } + return m; +} + +// Returns mask ? yes : no. +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + const DFromV d; + return IfVecThenElse(VecFromMask(d, mask), yes, no); +} + +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + const DFromV d; + return IfVecThenElse(VecFromMask(d, mask), yes, Zero(d)); +} + +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + const DFromV d; + return IfVecThenElse(VecFromMask(d, mask), Zero(d), no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + const DFromV d; + const RebindToSigned di; + const auto vi = BitCast(di, v); + + for (size_t i = 0; i < N; ++i) { + v.raw[i] = vi.raw[i] < 0 ? yes.raw[i] : no.raw[i]; + } + return v; +} + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(Mask128 m) { + const Simd d; + return MaskFromVec(Not(VecFromMask(d, m))); +} + +template +HWY_API Mask128 And(Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 AndNot(Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Or(Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Xor(Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 ExclusiveNeither(Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ================================================== SHIFTS + +// ------------------------------ ShiftLeft/ShiftRight (BroadcastSignBit) + +template +HWY_API Vec128 ShiftLeft(Vec128 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + using TU = hwy::MakeUnsigned; + for (size_t i = 0; i < N; ++i) { + const TU raw_u = static_cast(v.raw[i]); + const auto shifted = raw_u << kBits; // separate line to avoid MSVC warning + v.raw[i] = static_cast(shifted); + } + return v; +} + +template +HWY_API Vec128 ShiftRight(Vec128 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + // Signed right shift is now guaranteed to be arithmetic (rounding toward + // negative infinity, i.e. shifting in the sign bit). + for (size_t i = 0; i < N; ++i) { + v.raw[i] = ScalarShr(v.raw[i], kBits); + } + + return v; +} + +// ------------------------------ RotateRight (ShiftRight) +template +HWY_API Vec128 RotateRight(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + + constexpr size_t kSizeInBits = sizeof(T) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + + return Or(BitCast(d, ShiftRight(BitCast(du, v))), + ShiftLeft(v)); +} + +// ------------------------------ ShiftLeftSame + +template +HWY_API Vec128 ShiftLeftSame(Vec128 v, int bits) { + for (size_t i = 0; i < N; ++i) { + const auto shifted = static_cast>(v.raw[i]) << bits; + v.raw[i] = static_cast(shifted); + } + return v; +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, int bits) { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = ScalarShr(v.raw[i], bits); + } + + return v; +} + +// ------------------------------ Shl + +template +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + for (size_t i = 0; i < N; ++i) { + const auto shifted = static_cast>(v.raw[i]) + << bits.raw[i]; + v.raw[i] = static_cast(shifted); + } + return v; +} + +template +HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = ScalarShr(v.raw[i], static_cast(bits.raw[i])); + } + + return v; +} + +// ================================================== ARITHMETIC + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec128 Add(hwy::NonFloatTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + const uint64_t a64 = static_cast(a.raw[i]); + const uint64_t b64 = static_cast(b.raw[i]); + a.raw[i] = static_cast((a64 + b64) & static_cast(~T(0))); + } + return a; +} +template +HWY_INLINE Vec128 Sub(hwy::NonFloatTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + const uint64_t a64 = static_cast(a.raw[i]); + const uint64_t b64 = static_cast(b.raw[i]); + a.raw[i] = static_cast((a64 - b64) & static_cast(~T(0))); + } + return a; +} + +template +HWY_INLINE Vec128 Add(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] += b.raw[i]; + } + return a; +} + +template +HWY_INLINE Vec128 Sub(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] -= b.raw[i]; + } + return a; +} + +} // namespace detail + +template +HWY_API Vec128 operator-(Vec128 a, Vec128 b) { + return detail::Sub(hwy::IsFloatTag(), a, b); +} +template +HWY_API Vec128 operator+(Vec128 a, Vec128 b) { + return detail::Add(hwy::IsFloatTag(), a, b); +} + +// ------------------------------ SumsOf8 + +template +HWY_API Vec128 SumsOf8(Vec128 v) { + Vec128 sums; + for (size_t i = 0; i < N; ++i) { + sums.raw[i / 8] += v.raw[i]; + } + return sums; +} + +template +HWY_API Vec128 SumsOf8(Vec128 v) { + Vec128 sums; + for (size_t i = 0; i < N; ++i) { + sums.raw[i / 8] += v.raw[i]; + } + return sums; +} + +// ------------------------------ SaturatedAdd +template +HWY_API Vec128 SaturatedAdd(Vec128 a, Vec128 b) { + using TW = MakeSigned>; + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast(HWY_MIN( + HWY_MAX(hwy::LowestValue(), static_cast(a.raw[i]) + b.raw[i]), + hwy::HighestValue())); + } + return a; +} + +// ------------------------------ SaturatedSub +template +HWY_API Vec128 SaturatedSub(Vec128 a, Vec128 b) { + using TW = MakeSigned>; + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast(HWY_MIN( + HWY_MAX(hwy::LowestValue(), static_cast(a.raw[i]) - b.raw[i]), + hwy::HighestValue())); + } + return a; +} + +// ------------------------------ AverageRound + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 +#undef HWY_NATIVE_AVERAGE_ROUND_UI32 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI32 +#endif + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64 +#undef HWY_NATIVE_AVERAGE_ROUND_UI64 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI64 +#endif + +template +HWY_API Vec128 AverageRound(Vec128 a, Vec128 b) { + for (size_t i = 0; i < N; ++i) { + const T a_val = a.raw[i]; + const T b_val = b.raw[i]; + a.raw[i] = static_cast((a_val | b_val) - ScalarShr(a_val ^ b_val, 1)); + } + return a; +} + +// ------------------------------ Abs + +template +HWY_API Vec128 Abs(Vec128 a) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = ScalarAbs(a.raw[i]); + } + return a; +} + +// ------------------------------ Min/Max + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec128 Min(hwy::NonFloatTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = HWY_MIN(a.raw[i], b.raw[i]); + } + return a; +} +template +HWY_INLINE Vec128 Max(hwy::NonFloatTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = HWY_MAX(a.raw[i], b.raw[i]); + } + return a; +} + +template +HWY_INLINE Vec128 Min(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + if (ScalarIsNaN(a.raw[i])) { + a.raw[i] = b.raw[i]; + } else if (ScalarIsNaN(b.raw[i])) { + // no change + } else { + a.raw[i] = HWY_MIN(a.raw[i], b.raw[i]); + } + } + return a; +} +template +HWY_INLINE Vec128 Max(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + if (ScalarIsNaN(a.raw[i])) { + a.raw[i] = b.raw[i]; + } else if (ScalarIsNaN(b.raw[i])) { + // no change + } else { + a.raw[i] = HWY_MAX(a.raw[i], b.raw[i]); + } + } + return a; +} + +} // namespace detail + +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return detail::Min(hwy::IsFloatTag(), a, b); +} + +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return detail::Max(hwy::IsFloatTag(), a, b); +} + +// ------------------------------ Neg + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_API Vec128 Neg(hwy::NonFloatTag /*tag*/, Vec128 v) { + const DFromV d; + return Zero(d) - v; +} + +template +HWY_API Vec128 Neg(hwy::FloatTag /*tag*/, Vec128 v) { + const DFromV d; + return Xor(v, SignBit(d)); +} + +template +HWY_API Vec128 Neg(hwy::SpecialTag /*tag*/, Vec128 v) { + const DFromV d; + return Xor(v, SignBit(d)); +} + +} // namespace detail + +template +HWY_API Vec128 Neg(Vec128 v) { + return detail::Neg(hwy::IsFloatTag(), v); +} + +// ------------------------------ Mul/Div + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec128 Mul(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] *= b.raw[i]; + } + return a; +} + +template +HWY_INLINE Vec128 Mul(SignedTag /*tag*/, Vec128 a, Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast(static_cast(a.raw[i]) * + static_cast(b.raw[i])); + } + return a; +} + +template +HWY_INLINE Vec128 Mul(UnsignedTag /*tag*/, Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast(static_cast(a.raw[i]) * + static_cast(b.raw[i])); + } + return a; +} + +} // namespace detail + +// Per-target flags to prevent generic_ops-inl.h defining 8/64-bit operator*. +#ifdef HWY_NATIVE_MUL_8 +#undef HWY_NATIVE_MUL_8 +#else +#define HWY_NATIVE_MUL_8 +#endif +#ifdef HWY_NATIVE_MUL_64 +#undef HWY_NATIVE_MUL_64 +#else +#define HWY_NATIVE_MUL_64 +#endif + +template +HWY_API Vec128 operator*(Vec128 a, Vec128 b) { + return detail::Mul(hwy::TypeTag(), a, b); +} + +template +HWY_API Vec128 operator/(Vec128 a, Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = (b.raw[i] == T{0}) ? 0 : a.raw[i] / b.raw[i]; + } + return a; +} + +// Returns the upper sizeof(T)*8 bits of a * b in each lane. +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + using TW = MakeWide; + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast( + (static_cast(a.raw[i]) * static_cast(b.raw[i])) >> + (sizeof(T) * 8)); + } + return a; +} + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + T hi; + Mul128(GetLane(a), GetLane(b), &hi); + return Set(Full64(), hi); +} + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + T hi_0; + T hi_1; + + Mul128(GetLane(a), GetLane(b), &hi_0); + Mul128(ExtractLane(a, 1), ExtractLane(b, 1), &hi_1); + + return Dup128VecFromValues(Full128(), hi_0, hi_1); +} + +template +HWY_API Vec128 MulFixedPoint15(Vec128 a, + Vec128 b) { + for (size_t i = 0; i < N; ++i) { + a.raw[i] = static_cast((a.raw[i] * b.raw[i] + 16384) >> 15); + } + return a; +} + +// Multiplies even lanes (0, 2, ..) and returns the double-wide result. +template +HWY_API Vec128, (N + 1) / 2> MulEven(Vec128 a, + Vec128 b) { + using TW = MakeWide; + Vec128 mul; + for (size_t i = 0; i < N; i += 2) { + const TW a_wide = a.raw[i]; + mul.raw[i / 2] = static_cast(a_wide * b.raw[i]); + } + return mul; +} + +// Multiplies odd lanes (1, 3, ..) and returns the double-wide result. +template +HWY_API Vec128, (N + 1) / 2> MulOdd(Vec128 a, + Vec128 b) { + using TW = MakeWide; + Vec128 mul; + for (size_t i = 0; i < N; i += 2) { + const TW a_wide = a.raw[i + 1]; + mul.raw[i / 2] = static_cast(a_wide * b.raw[i + 1]); + } + return mul; +} + +template +HWY_API Vec128 ApproximateReciprocal(Vec128 v) { + for (size_t i = 0; i < N; ++i) { + // Zero inputs are allowed, but callers are responsible for replacing the + // return value with something else (typically using IfThenElse). This check + // avoids a ubsan error. The result is arbitrary. + v.raw[i] = (ScalarAbs(v.raw[i]) == 0.0f) ? 0.0f : 1.0f / v.raw[i]; + } + return v; +} + +// generic_ops takes care of integer T. +template +HWY_API Vec128 AbsDiff(Vec128 a, Vec128 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return mul * x + add; +} + +template +HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return add - mul * x; +} + +template +HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, + Vec128 sub) { + return mul * x - sub; +} + +template +HWY_API Vec128 NegMulSub(Vec128 mul, Vec128 x, + Vec128 sub) { + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +template +HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { + for (size_t i = 0; i < N; ++i) { + const float half = v.raw[i] * 0.5f; + // Initial guess based on log2(f) + v.raw[i] = BitCastScalar(static_cast( + 0x5F3759DF - (BitCastScalar(v.raw[i]) >> 1))); + // One Newton-Raphson iteration + v.raw[i] = v.raw[i] * (1.5f - (half * v.raw[i] * v.raw[i])); + } + return v; +} + +namespace detail { + +static HWY_INLINE float ScalarSqrt(float v) { +#if defined(HWY_NO_LIBCXX) +#if HWY_COMPILER_GCC_ACTUAL + return __builtin_sqrt(v); +#else + uint32_t bits = BitCastScalar(v); + // Coarse approximation, letting the exponent LSB leak into the mantissa + bits = (1 << 29) + (bits >> 1) - (1 << 22); + return BitCastScalar(bits); +#endif // !HWY_COMPILER_GCC_ACTUAL +#else + return sqrtf(v); +#endif // !HWY_NO_LIBCXX +} +static HWY_INLINE double ScalarSqrt(double v) { +#if defined(HWY_NO_LIBCXX) +#if HWY_COMPILER_GCC_ACTUAL + return __builtin_sqrt(v); +#else + uint64_t bits = BitCastScalar(v); + // Coarse approximation, letting the exponent LSB leak into the mantissa + bits = (1ULL << 61) + (bits >> 1) - (1ULL << 51); + return BitCastScalar(bits); +#endif // !HWY_COMPILER_GCC_ACTUAL +#else + return sqrt(v); +#endif // HWY_NO_LIBCXX +} + +} // namespace detail + +template +HWY_API Vec128 Sqrt(Vec128 v) { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = detail::ScalarSqrt(v.raw[i]); + } + return v; +} + +// ------------------------------ Floating-point rounding + +template +HWY_API Vec128 Round(Vec128 v) { + using TI = MakeSigned; + const T k0 = ConvertScalarTo(0); + const Vec128 a = Abs(v); + for (size_t i = 0; i < N; ++i) { + if (!(a.raw[i] < MantissaEnd())) { // Huge or NaN + continue; + } + const T bias = ConvertScalarTo(v.raw[i] < k0 ? -0.5 : 0.5); + const TI rounded = ConvertScalarTo(v.raw[i] + bias); + if (rounded == 0) { + v.raw[i] = v.raw[i] < 0 ? ConvertScalarTo(-0) : k0; + continue; + } + const T rounded_f = ConvertScalarTo(rounded); + // Round to even + if ((rounded & 1) && + ScalarAbs(rounded_f - v.raw[i]) == ConvertScalarTo(0.5)) { + v.raw[i] = ConvertScalarTo(rounded - (v.raw[i] < k0 ? -1 : 1)); + continue; + } + v.raw[i] = rounded_f; + } + return v; +} + +// Round-to-nearest even. +template +HWY_API Vec128, N> NearestInt(Vec128 v) { + using TI = MakeSigned; + const T k0 = ConvertScalarTo(0); + + const Vec128 abs = Abs(v); + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + const bool signbit = ScalarSignBit(v.raw[i]); + + if (!(abs.raw[i] < MantissaEnd())) { // Huge or NaN + // Check if too large to cast or NaN + if (!(abs.raw[i] <= ConvertScalarTo(LimitsMax()))) { + ret.raw[i] = signbit ? LimitsMin() : LimitsMax(); + continue; + } + ret.raw[i] = static_cast(v.raw[i]); + continue; + } + const T bias = ConvertScalarTo(v.raw[i] < k0 ? -0.5 : 0.5); + const TI rounded = ConvertScalarTo(v.raw[i] + bias); + if (rounded == 0) { + ret.raw[i] = 0; + continue; + } + const T rounded_f = ConvertScalarTo(rounded); + // Round to even + if ((rounded & 1) && + ScalarAbs(rounded_f - v.raw[i]) == ConvertScalarTo(0.5)) { + ret.raw[i] = rounded - (signbit ? -1 : 1); + continue; + } + ret.raw[i] = rounded; + } + return ret; +} + +template +HWY_API VFromD DemoteToNearestInt(DI32 /*di32*/, + VFromD> v) { + using T = double; + using TI = int32_t; + const T k0 = ConvertScalarTo(0); + + constexpr size_t N = HWY_MAX_LANES_D(DI32); + + const VFromD> abs = Abs(v); + VFromD ret; + for (size_t i = 0; i < N; ++i) { + const bool signbit = ScalarSignBit(v.raw[i]); + + // Check if too large to cast or NaN + if (!(abs.raw[i] <= ConvertScalarTo(LimitsMax()))) { + ret.raw[i] = signbit ? LimitsMin() : LimitsMax(); + continue; + } + + const T bias = ConvertScalarTo(v.raw[i] < k0 ? -0.5 : 0.5); + const TI rounded = ConvertScalarTo(v.raw[i] + bias); + if (rounded == 0) { + ret.raw[i] = 0; + continue; + } + const T rounded_f = ConvertScalarTo(rounded); + // Round to even + if ((rounded & 1) && + ScalarAbs(rounded_f - v.raw[i]) == ConvertScalarTo(0.5)) { + ret.raw[i] = rounded - (signbit ? -1 : 1); + continue; + } + ret.raw[i] = rounded; + } + return ret; +} + +template +HWY_API Vec128 Trunc(Vec128 v) { + using TI = MakeSigned; + const Vec128 abs = Abs(v); + for (size_t i = 0; i < N; ++i) { + if (!(abs.raw[i] <= MantissaEnd())) { // Huge or NaN + continue; + } + const TI truncated = static_cast(v.raw[i]); + if (truncated == 0) { + v.raw[i] = v.raw[i] < 0 ? -T{0} : T{0}; + continue; + } + v.raw[i] = static_cast(truncated); + } + return v; +} + +// Toward +infinity, aka ceiling +template +Vec128 Ceil(Vec128 v) { + constexpr int kMantissaBits = MantissaBits(); + using Bits = MakeUnsigned; + const Bits kExponentMask = MaxExponentField(); + const Bits kMantissaMask = MantissaMask(); + const Bits kBias = kExponentMask / 2; + + for (size_t i = 0; i < N; ++i) { + const bool positive = v.raw[i] > Float(0.0); + + Bits bits = BitCastScalar(v.raw[i]); + + const int exponent = + static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) continue; + // |v| <= 1 => 0 or 1. + if (exponent < 0) { + v.raw[i] = positive ? Float{1} : Float{-0.0}; + continue; + } + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) continue; + + // Clear fractional bits and round up + if (positive) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + v.raw[i] = BitCastScalar(bits); + } + return v; +} + +// Toward -infinity, aka floor +template +Vec128 Floor(Vec128 v) { + constexpr int kMantissaBits = MantissaBits(); + using Bits = MakeUnsigned; + const Bits kExponentMask = MaxExponentField(); + const Bits kMantissaMask = MantissaMask(); + const Bits kBias = kExponentMask / 2; + + for (size_t i = 0; i < N; ++i) { + const bool negative = v.raw[i] < Float(0.0); + + Bits bits = BitCastScalar(v.raw[i]); + + const int exponent = + static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) continue; + // |v| <= 1 => -1 or 0. + if (exponent < 0) { + v.raw[i] = negative ? Float(-1.0) : Float(0.0); + continue; + } + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) continue; + + // Clear fractional bits and round down + if (negative) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + v.raw[i] = BitCastScalar(bits); + } + return v; +} + +// ------------------------------ Floating-point classification + +template +HWY_API Mask128 IsNaN(Vec128 v) { + Mask128 ret; + for (size_t i = 0; i < N; ++i) { + // std::isnan returns false for 0x7F..FF in clang AVX3 builds, so DIY. + ret.bits[i] = Mask128::FromBool(ScalarIsNaN(v.raw[i])); + } + return ret; +} + +// ================================================== COMPARE + +template +HWY_API Mask128 operator==(Vec128 a, Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] == b.raw[i]); + } + return m; +} + +template +HWY_API Mask128 operator!=(Vec128 a, Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] != b.raw[i]); + } + return m; +} + +template +HWY_API Mask128 TestBit(Vec128 v, Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +template +HWY_API Mask128 operator<(Vec128 a, Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] < b.raw[i]); + } + return m; +} +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] > b.raw[i]); + } + return m; +} + +template +HWY_API Mask128 operator<=(Vec128 a, Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] <= b.raw[i]); + } + return m; +} +template +HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { + Mask128 m; + for (size_t i = 0; i < N; ++i) { + m.bits[i] = Mask128::FromBool(a.raw[i] >= b.raw[i]); + } + return m; +} + +// ------------------------------ Lt128 + +// Only makes sense for full vectors of u64. +template +HWY_API MFromD Lt128(D /* tag */, Vec128 a, Vec128 b) { + const bool lt = + (a.raw[1] < b.raw[1]) || (a.raw[1] == b.raw[1] && a.raw[0] < b.raw[0]); + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(lt); + return ret; +} + +template +HWY_API MFromD Lt128Upper(D /* tag */, Vec128 a, + Vec128 b) { + const bool lt = a.raw[1] < b.raw[1]; + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(lt); + return ret; +} + +// ------------------------------ Eq128 + +// Only makes sense for full vectors of u64. +template +HWY_API MFromD Eq128(D /* tag */, Vec128 a, Vec128 b) { + const bool eq = a.raw[1] == b.raw[1] && a.raw[0] == b.raw[0]; + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(eq); + return ret; +} + +template +HWY_API Mask128 Ne128(D /* tag */, Vec128 a, + Vec128 b) { + const bool ne = a.raw[1] != b.raw[1] || a.raw[0] != b.raw[0]; + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(ne); + return ret; +} + +template +HWY_API MFromD Eq128Upper(D /* tag */, Vec128 a, + Vec128 b) { + const bool eq = a.raw[1] == b.raw[1]; + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(eq); + return ret; +} + +template +HWY_API MFromD Ne128Upper(D /* tag */, Vec128 a, + Vec128 b) { + const bool ne = a.raw[1] != b.raw[1]; + Mask128 ret; + ret.bits[0] = ret.bits[1] = Mask128::FromBool(ne); + return ret; +} + +// ------------------------------ Min128, Max128 (Lt128) + +template +HWY_API VFromD Min128(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128(d, a, b), a, b); +} + +template +HWY_API VFromD Max128(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128(d, b, a), a, b); +} + +template +HWY_API VFromD Min128Upper(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template +HWY_API VFromD Max128Upper(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT aligned) { + VFromD v; + CopyBytes(aligned, v.raw); // copy from array + return v; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + return IfThenElseZero(m, LoadU(d, p)); +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + return IfThenElse(m, LoadU(d, p), v); +} + +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + return Load(d, p); +} + +// In some use cases, "load single lane" is sufficient; otherwise avoid this. +template +HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT aligned) { + return Load(d, aligned); +} + +#ifdef HWY_NATIVE_LOAD_N +#undef HWY_NATIVE_LOAD_N +#else +#define HWY_NATIVE_LOAD_N +#endif + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t max_lanes_to_load) { + VFromD v = Zero(d); + const size_t N = Lanes(d); + const size_t num_of_lanes_to_load = HWY_MIN(max_lanes_to_load, N); + CopyBytes(p, v.raw, num_of_lanes_to_load * sizeof(TFromD)); + return v; +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t max_lanes_to_load) { + VFromD v = no; + const size_t N = Lanes(d); + const size_t num_of_lanes_to_load = HWY_MIN(max_lanes_to_load, N); + CopyBytes(p, v.raw, num_of_lanes_to_load * sizeof(TFromD)); + return v; +} + +// ------------------------------ Store + +template +HWY_API void Store(VFromD v, D d, TFromD* HWY_RESTRICT aligned) { + CopyBytes(v.raw, aligned); // copy to array +} + +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + Store(v, d, p); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + for (size_t i = 0; i < MaxLanes(d); ++i) { + if (m.bits[i]) p[i] = v.raw[i]; + } +} + +#ifdef HWY_NATIVE_STORE_N +#undef HWY_NATIVE_STORE_N +#else +#define HWY_NATIVE_STORE_N +#endif + +template +HWY_API void StoreN(VFromD v, D d, TFromD* HWY_RESTRICT p, + size_t max_lanes_to_store) { + const size_t N = Lanes(d); + const size_t num_of_lanes_to_store = HWY_MIN(max_lanes_to_store, N); + CopyBytes(v.raw, p, num_of_lanes_to_store * sizeof(TFromD)); +} + +// ================================================== COMBINE + +template +HWY_API Vec128 LowerHalf(Vec128 v) { + Vec128 ret; + CopyBytes(v.raw, ret.raw); + return ret; +} + +template +HWY_API VFromD LowerHalf(D /* tag */, VFromD> v) { + return LowerHalf(v); +} + +template +HWY_API VFromD UpperHalf(D d, VFromD> v) { + VFromD ret; + CopyBytes(&v.raw[MaxLanes(d)], ret.raw); + return ret; +} + +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> v) { + const Half dh; + VFromD ret; // zero-initialized + CopyBytes(v.raw, ret.raw); + return ret; +} + +template >> +HWY_API VFromD Combine(D d, VH hi_half, VH lo_half) { + const Half dh; + VFromD ret; + CopyBytes(lo_half.raw, &ret.raw[0]); + CopyBytes(hi_half.raw, &ret.raw[MaxLanes(dh)]); + return ret; +} + +template +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + const Half dh; + VFromD ret; + CopyBytes(lo.raw, &ret.raw[0]); + CopyBytes(hi.raw, &ret.raw[MaxLanes(dh)]); + return ret; +} + +template +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + const Half dh; + VFromD ret; + CopyBytes(&lo.raw[MaxLanes(dh)], &ret.raw[0]); + CopyBytes(&hi.raw[MaxLanes(dh)], &ret.raw[MaxLanes(dh)]); + return ret; +} + +template +HWY_API VFromD ConcatLowerUpper(D d, VFromD hi, VFromD lo) { + const Half dh; + VFromD ret; + CopyBytes(&lo.raw[MaxLanes(dh)], &ret.raw[0]); + CopyBytes(hi.raw, &ret.raw[MaxLanes(dh)]); + return ret; +} + +template +HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { + const Half dh; + VFromD ret; + CopyBytes(lo.raw, &ret.raw[0]); + CopyBytes(&hi.raw[MaxLanes(dh)], &ret.raw[MaxLanes(dh)]); + return ret; +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const Half dh; + VFromD ret; + for (size_t i = 0; i < MaxLanes(dh); ++i) { + ret.raw[i] = lo.raw[2 * i]; + } + for (size_t i = 0; i < MaxLanes(dh); ++i) { + ret.raw[MaxLanes(dh) + i] = hi.raw[2 * i]; + } + return ret; +} + +// 2023-11-23: workaround for incorrect codegen (reduction_test fails for +// SumsOf2 because PromoteOddTo, which uses ConcatOdd, returns zero). +#if HWY_ARCH_RISCV && HWY_TARGET == HWY_EMU128 && HWY_COMPILER_CLANG +#define HWY_EMU128_CONCAT_INLINE HWY_NOINLINE +#else +#define HWY_EMU128_CONCAT_INLINE HWY_API +#endif + +template +HWY_EMU128_CONCAT_INLINE VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const Half dh; + VFromD ret; + for (size_t i = 0; i < MaxLanes(dh); ++i) { + ret.raw[i] = lo.raw[2 * i + 1]; + } + for (size_t i = 0; i < MaxLanes(dh); ++i) { + ret.raw[MaxLanes(dh) + i] = hi.raw[2 * i + 1]; + } + return ret; +} + +// ------------------------------ CombineShiftRightBytes +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + VFromD ret; + const uint8_t* HWY_RESTRICT lo8 = + reinterpret_cast(lo.raw); + uint8_t* HWY_RESTRICT ret8 = + reinterpret_cast(ret.raw); + CopyBytes(lo8 + kBytes, ret8); + CopyBytes(hi.raw, ret8 + d.MaxBytes() - kBytes); + return ret; +} + +// ------------------------------ ShiftLeftBytes + +template +HWY_API VFromD ShiftLeftBytes(D d, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + VFromD ret; + uint8_t* HWY_RESTRICT ret8 = + reinterpret_cast(ret.raw); + ZeroBytes(ret8); + CopyBytes(v.raw, ret8 + kBytes); + return ret; +} + +template +HWY_API Vec128 ShiftLeftBytes(Vec128 v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftLeftLanes + +template > +HWY_API VFromD ShiftLeftLanes(D d, VFromD v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec128 ShiftLeftLanes(Vec128 v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes +template +HWY_API VFromD ShiftRightBytes(D d, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + VFromD ret; + const uint8_t* HWY_RESTRICT v8 = + reinterpret_cast(v.raw); + uint8_t* HWY_RESTRICT ret8 = + reinterpret_cast(ret.raw); + CopyBytes(v8 + kBytes, ret8); + ZeroBytes(ret8 + d.MaxBytes() - kBytes); + return ret; +} + +// ------------------------------ ShiftRightLanes +template +HWY_API VFromD ShiftRightLanes(D d, VFromD v) { + const Repartition d8; + constexpr size_t kBytes = kLanes * sizeof(TFromD); + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ Tuples, PromoteEvenTo/PromoteOddTo +#include "hwy/ops/inside-inl.h" + +// ------------------------------ LoadInterleaved2/3/4 + +// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. +// We implement those here because scalar code is likely faster than emulation +// via shuffles. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +// Same for Load/StoreInterleaved of special floats. +#ifdef HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED +#endif + +template > +HWY_API void LoadInterleaved2(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1) { + alignas(16) T buf0[MaxLanes(d)]; + alignas(16) T buf1[MaxLanes(d)]; + for (size_t i = 0; i < MaxLanes(d); ++i) { + buf0[i] = *unaligned++; + buf1[i] = *unaligned++; + } + v0 = Load(d, buf0); + v1 = Load(d, buf1); +} + +template > +HWY_API void LoadInterleaved3(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + alignas(16) T buf0[MaxLanes(d)]; + alignas(16) T buf1[MaxLanes(d)]; + alignas(16) T buf2[MaxLanes(d)]; + for (size_t i = 0; i < MaxLanes(d); ++i) { + buf0[i] = *unaligned++; + buf1[i] = *unaligned++; + buf2[i] = *unaligned++; + } + v0 = Load(d, buf0); + v1 = Load(d, buf1); + v2 = Load(d, buf2); +} + +template > +HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + alignas(16) T buf0[MaxLanes(d)]; + alignas(16) T buf1[MaxLanes(d)]; + alignas(16) T buf2[MaxLanes(d)]; + alignas(16) T buf3[MaxLanes(d)]; + for (size_t i = 0; i < MaxLanes(d); ++i) { + buf0[i] = *unaligned++; + buf1[i] = *unaligned++; + buf2[i] = *unaligned++; + buf3[i] = *unaligned++; + } + v0 = Load(d, buf0); + v1 = Load(d, buf1); + v2 = Load(d, buf2); + v3 = Load(d, buf3); +} + +// ------------------------------ StoreInterleaved2/3/4 + +template +HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, + TFromD* HWY_RESTRICT unaligned) { + for (size_t i = 0; i < MaxLanes(d); ++i) { + *unaligned++ = v0.raw[i]; + *unaligned++ = v1.raw[i]; + } +} + +template +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + TFromD* HWY_RESTRICT unaligned) { + for (size_t i = 0; i < MaxLanes(d); ++i) { + *unaligned++ = v0.raw[i]; + *unaligned++ = v1.raw[i]; + *unaligned++ = v2.raw[i]; + } +} + +template +HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, + VFromD v3, D d, + TFromD* HWY_RESTRICT unaligned) { + for (size_t i = 0; i < MaxLanes(d); ++i) { + *unaligned++ = v0.raw[i]; + *unaligned++ = v1.raw[i]; + *unaligned++ = v2.raw[i]; + *unaligned++ = v3.raw[i]; + } +} + +// ------------------------------ Stream +template +HWY_API void Stream(VFromD v, D d, TFromD* HWY_RESTRICT aligned) { + Store(v, d, aligned); +} + +// ------------------------------ Scatter in generic_ops-inl.h +// ------------------------------ Gather in generic_ops-inl.h + +// ================================================== CONVERT + +// ConvertTo and DemoteTo with floating-point input and integer output truncate +// (rounding toward zero). + +namespace detail { + +template +HWY_INLINE ToT CastValueForF2IConv(FromT val) { + // Prevent ubsan errors when converting float to narrower integer + + using FromTU = MakeUnsigned; + using ToTU = MakeUnsigned; + + constexpr unsigned kMaxExpField = + static_cast(MaxExponentField()); + constexpr unsigned kExpBias = kMaxExpField >> 1; + constexpr unsigned kMinOutOfRangeExpField = static_cast(HWY_MIN( + kExpBias + sizeof(ToT) * 8 - static_cast(IsSigned()), + kMaxExpField)); + + // If ToT is signed, compare only the exponent bits of val against + // kMinOutOfRangeExpField. + // + // Otherwise, if ToT is unsigned, compare the sign bit plus exponent bits of + // val against kMinOutOfRangeExpField as a negative value is outside of the + // range of an unsigned integer type. + const FromT val_to_compare = + static_cast(IsSigned() ? ScalarAbs(val) : val); + + // val is within the range of ToT if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is less + // than kMinOutOfRangeExpField + // + // Otherwise, val is either outside of the range of ToT or equal to + // LimitsMin() if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is greater + // than or equal to kMinOutOfRangeExpField. + + return (static_cast(BitCastScalar(val_to_compare) >> + MantissaBits()) < kMinOutOfRangeExpField) + ? static_cast(val) + : static_cast(static_cast(LimitsMax()) + + static_cast(ScalarSignBit(val))); +} + +template +HWY_INLINE ToT CastValueForPromoteTo(ToTypeTag /* to_type_tag */, FromT val) { + return ConvertScalarTo(val); +} + +template +HWY_INLINE ToT CastValueForPromoteTo(hwy::SignedTag /*to_type_tag*/, + float val) { + return CastValueForF2IConv(val); +} + +template +HWY_INLINE ToT CastValueForPromoteTo(hwy::UnsignedTag /*to_type_tag*/, + float val) { + return CastValueForF2IConv(val); +} +// If val is within the range of ToT, CastValueForInRangeF2IConv(val) +// returns static_cast(val) +// +// Otherwise, CastValueForInRangeF2IConv(val) returns an +// implementation-defined result if val is not within the range of ToT. +template +HWY_INLINE ToT CastValueForInRangeF2IConv(FromT val) { + // Prevent ubsan errors when converting float to narrower integer + + using FromTU = MakeUnsigned; + + constexpr unsigned kMaxExpField = + static_cast(MaxExponentField()); + constexpr unsigned kExpBias = kMaxExpField >> 1; + constexpr unsigned kMinOutOfRangeExpField = static_cast(HWY_MIN( + kExpBias + sizeof(ToT) * 8 - static_cast(IsSigned()), + kMaxExpField)); + + // If ToT is signed, compare only the exponent bits of val against + // kMinOutOfRangeExpField. + // + // Otherwise, if ToT is unsigned, compare the sign bit plus exponent bits of + // val against kMinOutOfRangeExpField as a negative value is outside of the + // range of an unsigned integer type. + const FromT val_to_compare = + static_cast(IsSigned() ? ScalarAbs(val) : val); + + // val is within the range of ToT if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is less + // than kMinOutOfRangeExpField + // + // Otherwise, val is either outside of the range of ToT or equal to + // LimitsMin() if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is greater + // than or equal to kMinOutOfRangeExpField. + + return (static_cast(BitCastScalar(val_to_compare) >> + MantissaBits()) < kMinOutOfRangeExpField) + ? static_cast(val) + : static_cast(LimitsMin()); +} + +} // namespace detail + +template +HWY_API VFromD PromoteTo(DTo d, Vec128 from) { + static_assert(sizeof(TFromD) > sizeof(TFrom), "Not promoting"); + VFromD ret; + for (size_t i = 0; i < MaxLanes(d); ++i) { + // For bits Y > X, floatX->floatY and intX->intY are always representable. + ret.raw[i] = detail::CastValueForPromoteTo>( + hwy::TypeTag>(), from.raw[i]); + } + return ret; +} + +#ifdef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#endif + +template +HWY_API VFromD PromoteInRangeTo(D64 d64, VFromD> v) { + VFromD ret; + for (size_t i = 0; i < MaxLanes(d64); ++i) { + ret.raw[i] = detail::CastValueForInRangeF2IConv>(v.raw[i]); + } + return ret; +} + +// MSVC 19.10 cannot deduce the argument type if HWY_IF_FLOAT(TFrom) is here, +// so we overload for TFrom=double and ToT={float,int32_t}. +template +HWY_API VFromD DemoteTo(D d, VFromD> from) { + VFromD ret; + for (size_t i = 0; i < MaxLanes(d); ++i) { + // Prevent ubsan errors when converting float to narrower integer/float + if (ScalarIsInf(from.raw[i]) || + ScalarAbs(from.raw[i]) > static_cast(HighestValue())) { + ret.raw[i] = ScalarSignBit(from.raw[i]) ? LowestValue() + : HighestValue(); + continue; + } + ret.raw[i] = static_cast(from.raw[i]); + } + return ret; +} +template +HWY_API VFromD DemoteTo(D d, VFromD> from) { + VFromD ret; + for (size_t i = 0; i < MaxLanes(d); ++i) { + // Prevent ubsan errors when converting double to narrower integer/int32_t + ret.raw[i] = detail::CastValueForF2IConv>(from.raw[i]); + } + return ret; +} + +template )> +HWY_API VFromD DemoteTo(DTo /* tag */, Vec128 from) { + using TTo = TFromD; + static_assert(sizeof(TTo) < sizeof(TFrom), "Not demoting"); + + VFromD ret; + for (size_t i = 0; i < N; ++i) { + // Int to int: choose closest value in ToT to `from` (avoids UB) + from.raw[i] = + HWY_MIN(HWY_MAX(LimitsMin(), from.raw[i]), LimitsMax()); + ret.raw[i] = static_cast(from.raw[i]); + } + return ret; +} + +// Disable the default unsigned to signed DemoteTo/ReorderDemote2To +// implementations in generic_ops-inl.h on EMU128 as the EMU128 target has +// target-specific implementations of the unsigned to signed DemoteTo and +// ReorderDemote2To ops + +// NOTE: hwy::EnableIf()>* = nullptr is used instead of +// hwy::EnableIf* = nullptr to avoid compiler errors since +// !hwy::IsSame() is always false and as !hwy::IsSame() will cause +// SFINAE to occur instead of a hard error due to a dependency on the V template +// argument +#undef HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V +#define HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V(V) \ + hwy::EnableIf()>* = nullptr + +template +HWY_API VFromD DemoteTo(DTo /* tag */, Vec128 from) { + using TTo = TFromD; + static_assert(sizeof(TTo) < sizeof(TFrom), "Not demoting"); + + const auto max = static_cast>(LimitsMax()); + + VFromD ret; + for (size_t i = 0; i < N; ++i) { + // Int to int: choose closest value in ToT to `from` (avoids UB) + ret.raw[i] = static_cast(HWY_MIN(from.raw[i], max)); + } + return ret; +} + +template +HWY_API VFromD DemoteTo(DTo /* tag */, Vec128 from) { + using TTo = TFromD; + static_assert(sizeof(TTo) < sizeof(TFrom), "Not demoting"); + + VFromD ret; + for (size_t i = 0; i < N; ++i) { + // int64_t/uint64_t to float: okay to cast to float as an int64_t/uint64_t + // value is always within the range of a float + ret.raw[i] = static_cast(from.raw[i]); + } + return ret; +} + +template +HWY_API VFromD ReorderDemote2To(DBF16 dbf16, VF32 a, VF32 b) { + const Repartition du32; + const VFromD b_in_lower = ShiftRight<16>(BitCast(du32, b)); + // Avoid OddEven - we want the upper half of `a` even on big-endian systems. + const VFromD a_mask = Set(du32, 0xFFFF0000); + return BitCast(dbf16, IfVecThenElse(a_mask, BitCast(du32, a), b_in_lower)); +} + +template ), class V, + HWY_IF_SIGNED_V(V), HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + HWY_IF_LANES_D(DN, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const RepartitionToWide dw; + const size_t NW = Lanes(dw); + using TN = TFromD; + const TN min = LimitsMin(); + const TN max = LimitsMax(); + VFromD ret; + for (size_t i = 0; i < NW; ++i) { + ret.raw[i] = static_cast(HWY_MIN(HWY_MAX(min, a.raw[i]), max)); + } + for (size_t i = 0; i < NW; ++i) { + ret.raw[NW + i] = static_cast(HWY_MIN(HWY_MAX(min, b.raw[i]), max)); + } + return ret; +} + +template ) * 2), + HWY_IF_LANES_D(DN, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const RepartitionToWide dw; + const size_t NW = Lanes(dw); + using TN = TFromD; + using TN_U = MakeUnsigned; + const TN_U max = static_cast(LimitsMax()); + VFromD ret; + for (size_t i = 0; i < NW; ++i) { + ret.raw[i] = static_cast(HWY_MIN(a.raw[i], max)); + } + for (size_t i = 0; i < NW; ++i) { + ret.raw[NW + i] = static_cast(HWY_MIN(b.raw[i], max)); + } + return ret; +} + +template ), class V, + HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + HWY_IF_LANES_D(DN, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedDemote2To(DN dn, V a, V b) { + return ReorderDemote2To(dn, a, b); +} + +template ), + HWY_IF_LANES_D(DN, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedDemote2To(DN dn, V a, V b) { + const size_t NW = Lanes(dn) / 2; + using TN = TFromD; + VFromD ret; + for (size_t i = 0; i < NW; ++i) { + ret.raw[i] = ConvertScalarTo(a.raw[i]); + } + for (size_t i = 0; i < NW; ++i) { + ret.raw[NW + i] = ConvertScalarTo(b.raw[i]); + } + return ret; +} + +namespace detail { + +HWY_INLINE void StoreU16ToF16(const uint16_t val, + hwy::float16_t* HWY_RESTRICT to) { + CopySameSize(&val, to); +} + +HWY_INLINE uint16_t U16FromF16(const hwy::float16_t* HWY_RESTRICT from) { + uint16_t bits16; + CopySameSize(from, &bits16); + return bits16; +} + +} // namespace detail + +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + VFromD ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = F32FromBF16(v.raw[i]); + } + return ret; +} + +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec128 v) { + VFromD ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = BF16FromF32(v.raw[i]); + } + return ret; +} + +#ifdef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#endif + +template +HWY_API VFromD DemoteInRangeTo(D32 d32, VFromD> v) { + VFromD ret; + for (size_t i = 0; i < MaxLanes(d32); ++i) { + ret.raw[i] = detail::CastValueForInRangeF2IConv>(v.raw[i]); + } + return ret; +} + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_API VFromD ConvertTo(hwy::FloatTag /*tag*/, DTo /*tag*/, + Vec128 from) { + using ToT = TFromD; + static_assert(sizeof(ToT) == sizeof(TFrom), "Should have same size"); + VFromD ret; + constexpr size_t N = HWY_MAX_LANES_D(DTo); + + for (size_t i = 0; i < N; ++i) { + // float## -> int##: return closest representable value + ret.raw[i] = CastValueForF2IConv(from.raw[i]); + } + return ret; +} + +template +HWY_API VFromD ConvertTo(hwy::NonFloatTag /*tag*/, DTo /* tag */, + Vec128 from) { + using ToT = TFromD; + static_assert(sizeof(ToT) == sizeof(TFrom), "Should have same size"); + VFromD ret; + constexpr size_t N = HWY_MAX_LANES_D(DTo); + for (size_t i = 0; i < N; ++i) { + // int## -> float##: no check needed + ret.raw[i] = static_cast(from.raw[i]); + } + return ret; +} + +} // namespace detail + +template +HWY_API VFromD ConvertTo(DTo d, Vec128 from) { + return detail::ConvertTo(hwy::IsFloatTag(), d, from); +} + +#ifdef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#undef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#else +#define HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#endif + +template +HWY_API VFromD ConvertInRangeTo(DI di, VFromD> v) { + VFromD ret; + for (size_t i = 0; i < MaxLanes(di); i++) { + ret.raw[i] = detail::CastValueForInRangeF2IConv>(v.raw[i]); + } + return ret; +} + +template +HWY_API Vec128 U8FromU32(Vec128 v) { + return DemoteTo(Simd(), v); +} + +// ------------------------------ Truncations + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec128 v) { + VFromD ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFF); + } + return ret; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec128 v) { + VFromD ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFFFF); + } + return ret; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec128 v) { + VFromD ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFFFFFFFFu); + } + return ret; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec128 v) { + VFromD ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFF); + } + return ret; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec128 v) { + VFromD ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFFFF); + } + return ret; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec128 v) { + VFromD ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = static_cast(v.raw[i] & 0xFF); + } + return ret; +} + +#ifdef HWY_NATIVE_ORDERED_TRUNCATE_2_TO +#undef HWY_NATIVE_ORDERED_TRUNCATE_2_TO +#else +#define HWY_NATIVE_ORDERED_TRUNCATE_2_TO +#endif + +template ) * 2), + HWY_IF_LANES_D(DN, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedTruncate2To(DN dn, V a, V b) { + const RepartitionToWide dw; + const size_t NW = Lanes(dw); + using TW = TFromD; + using TN = TFromD; + VFromD ret; + constexpr TW max_val{LimitsMax()}; + + for (size_t i = 0; i < NW; ++i) { + ret.raw[i] = static_cast(a.raw[i] & max_val); + } + for (size_t i = 0; i < NW; ++i) { + ret.raw[NW + i] = static_cast(b.raw[i] & max_val); + } + return ret; +} + +// ================================================== SWIZZLE + +template +HWY_API T GetLane(Vec128 v) { + return v.raw[0]; +} + +template +HWY_API Vec128 InsertLane(Vec128 v, size_t i, T t) { + v.raw[i] = t; + return v; +} + +template +HWY_API T ExtractLane(Vec128 v, size_t i) { + return v.raw[i]; +} + +template +HWY_API Vec128 DupEven(Vec128 v) { + for (size_t i = 0; i < N; i += 2) { + v.raw[i + 1] = v.raw[i]; + } + return v; +} + +template +HWY_API Vec128 DupOdd(Vec128 v) { + for (size_t i = 0; i < N; i += 2) { + v.raw[i] = v.raw[i + 1]; + } + return v; +} + +template +HWY_API Vec128 OddEven(Vec128 odd, Vec128 even) { + for (size_t i = 0; i < N; i += 2) { + odd.raw[i] = even.raw[i]; + } + return odd; +} + +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + constexpr size_t N = HWY_MAX_LANES_D(D); + for (size_t i = 1; i < N; i += 2) { + a.raw[i] = b.raw[i - 1]; + } + return a; +} + +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + constexpr size_t N = HWY_MAX_LANES_D(D); + for (size_t i = 1; i < N; i += 2) { + b.raw[i - 1] = a.raw[i]; + } + return b; +} + +template +HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks +template +HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { + return v; +} + +// ------------------------------ InterleaveEvenBlocks +template > +HWY_API V InterleaveEvenBlocks(D, V a, V /*b*/) { + return a; +} +// ------------------------------ InterleaveOddBlocks +template > +HWY_API V InterleaveOddBlocks(D, V a, V /*b*/) { + return a; +} + +// ------------------------------ InterleaveLowerBlocks +template > +HWY_API V InterleaveLowerBlocks(D, V a, V /*b*/) { + return a; +} +// ------------------------------ InterleaveUpperBlocks +template > +HWY_API V InterleaveUpperBlocks(D, V a, V /*b*/) { + return a; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices128 { + MakeSigned raw[N]; +}; + +template +HWY_API Indices128, N> IndicesFromVec(D d, Vec128 vec) { + static_assert(sizeof(TFromD) == sizeof(TI), "Index/lane size must match"); + Indices128, N> ret; + CopyBytes(vec.raw, ret.raw); + return ret; +} + +template +HWY_API Indices128, HWY_MAX_LANES_D(D)> SetTableIndices( + D d, const TI* idx) { + return IndicesFromVec(d, LoadU(Rebind(), idx)); +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + ret.raw[i] = v.raw[idx.raw[i]]; + } + return ret; +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { + using TI = MakeSigned; + Vec128 ret; + constexpr TI kVecLaneIdxMask = static_cast(N - 1); + for (size_t i = 0; i < N; ++i) { + const auto src_idx = idx.raw[i]; + const auto masked_src_lane_idx = src_idx & kVecLaneIdxMask; + ret.raw[i] = (src_idx < static_cast(N)) ? a.raw[masked_src_lane_idx] + : b.raw[masked_src_lane_idx]; + } + return ret; +} + +// ------------------------------ ReverseBlocks +template +HWY_API VFromD ReverseBlocks(D /* tag */, VFromD v) { + return v; // Single block: no change +} + +// ------------------------------ Reverse + +template +HWY_API VFromD Reverse(D d, VFromD v) { + VFromD ret; + for (size_t i = 0; i < MaxLanes(d); ++i) { + ret.raw[i] = v.raw[MaxLanes(d) - 1 - i]; + } + return ret; +} + +// Per-target flag to prevent generic_ops-inl.h defining 8-bit Reverse2/4/8. +#ifdef HWY_NATIVE_REVERSE2_8 +#undef HWY_NATIVE_REVERSE2_8 +#else +#define HWY_NATIVE_REVERSE2_8 +#endif + +template +HWY_API VFromD Reverse2(D d, VFromD v) { + VFromD ret; + for (size_t i = 0; i < MaxLanes(d); i += 2) { + ret.raw[i + 0] = v.raw[i + 1]; + ret.raw[i + 1] = v.raw[i + 0]; + } + return ret; +} + +template +HWY_API VFromD Reverse4(D d, VFromD v) { + VFromD ret; + for (size_t i = 0; i < MaxLanes(d); i += 4) { + ret.raw[i + 0] = v.raw[i + 3]; + ret.raw[i + 1] = v.raw[i + 2]; + ret.raw[i + 2] = v.raw[i + 1]; + ret.raw[i + 3] = v.raw[i + 0]; + } + return ret; +} + +template +HWY_API VFromD Reverse8(D d, VFromD v) { + VFromD ret; + for (size_t i = 0; i < MaxLanes(d); i += 8) { + ret.raw[i + 0] = v.raw[i + 7]; + ret.raw[i + 1] = v.raw[i + 6]; + ret.raw[i + 2] = v.raw[i + 5]; + ret.raw[i + 3] = v.raw[i + 4]; + ret.raw[i + 4] = v.raw[i + 3]; + ret.raw[i + 5] = v.raw[i + 2]; + ret.raw[i + 6] = v.raw[i + 1]; + ret.raw[i + 7] = v.raw[i + 0]; + } + return ret; +} + +// ------------------------------ SlideUpLanes + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { + VFromD ret = Zero(d); + constexpr size_t N = HWY_MAX_LANES_D(D); + const size_t clamped_amt = HWY_MIN(amt, N); + CopyBytes(v.raw, ret.raw + clamped_amt, + (N - clamped_amt) * sizeof(TFromD)); + return ret; +} + +// ------------------------------ SlideDownLanes + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { + VFromD ret = Zero(d); + constexpr size_t N = HWY_MAX_LANES_D(D); + const size_t clamped_amt = HWY_MIN(amt, N); + CopyBytes(v.raw + clamped_amt, ret.raw, + (N - clamped_amt) * sizeof(TFromD)); + return ret; +} + +// ================================================== BLOCKWISE + +// ------------------------------ Shuffle* + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec128 Shuffle2301(Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit"); + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Reverse2(DFromV(), v); +} + +// Swap 64-bit halves +template +HWY_API Vec128 Shuffle1032(Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit"); + Vec128 ret; + ret.raw[3] = v.raw[1]; + ret.raw[2] = v.raw[0]; + ret.raw[1] = v.raw[3]; + ret.raw[0] = v.raw[2]; + return ret; +} +template +HWY_API Vec128 Shuffle01(Vec128 v) { + static_assert(sizeof(T) == 8, "Only for 64-bit"); + return Reverse2(DFromV(), v); +} + +// Rotate right 32 bits +template +HWY_API Vec128 Shuffle0321(Vec128 v) { + Vec128 ret; + ret.raw[3] = v.raw[0]; + ret.raw[2] = v.raw[3]; + ret.raw[1] = v.raw[2]; + ret.raw[0] = v.raw[1]; + return ret; +} + +// Rotate left 32 bits +template +HWY_API Vec128 Shuffle2103(Vec128 v) { + Vec128 ret; + ret.raw[3] = v.raw[2]; + ret.raw[2] = v.raw[1]; + ret.raw[1] = v.raw[0]; + ret.raw[0] = v.raw[3]; + return ret; +} + +template +HWY_API Vec128 Shuffle0123(Vec128 v) { + return Reverse4(DFromV(), v); +} + +// ------------------------------ Broadcast +template +HWY_API Vec128 Broadcast(Vec128 v) { + for (size_t i = 0; i < N; ++i) { + v.raw[i] = v.raw[kLane]; + } + return v; +} + +// ------------------------------ TableLookupBytes, TableLookupBytesOr0 + +template +HWY_API Vec128 TableLookupBytes(Vec128 v, + Vec128 indices) { + const uint8_t* HWY_RESTRICT v_bytes = + reinterpret_cast(v.raw); + const uint8_t* HWY_RESTRICT idx_bytes = + reinterpret_cast(indices.raw); + Vec128 ret; + uint8_t* HWY_RESTRICT ret_bytes = + reinterpret_cast(ret.raw); + for (size_t i = 0; i < NI * sizeof(TI); ++i) { + const size_t idx = idx_bytes[i]; + // Avoid out of bounds reads. + ret_bytes[i] = idx < sizeof(T) * N ? v_bytes[idx] : 0; + } + return ret; +} + +template +HWY_API Vec128 TableLookupBytesOr0(Vec128 v, + Vec128 indices) { + // Same as TableLookupBytes, which already returns 0 if out of bounds. + return TableLookupBytes(v, indices); +} + +// ------------------------------ InterleaveLower/InterleaveUpper + +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + Vec128 ret; + for (size_t i = 0; i < N / 2; ++i) { + ret.raw[2 * i + 0] = a.raw[i]; + ret.raw[2 * i + 1] = b.raw[i]; + } + return ret; +} + +// Additional overload for the optional tag. +template +HWY_API VFromD InterleaveLower(D /* tag */, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +template +HWY_API VFromD InterleaveUpper(D d, VFromD a, VFromD b) { + const Half dh; + VFromD ret; + for (size_t i = 0; i < MaxLanes(dh); ++i) { + ret.raw[2 * i + 0] = a.raw[MaxLanes(dh) + i]; + ret.raw[2 * i + 1] = b.raw[MaxLanes(dh) + i]; + } + return ret; +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ================================================== MASK + +template +HWY_API bool AllFalse(D d, MFromD mask) { + typename MFromD::Raw or_sum = 0; + for (size_t i = 0; i < MaxLanes(d); ++i) { + or_sum |= mask.bits[i]; + } + return or_sum == 0; +} + +template +HWY_API bool AllTrue(D d, MFromD mask) { + constexpr uint64_t kAll = LimitsMax::Raw>(); + uint64_t and_sum = kAll; + for (size_t i = 0; i < MaxLanes(d); ++i) { + and_sum &= mask.bits[i]; + } + return and_sum == kAll; +} + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + MFromD m; + for (size_t i = 0; i < MaxLanes(d); ++i) { + const size_t bit = size_t{1} << (i & 7); + const size_t idx_byte = i >> 3; + m.bits[i] = MFromD::FromBool((bits[idx_byte] & bit) != 0); + } + return m; +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + MFromD m; + for (size_t i = 0; i < MaxLanes(d); ++i) { + m.bits[i] = MFromD::FromBool(((mask_bits >> i) & 1u) != 0); + } + return m; +} + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(D d, MFromD mask, uint8_t* bits) { + bits[0] = 0; + if (MaxLanes(d) > 8) bits[1] = 0; // MaxLanes(d) <= 16, so max two bytes + for (size_t i = 0; i < MaxLanes(d); ++i) { + const size_t bit = size_t{1} << (i & 7); + const size_t idx_byte = i >> 3; + if (mask.bits[i]) { + bits[idx_byte] = static_cast(bits[idx_byte] | bit); + } + } + return MaxLanes(d) > 8 ? 2 : 1; +} + +template +HWY_API size_t CountTrue(D d, MFromD mask) { + size_t count = 0; + for (size_t i = 0; i < MaxLanes(d); ++i) { + count += mask.bits[i] != 0; + } + return count; +} + +template +HWY_API size_t FindKnownFirstTrue(D d, MFromD mask) { + for (size_t i = 0; i < MaxLanes(d); ++i) { + if (mask.bits[i] != 0) return i; + } + HWY_DASSERT(false); + return 0; +} + +template +HWY_API intptr_t FindFirstTrue(D d, MFromD mask) { + for (size_t i = 0; i < MaxLanes(d); ++i) { + if (mask.bits[i] != 0) return static_cast(i); + } + return intptr_t{-1}; +} + +template +HWY_API size_t FindKnownLastTrue(D d, MFromD mask) { + for (intptr_t i = static_cast(MaxLanes(d) - 1); i >= 0; i--) { + if (mask.bits[i] != 0) return static_cast(i); + } + HWY_DASSERT(false); + return 0; +} + +template +HWY_API intptr_t FindLastTrue(D d, MFromD mask) { + for (intptr_t i = static_cast(MaxLanes(d) - 1); i >= 0; i--) { + if (mask.bits[i] != 0) return i; + } + return intptr_t{-1}; +} + +// ------------------------------ Compress + +template +struct CompressIsPartition { + enum { value = (sizeof(T) != 1) }; +}; + +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + size_t count = 0; + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + if (mask.bits[i]) { + ret.raw[count++] = v.raw[i]; + } + } + for (size_t i = 0; i < N; ++i) { + if (!mask.bits[i]) { + ret.raw[count++] = v.raw[i]; + } + } + HWY_DASSERT(count == N); + return ret; +} + +// ------------------------------ Expand + +// Could also just allow generic_ops-inl.h to implement these, but use our +// simple implementation below to ensure the test is correct. +#ifdef HWY_NATIVE_EXPAND +#undef HWY_NATIVE_EXPAND +#else +#define HWY_NATIVE_EXPAND +#endif + +template +HWY_API Vec128 Expand(Vec128 v, const Mask128 mask) { + size_t in_pos = 0; + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + if (mask.bits[i]) { + ret.raw[i] = v.raw[in_pos++]; + } else { + ret.raw[i] = ConvertScalarTo(0); + } + } + return ret; +} + +// ------------------------------ LoadExpand + +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { + size_t in_pos = 0; + VFromD ret; + for (size_t i = 0; i < Lanes(d); ++i) { + if (mask.bits[i]) { + ret.raw[i] = unaligned[in_pos++]; + } else { + ret.raw[i] = TFromD(); // zero, also works for float16_t + } + } + return ret; +} + +// ------------------------------ CompressNot +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + size_t count = 0; + Vec128 ret; + for (size_t i = 0; i < N; ++i) { + if (!mask.bits[i]) { + ret.raw[count++] = v.raw[i]; + } + } + for (size_t i = 0; i < N; ++i) { + if (mask.bits[i]) { + ret.raw[count++] = v.raw[i]; + } + } + HWY_DASSERT(count == N); + return ret; +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +// ------------------------------ CompressBits +template +HWY_API Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + return Compress(v, LoadMaskBits(Simd(), bits)); +} + +// ------------------------------ CompressStore + +// generic_ops-inl defines the 8-bit versions. +template +HWY_API size_t CompressStore(VFromD v, MFromD mask, D d, + TFromD* HWY_RESTRICT unaligned) { + size_t count = 0; + for (size_t i = 0; i < MaxLanes(d); ++i) { + if (mask.bits[i]) { + unaligned[count++] = v.raw[i]; + } + } + return count; +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(VFromD v, MFromD mask, D d, + TFromD* HWY_RESTRICT unaligned) { + return CompressStore(v, mask, d, unaligned); +} + +// ------------------------------ CompressBitsStore +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + const MFromD mask = LoadMaskBits(d, bits); + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ Additional mask logical operations +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + return mask; +} + +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + using TU = hwy::MakeUnsigned; + + Mask128 result; + TU result_lane_mask{0}; + for (size_t i = 0; i < N; i++) { + result_lane_mask = static_cast(result_lane_mask | mask.bits[i]); + result.bits[i] = result_lane_mask; + } + return result; +} + +template +HWY_API Mask128 SetBeforeFirst(Mask128 mask) { + return Not(SetAtOrAfterFirst(mask)); +} + +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + using TU = hwy::MakeUnsigned; + using TI = hwy::MakeSigned; + + Mask128 result; + TU result_lane_mask = static_cast(~TU{0}); + for (size_t i = 0; i < N; i++) { + const auto curr_lane_mask_bits = mask.bits[i]; + result.bits[i] = static_cast(curr_lane_mask_bits & result_lane_mask); + result_lane_mask = + static_cast(result_lane_mask & + static_cast(-static_cast(mask.bits[i] == 0))); + } + return result; +} + +template +HWY_API Mask128 SetAtOrBeforeFirst(Mask128 mask) { + using TU = hwy::MakeUnsigned; + using TI = hwy::MakeSigned; + + Mask128 result; + TU result_lane_mask = static_cast(~TU{0}); + for (size_t i = 0; i < N; i++) { + result.bits[i] = result_lane_mask; + result_lane_mask = + static_cast(result_lane_mask & + static_cast(-static_cast(mask.bits[i] == 0))); + } + return result; +} + +// ------------------------------ WidenMulPairwiseAdd + +template +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), + Mul(PromoteOddTo(df, a), PromoteOddTo(df, b))); +} + +template +HWY_API VFromD WidenMulPairwiseAdd(D d32, V16 a, V16 b) { + return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), + Mul(PromoteOddTo(d32, a), PromoteOddTo(d32, b))); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +template +HWY_API VFromD ReorderWidenMulAccumulate(D d32, V16 a, V16 b, + const VFromD sum0, + VFromD& sum1) { + sum1 = MulAdd(PromoteOddTo(d32, a), PromoteOddTo(d32, b), sum1); + return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), sum0); +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API VW RearrangeToOddPlusEven(VW sum0, VW sum1) { + return Add(sum0, sum1); +} + +// ================================================== REDUCTIONS + +#ifdef HWY_NATIVE_REDUCE_SCALAR +#undef HWY_NATIVE_REDUCE_SCALAR +#else +#define HWY_NATIVE_REDUCE_SCALAR +#endif + +template , HWY_IF_REDUCE_D(D)> +HWY_API T ReduceSum(D d, VFromD v) { + T sum = T{0}; + for (size_t i = 0; i < MaxLanes(d); ++i) { + sum += v.raw[i]; + } + return sum; +} + +template , HWY_IF_REDUCE_D(D)> +HWY_API T ReduceMin(D d, VFromD v) { + T min = PositiveInfOrHighestValue(); + for (size_t i = 0; i < MaxLanes(d); ++i) { + min = HWY_MIN(min, v.raw[i]); + } + return min; +} +template , HWY_IF_REDUCE_D(D)> +HWY_API T ReduceMax(D d, VFromD v) { + T max = NegativeInfOrLowestValue(); + for (size_t i = 0; i < MaxLanes(d); ++i) { + max = HWY_MAX(max, v.raw[i]); + } + return max; +} + +// ------------------------------ SumOfLanes + +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + return Set(d, ReduceSum(d, v)); +} +template +HWY_API VFromD MinOfLanes(D d, VFromD v) { + return Set(d, ReduceMin(d, v)); +} +template +HWY_API VFromD MaxOfLanes(D d, VFromD v) { + return Set(d, ReduceMax(d, v)); +} + +// ================================================== OPS WITH DEPENDENCIES + +// ------------------------------ MulEven/Odd 64x64 (UpperHalf) + +template +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + alignas(16) T mul[2]; + mul[0] = Mul128(GetLane(a), GetLane(b), &mul[1]); + return Load(Full128(), mul); +} + +template +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + alignas(16) T mul[2]; + const Half> d2; + mul[0] = + Mul128(GetLane(UpperHalf(d2, a)), GetLane(UpperHalf(d2, b)), &mul[1]); + return Load(Full128(), mul); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/lib/highway/hwy/ops/generic_ops-inl.h b/lib/highway/hwy/ops/generic_ops-inl.h new file mode 100644 index 00000000000..8564b39b61a --- /dev/null +++ b/lib/highway/hwy/ops/generic_ops-inl.h @@ -0,0 +1,8347 @@ +// Copyright 2021 Google LLC +// Copyright 2023,2024 Arm Limited and/or +// its affiliates +// SPDX-License-Identifier: Apache-2.0 +// SPDX-License-Identifier: BSD-3-Clause +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Target-independent types/functions defined after target-specific ops. + +// The "include guards" in this file that check HWY_TARGET_TOGGLE serve to skip +// the generic implementation here if native ops are already defined. + +#include "hwy/base.h" +#include "hwy/detect_compiler_arch.h" + +// Define detail::Shuffle1230 etc, but only when viewing the current header; +// normally this is included via highway.h, which includes ops/*.h. +#if HWY_IDE && !defined(HWY_HIGHWAY_INCLUDED) +#include "hwy/detect_targets.h" +#include "hwy/ops/emu128-inl.h" +#endif // HWY_IDE + +// Relies on the external include guard in highway.h. +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// The lane type of a vector type, e.g. float for Vec>. +template +using LaneType = decltype(GetLane(V())); + +// Vector type, e.g. Vec128 for CappedTag. Useful as the return +// type of functions that do not take a vector argument, or as an argument type +// if the function only has a template argument for D, or for explicit type +// names instead of auto. This may be a built-in type. +template +using Vec = decltype(Zero(D())); + +// Mask type. Useful as the return type of functions that do not take a mask +// argument, or as an argument type if the function only has a template argument +// for D, or for explicit type names instead of auto. +template +using Mask = decltype(MaskFromVec(Zero(D()))); + +// Returns the closest value to v within [lo, hi]. +template +HWY_API V Clamp(const V v, const V lo, const V hi) { + return Min(Max(lo, v), hi); +} + +// CombineShiftRightBytes (and -Lanes) are not available for the scalar target, +// and RVV has its own implementation of -Lanes. +#if (HWY_TARGET != HWY_SCALAR && HWY_TARGET != HWY_RVV) || HWY_IDE + +template +HWY_API VFromD CombineShiftRightLanes(D d, VFromD hi, VFromD lo) { + constexpr size_t kBytes = kLanes * sizeof(TFromD); + static_assert(kBytes < 16, "Shift count is per-block"); + return CombineShiftRightBytes(d, hi, lo); +} + +#endif + +// Returns lanes with the most significant bit set and all other bits zero. +template +HWY_API Vec SignBit(D d) { + const RebindToUnsigned du; + return BitCast(d, Set(du, SignMask>())); +} + +// Returns quiet NaN. +template +HWY_API Vec NaN(D d) { + const RebindToSigned di; + // LimitsMax sets all exponent and mantissa bits to 1. The exponent plus + // mantissa MSB (to indicate quiet) would be sufficient. + return BitCast(d, Set(di, LimitsMax>())); +} + +// Returns positive infinity. +template +HWY_API Vec Inf(D d) { + const RebindToUnsigned du; + using T = TFromD; + using TU = TFromD; + const TU max_x2 = static_cast(MaxExponentTimes2()); + return BitCast(d, Set(du, max_x2 >> 1)); +} + +// ------------------------------ MaskedSetOr/MaskedSet + +template , typename D = DFromV, + typename M = MFromD> +HWY_API V MaskedSetOr(V no, M m, T a) { + D d; + return IfThenElse(m, Set(d, a), no); +} + +template , typename M = MFromD, + typename T = TFromD> +HWY_API V MaskedSet(D d, M m, T a) { + return IfThenElseZero(m, Set(d, a)); +} + +// ------------------------------ ZeroExtendResizeBitCast + +// The implementation of detail::ZeroExtendResizeBitCast for the HWY_EMU128 +// target is in emu128-inl.h, and the implementation of +// detail::ZeroExtendResizeBitCast for the HWY_SCALAR target is in scalar-inl.h +#if HWY_TARGET != HWY_EMU128 && HWY_TARGET != HWY_SCALAR +namespace detail { + +#if HWY_HAVE_SCALABLE +template +HWY_INLINE VFromD ZeroExtendResizeBitCast( + hwy::SizeTag /* from_size_tag */, + hwy::SizeTag /* to_size_tag */, DTo d_to, DFrom d_from, + VFromD v) { + const Repartition d_to_u8; + const auto resized = ResizeBitCast(d_to_u8, v); + // Zero the upper bytes which were not present/valid in d_from. + const size_t num_bytes = Lanes(Repartition()); + return BitCast(d_to, IfThenElseZero(FirstN(d_to_u8, num_bytes), resized)); +} +#else // target that uses fixed-size vectors +// Truncating or same-size resizing cast: same as ResizeBitCast +template +HWY_INLINE VFromD ZeroExtendResizeBitCast( + hwy::SizeTag /* from_size_tag */, + hwy::SizeTag /* to_size_tag */, DTo d_to, DFrom /*d_from*/, + VFromD v) { + return ResizeBitCast(d_to, v); +} + +// Resizing cast to vector that has twice the number of lanes of the source +// vector +template +HWY_INLINE VFromD ZeroExtendResizeBitCast( + hwy::SizeTag /* from_size_tag */, + hwy::SizeTag /* to_size_tag */, DTo d_to, DFrom d_from, + VFromD v) { + const Twice dt_from; + return BitCast(d_to, ZeroExtendVector(dt_from, v)); +} + +// Resizing cast to vector that has more than twice the number of lanes of the +// source vector +template +HWY_INLINE VFromD ZeroExtendResizeBitCast( + hwy::SizeTag /* from_size_tag */, + hwy::SizeTag /* to_size_tag */, DTo d_to, DFrom /*d_from*/, + VFromD v) { + using TFrom = TFromD; + constexpr size_t kNumOfFromLanes = kFromVectSize / sizeof(TFrom); + const Repartition d_resize_to; + return BitCast(d_to, IfThenElseZero(FirstN(d_resize_to, kNumOfFromLanes), + ResizeBitCast(d_resize_to, v))); +} +#endif // HWY_HAVE_SCALABLE + +} // namespace detail +#endif // HWY_TARGET != HWY_EMU128 && HWY_TARGET != HWY_SCALAR + +template +HWY_API VFromD ZeroExtendResizeBitCast(DTo d_to, DFrom d_from, + VFromD v) { + return detail::ZeroExtendResizeBitCast(hwy::SizeTag(), + hwy::SizeTag(), d_to, + d_from, v); +} + +// ------------------------------ SafeFillN + +template > +HWY_API void SafeFillN(const size_t num, const T value, D d, + T* HWY_RESTRICT to) { +#if HWY_MEM_OPS_MIGHT_FAULT + (void)d; + for (size_t i = 0; i < num; ++i) { + to[i] = value; + } +#else + BlendedStore(Set(d, value), FirstN(d, num), d, to); +#endif +} + +// ------------------------------ SafeCopyN + +template > +HWY_API void SafeCopyN(const size_t num, D d, const T* HWY_RESTRICT from, + T* HWY_RESTRICT to) { +#if HWY_MEM_OPS_MIGHT_FAULT + (void)d; + for (size_t i = 0; i < num; ++i) { + to[i] = from[i]; + } +#else + const Mask mask = FirstN(d, num); + BlendedStore(MaskedLoad(mask, d, from), mask, d, to); +#endif +} + +// ------------------------------ IsNegative +#if (defined(HWY_NATIVE_IS_NEGATIVE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_IS_NEGATIVE +#undef HWY_NATIVE_IS_NEGATIVE +#else +#define HWY_NATIVE_IS_NEGATIVE +#endif + +template +HWY_API Mask> IsNegative(V v) { + const DFromV d; + const RebindToSigned di; + return RebindMask(d, MaskFromVec(BroadcastSignBit(BitCast(di, v)))); +} + +#endif // HWY_NATIVE_IS_NEGATIVE + +// ------------------------------ MaskFalse +#if (defined(HWY_NATIVE_MASK_FALSE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MASK_FALSE +#undef HWY_NATIVE_MASK_FALSE +#else +#define HWY_NATIVE_MASK_FALSE +#endif + +template +HWY_API Mask MaskFalse(D d) { + return MaskFromVec(Zero(d)); +} + +#endif // HWY_NATIVE_MASK_FALSE + +// ------------------------------ SetMask +#if (defined(HWY_NATIVE_SET_MASK) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SET_MASK +#undef HWY_NATIVE_SET_MASK +#else +#define HWY_NATIVE_SET_MASK +#endif + +template +HWY_API Mask SetMask(D d, bool val) { + const Repartition di32; + return MaskFromVec(ResizeBitCast(d, Set(di32, -static_cast(val)))); +} + +#endif // HWY_NATIVE_SET_MASK + +// ------------------------------ IfNegativeThenElseZero +#if (defined(HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#undef HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#else +#define HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#endif + +template +HWY_API V IfNegativeThenElseZero(V v, V yes) { + return IfThenElseZero(IsNegative(v), yes); +} + +#endif // HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO + +// ------------------------------ IfNegativeThenZeroElse +#if (defined(HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#undef HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#else +#define HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#endif + +template +HWY_API V IfNegativeThenZeroElse(V v, V no) { + return IfThenZeroElse(IsNegative(v), no); +} + +#endif // HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE + +// ------------------------------ ZeroIfNegative (IfNegativeThenZeroElse) + +// ZeroIfNegative is generic for all vector lengths +template +HWY_API V ZeroIfNegative(V v) { + return IfNegativeThenZeroElse(v, v); +} + +// ------------------------------ BitwiseIfThenElse +#if (defined(HWY_NATIVE_BITWISE_IF_THEN_ELSE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_BITWISE_IF_THEN_ELSE +#undef HWY_NATIVE_BITWISE_IF_THEN_ELSE +#else +#define HWY_NATIVE_BITWISE_IF_THEN_ELSE +#endif + +template +HWY_API V BitwiseIfThenElse(V mask, V yes, V no) { + return Or(And(mask, yes), AndNot(mask, no)); +} + +#endif // HWY_NATIVE_BITWISE_IF_THEN_ELSE + +// ------------------------------ PromoteMaskTo + +#if (defined(HWY_NATIVE_PROMOTE_MASK_TO) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_PROMOTE_MASK_TO +#undef HWY_NATIVE_PROMOTE_MASK_TO +#else +#define HWY_NATIVE_PROMOTE_MASK_TO +#endif + +template +HWY_API Mask PromoteMaskTo(DTo d_to, DFrom d_from, Mask m) { + static_assert( + sizeof(TFromD) > sizeof(TFromD), + "sizeof(TFromD) must be greater than sizeof(TFromD)"); + static_assert( + IsSame, Mask, DTo>>>(), + "Mask must be the same type as Mask, DTo>>"); + + const RebindToSigned di_to; + const RebindToSigned di_from; + + return MaskFromVec(BitCast( + d_to, PromoteTo(di_to, BitCast(di_from, VecFromMask(d_from, m))))); +} + +#endif // HWY_NATIVE_PROMOTE_MASK_TO + +// ------------------------------ DemoteMaskTo + +#if (defined(HWY_NATIVE_DEMOTE_MASK_TO) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_DEMOTE_MASK_TO +#undef HWY_NATIVE_DEMOTE_MASK_TO +#else +#define HWY_NATIVE_DEMOTE_MASK_TO +#endif + +template +HWY_API Mask DemoteMaskTo(DTo d_to, DFrom d_from, Mask m) { + static_assert(sizeof(TFromD) < sizeof(TFromD), + "sizeof(TFromD) must be less than sizeof(TFromD)"); + static_assert( + IsSame, Mask, DTo>>>(), + "Mask must be the same type as Mask, DTo>>"); + + const RebindToSigned di_to; + const RebindToSigned di_from; + + return MaskFromVec( + BitCast(d_to, DemoteTo(di_to, BitCast(di_from, VecFromMask(d_from, m))))); +} + +#endif // HWY_NATIVE_DEMOTE_MASK_TO + +// ------------------------------ InsertIntoUpper +#if (defined(HWY_NATIVE_LOAD_HIGHER) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_LOAD_HIGHER +#undef HWY_NATIVE_LOAD_HIGHER +#else +#define HWY_NATIVE_LOAD_HIGHER +#endif +template (), HWY_IF_LANES_GT_D(D, 1), + HWY_IF_POW2_GT_D(D, -3)> +HWY_API V InsertIntoUpper(D d, T* p, V a) { + Half dh; + const VFromD b = LoadU(dh, p); + return Combine(d, b, LowerHalf(a)); +} +#endif // HWY_NATIVE_LOAD_HIGHER + +// ------------------------------ CombineMasks + +#if (defined(HWY_NATIVE_COMBINE_MASKS) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_COMBINE_MASKS +#undef HWY_NATIVE_COMBINE_MASKS +#else +#define HWY_NATIVE_COMBINE_MASKS +#endif + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template +HWY_API Mask CombineMasks(D d, Mask> hi, Mask> lo) { + const Half dh; + return MaskFromVec(Combine(d, VecFromMask(dh, hi), VecFromMask(dh, lo))); +} +#endif + +#endif // HWY_NATIVE_COMBINE_MASKS + +// ------------------------------ LowerHalfOfMask + +#if (defined(HWY_NATIVE_LOWER_HALF_OF_MASK) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_LOWER_HALF_OF_MASK +#undef HWY_NATIVE_LOWER_HALF_OF_MASK +#else +#define HWY_NATIVE_LOWER_HALF_OF_MASK +#endif + +template +HWY_API Mask LowerHalfOfMask(D d, Mask> m) { + const Twice dt; + return MaskFromVec(LowerHalf(d, VecFromMask(dt, m))); +} + +#endif // HWY_NATIVE_LOWER_HALF_OF_MASK + +// ------------------------------ UpperHalfOfMask + +#if (defined(HWY_NATIVE_UPPER_HALF_OF_MASK) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_UPPER_HALF_OF_MASK +#undef HWY_NATIVE_UPPER_HALF_OF_MASK +#else +#define HWY_NATIVE_UPPER_HALF_OF_MASK +#endif + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template +HWY_API Mask UpperHalfOfMask(D d, Mask> m) { + const Twice dt; + return MaskFromVec(UpperHalf(d, VecFromMask(dt, m))); +} +#endif + +#endif // HWY_NATIVE_UPPER_HALF_OF_MASK + +// ------------------------------ OrderedDemote2MasksTo + +#if (defined(HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#undef HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#else +#define HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#endif + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template +HWY_API Mask OrderedDemote2MasksTo(DTo d_to, DFrom d_from, Mask a, + Mask b) { + static_assert( + sizeof(TFromD) == sizeof(TFromD) / 2, + "sizeof(TFromD) must be equal to sizeof(TFromD) / 2"); + static_assert(IsSame, Mask, DFrom>>>(), + "Mask must be the same type as " + "Mask, DFrom>>>()"); + + const RebindToSigned di_from; + const RebindToSigned di_to; + + const auto va = BitCast(di_from, VecFromMask(d_from, a)); + const auto vb = BitCast(di_from, VecFromMask(d_from, b)); + return MaskFromVec(BitCast(d_to, OrderedDemote2To(di_to, va, vb))); +} +#endif + +#endif // HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO + +// ------------------------------ RotateLeft +template +HWY_API V RotateLeft(V v) { + constexpr size_t kSizeInBits = sizeof(TFromV) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + + constexpr int kRotateRightAmt = + (kBits == 0) ? 0 : static_cast(kSizeInBits) - kBits; + return RotateRight(v); +} + +// ------------------------------ InterleaveWholeLower/InterleaveWholeUpper +#if (defined(HWY_TOGGLE_INTERLEAVE_WHOLE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_TOGGLE_INTERLEAVE_WHOLE +#undef HWY_TOGGLE_INTERLEAVE_WHOLE +#else +#define HWY_TOGGLE_INTERLEAVE_WHOLE +#endif + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + // InterleaveWholeLower(d, a, b) is equivalent to InterleaveLower(a, b) if + // D().MaxBytes() <= 16 is true + return InterleaveLower(d, a, b); +} +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + // InterleaveWholeUpper(d, a, b) is equivalent to InterleaveUpper(a, b) if + // D().MaxBytes() <= 16 is true + return InterleaveUpper(d, a, b); +} + +// InterleaveWholeLower/InterleaveWholeUpper for 32-byte vectors on AVX2/AVX3 +// is implemented in x86_256-inl.h. + +// InterleaveWholeLower/InterleaveWholeUpper for 64-byte vectors on AVX3 is +// implemented in x86_512-inl.h. + +// InterleaveWholeLower/InterleaveWholeUpper for 32-byte vectors on WASM_EMU256 +// is implemented in wasm_256-inl.h. +#endif // HWY_TARGET != HWY_SCALAR + +#endif // HWY_TOGGLE_INTERLEAVE_WHOLE + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +// The InterleaveWholeLower without the optional D parameter is generic for all +// vector lengths. +template +HWY_API V InterleaveWholeLower(V a, V b) { + return InterleaveWholeLower(DFromV(), a, b); +} +#endif // HWY_TARGET != HWY_SCALAR + +// ------------------------------ InterleaveEven + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +// InterleaveEven without the optional D parameter is generic for all vector +// lengths +template +HWY_API V InterleaveEven(V a, V b) { + return InterleaveEven(DFromV(), a, b); +} +#endif + +// ------------------------------ MinNumber/MaxNumber + +#if (defined(HWY_NATIVE_FLOAT_MIN_MAX_NUMBER) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_FLOAT_MIN_MAX_NUMBER +#undef HWY_NATIVE_FLOAT_MIN_MAX_NUMBER +#else +#define HWY_NATIVE_FLOAT_MIN_MAX_NUMBER +#endif + +template +HWY_API V MinNumber(V a, V b) { + return Min(a, b); +} + +template +HWY_API V MaxNumber(V a, V b) { + return Max(a, b); +} + +#endif + +template +HWY_API V MinNumber(V a, V b) { + return Min(a, b); +} + +template +HWY_API V MaxNumber(V a, V b) { + return Max(a, b); +} + +// ------------------------------ MinMagnitude/MaxMagnitude + +#if (defined(HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE +#undef HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE +#else +#define HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE +#endif + +template +HWY_API V MinMagnitude(V a, V b) { + const V abs_a = Abs(a); + const V abs_b = Abs(b); + const V min = Min(IfThenElse(Eq(abs_a, abs_b), a, b), b); + return IfThenElse(Lt(abs_a, abs_b), a, min); +} + +template +HWY_API V MaxMagnitude(V a, V b) { + const V abs_a = Abs(a); + const V abs_b = Abs(b); + // This lvalue appears to be necessary to avoid a clang bug on SVE. + const V max = Max(IfThenElse(Eq(abs_a, abs_b), b, a), a); + return IfThenElse(Lt(abs_a, abs_b), b, max); +} + +#endif // HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE + +template +HWY_API V MinMagnitude(V a, V b) { + const DFromV d; + const RebindToUnsigned du; + const auto abs_a = BitCast(du, Abs(a)); + const auto abs_b = BitCast(du, Abs(b)); + return IfThenElse(RebindMask(d, Lt(abs_a, abs_b)), a, + Min(IfThenElse(RebindMask(d, Eq(abs_a, abs_b)), a, b), b)); +} + +template +HWY_API V MaxMagnitude(V a, V b) { + const DFromV d; + const RebindToUnsigned du; + const auto abs_a = BitCast(du, Abs(a)); + const auto abs_b = BitCast(du, Abs(b)); + return IfThenElse(RebindMask(d, Lt(abs_a, abs_b)), b, + Max(IfThenElse(RebindMask(d, Eq(abs_a, abs_b)), b, a), a)); +} + +template +HWY_API V MinMagnitude(V a, V b) { + return Min(a, b); +} + +template +HWY_API V MaxMagnitude(V a, V b) { + return Max(a, b); +} + +// ------------------------------ AddSub + +template , 1)> +HWY_API V AddSub(V a, V b) { + // AddSub(a, b) for a one-lane vector is equivalent to Sub(a, b) + return Sub(a, b); +} + +// AddSub for F32x2, F32x4, and F64x2 vectors is implemented in x86_128-inl.h on +// SSSE3/SSE4/AVX2/AVX3 + +// AddSub for F32x8 and F64x4 vectors is implemented in x86_256-inl.h on +// AVX2/AVX3 + +// AddSub for F16/F32/F64 vectors on SVE is implemented in arm_sve-inl.h + +// AddSub for integer vectors on SVE2 is implemented in arm_sve-inl.h +template +HWY_API V AddSub(V a, V b) { + using D = DFromV; + using T = TFromD; + using TNegate = If(), MakeSigned, T>; + + const D d; + const Rebind d_negate; + + // Negate the even lanes of b + const auto negated_even_b = OddEven(b, BitCast(d, Neg(BitCast(d_negate, b)))); + + return Add(a, negated_even_b); +} + +// ------------------------------ MaskedAddOr etc. +#if (defined(HWY_NATIVE_MASKED_ARITH) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MASKED_ARITH +#undef HWY_NATIVE_MASKED_ARITH +#else +#define HWY_NATIVE_MASKED_ARITH +#endif + +template +HWY_API V MaskedMinOr(V no, M m, V a, V b) { + return IfThenElse(m, Min(a, b), no); +} + +template +HWY_API V MaskedMaxOr(V no, M m, V a, V b) { + return IfThenElse(m, Max(a, b), no); +} + +template +HWY_API V MaskedAddOr(V no, M m, V a, V b) { + return IfThenElse(m, Add(a, b), no); +} + +template +HWY_API V MaskedSubOr(V no, M m, V a, V b) { + return IfThenElse(m, Sub(a, b), no); +} + +template +HWY_API V MaskedMulOr(V no, M m, V a, V b) { + return IfThenElse(m, Mul(a, b), no); +} + +template +HWY_API V MaskedDivOr(V no, M m, V a, V b) { + const DFromV d; + // Avoid division by zero for masked-out lanes. + const V nonzero = Set(d, TFromD{1}); + return IfThenElse(m, Div(a, IfThenElse(m, b, nonzero)), no); +} + +template +HWY_API V MaskedModOr(V no, M m, V a, V b) { + const DFromV d; + // Avoid division by zero for masked-out lanes. + const V nonzero = Set(d, TFromD{1}); + return IfThenElse(m, Mod(a, IfThenElse(m, b, nonzero)), no); +} + +template +HWY_API V MaskedSatAddOr(V no, M m, V a, V b) { + return IfThenElse(m, SaturatedAdd(a, b), no); +} + +template +HWY_API V MaskedSatSubOr(V no, M m, V a, V b) { + return IfThenElse(m, SaturatedSub(a, b), no); +} +#endif // HWY_NATIVE_MASKED_ARITH + +#if (defined(HWY_NATIVE_ZERO_MASKED_ARITH) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ZERO_MASKED_ARITH +#undef HWY_NATIVE_ZERO_MASKED_ARITH +#else +#define HWY_NATIVE_ZERO_MASKED_ARITH +#endif + +template +HWY_API V MaskedMax(M m, V a, V b) { + return IfThenElseZero(m, (Max(a, b))); +} + +template +HWY_API V MaskedAdd(M m, V a, V b) { + return IfThenElseZero(m, Add(a, b)); +} + +template +HWY_API V MaskedSub(M m, V a, V b) { + return IfThenElseZero(m, Sub(a, b)); +} + +template +HWY_API V MaskedMul(M m, V a, V b) { + return IfThenElseZero(m, Mul(a, b)); +} + +template +HWY_API V MaskedDiv(M m, V a, V b) { + return IfThenElseZero(m, Div(a, b)); +} + +template +HWY_API V MaskedSaturatedAdd(M m, V a, V b) { + return IfThenElseZero(m, SaturatedAdd(a, b)); +} + +template +HWY_API V MaskedSaturatedSub(M m, V a, V b) { + return IfThenElseZero(m, SaturatedSub(a, b)); +} + +template , HWY_IF_I16_D(D)> +HWY_API V MaskedMulFixedPoint15(M m, V a, V b) { + return IfThenElseZero(m, MulFixedPoint15(a, b)); +} + +template +HWY_API V MaskedMulAdd(M m, V mul, V x, V add) { + return IfThenElseZero(m, MulAdd(mul, x, add)); +} + +template +HWY_API V MaskedNegMulAdd(M m, V mul, V x, V add) { + return IfThenElseZero(m, NegMulAdd(mul, x, add)); +} + +template >> +HWY_API VFromD MaskedWidenMulPairwiseAdd(D d32, M m, V16 a, V16 b) { + return IfThenElseZero(m, WidenMulPairwiseAdd(d32, a, b)); +} + +template +HWY_API VFromD MaskedWidenMulPairwiseAdd(DF df, M m, VBF a, VBF b) { + return IfThenElseZero(m, WidenMulPairwiseAdd(df, a, b)); +} +#endif // HWY_NATIVE_ZERO_MASKED_ARITH + +// ------------------------------ MaskedShift +template +HWY_API V MaskedShiftLeft(M m, V a) { + return IfThenElseZero(m, ShiftLeft(a)); +} + +template +HWY_API V MaskedShiftRight(M m, V a) { + return IfThenElseZero(m, ShiftRight(a)); +} + +template +HWY_API V MaskedShiftRightOr(V no, M m, V a) { + return IfThenElse(m, ShiftRight(a), no); +} + +template +HWY_API V MaskedShrOr(V no, M m, V a, V shifts) { + return IfThenElse(m, Shr(a, shifts), no); +} + +// ------------------------------ MaskedEq etc. +#if (defined(HWY_NATIVE_MASKED_COMP) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MASKED_COMP +#undef HWY_NATIVE_MASKED_COMP +#else +#define HWY_NATIVE_MASKED_COMP +#endif + +template +HWY_API auto MaskedEq(M m, V a, V b) -> decltype(a == b) { + return And(m, Eq(a, b)); +} + +template +HWY_API auto MaskedNe(M m, V a, V b) -> decltype(a == b) { + return And(m, Ne(a, b)); +} + +template +HWY_API auto MaskedLt(M m, V a, V b) -> decltype(a == b) { + return And(m, Lt(a, b)); +} + +template +HWY_API auto MaskedGt(M m, V a, V b) -> decltype(a == b) { + return And(m, Gt(a, b)); +} + +template +HWY_API auto MaskedLe(M m, V a, V b) -> decltype(a == b) { + return And(m, Le(a, b)); +} + +template +HWY_API auto MaskedGe(M m, V a, V b) -> decltype(a == b) { + return And(m, Ge(a, b)); +} + +template > +HWY_API MFromD MaskedIsNaN(const M m, const V v) { + return And(m, IsNaN(v)); +} +#endif // HWY_NATIVE_MASKED_COMP + +// ------------------------------ Xor3 + +#if (defined(HWY_NATIVE_XOR3) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_XOR3 +#undef HWY_NATIVE_XOR3 +#else +#define HWY_NATIVE_XOR3 +#endif + +template +HWY_API V Xor3(V x1, V x2, V x3) { + return Xor(x1, Xor(x2, x3)); +} + +#endif // HWY_NATIVE_XOR3 + +// ------------------------------ XorAndNot + +#if (defined(HWY_NATIVE_BCAX) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_BCAX +#undef HWY_NATIVE_BCAX +#else +#define HWY_NATIVE_BCAX +#endif + +template +HWY_API V XorAndNot(const V x, const V a1, const V a2) { + return Xor(x, AndNot(a1, a2)); +} + +#endif // HWY_NATIVE_BCAX + +// ------------------------------ IfNegativeThenNegOrUndefIfZero + +#if (defined(HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#undef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#else +#define HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#endif + +template +HWY_API V IfNegativeThenNegOrUndefIfZero(V mask, V v) { +#if HWY_HAVE_SCALABLE || HWY_TARGET_IS_SVE + // MaskedSubOr is more efficient than IfNegativeThenElse on RVV/SVE + const auto zero = Zero(DFromV()); + return MaskedSubOr(v, Lt(mask, zero), zero, v); +#else + return IfNegativeThenElse(mask, Neg(v), v); +#endif +} + +#endif // HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG + +template +HWY_API V IfNegativeThenNegOrUndefIfZero(V mask, V v) { + return CopySign(v, Xor(mask, v)); +} + +// ------------------------------ SaturatedNeg + +#if (defined(HWY_NATIVE_SATURATED_NEG_8_16_32) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SATURATED_NEG_8_16_32 +#undef HWY_NATIVE_SATURATED_NEG_8_16_32 +#else +#define HWY_NATIVE_SATURATED_NEG_8_16_32 +#endif + +template +HWY_API V SaturatedNeg(V v) { + const DFromV d; + return SaturatedSub(Zero(d), v); +} + +template )> +HWY_API V SaturatedNeg(V v) { + const DFromV d; + +#if HWY_TARGET == HWY_RVV || HWY_TARGET_IS_PPC || HWY_TARGET_IS_SVE || \ + HWY_TARGET_IS_NEON + // RVV/PPC/SVE/NEON have native I32 SaturatedSub instructions + return SaturatedSub(Zero(d), v); +#else + // ~v[i] - ((v[i] > LimitsMin()) ? -1 : 0) is equivalent to + // (v[i] > LimitsMin) ? (-v[i]) : LimitsMax() since + // -v[i] == ~v[i] + 1 == ~v[i] - (-1) and + // ~LimitsMin() == LimitsMax(). + return Sub(Not(v), VecFromMask(d, Gt(v, Set(d, LimitsMin())))); +#endif +} +#endif // HWY_NATIVE_SATURATED_NEG_8_16_32 + +#if (defined(HWY_NATIVE_SATURATED_NEG_64) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SATURATED_NEG_64 +#undef HWY_NATIVE_SATURATED_NEG_64 +#else +#define HWY_NATIVE_SATURATED_NEG_64 +#endif + +template )> +HWY_API V SaturatedNeg(V v) { +#if HWY_TARGET == HWY_RVV || HWY_TARGET_IS_SVE || HWY_TARGET_IS_NEON + // RVV/SVE/NEON have native I64 SaturatedSub instructions + const DFromV d; + return SaturatedSub(Zero(d), v); +#else + const auto neg_v = Neg(v); + return Add(neg_v, BroadcastSignBit(And(v, neg_v))); +#endif +} +#endif // HWY_NATIVE_SATURATED_NEG_64 + +// ------------------------------ SaturatedAbs + +#if (defined(HWY_NATIVE_SATURATED_ABS) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SATURATED_ABS +#undef HWY_NATIVE_SATURATED_ABS +#else +#define HWY_NATIVE_SATURATED_ABS +#endif + +template +HWY_API V SaturatedAbs(V v) { + return Max(v, SaturatedNeg(v)); +} + +#endif + +// ------------------------------ MaskedAbsOr +template +HWY_API V MaskedAbsOr(V no, M m, V v) { + return IfThenElse(m, Abs(v), no); +} + +// ------------------------------ MaskedAbs +template +HWY_API V MaskedAbs(M m, V v) { + return IfThenElseZero(m, Abs(v)); +} + +// ------------------------------ Reductions + +// Targets follow one of two strategies. If HWY_NATIVE_REDUCE_SCALAR is toggled, +// they (RVV/SVE/Armv8/Emu128) implement ReduceSum and SumOfLanes via Set. +// Otherwise, they (Armv7/PPC/scalar/WASM/x86) define zero to most of the +// SumOfLanes overloads. For the latter group, we here define the remaining +// overloads, plus ReduceSum which uses them plus GetLane. +#if (defined(HWY_NATIVE_REDUCE_SCALAR) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_REDUCE_SCALAR +#undef HWY_NATIVE_REDUCE_SCALAR +#else +#define HWY_NATIVE_REDUCE_SCALAR +#endif + +namespace detail { + +// Allows reusing the same shuffle code for SumOfLanes/MinOfLanes/MaxOfLanes. +struct AddFunc { + template + V operator()(V a, V b) const { + return Add(a, b); + } +}; + +struct MinFunc { + template + V operator()(V a, V b) const { + return Min(a, b); + } +}; + +struct MaxFunc { + template + V operator()(V a, V b) const { + return Max(a, b); + } +}; + +// No-op for vectors of at most one block. +template +HWY_INLINE VFromD ReduceAcrossBlocks(D, Func, VFromD v) { + return v; +} + +// Reduces a lane with its counterpart in other block(s). Shared by AVX2 and +// WASM_EMU256. AVX3 has its own overload. +template +HWY_INLINE VFromD ReduceAcrossBlocks(D /*d*/, Func f, VFromD v) { + return f(v, SwapAdjacentBlocks(v)); +} + +// These return the reduction result broadcasted across all lanes. They assume +// the caller has already reduced across blocks. + +template +HWY_INLINE VFromD ReduceWithinBlocks(D d, Func f, VFromD v10) { + return f(v10, Reverse2(d, v10)); +} + +template +HWY_INLINE VFromD ReduceWithinBlocks(D d, Func f, VFromD v3210) { + const VFromD v0123 = Reverse4(d, v3210); + const VFromD v03_12_12_03 = f(v3210, v0123); + const VFromD v12_03_03_12 = Reverse2(d, v03_12_12_03); + return f(v03_12_12_03, v12_03_03_12); +} + +template +HWY_INLINE VFromD ReduceWithinBlocks(D d, Func f, VFromD v76543210) { + // The upper half is reversed from the lower half; omit for brevity. + const VFromD v34_25_16_07 = f(v76543210, Reverse8(d, v76543210)); + const VFromD v0347_1625_1625_0347 = + f(v34_25_16_07, Reverse4(d, v34_25_16_07)); + return f(v0347_1625_1625_0347, Reverse2(d, v0347_1625_1625_0347)); +} + +template +HWY_INLINE VFromD ReduceWithinBlocks(D d, Func f, VFromD v) { + const RepartitionToWide dw; + using VW = VFromD; + const VW vw = BitCast(dw, v); + // f is commutative, so no need to adapt for HWY_IS_LITTLE_ENDIAN. + const VW even = And(vw, Set(dw, 0xFF)); + const VW odd = ShiftRight<8>(vw); + const VW reduced = ReduceWithinBlocks(dw, f, f(even, odd)); +#if HWY_IS_LITTLE_ENDIAN + return DupEven(BitCast(d, reduced)); +#else + return DupOdd(BitCast(d, reduced)); +#endif +} + +template +HWY_INLINE VFromD ReduceWithinBlocks(D d, Func f, VFromD v) { + const RepartitionToWide dw; + using VW = VFromD; + const VW vw = BitCast(dw, v); + // Sign-extend + // f is commutative, so no need to adapt for HWY_IS_LITTLE_ENDIAN. + const VW even = ShiftRight<8>(ShiftLeft<8>(vw)); + const VW odd = ShiftRight<8>(vw); + const VW reduced = ReduceWithinBlocks(dw, f, f(even, odd)); +#if HWY_IS_LITTLE_ENDIAN + return DupEven(BitCast(d, reduced)); +#else + return DupOdd(BitCast(d, reduced)); +#endif +} + +} // namespace detail + +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + const detail::AddFunc f; + v = detail::ReduceAcrossBlocks(d, f, v); + return detail::ReduceWithinBlocks(d, f, v); +} +template +HWY_API VFromD MinOfLanes(D d, VFromD v) { + const detail::MinFunc f; + v = detail::ReduceAcrossBlocks(d, f, v); + return detail::ReduceWithinBlocks(d, f, v); +} +template +HWY_API VFromD MaxOfLanes(D d, VFromD v) { + const detail::MaxFunc f; + v = detail::ReduceAcrossBlocks(d, f, v); + return detail::ReduceWithinBlocks(d, f, v); +} + +template +HWY_API TFromD ReduceSum(D d, VFromD v) { + return GetLane(SumOfLanes(d, v)); +} +template +HWY_API TFromD ReduceMin(D d, VFromD v) { + return GetLane(MinOfLanes(d, v)); +} +template +HWY_API TFromD ReduceMax(D d, VFromD v) { + return GetLane(MaxOfLanes(d, v)); +} + +#endif // HWY_NATIVE_REDUCE_SCALAR + +// Corner cases for both generic and native implementations: +// N=1 (native covers N=2 e.g. for u64x2 and even u32x2 on Arm) +template +HWY_API TFromD ReduceSum(D /*d*/, VFromD v) { + return GetLane(v); +} +template +HWY_API TFromD ReduceMin(D /*d*/, VFromD v) { + return GetLane(v); +} +template +HWY_API TFromD ReduceMax(D /*d*/, VFromD v) { + return GetLane(v); +} + +template +HWY_API VFromD SumOfLanes(D /* tag */, VFromD v) { + return v; +} +template +HWY_API VFromD MinOfLanes(D /* tag */, VFromD v) { + return v; +} +template +HWY_API VFromD MaxOfLanes(D /* tag */, VFromD v) { + return v; +} + +// N=4 for 8-bit is still less than the minimum native size. + +// ARMv7 NEON/PPC/RVV/SVE have target-specific implementations of the N=4 I8/U8 +// ReduceSum operations +#if (defined(HWY_NATIVE_REDUCE_SUM_4_UI8) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_REDUCE_SUM_4_UI8 +#undef HWY_NATIVE_REDUCE_SUM_4_UI8 +#else +#define HWY_NATIVE_REDUCE_SUM_4_UI8 +#endif + +template +HWY_API TFromD ReduceSum(D d, VFromD v) { + const Twice> dw; + return static_cast>(ReduceSum(dw, PromoteTo(dw, v))); +} +#endif // HWY_NATIVE_REDUCE_SUM_4_UI8 + +// RVV/SVE have target-specific implementations of the N=4 I8/U8 +// ReduceMin/ReduceMax operations +#if (defined(HWY_NATIVE_REDUCE_MINMAX_4_UI8) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#undef HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#else +#define HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#endif +template +HWY_API TFromD ReduceMin(D d, VFromD v) { + const Twice> dw; + return static_cast>(ReduceMin(dw, PromoteTo(dw, v))); +} +template +HWY_API TFromD ReduceMax(D d, VFromD v) { + const Twice> dw; + return static_cast>(ReduceMax(dw, PromoteTo(dw, v))); +} +#endif // HWY_NATIVE_REDUCE_MINMAX_4_UI8 + +#if (defined(HWY_NATIVE_MASKED_REDUCE_SCALAR) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MASKED_REDUCE_SCALAR +#undef HWY_NATIVE_MASKED_REDUCE_SCALAR +#else +#define HWY_NATIVE_MASKED_REDUCE_SCALAR +#endif + +template +HWY_API TFromD MaskedReduceSum(D d, M m, VFromD v) { + return ReduceSum(d, IfThenElseZero(m, v)); +} +template +HWY_API TFromD MaskedReduceMin(D d, M m, VFromD v) { + return ReduceMin( + d, IfThenElse(m, v, Set(d, hwy::PositiveInfOrHighestValue>()))); +} +template +HWY_API TFromD MaskedReduceMax(D d, M m, VFromD v) { + return ReduceMax( + d, IfThenElse(m, v, Set(d, hwy::NegativeInfOrLowestValue>()))); +} + +#endif // HWY_NATIVE_MASKED_REDUCE_SCALAR + +// ------------------------------ IsEitherNaN +#if (defined(HWY_NATIVE_IS_EITHER_NAN) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_IS_EITHER_NAN +#undef HWY_NATIVE_IS_EITHER_NAN +#else +#define HWY_NATIVE_IS_EITHER_NAN +#endif + +template +HWY_API MFromD> IsEitherNaN(V a, V b) { + return Or(IsNaN(a), IsNaN(b)); +} + +#endif // HWY_NATIVE_IS_EITHER_NAN + +// ------------------------------ IsInf, IsFinite + +// AVX3 has target-specific implementations of these. +#if (defined(HWY_NATIVE_ISINF) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ISINF +#undef HWY_NATIVE_ISINF +#else +#define HWY_NATIVE_ISINF +#endif + +template > +HWY_API MFromD IsInf(const V v) { + using T = TFromD; + const D d; + const RebindToUnsigned du; + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask( + d, + Eq(Add(vu, vu), + Set(du, static_cast>(hwy::MaxExponentTimes2())))); +} + +// Returns whether normal/subnormal/zero. +template > +HWY_API MFromD IsFinite(const V v) { + using T = TFromD; + const D d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); +// 'Shift left' to clear the sign bit. MSVC seems to generate incorrect code +// for AVX2 if we instead add vu + vu. +#if HWY_COMPILER_MSVC + const VFromD shl = ShiftLeft<1>(vu); +#else + const VFromD shl = Add(vu, vu); +#endif + + // Then shift right so we can compare with the max exponent (cannot compare + // with MaxExponentTimes2 directly because it is negative and non-negative + // floats would be greater). + const VFromD exp = + BitCast(di, ShiftRight() + 1>(shl)); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); +} + +#endif // HWY_NATIVE_ISINF + +// ------------------------------ CeilInt/FloorInt +#if (defined(HWY_NATIVE_CEIL_FLOOR_INT) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_CEIL_FLOOR_INT +#undef HWY_NATIVE_CEIL_FLOOR_INT +#else +#define HWY_NATIVE_CEIL_FLOOR_INT +#endif + +template +HWY_API VFromD>> CeilInt(V v) { + const DFromV d; + const RebindToSigned di; + return ConvertTo(di, Ceil(v)); +} + +template +HWY_API VFromD>> FloorInt(V v) { + const DFromV d; + const RebindToSigned di; + return ConvertTo(di, Floor(v)); +} + +#endif // HWY_NATIVE_CEIL_FLOOR_INT + +// ------------------------------ MulByPow2/MulByFloorPow2 + +#if (defined(HWY_NATIVE_MUL_BY_POW2) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MUL_BY_POW2 +#undef HWY_NATIVE_MUL_BY_POW2 +#else +#define HWY_NATIVE_MUL_BY_POW2 +#endif + +template +HWY_API V MulByPow2(V v, VFromD>> exp) { + const DFromV df; + const RebindToSigned di; + + using TF = TFromD; + using TI = TFromD; + + using VF = VFromD; + using VI = VFromD; + + constexpr TI kMaxBiasedExp = MaxExponentField(); + static_assert(kMaxBiasedExp > 0, "kMaxBiasedExp > 0 must be true"); + + constexpr TI kExpBias = static_cast(kMaxBiasedExp >> 1); + static_assert(kExpBias > 0, "kExpBias > 0 must be true"); + static_assert(kExpBias <= LimitsMax() / 3, + "kExpBias <= LimitsMax() / 3 must be true"); + +#if HWY_TARGET > HWY_AVX3 && HWY_TARGET <= HWY_SSE4 + using TExpMinMax = If<(sizeof(TI) <= 4), TI, int32_t>; +#elif (HWY_TARGET >= HWY_SSSE3 && HWY_TARGET <= HWY_SSE2) || \ + HWY_TARGET == HWY_WASM || HWY_TARGET == HWY_WASM_EMU256 + using TExpMinMax = int16_t; +#else + using TExpMinMax = TI; +#endif + + static_assert(kExpBias <= static_cast(LimitsMax() / 3), + "kExpBias <= LimitsMax() / 3 must be true"); + + const Repartition d_exp_min_max; + + constexpr int kNumOfMantBits = MantissaBits(); + + const VI exp_bias = Set(di, kExpBias); + + const VI clamped_exp = + Clamp(exp, Set(di, 3 - 3 * kExpBias), Set(di, 3 * kExpBias)); + + const auto min_scale_factor_exp = + BitCast(d_exp_min_max, Set(di, 1 - kExpBias)); + const auto max_scale_factor_exp = BitCast(d_exp_min_max, exp_bias); + + // If clamped_exp[i] < 0, ensure that 1 - kExpBias <= exp1[i] <= 0, + // 1 - kExpBias <= exp2[i] <= 0, and 1 - kExpBias <= exp3[i] <= 0 are + // true. + + // In addition, if clamped_exp[i] < 1 - kExpBias, ensure that + // exp3[i] == 1 - kExpBias to ensure results are correctly rounded if the + // exact value of |x[i] * factor1[i] * factor2[i] * factor3[i]| is less than + // the smallest positive normal value. + + // Otherwise, if clamped_exp[i] >= 0, ensure that 0 <= exp1[i] <= kExpBias, + // 0 <= exp2[i] <= kExpBias, and 0 <= exp3[i] <= kExpBias are all true. + + const VI exp3 = + BitCast(di, Clamp(BitCast(d_exp_min_max, clamped_exp), + min_scale_factor_exp, max_scale_factor_exp)); + + const VI clamped_exp_minus_exp3 = Sub(clamped_exp, exp3); + const VI exp2 = + BitCast(di, Clamp(BitCast(d_exp_min_max, clamped_exp_minus_exp3), + min_scale_factor_exp, max_scale_factor_exp)); + const VI exp1 = Sub(clamped_exp_minus_exp3, exp2); + + const VF factor1 = + BitCast(df, ShiftLeft(Add(exp1, exp_bias))); + const VF factor2 = + BitCast(df, ShiftLeft(Add(exp2, exp_bias))); + const VF factor3 = + BitCast(df, ShiftLeft(Add(exp3, exp_bias))); + + // If exp2[i] < 0, then clamped_exp[i] < 1 - kExpBias and + // exp3[i] == 1 - kExpBias will both be true. factor3[i] will be equal to the + // smallest positive normal value if exp3[i] == 1 - kExpBias. + + // If exp2[i] >= 0, then exp1[i] >= 0 and factor1[i] * factor2[i] >= 1. + + // If exp2[i] < 0 and the exact value of |v[i] * factor1[i] * factor2[i]| is + // less than the smallest positive normal value, then the exact value of + // |v[i] * factor1[i] * factor2[i] * factor3[i]| will be much smaller than + // half of the smallest positive denormal value (since factor3[i] will be + // equal to the smallest positive normal value in this case), resulting in a + // correctly rounded result in this case. + + // If kExpBias >= kNumOfMantBits + 3 and exp3[i] == 1 - kExpBias are both + // true, then factor3[i] will be small enough such that + // v[i] * factor1[i] * factor2[i] * factor3[i] will be correctly rounded, + // even if the exact value of |v[i] * factor1[i] * factor2[i]| is smaller than + // the smallest positive normal value. + + // kExpBias >= kNumOfMantBits + 3 is true for the F16, F32, and F64 + // floating-point types. + + // Otherwise, either exp2[i] >= 0, the exact value of + // |v[i] * factor1[i] * factor2[i]| is greater than or equal to the smallest + // positive normal value, or v[i] is NaN. In these cases, + // v[i] * factor1[i] * factor2[i] will either be exact or overflow to + // infinity (if clamped_exp[i] > 0 and v[i] is a non-zero finite value), + // resulting in a correctly rounded result if the exact value of + // |v[i] * factor1[i] * factor2[i] * factor3[i]| is less than the smallest + // positive normal value. + + return Mul(Mul(Mul(v, factor1), factor2), factor3); +} + +template +HWY_API V MulByFloorPow2(V v, V exp) { + const DFromV df; + + // MulByFloorPow2 special cases: + // MulByFloorPow2(v, NaN) => NaN + // MulByFloorPow2(0, inf) => NaN + // MulByFloorPow2(inf, -inf) => NaN + // MulByFloorPow2(-inf, -inf) => NaN + const auto is_special_case_with_nan_result = + Or(IsNaN(exp), + And(Eq(Abs(v), IfNegativeThenElseZero(exp, Inf(df))), IsInf(exp))); + + return IfThenElse(is_special_case_with_nan_result, NaN(df), + MulByPow2(v, FloorInt(exp))); +} + +#endif // HWY_NATIVE_MUL_BY_POW2 + +// ------------------------------ GetBiasedExponent +#if (defined(HWY_NATIVE_GET_BIASED_EXPONENT) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_GET_BIASED_EXPONENT +#undef HWY_NATIVE_GET_BIASED_EXPONENT +#else +#define HWY_NATIVE_GET_BIASED_EXPONENT +#endif + +template +HWY_API VFromD>> GetBiasedExponent(V v) { + using T = TFromV; + + const DFromV d; + const RebindToUnsigned du; + + constexpr int kNumOfMantBits = MantissaBits(); + return ShiftRight(BitCast(du, Abs(v))); +} + +#endif + +// ------------------------------ GetExponent + +#if (defined(HWY_NATIVE_GET_EXPONENT) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_GET_EXPONENT +#undef HWY_NATIVE_GET_EXPONENT +#else +#define HWY_NATIVE_GET_EXPONENT +#endif + +template +HWY_API V GetExponent(V v) { + const DFromV d; + using T = TFromV; + const RebindToSigned di; + + const auto exponent_offset = Set(di, MaxExponentField() >> 1); + + // extract exponent bits as integer + const auto encoded_exponent = GetBiasedExponent(v); + const auto exponent_int = Sub(BitCast(di, encoded_exponent), exponent_offset); + + // convert integer to original type + return ConvertTo(d, exponent_int); +} + +#endif // HWY_NATIVE_GET_EXPONENT +// ------------------------------ LoadInterleaved2 + +#if HWY_IDE || \ + (defined(HWY_NATIVE_LOAD_STORE_INTERLEAVED) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +template +HWY_API void LoadInterleaved2(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1) { + const VFromD A = LoadU(d, unaligned); // v1[1] v0[1] v1[0] v0[0] + const VFromD B = LoadU(d, unaligned + Lanes(d)); + v0 = ConcatEven(d, B, A); + v1 = ConcatOdd(d, B, A); +} + +template +HWY_API void LoadInterleaved2(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); +} + +// ------------------------------ LoadInterleaved3 (CombineShiftRightBytes) + +namespace detail { + +#if HWY_IDE +template +HWY_INLINE V ShuffleTwo1230(V a, V /* b */) { + return a; +} +template +HWY_INLINE V ShuffleTwo2301(V a, V /* b */) { + return a; +} +template +HWY_INLINE V ShuffleTwo3012(V a, V /* b */) { + return a; +} +#endif // HWY_IDE + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_INLINE void LoadTransposedBlocks3(D d, + const TFromD* HWY_RESTRICT unaligned, + VFromD& A, VFromD& B, + VFromD& C) { + constexpr size_t kN = MaxLanes(d); + A = LoadU(d, unaligned + 0 * kN); + B = LoadU(d, unaligned + 1 * kN); + C = LoadU(d, unaligned + 2 * kN); +} + +} // namespace detail + +template +HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + const RebindToUnsigned du; + using V = VFromD; + using VU = VFromD; + // Compact notation so these fit on one line: 12 := v1[2]. + V A; // 05 24 14 04 23 13 03 22 12 02 21 11 01 20 10 00 + V B; // 1a 0a 29 19 09 28 18 08 27 17 07 26 16 06 25 15 + V C; // 2f 1f 0f 2e 1e 0e 2d 1d 0d 2c 1c 0c 2b 1b 0b 2a + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + // Compress all lanes belonging to v0 into consecutive lanes. + constexpr uint8_t Z = 0x80; + const VU idx_v0A = + Dup128VecFromValues(du, 0, 3, 6, 9, 12, 15, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z); + const VU idx_v0B = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, Z, 2, 5, 8, 11, 14, Z, Z, Z, Z, Z); + const VU idx_v0C = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, 1, 4, 7, 10, 13); + const VU idx_v1A = + Dup128VecFromValues(du, 1, 4, 7, 10, 13, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z); + const VU idx_v1B = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, 0, 3, 6, 9, 12, 15, Z, Z, Z, Z, Z); + const VU idx_v1C = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, 2, 5, 8, 11, 14); + const VU idx_v2A = + Dup128VecFromValues(du, 2, 5, 8, 11, 14, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z); + const VU idx_v2B = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, 1, 4, 7, 10, 13, Z, Z, Z, Z, Z, Z); + const VU idx_v2C = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, 0, 3, 6, 9, 12, 15); + const V v0L = BitCast(d, TableLookupBytesOr0(A, idx_v0A)); + const V v0M = BitCast(d, TableLookupBytesOr0(B, idx_v0B)); + const V v0U = BitCast(d, TableLookupBytesOr0(C, idx_v0C)); + const V v1L = BitCast(d, TableLookupBytesOr0(A, idx_v1A)); + const V v1M = BitCast(d, TableLookupBytesOr0(B, idx_v1B)); + const V v1U = BitCast(d, TableLookupBytesOr0(C, idx_v1C)); + const V v2L = BitCast(d, TableLookupBytesOr0(A, idx_v2A)); + const V v2M = BitCast(d, TableLookupBytesOr0(B, idx_v2B)); + const V v2U = BitCast(d, TableLookupBytesOr0(C, idx_v2C)); + v0 = Xor3(v0L, v0M, v0U); + v1 = Xor3(v1L, v1M, v1U); + v2 = Xor3(v2L, v2M, v2U); +} + +// 8-bit lanes x8 +template +HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + const RebindToUnsigned du; + using V = VFromD; + using VU = VFromD; + V A; // v1[2] v0[2] v2[1] v1[1] v0[1] v2[0] v1[0] v0[0] + V B; // v0[5] v2[4] v1[4] v0[4] v2[3] v1[3] v0[3] v2[2] + V C; // v2[7] v1[7] v0[7] v2[6] v1[6] v0[6] v2[5] v1[5] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + // Compress all lanes belonging to v0 into consecutive lanes. + constexpr uint8_t Z = 0x80; + const VU idx_v0A = + Dup128VecFromValues(du, 0, 3, 6, Z, Z, Z, Z, Z, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v0B = + Dup128VecFromValues(du, Z, Z, Z, 1, 4, 7, Z, Z, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v0C = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, Z, 2, 5, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v1A = + Dup128VecFromValues(du, 1, 4, 7, Z, Z, Z, Z, Z, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v1B = + Dup128VecFromValues(du, Z, Z, Z, 2, 5, Z, Z, Z, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v1C = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, 0, 3, 6, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v2A = + Dup128VecFromValues(du, 2, 5, Z, Z, Z, Z, Z, Z, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v2B = + Dup128VecFromValues(du, Z, Z, 0, 3, 6, Z, Z, Z, 0, 0, 0, 0, 0, 0, 0, 0); + const VU idx_v2C = + Dup128VecFromValues(du, Z, Z, Z, Z, Z, 1, 4, 7, 0, 0, 0, 0, 0, 0, 0, 0); + const V v0L = BitCast(d, TableLookupBytesOr0(A, idx_v0A)); + const V v0M = BitCast(d, TableLookupBytesOr0(B, idx_v0B)); + const V v0U = BitCast(d, TableLookupBytesOr0(C, idx_v0C)); + const V v1L = BitCast(d, TableLookupBytesOr0(A, idx_v1A)); + const V v1M = BitCast(d, TableLookupBytesOr0(B, idx_v1B)); + const V v1U = BitCast(d, TableLookupBytesOr0(C, idx_v1C)); + const V v2L = BitCast(d, TableLookupBytesOr0(A, idx_v2A)); + const V v2M = BitCast(d, TableLookupBytesOr0(B, idx_v2B)); + const V v2U = BitCast(d, TableLookupBytesOr0(C, idx_v2C)); + v0 = Xor3(v0L, v0M, v0U); + v1 = Xor3(v1L, v1M, v1U); + v2 = Xor3(v2L, v2M, v2U); +} + +// 16-bit lanes x8 +template +HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + const RebindToUnsigned du; + const Repartition du8; + using V = VFromD; + using VU8 = VFromD; + V A; // v1[2] v0[2] v2[1] v1[1] v0[1] v2[0] v1[0] v0[0] + V B; // v0[5] v2[4] v1[4] v0[4] v2[3] v1[3] v0[3] v2[2] + V C; // v2[7] v1[7] v0[7] v2[6] v1[6] v0[6] v2[5] v1[5] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + // Compress all lanes belonging to v0 into consecutive lanes. Same as above, + // but each element of the array contains a byte index for a byte of a lane. + constexpr uint8_t Z = 0x80; + const VU8 idx_v0A = Dup128VecFromValues(du8, 0x00, 0x01, 0x06, 0x07, 0x0C, + 0x0D, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z); + const VU8 idx_v0B = Dup128VecFromValues(du8, Z, Z, Z, Z, Z, Z, 0x02, 0x03, + 0x08, 0x09, 0x0E, 0x0F, Z, Z, Z, Z); + const VU8 idx_v0C = Dup128VecFromValues(du8, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, + Z, 0x04, 0x05, 0x0A, 0x0B); + const VU8 idx_v1A = Dup128VecFromValues(du8, 0x02, 0x03, 0x08, 0x09, 0x0E, + 0x0F, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z); + const VU8 idx_v1B = Dup128VecFromValues(du8, Z, Z, Z, Z, Z, Z, 0x04, 0x05, + 0x0A, 0x0B, Z, Z, Z, Z, Z, Z); + const VU8 idx_v1C = Dup128VecFromValues(du8, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, + 0x00, 0x01, 0x06, 0x07, 0x0C, 0x0D); + const VU8 idx_v2A = Dup128VecFromValues(du8, 0x04, 0x05, 0x0A, 0x0B, Z, Z, Z, + Z, Z, Z, Z, Z, Z, Z, Z, Z); + const VU8 idx_v2B = Dup128VecFromValues(du8, Z, Z, Z, Z, 0x00, 0x01, 0x06, + 0x07, 0x0C, 0x0D, Z, Z, Z, Z, Z, Z); + const VU8 idx_v2C = Dup128VecFromValues(du8, Z, Z, Z, Z, Z, Z, Z, Z, Z, Z, + 0x02, 0x03, 0x08, 0x09, 0x0E, 0x0F); + const V v0L = TableLookupBytesOr0(A, BitCast(d, idx_v0A)); + const V v0M = TableLookupBytesOr0(B, BitCast(d, idx_v0B)); + const V v0U = TableLookupBytesOr0(C, BitCast(d, idx_v0C)); + const V v1L = TableLookupBytesOr0(A, BitCast(d, idx_v1A)); + const V v1M = TableLookupBytesOr0(B, BitCast(d, idx_v1B)); + const V v1U = TableLookupBytesOr0(C, BitCast(d, idx_v1C)); + const V v2L = TableLookupBytesOr0(A, BitCast(d, idx_v2A)); + const V v2M = TableLookupBytesOr0(B, BitCast(d, idx_v2B)); + const V v2U = TableLookupBytesOr0(C, BitCast(d, idx_v2C)); + v0 = Xor3(v0L, v0M, v0U); + v1 = Xor3(v1L, v1M, v1U); + v2 = Xor3(v2L, v2M, v2U); +} + +template +HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + using V = VFromD; + V A; // v0[1] v2[0] v1[0] v0[0] + V B; // v1[2] v0[2] v2[1] v1[1] + V C; // v2[3] v1[3] v0[3] v2[2] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + + const V vxx_02_03_xx = OddEven(C, B); + v0 = detail::ShuffleTwo1230(A, vxx_02_03_xx); + + // Shuffle2301 takes the upper/lower halves of the output from one input, so + // we cannot just combine 13 and 10 with 12 and 11 (similar to v0/v2). Use + // OddEven because it may have higher throughput than Shuffle. + const V vxx_xx_10_11 = OddEven(A, B); + const V v12_13_xx_xx = OddEven(B, C); + v1 = detail::ShuffleTwo2301(vxx_xx_10_11, v12_13_xx_xx); + + const V vxx_20_21_xx = OddEven(B, A); + v2 = detail::ShuffleTwo3012(vxx_20_21_xx, C); +} + +template +HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + VFromD A; // v1[0] v0[0] + VFromD B; // v0[1] v2[0] + VFromD C; // v2[1] v1[1] + detail::LoadTransposedBlocks3(d, unaligned, A, B, C); + v0 = OddEven(B, A); + v1 = CombineShiftRightBytes)>(d, C, A); + v2 = OddEven(C, B); +} + +template , HWY_IF_LANES_D(D, 1)> +HWY_API void LoadInterleaved3(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); +} + +// ------------------------------ LoadInterleaved4 + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_INLINE void LoadTransposedBlocks4(D d, + const TFromD* HWY_RESTRICT unaligned, + VFromD& vA, VFromD& vB, + VFromD& vC, VFromD& vD) { + constexpr size_t kN = MaxLanes(d); + vA = LoadU(d, unaligned + 0 * kN); + vB = LoadU(d, unaligned + 1 * kN); + vC = LoadU(d, unaligned + 2 * kN); + vD = LoadU(d, unaligned + 3 * kN); +} + +} // namespace detail + +template +HWY_API void LoadInterleaved4(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + const Repartition d64; + using V64 = VFromD; + using V = VFromD; + // 16 lanes per block; the lowest four blocks are at the bottom of vA..vD. + // Here int[i] means the four interleaved values of the i-th 4-tuple and + // int[3..0] indicates four consecutive 4-tuples (0 = least-significant). + V vA; // int[13..10] int[3..0] + V vB; // int[17..14] int[7..4] + V vC; // int[1b..18] int[b..8] + V vD; // int[1f..1c] int[f..c] + detail::LoadTransposedBlocks4(d, unaligned, vA, vB, vC, vD); + + // For brevity, the comments only list the lower block (upper = lower + 0x10) + const V v5140 = InterleaveLower(d, vA, vB); // int[5,1,4,0] + const V vd9c8 = InterleaveLower(d, vC, vD); // int[d,9,c,8] + const V v7362 = InterleaveUpper(d, vA, vB); // int[7,3,6,2] + const V vfbea = InterleaveUpper(d, vC, vD); // int[f,b,e,a] + + const V v6420 = InterleaveLower(d, v5140, v7362); // int[6,4,2,0] + const V veca8 = InterleaveLower(d, vd9c8, vfbea); // int[e,c,a,8] + const V v7531 = InterleaveUpper(d, v5140, v7362); // int[7,5,3,1] + const V vfdb9 = InterleaveUpper(d, vd9c8, vfbea); // int[f,d,b,9] + + const V64 v10L = BitCast(d64, InterleaveLower(d, v6420, v7531)); // v10[7..0] + const V64 v10U = BitCast(d64, InterleaveLower(d, veca8, vfdb9)); // v10[f..8] + const V64 v32L = BitCast(d64, InterleaveUpper(d, v6420, v7531)); // v32[7..0] + const V64 v32U = BitCast(d64, InterleaveUpper(d, veca8, vfdb9)); // v32[f..8] + + v0 = BitCast(d, InterleaveLower(d64, v10L, v10U)); + v1 = BitCast(d, InterleaveUpper(d64, v10L, v10U)); + v2 = BitCast(d, InterleaveLower(d64, v32L, v32U)); + v3 = BitCast(d, InterleaveUpper(d64, v32L, v32U)); +} + +template +HWY_API void LoadInterleaved4(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + // In the last step, we interleave by half of the block size, which is usually + // 8 bytes but half that for 8-bit x8 vectors. + using TW = hwy::UnsignedFromSize; + const Repartition dw; + using VW = VFromD; + + // (Comments are for 256-bit vectors.) + // 8 lanes per block; the lowest four blocks are at the bottom of vA..vD. + VFromD vA; // v3210[9]v3210[8] v3210[1]v3210[0] + VFromD vB; // v3210[b]v3210[a] v3210[3]v3210[2] + VFromD vC; // v3210[d]v3210[c] v3210[5]v3210[4] + VFromD vD; // v3210[f]v3210[e] v3210[7]v3210[6] + detail::LoadTransposedBlocks4(d, unaligned, vA, vB, vC, vD); + + const VFromD va820 = InterleaveLower(d, vA, vB); // v3210[a,8] v3210[2,0] + const VFromD vec64 = InterleaveLower(d, vC, vD); // v3210[e,c] v3210[6,4] + const VFromD vb931 = InterleaveUpper(d, vA, vB); // v3210[b,9] v3210[3,1] + const VFromD vfd75 = InterleaveUpper(d, vC, vD); // v3210[f,d] v3210[7,5] + + const VW v10_b830 = // v10[b..8] v10[3..0] + BitCast(dw, InterleaveLower(d, va820, vb931)); + const VW v10_fc74 = // v10[f..c] v10[7..4] + BitCast(dw, InterleaveLower(d, vec64, vfd75)); + const VW v32_b830 = // v32[b..8] v32[3..0] + BitCast(dw, InterleaveUpper(d, va820, vb931)); + const VW v32_fc74 = // v32[f..c] v32[7..4] + BitCast(dw, InterleaveUpper(d, vec64, vfd75)); + + v0 = BitCast(d, InterleaveLower(dw, v10_b830, v10_fc74)); + v1 = BitCast(d, InterleaveUpper(dw, v10_b830, v10_fc74)); + v2 = BitCast(d, InterleaveLower(dw, v32_b830, v32_fc74)); + v3 = BitCast(d, InterleaveUpper(dw, v32_b830, v32_fc74)); +} + +template +HWY_API void LoadInterleaved4(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + using V = VFromD; + V vA; // v3210[4] v3210[0] + V vB; // v3210[5] v3210[1] + V vC; // v3210[6] v3210[2] + V vD; // v3210[7] v3210[3] + detail::LoadTransposedBlocks4(d, unaligned, vA, vB, vC, vD); + const V v10e = InterleaveLower(d, vA, vC); // v1[6,4] v0[6,4] v1[2,0] v0[2,0] + const V v10o = InterleaveLower(d, vB, vD); // v1[7,5] v0[7,5] v1[3,1] v0[3,1] + const V v32e = InterleaveUpper(d, vA, vC); // v3[6,4] v2[6,4] v3[2,0] v2[2,0] + const V v32o = InterleaveUpper(d, vB, vD); // v3[7,5] v2[7,5] v3[3,1] v2[3,1] + + v0 = InterleaveLower(d, v10e, v10o); + v1 = InterleaveUpper(d, v10e, v10o); + v2 = InterleaveLower(d, v32e, v32o); + v3 = InterleaveUpper(d, v32e, v32o); +} + +template +HWY_API void LoadInterleaved4(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + VFromD vA, vB, vC, vD; + detail::LoadTransposedBlocks4(d, unaligned, vA, vB, vC, vD); + v0 = InterleaveLower(d, vA, vC); + v1 = InterleaveUpper(d, vA, vC); + v2 = InterleaveLower(d, vB, vD); + v3 = InterleaveUpper(d, vB, vD); +} + +// Any T x1 +template , HWY_IF_LANES_D(D, 1)> +HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); + v3 = LoadU(d, unaligned + 3); +} + +// ------------------------------ StoreInterleaved2 + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_INLINE void StoreTransposedBlocks2(VFromD A, VFromD B, D d, + TFromD* HWY_RESTRICT unaligned) { + constexpr size_t kN = MaxLanes(d); + StoreU(A, d, unaligned + 0 * kN); + StoreU(B, d, unaligned + 1 * kN); +} + +} // namespace detail + +// >= 128 bit vector +template +HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, + TFromD* HWY_RESTRICT unaligned) { + const auto v10L = InterleaveLower(d, v0, v1); // .. v1[0] v0[0] + const auto v10U = InterleaveUpper(d, v0, v1); // .. v1[kN/2] v0[kN/2] + detail::StoreTransposedBlocks2(v10L, v10U, d, unaligned); +} + +// <= 64 bits +template +HWY_API void StoreInterleaved2(V part0, V part1, D d, + TFromD* HWY_RESTRICT unaligned) { + const Twice d2; + const auto v0 = ZeroExtendVector(d2, part0); + const auto v1 = ZeroExtendVector(d2, part1); + const auto v10 = InterleaveLower(d2, v0, v1); + StoreU(v10, d2, unaligned); +} + +// ------------------------------ StoreInterleaved3 (CombineShiftRightBytes, +// TableLookupBytes) + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_INLINE void StoreTransposedBlocks3(VFromD A, VFromD B, VFromD C, + D d, TFromD* HWY_RESTRICT unaligned) { + constexpr size_t kN = MaxLanes(d); + StoreU(A, d, unaligned + 0 * kN); + StoreU(B, d, unaligned + 1 * kN); + StoreU(C, d, unaligned + 2 * kN); +} + +} // namespace detail + +// >= 128-bit vector, 8-bit lanes +template +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + using TU = TFromD; + using VU = VFromD; + const VU k5 = Set(du, TU{5}); + const VU k6 = Set(du, TU{6}); + + // Interleave (v0,v1,v2) to (MSB on left, lane 0 on right): + // v0[5], v2[4],v1[4],v0[4] .. v2[0],v1[0],v0[0]. We're expanding v0 lanes + // to their place, with 0x80 so lanes to be filled from other vectors are 0 + // to enable blending by ORing together. + const VFromD shuf_A0 = + Dup128VecFromValues(du, 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, 3, + 0x80, 0x80, 4, 0x80, 0x80, 5); + // Cannot reuse shuf_A0 because it contains 5. + const VFromD shuf_A1 = + Dup128VecFromValues(du, 0x80, 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, + 3, 0x80, 0x80, 4, 0x80, 0x80); + // The interleaved vectors will be named A, B, C; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + // cannot reuse shuf_A0 (has 5) + const VU shuf_A2 = CombineShiftRightBytes<15>(du, shuf_A1, shuf_A1); + const VU vA0 = TableLookupBytesOr0(v0, shuf_A0); // 5..4..3..2..1..0 + const VU vA1 = TableLookupBytesOr0(v1, shuf_A1); // ..4..3..2..1..0. + const VU vA2 = TableLookupBytesOr0(v2, shuf_A2); // .4..3..2..1..0.. + const VFromD A = BitCast(d, vA0 | vA1 | vA2); + + // B: v1[10],v0[10], v2[9],v1[9],v0[9] .. , v2[6],v1[6],v0[6], v2[5],v1[5] + const VU shuf_B0 = shuf_A2 + k6; // .A..9..8..7..6.. + const VU shuf_B1 = shuf_A0 + k5; // A..9..8..7..6..5 + const VU shuf_B2 = shuf_A1 + k5; // ..9..8..7..6..5. + const VU vB0 = TableLookupBytesOr0(v0, shuf_B0); + const VU vB1 = TableLookupBytesOr0(v1, shuf_B1); + const VU vB2 = TableLookupBytesOr0(v2, shuf_B2); + const VFromD B = BitCast(d, vB0 | vB1 | vB2); + + // C: v2[15],v1[15],v0[15], v2[11],v1[11],v0[11], v2[10] + const VU shuf_C0 = shuf_B2 + k6; // ..F..E..D..C..B. + const VU shuf_C1 = shuf_B0 + k5; // .F..E..D..C..B.. + const VU shuf_C2 = shuf_B1 + k5; // F..E..D..C..B..A + const VU vC0 = TableLookupBytesOr0(v0, shuf_C0); + const VU vC1 = TableLookupBytesOr0(v1, shuf_C1); + const VU vC2 = TableLookupBytesOr0(v2, shuf_C2); + const VFromD C = BitCast(d, vC0 | vC1 | vC2); + + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// >= 128-bit vector, 16-bit lanes +template +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + TFromD* HWY_RESTRICT unaligned) { + const Repartition du8; + using VU8 = VFromD; + const VU8 k2 = Set(du8, uint8_t{2 * sizeof(TFromD)}); + const VU8 k3 = Set(du8, uint8_t{3 * sizeof(TFromD)}); + + // Interleave (v0,v1,v2) to (MSB on left, lane 0 on right): + // v1[2],v0[2], v2[1],v1[1],v0[1], v2[0],v1[0],v0[0]. 0x80 so lanes to be + // filled from other vectors are 0 for blending. Note that these are byte + // indices for 16-bit lanes. + const VFromD shuf_A1 = + Dup128VecFromValues(du8, 0x80, 0x80, 0, 1, 0x80, 0x80, 0x80, 0x80, 2, 3, + 0x80, 0x80, 0x80, 0x80, 4, 5); + const VFromD shuf_A2 = + Dup128VecFromValues(du8, 0x80, 0x80, 0x80, 0x80, 0, 1, 0x80, 0x80, 0x80, + 0x80, 2, 3, 0x80, 0x80, 0x80, 0x80); + + // The interleaved vectors will be named A, B, C; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const VU8 shuf_A0 = CombineShiftRightBytes<2>(du8, shuf_A1, shuf_A1); + + const VU8 A0 = TableLookupBytesOr0(v0, shuf_A0); + const VU8 A1 = TableLookupBytesOr0(v1, shuf_A1); + const VU8 A2 = TableLookupBytesOr0(v2, shuf_A2); + const VFromD A = BitCast(d, A0 | A1 | A2); + + // B: v0[5] v2[4],v1[4],v0[4], v2[3],v1[3],v0[3], v2[2] + const VU8 shuf_B0 = shuf_A1 + k3; // 5..4..3. + const VU8 shuf_B1 = shuf_A2 + k3; // ..4..3.. + const VU8 shuf_B2 = shuf_A0 + k2; // .4..3..2 + const VU8 vB0 = TableLookupBytesOr0(v0, shuf_B0); + const VU8 vB1 = TableLookupBytesOr0(v1, shuf_B1); + const VU8 vB2 = TableLookupBytesOr0(v2, shuf_B2); + const VFromD B = BitCast(d, vB0 | vB1 | vB2); + + // C: v2[7],v1[7],v0[7], v2[6],v1[6],v0[6], v2[5],v1[5] + const VU8 shuf_C0 = shuf_B1 + k3; // ..7..6.. + const VU8 shuf_C1 = shuf_B2 + k3; // .7..6..5 + const VU8 shuf_C2 = shuf_B0 + k2; // 7..6..5. + const VU8 vC0 = TableLookupBytesOr0(v0, shuf_C0); + const VU8 vC1 = TableLookupBytesOr0(v1, shuf_C1); + const VU8 vC2 = TableLookupBytesOr0(v2, shuf_C2); + const VFromD C = BitCast(d, vC0 | vC1 | vC2); + + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// >= 128-bit vector, 32-bit lanes +template +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + TFromD* HWY_RESTRICT unaligned) { + const RepartitionToWide dw; + + const VFromD v10_v00 = InterleaveLower(d, v0, v1); + const VFromD v01_v20 = OddEven(v0, v2); + // A: v0[1], v2[0],v1[0],v0[0] (<- lane 0) + const VFromD A = BitCast( + d, InterleaveLower(dw, BitCast(dw, v10_v00), BitCast(dw, v01_v20))); + + const VFromD v1_321 = ShiftRightLanes<1>(d, v1); + const VFromD v0_32 = ShiftRightLanes<2>(d, v0); + const VFromD v21_v11 = OddEven(v2, v1_321); + const VFromD v12_v02 = OddEven(v1_321, v0_32); + // B: v1[2],v0[2], v2[1],v1[1] + const VFromD B = BitCast( + d, InterleaveLower(dw, BitCast(dw, v21_v11), BitCast(dw, v12_v02))); + + // Notation refers to the upper 2 lanes of the vector for InterleaveUpper. + const VFromD v23_v13 = OddEven(v2, v1_321); + const VFromD v03_v22 = OddEven(v0, v2); + // C: v2[3],v1[3],v0[3], v2[2] + const VFromD C = BitCast( + d, InterleaveUpper(dw, BitCast(dw, v03_v22), BitCast(dw, v23_v13))); + + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// >= 128-bit vector, 64-bit lanes +template +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + TFromD* HWY_RESTRICT unaligned) { + const VFromD A = InterleaveLower(d, v0, v1); + const VFromD B = OddEven(v0, v2); + const VFromD C = InterleaveUpper(d, v1, v2); + detail::StoreTransposedBlocks3(A, B, C, d, unaligned); +} + +// 64-bit vector, 8-bit lanes +template +HWY_API void StoreInterleaved3(VFromD part0, VFromD part1, + VFromD part2, D d, + TFromD* HWY_RESTRICT unaligned) { + // Use full vectors for the shuffles and first result. + constexpr size_t kFullN = 16 / sizeof(TFromD); + const Full128 du; + using VU = VFromD; + const Full128> d_full; + const VU k5 = Set(du, uint8_t{5}); + const VU k6 = Set(du, uint8_t{6}); + + const VFromD v0{part0.raw}; + const VFromD v1{part1.raw}; + const VFromD v2{part2.raw}; + + // Interleave (v0,v1,v2) to (MSB on left, lane 0 on right): + // v1[2],v0[2], v2[1],v1[1],v0[1], v2[0],v1[0],v0[0]. 0x80 so lanes to be + // filled from other vectors are 0 for blending. + alignas(16) static constexpr uint8_t tbl_v0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // + 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; + alignas(16) static constexpr uint8_t tbl_v1[16] = { + 0x80, 0, 0x80, 0x80, 1, 0x80, // + 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; + // The interleaved vectors will be named A, B, C; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const VU shuf_A0 = Load(du, tbl_v0); + const VU shuf_A1 = Load(du, tbl_v1); // cannot reuse shuf_A0 (5 in MSB) + const VU shuf_A2 = CombineShiftRightBytes<15>(du, shuf_A1, shuf_A1); + const VU A0 = TableLookupBytesOr0(v0, shuf_A0); // 5..4..3..2..1..0 + const VU A1 = TableLookupBytesOr0(v1, shuf_A1); // ..4..3..2..1..0. + const VU A2 = TableLookupBytesOr0(v2, shuf_A2); // .4..3..2..1..0.. + const auto A = BitCast(d_full, A0 | A1 | A2); + StoreU(A, d_full, unaligned + 0 * kFullN); + + // Second (HALF) vector: v2[7],v1[7],v0[7], v2[6],v1[6],v0[6], v2[5],v1[5] + const VU shuf_B0 = shuf_A2 + k6; // ..7..6.. + const VU shuf_B1 = shuf_A0 + k5; // .7..6..5 + const VU shuf_B2 = shuf_A1 + k5; // 7..6..5. + const VU vB0 = TableLookupBytesOr0(v0, shuf_B0); + const VU vB1 = TableLookupBytesOr0(v1, shuf_B1); + const VU vB2 = TableLookupBytesOr0(v2, shuf_B2); + const VFromD B{BitCast(d_full, vB0 | vB1 | vB2).raw}; + StoreU(B, d, unaligned + 1 * kFullN); +} + +// 64-bit vector, 16-bit lanes +template +HWY_API void StoreInterleaved3(VFromD part0, VFromD part1, + VFromD part2, D dh, + TFromD* HWY_RESTRICT unaligned) { + const Twice d_full; + const Full128 du8; + using VU8 = VFromD; + const VU8 k2 = Set(du8, uint8_t{2 * sizeof(TFromD)}); + const VU8 k3 = Set(du8, uint8_t{3 * sizeof(TFromD)}); + + const VFromD v0{part0.raw}; + const VFromD v1{part1.raw}; + const VFromD v2{part2.raw}; + + // Interleave part (v0,v1,v2) to full (MSB on left, lane 0 on right): + // v1[2],v0[2], v2[1],v1[1],v0[1], v2[0],v1[0],v0[0]. We're expanding v0 lanes + // to their place, with 0x80 so lanes to be filled from other vectors are 0 + // to enable blending by ORing together. + alignas(16) static constexpr uint8_t tbl_v1[16] = { + 0x80, 0x80, 0, 1, 0x80, 0x80, 0x80, 0x80, + 2, 3, 0x80, 0x80, 0x80, 0x80, 4, 5}; + alignas(16) static constexpr uint8_t tbl_v2[16] = { + 0x80, 0x80, 0x80, 0x80, 0, 1, 0x80, 0x80, + 0x80, 0x80, 2, 3, 0x80, 0x80, 0x80, 0x80}; + + // The interleaved vectors will be named A, B; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const VU8 shuf_A1 = Load(du8, tbl_v1); // 2..1..0. + // .2..1..0 + const VU8 shuf_A0 = CombineShiftRightBytes<2>(du8, shuf_A1, shuf_A1); + const VU8 shuf_A2 = Load(du8, tbl_v2); // ..1..0.. + + const VU8 A0 = TableLookupBytesOr0(v0, shuf_A0); + const VU8 A1 = TableLookupBytesOr0(v1, shuf_A1); + const VU8 A2 = TableLookupBytesOr0(v2, shuf_A2); + const VFromD A = BitCast(d_full, A0 | A1 | A2); + StoreU(A, d_full, unaligned); + + // Second (HALF) vector: v2[3],v1[3],v0[3], v2[2] + const VU8 shuf_B0 = shuf_A1 + k3; // ..3. + const VU8 shuf_B1 = shuf_A2 + k3; // .3.. + const VU8 shuf_B2 = shuf_A0 + k2; // 3..2 + const VU8 vB0 = TableLookupBytesOr0(v0, shuf_B0); + const VU8 vB1 = TableLookupBytesOr0(v1, shuf_B1); + const VU8 vB2 = TableLookupBytesOr0(v2, shuf_B2); + const VFromD B = BitCast(d_full, vB0 | vB1 | vB2); + StoreU(VFromD{B.raw}, dh, unaligned + MaxLanes(d_full)); +} + +// 64-bit vector, 32-bit lanes +template +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + TFromD* HWY_RESTRICT unaligned) { + // (same code as 128-bit vector, 64-bit lanes) + const VFromD v10_v00 = InterleaveLower(d, v0, v1); + const VFromD v01_v20 = OddEven(v0, v2); + const VFromD v21_v11 = InterleaveUpper(d, v1, v2); + constexpr size_t kN = MaxLanes(d); + StoreU(v10_v00, d, unaligned + 0 * kN); + StoreU(v01_v20, d, unaligned + 1 * kN); + StoreU(v21_v11, d, unaligned + 2 * kN); +} + +// 64-bit lanes are handled by the N=1 case below. + +// <= 32-bit vector, 8-bit lanes +template +HWY_API void StoreInterleaved3(VFromD part0, VFromD part1, + VFromD part2, D d, + TFromD* HWY_RESTRICT unaligned) { + // Use full vectors for the shuffles and result. + const Full128 du; + using VU = VFromD; + const Full128> d_full; + + const VFromD v0{part0.raw}; + const VFromD v1{part1.raw}; + const VFromD v2{part2.raw}; + + // Interleave (v0,v1,v2). We're expanding v0 lanes to their place, with 0x80 + // so lanes to be filled from other vectors are 0 to enable blending by ORing + // together. + alignas(16) static constexpr uint8_t tbl_v0[16] = { + 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, + 0x80, 3, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80}; + // The interleaved vector will be named A; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const VU shuf_A0 = Load(du, tbl_v0); + const VU shuf_A1 = CombineShiftRightBytes<15>(du, shuf_A0, shuf_A0); + const VU shuf_A2 = CombineShiftRightBytes<14>(du, shuf_A0, shuf_A0); + const VU A0 = TableLookupBytesOr0(v0, shuf_A0); // ......3..2..1..0 + const VU A1 = TableLookupBytesOr0(v1, shuf_A1); // .....3..2..1..0. + const VU A2 = TableLookupBytesOr0(v2, shuf_A2); // ....3..2..1..0.. + const VFromD A = BitCast(d_full, A0 | A1 | A2); + alignas(16) TFromD buf[MaxLanes(d_full)]; + StoreU(A, d_full, buf); + CopyBytes(buf, unaligned); +} + +// 32-bit vector, 16-bit lanes +template +HWY_API void StoreInterleaved3(VFromD part0, VFromD part1, + VFromD part2, D d, + TFromD* HWY_RESTRICT unaligned) { + // Use full vectors for the shuffles and result. + const Full128 du8; + using VU8 = VFromD; + const Full128> d_full; + + const VFromD v0{part0.raw}; + const VFromD v1{part1.raw}; + const VFromD v2{part2.raw}; + + // Interleave (v0,v1,v2). We're expanding v0 lanes to their place, with 0x80 + // so lanes to be filled from other vectors are 0 to enable blending by ORing + // together. + alignas(16) static constexpr uint8_t tbl_v2[16] = { + 0x80, 0x80, 0x80, 0x80, 0, 1, 0x80, 0x80, + 0x80, 0x80, 2, 3, 0x80, 0x80, 0x80, 0x80}; + // The interleaved vector will be named A; temporaries with suffix + // 0..2 indicate which input vector's lanes they hold. + const VU8 shuf_A2 = Load(du8, tbl_v2); // ..1..0.. + const VU8 shuf_A1 = + CombineShiftRightBytes<2>(du8, shuf_A2, shuf_A2); // ...1..0. + const VU8 shuf_A0 = + CombineShiftRightBytes<4>(du8, shuf_A2, shuf_A2); // ....1..0 + const VU8 A0 = TableLookupBytesOr0(v0, shuf_A0); // ..1..0 + const VU8 A1 = TableLookupBytesOr0(v1, shuf_A1); // .1..0. + const VU8 A2 = TableLookupBytesOr0(v2, shuf_A2); // 1..0.. + const auto A = BitCast(d_full, A0 | A1 | A2); + alignas(16) TFromD buf[MaxLanes(d_full)]; + StoreU(A, d_full, buf); + CopyBytes(buf, unaligned); +} + +// Single-element vector, any lane size: just store directly +template +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + TFromD* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); + StoreU(v2, d, unaligned + 2); +} + +// ------------------------------ StoreInterleaved4 + +namespace detail { + +// Default for <= 128-bit vectors; x86_256 and x86_512 have their own overload. +template +HWY_INLINE void StoreTransposedBlocks4(VFromD vA, VFromD vB, VFromD vC, + VFromD vD, D d, + TFromD* HWY_RESTRICT unaligned) { + constexpr size_t kN = MaxLanes(d); + StoreU(vA, d, unaligned + 0 * kN); + StoreU(vB, d, unaligned + 1 * kN); + StoreU(vC, d, unaligned + 2 * kN); + StoreU(vD, d, unaligned + 3 * kN); +} + +} // namespace detail + +// >= 128-bit vector, 8..32-bit lanes +template +HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, + VFromD v3, D d, + TFromD* HWY_RESTRICT unaligned) { + const RepartitionToWide dw; + const auto v10L = ZipLower(dw, v0, v1); // .. v1[0] v0[0] + const auto v32L = ZipLower(dw, v2, v3); + const auto v10U = ZipUpper(dw, v0, v1); + const auto v32U = ZipUpper(dw, v2, v3); + // The interleaved vectors are vA, vB, vC, vD. + const VFromD vA = BitCast(d, InterleaveLower(dw, v10L, v32L)); // 3210 + const VFromD vB = BitCast(d, InterleaveUpper(dw, v10L, v32L)); + const VFromD vC = BitCast(d, InterleaveLower(dw, v10U, v32U)); + const VFromD vD = BitCast(d, InterleaveUpper(dw, v10U, v32U)); + detail::StoreTransposedBlocks4(vA, vB, vC, vD, d, unaligned); +} + +// >= 128-bit vector, 64-bit lanes +template +HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, + VFromD v3, D d, + TFromD* HWY_RESTRICT unaligned) { + // The interleaved vectors are vA, vB, vC, vD. + const VFromD vA = InterleaveLower(d, v0, v1); // v1[0] v0[0] + const VFromD vB = InterleaveLower(d, v2, v3); + const VFromD vC = InterleaveUpper(d, v0, v1); + const VFromD vD = InterleaveUpper(d, v2, v3); + detail::StoreTransposedBlocks4(vA, vB, vC, vD, d, unaligned); +} + +// 64-bit vector, 8..32-bit lanes +template +HWY_API void StoreInterleaved4(VFromD part0, VFromD part1, + VFromD part2, VFromD part3, D /* tag */, + TFromD* HWY_RESTRICT unaligned) { + // Use full vectors to reduce the number of stores. + const Full128> d_full; + const RepartitionToWide dw; + const VFromD v0{part0.raw}; + const VFromD v1{part1.raw}; + const VFromD v2{part2.raw}; + const VFromD v3{part3.raw}; + const auto v10 = ZipLower(dw, v0, v1); // v1[0] v0[0] + const auto v32 = ZipLower(dw, v2, v3); + const auto A = BitCast(d_full, InterleaveLower(dw, v10, v32)); + const auto B = BitCast(d_full, InterleaveUpper(dw, v10, v32)); + StoreU(A, d_full, unaligned); + StoreU(B, d_full, unaligned + MaxLanes(d_full)); +} + +// 64-bit vector, 64-bit lane +template +HWY_API void StoreInterleaved4(VFromD part0, VFromD part1, + VFromD part2, VFromD part3, D /* tag */, + TFromD* HWY_RESTRICT unaligned) { + // Use full vectors to reduce the number of stores. + const Full128> d_full; + const VFromD v0{part0.raw}; + const VFromD v1{part1.raw}; + const VFromD v2{part2.raw}; + const VFromD v3{part3.raw}; + const auto A = InterleaveLower(d_full, v0, v1); // v1[0] v0[0] + const auto B = InterleaveLower(d_full, v2, v3); + StoreU(A, d_full, unaligned); + StoreU(B, d_full, unaligned + MaxLanes(d_full)); +} + +// <= 32-bit vectors +template +HWY_API void StoreInterleaved4(VFromD part0, VFromD part1, + VFromD part2, VFromD part3, D d, + TFromD* HWY_RESTRICT unaligned) { + // Use full vectors to reduce the number of stores. + const Full128> d_full; + const RepartitionToWide dw; + const VFromD v0{part0.raw}; + const VFromD v1{part1.raw}; + const VFromD v2{part2.raw}; + const VFromD v3{part3.raw}; + const auto v10 = ZipLower(dw, v0, v1); // .. v1[0] v0[0] + const auto v32 = ZipLower(dw, v2, v3); + const auto v3210 = BitCast(d_full, InterleaveLower(dw, v10, v32)); + alignas(16) TFromD buf[MaxLanes(d_full)]; + StoreU(v3210, d_full, buf); + CopyBytes(buf, unaligned); +} + +#endif // HWY_NATIVE_LOAD_STORE_INTERLEAVED + +// ------------------------------ PairwiseAdd/PairwiseSub +#if (defined(HWY_NATIVE_PAIRWISE_ADD) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_PAIRWISE_ADD +#undef HWY_NATIVE_PAIRWISE_ADD +#else +#define HWY_NATIVE_PAIRWISE_ADD +#endif + +template (), HWY_IF_LANES_GT_D(D, 1)> +HWY_API V PairwiseAdd(D d, V a, V b) { + return Add(InterleaveEven(d, a, b), InterleaveOdd(d, a, b)); +} + +#endif + +#if (defined(HWY_NATIVE_PAIRWISE_SUB) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_PAIRWISE_SUB +#undef HWY_NATIVE_PAIRWISE_SUB +#else +#define HWY_NATIVE_PAIRWISE_SUB +#endif + +template (), HWY_IF_LANES_GT_D(D, 1)> +HWY_API V PairwiseSub(D d, V a, V b) { + return Sub(InterleaveOdd(d, a, b), InterleaveEven(d, a, b)); +} + +#endif + +// Load/StoreInterleaved for special floats. Requires HWY_GENERIC_IF_EMULATED_D +// is defined such that it is true only for types that actually require these +// generic implementations. +#if HWY_IDE || (defined(HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED) == \ + defined(HWY_TARGET_TOGGLE) && \ + defined(HWY_GENERIC_IF_EMULATED_D)) +#ifdef HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED +#endif +#if HWY_IDE +#define HWY_GENERIC_IF_EMULATED_D(D) int +#endif + +template > +HWY_API void LoadInterleaved2(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1) { + const RebindToUnsigned du; + VFromD vu0, vu1; + LoadInterleaved2(du, detail::U16LanePointer(unaligned), vu0, vu1); + v0 = BitCast(d, vu0); + v1 = BitCast(d, vu1); +} + +template > +HWY_API void LoadInterleaved3(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + const RebindToUnsigned du; + VFromD vu0, vu1, vu2; + LoadInterleaved3(du, detail::U16LanePointer(unaligned), vu0, vu1, vu2); + v0 = BitCast(d, vu0); + v1 = BitCast(d, vu1); + v2 = BitCast(d, vu2); +} + +template > +HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + const RebindToUnsigned du; + VFromD vu0, vu1, vu2, vu3; + LoadInterleaved4(du, detail::U16LanePointer(unaligned), vu0, vu1, vu2, vu3); + v0 = BitCast(d, vu0); + v1 = BitCast(d, vu1); + v2 = BitCast(d, vu2); + v3 = BitCast(d, vu3); +} + +template > +HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + StoreInterleaved2(BitCast(du, v0), BitCast(du, v1), du, + detail::U16LanePointer(unaligned)); +} + +template > +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + StoreInterleaved3(BitCast(du, v0), BitCast(du, v1), BitCast(du, v2), du, + detail::U16LanePointer(unaligned)); +} + +template > +HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, + VFromD v3, D d, T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + StoreInterleaved4(BitCast(du, v0), BitCast(du, v1), BitCast(du, v2), + BitCast(du, v3), du, detail::U16LanePointer(unaligned)); +} + +#endif // HWY_NATIVE_LOAD_STORE_SPECIAL_FLOAT_INTERLEAVED + +// ------------------------------ LoadN + +#if (defined(HWY_NATIVE_LOAD_N) == defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_LOAD_N +#undef HWY_NATIVE_LOAD_N +#else +#define HWY_NATIVE_LOAD_N +#endif + +#if HWY_MEM_OPS_MIGHT_FAULT && !HWY_HAVE_SCALABLE +namespace detail { + +template +HWY_INLINE VFromD LoadNResizeBitCast(DTo d_to, DFrom d_from, + VFromD v) { +#if HWY_TARGET <= HWY_SSE2 + // On SSE2/SSSE3/SSE4, the LoadU operation will zero out any lanes of v.raw + // past the first (lowest-index) Lanes(d_from) lanes of v.raw if + // sizeof(decltype(v.raw)) > d_from.MaxBytes() is true + (void)d_from; + return ResizeBitCast(d_to, v); +#else + // On other targets such as PPC/NEON, the contents of any lanes past the first + // (lowest-index) Lanes(d_from) lanes of v.raw might be non-zero if + // sizeof(decltype(v.raw)) > d_from.MaxBytes() is true. + return ZeroExtendResizeBitCast(d_to, d_from, v); +#endif +} + +} // namespace detail + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + return (num_lanes > 0) ? LoadU(d, p) : Zero(d); +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + return (num_lanes > 0) ? LoadU(d, p) : no; +} + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 1> d1; + + if (num_lanes >= 2) return LoadU(d, p); + if (num_lanes == 0) return Zero(d); + return detail::LoadNResizeBitCast(d, d1, LoadU(d1, p)); +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 1> d1; + + if (num_lanes >= 2) return LoadU(d, p); + if (num_lanes == 0) return no; + return InterleaveLower(ResizeBitCast(d, LoadU(d1, p)), no); +} + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 2> d2; + const Half d1; + + if (num_lanes >= 4) return LoadU(d, p); + if (num_lanes == 0) return Zero(d); + if (num_lanes == 1) return detail::LoadNResizeBitCast(d, d1, LoadU(d1, p)); + + // Two or three lanes. + const VFromD v_lo = detail::LoadNResizeBitCast(d, d2, LoadU(d2, p)); + return (num_lanes == 2) ? v_lo : InsertLane(v_lo, 2, p[2]); +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 2> d2; + + if (num_lanes >= 4) return LoadU(d, p); + if (num_lanes == 0) return no; + if (num_lanes == 1) return InsertLane(no, 0, p[0]); + + // Two or three lanes. + const VFromD v_lo = + ConcatUpperLower(d, no, ResizeBitCast(d, LoadU(d2, p))); + return (num_lanes == 2) ? v_lo : InsertLane(v_lo, 2, p[2]); +} + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 4> d4; + const Half d2; + const Half d1; + + if (num_lanes >= 8) return LoadU(d, p); + if (num_lanes == 0) return Zero(d); + if (num_lanes == 1) return detail::LoadNResizeBitCast(d, d1, LoadU(d1, p)); + + const size_t leading_len = num_lanes & 4; + VFromD v_trailing = Zero(d4); + + if ((num_lanes & 2) != 0) { + const VFromD v_trailing_lo2 = LoadU(d2, p + leading_len); + if ((num_lanes & 1) != 0) { + v_trailing = Combine( + d4, + detail::LoadNResizeBitCast(d2, d1, LoadU(d1, p + leading_len + 2)), + v_trailing_lo2); + } else { + v_trailing = detail::LoadNResizeBitCast(d4, d2, v_trailing_lo2); + } + } else if ((num_lanes & 1) != 0) { + v_trailing = detail::LoadNResizeBitCast(d4, d1, LoadU(d1, p + leading_len)); + } + + if (leading_len != 0) { + return Combine(d, v_trailing, LoadU(d4, p)); + } else { + return detail::LoadNResizeBitCast(d, d4, v_trailing); + } +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 4> d4; + const Half d2; + const Half d1; + + if (num_lanes >= 8) return LoadU(d, p); + if (num_lanes == 0) return no; + if (num_lanes == 1) return InsertLane(no, 0, p[0]); + + const size_t leading_len = num_lanes & 4; + VFromD v_trailing = ResizeBitCast(d4, no); + + if ((num_lanes & 2) != 0) { + const VFromD v_trailing_lo2 = LoadU(d2, p + leading_len); + if ((num_lanes & 1) != 0) { + v_trailing = Combine( + d4, + InterleaveLower(ResizeBitCast(d2, LoadU(d1, p + leading_len + 2)), + ResizeBitCast(d2, no)), + v_trailing_lo2); + } else { + v_trailing = ConcatUpperLower(d4, ResizeBitCast(d4, no), + ResizeBitCast(d4, v_trailing_lo2)); + } + } else if ((num_lanes & 1) != 0) { + v_trailing = InsertLane(ResizeBitCast(d4, no), 0, p[leading_len]); + } + + if (leading_len != 0) { + return Combine(d, v_trailing, LoadU(d4, p)); + } else { + return ConcatUpperLower(d, no, ResizeBitCast(d, v_trailing)); + } +} + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 8> d8; + const Half d4; + const Half d2; + const Half d1; + + if (num_lanes >= 16) return LoadU(d, p); + if (num_lanes == 0) return Zero(d); + if (num_lanes == 1) return detail::LoadNResizeBitCast(d, d1, LoadU(d1, p)); + + const size_t leading_len = num_lanes & 12; + VFromD v_trailing = Zero(d4); + + if ((num_lanes & 2) != 0) { + const VFromD v_trailing_lo2 = LoadU(d2, p + leading_len); + if ((num_lanes & 1) != 0) { + v_trailing = Combine( + d4, + detail::LoadNResizeBitCast(d2, d1, LoadU(d1, p + leading_len + 2)), + v_trailing_lo2); + } else { + v_trailing = detail::LoadNResizeBitCast(d4, d2, v_trailing_lo2); + } + } else if ((num_lanes & 1) != 0) { + v_trailing = detail::LoadNResizeBitCast(d4, d1, LoadU(d1, p + leading_len)); + } + + if (leading_len != 0) { + if (leading_len >= 8) { + const VFromD v_hi7 = + ((leading_len & 4) != 0) + ? Combine(d8, v_trailing, LoadU(d4, p + 8)) + : detail::LoadNResizeBitCast(d8, d4, v_trailing); + return Combine(d, v_hi7, LoadU(d8, p)); + } else { + return detail::LoadNResizeBitCast(d, d8, + Combine(d8, v_trailing, LoadU(d4, p))); + } + } else { + return detail::LoadNResizeBitCast(d, d4, v_trailing); + } +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, 8> d8; + const Half d4; + const Half d2; + const Half d1; + + if (num_lanes >= 16) return LoadU(d, p); + if (num_lanes == 0) return no; + if (num_lanes == 1) return InsertLane(no, 0, p[0]); + + const size_t leading_len = num_lanes & 12; + VFromD v_trailing = ResizeBitCast(d4, no); + + if ((num_lanes & 2) != 0) { + const VFromD v_trailing_lo2 = LoadU(d2, p + leading_len); + if ((num_lanes & 1) != 0) { + v_trailing = Combine( + d4, + InterleaveLower(ResizeBitCast(d2, LoadU(d1, p + leading_len + 2)), + ResizeBitCast(d2, no)), + v_trailing_lo2); + } else { + v_trailing = ConcatUpperLower(d4, ResizeBitCast(d4, no), + ResizeBitCast(d4, v_trailing_lo2)); + } + } else if ((num_lanes & 1) != 0) { + v_trailing = InsertLane(ResizeBitCast(d4, no), 0, p[leading_len]); + } + + if (leading_len != 0) { + if (leading_len >= 8) { + const VFromD v_hi7 = + ((leading_len & 4) != 0) + ? Combine(d8, v_trailing, LoadU(d4, p + 8)) + : ConcatUpperLower(d8, ResizeBitCast(d8, no), + ResizeBitCast(d8, v_trailing)); + return Combine(d, v_hi7, LoadU(d8, p)); + } else { + return ConcatUpperLower( + d, ResizeBitCast(d, no), + ResizeBitCast(d, Combine(d8, v_trailing, LoadU(d4, p)))); + } + } else { + const Repartition du32; + // lowest 4 bytes from v_trailing, next 4 from no. + const VFromD lo8 = + InterleaveLower(ResizeBitCast(du32, v_trailing), BitCast(du32, no)); + return ConcatUpperLower(d, ResizeBitCast(d, no), ResizeBitCast(d, lo8)); + } +} + +#if HWY_MAX_BYTES >= 32 + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + if (num_lanes >= Lanes(d)) return LoadU(d, p); + + const Half dh; + const size_t half_N = Lanes(dh); + if (num_lanes <= half_N) { + return ZeroExtendVector(d, LoadN(dh, p, num_lanes)); + } else { + const VFromD v_lo = LoadU(dh, p); + const VFromD v_hi = LoadN(dh, p + half_N, num_lanes - half_N); + return Combine(d, v_hi, v_lo); + } +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + if (num_lanes >= Lanes(d)) return LoadU(d, p); + + const Half dh; + const size_t half_N = Lanes(dh); + const VFromD no_h = LowerHalf(no); + if (num_lanes <= half_N) { + return ConcatUpperLower(d, no, + ResizeBitCast(d, LoadNOr(no_h, dh, p, num_lanes))); + } else { + const VFromD v_lo = LoadU(dh, p); + const VFromD v_hi = + LoadNOr(no_h, dh, p + half_N, num_lanes - half_N); + return Combine(d, v_hi, v_lo); + } +} + +#endif // HWY_MAX_BYTES >= 32 + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const RebindToUnsigned du; + return BitCast(d, LoadN(du, detail::U16LanePointer(p), num_lanes)); +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const RebindToUnsigned du; + return BitCast( + d, LoadNOr(BitCast(du, no), du, detail::U16LanePointer(p), num_lanes)); +} + +#else // !HWY_MEM_OPS_MIGHT_FAULT || HWY_HAVE_SCALABLE + +// For SVE and non-sanitizer AVX-512; RVV has its own specialization. +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { +#if HWY_MEM_OPS_MIGHT_FAULT + if (num_lanes <= 0) return Zero(d); +#endif + + return MaskedLoad(FirstN(d, num_lanes), d, p); +} + +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { +#if HWY_MEM_OPS_MIGHT_FAULT + if (num_lanes <= 0) return no; +#endif + + return MaskedLoadOr(no, FirstN(d, num_lanes), d, p); +} + +#endif // HWY_MEM_OPS_MIGHT_FAULT && !HWY_HAVE_SCALABLE +#endif // HWY_NATIVE_LOAD_N + +// ------------------------------ StoreN +#if (defined(HWY_NATIVE_STORE_N) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_STORE_N +#undef HWY_NATIVE_STORE_N +#else +#define HWY_NATIVE_STORE_N +#endif + +#if HWY_MEM_OPS_MIGHT_FAULT && !HWY_HAVE_SCALABLE +namespace detail { + +template +HWY_INLINE VFromD StoreNGetUpperHalf(DH dh, VFromD> v) { + constexpr size_t kMinShrVectBytes = HWY_TARGET_IS_NEON ? 8 : 16; + const FixedTag d_shift; + return ResizeBitCast( + dh, ShiftRightBytes(d_shift, ResizeBitCast(d_shift, v))); +} + +template +HWY_INLINE VFromD StoreNGetUpperHalf(DH dh, VFromD> v) { + return UpperHalf(dh, v); +} + +} // namespace detail + +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + if (max_lanes_to_store > 0) { + StoreU(v, d, p); + } +} + +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + if (max_lanes_to_store > 1) { + StoreU(v, d, p); + } else if (max_lanes_to_store == 1) { + const FixedTag, 1> d1; + StoreU(LowerHalf(d1, v), d1, p); + } +} + +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + const FixedTag, 2> d2; + const Half d1; + + if (max_lanes_to_store > 1) { + if (max_lanes_to_store >= 4) { + StoreU(v, d, p); + } else { + StoreU(ResizeBitCast(d2, v), d2, p); + if (max_lanes_to_store == 3) { + StoreU(ResizeBitCast(d1, detail::StoreNGetUpperHalf(d2, v)), d1, p + 2); + } + } + } else if (max_lanes_to_store == 1) { + StoreU(ResizeBitCast(d1, v), d1, p); + } +} + +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + const FixedTag, 4> d4; + const Half d2; + const Half d1; + + if (max_lanes_to_store <= 1) { + if (max_lanes_to_store == 1) { + StoreU(ResizeBitCast(d1, v), d1, p); + } + } else if (max_lanes_to_store >= 8) { + StoreU(v, d, p); + } else if (max_lanes_to_store >= 4) { + StoreU(LowerHalf(d4, v), d4, p); + StoreN(detail::StoreNGetUpperHalf(d4, v), d4, p + 4, + max_lanes_to_store - 4); + } else { + StoreN(LowerHalf(d4, v), d4, p, max_lanes_to_store); + } +} + +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + const FixedTag, 8> d8; + const Half d4; + const Half d2; + const Half d1; + + if (max_lanes_to_store <= 1) { + if (max_lanes_to_store == 1) { + StoreU(ResizeBitCast(d1, v), d1, p); + } + } else if (max_lanes_to_store >= 16) { + StoreU(v, d, p); + } else if (max_lanes_to_store >= 8) { + StoreU(LowerHalf(d8, v), d8, p); + StoreN(detail::StoreNGetUpperHalf(d8, v), d8, p + 8, + max_lanes_to_store - 8); + } else { + StoreN(LowerHalf(d8, v), d8, p, max_lanes_to_store); + } +} + +#if HWY_MAX_BYTES >= 32 +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + const size_t N = Lanes(d); + if (max_lanes_to_store >= N) { + StoreU(v, d, p); + return; + } + + const Half dh; + const size_t half_N = Lanes(dh); + if (max_lanes_to_store <= half_N) { + StoreN(LowerHalf(dh, v), dh, p, max_lanes_to_store); + } else { + StoreU(LowerHalf(dh, v), dh, p); + StoreN(UpperHalf(dh, v), dh, p + half_N, max_lanes_to_store - half_N); + } +} +#endif // HWY_MAX_BYTES >= 32 + +#else // !HWY_MEM_OPS_MIGHT_FAULT || HWY_HAVE_SCALABLE +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + const size_t N = Lanes(d); + const size_t clamped_max_lanes_to_store = HWY_MIN(max_lanes_to_store, N); +#if HWY_MEM_OPS_MIGHT_FAULT + if (clamped_max_lanes_to_store == 0) return; +#endif + + BlendedStore(v, FirstN(d, clamped_max_lanes_to_store), d, p); + + detail::MaybeUnpoison(p, clamped_max_lanes_to_store); +} +#endif // HWY_MEM_OPS_MIGHT_FAULT && !HWY_HAVE_SCALABLE + +#endif // (defined(HWY_NATIVE_STORE_N) == defined(HWY_TARGET_TOGGLE)) + +// ------------------------------ TruncateStore +#if (defined(HWY_NATIVE_STORE_TRUNCATED) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_STORE_TRUNCATED +#undef HWY_NATIVE_STORE_TRUNCATED +#else +#define HWY_NATIVE_STORE_TRUNCATED +#endif + +template +HWY_API void TruncateStore(VFromD v, const D /*d*/, T* HWY_RESTRICT p) { + using DTo = Rebind; + DTo dsmall; + StoreU(TruncateTo(dsmall, v), dsmall, p); +} + +#endif // (defined(HWY_NATIVE_STORE_TRUNCATED) == defined(HWY_TARGET_TOGGLE)) + +// ------------------------------ Scatter + +#if (defined(HWY_NATIVE_SCATTER) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SCATTER +#undef HWY_NATIVE_SCATTER +#else +#define HWY_NATIVE_SCATTER +#endif + +template > +HWY_API void ScatterOffset(VFromD v, D d, T* HWY_RESTRICT base, + VFromD> offset) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + HWY_ALIGN T lanes[MaxLanes(d)]; + Store(v, d, lanes); + + HWY_ALIGN TI offset_lanes[MaxLanes(d)]; + Store(offset, di, offset_lanes); + + uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < MaxLanes(d); ++i) { + CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); + } +} + +template > +HWY_API void ScatterIndex(VFromD v, D d, T* HWY_RESTRICT base, + VFromD> index) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + HWY_ALIGN T lanes[MaxLanes(d)]; + Store(v, d, lanes); + + HWY_ALIGN TI index_lanes[MaxLanes(d)]; + Store(index, di, index_lanes); + + for (size_t i = 0; i < MaxLanes(d); ++i) { + base[index_lanes[i]] = lanes[i]; + } +} + +template > +HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D d, + T* HWY_RESTRICT base, + VFromD> index) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + HWY_ALIGN T lanes[MaxLanes(d)]; + Store(v, d, lanes); + + HWY_ALIGN TI index_lanes[MaxLanes(d)]; + Store(index, di, index_lanes); + + HWY_ALIGN TI mask_lanes[MaxLanes(di)]; + Store(BitCast(di, VecFromMask(d, m)), di, mask_lanes); + + for (size_t i = 0; i < MaxLanes(d); ++i) { + if (mask_lanes[i]) base[index_lanes[i]] = lanes[i]; + } +} + +template > +HWY_API void ScatterIndexN(VFromD v, D d, T* HWY_RESTRICT base, + VFromD> index, + const size_t max_lanes_to_store) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + for (size_t i = 0; i < MaxLanes(d); ++i) { + if (i < max_lanes_to_store) base[ExtractLane(index, i)] = ExtractLane(v, i); + } +} +#else +template > +HWY_API void ScatterIndexN(VFromD v, D d, T* HWY_RESTRICT base, + VFromD> index, + const size_t max_lanes_to_store) { + MaskedScatterIndex(v, FirstN(d, max_lanes_to_store), d, base, index); +} +#endif // (defined(HWY_NATIVE_SCATTER) == defined(HWY_TARGET_TOGGLE)) + +// ------------------------------ Gather + +#if (defined(HWY_NATIVE_GATHER) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_GATHER +#undef HWY_NATIVE_GATHER +#else +#define HWY_NATIVE_GATHER +#endif + +template > +HWY_API VFromD GatherOffset(D d, const T* HWY_RESTRICT base, + VFromD> offset) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + HWY_ALIGN TI offset_lanes[MaxLanes(d)]; + Store(offset, di, offset_lanes); + + HWY_ALIGN T lanes[MaxLanes(d)]; + const uint8_t* base_bytes = reinterpret_cast(base); + for (size_t i = 0; i < MaxLanes(d); ++i) { + HWY_DASSERT(offset_lanes[i] >= 0); + CopyBytes(base_bytes + offset_lanes[i], &lanes[i]); + } + return Load(d, lanes); +} + +template > +HWY_API VFromD GatherIndex(D d, const T* HWY_RESTRICT base, + VFromD> index) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + HWY_ALIGN TI index_lanes[MaxLanes(d)]; + Store(index, di, index_lanes); + + HWY_ALIGN T lanes[MaxLanes(d)]; + for (size_t i = 0; i < MaxLanes(d); ++i) { + HWY_DASSERT(index_lanes[i] >= 0); + lanes[i] = base[index_lanes[i]]; + } + return Load(d, lanes); +} + +template > +HWY_API VFromD MaskedGatherIndex(MFromD m, D d, + const T* HWY_RESTRICT base, + VFromD> index) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + HWY_ALIGN TI index_lanes[MaxLanes(di)]; + Store(index, di, index_lanes); + + HWY_ALIGN TI mask_lanes[MaxLanes(di)]; + Store(BitCast(di, VecFromMask(d, m)), di, mask_lanes); + + HWY_ALIGN T lanes[MaxLanes(d)]; + for (size_t i = 0; i < MaxLanes(d); ++i) { + HWY_DASSERT(index_lanes[i] >= 0); + lanes[i] = mask_lanes[i] ? base[index_lanes[i]] : T{0}; + } + return Load(d, lanes); +} + +template > +HWY_API VFromD MaskedGatherIndexOr(VFromD no, MFromD m, D d, + const T* HWY_RESTRICT base, + VFromD> index) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + HWY_ALIGN TI index_lanes[MaxLanes(di)]; + Store(index, di, index_lanes); + + HWY_ALIGN TI mask_lanes[MaxLanes(di)]; + Store(BitCast(di, VecFromMask(d, m)), di, mask_lanes); + + HWY_ALIGN T no_lanes[MaxLanes(d)]; + Store(no, d, no_lanes); + + HWY_ALIGN T lanes[MaxLanes(d)]; + for (size_t i = 0; i < MaxLanes(d); ++i) { + HWY_DASSERT(index_lanes[i] >= 0); + lanes[i] = mask_lanes[i] ? base[index_lanes[i]] : no_lanes[i]; + } + return Load(d, lanes); +} + +template > +HWY_API VFromD GatherIndexN(D d, const T* HWY_RESTRICT base, + VFromD> index, + const size_t max_lanes_to_load) { + return GatherIndexNOr(Zero(d), d, base, index, max_lanes_to_load); +} + +template > +HWY_API VFromD GatherIndexNOr(VFromD no, D d, const T* HWY_RESTRICT base, + VFromD> index, + const size_t max_lanes_to_load) { + const RebindToSigned di; + using TI = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + + VFromD v = no; + for (size_t i = 0; i < MaxLanes(d); ++i) { + if (i < max_lanes_to_load) + v = InsertLane(v, i, base[ExtractLane(index, i)]); + } + return v; +} +#else +template > +HWY_API VFromD GatherIndexN(D d, const T* HWY_RESTRICT base, + VFromD> index, + const size_t max_lanes_to_load) { + return MaskedGatherIndex(FirstN(d, max_lanes_to_load), d, base, index); +} +template > +HWY_API VFromD GatherIndexNOr(VFromD no, D d, const T* HWY_RESTRICT base, + VFromD> index, + const size_t max_lanes_to_load) { + return MaskedGatherIndexOr(no, FirstN(d, max_lanes_to_load), d, base, index); +} +#endif // (defined(HWY_NATIVE_GATHER) == defined(HWY_TARGET_TOGGLE)) + +// ------------------------------ Integer AbsDiff and SumsOf8AbsDiff + +#if (defined(HWY_NATIVE_INTEGER_ABS_DIFF) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_INTEGER_ABS_DIFF +#undef HWY_NATIVE_INTEGER_ABS_DIFF +#else +#define HWY_NATIVE_INTEGER_ABS_DIFF +#endif + +template +HWY_API V AbsDiff(V a, V b) { + return Sub(Max(a, b), Min(a, b)); +} + +#endif // HWY_NATIVE_INTEGER_ABS_DIFF + +#if (defined(HWY_NATIVE_SUMS_OF_8_ABS_DIFF) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SUMS_OF_8_ABS_DIFF +#undef HWY_NATIVE_SUMS_OF_8_ABS_DIFF +#else +#define HWY_NATIVE_SUMS_OF_8_ABS_DIFF +#endif + +template ), + HWY_IF_V_SIZE_GT_D(DFromV, (HWY_TARGET == HWY_SCALAR ? 0 : 4))> +HWY_API Vec>> SumsOf8AbsDiff(V a, V b) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWideX3 dw; + + return BitCast(dw, SumsOf8(BitCast(du, AbsDiff(a, b)))); +} + +#endif // HWY_NATIVE_SUMS_OF_8_ABS_DIFF + +// ------------------------------ SaturatedAdd/SaturatedSub for UI32/UI64 + +#if (defined(HWY_NATIVE_I32_SATURATED_ADDSUB) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_I32_SATURATED_ADDSUB +#undef HWY_NATIVE_I32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I32_SATURATED_ADDSUB +#endif + +template )> +HWY_API V SaturatedAdd(V a, V b) { + const DFromV d; + const auto sum = Add(a, b); + const auto overflow_mask = AndNot(Xor(a, b), Xor(a, sum)); + const auto overflow_result = + Xor(BroadcastSignBit(a), Set(d, LimitsMax())); + return IfNegativeThenElse(overflow_mask, overflow_result, sum); +} + +template )> +HWY_API V SaturatedSub(V a, V b) { + const DFromV d; + const auto diff = Sub(a, b); + const auto overflow_mask = And(Xor(a, b), Xor(a, diff)); + const auto overflow_result = + Xor(BroadcastSignBit(a), Set(d, LimitsMax())); + return IfNegativeThenElse(overflow_mask, overflow_result, diff); +} + +#endif // HWY_NATIVE_I32_SATURATED_ADDSUB + +#if (defined(HWY_NATIVE_I64_SATURATED_ADDSUB) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_I64_SATURATED_ADDSUB +#undef HWY_NATIVE_I64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I64_SATURATED_ADDSUB +#endif + +template )> +HWY_API V SaturatedAdd(V a, V b) { + const DFromV d; + const auto sum = Add(a, b); + const auto overflow_mask = AndNot(Xor(a, b), Xor(a, sum)); + const auto overflow_result = + Xor(BroadcastSignBit(a), Set(d, LimitsMax())); + return IfNegativeThenElse(overflow_mask, overflow_result, sum); +} + +template )> +HWY_API V SaturatedSub(V a, V b) { + const DFromV d; + const auto diff = Sub(a, b); + const auto overflow_mask = And(Xor(a, b), Xor(a, diff)); + const auto overflow_result = + Xor(BroadcastSignBit(a), Set(d, LimitsMax())); + return IfNegativeThenElse(overflow_mask, overflow_result, diff); +} + +#endif // HWY_NATIVE_I64_SATURATED_ADDSUB + +#if (defined(HWY_NATIVE_U32_SATURATED_ADDSUB) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_U32_SATURATED_ADDSUB +#undef HWY_NATIVE_U32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_U32_SATURATED_ADDSUB +#endif + +template )> +HWY_API V SaturatedAdd(V a, V b) { + return Add(a, Min(b, Not(a))); +} + +template )> +HWY_API V SaturatedSub(V a, V b) { + return Sub(a, Min(a, b)); +} + +#endif // HWY_NATIVE_U32_SATURATED_ADDSUB + +#if (defined(HWY_NATIVE_U64_SATURATED_ADDSUB) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_U64_SATURATED_ADDSUB +#undef HWY_NATIVE_U64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_U64_SATURATED_ADDSUB +#endif + +template )> +HWY_API V SaturatedAdd(V a, V b) { + return Add(a, Min(b, Not(a))); +} + +template )> +HWY_API V SaturatedSub(V a, V b) { + return Sub(a, Min(a, b)); +} + +#endif // HWY_NATIVE_U64_SATURATED_ADDSUB + +// ------------------------------ Unsigned to signed demotions + +template , DN>>, + hwy::EnableIf<(sizeof(TFromD) < sizeof(TFromV))>* = nullptr, + HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_D(DFromV))> +HWY_API VFromD DemoteTo(DN dn, V v) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned dn_u; + + // First, do a signed to signed demotion. This will convert any values + // that are greater than hwy::HighestValue>>() to a + // negative value. + const auto i2i_demote_result = DemoteTo(dn, BitCast(di, v)); + + // Second, convert any negative values to hwy::HighestValue>() + // using an unsigned Min operation. + const auto max_signed_val = Set(dn, hwy::HighestValue>()); + + return BitCast( + dn, Min(BitCast(dn_u, i2i_demote_result), BitCast(dn_u, max_signed_val))); +} + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template , DN>>, + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_D(DFromV))> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned dn_u; + + // First, do a signed to signed demotion. This will convert any values + // that are greater than hwy::HighestValue>>() to a + // negative value. + const auto i2i_demote_result = + ReorderDemote2To(dn, BitCast(di, a), BitCast(di, b)); + + // Second, convert any negative values to hwy::HighestValue>() + // using an unsigned Min operation. + const auto max_signed_val = Set(dn, hwy::HighestValue>()); + + return BitCast( + dn, Min(BitCast(dn_u, i2i_demote_result), BitCast(dn_u, max_signed_val))); +} +#endif + +// ------------------------------ PromoteLowerTo + +// There is no codegen advantage for a native version of this. It is provided +// only for convenience. +template +HWY_API VFromD PromoteLowerTo(D d, V v) { + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V + // because it cannot be deduced from D (could be either bf16 or f16). + const Rebind, decltype(d)> dh; + return PromoteTo(d, LowerHalf(dh, v)); +} + +// ------------------------------ PromoteUpperTo + +#if (defined(HWY_NATIVE_PROMOTE_UPPER_TO) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_PROMOTE_UPPER_TO +#undef HWY_NATIVE_PROMOTE_UPPER_TO +#else +#define HWY_NATIVE_PROMOTE_UPPER_TO +#endif + +// This requires UpperHalf. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + +template +HWY_API VFromD PromoteUpperTo(D d, V v) { + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V + // because it cannot be deduced from D (could be either bf16 or f16). + const Rebind, decltype(d)> dh; + return PromoteTo(d, UpperHalf(dh, v)); +} + +#endif // HWY_TARGET != HWY_SCALAR +#endif // HWY_NATIVE_PROMOTE_UPPER_TO + +// ------------------------------ float16_t <-> float + +#if (defined(HWY_NATIVE_F16C) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_F16C +#undef HWY_NATIVE_F16C +#else +#define HWY_NATIVE_F16C +#endif + +template +HWY_API VFromD PromoteTo(D df32, VFromD> v) { + const RebindToSigned di32; + const RebindToUnsigned du32; + const Rebind du16; + using VU32 = VFromD; + + const VU32 bits16 = PromoteTo(du32, BitCast(du16, v)); + const VU32 sign = ShiftRight<15>(bits16); + const VU32 biased_exp = And(ShiftRight<10>(bits16), Set(du32, 0x1F)); + const VU32 mantissa = And(bits16, Set(du32, 0x3FF)); + const VU32 subnormal = + BitCast(du32, Mul(ConvertTo(df32, BitCast(di32, mantissa)), + Set(df32, 1.0f / 16384 / 1024))); + + const VU32 biased_exp32 = Add(biased_exp, Set(du32, 127 - 15)); + const VU32 mantissa32 = ShiftLeft<23 - 10>(mantissa); + const VU32 normal = Or(ShiftLeft<23>(biased_exp32), mantissa32); + const VU32 bits32 = IfThenElse(Eq(biased_exp, Zero(du32)), subnormal, normal); + return BitCast(df32, Or(ShiftLeft<31>(sign), bits32)); +} + +template +HWY_API VFromD DemoteTo(D df16, VFromD> v) { + const RebindToSigned di16; + const Rebind di32; + const RebindToFloat df32; + const RebindToUnsigned du32; + + // There are 23 fractional bits (plus the implied 1 bit) in the mantissa of + // a F32, and there are 10 fractional bits (plus the implied 1 bit) in the + // mantissa of a F16 + + // We want the unbiased exponent of round_incr[i] to be at least (-14) + 13 as + // 2^(-14) is the smallest positive normal F16 value and as we want 13 + // mantissa bits (including the implicit 1 bit) to the left of the + // F32 mantissa bits in rounded_val[i] since 23 - 10 is equal to 13 + + // The biased exponent of round_incr[i] needs to be at least 126 as + // (-14) + 13 + 127 is equal to 126 + + // We also want to biased exponent of round_incr[i] to be less than or equal + // to 255 (which is equal to MaxExponentField()) + + // The biased F32 exponent of round_incr is equal to + // HWY_MAX(HWY_MIN(((exp_bits[i] >> 23) & 255) + 13, 255), 126) + + // hi9_bits[i] is equal to the upper 9 bits of v[i] + const auto hi9_bits = ShiftRight<23>(BitCast(du32, v)); + + const auto k13 = Set(du32, uint32_t{13u}); + + // Minimum biased F32 exponent of round_incr + const auto k126 = Set(du32, uint32_t{126u}); + + // round_incr_hi9_bits[i] is equivalent to + // (hi9_bits[i] & 0x100) | + // HWY_MAX(HWY_MIN((hi9_bits[i] & 0xFF) + 13, 255), 126) + +#if HWY_TARGET == HWY_SCALAR || HWY_TARGET == HWY_EMU128 + const auto k255 = Set(du32, uint32_t{255u}); + const auto round_incr_hi9_bits = BitwiseIfThenElse( + k255, Max(Min(Add(And(hi9_bits, k255), k13), k255), k126), hi9_bits); +#else + // On targets other than SCALAR and EMU128, the exponent bits of hi9_bits can + // be incremented by 13 and clamped to the [13, 255] range without overflowing + // into the sign bit of hi9_bits by using U8 SaturatedAdd as there are 8 + // exponent bits in an F32 + + // U8 Max can be used on targets other than SCALAR and EMU128 to clamp + // ((hi9_bits & 0xFF) + 13) to the [126, 255] range without affecting the sign + // bit + + const Repartition du32_as_u8; + const auto round_incr_hi9_bits = BitCast( + du32, + Max(SaturatedAdd(BitCast(du32_as_u8, hi9_bits), BitCast(du32_as_u8, k13)), + BitCast(du32_as_u8, k126))); +#endif + + // (round_incr_hi9_bits >> 8) is equal to (hi9_bits >> 8), and + // (round_incr_hi9_bits & 0xFF) is equal to + // HWY_MAX(HWY_MIN((round_incr_hi9_bits & 0xFF) + 13, 255), 126) + + const auto round_incr = BitCast(df32, ShiftLeft<23>(round_incr_hi9_bits)); + + // Add round_incr[i] to v[i] to round the mantissa to the nearest F16 mantissa + // and to move the fractional bits of the resulting non-NaN mantissa down to + // the lower 10 bits of rounded_val if (v[i] + round_incr[i]) is a non-NaN + // value + const auto rounded_val = Add(v, round_incr); + + // rounded_val_bits is the bits of rounded_val as a U32 + const auto rounded_val_bits = BitCast(du32, rounded_val); + + // rounded_val[i] is known to have the same biased exponent as round_incr[i] + // as |round_incr[i]| > 2^12*|v[i]| is true if round_incr[i] is a finite + // value, round_incr[i] and v[i] both have the same sign, and |round_incr[i]| + // is either a power of 2 that is greater than or equal to 2^-1 or infinity. + + // If rounded_val[i] is a finite F32 value, then + // (rounded_val_bits[i] & 0x00000FFF) is the bit representation of the + // rounded mantissa of rounded_val[i] as a UQ2.10 fixed point number that is + // in the range [0, 2]. + + // In other words, (rounded_val_bits[i] & 0x00000FFF) is between 0 and 0x0800, + // with (rounded_val_bits[i] & 0x000003FF) being the fractional bits of the + // resulting F16 mantissa, if rounded_v[i] is a finite F32 value. + + // (rounded_val_bits[i] & 0x007FF000) == 0 is guaranteed to be true if + // rounded_val[i] is a non-NaN value + + // The biased exponent of rounded_val[i] is guaranteed to be at least 126 as + // the biased exponent of round_incr[i] is at least 126 and as both v[i] and + // round_incr[i] have the same sign bit + + // The ULP of a F32 value with a biased exponent of 126 is equal to + // 2^(126 - 127 - 23), which is equal to 2^(-24) (which is also the ULP of a + // F16 value with a biased exponent of 0 or 1 as (1 - 15 - 10) is equal to + // -24) + + // The biased exponent (before subtracting by 126) needs to be clamped to the + // [126, 157] range as 126 + 31 is equal to 157 and as 31 is the largest + // biased exponent of a F16. + + // The biased exponent of the resulting F16 value is equal to + // HWY_MIN((round_incr_hi9_bits[i] & 0xFF) + + // ((rounded_val_bits[i] >> 10) & 0xFF), 157) - 126 + +#if HWY_TARGET == HWY_SCALAR || HWY_TARGET == HWY_EMU128 + const auto k157Shl10 = Set(du32, static_cast(uint32_t{157u} << 10)); + auto f16_exp_bits = + Min(Add(ShiftLeft<10>(And(round_incr_hi9_bits, k255)), + And(rounded_val_bits, + Set(du32, static_cast(uint32_t{0xFFu} << 10)))), + k157Shl10); + const auto f16_result_is_inf_mask = + RebindMask(df32, Eq(f16_exp_bits, k157Shl10)); +#else + const auto k157 = Set(du32, uint32_t{157}); + auto f16_exp_bits = BitCast( + du32, + Min(SaturatedAdd(BitCast(du32_as_u8, round_incr_hi9_bits), + BitCast(du32_as_u8, ShiftRight<10>(rounded_val_bits))), + BitCast(du32_as_u8, k157))); + const auto f16_result_is_inf_mask = RebindMask(df32, Eq(f16_exp_bits, k157)); + f16_exp_bits = ShiftLeft<10>(f16_exp_bits); +#endif + + f16_exp_bits = + Sub(f16_exp_bits, Set(du32, static_cast(uint32_t{126u} << 10))); + + const auto f16_unmasked_mant_bits = + BitCast(di32, Or(IfThenZeroElse(f16_result_is_inf_mask, rounded_val), + VecFromMask(df32, IsNaN(rounded_val)))); + + const auto f16_exp_mant_bits = + OrAnd(BitCast(di32, f16_exp_bits), f16_unmasked_mant_bits, + Set(di32, int32_t{0x03FF})); + + // f16_bits_as_i32 is the F16 bits sign-extended to an I32 (with the upper 17 + // bits of f16_bits_as_i32[i] set to the sign bit of rounded_val[i]) to allow + // efficient truncation of the F16 bits to an I16 using an I32->I16 DemoteTo + // operation + const auto f16_bits_as_i32 = + OrAnd(f16_exp_mant_bits, ShiftRight<16>(BitCast(di32, rounded_val_bits)), + Set(di32, static_cast(0xFFFF8000u))); + return BitCast(df16, DemoteTo(di16, f16_bits_as_i32)); +} + +#endif // HWY_NATIVE_F16C + +// ------------------------------ F64->F16 DemoteTo +#if (defined(HWY_NATIVE_DEMOTE_F64_TO_F16) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_DEMOTE_F64_TO_F16 +#undef HWY_NATIVE_DEMOTE_F64_TO_F16 +#else +#define HWY_NATIVE_DEMOTE_F64_TO_F16 +#endif + +#if HWY_HAVE_FLOAT64 +template +HWY_API VFromD DemoteTo(D df16, VFromD> v) { + const Rebind df64; + const Rebind du64; + const Rebind df32; + + // The mantissa bits of v[i] are first rounded using round-to-odd rounding to + // the nearest F64 value that has the lower 29 bits zeroed out to ensure that + // the result is correctly rounded to a F16. + + const auto vf64_rounded = OrAnd( + And(v, + BitCast(df64, Set(du64, static_cast(0xFFFFFFFFE0000000u)))), + BitCast(df64, Add(BitCast(du64, v), + Set(du64, static_cast(0x000000001FFFFFFFu)))), + BitCast(df64, Set(du64, static_cast(0x0000000020000000ULL)))); + + return DemoteTo(df16, DemoteTo(df32, vf64_rounded)); +} +#endif // HWY_HAVE_FLOAT64 + +#endif // HWY_NATIVE_DEMOTE_F64_TO_F16 + +// ------------------------------ F16->F64 PromoteTo +#if (defined(HWY_NATIVE_PROMOTE_F16_TO_F64) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_PROMOTE_F16_TO_F64 +#undef HWY_NATIVE_PROMOTE_F16_TO_F64 +#else +#define HWY_NATIVE_PROMOTE_F16_TO_F64 +#endif + +#if HWY_HAVE_FLOAT64 +template +HWY_API VFromD PromoteTo(D df64, VFromD> v) { + return PromoteTo(df64, PromoteTo(Rebind(), v)); +} +#endif // HWY_HAVE_FLOAT64 + +#endif // HWY_NATIVE_PROMOTE_F16_TO_F64 + +// ------------------------------ F32 to BF16 DemoteTo +#if (defined(HWY_NATIVE_DEMOTE_F32_TO_BF16) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + +namespace detail { + +// Round a F32 value to the nearest BF16 value, with the result returned as the +// rounded F32 value bitcasted to an U32 + +// RoundF32ForDemoteToBF16 also converts NaN values to QNaN values to prevent +// NaN F32 values from being converted to an infinity +template )> +HWY_INLINE VFromD>> RoundF32ForDemoteToBF16(V v) { + const DFromV d; + const RebindToUnsigned du32; + + const auto is_non_nan = Not(IsNaN(v)); + const auto bits32 = BitCast(du32, v); + + const auto round_incr = + Add(And(ShiftRight<16>(bits32), Set(du32, uint32_t{1})), + Set(du32, uint32_t{0x7FFFu})); + return MaskedAddOr(Or(bits32, Set(du32, uint32_t{0x00400000u})), + RebindMask(du32, is_non_nan), bits32, round_incr); +} + +} // namespace detail + +template +HWY_API VFromD DemoteTo(D dbf16, VFromD> v) { + const RebindToUnsigned du16; + const Twice dt_u16; + + const auto rounded_bits = BitCast(dt_u16, detail::RoundF32ForDemoteToBF16(v)); +#if HWY_IS_LITTLE_ENDIAN + return BitCast( + dbf16, LowerHalf(du16, ConcatOdd(dt_u16, rounded_bits, rounded_bits))); +#else + return BitCast( + dbf16, LowerHalf(du16, ConcatEven(dt_u16, rounded_bits, rounded_bits))); +#endif +} + +template +HWY_API VFromD OrderedDemote2To(D dbf16, VFromD> a, + VFromD> b) { + const RebindToUnsigned du16; + + const auto rounded_a_bits32 = + BitCast(du16, detail::RoundF32ForDemoteToBF16(a)); + const auto rounded_b_bits32 = + BitCast(du16, detail::RoundF32ForDemoteToBF16(b)); +#if HWY_IS_LITTLE_ENDIAN + return BitCast(dbf16, ConcatOdd(du16, BitCast(du16, rounded_b_bits32), + BitCast(du16, rounded_a_bits32))); +#else + return BitCast(dbf16, ConcatEven(du16, BitCast(du16, rounded_b_bits32), + BitCast(du16, rounded_a_bits32))); +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D dbf16, VFromD> a, + VFromD> b) { + const RebindToUnsigned du16; + +#if HWY_IS_LITTLE_ENDIAN + const auto a_in_odd = detail::RoundF32ForDemoteToBF16(a); + const auto b_in_even = ShiftRight<16>(detail::RoundF32ForDemoteToBF16(b)); +#else + const auto a_in_odd = ShiftRight<16>(detail::RoundF32ForDemoteToBF16(a)); + const auto b_in_even = detail::RoundF32ForDemoteToBF16(b); +#endif + + return BitCast(dbf16, + OddEven(BitCast(du16, a_in_odd), BitCast(du16, b_in_even))); +} + +#endif // HWY_NATIVE_DEMOTE_F32_TO_BF16 + +// ------------------------------ PromoteInRangeTo +#if (defined(HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#endif + +#if HWY_HAVE_INTEGER64 +template +HWY_API VFromD PromoteInRangeTo(D64 d64, VFromD> v) { + return PromoteTo(d64, v); +} +#endif + +#endif // HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO + +// ------------------------------ ConvertInRangeTo +#if (defined(HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#undef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#else +#define HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#endif + +template +HWY_API VFromD ConvertInRangeTo(DI di, VFromD> v) { + return ConvertTo(di, v); +} + +#endif // HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO + +// ------------------------------ DemoteInRangeTo +#if (defined(HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#endif + +#if HWY_HAVE_FLOAT64 +template +HWY_API VFromD DemoteInRangeTo(D32 d32, VFromD> v) { + return DemoteTo(d32, v); +} +#endif + +#endif // HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO + +// ------------------------------ PromoteInRangeLowerTo/PromoteInRangeUpperTo + +template )> +HWY_API VFromD PromoteInRangeLowerTo(D d, V v) { + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V + // because it cannot be deduced from D (could be either bf16 or f16). + const Rebind, decltype(d)> dh; + return PromoteInRangeTo(d, LowerHalf(dh, v)); +} + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template )> +HWY_API VFromD PromoteInRangeUpperTo(D d, V v) { +#if (HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_EMU128 || \ + (HWY_TARGET_IS_NEON && !HWY_HAVE_FLOAT64)) + // On targets that provide target-specific implementations of F32->UI64 + // PromoteInRangeTo, promote the upper half of v using PromoteInRangeTo + + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V + // because it cannot be deduced from D (could be either bf16 or f16). + const Rebind, decltype(d)> dh; + return PromoteInRangeTo(d, UpperHalf(dh, v)); +#else + // Otherwise, on targets where F32->UI64 PromoteInRangeTo is simply a wrapper + // around F32->UI64 PromoteTo, promote the upper half of v to TFromD using + // PromoteUpperTo + return PromoteUpperTo(d, v); +#endif +} +#endif // HWY_TARGET != HWY_SCALAR + +// ------------------------------ PromoteInRangeEvenTo/PromoteInRangeOddTo + +template )> +HWY_API VFromD PromoteInRangeEvenTo(D d, V v) { +#if HWY_TARGET == HWY_SCALAR + return PromoteInRangeTo(d, v); +#elif (HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_EMU128 || \ + (HWY_TARGET_IS_NEON && !HWY_HAVE_FLOAT64)) + // On targets that provide target-specific implementations of F32->UI64 + // PromoteInRangeTo, promote the even lanes of v using PromoteInRangeTo + + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V + // because it cannot be deduced from D (could be either bf16 or f16). + const DFromV d_from; + const Rebind, decltype(d)> dh; + return PromoteInRangeTo(d, LowerHalf(dh, ConcatEven(d_from, v, v))); +#else + // Otherwise, on targets where F32->UI64 PromoteInRangeTo is simply a wrapper + // around F32->UI64 PromoteTo, promote the even lanes of v to TFromD using + // PromoteEvenTo + return PromoteEvenTo(d, v); +#endif // HWY_TARGET == HWY_SCALAR +} + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template )> +HWY_API VFromD PromoteInRangeOddTo(D d, V v) { +#if (HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_EMU128 || \ + (HWY_TARGET_IS_NEON && !HWY_HAVE_FLOAT64)) + // On targets that provide target-specific implementations of F32->UI64 + // PromoteInRangeTo, promote the odd lanes of v using PromoteInRangeTo + + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type from V + // because it cannot be deduced from D (could be either bf16 or f16). + const DFromV d_from; + const Rebind, decltype(d)> dh; + return PromoteInRangeTo(d, LowerHalf(dh, ConcatOdd(d_from, v, v))); +#else + // Otherwise, on targets where F32->UI64 PromoteInRangeTo is simply a wrapper + // around F32->UI64 PromoteTo, promote the odd lanes of v to TFromD using + // PromoteOddTo + return PromoteOddTo(d, v); +#endif +} +#endif // HWY_TARGET != HWY_SCALAR + +// ------------------------------ SumsOf2 + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +namespace detail { + +template +HWY_INLINE VFromD>> SumsOf2( + TypeTag /*type_tag*/, hwy::SizeTag /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWide dw; + return Add(PromoteEvenTo(dw, v), PromoteOddTo(dw, v)); +} + +} // namespace detail + +template +HWY_API VFromD>> SumsOf2(V v) { + return detail::SumsOf2(hwy::TypeTag>(), + hwy::SizeTag)>(), v); +} +#endif // HWY_TARGET != HWY_SCALAR + +// ------------------------------ SumsOf4 + +namespace detail { + +template +HWY_INLINE VFromD>> SumsOf4( + TypeTag /*type_tag*/, hwy::SizeTag /*lane_size_tag*/, V v) { + using hwy::HWY_NAMESPACE::SumsOf2; + return SumsOf2(SumsOf2(v)); +} + +} // namespace detail + +template +HWY_API VFromD>> SumsOf4(V v) { + return detail::SumsOf4(hwy::TypeTag>(), + hwy::SizeTag)>(), v); +} + +// ------------------------------ OrderedTruncate2To + +#if HWY_IDE || \ + (defined(HWY_NATIVE_ORDERED_TRUNCATE_2_TO) == defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_ORDERED_TRUNCATE_2_TO +#undef HWY_NATIVE_ORDERED_TRUNCATE_2_TO +#else +#define HWY_NATIVE_ORDERED_TRUNCATE_2_TO +#endif + +// (Must come after HWY_TARGET_TOGGLE, else we don't reset it for scalar) +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template ) * 2), + HWY_IF_LANES_D(DFromV>, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedTruncate2To(DN dn, V a, V b) { + return ConcatEven(dn, BitCast(dn, b), BitCast(dn, a)); +} +#endif // HWY_TARGET != HWY_SCALAR +#endif // HWY_NATIVE_ORDERED_TRUNCATE_2_TO + +// -------------------- LeadingZeroCount, TrailingZeroCount, HighestSetBitIndex + +#if (defined(HWY_NATIVE_LEADING_ZERO_COUNT) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_LEADING_ZERO_COUNT +#undef HWY_NATIVE_LEADING_ZERO_COUNT +#else +#define HWY_NATIVE_LEADING_ZERO_COUNT +#endif + +namespace detail { + +template +HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { + const RebindToFloat df; +#if HWY_TARGET > HWY_AVX3 && HWY_TARGET <= HWY_SSE2 + const RebindToSigned di; + const Repartition di16; + + // On SSE2/SSSE3/SSE4/AVX2, do an int32_t to float conversion, followed + // by a unsigned right shift of the uint32_t bit representation of the + // floating point values by 23, followed by an int16_t Min + // operation as we are only interested in the biased exponent that would + // result from a uint32_t to float conversion. + + // An int32_t to float vector conversion is also much more efficient on + // SSE2/SSSE3/SSE4/AVX2 than an uint32_t vector to float vector conversion + // as an uint32_t vector to float vector conversion on SSE2/SSSE3/SSE4/AVX2 + // requires multiple instructions whereas an int32_t to float vector + // conversion can be carried out using a single instruction on + // SSE2/SSSE3/SSE4/AVX2. + + const auto f32_bits = BitCast(d, ConvertTo(df, BitCast(di, v))); + return BitCast(d, Min(BitCast(di16, ShiftRight<23>(f32_bits)), + BitCast(di16, Set(d, 158)))); +#else + const auto f32_bits = BitCast(d, ConvertTo(df, v)); + return BitCast(d, ShiftRight<23>(f32_bits)); +#endif +} + +template )> +HWY_INLINE V I32RangeU32ToF32BiasedExp(V v) { + // I32RangeU32ToF32BiasedExp is similar to UIntToF32BiasedExp, but + // I32RangeU32ToF32BiasedExp assumes that v[i] is between 0 and 2147483647. + const DFromV d; + const RebindToFloat df; +#if HWY_TARGET > HWY_AVX3 && HWY_TARGET <= HWY_SSE2 + const RebindToSigned d_src; +#else + const RebindToUnsigned d_src; +#endif + const auto f32_bits = BitCast(d, ConvertTo(df, BitCast(d_src, v))); + return ShiftRight<23>(f32_bits); +} + +template +HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { + const Rebind du32; + const auto f32_biased_exp_as_u32 = + I32RangeU32ToF32BiasedExp(PromoteTo(du32, v)); + return TruncateTo(d, f32_biased_exp_as_u32); +} + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template +HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { + const Half dh; + const Rebind du32; + + const auto lo_u32 = PromoteTo(du32, LowerHalf(dh, v)); + const auto hi_u32 = PromoteTo(du32, UpperHalf(dh, v)); + + const auto lo_f32_biased_exp_as_u32 = I32RangeU32ToF32BiasedExp(lo_u32); + const auto hi_f32_biased_exp_as_u32 = I32RangeU32ToF32BiasedExp(hi_u32); +#if HWY_TARGET <= HWY_SSE2 + const RebindToSigned di32; + const RebindToSigned di; + return BitCast(d, + OrderedDemote2To(di, BitCast(di32, lo_f32_biased_exp_as_u32), + BitCast(di32, hi_f32_biased_exp_as_u32))); +#else + return OrderedTruncate2To(d, lo_f32_biased_exp_as_u32, + hi_f32_biased_exp_as_u32); +#endif +} +#endif // HWY_TARGET != HWY_SCALAR + +template +HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { + const Rebind du32; + const auto f32_biased_exp_as_u32 = + I32RangeU32ToF32BiasedExp(PromoteTo(du32, v)); + return U8FromU32(f32_biased_exp_as_u32); +} + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template +HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { + const Half dh; + const Rebind du32; + const Repartition du16; + + const auto lo_u32 = PromoteTo(du32, LowerHalf(dh, v)); + const auto hi_u32 = PromoteTo(du32, UpperHalf(dh, v)); + + const auto lo_f32_biased_exp_as_u32 = I32RangeU32ToF32BiasedExp(lo_u32); + const auto hi_f32_biased_exp_as_u32 = I32RangeU32ToF32BiasedExp(hi_u32); + +#if HWY_TARGET <= HWY_SSE2 + const RebindToSigned di32; + const RebindToSigned di16; + const auto f32_biased_exp_as_i16 = + OrderedDemote2To(di16, BitCast(di32, lo_f32_biased_exp_as_u32), + BitCast(di32, hi_f32_biased_exp_as_u32)); + return DemoteTo(d, f32_biased_exp_as_i16); +#else + const auto f32_biased_exp_as_u16 = OrderedTruncate2To( + du16, lo_f32_biased_exp_as_u32, hi_f32_biased_exp_as_u32); + return TruncateTo(d, f32_biased_exp_as_u16); +#endif +} + +template +HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { + const Half dh; + const Half dq; + const Rebind du32; + const Repartition du16; + + const auto lo_half = LowerHalf(dh, v); + const auto hi_half = UpperHalf(dh, v); + + const auto u32_q0 = PromoteTo(du32, LowerHalf(dq, lo_half)); + const auto u32_q1 = PromoteTo(du32, UpperHalf(dq, lo_half)); + const auto u32_q2 = PromoteTo(du32, LowerHalf(dq, hi_half)); + const auto u32_q3 = PromoteTo(du32, UpperHalf(dq, hi_half)); + + const auto f32_biased_exp_as_u32_q0 = I32RangeU32ToF32BiasedExp(u32_q0); + const auto f32_biased_exp_as_u32_q1 = I32RangeU32ToF32BiasedExp(u32_q1); + const auto f32_biased_exp_as_u32_q2 = I32RangeU32ToF32BiasedExp(u32_q2); + const auto f32_biased_exp_as_u32_q3 = I32RangeU32ToF32BiasedExp(u32_q3); + +#if HWY_TARGET <= HWY_SSE2 + const RebindToSigned di32; + const RebindToSigned di16; + + const auto lo_f32_biased_exp_as_i16 = + OrderedDemote2To(di16, BitCast(di32, f32_biased_exp_as_u32_q0), + BitCast(di32, f32_biased_exp_as_u32_q1)); + const auto hi_f32_biased_exp_as_i16 = + OrderedDemote2To(di16, BitCast(di32, f32_biased_exp_as_u32_q2), + BitCast(di32, f32_biased_exp_as_u32_q3)); + return OrderedDemote2To(d, lo_f32_biased_exp_as_i16, + hi_f32_biased_exp_as_i16); +#else + const auto lo_f32_biased_exp_as_u16 = OrderedTruncate2To( + du16, f32_biased_exp_as_u32_q0, f32_biased_exp_as_u32_q1); + const auto hi_f32_biased_exp_as_u16 = OrderedTruncate2To( + du16, f32_biased_exp_as_u32_q2, f32_biased_exp_as_u32_q3); + return OrderedTruncate2To(d, lo_f32_biased_exp_as_u16, + hi_f32_biased_exp_as_u16); +#endif +} +#endif // HWY_TARGET != HWY_SCALAR + +#if HWY_TARGET == HWY_SCALAR +template +using F32ExpLzcntMinMaxRepartition = RebindToUnsigned; +#elif HWY_TARGET >= HWY_SSSE3 && HWY_TARGET <= HWY_SSE2 +template +using F32ExpLzcntMinMaxRepartition = Repartition; +#else +template +using F32ExpLzcntMinMaxRepartition = + Repartition), 4)>, D>; +#endif + +template +using F32ExpLzcntMinMaxCmpV = VFromD>>; + +template +HWY_INLINE F32ExpLzcntMinMaxCmpV F32ExpLzcntMinMaxBitCast(V v) { + const DFromV d; + const F32ExpLzcntMinMaxRepartition d2; + return BitCast(d2, v); +} + +template +HWY_INLINE VFromD UIntToF32BiasedExp(D d, VFromD v) { +#if HWY_TARGET == HWY_SCALAR + const uint64_t u64_val = GetLane(v); + const float f32_val = static_cast(u64_val); + const uint32_t f32_bits = BitCastScalar(f32_val); + return Set(d, static_cast(f32_bits >> 23)); +#else + const Repartition du32; + const auto f32_biased_exp = UIntToF32BiasedExp(du32, BitCast(du32, v)); + const auto f32_biased_exp_adj = + IfThenZeroElse(Eq(f32_biased_exp, Zero(du32)), + BitCast(du32, Set(d, 0x0000002000000000u))); + const auto adj_f32_biased_exp = Add(f32_biased_exp, f32_biased_exp_adj); + + return ShiftRight<32>(BitCast( + d, Max(F32ExpLzcntMinMaxBitCast(adj_f32_biased_exp), + F32ExpLzcntMinMaxBitCast(Reverse2(du32, adj_f32_biased_exp))))); +#endif +} + +template +HWY_INLINE V UIntToF32BiasedExp(V v) { + const DFromV d; + return UIntToF32BiasedExp(d, v); +} + +template +HWY_INLINE V NormalizeForUIntTruncConvToF32(V v) { + return v; +} + +template +HWY_INLINE V NormalizeForUIntTruncConvToF32(V v) { + // If v[i] >= 16777216 is true, make sure that the bit at + // HighestSetBitIndex(v[i]) - 24 is zeroed out to ensure that any inexact + // conversion to single-precision floating point is rounded down. + + // This zeroing-out can be accomplished through the AndNot operation below. + return AndNot(ShiftRight<24>(v), v); +} + +} // namespace detail + +template +HWY_API V HighestSetBitIndex(V v) { + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; + + const auto f32_biased_exp = detail::UIntToF32BiasedExp( + detail::NormalizeForUIntTruncConvToF32(BitCast(du, v))); + return BitCast(d, Sub(f32_biased_exp, Set(du, TU{127}))); +} + +template +HWY_API V LeadingZeroCount(V v) { + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; + + constexpr TU kNumOfBitsInT{sizeof(TU) * 8}; + const auto f32_biased_exp = detail::UIntToF32BiasedExp( + detail::NormalizeForUIntTruncConvToF32(BitCast(du, v))); + const auto lz_count = Sub(Set(du, TU{kNumOfBitsInT + 126}), f32_biased_exp); + + return BitCast(d, + Min(detail::F32ExpLzcntMinMaxBitCast(lz_count), + detail::F32ExpLzcntMinMaxBitCast(Set(du, kNumOfBitsInT)))); +} + +template +HWY_API V TrailingZeroCount(V v) { + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; + using TU = TFromD; + + const auto vi = BitCast(di, v); + const auto lowest_bit = BitCast(du, And(vi, Neg(vi))); + + constexpr TU kNumOfBitsInT{sizeof(TU) * 8}; + const auto f32_biased_exp = detail::UIntToF32BiasedExp(lowest_bit); + const auto tz_count = Sub(f32_biased_exp, Set(du, TU{127})); + + return BitCast(d, + Min(detail::F32ExpLzcntMinMaxBitCast(tz_count), + detail::F32ExpLzcntMinMaxBitCast(Set(du, kNumOfBitsInT)))); +} +#endif // HWY_NATIVE_LEADING_ZERO_COUNT + +// ------------------------------ MaskedLeadingZeroCount +#if (defined(HWY_NATIVE_MASKED_LEADING_ZERO_COUNT) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT +#undef HWY_NATIVE_MASKED_LEADING_ZERO_COUNT +#else +#define HWY_NATIVE_MASKED_LEADING_ZERO_COUNT +#endif + +template +HWY_API V MaskedLeadingZeroCount(M m, V v) { + return IfThenElseZero(m, LeadingZeroCount(v)); +} +#endif // HWY_NATIVE_MASKED_LEADING_ZERO_COUNT + +// ------------------------------ AESRound + +// Cannot implement on scalar: need at least 16 bytes for TableLookupBytes. +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + +// Define for white-box testing, even if native instructions are available. +namespace detail { + +// Constant-time: computes inverse in GF(2^4) based on "Accelerating AES with +// Vector Permute Instructions" and the accompanying assembly language +// implementation: https://crypto.stanford.edu/vpaes/vpaes.tgz. See also Botan: +// https://botan.randombit.net/doxygen/aes__vperm_8cpp_source.html . +// +// A brute-force 256 byte table lookup can also be made constant-time, and +// possibly competitive on NEON, but this is more performance-portable +// especially for x86 and large vectors. + +template // u8 +HWY_INLINE V SubBytesMulInverseAndAffineLookup(V state, V affine_tblL, + V affine_tblU) { + const DFromV du; + const auto mask = Set(du, uint8_t{0xF}); + + // Change polynomial basis to GF(2^4) + { + const VFromD basisL = + Dup128VecFromValues(du, 0x00, 0x70, 0x2A, 0x5A, 0x98, 0xE8, 0xB2, 0xC2, + 0x08, 0x78, 0x22, 0x52, 0x90, 0xE0, 0xBA, 0xCA); + const VFromD basisU = + Dup128VecFromValues(du, 0x00, 0x4D, 0x7C, 0x31, 0x7D, 0x30, 0x01, 0x4C, + 0x81, 0xCC, 0xFD, 0xB0, 0xFC, 0xB1, 0x80, 0xCD); + const auto sL = And(state, mask); + const auto sU = ShiftRight<4>(state); // byte shift => upper bits are zero + const auto gf4L = TableLookupBytes(basisL, sL); + const auto gf4U = TableLookupBytes(basisU, sU); + state = Xor(gf4L, gf4U); + } + + // Inversion in GF(2^4). Elements 0 represent "infinity" (division by 0) and + // cause TableLookupBytesOr0 to return 0. + const VFromD zetaInv = Dup128VecFromValues( + du, 0x80, 7, 11, 15, 6, 10, 4, 1, 9, 8, 5, 2, 12, 14, 13, 3); + const VFromD tbl = Dup128VecFromValues( + du, 0x80, 1, 8, 13, 15, 6, 5, 14, 2, 12, 11, 10, 9, 3, 7, 4); + const auto sL = And(state, mask); // L=low nibble, U=upper + const auto sU = ShiftRight<4>(state); // byte shift => upper bits are zero + const auto sX = Xor(sU, sL); + const auto invL = TableLookupBytes(zetaInv, sL); + const auto invU = TableLookupBytes(tbl, sU); + const auto invX = TableLookupBytes(tbl, sX); + const auto outL = Xor(sX, TableLookupBytesOr0(tbl, Xor(invL, invU))); + const auto outU = Xor(sU, TableLookupBytesOr0(tbl, Xor(invL, invX))); + + const auto affL = TableLookupBytesOr0(affine_tblL, outL); + const auto affU = TableLookupBytesOr0(affine_tblU, outU); + return Xor(affL, affU); +} + +template // u8 +HWY_INLINE V SubBytes(V state) { + const DFromV du; + // Linear skew (cannot bake 0x63 bias into the table because out* indices + // may have the infinity flag set). + const VFromD affineL = + Dup128VecFromValues(du, 0x00, 0xC7, 0xBD, 0x6F, 0x17, 0x6D, 0xD2, 0xD0, + 0x78, 0xA8, 0x02, 0xC5, 0x7A, 0xBF, 0xAA, 0x15); + const VFromD affineU = + Dup128VecFromValues(du, 0x00, 0x6A, 0xBB, 0x5F, 0xA5, 0x74, 0xE4, 0xCF, + 0xFA, 0x35, 0x2B, 0x41, 0xD1, 0x90, 0x1E, 0x8E); + return Xor(SubBytesMulInverseAndAffineLookup(state, affineL, affineU), + Set(du, uint8_t{0x63})); +} + +template // u8 +HWY_INLINE V InvSubBytes(V state) { + const DFromV du; + const VFromD gF2P4InvToGF2P8InvL = + Dup128VecFromValues(du, 0x00, 0x40, 0xF9, 0x7E, 0x53, 0xEA, 0x87, 0x13, + 0x2D, 0x3E, 0x94, 0xD4, 0xB9, 0x6D, 0xAA, 0xC7); + const VFromD gF2P4InvToGF2P8InvU = + Dup128VecFromValues(du, 0x00, 0x1D, 0x44, 0x93, 0x0F, 0x56, 0xD7, 0x12, + 0x9C, 0x8E, 0xC5, 0xD8, 0x59, 0x81, 0x4B, 0xCA); + + // Apply the inverse affine transformation + const auto b = Xor(Xor3(Or(ShiftLeft<1>(state), ShiftRight<7>(state)), + Or(ShiftLeft<3>(state), ShiftRight<5>(state)), + Or(ShiftLeft<6>(state), ShiftRight<2>(state))), + Set(du, uint8_t{0x05})); + + // The GF(2^8) multiplicative inverse is computed as follows: + // - Changing the polynomial basis to GF(2^4) + // - Computing the GF(2^4) multiplicative inverse + // - Converting the GF(2^4) multiplicative inverse to the GF(2^8) + // multiplicative inverse through table lookups using the + // kGF2P4InvToGF2P8InvL and kGF2P4InvToGF2P8InvU tables + return SubBytesMulInverseAndAffineLookup(b, gF2P4InvToGF2P8InvL, + gF2P4InvToGF2P8InvU); +} + +} // namespace detail + +#endif // HWY_TARGET != HWY_SCALAR + +#if (defined(HWY_NATIVE_AES) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +// (Must come after HWY_TARGET_TOGGLE, else we don't reset it for scalar) +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + +namespace detail { + +template // u8 +HWY_INLINE V ShiftRows(const V state) { + const DFromV du; + // transposed: state is column major + const VFromD shift_row = Dup128VecFromValues( + du, 0, 5, 10, 15, 4, 9, 14, 3, 8, 13, 2, 7, 12, 1, 6, 11); + return TableLookupBytes(state, shift_row); +} + +template // u8 +HWY_INLINE V InvShiftRows(const V state) { + const DFromV du; + // transposed: state is column major + const VFromD shift_row = Dup128VecFromValues( + du, 0, 13, 10, 7, 4, 1, 14, 11, 8, 5, 2, 15, 12, 9, 6, 3); + return TableLookupBytes(state, shift_row); +} + +template // u8 +HWY_INLINE V GF2P8Mod11BMulBy2(V v) { + const DFromV du; + const RebindToSigned di; // can only do signed comparisons + const auto msb = Lt(BitCast(di, v), Zero(di)); + const auto overflow = BitCast(du, IfThenElseZero(msb, Set(di, int8_t{0x1B}))); + return Xor(Add(v, v), overflow); // = v*2 in GF(2^8). +} + +template // u8 +HWY_INLINE V MixColumns(const V state) { + const DFromV du; + // For each column, the rows are the sum of GF(2^8) matrix multiplication by: + // 2 3 1 1 // Let s := state*1, d := state*2, t := state*3. + // 1 2 3 1 // d are on diagonal, no permutation needed. + // 1 1 2 3 // t1230 indicates column indices of threes for the 4 rows. + // 3 1 1 2 // We also need to compute s2301 and s3012 (=1230 o 2301). + const VFromD v2301 = Dup128VecFromValues( + du, 2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13); + const VFromD v1230 = Dup128VecFromValues( + du, 1, 2, 3, 0, 5, 6, 7, 4, 9, 10, 11, 8, 13, 14, 15, 12); + const auto d = GF2P8Mod11BMulBy2(state); // = state*2 in GF(2^8). + const auto s2301 = TableLookupBytes(state, v2301); + const auto d_s2301 = Xor(d, s2301); + const auto t_s2301 = Xor(state, d_s2301); // t(s*3) = XOR-sum {s, d(s*2)} + const auto t1230_s3012 = TableLookupBytes(t_s2301, v1230); + return Xor(d_s2301, t1230_s3012); // XOR-sum of 4 terms +} + +template // u8 +HWY_INLINE V InvMixColumns(const V state) { + const DFromV du; + // For each column, the rows are the sum of GF(2^8) matrix multiplication by: + // 14 11 13 9 + // 9 14 11 13 + // 13 9 14 11 + // 11 13 9 14 + const VFromD v2301 = Dup128VecFromValues( + du, 2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13); + const VFromD v1230 = Dup128VecFromValues( + du, 1, 2, 3, 0, 5, 6, 7, 4, 9, 10, 11, 8, 13, 14, 15, 12); + + const auto sx2 = GF2P8Mod11BMulBy2(state); /* = state*2 in GF(2^8) */ + const auto sx4 = GF2P8Mod11BMulBy2(sx2); /* = state*4 in GF(2^8) */ + const auto sx8 = GF2P8Mod11BMulBy2(sx4); /* = state*8 in GF(2^8) */ + const auto sx9 = Xor(sx8, state); /* = state*9 in GF(2^8) */ + const auto sx11 = Xor(sx9, sx2); /* = state*11 in GF(2^8) */ + const auto sx13 = Xor(sx9, sx4); /* = state*13 in GF(2^8) */ + const auto sx14 = Xor3(sx8, sx4, sx2); /* = state*14 in GF(2^8) */ + + const auto sx13_0123_sx9_1230 = Xor(sx13, TableLookupBytes(sx9, v1230)); + const auto sx14_0123_sx11_1230 = Xor(sx14, TableLookupBytes(sx11, v1230)); + const auto sx13_2301_sx9_3012 = TableLookupBytes(sx13_0123_sx9_1230, v2301); + return Xor(sx14_0123_sx11_1230, sx13_2301_sx9_3012); +} + +} // namespace detail + +template // u8 +HWY_API V AESRound(V state, const V round_key) { + // Intel docs swap the first two steps, but it does not matter because + // ShiftRows is a permutation and SubBytes is independent of lane index. + state = detail::SubBytes(state); + state = detail::ShiftRows(state); + state = detail::MixColumns(state); + state = Xor(state, round_key); // AddRoundKey + return state; +} + +template // u8 +HWY_API V AESLastRound(V state, const V round_key) { + // LIke AESRound, but without MixColumns. + state = detail::SubBytes(state); + state = detail::ShiftRows(state); + state = Xor(state, round_key); // AddRoundKey + return state; +} + +template +HWY_API V AESInvMixColumns(V state) { + return detail::InvMixColumns(state); +} + +template // u8 +HWY_API V AESRoundInv(V state, const V round_key) { + state = detail::InvSubBytes(state); + state = detail::InvShiftRows(state); + state = detail::InvMixColumns(state); + state = Xor(state, round_key); // AddRoundKey + return state; +} + +template // u8 +HWY_API V AESLastRoundInv(V state, const V round_key) { + // Like AESRoundInv, but without InvMixColumns. + state = detail::InvSubBytes(state); + state = detail::InvShiftRows(state); + state = Xor(state, round_key); // AddRoundKey + return state; +} + +template )> +HWY_API V AESKeyGenAssist(V v) { + const DFromV d; + const V rconXorMask = Dup128VecFromValues(d, 0, 0, 0, 0, kRcon, 0, 0, 0, 0, 0, + 0, 0, kRcon, 0, 0, 0); + const V rotWordShuffle = Dup128VecFromValues(d, 4, 5, 6, 7, 5, 6, 7, 4, 12, + 13, 14, 15, 13, 14, 15, 12); + const auto sub_word_result = detail::SubBytes(v); + const auto rot_word_result = + TableLookupBytes(sub_word_result, rotWordShuffle); + return Xor(rot_word_result, rconXorMask); +} + +// Constant-time implementation inspired by +// https://www.bearssl.org/constanttime.html, but about half the cost because we +// use 64x64 multiplies and 128-bit XORs. +template +HWY_API V CLMulLower(V a, V b) { + const DFromV d; + static_assert(IsSame, uint64_t>(), "V must be u64"); + const auto k1 = Set(d, 0x1111111111111111ULL); + const auto k2 = Set(d, 0x2222222222222222ULL); + const auto k4 = Set(d, 0x4444444444444444ULL); + const auto k8 = Set(d, 0x8888888888888888ULL); + const auto a0 = And(a, k1); + const auto a1 = And(a, k2); + const auto a2 = And(a, k4); + const auto a3 = And(a, k8); + const auto b0 = And(b, k1); + const auto b1 = And(b, k2); + const auto b2 = And(b, k4); + const auto b3 = And(b, k8); + + auto m0 = Xor(MulEven(a0, b0), MulEven(a1, b3)); + auto m1 = Xor(MulEven(a0, b1), MulEven(a1, b0)); + auto m2 = Xor(MulEven(a0, b2), MulEven(a1, b1)); + auto m3 = Xor(MulEven(a0, b3), MulEven(a1, b2)); + m0 = Xor(m0, Xor(MulEven(a2, b2), MulEven(a3, b1))); + m1 = Xor(m1, Xor(MulEven(a2, b3), MulEven(a3, b2))); + m2 = Xor(m2, Xor(MulEven(a2, b0), MulEven(a3, b3))); + m3 = Xor(m3, Xor(MulEven(a2, b1), MulEven(a3, b0))); + return Or(Or(And(m0, k1), And(m1, k2)), Or(And(m2, k4), And(m3, k8))); +} + +template +HWY_API V CLMulUpper(V a, V b) { + const DFromV d; + static_assert(IsSame, uint64_t>(), "V must be u64"); + const auto k1 = Set(d, 0x1111111111111111ULL); + const auto k2 = Set(d, 0x2222222222222222ULL); + const auto k4 = Set(d, 0x4444444444444444ULL); + const auto k8 = Set(d, 0x8888888888888888ULL); + const auto a0 = And(a, k1); + const auto a1 = And(a, k2); + const auto a2 = And(a, k4); + const auto a3 = And(a, k8); + const auto b0 = And(b, k1); + const auto b1 = And(b, k2); + const auto b2 = And(b, k4); + const auto b3 = And(b, k8); + + auto m0 = Xor(MulOdd(a0, b0), MulOdd(a1, b3)); + auto m1 = Xor(MulOdd(a0, b1), MulOdd(a1, b0)); + auto m2 = Xor(MulOdd(a0, b2), MulOdd(a1, b1)); + auto m3 = Xor(MulOdd(a0, b3), MulOdd(a1, b2)); + m0 = Xor(m0, Xor(MulOdd(a2, b2), MulOdd(a3, b1))); + m1 = Xor(m1, Xor(MulOdd(a2, b3), MulOdd(a3, b2))); + m2 = Xor(m2, Xor(MulOdd(a2, b0), MulOdd(a3, b3))); + m3 = Xor(m3, Xor(MulOdd(a2, b1), MulOdd(a3, b0))); + return Or(Or(And(m0, k1), And(m1, k2)), Or(And(m2, k4), And(m3, k8))); +} + +#endif // HWY_NATIVE_AES +#endif // HWY_TARGET != HWY_SCALAR + +// ------------------------------ PopulationCount + +#if (defined(HWY_NATIVE_POPCNT) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +template , HWY_IF_U8_D(D)> +HWY_API V PopulationCount(V v) { + const D d; + +#if HWY_TARGET == HWY_SSE2 + // TableLookupBytes is slow on SSE2 + + // See https://arxiv.org/pdf/1611.07612.pdf, Figure 3 + const V k33 = Set(d, uint8_t{0x33}); + v = Sub(v, And(ShiftRight<1>(v), Set(d, uint8_t{0x55}))); + v = Add(And(ShiftRight<2>(v), k33), And(v, k33)); + return And(Add(v, ShiftRight<4>(v)), Set(d, uint8_t{0x0F})); +#else // HWY_TARGET != HWY_SSE2 + +#if HWY_TARGET == HWY_RVV + // Need at least LMUL=1 on RVV to ensure that Lanes(d_tbl) is at least 16 + const ScalableTag d_tbl; +#else + const FixedTag d_tbl; +#endif + + const auto lookup = Dup128VecFromValues(d_tbl, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, + 2, 3, 2, 3, 3, 4); + const auto lo = And(v, Set(d, uint8_t{0xF})); + const auto hi = ShiftRight<4>(v); + +#if HWY_TARGET == HWY_RVV + // On RVV, use TableLookupLanes to avoid unnecessary overhead + const auto hi_popcnt = + ResizeBitCast(d, TableLookupLanes(lookup, ResizeBitCast(d_tbl, hi))); + const auto lo_popcnt = + ResizeBitCast(d, TableLookupLanes(lookup, ResizeBitCast(d_tbl, lo))); +#else // HWY_TARGET != HWY_RVV + const auto hi_popcnt = TableLookupBytes(lookup, hi); + const auto lo_popcnt = TableLookupBytes(lookup, lo); +#endif // HWY_TARGET == HWY_RVV + + return Add(hi_popcnt, lo_popcnt); +#endif // HWY_TARGET == HWY_SSE2 +} + +template , HWY_IF_U16_D(D)> +HWY_API V PopulationCount(V v) { + const D d; + const Repartition d8; + const auto vals = BitCast(d, PopulationCount(BitCast(d8, v))); + return Add(ShiftRight<8>(vals), And(vals, Set(d, uint16_t{0xFF}))); +} + +template , HWY_IF_U32_D(D)> +HWY_API V PopulationCount(V v) { + const D d; + Repartition d16; + auto vals = BitCast(d, PopulationCount(BitCast(d16, v))); + return Add(ShiftRight<16>(vals), And(vals, Set(d, uint32_t{0xFF}))); +} + +#if HWY_HAVE_INTEGER64 +template , HWY_IF_U64_D(D)> +HWY_API V PopulationCount(V v) { + const D d; + Repartition d32; + auto vals = BitCast(d, PopulationCount(BitCast(d32, v))); + return Add(ShiftRight<32>(vals), And(vals, Set(d, 0xFFULL))); +} +#endif + +#endif // HWY_NATIVE_POPCNT + +// ------------------------------ 8-bit multiplication + +#if (defined(HWY_NATIVE_MUL_8) == defined(HWY_TARGET_TOGGLE)) || HWY_IDE +#ifdef HWY_NATIVE_MUL_8 +#undef HWY_NATIVE_MUL_8 +#else +#define HWY_NATIVE_MUL_8 +#endif + +// 8 bit and fits in wider reg: promote +template +HWY_API V operator*(const V a, const V b) { + const DFromV d; + const Rebind>, decltype(d)> dw; + const RebindToUnsigned du; // TruncateTo result + const RebindToUnsigned dwu; // TruncateTo input + const VFromD mul = PromoteTo(dw, a) * PromoteTo(dw, b); + // TruncateTo is cheaper than ConcatEven. + return BitCast(d, TruncateTo(du, BitCast(dwu, mul))); +} + +// 8 bit full reg: promote halves +template +HWY_API V operator*(const V a, const V b) { + const DFromV d; + const Half dh; + const Twice> dw; + const VFromD a0 = PromoteTo(dw, LowerHalf(dh, a)); + const VFromD a1 = PromoteTo(dw, UpperHalf(dh, a)); + const VFromD b0 = PromoteTo(dw, LowerHalf(dh, b)); + const VFromD b1 = PromoteTo(dw, UpperHalf(dh, b)); + const VFromD m0 = a0 * b0; + const VFromD m1 = a1 * b1; + return ConcatEven(d, BitCast(d, m1), BitCast(d, m0)); +} + +#endif // HWY_NATIVE_MUL_8 + +// ------------------------------ 64-bit multiplication + +#if (defined(HWY_NATIVE_MUL_64) == defined(HWY_TARGET_TOGGLE)) || HWY_IDE +#ifdef HWY_NATIVE_MUL_64 +#undef HWY_NATIVE_MUL_64 +#else +#define HWY_NATIVE_MUL_64 +#endif + +// Single-lane i64 or u64 +template +HWY_API V operator*(V x, V y) { + const DFromV d; + using T = TFromD; + using TU = MakeUnsigned; + const TU xu = static_cast(GetLane(x)); + const TU yu = static_cast(GetLane(y)); + return Set(d, static_cast(xu * yu)); +} + +template , HWY_IF_U64_D(D64), + HWY_IF_V_SIZE_GT_D(D64, 8)> +HWY_API V operator*(V x, V y) { + RepartitionToNarrow d32; + auto x32 = BitCast(d32, x); + auto y32 = BitCast(d32, y); + auto lolo = BitCast(d32, MulEven(x32, y32)); + auto lohi = BitCast(d32, MulEven(x32, BitCast(d32, ShiftRight<32>(y)))); + auto hilo = BitCast(d32, MulEven(BitCast(d32, ShiftRight<32>(x)), y32)); + auto hi = BitCast(d32, ShiftLeft<32>(BitCast(D64{}, lohi + hilo))); + return BitCast(D64{}, lolo + hi); +} +template , HWY_IF_I64_D(DI64), + HWY_IF_V_SIZE_GT_D(DI64, 8)> +HWY_API V operator*(V x, V y) { + RebindToUnsigned du64; + return BitCast(DI64{}, BitCast(du64, x) * BitCast(du64, y)); +} + +#endif // HWY_NATIVE_MUL_64 + +// ------------------------------ MulRound +template +HWY_API V MulRound(V a, V b) { + return Round(Mul(a, b)); +} + +// ------------------------------ MulAdd / NegMulAdd + +#if (defined(HWY_NATIVE_INT_FMA) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_INT_FMA +#undef HWY_NATIVE_INT_FMA +#else +#define HWY_NATIVE_INT_FMA +#endif + +#ifdef HWY_NATIVE_INT_FMSUB +#undef HWY_NATIVE_INT_FMSUB +#else +#define HWY_NATIVE_INT_FMSUB +#endif + +template +HWY_API V MulAdd(V mul, V x, V add) { + return Add(Mul(mul, x), add); +} + +template +HWY_API V NegMulAdd(V mul, V x, V add) { + return Sub(add, Mul(mul, x)); +} + +template +HWY_API V MulSub(V mul, V x, V sub) { + return Sub(Mul(mul, x), sub); +} +#endif // HWY_NATIVE_INT_FMA +// ------------------------------ MulComplex* / MaskedMulComplex* + +#if (defined(HWY_NATIVE_CPLX) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_CPLX +#undef HWY_NATIVE_CPLX +#else +#define HWY_NATIVE_CPLX +#endif + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE + +template )> +HWY_API V ComplexConj(V a) { + return OddEven(Neg(a), a); +} + +template +HWY_API V MulComplex(V a, V b) { + // a = u + iv, b = x + iy + const auto u = DupEven(a); + const auto v = DupOdd(a); + const auto x = DupEven(b); + const auto y = DupOdd(b); + + return OddEven(MulAdd(u, y, Mul(v, x)), Sub(Mul(u, x), Mul(v, y))); +} + +template +HWY_API V MulComplexConj(V a, V b) { + // a = u + iv, b = x + iy + const auto u = DupEven(a); + const auto v = DupOdd(a); + const auto x = DupEven(b); + const auto y = DupOdd(b); + + return OddEven(Sub(Mul(v, x), Mul(u, y)), MulAdd(u, x, Mul(v, y))); +} + +template +HWY_API V MulComplexAdd(V a, V b, V c) { + return Add(MulComplex(a, b), c); +} + +template +HWY_API V MulComplexConjAdd(V a, V b, V c) { + return Add(MulComplexConj(a, b), c); +} + +template +HWY_API V MaskedMulComplexConjAdd(M mask, V a, V b, V c) { + return IfThenElseZero(mask, MulComplexConjAdd(a, b, c)); +} + +template +HWY_API V MaskedMulComplexConj(M mask, V a, V b) { + return IfThenElseZero(mask, MulComplexConj(a, b)); +} + +template +HWY_API V MaskedMulComplexOr(V no, M mask, V a, V b) { + return IfThenElse(mask, MulComplex(a, b), no); +} +#endif // HWY_TARGET != HWY_SCALAR + +#endif // HWY_NATIVE_CPLX + +// ------------------------------ MaskedMulAddOr +#if (defined(HWY_NATIVE_MASKED_INT_FMA) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MASKED_INT_FMA +#undef HWY_NATIVE_MASKED_INT_FMA +#else +#define HWY_NATIVE_MASKED_INT_FMA +#endif + +template +HWY_API V MaskedMulAddOr(V no, M m, V mul, V x, V add) { + return IfThenElse(m, MulAdd(mul, x, add), no); +} + +#endif // HWY_NATIVE_MASKED_INT_FMA + +// ------------------------------ Integer MulSub / NegMulSub +#if (defined(HWY_NATIVE_INT_FMSUB) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_INT_FMSUB +#undef HWY_NATIVE_INT_FMSUB +#else +#define HWY_NATIVE_INT_FMSUB +#endif + +template +HWY_API V MulSub(V mul, V x, V sub) { + const DFromV d; + const RebindToSigned di; + return MulAdd(mul, x, BitCast(d, Neg(BitCast(di, sub)))); +} + +#endif // HWY_NATIVE_INT_FMSUB + +template +HWY_API V NegMulSub(V mul, V x, V sub) { + const DFromV d; + const RebindToSigned di; + + return BitCast(d, Neg(BitCast(di, MulAdd(mul, x, sub)))); +} + +// ------------------------------ MulAddSub + +// MulAddSub(mul, x, sub_or_add) for a 1-lane vector is equivalent to +// MulSub(mul, x, sub_or_add) +template , 1)> +HWY_API V MulAddSub(V mul, V x, V sub_or_add) { + return MulSub(mul, x, sub_or_add); +} + +// MulAddSub for F16/F32/F64 vectors with 2 or more lanes on +// SSSE3/SSE4/AVX2/AVX3 is implemented in x86_128-inl.h, x86_256-inl.h, and +// x86_512-inl.h + +// MulAddSub for F16/F32/F64 vectors on SVE is implemented in arm_sve-inl.h + +// MulAddSub for integer vectors on SVE2 is implemented in arm_sve-inl.h +template +HWY_API V MulAddSub(V mul, V x, V sub_or_add) { + using D = DFromV; + using T = TFromD; + using TNegate = If(), MakeSigned, T>; + + const D d; + const Rebind d_negate; + + const auto add = + OddEven(sub_or_add, BitCast(d, Neg(BitCast(d_negate, sub_or_add)))); + return MulAdd(mul, x, add); +} +// ------------------------------ MulSubAdd + +template +HWY_API V MulSubAdd(V mul, V x, V sub_or_add) { + using D = DFromV; + using T = TFromD; + using TNegate = If(), MakeSigned, T>; + + const D d; + const Rebind d_negate; + + return MulAddSub(mul, x, BitCast(d, Neg(BitCast(d_negate, sub_or_add)))); +} + +// ------------------------------ MaskedConvertTo +template +HWY_API VFromD MaskedConvertTo(M m, D d, V v) { + return IfThenElseZero(m, ConvertTo(d, v)); +} + +// ------------------------------ Integer division +#if (defined(HWY_NATIVE_INT_DIV) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_INT_DIV +#undef HWY_NATIVE_INT_DIV +#else +#define HWY_NATIVE_INT_DIV +#endif + +namespace detail { + +// DemoteInRangeTo, PromoteInRangeTo, and ConvertInRangeTo are okay to use in +// the implementation of detail::IntDiv in generic_ops-inl.h as the current +// implementations of DemoteInRangeTo, PromoteInRangeTo, and ConvertInRangeTo +// will convert values that are outside of the range of TFromD by either +// saturation, truncation, or converting values that are outside of the +// destination range to LimitsMin>() (which is equal to +// static_cast>(LimitsMax>() + 1)) + +template ))> +HWY_INLINE Vec IntDivConvFloatToInt(D di, V vf) { + return ConvertInRangeTo(di, vf); +} + +template ))> +HWY_INLINE Vec IntDivConvIntToFloat(D df, V vi) { + return ConvertTo(df, vi); +} + +#if !HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 +template )> +HWY_INLINE Vec IntDivConvFloatToInt(D df, V vi) { + return PromoteInRangeTo(df, vi); +} + +// If !HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 is true, then UI64->F32 +// IntDivConvIntToFloat(df, vi) returns an approximation of +// static_cast(v[i]) that is within 4 ULP of static_cast(v[i]) +template )> +HWY_INLINE Vec IntDivConvIntToFloat(D df32, V vi) { + const Twice dt_f32; + + auto vf32 = + ConvertTo(dt_f32, BitCast(RebindToSigned(), vi)); + +#if HWY_IS_LITTLE_ENDIAN + const auto lo_f32 = LowerHalf(df32, ConcatEven(dt_f32, vf32, vf32)); + auto hi_f32 = LowerHalf(df32, ConcatOdd(dt_f32, vf32, vf32)); +#else + const auto lo_f32 = LowerHalf(df32, ConcatOdd(dt_f32, vf32, vf32)); + auto hi_f32 = LowerHalf(df32, ConcatEven(dt_f32, vf32, vf32)); +#endif + + const RebindToSigned di32; + + hi_f32 = + Add(hi_f32, And(BitCast(df32, BroadcastSignBit(BitCast(di32, lo_f32))), + Set(df32, 1.0f))); + return hwy::HWY_NAMESPACE::MulAdd(hi_f32, Set(df32, 4294967296.0f), lo_f32); +} + +template )> +HWY_INLINE Vec IntDivConvIntToFloat(D df32, V vu) { + const Twice dt_f32; + + auto vf32 = + ConvertTo(dt_f32, BitCast(RebindToUnsigned(), vu)); + +#if HWY_IS_LITTLE_ENDIAN + const auto lo_f32 = LowerHalf(df32, ConcatEven(dt_f32, vf32, vf32)); + const auto hi_f32 = LowerHalf(df32, ConcatOdd(dt_f32, vf32, vf32)); +#else + const auto lo_f32 = LowerHalf(df32, ConcatOdd(dt_f32, vf32, vf32)); + const auto hi_f32 = LowerHalf(df32, ConcatEven(dt_f32, vf32, vf32)); +#endif + + return hwy::HWY_NAMESPACE::MulAdd(hi_f32, Set(df32, 4294967296.0f), lo_f32); +} +#endif // !HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 + +template , kOrigLaneSize)> +HWY_INLINE V IntDivUsingFloatDiv(V a, V b) { + const DFromV d; + const RebindToFloat df; + + // If kOrigLaneSize < sizeof(T) is true, then a[i] and b[i] are both in the + // [LimitsMin>(), + // LimitsMax>()] range. + + // floor(|a[i] / b[i]|) <= |flt_q| < floor(|a[i] / b[i]|) + 1 is also + // guaranteed to be true if MakeFloat has at least kOrigLaneSize*8 + 1 + // mantissa bits (including the implied one bit), where flt_q is equal to + // static_cast>(a[i]) / static_cast>(b[i]), + // even in the case where the magnitude of an inexact floating point division + // result is rounded up. + + // In other words, floor(flt_q) < flt_q < ceil(flt_q) is guaranteed to be true + // if (a[i] % b[i]) != 0 is true and MakeFloat has at least + // kOrigLaneSize*8 + 1 mantissa bits (including the implied one bit), even in + // the case where the magnitude of an inexact floating point division result + // is rounded up. + + // It is okay to do conversions from MakeFloat> to TFromV using + // ConvertInRangeTo if sizeof(TFromV) > kOrigLaneSize as the result of the + // floating point division is always greater than LimitsMin>() and + // less than LimitsMax>() if sizeof(TFromV) > kOrigLaneSize and + // b[i] != 0. + +#if HWY_TARGET_IS_NEON && !HWY_HAVE_FLOAT64 + // On Armv7, do division by multiplying by the ApproximateReciprocal + // to avoid unnecessary overhead as F32 Div refines the approximate + // reciprocal using 4 Newton-Raphson iterations + + const RebindToSigned di; + const RebindToUnsigned du; + + const auto flt_b = ConvertTo(df, b); + auto flt_recip_b = ApproximateReciprocal(flt_b); + if (kOrigLaneSize > 1) { + flt_recip_b = + Mul(flt_recip_b, ReciprocalNewtonRaphsonStep(flt_recip_b, flt_b)); + } + + auto q0 = ConvertInRangeTo(d, Mul(ConvertTo(df, a), flt_recip_b)); + const auto r0 = BitCast(di, hwy::HWY_NAMESPACE::NegMulAdd(q0, b, a)); + + auto r1 = r0; + + // Need to negate r1[i] if a[i] < 0 is true + if (IsSigned>()) { + r1 = IfNegativeThenNegOrUndefIfZero(BitCast(di, a), r1); + } + + // r1[i] is now equal to (a[i] < 0) ? (-r0[i]) : r0[i] + + auto abs_b = BitCast(du, b); + if (IsSigned>()) { + abs_b = BitCast(du, Abs(BitCast(di, abs_b))); + } + + // If (r1[i] < 0 || r1[i] >= abs_b[i]) is true, then set q1[i] to -1. + // Otherwise, set q1[i] to 0. + + // (r1[i] < 0 || r1[i] >= abs_b[i]) can be carried out using a single unsigned + // comparison as static_cast(r1[i]) >= TU(LimitsMax() + 1) >= abs_b[i] + // will be true if r1[i] < 0 is true. + auto q1 = BitCast(di, VecFromMask(du, Ge(BitCast(du, r1), abs_b))); + + // q1[i] is now equal to (r1[i] < 0 || r1[i] >= abs_b[i]) ? -1 : 0 + + // Need to negate q1[i] if r0[i] and b[i] do not have the same sign + auto q1_negate_mask = r0; + if (IsSigned>()) { + q1_negate_mask = Xor(q1_negate_mask, BitCast(di, b)); + } + q1 = IfNegativeThenElse(q1_negate_mask, Neg(q1), q1); + + // q1[i] is now equal to (r1[i] < 0 || r1[i] >= abs_b[i]) ? + // (((r0[i] ^ b[i]) < 0) ? 1 : -1) + + // Need to subtract q1[i] from q0[i] to get the final result + return Sub(q0, BitCast(d, q1)); +#else + // On targets other than Armv7 NEON, use F16 or F32 division as most targets + // other than Armv7 NEON have native F32 divide instructions + return ConvertInRangeTo(d, Div(ConvertTo(df, a), ConvertTo(df, b))); +#endif +} + +template , kOrigLaneSize), + HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 4) | (1 << 8))> +HWY_INLINE V IntDivUsingFloatDiv(V a, V b) { + // If kOrigLaneSize == sizeof(T) is true, at least two reciprocal + // multiplication steps are needed as the mantissa of MakeFloat has fewer + // than kOrigLaneSize*8 + 1 bits + + using T = TFromV; + +#if HWY_HAVE_FLOAT64 + using TF = MakeFloat; +#else + using TF = float; +#endif + + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + const Rebind df; + + if (!IsSigned()) { + // If T is unsigned, set a[i] to (a[i] >= b[i] ? 1 : 0) and set b[i] to 1 if + // b[i] > LimitsMax>() is true + + const auto one = Set(di, MakeSigned{1}); + a = BitCast( + d, IfNegativeThenElse(BitCast(di, b), + IfThenElseZero(RebindMask(di, Ge(a, b)), one), + BitCast(di, a))); + b = BitCast(d, IfNegativeThenElse(BitCast(di, b), one, BitCast(di, b))); + } + + // LimitsMin() <= b[i] <= LimitsMax>() is now true + + const auto flt_b = IntDivConvIntToFloat(df, b); + +#if HWY_TARGET_IS_NEON && !HWY_HAVE_FLOAT64 + auto flt_recip_b = ApproximateReciprocal(flt_b); + flt_recip_b = + Mul(flt_recip_b, ReciprocalNewtonRaphsonStep(flt_recip_b, flt_b)); +#else + const auto flt_recip_b = Div(Set(df, TF(1.0)), flt_b); +#endif + + // It is okay if the conversion of a[i] * flt_recip_b[i] to T using + // IntDivConvFloatToInt returns incorrect results in any lanes where b[i] == 0 + // as the result of IntDivUsingFloatDiv(a, b) is implementation-defined in any + // lanes where b[i] == 0. + + // If ScalarAbs(b[i]) == 1 is true, then it is possible for + // a[i] * flt_recip_b[i] to be rounded up to a value that is outside of the + // range of T. If a[i] * flt_recip_b[i] is outside of the range of T, + // IntDivConvFloatToInt will convert any values that are out of the range of T + // by either saturation, truncation, or wrapping around to LimitsMin(). + + // It is okay if the conversion of a[i] * flt_recip_b[i] to T using + // IntDivConvFloatToInt wraps around if ScalarAbs(b[i]) == 1 as r0 will have + // the correct sign if ScalarAbs(b[i]) == 1, even in the cases where the + // conversion of a[i] * flt_recip_b[i] to T using IntDivConvFloatToInt is + // truncated or wraps around. + + // If ScalarAbs(b[i]) >= 2 is true, a[i] * flt_recip_b[i] will be within the + // range of T, even in the cases where the conversion of a[i] to TF is + // rounded up or the result of multiplying a[i] by flt_recip_b[i] is rounded + // up. + + // ScalarAbs(r0[i]) will also always be less than (LimitsMax() / 2) if + // b[i] != 0, even in the cases where the conversion of a[i] * flt_recip_b[i] + // to T using IntDivConvFloatToInt is truncated or is wrapped around. + + auto q0 = + IntDivConvFloatToInt(d, Mul(IntDivConvIntToFloat(df, a), flt_recip_b)); + const auto r0 = BitCast(di, hwy::HWY_NAMESPACE::NegMulAdd(q0, b, a)); + + // If b[i] != 0 is true, r0[i] * flt_recip_b[i] is always within the range of + // T, even in the cases where the conversion of r0[i] to TF is rounded up or + // the multiplication of r0[i] by flt_recip_b[i] is rounded up. + + auto q1 = + IntDivConvFloatToInt(di, Mul(IntDivConvIntToFloat(df, r0), flt_recip_b)); + const auto r1 = hwy::HWY_NAMESPACE::NegMulAdd(q1, BitCast(di, b), r0); + + auto r3 = r1; + +#if !HWY_HAVE_FLOAT64 + // Need two additional reciprocal multiplication steps for I64/U64 vectors if + // HWY_HAVE_FLOAT64 is 0 + if (sizeof(T) == 8) { + const auto q2 = IntDivConvFloatToInt( + di, Mul(IntDivConvIntToFloat(df, r1), flt_recip_b)); + const auto r2 = hwy::HWY_NAMESPACE::NegMulAdd(q2, BitCast(di, b), r1); + + const auto q3 = IntDivConvFloatToInt( + di, Mul(IntDivConvIntToFloat(df, r2), flt_recip_b)); + r3 = hwy::HWY_NAMESPACE::NegMulAdd(q3, BitCast(di, b), r2); + + q0 = Add(q0, BitCast(d, q2)); + q1 = Add(q1, q3); + } +#endif // !HWY_HAVE_FLOAT64 + + auto r4 = r3; + + // Need to negate r4[i] if a[i] < 0 is true + if (IsSigned>()) { + r4 = IfNegativeThenNegOrUndefIfZero(BitCast(di, a), r4); + } + + // r4[i] is now equal to (a[i] < 0) ? (-r3[i]) : r3[i] + + auto abs_b = BitCast(du, b); + if (IsSigned>()) { + abs_b = BitCast(du, Abs(BitCast(di, abs_b))); + } + + // If (r4[i] < 0 || r4[i] >= abs_b[i]) is true, then set q4[i] to -1. + // Otherwise, set r4[i] to 0. + + // (r4[i] < 0 || r4[i] >= abs_b[i]) can be carried out using a single unsigned + // comparison as static_cast(r4[i]) >= TU(LimitsMax() + 1) >= abs_b[i] + // will be true if r4[i] < 0 is true. + auto q4 = BitCast(di, VecFromMask(du, Ge(BitCast(du, r4), abs_b))); + + // q4[i] is now equal to (r4[i] < 0 || r4[i] >= abs_b[i]) ? -1 : 0 + + // Need to negate q4[i] if r3[i] and b[i] do not have the same sign + auto q4_negate_mask = r3; + if (IsSigned>()) { + q4_negate_mask = Xor(q4_negate_mask, BitCast(di, b)); + } + q4 = IfNegativeThenElse(q4_negate_mask, Neg(q4), q4); + + // q4[i] is now equal to (r4[i] < 0 || r4[i] >= abs_b[i]) ? + // (((r3[i] ^ b[i]) < 0) ? 1 : -1) + + // The final result is equal to q0[i] + q1[i] - q4[i] + return Sub(Add(q0, BitCast(d, q1)), BitCast(d, q4)); +} + +template ) == 1) ? 4 : 2))> +HWY_INLINE V IntDiv(V a, V b) { + using T = TFromV; + + // If HWY_HAVE_FLOAT16 is 0, need to promote I8 to I32 and U8 to U32 + using TW = MakeWide< + If<(!HWY_HAVE_FLOAT16 && sizeof(TFromV) == 1), MakeWide, T>>; + + const DFromV d; + const Rebind dw; + +#if HWY_TARGET <= HWY_SSE2 + // On SSE2/SSSE3/SSE4/AVX2/AVX3, promote to and from MakeSigned to avoid + // unnecessary overhead + const RebindToSigned dw_i; + + // On SSE2/SSSE3/SSE4/AVX2/AVX3, demote to MakeSigned if + // kOrigLaneSize < sizeof(T) to avoid unnecessary overhead + const If<(kOrigLaneSize < sizeof(T)), RebindToSigned, + decltype(d)> + d_demote_to; +#else + // On other targets, promote to TW and demote to T + const decltype(dw) dw_i; + const decltype(d) d_demote_to; +#endif + + return BitCast( + d, DemoteTo(d_demote_to, IntDivUsingFloatDiv( + PromoteTo(dw_i, a), PromoteTo(dw_i, b)))); +} + +template +HWY_INLINE V IntDiv(V a, V b) { + const DFromV d; + const RepartitionToWide dw; + +#if HWY_TARGET <= HWY_SSE2 + // On SSE2/SSSE3/SSE4/AVX2/AVX3, promote to and from MakeSigned to avoid + // unnecessary overhead + const RebindToSigned dw_i; + + // On SSE2/SSSE3/SSE4/AVX2/AVX3, demote to MakeSigned> if + // kOrigLaneSize < sizeof(TFromV) to avoid unnecessary overhead + const If<(kOrigLaneSize < sizeof(TFromV)), RebindToSigned, + decltype(d)> + d_demote_to; +#else + // On other targets, promote to MakeWide> and demote to TFromV + const decltype(dw) dw_i; + const decltype(d) d_demote_to; +#endif + + return BitCast(d, OrderedDemote2To( + d_demote_to, + IntDivUsingFloatDiv( + PromoteLowerTo(dw_i, a), PromoteLowerTo(dw_i, b)), + IntDivUsingFloatDiv( + PromoteUpperTo(dw_i, a), PromoteUpperTo(dw_i, b)))); +} + +#if !HWY_HAVE_FLOAT16 +template ), + HWY_IF_V_SIZE_V(V, HWY_MAX_BYTES / 2)> +HWY_INLINE V IntDiv(V a, V b) { + const DFromV d; + const Rebind>, decltype(d)> dw; + +#if HWY_TARGET <= HWY_SSE2 + // On SSE2/SSSE3, demote from int16_t to TFromV to avoid unnecessary + // overhead + const RebindToSigned dw_i; +#else + // On other targets, demote from MakeWide> to TFromV + const decltype(dw) dw_i; +#endif + + return DemoteTo(d, + BitCast(dw_i, IntDiv<1>(PromoteTo(dw, a), PromoteTo(dw, b)))); +} +template ), + HWY_IF_V_SIZE_GT_V(V, HWY_MAX_BYTES / 2)> +HWY_INLINE V IntDiv(V a, V b) { + const DFromV d; + const RepartitionToWide dw; + +#if HWY_TARGET <= HWY_SSE2 + // On SSE2/SSSE3, demote from int16_t to TFromV to avoid unnecessary + // overhead + const RebindToSigned dw_i; +#else + // On other targets, demote from MakeWide> to TFromV + const decltype(dw) dw_i; +#endif + + return OrderedDemote2To( + d, BitCast(dw_i, IntDiv<1>(PromoteLowerTo(dw, a), PromoteLowerTo(dw, b))), + BitCast(dw_i, IntDiv<1>(PromoteUpperTo(dw, a), PromoteUpperTo(dw, b)))); +} +#endif // !HWY_HAVE_FLOAT16 + +template +HWY_INLINE V IntDiv(V a, V b) { + return IntDivUsingFloatDiv(a, b); +} + +#if HWY_HAVE_FLOAT64 +template ), + HWY_IF_V_SIZE_LE_V(V, HWY_MAX_BYTES / 2)> +HWY_INLINE V IntDiv(V a, V b) { + const DFromV d; + const Rebind df64; + + // It is okay to demote the F64 Div result to int32_t or uint32_t using + // DemoteInRangeTo as static_cast(a[i]) / static_cast(b[i]) + // will always be within the range of TFromV if b[i] != 0 and + // sizeof(TFromV) <= 4. + + return DemoteInRangeTo(d, Div(PromoteTo(df64, a), PromoteTo(df64, b))); +} +template ), + HWY_IF_V_SIZE_GT_V(V, HWY_MAX_BYTES / 2)> +HWY_INLINE V IntDiv(V a, V b) { + const DFromV d; + const Half dh; + const Repartition df64; + + // It is okay to demote the F64 Div result to int32_t or uint32_t using + // DemoteInRangeTo as static_cast(a[i]) / static_cast(b[i]) + // will always be within the range of TFromV if b[i] != 0 and + // sizeof(TFromV) <= 4. + + const VFromD div1 = + Div(PromoteUpperTo(df64, a), PromoteUpperTo(df64, b)); + const VFromD div0 = + Div(PromoteLowerTo(df64, a), PromoteLowerTo(df64, b)); + return Combine(d, DemoteInRangeTo(dh, div1), DemoteInRangeTo(dh, div0)); +} +#endif // HWY_HAVE_FLOAT64 + +template +HWY_INLINE V IntMod(V a, V b) { + return hwy::HWY_NAMESPACE::NegMulAdd(IntDiv(a, b), b, a); +} + +#if HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_WASM || \ + HWY_TARGET == HWY_WASM_EMU256 || HWY_TARGET == HWY_LSX || \ + HWY_TARGET == HWY_LASX +template ), + HWY_IF_V_SIZE_LE_V(V, HWY_MAX_BYTES / 2)> +HWY_INLINE V IntMod(V a, V b) { + const DFromV d; + const Rebind>, decltype(d)> dw; + return DemoteTo(d, IntMod(PromoteTo(dw, a), PromoteTo(dw, b))); +} + +template ), + HWY_IF_V_SIZE_GT_V(V, HWY_MAX_BYTES / 2)> +HWY_INLINE V IntMod(V a, V b) { + const DFromV d; + const RepartitionToWide dw; + return OrderedDemote2To( + d, IntMod(PromoteLowerTo(dw, a), PromoteLowerTo(dw, b)), + IntMod(PromoteUpperTo(dw, a), PromoteUpperTo(dw, b))); +} +#endif // HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_WASM || HWY_TARGET == + // HWY_WASM_EMU256 || HWY_TARGET == HWY_LSX || HWY_TARGET == HWY_LASX + +} // namespace detail + +#if HWY_TARGET == HWY_SCALAR + +template +HWY_API Vec1 operator/(Vec1 a, Vec1 b) { + return detail::IntDiv(a, b); +} +template +HWY_API Vec1 operator%(Vec1 a, Vec1 b) { + return detail::IntMod(a, b); +} + +#else // HWY_TARGET != HWY_SCALAR + +template +HWY_API Vec128 operator/(Vec128 a, Vec128 b) { + return detail::IntDiv(a, b); +} + +template +HWY_API Vec128 operator%(Vec128 a, Vec128 b) { + return detail::IntMod(a, b); +} + +#if HWY_CAP_GE256 +template +HWY_API Vec256 operator/(Vec256 a, Vec256 b) { + return detail::IntDiv(a, b); +} +template +HWY_API Vec256 operator%(Vec256 a, Vec256 b) { + return detail::IntMod(a, b); +} +#endif + +#if HWY_CAP_GE512 +template +HWY_API Vec512 operator/(Vec512 a, Vec512 b) { + return detail::IntDiv(a, b); +} +template +HWY_API Vec512 operator%(Vec512 a, Vec512 b) { + return detail::IntMod(a, b); +} +#endif + +#endif // HWY_TARGET == HWY_SCALAR + +#endif // HWY_NATIVE_INT_DIV + +// ------------------------------ AverageRound + +#if (defined(HWY_NATIVE_AVERAGE_ROUND_UI32) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 +#undef HWY_NATIVE_AVERAGE_ROUND_UI32 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI32 +#endif + +template )> +HWY_API V AverageRound(V a, V b) { + return Sub(Or(a, b), ShiftRight<1>(Xor(a, b))); +} + +#endif // HWY_NATIVE_AVERAGE_ROUND_UI64 + +#if (defined(HWY_NATIVE_AVERAGE_ROUND_UI64) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64 +#undef HWY_NATIVE_AVERAGE_ROUND_UI64 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI64 +#endif + +#if HWY_HAVE_INTEGER64 +template )> +HWY_API V AverageRound(V a, V b) { + return Sub(Or(a, b), ShiftRight<1>(Xor(a, b))); +} +#endif + +#endif // HWY_NATIVE_AVERAGE_ROUND_UI64 + +// ------------------------------ RoundingShiftRight (AverageRound) + +#if (defined(HWY_NATIVE_ROUNDING_SHR) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ROUNDING_SHR +#undef HWY_NATIVE_ROUNDING_SHR +#else +#define HWY_NATIVE_ROUNDING_SHR +#endif + +template +HWY_API V RoundingShiftRight(V v) { + const DFromV d; + using T = TFromD; + + static_assert( + 0 <= kShiftAmt && kShiftAmt <= static_cast(sizeof(T) * 8 - 1), + "kShiftAmt is out of range"); + + constexpr int kScaleDownShrAmt = HWY_MAX(kShiftAmt - 1, 0); + + auto scaled_down_v = v; + HWY_IF_CONSTEXPR(kScaleDownShrAmt > 0) { + scaled_down_v = ShiftRight(v); + } + + HWY_IF_CONSTEXPR(kShiftAmt == 0) { return scaled_down_v; } + + return AverageRound(scaled_down_v, Zero(d)); +} + +template +HWY_API V RoundingShiftRightSame(V v, int shift_amt) { + const DFromV d; + + const bool shift_amt_is_zero = (shift_amt == 0); + const auto scaled_down_v = ShiftRightSame( + v, static_cast(static_cast(shift_amt) + + static_cast(shift_amt_is_zero) - 1u)); + + return AverageRound( + scaled_down_v, + IfThenElseZero(SetMask(d, shift_amt_is_zero), scaled_down_v)); +} + +template +HWY_API V RoundingShr(V v, V amt) { + const DFromV d; + const RebindToUnsigned du; + using T = TFromD; + using TU = MakeUnsigned; + + const auto unsigned_amt = BitCast(du, amt); + const auto scale_down_shr_amt = + BitCast(d, SaturatedSub(unsigned_amt, Set(du, TU{1}))); + + const auto scaled_down_v = Shr(v, scale_down_shr_amt); + return AverageRound(scaled_down_v, + IfThenElseZero(Eq(amt, Zero(d)), scaled_down_v)); +} + +#endif // HWY_NATIVE_ROUNDING_SHR + +// ------------------------------ MulEvenAdd (PromoteEvenTo) + +// SVE with bf16 and NEON with bf16 override this. +#if (defined(HWY_NATIVE_MUL_EVEN_BF16) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MUL_EVEN_BF16 +#undef HWY_NATIVE_MUL_EVEN_BF16 +#else +#define HWY_NATIVE_MUL_EVEN_BF16 +#endif + +template >> +HWY_API VFromD MulEvenAdd(DF df, VBF a, VBF b, VFromD c) { + return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), c); +} + +template >> +HWY_API VFromD MulOddAdd(DF df, VBF a, VBF b, VFromD c) { + return MulAdd(PromoteOddTo(df, a), PromoteOddTo(df, b), c); +} + +#endif // HWY_NATIVE_MUL_EVEN_BF16 + +// ------------------------------ ReorderWidenMulAccumulate (MulEvenAdd) + +// AVX3_SPR/ZEN4, NEON with bf16 and SVE override this. +#if (defined(HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#undef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#else +#define HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#endif + +template >> +HWY_API VFromD ReorderWidenMulAccumulate(DF df, VBF a, VBF b, + VFromD sum0, + VFromD& sum1) { + // Lane order within sum0/1 is undefined, hence we can avoid the + // longer-latency lane-crossing PromoteTo by using PromoteEvenTo. + sum1 = MulOddAdd(df, a, b, sum1); + return MulEvenAdd(df, a, b, sum0); +} + +template +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { + // sum1 contains the odd lanes and sum0 the even, hence their sum is the + // desired pairwise sum. + return Add(sum0, sum1); +} + +#endif // HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 + +// ------------------------------ WidenMulAccumulate + +#if (defined(HWY_NATIVE_WIDEN_MUL_ACCUMULATE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_WIDEN_MUL_ACCUMULATE +#undef HWY_NATIVE_WIDEN_MUL_ACCUMULATE +#else +#define HWY_NATIVE_WIDEN_MUL_ACCUMULATE +#endif + +template ), class DN = RepartitionToNarrow> +HWY_API VFromD WidenMulAccumulate(D d, VFromD mul, VFromD x, + VFromD low, VFromD& high) { + high = MulAdd(PromoteUpperTo(d, mul), PromoteUpperTo(d, x), high); + return MulAdd(PromoteLowerTo(d, mul), PromoteLowerTo(d, x), low); +} + +#endif // HWY_NATIVE_WIDEN_MUL_ACCUMULATE + +#if 0 +#if (defined(HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16) == defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16 +#undef HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16 +#else +#define HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16 +#endif + +#if HWY_HAVE_FLOAT16 + +template> +HWY_API VFromD WidenMulAccumulate(D d, VFromD mul, VFromD x, + VFromD low, VFromD& high) { + high = MulAdd(PromoteUpperTo(d, mul), PromoteUpperTo(d, x), high); + return MulAdd(PromoteLowerTo(d, mul), PromoteLowerTo(d, x), low); +} + +#endif // HWY_HAVE_FLOAT16 + +#endif // HWY_NATIVE_WIDEN_MUL_ACCUMULATE_F16 +#endif // #if 0 + +// ------------------------------ SatWidenMulPairwiseAdd + +#if (defined(HWY_NATIVE_U8_I8_SATWIDENMULPAIRWISEADD) == \ + defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_U8_I8_SATWIDENMULPAIRWISEADD +#undef HWY_NATIVE_U8_I8_SATWIDENMULPAIRWISEADD +#else +#define HWY_NATIVE_U8_I8_SATWIDENMULPAIRWISEADD +#endif + +template >, HWY_IF_I16_D(DI16), + HWY_IF_U8_D(DFromV), HWY_IF_I8_D(DFromV), + HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_V(VI8)), + HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_V(VU8_2))> +HWY_API Vec SatWidenMulPairwiseAdd(DI16 di16, VU8 a, VI8 b) { + const RebindToUnsigned du16; + + const auto a0 = BitCast(di16, PromoteEvenTo(du16, a)); + const auto b0 = PromoteEvenTo(di16, b); + + const auto a1 = BitCast(di16, PromoteOddTo(du16, a)); + const auto b1 = PromoteOddTo(di16, b); + + return SaturatedAdd(Mul(a0, b0), Mul(a1, b1)); +} + +#endif + +// ------------------------------ SatWidenMulPairwiseAccumulate + +#if (defined(HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM) == \ + defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#undef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#else +#define HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#endif + +template +HWY_API VFromD SatWidenMulPairwiseAccumulate( + DI32 di32, VFromD> a, + VFromD> b, VFromD sum) { + // WidenMulPairwiseAdd(di32, a, b) is okay here as + // a[0]*b[0]+a[1]*b[1] is between -2147418112 and 2147483648 and as + // a[0]*b[0]+a[1]*b[1] can only overflow an int32_t if + // a[0], b[0], a[1], and b[1] are all equal to -32768. + + const auto product = WidenMulPairwiseAdd(di32, a, b); + + const auto mul_overflow = + VecFromMask(di32, Eq(product, Set(di32, LimitsMin()))); + + return SaturatedAdd(Sub(sum, And(BroadcastSignBit(sum), mul_overflow)), + Add(product, mul_overflow)); +} + +#endif // HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM + +// ------------------------------ SatWidenMulAccumFixedPoint + +#if (defined(HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT) == \ + defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#undef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#else +#define HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#endif + +template +HWY_API VFromD SatWidenMulAccumFixedPoint(DI32 di32, + VFromD> a, + VFromD> b, + VFromD sum) { + const Repartition dt_i16; + + const auto vt_a = ResizeBitCast(dt_i16, a); + const auto vt_b = ResizeBitCast(dt_i16, b); + + const auto dup_a = InterleaveWholeLower(dt_i16, vt_a, vt_a); + const auto dup_b = InterleaveWholeLower(dt_i16, vt_b, vt_b); + + return SatWidenMulPairwiseAccumulate(di32, dup_a, dup_b, sum); +} + +#endif // HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT + +// ------------------------------ MaskedSqrt + +#if (defined(HWY_NATIVE_MASKED_SQRT) == defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_MASKED_SQRT +#undef HWY_NATIVE_MASKED_SQRT +#else +#define HWY_NATIVE_MASKED_SQRT +#endif +template +HWY_API V MaskedSqrt(M m, V v) { + return IfThenElseZero(m, Sqrt(v)); +} + +template +HWY_API V MaskedSqrtOr(V no, M m, V v) { + return IfThenElse(m, Sqrt(v), no); +} +#endif + +// ------------------------------ SumOfMulQuadAccumulate + +#if (defined(HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE) == \ + defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 di32, + VFromD> a, + VFromD> b, + VFromD sum) { + const Repartition di16; + + const auto a0 = PromoteEvenTo(di16, a); + const auto b0 = PromoteEvenTo(di16, b); + + const auto a1 = PromoteOddTo(di16, a); + const auto b1 = PromoteOddTo(di16, b); + + return Add(sum, Add(WidenMulPairwiseAdd(di32, a0, b0), + WidenMulPairwiseAdd(di32, a1, b1))); +} + +#endif + +#if (defined(HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE) == \ + defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DU32 du32, VFromD> a, + VFromD> b, VFromD sum) { + const Repartition du16; + const RebindToSigned di16; + const RebindToSigned di32; + + const auto lo8_mask = Set(di16, int16_t{0x00FF}); + const auto a0 = And(BitCast(di16, a), lo8_mask); + const auto b0 = And(BitCast(di16, b), lo8_mask); + + const auto a1 = BitCast(di16, ShiftRight<8>(BitCast(du16, a))); + const auto b1 = BitCast(di16, ShiftRight<8>(BitCast(du16, b))); + + return Add(sum, Add(BitCast(du32, WidenMulPairwiseAdd(di32, a0, b0)), + BitCast(du32, WidenMulPairwiseAdd(di32, a1, b1)))); +} + +#endif + +#if (defined(HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE) == \ + defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DI32 di32, VFromD> a_u, + VFromD> b_i, VFromD sum) { + const Repartition di16; + const RebindToUnsigned du16; + + const auto a0 = And(BitCast(di16, a_u), Set(di16, int16_t{0x00FF})); + const auto b0 = ShiftRight<8>(ShiftLeft<8>(BitCast(di16, b_i))); + + const auto a1 = BitCast(di16, ShiftRight<8>(BitCast(du16, a_u))); + const auto b1 = ShiftRight<8>(BitCast(di16, b_i)); + + // NOTE: SatWidenMulPairwiseAdd(di16, a_u, b_i) cannot be used in + // SumOfMulQuadAccumulate as it is possible for + // a_u[0]*b_i[0]+a_u[1]*b_i[1] to overflow an int16_t if a_u[0], b_i[0], + // a_u[1], and b_i[1] are all non-zero and b_i[0] and b_i[1] have the same + // sign. + + return Add(sum, Add(WidenMulPairwiseAdd(di32, a0, b0), + WidenMulPairwiseAdd(di32, a1, b1))); +} + +#endif + +#if (defined(HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE) == \ + defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE +#endif + +#if HWY_HAVE_INTEGER64 +template +HWY_API VFromD SumOfMulQuadAccumulate( + DI64 di64, VFromD> a, + VFromD> b, VFromD sum) { + const Repartition di32; + + // WidenMulPairwiseAdd(di32, a, b) is okay here as + // a[0]*b[0]+a[1]*b[1] is between -2147418112 and 2147483648 and as + // a[0]*b[0]+a[1]*b[1] can only overflow an int32_t if + // a[0], b[0], a[1], and b[1] are all equal to -32768. + + const auto i32_pairwise_sum = WidenMulPairwiseAdd(di32, a, b); + const auto i32_pairwise_sum_overflow = + VecFromMask(di32, Eq(i32_pairwise_sum, Set(di32, LimitsMin()))); + + // The upper 32 bits of sum0 and sum1 need to be zeroed out in the case of + // overflow. + const auto hi32_mask = Set(di64, static_cast(~int64_t{0xFFFFFFFF})); + const auto p0_zero_out_mask = + ShiftLeft<32>(BitCast(di64, i32_pairwise_sum_overflow)); + const auto p1_zero_out_mask = + And(BitCast(di64, i32_pairwise_sum_overflow), hi32_mask); + + const auto p0 = + AndNot(p0_zero_out_mask, + ShiftRight<32>(ShiftLeft<32>(BitCast(di64, i32_pairwise_sum)))); + const auto p1 = + AndNot(p1_zero_out_mask, ShiftRight<32>(BitCast(di64, i32_pairwise_sum))); + + return Add(sum, Add(p0, p1)); +} +#endif // HWY_HAVE_INTEGER64 +#endif // HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE + +#if (defined(HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE) == \ + defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE +#endif + +#if HWY_HAVE_INTEGER64 +template +HWY_API VFromD SumOfMulQuadAccumulate( + DU64 du64, VFromD> a, + VFromD> b, VFromD sum) { + const auto u32_even_prod = MulEven(a, b); + const auto u32_odd_prod = MulOdd(a, b); + + const auto p0 = Add(PromoteEvenTo(du64, u32_even_prod), + PromoteEvenTo(du64, u32_odd_prod)); + const auto p1 = + Add(PromoteOddTo(du64, u32_even_prod), PromoteOddTo(du64, u32_odd_prod)); + + return Add(sum, Add(p0, p1)); +} +#endif // HWY_HAVE_INTEGER64 +#endif // HWY_NATIVE_U16_U16_SUMOFMULQUADACCUMULATE + +// ------------------------------ F64 ApproximateReciprocal + +#if (defined(HWY_NATIVE_F64_APPROX_RECIP) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_F64_APPROX_RECIP +#undef HWY_NATIVE_F64_APPROX_RECIP +#else +#define HWY_NATIVE_F64_APPROX_RECIP +#endif + +#if HWY_HAVE_FLOAT64 +template )> +HWY_API V ApproximateReciprocal(V v) { + const DFromV d; + return Div(Set(d, 1.0), v); +} +#endif // HWY_HAVE_FLOAT64 + +#endif // HWY_NATIVE_F64_APPROX_RECIP + +// ------------------------------ MaskedApproximateReciprocal +template +HWY_API V MaskedApproximateReciprocal(M m, V v) { + return IfThenElseZero(m, ApproximateReciprocal(v)); +} + +// ------------------------------ F64 ApproximateReciprocalSqrt + +#if (defined(HWY_NATIVE_F64_APPROX_RSQRT) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_F64_APPROX_RSQRT +#undef HWY_NATIVE_F64_APPROX_RSQRT +#else +#define HWY_NATIVE_F64_APPROX_RSQRT +#endif + +#if HWY_HAVE_FLOAT64 +template )> +HWY_API V ApproximateReciprocalSqrt(V v) { + const DFromV d; + const RebindToUnsigned du; + const auto half = Mul(v, Set(d, 0.5)); + // Initial guess based on log2(f) + const auto guess = BitCast(d, Sub(Set(du, uint64_t{0x5FE6EB50C7B537A9u}), + ShiftRight<1>(BitCast(du, v)))); + // One Newton-Raphson iteration + return Mul(guess, NegMulAdd(Mul(half, guess), guess, Set(d, 1.5))); +} +#endif // HWY_HAVE_FLOAT64 + +#endif // HWY_NATIVE_F64_APPROX_RSQRT + +// ------------------------------ MaskedApproximateReciprocalSqrt +template +HWY_API V MaskedApproximateReciprocalSqrt(M m, V v) { + return IfThenElseZero(m, ApproximateReciprocalSqrt(v)); +} + +// ------------------------------ Compress* + +#if (defined(HWY_NATIVE_COMPRESS8) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_COMPRESS8 +#undef HWY_NATIVE_COMPRESS8 +#else +#define HWY_NATIVE_COMPRESS8 +#endif + +template +HWY_API size_t CompressBitsStore(V v, const uint8_t* HWY_RESTRICT bits, D d, + T* unaligned) { + HWY_ALIGN T lanes[MaxLanes(d)]; + Store(v, d, lanes); + + const Simd d8; + T* pos = unaligned; + + HWY_ALIGN constexpr T table[2048] = { + 0, 1, 2, 3, 4, 5, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 1, 0, 2, 3, 4, 5, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 2, 0, 1, 3, 4, 5, 6, 7, /**/ 0, 2, 1, 3, 4, 5, 6, 7, // + 1, 2, 0, 3, 4, 5, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 3, 0, 1, 2, 4, 5, 6, 7, /**/ 0, 3, 1, 2, 4, 5, 6, 7, // + 1, 3, 0, 2, 4, 5, 6, 7, /**/ 0, 1, 3, 2, 4, 5, 6, 7, // + 2, 3, 0, 1, 4, 5, 6, 7, /**/ 0, 2, 3, 1, 4, 5, 6, 7, // + 1, 2, 3, 0, 4, 5, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 4, 0, 1, 2, 3, 5, 6, 7, /**/ 0, 4, 1, 2, 3, 5, 6, 7, // + 1, 4, 0, 2, 3, 5, 6, 7, /**/ 0, 1, 4, 2, 3, 5, 6, 7, // + 2, 4, 0, 1, 3, 5, 6, 7, /**/ 0, 2, 4, 1, 3, 5, 6, 7, // + 1, 2, 4, 0, 3, 5, 6, 7, /**/ 0, 1, 2, 4, 3, 5, 6, 7, // + 3, 4, 0, 1, 2, 5, 6, 7, /**/ 0, 3, 4, 1, 2, 5, 6, 7, // + 1, 3, 4, 0, 2, 5, 6, 7, /**/ 0, 1, 3, 4, 2, 5, 6, 7, // + 2, 3, 4, 0, 1, 5, 6, 7, /**/ 0, 2, 3, 4, 1, 5, 6, 7, // + 1, 2, 3, 4, 0, 5, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 5, 0, 1, 2, 3, 4, 6, 7, /**/ 0, 5, 1, 2, 3, 4, 6, 7, // + 1, 5, 0, 2, 3, 4, 6, 7, /**/ 0, 1, 5, 2, 3, 4, 6, 7, // + 2, 5, 0, 1, 3, 4, 6, 7, /**/ 0, 2, 5, 1, 3, 4, 6, 7, // + 1, 2, 5, 0, 3, 4, 6, 7, /**/ 0, 1, 2, 5, 3, 4, 6, 7, // + 3, 5, 0, 1, 2, 4, 6, 7, /**/ 0, 3, 5, 1, 2, 4, 6, 7, // + 1, 3, 5, 0, 2, 4, 6, 7, /**/ 0, 1, 3, 5, 2, 4, 6, 7, // + 2, 3, 5, 0, 1, 4, 6, 7, /**/ 0, 2, 3, 5, 1, 4, 6, 7, // + 1, 2, 3, 5, 0, 4, 6, 7, /**/ 0, 1, 2, 3, 5, 4, 6, 7, // + 4, 5, 0, 1, 2, 3, 6, 7, /**/ 0, 4, 5, 1, 2, 3, 6, 7, // + 1, 4, 5, 0, 2, 3, 6, 7, /**/ 0, 1, 4, 5, 2, 3, 6, 7, // + 2, 4, 5, 0, 1, 3, 6, 7, /**/ 0, 2, 4, 5, 1, 3, 6, 7, // + 1, 2, 4, 5, 0, 3, 6, 7, /**/ 0, 1, 2, 4, 5, 3, 6, 7, // + 3, 4, 5, 0, 1, 2, 6, 7, /**/ 0, 3, 4, 5, 1, 2, 6, 7, // + 1, 3, 4, 5, 0, 2, 6, 7, /**/ 0, 1, 3, 4, 5, 2, 6, 7, // + 2, 3, 4, 5, 0, 1, 6, 7, /**/ 0, 2, 3, 4, 5, 1, 6, 7, // + 1, 2, 3, 4, 5, 0, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 6, 0, 1, 2, 3, 4, 5, 7, /**/ 0, 6, 1, 2, 3, 4, 5, 7, // + 1, 6, 0, 2, 3, 4, 5, 7, /**/ 0, 1, 6, 2, 3, 4, 5, 7, // + 2, 6, 0, 1, 3, 4, 5, 7, /**/ 0, 2, 6, 1, 3, 4, 5, 7, // + 1, 2, 6, 0, 3, 4, 5, 7, /**/ 0, 1, 2, 6, 3, 4, 5, 7, // + 3, 6, 0, 1, 2, 4, 5, 7, /**/ 0, 3, 6, 1, 2, 4, 5, 7, // + 1, 3, 6, 0, 2, 4, 5, 7, /**/ 0, 1, 3, 6, 2, 4, 5, 7, // + 2, 3, 6, 0, 1, 4, 5, 7, /**/ 0, 2, 3, 6, 1, 4, 5, 7, // + 1, 2, 3, 6, 0, 4, 5, 7, /**/ 0, 1, 2, 3, 6, 4, 5, 7, // + 4, 6, 0, 1, 2, 3, 5, 7, /**/ 0, 4, 6, 1, 2, 3, 5, 7, // + 1, 4, 6, 0, 2, 3, 5, 7, /**/ 0, 1, 4, 6, 2, 3, 5, 7, // + 2, 4, 6, 0, 1, 3, 5, 7, /**/ 0, 2, 4, 6, 1, 3, 5, 7, // + 1, 2, 4, 6, 0, 3, 5, 7, /**/ 0, 1, 2, 4, 6, 3, 5, 7, // + 3, 4, 6, 0, 1, 2, 5, 7, /**/ 0, 3, 4, 6, 1, 2, 5, 7, // + 1, 3, 4, 6, 0, 2, 5, 7, /**/ 0, 1, 3, 4, 6, 2, 5, 7, // + 2, 3, 4, 6, 0, 1, 5, 7, /**/ 0, 2, 3, 4, 6, 1, 5, 7, // + 1, 2, 3, 4, 6, 0, 5, 7, /**/ 0, 1, 2, 3, 4, 6, 5, 7, // + 5, 6, 0, 1, 2, 3, 4, 7, /**/ 0, 5, 6, 1, 2, 3, 4, 7, // + 1, 5, 6, 0, 2, 3, 4, 7, /**/ 0, 1, 5, 6, 2, 3, 4, 7, // + 2, 5, 6, 0, 1, 3, 4, 7, /**/ 0, 2, 5, 6, 1, 3, 4, 7, // + 1, 2, 5, 6, 0, 3, 4, 7, /**/ 0, 1, 2, 5, 6, 3, 4, 7, // + 3, 5, 6, 0, 1, 2, 4, 7, /**/ 0, 3, 5, 6, 1, 2, 4, 7, // + 1, 3, 5, 6, 0, 2, 4, 7, /**/ 0, 1, 3, 5, 6, 2, 4, 7, // + 2, 3, 5, 6, 0, 1, 4, 7, /**/ 0, 2, 3, 5, 6, 1, 4, 7, // + 1, 2, 3, 5, 6, 0, 4, 7, /**/ 0, 1, 2, 3, 5, 6, 4, 7, // + 4, 5, 6, 0, 1, 2, 3, 7, /**/ 0, 4, 5, 6, 1, 2, 3, 7, // + 1, 4, 5, 6, 0, 2, 3, 7, /**/ 0, 1, 4, 5, 6, 2, 3, 7, // + 2, 4, 5, 6, 0, 1, 3, 7, /**/ 0, 2, 4, 5, 6, 1, 3, 7, // + 1, 2, 4, 5, 6, 0, 3, 7, /**/ 0, 1, 2, 4, 5, 6, 3, 7, // + 3, 4, 5, 6, 0, 1, 2, 7, /**/ 0, 3, 4, 5, 6, 1, 2, 7, // + 1, 3, 4, 5, 6, 0, 2, 7, /**/ 0, 1, 3, 4, 5, 6, 2, 7, // + 2, 3, 4, 5, 6, 0, 1, 7, /**/ 0, 2, 3, 4, 5, 6, 1, 7, // + 1, 2, 3, 4, 5, 6, 0, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, // + 7, 0, 1, 2, 3, 4, 5, 6, /**/ 0, 7, 1, 2, 3, 4, 5, 6, // + 1, 7, 0, 2, 3, 4, 5, 6, /**/ 0, 1, 7, 2, 3, 4, 5, 6, // + 2, 7, 0, 1, 3, 4, 5, 6, /**/ 0, 2, 7, 1, 3, 4, 5, 6, // + 1, 2, 7, 0, 3, 4, 5, 6, /**/ 0, 1, 2, 7, 3, 4, 5, 6, // + 3, 7, 0, 1, 2, 4, 5, 6, /**/ 0, 3, 7, 1, 2, 4, 5, 6, // + 1, 3, 7, 0, 2, 4, 5, 6, /**/ 0, 1, 3, 7, 2, 4, 5, 6, // + 2, 3, 7, 0, 1, 4, 5, 6, /**/ 0, 2, 3, 7, 1, 4, 5, 6, // + 1, 2, 3, 7, 0, 4, 5, 6, /**/ 0, 1, 2, 3, 7, 4, 5, 6, // + 4, 7, 0, 1, 2, 3, 5, 6, /**/ 0, 4, 7, 1, 2, 3, 5, 6, // + 1, 4, 7, 0, 2, 3, 5, 6, /**/ 0, 1, 4, 7, 2, 3, 5, 6, // + 2, 4, 7, 0, 1, 3, 5, 6, /**/ 0, 2, 4, 7, 1, 3, 5, 6, // + 1, 2, 4, 7, 0, 3, 5, 6, /**/ 0, 1, 2, 4, 7, 3, 5, 6, // + 3, 4, 7, 0, 1, 2, 5, 6, /**/ 0, 3, 4, 7, 1, 2, 5, 6, // + 1, 3, 4, 7, 0, 2, 5, 6, /**/ 0, 1, 3, 4, 7, 2, 5, 6, // + 2, 3, 4, 7, 0, 1, 5, 6, /**/ 0, 2, 3, 4, 7, 1, 5, 6, // + 1, 2, 3, 4, 7, 0, 5, 6, /**/ 0, 1, 2, 3, 4, 7, 5, 6, // + 5, 7, 0, 1, 2, 3, 4, 6, /**/ 0, 5, 7, 1, 2, 3, 4, 6, // + 1, 5, 7, 0, 2, 3, 4, 6, /**/ 0, 1, 5, 7, 2, 3, 4, 6, // + 2, 5, 7, 0, 1, 3, 4, 6, /**/ 0, 2, 5, 7, 1, 3, 4, 6, // + 1, 2, 5, 7, 0, 3, 4, 6, /**/ 0, 1, 2, 5, 7, 3, 4, 6, // + 3, 5, 7, 0, 1, 2, 4, 6, /**/ 0, 3, 5, 7, 1, 2, 4, 6, // + 1, 3, 5, 7, 0, 2, 4, 6, /**/ 0, 1, 3, 5, 7, 2, 4, 6, // + 2, 3, 5, 7, 0, 1, 4, 6, /**/ 0, 2, 3, 5, 7, 1, 4, 6, // + 1, 2, 3, 5, 7, 0, 4, 6, /**/ 0, 1, 2, 3, 5, 7, 4, 6, // + 4, 5, 7, 0, 1, 2, 3, 6, /**/ 0, 4, 5, 7, 1, 2, 3, 6, // + 1, 4, 5, 7, 0, 2, 3, 6, /**/ 0, 1, 4, 5, 7, 2, 3, 6, // + 2, 4, 5, 7, 0, 1, 3, 6, /**/ 0, 2, 4, 5, 7, 1, 3, 6, // + 1, 2, 4, 5, 7, 0, 3, 6, /**/ 0, 1, 2, 4, 5, 7, 3, 6, // + 3, 4, 5, 7, 0, 1, 2, 6, /**/ 0, 3, 4, 5, 7, 1, 2, 6, // + 1, 3, 4, 5, 7, 0, 2, 6, /**/ 0, 1, 3, 4, 5, 7, 2, 6, // + 2, 3, 4, 5, 7, 0, 1, 6, /**/ 0, 2, 3, 4, 5, 7, 1, 6, // + 1, 2, 3, 4, 5, 7, 0, 6, /**/ 0, 1, 2, 3, 4, 5, 7, 6, // + 6, 7, 0, 1, 2, 3, 4, 5, /**/ 0, 6, 7, 1, 2, 3, 4, 5, // + 1, 6, 7, 0, 2, 3, 4, 5, /**/ 0, 1, 6, 7, 2, 3, 4, 5, // + 2, 6, 7, 0, 1, 3, 4, 5, /**/ 0, 2, 6, 7, 1, 3, 4, 5, // + 1, 2, 6, 7, 0, 3, 4, 5, /**/ 0, 1, 2, 6, 7, 3, 4, 5, // + 3, 6, 7, 0, 1, 2, 4, 5, /**/ 0, 3, 6, 7, 1, 2, 4, 5, // + 1, 3, 6, 7, 0, 2, 4, 5, /**/ 0, 1, 3, 6, 7, 2, 4, 5, // + 2, 3, 6, 7, 0, 1, 4, 5, /**/ 0, 2, 3, 6, 7, 1, 4, 5, // + 1, 2, 3, 6, 7, 0, 4, 5, /**/ 0, 1, 2, 3, 6, 7, 4, 5, // + 4, 6, 7, 0, 1, 2, 3, 5, /**/ 0, 4, 6, 7, 1, 2, 3, 5, // + 1, 4, 6, 7, 0, 2, 3, 5, /**/ 0, 1, 4, 6, 7, 2, 3, 5, // + 2, 4, 6, 7, 0, 1, 3, 5, /**/ 0, 2, 4, 6, 7, 1, 3, 5, // + 1, 2, 4, 6, 7, 0, 3, 5, /**/ 0, 1, 2, 4, 6, 7, 3, 5, // + 3, 4, 6, 7, 0, 1, 2, 5, /**/ 0, 3, 4, 6, 7, 1, 2, 5, // + 1, 3, 4, 6, 7, 0, 2, 5, /**/ 0, 1, 3, 4, 6, 7, 2, 5, // + 2, 3, 4, 6, 7, 0, 1, 5, /**/ 0, 2, 3, 4, 6, 7, 1, 5, // + 1, 2, 3, 4, 6, 7, 0, 5, /**/ 0, 1, 2, 3, 4, 6, 7, 5, // + 5, 6, 7, 0, 1, 2, 3, 4, /**/ 0, 5, 6, 7, 1, 2, 3, 4, // + 1, 5, 6, 7, 0, 2, 3, 4, /**/ 0, 1, 5, 6, 7, 2, 3, 4, // + 2, 5, 6, 7, 0, 1, 3, 4, /**/ 0, 2, 5, 6, 7, 1, 3, 4, // + 1, 2, 5, 6, 7, 0, 3, 4, /**/ 0, 1, 2, 5, 6, 7, 3, 4, // + 3, 5, 6, 7, 0, 1, 2, 4, /**/ 0, 3, 5, 6, 7, 1, 2, 4, // + 1, 3, 5, 6, 7, 0, 2, 4, /**/ 0, 1, 3, 5, 6, 7, 2, 4, // + 2, 3, 5, 6, 7, 0, 1, 4, /**/ 0, 2, 3, 5, 6, 7, 1, 4, // + 1, 2, 3, 5, 6, 7, 0, 4, /**/ 0, 1, 2, 3, 5, 6, 7, 4, // + 4, 5, 6, 7, 0, 1, 2, 3, /**/ 0, 4, 5, 6, 7, 1, 2, 3, // + 1, 4, 5, 6, 7, 0, 2, 3, /**/ 0, 1, 4, 5, 6, 7, 2, 3, // + 2, 4, 5, 6, 7, 0, 1, 3, /**/ 0, 2, 4, 5, 6, 7, 1, 3, // + 1, 2, 4, 5, 6, 7, 0, 3, /**/ 0, 1, 2, 4, 5, 6, 7, 3, // + 3, 4, 5, 6, 7, 0, 1, 2, /**/ 0, 3, 4, 5, 6, 7, 1, 2, // + 1, 3, 4, 5, 6, 7, 0, 2, /**/ 0, 1, 3, 4, 5, 6, 7, 2, // + 2, 3, 4, 5, 6, 7, 0, 1, /**/ 0, 2, 3, 4, 5, 6, 7, 1, // + 1, 2, 3, 4, 5, 6, 7, 0, /**/ 0, 1, 2, 3, 4, 5, 6, 7}; + + size_t i = 0; + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + constexpr bool kMaybeLt128 = + (HWY_TARGET == HWY_SCALAR) || !detail::IsFull(D()); + // If less than 128 bit, we may not enter the main loop below, and even + // the remainder loop might not write anything if bits are not set. + // Ensure the output is initialized. GCC seems not to understand this is only + // necessary if kMaybeLt128. + HWY_IF_CONSTEXPR(kMaybeLt128 || HWY_COMPILER_GCC_ACTUAL) { + StoreU(v, d, unaligned); + } + HWY_ASSUME(N >= 8 || kMaybeLt128); + if (N >= 8) { + for (; i <= N - 8; i += 8) { + // Each byte worth of bits is the index of one of 256 8-byte ranges, and + // its population count determines how far to advance the write position. + const size_t bits8 = bits[i / 8]; + const auto indices = Load(d8, table + bits8 * 8); + const auto compressed = TableLookupBytes(LoadU(d8, lanes + i), indices); + StoreU(compressed, d8, pos); + pos += PopCount(bits8); + } + } + // Not required if we have full vectors of >= 128 bits, because they are + // multiples of 8 bytes. Inefficient loop is mainly required for safely + // handling compress_test). + HWY_IF_CONSTEXPR(kMaybeLt128) { + for (; i < N; ++i) { + if (bits[i / 8] & (1u << (i % 8))) { + *pos++ = lanes[i]; + } + } + } + + return static_cast(pos - unaligned); +} + +template +HWY_API size_t CompressStore(V v, M mask, D d, T* HWY_RESTRICT unaligned) { + uint8_t bits[HWY_MAX(size_t{8}, MaxLanes(d) / 8)]; + (void)StoreMaskBits(d, mask, bits); + return CompressBitsStore(v, bits, d, unaligned); +} + +template +HWY_API size_t CompressBlendedStore(V v, M mask, D d, + T* HWY_RESTRICT unaligned) { + HWY_ALIGN T buf[MaxLanes(d)]; + const size_t bytes = CompressStore(v, mask, d, buf); + BlendedStore(Load(d, buf), FirstN(d, bytes), d, unaligned); + return bytes; +} + +// For reasons unknown, HWY_IF_T_SIZE_V is a compile error in SVE. +template , HWY_IF_T_SIZE(T, 1)> +HWY_API V Compress(V v, const M mask) { + const DFromV d; + HWY_ALIGN T lanes[MaxLanes(d)]; + (void)CompressStore(v, mask, d, lanes); + return Load(d, lanes); +} + +template , HWY_IF_T_SIZE(T, 1)> +HWY_API V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) { + const DFromV d; + HWY_ALIGN T lanes[MaxLanes(d)]; + (void)CompressBitsStore(v, bits, d, lanes); + return Load(d, lanes); +} + +template , HWY_IF_T_SIZE(T, 1)> +HWY_API V CompressNot(V v, M mask) { + return Compress(v, Not(mask)); +} + +#endif // HWY_NATIVE_COMPRESS8 + +// ------------------------------ Expand + +// Note that this generic implementation assumes <= 128 bit fixed vectors; +// the SVE and RVV targets provide their own native implementations. +#if (defined(HWY_NATIVE_EXPAND) == defined(HWY_TARGET_TOGGLE)) || HWY_IDE +#ifdef HWY_NATIVE_EXPAND +#undef HWY_NATIVE_EXPAND +#else +#define HWY_NATIVE_EXPAND +#endif + +namespace detail { + +template +HWY_INLINE Vec128 IndicesForExpandFromBits(uint64_t mask_bits) { + static_assert(N <= 8, "Should only be called for half-vectors"); + const Simd du8; + HWY_DASSERT(mask_bits < 0x100); + alignas(16) static constexpr uint8_t table[2048] = { + // PrintExpand8x8Tables + 128, 128, 128, 128, 128, 128, 128, 128, // + 0, 128, 128, 128, 128, 128, 128, 128, // + 128, 0, 128, 128, 128, 128, 128, 128, // + 0, 1, 128, 128, 128, 128, 128, 128, // + 128, 128, 0, 128, 128, 128, 128, 128, // + 0, 128, 1, 128, 128, 128, 128, 128, // + 128, 0, 1, 128, 128, 128, 128, 128, // + 0, 1, 2, 128, 128, 128, 128, 128, // + 128, 128, 128, 0, 128, 128, 128, 128, // + 0, 128, 128, 1, 128, 128, 128, 128, // + 128, 0, 128, 1, 128, 128, 128, 128, // + 0, 1, 128, 2, 128, 128, 128, 128, // + 128, 128, 0, 1, 128, 128, 128, 128, // + 0, 128, 1, 2, 128, 128, 128, 128, // + 128, 0, 1, 2, 128, 128, 128, 128, // + 0, 1, 2, 3, 128, 128, 128, 128, // + 128, 128, 128, 128, 0, 128, 128, 128, // + 0, 128, 128, 128, 1, 128, 128, 128, // + 128, 0, 128, 128, 1, 128, 128, 128, // + 0, 1, 128, 128, 2, 128, 128, 128, // + 128, 128, 0, 128, 1, 128, 128, 128, // + 0, 128, 1, 128, 2, 128, 128, 128, // + 128, 0, 1, 128, 2, 128, 128, 128, // + 0, 1, 2, 128, 3, 128, 128, 128, // + 128, 128, 128, 0, 1, 128, 128, 128, // + 0, 128, 128, 1, 2, 128, 128, 128, // + 128, 0, 128, 1, 2, 128, 128, 128, // + 0, 1, 128, 2, 3, 128, 128, 128, // + 128, 128, 0, 1, 2, 128, 128, 128, // + 0, 128, 1, 2, 3, 128, 128, 128, // + 128, 0, 1, 2, 3, 128, 128, 128, // + 0, 1, 2, 3, 4, 128, 128, 128, // + 128, 128, 128, 128, 128, 0, 128, 128, // + 0, 128, 128, 128, 128, 1, 128, 128, // + 128, 0, 128, 128, 128, 1, 128, 128, // + 0, 1, 128, 128, 128, 2, 128, 128, // + 128, 128, 0, 128, 128, 1, 128, 128, // + 0, 128, 1, 128, 128, 2, 128, 128, // + 128, 0, 1, 128, 128, 2, 128, 128, // + 0, 1, 2, 128, 128, 3, 128, 128, // + 128, 128, 128, 0, 128, 1, 128, 128, // + 0, 128, 128, 1, 128, 2, 128, 128, // + 128, 0, 128, 1, 128, 2, 128, 128, // + 0, 1, 128, 2, 128, 3, 128, 128, // + 128, 128, 0, 1, 128, 2, 128, 128, // + 0, 128, 1, 2, 128, 3, 128, 128, // + 128, 0, 1, 2, 128, 3, 128, 128, // + 0, 1, 2, 3, 128, 4, 128, 128, // + 128, 128, 128, 128, 0, 1, 128, 128, // + 0, 128, 128, 128, 1, 2, 128, 128, // + 128, 0, 128, 128, 1, 2, 128, 128, // + 0, 1, 128, 128, 2, 3, 128, 128, // + 128, 128, 0, 128, 1, 2, 128, 128, // + 0, 128, 1, 128, 2, 3, 128, 128, // + 128, 0, 1, 128, 2, 3, 128, 128, // + 0, 1, 2, 128, 3, 4, 128, 128, // + 128, 128, 128, 0, 1, 2, 128, 128, // + 0, 128, 128, 1, 2, 3, 128, 128, // + 128, 0, 128, 1, 2, 3, 128, 128, // + 0, 1, 128, 2, 3, 4, 128, 128, // + 128, 128, 0, 1, 2, 3, 128, 128, // + 0, 128, 1, 2, 3, 4, 128, 128, // + 128, 0, 1, 2, 3, 4, 128, 128, // + 0, 1, 2, 3, 4, 5, 128, 128, // + 128, 128, 128, 128, 128, 128, 0, 128, // + 0, 128, 128, 128, 128, 128, 1, 128, // + 128, 0, 128, 128, 128, 128, 1, 128, // + 0, 1, 128, 128, 128, 128, 2, 128, // + 128, 128, 0, 128, 128, 128, 1, 128, // + 0, 128, 1, 128, 128, 128, 2, 128, // + 128, 0, 1, 128, 128, 128, 2, 128, // + 0, 1, 2, 128, 128, 128, 3, 128, // + 128, 128, 128, 0, 128, 128, 1, 128, // + 0, 128, 128, 1, 128, 128, 2, 128, // + 128, 0, 128, 1, 128, 128, 2, 128, // + 0, 1, 128, 2, 128, 128, 3, 128, // + 128, 128, 0, 1, 128, 128, 2, 128, // + 0, 128, 1, 2, 128, 128, 3, 128, // + 128, 0, 1, 2, 128, 128, 3, 128, // + 0, 1, 2, 3, 128, 128, 4, 128, // + 128, 128, 128, 128, 0, 128, 1, 128, // + 0, 128, 128, 128, 1, 128, 2, 128, // + 128, 0, 128, 128, 1, 128, 2, 128, // + 0, 1, 128, 128, 2, 128, 3, 128, // + 128, 128, 0, 128, 1, 128, 2, 128, // + 0, 128, 1, 128, 2, 128, 3, 128, // + 128, 0, 1, 128, 2, 128, 3, 128, // + 0, 1, 2, 128, 3, 128, 4, 128, // + 128, 128, 128, 0, 1, 128, 2, 128, // + 0, 128, 128, 1, 2, 128, 3, 128, // + 128, 0, 128, 1, 2, 128, 3, 128, // + 0, 1, 128, 2, 3, 128, 4, 128, // + 128, 128, 0, 1, 2, 128, 3, 128, // + 0, 128, 1, 2, 3, 128, 4, 128, // + 128, 0, 1, 2, 3, 128, 4, 128, // + 0, 1, 2, 3, 4, 128, 5, 128, // + 128, 128, 128, 128, 128, 0, 1, 128, // + 0, 128, 128, 128, 128, 1, 2, 128, // + 128, 0, 128, 128, 128, 1, 2, 128, // + 0, 1, 128, 128, 128, 2, 3, 128, // + 128, 128, 0, 128, 128, 1, 2, 128, // + 0, 128, 1, 128, 128, 2, 3, 128, // + 128, 0, 1, 128, 128, 2, 3, 128, // + 0, 1, 2, 128, 128, 3, 4, 128, // + 128, 128, 128, 0, 128, 1, 2, 128, // + 0, 128, 128, 1, 128, 2, 3, 128, // + 128, 0, 128, 1, 128, 2, 3, 128, // + 0, 1, 128, 2, 128, 3, 4, 128, // + 128, 128, 0, 1, 128, 2, 3, 128, // + 0, 128, 1, 2, 128, 3, 4, 128, // + 128, 0, 1, 2, 128, 3, 4, 128, // + 0, 1, 2, 3, 128, 4, 5, 128, // + 128, 128, 128, 128, 0, 1, 2, 128, // + 0, 128, 128, 128, 1, 2, 3, 128, // + 128, 0, 128, 128, 1, 2, 3, 128, // + 0, 1, 128, 128, 2, 3, 4, 128, // + 128, 128, 0, 128, 1, 2, 3, 128, // + 0, 128, 1, 128, 2, 3, 4, 128, // + 128, 0, 1, 128, 2, 3, 4, 128, // + 0, 1, 2, 128, 3, 4, 5, 128, // + 128, 128, 128, 0, 1, 2, 3, 128, // + 0, 128, 128, 1, 2, 3, 4, 128, // + 128, 0, 128, 1, 2, 3, 4, 128, // + 0, 1, 128, 2, 3, 4, 5, 128, // + 128, 128, 0, 1, 2, 3, 4, 128, // + 0, 128, 1, 2, 3, 4, 5, 128, // + 128, 0, 1, 2, 3, 4, 5, 128, // + 0, 1, 2, 3, 4, 5, 6, 128, // + 128, 128, 128, 128, 128, 128, 128, 0, // + 0, 128, 128, 128, 128, 128, 128, 1, // + 128, 0, 128, 128, 128, 128, 128, 1, // + 0, 1, 128, 128, 128, 128, 128, 2, // + 128, 128, 0, 128, 128, 128, 128, 1, // + 0, 128, 1, 128, 128, 128, 128, 2, // + 128, 0, 1, 128, 128, 128, 128, 2, // + 0, 1, 2, 128, 128, 128, 128, 3, // + 128, 128, 128, 0, 128, 128, 128, 1, // + 0, 128, 128, 1, 128, 128, 128, 2, // + 128, 0, 128, 1, 128, 128, 128, 2, // + 0, 1, 128, 2, 128, 128, 128, 3, // + 128, 128, 0, 1, 128, 128, 128, 2, // + 0, 128, 1, 2, 128, 128, 128, 3, // + 128, 0, 1, 2, 128, 128, 128, 3, // + 0, 1, 2, 3, 128, 128, 128, 4, // + 128, 128, 128, 128, 0, 128, 128, 1, // + 0, 128, 128, 128, 1, 128, 128, 2, // + 128, 0, 128, 128, 1, 128, 128, 2, // + 0, 1, 128, 128, 2, 128, 128, 3, // + 128, 128, 0, 128, 1, 128, 128, 2, // + 0, 128, 1, 128, 2, 128, 128, 3, // + 128, 0, 1, 128, 2, 128, 128, 3, // + 0, 1, 2, 128, 3, 128, 128, 4, // + 128, 128, 128, 0, 1, 128, 128, 2, // + 0, 128, 128, 1, 2, 128, 128, 3, // + 128, 0, 128, 1, 2, 128, 128, 3, // + 0, 1, 128, 2, 3, 128, 128, 4, // + 128, 128, 0, 1, 2, 128, 128, 3, // + 0, 128, 1, 2, 3, 128, 128, 4, // + 128, 0, 1, 2, 3, 128, 128, 4, // + 0, 1, 2, 3, 4, 128, 128, 5, // + 128, 128, 128, 128, 128, 0, 128, 1, // + 0, 128, 128, 128, 128, 1, 128, 2, // + 128, 0, 128, 128, 128, 1, 128, 2, // + 0, 1, 128, 128, 128, 2, 128, 3, // + 128, 128, 0, 128, 128, 1, 128, 2, // + 0, 128, 1, 128, 128, 2, 128, 3, // + 128, 0, 1, 128, 128, 2, 128, 3, // + 0, 1, 2, 128, 128, 3, 128, 4, // + 128, 128, 128, 0, 128, 1, 128, 2, // + 0, 128, 128, 1, 128, 2, 128, 3, // + 128, 0, 128, 1, 128, 2, 128, 3, // + 0, 1, 128, 2, 128, 3, 128, 4, // + 128, 128, 0, 1, 128, 2, 128, 3, // + 0, 128, 1, 2, 128, 3, 128, 4, // + 128, 0, 1, 2, 128, 3, 128, 4, // + 0, 1, 2, 3, 128, 4, 128, 5, // + 128, 128, 128, 128, 0, 1, 128, 2, // + 0, 128, 128, 128, 1, 2, 128, 3, // + 128, 0, 128, 128, 1, 2, 128, 3, // + 0, 1, 128, 128, 2, 3, 128, 4, // + 128, 128, 0, 128, 1, 2, 128, 3, // + 0, 128, 1, 128, 2, 3, 128, 4, // + 128, 0, 1, 128, 2, 3, 128, 4, // + 0, 1, 2, 128, 3, 4, 128, 5, // + 128, 128, 128, 0, 1, 2, 128, 3, // + 0, 128, 128, 1, 2, 3, 128, 4, // + 128, 0, 128, 1, 2, 3, 128, 4, // + 0, 1, 128, 2, 3, 4, 128, 5, // + 128, 128, 0, 1, 2, 3, 128, 4, // + 0, 128, 1, 2, 3, 4, 128, 5, // + 128, 0, 1, 2, 3, 4, 128, 5, // + 0, 1, 2, 3, 4, 5, 128, 6, // + 128, 128, 128, 128, 128, 128, 0, 1, // + 0, 128, 128, 128, 128, 128, 1, 2, // + 128, 0, 128, 128, 128, 128, 1, 2, // + 0, 1, 128, 128, 128, 128, 2, 3, // + 128, 128, 0, 128, 128, 128, 1, 2, // + 0, 128, 1, 128, 128, 128, 2, 3, // + 128, 0, 1, 128, 128, 128, 2, 3, // + 0, 1, 2, 128, 128, 128, 3, 4, // + 128, 128, 128, 0, 128, 128, 1, 2, // + 0, 128, 128, 1, 128, 128, 2, 3, // + 128, 0, 128, 1, 128, 128, 2, 3, // + 0, 1, 128, 2, 128, 128, 3, 4, // + 128, 128, 0, 1, 128, 128, 2, 3, // + 0, 128, 1, 2, 128, 128, 3, 4, // + 128, 0, 1, 2, 128, 128, 3, 4, // + 0, 1, 2, 3, 128, 128, 4, 5, // + 128, 128, 128, 128, 0, 128, 1, 2, // + 0, 128, 128, 128, 1, 128, 2, 3, // + 128, 0, 128, 128, 1, 128, 2, 3, // + 0, 1, 128, 128, 2, 128, 3, 4, // + 128, 128, 0, 128, 1, 128, 2, 3, // + 0, 128, 1, 128, 2, 128, 3, 4, // + 128, 0, 1, 128, 2, 128, 3, 4, // + 0, 1, 2, 128, 3, 128, 4, 5, // + 128, 128, 128, 0, 1, 128, 2, 3, // + 0, 128, 128, 1, 2, 128, 3, 4, // + 128, 0, 128, 1, 2, 128, 3, 4, // + 0, 1, 128, 2, 3, 128, 4, 5, // + 128, 128, 0, 1, 2, 128, 3, 4, // + 0, 128, 1, 2, 3, 128, 4, 5, // + 128, 0, 1, 2, 3, 128, 4, 5, // + 0, 1, 2, 3, 4, 128, 5, 6, // + 128, 128, 128, 128, 128, 0, 1, 2, // + 0, 128, 128, 128, 128, 1, 2, 3, // + 128, 0, 128, 128, 128, 1, 2, 3, // + 0, 1, 128, 128, 128, 2, 3, 4, // + 128, 128, 0, 128, 128, 1, 2, 3, // + 0, 128, 1, 128, 128, 2, 3, 4, // + 128, 0, 1, 128, 128, 2, 3, 4, // + 0, 1, 2, 128, 128, 3, 4, 5, // + 128, 128, 128, 0, 128, 1, 2, 3, // + 0, 128, 128, 1, 128, 2, 3, 4, // + 128, 0, 128, 1, 128, 2, 3, 4, // + 0, 1, 128, 2, 128, 3, 4, 5, // + 128, 128, 0, 1, 128, 2, 3, 4, // + 0, 128, 1, 2, 128, 3, 4, 5, // + 128, 0, 1, 2, 128, 3, 4, 5, // + 0, 1, 2, 3, 128, 4, 5, 6, // + 128, 128, 128, 128, 0, 1, 2, 3, // + 0, 128, 128, 128, 1, 2, 3, 4, // + 128, 0, 128, 128, 1, 2, 3, 4, // + 0, 1, 128, 128, 2, 3, 4, 5, // + 128, 128, 0, 128, 1, 2, 3, 4, // + 0, 128, 1, 128, 2, 3, 4, 5, // + 128, 0, 1, 128, 2, 3, 4, 5, // + 0, 1, 2, 128, 3, 4, 5, 6, // + 128, 128, 128, 0, 1, 2, 3, 4, // + 0, 128, 128, 1, 2, 3, 4, 5, // + 128, 0, 128, 1, 2, 3, 4, 5, // + 0, 1, 128, 2, 3, 4, 5, 6, // + 128, 128, 0, 1, 2, 3, 4, 5, // + 0, 128, 1, 2, 3, 4, 5, 6, // + 128, 0, 1, 2, 3, 4, 5, 6, // + 0, 1, 2, 3, 4, 5, 6, 7}; + return LoadU(du8, table + mask_bits * 8); +} + +} // namespace detail + +// Half vector of bytes: one table lookup +template +HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { + const DFromV d; + + const uint64_t mask_bits = BitsFromMask(d, mask); + const Vec128 indices = + detail::IndicesForExpandFromBits(mask_bits); + return BitCast(d, TableLookupBytesOr0(v, indices)); +} + +// Full vector of bytes: two table lookups +template +HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { + const Full128 d; + const RebindToUnsigned du; + const Half duh; + const Vec128 vu = BitCast(du, v); + + const uint64_t mask_bits = BitsFromMask(d, mask); + const uint64_t maskL = mask_bits & 0xFF; + const uint64_t maskH = mask_bits >> 8; + + // We want to skip past the v bytes already consumed by idxL. There is no + // instruction for shift-reg by variable bytes. Storing v itself would work + // but would involve a store-load forwarding stall. We instead shuffle using + // loaded indices. + // TODO: MultiRotateRight would also help, but if we have that, we probably + // also have native 8-bit Expand? + alignas(16) static constexpr uint8_t iota[32] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128}; + const VFromD shift = LoadU(du, iota + PopCount(maskL)); + const VFromD vL = LowerHalf(duh, vu); + const VFromD vH = + LowerHalf(duh, TableLookupBytesOr0(vu, shift)); + + const VFromD idxL = detail::IndicesForExpandFromBits<8>(maskL); + const VFromD idxH = detail::IndicesForExpandFromBits<8>(maskH); + + const VFromD expandL = TableLookupBytesOr0(vL, idxL); + const VFromD expandH = TableLookupBytesOr0(vH, idxH); + return BitCast(d, Combine(du, expandH, expandL)); +} + +template +HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { + const DFromV d; + const RebindToUnsigned du; + + const Rebind du8; + const uint64_t mask_bits = BitsFromMask(d, mask); + + // Storing as 8-bit reduces table size from 4 KiB to 2 KiB. We cannot apply + // the nibble trick used below because not all indices fit within one lane. + alignas(16) static constexpr uint8_t table[2048] = { + // PrintExpand16x8ByteTables + 128, 128, 128, 128, 128, 128, 128, 128, // + 0, 128, 128, 128, 128, 128, 128, 128, // + 128, 0, 128, 128, 128, 128, 128, 128, // + 0, 2, 128, 128, 128, 128, 128, 128, // + 128, 128, 0, 128, 128, 128, 128, 128, // + 0, 128, 2, 128, 128, 128, 128, 128, // + 128, 0, 2, 128, 128, 128, 128, 128, // + 0, 2, 4, 128, 128, 128, 128, 128, // + 128, 128, 128, 0, 128, 128, 128, 128, // + 0, 128, 128, 2, 128, 128, 128, 128, // + 128, 0, 128, 2, 128, 128, 128, 128, // + 0, 2, 128, 4, 128, 128, 128, 128, // + 128, 128, 0, 2, 128, 128, 128, 128, // + 0, 128, 2, 4, 128, 128, 128, 128, // + 128, 0, 2, 4, 128, 128, 128, 128, // + 0, 2, 4, 6, 128, 128, 128, 128, // + 128, 128, 128, 128, 0, 128, 128, 128, // + 0, 128, 128, 128, 2, 128, 128, 128, // + 128, 0, 128, 128, 2, 128, 128, 128, // + 0, 2, 128, 128, 4, 128, 128, 128, // + 128, 128, 0, 128, 2, 128, 128, 128, // + 0, 128, 2, 128, 4, 128, 128, 128, // + 128, 0, 2, 128, 4, 128, 128, 128, // + 0, 2, 4, 128, 6, 128, 128, 128, // + 128, 128, 128, 0, 2, 128, 128, 128, // + 0, 128, 128, 2, 4, 128, 128, 128, // + 128, 0, 128, 2, 4, 128, 128, 128, // + 0, 2, 128, 4, 6, 128, 128, 128, // + 128, 128, 0, 2, 4, 128, 128, 128, // + 0, 128, 2, 4, 6, 128, 128, 128, // + 128, 0, 2, 4, 6, 128, 128, 128, // + 0, 2, 4, 6, 8, 128, 128, 128, // + 128, 128, 128, 128, 128, 0, 128, 128, // + 0, 128, 128, 128, 128, 2, 128, 128, // + 128, 0, 128, 128, 128, 2, 128, 128, // + 0, 2, 128, 128, 128, 4, 128, 128, // + 128, 128, 0, 128, 128, 2, 128, 128, // + 0, 128, 2, 128, 128, 4, 128, 128, // + 128, 0, 2, 128, 128, 4, 128, 128, // + 0, 2, 4, 128, 128, 6, 128, 128, // + 128, 128, 128, 0, 128, 2, 128, 128, // + 0, 128, 128, 2, 128, 4, 128, 128, // + 128, 0, 128, 2, 128, 4, 128, 128, // + 0, 2, 128, 4, 128, 6, 128, 128, // + 128, 128, 0, 2, 128, 4, 128, 128, // + 0, 128, 2, 4, 128, 6, 128, 128, // + 128, 0, 2, 4, 128, 6, 128, 128, // + 0, 2, 4, 6, 128, 8, 128, 128, // + 128, 128, 128, 128, 0, 2, 128, 128, // + 0, 128, 128, 128, 2, 4, 128, 128, // + 128, 0, 128, 128, 2, 4, 128, 128, // + 0, 2, 128, 128, 4, 6, 128, 128, // + 128, 128, 0, 128, 2, 4, 128, 128, // + 0, 128, 2, 128, 4, 6, 128, 128, // + 128, 0, 2, 128, 4, 6, 128, 128, // + 0, 2, 4, 128, 6, 8, 128, 128, // + 128, 128, 128, 0, 2, 4, 128, 128, // + 0, 128, 128, 2, 4, 6, 128, 128, // + 128, 0, 128, 2, 4, 6, 128, 128, // + 0, 2, 128, 4, 6, 8, 128, 128, // + 128, 128, 0, 2, 4, 6, 128, 128, // + 0, 128, 2, 4, 6, 8, 128, 128, // + 128, 0, 2, 4, 6, 8, 128, 128, // + 0, 2, 4, 6, 8, 10, 128, 128, // + 128, 128, 128, 128, 128, 128, 0, 128, // + 0, 128, 128, 128, 128, 128, 2, 128, // + 128, 0, 128, 128, 128, 128, 2, 128, // + 0, 2, 128, 128, 128, 128, 4, 128, // + 128, 128, 0, 128, 128, 128, 2, 128, // + 0, 128, 2, 128, 128, 128, 4, 128, // + 128, 0, 2, 128, 128, 128, 4, 128, // + 0, 2, 4, 128, 128, 128, 6, 128, // + 128, 128, 128, 0, 128, 128, 2, 128, // + 0, 128, 128, 2, 128, 128, 4, 128, // + 128, 0, 128, 2, 128, 128, 4, 128, // + 0, 2, 128, 4, 128, 128, 6, 128, // + 128, 128, 0, 2, 128, 128, 4, 128, // + 0, 128, 2, 4, 128, 128, 6, 128, // + 128, 0, 2, 4, 128, 128, 6, 128, // + 0, 2, 4, 6, 128, 128, 8, 128, // + 128, 128, 128, 128, 0, 128, 2, 128, // + 0, 128, 128, 128, 2, 128, 4, 128, // + 128, 0, 128, 128, 2, 128, 4, 128, // + 0, 2, 128, 128, 4, 128, 6, 128, // + 128, 128, 0, 128, 2, 128, 4, 128, // + 0, 128, 2, 128, 4, 128, 6, 128, // + 128, 0, 2, 128, 4, 128, 6, 128, // + 0, 2, 4, 128, 6, 128, 8, 128, // + 128, 128, 128, 0, 2, 128, 4, 128, // + 0, 128, 128, 2, 4, 128, 6, 128, // + 128, 0, 128, 2, 4, 128, 6, 128, // + 0, 2, 128, 4, 6, 128, 8, 128, // + 128, 128, 0, 2, 4, 128, 6, 128, // + 0, 128, 2, 4, 6, 128, 8, 128, // + 128, 0, 2, 4, 6, 128, 8, 128, // + 0, 2, 4, 6, 8, 128, 10, 128, // + 128, 128, 128, 128, 128, 0, 2, 128, // + 0, 128, 128, 128, 128, 2, 4, 128, // + 128, 0, 128, 128, 128, 2, 4, 128, // + 0, 2, 128, 128, 128, 4, 6, 128, // + 128, 128, 0, 128, 128, 2, 4, 128, // + 0, 128, 2, 128, 128, 4, 6, 128, // + 128, 0, 2, 128, 128, 4, 6, 128, // + 0, 2, 4, 128, 128, 6, 8, 128, // + 128, 128, 128, 0, 128, 2, 4, 128, // + 0, 128, 128, 2, 128, 4, 6, 128, // + 128, 0, 128, 2, 128, 4, 6, 128, // + 0, 2, 128, 4, 128, 6, 8, 128, // + 128, 128, 0, 2, 128, 4, 6, 128, // + 0, 128, 2, 4, 128, 6, 8, 128, // + 128, 0, 2, 4, 128, 6, 8, 128, // + 0, 2, 4, 6, 128, 8, 10, 128, // + 128, 128, 128, 128, 0, 2, 4, 128, // + 0, 128, 128, 128, 2, 4, 6, 128, // + 128, 0, 128, 128, 2, 4, 6, 128, // + 0, 2, 128, 128, 4, 6, 8, 128, // + 128, 128, 0, 128, 2, 4, 6, 128, // + 0, 128, 2, 128, 4, 6, 8, 128, // + 128, 0, 2, 128, 4, 6, 8, 128, // + 0, 2, 4, 128, 6, 8, 10, 128, // + 128, 128, 128, 0, 2, 4, 6, 128, // + 0, 128, 128, 2, 4, 6, 8, 128, // + 128, 0, 128, 2, 4, 6, 8, 128, // + 0, 2, 128, 4, 6, 8, 10, 128, // + 128, 128, 0, 2, 4, 6, 8, 128, // + 0, 128, 2, 4, 6, 8, 10, 128, // + 128, 0, 2, 4, 6, 8, 10, 128, // + 0, 2, 4, 6, 8, 10, 12, 128, // + 128, 128, 128, 128, 128, 128, 128, 0, // + 0, 128, 128, 128, 128, 128, 128, 2, // + 128, 0, 128, 128, 128, 128, 128, 2, // + 0, 2, 128, 128, 128, 128, 128, 4, // + 128, 128, 0, 128, 128, 128, 128, 2, // + 0, 128, 2, 128, 128, 128, 128, 4, // + 128, 0, 2, 128, 128, 128, 128, 4, // + 0, 2, 4, 128, 128, 128, 128, 6, // + 128, 128, 128, 0, 128, 128, 128, 2, // + 0, 128, 128, 2, 128, 128, 128, 4, // + 128, 0, 128, 2, 128, 128, 128, 4, // + 0, 2, 128, 4, 128, 128, 128, 6, // + 128, 128, 0, 2, 128, 128, 128, 4, // + 0, 128, 2, 4, 128, 128, 128, 6, // + 128, 0, 2, 4, 128, 128, 128, 6, // + 0, 2, 4, 6, 128, 128, 128, 8, // + 128, 128, 128, 128, 0, 128, 128, 2, // + 0, 128, 128, 128, 2, 128, 128, 4, // + 128, 0, 128, 128, 2, 128, 128, 4, // + 0, 2, 128, 128, 4, 128, 128, 6, // + 128, 128, 0, 128, 2, 128, 128, 4, // + 0, 128, 2, 128, 4, 128, 128, 6, // + 128, 0, 2, 128, 4, 128, 128, 6, // + 0, 2, 4, 128, 6, 128, 128, 8, // + 128, 128, 128, 0, 2, 128, 128, 4, // + 0, 128, 128, 2, 4, 128, 128, 6, // + 128, 0, 128, 2, 4, 128, 128, 6, // + 0, 2, 128, 4, 6, 128, 128, 8, // + 128, 128, 0, 2, 4, 128, 128, 6, // + 0, 128, 2, 4, 6, 128, 128, 8, // + 128, 0, 2, 4, 6, 128, 128, 8, // + 0, 2, 4, 6, 8, 128, 128, 10, // + 128, 128, 128, 128, 128, 0, 128, 2, // + 0, 128, 128, 128, 128, 2, 128, 4, // + 128, 0, 128, 128, 128, 2, 128, 4, // + 0, 2, 128, 128, 128, 4, 128, 6, // + 128, 128, 0, 128, 128, 2, 128, 4, // + 0, 128, 2, 128, 128, 4, 128, 6, // + 128, 0, 2, 128, 128, 4, 128, 6, // + 0, 2, 4, 128, 128, 6, 128, 8, // + 128, 128, 128, 0, 128, 2, 128, 4, // + 0, 128, 128, 2, 128, 4, 128, 6, // + 128, 0, 128, 2, 128, 4, 128, 6, // + 0, 2, 128, 4, 128, 6, 128, 8, // + 128, 128, 0, 2, 128, 4, 128, 6, // + 0, 128, 2, 4, 128, 6, 128, 8, // + 128, 0, 2, 4, 128, 6, 128, 8, // + 0, 2, 4, 6, 128, 8, 128, 10, // + 128, 128, 128, 128, 0, 2, 128, 4, // + 0, 128, 128, 128, 2, 4, 128, 6, // + 128, 0, 128, 128, 2, 4, 128, 6, // + 0, 2, 128, 128, 4, 6, 128, 8, // + 128, 128, 0, 128, 2, 4, 128, 6, // + 0, 128, 2, 128, 4, 6, 128, 8, // + 128, 0, 2, 128, 4, 6, 128, 8, // + 0, 2, 4, 128, 6, 8, 128, 10, // + 128, 128, 128, 0, 2, 4, 128, 6, // + 0, 128, 128, 2, 4, 6, 128, 8, // + 128, 0, 128, 2, 4, 6, 128, 8, // + 0, 2, 128, 4, 6, 8, 128, 10, // + 128, 128, 0, 2, 4, 6, 128, 8, // + 0, 128, 2, 4, 6, 8, 128, 10, // + 128, 0, 2, 4, 6, 8, 128, 10, // + 0, 2, 4, 6, 8, 10, 128, 12, // + 128, 128, 128, 128, 128, 128, 0, 2, // + 0, 128, 128, 128, 128, 128, 2, 4, // + 128, 0, 128, 128, 128, 128, 2, 4, // + 0, 2, 128, 128, 128, 128, 4, 6, // + 128, 128, 0, 128, 128, 128, 2, 4, // + 0, 128, 2, 128, 128, 128, 4, 6, // + 128, 0, 2, 128, 128, 128, 4, 6, // + 0, 2, 4, 128, 128, 128, 6, 8, // + 128, 128, 128, 0, 128, 128, 2, 4, // + 0, 128, 128, 2, 128, 128, 4, 6, // + 128, 0, 128, 2, 128, 128, 4, 6, // + 0, 2, 128, 4, 128, 128, 6, 8, // + 128, 128, 0, 2, 128, 128, 4, 6, // + 0, 128, 2, 4, 128, 128, 6, 8, // + 128, 0, 2, 4, 128, 128, 6, 8, // + 0, 2, 4, 6, 128, 128, 8, 10, // + 128, 128, 128, 128, 0, 128, 2, 4, // + 0, 128, 128, 128, 2, 128, 4, 6, // + 128, 0, 128, 128, 2, 128, 4, 6, // + 0, 2, 128, 128, 4, 128, 6, 8, // + 128, 128, 0, 128, 2, 128, 4, 6, // + 0, 128, 2, 128, 4, 128, 6, 8, // + 128, 0, 2, 128, 4, 128, 6, 8, // + 0, 2, 4, 128, 6, 128, 8, 10, // + 128, 128, 128, 0, 2, 128, 4, 6, // + 0, 128, 128, 2, 4, 128, 6, 8, // + 128, 0, 128, 2, 4, 128, 6, 8, // + 0, 2, 128, 4, 6, 128, 8, 10, // + 128, 128, 0, 2, 4, 128, 6, 8, // + 0, 128, 2, 4, 6, 128, 8, 10, // + 128, 0, 2, 4, 6, 128, 8, 10, // + 0, 2, 4, 6, 8, 128, 10, 12, // + 128, 128, 128, 128, 128, 0, 2, 4, // + 0, 128, 128, 128, 128, 2, 4, 6, // + 128, 0, 128, 128, 128, 2, 4, 6, // + 0, 2, 128, 128, 128, 4, 6, 8, // + 128, 128, 0, 128, 128, 2, 4, 6, // + 0, 128, 2, 128, 128, 4, 6, 8, // + 128, 0, 2, 128, 128, 4, 6, 8, // + 0, 2, 4, 128, 128, 6, 8, 10, // + 128, 128, 128, 0, 128, 2, 4, 6, // + 0, 128, 128, 2, 128, 4, 6, 8, // + 128, 0, 128, 2, 128, 4, 6, 8, // + 0, 2, 128, 4, 128, 6, 8, 10, // + 128, 128, 0, 2, 128, 4, 6, 8, // + 0, 128, 2, 4, 128, 6, 8, 10, // + 128, 0, 2, 4, 128, 6, 8, 10, // + 0, 2, 4, 6, 128, 8, 10, 12, // + 128, 128, 128, 128, 0, 2, 4, 6, // + 0, 128, 128, 128, 2, 4, 6, 8, // + 128, 0, 128, 128, 2, 4, 6, 8, // + 0, 2, 128, 128, 4, 6, 8, 10, // + 128, 128, 0, 128, 2, 4, 6, 8, // + 0, 128, 2, 128, 4, 6, 8, 10, // + 128, 0, 2, 128, 4, 6, 8, 10, // + 0, 2, 4, 128, 6, 8, 10, 12, // + 128, 128, 128, 0, 2, 4, 6, 8, // + 0, 128, 128, 2, 4, 6, 8, 10, // + 128, 0, 128, 2, 4, 6, 8, 10, // + 0, 2, 128, 4, 6, 8, 10, 12, // + 128, 128, 0, 2, 4, 6, 8, 10, // + 0, 128, 2, 4, 6, 8, 10, 12, // + 128, 0, 2, 4, 6, 8, 10, 12, // + 0, 2, 4, 6, 8, 10, 12, 14}; + // Extend to double length because InterleaveLower will only use the (valid) + // lower half, and we want N u16. + const Twice du8x2; + const Vec128 indices8 = + ZeroExtendVector(du8x2, Load(du8, table + mask_bits * 8)); + const Vec128 indices16 = + BitCast(du, InterleaveLower(du8x2, indices8, indices8)); + // TableLookupBytesOr0 operates on bytes. To convert u16 lane indices to byte + // indices, add 0 to even and 1 to odd byte lanes. + const Vec128 byte_indices = Add( + indices16, + Set(du, static_cast(HWY_IS_LITTLE_ENDIAN ? 0x0100 : 0x0001))); + return BitCast(d, TableLookupBytesOr0(v, byte_indices)); +} + +template +HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { + const DFromV d; + const RebindToUnsigned du; + + const uint64_t mask_bits = BitsFromMask(d, mask); + + alignas(16) static constexpr uint32_t packed_array[16] = { + // PrintExpand64x4Nibble - same for 32x4. + 0x0000ffff, 0x0000fff0, 0x0000ff0f, 0x0000ff10, 0x0000f0ff, 0x0000f1f0, + 0x0000f10f, 0x0000f210, 0x00000fff, 0x00001ff0, 0x00001f0f, 0x00002f10, + 0x000010ff, 0x000021f0, 0x0000210f, 0x00003210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 2). + const Vec128 packed = Set(du, packed_array[mask_bits]); + alignas(16) static constexpr uint32_t shifts[4] = {0, 4, 8, 12}; + Vec128 indices = packed >> Load(du, shifts); + // AVX2 _mm256_permutexvar_epi32 will ignore upper bits, but IndicesFromVec + // checks bounds, so clear the upper bits. + indices = And(indices, Set(du, N - 1)); + const Vec128 expand = + TableLookupLanes(BitCast(du, v), IndicesFromVec(du, indices)); + // TableLookupLanes cannot also zero masked-off lanes, so do that now. + return IfThenElseZero(mask, BitCast(d, expand)); +} + +template +HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { + // Same as Compress, just zero out the mask=false lanes. + return IfThenElseZero(mask, Compress(v, mask)); +} + +// For single-element vectors, this is at least as fast as native. +template +HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { + return IfThenElseZero(mask, v); +} + +// ------------------------------ LoadExpand + +// #2957: clangd warning because x86_128-inl.h defines an overload with this +// condition, so negate it here. +#if !(HWY_TARGET <= HWY_AVX3 || HWY_IDE) + +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { + return Expand(LoadU(d, unaligned), mask); +} + +#endif // !(HWY_TARGET <= HWY_AVX3 || HWY_IDE) + +#endif // HWY_NATIVE_EXPAND + +// ------------------------------ TwoTablesLookupLanes + +template +using IndicesFromD = decltype(IndicesFromVec(D(), Zero(RebindToUnsigned()))); + +// RVV/SVE have their own implementations of +// TwoTablesLookupLanes(D d, VFromD a, VFromD b, IndicesFromD idx) +#if HWY_TARGET != HWY_RVV && !HWY_TARGET_IS_SVE +template +HWY_API VFromD TwoTablesLookupLanes(D /*d*/, VFromD a, VFromD b, + IndicesFromD idx) { + return TwoTablesLookupLanes(a, b, idx); +} +#endif + +// ------------------------------ Lookup8 + +template , class VI> +HWY_INLINE Vec Lookup8(D d, const T* HWY_RESTRICT table, VI indices) { + // `di` describes the indices given - same bits per lane, but `d` determines + // the actual lane count of the result and also of the table vectors, which + // is relevant for adjusting the index values, see below. + DFromV di; + static_assert(sizeof(T) == sizeof(TFromD), + "Index/vector must have same lane size"); + HWY_IF_CONSTEXPR(HWY_IS_DEBUG_BUILD) { + // Asserting Lanes(di) >= 4 not needed since both d and di have the same + // number of Lanes() + HWY_DASSERT(Lanes(d) >= 4); + HWY_DASSERT(AllTrue(di, Lt(indices, Set(di, 8)))); + } + + HWY_IF_CONSTEXPR(!HWY_HAVE_SCALABLE) { + // Fixed-size vectors: we know they are >= 128 bit, so either one or two + // tables are sufficient. + HWY_IF_CONSTEXPR(MaxLanes(d) >= 8) { + const CappedTag d8; + // We want to perform one lookup per index, hence cast. This has no + // runtime cost; the upper lanes are unused. + const Vec t0 = ResizeBitCast(d, Load(d8, table)); + return TableLookupLanes(t0, IndicesFromVec(d, indices)); + } + HWY_IF_CONSTEXPR(MaxLanes(d) < 8) { + // Exactly 4 lanes, because we ensured >= 4 above. + const Vec t0 = Load(d, table); + const Vec t1 = Load(d, table + 4); + return TwoTablesLookupLanes(d, t0, t1, IndicesFromVec(d, indices)); + } + } + + HWY_IF_CONSTEXPR(HWY_HAVE_SCALABLE) { + // Scalable: first we must load two halves of the table into two vectors, + // regardless of vector size. We always use two-vector lookups to avoid + // runtime branching. Note that RVV can have U64x8 even with 128-bit + // vectors (LMUL=4), hence we must use the given LMUL, not FixedTag, but we + // still want to cap at 4 lanes to avoid overrunning the table. + const CappedTag d4; + + // We want to use native lookup instructions (more efficient on SVE than two + // lookups plus a blend), hence cast. This has no runtime cost. No LoadU + // required because + 4 is still aligned relative to `d4`. + const Vec t0 = ResizeBitCast(d, Load(d4, table)); + const Vec t1 = ResizeBitCast(d, Load(d4, table + 4)); + + // Now ensure indices for the second half of the table point to the second + // vector. Note that SVE2_128 and SVE_256 are handled by the fixed-size case + // above. The adjustment factor is 0 for 128-bit SIMD, which can happen with + // 128-bit SVE1 hardware, but we do not know that at compile time. + using TI = TFromD; + const VI adjust = Set(di, static_cast(Lanes(d) - 4)); +#if HWY_TARGET_IS_SVE + const Mask ge_4 = detail::GeN(indices, 4); +#else + const Mask ge_4 = Ge(indices, Set(di, 4)); +#endif + indices = MaskedAddOr(indices, ge_4, indices, adjust); + + return TwoTablesLookupLanes(d, t0, t1, IndicesFromVec(d, indices)); + } +} + +// ------------------------------ Reverse2, Reverse4, Reverse8 (8-bit) + +#if (defined(HWY_NATIVE_REVERSE2_8) == defined(HWY_TARGET_TOGGLE)) || HWY_IDE +#ifdef HWY_NATIVE_REVERSE2_8 +#undef HWY_NATIVE_REVERSE2_8 +#else +#define HWY_NATIVE_REVERSE2_8 +#endif + +#undef HWY_PREFER_ROTATE +// Platforms on which RotateRight is likely faster than TableLookupBytes. +// RVV and SVE anyway have their own implementation of this. +#if HWY_TARGET == HWY_SSE2 || HWY_TARGET <= HWY_AVX3 || \ + HWY_TARGET == HWY_WASM || HWY_TARGET == HWY_PPC8 +#define HWY_PREFER_ROTATE 1 +#else +#define HWY_PREFER_ROTATE 0 +#endif + +template +HWY_API VFromD Reverse2(D d, VFromD v) { + // Exclude AVX3 because its 16-bit RotateRight is actually 3 instructions. +#if HWY_PREFER_ROTATE && HWY_TARGET > HWY_AVX3 + const Repartition du16; + return BitCast(d, RotateRight<8>(BitCast(du16, v))); +#else + const VFromD shuffle = Dup128VecFromValues(d, 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, + 11, 10, 13, 12, 15, 14); + return TableLookupBytes(v, shuffle); +#endif +} + +template +HWY_API VFromD Reverse4(D d, VFromD v) { +#if HWY_PREFER_ROTATE + const Repartition du16; + return BitCast(d, Reverse2(du16, BitCast(du16, Reverse2(d, v)))); +#else + const Repartition du8; + const VFromD shuffle = Dup128VecFromValues( + du8, 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12); + return TableLookupBytes(v, BitCast(d, shuffle)); +#endif +} + +template +HWY_API VFromD Reverse8(D d, VFromD v) { +#if HWY_PREFER_ROTATE + const Repartition du32; + return BitCast(d, Reverse2(du32, BitCast(du32, Reverse4(d, v)))); +#else + const Repartition du8; + const VFromD shuffle = Dup128VecFromValues( + du8, 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8); + return TableLookupBytes(v, BitCast(d, shuffle)); +#endif +} + +#endif // HWY_NATIVE_REVERSE2_8 + +// ------------------------------ ReverseLaneBytes + +#if (defined(HWY_NATIVE_REVERSE_LANE_BYTES) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_REVERSE_LANE_BYTES +#undef HWY_NATIVE_REVERSE_LANE_BYTES +#else +#define HWY_NATIVE_REVERSE_LANE_BYTES +#endif + +template +HWY_API V ReverseLaneBytes(V v) { + const DFromV d; + const Repartition du8; + return BitCast(d, Reverse2(du8, BitCast(du8, v))); +} + +template +HWY_API V ReverseLaneBytes(V v) { + const DFromV d; + const Repartition du8; + return BitCast(d, Reverse4(du8, BitCast(du8, v))); +} + +template +HWY_API V ReverseLaneBytes(V v) { + const DFromV d; + const Repartition du8; + return BitCast(d, Reverse8(du8, BitCast(du8, v))); +} + +#endif // HWY_NATIVE_REVERSE_LANE_BYTES + +// ------------------------------ ReverseBits + +// On these targets, we emulate 8-bit shifts using 16-bit shifts and therefore +// require at least two lanes to BitCast to 16-bit. We avoid Highway's 8-bit +// shifts because those would add extra masking already taken care of by +// UI8ReverseBitsStep. Note that AVX3_DL/AVX3_ZEN4 support GFNI and use it to +// implement ReverseBits, so this code is not used there. +#undef HWY_REVERSE_BITS_MIN_BYTES +#if ((HWY_TARGET >= HWY_AVX3 && HWY_TARGET <= HWY_SSE2) || \ + HWY_TARGET == HWY_WASM || HWY_TARGET == HWY_WASM_EMU256) +#define HWY_REVERSE_BITS_MIN_BYTES 2 +#else +#define HWY_REVERSE_BITS_MIN_BYTES 1 +#endif + +#if (defined(HWY_NATIVE_REVERSE_BITS_UI8) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_REVERSE_BITS_UI8 +#undef HWY_NATIVE_REVERSE_BITS_UI8 +#else +#define HWY_NATIVE_REVERSE_BITS_UI8 +#endif + +namespace detail { + +template , HWY_REVERSE_BITS_MIN_BYTES - 1)> +HWY_INLINE V UI8ReverseBitsStep(V v) { + const DFromV d; + const RebindToUnsigned du; +#if HWY_REVERSE_BITS_MIN_BYTES == 2 + const Repartition d_shift; +#else + const RebindToUnsigned d_shift; +#endif + + const auto v_to_shift = BitCast(d_shift, v); + const auto shl_result = BitCast(d, ShiftLeft(v_to_shift)); + const auto shr_result = BitCast(d, ShiftRight(v_to_shift)); + const auto shr_result_mask = + BitCast(d, Set(du, static_cast(kShrResultMask))); + return Or(And(shr_result, shr_result_mask), + AndNot(shr_result_mask, shl_result)); +} + +#if HWY_REVERSE_BITS_MIN_BYTES == 2 +template , 1)> +HWY_INLINE V UI8ReverseBitsStep(V v) { + return V{UI8ReverseBitsStep(Vec128{v.raw}) + .raw}; +} +#endif + +} // namespace detail + +template +HWY_API V ReverseBits(V v) { + auto result = detail::UI8ReverseBitsStep<1, 0x55>(v); + result = detail::UI8ReverseBitsStep<2, 0x33>(result); + result = detail::UI8ReverseBitsStep<4, 0x0F>(result); + return result; +} + +#endif // HWY_NATIVE_REVERSE_BITS_UI8 + +#if (defined(HWY_NATIVE_REVERSE_BITS_UI16_32_64) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_REVERSE_BITS_UI16_32_64 +#undef HWY_NATIVE_REVERSE_BITS_UI16_32_64 +#else +#define HWY_NATIVE_REVERSE_BITS_UI16_32_64 +#endif + +template +HWY_API V ReverseBits(V v) { + const DFromV d; + const Repartition du8; + return ReverseLaneBytes(BitCast(d, ReverseBits(BitCast(du8, v)))); +} +#endif // HWY_NATIVE_REVERSE_BITS_UI16_32_64 + +// ------------------------------ Per4LaneBlockShuffle + +#if (defined(HWY_NATIVE_PER4LANEBLKSHUF_DUP32) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#undef HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#else +#define HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#endif + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +namespace detail { + +template +HWY_INLINE Vec Per4LaneBlkShufDupSet4xU32(D d, const uint32_t x3, + const uint32_t x2, + const uint32_t x1, + const uint32_t x0) { +#if HWY_TARGET == HWY_RVV + constexpr int kPow2 = d.Pow2(); + constexpr int kLoadPow2 = HWY_MAX(kPow2, -1); + const ScalableTag d_load; +#else + constexpr size_t kMaxBytes = d.MaxBytes(); +#if HWY_TARGET_IS_NEON + constexpr size_t kMinLanesToLoad = 2; +#else + constexpr size_t kMinLanesToLoad = 4; +#endif + constexpr size_t kNumToLoad = + HWY_MAX(kMaxBytes / sizeof(uint32_t), kMinLanesToLoad); + const CappedTag d_load; +#endif + return ResizeBitCast(d, Dup128VecFromValues(d_load, x0, x1, x2, x3)); +} + +} // namespace detail +#endif + +#endif // HWY_NATIVE_PER4LANEBLKSHUF_DUP32 + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +namespace detail { + +template +HWY_INLINE V Per2LaneBlockShuffle(hwy::SizeTag<0> /*idx_10_tag*/, V v) { + return DupEven(v); +} + +template +HWY_INLINE V Per2LaneBlockShuffle(hwy::SizeTag<1> /*idx_10_tag*/, V v) { + const DFromV d; + return Reverse2(d, v); +} + +template +HWY_INLINE V Per2LaneBlockShuffle(hwy::SizeTag<2> /*idx_10_tag*/, V v) { + return v; +} + +template +HWY_INLINE V Per2LaneBlockShuffle(hwy::SizeTag<3> /*idx_10_tag*/, V v) { + return DupOdd(v); +} + +HWY_INLINE uint32_t U8x4Per4LaneBlkIndices(const uint32_t idx3, + const uint32_t idx2, + const uint32_t idx1, + const uint32_t idx0) { +#if HWY_IS_LITTLE_ENDIAN + return static_cast((idx3 << 24) | (idx2 << 16) | (idx1 << 8) | + idx0); +#else + return static_cast(idx3 | (idx2 << 8) | (idx1 << 16) | + (idx0 << 24)); +#endif +} + +template +HWY_INLINE Vec TblLookupPer4LaneBlkU8IdxInBlk(D d, const uint32_t idx3, + const uint32_t idx2, + const uint32_t idx1, + const uint32_t idx0) { +#if HWY_TARGET == HWY_RVV + const AdjustSimdTagToMinVecPow2> du32; +#else + const Repartition du32; +#endif + + return ResizeBitCast( + d, Set(du32, U8x4Per4LaneBlkIndices(idx3, idx2, idx1, idx0))); +} + +#if HWY_HAVE_SCALABLE || HWY_TARGET_IS_SVE || HWY_TARGET == HWY_EMU128 +#define HWY_PER_4_BLK_TBL_LOOKUP_LANES_ENABLE(D) void* = nullptr +#else +#define HWY_PER_4_BLK_TBL_LOOKUP_LANES_ENABLE(D) HWY_IF_T_SIZE_D(D, 8) + +template +HWY_INLINE V Per4LaneBlkShufDoTblLookup(V v, V idx) { + const DFromV d; + const Repartition du8; + return BitCast(d, TableLookupBytes(BitCast(du8, v), BitCast(du8, idx))); +} + +template +HWY_INLINE Vec TblLookupPer4LaneBlkShufIdx(D d, const uint32_t idx3, + const uint32_t idx2, + const uint32_t idx1, + const uint32_t idx0) { + const Repartition du32; + const uint32_t idx3210 = U8x4Per4LaneBlkIndices(idx3, idx2, idx1, idx0); + const auto v_byte_idx = Per4LaneBlkShufDupSet4xU32( + du32, static_cast(idx3210 + 0x0C0C0C0C), + static_cast(idx3210 + 0x08080808), + static_cast(idx3210 + 0x04040404), + static_cast(idx3210)); + return ResizeBitCast(d, v_byte_idx); +} + +template +HWY_INLINE Vec TblLookupPer4LaneBlkShufIdx(D d, const uint32_t idx3, + const uint32_t idx2, + const uint32_t idx1, + const uint32_t idx0) { + const Repartition du32; +#if HWY_IS_LITTLE_ENDIAN + const uint32_t idx10 = static_cast((idx1 << 16) | idx0); + const uint32_t idx32 = static_cast((idx3 << 16) | idx2); + constexpr uint32_t kLaneByteOffsets{0x01000100}; +#else + const uint32_t idx10 = static_cast(idx1 | (idx0 << 16)); + const uint32_t idx32 = static_cast(idx3 | (idx2 << 16)); + constexpr uint32_t kLaneByteOffsets{0x00010001}; +#endif + constexpr uint32_t kHiLaneByteOffsets{kLaneByteOffsets + 0x08080808u}; + + const auto v_byte_idx = Per4LaneBlkShufDupSet4xU32( + du32, static_cast(idx32 * 0x0202u + kHiLaneByteOffsets), + static_cast(idx10 * 0x0202u + kHiLaneByteOffsets), + static_cast(idx32 * 0x0202u + kLaneByteOffsets), + static_cast(idx10 * 0x0202u + kLaneByteOffsets)); + return ResizeBitCast(d, v_byte_idx); +} + +template +HWY_INLINE Vec TblLookupPer4LaneBlkShufIdx(D d, const uint32_t idx3, + const uint32_t idx2, + const uint32_t idx1, + const uint32_t idx0) { + const Repartition du32; +#if HWY_IS_LITTLE_ENDIAN + constexpr uint32_t kLaneByteOffsets{0x03020100}; +#else + constexpr uint32_t kLaneByteOffsets{0x00010203}; +#endif + + const auto v_byte_idx = Per4LaneBlkShufDupSet4xU32( + du32, static_cast(idx3 * 0x04040404u + kLaneByteOffsets), + static_cast(idx2 * 0x04040404u + kLaneByteOffsets), + static_cast(idx1 * 0x04040404u + kLaneByteOffsets), + static_cast(idx0 * 0x04040404u + kLaneByteOffsets)); + return ResizeBitCast(d, v_byte_idx); +} +#endif + +template +HWY_INLINE VFromD TblLookupPer4LaneBlkIdxInBlk(D d, const uint32_t idx3, + const uint32_t idx2, + const uint32_t idx1, + const uint32_t idx0) { + return TblLookupPer4LaneBlkU8IdxInBlk(d, idx3, idx2, idx1, idx0); +} + +#if HWY_TARGET == HWY_RVV +template +HWY_INLINE VFromD TblLookupPer4LaneBlkIdxInBlk(D d, const uint32_t idx3, + const uint32_t idx2, + const uint32_t idx1, + const uint32_t idx0) { + const Rebind du8; + return PromoteTo(d, + TblLookupPer4LaneBlkU8IdxInBlk(du8, idx3, idx2, idx1, idx0)); +} +#else +template +HWY_INLINE VFromD TblLookupPer4LaneBlkIdxInBlk(D d, const uint32_t idx3, + const uint32_t idx2, + const uint32_t idx1, + const uint32_t idx0) { + const uint16_t u16_idx0 = static_cast(idx0); + const uint16_t u16_idx1 = static_cast(idx1); + const uint16_t u16_idx2 = static_cast(idx2); + const uint16_t u16_idx3 = static_cast(idx3); +#if HWY_TARGET_IS_NEON + constexpr size_t kMinLanesToLoad = 4; +#else + constexpr size_t kMinLanesToLoad = 8; +#endif + constexpr size_t kNumToLoad = HWY_MAX(HWY_MAX_LANES_D(D), kMinLanesToLoad); + const CappedTag d_load; + return ResizeBitCast( + d, Dup128VecFromValues(d_load, u16_idx0, u16_idx1, u16_idx2, u16_idx3, + u16_idx0, u16_idx1, u16_idx2, u16_idx3)); +} + +template +HWY_INLINE VFromD TblLookupPer4LaneBlkIdxInBlk(D d, const uint32_t idx3, + const uint32_t idx2, + const uint32_t idx1, + const uint32_t idx0) { + return Per4LaneBlkShufDupSet4xU32(d, idx3, idx2, idx1, idx0); +} + +template +HWY_INLINE VFromD TblLookupPer4LaneBlkIdxInBlk(D d, const uint32_t idx3, + const uint32_t idx2, + const uint32_t idx1, + const uint32_t idx0) { + const RebindToUnsigned du; + const Rebind du32; + return BitCast(d, PromoteTo(du, Per4LaneBlkShufDupSet4xU32(du32, idx3, idx2, + idx1, idx0))); +} +#endif + +template +HWY_INLINE IndicesFromD TblLookupPer4LaneBlkShufIdx(D d, const uint32_t idx3, + const uint32_t idx2, + const uint32_t idx1, + const uint32_t idx0) { + const RebindToUnsigned du; + using TU = TFromD; + auto idx_in_blk = TblLookupPer4LaneBlkIdxInBlk(du, idx3, idx2, idx1, idx0); + + constexpr size_t kN = HWY_MAX_LANES_D(D); + if (kN < 4) { + idx_in_blk = And(idx_in_blk, Set(du, static_cast(kN - 1))); + } + +#if HWY_TARGET == HWY_RVV + const auto blk_offsets = AndS(Iota0(du), static_cast(~TU{3})); +#else + const auto blk_offsets = + And(Iota(du, TU{0}), Set(du, static_cast(~TU{3}))); +#endif + return IndicesFromVec(d, Add(idx_in_blk, blk_offsets)); +} + +template )> +HWY_INLINE V Per4LaneBlkShufDoTblLookup(V v, IndicesFromD> idx) { + return TableLookupLanes(v, idx); +} + +#undef HWY_PER_4_BLK_TBL_LOOKUP_LANES_ENABLE + +template +HWY_INLINE V TblLookupPer4LaneBlkShuf(V v, size_t idx3210) { + const DFromV d; + const uint32_t idx3 = static_cast((idx3210 >> 6) & 3); + const uint32_t idx2 = static_cast((idx3210 >> 4) & 3); + const uint32_t idx1 = static_cast((idx3210 >> 2) & 3); + const uint32_t idx0 = static_cast(idx3210 & 3); + const auto idx = TblLookupPer4LaneBlkShufIdx(d, idx3, idx2, idx1, idx0); + return Per4LaneBlkShufDoTblLookup(v, idx); +} + +// The detail::Per4LaneBlockShuffle overloads that have the extra lane_size_tag +// and vect_size_tag parameters are only called for vectors that have at +// least 4 lanes (or scalable vectors that might possibly have 4 or more lanes) +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag /*lane_size_tag*/, + hwy::SizeTag /*vect_size_tag*/, + V v) { + return TblLookupPer4LaneBlkShuf(v, kIdx3210); +} + +#if HWY_HAVE_FLOAT64 +template +HWY_INLINE VFromD>> Per4LaneBlockShufCastToWide( + hwy::FloatTag /* type_tag */, hwy::SizeTag<4> /* lane_size_tag */, V v) { + const DFromV d; + const RepartitionToWide dw; + return BitCast(dw, v); +} +#endif + +template +HWY_INLINE VFromD>>> +Per4LaneBlockShufCastToWide(hwy::FloatTag /* type_tag */, + hwy::SizeTag /* lane_size_tag */, V v) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWide dw; + return BitCast(dw, v); +} + +template +HWY_INLINE VFromD>> Per4LaneBlockShufCastToWide( + hwy::NonFloatTag /* type_tag */, + hwy::SizeTag /* lane_size_tag */, V v) { + const DFromV d; + const RepartitionToWide dw; + return BitCast(dw, v); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x1B> /*idx_3210_tag*/, V v) { + const DFromV d; + return Reverse4(d, v); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x44> /*idx_3210_tag*/, V v) { + const DFromV d; + const auto vw = Per4LaneBlockShufCastToWide( + hwy::IsFloatTag>(), hwy::SizeTag)>(), v); + return BitCast(d, DupEven(vw)); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x4E> /*idx_3210_tag*/, V v) { + const DFromV d; + const auto vw = Per4LaneBlockShufCastToWide( + hwy::IsFloatTag>(), hwy::SizeTag)>(), v); + const DFromV dw; + return BitCast(d, Reverse2(dw, vw)); +} + +#if HWY_MAX_BYTES >= 32 +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x4E> /*idx_3210_tag*/, V v) { + return SwapAdjacentBlocks(v); +} +#endif + +template , 4), + HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2))> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x50> /*idx_3210_tag*/, V v) { + const DFromV d; + return InterleaveLower(d, v, v); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x50> /*idx_3210_tag*/, V v) { + const DFromV d; + return InterleaveLower(d, v, v); +} + +template , 4)> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x88> /*idx_3210_tag*/, V v) { + const DFromV d; + return ConcatEven(d, v, v); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xA0> /*idx_3210_tag*/, V v) { + return DupEven(v); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xB1> /*idx_3210_tag*/, V v) { + const DFromV d; + return Reverse2(d, v); +} + +template , 4)> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xDD> /*idx_3210_tag*/, V v) { + const DFromV d; + return ConcatOdd(d, v, v); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xE4> /*idx_3210_tag*/, V v) { + return v; +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xEE> /*idx_3210_tag*/, V v) { + const DFromV d; + const auto vw = Per4LaneBlockShufCastToWide( + hwy::IsFloatTag>(), hwy::SizeTag)>(), v); + return BitCast(d, DupOdd(vw)); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xF5> /*idx_3210_tag*/, V v) { + return DupOdd(v); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xFA> /*idx_3210_tag*/, V v) { + const DFromV d; + return InterleaveUpper(d, v, v); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag idx_3210_tag, V v) { + const DFromV d; + return Per4LaneBlockShuffle(idx_3210_tag, hwy::SizeTag)>(), + hwy::SizeTag(), v); +} + +} // namespace detail +#endif // HWY_TARGET != HWY_SCALAR + +template , 1)> +HWY_API V Per4LaneBlockShuffle(V v) { + static_assert(kIdx0 <= 3, "kIdx0 <= 3 must be true"); + static_assert(kIdx1 <= 3, "kIdx1 <= 3 must be true"); + static_assert(kIdx2 <= 3, "kIdx2 <= 3 must be true"); + static_assert(kIdx3 <= 3, "kIdx3 <= 3 must be true"); + + return v; +} + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template , 2)> +HWY_API V Per4LaneBlockShuffle(V v) { + static_assert(kIdx0 <= 3, "kIdx0 <= 3 must be true"); + static_assert(kIdx1 <= 3, "kIdx1 <= 3 must be true"); + static_assert(kIdx2 <= 3, "kIdx2 <= 3 must be true"); + static_assert(kIdx3 <= 3, "kIdx3 <= 3 must be true"); + + constexpr bool isReverse2 = (kIdx0 == 1 || kIdx1 == 0) && (kIdx0 != kIdx1); + constexpr size_t kPer2BlkIdx0 = (kIdx0 <= 1) ? kIdx0 : (isReverse2 ? 1 : 0); + constexpr size_t kPer2BlkIdx1 = (kIdx1 <= 1) ? kIdx1 : (isReverse2 ? 0 : 1); + + constexpr size_t kIdx10 = (kPer2BlkIdx1 << 1) | kPer2BlkIdx0; + static_assert(kIdx10 <= 3, "kIdx10 <= 3 must be true"); + return detail::Per2LaneBlockShuffle(hwy::SizeTag(), v); +} + +template , 2)> +HWY_API V Per4LaneBlockShuffle(V v) { + static_assert(kIdx0 <= 3, "kIdx0 <= 3 must be true"); + static_assert(kIdx1 <= 3, "kIdx1 <= 3 must be true"); + static_assert(kIdx2 <= 3, "kIdx2 <= 3 must be true"); + static_assert(kIdx3 <= 3, "kIdx3 <= 3 must be true"); + + constexpr size_t kIdx3210 = + (kIdx3 << 6) | (kIdx2 << 4) | (kIdx1 << 2) | kIdx0; + return detail::Per4LaneBlockShuffle(hwy::SizeTag(), v); +} +#endif + +// ------------------------------ PairwiseAdd128/PairwiseSub128 +// (Per4LaneBlockShuffle) +#if (defined(HWY_NATIVE_PAIRWISE_ADD_128) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_PAIRWISE_ADD_128 +#undef HWY_NATIVE_PAIRWISE_ADD_128 +#else +#define HWY_NATIVE_PAIRWISE_ADD_128 +#endif + +namespace detail { + +// detail::BlockwiseConcatOddEven(d, v) returns the even lanes of each block of +// v followed by the odd lanes of v +#if HWY_TARGET_IS_NEON || HWY_TARGET_IS_SVE || HWY_TARGET == HWY_RVV || \ + HWY_TARGET == HWY_LSX || HWY_TARGET == HWY_LASX +template +static HWY_INLINE HWY_MAYBE_UNUSED Vec BlockwiseConcatOddEven(D d, + Vec v) { +#if HWY_TARGET == HWY_RVV + const ScalableTag du64; +#else + const Repartition> du64; +#endif + + const Repartition, decltype(du64)> d_concat; + const auto v_to_concat = ResizeBitCast(d_concat, v); + + const auto evens = ConcatEven(d, v_to_concat, v_to_concat); + const auto odds = ConcatOdd(d, v_to_concat, v_to_concat); + return ResizeBitCast( + d, InterleaveWholeLower(BitCast(du64, evens), BitCast(du64, odds))); +} + +#else // !(HWY_TARGET_IS_NEON || HWY_TARGET_IS_SVE || HWY_TARGET == HWY_RVV) + +template +static HWY_INLINE HWY_MAYBE_UNUSED Vec BlockwiseConcatOddEven(D d, + Vec v) { +#if HWY_TARGET == HWY_SSE2 + const RebindToUnsigned du; + const RebindToSigned> dw; + + const auto vu = BitCast(du, v); + return BitCast( + d, OrderedDemote2To(du, PromoteEvenTo(dw, vu), PromoteOddTo(dw, vu))); +#else + const Repartition du8; + const auto idx = + BitCast(d, Dup128VecFromValues(du8, 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, + 9, 11, 13, 15)); + return TableLookupBytes(v, idx); +#endif +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED Vec BlockwiseConcatOddEven(D d, + Vec v) { +#if HWY_TARGET == HWY_SSE2 + const RebindToSigned di; + const RepartitionToWide dw; + const auto vi = BitCast(di, v); + return BitCast( + d, OrderedDemote2To(di, PromoteEvenTo(dw, vi), PromoteOddTo(dw, vi))); +#else + const Repartition du8; + const auto idx = BitCast(d, Dup128VecFromValues(du8, 0, 1, 4, 5, 8, 9, 12, 13, + 2, 3, 6, 7, 10, 11, 14, 15)); + return TableLookupBytes(v, idx); +#endif +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED Vec BlockwiseConcatOddEven(D /*d*/, + Vec v) { + return Per4LaneBlockShuffle<3, 1, 2, 0>(v); +} +#endif // HWY_TARGET_IS_NEON || HWY_TARGET_IS_SVE || HWY_TARGET == HWY_RVV + +template +static HWY_INLINE HWY_MAYBE_UNUSED Vec BlockwiseConcatOddEven(D /*d*/, + Vec v) { + return v; +} + +} // namespace detail + +// Pairwise add with output in 128 bit blocks of a and b. +template +HWY_API Vec PairwiseAdd128(D d, Vec a, Vec b) { + return detail::BlockwiseConcatOddEven(d, PairwiseAdd(d, a, b)); +} + +// Pairwise sub with output in 128 bit blocks of a and b. +template +HWY_API Vec PairwiseSub128(D d, Vec a, Vec b) { + return detail::BlockwiseConcatOddEven(d, PairwiseSub(d, a, b)); +} + +#endif + +// ------------------------------ Blocks + +template +HWY_API size_t Blocks(D d) { + return (d.MaxBytes() <= 16) ? 1 : ((Lanes(d) * sizeof(TFromD) + 15) / 16); +} + +// ------------------------------ Block insert/extract/broadcast ops +#if (defined(HWY_NATIVE_BLK_INSERT_EXTRACT) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_BLK_INSERT_EXTRACT +#undef HWY_NATIVE_BLK_INSERT_EXTRACT +#else +#define HWY_NATIVE_BLK_INSERT_EXTRACT +#endif + +template +HWY_API V InsertBlock(V /*v*/, V blk_to_insert) { + static_assert(kBlockIdx == 0, "Invalid block index"); + return blk_to_insert; +} + +template +HWY_API V ExtractBlock(V v) { + static_assert(kBlockIdx == 0, "Invalid block index"); + return v; +} + +template +HWY_API V BroadcastBlock(V v) { + static_assert(kBlockIdx == 0, "Invalid block index"); + return v; +} + +#endif // HWY_NATIVE_BLK_INSERT_EXTRACT + +// ------------------------------ BroadcastLane +#if (defined(HWY_NATIVE_BROADCASTLANE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_BROADCASTLANE +#undef HWY_NATIVE_BROADCASTLANE +#else +#define HWY_NATIVE_BROADCASTLANE +#endif + +template +HWY_API V BroadcastLane(V v) { + return Broadcast(v); +} + +#endif // HWY_NATIVE_BROADCASTLANE + +// ------------------------------ Slide1Up and Slide1Down +#if (defined(HWY_NATIVE_SLIDE1_UP_DOWN) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SLIDE1_UP_DOWN +#undef HWY_NATIVE_SLIDE1_UP_DOWN +#else +#define HWY_NATIVE_SLIDE1_UP_DOWN +#endif + +template +HWY_API VFromD Slide1Up(D d, VFromD /*v*/) { + return Zero(d); +} +template +HWY_API VFromD Slide1Down(D d, VFromD /*v*/) { + return Zero(d); +} + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template +HWY_API VFromD Slide1Up(D d, VFromD v) { + return ShiftLeftLanes<1>(d, v); +} +template +HWY_API VFromD Slide1Down(D d, VFromD v) { + return ShiftRightLanes<1>(d, v); +} +#endif // HWY_TARGET != HWY_SCALAR + +#endif // HWY_NATIVE_SLIDE1_UP_DOWN + +// ------------------------------ SlideUpBlocks + +template +HWY_API VFromD SlideUpBlocks(D /*d*/, VFromD v) { + static_assert(kBlocks == 0, "kBlocks == 0 must be true"); + return v; +} + +#if HWY_HAVE_SCALABLE || HWY_TARGET == HWY_SVE_256 +template +HWY_API VFromD SlideUpBlocks(D d, VFromD v) { + static_assert(0 <= kBlocks && static_cast(kBlocks) < d.MaxBlocks(), + "kBlocks must be between 0 and d.MaxBlocks() - 1"); + constexpr size_t kLanesPerBlock = 16 / sizeof(TFromD); + return SlideUpLanes(d, v, static_cast(kBlocks) * kLanesPerBlock); +} +#endif + +// ------------------------------ SlideDownBlocks + +template +HWY_API VFromD SlideDownBlocks(D /*d*/, VFromD v) { + static_assert(kBlocks == 0, "kBlocks == 0 must be true"); + return v; +} + +#if HWY_HAVE_SCALABLE || HWY_TARGET == HWY_SVE_256 +template +HWY_API VFromD SlideDownBlocks(D d, VFromD v) { + static_assert(0 <= kBlocks && static_cast(kBlocks) < d.MaxBlocks(), + "kBlocks must be between 0 and d.MaxBlocks() - 1"); + constexpr size_t kLanesPerBlock = 16 / sizeof(TFromD); + return SlideDownLanes(d, v, static_cast(kBlocks) * kLanesPerBlock); +} +#endif + +// ------------------------------ Slide mask up/down +#if (defined(HWY_NATIVE_SLIDE_MASK) == defined(HWY_TARGET_TOGGLE)) + +#ifdef HWY_NATIVE_SLIDE_MASK +#undef HWY_NATIVE_SLIDE_MASK +#else +#define HWY_NATIVE_SLIDE_MASK +#endif + +template +HWY_API Mask SlideMask1Up(D d, Mask m) { + return MaskFromVec(Slide1Up(d, VecFromMask(d, m))); +} + +template +HWY_API Mask SlideMask1Down(D d, Mask m) { + return MaskFromVec(Slide1Down(d, VecFromMask(d, m))); +} + +template +HWY_API Mask SlideMaskUpLanes(D d, Mask m, size_t amt) { + return MaskFromVec(SlideUpLanes(d, VecFromMask(d, m), amt)); +} + +template +HWY_API Mask SlideMaskDownLanes(D d, Mask m, size_t amt) { + return MaskFromVec(SlideDownLanes(d, VecFromMask(d, m), amt)); +} + +#endif // HWY_NATIVE_SLIDE_MASK + +// ------------------------------ SumsOfAdjQuadAbsDiff + +#if (defined(HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF +#undef HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF +#else +#define HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF +#endif + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template )> +HWY_API Vec>> SumsOfAdjQuadAbsDiff(V8 a, V8 b) { + static_assert(0 <= kAOffset && kAOffset <= 1, + "kAOffset must be between 0 and 1"); + static_assert(0 <= kBOffset && kBOffset <= 3, + "kBOffset must be between 0 and 3"); + using D8 = DFromV; + const D8 d8; + const RebindToUnsigned du8; + const RepartitionToWide d16; + const RepartitionToWide du16; + + // Ensure that a is resized to a vector that has at least + // HWY_MAX(Lanes(d8), size_t{8} << kAOffset) lanes for the interleave and + // CombineShiftRightBytes operations below. +#if HWY_TARGET == HWY_RVV + // On RVV targets, need to ensure that d8_interleave.Pow2() >= 0 is true + // to ensure that Lanes(d8_interleave) >= 16 is true. + + // Lanes(d8_interleave) >= Lanes(d8) is guaranteed to be true on RVV + // targets as d8_interleave.Pow2() >= d8.Pow2() is true. + constexpr int kInterleavePow2 = HWY_MAX(d8.Pow2(), 0); + const ScalableTag, kInterleavePow2> d8_interleave; +#elif HWY_HAVE_SCALABLE || HWY_TARGET_IS_SVE + // On SVE targets, Lanes(d8_interleave) >= 16 and + // Lanes(d8_interleave) >= Lanes(d8) are both already true as d8 is a SIMD + // tag for a full u8/i8 vector on SVE. + const D8 d8_interleave; +#else + // On targets that use non-scalable vector types, Lanes(d8_interleave) is + // equal to HWY_MAX(Lanes(d8), size_t{8} << kAOffset). + constexpr size_t kInterleaveLanes = + HWY_MAX(HWY_MAX_LANES_D(D8), size_t{8} << kAOffset); + const FixedTag, kInterleaveLanes> d8_interleave; +#endif + + // The ResizeBitCast operation below will resize a to a vector that has + // at least HWY_MAX(Lanes(d8), size_t{8} << kAOffset) lanes for the + // InterleaveLower, InterleaveUpper, and CombineShiftRightBytes operations + // below. + const auto a_to_interleave = ResizeBitCast(d8_interleave, a); + + const auto a_interleaved_lo = + InterleaveLower(d8_interleave, a_to_interleave, a_to_interleave); + const auto a_interleaved_hi = + InterleaveUpper(d8_interleave, a_to_interleave, a_to_interleave); + + /* a01: { a[kAOffset*4+0], a[kAOffset*4+1], a[kAOffset*4+1], a[kAOffset*4+2], + a[kAOffset*4+2], a[kAOffset*4+3], a[kAOffset*4+3], a[kAOffset*4+4], + a[kAOffset*4+4], a[kAOffset*4+5], a[kAOffset*4+5], a[kAOffset*4+6], + a[kAOffset*4+6], a[kAOffset*4+7], a[kAOffset*4+7], a[kAOffset*4+8] } + */ + /* a23: { a[kAOffset*4+2], a[kAOffset*4+3], a[kAOffset*4+3], a[kAOffset*4+4], + a[kAOffset*4+4], a[kAOffset*4+5], a[kAOffset*4+5], a[kAOffset*4+6], + a[kAOffset*4+6], a[kAOffset*4+7], a[kAOffset*4+7], a[kAOffset*4+8], + a[kAOffset*4+8], a[kAOffset*4+9], a[kAOffset*4+9], a[kAOffset*4+10] + } */ + + // a01 and a23 are resized back to V8 as only the first Lanes(d8) lanes of + // the CombineShiftRightBytes are needed for the subsequent AbsDiff operations + // and as a01 and a23 need to be the same vector type as b01 and b23 for the + // AbsDiff operations below. + const V8 a01 = + ResizeBitCast(d8, CombineShiftRightBytes( + d8_interleave, a_interleaved_hi, a_interleaved_lo)); + const V8 a23 = + ResizeBitCast(d8, CombineShiftRightBytes( + d8_interleave, a_interleaved_hi, a_interleaved_lo)); + + /* b01: { b[kBOffset*4+0], b[kBOffset*4+1], b[kBOffset*4+0], b[kBOffset*4+1], + b[kBOffset*4+0], b[kBOffset*4+1], b[kBOffset*4+0], b[kBOffset*4+1], + b[kBOffset*4+0], b[kBOffset*4+1], b[kBOffset*4+0], b[kBOffset*4+1], + b[kBOffset*4+0], b[kBOffset*4+1], b[kBOffset*4+0], b[kBOffset*4+1] } + */ + /* b23: { b[kBOffset*4+2], b[kBOffset*4+3], b[kBOffset*4+2], b[kBOffset*4+3], + b[kBOffset*4+2], b[kBOffset*4+3], b[kBOffset*4+2], b[kBOffset*4+3], + b[kBOffset*4+2], b[kBOffset*4+3], b[kBOffset*4+2], b[kBOffset*4+3], + b[kBOffset*4+2], b[kBOffset*4+3], b[kBOffset*4+2], b[kBOffset*4+3] } + */ + const V8 b01 = BitCast(d8, Broadcast(BitCast(d16, b))); + const V8 b23 = BitCast(d8, Broadcast(BitCast(d16, b))); + + const VFromD absdiff_sum_01 = + SumsOf2(BitCast(du8, AbsDiff(a01, b01))); + const VFromD absdiff_sum_23 = + SumsOf2(BitCast(du8, AbsDiff(a23, b23))); + return BitCast(d16, Add(absdiff_sum_01, absdiff_sum_23)); +} +#endif // HWY_TARGET != HWY_SCALAR + +#endif // HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF + +// ------------------------------ SumsOfShuffledQuadAbsDiff + +#if (defined(HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF) == \ + defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF +#undef HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF +#else +#define HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF +#endif + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +template )> +HWY_API Vec>> SumsOfShuffledQuadAbsDiff(V8 a, + V8 b) { + static_assert(0 <= kIdx0 && kIdx0 <= 3, "kIdx0 must be between 0 and 3"); + static_assert(0 <= kIdx1 && kIdx1 <= 3, "kIdx1 must be between 0 and 3"); + static_assert(0 <= kIdx2 && kIdx2 <= 3, "kIdx2 must be between 0 and 3"); + static_assert(0 <= kIdx3 && kIdx3 <= 3, "kIdx3 must be between 0 and 3"); + +#if HWY_TARGET == HWY_RVV + // On RVV, ensure that both vA and vB have a LMUL of at least 1/2 so that + // both vA and vB can be bitcasted to a u32 vector. + const detail::AdjustSimdTagToMinVecPow2< + RepartitionToWideX2>> + d32; + const RepartitionToNarrow d16; + const RepartitionToNarrow d8; + + const auto vA = ResizeBitCast(d8, a); + const auto vB = ResizeBitCast(d8, b); +#else + const DFromV d8; + const RepartitionToWide d16; + const RepartitionToWide d32; + + const auto vA = a; + const auto vB = b; +#endif + + const RebindToUnsigned du8; + + const auto a_shuf = + Per4LaneBlockShuffle(BitCast(d32, vA)); + /* a0123_2345: { a_shuf[0], a_shuf[1], a_shuf[2], a_shuf[3], + a_shuf[2], a_shuf[3], a_shuf[4], a_shuf[5], + a_shuf[8], a_shuf[9], a_shuf[10], a_shuf[11], + a_shuf[10], a_shuf[11], a_shuf[12], a_shuf[13] } */ + /* a1234_3456: { a_shuf[1], a_shuf[2], a_shuf[3], a_shuf[4], + a_shuf[3], a_shuf[4], a_shuf[5], a_shuf[6], + a_shuf[9], a_shuf[10], a_shuf[11], a_shuf[12], + a_shuf[11], a_shuf[12], a_shuf[13], a_shuf[14] } */ +#if HWY_HAVE_SCALABLE || HWY_TARGET_IS_SVE + // On RVV/SVE targets, use Slide1Up/Slide1Down instead of + // ShiftLeftBytes/ShiftRightBytes to avoid unnecessary zeroing out of any + // lanes that are shifted into an adjacent 16-byte block as any lanes that are + // shifted into an adjacent 16-byte block by Slide1Up/Slide1Down will be + // replaced by the OddEven operation. + const auto a_0123_2345 = BitCast( + d8, OddEven(BitCast(d32, Slide1Up(d16, BitCast(d16, a_shuf))), a_shuf)); + const auto a_1234_3456 = + BitCast(d8, OddEven(BitCast(d32, Slide1Up(d8, BitCast(d8, a_shuf))), + BitCast(d32, Slide1Down(d8, BitCast(d8, a_shuf))))); +#else + const auto a_0123_2345 = + BitCast(d8, OddEven(ShiftLeftBytes<2>(d32, a_shuf), a_shuf)); + const auto a_1234_3456 = BitCast( + d8, + OddEven(ShiftLeftBytes<1>(d32, a_shuf), ShiftRightBytes<1>(d32, a_shuf))); +#endif + + auto even_sums = SumsOf4(BitCast(du8, AbsDiff(a_0123_2345, vB))); + auto odd_sums = SumsOf4(BitCast(du8, AbsDiff(a_1234_3456, vB))); + +#if HWY_IS_LITTLE_ENDIAN + odd_sums = ShiftLeft<16>(odd_sums); +#else + even_sums = ShiftLeft<16>(even_sums); +#endif + + const auto sums = OddEven(BitCast(d16, odd_sums), BitCast(d16, even_sums)); + +#if HWY_TARGET == HWY_RVV + return ResizeBitCast(RepartitionToWide>(), sums); +#else + return sums; +#endif +} +#endif // HWY_TARGET != HWY_SCALAR + +#endif // HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF + +// ------------------------------ BitShuffle (Rol) +#if (defined(HWY_NATIVE_BITSHUFFLE) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_BITSHUFFLE +#undef HWY_NATIVE_BITSHUFFLE +#else +#define HWY_NATIVE_BITSHUFFLE +#endif + +#if HWY_HAVE_INTEGER64 && HWY_TARGET != HWY_SCALAR +template ), HWY_IF_UI8(TFromV)> +HWY_API V BitShuffle(V v, VI idx) { + const DFromV d64; + const RebindToUnsigned du64; + const Repartition du8; + +#if HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_WASM || \ + HWY_TARGET == HWY_WASM_EMU256 + const Repartition d_idx_shr; +#else + const Repartition d_idx_shr; +#endif + +#if HWY_IS_LITTLE_ENDIAN + constexpr uint64_t kExtractedBitsMask = + static_cast(0x8040201008040201u); +#else + constexpr uint64_t kExtractedBitsMask = + static_cast(0x0102040810204080u); +#endif + + const auto k7 = Set(du8, uint8_t{0x07}); + + auto unmasked_byte_idx = BitCast(du8, ShiftRight<3>(BitCast(d_idx_shr, idx))); +#if HWY_IS_BIG_ENDIAN + // Need to invert the lower 3 bits of unmasked_byte_idx[i] on big-endian + // targets + unmasked_byte_idx = Xor(unmasked_byte_idx, k7); +#endif // HWY_IS_BIG_ENDIAN + + const auto byte_idx = BitwiseIfThenElse( + k7, unmasked_byte_idx, + BitCast(du8, Dup128VecFromValues(du64, uint64_t{0}, + uint64_t{0x0808080808080808u}))); + // We want to shift right by idx & 7 to extract the desired bit in `bytes`, + // and left by iota & 7 to put it in the correct output bit. To correctly + // handle shift counts from -7 to 7, we rotate. + const auto rotate_left_bits = Sub(Iota(du8, uint8_t{0}), BitCast(du8, idx)); + + const auto extracted_bits = + And(Rol(TableLookupBytes(v, byte_idx), rotate_left_bits), + BitCast(du8, Set(du64, kExtractedBitsMask))); + // Combine bit-sliced (one bit per byte) into one 64-bit sum. + return BitCast(d64, SumsOf8(extracted_bits)); +} +#endif // HWY_HAVE_INTEGER64 && HWY_TARGET != HWY_SCALAR + +#endif // HWY_NATIVE_BITSHUFFLE + +template +HWY_API V MaskedOr(M m, V a, V b) { + return IfThenElseZero(m, Or(a, b)); +} +// ------------------------------ AllBits1/AllBits0 +#if (defined(HWY_NATIVE_ALLONES) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ALLONES +#undef HWY_NATIVE_ALLONES +#else +#define HWY_NATIVE_ALLONES +#endif + +template > +HWY_API bool AllBits1(D d, V v) { + const RebindToUnsigned du; + using TU = TFromD; + return AllTrue(du, Eq(BitCast(du, v), Set(du, hwy::HighestValue()))); +} +#endif // HWY_NATIVE_ALLONES + +#if (defined(HWY_NATIVE_ALLZEROS) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ALLZEROS +#undef HWY_NATIVE_ALLZEROS +#else +#define HWY_NATIVE_ALLZEROS +#endif + +template > +HWY_API bool AllBits0(D d, V v) { + return AllTrue(d, Eq(v, Zero(d))); +} +#endif // HWY_NATIVE_ALLZEROS + +// ------------------------------ MultiRotateRight +#if (defined(HWY_NATIVE_MULTIROTATERIGHT) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MULTIROTATERIGHT +#undef HWY_NATIVE_MULTIROTATERIGHT +#else +#define HWY_NATIVE_MULTIROTATERIGHT +#endif + +template ), HWY_IF_UI8(TFromV), + class VI_2 = VFromD, DFromV>>, + HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_V(VI_2)), + HWY_IF_V_SIZE_V(V, 8)> +HWY_API V MultiRotateRight(V v, VI idx) { + const DFromV d64; + const Twice dt64; + const Repartition du8; + const Repartition dt_u8; + const Repartition dt_u16; + const auto k7 = Set(du8, uint8_t{0x07}); + const auto k63 = Set(du8, uint8_t{0x3F}); + + const auto masked_idx = And(k63, BitCast(du8, idx)); + + auto byte_idx = ShiftRight<3>(masked_idx); +#if HWY_IS_LITTLE_ENDIAN + const auto hi_byte_idx = Add(byte_idx, Set(du8, uint8_t{1})); +#else + byte_idx = Xor(byte_idx, k7); + const auto hi_byte_idx = Add(byte_idx, k7); +#endif + + const auto idx_shift = And(k7, masked_idx); + + // Calculate even lanes + const auto even_src = DupEven(ResizeBitCast(dt64, v)); + // Expand indexes to pull out 16 bit segments of idx and idx + 1 +#if HWY_IS_LITTLE_ENDIAN + const auto even_idx = InterleaveLower(ResizeBitCast(dt_u8, byte_idx), + ResizeBitCast(dt_u8, hi_byte_idx)); +#else + const auto even_idx = InterleaveLower(ResizeBitCast(dt_u8, hi_byte_idx), + ResizeBitCast(dt_u8, byte_idx)); +#endif + // TableLookupBytes indexes select from within a 16 byte block + const auto even_segments = TableLookupBytes(even_src, even_idx); + // Extract unaligned bytes from 16 bit segments + const auto even_idx_shift = PromoteTo(dt_u16, idx_shift); + const auto extracted_even_bytes = + Shr(BitCast(dt_u16, even_segments), even_idx_shift); + + // Extract the even bytes of each 128 bit block and pack into lower 64 bits +#if HWY_IS_LITTLE_ENDIAN + const auto even_lanes = BitCast( + dt64, + ConcatEven(dt_u8, Zero(dt_u8), BitCast(dt_u8, extracted_even_bytes))); +#else + const auto even_lanes = BitCast( + dt64, + ConcatOdd(dt_u8, Zero(dt_u8), BitCast(dt_u8, extracted_even_bytes))); +#endif + + return LowerHalf(d64, even_lanes); +} + +template ), HWY_IF_UI8(TFromV), + class VI_2 = VFromD, DFromV>>, + HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_V(VI_2)), + HWY_IF_V_SIZE_GT_V(V, 8)> +HWY_API V MultiRotateRight(V v, VI idx) { + const DFromV d64; + const Repartition du8; + const Repartition du16; + const auto k7 = Set(du8, uint8_t{0x07}); + const auto k63 = Set(du8, uint8_t{0x3F}); + + const auto masked_idx = And(k63, BitCast(du8, idx)); + + auto byte_idx = ShiftRight<3>(masked_idx); +#if HWY_IS_LITTLE_ENDIAN + const auto hi_byte_idx = Add(byte_idx, Set(du8, uint8_t{1})); +#else + byte_idx = Xor(byte_idx, k7); + const auto hi_byte_idx = Add(byte_idx, k7); +#endif + + const auto idx_shift = And(k7, masked_idx); + + // Calculate even lanes + const auto even_src = DupEven(v); + // Expand indexes to pull out 16 bit segments of idx and idx + 1 +#if HWY_IS_LITTLE_ENDIAN + const auto even_idx = InterleaveLower(byte_idx, hi_byte_idx); +#else + const auto even_idx = InterleaveLower(hi_byte_idx, byte_idx); +#endif + // TableLookupBytes indexes select from within a 16 byte block + const auto even_segments = TableLookupBytes(even_src, even_idx); + // Extract unaligned bytes from 16 bit segments +#if HWY_IS_LITTLE_ENDIAN + const auto even_idx_shift = ZipLower(idx_shift, Zero(du8)); +#else + const auto even_idx_shift = ZipLower(Zero(du8), idx_shift); +#endif + const auto extracted_even_bytes = + Shr(BitCast(du16, even_segments), even_idx_shift); + + // Calculate odd lanes + const auto odd_src = DupOdd(v); + // Expand indexes to pull out 16 bit segments of idx and idx + 1 +#if HWY_IS_LITTLE_ENDIAN + const auto odd_idx = InterleaveUpper(du8, byte_idx, hi_byte_idx); +#else + const auto odd_idx = InterleaveUpper(du8, hi_byte_idx, byte_idx); +#endif + // TableLookupBytes indexes select from within a 16 byte block + const auto odd_segments = TableLookupBytes(odd_src, odd_idx); + // Extract unaligned bytes from 16 bit segments +#if HWY_IS_LITTLE_ENDIAN + const auto odd_idx_shift = ZipUpper(du16, idx_shift, Zero(du8)); +#else + const auto odd_idx_shift = ZipUpper(du16, Zero(du8), idx_shift); +#endif + const auto extracted_odd_bytes = + Shr(BitCast(du16, odd_segments), odd_idx_shift); + + // Extract the even bytes of each 128 bit block and pack into lower 64 bits +#if HWY_IS_LITTLE_ENDIAN + const auto even_lanes = BitCast( + d64, ConcatEven(du8, Zero(du8), BitCast(du8, extracted_even_bytes))); + const auto odd_lanes = BitCast( + d64, ConcatEven(du8, Zero(du8), BitCast(du8, extracted_odd_bytes))); +#else + const auto even_lanes = BitCast( + d64, ConcatOdd(du8, Zero(du8), BitCast(du8, extracted_even_bytes))); + const auto odd_lanes = BitCast( + d64, ConcatOdd(du8, Zero(du8), BitCast(du8, extracted_odd_bytes))); +#endif + // Interleave at 64 bit level + return InterleaveWholeLower(even_lanes, odd_lanes); +} + +#if HWY_TARGET == HWY_RVV + +// MultiRotateRight for LMUL=1/2 case on RVV +template ), HWY_IF_UI8(TFromV), + class VI_2 = VFromD, DFromV>>, + HWY_IF_POW2_LE_D(DFromV, 0), + HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_V(VI_2) / 2)> +HWY_API V MultiRotateRight(V v, VI idx) { + return MultiRotateRight(v, ResizeBitCast(Twice>(), idx)); +} + +#endif + +#endif + +// ================================================== Operator wrapper + +// SVE* and RVV currently cannot define operators and have already defined +// (only) the corresponding functions such as Add. +#if (defined(HWY_NATIVE_OPERATOR_REPLACEMENTS) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_OPERATOR_REPLACEMENTS +#undef HWY_NATIVE_OPERATOR_REPLACEMENTS +#else +#define HWY_NATIVE_OPERATOR_REPLACEMENTS +#endif + +template +HWY_API V Add(V a, V b) { + return a + b; +} +template +HWY_API V Sub(V a, V b) { + return a - b; +} + +template +HWY_API V Mul(V a, V b) { + return a * b; +} +template +HWY_API V Div(V a, V b) { + return a / b; +} +template +HWY_API V Mod(V a, V b) { + return a % b; +} + +template +V Shl(V a, V b) { + return a << b; +} +template +V Shr(V a, V b) { + return a >> b; +} + +template +HWY_API auto Eq(V a, V b) -> decltype(a == b) { + return a == b; +} +template +HWY_API auto Ne(V a, V b) -> decltype(a == b) { + return a != b; +} +template +HWY_API auto Lt(V a, V b) -> decltype(a == b) { + return a < b; +} + +template +HWY_API auto Gt(V a, V b) -> decltype(a == b) { + return a > b; +} +template +HWY_API auto Ge(V a, V b) -> decltype(a == b) { + return a >= b; +} + +template +HWY_API auto Le(V a, V b) -> decltype(a == b) { + return a <= b; +} + +#endif // HWY_NATIVE_OPERATOR_REPLACEMENTS + +#undef HWY_GENERIC_IF_EMULATED_D + +// TODO: remove once callers are updated. +// SVE and RVV do not support DFromM because their masks are loosely typed. +#if HWY_MAX_BYTES <= 64 && !HWY_TARGET_IS_SVE && HWY_TARGET != HWY_RVV +namespace detail { +template +uint64_t BitsFromMask(M m) { + const DFromM d; + return ::hwy::HWY_NAMESPACE::BitsFromMask(d, m); +} +} // namespace detail +#endif // !HWY_HAVE_SCALABLE && HWY_MAX_BYTES <= 64 + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/lib/highway/hwy/ops/inside-inl.h b/lib/highway/hwy/ops/inside-inl.h new file mode 100644 index 00000000000..07759afedee --- /dev/null +++ b/lib/highway/hwy/ops/inside-inl.h @@ -0,0 +1,691 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Must be included inside an existing include guard, with the following ops +// already defined: BitCast, And, Set, ShiftLeft, ShiftRight, PromoteLowerTo, +// ConcatEven, ConcatOdd, plus the optional detail::PromoteEvenTo and +// detail::PromoteOddTo (if implemented in the target-specific header). + +// This is normally set by set_macros-inl.h before this header is included; +// if not, we are viewing this header standalone. Reduce IDE errors by: +#if !defined(HWY_NAMESPACE) +// 1) Defining HWY_IDE so we get syntax highlighting rather than all-gray text. +#include "hwy/ops/shared-inl.h" +// 2) Entering the HWY_NAMESPACE to make definitions from shared-inl.h visible. +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +#define HWY_INSIDE_END_NAMESPACE +// 3) Providing a dummy VFromD (usually done by the target-specific header). +template +using VFromD = int; +template +using TFromV = int; +template +struct DFromV {}; +#endif + +// ------------------------------ Vec/Create/Get/Set2..4 + +// On SVE and RVV, Vec2..4 are aliases to built-in types. Also exclude the +// fixed-size SVE targets. +#if HWY_IDE || (!HWY_HAVE_SCALABLE && !HWY_TARGET_IS_SVE) + +// NOTE: these are used inside arm_neon-inl.h, hence they cannot be defined in +// generic_ops-inl.h, which is included after that. +template +struct Vec2 { + VFromD v0; + VFromD v1; +}; + +template +struct Vec3 { + VFromD v0; + VFromD v1; + VFromD v2; +}; + +template +struct Vec4 { + VFromD v0; + VFromD v1; + VFromD v2; + VFromD v3; +}; + +// D arg is unused but allows deducing D. +template +HWY_API Vec2 Create2(D /* tag */, VFromD v0, VFromD v1) { + return Vec2{v0, v1}; +} + +template +HWY_API Vec3 Create3(D /* tag */, VFromD v0, VFromD v1, VFromD v2) { + return Vec3{v0, v1, v2}; +} + +template +HWY_API Vec4 Create4(D /* tag */, VFromD v0, VFromD v1, VFromD v2, + VFromD v3) { + return Vec4{v0, v1, v2, v3}; +} + +template +HWY_API VFromD Get2(Vec2 tuple) { + static_assert(kIndex < 2, "Tuple index out of bounds"); + return kIndex == 0 ? tuple.v0 : tuple.v1; +} + +template +HWY_API VFromD Get3(Vec3 tuple) { + static_assert(kIndex < 3, "Tuple index out of bounds"); + return kIndex == 0 ? tuple.v0 : kIndex == 1 ? tuple.v1 : tuple.v2; +} + +template +HWY_API VFromD Get4(Vec4 tuple) { + static_assert(kIndex < 4, "Tuple index out of bounds"); + return kIndex == 0 ? tuple.v0 + : kIndex == 1 ? tuple.v1 + : kIndex == 2 ? tuple.v2 + : tuple.v3; +} + +template +HWY_API Vec2 Set2(Vec2 tuple, VFromD val) { + static_assert(kIndex < 2, "Tuple index out of bounds"); + if (kIndex == 0) { + tuple.v0 = val; + } else { + tuple.v1 = val; + } + return tuple; +} + +template +HWY_API Vec3 Set3(Vec3 tuple, VFromD val) { + static_assert(kIndex < 3, "Tuple index out of bounds"); + if (kIndex == 0) { + tuple.v0 = val; + } else if (kIndex == 1) { + tuple.v1 = val; + } else { + tuple.v2 = val; + } + return tuple; +} + +template +HWY_API Vec4 Set4(Vec4 tuple, VFromD val) { + static_assert(kIndex < 4, "Tuple index out of bounds"); + if (kIndex == 0) { + tuple.v0 = val; + } else if (kIndex == 1) { + tuple.v1 = val; + } else if (kIndex == 2) { + tuple.v2 = val; + } else { + tuple.v3 = val; + } + return tuple; +} + +#endif // !HWY_HAVE_SCALABLE || HWY_IDE + +// ------------------------------ Rol/Ror (And, Or, Neg, Shl, Shr) +#if (defined(HWY_NATIVE_ROL_ROR_8) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ROL_ROR_8 +#undef HWY_NATIVE_ROL_ROR_8 +#else +#define HWY_NATIVE_ROL_ROR_8 +#endif + +template )> +HWY_API V Rol(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint8_t{7}); + const auto shl_amt = And(BitCast(du, b), shift_amt_mask); + const auto shr_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} + +template )> +HWY_API V Ror(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint8_t{7}); + const auto shr_amt = And(BitCast(du, b), shift_amt_mask); + const auto shl_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} + +#endif // HWY_NATIVE_ROL_ROR_8 + +#if (defined(HWY_NATIVE_ROL_ROR_16) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ROL_ROR_16 +#undef HWY_NATIVE_ROL_ROR_16 +#else +#define HWY_NATIVE_ROL_ROR_16 +#endif + +template )> +HWY_API V Rol(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint16_t{15}); + const auto shl_amt = And(BitCast(du, b), shift_amt_mask); + const auto shr_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} + +template )> +HWY_API V Ror(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint16_t{15}); + const auto shr_amt = And(BitCast(du, b), shift_amt_mask); + const auto shl_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} + +#endif // HWY_NATIVE_ROL_ROR_16 + +#if (defined(HWY_NATIVE_ROL_ROR_32_64) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ROL_ROR_32_64 +#undef HWY_NATIVE_ROL_ROR_32_64 +#else +#define HWY_NATIVE_ROL_ROR_32_64 +#endif + +template )> +HWY_API V Rol(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint32_t{31}); + const auto shl_amt = And(BitCast(du, b), shift_amt_mask); + const auto shr_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} + +template )> +HWY_API V Ror(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint32_t{31}); + const auto shr_amt = And(BitCast(du, b), shift_amt_mask); + const auto shl_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} + +#if HWY_HAVE_INTEGER64 +template )> +HWY_API V Rol(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint64_t{63}); + const auto shl_amt = And(BitCast(du, b), shift_amt_mask); + const auto shr_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} + +template )> +HWY_API V Ror(V a, V b) { + const DFromV d; + const RebindToSigned di; + const RebindToUnsigned du; + + const auto shift_amt_mask = Set(du, uint64_t{63}); + const auto shr_amt = And(BitCast(du, b), shift_amt_mask); + const auto shl_amt = And(BitCast(du, Neg(BitCast(di, b))), shift_amt_mask); + + const auto vu = BitCast(du, a); + return BitCast(d, Or(Shl(vu, shl_amt), Shr(vu, shr_amt))); +} +#endif // HWY_HAVE_INTEGER64 + +#endif // HWY_NATIVE_ROL_ROR_32_64 + +// ------------------------------ RotateLeftSame/RotateRightSame + +#if (defined(HWY_NATIVE_ROL_ROR_SAME_8) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ROL_ROR_SAME_8 +#undef HWY_NATIVE_ROL_ROR_SAME_8 +#else +#define HWY_NATIVE_ROL_ROR_SAME_8 +#endif + +template )> +HWY_API V RotateLeftSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shl_amt = bits & 7; + const int shr_amt = static_cast((0u - static_cast(bits)) & 7u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} + +template )> +HWY_API V RotateRightSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shr_amt = bits & 7; + const int shl_amt = static_cast((0u - static_cast(bits)) & 7u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} + +#endif // HWY_NATIVE_ROL_ROR_SAME_8 + +#if (defined(HWY_NATIVE_ROL_ROR_SAME_16) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ROL_ROR_SAME_16 +#undef HWY_NATIVE_ROL_ROR_SAME_16 +#else +#define HWY_NATIVE_ROL_ROR_SAME_16 +#endif + +template )> +HWY_API V RotateLeftSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shl_amt = bits & 15; + const int shr_amt = + static_cast((0u - static_cast(bits)) & 15u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} + +template )> +HWY_API V RotateRightSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shr_amt = bits & 15; + const int shl_amt = + static_cast((0u - static_cast(bits)) & 15u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} +#endif // HWY_NATIVE_ROL_ROR_SAME_16 + +#if (defined(HWY_NATIVE_ROL_ROR_SAME_32_64) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_ROL_ROR_SAME_32_64 +#undef HWY_NATIVE_ROL_ROR_SAME_32_64 +#else +#define HWY_NATIVE_ROL_ROR_SAME_32_64 +#endif + +template )> +HWY_API V RotateLeftSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shl_amt = bits & 31; + const int shr_amt = + static_cast((0u - static_cast(bits)) & 31u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} + +template )> +HWY_API V RotateRightSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shr_amt = bits & 31; + const int shl_amt = + static_cast((0u - static_cast(bits)) & 31u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} + +#if HWY_HAVE_INTEGER64 +template )> +HWY_API V RotateLeftSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shl_amt = bits & 63; + const int shr_amt = + static_cast((0u - static_cast(bits)) & 63u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} + +template )> +HWY_API V RotateRightSame(V v, int bits) { + const DFromV d; + const RebindToUnsigned du; + + const int shr_amt = bits & 63; + const int shl_amt = + static_cast((0u - static_cast(bits)) & 63u); + + const auto vu = BitCast(du, v); + return BitCast(d, + Or(ShiftLeftSame(vu, shl_amt), ShiftRightSame(vu, shr_amt))); +} +#endif // HWY_HAVE_INTEGER64 + +#endif // HWY_NATIVE_ROL_ROR_SAME_32_64 + +// ------------------------------ PromoteEvenTo/PromoteOddTo + +// These are used by target-specific headers for ReorderWidenMulAccumulate etc. + +#if HWY_TARGET != HWY_SCALAR || HWY_IDE +namespace detail { + +// Tag dispatch is used in detail::PromoteEvenTo and detail::PromoteOddTo as +// there are target-specific specializations for some of the +// detail::PromoteEvenTo and detail::PromoteOddTo cases on +// SVE/PPC/SSE2/SSSE3/SSE4/AVX2. + +// All targets except HWY_SCALAR use the implementations of +// detail::PromoteEvenTo and detail::PromoteOddTo in generic_ops-inl.h for at +// least some of the PromoteEvenTo and PromoteOddTo cases. + +// Signed to signed PromoteEvenTo/PromoteOddTo +template +HWY_INLINE VFromD PromoteEvenTo( + hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, V v) { +#if HWY_TARGET_IS_SVE + // The intrinsic expects the wide lane type. + return NativePromoteEvenTo(BitCast(d_to, v)); +#else +#if HWY_IS_LITTLE_ENDIAN + // On little-endian targets, need to shift each lane of the bitcasted + // vector left by kToLaneSize * 4 bits to get the bits of the even + // source lanes into the upper kToLaneSize * 4 bits of even_in_hi. + const auto even_in_hi = ShiftLeft(BitCast(d_to, v)); +#else + // On big-endian targets, the bits of the even source lanes are already + // in the upper kToLaneSize * 4 bits of the lanes of the bitcasted + // vector. + const auto even_in_hi = BitCast(d_to, v); +#endif + + // Right-shift even_in_hi by kToLaneSize * 4 bits + return ShiftRight(even_in_hi); +#endif // HWY_TARGET_IS_SVE +} + +// Unsigned to unsigned PromoteEvenTo/PromoteOddTo +template +HWY_INLINE VFromD PromoteEvenTo( + hwy::UnsignedTag /*to_type_tag*/, + hwy::SizeTag /*to_lane_size_tag*/, + hwy::UnsignedTag /*from_type_tag*/, D d_to, V v) { +#if HWY_TARGET_IS_SVE + // The intrinsic expects the wide lane type. + return NativePromoteEvenTo(BitCast(d_to, v)); +#else +#if HWY_IS_LITTLE_ENDIAN + // On little-endian targets, the bits of the even source lanes are already + // in the lower kToLaneSize * 4 bits of the lanes of the bitcasted vector. + + // Simply need to zero out the upper bits of each lane of the bitcasted + // vector. + return And(BitCast(d_to, v), + Set(d_to, static_cast>(LimitsMax>()))); +#else + // On big-endian targets, need to shift each lane of the bitcasted vector + // right by kToLaneSize * 4 bits to get the bits of the even source lanes into + // the lower kToLaneSize * 4 bits of the result. + + // The right shift below will zero out the upper kToLaneSize * 4 bits of the + // result. + return ShiftRight(BitCast(d_to, v)); +#endif +#endif // HWY_TARGET_IS_SVE +} + +template +HWY_INLINE VFromD PromoteOddTo( + hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, V v) { +#if HWY_IS_LITTLE_ENDIAN + // On little-endian targets, the bits of the odd source lanes are already in + // the upper kToLaneSize * 4 bits of the lanes of the bitcasted vector. + const auto odd_in_hi = BitCast(d_to, v); +#else + // On big-endian targets, need to shift each lane of the bitcasted vector + // left by kToLaneSize * 4 bits to get the bits of the odd source lanes into + // the upper kToLaneSize * 4 bits of odd_in_hi. + const auto odd_in_hi = ShiftLeft(BitCast(d_to, v)); +#endif + + // Right-shift odd_in_hi by kToLaneSize * 4 bits + return ShiftRight(odd_in_hi); +} + +template +HWY_INLINE VFromD PromoteOddTo( + hwy::UnsignedTag /*to_type_tag*/, + hwy::SizeTag /*to_lane_size_tag*/, + hwy::UnsignedTag /*from_type_tag*/, D d_to, V v) { +#if HWY_IS_LITTLE_ENDIAN + // On little-endian targets, need to shift each lane of the bitcasted vector + // right by kToLaneSize * 4 bits to get the bits of the odd source lanes into + // the lower kToLaneSize * 4 bits of the result. + + // The right shift below will zero out the upper kToLaneSize * 4 bits of the + // result. + return ShiftRight(BitCast(d_to, v)); +#else + // On big-endian targets, the bits of the even source lanes are already + // in the lower kToLaneSize * 4 bits of the lanes of the bitcasted vector. + + // Simply need to zero out the upper bits of each lane of the bitcasted + // vector. + return And(BitCast(d_to, v), + Set(d_to, static_cast>(LimitsMax>()))); +#endif +} + +// Unsigned to signed: Same as unsigned->unsigned PromoteEvenTo/PromoteOddTo +// followed by BitCast to signed +template +HWY_INLINE VFromD PromoteEvenTo( + hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag /*to_lane_size_tag*/, + hwy::UnsignedTag /*from_type_tag*/, D d_to, V v) { + const RebindToUnsigned du_to; + return BitCast(d_to, + PromoteEvenTo(hwy::UnsignedTag(), hwy::SizeTag(), + hwy::UnsignedTag(), du_to, v)); +} + +template +HWY_INLINE VFromD PromoteOddTo( + hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag /*to_lane_size_tag*/, + hwy::UnsignedTag /*from_type_tag*/, D d_to, V v) { + const RebindToUnsigned du_to; + return BitCast(d_to, + PromoteOddTo(hwy::UnsignedTag(), hwy::SizeTag(), + hwy::UnsignedTag(), du_to, v)); +} + +// BF16->F32 PromoteEvenTo + +// NOTE: It is possible for FromTypeTag to be hwy::SignedTag or hwy::UnsignedTag +// instead of hwy::FloatTag on targets that use scalable vectors. + +// VBF16 is considered to be a bfloat16_t vector if TFromV is the same +// type as TFromV>> + +// The BF16->F32 PromoteEvenTo overload is only enabled if VBF16 is considered +// to be a bfloat16_t vector. +template >, + hwy::EnableIf, TFromV>()>* = nullptr> +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<4> /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, DF32 d_to, + VBF16 v) { + const RebindToUnsigned du_to; +#if HWY_IS_LITTLE_ENDIAN + // On little-endian platforms, need to shift left each lane of the bitcasted + // vector by 16 bits. + return BitCast(d_to, ShiftLeft<16>(BitCast(du_to, v))); +#else + // On big-endian platforms, the even lanes of the source vector are already + // in the upper 16 bits of the lanes of the bitcasted vector. + + // Need to simply zero out the lower 16 bits of each lane of the bitcasted + // vector. + return BitCast(d_to, + And(BitCast(du_to, v), Set(du_to, uint32_t{0xFFFF0000u}))); +#endif +} + +// BF16->F32 PromoteOddTo + +// NOTE: It is possible for FromTypeTag to be hwy::SignedTag or hwy::UnsignedTag +// instead of hwy::FloatTag on targets that use scalable vectors. + +// VBF16 is considered to be a bfloat16_t vector if TFromV is the same +// type as TFromV>> + +// The BF16->F32 PromoteEvenTo overload is only enabled if VBF16 is considered +// to be a bfloat16_t vector. +template >, + hwy::EnableIf, TFromV>()>* = nullptr> +HWY_INLINE VFromD PromoteOddTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<4> /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, DF32 d_to, + VBF16 v) { + const RebindToUnsigned du_to; +#if HWY_IS_LITTLE_ENDIAN + // On little-endian platforms, the odd lanes of the source vector are already + // in the upper 16 bits of the lanes of the bitcasted vector. + + // Need to simply zero out the lower 16 bits of each lane of the bitcasted + // vector. + return BitCast(d_to, + And(BitCast(du_to, v), Set(du_to, uint32_t{0xFFFF0000u}))); +#else + // On big-endian platforms, need to shift left each lane of the bitcasted + // vector by 16 bits. + return BitCast(d_to, ShiftLeft<16>(BitCast(du_to, v))); +#endif +} + +// Default PromoteEvenTo/PromoteOddTo implementations +template +HWY_INLINE VFromD PromoteEvenTo( + ToTypeTag /*to_type_tag*/, hwy::SizeTag /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, D d_to, V v) { + return PromoteLowerTo(d_to, v); +} + +template +HWY_INLINE VFromD PromoteEvenTo( + ToTypeTag /*to_type_tag*/, hwy::SizeTag /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, D d_to, V v) { + const DFromV d; + return PromoteLowerTo(d_to, ConcatEven(d, v, v)); +} + +template +HWY_INLINE VFromD PromoteOddTo( + ToTypeTag /*to_type_tag*/, hwy::SizeTag /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, D d_to, V v) { + const DFromV d; + return PromoteLowerTo(d_to, ConcatOdd(d, v, v)); +} + +} // namespace detail + +template )), + class V2 = VFromD, D>>, + HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_V(V2))> +HWY_API VFromD PromoteEvenTo(D d, V v) { + return detail::PromoteEvenTo(hwy::TypeTag>(), + hwy::SizeTag)>(), + hwy::TypeTag>(), d, v); +} + +template )), + class V2 = VFromD, D>>, + HWY_IF_LANES_D(DFromV, HWY_MAX_LANES_V(V2))> +HWY_API VFromD PromoteOddTo(D d, V v) { + return detail::PromoteOddTo(hwy::TypeTag>(), + hwy::SizeTag)>(), + hwy::TypeTag>(), d, v); +} +#endif // HWY_TARGET != HWY_SCALAR + +#ifdef HWY_INSIDE_END_NAMESPACE +#undef HWY_INSIDE_END_NAMESPACE +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); +#endif diff --git a/lib/highway/hwy/ops/loongarch_lasx-inl.h b/lib/highway/hwy/ops/loongarch_lasx-inl.h new file mode 100644 index 00000000000..e947d611330 --- /dev/null +++ b/lib/highway/hwy/ops/loongarch_lasx-inl.h @@ -0,0 +1,4686 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 256-bit LASX vectors and operations. +// External include guard in highway.h - see comment there. + +#include "hwy/ops/loongarch_lsx-inl.h" +#include "hwy/ops/shared-inl.h" + +#ifndef __loongarch_asx +// If LASX is to be runtime dispatched (instead of in baseline), we need +// to enable it *and* define __loongarch_asx or the intrinsic header will +// fail to compile. +// +// For consistency, the same pattern as the lsxintrin.h handling in +// loongarch_lsx-inl.h is used (instead of moving lasxintrin.h after +// HWY_BEFORE_NAMESPACE). +HWY_PUSH_ATTRIBUTES("lsx,lasx") +#define __loongarch_asx +#include +#undef __loongarch_asx +// Prevent "unused push_attribute" warning from Clang. +HWY_MAYBE_UNUSED static void HWY_CONCAT(hwy_lasx_dummy, __COUNTER__) () {} +HWY_POP_ATTRIBUTES +#else +#include +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +template +struct Raw256 { + using type = __m256i; +}; +template <> +struct Raw256 { + using type = __m256; +}; +template <> +struct Raw256 { + using type = __m256d; +}; + +} // namespace detail + +template +class Vec256 { + using Raw = typename detail::Raw256::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = 32 / sizeof(T); // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec256& operator*=(const Vec256 other) { + return *this = (*this * other); + } + HWY_INLINE Vec256& operator/=(const Vec256 other) { + return *this = (*this / other); + } + HWY_INLINE Vec256& operator+=(const Vec256 other) { + return *this = (*this + other); + } + HWY_INLINE Vec256& operator-=(const Vec256 other) { + return *this = (*this - other); + } + HWY_INLINE Vec256& operator%=(const Vec256 other) { + return *this = (*this % other); + } + HWY_INLINE Vec256& operator&=(const Vec256 other) { + return *this = (*this & other); + } + HWY_INLINE Vec256& operator|=(const Vec256 other) { + return *this = (*this | other); + } + HWY_INLINE Vec256& operator^=(const Vec256 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +namespace detail { + +template +using RawMask256 = typename Raw256::type; + +} // namespace detail + +template +struct Mask256 { + using Raw = typename detail::RawMask256; + + using PrivateT = T; // only for DFromM + static constexpr size_t kPrivateN = 32 / sizeof(T); // only for DFromM + + Raw raw; +}; + +template +using Full256 = Simd; + +// ------------------------------ Zero + +// Cannot use VFromD here because it is defined in terms of Zero. +template +HWY_API Vec256> Zero(D /* tag */) { + return Vec256>{__lasx_xvreplgr2vr_d(0)}; +} +template +HWY_API Vec256 Zero(D /* tag */) { + return Vec256{__lasx_xvreplgr2vr_d(0)}; +} +template +HWY_API Vec256 Zero(D /* tag */) { + return Vec256{__lasx_xvreplgr2vr_d(0)}; +} +template +HWY_API Vec256 Zero(D /* tag */) { + return Vec256{reinterpret_cast<__m256>(__lasx_xvreplgr2vr_d(0))}; +} +template +HWY_API Vec256 Zero(D /* tag */) { + return Vec256{reinterpret_cast<__m256d>(__lasx_xvreplgr2vr_d(0))}; +} + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __m256i BitCastToInteger(__m256i v) { return v; } +HWY_INLINE __m256i BitCastToInteger(__m256 v) { + return reinterpret_cast<__m256i>(v); +} +HWY_INLINE __m256i BitCastToInteger(__m256d v) { + return reinterpret_cast<__m256i>(v); +} + +template +HWY_INLINE Vec256 BitCastToByte(Vec256 v) { + return Vec256{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger256 { + HWY_INLINE __m256i operator()(__m256i v) { return v; } +}; +template <> +struct BitCastFromInteger256 { + HWY_INLINE __m256 operator()(__m256i v) { + return reinterpret_cast<__m256>(v); + } +}; +template <> +struct BitCastFromInteger256 { + HWY_INLINE __m256d operator()(__m256i v) { + return reinterpret_cast<__m256d>(v); + } +}; + +template +HWY_INLINE VFromD BitCastFromByte(D /* tag */, Vec256 v) { + return VFromD{BitCastFromInteger256>()(v.raw)}; +} + +} // namespace detail + +template +HWY_API VFromD BitCast(D d, Vec256 v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Set + +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{__lasx_xvreplgr2vr_b(static_cast(t))}; // NOLINT +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{__lasx_xvreplgr2vr_h(static_cast(t))}; // NOLINT +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{__lasx_xvreplgr2vr_w(static_cast(t))}; +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{__lasx_xvreplgr2vr_d(static_cast(t))}; // NOLINT +} +template +HWY_API Vec256 Set(D /* tag */, float t) { + return BitCast(D(), Vec256{__lasx_xvldrepl_w(&t, 0)}); +} +template +HWY_API Vec256 Set(D /* tag */, double t) { + return BitCast(D(), Vec256{__lasx_xvldrepl_d(&t, 0)}); +} + +// ------------------------------ ResizeBitCast + +// 32-byte vector to 32-byte vector +template ))> +HWY_API VFromD ResizeBitCast(D d, FromV v) { + return BitCast(d, v); +} + +// 32-byte vector to 16-byte vector +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + const DFromV d_from; + const Half dh_from; + return BitCast(d, LowerHalf(dh_from, v)); +} + +// 32-byte vector to <= 8-byte vector +template +HWY_API VFromD ResizeBitCast(D /*d*/, FromV v) { + return VFromD{ResizeBitCast(Full128>(), v).raw}; +} + +// <= 16-byte vector to 32-byte vector +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + typedef uint64_t GccRawU64M128Vec __attribute__((__vector_size__(16))); + + const GccRawU64M128Vec raw_v0 = reinterpret_cast(v.raw); +#if HWY_COMPILER_CLANG && HWY_HAS_BUILTIN(__builtin_nondeterministic_value) + const GccRawU64M128Vec raw_v1 = __builtin_nondeterministic_value(raw_v0); +#else + const GccRawU64M128Vec raw_v1 = raw_v0; +#endif + + const Repartition du64; + const Half dh_u64; + return BitCast( + d, + Combine(du64, VFromD{reinterpret_cast<__m128i>(raw_v1)}, + VFromD{reinterpret_cast<__m128i>(raw_v0)})); +} + +// ------------------------------ Dup128VecFromValues + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + typedef int8_t GccI8RawVectType __attribute__((__vector_size__(32))); + GccI8RawVectType raw_i8_vec = { + static_cast(t0), static_cast(t1), static_cast(t2), + static_cast(t3), static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7), static_cast(t8), + static_cast(t9), static_cast(t10), static_cast(t11), + static_cast(t12), static_cast(t13), static_cast(t14), + static_cast(t15), static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), static_cast(t4), + static_cast(t5), static_cast(t6), static_cast(t7), + static_cast(t8), static_cast(t9), static_cast(t10), + static_cast(t11), static_cast(t12), static_cast(t13), + static_cast(t14), static_cast(t15)}; + return VFromD{reinterpret_cast<__m256i>(raw_i8_vec)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + typedef int16_t GccI16RawVectType __attribute__((__vector_size__(32))); + GccI16RawVectType raw_i16_vec = { + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7), + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7)}; + return VFromD{reinterpret_cast<__m256i>(raw_i16_vec)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + typedef int32_t GccI32RawVectType __attribute__((__vector_size__(32))); + GccI32RawVectType raw_i32_vec = { + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3)}; + return VFromD{reinterpret_cast<__m256i>(raw_i32_vec)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + typedef float GccF32RawVectType __attribute__((__vector_size__(32))); + GccF32RawVectType raw_f32_vec = {t0, t1, t2, t3, t0, t1, t2, t3}; + return Vec256{reinterpret_cast<__m256>(raw_f32_vec)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + typedef int64_t GccI64RawVectType __attribute__((__vector_size__(32))); + const GccI64RawVectType raw_i64_vec = { + static_cast(t0), static_cast(t1), + static_cast(t0), static_cast(t1)}; + return VFromD{reinterpret_cast<__m256i>(raw_i64_vec)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + typedef double GccF64RawVectType __attribute__((__vector_size__(32))); + const GccF64RawVectType raw_f64_vec = {t0, t1, t0, t1}; + return VFromD{reinterpret_cast<__m256d>(raw_f64_vec)}; +} + +// ------------------------------ And + +template +HWY_API Vec256 And(Vec256 a, Vec256 b) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{__lasx_xvand_v(BitCast(du, a).raw, + BitCast(du, b).raw)}); +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec256 AndNot(Vec256 not_mask, Vec256 mask) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{__lasx_xvandn_v( + BitCast(du, not_mask).raw, BitCast(du, mask).raw)}); +} + +// ------------------------------ Or + +template +HWY_API Vec256 Or(Vec256 a, Vec256 b) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{ + __lasx_xvor_v(BitCast(du, a).raw, BitCast(du, b).raw)}); +} + +// ------------------------------ Xor + +template +HWY_API Vec256 Xor(Vec256 a, Vec256 b) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{__lasx_xvxor_v(BitCast(du, a).raw, + BitCast(du, b).raw)}); +} + +// ------------------------------ Not +template +HWY_API Vec256 Not(const Vec256 v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, VFromD{__lasx_xvnor_v(BitCast(du, v).raw, + BitCast(du, v).raw)}); +} + +// ------------------------------ Or3 +template +HWY_API Vec256 Or3(Vec256 o1, Vec256 o2, Vec256 o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd +template +HWY_API Vec256 OrAnd(Vec256 o, Vec256 a1, Vec256 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse +template +HWY_API Vec256 IfVecThenElse(Vec256 mask, Vec256 yes, Vec256 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec256 operator&(const Vec256 a, const Vec256 b) { + return And(a, b); +} + +template +HWY_API Vec256 operator|(const Vec256 a, const Vec256 b) { + return Or(a, b); +} + +template +HWY_API Vec256 operator^(const Vec256 a, const Vec256 b) { + return Xor(a, b); +} + +// ------------------------------ PopulationCount + +namespace detail { + +template +HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<1> /* tag */, Vec256 v) { + return Vec256{__lasx_xvpcnt_b(v.raw)}; +} +template +HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<2> /* tag */, Vec256 v) { + return Vec256{__lasx_xvpcnt_h(v.raw)}; +} +template +HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<4> /* tag */, Vec256 v) { + return Vec256{__lasx_xvpcnt_w(v.raw)}; +} +template +HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<8> /* tag */, Vec256 v) { + return Vec256{__lasx_xvpcnt_d(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec256 PopulationCount(Vec256 v) { + return detail::PopulationCount(hwy::SizeTag(), v); +} + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask256 MaskFromVec(const Vec256 v) { + return Mask256{v.raw}; +} + +template +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{v.raw}; +} + +// ------------------------------ IfThenElse + +// mask ? yes : no +template +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { + const DFromV d; + RebindToSigned di; + return BitCast(d, VFromD{__lasx_xvbitsel_v( + BitCast(di, no).raw, BitCast(di, yes).raw, + RebindMask(di, mask).raw)}); +} + +// mask ? yes : 0 +template +HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { + return yes & VecFromMask(mask); +} + +// mask ? 0 : no +template +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + return AndNot(VecFromMask(mask), no); +} + +template +HWY_API Vec256 ZeroIfNegative(Vec256 v) { + static_assert(IsSigned(), "Only for float"); + const DFromV d; + const auto zero = Zero(d); + return IfThenElse(v < zero, zero, v); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask256 Not(const Mask256 m) { + const Full256 d; + return MaskFromVec(Not(VecFromMask(d, m))); +} + +template +HWY_API Mask256 And(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 Or(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 ExclusiveNeither(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API MFromD RebindMask(DTo d_to, Mask256 m) { + static_assert(sizeof(TFrom) == sizeof(TFromD), "Must have same size"); + const Full256 dfrom; + return MaskFromVec(BitCast(d_to, VecFromMask(dfrom, m))); +} + +template +HWY_API Mask256 TestBit(const Vec256 v, const Vec256 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Equality + +template +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{__lasx_xvseq_b(a.raw, b.raw)}; +} + +template +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{__lasx_xvseq_h(a.raw, b.raw)}; +} + +template +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{__lasx_xvseq_w(a.raw, b.raw)}; +} + +template +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{__lasx_xvseq_d(a.raw, b.raw)}; +} + +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + const DFromV d; + const RebindToSigned di; + return RebindMask(d, MFromD{__lasx_xvfcmp_ceq_s(a.raw, b.raw)}); +} + +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + const DFromV d; + const RebindToSigned di; + return RebindMask(d, MFromD{__lasx_xvfcmp_ceq_d(a.raw, b.raw)}); +} + +// ------------------------------ Inequality + +template +HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { + return Not(a == b); +} +HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { + const DFromV d; + const RebindToSigned di; + return RebindMask(d, MFromD{__lasx_xvfcmp_cne_s(a.raw, b.raw)}); +} +HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { + const DFromV d; + const RebindToSigned di; + return RebindMask(d, MFromD{__lasx_xvfcmp_cne_d(a.raw, b.raw)}); +} + +// ------------------------------ Strict inequality + +namespace detail { + +HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{__lasx_xvslt_b(b.raw, a.raw)}; +} +HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{__lasx_xvslt_h(b.raw, a.raw)}; +} +HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{__lasx_xvslt_w(b.raw, a.raw)}; +} +HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{__lasx_xvslt_d(b.raw, a.raw)}; +} + +HWY_API Mask256 Gt(hwy::UnsignedTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{__lasx_xvslt_bu(b.raw, a.raw)}; +} +HWY_API Mask256 Gt(hwy::UnsignedTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{__lasx_xvslt_hu(b.raw, a.raw)}; +} +HWY_API Mask256 Gt(hwy::UnsignedTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{__lasx_xvslt_wu(b.raw, a.raw)}; +} +HWY_API Mask256 Gt(hwy::UnsignedTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{__lasx_xvslt_du(b.raw, a.raw)}; +} + +HWY_API Mask256 Gt(hwy::FloatTag /*tag*/, Vec256 a, + Vec256 b) { + const DFromV d; + const RebindToSigned di; + return RebindMask(d, MFromD{__lasx_xvfcmp_clt_s(b.raw, a.raw)}); +} +HWY_API Mask256 Gt(hwy::FloatTag /*tag*/, Vec256 a, + Vec256 b) { + const DFromV d; + const RebindToSigned di; + return RebindMask(d, MFromD{__lasx_xvfcmp_clt_d(b.raw, a.raw)}); +} + +} // namespace detail + +template +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return detail::Gt(hwy::TypeTag(), a, b); +} + +// ------------------------------ Weak inequality + +namespace detail { + +template +HWY_INLINE Mask256 Ge(hwy::SignedTag /*tag*/, Vec256 a, Vec256 b) { + return Not(b > a); +} + +template +HWY_INLINE Mask256 Ge(hwy::UnsignedTag /*tag*/, Vec256 a, Vec256 b) { + return Not(b > a); +} + +HWY_INLINE Mask256 Ge(hwy::FloatTag /*tag*/, Vec256 a, + Vec256 b) { + const DFromV d; + const RebindToSigned di; + return RebindMask(d, MFromD{__lasx_xvfcmp_cle_s(b.raw, a.raw)}); +} +HWY_INLINE Mask256 Ge(hwy::FloatTag /*tag*/, Vec256 a, + Vec256 b) { + const DFromV d; + const RebindToSigned di; + return RebindMask(d, MFromD{__lasx_xvfcmp_cle_d(b.raw, a.raw)}); +} + +} // namespace detail + +template +HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { + return detail::Ge(hwy::TypeTag(), a, b); +} + +// ------------------------------ Reversed comparisons + +template +HWY_API Mask256 operator<(const Vec256 a, const Vec256 b) { + return b > a; +} + +template +HWY_API Mask256 operator<=(const Vec256 a, const Vec256 b) { + return b >= a; +} + +// ------------------------------ Min (Gt, IfThenElse) + +// Unsigned +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{__lasx_xvmin_bu(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { + return Vec256{__lasx_xvmin_hu(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { + return Vec256{__lasx_xvmin_wu(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { + return Vec256{__lasx_xvmin_du(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{__lasx_xvmin_b(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{__lasx_xvmin_h(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{__lasx_xvmin_w(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{__lasx_xvmin_d(a.raw, b.raw)}; +} + +// Float +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{__lasx_xvfmin_s(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{__lasx_xvfmin_d(a.raw, b.raw)}; +} + +// ------------------------------ Max (Gt, IfThenElse) + +// Unsigned +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{__lasx_xvmax_bu(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { + return Vec256{__lasx_xvmax_hu(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { + return Vec256{__lasx_xvmax_wu(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { + return Vec256{__lasx_xvmax_du(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{__lasx_xvmax_b(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{__lasx_xvmax_h(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{__lasx_xvmax_w(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{__lasx_xvmax_d(a.raw, b.raw)}; +} + +// Float +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{__lasx_xvfmax_s(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{__lasx_xvfmax_d(a.raw, b.raw)}; +} + +// ------------------------------ MinMagnitude and MaxMagnitude + +HWY_API Vec256 MinMagnitude(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvfmina_s(a.raw, b.raw)}; +} +HWY_API Vec256 MinMagnitude(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvfmina_d(a.raw, b.raw)}; +} + +HWY_API Vec256 MaxMagnitude(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvfmaxa_s(a.raw, b.raw)}; +} +HWY_API Vec256 MaxMagnitude(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvfmaxa_d(a.raw, b.raw)}; +} + +// ------------------------------ Iota + +namespace detail { + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + typedef int8_t GccI8RawVectType __attribute__((__vector_size__(32))); + const GccI8RawVectType raw_i8_vec = { + static_cast(0), static_cast(1), static_cast(2), + static_cast(3), static_cast(4), static_cast(5), + static_cast(6), static_cast(7), static_cast(8), + static_cast(9), static_cast(10), static_cast(11), + static_cast(12), static_cast(13), static_cast(14), + static_cast(15), static_cast(16), static_cast(17), + static_cast(18), static_cast(19), static_cast(20), + static_cast(21), static_cast(22), static_cast(23), + static_cast(24), static_cast(25), static_cast(26), + static_cast(27), static_cast(28), static_cast(29), + static_cast(30), static_cast(31)}; + return VFromD{reinterpret_cast<__m256i>(raw_i8_vec)}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + typedef int16_t GccI16RawVectType __attribute__((__vector_size__(32))); + const GccI16RawVectType raw_i16_vec = { + static_cast(0), static_cast(1), + static_cast(2), static_cast(3), + static_cast(4), static_cast(5), + static_cast(6), static_cast(7), + static_cast(8), static_cast(9), + static_cast(10), static_cast(11), + static_cast(12), static_cast(13), + static_cast(14), static_cast(15)}; + return VFromD{reinterpret_cast<__m256i>(raw_i16_vec)}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + typedef int32_t GccI32RawVectType __attribute__((__vector_size__(32))); + const GccI32RawVectType raw_i32_vec = { + static_cast(0), static_cast(1), static_cast(2), + static_cast(3), static_cast(4), static_cast(5), + static_cast(6), static_cast(7)}; + return VFromD{reinterpret_cast<__m256i>(raw_i32_vec)}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + typedef int64_t GccI64RawVectType __attribute__((__vector_size__(32))); + const GccI64RawVectType raw_i64_vec = { + static_cast(0), static_cast(1), static_cast(2), + static_cast(3)}; + return VFromD{reinterpret_cast<__m256i>(raw_i64_vec)}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + typedef float GccF32RawVectType __attribute__((__vector_size__(32))); + const GccF32RawVectType raw_f32_vec = {0.0f, 1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f, 7.0f}; + return VFromD{reinterpret_cast<__m256>(raw_f32_vec)}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + typedef double GccF64RawVectType __attribute__((__vector_size__(32))); + const GccF64RawVectType raw_f64_vec = {0.0, 1.0, 2.0, 3.0}; + return VFromD{reinterpret_cast<__m256d>(raw_f64_vec)}; +} + +} // namespace detail + +template +HWY_API VFromD Iota(D d, const T2 first) { + return detail::Iota0(d) + Set(d, ConvertScalarTo>(first)); +} + +// ------------------------------ FirstN (Iota, Lt) + +template > +HWY_API M FirstN(const D d, size_t n) { + constexpr size_t kN = MaxLanes(d); + n = HWY_MIN(n, kN); + const RebindToSigned di; // Signed comparisons are cheaper. + using TI = TFromD; + return RebindMask(d, detail::Iota0(di) < Set(di, static_cast(n))); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvadd_b(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvadd_h(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvadd_w(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvadd_d(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvadd_b(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvadd_h(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvadd_w(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvadd_d(a.raw, b.raw)}; +} + +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvfadd_s(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvfadd_d(a.raw, b.raw)}; +} + +template +HWY_API Vec256 Add(Vec256 a, Vec256 b) { + return a + b; +} + +// ------------------------------ Subtraction + +// Unsigne +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvsub_b(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvsub_h(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvsub_w(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvsub_d(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvsub_b(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvsub_h(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvsub_w(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvsub_d(a.raw, b.raw)}; +} + +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvfsub_s(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvfsub_d(a.raw, b.raw)}; +} + +// ------------------------------ SumsOf8 +HWY_API Vec256 SumsOf8(Vec256 v) { + v.raw = __lasx_xvhaddw_hu_bu(v.raw, v.raw); + v.raw = __lasx_xvhaddw_wu_hu(v.raw, v.raw); + return Vec256{__lasx_xvhaddw_du_wu(v.raw, v.raw)}; +} +HWY_API Vec256 SumsOf8(Vec256 v) { + v.raw = __lasx_xvhaddw_h_b(v.raw, v.raw); + v.raw = __lasx_xvhaddw_w_h(v.raw, v.raw); + return Vec256{__lasx_xvhaddw_d_w(v.raw, v.raw)}; +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +HWY_API Vec256 SaturatedAdd(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvsadd_bu(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedAdd(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvsadd_hu(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedAdd(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvsadd_wu(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedAdd(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvsadd_du(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 SaturatedAdd(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvsadd_b(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedAdd(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvsadd_h(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedAdd(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvsadd_w(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedAdd(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvsadd_d(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +HWY_API Vec256 SaturatedSub(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvssub_bu(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedSub(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvssub_hu(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedSub(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvssub_wu(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedSub(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvssub_du(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 SaturatedSub(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvssub_b(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedSub(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvssub_h(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedSub(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvssub_w(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedSub(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvssub_d(a.raw, b.raw)}; +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +// Unsigned +HWY_API Vec256 AverageRound(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvavgr_b(a.raw, b.raw)}; +} +HWY_API Vec256 AverageRound(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvavgr_bu(a.raw, b.raw)}; +} +HWY_API Vec256 AverageRound(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvavgr_h(a.raw, b.raw)}; +} +HWY_API Vec256 AverageRound(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvavgr_hu(a.raw, b.raw)}; +} +HWY_API Vec256 AverageRound(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvavgr_w(a.raw, b.raw)}; +} +HWY_API Vec256 AverageRound(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvavgr_wu(a.raw, b.raw)}; +} +HWY_API Vec256 AverageRound(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvavgr_d(a.raw, b.raw)}; +} +HWY_API Vec256 AverageRound(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvavgr_du(a.raw, b.raw)}; +} + +// ------------------------------ Abs (Sub) + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +HWY_API Vec256 Abs(Vec256 v) { + return Vec256{__lasx_xvabsd_b(v.raw, __lasx_xvreplgr2vr_b(0))}; +} +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{__lasx_xvabsd_h(v.raw, __lasx_xvreplgr2vr_h(0))}; +} +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{__lasx_xvabsd_w(v.raw, __lasx_xvreplgr2vr_w(0))}; +} +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{__lasx_xvabsd_d(v.raw, __lasx_xvreplgr2vr_d(0))}; +} + +// ------------------------------ Integer AbsDiff +HWY_API Vec256 AbsDiff(const Vec256 a, Vec256 b) { + return Vec256{__lasx_xvabsd_b(a.raw, b.raw)}; +} +HWY_API Vec256 AbsDiff(const Vec256 a, Vec256 b) { + return Vec256{__lasx_xvabsd_h(a.raw, b.raw)}; +} +HWY_API Vec256 AbsDiff(const Vec256 a, Vec256 b) { + return Vec256{__lasx_xvabsd_w(a.raw, b.raw)}; +} +HWY_API Vec256 AbsDiff(const Vec256 a, Vec256 b) { + return Vec256{__lasx_xvabsd_d(a.raw, b.raw)}; +} + +HWY_API Vec256 AbsDiff(const Vec256 a, Vec256 b) { + return Vec256{__lasx_xvabsd_bu(a.raw, b.raw)}; +} +HWY_API Vec256 AbsDiff(const Vec256 a, Vec256 b) { + return Vec256{__lasx_xvabsd_hu(a.raw, b.raw)}; +} +HWY_API Vec256 AbsDiff(const Vec256 a, Vec256 b) { + return Vec256{__lasx_xvabsd_wu(a.raw, b.raw)}; +} +HWY_API Vec256 AbsDiff(const Vec256 a, Vec256 b) { + return Vec256{__lasx_xvabsd_du(a.raw, b.raw)}; +} + +// ------------------------------ Integer multiplication + +// Unsigned +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmul_b(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmul_h(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmul_w(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmul_d(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmul_b(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmul_h(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmul_w(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmul_d(a.raw, b.raw)}; +} + +HWY_API Vec256 MulHigh(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmuh_bu(a.raw, b.raw)}; +} +HWY_API Vec256 MulHigh(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmuh_b(a.raw, b.raw)}; +} +HWY_API Vec256 MulHigh(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmuh_hu(a.raw, b.raw)}; +} +HWY_API Vec256 MulHigh(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmuh_h(a.raw, b.raw)}; +} +HWY_API Vec256 MulHigh(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmuh_wu(a.raw, b.raw)}; +} +HWY_API Vec256 MulHigh(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmuh_w(a.raw, b.raw)}; +} +HWY_API Vec256 MulHigh(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmuh_du(a.raw, b.raw)}; +} +HWY_API Vec256 MulHigh(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmuh_d(a.raw, b.raw)}; +} + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +HWY_API Vec256 MulEven(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmulwev_h_b(a.raw, b.raw)}; +} +HWY_API Vec256 MulEven(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmulwev_h_bu(a.raw, b.raw)}; +} +HWY_API Vec256 MulEven(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmulwev_w_h(a.raw, b.raw)}; +} +HWY_API Vec256 MulEven(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmulwev_w_hu(a.raw, b.raw)}; +} +HWY_API Vec256 MulEven(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmulwev_d_w(a.raw, b.raw)}; +} +HWY_API Vec256 MulEven(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmulwev_d_wu(a.raw, b.raw)}; +} +template +HWY_API Vec256 MulEven(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmulwev_q_d(a.raw, b.raw)}; +} +template +HWY_API Vec256 MulEven(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmulwev_q_du(a.raw, b.raw)}; +} + +HWY_API Vec256 MulOdd(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmulwod_h_b(a.raw, b.raw)}; +} +HWY_API Vec256 MulOdd(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmulwod_h_bu(a.raw, b.raw)}; +} +HWY_API Vec256 MulOdd(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmulwod_w_h(a.raw, b.raw)}; +} +HWY_API Vec256 MulOdd(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmulwod_w_hu(a.raw, b.raw)}; +} +HWY_API Vec256 MulOdd(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmulwod_d_w(a.raw, b.raw)}; +} +HWY_API Vec256 MulOdd(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmulwod_d_wu(a.raw, b.raw)}; +} +template +HWY_API Vec256 MulOdd(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmulwod_q_d(a.raw, b.raw)}; +} +template +HWY_API Vec256 MulOdd(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvmulwod_q_du(a.raw, b.raw)}; +} + +template +HWY_API Vec256 MulFixedPoint15(Vec256 a, Vec256 b) { + const auto i32_ev = MulEven(a, b); + const auto i32_od = MulOdd(a, b); + const auto i64_lo = InterleaveLower(i32_ev, i32_od); + const auto i64_hi = InterleaveUpper(Full256(), i32_ev, i32_od); + return Vec256{__lasx_xvssrarni_h_w(i64_hi.raw, i64_lo.raw, 15)}; +} + +// ------------------------------ Integer division + +HWY_API Vec256 operator/(const Vec256 a, + const Vec256 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + // or a[i] == LimitsMin() && b[i] == -1 + __m256i raw_result; + __asm__("xvdiv.b %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec256{raw_result}; +} + +HWY_API Vec256 operator/(const Vec256 a, + const Vec256 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + __m256i raw_result; + __asm__("xvdiv.bu %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec256{raw_result}; +} + +HWY_API Vec256 operator/(const Vec256 a, + const Vec256 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + // or a[i] == LimitsMin() && b[i] == -1 + __m256i raw_result; + __asm__("xvdiv.h %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec256{raw_result}; +} + +HWY_API Vec256 operator/(const Vec256 a, + const Vec256 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + __m256i raw_result; + __asm__("xvdiv.hu %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec256{raw_result}; +} + +HWY_API Vec256 operator/(const Vec256 a, + const Vec256 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + // or a[i] == LimitsMin() && b[i] == -1 + __m256i raw_result; + __asm__("xvdiv.w %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec256{raw_result}; +} + +HWY_API Vec256 operator/(const Vec256 a, + const Vec256 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + __m256i raw_result; + __asm__("xvdiv.wu %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec256{raw_result}; +} + +HWY_API Vec256 operator/(const Vec256 a, + const Vec256 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + // or a[i] == LimitsMin() && b[i] == -1 + __m256i raw_result; + __asm__("xvdiv.d %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec256{raw_result}; +} + +HWY_API Vec256 operator/(const Vec256 a, + const Vec256 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + __m256i raw_result; + __asm__("xvdiv.du %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec256{raw_result}; +} + +// ------------------------------ Integer modulo + +HWY_API Vec256 operator%(const Vec256 a, + const Vec256 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + // or a[i] == LimitsMin() && b[i] == -1 + __m256i raw_result; + __asm__("xvmod.b %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec256{raw_result}; +} + +HWY_API Vec256 operator%(const Vec256 a, + const Vec256 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + __m256i raw_result; + __asm__("xvmod.bu %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec256{raw_result}; +} + +HWY_API Vec256 operator%(const Vec256 a, + const Vec256 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + // or a[i] == LimitsMin() && b[i] == -1 + __m256i raw_result; + __asm__("xvmod.h %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec256{raw_result}; +} + +HWY_API Vec256 operator%(const Vec256 a, + const Vec256 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + __m256i raw_result; + __asm__("xvmod.hu %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec256{raw_result}; +} + +HWY_API Vec256 operator%(const Vec256 a, + const Vec256 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + // or a[i] == LimitsMin() && b[i] == -1 + __m256i raw_result; + __asm__("xvmod.w %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec256{raw_result}; +} + +HWY_API Vec256 operator%(const Vec256 a, + const Vec256 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + __m256i raw_result; + __asm__("xvmod.wu %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec256{raw_result}; +} + +HWY_API Vec256 operator%(const Vec256 a, + const Vec256 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + // or a[i] == LimitsMin() && b[i] == -1 + __m256i raw_result; + __asm__("xvmod.d %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec256{raw_result}; +} + +HWY_API Vec256 operator%(const Vec256 a, + const Vec256 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + __m256i raw_result; + __asm__("xvmod.du %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec256{raw_result}; +} + +// ------------------------------ ShiftLeft (Compile-time constant shifts) + +template +HWY_API Vec256 ShiftLeft(Vec256 v) { + return Vec256{__lasx_xvslli_b(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(Vec256 v) { + return Vec256{__lasx_xvslli_h(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(Vec256 v) { + return Vec256{__lasx_xvslli_w(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(Vec256 v) { + return Vec256{__lasx_xvslli_d(v.raw, kBits)}; +} + +// ------------------------------ ShiftRight (Compile-time constant shifts) + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + return Vec256{__lasx_xvsrli_b(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + return Vec256{__lasx_xvsrli_h(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + return Vec256{__lasx_xvsrli_w(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + return Vec256{__lasx_xvsrli_d(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + return Vec256{__lasx_xvsrai_b(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + return Vec256{__lasx_xvsrai_h(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + return Vec256{__lasx_xvsrai_w(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + return Vec256{__lasx_xvsrai_d(v.raw, kBits)}; +} + +// ------------------------------ RoundingShiftRight + +template +HWY_API Vec256 RoundingShiftRight(Vec256 v) { + return Vec256{__lasx_xvsrari_b(v.raw, kBits)}; +} +template +HWY_API Vec256 RoundingShiftRight(Vec256 v) { + return Vec256{__lasx_xvsrari_h(v.raw, kBits)}; +} +template +HWY_API Vec256 RoundingShiftRight(Vec256 v) { + return Vec256{__lasx_xvsrari_w(v.raw, kBits)}; +} +template +HWY_API Vec256 RoundingShiftRight(Vec256 v) { + return Vec256{__lasx_xvsrari_d(v.raw, kBits)}; +} + +template +HWY_API Vec256 RoundingShiftRight(Vec256 v) { + return Vec256{__lasx_xvsrlri_b(v.raw, kBits)}; +} +template +HWY_API Vec256 RoundingShiftRight(Vec256 v) { + return Vec256{__lasx_xvsrlri_h(v.raw, kBits)}; +} +template +HWY_API Vec256 RoundingShiftRight(Vec256 v) { + return Vec256{__lasx_xvsrlri_w(v.raw, kBits)}; +} +template +HWY_API Vec256 RoundingShiftRight(Vec256 v) { + return Vec256{__lasx_xvsrlri_d(v.raw, kBits)}; +} +// ------------------------------ RoundingShr + +HWY_API Vec256 RoundingShr(Vec256 v, Vec256 bits) { + return Vec256{__lasx_xvsrlr_b(v.raw, bits.raw)}; +} +HWY_API Vec256 RoundingShr(Vec256 v, + Vec256 bits) { + return Vec256{__lasx_xvsrlr_h(v.raw, bits.raw)}; +} +HWY_API Vec256 RoundingShr(Vec256 v, + Vec256 bits) { + return Vec256{__lasx_xvsrlr_w(v.raw, bits.raw)}; +} +HWY_API Vec256 RoundingShr(Vec256 v, + Vec256 bits) { + return Vec256{__lasx_xvsrlr_d(v.raw, bits.raw)}; +} + +HWY_API Vec256 RoundingShr(Vec256 v, Vec256 bits) { + return Vec256{__lasx_xvsrar_b(v.raw, bits.raw)}; +} +HWY_API Vec256 RoundingShr(Vec256 v, Vec256 bits) { + return Vec256{__lasx_xvsrar_h(v.raw, bits.raw)}; +} +HWY_API Vec256 RoundingShr(Vec256 v, Vec256 bits) { + return Vec256{__lasx_xvsrar_w(v.raw, bits.raw)}; +} +HWY_API Vec256 RoundingShr(Vec256 v, Vec256 bits) { + return Vec256{__lasx_xvsrar_d(v.raw, bits.raw)}; +} + +// ------------------------------ RoundingShiftRightSame (RoundingShr) + +template +HWY_API Vec256 RoundingShiftRightSame(const Vec256 v, int bits) { + return RoundingShr(v, Set(DFromV(), static_cast(bits))); +} + +// ------------------------------ RotateRight (Compile-time constant shifts) + +template +HWY_API Vec256 RotateRight(const Vec256 v) { + static_assert(0 <= kBits && kBits < 8, "Invalid shift count"); + if (kBits == 0) return v; + return Vec256{__lasx_xvrotri_b(v.raw, kBits)}; +} + +template +HWY_API Vec256 RotateRight(const Vec256 v) { + static_assert(0 <= kBits && kBits < 16, "Invalid shift count"); + if (kBits == 0) return v; + return Vec256{__lasx_xvrotri_h(v.raw, kBits)}; +} + +template +HWY_API Vec256 RotateRight(const Vec256 v) { + static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); + if (kBits == 0) return v; + return Vec256{__lasx_xvrotri_w(v.raw, kBits)}; +} + +template +HWY_API Vec256 RotateRight(const Vec256 v) { + static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); + if (kBits == 0) return v; + return Vec256{__lasx_xvrotri_d(v.raw, kBits)}; +} + +// ------------------------------ Rol/Ror +template +HWY_API Vec256 Ror(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvrotr_b(a.raw, b.raw)}; +} + +template +HWY_API Vec256 Ror(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvrotr_h(a.raw, b.raw)}; +} + +template +HWY_API Vec256 Ror(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvrotr_w(a.raw, b.raw)}; +} + +template +HWY_API Vec256 Ror(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvrotr_d(a.raw, b.raw)}; +} + +// ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return Vec256{__lasx_xvsrai_b(v.raw, 7)}; +} + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return Vec256{__lasx_xvsrai_h(v.raw, 15)}; +} + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return Vec256{__lasx_xvsrai_w(v.raw, 31)}; +} + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return Vec256{__lasx_xvsrai_d(v.raw, 63)}; +} + +// ------------------------------ IfNegativeThenElse (BroadcastSignBit) +template +HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const DFromV d; + const RebindToSigned di; + const auto mask = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); + return IfThenElse(mask, yes, no); +} + +// ------------------------------ IfNegativeThenNegOrUndefIfZero + +HWY_API Vec256 IfNegativeThenNegOrUndefIfZero(Vec256 mask, + Vec256 v) { + return Vec256{__lasx_xvsigncov_b(mask.raw, v.raw)}; +} + +HWY_API Vec256 IfNegativeThenNegOrUndefIfZero(Vec256 mask, + Vec256 v) { + return Vec256{__lasx_xvsigncov_h(mask.raw, v.raw)}; +} + +HWY_API Vec256 IfNegativeThenNegOrUndefIfZero(Vec256 mask, + Vec256 v) { + return Vec256{__lasx_xvsigncov_w(mask.raw, v.raw)}; +} + +HWY_API Vec256 IfNegativeThenNegOrUndefIfZero(Vec256 mask, + Vec256 v) { + return Vec256{__lasx_xvsigncov_d(mask.raw, v.raw)}; +} + +// ------------------------------ ShiftLeftSame + +template +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { + return Shl(v, Set(DFromV(), static_cast(bits))); +} + +// ------------------------------ ShiftRightSame (BroadcastSignBit) + +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{__lasx_xvsrl_b(v.raw, __lasx_xvreplgr2vr_b(bits))}; +} + +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{__lasx_xvsrl_h(v.raw, __lasx_xvreplgr2vr_h(bits))}; +} + +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{__lasx_xvsrl_w(v.raw, __lasx_xvreplgr2vr_w(bits))}; +} + +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{__lasx_xvsrl_d(v.raw, __lasx_xvreplgr2vr_d(bits))}; +} + +HWY_API Vec256 ShiftRightSame(const Vec256 v, const int bits) { + return Vec256{__lasx_xvsra_b(v.raw, __lasx_xvreplgr2vr_b(bits))}; +} + +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{__lasx_xvsra_h(v.raw, __lasx_xvreplgr2vr_h(bits))}; +} + +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{__lasx_xvsra_w(v.raw, __lasx_xvreplgr2vr_w(bits))}; +} + +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { + return Vec256{__lasx_xvsra_d(v.raw, __lasx_xvreplgr2vr_d(bits))}; +} + +// ------------------------------ Neg (Xor, Sub) + +namespace detail { + +template +HWY_INLINE Vec256 Neg(hwy::FloatTag /*tag*/, const Vec256 v) { + const DFromV d; + return Xor(v, SignBit(d)); +} + +template +HWY_INLINE Vec256 Neg(hwy::SpecialTag /*tag*/, const Vec256 v) { + const DFromV d; + return Xor(v, SignBit(d)); +} + +// Not floating-point +template +HWY_INLINE Vec256 Neg(hwy::SignedTag /*tag*/, const Vec256 v) { + return Vec256{__lasx_xvneg_b(v.raw)}; +} + +template +HWY_INLINE Vec256 Neg(hwy::SignedTag /*tag*/, const Vec256 v) { + return Vec256{__lasx_xvneg_h(v.raw)}; +} + +template +HWY_INLINE Vec256 Neg(hwy::SignedTag /*tag*/, const Vec256 v) { + return Vec256{__lasx_xvneg_w(v.raw)}; +} + +template +HWY_INLINE Vec256 Neg(hwy::SignedTag /*tag*/, const Vec256 v) { + return Vec256{__lasx_xvneg_d(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec256 Neg(const Vec256 v) { + return detail::Neg(hwy::TypeTag(), v); +} + +// ------------------------------ Floating-point mul / div + +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvfmul_s(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvfmul_d(a.raw, b.raw)}; +} + +HWY_API Vec256 operator/(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvfdiv_s(a.raw, b.raw)}; +} +HWY_API Vec256 operator/(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvfdiv_d(a.raw, b.raw)}; +} + +// Approximate reciprocal + +HWY_API Vec256 ApproximateReciprocal(Vec256 v) { + return Vec256{__lasx_xvfrecip_s(v.raw)}; +} + +HWY_API Vec256 ApproximateReciprocal(Vec256 v) { + return Vec256{__lasx_xvfrecip_d(v.raw)}; +} + +// Integer multiply-add variants + +// signed +HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{__lasx_xvmadd_b(add.raw, mul.raw, x.raw)}; +} +HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{__lasx_xvmadd_h(add.raw, mul.raw, x.raw)}; +} +HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{__lasx_xvmadd_w(add.raw, mul.raw, x.raw)}; +} +HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{__lasx_xvmadd_d(add.raw, mul.raw, x.raw)}; +} + +// unsigend +HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{__lasx_xvmadd_b(add.raw, mul.raw, x.raw)}; +} +HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{__lasx_xvmadd_h(add.raw, mul.raw, x.raw)}; +} +HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{__lasx_xvmadd_w(add.raw, mul.raw, x.raw)}; +} +HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{__lasx_xvmadd_d(add.raw, mul.raw, x.raw)}; +} + +// signed +HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{__lasx_xvmsub_b(add.raw, mul.raw, x.raw)}; +} +HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{__lasx_xvmsub_h(add.raw, mul.raw, x.raw)}; +} +HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{__lasx_xvmsub_w(add.raw, mul.raw, x.raw)}; +} +HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{__lasx_xvmsub_d(add.raw, mul.raw, x.raw)}; +} + +// unsigned +HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{__lasx_xvmsub_b(add.raw, mul.raw, x.raw)}; +} +HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{__lasx_xvmsub_h(add.raw, mul.raw, x.raw)}; +} +HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{__lasx_xvmsub_w(add.raw, mul.raw, x.raw)}; +} +HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{__lasx_xvmsub_d(add.raw, mul.raw, x.raw)}; +} + +// ------------------------------ Floating-point multiply-add variants + +HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{__lasx_xvfmadd_s(mul.raw, x.raw, add.raw)}; +} +HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{__lasx_xvfmadd_d(mul.raw, x.raw, add.raw)}; +} + +HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return add - mul * x; +} +HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return add - mul * x; +} + +HWY_API Vec256 MulSub(Vec256 mul, Vec256 x, + Vec256 sub) { + return Vec256{__lasx_xvfmsub_s(mul.raw, x.raw, sub.raw)}; +} +HWY_API Vec256 MulSub(Vec256 mul, Vec256 x, + Vec256 sub) { + return Vec256{__lasx_xvfmsub_d(mul.raw, x.raw, sub.raw)}; +} + +HWY_API Vec256 NegMulSub(Vec256 mul, Vec256 x, + Vec256 sub) { + return Vec256{__lasx_xvfnmadd_s(mul.raw, x.raw, sub.raw)}; +} +HWY_API Vec256 NegMulSub(Vec256 mul, Vec256 x, + Vec256 sub) { + return Vec256{__lasx_xvfnmadd_d(mul.raw, x.raw, sub.raw)}; +} + +// ------------------------------ MulAddSub(Float) + +template +HWY_API Vec256 MulAddSub(Vec256 mul, Vec256 x, Vec256 sub_or_add) { + return OddEven(MulAdd(mul, x, sub_or_add), MulSub(mul, x, sub_or_add)); +} + +// ------------------------------ Floating-point square root + +// Full precision square root +HWY_API Vec256 Sqrt(Vec256 v) { + return Vec256{__lasx_xvfsqrt_s(v.raw)}; +} + +HWY_API Vec256 Sqrt(Vec256 v) { + return Vec256{__lasx_xvfsqrt_d(v.raw)}; +} + +// Approximate reciprocal square root +HWY_API Vec256 ApproximateReciprocalSqrt(Vec256 v) { + return Vec256{__lasx_xvfrsqrt_s(v.raw)}; +} + +HWY_API Vec256 ApproximateReciprocalSqrt(Vec256 v) { + return Vec256{__lasx_xvfrsqrt_d(v.raw)}; +} + +// ------------------------------ Floating-point rounding + +// Toward nearest integer, tie to even +HWY_API Vec256 Round(Vec256 v) { + return Vec256{__lasx_xvfrintrne_s(v.raw)}; +} + +HWY_API Vec256 Round(Vec256 v) { + return Vec256{__lasx_xvfrintrne_d(v.raw)}; +} + +// Toward zero, aka truncate +HWY_API Vec256 Trunc(Vec256 v) { + return Vec256{__lasx_xvfrintrz_s(v.raw)}; +} + +HWY_API Vec256 Trunc(Vec256 v) { + return Vec256{__lasx_xvfrintrz_d(v.raw)}; +} + +// Toward +infinity, aka ceiling +HWY_API Vec256 Ceil(Vec256 v) { + return Vec256{__lasx_xvfrintrp_s(v.raw)}; +} + +HWY_API Vec256 Ceil(Vec256 v) { + return Vec256{__lasx_xvfrintrp_d(v.raw)}; +} + +// Toward -infinity, aka floor +HWY_API Vec256 Floor(Vec256 v) { + return Vec256{__lasx_xvfrintrm_s(v.raw)}; +} + +HWY_API Vec256 Floor(Vec256 v) { + return Vec256{__lasx_xvfrintrm_d(v.raw)}; +} + +// ------------------------------ Floating-point classification + +// FIXME: disable gcc-14 tree-based loop optimizations to prevent +// 'HighwayTestGroup/HighwayTest.TestAllIsNaN/LASX' failures +#if HWY_COMPILER_GCC && !HWY_COMPILER_CLANG +#pragma GCC push_options +#pragma GCC optimize("-fno-tree-loop-optimize") +#endif + +HWY_API Mask256 IsNaN(Vec256 v) { + const DFromV d; + const RebindToSigned di; + return RebindMask(d, + MFromD{__lasx_xvfcmp_cune_s(v.raw, v.raw)}); +} + +HWY_API Mask256 IsNaN(Vec256 v) { + const DFromV d; + const RebindToSigned di; + return RebindMask(d, + MFromD{__lasx_xvfcmp_cune_d(v.raw, v.raw)}); +} + +#if HWY_COMPILER_GCC && !HWY_COMPILER_CLANG +#pragma GCC pop_options +#endif + +HWY_API Mask256 IsEitherNaN(Vec256 a, Vec256 b) { + const DFromV d; + const RebindToSigned di; + return RebindMask(d, MFromD{__lasx_xvfcmp_cun_s(a.raw, b.raw)}); +} + +HWY_API Mask256 IsEitherNaN(Vec256 a, Vec256 b) { + const DFromV d; + const RebindToSigned di; + return RebindMask(d, MFromD{__lasx_xvfcmp_cun_d(a.raw, b.raw)}); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API VFromD Load(D /* tag */, const TFromD* HWY_RESTRICT aligned) { + const RebindToSigned di; + return BitCast(D(), VFromD{__lasx_xvld(aligned, 0)}); +} + +template +HWY_API VFromD LoadU(D /* tag */, const TFromD* HWY_RESTRICT p) { + const RebindToSigned di; + return BitCast(D(), VFromD{__lasx_xvld(p, 0)}); +} + +// ------------------------------ MaskedLoad + +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + return IfThenElseZero(m, LoadU(d, p)); +} + +// ------------------------------ LoadDup128 + +template +HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { + VFromD> vec_tmp; + vec_tmp = Load(Half(), p); + return Combine(d, vec_tmp, vec_tmp); +} + +// ------------------------------ Store + +template +HWY_API void Store(VFromD v, D /* tag */, TFromD* HWY_RESTRICT aligned) { + __lasx_xvst(v.raw, aligned, 0); +} + +template +HWY_API void StoreU(VFromD v, D /* tag */, TFromD* HWY_RESTRICT p) { + __lasx_xvst(v.raw, p, 0); +} + +// ------------------------------ BlendedStore +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + const auto blended = + IfThenElse(RebindMask(du, m), BitCast(du, v), BitCast(du, LoadU(d, p))); + StoreU(BitCast(d, blended), d, p); +} + +// ================================================== SWIZZLE +// ------------------------------ LowerHalf + +template +HWY_API VFromD LowerHalf(D /* tag */, VFromD> v) { +#if HWY_HAS_BUILTIN(__builtin_shufflevector) + typedef uint32_t U32RawVectType __attribute__((__vector_size__(32))); + return VFromD{reinterpret_cast>::type>( + __builtin_shufflevector(reinterpret_cast(v.raw), + reinterpret_cast(v.raw), 0, 1, 2, + 3))}; +#else + const RebindToUnsigned du; + const Twice dut; + alignas(32) __m128i vec_tmp[2]; + __m256i vec_result = BitCast(dut, v).raw; + CopyBytes<32>(&vec_result, vec_tmp); + return BitCast(D(), VFromD{vec_tmp[0]}); +#endif +} + +template +HWY_API Vec128 LowerHalf(Vec256 v) { + const Full128 dh; + return LowerHalf(dh, v); +} + +// ------------------------------ UpperHalf + +template +HWY_API VFromD UpperHalf(D d, VFromD> v) { +#if HWY_HAS_BUILTIN(__builtin_shufflevector) + (void)d; + typedef uint32_t U32RawVectType __attribute__((__vector_size__(32))); + return VFromD{reinterpret_cast>::type>( + __builtin_shufflevector(reinterpret_cast(v.raw), + reinterpret_cast(v.raw), 4, 5, 6, + 7))}; +#else + const RebindToUnsigned du; + const Twice dut; + alignas(32) __m128i vec_tmp[2]; + __m256i vec_result = BitCast(dut, v).raw; + CopyBytes<32>(&vec_result, vec_tmp); + return BitCast(d, VFromD{vec_tmp[1]}); +#endif +} + +// ------------------------------ ExtractLane (Store) +template +HWY_API T ExtractLane(const Vec256 v, size_t i) { + const DFromV d; + HWY_DASSERT(i < Lanes(d)); + +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + constexpr size_t kLanesPerBlock = 16 / sizeof(T); + if (__builtin_constant_p(i < kLanesPerBlock) && (i < kLanesPerBlock)) { + return ExtractLane(LowerHalf(Half(), v), i); + } +#endif + + alignas(32) T lanes[32 / sizeof(T)]; + Store(v, d, lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane (Store) +template +HWY_API Vec256 InsertLane(const Vec256 v, size_t i, T t) { + return detail::InsertLaneUsingBroadcastAndBlend(v, i, t); +} + +// ------------------------------ GetLane (LowerHalf) +template +HWY_API T GetLane(const Vec256 v) { + return GetLane(LowerHalf(v)); +} + +// ------------------------------ ExtractBlock (LowerHalf, UpperHalf) + +template +HWY_API Vec128 ExtractBlock(Vec256 v) { + static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index"); + const Half> dh; + return (kBlockIdx == 0) ? LowerHalf(dh, v) : UpperHalf(dh, v); +} + +// ------------------------------ ZeroExtendVector + +template +HWY_API VFromD ZeroExtendVector(D /* tag */, VFromD> lo) { +#if HWY_HAS_BUILTIN(__builtin_shufflevector) + typedef uint32_t U32RawVectType __attribute__((__vector_size__(16))); + U32RawVectType zero = {0, 0, 0, 0}; + return VFromD{reinterpret_cast>::type>( + __builtin_shufflevector(reinterpret_cast(lo.raw), zero, 0, + 1, 2, 3, 4, 5, 6, 7))}; +#else + return Combine(D(), Zero(Half()), lo); +#endif +} + +// ------------------------------ ZeroExtendResizeBitCast + +namespace detail { + +template +HWY_INLINE VFromD ZeroExtendResizeBitCast( + hwy::SizeTag<8> /* from_size_tag */, hwy::SizeTag<32> /* to_size_tag */, + DTo d_to, DFrom d_from, VFromD v) { + const Twice dt_from; + const Twice dq_from; + return BitCast(d_to, ZeroExtendVector(dq_from, ZeroExtendVector(dt_from, v))); +} + +} // namespace detail + +// ------------------------------ Combine + +template +HWY_API VFromD Combine(D d, VFromD> hi, VFromD> lo) { +#if HWY_HAS_BUILTIN(__builtin_shufflevector) + (void)d; + typedef uint32_t U32RawVectType __attribute__((__vector_size__(16))); + return VFromD{reinterpret_cast>::type>( + __builtin_shufflevector(reinterpret_cast(lo.raw), + reinterpret_cast(hi.raw), 0, 1, 2, + 3, 4, 5, 6, 7))}; +#else + const RebindToUnsigned du; + const Half du128; + alignas(32) __m128i vec_tmp[2]; + __m256i vec_result; + vec_tmp[0] = BitCast(du128, lo).raw; + vec_tmp[1] = BitCast(du128, hi).raw; + CopyBytes<32>(vec_tmp, &vec_result); + return BitCast(d, VFromD{vec_result}); +#endif +} + +// ------------------------------ ShiftLeftBytes +template +HWY_API VFromD ShiftLeftBytes(D d, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + if (kBytes == 0) return v; + const RebindToUnsigned du; + return BitCast( + d, VFromD{__lasx_xvbsll_v(BitCast(du, v).raw, kBytes)}); +} + +// ------------------------------ ShiftRightBytes +template +HWY_API VFromD ShiftRightBytes(D d, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + if (kBytes == 0) return v; + const RebindToUnsigned du; + return BitCast( + d, VFromD{__lasx_xvbsrl_v(BitCast(du, v).raw, kBytes)}); +} + +// ------------------------------ CombineShiftRightBytes +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + return Or(ShiftRightBytes(d, lo), ShiftLeftBytes<16 - kBytes>(d, hi)); +} + +// ------------------------------ Broadcast + +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 16, "Invalid lane"); + return Vec256{__lasx_xvreplve_b(v.raw, kLane)}; +} + +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, VFromD{__lasx_xvreplve_h(BitCast(du, v).raw, kLane)}); +} + +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec256{__lasx_xvreplve_w(v.raw, kLane)}; +} + +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec256{__lasx_xvreplve_d(v.raw, kLane)}; +} + +template +HWY_API Vec256 Broadcast(Vec256 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, VFromD{__lasx_xvreplve_w(BitCast(du, v).raw, kLane)}); +} + +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, VFromD{__lasx_xvreplve_d(BitCast(du, v).raw, kLane)}); +} + +// ------------------------------ BroadcastBlock + +template +HWY_API Vec256 BroadcastBlock(Vec256 v) { + static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index"); + const DFromV d; + return (kBlockIdx == 0) ? ConcatLowerLower(d, v, v) + : ConcatUpperUpper(d, v, v); +} + +// ------------------------------ BroadcastLane + +namespace detail { + +template +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec256 v) { + return Vec256{__lasx_xvreplve0_b(v.raw)}; +} + +template +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec256 v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast(d, + VFromD{__lasx_xvreplve0_h(BitCast(du, v).raw)}); +} + +template +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec256 v) { + return Vec256{__lasx_xvreplve0_w(v.raw)}; +} + +template +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec256 v) { + return Vec256{__lasx_xvreplve0_d(v.raw)}; +} + +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec256 v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, + VFromD{__lasx_xvreplve0_w(BitCast(du, v).raw)}); +} + +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec256 v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, + VFromD{__lasx_xvreplve0_d(BitCast(du, v).raw)}); +} + +template * = nullptr> +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag /* lane_idx_tag */, + Vec256 v) { + constexpr size_t kLanesPerBlock = 16 / sizeof(T); + constexpr int kBlockIdx = static_cast(kLaneIdx / kLanesPerBlock); + constexpr int kLaneInBlkIdx = + static_cast(kLaneIdx) & (kLanesPerBlock - 1); + return Broadcast(BroadcastBlock(v)); +} +} // namespace detail + +template +HWY_API Vec256 BroadcastLane(Vec256 v) { + static_assert(kLaneIdx >= 0, "Invalid lane"); + return detail::BroadcastLane(hwy::SizeTag(kLaneIdx)>(), + v); +} + +// ------------------------------ Hard-coded shuffles + +// Notation: let Vec256 have lanes 7,6,5,4,3,2,1,0 (0 is +// least-significant). Shuffle0321 rotates four-lane blocks one lane to the +// right (the previous least-significant lane is now most-significant => +// 47650321). These could also be implemented via CombineShiftRightBytes but +// the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec256 Shuffle2301(const Vec256 v) { + return Vec256{__lasx_xvshuf4i_w(v.raw, 0xb1)}; +} +HWY_API Vec256 Shuffle2301(const Vec256 v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, VFromD{__lasx_xvshuf4i_w(BitCast(du, v).raw, 0xb1)}); +} + +// Used by generic_ops-inl.h +namespace detail { + +template +HWY_API Vec256 ShuffleTwo2301(const Vec256 a, const Vec256 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, VFromD{__lasx_xvpermi_w( + BitCast(du, b).raw, BitCast(du, a).raw, 0xb1)}); +} +template +HWY_API Vec256 ShuffleTwo1230(const Vec256 a, const Vec256 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, VFromD{__lasx_xvpermi_w( + BitCast(du, b).raw, BitCast(du, a).raw, 0x6c)}); +} +template +HWY_API Vec256 ShuffleTwo3012(const Vec256 a, const Vec256 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, VFromD{__lasx_xvpermi_w( + BitCast(du, b).raw, BitCast(du, a).raw, 0xc6)}); +} + +} // namespace detail + +// Swap 64-bit halves +HWY_API Vec256 Shuffle1032(const Vec256 v) { + return Vec256{__lasx_xvshuf4i_w(v.raw, 0x4e)}; +} +HWY_API Vec256 Shuffle1032(const Vec256 v) { + return Vec256{__lasx_xvshuf4i_w(v.raw, 0x4e)}; +} +HWY_API Vec256 Shuffle1032(const Vec256 v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, VFromD{__lasx_xvshuf4i_w(BitCast(du, v).raw, 0x4e)}); +} +HWY_API Vec256 Shuffle01(const Vec256 v) { + return Vec256{__lasx_xvshuf4i_w(v.raw, 0x4e)}; +} +HWY_API Vec256 Shuffle01(const Vec256 v) { + return Vec256{__lasx_xvshuf4i_w(v.raw, 0x4e)}; +} +HWY_API Vec256 Shuffle01(const Vec256 v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, VFromD{__lasx_xvshuf4i_w(BitCast(du, v).raw, 0x4e)}); +} + +// Rotate right 32 bits +HWY_API Vec256 Shuffle0321(const Vec256 v) { + return Vec256{__lasx_xvshuf4i_w(v.raw, 0x39)}; +} +HWY_API Vec256 Shuffle0321(const Vec256 v) { + return Vec256{__lasx_xvshuf4i_w(v.raw, 0x39)}; +} +HWY_API Vec256 Shuffle0321(const Vec256 v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, VFromD{__lasx_xvshuf4i_w(BitCast(du, v).raw, 0x39)}); +} +// Rotate left 32 bits +HWY_API Vec256 Shuffle2103(const Vec256 v) { + return Vec256{__lasx_xvshuf4i_w(v.raw, 0x93)}; +} +HWY_API Vec256 Shuffle2103(const Vec256 v) { + return Vec256{__lasx_xvshuf4i_w(v.raw, 0x93)}; +} +HWY_API Vec256 Shuffle2103(const Vec256 v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, VFromD{__lasx_xvshuf4i_w(BitCast(du, v).raw, 0x93)}); +} + +// Reverse +HWY_API Vec256 Shuffle0123(const Vec256 v) { + return Vec256{__lasx_xvshuf4i_w(v.raw, 0x1B)}; +} +HWY_API Vec256 Shuffle0123(const Vec256 v) { + return Vec256{__lasx_xvshuf4i_w(v.raw, 0x1B)}; +} +HWY_API Vec256 Shuffle0123(const Vec256 v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, VFromD{__lasx_xvshuf4i_w(BitCast(du, v).raw, 0x1b)}); +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. +template +struct Indices256 { + __m256i raw; +}; + +template +HWY_API Indices256> IndicesFromVec(D /* tag */, Vec256 vec) { + static_assert(sizeof(TFromD) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Full256 di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast(2 * Lanes(di)))))); +#endif + return Indices256>{vec.raw}; +} + +template +HWY_API Indices256> SetTableIndices(D d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { + const DFromV d; + const auto a = ConcatLowerLower(d, v, v); + const auto b = ConcatUpperUpper(d, v, v); + return Vec256{__lasx_xvshuf_b(b.raw, a.raw, idx.raw)}; +} + +template +HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { + const DFromV d; + const RebindToUnsigned du; + const auto a = ConcatLowerLower(d, v, v); + const auto b = ConcatUpperUpper(d, v, v); + return BitCast(d, VFromD{__lasx_xvshuf_h( + idx.raw, BitCast(du, b).raw, BitCast(du, a).raw)}); +} + +template +HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { + const DFromV d; + const RebindToSigned di; + return BitCast(d, + Vec256{__lasx_xvperm_w(BitCast(di, v).raw, idx.raw)}); +} + +template +HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { + using TI = MakeSigned; + const DFromV d; + const RebindToSigned di64; + const Repartition di32; + // Replicate 64-bit index into upper 32 bits + const Vec256 dup{__lasx_xvpackev_w(idx.raw, idx.raw)}; + // For each idx64 i, idx32 are 2*i and 2*i+1. + const Vec256 idx32 = dup + dup + Set(di64, int64_t{1} << 32); + return BitCast( + d, TableLookupLanes(BitCast(di32, v), Indices256{idx32.raw})); +} + +template +HWY_API Vec256 TwoTablesLookupLanes(Vec256 a, Vec256 b, + Indices256 idx) { + const auto idx2 = Indices256{__lasx_xvandi_b(idx.raw, 31)}; + const Vec256 idx_vec{idx.raw}; + const auto sel_hi_mask = ShiftLeft<2>(idx_vec); + const auto mask0or1 = __lasx_xvslti_b(sel_hi_mask.raw, 0); + const auto lo_lookup_result = TableLookupLanes(a, idx); + const auto hi_lookup_result = TableLookupLanes(b, idx2); + return IfThenElse(Mask256{mask0or1}, hi_lookup_result, lo_lookup_result); +} + +template +HWY_API Vec256 TwoTablesLookupLanes(Vec256 a, Vec256 b, + Indices256 idx) { + const DFromV d; + const RebindToSigned di; + const Vec256> idx_vec{idx.raw}; + constexpr int shift_count = 8 * sizeof(T) - 6 + CeilLog2(sizeof(T)); + const auto sel_hi_mask = BitCast(di, ShiftLeft(idx_vec)); + const auto lo_lookup_result = BitCast(di, TableLookupLanes(a, idx)); + const auto hi_lookup_result = BitCast(di, TableLookupLanes(b, idx)); + return BitCast( + d, IfNegativeThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result)); +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, Vec256{__lasx_xvpermi_q( + BitCast(du, v).raw, BitCast(du, v).raw, 0x01)}); +} + +// ------------------------------ InterleaveEvenBlocks (ConcatLowerLower) +template , HWY_IF_V_SIZE_D(D, 32)> +HWY_API V InterleaveEvenBlocks(D d, V a, V b) { + return ConcatLowerLower(d, b, a); +} + +// ------------------------------ InterleaveOddBlocks (ConcatUpperUpper) +template , HWY_IF_V_SIZE_D(D, 32)> +HWY_API V InterleaveOddBlocks(D d, V a, V b) { + return ConcatUpperUpper(d, b, a); +} + +// ------------------------------ InterleaveLowerBlocks +template , HWY_IF_V_SIZE_D(D, 32)> +HWY_API V InterleaveLowerBlocks(D d, V a, V b) { + return InterleaveEvenBlocks(d, a, b); +} +// ------------------------------ InterleaveUpperBlocks +template , HWY_IF_V_SIZE_D(D, 32)> +HWY_API V InterleaveUpperBlocks(D d, V a, V b) { + return InterleaveOddBlocks(d, a, b); +} + +// ------------------------------ Reverse (RotateRight) + +template +HWY_API VFromD Reverse(D d, const VFromD v) { + alignas(32) static constexpr int32_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +template +HWY_API VFromD Reverse(D d, const VFromD v) { + const RebindToUnsigned du; + return BitCast( + d, VFromD{__lasx_xvpermi_d(BitCast(du, v).raw, 0x1b)}); +} + +template +HWY_API VFromD Reverse(D d, const VFromD v) { + alignas(32) static constexpr int16_t kReverse[16] = { + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +template +HWY_API VFromD Reverse(D d, const VFromD v) { + alignas(32) static constexpr TFromD kReverse[32] = { + 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +// ------------------------------ Reverse4 (SwapAdjacentBlocks) + +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + const RebindToUnsigned du; + return BitCast( + d, VFromD{__lasx_xvshuf4i_h(BitCast(du, v).raw, 0x1b)}); +} + +template +HWY_API VFromD Reverse4(D /* tag */, const VFromD v) { + const RebindToUnsigned du; + return BitCast( + D(), VFromD{__lasx_xvpermi_d(BitCast(du, v).raw, 0x1b)}); +} + +// ------------------------------ Reverse8 + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + const RebindToSigned di; + const VFromD shuffle = Dup128VecFromValues( + di, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100); + return BitCast(d, TableLookupBytes(v, shuffle)); +} + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + return Reverse(d, v); +} + +template +HWY_API VFromD Reverse8(D /* tag */, const VFromD /* v */) { + HWY_ASSERT(0); +} + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +template +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvilvl_b(b.raw, a.raw)}; +} +template +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; // for float16_t + return BitCast(d, + VU{__lasx_xvilvl_h(BitCast(du, b).raw, BitCast(du, a).raw)}); +} +template +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvilvl_w(b.raw, a.raw)}; +} +template +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvilvl_d(b.raw, a.raw)}; +} + +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + const Full256 du; + const Full256 df; + return BitCast(df, Vec256{__lasx_xvilvl_w(BitCast(du, b).raw, + BitCast(du, a).raw)}); +} +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + const Full256 du; + const Full256 df; + return BitCast(df, Vec256{__lasx_xvilvl_d(BitCast(du, b).raw, + BitCast(du, a).raw)}); +} + +// ------------------------------ InterleaveUpper + +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{__lasx_xvilvh_b(b.raw, a.raw)}; +} + +template +HWY_API VFromD InterleaveUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + using VU = VFromD; // for float16_t + return BitCast(d, + VU{__lasx_xvilvh_h(BitCast(du, b).raw, BitCast(du, a).raw)}); +} + +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{__lasx_xvilvh_w(b.raw, a.raw)}; +} + +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{__lasx_xvilvh_d(b.raw, a.raw)}; +} + +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + const RebindToUnsigned du; + return BitCast(D(), VFromD{__lasx_xvilvh_w( + BitCast(du, b).raw, BitCast(du, a).raw)}); +} + +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + const RebindToUnsigned du; + return BitCast(D(), VFromD{__lasx_xvilvh_d( + BitCast(du, b).raw, BitCast(du, a).raw)}); +} + +// ------------------------------ Blocks (LowerHalf, ZeroExtendVector) + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + return BitCast(d, VFromD{__lasx_xvpermi_q( + BitCast(du, hi).raw, BitCast(du, lo).raw, 0x20)}); +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) +template +HWY_API VFromD ConcatLowerUpper(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + return BitCast(d, VFromD{__lasx_xvpermi_q( + BitCast(du, hi).raw, BitCast(du, lo).raw, 0x21)}); +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template +HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{__lasx_xvpermi_q( + BitCast(du, hi).raw, BitCast(du, lo).raw, 0x30)}); +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{__lasx_xvpermi_q( + BitCast(du, hi).raw, BitCast(du, lo).raw, 0x31)}); +} + +// ---------------------------- InsertBlock (ConcatLowerLower, ConcatUpperLower) +template +HWY_API Vec256 InsertBlock(Vec256 v, Vec128 blk_to_insert) { + static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index"); + + const DFromV d; + const auto vec_to_insert = ResizeBitCast(d, blk_to_insert); + return (kBlockIdx == 0) ? ConcatUpperLower(d, v, vec_to_insert) + : ConcatLowerLower(d, vec_to_insert, v); +} + +// ------------------------------ ConcatOdd + +template +HWY_API VFromD ConcatOdd(D /* tag */, VFromD hi, VFromD lo) { + __m256i od = __lasx_xvpickod_b(hi.raw, lo.raw); + return VFromD{__lasx_xvpermi_d(od, 0xd8)}; +} + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + __m256i od = __lasx_xvpickod_h(BitCast(du, hi).raw, BitCast(du, lo).raw); + return BitCast(d, VFromD{__lasx_xvpermi_d(od, 0xd8)}); +} + +template +HWY_API VFromD ConcatOdd(D /* tag */, VFromD hi, VFromD lo) { + __m256i od = __lasx_xvpickod_w(hi.raw, lo.raw); + return VFromD{__lasx_xvpermi_d(od, 0xd8)}; +} + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + __m256i od = __lasx_xvpickod_w(BitCast(du, hi).raw, BitCast(du, lo).raw); + return BitCast(d, VFromD{__lasx_xvpermi_d(od, 0xd8)}); +} + +template +HWY_API VFromD ConcatOdd(D /* tag */, VFromD hi, VFromD lo) { + __m256i od = __lasx_xvpickod_d(hi.raw, lo.raw); + return VFromD{__lasx_xvpermi_d(od, 0xd8)}; +} + +template +HWY_API Vec256 ConcatOdd(D d, Vec256 hi, Vec256 lo) { + const RebindToUnsigned du; + __m256i od = __lasx_xvpickod_d(BitCast(du, hi).raw, BitCast(du, lo).raw); + return BitCast(d, VFromD{__lasx_xvpermi_d(od, 0xd8)}); +} + +// ------------------------------ ConcatEven + +template +HWY_API VFromD ConcatEven(D /* tag */, VFromD hi, VFromD lo) { + __m256i ev = __lasx_xvpickev_b(hi.raw, lo.raw); + return VFromD{__lasx_xvpermi_d(ev, 0xd8)}; +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + __m256i ev = __lasx_xvpickev_h(BitCast(du, hi).raw, BitCast(du, lo).raw); + return BitCast(d, VFromD{__lasx_xvpermi_d(ev, 0xd8)}); +} + +template +HWY_API VFromD ConcatEven(D /* tag */, VFromD hi, VFromD lo) { + __m256i ev = __lasx_xvpickev_w(hi.raw, lo.raw); + return VFromD{__lasx_xvpermi_d(ev, 0xd8)}; +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + __m256i ev = __lasx_xvpickev_w(BitCast(du, hi).raw, BitCast(du, lo).raw); + return BitCast(d, VFromD{__lasx_xvpermi_d(ev, 0xd8)}); +} + +template +HWY_API VFromD ConcatEven(D /* tag */, VFromD hi, VFromD lo) { + __m256i ev = __lasx_xvpickev_d(hi.raw, lo.raw); + return VFromD{__lasx_xvpermi_d(ev, 0xd8)}; +} + +template +HWY_API Vec256 ConcatEven(D d, Vec256 hi, Vec256 lo) { + const RebindToUnsigned du; + __m256i ev = __lasx_xvpickev_d(BitCast(du, hi).raw, BitCast(du, lo).raw); + return BitCast(d, VFromD{__lasx_xvpermi_d(ev, 0xd8)}); +} + +// ------------------------------ InterleaveWholeLower + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + return ConcatLowerLower(d, InterleaveUpper(d, a, b), InterleaveLower(a, b)); +} + +// ------------------------------ InterleaveWholeUpper + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + return ConcatUpperUpper(d, InterleaveUpper(d, a, b), InterleaveLower(a, b)); +} + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec256 DupEven(Vec256 v) { + return Vec256{__lasx_xvpackev_b(v.raw, v.raw)}; +} + +template +HWY_API Vec256 DupEven(Vec256 v) { + return Vec256{__lasx_xvpackev_h(v.raw, v.raw)}; +} + +template +HWY_API Vec256 DupEven(Vec256 v) { + return Vec256{__lasx_xvpackev_w(v.raw, v.raw)}; +} + +HWY_API Vec256 DupEven(Vec256 v) { + const Full256 du; + const DFromV d; + return BitCast(d, Vec256{__lasx_xvpackev_w(BitCast(du, v).raw, + BitCast(du, v).raw)}); +} + +template +HWY_API Vec256 DupEven(const Vec256 v) { + const DFromV d; + return InterleaveLower(d, v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec256 DupOdd(Vec256 v) { + return Vec256{__lasx_xvpackod_b(v.raw, v.raw)}; +} + +template +HWY_API Vec256 DupOdd(Vec256 v) { + return Vec256{__lasx_xvpackod_h(v.raw, v.raw)}; +} + +template +HWY_API Vec256 DupOdd(Vec256 v) { + return Vec256{__lasx_xvpackod_w(v.raw, v.raw)}; +} + +HWY_API Vec256 DupOdd(Vec256 v) { + const Full256 du; + const DFromV d; + return BitCast(d, Vec256{__lasx_xvpackod_w(BitCast(du, v).raw, + BitCast(du, v).raw)}); +} + +template +HWY_API Vec256 DupOdd(const Vec256 v) { + const DFromV d; + return InterleaveUpper(d, v, v); +} + +// ------------------------------ OddEven + +template +HWY_INLINE Vec256 OddEven(Vec256 a, Vec256 b) { + __m256i c = __lasx_xvpackod_b(a.raw, a.raw); + return Vec256{__lasx_xvpackev_b(c, b.raw)}; +} + +template +HWY_INLINE Vec256 OddEven(Vec256 a, Vec256 b) { + __m256i c = __lasx_xvpackod_h(a.raw, a.raw); + return Vec256{__lasx_xvpackev_h(c, b.raw)}; +} + +template +HWY_INLINE Vec256 OddEven(Vec256 a, Vec256 b) { + __m256i c = __lasx_xvpackod_w(a.raw, a.raw); + return Vec256{__lasx_xvpackev_w(c, b.raw)}; +} + +template +HWY_INLINE Vec256 OddEven(Vec256 a, Vec256 b) { + return Vec256{__lasx_xvextrins_d(b.raw, a.raw, 0x11)}; +} + +HWY_API Vec256 OddEven(Vec256 a, Vec256 b) { + const DFromV d; + const RebindToUnsigned du; + __m256i c = __lasx_xvpackod_w(BitCast(du, a).raw, BitCast(du, a).raw); + return BitCast( + d, VFromD{__lasx_xvpackev_w(c, BitCast(du, b).raw)}); +} + +HWY_API Vec256 OddEven(Vec256 a, Vec256 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, VFromD{__lasx_xvextrins_d( + BitCast(du, b).raw, BitCast(du, a).raw, 0x11)}); +} + +// -------------------------- InterleaveEven +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{__lasx_xvpackev_b(b.raw, a.raw)}; +} +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{__lasx_xvpackev_h(b.raw, a.raw)}; +} +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + return BitCast(d, VFromD{__lasx_xvpackev_w( + BitCast(du, b).raw, BitCast(du, a).raw)}); +} +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// -------------------------- InterleaveOdd +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{__lasx_xvpackod_b(b.raw, a.raw)}; +} +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{__lasx_xvpackod_h(b.raw, a.raw)}; +} +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + return BitCast(d, VFromD{__lasx_xvpackod_w( + BitCast(du, b).raw, BitCast(du, a).raw)}); +} +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + return InterleaveUpper(d, a, b); +} + +// ------------------------------ OddEvenBlocks + +template +Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, VFromD{__lasx_xvpermi_q( + BitCast(du, odd).raw, BitCast(du, even).raw, 0x30)}); +} + +// ------------------------------ ReverseBlocks (SwapAdjacentBlocks) + +template +HWY_API VFromD ReverseBlocks(D /*d*/, VFromD v) { + return SwapAdjacentBlocks(v); +} + +// ------------------------------ TableLookupBytes (ZeroExtendVector) + +// Both full +template +HWY_API Vec256 TableLookupBytes(Vec256 bytes, Vec256 from) { + const DFromV d; + return BitCast(d, Vec256{__lasx_xvshuf_b( + BitCast(Full256(), bytes).raw, + BitCast(Full256(), bytes).raw, + BitCast(Full256(), from).raw)}); +} + +// Partial index vector +template +HWY_API Vec128 TableLookupBytes(Vec256 bytes, Vec128 from) { + const Full256 di; + const Half dih; + // First expand to full 128, then 256. + const auto from_256 = ZeroExtendVector(di, Vec128{from.raw}); + const auto tbl_full = TableLookupBytes(bytes, from_256); + // Shrink to 128, then partial. + return Vec128{LowerHalf(dih, tbl_full).raw}; +} + +// Partial table vector +template +HWY_API Vec256 TableLookupBytes(Vec128 bytes, Vec256 from) { + const Full256 d; + // First expand to full 128, then 256. + const auto bytes_256 = ZeroExtendVector(d, Vec128{bytes.raw}); + return TableLookupBytes(bytes_256, from); +} + +// ------------------------------ Per4LaneBlockShuffle + +namespace detail { + +template +HWY_INLINE VFromD Per4LaneBlkShufDupSet4xU32(D d, const uint32_t x3, + const uint32_t x2, + const uint32_t x1, + const uint32_t x0) { + alignas(32) uint32_t rawU32[8] = {x0, x1, x2, x3, x0, x1, x2, x3}; + return BitCast(d, Vec256{__lasx_xvld(rawU32, 0)}); +} + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + const DFromV d; + V idx = + Per4LaneBlkShufDupSet4xU32(d, (kIdx3210 >> 6) & 3, (kIdx3210 >> 4) & 3, + (kIdx3210 >> 2) & 3, kIdx3210 & 3); + return V{__lasx_xvshuf_w(idx.raw, v.raw, v.raw)}; +} + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + const DFromV d; + const RebindToUnsigned du; + const auto idx = + Per4LaneBlkShufDupSet4xU32(du, (kIdx3210 >> 6) & 3, (kIdx3210 >> 4) & 3, + (kIdx3210 >> 2) & 3, kIdx3210 & 3); + return BitCast(d, VFromD{__lasx_xvshuf_w( + idx.raw, BitCast(du, v).raw, BitCast(du, v).raw)}); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x44> /*idx_3210_tag*/, + hwy::SizeTag<8> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + const DFromV d; + return ConcatLowerLower(d, v, v); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xEE> /*idx_3210_tag*/, + hwy::SizeTag<8> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + const DFromV d; + return ConcatUpperUpper(d, v, v); +} + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<8> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + + const VU vu = BitCast(du, v); + return BitCast( + d, VU{__lasx_xvpermi_d(vu.raw, static_cast(kIdx3210 & 0xFF))}); +} + +} // namespace detail + +// ------------------------------ SlideUpLanes + +namespace detail { + +template +HWY_INLINE VFromD TableLookupSlideUpLanes(D d, VFromD v, size_t amt) { + const RebindToUnsigned du; + using TU = TFromD; + const auto idx = Iota(du, static_cast(size_t{0} - amt)); + const auto masked_idx = And(idx, Set(du, static_cast(MaxLanes(d) - 1))); + return BitCast( + d, IfThenElseZero( + idx == masked_idx, + TableLookupLanes(BitCast(du, v), IndicesFromVec(du, masked_idx)))); +} + +} // namespace detail + +template +HWY_API VFromD SlideUpBlocks(D d, VFromD v) { + static_assert(0 <= kBlocks && kBlocks <= 1, + "kBlocks must be between 0 and 1"); + return (kBlocks == 1) ? ConcatLowerLower(d, v, Zero(d)) : v; +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD + constexpr size_t kLanesPerBlock = 16 / sizeof(TFromD); + if (__builtin_constant_p(amt)) { + const auto v_lo = ConcatLowerLower(d, v, Zero(d)); + switch (amt * sizeof(TFromD)) { + case 0: + return v; + case 1: + return CombineShiftRightBytes<15>(d, v, v_lo); + case 2: + return CombineShiftRightBytes<14>(d, v, v_lo); + case 3: + return CombineShiftRightBytes<13>(d, v, v_lo); + case 4: + return CombineShiftRightBytes<12>(d, v, v_lo); + case 5: + return CombineShiftRightBytes<11>(d, v, v_lo); + case 6: + return CombineShiftRightBytes<10>(d, v, v_lo); + case 7: + return CombineShiftRightBytes<9>(d, v, v_lo); + case 8: + return CombineShiftRightBytes<8>(d, v, v_lo); + case 9: + return CombineShiftRightBytes<7>(d, v, v_lo); + case 10: + return CombineShiftRightBytes<6>(d, v, v_lo); + case 11: + return CombineShiftRightBytes<5>(d, v, v_lo); + case 12: + return CombineShiftRightBytes<4>(d, v, v_lo); + case 13: + return CombineShiftRightBytes<3>(d, v, v_lo); + case 14: + return CombineShiftRightBytes<2>(d, v, v_lo); + case 15: + return CombineShiftRightBytes<1>(d, v, v_lo); + } + } + + if (__builtin_constant_p(amt >= kLanesPerBlock) && amt >= kLanesPerBlock) { + const Half dh; + return Combine(d, SlideUpLanes(dh, LowerHalf(dh, v), amt - kLanesPerBlock), + Zero(dh)); + } +#endif + + return detail::TableLookupSlideUpLanes(d, v, amt); +} + +// ------------------------------ Slide1Up + +template +HWY_API VFromD Slide1Up(D d, VFromD v) { + const auto v_lo = ConcatLowerLower(d, v, Zero(d)); + return CombineShiftRightBytes<15>(d, v, v_lo); +} + +template +HWY_API VFromD Slide1Up(D d, VFromD v) { + const auto v_lo = ConcatLowerLower(d, v, Zero(d)); + return CombineShiftRightBytes<14>(d, v, v_lo); +} + +template +HWY_API VFromD Slide1Up(D d, VFromD v) { + const auto v_lo = ConcatLowerLower(d, v, Zero(d)); + return CombineShiftRightBytes<12>(d, v, v_lo); +} + +template +HWY_API VFromD Slide1Up(D d, VFromD v) { + const auto v_lo = ConcatLowerLower(d, v, Zero(d)); + return CombineShiftRightBytes<8>(d, v, v_lo); +} + +// ------------------------------ SlideDownLanes + +namespace detail { + +template +HWY_INLINE VFromD TableLookupSlideDownLanes(D d, VFromD v, size_t amt) { + const RebindToUnsigned du; + using TU = TFromD; + const auto idx = Iota(du, static_cast(amt)); + const auto masked_idx = And(idx, Set(du, static_cast(MaxLanes(d) - 1))); + return IfThenElseZero(RebindMask(d, idx == masked_idx), + TableLookupLanes(v, IndicesFromVec(d, masked_idx))); +} + +} // namespace detail + +template +HWY_API VFromD SlideDownBlocks(D d, VFromD v) { + static_assert(0 <= kBlocks && kBlocks <= 1, + "kBlocks must be between 0 and 1"); + const Half dh; + return (kBlocks == 1) ? ZeroExtendVector(d, UpperHalf(dh, v)) : v; +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD + constexpr size_t kLanesPerBlock = 16 / sizeof(TFromD); + const Half dh; + if (__builtin_constant_p(amt)) { + const auto v_hi = ZeroExtendVector(d, UpperHalf(dh, v)); + switch (amt * sizeof(TFromD)) { + case 0: + return v; + case 1: + return CombineShiftRightBytes<1>(d, v_hi, v); + case 2: + return CombineShiftRightBytes<2>(d, v_hi, v); + case 3: + return CombineShiftRightBytes<3>(d, v_hi, v); + case 4: + return CombineShiftRightBytes<4>(d, v_hi, v); + case 5: + return CombineShiftRightBytes<5>(d, v_hi, v); + case 6: + return CombineShiftRightBytes<6>(d, v_hi, v); + case 7: + return CombineShiftRightBytes<7>(d, v_hi, v); + case 8: + return CombineShiftRightBytes<8>(d, v_hi, v); + case 9: + return CombineShiftRightBytes<9>(d, v_hi, v); + case 10: + return CombineShiftRightBytes<10>(d, v_hi, v); + case 11: + return CombineShiftRightBytes<11>(d, v_hi, v); + case 12: + return CombineShiftRightBytes<12>(d, v_hi, v); + case 13: + return CombineShiftRightBytes<13>(d, v_hi, v); + case 14: + return CombineShiftRightBytes<14>(d, v_hi, v); + case 15: + return CombineShiftRightBytes<15>(d, v_hi, v); + } + } + + if (__builtin_constant_p(amt >= kLanesPerBlock) && amt >= kLanesPerBlock) { + return ZeroExtendVector( + d, SlideDownLanes(dh, UpperHalf(dh, v), amt - kLanesPerBlock)); + } +#endif + + return detail::TableLookupSlideDownLanes(d, v, amt); +} + +// ------------------------------ Slide1Down + +template +HWY_API VFromD Slide1Down(D d, VFromD v) { + const Half dh; + const auto v_hi = ZeroExtendVector(d, UpperHalf(dh, v)); + return CombineShiftRightBytes<1>(d, v_hi, v); +} + +template +HWY_API VFromD Slide1Down(D d, VFromD v) { + const Half dh; + const auto v_hi = ZeroExtendVector(d, UpperHalf(dh, v)); + return CombineShiftRightBytes<2>(d, v_hi, v); +} + +template +HWY_API VFromD Slide1Down(D d, VFromD v) { + const Half dh; + const auto v_hi = ZeroExtendVector(d, UpperHalf(dh, v)); + return CombineShiftRightBytes<4>(d, v_hi, v); +} + +template +HWY_API VFromD Slide1Down(D d, VFromD v) { + const Half dh; + const auto v_hi = ZeroExtendVector(d, UpperHalf(dh, v)); + return CombineShiftRightBytes<8>(d, v_hi, v); +} + +// ------------------------------ Shl (Mul, ZipLower) +namespace detail { + +HWY_INLINE Vec256 Shl(hwy::UnsignedTag /*tag*/, Vec256 v, + Vec256 bits) { + return Vec256{__lasx_xvsll_b(v.raw, bits.raw)}; +} + +HWY_INLINE Vec256 Shl(hwy::UnsignedTag /*tag*/, Vec256 v, + Vec256 bits) { + return Vec256{__lasx_xvsll_h(v.raw, bits.raw)}; +} + +HWY_INLINE Vec256 Shl(hwy::UnsignedTag /*tag*/, Vec256 v, + Vec256 bits) { + return Vec256{__lasx_xvsll_w(v.raw, bits.raw)}; +} + +HWY_INLINE Vec256 Shl(hwy::UnsignedTag /*tag*/, Vec256 v, + Vec256 bits) { + return Vec256{__lasx_xvsll_d(v.raw, bits.raw)}; +} + +template +HWY_INLINE Vec256 Shl(hwy::SignedTag /*tag*/, Vec256 v, Vec256 bits) { + // Signed left shifts are the same as unsigned. + const Full256 di; + const Full256> du; + return BitCast(di, + Shl(hwy::UnsignedTag(), BitCast(du, v), BitCast(du, bits))); +} + +} // namespace detail + +template +HWY_API Vec256 operator<<(Vec256 v, Vec256 bits) { + return detail::Shl(hwy::TypeTag(), v, bits); +} + +// ------------------------------ Shr (MulHigh, IfThenElse, Not) + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + return Vec256{__lasx_xvsrl_b(v.raw, bits.raw)}; +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + return Vec256{__lasx_xvsrl_h(v.raw, bits.raw)}; +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + return Vec256{__lasx_xvsrl_w(v.raw, bits.raw)}; +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + return Vec256{__lasx_xvsrl_d(v.raw, bits.raw)}; +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + return Vec256{__lasx_xvsra_b(v.raw, bits.raw)}; +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + return Vec256{__lasx_xvsra_h(v.raw, bits.raw)}; +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + return Vec256{__lasx_xvsra_w(v.raw, bits.raw)}; +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + return Vec256{__lasx_xvsra_d(v.raw, bits.raw)}; +} + +// ------------------------------ WidenMulPairwiseAdd + +template +HWY_API VFromD WidenMulPairwiseAdd(D /*d32*/, Vec256 a, + Vec256 b) { + __m256i ev = __lasx_xvmulwev_w_h(b.raw, a.raw); + return VFromD{__lasx_xvmaddwod_w_h(ev, b.raw, a.raw)}; +} + +template +HWY_API VFromD WidenMulPairwiseAdd(D /*d32*/, Vec256 a, + Vec256 b) { + __m256i ev = __lasx_xvmulwev_w_hu(b.raw, a.raw); + return VFromD{__lasx_xvmaddwod_w_hu(ev, b.raw, a.raw)}; +} + +// ------------------------------ ReorderWidenMulAccumulate + +template +HWY_API VFromD ReorderWidenMulAccumulate(D /*tag*/, Vec256 a, + Vec256 b, + const VFromD sum0, + VFromD& /*sum1*/) { + return VFromD{__lasx_xvmaddwev_w_h( + __lasx_xvmaddwod_w_h(sum0.raw, a.raw, b.raw), a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderWidenMulAccumulate(D /*tag*/, Vec256 a, + Vec256 b, + const VFromD sum0, + VFromD& /*sum1*/) { + return VFromD{__lasx_xvmaddwev_w_hu( + __lasx_xvmaddwod_w_hu(sum0.raw, a.raw, b.raw), a.raw, b.raw)}; +} + +// ------------------------------ RearrangeToOddPlusEven +HWY_API Vec256 RearrangeToOddPlusEven(const Vec256 sum0, + Vec256 /*sum1*/) { + return sum0; // invariant already holds +} + +HWY_API Vec256 RearrangeToOddPlusEven(const Vec256 sum0, + Vec256 /*sum1*/) { + return sum0; // invariant already holds +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + const Repartition df16; + const auto from_128 = ZeroExtendVector(df16, v); + const VFromD f16_concat{__lasx_xvpermi_d(from_128.raw, 0xd8)}; + return VFromD{__lasx_xvfcvtl_s_h(f16_concat.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + const Repartition df; + const RebindToSigned di; + const auto from_128 = ZeroExtendVector(df, v); + const auto f32_concat = BitCast( + df, Vec256{__lasx_xvpermi_d(BitCast(di, from_128).raw, 0xd8)}); + return VFromD{__lasx_xvfcvtl_d_s(f32_concat.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /*di64*/, Vec128 v) { + const Repartition df; + const RebindToSigned di; + const auto from_128 = ZeroExtendVector(df, v); + const auto f32_concat = BitCast( + df, Vec256{__lasx_xvpermi_d(BitCast(di, from_128).raw, 0xd8)}); + return VFromD{__lasx_xvftintrzl_l_s(f32_concat.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + alignas(32) __m128i vec_tmp[2]; + __m256i vec_temp; + vec_tmp[0] = v.raw; + CopyBytes<32>(vec_tmp, &vec_temp); + vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8); + vec_temp = __lasx_xvsllwil_d_w(vec_temp, 0); + return VFromD{__lasx_xvffint_d_l(vec_temp)}; +} + +template +HWY_API Vec256 PromoteTo(D /* tag */, Vec128 v) { + alignas(32) __m128i vec_tmp[2]; + __m256i vec_temp; + vec_tmp[0] = v.raw; + CopyBytes<32>(vec_tmp, &vec_temp); + vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8); + vec_temp = __lasx_xvsllwil_du_wu(vec_temp, 0); + return VFromD{__lasx_xvffint_d_lu(vec_temp)}; +} + +// Unsigned: zero-extend. +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + alignas(32) __m128i vec_tmp[2]; + __m256i vec_temp; + vec_tmp[0] = v.raw; + CopyBytes<32>(vec_tmp, &vec_temp); + vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8); + return VFromD{__lasx_xvsllwil_hu_bu(vec_temp, 0)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + alignas(32) __m128i vec_tmp[2]; + __m256i vec_temp; + vec_tmp[0] = v.raw; + CopyBytes<32>(vec_tmp, &vec_temp); + vec_temp = __lasx_xvsllwil_hu_bu(vec_temp, 0); + vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8); + return VFromD{__lasx_xvsllwil_wu_hu(vec_temp, 0)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + alignas(32) __m128i vec_tmp[2]; + __m256i vec_temp; + vec_tmp[0] = v.raw; + CopyBytes<32>(vec_tmp, &vec_temp); + vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8); + return VFromD{__lasx_xvsllwil_wu_hu(vec_temp, 0)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + alignas(32) __m128i vec_tmp[2]; + __m256i vec_temp; + vec_tmp[0] = v.raw; + CopyBytes<32>(vec_tmp, &vec_temp); + vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8); + return VFromD{__lasx_xvsllwil_du_wu(vec_temp, 0)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec64 v) { + alignas(32) __m128i vec_tmp[2]; + __m256i vec_temp; + vec_tmp[0] = v.raw; + CopyBytes<32>(vec_tmp, &vec_temp); + vec_temp = __lasx_xvsllwil_wu_hu(vec_temp, 0); + vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8); + return VFromD{__lasx_xvsllwil_du_wu(vec_temp, 0)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec32 v) { + alignas(32) __m128i vec_tmp[2]; + __m256i vec_temp; + vec_tmp[0] = v.raw; + CopyBytes<32>(vec_tmp, &vec_temp); + vec_temp = __lasx_xvsllwil_hu_bu(vec_temp, 0); + vec_temp = __lasx_xvsllwil_wu_hu(vec_temp, 0); + vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8); + return VFromD{__lasx_xvsllwil_du_wu(vec_temp, 0)}; +} + +// Signed: replicate sign bit. +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + alignas(32) __m128i vec_tmp[2]; + __m256i vec_temp; + vec_tmp[0] = v.raw; + CopyBytes<32>(vec_tmp, &vec_temp); + vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8); + return VFromD{__lasx_xvsllwil_h_b(vec_temp, 0)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + alignas(32) __m128i vec_tmp[2]; + __m256i vec_temp; + vec_tmp[0] = v.raw; + CopyBytes<32>(vec_tmp, &vec_temp); + vec_temp = __lasx_xvsllwil_h_b(vec_temp, 0); + vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8); + return VFromD{__lasx_xvsllwil_w_h(vec_temp, 0)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + alignas(32) __m128i vec_tmp[2]; + __m256i vec_temp; + vec_tmp[0] = v.raw; + CopyBytes<32>(vec_tmp, &vec_temp); + vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8); + return VFromD{__lasx_xvsllwil_w_h(vec_temp, 0)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + alignas(32) __m128i vec_tmp[2]; + __m256i vec_temp; + vec_tmp[0] = v.raw; + CopyBytes<32>(vec_tmp, &vec_temp); + vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8); + return VFromD{__lasx_xvsllwil_d_w(vec_temp, 0)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec64 v) { + alignas(32) __m128i vec_tmp[2]; + __m256i vec_temp; + vec_tmp[0] = v.raw; + CopyBytes<32>(vec_tmp, &vec_temp); + vec_temp = __lasx_xvsllwil_w_h(vec_temp, 0); + vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8); + return VFromD{__lasx_xvsllwil_d_w(vec_temp, 0)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec32 v) { + alignas(32) __m128i vec_tmp[2]; + __m256i vec_temp; + vec_tmp[0] = v.raw; + CopyBytes<32>(vec_tmp, &vec_temp); + vec_temp = __lasx_xvsllwil_h_b(vec_temp, 0); + vec_temp = __lasx_xvsllwil_w_h(vec_temp, 0); + vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8); + return VFromD{__lasx_xvsllwil_d_w(vec_temp, 0)}; +} + +// ------------------------------ PromoteEvenTo/PromoteOddTo +namespace detail { + +// I32->I64 PromoteEvenTo/PromoteOddTo + +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + Vec256 v) { + return BitCast(d_to, OddEven(DupEven(BroadcastSignBit(v)), v)); +} + +template +HWY_INLINE VFromD PromoteOddTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + Vec256 v) { + return BitCast(d_to, OddEven(BroadcastSignBit(v), DupOdd(v))); +} + +} // namespace detail + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec256 a, + Vec256 b) { + return VFromD{__lasx_xvssrani_b_h(b.raw, a.raw, 0)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec256 a, + Vec256 b) { + return VFromD{__lasx_xvssrani_bu_h(b.raw, a.raw, 0)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec256 a, + Vec256 b) { + return VFromD{__lasx_xvssrlni_b_h(b.raw, a.raw, 0)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec256 a, + Vec256 b) { + return VFromD{__lasx_xvssrlni_bu_h(b.raw, a.raw, 0)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec256 a, + Vec256 b) { + return VFromD{__lasx_xvssrani_h_w(b.raw, a.raw, 0)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec256 a, + Vec256 b) { + return VFromD{__lasx_xvssrani_hu_w(b.raw, a.raw, 0)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec256 a, + Vec256 b) { + return VFromD{__lasx_xvssrlni_h_w(b.raw, a.raw, 0)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec256 a, + Vec256 b) { + return VFromD{__lasx_xvssrlni_hu_w(b.raw, a.raw, 0)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec256 a, + Vec256 b) { + return VFromD{__lasx_xvssrani_w_d(b.raw, a.raw, 0)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec256 a, + Vec256 b) { + return VFromD{__lasx_xvssrani_wu_d(b.raw, a.raw, 0)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec256 a, + Vec256 b) { + return VFromD{__lasx_xvssrlni_w_d(b.raw, a.raw, 0)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec256 a, + Vec256 b) { + return VFromD{__lasx_xvssrlni_wu_d(b.raw, a.raw, 0)}; +} + +template ), + HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedDemote2To(D d, V a, V b) { + return VFromD{__lasx_xvpermi_d(ReorderDemote2To(d, a, b).raw, 0xd8)}; +} + +template ) * 2), + HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV))> +HWY_API VFromD DemoteTo(D d, V v) { + return LowerHalf(OrderedDemote2To(Twice(), v, v)); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + const Full256 di; + const Vec256 f16_blocks{__lasx_xvfcvt_h_s(v.raw, v.raw)}; + const auto f16_concat = + BitCast(Twice(), VFromD{__lasx_xvpermi_d( + BitCast(di, f16_blocks).raw, 0xd8)}); + return LowerHalf(f16_concat); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + const Full256 di; + const Vec256 f32_blocks{__lasx_xvfcvt_s_d(v.raw, v.raw)}; + const auto f32_concat = + BitCast(Twice(), VFromD{__lasx_xvpermi_d( + BitCast(di, f32_blocks).raw, 0xd8)}); + return LowerHalf(f32_concat); +} + +template +HWY_API VFromD DemoteTo(D dn, Vec256 v) { + const __m256i i32_blocks = __lasx_xvftintrz_w_d(v.raw, v.raw); + return LowerHalf(dn, VFromD>{__lasx_xvpermi_d(i32_blocks, 0xd8)}); +} + +// For already range-limited input [0, 255]. +HWY_API Vec128 U8FromU32(const Vec256 v) { + const Full256 d32; + const Full64 d8; + alignas(32) static constexpr uint32_t k8From32[8] = { + 0x0C080400u, 0x13121110u, 0, 0, 0x13121110u, 0x0C080400u, 0, 0}; + // Place first four bytes in lo[0], remaining 4 in hi[1]. + const auto quad = VFromD{ + __lasx_xvshuf_b(Zero(d32).raw, v.raw, Load(d32, k8From32).raw)}; + // Interleave both quadruplets - OR instead of unpack reduces port5 pressure. + const auto lo = LowerHalf(quad); + const auto hi = UpperHalf(Half(), quad); + return BitCast(d8, LowerHalf(lo | hi)); +} + +// ------------------------------ Truncations + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec256 v) { + const Full256 d8; + alignas(32) static constexpr uint8_t kMap[32] = {0, 8, 16, 24}; + const auto i8 = TableLookupLanes(BitCast(d8, v), SetTableIndices(d8, kMap)); + return LowerHalf(LowerHalf(LowerHalf(Vec256{i8.raw}))); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec256 v) { + const __m256i i32_blocks = __lasx_xvpickev_w(v.raw, v.raw); + const __m256i i32_concat = __lasx_xvpermi_d(i32_blocks, 0xd8); + const __m256i i16 = __lasx_xvpickev_h(i32_concat, i32_concat); + return LowerHalf(LowerHalf(Vec256{i16})); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec256 v) { + const Full256 d32; + alignas(32) static constexpr uint32_t kEven[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto v32 = + TableLookupLanes(BitCast(d32, v), SetTableIndices(d32, kEven)); + return LowerHalf(Vec256{v32.raw}); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec256 v) { + const Full256 d8; + alignas(32) static constexpr uint8_t kEven[32] = {0, 4, 8, 12, + 16, 20, 24, 28}; + const auto i8 = TableLookupLanes(BitCast(d8, v), SetTableIndices(d8, kEven)); + return LowerHalf(LowerHalf(Vec256{i8.raw})); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec256 v) { + const __m256i i16_blocks = __lasx_xvpickev_h(v.raw, v.raw); + const __m256i i16_concat = __lasx_xvpermi_d(i16_blocks, 0xd8); + return LowerHalf(Vec256{i16_concat}); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec256 v) { + const __m256i i8_blocks = __lasx_xvpickev_b(v.raw, v.raw); + const __m256i i8_concat = __lasx_xvpermi_d(i8_blocks, 0xd8); + return LowerHalf(Vec256{i8_concat}); +} + +// ------------------------------ Integer <=> fp (ShiftRight, OddEven) + +template +HWY_API VFromD ConvertTo(D /* tag */, Vec256 v) { + return VFromD{__lasx_xvffint_s_w(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /*df*/, Vec256 v) { + return VFromD{__lasx_xvffint_s_wu(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /*dd*/, Vec256 v) { + return VFromD{__lasx_xvffint_d_l(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /*dd*/, Vec256 v) { + return VFromD{__lasx_xvffint_d_lu(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /*d*/, Vec256 v) { + return VFromD{__lasx_xvftintrz_w_s(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /*di*/, Vec256 v) { + return VFromD{__lasx_xvftintrz_l_d(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(DU /*du*/, VFromD> v) { + return VFromD{__lasx_xvftintrz_wu_s(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(DU /*du*/, VFromD> v) { + return VFromD{__lasx_xvftintrz_lu_d(v.raw)}; +} + +template +HWY_API Vec256> NearestInt(const Vec256 v) { + return ConvertTo(Full256>(), Round(v)); +} + +// ------------------------------ LoadMaskBits (TestBit) + +namespace detail { + +template +HWY_INLINE Mask256 LoadMaskBits256(uint64_t mask_bits) { + const Full256 d; + const RebindToUnsigned du; + const Repartition du32; + const auto vbits = BitCast(du, Set(du32, static_cast(mask_bits))); + + // Replicate bytes 8x such that each byte contains the bit that governs it. + const Repartition du64; + alignas(32) static constexpr uint64_t kRep8[4] = { + 0x0000000000000000ull, 0x0101010101010101ull, 0x0202020202020202ull, + 0x0303030303030303ull}; + const auto rep8 = TableLookupBytes(vbits, BitCast(du, Load(du64, kRep8))); + + const VFromD bit = Dup128VecFromValues( + du, 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); + return RebindMask(d, TestBit(rep8, bit)); +} + +template +HWY_INLINE Mask256 LoadMaskBits256(uint64_t mask_bits) { + const Full256 d; + const RebindToUnsigned du; + alignas(32) static constexpr uint16_t kBit[16] = { + 1, 2, 4, 8, 16, 32, 64, 128, + 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE Mask256 LoadMaskBits256(uint64_t mask_bits) { + const Full256 d; + const RebindToUnsigned du; + alignas(32) static constexpr uint32_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE Mask256 LoadMaskBits256(uint64_t mask_bits) { + const Full256 d; + const RebindToUnsigned du; + alignas(32) static constexpr uint64_t kBit[8] = {1, 2, 4, 8}; + return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); +} + +} // namespace detail + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + constexpr size_t kN = MaxLanes(d); + constexpr size_t kNumBytes = (kN + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes(bits, &mask_bits); + + if (kN < 8) { + mask_bits &= (1ull << kN) - 1; + } + + return detail::LoadMaskBits256>(mask_bits); +} + +// ------------------------------ BitsFromMask + +template +HWY_API uint64_t BitsFromMask(D /*tag*/, MFromD mask) { + const auto sign_bits = __lasx_xvmskltz_b(mask.raw); + return static_cast(__lasx_xvpickve2gr_w(sign_bits, 0) | + (__lasx_xvpickve2gr_w(sign_bits, 4) << 16)); +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const RebindToSigned di; + const auto vec_mask = VecFromMask(mask); + const auto sign_bits = + __lasx_xvpickod_b(BitCast(di, vec_mask).raw, BitCast(di, vec_mask).raw); + const auto sign_shuf = __lasx_xvpermi_d(sign_bits, 0xd8); + const auto sign_last = __lasx_xvmskltz_b(sign_shuf); + return static_cast(__lasx_xvpickve2gr_w(sign_last, 0)); +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const RebindToSigned di; + const auto vec_mask = VecFromMask(mask); + const auto sign_bits = + __lasx_xvpickod_h(BitCast(di, vec_mask).raw, BitCast(di, vec_mask).raw); + const auto sign_shuf = __lasx_xvpermi_d(sign_bits, 0xd8); + const auto sign_last = __lasx_xvmskltz_h(sign_shuf); + return static_cast(__lasx_xvpickve2gr_w(sign_last, 0)); +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const RebindToSigned di; + const auto vec_mask = VecFromMask(mask); + const auto sign_bits = + __lasx_xvpickod_w(BitCast(di, vec_mask).raw, BitCast(di, vec_mask).raw); + const auto sign_shuf = __lasx_xvpermi_d(sign_bits, 0xd8); + const auto sign_last = __lasx_xvmskltz_w(sign_shuf); + return static_cast(__lasx_xvpickve2gr_w(sign_last, 0)); +} + +// ------------------------------ StoreMaskBits +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(D d, MFromD mask, uint8_t* bits) { + constexpr size_t N = MaxLanes(d); + constexpr size_t kNumBytes = (N + 7) / 8; + + const uint64_t mask_bits = BitsFromMask(d, mask); + CopyBytes(&mask_bits, bits); + return kNumBytes; +} + +// ------------------------------ Mask testing + +template +HWY_API bool AllFalse(D d, MFromD mask) { + return BitsFromMask(d, mask) == 0; +} + +template +HWY_API bool AllTrue(D d, MFromD mask) { + constexpr size_t kN = MaxLanes(d); + constexpr uint64_t kAllBits = (1ull << kN) - 1; + return BitsFromMask(d, mask) == kAllBits; +} + +template +HWY_API size_t CountTrue(D d, MFromD mask) { + return PopCount(BitsFromMask(d, mask)); +} + +template +HWY_API size_t FindKnownFirstTrue(D d, MFromD mask) { + const uint32_t mask_bits = static_cast(BitsFromMask(d, mask)); + return Num0BitsBelowLS1Bit_Nonzero32(mask_bits); +} + +template +HWY_API intptr_t FindFirstTrue(D d, MFromD mask) { + const uint32_t mask_bits = static_cast(BitsFromMask(d, mask)); + return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero32(mask_bits)) : -1; +} + +template +HWY_API size_t FindKnownLastTrue(D d, MFromD mask) { + const uint32_t mask_bits = static_cast(BitsFromMask(d, mask)); + return 31 - Num0BitsAboveMS1Bit_Nonzero32(mask_bits); +} + +template +HWY_API intptr_t FindLastTrue(D d, MFromD mask) { + const uint32_t mask_bits = static_cast(BitsFromMask(d, mask)); + return mask_bits ? intptr_t(31 - Num0BitsAboveMS1Bit_Nonzero32(mask_bits)) + : -1; +} + +// ------------------------------ Compress, CompressBits + +namespace detail { + +template +HWY_INLINE Vec256 IndicesFromBits256(uint64_t mask_bits) { + const Full256 d32; + // We need a masked Iota(). With 8 lanes, there are 256 combinations and a LUT + // of SetTableIndices would require 8 KiB, a large part of L1D. We instead + // compress each index into 4 bits, for a total of 1 KiB. + alignas(16) static constexpr uint32_t packed_array[256] = { + // PrintCompress32x8Tables + 0x76543210, 0x76543218, 0x76543209, 0x76543298, 0x7654310a, 0x765431a8, + 0x765430a9, 0x76543a98, 0x7654210b, 0x765421b8, 0x765420b9, 0x76542b98, + 0x765410ba, 0x76541ba8, 0x76540ba9, 0x7654ba98, 0x7653210c, 0x765321c8, + 0x765320c9, 0x76532c98, 0x765310ca, 0x76531ca8, 0x76530ca9, 0x7653ca98, + 0x765210cb, 0x76521cb8, 0x76520cb9, 0x7652cb98, 0x76510cba, 0x7651cba8, + 0x7650cba9, 0x765cba98, 0x7643210d, 0x764321d8, 0x764320d9, 0x76432d98, + 0x764310da, 0x76431da8, 0x76430da9, 0x7643da98, 0x764210db, 0x76421db8, + 0x76420db9, 0x7642db98, 0x76410dba, 0x7641dba8, 0x7640dba9, 0x764dba98, + 0x763210dc, 0x76321dc8, 0x76320dc9, 0x7632dc98, 0x76310dca, 0x7631dca8, + 0x7630dca9, 0x763dca98, 0x76210dcb, 0x7621dcb8, 0x7620dcb9, 0x762dcb98, + 0x7610dcba, 0x761dcba8, 0x760dcba9, 0x76dcba98, 0x7543210e, 0x754321e8, + 0x754320e9, 0x75432e98, 0x754310ea, 0x75431ea8, 0x75430ea9, 0x7543ea98, + 0x754210eb, 0x75421eb8, 0x75420eb9, 0x7542eb98, 0x75410eba, 0x7541eba8, + 0x7540eba9, 0x754eba98, 0x753210ec, 0x75321ec8, 0x75320ec9, 0x7532ec98, + 0x75310eca, 0x7531eca8, 0x7530eca9, 0x753eca98, 0x75210ecb, 0x7521ecb8, + 0x7520ecb9, 0x752ecb98, 0x7510ecba, 0x751ecba8, 0x750ecba9, 0x75ecba98, + 0x743210ed, 0x74321ed8, 0x74320ed9, 0x7432ed98, 0x74310eda, 0x7431eda8, + 0x7430eda9, 0x743eda98, 0x74210edb, 0x7421edb8, 0x7420edb9, 0x742edb98, + 0x7410edba, 0x741edba8, 0x740edba9, 0x74edba98, 0x73210edc, 0x7321edc8, + 0x7320edc9, 0x732edc98, 0x7310edca, 0x731edca8, 0x730edca9, 0x73edca98, + 0x7210edcb, 0x721edcb8, 0x720edcb9, 0x72edcb98, 0x710edcba, 0x71edcba8, + 0x70edcba9, 0x7edcba98, 0x6543210f, 0x654321f8, 0x654320f9, 0x65432f98, + 0x654310fa, 0x65431fa8, 0x65430fa9, 0x6543fa98, 0x654210fb, 0x65421fb8, + 0x65420fb9, 0x6542fb98, 0x65410fba, 0x6541fba8, 0x6540fba9, 0x654fba98, + 0x653210fc, 0x65321fc8, 0x65320fc9, 0x6532fc98, 0x65310fca, 0x6531fca8, + 0x6530fca9, 0x653fca98, 0x65210fcb, 0x6521fcb8, 0x6520fcb9, 0x652fcb98, + 0x6510fcba, 0x651fcba8, 0x650fcba9, 0x65fcba98, 0x643210fd, 0x64321fd8, + 0x64320fd9, 0x6432fd98, 0x64310fda, 0x6431fda8, 0x6430fda9, 0x643fda98, + 0x64210fdb, 0x6421fdb8, 0x6420fdb9, 0x642fdb98, 0x6410fdba, 0x641fdba8, + 0x640fdba9, 0x64fdba98, 0x63210fdc, 0x6321fdc8, 0x6320fdc9, 0x632fdc98, + 0x6310fdca, 0x631fdca8, 0x630fdca9, 0x63fdca98, 0x6210fdcb, 0x621fdcb8, + 0x620fdcb9, 0x62fdcb98, 0x610fdcba, 0x61fdcba8, 0x60fdcba9, 0x6fdcba98, + 0x543210fe, 0x54321fe8, 0x54320fe9, 0x5432fe98, 0x54310fea, 0x5431fea8, + 0x5430fea9, 0x543fea98, 0x54210feb, 0x5421feb8, 0x5420feb9, 0x542feb98, + 0x5410feba, 0x541feba8, 0x540feba9, 0x54feba98, 0x53210fec, 0x5321fec8, + 0x5320fec9, 0x532fec98, 0x5310feca, 0x531feca8, 0x530feca9, 0x53feca98, + 0x5210fecb, 0x521fecb8, 0x520fecb9, 0x52fecb98, 0x510fecba, 0x51fecba8, + 0x50fecba9, 0x5fecba98, 0x43210fed, 0x4321fed8, 0x4320fed9, 0x432fed98, + 0x4310feda, 0x431feda8, 0x430feda9, 0x43feda98, 0x4210fedb, 0x421fedb8, + 0x420fedb9, 0x42fedb98, 0x410fedba, 0x41fedba8, 0x40fedba9, 0x4fedba98, + 0x3210fedc, 0x321fedc8, 0x320fedc9, 0x32fedc98, 0x310fedca, 0x31fedca8, + 0x30fedca9, 0x3fedca98, 0x210fedcb, 0x21fedcb8, 0x20fedcb9, 0x2fedcb98, + 0x10fedcba, 0x1fedcba8, 0x0fedcba9, 0xfedcba98}; + + // No need to mask because __lasx_xvperm_w ignores bits 3..31. + // Just shift each copy of the 32 bit LUT to extract its 4-bit fields. + const auto packed = Set(d32, packed_array[mask_bits]); + alignas(32) static constexpr uint32_t shifts[8] = {0, 4, 8, 12, + 16, 20, 24, 28}; + return packed >> Load(d32, shifts); +} + +template +HWY_INLINE Vec256 IndicesFromBits256(uint64_t mask_bits) { + const Full256 d64; + + // For 64-bit, there are only 4 lanes, so we can afford to load the + // entire index vector directly. + alignas(32) static constexpr uint64_t u64_indices[64] = { + // PrintCompress64x4PairTables + 0, 1, 2, 3, 8, 1, 2, 3, 9, 0, 2, 3, 8, 9, 2, 3, + 10, 0, 1, 3, 8, 10, 1, 3, 9, 10, 0, 3, 8, 9, 10, 3, + 11, 0, 1, 2, 8, 11, 1, 2, 9, 11, 0, 2, 8, 9, 11, 2, + 10, 11, 0, 1, 8, 10, 11, 1, 9, 10, 11, 0, 8, 9, 10, 11}; + return Load(d64, u64_indices + 4 * mask_bits); +} + +template +HWY_INLINE Vec256 IndicesFromNotBits256(uint64_t mask_bits) { + const Full256 d32; + // We need a masked Iota(). With 8 lanes, there are 256 combinations and a LUT + // of SetTableIndices would require 8 KiB, a large part of L1D. We instead + // compress each index into 4 bits, for a total of 1 KiB. + alignas(16) static constexpr uint32_t packed_array[256] = { + // PrintCompressNot32x8Tables + 0xfedcba98, 0x8fedcba9, 0x9fedcba8, 0x98fedcba, 0xafedcb98, 0xa8fedcb9, + 0xa9fedcb8, 0xa98fedcb, 0xbfedca98, 0xb8fedca9, 0xb9fedca8, 0xb98fedca, + 0xbafedc98, 0xba8fedc9, 0xba9fedc8, 0xba98fedc, 0xcfedba98, 0xc8fedba9, + 0xc9fedba8, 0xc98fedba, 0xcafedb98, 0xca8fedb9, 0xca9fedb8, 0xca98fedb, + 0xcbfeda98, 0xcb8feda9, 0xcb9feda8, 0xcb98feda, 0xcbafed98, 0xcba8fed9, + 0xcba9fed8, 0xcba98fed, 0xdfecba98, 0xd8fecba9, 0xd9fecba8, 0xd98fecba, + 0xdafecb98, 0xda8fecb9, 0xda9fecb8, 0xda98fecb, 0xdbfeca98, 0xdb8feca9, + 0xdb9feca8, 0xdb98feca, 0xdbafec98, 0xdba8fec9, 0xdba9fec8, 0xdba98fec, + 0xdcfeba98, 0xdc8feba9, 0xdc9feba8, 0xdc98feba, 0xdcafeb98, 0xdca8feb9, + 0xdca9feb8, 0xdca98feb, 0xdcbfea98, 0xdcb8fea9, 0xdcb9fea8, 0xdcb98fea, + 0xdcbafe98, 0xdcba8fe9, 0xdcba9fe8, 0xdcba98fe, 0xefdcba98, 0xe8fdcba9, + 0xe9fdcba8, 0xe98fdcba, 0xeafdcb98, 0xea8fdcb9, 0xea9fdcb8, 0xea98fdcb, + 0xebfdca98, 0xeb8fdca9, 0xeb9fdca8, 0xeb98fdca, 0xebafdc98, 0xeba8fdc9, + 0xeba9fdc8, 0xeba98fdc, 0xecfdba98, 0xec8fdba9, 0xec9fdba8, 0xec98fdba, + 0xecafdb98, 0xeca8fdb9, 0xeca9fdb8, 0xeca98fdb, 0xecbfda98, 0xecb8fda9, + 0xecb9fda8, 0xecb98fda, 0xecbafd98, 0xecba8fd9, 0xecba9fd8, 0xecba98fd, + 0xedfcba98, 0xed8fcba9, 0xed9fcba8, 0xed98fcba, 0xedafcb98, 0xeda8fcb9, + 0xeda9fcb8, 0xeda98fcb, 0xedbfca98, 0xedb8fca9, 0xedb9fca8, 0xedb98fca, + 0xedbafc98, 0xedba8fc9, 0xedba9fc8, 0xedba98fc, 0xedcfba98, 0xedc8fba9, + 0xedc9fba8, 0xedc98fba, 0xedcafb98, 0xedca8fb9, 0xedca9fb8, 0xedca98fb, + 0xedcbfa98, 0xedcb8fa9, 0xedcb9fa8, 0xedcb98fa, 0xedcbaf98, 0xedcba8f9, + 0xedcba9f8, 0xedcba98f, 0xfedcba98, 0xf8edcba9, 0xf9edcba8, 0xf98edcba, + 0xfaedcb98, 0xfa8edcb9, 0xfa9edcb8, 0xfa98edcb, 0xfbedca98, 0xfb8edca9, + 0xfb9edca8, 0xfb98edca, 0xfbaedc98, 0xfba8edc9, 0xfba9edc8, 0xfba98edc, + 0xfcedba98, 0xfc8edba9, 0xfc9edba8, 0xfc98edba, 0xfcaedb98, 0xfca8edb9, + 0xfca9edb8, 0xfca98edb, 0xfcbeda98, 0xfcb8eda9, 0xfcb9eda8, 0xfcb98eda, + 0xfcbaed98, 0xfcba8ed9, 0xfcba9ed8, 0xfcba98ed, 0xfdecba98, 0xfd8ecba9, + 0xfd9ecba8, 0xfd98ecba, 0xfdaecb98, 0xfda8ecb9, 0xfda9ecb8, 0xfda98ecb, + 0xfdbeca98, 0xfdb8eca9, 0xfdb9eca8, 0xfdb98eca, 0xfdbaec98, 0xfdba8ec9, + 0xfdba9ec8, 0xfdba98ec, 0xfdceba98, 0xfdc8eba9, 0xfdc9eba8, 0xfdc98eba, + 0xfdcaeb98, 0xfdca8eb9, 0xfdca9eb8, 0xfdca98eb, 0xfdcbea98, 0xfdcb8ea9, + 0xfdcb9ea8, 0xfdcb98ea, 0xfdcbae98, 0xfdcba8e9, 0xfdcba9e8, 0xfdcba98e, + 0xfedcba98, 0xfe8dcba9, 0xfe9dcba8, 0xfe98dcba, 0xfeadcb98, 0xfea8dcb9, + 0xfea9dcb8, 0xfea98dcb, 0xfebdca98, 0xfeb8dca9, 0xfeb9dca8, 0xfeb98dca, + 0xfebadc98, 0xfeba8dc9, 0xfeba9dc8, 0xfeba98dc, 0xfecdba98, 0xfec8dba9, + 0xfec9dba8, 0xfec98dba, 0xfecadb98, 0xfeca8db9, 0xfeca9db8, 0xfeca98db, + 0xfecbda98, 0xfecb8da9, 0xfecb9da8, 0xfecb98da, 0xfecbad98, 0xfecba8d9, + 0xfecba9d8, 0xfecba98d, 0xfedcba98, 0xfed8cba9, 0xfed9cba8, 0xfed98cba, + 0xfedacb98, 0xfeda8cb9, 0xfeda9cb8, 0xfeda98cb, 0xfedbca98, 0xfedb8ca9, + 0xfedb9ca8, 0xfedb98ca, 0xfedbac98, 0xfedba8c9, 0xfedba9c8, 0xfedba98c, + 0xfedcba98, 0xfedc8ba9, 0xfedc9ba8, 0xfedc98ba, 0xfedcab98, 0xfedca8b9, + 0xfedca9b8, 0xfedca98b, 0xfedcba98, 0xfedcb8a9, 0xfedcb9a8, 0xfedcb98a, + 0xfedcba98, 0xfedcba89, 0xfedcba98, 0xfedcba98}; + + // No need to mask because <__lasx_xvperm_w> ignores bits 3..31. + // Just shift each copy of the 32 bit LUT to extract its 4-bit fields. + const Vec256 packed = Set(d32, packed_array[mask_bits]); + alignas(32) static constexpr uint32_t shifts[8] = {0, 4, 8, 12, + 16, 20, 24, 28}; + return packed >> Load(d32, shifts); +} + +template +HWY_INLINE Vec256 IndicesFromNotBits256(uint64_t mask_bits) { + const Full256 d64; + + // For 64-bit, there are only 4 lanes, so we can afford to load + // the entire index vector directly. + alignas(32) static constexpr uint64_t u64_indices[64] = { + // PrintCompressNot64x4PairTables + 8, 9, 10, 11, 9, 10, 11, 0, 8, 10, 11, 1, 10, 11, 0, 1, + 8, 9, 11, 2, 9, 11, 0, 2, 8, 11, 1, 2, 11, 0, 1, 2, + 8, 9, 10, 3, 9, 10, 0, 3, 8, 10, 1, 3, 10, 0, 1, 3, + 8, 9, 2, 3, 9, 0, 2, 3, 8, 1, 2, 3, 0, 1, 2, 3}; + return Load(d64, u64_indices + 4 * mask_bits); +} + +template +HWY_INLINE Vec256 Compress(Vec256 v, const uint64_t mask_bits) { + const DFromV d; + const RebindToSigned di; + + HWY_DASSERT(mask_bits < (1ull << Lanes(d))); + const Indices256> indices{ + IndicesFromBits256(mask_bits).raw}; + return BitCast(d, TableLookupLanes(BitCast(di, v), indices)); +} + +// LUTs are infeasible for 2^16 possible masks, so splice together two +// half-vector Compress. +template +HWY_INLINE Vec256 Compress(Vec256 v, const uint64_t mask_bits) { + const DFromV d; + const RebindToUnsigned du; + const auto vu16 = BitCast(du, v); // (required for float16_t inputs) + const Half duh; + const auto half0 = LowerHalf(duh, vu16); + const auto half1 = UpperHalf(duh, vu16); + + const uint64_t mask_bits0 = mask_bits & 0xFF; + const uint64_t mask_bits1 = mask_bits >> 8; + const auto compressed0 = detail::CompressBits(half0, mask_bits0); + const auto compressed1 = detail::CompressBits(half1, mask_bits1); + + alignas(32) uint16_t all_true[16] = {}; + // Store mask=true lanes, left to right. + const size_t num_true0 = PopCount(mask_bits0); + Store(compressed0, duh, all_true); + StoreU(compressed1, duh, all_true + num_true0); + + if (hwy::HWY_NAMESPACE::CompressIsPartition::value) { + // Store mask=false lanes, right to left. The second vector fills the upper + // half with right-aligned false lanes. The first vector is shifted + // rightwards to overwrite the true lanes of the second. + alignas(32) uint16_t all_false[16] = {}; + const size_t num_true1 = PopCount(mask_bits1); + Store(compressed1, duh, all_false + 8); + StoreU(compressed0, duh, all_false + num_true1); + + const auto mask = FirstN(du, num_true0 + num_true1); + return BitCast(d, + IfThenElse(mask, Load(du, all_true), Load(du, all_false))); + } else { + // Only care about the mask=true lanes. + return BitCast(d, Load(du, all_true)); + } +} + +template +HWY_INLINE Vec256 CompressNot(Vec256 v, const uint64_t mask_bits) { + const DFromV d; + const RebindToSigned di; + + HWY_DASSERT(mask_bits < (1ull << Lanes(d))); + const Indices256> indices{ + IndicesFromNotBits256(mask_bits).raw}; + return BitCast(d, TableLookupLanes(BitCast(di, v), indices)); +} + +// LUTs are infeasible for 2^16 possible masks, so splice together two +// half-vector Compress. +template +HWY_INLINE Vec256 CompressNot(Vec256 v, const uint64_t mask_bits) { + // Compress ensures only the lower 16 bits are set, so flip those. + return Compress(v, mask_bits ^ 0xFFFF); +} + +} // namespace detail + +template +HWY_API Vec256 Compress(Vec256 v, Mask256 m) { + const DFromV d; + return detail::Compress(v, BitsFromMask(d, m)); +} + +template +HWY_API Vec256 CompressNot(Vec256 v, Mask256 m) { + const DFromV d; + return detail::CompressNot(v, BitsFromMask(d, m)); +} + +HWY_API Vec256 CompressBlocksNot(Vec256 v, + Mask256 mask) { + return CompressNot(v, mask); +} + +template +HWY_API Vec256 CompressBits(Vec256 v, const uint8_t* HWY_RESTRICT bits) { + constexpr size_t N = 32 / sizeof(T); + constexpr size_t kNumBytes = (N + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes(bits, &mask_bits); + + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::Compress(v, mask_bits); +} + +// ------------------------------ CompressStore, CompressBitsStore + +template +HWY_API size_t CompressStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = BitsFromMask(d, m); + const size_t count = PopCount(mask_bits); + StoreU(detail::Compress(v, mask_bits), d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = BitsFromMask(d, m); + const size_t count = PopCount(mask_bits); + using TU = MakeUnsigned>; + + const RebindToUnsigned du; + HWY_DASSERT(mask_bits < (1ull << Lanes(d))); + const Vec256 idx_mask = detail::IndicesFromBits256>(mask_bits); + // Shift nibble MSB into MSB + const auto shiftVal = sizeof(TU) == 4 ? 28 : 60; + const Mask256 mask32or64 = MaskFromVec(ShiftLeft(idx_mask)); + const Mask256 masku{sizeof(TU) == 4 ? __lasx_xvslti_w(mask32or64.raw, 0) + : __lasx_xvslti_d(mask32or64.raw, 0)}; + const MFromD mask = RebindMask(d, masku); + const VFromD compressed = BitCast( + d, TableLookupLanes(BitCast(du, v), Indices256{idx_mask.raw})); + + BlendedStore(compressed, mask, d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = BitsFromMask(d, m); + const size_t count = PopCount(mask_bits); + const VFromD compressed = detail::Compress(v, mask_bits); + BlendedStore(compressed, FirstN(d, count), d, unaligned); + return count; +} + +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + constexpr size_t N = MaxLanes(d); + constexpr size_t kNumBytes = (N + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes(bits, &mask_bits); + + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + const size_t count = PopCount(mask_bits); + + StoreU(detail::Compress(v, mask_bits), d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +// ------------------------------ Dup128MaskFromMaskBits + +// Generic for all vector lengths >= 32 bytes +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + const Half dh; + const auto mh = Dup128MaskFromMaskBits(dh, mask_bits); + return CombineMasks(d, mh, mh); +} + +// ------------------------------ Expand + +template +HWY_API Vec256 Expand(Vec256 v, Mask256 mask) { + const DFromV d; + // LUTs are infeasible for so many mask combinations, so Combine two + // half-vector Expand. + const Half dh; + const uint64_t mask_bits = BitsFromMask(d, mask); + constexpr size_t N = 32 / sizeof(T); + const size_t countL = PopCount(mask_bits & ((1 << (N / 2)) - 1)); + const Mask128 maskL = MaskFromVec(LowerHalf(VecFromMask(d, mask))); + const Vec128 expandL = Expand(LowerHalf(v), maskL); + + alignas(32) T lanes[N]; + Store(v, d, lanes); + const Mask128 maskH = MaskFromVec(UpperHalf(dh, VecFromMask(d, mask))); + const Vec128 expandH = Expand(LoadU(dh, lanes + countL), maskH); + return Combine(d, expandH, expandL); +} + +template +HWY_API Vec256 Expand(Vec256 v, Mask256 mask) { + const Full256 d; + // LUTs are infeasible for 2^16 possible masks, so splice together two + // half-vector Expand. + const Half dh; + const Mask128 maskL = MaskFromVec(LowerHalf(VecFromMask(d, mask))); + const Vec128 expandL = Expand(LowerHalf(v), maskL); + + alignas(32) T lanes[32 / sizeof(T)]; + Store(v, d, lanes); + const Vec128 vH = LoadU(dh, lanes + CountTrue(dh, maskL)); + const Mask128 maskH = MaskFromVec(UpperHalf(dh, VecFromMask(d, mask))); + const Vec128 expandH = Expand(vH, maskH); + return Combine(d, expandH, expandL); +} + +template +HWY_API Vec256 Expand(Vec256 v, Mask256 mask) { + const Full256 d; + const RebindToUnsigned du; + const uint64_t mask_bits = BitsFromMask(d, mask); + alignas(16) constexpr uint32_t packed_array[256] = { + // PrintExpand32x8Nibble. + 0xffffffff, 0xfffffff0, 0xffffff0f, 0xffffff10, 0xfffff0ff, 0xfffff1f0, + 0xfffff10f, 0xfffff210, 0xffff0fff, 0xffff1ff0, 0xffff1f0f, 0xffff2f10, + 0xffff10ff, 0xffff21f0, 0xffff210f, 0xffff3210, 0xfff0ffff, 0xfff1fff0, + 0xfff1ff0f, 0xfff2ff10, 0xfff1f0ff, 0xfff2f1f0, 0xfff2f10f, 0xfff3f210, + 0xfff10fff, 0xfff21ff0, 0xfff21f0f, 0xfff32f10, 0xfff210ff, 0xfff321f0, + 0xfff3210f, 0xfff43210, 0xff0fffff, 0xff1ffff0, 0xff1fff0f, 0xff2fff10, + 0xff1ff0ff, 0xff2ff1f0, 0xff2ff10f, 0xff3ff210, 0xff1f0fff, 0xff2f1ff0, + 0xff2f1f0f, 0xff3f2f10, 0xff2f10ff, 0xff3f21f0, 0xff3f210f, 0xff4f3210, + 0xff10ffff, 0xff21fff0, 0xff21ff0f, 0xff32ff10, 0xff21f0ff, 0xff32f1f0, + 0xff32f10f, 0xff43f210, 0xff210fff, 0xff321ff0, 0xff321f0f, 0xff432f10, + 0xff3210ff, 0xff4321f0, 0xff43210f, 0xff543210, 0xf0ffffff, 0xf1fffff0, + 0xf1ffff0f, 0xf2ffff10, 0xf1fff0ff, 0xf2fff1f0, 0xf2fff10f, 0xf3fff210, + 0xf1ff0fff, 0xf2ff1ff0, 0xf2ff1f0f, 0xf3ff2f10, 0xf2ff10ff, 0xf3ff21f0, + 0xf3ff210f, 0xf4ff3210, 0xf1f0ffff, 0xf2f1fff0, 0xf2f1ff0f, 0xf3f2ff10, + 0xf2f1f0ff, 0xf3f2f1f0, 0xf3f2f10f, 0xf4f3f210, 0xf2f10fff, 0xf3f21ff0, + 0xf3f21f0f, 0xf4f32f10, 0xf3f210ff, 0xf4f321f0, 0xf4f3210f, 0xf5f43210, + 0xf10fffff, 0xf21ffff0, 0xf21fff0f, 0xf32fff10, 0xf21ff0ff, 0xf32ff1f0, + 0xf32ff10f, 0xf43ff210, 0xf21f0fff, 0xf32f1ff0, 0xf32f1f0f, 0xf43f2f10, + 0xf32f10ff, 0xf43f21f0, 0xf43f210f, 0xf54f3210, 0xf210ffff, 0xf321fff0, + 0xf321ff0f, 0xf432ff10, 0xf321f0ff, 0xf432f1f0, 0xf432f10f, 0xf543f210, + 0xf3210fff, 0xf4321ff0, 0xf4321f0f, 0xf5432f10, 0xf43210ff, 0xf54321f0, + 0xf543210f, 0xf6543210, 0x0fffffff, 0x1ffffff0, 0x1fffff0f, 0x2fffff10, + 0x1ffff0ff, 0x2ffff1f0, 0x2ffff10f, 0x3ffff210, 0x1fff0fff, 0x2fff1ff0, + 0x2fff1f0f, 0x3fff2f10, 0x2fff10ff, 0x3fff21f0, 0x3fff210f, 0x4fff3210, + 0x1ff0ffff, 0x2ff1fff0, 0x2ff1ff0f, 0x3ff2ff10, 0x2ff1f0ff, 0x3ff2f1f0, + 0x3ff2f10f, 0x4ff3f210, 0x2ff10fff, 0x3ff21ff0, 0x3ff21f0f, 0x4ff32f10, + 0x3ff210ff, 0x4ff321f0, 0x4ff3210f, 0x5ff43210, 0x1f0fffff, 0x2f1ffff0, + 0x2f1fff0f, 0x3f2fff10, 0x2f1ff0ff, 0x3f2ff1f0, 0x3f2ff10f, 0x4f3ff210, + 0x2f1f0fff, 0x3f2f1ff0, 0x3f2f1f0f, 0x4f3f2f10, 0x3f2f10ff, 0x4f3f21f0, + 0x4f3f210f, 0x5f4f3210, 0x2f10ffff, 0x3f21fff0, 0x3f21ff0f, 0x4f32ff10, + 0x3f21f0ff, 0x4f32f1f0, 0x4f32f10f, 0x5f43f210, 0x3f210fff, 0x4f321ff0, + 0x4f321f0f, 0x5f432f10, 0x4f3210ff, 0x5f4321f0, 0x5f43210f, 0x6f543210, + 0x10ffffff, 0x21fffff0, 0x21ffff0f, 0x32ffff10, 0x21fff0ff, 0x32fff1f0, + 0x32fff10f, 0x43fff210, 0x21ff0fff, 0x32ff1ff0, 0x32ff1f0f, 0x43ff2f10, + 0x32ff10ff, 0x43ff21f0, 0x43ff210f, 0x54ff3210, 0x21f0ffff, 0x32f1fff0, + 0x32f1ff0f, 0x43f2ff10, 0x32f1f0ff, 0x43f2f1f0, 0x43f2f10f, 0x54f3f210, + 0x32f10fff, 0x43f21ff0, 0x43f21f0f, 0x54f32f10, 0x43f210ff, 0x54f321f0, + 0x54f3210f, 0x65f43210, 0x210fffff, 0x321ffff0, 0x321fff0f, 0x432fff10, + 0x321ff0ff, 0x432ff1f0, 0x432ff10f, 0x543ff210, 0x321f0fff, 0x432f1ff0, + 0x432f1f0f, 0x543f2f10, 0x432f10ff, 0x543f21f0, 0x543f210f, 0x654f3210, + 0x3210ffff, 0x4321fff0, 0x4321ff0f, 0x5432ff10, 0x4321f0ff, 0x5432f1f0, + 0x5432f10f, 0x6543f210, 0x43210fff, 0x54321ff0, 0x54321f0f, 0x65432f10, + 0x543210ff, 0x654321f0, 0x6543210f, 0x76543210, + }; + + // For lane i, shift the i-th 4-bit index down to bits [0, 3). + const Vec256 packed = Set(du, packed_array[mask_bits]); + alignas(32) constexpr uint32_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; + // TableLookupLanes ignores upper bits; avoid bounds-check in IndicesFromVec. + const Indices256 indices{(packed >> Load(du, shifts)).raw}; + const Vec256 expand = TableLookupLanes(BitCast(du, v), indices); + // TableLookupLanes cannot also zero masked-off lanes, so do that now. + return IfThenElseZero(mask, BitCast(d, expand)); +} + +template +HWY_API Vec256 Expand(Vec256 v, Mask256 mask) { + const Full256 d; + const RebindToUnsigned du; + const uint64_t mask_bits = BitsFromMask(d, mask); + + alignas(16) constexpr uint64_t packed_array[16] = { + // PrintExpand64x4Nibble. + 0x0000ffff, 0x0000fff0, 0x0000ff0f, 0x0000ff10, 0x0000f0ff, 0x0000f1f0, + 0x0000f10f, 0x0000f210, 0x00000fff, 0x00001ff0, 0x00001f0f, 0x00002f10, + 0x000010ff, 0x000021f0, 0x0000210f, 0x00003210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 2). + const Vec256 packed = Set(du, packed_array[mask_bits]); + alignas(32) constexpr uint64_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; + // 64-bit TableLookupLanes on LASX requires IndicesFromVec, which checks + // bounds, so clear the upper bits. + const Vec256 masked = And(packed >> Load(du, shifts), Set(du, 3)); + const Indices256 indices = IndicesFromVec(du, masked); + const Vec256 expand = TableLookupLanes(BitCast(du, v), indices); + // TableLookupLanes cannot also zero masked-off lanes, so do that now. + return IfThenElseZero(mask, BitCast(d, expand)); +} + +// ------------------------------ LoadExpand + +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { + return Expand(LoadU(d, unaligned), mask); +} + +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { + return Expand(LoadU(d, unaligned), mask); +} + +// ------------------------------ LoadInterleaved3/4 + +// Implemented in generic_ops, we just overload LoadTransposedBlocks3/4. + +namespace detail { +// Input: +// 1 0 (<- first block of unaligned) +// 3 2 +// 5 4 +// Output: +// 3 0 +// 4 1 +// 5 2 +template +HWY_API void LoadTransposedBlocks3(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& A, VFromD& B, VFromD& C) { + constexpr size_t N = MaxLanes(d); + const VFromD v10 = LoadU(d, unaligned + 0 * N); // 1 0 + const VFromD v32 = LoadU(d, unaligned + 1 * N); + const VFromD v54 = LoadU(d, unaligned + 2 * N); + + A = ConcatUpperLower(d, v32, v10); + B = ConcatLowerUpper(d, v54, v10); + C = ConcatUpperLower(d, v54, v32); +} + +// Input (128-bit blocks): +// 1 0 (first block of unaligned) +// 3 2 +// 5 4 +// 7 6 +// Output: +// 4 0 (LSB of vA) +// 5 1 +// 6 2 +// 7 3 +template +HWY_API void LoadTransposedBlocks4(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& vA, VFromD& vB, VFromD& vC, + VFromD& vD) { + constexpr size_t N = MaxLanes(d); + const VFromD v10 = LoadU(d, unaligned + 0 * N); + const VFromD v32 = LoadU(d, unaligned + 1 * N); + const VFromD v54 = LoadU(d, unaligned + 2 * N); + const VFromD v76 = LoadU(d, unaligned + 3 * N); + + vA = ConcatLowerLower(d, v54, v10); + vB = ConcatUpperUpper(d, v54, v10); + vC = ConcatLowerLower(d, v76, v32); + vD = ConcatUpperUpper(d, v76, v32); +} +} // namespace detail + +// ------------------------------ StoreInterleaved2/3/4 (ConcatUpperLower) + +// Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. + +namespace detail { +// Input (128-bit blocks): +// 2 0 (LSB of i) +// 3 1 +// Output: +// 1 0 +// 3 2 +template +HWY_API void StoreTransposedBlocks2(VFromD i, VFromD j, D d, + TFromD* HWY_RESTRICT unaligned) { + constexpr size_t N = MaxLanes(d); + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatUpperUpper(d, j, i); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); +} + +// Input (128-bit blocks): +// 3 0 (LSB of i) +// 4 1 +// 5 2 +// Output: +// 1 0 +// 3 2 +// 5 4 +template +HWY_API void StoreTransposedBlocks3(VFromD i, VFromD j, VFromD k, D d, + TFromD* HWY_RESTRICT unaligned) { + constexpr size_t N = MaxLanes(d); + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatUpperLower(d, i, k); + const auto out2 = ConcatUpperUpper(d, k, j); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + StoreU(out2, d, unaligned + 2 * N); +} + +// Input (128-bit blocks): +// 4 0 (LSB of i) +// 5 1 +// 6 2 +// 7 3 +// Output: +// 1 0 +// 3 2 +// 5 4 +// 7 6 +template +HWY_API void StoreTransposedBlocks4(VFromD i, VFromD j, VFromD k, + VFromD l, D d, + TFromD* HWY_RESTRICT unaligned) { + constexpr size_t N = MaxLanes(d); + // Write lower halves, then upper. + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatLowerLower(d, l, k); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + const auto out2 = ConcatUpperUpper(d, j, i); + const auto out3 = ConcatUpperUpper(d, l, k); + StoreU(out2, d, unaligned + 2 * N); + StoreU(out3, d, unaligned + 3 * N); +} +} // namespace detail + +// ------------------------------ Additional mask logical operations + +namespace detail { + +template +static HWY_INLINE HWY_MAYBE_UNUSED Vec256 LasxI256Neg(Vec256 v) { + const Full256 d; + const Repartition du64; + + const auto vu64 = BitCast(du64, v); + const auto vu64_zero = Zero(du64); + const auto i128_ne_zero = VecFromMask(du64, Ne128(du64, vu64, vu64_zero)); + const VFromD i128_neg_result{ + __lasx_xvsub_q(vu64_zero.raw, vu64.raw)}; + const VFromD i256_neg_result_as_u64{ + __lasx_xvadd_q(i128_neg_result.raw, + ConcatLowerLower(du64, i128_ne_zero, vu64_zero).raw)}; + + return BitCast(d, i256_neg_result_as_u64); +} + +} // namespace detail + +template +HWY_API Mask256 SetAtOrAfterFirst(Mask256 mask) { + const Full256 d; + return Or(mask, MaskFromVec(detail::LasxI256Neg(VecFromMask(d, mask)))); +} + +template +HWY_API Mask256 SetBeforeFirst(Mask256 mask) { + return Not(SetAtOrAfterFirst(mask)); +} + +template +HWY_API Mask256 SetOnlyFirst(Mask256 mask) { + const Full256 d; + const RebindToSigned di; + + const auto vmask = BitCast(di, VecFromMask(d, mask)); + const auto neg_vmask = detail::LasxI256Neg(vmask); + + return MaskFromVec(BitCast(d, Neg(And(vmask, neg_vmask)))); +} + +template +HWY_API Mask256 SetAtOrBeforeFirst(Mask256 mask) { + const Full256 d; + constexpr size_t kLanesPerBlock = MaxLanes(d) / 2; + + const auto vmask = VecFromMask(d, mask); + const auto vmask_lo = ConcatLowerLower(d, vmask, Zero(d)); + return SetBeforeFirst( + MaskFromVec(CombineShiftRightBytes<(kLanesPerBlock - 1) * sizeof(T)>( + d, vmask, vmask_lo))); +} + +// ------------------------------ LeadingZeroCount + +template ), HWY_IF_V_SIZE_V(V, 32)> +HWY_API V LeadingZeroCount(V v) { + return V{__lasx_xvclz_b(v.raw)}; +} +template ), HWY_IF_V_SIZE_V(V, 32)> +HWY_API V LeadingZeroCount(V v) { + return V{__lasx_xvclz_h(v.raw)}; +} +template ), HWY_IF_V_SIZE_V(V, 32)> +HWY_API V LeadingZeroCount(V v) { + return V{__lasx_xvclz_w(v.raw)}; +} +template ), HWY_IF_V_SIZE_V(V, 32)> +HWY_API V LeadingZeroCount(V v) { + return V{__lasx_xvclz_d(v.raw)}; +} + +template +HWY_API V HighestSetBitIndex(V v) { + const DFromV d; + using T = TFromD; + return BitCast(d, Set(d, T{sizeof(T) * 8 - 1}) - LeadingZeroCount(v)); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/lib/highway/hwy/ops/loongarch_lsx-inl.h b/lib/highway/hwy/ops/loongarch_lsx-inl.h new file mode 100644 index 00000000000..40642743a93 --- /dev/null +++ b/lib/highway/hwy/ops/loongarch_lsx-inl.h @@ -0,0 +1,5954 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#ifndef __loongarch_sx +// If LSX is to be runtime dispatched (instead of in baseline), we need +// to enable it *and* define __loongarch_sx or the intrinsic header will +// fail to compile. +// +// We cannot simply move lsxintrin.h after HWY_BEFORE_NAMESPACE because +// doing so may cause the first (the only effective) inclusion of +// lsxintrin.h to be compiled with both LSX and LASX enabled. Then when +// we call the inline functions in the header with only LSX enabled, +// we'll get an "always_inline function requires lasx but would be inlined +// into a function that is compiled without suport for lasx" error. +HWY_PUSH_ATTRIBUTES("lsx") +#define __loongarch_sx +#include +#undef __loongarch_sx +// Prevent "unused push_attribute" warning from Clang. +HWY_MAYBE_UNUSED static void HWY_CONCAT(hwy_lsx_dummy, __COUNTER__) () {} +HWY_POP_ATTRIBUTES +#else +#include +#endif + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +// Enable generic functions for whichever of (f16, bf16) are not supported. +#define HWY_LSX_IF_EMULATED_D(D) HWY_IF_SPECIAL_FLOAT_D(D) + +template +struct Raw128 { + using type = __m128i; +}; +template <> +struct Raw128 { + using type = __m128; +}; +template <> +struct Raw128 { + using type = __m128d; +}; + +} // namespace detail + +template +class Vec128 { + using Raw = typename detail::Raw128::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = N; // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator%=(const Vec128 other) { + return *this = (*this % other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +template +using Vec64 = Vec128; + +template +using Vec32 = Vec128; + +template +using Vec16 = Vec128; + +namespace detail { + +template +using RawMask128 = typename Raw128::type; + +} // namespace detail + +template +struct Mask128 { + using Raw = typename detail::RawMask128; + + using PrivateT = T; // only for DFromM + static constexpr size_t kPrivateN = N; // only for DFromM + + Raw raw; +}; + +template +using DFromV = Simd; + +template +using DFromM = Simd; + +template +using TFromV = typename V::PrivateT; + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __m128i BitCastToInteger(__m128i v) { return v; } +HWY_INLINE __m128i BitCastToInteger(__m128 v) { + return reinterpret_cast<__m128i>(v); +} +HWY_INLINE __m128i BitCastToInteger(__m128d v) { + return reinterpret_cast<__m128i>(v); +} + +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return Vec128{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger128 { + HWY_INLINE __m128i operator()(__m128i v) { return v; } +}; +template <> +struct BitCastFromInteger128 { + HWY_INLINE __m128 operator()(__m128i v) { + return reinterpret_cast<__m128>(v); + } +}; +template <> +struct BitCastFromInteger128 { + HWY_INLINE __m128d operator()(__m128i v) { + return reinterpret_cast<__m128d>(v); + } +}; + +} // namespace detail + +// ------------------------------ Zero + +// Use HWY_MAX_LANES_D here because VFromD is defined in terms of Zero. +template +HWY_API Vec128, HWY_MAX_LANES_D(D)> Zero(D /* tag */) { + return Vec128, HWY_MAX_LANES_D(D)>{(__lsx_vreplgr2vr_w(0))}; +} + +template +HWY_API Vec128, HWY_MAX_LANES_D(D)> Zero(D /* tag */) { + return Vec128, HWY_MAX_LANES_D(D)>{ + detail::BitCastFromInteger128>()(__lsx_vreplgr2vr_w(0))}; +} + +template +using VFromD = decltype(Zero(D())); + +namespace detail { + +template +HWY_INLINE VFromD BitCastFromByte(D /* tag */, + Vec128 v) { + return VFromD{BitCastFromInteger128>()(v.raw)}; +} + +} // namespace detail + +template +HWY_API VFromD BitCast(D d, + Vec128().MaxLanes()> v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Set + +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{__lsx_vreplgr2vr_b(static_cast(t))}; +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{__lsx_vreplgr2vr_h(static_cast(t))}; +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{__lsx_vreplgr2vr_w(static_cast(t))}; +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{__lsx_vreplgr2vr_d(static_cast(t))}; +} + +template +HWY_API VFromD Set(D d, float t) { + const RebindToSigned di; + return BitCast(d, VFromD{__lsx_vldrepl_w(&t, 0)}); +} + +template +HWY_API VFromD Set(D d, double t) { + const RebindToSigned di; + return BitCast(d, VFromD{__lsx_vldrepl_d(&t, 0)}); +} + +// Generic for all vector lengths. +template +HWY_API VFromD Set(D df, TFromD t) { + const RebindToUnsigned du; + static_assert(sizeof(TFromD) == 2, "Expecting [b]f16"); + uint16_t bits; + CopyBytes<2>(&t, &bits); + return BitCast(df, Set(du, bits)); +} + +// ------------------------------ Undefined + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template +HWY_API VFromD Undefined(D /* tag */) { + VFromD v; + return v; +} + +HWY_DIAGNOSTICS(pop) + +// ------------------------------ GetLane + +template +HWY_API T GetLane(const Vec128 v) { + return static_cast(__lsx_vpickve2gr_b(v.raw, 0)); +} +template +HWY_API T GetLane(const Vec128 v) { + return static_cast(__lsx_vpickve2gr_h(v.raw, 0)); +} +template +HWY_API T GetLane(const Vec128 v) { + return static_cast(__lsx_vpickve2gr_w(v.raw, 0)); +} +template +HWY_API T GetLane(const Vec128 v) { + return static_cast(__lsx_vpickve2gr_d(v.raw, 0)); +} +template +HWY_API float GetLane(const Vec128 v) { + float f32; + int32_t i32 = __lsx_vpickve2gr_w(reinterpret_cast<__m128i>(v.raw), 0); + CopyBytes<4>(&i32, &f32); + return f32; +} +template +HWY_API double GetLane(const Vec128 v) { + double f64; + int64_t i64 = __lsx_vpickve2gr_d(reinterpret_cast<__m128i>(v.raw), 0); + CopyBytes<8>(&i64, &f64); + return f64; +} + +// ------------------------------ ResizeBitCast + +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + const Repartition du8; + return BitCast(d, VFromD{detail::BitCastToInteger(v.raw)}); +} + +// ------------------------------ Dup128VecFromValues + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + typedef int8_t GccI8RawVectType __attribute__((__vector_size__(16))); + const GccI8RawVectType raw = { + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7), + static_cast(t8), static_cast(t9), + static_cast(t10), static_cast(t11), + static_cast(t12), static_cast(t13), + static_cast(t14), static_cast(t15)}; + return VFromD{reinterpret_cast<__m128i>(raw)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + typedef int16_t GccI16RawVectType __attribute__((__vector_size__(16))); + const GccI16RawVectType raw = { + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7)}; + return VFromD{reinterpret_cast<__m128i>(raw)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const RebindToSigned di; + return BitCast(d, + Dup128VecFromValues( + di, BitCastScalar(t0), BitCastScalar(t1), + BitCastScalar(t2), BitCastScalar(t3), + BitCastScalar(t4), BitCastScalar(t5), + BitCastScalar(t6), BitCastScalar(t7))); +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + typedef int32_t GccI32RawVectType __attribute__((__vector_size__(16))); + const GccI32RawVectType raw = { + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3)}; + return VFromD{reinterpret_cast<__m128i>(raw)}; +} +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + typedef int64_t GccI64RawVectType __attribute__((__vector_size__(16))); + const GccI64RawVectType raw = {static_cast(t0), + static_cast(t1)}; + return VFromD{reinterpret_cast<__m128i>(raw)}; +} +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + const GccF32RawVectType raw = {t0, t1, t2, t3}; + return VFromD{reinterpret_cast<__m128>(raw)}; +} +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + const GccF64RawVectType raw = {t0, t1}; + return VFromD{reinterpret_cast<__m128d>(raw)}; +} + +// ================================================== LOGICAL + +// ------------------------------ And + +template +HWY_API Vec128 And(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, VFromD{ + __lsx_vand_v(BitCast(du, a).raw, BitCast(du, b).raw)}); +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec128 AndNot(Vec128 not_mask, Vec128 mask) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, VFromD{__lsx_vandn_v( + BitCast(du, not_mask).raw, BitCast(du, mask).raw)}); +} + +// ------------------------------ Or + +template +HWY_API Vec128 Or(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, VFromD{ + __lsx_vor_v(BitCast(du, a).raw, BitCast(du, b).raw)}); +} + +// ------------------------------ Xor + +template +HWY_API Vec128 Xor(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, VFromD{ + __lsx_vxor_v(BitCast(du, a).raw, BitCast(du, b).raw)}); +} + +// ------------------------------ Not +template +HWY_API Vec128 Not(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, VFromD{ + __lsx_vnor_v(BitCast(du, v).raw, BitCast(du, v).raw)}); +} + +// ------------------------------ Or3 +template +HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd +template +HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + return Mask128{v.raw}; +} + +template +using MFromD = decltype(MaskFromVec(VFromD())); + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{v.raw}; +} + +// Generic for all vector lengths. +template +HWY_API VFromD VecFromMask(D /* tag */, MFromD v) { + return VecFromMask(v); +} + +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + const DFromV d; + RebindToSigned di; + return BitCast(d, VFromD{__lsx_vbitsel_v( + BitCast(di, no).raw, BitCast(di, yes).raw, + RebindMask(di, mask).raw)}); +} + +// ------------------------------ IfVecThenElse +template +HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, + Vec128 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ BitwiseIfThenElse + +#ifdef HWY_NATIVE_BITWISE_IF_THEN_ELSE +#undef HWY_NATIVE_BITWISE_IF_THEN_ELSE +#else +#define HWY_NATIVE_BITWISE_IF_THEN_ELSE +#endif + +template +HWY_API V BitwiseIfThenElse(V mask, V yes, V no) { + return IfVecThenElse(mask, yes, no); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec128 operator&(const Vec128 a, const Vec128 b) { + return And(a, b); +} + +template +HWY_API Vec128 operator|(const Vec128 a, const Vec128 b) { + return Or(a, b); +} + +template +HWY_API Vec128 operator^(const Vec128 a, const Vec128 b) { + return Xor(a, b); +} + +// ------------------------------ PopulationCount + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +namespace detail { + +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<1> /* tag */, + Vec128 v) { + return Vec128{__lsx_vpcnt_b(v.raw)}; +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<2> /* tag */, + Vec128 v) { + return Vec128{__lsx_vpcnt_h(v.raw)}; +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<4> /* tag */, + Vec128 v) { + return Vec128{__lsx_vpcnt_w(v.raw)}; +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<8> /* tag */, + Vec128 v) { + return Vec128{__lsx_vpcnt_d(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 PopulationCount(Vec128 v) { + return detail::PopulationCount(hwy::SizeTag(), v); +} + +// ================================================== SIGN + +// ------------------------------ Neg + +template +HWY_API Vec128 Neg(const Vec128 v) { + return Xor(v, SignBit(DFromV())); +} + +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{__lsx_vneg_b(v.raw)}; +} + +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{__lsx_vneg_h(v.raw)}; +} + +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{__lsx_vneg_w(v.raw)}; +} + +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{__lsx_vneg_d(v.raw)}; +} + +// ------------------------------ Floating-point Abs +// Generic for all vector lengths +template )> +HWY_API V Abs(V v) { + const DFromV d; + const RebindToSigned di; + using TI = TFromD; + return v & BitCast(d, Set(di, static_cast(~SignMask()))); +} + +// ------------------------------ CopySign +// Generic for all vector lengths. +template +HWY_API V CopySign(const V magn, const V sign) { + static_assert(IsFloat>(), "Only makes sense for floating-point"); + + const DFromV d; + const auto msb = SignBit(d); + return BitwiseIfThenElse(msb, sign, magn); +} + +// ------------------------------ CopySignToAbs +// Generic for all vector lengths. +template +HWY_API V CopySignToAbs(const V abs, const V sign) { + const DFromV d; + return OrAnd(abs, SignBit(d), sign); +} + +// ------------------------------ IfThenElseZero + +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return yes & VecFromMask(DFromV(), mask); +} + +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return AndNot(VecFromMask(DFromV(), mask), no); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(const Mask128 m) { + const Simd d; + return MaskFromVec(Not(VecFromMask(d, m))); +} + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +// ------------------------------ ExclusiveNeither + +template +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ------------------------------ ShiftLeft + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{__lsx_vslli_b(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{__lsx_vslli_h(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{__lsx_vslli_w(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{__lsx_vslli_d(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{__lsx_vslli_b(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{__lsx_vslli_h(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{__lsx_vslli_w(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{__lsx_vslli_d(v.raw, kBits)}; +} + +// ------------------------------ ShiftRight + +template +HWY_API Vec128 ShiftRight(Vec128 v) { + return Vec128{__lsx_vsrli_b(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(Vec128 v) { + return Vec128{__lsx_vsrli_h(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(Vec128 v) { + return Vec128{__lsx_vsrli_w(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(Vec128 v) { + return Vec128{__lsx_vsrli_d(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftRight(Vec128 v) { + return Vec128{__lsx_vsrai_b(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(Vec128 v) { + return Vec128{__lsx_vsrai_h(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(Vec128 v) { + return Vec128{__lsx_vsrai_w(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(Vec128 v) { + return Vec128{__lsx_vsrai_d(v.raw, kBits)}; +} + +// ------------------------------ RoundingShiftRight + +#ifdef HWY_NATIVE_ROUNDING_SHR +#undef HWY_NATIVE_ROUNDING_SHR +#else +#define HWY_NATIVE_ROUNDING_SHR +#endif + +template +HWY_API Vec128 RoundingShiftRight(Vec128 v) { + return Vec128{__lsx_vsrari_b(v.raw, kBits)}; +} +template +HWY_API Vec128 RoundingShiftRight(Vec128 v) { + return Vec128{__lsx_vsrari_h(v.raw, kBits)}; +} +template +HWY_API Vec128 RoundingShiftRight(Vec128 v) { + return Vec128{__lsx_vsrari_w(v.raw, kBits)}; +} +template +HWY_API Vec128 RoundingShiftRight(Vec128 v) { + return Vec128{__lsx_vsrari_d(v.raw, kBits)}; +} + +template +HWY_API Vec128 RoundingShiftRight(Vec128 v) { + return Vec128{__lsx_vsrlri_b(v.raw, kBits)}; +} +template +HWY_API Vec128 RoundingShiftRight(Vec128 v) { + return Vec128{__lsx_vsrlri_h(v.raw, kBits)}; +} +template +HWY_API Vec128 RoundingShiftRight(Vec128 v) { + return Vec128{__lsx_vsrlri_w(v.raw, kBits)}; +} +template +HWY_API Vec128 RoundingShiftRight(Vec128 v) { + return Vec128{__lsx_vsrlri_d(v.raw, kBits)}; +} + +// ------------------------------ RoundingShr + +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + return Vec128{__lsx_vsrar_b(v.raw, bits.raw)}; +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + return Vec128{__lsx_vsrar_h(v.raw, bits.raw)}; +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + return Vec128{__lsx_vsrar_w(v.raw, bits.raw)}; +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + return Vec128{__lsx_vsrar_d(v.raw, bits.raw)}; +} + +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + return Vec128{__lsx_vsrlr_b(v.raw, bits.raw)}; +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + return Vec128{__lsx_vsrlr_h(v.raw, bits.raw)}; +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + return Vec128{__lsx_vsrlr_w(v.raw, bits.raw)}; +} +template +HWY_API Vec128 RoundingShr(Vec128 v, + Vec128 bits) { + return Vec128{__lsx_vsrlr_d(v.raw, bits.raw)}; +} + +// ------------------------------ RoundingShiftRightSame (RoundingShr) + +template +HWY_API Vec128 RoundingShiftRightSame(const Vec128 v, int bits) { + return RoundingShr(v, Set(DFromV(), static_cast(bits))); +} + +// ================================================== MEMORY (1) + +// ------------------------------ Load 128 + +template > +HWY_API Vec128 Load(D d, const T* HWY_RESTRICT aligned) { + const RebindToUnsigned du; + return BitCast(d, VFromD{__lsx_vld(aligned, 0)}); +} + +// Partial +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT p) { + VFromD v; + CopyBytes(p, &v); + return v; +} + +// LoadU == Load +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + return Load(d, p); +} + +// ------------------------------ MaskedLoad + +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + return IfThenElseZero(m, LoadU(d, p)); +} + +// ------------------------------ MaskedLoadOr + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + return IfThenElse(m, LoadU(d, p), v); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template +HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { + return Load(d, p); +} + +// ------------------------------ Store 128 + +template +HWY_API void Store(VFromD v, D /* tag */, void* HWY_RESTRICT aligned) { + __lsx_vst(v.raw, aligned, 0); +} + +// ------------------------------ Store 64 + +template +HWY_API void Store(VFromD v, D /* tag */, void* HWY_RESTRICT aligned) { + __lsx_vstelm_d(v.raw, aligned, 0, 0); +} + +// ------------------------------ Store 32 + +template +HWY_API void Store(VFromD v, D /* tag */, void* HWY_RESTRICT aligned) { + __lsx_vstelm_w(v.raw, aligned, 0, 0); +} + +// ------------------------------ Store 16 + +template +HWY_API void Store(VFromD v, D /* tag */, void* HWY_RESTRICT aligned) { + __lsx_vstelm_h(v.raw, aligned, 0, 0); +} + +// ------------------------------ Store 8 + +template +HWY_API void Store(VFromD v, D /* tag */, void* HWY_RESTRICT aligned) { + __lsx_vstelm_b(v.raw, aligned, 0, 0); +} + +template +HWY_API void StoreU(VFromD v, D d, void* HWY_RESTRICT p) { + Store(v, d, p); +} + +// ================================================== SWIZZLE (1) + +// ------------------------------ TableLookupBytes +template +HWY_API Vec128 TableLookupBytes(const Vec128 bytes, + const Vec128 from) { + const DFromV d; + const Repartition du8; + const DFromV d_bytes; + const Repartition du8_bytes; + return BitCast( + d, VFromD{__lsx_vshuf_b(BitCast(du8_bytes, bytes).raw, + BitCast(du8_bytes, bytes).raw, + (BitCast(du8, from).raw))}); +} + +// ------------------------------ TableLookupBytesOr0 +template +HWY_API VI TableLookupBytesOr0(const V bytes, const VI from) { + const DFromV d; + const Repartition di8; + return BitCast(d, + IfThenZeroElse(Lt(BitCast(di8, from), Zero(di8)), + BitCast(di8, TableLookupBytes(bytes, from)))); +} + +// ------------------------------ Shuffles (ShiftRight, TableLookupBytes) + +// Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec128 Shuffle2301(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, VFromD{__lsx_vshuf4i_w( + detail::BitCastToInteger(v.raw), 0xB1)}); +} + +namespace detail { + +template +HWY_API Vec32 ShuffleTwo2301(const Vec32 a, const Vec32 b) { + const int8_t _data_idx[] = {1, 0, 19, 18}; + __m128i shuffle_idx = __lsx_vld(_data_idx, 0); + return Vec32{__lsx_vshuf_b(b.raw, a.raw, shuffle_idx)}; +} +template +HWY_API Vec64 ShuffleTwo2301(const Vec64 a, const Vec64 b) { + const int16_t _data_idx[] = {9, 8, 3, 2}; + __m128i shuffle_idx = __lsx_vld(_data_idx, 0); + return Vec64{__lsx_vshuf_h(shuffle_idx, a.raw, b.raw)}; +} +template +HWY_API Vec128 ShuffleTwo2301(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToSigned di; + return BitCast(d, Vec128{__lsx_vpermi_w(BitCast(di, b).raw, + BitCast(di, a).raw, 0xB1)}); +} + +template +HWY_API Vec32 ShuffleTwo1230(const Vec32 a, const Vec32 b) { + const int8_t _data_idx[] = {0, 3, 18, 17}; + __m128i shuffle_idx = __lsx_vld(_data_idx, 0); + return Vec32{__lsx_vshuf_b(b.raw, a.raw, shuffle_idx)}; +} +template +HWY_API Vec64 ShuffleTwo1230(const Vec64 a, const Vec64 b) { + const int16_t _data_idx[] = {10, 11, 2, 1}; + __m128i shuffle_idx = __lsx_vld(_data_idx, 0); + auto t0 = __lsx_vshuf_h(shuffle_idx, a.raw, b.raw); + return Vec64{t0}; +} +template +HWY_API Vec128 ShuffleTwo1230(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToSigned di; + return BitCast(d, Vec128{__lsx_vpermi_w(BitCast(di, b).raw, + BitCast(di, a).raw, 0x6C)}); +} + +template +HWY_API Vec32 ShuffleTwo3012(const Vec32 a, const Vec32 b) { + const int8_t _data_idx[] = {2, 1, 16, 19}; + __m128i shuffle_idx = __lsx_vld(_data_idx, 0); + return Vec32{__lsx_vshuf_b(b.raw, a.raw, shuffle_idx)}; +} +template +HWY_API Vec64 ShuffleTwo3012(const Vec64 a, const Vec64 b) { + const int16_t _data_idx[] = {8, 9, 0, 3}; + __m128i shuffle_idx = __lsx_vld(_data_idx, 0); + return Vec64{__lsx_vshuf_h(shuffle_idx, a.raw, b.raw)}; +} +template +HWY_API Vec128 ShuffleTwo3012(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToSigned di; + return BitCast(d, Vec128{__lsx_vpermi_w(BitCast(di, b).raw, + BitCast(di, a).raw, 0xC6)}); +} + +} // namespace detail + +// Swap 64-bit halves +template +HWY_API Vec128 Shuffle1032(const Vec128 v) { + const DFromV d; + return BitCast(d, Vec128{__lsx_vshuf4i_w( + reinterpret_cast<__m128i>(v.raw), 0x4E)}); +} +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{__lsx_vshuf4i_w(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{__lsx_vshuf4i_w(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle01(const Vec128 v) { + const DFromV d; + return BitCast(d, Vec128{__lsx_vshuf4i_d( + reinterpret_cast<__m128i>(v.raw), + reinterpret_cast<__m128i>(v.raw), 0x1)}); +} + +// Rotate right 32 bits +template +HWY_API Vec128 Shuffle0321(const Vec128 v) { + const DFromV d; + return BitCast(d, Vec128{__lsx_vshuf4i_w( + reinterpret_cast<__m128i>(v.raw), 0x39)}); +} +// Rotate left 32 bits +template +HWY_API Vec128 Shuffle2103(const Vec128 v) { + const DFromV d; + return BitCast(d, Vec128{__lsx_vshuf4i_w( + reinterpret_cast<__m128i>(v.raw), 0x93)}); +} +// Reverse +template +HWY_API Vec128 Shuffle0123(const Vec128 v) { + const DFromV d; + return BitCast(d, Vec128{__lsx_vshuf4i_w( + reinterpret_cast<__m128i>(v.raw), 0x1B)}); +} + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API MFromD RebindMask(DTo dto, Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TFromD), "Must have same size"); + const Simd d; + return MaskFromVec(BitCast(dto, VecFromMask(d, m))); +} + +// ================================================== COMPARE + +template +HWY_API Mask128 TestBit(Vec128 v, Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Equality + +// Unsigned +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{__lsx_vseq_b(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{__lsx_vseq_h(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{__lsx_vseq_w(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{__lsx_vseq_d(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{__lsx_vseq_b(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{__lsx_vseq_h(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{__lsx_vseq_w(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{__lsx_vseq_d(a.raw, b.raw)}; +} + +// Float +template +HWY_API Mask128 operator==(Vec128 a, Vec128 b) { + return Mask128{ + reinterpret_cast<__m128>(__lsx_vfcmp_ceq_s(a.raw, b.raw))}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{ + reinterpret_cast<__m128d>(__lsx_vfcmp_ceq_d(a.raw, b.raw))}; +} + +// ------------------------------ Inequality + +// This cannot have T as a template argument, otherwise it is not more +// specialized than rewritten operator== in C++20, leading to compile +// errors: https://gcc.godbolt.org/z/xsrPhPvPT. +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} + +template +HWY_API Mask128 operator!=(Vec128 a, Vec128 b) { + return Mask128{ + reinterpret_cast<__m128>(__lsx_vfcmp_cune_s(a.raw, b.raw))}; +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Mask128{ + reinterpret_cast<__m128d>(__lsx_vfcmp_cune_d(a.raw, b.raw))}; +} + +// ------------------------------ Strict inequality + +namespace detail { + +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{__lsx_vslt_b(b.raw, a.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{__lsx_vslt_h(b.raw, a.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{__lsx_vslt_w(b.raw, a.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, + const Vec128 a, + const Vec128 b) { + return Mask128{__lsx_vslt_d(b.raw, a.raw)}; +} + +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{__lsx_vslt_b(b.raw, a.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, + Vec128 a, + Vec128 b) { + return Mask128{__lsx_vslt_h(b.raw, a.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, + Vec128 a, + Vec128 b) { + return Mask128{__lsx_vslt_w(b.raw, a.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, + const Vec128 a, + const Vec128 b) { + return Mask128{__lsx_vslt_d(b.raw, a.raw)}; +} + +template +HWY_INLINE Mask128 Gt(hwy::UnsignedTag /*tag*/, Vec128 a, + Vec128 b) { + const DFromV du; + const RebindToSigned di; + const Vec128 msb = Set(du, (LimitsMax() >> 1) + 1); + const auto sa = BitCast(di, Xor(a, msb)); + const auto sb = BitCast(di, Xor(b, msb)); + return RebindMask(du, Gt(hwy::SignedTag(), sa, sb)); +} + +template +HWY_INLINE Mask128 Gt(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{ + reinterpret_cast<__m128>(__lsx_vfcmp_clt_s(b.raw, a.raw))}; +} +template +HWY_INLINE Mask128 Gt(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{ + reinterpret_cast<__m128d>(__lsx_vfcmp_clt_d(b.raw, a.raw))}; +} + +} // namespace detail + +template +HWY_INLINE Mask128 operator>(Vec128 a, Vec128 b) { + return detail::Gt(hwy::TypeTag(), a, b); +} + +// ------------------------------ Weak inequality + +namespace detail { +template +HWY_INLINE Mask128 Ge(hwy::SignedTag tag, Vec128 a, + Vec128 b) { + return Not(Gt(tag, b, a)); +} + +template +HWY_INLINE Mask128 Ge(hwy::UnsignedTag tag, Vec128 a, + Vec128 b) { + return Not(Gt(tag, b, a)); +} + +template +HWY_INLINE Mask128 Ge(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{ + reinterpret_cast<__m128>(__lsx_vfcmp_cle_s(b.raw, a.raw))}; +} +template +HWY_INLINE Mask128 Ge(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{ + reinterpret_cast<__m128d>(__lsx_vfcmp_cle_d(b.raw, a.raw))}; +} + +} // namespace detail + +template +HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { + return detail::Ge(hwy::TypeTag(), a, b); +} + +// ------------------------------ Reversed comparisons + +template +HWY_API Mask128 operator<(Vec128 a, Vec128 b) { + return b > a; +} + +template +HWY_API Mask128 operator<=(Vec128 a, Vec128 b) { + return b >= a; +} + +// ------------------------------ Iota (Load) + +namespace detail { + +template +HWY_INLINE VFromD Iota0(D d) { + return Dup128VecFromValues( + d, TFromD{0}, TFromD{1}, TFromD{2}, TFromD{3}, TFromD{4}, + TFromD{5}, TFromD{6}, TFromD{7}, TFromD{8}, TFromD{9}, + TFromD{10}, TFromD{11}, TFromD{12}, TFromD{13}, TFromD{14}, + TFromD{15}); +} + +template +HWY_INLINE VFromD Iota0(D d) { + return Dup128VecFromValues(d, TFromD{0}, TFromD{1}, TFromD{2}, + TFromD{3}, TFromD{4}, TFromD{5}, + TFromD{6}, TFromD{7}); +} + +template +HWY_INLINE VFromD Iota0(D d) { + return Dup128VecFromValues( + d, static_cast>(0), static_cast>(1), + static_cast>(2), static_cast>(3)); +} + +template +HWY_INLINE VFromD Iota0(D d) { + return Dup128VecFromValues(d, static_cast>(0), + static_cast>(1)); +} + +} // namespace detail + +template +HWY_API VFromD Iota(D d, const T2 first) { + const auto result_iota = + detail::Iota0(d) + Set(d, static_cast>(first)); + return result_iota; +} + +// ------------------------------ FirstN (Iota, Lt) + +template +HWY_API MFromD FirstN(D d, size_t num) { + const RebindToSigned di; // Signed comparisons are cheaper. + using TI = TFromD; + return RebindMask(d, detail::Iota0(di) < Set(di, static_cast(num))); +} + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + return Vec128{__lsx_vilvl_b(b.raw, a.raw)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + return Vec128{__lsx_vilvl_h(b.raw, a.raw)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + return Vec128{__lsx_vilvl_w(b.raw, a.raw)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + return Vec128{__lsx_vilvl_d(b.raw, a.raw)}; +} + +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{reinterpret_cast<__m128>(__lsx_vilvl_w( + reinterpret_cast<__m128i>(b.raw), reinterpret_cast<__m128i>(a.raw)))}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{reinterpret_cast<__m128d>(__lsx_vilvl_d( + reinterpret_cast<__m128i>(b.raw), reinterpret_cast<__m128i>(a.raw)))}; +} + +// Generic for all vector lengths. +template +HWY_API VFromD InterleaveLower(D /* tag */, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// ------------------------------ BlendedStore + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + StoreU(IfThenElse(m, v, LoadU(d, p)), d, p); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vadd_b(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vadd_h(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vadd_w(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vadd_d(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vadd_b(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vadd_h(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vadd_w(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vadd_d(a.raw, b.raw)}; +} + +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vfadd_s(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vfadd_d(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vsub_b(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(Vec128 a, + Vec128 b) { + return Vec128{__lsx_vsub_h(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vsub_w(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vsub_d(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vsub_b(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vsub_h(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vsub_w(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vsub_d(a.raw, b.raw)}; +} + +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vfsub_s(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vfsub_d(a.raw, b.raw)}; +} + +// ------------------------------ SumsOf2 +namespace detail { + +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + return VFromD>>{__lsx_vhaddw_h_b(v.raw, v.raw)}; +} +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + return VFromD>>{__lsx_vhaddw_hu_bu(v.raw, v.raw)}; +} +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + return VFromD>>{__lsx_vhaddw_w_h(v.raw, v.raw)}; +} +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + return VFromD>>{__lsx_vhaddw_wu_hu(v.raw, v.raw)}; +} +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag, hwy::SizeTag<4> /*lane_size_tag*/, V v) { + return VFromD>>{__lsx_vhaddw_d_w(v.raw, v.raw)}; +} +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag, hwy::SizeTag<4> /*lane_size_tag*/, V v) { + return VFromD>>{__lsx_vhaddw_du_wu(v.raw, v.raw)}; +} + +} // namespace detail + +// ------------------------------ SumsOf8 +template +HWY_API Vec128 SumsOf8(const Vec128 v) { + __m128i temp = __lsx_vhaddw_hu_bu(v.raw, v.raw); + temp = __lsx_vhaddw_wu_hu(temp, temp); + return Vec128{__lsx_vhaddw_du_wu(temp, temp)}; +} +template +HWY_API Vec128 SumsOf8(const Vec128 v) { + __m128i temp = __lsx_vhaddw_h_b(v.raw, v.raw); + temp = __lsx_vhaddw_w_h(temp, temp); + return Vec128{__lsx_vhaddw_d_w(temp, temp)}; +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +#ifdef HWY_NATIVE_I32_SATURATED_ADDSUB +#undef HWY_NATIVE_I32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I32_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_I64_SATURATED_ADDSUB +#undef HWY_NATIVE_I64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I64_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_U32_SATURATED_ADDSUB +#undef HWY_NATIVE_U32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_U32_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_U64_SATURATED_ADDSUB +#undef HWY_NATIVE_U64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_U64_SATURATED_ADDSUB +#endif + +// Unsigned +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vsadd_bu(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vsadd_hu(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vsadd_wu(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vsadd_du(a.raw, b.raw)}; +} + +// signed +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vsadd_b(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vsadd_h(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vsadd_w(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vsadd_d(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vssub_bu(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vssub_hu(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vssub_wu(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vssub_du(a.raw, b.raw)}; +} + +// signed +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vssub_b(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vssub_h(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vssub_w(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vssub_d(a.raw, b.raw)}; +} + +// ------------------------------ AverageRound + +// Returns (a + b + 1) / 2 + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 +#undef HWY_NATIVE_AVERAGE_ROUND_UI32 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI32 +#endif + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64 +#undef HWY_NATIVE_AVERAGE_ROUND_UI64 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI64 +#endif + +// Unsigned +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vavgr_bu(a.raw, b.raw)}; +} +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vavgr_hu(a.raw, b.raw)}; +} +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vavgr_wu(a.raw, b.raw)}; +} +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vavgr_du(a.raw, b.raw)}; +} + +// signed +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vavgr_b(a.raw, b.raw)}; +} +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vavgr_h(a.raw, b.raw)}; +} +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vavgr_w(a.raw, b.raw)}; +} +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vavgr_d(a.raw, b.raw)}; +} + +// ------------------------------ Integer/Float multiplication + +// Per-target flags to prevent generic_ops-inl.h defining 8/64-bit operator*. +#ifdef HWY_NATIVE_MUL_8 +#undef HWY_NATIVE_MUL_8 +#else +#define HWY_NATIVE_MUL_8 +#endif +#ifdef HWY_NATIVE_MUL_64 +#undef HWY_NATIVE_MUL_64 +#else +#define HWY_NATIVE_MUL_64 +#endif + +template +HWY_API Vec128 operator*(const Vec128 a, const Vec128 b) { + return Vec128{__lsx_vmul_b(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, const Vec128 b) { + return Vec128{__lsx_vmul_h(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, const Vec128 b) { + return Vec128{__lsx_vmul_w(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, const Vec128 b) { + return Vec128{__lsx_vmul_d(a.raw, b.raw)}; +} + +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vfmul_s(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vfmul_d(a.raw, b.raw)}; +} + +// ------------------------------ MulHigh + +// Usigned +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vmuh_bu(a.raw, b.raw)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vmuh_hu(a.raw, b.raw)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vmuh_wu(a.raw, b.raw)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vmuh_du(a.raw, b.raw)}; +} + +// signed +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vmuh_b(a.raw, b.raw)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vmuh_h(a.raw, b.raw)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vmuh_w(a.raw, b.raw)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vmuh_d(a.raw, b.raw)}; +} + +// ------------------------------ MulEven + +template +HWY_API Vec128 MulEven(Vec128 a, + Vec128 b) { + return Vec128{__lsx_vmulwev_h_b(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MulEven(Vec128 a, + Vec128 b) { + return Vec128{__lsx_vmulwev_h_bu(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MulEven(Vec128 a, + Vec128 b) { + return Vec128{__lsx_vmulwev_w_h(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MulEven(Vec128 a, + Vec128 b) { + return Vec128{__lsx_vmulwev_w_hu(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MulEven(Vec128 a, + Vec128 b) { + return Vec128{__lsx_vmulwev_d_w(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MulEven(Vec128 a, + Vec128 b) { + return Vec128{__lsx_vmulwev_d_wu(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + return Vec128{__lsx_vmulwev_q_d(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + return Vec128{__lsx_vmulwev_q_du(a.raw, b.raw)}; +} + +// ------------------------------ MulOdd + +template +HWY_API Vec128 MulOdd(Vec128 a, + Vec128 b) { + return Vec128{__lsx_vmulwod_h_b(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MulOdd(Vec128 a, + Vec128 b) { + return Vec128{__lsx_vmulwod_h_bu(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MulOdd(Vec128 a, + Vec128 b) { + return Vec128{__lsx_vmulwod_w_h(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MulOdd(Vec128 a, + Vec128 b) { + return Vec128{__lsx_vmulwod_w_hu(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MulOdd(Vec128 a, + Vec128 b) { + return Vec128{__lsx_vmulwod_d_w(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MulOdd(Vec128 a, + Vec128 b) { + return Vec128{__lsx_vmulwod_d_wu(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + return Vec128{__lsx_vmulwod_q_d(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + return Vec128{__lsx_vmulwod_q_du(a.raw, b.raw)}; +} + +// ------------------------------ RotateRight (ShiftRight, Or) + +template +HWY_API Vec128 RotateRight(const Vec128 v) { + return Vec128{__lsx_vrotri_b(v.raw, kBits)}; +} +template +HWY_API Vec128 RotateRight(const Vec128 v) { + return Vec128{__lsx_vrotri_h(v.raw, kBits)}; +} +template +HWY_API Vec128 RotateRight(const Vec128 v) { + return Vec128{__lsx_vrotri_w(v.raw, kBits)}; +} +template +HWY_API Vec128 RotateRight(const Vec128 v) { + return Vec128{__lsx_vrotri_d(v.raw, kBits)}; +} + +// ------------------------------ Ror +#ifdef HWY_NATIVE_ROL_ROR_8 +#undef HWY_NATIVE_ROL_ROR_8 +#else +#define HWY_NATIVE_ROL_ROR_8 +#endif + +#ifdef HWY_NATIVE_ROL_ROR_16 +#undef HWY_NATIVE_ROL_ROR_16 +#else +#define HWY_NATIVE_ROL_ROR_16 +#endif + +#ifdef HWY_NATIVE_ROL_ROR_32_64 +#undef HWY_NATIVE_ROL_ROR_32_64 +#else +#define HWY_NATIVE_ROL_ROR_32_64 +#endif + +template +HWY_API Vec128 Ror(Vec128 a, Vec128 b) { + return Vec128{__lsx_vrotr_b(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Ror(Vec128 a, Vec128 b) { + return Vec128{__lsx_vrotr_h(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Ror(Vec128 a, Vec128 b) { + return Vec128{__lsx_vrotr_w(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Ror(Vec128 a, Vec128 b) { + return Vec128{__lsx_vrotr_d(a.raw, b.raw)}; +} + +// Rol is generic for all vector lengths +template +HWY_API V Rol(V a, V b) { + const DFromV d; + const RebindToSigned di; + + return Ror(a, BitCast(d, Neg(BitCast(di, b)))); +} + +// ------------------------------ RotateLeftSame/RotateRightSame + +#ifdef HWY_NATIVE_ROL_ROR_SAME_8 +#undef HWY_NATIVE_ROL_ROR_SAME_8 +#else +#define HWY_NATIVE_ROL_ROR_SAME_8 +#endif + +#ifdef HWY_NATIVE_ROL_ROR_SAME_16 +#undef HWY_NATIVE_ROL_ROR_SAME_16 +#else +#define HWY_NATIVE_ROL_ROR_SAME_16 +#endif + +#ifdef HWY_NATIVE_ROL_ROR_SAME_32_64 +#undef HWY_NATIVE_ROL_ROR_SAME_32_64 +#else +#define HWY_NATIVE_ROL_ROR_SAME_32_64 +#endif + +// RotateLeftSame/RotateRightSame are generic for all vector lengths +template +HWY_API V RotateLeftSame(V v, int bits) { + using T = TFromV; + const DFromV d; + return Rol(v, Set(d, static_cast(bits))); +} + +template +HWY_API V RotateRightSame(V v, int bits) { + using T = TFromV; + const DFromV d; + return Ror(v, Set(d, static_cast(bits))); +} + +// ------------------------------ BroadcastSignBit + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight(v); +} + +// ------------------------------ Integer Abs + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{__lsx_vabsd_b(v.raw, __lsx_vreplgr2vr_b(0))}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{__lsx_vabsd_h(v.raw, __lsx_vreplgr2vr_b(0))}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{__lsx_vabsd_w(v.raw, __lsx_vreplgr2vr_b(0))}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{__lsx_vabsd_d(v.raw, __lsx_vreplgr2vr_b(0))}; +} + +// ------------------------------ SaturatedAbs + +#ifdef HWY_NATIVE_SATURATED_ABS +#undef HWY_NATIVE_SATURATED_ABS +#else +#define HWY_NATIVE_SATURATED_ABS +#endif + +template )> +HWY_API V SaturatedAbs(V v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, Min(BitCast(du, v), BitCast(du, SaturatedSub(Zero(d), v)))); +} +template )> +HWY_API V SaturatedAbs(V v) { + return Max(v, SaturatedSub(Zero(DFromV()), v)); +} +template )> +HWY_API V SaturatedAbs(V v) { + const auto abs_v = Abs(v); + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, Min(BitCast(du, abs_v), + Set(du, static_cast(LimitsMax())))); +} +template )> +HWY_API V SaturatedAbs(V v) { + const auto abs_v = Abs(v); + return Add(abs_v, BroadcastSignBit(abs_v)); +} + +// ------------------------------ IfNegativeThenElse +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const DFromV d; + const RebindToSigned di; + + Mask128 m = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); + return IfThenElse(m, yes, no); +} + +// ------------------------------ IfNegativeThenNegOrUndefIfZero + +#ifdef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#undef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#else +#define HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#endif + +template +HWY_API Vec128 IfNegativeThenNegOrUndefIfZero(Vec128 mask, + Vec128 v) { + return Vec128{__lsx_vsigncov_b(mask.raw, v.raw)}; +} + +template +HWY_API Vec128 IfNegativeThenNegOrUndefIfZero( + Vec128 mask, Vec128 v) { + return Vec128{__lsx_vsigncov_h(mask.raw, v.raw)}; +} + +template +HWY_API Vec128 IfNegativeThenNegOrUndefIfZero( + Vec128 mask, Vec128 v) { + return Vec128{__lsx_vsigncov_w(mask.raw, v.raw)}; +} + +template +HWY_API Vec128 IfNegativeThenNegOrUndefIfZero( + Vec128 mask, Vec128 v) { + return Vec128{__lsx_vsigncov_d(mask.raw, v.raw)}; +} + +// ------------------------------ ShiftLeftSame/ShiftRightSame + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, int bits) { + return v << Set(DFromV(), static_cast(bits)); +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, int bits) { + return v >> Set(DFromV(), static_cast(bits)); +} + +// ------------------------------ Integer/Float Div + +#ifdef HWY_NATIVE_INT_DIV +#undef HWY_NATIVE_INT_DIV +#else +#define HWY_NATIVE_INT_DIV +#endif + +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + // or a[i] == LimitsMin() && b[i] == -1 + __m128i raw_result; + __asm__("vdiv.b %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + __m128i raw_result; + __asm__("vdiv.bu %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + // or a[i] == LimitsMin() && b[i] == -1 + __m128i raw_result; + __asm__("vdiv.h %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + __m128i raw_result; + __asm__("vdiv.hu %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + // or a[i] == LimitsMin() && b[i] == -1 + __m128i raw_result; + __asm__("vdiv.w %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + __m128i raw_result; + __asm__("vdiv.wu %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + // or a[i] == LimitsMin() && b[i] == -1 + __m128i raw_result; + __asm__("vdiv.d %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + __m128i raw_result; + __asm__("vdiv.du %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vfdiv_s(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{__lsx_vfdiv_d(a.raw, b.raw)}; +} + +// ------------------------------ Integer Mod + +template +HWY_API Vec128 operator%(const Vec128 a, + const Vec128 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + // or a[i] == LimitsMin() && b[i] == -1 + __m128i raw_result; + __asm__("vmod.b %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator%(const Vec128 a, + const Vec128 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + __m128i raw_result; + __asm__("vmod.bu %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator%(const Vec128 a, + const Vec128 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + // or a[i] == LimitsMin() && b[i] == -1 + __m128i raw_result; + __asm__("vmod.h %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator%(const Vec128 a, + const Vec128 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + __m128i raw_result; + __asm__("vmod.hu %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator%(const Vec128 a, + const Vec128 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + // or a[i] == LimitsMin() && b[i] == -1 + __m128i raw_result; + __asm__("vmod.w %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator%(const Vec128 a, + const Vec128 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + __m128i raw_result; + __asm__("vmod.wu %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator%(const Vec128 a, + const Vec128 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + // or a[i] == LimitsMin() && b[i] == -1 + __m128i raw_result; + __asm__("vmod.d %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator%(const Vec128 a, + const Vec128 b) { + // Use inline assembly to avoid undefined behavior if any lanes of b are zero + __m128i raw_result; + __asm__("vmod.du %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :); + return Vec128{raw_result}; +} + +// ------------------------------ ApproximateReciprocal + +#ifdef HWY_NATIVE_F64_APPROX_RECIP +#undef HWY_NATIVE_F64_APPROX_RECIP +#else +#define HWY_NATIVE_F64_APPROX_RECIP +#endif + +template +HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { + return Vec128{__lsx_vfrecip_s(v.raw)}; +} +template +HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { + return Vec128{__lsx_vfrecip_d(v.raw)}; +} + +// ------------------------------ Absolute value of difference + +#ifdef HWY_NATIVE_INTEGER_ABS_DIFF +#undef HWY_NATIVE_INTEGER_ABS_DIFF +#else +#define HWY_NATIVE_INTEGER_ABS_DIFF +#endif + +template +HWY_API Vec128 AbsDiff(const Vec128 a, + Vec128 b) { + return Vec128{__lsx_vabsd_b(a.raw, b.raw)}; +} +template +HWY_API Vec128 AbsDiff(const Vec128 a, + Vec128 b) { + return Vec128{__lsx_vabsd_h(a.raw, b.raw)}; +} +template +HWY_API Vec128 AbsDiff(const Vec128 a, + Vec128 b) { + return Vec128{__lsx_vabsd_w(a.raw, b.raw)}; +} +template +HWY_API Vec128 AbsDiff(const Vec128 a, + Vec128 b) { + return Vec128{__lsx_vabsd_d(a.raw, b.raw)}; +} + +template +HWY_API Vec128 AbsDiff(const Vec128 a, + Vec128 b) { + return Vec128{__lsx_vabsd_bu(a.raw, b.raw)}; +} +template +HWY_API Vec128 AbsDiff(const Vec128 a, + Vec128 b) { + return Vec128{__lsx_vabsd_hu(a.raw, b.raw)}; +} +template +HWY_API Vec128 AbsDiff(const Vec128 a, + Vec128 b) { + return Vec128{__lsx_vabsd_wu(a.raw, b.raw)}; +} +template +HWY_API Vec128 AbsDiff(const Vec128 a, + Vec128 b) { + return Vec128{__lsx_vabsd_du(a.raw, b.raw)}; +} + +// Generic for all vector lengths. +template +HWY_API V AbsDiff(V a, V b) { + return Abs(a - b); +} + +// ------------------------------ Integer/Float multiply-add + +#ifdef HWY_NATIVE_INT_FMA +#undef HWY_NATIVE_INT_FMA +#else +#define HWY_NATIVE_INT_FMA +#endif + +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return Vec128{__lsx_vmadd_b(add.raw, mul.raw, x.raw)}; +} +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return Vec128{__lsx_vmadd_h(add.raw, mul.raw, x.raw)}; +} +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return Vec128{__lsx_vmadd_w(add.raw, mul.raw, x.raw)}; +} +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return Vec128{__lsx_vmadd_d(add.raw, mul.raw, x.raw)}; +} + +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return Vec128{__lsx_vfmadd_s(mul.raw, x.raw, add.raw)}; +} +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return Vec128{__lsx_vfmadd_d(mul.raw, x.raw, add.raw)}; +} + +// Unsinged +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return mul * x + add; +} + +// ------------------------------ Integer/Float NegMulAdd + +template +HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return Vec128{__lsx_vmsub_b(add.raw, mul.raw, x.raw)}; +} +template +HWY_API Vec128 NegMulAdd(Vec128 mul, + Vec128 x, + Vec128 add) { + return Vec128{__lsx_vmsub_h(add.raw, mul.raw, x.raw)}; +} +template +HWY_API Vec128 NegMulAdd(Vec128 mul, + Vec128 x, + Vec128 sub) { + return Vec128{__lsx_vmsub_w(sub.raw, mul.raw, x.raw)}; +} +template +HWY_API Vec128 NegMulAdd(Vec128 mul, + Vec128 x, + Vec128 sub) { + return Vec128{__lsx_vmsub_d(sub.raw, mul.raw, x.raw)}; +} + +// Float/unsigned +template +HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return add - mul * x; +} + +// ------------------------------ Float MulSub + +// float +template +HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, + Vec128 sub) { + return Vec128{__lsx_vfmsub_s(x.raw, mul.raw, sub.raw)}; +} +template +HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, + Vec128 sub) { + return Vec128{__lsx_vfmsub_d(x.raw, mul.raw, sub.raw)}; +} + +// unsigned +template +HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, + Vec128 sub) { + return mul * x - sub; +} + +// ------------------------------ Float NegMulSub + +// float/unsigned +template +HWY_API Vec128 NegMulSub(Vec128 mul, Vec128 x, + Vec128 sub) { + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +template +HWY_API Vec128 Sqrt(Vec128 v) { + return Vec128{__lsx_vfsqrt_s(v.raw)}; +} +template +HWY_API Vec128 Sqrt(Vec128 v) { + return Vec128{__lsx_vfsqrt_d(v.raw)}; +} + +// ------------------------------ ApproximateReciprocalSqrt +#ifdef HWY_NATIVE_F64_APPROX_RSQRT +#undef HWY_NATIVE_F64_APPROX_RSQRT +#else +#define HWY_NATIVE_F64_APPROX_RSQRT +#endif + +template +HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { + return Vec128{__lsx_vfrsqrt_s(v.raw)}; +} +template +HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { + return Vec128{__lsx_vfrsqrt_d(v.raw)}; +} + +// ------------------------------ Min + +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{__lsx_vmin_bu(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{__lsx_vmin_hu(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{__lsx_vmin_wu(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{__lsx_vmin_du(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{__lsx_vmin_b(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{__lsx_vmin_h(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{__lsx_vmin_w(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{__lsx_vmin_d(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{__lsx_vfmin_s(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{__lsx_vfmin_d(a.raw, b.raw)}; +} + +// ------------------------------ Max + +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{__lsx_vmax_bu(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{__lsx_vmax_hu(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{__lsx_vmax_wu(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{__lsx_vmax_du(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{__lsx_vmax_b(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{__lsx_vmax_h(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{__lsx_vmax_w(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{__lsx_vmax_d(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{__lsx_vfmax_s(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{__lsx_vfmax_d(a.raw, b.raw)}; +} + +// ------------------------------ MinMagnitude and MaxMagnitude + +#ifdef HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE +#undef HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE +#else +#define HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE +#endif + +template +HWY_API Vec128 MinMagnitude(Vec128 a, Vec128 b) { + return Vec128{__lsx_vfmina_s(a.raw, b.raw)}; +} +template +HWY_API Vec128 MinMagnitude(Vec128 a, + Vec128 b) { + return Vec128{__lsx_vfmina_d(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaxMagnitude(Vec128 a, Vec128 b) { + return Vec128{__lsx_vfmaxa_s(a.raw, b.raw)}; +} +template +HWY_API Vec128 MaxMagnitude(Vec128 a, + Vec128 b) { + return Vec128{__lsx_vfmaxa_d(a.raw, b.raw)}; +} + +// ------------------------------ Non-temporal stores + +// Same as aligned stores on non-x86. + +template +HWY_API void Stream(const VFromD v, D d, TFromD* HWY_RESTRICT aligned) { + __builtin_prefetch(aligned, 1, 0); + Store(v, d, aligned); +} + +// ------------------------------ Scatter in generic_ops-inl.h +// ------------------------------ Gather in generic_ops-inl.h + +// ================================================== SWIZZLE (2) + +// ------------------------------ LowerHalf + +template +HWY_API VFromD LowerHalf(D /* tag */, VFromD> v) { + return VFromD{v.raw}; +} +template +HWY_API Vec128 LowerHalf(Vec128 v) { + return Vec128{v.raw}; +} + +// ------------------------------ ShiftLeftBytes + +template +HWY_API VFromD ShiftLeftBytes(D d, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + if (kBytes == 0) return v; + const RebindToUnsigned du; + return BitCast( + d, VFromD{__lsx_vbsll_v(BitCast(du, v).raw, kBytes)}); +} + +// Generic for all vector lengths. +template +HWY_API V ShiftLeftBytes(const V v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftLeftLanes + +// Generic for all vector lengths. +template +HWY_API VFromD ShiftLeftLanes(D d, const VFromD v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes)>(BitCast(d8, v))); +} + +// Generic for all vector lengths. +template +HWY_API V ShiftLeftLanes(const V v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes +template +HWY_API VFromD ShiftRightBytes(D d, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + if (kBytes == 0) return v; + const RebindToUnsigned du; + // For partial vectors, clear upper lanes so we shift in zeros. + if (d.MaxBytes() != 16) { + const Full128> dfull; + const VFromD vfull{v.raw}; + v = VFromD{IfThenElseZero(FirstN(dfull, MaxLanes(d)), vfull).raw}; + } + return BitCast( + d, VFromD{__lsx_vbsrl_v(BitCast(du, v).raw, kBytes)}); +} + +// ------------------------------ ShiftRightLanes +// Generic for all vector lengths. +template +HWY_API VFromD ShiftRightLanes(D d, const VFromD v) { + const Repartition d8; + constexpr size_t kBytes = kLanes * sizeof(TFromD); + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +template +HWY_API VFromD UpperHalf(D d, VFromD> v) { + const Twice> dut; + using VUT = VFromD; // for float16_t + const VUT vut = BitCast(dut, v); + return BitCast(d, LowerHalf(VUT{__lsx_vilvh_d(vut.raw, vut.raw)})); +} + +// Partial +template +HWY_API VFromD UpperHalf(D d, VFromD> v) { + return LowerHalf(d, ShiftRightBytes(Twice(), v)); +} + +// ------------------------------ ExtractLane (UpperHalf) + +namespace detail { + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); + return static_cast(__lsx_vpickve2gr_b(v.raw, kLane) & 0xFF); +} + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); + const DFromV d; + const RebindToUnsigned du; + const uint16_t lane = static_cast( + __lsx_vpickve2gr_hu(BitCast(du, v).raw, kLane) & 0xFFFF); + return BitCastScalar(lane); +} + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); + return static_cast(__lsx_vpickve2gr_w(v.raw, kLane)); +} + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); + return static_cast(__lsx_vpickve2gr_d(v.raw, kLane)); +} + +template +HWY_INLINE float ExtractLane(const Vec128 v) { + float f32; + int32_t i32 = __lsx_vpickve2gr_w(reinterpret_cast<__m128i>(v.raw), kLane); + CopyBytes<4>(&i32, &f32); + return f32; +} +template +HWY_INLINE double ExtractLane(const Vec128 v) { + double f64; + int64_t i64 = __lsx_vpickve2gr_d(reinterpret_cast<__m128i>(v.raw), kLane); + CopyBytes<8>(&i64, &f64); + return f64; +} + +} // namespace detail + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return GetLane(v); +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + } + } +#endif + alignas(16) T lanes[2]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + } + } +#endif + alignas(16) T lanes[4]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + } + } +#endif + alignas(16) T lanes[8]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + case 8: + return detail::ExtractLane<8>(v); + case 9: + return detail::ExtractLane<9>(v); + case 10: + return detail::ExtractLane<10>(v); + case 11: + return detail::ExtractLane<11>(v); + case 12: + return detail::ExtractLane<12>(v); + case 13: + return detail::ExtractLane<13>(v); + case 14: + return detail::ExtractLane<14>(v); + case 15: + return detail::ExtractLane<15>(v); + } + } +#endif + alignas(16) T lanes[16]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane (UpperHalf) + +namespace detail { + +template +HWY_INLINE V InsertLaneUsingBroadcastAndBlend(V v, size_t i, TFromV t) { + const DFromV d; + +#if HWY_TARGET <= HWY_AVX3 + using RawMask = decltype(MaskFromVec(VFromD()).raw); + const auto mask = MFromD{static_cast(uint64_t{1} << i)}; +#else + const RebindToUnsigned du; + using TU = TFromD; + const auto mask = RebindMask(d, Iota(du, 0) == Set(du, static_cast(i))); +#endif + + return IfThenElse(mask, Set(d, t), v); +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{__lsx_vinsgr2vr_b(v.raw, t, kLane)}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + const DFromV d; + const RebindToUnsigned du; + const uint16_t bits = BitCastScalar(t); + return BitCast(d, VFromD{ + __lsx_vinsgr2vr_h(BitCast(du, v).raw, bits, kLane)}); +} +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{__lsx_vinsgr2vr_w(v.raw, t, kLane)}; +} +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{__lsx_vinsgr2vr_d(v.raw, t, kLane)}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, float t) { + static_assert(kLane < N, "Lane index out of bounds"); + const DFromV d; + int ti = BitCastScalar(t); + RebindToUnsigned du; + return BitCast(d, VFromD{__lsx_vinsgr2vr_w( + reinterpret_cast<__m128i>(v.raw), ti, kLane)}); +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, double t) { + static_assert(kLane < 2, "Lane index out of bounds"); + const DFromV d; + long int ti = BitCastScalar(t); + RebindToUnsigned du; + return BitCast(d, VFromD{__lsx_vinsgr2vr_d( + reinterpret_cast<__m128i>(v.raw), ti, kLane)}); +} + +} // namespace detail + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + return Set(DFromV(), t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + } + } +#endif + return detail::InsertLaneUsingBroadcastAndBlend(v, i, t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + } + } +#endif + return detail::InsertLaneUsingBroadcastAndBlend(v, i, t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + } + } +#endif + return detail::InsertLaneUsingBroadcastAndBlend(v, i, t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + case 8: + return detail::InsertLane<8>(v, t); + case 9: + return detail::InsertLane<9>(v, t); + case 10: + return detail::InsertLane<10>(v, t); + case 11: + return detail::InsertLane<11>(v, t); + case 12: + return detail::InsertLane<12>(v, t); + case 13: + return detail::InsertLane<13>(v, t); + case 14: + return detail::InsertLane<14>(v, t); + case 15: + return detail::InsertLane<15>(v, t); + } + } +#endif + return detail::InsertLaneUsingBroadcastAndBlend(v, i, t); +} + +// ------------------------------ CombineShiftRightBytes +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + static_assert(0 < kBytes && kBytes < 16, "kBytes invalid"); + return Or(ShiftRightBytes(d, lo), ShiftLeftBytes<16 - kBytes>(d, hi)); +} +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + constexpr size_t kSize = d.MaxBytes(); + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); + + const Twice dt; + return VFromD{ShiftRightBytes(dt, Combine(dt, hi, lo)).raw}; +} + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{__lsx_vreplvei_b(v.raw, kLane)}; +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{__lsx_vreplvei_h(v.raw, kLane)}; +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + const DFromV d; + return BitCast(d, Vec128{__lsx_vreplvei_w( + reinterpret_cast<__m128i>(v.raw), kLane)}); +} +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + const DFromV d; + return BitCast(d, Vec128{__lsx_vreplvei_d( + reinterpret_cast<__m128i>(v.raw), kLane)}); +} + +// ------------------------------ TableLookupLanes (Shuffle01) + +// Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. +template +struct Indices128 { + __m128i raw; +}; + +namespace detail { + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; + return Iota(d8, 0); +} + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14}; + return Load(d8, kBroadcastLaneBytes); +} + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12}; + return Load(d8, kBroadcastLaneBytes); +} + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 8, 8}; + return Load(d8, kBroadcastLaneBytes); +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + return Zero(d8); +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kByteOffsets[16] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}; + return Load(d8, kByteOffsets); +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kByteOffsets[16] = { + 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; + return Load(d8, kByteOffsets); +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kByteOffsets[16] = { + 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7}; + return Load(d8, kByteOffsets); +} + +} // namespace detail + +template +HWY_API Indices128, MaxLanes(D())> IndicesFromVec( + D d, Vec128 vec) { + using T = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const RebindToUnsigned du; + using TU = TFromD; + HWY_DASSERT(AllTrue( + du, Lt(BitCast(du, vec), Set(du, static_cast(MaxLanes(d) * 2))))); +#endif + + (void)d; + return Indices128, MaxLanes(D())>{BitCast(d, vec).raw}; +} + +template +HWY_API Indices128, MaxLanes(D())> IndicesFromVec( + D d, Vec128 vec) { + using T = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const RebindToUnsigned du; + using TU = TFromD; + HWY_DASSERT(AllTrue( + du, Lt(BitCast(du, vec), Set(du, static_cast(MaxLanes(d) * 2))))); +#endif + + const Repartition d8; + using V8 = VFromD; + + // Broadcast each lane index to all bytes of T and shift to bytes + const V8 lane_indices = TableLookupBytes( + BitCast(d8, vec), detail::IndicesFromVecBroadcastLaneBytes(d)); + constexpr int kIndexShiftAmt = static_cast(FloorLog2(sizeof(T))); + const V8 byte_indices = ShiftLeft(lane_indices); + const V8 sum = Add(byte_indices, detail::IndicesFromVecByteOffsets(d)); + return Indices128, MaxLanes(D())>{sum.raw}; +} + +template +HWY_API Indices128, MaxLanes(D())> SetTableIndices(D d, + const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + using TI = MakeSigned; + const DFromV d; + const Rebind di; + auto t1 = TableLookupBytes(BitCast(di, v), Vec128{idx.raw}); + return BitCast(d, t1); +} + +// Single lane: no change +template +HWY_API Vec128 TableLookupLanes(Vec128 v, + Indices128 /* idx */) { + return v; +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template +HWY_API VFromD ReverseBlocks(D /* tag */, VFromD v) { + return v; +} + +// ------------------------------ Reverse (Shuffle0123, Shuffle2301) + +// Single lane: no change +template +HWY_API VFromD Reverse(D /* tag */, VFromD v) { + return v; +} +// 32-bit x2: shuffle +template +HWY_API VFromD Reverse(D /* tag */, const VFromD v) { + return VFromD{Shuffle2301(Vec128>{v.raw}).raw}; +} +// 64-bit x2: shuffle +template +HWY_API VFromD Reverse(D /* tag */, const VFromD v) { + return Shuffle01(v); +} +// 32-bit x4: shuffle +template +HWY_API VFromD Reverse(D /* tag */, const VFromD v) { + return Shuffle0123(v); +} + +// 16-bit +template +HWY_API VFromD Reverse(D d, const VFromD v) { + const RebindToUnsigned du; + using VU = VFromD; + const VU vu = BitCast(du, v); + constexpr size_t kN = MaxLanes(d); + if (kN == 1) return v; + if (kN == 2) { + return BitCast(d, VU{__lsx_vshuf4i_h(vu.raw, 0x11)}); + } + if (kN == 4) { + return BitCast(d, VU{__lsx_vshuf4i_h(vu.raw, 0x1B)}); + } + const RebindToSigned di; + const VFromD shuffle = Dup128VecFromValues( + di, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100); + return BitCast(d, TableLookupBytes(v, shuffle)); +} + +template +HWY_API VFromD Reverse(D d, const VFromD v) { + static constexpr int kN = static_cast(MaxLanes(d)); + if (kN == 1) return v; + alignas(16) static constexpr int8_t _tmp_data[] = { + kN - 1, kN - 2, kN - 3, kN - 4, kN - 5, kN - 6, kN - 7, kN - 8, + kN - 9, kN - 10, kN - 11, kN - 12, kN - 13, kN - 14, kN - 15, kN - 16}; + return VFromD{__lsx_vshuf_b(v.raw, v.raw, __lsx_vld(_tmp_data, 0))}; +} + +// ------------------------------ Reverse2 + +// Single lane: no change +template +HWY_API VFromD Reverse2(D /* tag */, VFromD v) { + return v; +} + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const RepartitionToWide> dw; + return BitCast(d, RotateRight<16>(BitCast(dw, v))); +} + +// Generic for all vector lengths. +template +HWY_API VFromD Reverse2(D /* tag */, VFromD v) { + return Shuffle2301(v); +} + +// Generic for all vector lengths. +template +HWY_API VFromD Reverse2(D /* tag */, VFromD v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template +HWY_API VFromD Reverse4(D /* tag */, VFromD v) { + return VFromD{__lsx_vshuf4i_h(v.raw, 0x1B)}; +} + +// Generic for all vector lengths. +template +HWY_API VFromD Reverse4(D /* tag */, const VFromD v) { + return Shuffle0123(v); +} + +template +HWY_API VFromD Reverse4(D /* tag */, VFromD /* v */) { + HWY_ASSERT(0); // don't have 4 u64 lanes +} + +// ------------------------------ Reverse8 + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + const RepartitionToWide dw; + return Reverse2(d, BitCast(d, Shuffle0123(BitCast(dw, v)))); +} + +template +HWY_API VFromD Reverse8(D /* tag */, VFromD /* v */) { + HWY_ASSERT(0); // don't have 8 lanes if larger than 16-bit +} + +// ------------------------------ InterleaveUpper (UpperHalf) + +// Full +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{__lsx_vilvh_b(b.raw, a.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{__lsx_vilvh_h(b.raw, a.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D d, VFromD a, VFromD b) { + const RebindToSigned df; + return BitCast(d, VFromD{ + __lsx_vilvh_w(BitCast(df, b).raw, BitCast(df, a).raw)}); +} +template +HWY_API VFromD InterleaveUpper(D d, VFromD a, VFromD b) { + const RebindToSigned dd; + return BitCast(d, VFromD{ + __lsx_vilvh_d(BitCast(dd, b).raw, BitCast(dd, a).raw)}); +} + +// Partial +template +HWY_API VFromD InterleaveUpper(D d, VFromD a, VFromD b) { + const Half d2; + return InterleaveLower(d, VFromD{UpperHalf(d2, a).raw}, + VFromD{UpperHalf(d2, b).raw}); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ================================================== CONVERT (1) + +// ------------------------------ PromoteTo unsigned +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vsllwil_hu_bu(v.raw, 0)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vsllwil_wu_hu(v.raw, 0)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vsllwil_du_wu(v.raw, 0)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + const __m128i u16 = __lsx_vsllwil_hu_bu(v.raw, 0); + return VFromD{__lsx_vsllwil_wu_hu(u16, 0)}; +} +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { + const Rebind du32; + return PromoteTo(d, PromoteTo(du32, v)); +} +template +HWY_API VFromD PromoteTo(D /*tag*/, VFromD> v) { + const __m128i u32 = __lsx_vsllwil_wu_hu(v.raw, 0); + return VFromD{__lsx_vsllwil_du_wu(u32, 0)}; +} + +// Unsigned to signed: same plus cast. +template ), sizeof(TFromV)), + HWY_IF_LANES_D(D, HWY_MAX_LANES_V(V))> +HWY_API VFromD PromoteTo(D di, V v) { + const RebindToUnsigned du; + return BitCast(di, PromoteTo(du, v)); +} + +// signed +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vsllwil_h_b(v.raw, 0)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vsllwil_w_h(v.raw, 0)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vsllwil_d_w(v.raw, 0)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + const __m128i i16 = __lsx_vsllwil_h_b(v.raw, 0); + return VFromD{__lsx_vsllwil_w_h(i16, 0)}; +} +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { + const Rebind di32; + return PromoteTo(d, PromoteTo(di32, v)); +} +template +HWY_API VFromD PromoteTo(D /*tag*/, VFromD> v) { + const __m128i i32 = __lsx_vsllwil_w_h(v.raw, 0); + return VFromD{__lsx_vsllwil_d_w(i32, 0)}; +} + +// -------------------- PromoteTo float + +#ifdef HWY_NATIVE_F16C +#undef HWY_NATIVE_F16C +#else +#define HWY_NATIVE_F16C +#endif + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vfcvtl_s_h(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vfcvtl_d_s(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vffintl_d_w(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D df64, VFromD> v) { + const Rebind di32; + const auto i32_to_f64_result = PromoteTo(df64, BitCast(di32, v)); + return i32_to_f64_result + IfNegativeThenElse(i32_to_f64_result, + Set(df64, 4294967296.0), + Zero(df64)); +} + +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { + const RebindToSigned di32; + const Rebind du16; + return BitCast(d, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +// ------------------------------ Per4LaneBlockShuffle + +namespace detail { + +#ifdef HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#undef HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#else +#define HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#endif + +template +HWY_INLINE VFromD Per4LaneBlkShufDupSet4xU32(D d, const uint32_t x3, + const uint32_t x2, + const uint32_t x1, + const uint32_t x0) { + typedef uint32_t GccU32RawVectType __attribute__((__vector_size__(16))); + const GccU32RawVectType raw = {x0, x1, x2, x3}; + return ResizeBitCast(d, Vec128{reinterpret_cast<__m128i>(raw)}); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, + hwy::SizeTag /*vect_size_tag*/, + V v) { + constexpr int kShuffle = static_cast(kIdx3210 & 0xFF); + return V{__lsx_vshuf4i_b(v.raw, kShuffle)}; +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, + hwy::SizeTag /*vect_size_tag*/, + V v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + constexpr int kShuffle = static_cast(kIdx3210 & 0xFF); + return BitCast( + d, VFromD{__lsx_vshuf4i_h(BitCast(du, v).raw, kShuffle)}); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, + hwy::SizeTag<16> /*vect_size_tag*/, V v) { + const DFromV d; + constexpr int kShuffle = static_cast(kIdx3210 & 0xFF); + const RebindToUnsigned du; + return BitCast(d, VFromD{__lsx_vshuf4i_w( + reinterpret_cast<__m128i>(v.raw), kShuffle)}); +} + +} // namespace detail + +// ------------------------------ SlideUpLanes + +namespace detail { + +template +HWY_INLINE V SlideUpLanes(V v, size_t amt) { + const DFromV d; + const Full64 du64; + const auto vu64 = ResizeBitCast(du64, v); + return ResizeBitCast( + d, ShiftLeftSame(vu64, static_cast(amt * sizeof(TFromV) * 8))); +} + +template +HWY_INLINE V SlideUpLanes(V v, size_t amt) { + const DFromV d; + const Repartition du8; + const auto idx = + Iota(du8, static_cast(size_t{0} - amt * sizeof(TFromV))); + return BitCast(d, TableLookupBytesOr0(BitCast(du8, v), idx)); +} + +} // namespace detail + +template +HWY_API VFromD SlideUpLanes(D /*d*/, VFromD v, size_t /*amt*/) { + return v; +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + case 2: + return ShiftLeftLanes<2>(d, v); + case 3: + return ShiftLeftLanes<3>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + case 2: + return ShiftLeftLanes<2>(d, v); + case 3: + return ShiftLeftLanes<3>(d, v); + case 4: + return ShiftLeftLanes<4>(d, v); + case 5: + return ShiftLeftLanes<5>(d, v); + case 6: + return ShiftLeftLanes<6>(d, v); + case 7: + return ShiftLeftLanes<7>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + case 2: + return ShiftLeftLanes<2>(d, v); + case 3: + return ShiftLeftLanes<3>(d, v); + case 4: + return ShiftLeftLanes<4>(d, v); + case 5: + return ShiftLeftLanes<5>(d, v); + case 6: + return ShiftLeftLanes<6>(d, v); + case 7: + return ShiftLeftLanes<7>(d, v); + case 8: + return ShiftLeftLanes<8>(d, v); + case 9: + return ShiftLeftLanes<9>(d, v); + case 10: + return ShiftLeftLanes<10>(d, v); + case 11: + return ShiftLeftLanes<11>(d, v); + case 12: + return ShiftLeftLanes<12>(d, v); + case 13: + return ShiftLeftLanes<13>(d, v); + case 14: + return ShiftLeftLanes<14>(d, v); + case 15: + return ShiftLeftLanes<15>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +// ------------------------------ SlideDownLanes + +namespace detail { + +template +HWY_INLINE V SlideDownLanes(V v, size_t amt) { + const DFromV d; + const Repartition, decltype(d)> dv; + return BitCast(d, + ShiftRightSame(BitCast(dv, v), + static_cast(amt * sizeof(TFromV) * 8))); +} + +template +HWY_INLINE V SlideDownLanes(V v, size_t amt) { + const DFromV d; + const Repartition di8; + auto idx = Iota(di8, static_cast(amt * sizeof(TFromV))); + idx = Or(idx, VecFromMask(di8, idx > Set(di8, int8_t{15}))); + return BitCast(d, TableLookupBytesOr0(BitCast(di8, v), idx)); +} + +} // namespace detail + +template +HWY_API VFromD SlideDownLanes(D /*d*/, VFromD v, size_t /*amt*/) { + return v; +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + case 4: + return ShiftRightLanes<4>(d, v); + case 5: + return ShiftRightLanes<5>(d, v); + case 6: + return ShiftRightLanes<6>(d, v); + case 7: + return ShiftRightLanes<7>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + case 4: + return ShiftRightLanes<4>(d, v); + case 5: + return ShiftRightLanes<5>(d, v); + case 6: + return ShiftRightLanes<6>(d, v); + case 7: + return ShiftRightLanes<7>(d, v); + case 8: + return ShiftRightLanes<8>(d, v); + case 9: + return ShiftRightLanes<9>(d, v); + case 10: + return ShiftRightLanes<10>(d, v); + case 11: + return ShiftRightLanes<11>(d, v); + case 12: + return ShiftRightLanes<12>(d, v); + case 13: + return ShiftRightLanes<13>(d, v); + case 14: + return ShiftRightLanes<14>(d, v); + case 15: + return ShiftRightLanes<15>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) + +// N = N/2 + N/2 (upper half undefined) +template >> +HWY_API VFromD Combine(D d, VH hi_half, VH lo_half) { + const Half dh; + const RebindToUnsigned duh; + // Treat half-width input as one lane, and expand to two lanes. + using VU = Vec128, 2>; + const VU lo{BitCast(duh, lo_half).raw}; + const VU hi{BitCast(duh, hi_half).raw}; + return BitCast(d, InterleaveLower(lo, hi)); +} + +// ------------------------------ ZeroExtendVector (Combine) +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { + return Combine(d, Zero(Half()), lo); +} + +// ------------------------------ Concat full (InterleaveLower) + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + const Repartition d64; + return BitCast(d, InterleaveLower(BitCast(d64, lo), BitCast(d64, hi))); +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + const Repartition d64; + return BitCast(d, InterleaveUpper(d64, BitCast(d64, lo), BitCast(d64, hi))); +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves) +template +HWY_API VFromD ConcatLowerUpper(D d, VFromD hi, VFromD lo) { + return CombineShiftRightBytes<8>(d, hi, lo); +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template +HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { + return BitCast(d, Vec128{__lsx_vshuf4i_d( + reinterpret_cast<__m128i>(lo.raw), + reinterpret_cast<__m128i>(hi.raw), 0xC)}); +} + +// ------------------------------ Concat partial (Combine, LowerHalf) + +template +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), LowerHalf(d2, lo)); +} + +template +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, UpperHalf(d2, hi), UpperHalf(d2, lo)); +} + +template +HWY_API VFromD ConcatLowerUpper(D d, const VFromD hi, + const VFromD lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), UpperHalf(d2, lo)); +} + +template +HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, UpperHalf(d2, hi), LowerHalf(d2, lo)); +} + +// ------------------------------ ConcatOdd + +// 8-bit full +template +HWY_API VFromD ConcatOdd(D /* tag */, VFromD hi, VFromD lo) { + return VFromD{__lsx_vpickod_b(hi.raw, lo.raw)}; +} +// 8-bit x8 +template +HWY_API VFromD ConcatOdd(D /* tag */, VFromD hi, VFromD lo) { + __m128i _tmp = __lsx_vpickod_b(hi.raw, lo.raw); + return VFromD{__lsx_vextrins_w(_tmp, _tmp, 0x12)}; +} +// 8-bit x4 +template +HWY_API VFromD ConcatOdd(D /* tag */, VFromD hi, VFromD lo) { + __m128i _tmp = __lsx_vpickod_b(hi.raw, lo.raw); + return VFromD{__lsx_vextrins_h(_tmp, _tmp, 0x14)}; +} + +// 16-bit full +template +HWY_API VFromD ConcatOdd(D /* tag */, VFromD hi, VFromD lo) { + return VFromD{__lsx_vpickod_h(hi.raw, lo.raw)}; +} +// 16-bit x4 +template +HWY_API VFromD ConcatOdd(D /* tag */, VFromD hi, VFromD lo) { + __m128i _tmp = __lsx_vpickod_h(hi.raw, lo.raw); + return VFromD{__lsx_vextrins_w(_tmp, _tmp, 0x12)}; +} + +// 32-bit full +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + return BitCast( + d, Vec128{__lsx_vpickod_w(reinterpret_cast<__m128i>(hi.raw), + reinterpret_cast<__m128i>(lo.raw))}); +} + +// Any T x2 +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + return InterleaveUpper(d, lo, hi); +} + +// ------------------------------ ConcatEven + +// 8-bit full +template +HWY_API VFromD ConcatEven(D /* tag */, VFromD hi, VFromD lo) { + return VFromD{__lsx_vpickev_b(hi.raw, lo.raw)}; +} +// 8-bit x8 +template +HWY_API VFromD ConcatEven(D /* tag */, VFromD hi, VFromD lo) { + __m128i _tmp = __lsx_vpickev_b(hi.raw, lo.raw); + return VFromD{__lsx_vextrins_w(_tmp, _tmp, 0x12)}; +} +// 8-bit x4 +template +HWY_API VFromD ConcatEven(D /* tag */, VFromD hi, VFromD lo) { + __m128i _tmp = __lsx_vpickev_b(hi.raw, lo.raw); + return VFromD{__lsx_vextrins_h(_tmp, _tmp, 0x14)}; +} + +// 16-bit full +template +HWY_API VFromD ConcatEven(D /* tag */, VFromD hi, VFromD lo) { + return VFromD{__lsx_vpickev_h(hi.raw, lo.raw)}; +} +// 16-bit x4 +template +HWY_API VFromD ConcatEven(D /* tag */, VFromD hi, VFromD lo) { + __m128i _tmp = __lsx_vpickev_h(hi.raw, lo.raw); + return VFromD{__lsx_vextrins_w(_tmp, _tmp, 0x12)}; +} + +// 32-bit full +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + return BitCast( + d, Vec128{__lsx_vpickev_w(reinterpret_cast<__m128i>(hi.raw), + reinterpret_cast<__m128i>(lo.raw))}); +} + +// Any T x2 +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + return InterleaveLower(d, lo, hi); +} + +template +HWY_INLINE Vec128 ConcatEven(Vec128 hi, + Vec128 lo) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, ConcatEven(BitCast(du, hi), BitCast(du, lo))); +} +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec128 DupEven(const Vec128 v) { + return v; +} + +template +HWY_API Vec128 DupEven(const Vec128 v) { + __m128i _tmp = __lsx_vpickev_b(v.raw, v.raw); + return Vec128{__lsx_vilvl_b(_tmp, _tmp)}; +} + +template +HWY_API Vec128 DupEven(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + __m128i _tmp = __lsx_vpickev_h(BitCast(du, v).raw, BitCast(du, v).raw); + return BitCast(d, VFromD{__lsx_vilvl_h(_tmp, _tmp)}); +} + +template +HWY_API Vec128 DupEven(const Vec128 v) { + const DFromV d; + __m128i _tmp = detail::BitCastToInteger(v.raw); + __m128i _tmp1 = __lsx_vpickev_w(_tmp, _tmp); + return BitCast(d, Vec128{__lsx_vilvl_w(_tmp1, _tmp1)}); +} + +template +HWY_API Vec128 DupEven(Vec128 v) { + return InterleaveLower(DFromV(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return v; +} + +template +HWY_API Vec128 DupOdd(const Vec128 v) { + __m128i _tmp = __lsx_vpickod_b(v.raw, v.raw); + return Vec128{__lsx_vilvl_b(_tmp, _tmp)}; +} + +template +HWY_API Vec128 DupOdd(const Vec128 v) { + __m128i _tmp = __lsx_vpickod_h(v.raw, v.raw); + return Vec128{__lsx_vilvl_h(_tmp, _tmp)}; +} + +template +HWY_API Vec128 DupOdd(const Vec128 v) { + const DFromV d; + __m128i _tmp = detail::BitCastToInteger(v.raw); + __m128i _tmp1 = __lsx_vpickod_w(_tmp, _tmp); + return BitCast(d, Vec128{__lsx_vilvl_w(_tmp1, _tmp1)}); +} + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return InterleaveUpper(DFromV(), v, v); +} + +// ------------------------------ TwoTablesLookupLanes (DupEven) + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { + const DFromV d; + const Twice dt; + const Repartition dt_u8; +// TableLookupLanes currently requires table and index vectors to be the same +// size, though a half-length index vector would be sufficient here. +#if HWY_IS_MSAN + const Vec128 idx_vec{idx.raw}; + const Indices128 idx2{Combine(dt, idx_vec, idx_vec).raw}; +#else + // We only keep LowerHalf of the result, which is valid in idx. + const Indices128 idx2{idx.raw}; +#endif + return LowerHalf( + d, TableLookupBytes(Combine(dt, b, a), + BitCast(dt, VFromD{idx2.raw}))); +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { + return Vec128{__lsx_vshuf_b(b.raw, a.raw, idx.raw)}; +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { + const DFromV d; + const Repartition du8; + return BitCast(d, TwoTablesLookupLanes(BitCast(du8, a), BitCast(du8, b), + Indices128{idx.raw})); +} + +// ------------------------------ OddEven + +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { + __m128i t0 = __lsx_vpackod_b(a.raw, a.raw); + return Vec128{__lsx_vpackev_b(t0, b.raw)}; +} +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { + __m128i t0 = __lsx_vpackod_h(a.raw, a.raw); + return Vec128{__lsx_vpackev_h(t0, b.raw)}; +} +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + __m128i t0 = __lsx_vpackod_w(BitCast(du, a).raw, BitCast(du, a).raw); + return BitCast(d, + VFromD{__lsx_vpackev_w(t0, BitCast(du, b).raw)}); +} +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, VFromD{__lsx_vextrins_d( + BitCast(du, b).raw, BitCast(du, a).raw, 0x11)}); +} + +// -------------------------- InterleaveEven + +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{__lsx_vpackev_b(b.raw, a.raw)}; +} +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{__lsx_vpackev_h(b.raw, a.raw)}; +} +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const RebindToSigned di; + return BitCast(d, VFromD{__lsx_vpackev_w(BitCast(di, b).raw, + BitCast(di, a).raw)}); +} + +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const RebindToSigned di; + return BitCast(d, VFromD{__lsx_vpackev_d(BitCast(di, b).raw, + BitCast(di, a).raw)}); +} + +// -------------------------- InterleaveOdd + +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{__lsx_vpackod_b(b.raw, a.raw)}; +} +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{__lsx_vpackod_h(b.raw, a.raw)}; +} +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const RebindToSigned di; + return BitCast(d, VFromD{__lsx_vpackod_w(BitCast(di, b).raw, + BitCast(di, a).raw)}); +} + +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const RebindToSigned di; + return BitCast(d, VFromD{__lsx_vpackod_d(BitCast(di, b).raw, + BitCast(di, a).raw)}); +} + +// ------------------------------ OddEvenBlocks + +template +HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { + return v; +} + +// ------------------------------ InterleaveEvenBlocks +template , HWY_IF_V_SIZE_LE_D(D, 16)> +HWY_API V InterleaveEvenBlocks(D, V a, V /*b*/) { + return a; +} +// ------------------------------ InterleaveOddBlocks +template , HWY_IF_V_SIZE_LE_D(D, 16)> +HWY_API V InterleaveOddBlocks(D, V a, V /*b*/) { + return a; +} + +// ------------------------------ InterleaveLowerBlocks +template , HWY_IF_V_SIZE_LE_D(D, 16)> +HWY_API V InterleaveLowerBlocks(D, V a, V /*b*/) { + return a; +} +// ------------------------------ InterleaveUpperBlocks +template , HWY_IF_V_SIZE_LE_D(D, 16)> +HWY_API V InterleaveUpperBlocks(D, V a, V /*b*/) { + return a; +} + +// ------------------------------ Shl + +template +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return Vec128{__lsx_vsll_b(v.raw, bits.raw)}; +} + +template +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return Vec128{__lsx_vsll_h(v.raw, bits.raw)}; +} + +template +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return Vec128{__lsx_vsll_w(v.raw, bits.raw)}; +} + +template +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return Vec128{__lsx_vsll_d(v.raw, bits.raw)}; +} + +// ------------------------------ Shr + +namespace detail { + +template +HWY_API Vec128 Shr(Vec128 v, Vec128 bits) { + return Vec128{__lsx_vsrl_b(v.raw, bits.raw)}; +} +template +HWY_API Vec128 Shr(Vec128 v, + Vec128 bits) { + return Vec128{__lsx_vsrl_h(v.raw, bits.raw)}; +} +template +HWY_API Vec128 Shr(Vec128 v, + Vec128 bits) { + return Vec128{__lsx_vsrl_w(v.raw, bits.raw)}; +} +template +HWY_API Vec128 Shr(Vec128 v, + Vec128 bits) { + return Vec128{__lsx_vsrl_d(v.raw, bits.raw)}; +} + +template +HWY_API Vec128 Shr(Vec128 v, Vec128 bits) { + return Vec128{__lsx_vsra_b(v.raw, bits.raw)}; +} +template +HWY_API Vec128 Shr(Vec128 v, Vec128 bits) { + return Vec128{__lsx_vsra_h(v.raw, bits.raw)}; +} +template +HWY_API Vec128 Shr(Vec128 v, Vec128 bits) { + return Vec128{__lsx_vsra_w(v.raw, bits.raw)}; +} +template +HWY_API Vec128 Shr(Vec128 v, Vec128 bits) { + return Vec128{__lsx_vsra_d(v.raw, bits.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { + return detail::Shr(v, bits); +} + +// ================================================== CONVERT (2) + +// ------------------------------ PromoteEvenTo/PromoteOddTo +#include "hwy/ops/inside-inl.h" + +// Generic for all vector lengths. +template >> +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), + Mul(PromoteOddTo(df, a), PromoteOddTo(df, b))); +} + +template >> +HWY_API VFromD WidenMulPairwiseAdd(D32 /* tag */, V16 a, V16 b) { + __m128i _tmp = __lsx_vmulwev_w_h(a.raw, b.raw); + return VFromD{__lsx_vmaddwod_w_h(_tmp, a.raw, b.raw)}; +} + +template >> +HWY_API VFromD WidenMulPairwiseAdd(DU32 /* tag */, VU16 a, VU16 b) { + __m128i _tmp = __lsx_vmulwev_w_hu(a.raw, b.raw); + return VFromD{__lsx_vmaddwod_w_hu(_tmp, a.raw, b.raw)}; +} + +// ------------------------------ ReorderWidenMulAccumulate + +template >> +HWY_API VFromD ReorderWidenMulAccumulate(D32 /* tag */, V16 a, V16 b, + const VFromD sum0, + VFromD& /* sum1 */) { + return VFromD{__lsx_vmaddwev_w_h( + __lsx_vmaddwod_w_h(sum0.raw, a.raw, b.raw), a.raw, b.raw)}; +} + +template >> +HWY_API VFromD ReorderWidenMulAccumulate(DU32 /* tag */, VU16 a, VU16 b, + const VFromD sum0, + VFromD& /* sum1 */) { + return VFromD{__lsx_vmaddwev_w_hu( + __lsx_vmaddwod_w_hu(sum0.raw, a.raw, b.raw), a.raw, b.raw)}; +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API Vec128 RearrangeToOddPlusEven(const Vec128 sum0, + Vec128 /*sum1*/) { + return sum0; // invariant already holds +} + +template +HWY_API Vec128 RearrangeToOddPlusEven( + const Vec128 sum0, Vec128 /*sum1*/) { + return sum0; // invariant already holds +} + +// ------------------------------ Demotions + +// NOTE: hwy::EnableIf()>* = nullptr is used instead of +// hwy::EnableIf* = nullptr to avoid compiler errors since +// !hwy::IsSame() is always false and as !hwy::IsSame() will cause +// SFINAE to occur instead of a hard error due to a dependency on the V template +// argument +#undef HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V +#define HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V(V) \ + hwy::EnableIf()>* = nullptr + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vssrani_b_h(v.raw, v.raw, 0)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vssrani_bu_h(v.raw, v.raw, 0)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vssrlni_b_h(v.raw, v.raw, 0)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vssrlni_bu_h(v.raw, v.raw, 0)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vssrani_h_w(v.raw, v.raw, 0)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vssrani_hu_w(v.raw, v.raw, 0)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vssrlni_h_w(v.raw, v.raw, 0)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vssrlni_hu_w(v.raw, v.raw, 0)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vssrani_w_d(v.raw, v.raw, 0)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vssrani_wu_d(v.raw, v.raw, 0)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vssrlni_w_d(v.raw, v.raw, 0)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vssrlni_wu_d(v.raw, v.raw, 0)}; +} + +// UI->UI DemoteTo for the case where +// sizeof(TFromD) <= sizeof(TFromV) / 4 is generic for all vector lengths +template ) / 4)> +HWY_API VFromD DemoteTo(DN dn, V v) { + using T = TFromV; + using TN = TFromD; + + using TDemoteTo = + MakeNarrow() && IsSigned(), T, MakeUnsigned>>; + return DemoteTo(dn, DemoteTo(Rebind(), v)); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vfcvt_h_s(v.raw, v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vfcvt_s_d(v.raw, v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vftintrz_w_d( + reinterpret_cast<__m128d>(__lsx_vreplgr2vr_w(0)), v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D du32, VFromD> v) { + const Rebind du64; + return DemoteTo(du32, ConvertTo(du64, v)); +} + +template +HWY_API VFromD DemoteTo(D df32, VFromD> v) { + const Rebind df64; + const RebindToUnsigned du64; + const RebindToSigned di32; + const RebindToUnsigned du32; + + const auto k2p64_63 = Set(df64, 27670116110564327424.0); + const auto f64_hi52 = + Xor(BitCast(df64, ShiftRight<12>(BitCast(du64, v))), k2p64_63) - k2p64_63; + const auto f64_lo12 = + PromoteTo(df64, BitCast(di32, And(TruncateTo(du32, BitCast(du64, v)), + Set(du32, uint32_t{0x00000FFF})))); + + const auto f64_sum = f64_hi52 + f64_lo12; + const auto f64_carry = (f64_hi52 - f64_sum) + f64_lo12; + + const auto f64_sum_is_inexact = + ShiftRight<63>(BitCast(du64, VecFromMask(df64, f64_carry != Zero(df64)))); + const auto f64_bits_decrement = + And(ShiftRight<63>(BitCast(du64, Xor(f64_sum, f64_carry))), + f64_sum_is_inexact); + + const auto adj_f64_val = BitCast( + df64, + Or(BitCast(du64, f64_sum) - f64_bits_decrement, f64_sum_is_inexact)); + + return DemoteTo(df32, adj_f64_val); +} + +template +HWY_API VFromD DemoteTo(D df32, VFromD> v) { + const Rebind df64; + const RebindToUnsigned du64; + const RebindToSigned di32; + const RebindToUnsigned du32; + + const auto k2p64 = Set(df64, 18446744073709551616.0); + const auto f64_hi52 = Or(BitCast(df64, ShiftRight<12>(v)), k2p64) - k2p64; + const auto f64_lo12 = + PromoteTo(df64, BitCast(di32, And(TruncateTo(du32, BitCast(du64, v)), + Set(du32, uint32_t{0x00000FFF})))); + + const auto f64_sum = f64_hi52 + f64_lo12; + const auto f64_carry = (f64_hi52 - f64_sum) + f64_lo12; + const auto f64_sum_is_inexact = + ShiftRight<63>(BitCast(du64, VecFromMask(df64, f64_carry != Zero(df64)))); + + const auto adj_f64_val = BitCast( + df64, + Or(BitCast(du64, f64_sum) - ShiftRight<63>(BitCast(du64, f64_carry)), + f64_sum_is_inexact)); + + return DemoteTo(df32, adj_f64_val); +} + +// ------------------------------ ReorderDemote2To + +// ReorderDemote2To for 8-byte UI64->UI32, <= 4-byte UI32->UI16, +// and <= 4-byte UI16->UI8 +template ) <= 2 ? 4 : 8))), + HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(DN), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + HWY_IF_LANES_D(DN, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return VFromD{__lsx_vssrani_b_h(b.raw, a.raw, 0)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return VFromD{__lsx_vssrani_bu_h(b.raw, a.raw, 0)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return VFromD{__lsx_vssrlni_b_h(b.raw, a.raw, 0)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return VFromD{__lsx_vssrlni_bu_h(b.raw, a.raw, 0)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return VFromD{__lsx_vssrani_h_w(b.raw, a.raw, 0)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return VFromD{__lsx_vssrani_hu_w(b.raw, a.raw, 0)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return VFromD{__lsx_vssrlni_h_w(b.raw, a.raw, 0)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return VFromD{__lsx_vssrlni_hu_w(b.raw, a.raw, 0)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return VFromD{__lsx_vssrani_w_d(b.raw, a.raw, 0)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return VFromD{__lsx_vssrani_wu_d(b.raw, a.raw, 0)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return VFromD{__lsx_vssrlni_w_d(b.raw, a.raw, 0)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return VFromD{__lsx_vssrlni_wu_d(b.raw, a.raw, 0)}; +} + +// 8-byte UI32->UI16 and UI16->UI8 ReorderDemote2To +template ) * 2), + HWY_IF_LANES_D(DN, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const Twice> dt; + const Twice dt_n; + + const auto demote2_result = + ReorderDemote2To(dt_n, ResizeBitCast(dt, a), ResizeBitCast(dt, b)); + return VFromD{__lsx_vshuf4i_w(demote2_result.raw, 0x88)}; +} + +template ), + HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedDemote2To(D d, V a, V b) { + return ReorderDemote2To(d, a, b); +} + +template +HWY_API Vec128 U8FromU32(const Vec128 v) { + const DFromV du32; + const Rebind du8; + return DemoteTo(du8, BitCast(du32, v)); +} + +// ------------------------------ F32->UI64 PromoteTo + +// f32 ->i64 +template +HWY_API VFromD PromoteTo(D /*di64*/, VFromD> v) { + return VFromD{__lsx_vftintrzl_l_s(v.raw)}; +} + +// F32->U64 PromoteTo generic for all vector lengths +template +HWY_API VFromD PromoteTo(D du64, VFromD> v) { + const RebindToFloat df64; + return ConvertTo(du64, PromoteTo(df64, v)); +} + +// ------------------------------ MulFixedPoint15 + +template +HWY_API Vec128 MulFixedPoint15(const Vec128 a, + const Vec128 b) { + __m128i temp_ev = __lsx_vmulwev_w_h(a.raw, b.raw); + __m128i temp_od = __lsx_vmulwod_w_h(a.raw, b.raw); + __m128i temp1 = __lsx_vilvl_w(temp_od, temp_ev); + __m128i temp2 = __lsx_vilvh_w(temp_od, temp_ev); + return Vec128{__lsx_vssrarni_h_w(temp2, temp1, 15)}; +} + +// ------------------------------ Truncations + +template +HWY_API VFromD TruncateTo(DTo /* tag */, Vec128 v) { + const Repartition, DFromV> dto; + return VFromD{BitCast(dto, v).raw}; +} + +template +HWY_API Vec16 TruncateTo(D /* tag */, Vec128 v) { + return Vec16{__lsx_vextrins_b(v.raw, v.raw, 0x18)}; +} + +template +HWY_API Vec32 TruncateTo(D /* tag */, Vec128 v) { + return Vec32{__lsx_vextrins_h(v.raw, v.raw, 0x14)}; +} + +template +HWY_API Vec64 TruncateTo(D /* tag */, Vec128 v) { + return Vec64{__lsx_vpickev_w(v.raw, v.raw)}; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, VFromD> v) { + __m128i v_ev = __lsx_vpickev_b(v.raw, v.raw); + return VFromD{__lsx_vpickev_b(v_ev, v_ev)}; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vpickev_h(v.raw, v.raw)}; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vpickev_b(v.raw, v.raw)}; +} + +// ------------------------------ int -> float ConvertTo + +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vffint_s_w(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vffint_s_wu(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vffint_d_l(v.raw)}; +} + +// ------------------------------ float -> int ConvertTo + +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vffint_d_lu(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vftintrz_w_s(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vftintrz_wu_s(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vftintrz_l_d(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD{__lsx_vftintrz_lu_d(v.raw)}; +} + +// ------------------------------ NearestInt (Round) + +template +HWY_API Vec128 NearestInt(const Vec128 v) { + return Vec128{__lsx_vftintrne_w_s(v.raw)}; +} + +template +HWY_API Vec128 NearestInt(const Vec128 v) { + return Vec128{__lsx_vftintrne_l_d(v.raw)}; +} + +template +HWY_API VFromD DemoteToNearestInt(DI32 di32, + VFromD> v) { + return DemoteTo(di32, NearestInt(v)); +} + +// ------------------------------ Floating-point rounding + +template +HWY_API Vec128 Round(const Vec128 v) { + return Vec128{__lsx_vfrintrne_s(v.raw)}; +} +template +HWY_API Vec128 Round(const Vec128 v) { + return Vec128{__lsx_vfrintrne_d(v.raw)}; +} +template +HWY_API Vec128 Trunc(const Vec128 v) { + return Vec128{__lsx_vfrintrz_s(v.raw)}; +} +template +HWY_API Vec128 Trunc(const Vec128 v) { + return Vec128{__lsx_vfrintrz_d(v.raw)}; +} +template +HWY_API Vec128 Ceil(const Vec128 v) { + return Vec128{__lsx_vfrintrp_s(v.raw)}; +} +template +HWY_API Vec128 Ceil(const Vec128 v) { + return Vec128{__lsx_vfrintrp_d(v.raw)}; +} +// Toward -infinity, aka floor +template +HWY_API Vec128 Floor(const Vec128 v) { + return Vec128{__lsx_vfrintrm_s(v.raw)}; +} +template +HWY_API Vec128 Floor(const Vec128 v) { + return Vec128{__lsx_vfrintrm_d(v.raw)}; +} + +// ------------------------------ Floating-point classification + +// FIXME: disable gcc-14 tree-based loop optimizations to prevent +// 'HighwayTestGroup/HighwayTest.TestAllIsNaN/LSX' failures +#if HWY_COMPILER_GCC && !HWY_COMPILER_CLANG +#pragma GCC push_options +#pragma GCC optimize("-fno-tree-loop-optimize") +#endif + +template +HWY_API Mask128 IsNaN(const Vec128 v) { + return Mask128{ + reinterpret_cast<__m128>(__lsx_vfcmp_cune_s(v.raw, v.raw))}; +} + +template +HWY_API Mask128 IsNaN(const Vec128 v) { + return Mask128{ + reinterpret_cast<__m128d>(__lsx_vfcmp_cune_d(v.raw, v.raw))}; +} + +#if HWY_COMPILER_GCC && !HWY_COMPILER_CLANG +#pragma GCC pop_options +#endif + +#ifdef HWY_NATIVE_IS_EITHER_NAN +#undef HWY_NATIVE_IS_EITHER_NAN +#else +#define HWY_NATIVE_IS_EITHER_NAN +#endif + +template +HWY_API Mask128 IsEitherNaN(Vec128 a, Vec128 b) { + return Mask128{ + reinterpret_cast<__m128>(__lsx_vfcmp_cun_s(a.raw, b.raw))}; +} + +template +HWY_API Mask128 IsEitherNaN(Vec128 a, + Vec128 b) { + __m128i _tmp = __lsx_vor_v(__lsx_vfcmp_cune_d(a.raw, a.raw), + __lsx_vfcmp_cune_d(b.raw, b.raw)); + return Mask128{reinterpret_cast<__m128d>(_tmp)}; +} + +#ifdef HWY_NATIVE_ISINF +#undef HWY_NATIVE_ISINF +#else +#define HWY_NATIVE_ISINF +#endif + +template +HWY_API MFromD> IsInf(V v) { + using T = TFromV; + + static_assert(IsFloat(), "Only for float"); + using TU = MakeUnsigned; + const DFromV d; + const RebindToUnsigned du; + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent=max and + // mantissa=0. + return RebindMask( + d, + Eq(Add(vu, vu), Set(du, static_cast(hwy::MaxExponentTimes2())))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API MFromD> IsFinite(V v) { + using T = TFromV; + + static_assert(IsFloat(), "Only for float"); + using TU = MakeUnsigned; + const DFromV d; + const RebindToUnsigned du; + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent(hwy::MaxExponentTimes2())))); +} + +// ================================================== MISC + +// ------------------------------ LoadMaskBits (TestBit) + +namespace detail { + +template +HWY_INLINE MFromD LoadMaskBits(D d, uint64_t bits) { + const RebindToUnsigned du; + // Easier than Set(), which would require an >8-bit type, which would not + // compile for T=uint8_t, N=1. + const VFromD vbits{__lsx_vreplgr2vr_w(static_cast(bits))}; + + // Replicate bytes 8x such that each byte contains the bit that governs it. + alignas(16) static constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1}; + const auto rep8 = TableLookupBytes(vbits, Load(du, kRep8)); + + alignas(16) static constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, + 1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); +} + +template +HWY_INLINE MFromD LoadMaskBits(D d, uint64_t bits) { + const RebindToUnsigned du; + alignas(16) static constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask( + d, TestBit(Set(du, static_cast(bits)), Load(du, kBit))); +} + +template +HWY_INLINE MFromD LoadMaskBits(D d, uint64_t bits) { + const RebindToUnsigned du; + alignas(16) static constexpr uint32_t kBit[8] = {1, 2, 4, 8}; + return RebindMask( + d, TestBit(Set(du, static_cast(bits)), Load(du, kBit))); +} + +template +HWY_INLINE MFromD LoadMaskBits(D d, uint64_t bits) { + const RebindToUnsigned du; + alignas(16) static constexpr uint64_t kBit[8] = {1, 2}; + return RebindMask(d, TestBit(Set(du, bits), Load(du, kBit))); +} + +} // namespace detail + +template +HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + CopyBytes<(d.MaxLanes() + 7) / 8>(bits, &mask_bits); + return detail::LoadMaskBits(d, mask_bits); +} + +// ------------------------------ Dup128MaskFromMaskBits + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= (1u << kN) - 1; + return detail::LoadMaskBits(d, mask_bits); +} + +template +struct CompressIsPartition { + enum { value = (sizeof(T) != 1) }; +}; + +// ------------------------------ BitsFromMask + +namespace detail { + +template +constexpr uint64_t OnlyActive(D d, uint64_t mask_bits) { + return (d.MaxBytes() >= 16) ? mask_bits + : mask_bits & ((1ull << d.MaxLanes()) - 1); +} + +constexpr HWY_INLINE uint64_t U64FromInt(int mask_bits) { + return static_cast(static_cast(mask_bits)); +} + +} // namespace detail + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + return detail::OnlyActive( + d, detail::U64FromInt(__lsx_vpickve2gr_w(__lsx_vmskltz_b(mask.raw), 0))); +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + return detail::OnlyActive( + d, detail::U64FromInt(__lsx_vpickve2gr_w(__lsx_vmskltz_h(mask.raw), 0))); +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + return detail::OnlyActive( + d, detail::U64FromInt(__lsx_vpickve2gr_w( + __lsx_vmskltz_w(reinterpret_cast<__m128i>(mask.raw)), 0))); +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + return detail::OnlyActive( + d, detail::U64FromInt(__lsx_vpickve2gr_w( + __lsx_vmskltz_d(reinterpret_cast<__m128i>(mask.raw)), 0))); +} + +// ------------------------------ StoreMaskBits +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(D d, MFromD mask, uint8_t* bits) { + constexpr size_t kNumBytes = (MaxLanes(d) + 7) / 8; + const uint64_t mask_bits = BitsFromMask(d, mask); + CopyBytes(&mask_bits, bits); + return kNumBytes; +} + +template +HWY_API bool AllFalse(D d, MFromD mask) { + return BitsFromMask(d, mask) == 0; +} + +template +HWY_API bool AllTrue(D d, MFromD mask) { + constexpr size_t kN = MaxLanes(d); + constexpr uint64_t kAllBits = (1ull << kN) - 1; + return BitsFromMask(d, mask) == kAllBits; +} + +template +HWY_API size_t CountTrue(D d, MFromD mask) { + return PopCount(BitsFromMask(d, mask)); +} + +template +HWY_API size_t FindKnownFirstTrue(D d, MFromD mask) { + return Num0BitsBelowLS1Bit_Nonzero64(BitsFromMask(d, mask)); +} + +template +HWY_API intptr_t FindFirstTrue(D d, MFromD mask) { + const uint64_t mask_bits = BitsFromMask(d, mask); + return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero64(mask_bits)) : -1; +} + +template +HWY_API size_t FindKnownLastTrue(D d, MFromD mask) { + return 31 - Num0BitsAboveMS1Bit_Nonzero32( + static_cast(BitsFromMask(d, mask))); +} + +template +HWY_API intptr_t FindLastTrue(D d, MFromD mask) { + const uint32_t mask_bits = static_cast(BitsFromMask(d, mask)); + return mask_bits ? intptr_t(31 - Num0BitsAboveMS1Bit_Nonzero32(mask_bits)) + : -1; +} + +// ------------------------------ Compress, CompressBits + +namespace detail { + +// Also works for N < 8 because the first 16 4-tuples only reference bytes 0-6. +template +HWY_INLINE VFromD IndicesFromBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Rebind d8; + const Twice d8t; + const RebindToUnsigned du; + + alignas(16) static constexpr uint8_t table[2048] = { + // PrintCompress16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 2, 0, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 4, 0, 2, 6, 8, 10, 12, 14, /**/ 0, 4, 2, 6, 8, 10, 12, 14, // + 2, 4, 0, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 6, 0, 2, 4, 8, 10, 12, 14, /**/ 0, 6, 2, 4, 8, 10, 12, 14, // + 2, 6, 0, 4, 8, 10, 12, 14, /**/ 0, 2, 6, 4, 8, 10, 12, 14, // + 4, 6, 0, 2, 8, 10, 12, 14, /**/ 0, 4, 6, 2, 8, 10, 12, 14, // + 2, 4, 6, 0, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 8, 0, 2, 4, 6, 10, 12, 14, /**/ 0, 8, 2, 4, 6, 10, 12, 14, // + 2, 8, 0, 4, 6, 10, 12, 14, /**/ 0, 2, 8, 4, 6, 10, 12, 14, // + 4, 8, 0, 2, 6, 10, 12, 14, /**/ 0, 4, 8, 2, 6, 10, 12, 14, // + 2, 4, 8, 0, 6, 10, 12, 14, /**/ 0, 2, 4, 8, 6, 10, 12, 14, // + 6, 8, 0, 2, 4, 10, 12, 14, /**/ 0, 6, 8, 2, 4, 10, 12, 14, // + 2, 6, 8, 0, 4, 10, 12, 14, /**/ 0, 2, 6, 8, 4, 10, 12, 14, // + 4, 6, 8, 0, 2, 10, 12, 14, /**/ 0, 4, 6, 8, 2, 10, 12, 14, // + 2, 4, 6, 8, 0, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 10, 0, 2, 4, 6, 8, 12, 14, /**/ 0, 10, 2, 4, 6, 8, 12, 14, // + 2, 10, 0, 4, 6, 8, 12, 14, /**/ 0, 2, 10, 4, 6, 8, 12, 14, // + 4, 10, 0, 2, 6, 8, 12, 14, /**/ 0, 4, 10, 2, 6, 8, 12, 14, // + 2, 4, 10, 0, 6, 8, 12, 14, /**/ 0, 2, 4, 10, 6, 8, 12, 14, // + 6, 10, 0, 2, 4, 8, 12, 14, /**/ 0, 6, 10, 2, 4, 8, 12, 14, // + 2, 6, 10, 0, 4, 8, 12, 14, /**/ 0, 2, 6, 10, 4, 8, 12, 14, // + 4, 6, 10, 0, 2, 8, 12, 14, /**/ 0, 4, 6, 10, 2, 8, 12, 14, // + 2, 4, 6, 10, 0, 8, 12, 14, /**/ 0, 2, 4, 6, 10, 8, 12, 14, // + 8, 10, 0, 2, 4, 6, 12, 14, /**/ 0, 8, 10, 2, 4, 6, 12, 14, // + 2, 8, 10, 0, 4, 6, 12, 14, /**/ 0, 2, 8, 10, 4, 6, 12, 14, // + 4, 8, 10, 0, 2, 6, 12, 14, /**/ 0, 4, 8, 10, 2, 6, 12, 14, // + 2, 4, 8, 10, 0, 6, 12, 14, /**/ 0, 2, 4, 8, 10, 6, 12, 14, // + 6, 8, 10, 0, 2, 4, 12, 14, /**/ 0, 6, 8, 10, 2, 4, 12, 14, // + 2, 6, 8, 10, 0, 4, 12, 14, /**/ 0, 2, 6, 8, 10, 4, 12, 14, // + 4, 6, 8, 10, 0, 2, 12, 14, /**/ 0, 4, 6, 8, 10, 2, 12, 14, // + 2, 4, 6, 8, 10, 0, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 12, 0, 2, 4, 6, 8, 10, 14, /**/ 0, 12, 2, 4, 6, 8, 10, 14, // + 2, 12, 0, 4, 6, 8, 10, 14, /**/ 0, 2, 12, 4, 6, 8, 10, 14, // + 4, 12, 0, 2, 6, 8, 10, 14, /**/ 0, 4, 12, 2, 6, 8, 10, 14, // + 2, 4, 12, 0, 6, 8, 10, 14, /**/ 0, 2, 4, 12, 6, 8, 10, 14, // + 6, 12, 0, 2, 4, 8, 10, 14, /**/ 0, 6, 12, 2, 4, 8, 10, 14, // + 2, 6, 12, 0, 4, 8, 10, 14, /**/ 0, 2, 6, 12, 4, 8, 10, 14, // + 4, 6, 12, 0, 2, 8, 10, 14, /**/ 0, 4, 6, 12, 2, 8, 10, 14, // + 2, 4, 6, 12, 0, 8, 10, 14, /**/ 0, 2, 4, 6, 12, 8, 10, 14, // + 8, 12, 0, 2, 4, 6, 10, 14, /**/ 0, 8, 12, 2, 4, 6, 10, 14, // + 2, 8, 12, 0, 4, 6, 10, 14, /**/ 0, 2, 8, 12, 4, 6, 10, 14, // + 4, 8, 12, 0, 2, 6, 10, 14, /**/ 0, 4, 8, 12, 2, 6, 10, 14, // + 2, 4, 8, 12, 0, 6, 10, 14, /**/ 0, 2, 4, 8, 12, 6, 10, 14, // + 6, 8, 12, 0, 2, 4, 10, 14, /**/ 0, 6, 8, 12, 2, 4, 10, 14, // + 2, 6, 8, 12, 0, 4, 10, 14, /**/ 0, 2, 6, 8, 12, 4, 10, 14, // + 4, 6, 8, 12, 0, 2, 10, 14, /**/ 0, 4, 6, 8, 12, 2, 10, 14, // + 2, 4, 6, 8, 12, 0, 10, 14, /**/ 0, 2, 4, 6, 8, 12, 10, 14, // + 10, 12, 0, 2, 4, 6, 8, 14, /**/ 0, 10, 12, 2, 4, 6, 8, 14, // + 2, 10, 12, 0, 4, 6, 8, 14, /**/ 0, 2, 10, 12, 4, 6, 8, 14, // + 4, 10, 12, 0, 2, 6, 8, 14, /**/ 0, 4, 10, 12, 2, 6, 8, 14, // + 2, 4, 10, 12, 0, 6, 8, 14, /**/ 0, 2, 4, 10, 12, 6, 8, 14, // + 6, 10, 12, 0, 2, 4, 8, 14, /**/ 0, 6, 10, 12, 2, 4, 8, 14, // + 2, 6, 10, 12, 0, 4, 8, 14, /**/ 0, 2, 6, 10, 12, 4, 8, 14, // + 4, 6, 10, 12, 0, 2, 8, 14, /**/ 0, 4, 6, 10, 12, 2, 8, 14, // + 2, 4, 6, 10, 12, 0, 8, 14, /**/ 0, 2, 4, 6, 10, 12, 8, 14, // + 8, 10, 12, 0, 2, 4, 6, 14, /**/ 0, 8, 10, 12, 2, 4, 6, 14, // + 2, 8, 10, 12, 0, 4, 6, 14, /**/ 0, 2, 8, 10, 12, 4, 6, 14, // + 4, 8, 10, 12, 0, 2, 6, 14, /**/ 0, 4, 8, 10, 12, 2, 6, 14, // + 2, 4, 8, 10, 12, 0, 6, 14, /**/ 0, 2, 4, 8, 10, 12, 6, 14, // + 6, 8, 10, 12, 0, 2, 4, 14, /**/ 0, 6, 8, 10, 12, 2, 4, 14, // + 2, 6, 8, 10, 12, 0, 4, 14, /**/ 0, 2, 6, 8, 10, 12, 4, 14, // + 4, 6, 8, 10, 12, 0, 2, 14, /**/ 0, 4, 6, 8, 10, 12, 2, 14, // + 2, 4, 6, 8, 10, 12, 0, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 14, 0, 2, 4, 6, 8, 10, 12, /**/ 0, 14, 2, 4, 6, 8, 10, 12, // + 2, 14, 0, 4, 6, 8, 10, 12, /**/ 0, 2, 14, 4, 6, 8, 10, 12, // + 4, 14, 0, 2, 6, 8, 10, 12, /**/ 0, 4, 14, 2, 6, 8, 10, 12, // + 2, 4, 14, 0, 6, 8, 10, 12, /**/ 0, 2, 4, 14, 6, 8, 10, 12, // + 6, 14, 0, 2, 4, 8, 10, 12, /**/ 0, 6, 14, 2, 4, 8, 10, 12, // + 2, 6, 14, 0, 4, 8, 10, 12, /**/ 0, 2, 6, 14, 4, 8, 10, 12, // + 4, 6, 14, 0, 2, 8, 10, 12, /**/ 0, 4, 6, 14, 2, 8, 10, 12, // + 2, 4, 6, 14, 0, 8, 10, 12, /**/ 0, 2, 4, 6, 14, 8, 10, 12, // + 8, 14, 0, 2, 4, 6, 10, 12, /**/ 0, 8, 14, 2, 4, 6, 10, 12, // + 2, 8, 14, 0, 4, 6, 10, 12, /**/ 0, 2, 8, 14, 4, 6, 10, 12, // + 4, 8, 14, 0, 2, 6, 10, 12, /**/ 0, 4, 8, 14, 2, 6, 10, 12, // + 2, 4, 8, 14, 0, 6, 10, 12, /**/ 0, 2, 4, 8, 14, 6, 10, 12, // + 6, 8, 14, 0, 2, 4, 10, 12, /**/ 0, 6, 8, 14, 2, 4, 10, 12, // + 2, 6, 8, 14, 0, 4, 10, 12, /**/ 0, 2, 6, 8, 14, 4, 10, 12, // + 4, 6, 8, 14, 0, 2, 10, 12, /**/ 0, 4, 6, 8, 14, 2, 10, 12, // + 2, 4, 6, 8, 14, 0, 10, 12, /**/ 0, 2, 4, 6, 8, 14, 10, 12, // + 10, 14, 0, 2, 4, 6, 8, 12, /**/ 0, 10, 14, 2, 4, 6, 8, 12, // + 2, 10, 14, 0, 4, 6, 8, 12, /**/ 0, 2, 10, 14, 4, 6, 8, 12, // + 4, 10, 14, 0, 2, 6, 8, 12, /**/ 0, 4, 10, 14, 2, 6, 8, 12, // + 2, 4, 10, 14, 0, 6, 8, 12, /**/ 0, 2, 4, 10, 14, 6, 8, 12, // + 6, 10, 14, 0, 2, 4, 8, 12, /**/ 0, 6, 10, 14, 2, 4, 8, 12, // + 2, 6, 10, 14, 0, 4, 8, 12, /**/ 0, 2, 6, 10, 14, 4, 8, 12, // + 4, 6, 10, 14, 0, 2, 8, 12, /**/ 0, 4, 6, 10, 14, 2, 8, 12, // + 2, 4, 6, 10, 14, 0, 8, 12, /**/ 0, 2, 4, 6, 10, 14, 8, 12, // + 8, 10, 14, 0, 2, 4, 6, 12, /**/ 0, 8, 10, 14, 2, 4, 6, 12, // + 2, 8, 10, 14, 0, 4, 6, 12, /**/ 0, 2, 8, 10, 14, 4, 6, 12, // + 4, 8, 10, 14, 0, 2, 6, 12, /**/ 0, 4, 8, 10, 14, 2, 6, 12, // + 2, 4, 8, 10, 14, 0, 6, 12, /**/ 0, 2, 4, 8, 10, 14, 6, 12, // + 6, 8, 10, 14, 0, 2, 4, 12, /**/ 0, 6, 8, 10, 14, 2, 4, 12, // + 2, 6, 8, 10, 14, 0, 4, 12, /**/ 0, 2, 6, 8, 10, 14, 4, 12, // + 4, 6, 8, 10, 14, 0, 2, 12, /**/ 0, 4, 6, 8, 10, 14, 2, 12, // + 2, 4, 6, 8, 10, 14, 0, 12, /**/ 0, 2, 4, 6, 8, 10, 14, 12, // + 12, 14, 0, 2, 4, 6, 8, 10, /**/ 0, 12, 14, 2, 4, 6, 8, 10, // + 2, 12, 14, 0, 4, 6, 8, 10, /**/ 0, 2, 12, 14, 4, 6, 8, 10, // + 4, 12, 14, 0, 2, 6, 8, 10, /**/ 0, 4, 12, 14, 2, 6, 8, 10, // + 2, 4, 12, 14, 0, 6, 8, 10, /**/ 0, 2, 4, 12, 14, 6, 8, 10, // + 6, 12, 14, 0, 2, 4, 8, 10, /**/ 0, 6, 12, 14, 2, 4, 8, 10, // + 2, 6, 12, 14, 0, 4, 8, 10, /**/ 0, 2, 6, 12, 14, 4, 8, 10, // + 4, 6, 12, 14, 0, 2, 8, 10, /**/ 0, 4, 6, 12, 14, 2, 8, 10, // + 2, 4, 6, 12, 14, 0, 8, 10, /**/ 0, 2, 4, 6, 12, 14, 8, 10, // + 8, 12, 14, 0, 2, 4, 6, 10, /**/ 0, 8, 12, 14, 2, 4, 6, 10, // + 2, 8, 12, 14, 0, 4, 6, 10, /**/ 0, 2, 8, 12, 14, 4, 6, 10, // + 4, 8, 12, 14, 0, 2, 6, 10, /**/ 0, 4, 8, 12, 14, 2, 6, 10, // + 2, 4, 8, 12, 14, 0, 6, 10, /**/ 0, 2, 4, 8, 12, 14, 6, 10, // + 6, 8, 12, 14, 0, 2, 4, 10, /**/ 0, 6, 8, 12, 14, 2, 4, 10, // + 2, 6, 8, 12, 14, 0, 4, 10, /**/ 0, 2, 6, 8, 12, 14, 4, 10, // + 4, 6, 8, 12, 14, 0, 2, 10, /**/ 0, 4, 6, 8, 12, 14, 2, 10, // + 2, 4, 6, 8, 12, 14, 0, 10, /**/ 0, 2, 4, 6, 8, 12, 14, 10, // + 10, 12, 14, 0, 2, 4, 6, 8, /**/ 0, 10, 12, 14, 2, 4, 6, 8, // + 2, 10, 12, 14, 0, 4, 6, 8, /**/ 0, 2, 10, 12, 14, 4, 6, 8, // + 4, 10, 12, 14, 0, 2, 6, 8, /**/ 0, 4, 10, 12, 14, 2, 6, 8, // + 2, 4, 10, 12, 14, 0, 6, 8, /**/ 0, 2, 4, 10, 12, 14, 6, 8, // + 6, 10, 12, 14, 0, 2, 4, 8, /**/ 0, 6, 10, 12, 14, 2, 4, 8, // + 2, 6, 10, 12, 14, 0, 4, 8, /**/ 0, 2, 6, 10, 12, 14, 4, 8, // + 4, 6, 10, 12, 14, 0, 2, 8, /**/ 0, 4, 6, 10, 12, 14, 2, 8, // + 2, 4, 6, 10, 12, 14, 0, 8, /**/ 0, 2, 4, 6, 10, 12, 14, 8, // + 8, 10, 12, 14, 0, 2, 4, 6, /**/ 0, 8, 10, 12, 14, 2, 4, 6, // + 2, 8, 10, 12, 14, 0, 4, 6, /**/ 0, 2, 8, 10, 12, 14, 4, 6, // + 4, 8, 10, 12, 14, 0, 2, 6, /**/ 0, 4, 8, 10, 12, 14, 2, 6, // + 2, 4, 8, 10, 12, 14, 0, 6, /**/ 0, 2, 4, 8, 10, 12, 14, 6, // + 6, 8, 10, 12, 14, 0, 2, 4, /**/ 0, 6, 8, 10, 12, 14, 2, 4, // + 2, 6, 8, 10, 12, 14, 0, 4, /**/ 0, 2, 6, 8, 10, 12, 14, 4, // + 4, 6, 8, 10, 12, 14, 0, 2, /**/ 0, 4, 6, 8, 10, 12, 14, 2, // + 2, 4, 6, 8, 10, 12, 14, 0, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const VFromD byte_idx{Load(d8, table + mask_bits * 8).raw}; + const VFromD pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE VFromD IndicesFromNotBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Rebind d8; + const Twice d8t; + const RebindToUnsigned du; + + alignas(16) static constexpr uint8_t table[2048] = { + // PrintCompressNot16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 14, 0, // + 0, 4, 6, 8, 10, 12, 14, 2, /**/ 4, 6, 8, 10, 12, 14, 0, 2, // + 0, 2, 6, 8, 10, 12, 14, 4, /**/ 2, 6, 8, 10, 12, 14, 0, 4, // + 0, 6, 8, 10, 12, 14, 2, 4, /**/ 6, 8, 10, 12, 14, 0, 2, 4, // + 0, 2, 4, 8, 10, 12, 14, 6, /**/ 2, 4, 8, 10, 12, 14, 0, 6, // + 0, 4, 8, 10, 12, 14, 2, 6, /**/ 4, 8, 10, 12, 14, 0, 2, 6, // + 0, 2, 8, 10, 12, 14, 4, 6, /**/ 2, 8, 10, 12, 14, 0, 4, 6, // + 0, 8, 10, 12, 14, 2, 4, 6, /**/ 8, 10, 12, 14, 0, 2, 4, 6, // + 0, 2, 4, 6, 10, 12, 14, 8, /**/ 2, 4, 6, 10, 12, 14, 0, 8, // + 0, 4, 6, 10, 12, 14, 2, 8, /**/ 4, 6, 10, 12, 14, 0, 2, 8, // + 0, 2, 6, 10, 12, 14, 4, 8, /**/ 2, 6, 10, 12, 14, 0, 4, 8, // + 0, 6, 10, 12, 14, 2, 4, 8, /**/ 6, 10, 12, 14, 0, 2, 4, 8, // + 0, 2, 4, 10, 12, 14, 6, 8, /**/ 2, 4, 10, 12, 14, 0, 6, 8, // + 0, 4, 10, 12, 14, 2, 6, 8, /**/ 4, 10, 12, 14, 0, 2, 6, 8, // + 0, 2, 10, 12, 14, 4, 6, 8, /**/ 2, 10, 12, 14, 0, 4, 6, 8, // + 0, 10, 12, 14, 2, 4, 6, 8, /**/ 10, 12, 14, 0, 2, 4, 6, 8, // + 0, 2, 4, 6, 8, 12, 14, 10, /**/ 2, 4, 6, 8, 12, 14, 0, 10, // + 0, 4, 6, 8, 12, 14, 2, 10, /**/ 4, 6, 8, 12, 14, 0, 2, 10, // + 0, 2, 6, 8, 12, 14, 4, 10, /**/ 2, 6, 8, 12, 14, 0, 4, 10, // + 0, 6, 8, 12, 14, 2, 4, 10, /**/ 6, 8, 12, 14, 0, 2, 4, 10, // + 0, 2, 4, 8, 12, 14, 6, 10, /**/ 2, 4, 8, 12, 14, 0, 6, 10, // + 0, 4, 8, 12, 14, 2, 6, 10, /**/ 4, 8, 12, 14, 0, 2, 6, 10, // + 0, 2, 8, 12, 14, 4, 6, 10, /**/ 2, 8, 12, 14, 0, 4, 6, 10, // + 0, 8, 12, 14, 2, 4, 6, 10, /**/ 8, 12, 14, 0, 2, 4, 6, 10, // + 0, 2, 4, 6, 12, 14, 8, 10, /**/ 2, 4, 6, 12, 14, 0, 8, 10, // + 0, 4, 6, 12, 14, 2, 8, 10, /**/ 4, 6, 12, 14, 0, 2, 8, 10, // + 0, 2, 6, 12, 14, 4, 8, 10, /**/ 2, 6, 12, 14, 0, 4, 8, 10, // + 0, 6, 12, 14, 2, 4, 8, 10, /**/ 6, 12, 14, 0, 2, 4, 8, 10, // + 0, 2, 4, 12, 14, 6, 8, 10, /**/ 2, 4, 12, 14, 0, 6, 8, 10, // + 0, 4, 12, 14, 2, 6, 8, 10, /**/ 4, 12, 14, 0, 2, 6, 8, 10, // + 0, 2, 12, 14, 4, 6, 8, 10, /**/ 2, 12, 14, 0, 4, 6, 8, 10, // + 0, 12, 14, 2, 4, 6, 8, 10, /**/ 12, 14, 0, 2, 4, 6, 8, 10, // + 0, 2, 4, 6, 8, 10, 14, 12, /**/ 2, 4, 6, 8, 10, 14, 0, 12, // + 0, 4, 6, 8, 10, 14, 2, 12, /**/ 4, 6, 8, 10, 14, 0, 2, 12, // + 0, 2, 6, 8, 10, 14, 4, 12, /**/ 2, 6, 8, 10, 14, 0, 4, 12, // + 0, 6, 8, 10, 14, 2, 4, 12, /**/ 6, 8, 10, 14, 0, 2, 4, 12, // + 0, 2, 4, 8, 10, 14, 6, 12, /**/ 2, 4, 8, 10, 14, 0, 6, 12, // + 0, 4, 8, 10, 14, 2, 6, 12, /**/ 4, 8, 10, 14, 0, 2, 6, 12, // + 0, 2, 8, 10, 14, 4, 6, 12, /**/ 2, 8, 10, 14, 0, 4, 6, 12, // + 0, 8, 10, 14, 2, 4, 6, 12, /**/ 8, 10, 14, 0, 2, 4, 6, 12, // + 0, 2, 4, 6, 10, 14, 8, 12, /**/ 2, 4, 6, 10, 14, 0, 8, 12, // + 0, 4, 6, 10, 14, 2, 8, 12, /**/ 4, 6, 10, 14, 0, 2, 8, 12, // + 0, 2, 6, 10, 14, 4, 8, 12, /**/ 2, 6, 10, 14, 0, 4, 8, 12, // + 0, 6, 10, 14, 2, 4, 8, 12, /**/ 6, 10, 14, 0, 2, 4, 8, 12, // + 0, 2, 4, 10, 14, 6, 8, 12, /**/ 2, 4, 10, 14, 0, 6, 8, 12, // + 0, 4, 10, 14, 2, 6, 8, 12, /**/ 4, 10, 14, 0, 2, 6, 8, 12, // + 0, 2, 10, 14, 4, 6, 8, 12, /**/ 2, 10, 14, 0, 4, 6, 8, 12, // + 0, 10, 14, 2, 4, 6, 8, 12, /**/ 10, 14, 0, 2, 4, 6, 8, 12, // + 0, 2, 4, 6, 8, 14, 10, 12, /**/ 2, 4, 6, 8, 14, 0, 10, 12, // + 0, 4, 6, 8, 14, 2, 10, 12, /**/ 4, 6, 8, 14, 0, 2, 10, 12, // + 0, 2, 6, 8, 14, 4, 10, 12, /**/ 2, 6, 8, 14, 0, 4, 10, 12, // + 0, 6, 8, 14, 2, 4, 10, 12, /**/ 6, 8, 14, 0, 2, 4, 10, 12, // + 0, 2, 4, 8, 14, 6, 10, 12, /**/ 2, 4, 8, 14, 0, 6, 10, 12, // + 0, 4, 8, 14, 2, 6, 10, 12, /**/ 4, 8, 14, 0, 2, 6, 10, 12, // + 0, 2, 8, 14, 4, 6, 10, 12, /**/ 2, 8, 14, 0, 4, 6, 10, 12, // + 0, 8, 14, 2, 4, 6, 10, 12, /**/ 8, 14, 0, 2, 4, 6, 10, 12, // + 0, 2, 4, 6, 14, 8, 10, 12, /**/ 2, 4, 6, 14, 0, 8, 10, 12, // + 0, 4, 6, 14, 2, 8, 10, 12, /**/ 4, 6, 14, 0, 2, 8, 10, 12, // + 0, 2, 6, 14, 4, 8, 10, 12, /**/ 2, 6, 14, 0, 4, 8, 10, 12, // + 0, 6, 14, 2, 4, 8, 10, 12, /**/ 6, 14, 0, 2, 4, 8, 10, 12, // + 0, 2, 4, 14, 6, 8, 10, 12, /**/ 2, 4, 14, 0, 6, 8, 10, 12, // + 0, 4, 14, 2, 6, 8, 10, 12, /**/ 4, 14, 0, 2, 6, 8, 10, 12, // + 0, 2, 14, 4, 6, 8, 10, 12, /**/ 2, 14, 0, 4, 6, 8, 10, 12, // + 0, 14, 2, 4, 6, 8, 10, 12, /**/ 14, 0, 2, 4, 6, 8, 10, 12, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 0, 14, // + 0, 4, 6, 8, 10, 12, 2, 14, /**/ 4, 6, 8, 10, 12, 0, 2, 14, // + 0, 2, 6, 8, 10, 12, 4, 14, /**/ 2, 6, 8, 10, 12, 0, 4, 14, // + 0, 6, 8, 10, 12, 2, 4, 14, /**/ 6, 8, 10, 12, 0, 2, 4, 14, // + 0, 2, 4, 8, 10, 12, 6, 14, /**/ 2, 4, 8, 10, 12, 0, 6, 14, // + 0, 4, 8, 10, 12, 2, 6, 14, /**/ 4, 8, 10, 12, 0, 2, 6, 14, // + 0, 2, 8, 10, 12, 4, 6, 14, /**/ 2, 8, 10, 12, 0, 4, 6, 14, // + 0, 8, 10, 12, 2, 4, 6, 14, /**/ 8, 10, 12, 0, 2, 4, 6, 14, // + 0, 2, 4, 6, 10, 12, 8, 14, /**/ 2, 4, 6, 10, 12, 0, 8, 14, // + 0, 4, 6, 10, 12, 2, 8, 14, /**/ 4, 6, 10, 12, 0, 2, 8, 14, // + 0, 2, 6, 10, 12, 4, 8, 14, /**/ 2, 6, 10, 12, 0, 4, 8, 14, // + 0, 6, 10, 12, 2, 4, 8, 14, /**/ 6, 10, 12, 0, 2, 4, 8, 14, // + 0, 2, 4, 10, 12, 6, 8, 14, /**/ 2, 4, 10, 12, 0, 6, 8, 14, // + 0, 4, 10, 12, 2, 6, 8, 14, /**/ 4, 10, 12, 0, 2, 6, 8, 14, // + 0, 2, 10, 12, 4, 6, 8, 14, /**/ 2, 10, 12, 0, 4, 6, 8, 14, // + 0, 10, 12, 2, 4, 6, 8, 14, /**/ 10, 12, 0, 2, 4, 6, 8, 14, // + 0, 2, 4, 6, 8, 12, 10, 14, /**/ 2, 4, 6, 8, 12, 0, 10, 14, // + 0, 4, 6, 8, 12, 2, 10, 14, /**/ 4, 6, 8, 12, 0, 2, 10, 14, // + 0, 2, 6, 8, 12, 4, 10, 14, /**/ 2, 6, 8, 12, 0, 4, 10, 14, // + 0, 6, 8, 12, 2, 4, 10, 14, /**/ 6, 8, 12, 0, 2, 4, 10, 14, // + 0, 2, 4, 8, 12, 6, 10, 14, /**/ 2, 4, 8, 12, 0, 6, 10, 14, // + 0, 4, 8, 12, 2, 6, 10, 14, /**/ 4, 8, 12, 0, 2, 6, 10, 14, // + 0, 2, 8, 12, 4, 6, 10, 14, /**/ 2, 8, 12, 0, 4, 6, 10, 14, // + 0, 8, 12, 2, 4, 6, 10, 14, /**/ 8, 12, 0, 2, 4, 6, 10, 14, // + 0, 2, 4, 6, 12, 8, 10, 14, /**/ 2, 4, 6, 12, 0, 8, 10, 14, // + 0, 4, 6, 12, 2, 8, 10, 14, /**/ 4, 6, 12, 0, 2, 8, 10, 14, // + 0, 2, 6, 12, 4, 8, 10, 14, /**/ 2, 6, 12, 0, 4, 8, 10, 14, // + 0, 6, 12, 2, 4, 8, 10, 14, /**/ 6, 12, 0, 2, 4, 8, 10, 14, // + 0, 2, 4, 12, 6, 8, 10, 14, /**/ 2, 4, 12, 0, 6, 8, 10, 14, // + 0, 4, 12, 2, 6, 8, 10, 14, /**/ 4, 12, 0, 2, 6, 8, 10, 14, // + 0, 2, 12, 4, 6, 8, 10, 14, /**/ 2, 12, 0, 4, 6, 8, 10, 14, // + 0, 12, 2, 4, 6, 8, 10, 14, /**/ 12, 0, 2, 4, 6, 8, 10, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 0, 12, 14, // + 0, 4, 6, 8, 10, 2, 12, 14, /**/ 4, 6, 8, 10, 0, 2, 12, 14, // + 0, 2, 6, 8, 10, 4, 12, 14, /**/ 2, 6, 8, 10, 0, 4, 12, 14, // + 0, 6, 8, 10, 2, 4, 12, 14, /**/ 6, 8, 10, 0, 2, 4, 12, 14, // + 0, 2, 4, 8, 10, 6, 12, 14, /**/ 2, 4, 8, 10, 0, 6, 12, 14, // + 0, 4, 8, 10, 2, 6, 12, 14, /**/ 4, 8, 10, 0, 2, 6, 12, 14, // + 0, 2, 8, 10, 4, 6, 12, 14, /**/ 2, 8, 10, 0, 4, 6, 12, 14, // + 0, 8, 10, 2, 4, 6, 12, 14, /**/ 8, 10, 0, 2, 4, 6, 12, 14, // + 0, 2, 4, 6, 10, 8, 12, 14, /**/ 2, 4, 6, 10, 0, 8, 12, 14, // + 0, 4, 6, 10, 2, 8, 12, 14, /**/ 4, 6, 10, 0, 2, 8, 12, 14, // + 0, 2, 6, 10, 4, 8, 12, 14, /**/ 2, 6, 10, 0, 4, 8, 12, 14, // + 0, 6, 10, 2, 4, 8, 12, 14, /**/ 6, 10, 0, 2, 4, 8, 12, 14, // + 0, 2, 4, 10, 6, 8, 12, 14, /**/ 2, 4, 10, 0, 6, 8, 12, 14, // + 0, 4, 10, 2, 6, 8, 12, 14, /**/ 4, 10, 0, 2, 6, 8, 12, 14, // + 0, 2, 10, 4, 6, 8, 12, 14, /**/ 2, 10, 0, 4, 6, 8, 12, 14, // + 0, 10, 2, 4, 6, 8, 12, 14, /**/ 10, 0, 2, 4, 6, 8, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 0, 10, 12, 14, // + 0, 4, 6, 8, 2, 10, 12, 14, /**/ 4, 6, 8, 0, 2, 10, 12, 14, // + 0, 2, 6, 8, 4, 10, 12, 14, /**/ 2, 6, 8, 0, 4, 10, 12, 14, // + 0, 6, 8, 2, 4, 10, 12, 14, /**/ 6, 8, 0, 2, 4, 10, 12, 14, // + 0, 2, 4, 8, 6, 10, 12, 14, /**/ 2, 4, 8, 0, 6, 10, 12, 14, // + 0, 4, 8, 2, 6, 10, 12, 14, /**/ 4, 8, 0, 2, 6, 10, 12, 14, // + 0, 2, 8, 4, 6, 10, 12, 14, /**/ 2, 8, 0, 4, 6, 10, 12, 14, // + 0, 8, 2, 4, 6, 10, 12, 14, /**/ 8, 0, 2, 4, 6, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 0, 8, 10, 12, 14, // + 0, 4, 6, 2, 8, 10, 12, 14, /**/ 4, 6, 0, 2, 8, 10, 12, 14, // + 0, 2, 6, 4, 8, 10, 12, 14, /**/ 2, 6, 0, 4, 8, 10, 12, 14, // + 0, 6, 2, 4, 8, 10, 12, 14, /**/ 6, 0, 2, 4, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 0, 6, 8, 10, 12, 14, // + 0, 4, 2, 6, 8, 10, 12, 14, /**/ 4, 0, 2, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 0, 4, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const VFromD byte_idx{Load(d8, table + mask_bits * 8).raw}; + const VFromD pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE VFromD IndicesFromBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[256] = { + // PrintCompress32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, // + 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, // + 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE VFromD IndicesFromNotBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[256] = { + // PrintCompressNot32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, + 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE VFromD IndicesFromBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[64] = { + // PrintCompress64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE VFromD IndicesFromNotBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[64] = { + // PrintCompressNot64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_API Vec128 CompressBits(Vec128 v, uint64_t mask_bits) { + const DFromV d; + const RebindToUnsigned du; + + HWY_DASSERT(mask_bits < (1ull << N)); + const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits)); + return BitCast(d, TableLookupBytes(BitCast(du, v), indices)); +} + +template +HWY_API Vec128 CompressNotBits(Vec128 v, uint64_t mask_bits) { + const DFromV d; + const RebindToUnsigned du; + + HWY_DASSERT(mask_bits < (1ull << N)); + const auto indices = BitCast(du, detail::IndicesFromNotBits128(d, mask_bits)); + return BitCast(d, TableLookupBytes(BitCast(du, v), indices)); +} + +} // namespace detail + +// Single lane: no-op +template +HWY_API Vec128 Compress(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + // If mask[1] = 1 and mask[0] = 0, then swap both halves, else keep. + const DFromV d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskL, maskH); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 bytes +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + const DFromV d; + return detail::CompressBits(v, BitsFromMask(d, mask)); +} + +// ------------------------------ CompressNot + +// Single lane: no-op +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // If mask[1] = 0 and mask[0] = 1, then swap both halves, else keep. + const DFromV d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskH, maskL); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + const DFromV d; + // For partial vectors, we cannot pull the Not() into the table because + // BitsFromMask clears the upper bits. + if (N < 16 / sizeof(T)) { + return detail::CompressBits(v, BitsFromMask(d, Not(mask))); + } + return detail::CompressNotBits(v, BitsFromMask(d, mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +template +HWY_API Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::CompressBits(v, mask_bits); +} + +// ------------------------------ CompressStore, CompressBitsStore + +template +HWY_API size_t CompressStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + const uint64_t mask_bits = BitsFromMask(d, m); + HWY_DASSERT(mask_bits < (1ull << MaxLanes(d))); + const size_t count = PopCount(mask_bits); + + const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + StoreU(compressed, d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + const uint64_t mask_bits = BitsFromMask(d, m); + HWY_DASSERT(mask_bits < (1ull << MaxLanes(d))); + const size_t count = PopCount(mask_bits); + + const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + BlendedStore(compressed, FirstN(d, count), d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + uint64_t mask_bits = 0; + constexpr size_t kN = MaxLanes(d); + constexpr size_t kNumBytes = (kN + 7) / 8; + CopyBytes(bits, &mask_bits); + if (kN < 8) { + mask_bits &= (1ull << kN) - 1; + } + const size_t count = PopCount(mask_bits); + + const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + StoreU(compressed, d, unaligned); + + detail::MaybeUnpoison(unaligned, count); + return count; +} + +// ------------------------------ StoreInterleaved2/3/4 + +// HWY_NATIVE_LOAD_STORE_INTERLEAVED not set, hence defined in +// generic_ops-inl.h. + +// ------------------------------ Additional mask logical operations + +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + return mask; +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const FixedTag d; + const auto vmask = VecFromMask(d, mask); + return MaskFromVec(Or(vmask, InterleaveLower(vmask, vmask))); +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const Simd d; + const auto vmask = VecFromMask(d, mask); + const auto neg_vmask = + ResizeBitCast(d, Neg(ResizeBitCast(Full64(), vmask))); + return MaskFromVec(Or(vmask, neg_vmask)); +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const Full128 d; + const Repartition di64; + + auto vmask = BitCast(di64, VecFromMask(d, mask)); + VFromD neg_vmask{__lsx_vsub_q(Zero(di64).raw, vmask.raw)}; + + return MaskFromVec(BitCast(d, Or(vmask, neg_vmask))); +} + +template +HWY_API Mask128 SetBeforeFirst(Mask128 mask) { + return Not(SetAtOrAfterFirst(mask)); +} + +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + return mask; +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const FixedTag d; + const RebindToSigned di; + + const auto vmask = BitCast(di, VecFromMask(d, mask)); + const auto zero = Zero(di); + const auto vmask2 = VecFromMask(di, InterleaveLower(zero, vmask) == zero); + return MaskFromVec(BitCast(d, And(vmask, vmask2))); +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const Simd d; + const RebindToSigned di; + + const auto vmask = ResizeBitCast(Full64(), VecFromMask(d, mask)); + const auto only_first_vmask = + BitCast(d, Neg(ResizeBitCast(di, And(vmask, Neg(vmask))))); + return MaskFromVec(only_first_vmask); +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const Full128 d; + const RebindToSigned di; + + auto vmask = BitCast(di, VecFromMask(d, mask)); + VFromD neg_vmask{__lsx_vsub_q(Zero(di).raw, vmask.raw)}; + + return MaskFromVec(BitCast(d, Neg(And(vmask, neg_vmask)))); +} + +template +HWY_API Mask128 SetAtOrBeforeFirst(Mask128 /*mask*/) { + const FixedTag d; + const RebindToSigned di; + using TI = MakeSigned; + + return RebindMask(d, MaskFromVec(Set(di, TI(-1)))); +} +template +HWY_API Mask128 SetAtOrBeforeFirst(Mask128 mask) { + const Simd d; + return SetBeforeFirst(MaskFromVec(ShiftLeftLanes<1>(VecFromMask(d, mask)))); +} + +// ------------------------------ Reductions +#undef HWY_IF_SUM_OF_LANES_D +#define HWY_IF_SUM_OF_LANES_D(D) \ + HWY_IF_LANES_GT_D(D, 1), \ + hwy::EnableIf, uint8_t>() || \ + (HWY_V_SIZE_D(D) != 8 && HWY_V_SIZE_D(D) != 16)>* = \ + nullptr +// ------------------------------ SumOfLanes + +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + return Set(d, static_cast(GetLane(SumsOf8(v)) & 0xFF)); +} +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + const Repartition d64; + VFromD sums = SumsOf8(v); + sums = SumOfLanes(d64, sums); + return Broadcast<0>(BitCast(d, sums)); +} + +// ------------------------------ Lt128 + +namespace detail { + +// Returns vector-mask for Lt128. Generic for all vector lengths. +template +HWY_INLINE VFromD Lt128Vec(const D d, VFromD a, VFromD b) { + // Truth table of Eq and Lt for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const auto eqHL = Eq(a, b); + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + const VFromD ltLX = ShiftLeftLanes<1>(ltHL); + const VFromD vecHx = IfThenElse(eqHL, ltLX, ltHL); + return InterleaveUpper(d, vecHx, vecHx); +} + +// Returns vector-mask for Eq128. Generic for all vector lengths. +template +HWY_INLINE VFromD Eq128Vec(D d, VFromD a, VFromD b) { + const auto eqHL = VecFromMask(d, Eq(a, b)); + const auto eqLH = Reverse2(d, eqHL); + return And(eqHL, eqLH); +} + +template +HWY_INLINE VFromD Ne128Vec(D d, VFromD a, VFromD b) { + const auto neHL = VecFromMask(d, Ne(a, b)); + const auto neLH = Reverse2(d, neHL); + return Or(neHL, neLH); +} + +template +HWY_INLINE VFromD Lt128UpperVec(D d, VFromD a, VFromD b) { + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + return InterleaveUpper(d, ltHL, ltHL); +} + +template +HWY_INLINE VFromD Eq128UpperVec(D d, VFromD a, VFromD b) { + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + return InterleaveUpper(d, eqHL, eqHL); +} + +template +HWY_INLINE VFromD Ne128UpperVec(D d, VFromD a, VFromD b) { + const VFromD neHL = VecFromMask(d, Ne(a, b)); + return InterleaveUpper(d, neHL, neHL); +} + +} // namespace detail + +template +HWY_API MFromD Lt128(D d, VFromD a, VFromD b) { + return MaskFromVec(detail::Lt128Vec(d, a, b)); +} + +template +HWY_API MFromD Eq128(D d, VFromD a, VFromD b) { + return MaskFromVec(detail::Eq128Vec(d, a, b)); +} + +template +HWY_API MFromD Ne128(D d, VFromD a, VFromD b) { + return MaskFromVec(detail::Ne128Vec(d, a, b)); +} + +template +HWY_API MFromD Lt128Upper(D d, VFromD a, VFromD b) { + return MaskFromVec(detail::Lt128UpperVec(d, a, b)); +} + +template +HWY_API MFromD Eq128Upper(D d, VFromD a, VFromD b) { + return MaskFromVec(detail::Eq128UpperVec(d, a, b)); +} + +template +HWY_API MFromD Ne128Upper(D d, VFromD a, VFromD b) { + return MaskFromVec(detail::Ne128UpperVec(d, a, b)); +} + +// ------------------------------ Min128, Max128 (Lt128) + +// Avoids the extra MaskFromVec in Lt128. +template +HWY_API VFromD Min128(D d, VFromD a, VFromD b) { + return IfVecThenElse(detail::Lt128Vec(d, a, b), a, b); +} + +template +HWY_API VFromD Max128(D d, VFromD a, VFromD b) { + return IfVecThenElse(detail::Lt128Vec(d, b, a), a, b); +} + +template +HWY_API VFromD Min128Upper(D d, VFromD a, VFromD b) { + return IfVecThenElse(detail::Lt128UpperVec(d, a, b), a, b); +} + +template +HWY_API VFromD Max128Upper(D d, VFromD a, VFromD b) { + return IfVecThenElse(detail::Lt128UpperVec(d, b, a), a, b); +} + +// -------------------- LeadingZeroCount, TrailingZeroCount, +// HighestSetBitIndex + +#ifdef HWY_NATIVE_LEADING_ZERO_COUNT +#undef HWY_NATIVE_LEADING_ZERO_COUNT +#else +#define HWY_NATIVE_LEADING_ZERO_COUNT +#endif + +template ), HWY_IF_V_SIZE_LE_D(DFromV, 16)> +HWY_API V LeadingZeroCount(V v) { + return V{__lsx_vclz_b(v.raw)}; +} + +template ), HWY_IF_V_SIZE_LE_D(DFromV, 16)> +HWY_API V LeadingZeroCount(V v) { + return V{__lsx_vclz_h(v.raw)}; +} + +template ), HWY_IF_V_SIZE_LE_D(DFromV, 16)> +HWY_API V LeadingZeroCount(V v) { + return V{__lsx_vclz_w(v.raw)}; +} + +template ), HWY_IF_V_SIZE_LE_D(DFromV, 16)> +HWY_API V LeadingZeroCount(V v) { + return V{__lsx_vclz_d(v.raw)}; +} + +template +HWY_API V HighestSetBitIndex(V v) { + const DFromV d; + using T = TFromD; + return BitCast(d, Set(d, T{sizeof(T) * 8 - 1}) - LeadingZeroCount(v)); +} + +template +HWY_API V TrailingZeroCount(V v) { + const DFromV d; + const RebindToSigned di; + using T = TFromD; + + const auto lsb = And(v, BitCast(d, Neg(BitCast(di, v)))); + return IfThenElse(Eq(v, Zero(d)), Set(d, T{sizeof(T) * 8}), + HighestSetBitIndex(lsb)); +} + +} // namespace HWY_NAMESPACE +} // namespace hwy + +HWY_AFTER_NAMESPACE(); + +#undef HWY_LSX_IF_EMULATED_D diff --git a/lib/highway/hwy/ops/ppc_vsx-inl.h b/lib/highway/hwy/ops/ppc_vsx-inl.h new file mode 100644 index 00000000000..cb175019683 --- /dev/null +++ b/lib/highway/hwy/ops/ppc_vsx-inl.h @@ -0,0 +1,7489 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 128-bit vectors for VSX/Z14 +// External include guard in highway.h - see comment there. + +#if HWY_TARGET == HWY_Z14 || HWY_TARGET == HWY_Z15 +#define HWY_S390X_HAVE_Z14 1 +#else +#define HWY_S390X_HAVE_Z14 0 +#endif + +#pragma push_macro("vector") +#pragma push_macro("pixel") +#pragma push_macro("bool") + +#undef vector +#undef pixel +#undef bool + +#if HWY_S390X_HAVE_Z14 +#include +#else +#include +#endif + +#pragma pop_macro("vector") +#pragma pop_macro("pixel") +#pragma pop_macro("bool") + +#include "hwy/ops/shared-inl.h" + +// clang's altivec.h gates some intrinsics behind #ifdef __POWER10_VECTOR__, and +// some GCC do the same for _ARCH_PWR10. +// This means we can only use POWER10-specific intrinsics in static dispatch +// mode (where the -mpower10-vector compiler flag is passed). Same for PPC9. +// On other compilers, the usual target check is sufficient. +#if !HWY_S390X_HAVE_Z14 && HWY_TARGET <= HWY_PPC9 && \ + (defined(_ARCH_PWR9) || defined(__POWER9_VECTOR__)) +#define HWY_PPC_HAVE_9 1 +#else +#define HWY_PPC_HAVE_9 0 +#endif + +#if !HWY_S390X_HAVE_Z14 && HWY_TARGET <= HWY_PPC10 && \ + (defined(_ARCH_PWR10) || defined(__POWER10_VECTOR__)) +#define HWY_PPC_HAVE_10 1 +#else +#define HWY_PPC_HAVE_10 0 +#endif + +#if HWY_S390X_HAVE_Z14 && HWY_TARGET <= HWY_Z15 && __ARCH__ >= 13 +#define HWY_S390X_HAVE_Z15 1 +#else +#define HWY_S390X_HAVE_Z15 0 +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +template +struct Raw128; + +// Each Raw128 specialization defines the following typedefs: +// - type: +// the backing Altivec/VSX raw vector type of the Vec128 type +// - RawBoolVec: +// the backing Altivec/VSX raw __bool vector type of the Mask128 type +// - RawT: +// the lane type for intrinsics, in particular vec_splat +// - AlignedRawVec: +// the 128-bit GCC/Clang vector type for aligned loads/stores +// - UnalignedRawVec: +// the 128-bit GCC/Clang vector type for unaligned loads/stores +#define HWY_VSX_RAW128(LANE_TYPE, RAW_VECT_LANE_TYPE, RAW_BOOL_VECT_LANE_TYPE) \ + template <> \ + struct Raw128 { \ + using type = __vector RAW_VECT_LANE_TYPE; \ + using RawBoolVec = __vector __bool RAW_BOOL_VECT_LANE_TYPE; \ + using RawT = RAW_VECT_LANE_TYPE; \ + typedef LANE_TYPE AlignedRawVec \ + __attribute__((__vector_size__(16), __aligned__(16), __may_alias__)); \ + typedef LANE_TYPE UnalignedRawVec __attribute__(( \ + __vector_size__(16), __aligned__(alignof(LANE_TYPE)), __may_alias__)); \ + }; + +HWY_VSX_RAW128(int8_t, signed char, char) +HWY_VSX_RAW128(uint8_t, unsigned char, char) +HWY_VSX_RAW128(int16_t, signed short, short) // NOLINT(runtime/int) +HWY_VSX_RAW128(uint16_t, unsigned short, short) // NOLINT(runtime/int) +HWY_VSX_RAW128(int32_t, signed int, int) +HWY_VSX_RAW128(uint32_t, unsigned int, int) +HWY_VSX_RAW128(int64_t, signed long long, long long) // NOLINT(runtime/int) +HWY_VSX_RAW128(uint64_t, unsigned long long, long long) // NOLINT(runtime/int) +HWY_VSX_RAW128(float, float, int) +HWY_VSX_RAW128(double, double, long long) // NOLINT(runtime/int) + +template <> +struct Raw128 : public Raw128 {}; + +template <> +struct Raw128 : public Raw128 {}; + +#undef HWY_VSX_RAW128 + +} // namespace detail + +template +class Vec128 { + using Raw = typename detail::Raw128::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = N; // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator%=(const Vec128 other) { + return *this = (*this % other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +template +using Vec64 = Vec128; + +template +using Vec32 = Vec128; + +template +using Vec16 = Vec128; + +// FF..FF or 0. +template +struct Mask128 { + typename detail::Raw128::RawBoolVec raw; + + using PrivateT = T; // only for DFromM + static constexpr size_t kPrivateN = N; // only for DFromM +}; + +template +using DFromV = Simd; + +template +using DFromM = Simd; + +template +using TFromV = typename V::PrivateT; + +// ------------------------------ Zero + +// Returns an all-zero vector/part. +template > +HWY_API Vec128 Zero(D /* tag */) { + // There is no vec_splats for 64-bit, so we cannot rely on casting the 0 + // argument in order to select the correct overload. We instead cast the + // return vector type; see also the comment in BitCast. + return Vec128{ + reinterpret_cast::type>(vec_splats(0))}; +} + +template +using VFromD = decltype(Zero(D())); + +// ------------------------------ BitCast + +template +HWY_API VFromD BitCast(D /*d*/, + Vec128().MaxLanes()> v) { + // C-style casts are not sufficient when compiling with + // -fno-lax-vector-conversions, which will be the future default in Clang, + // but reinterpret_cast is. + return VFromD{ + reinterpret_cast>::type>(v.raw)}; +} + +// ------------------------------ ResizeBitCast + +template +HWY_API VFromD ResizeBitCast(D /*d*/, FromV v) { + // C-style casts are not sufficient when compiling with + // -fno-lax-vector-conversions, which will be the future default in Clang, + // but reinterpret_cast is. + return VFromD{ + reinterpret_cast>::type>(v.raw)}; +} + +// ------------------------------ Set + +// Returns a vector/part with all lanes set to "t". +template )> +HWY_API VFromD Set(D /* tag */, TFromD t) { + using RawLane = typename detail::Raw128>::RawT; + return VFromD{vec_splats(static_cast(t))}; +} + +template )> +HWY_API VFromD Set(D d, TFromD t) { + const RebindToUnsigned du; + return BitCast(d, Set(du, BitCastScalar>(t))); +} + +// Returns a vector with uninitialized elements. +template +HWY_API VFromD Undefined(D d) { +#if HWY_COMPILER_GCC_ACTUAL + // Suppressing maybe-uninitialized both here and at the caller does not work, + // so initialize. + return Zero(d); +#elif HWY_HAS_BUILTIN(__builtin_nondeterministic_value) + return VFromD{__builtin_nondeterministic_value(Zero(d).raw)}; +#else + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + typename detail::Raw128>::type raw; + return VFromD{raw}; + HWY_DIAGNOSTICS(pop) +#endif +} + +// ------------------------------ GetLane + +// Gets the single value stored in a vector/part. + +template +HWY_API T GetLane(Vec128 v) { + return static_cast(v.raw[0]); +} + +// ------------------------------ Dup128VecFromValues + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + const typename detail::Raw128>::type raw = { + t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15}; + return VFromD{raw}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const typename detail::Raw128>::type raw = {t0, t1, t2, t3, + t4, t5, t6, t7}; + return VFromD{raw}; +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const RebindToUnsigned du; + return BitCast( + d, Dup128VecFromValues( + du, BitCastScalar(t0), BitCastScalar(t1), + BitCastScalar(t2), BitCastScalar(t3), + BitCastScalar(t4), BitCastScalar(t5), + BitCastScalar(t6), BitCastScalar(t7))); +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + const typename detail::Raw128>::type raw = {t0, t1, t2, t3}; + return VFromD{raw}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + const typename detail::Raw128>::type raw = {t0, t1}; + return VFromD{raw}; +} + +// ================================================== LOGICAL + +// ------------------------------ And + +template +HWY_API Vec128 And(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VU{BitCast(du, a).raw & BitCast(du, b).raw}); +#else + return BitCast(d, VU{vec_and(BitCast(du, a).raw, BitCast(du, b).raw)}); +#endif +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec128 AndNot(Vec128 not_mask, Vec128 mask) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + return BitCast( + d, VU{vec_andc(BitCast(du, mask).raw, BitCast(du, not_mask).raw)}); +} + +// ------------------------------ Or + +template +HWY_API Vec128 Or(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VU{BitCast(du, a).raw | BitCast(du, b).raw}); +#else + return BitCast(d, VU{vec_or(BitCast(du, a).raw, BitCast(du, b).raw)}); +#endif +} + +// ------------------------------ Xor + +template +HWY_API Vec128 Xor(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VU{BitCast(du, a).raw ^ BitCast(du, b).raw}); +#else + return BitCast(d, VU{vec_xor(BitCast(du, a).raw, BitCast(du, b).raw)}); +#endif +} + +// ------------------------------ Not +template +HWY_API Vec128 Not(Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + return BitCast(d, VU{vec_nor(BitCast(du, v).raw, BitCast(du, v).raw)}); +} + +// ------------------------------ IsConstantRawAltivecVect +namespace detail { + +template +static HWY_INLINE bool IsConstantRawAltivecVect( + hwy::SizeTag<1> /* lane_size_tag */, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]) && + __builtin_constant_p(v[2]) && __builtin_constant_p(v[3]) && + __builtin_constant_p(v[4]) && __builtin_constant_p(v[5]) && + __builtin_constant_p(v[6]) && __builtin_constant_p(v[7]) && + __builtin_constant_p(v[8]) && __builtin_constant_p(v[9]) && + __builtin_constant_p(v[10]) && __builtin_constant_p(v[11]) && + __builtin_constant_p(v[12]) && __builtin_constant_p(v[13]) && + __builtin_constant_p(v[14]) && __builtin_constant_p(v[15]); +} + +template +static HWY_INLINE bool IsConstantRawAltivecVect( + hwy::SizeTag<2> /* lane_size_tag */, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]) && + __builtin_constant_p(v[2]) && __builtin_constant_p(v[3]) && + __builtin_constant_p(v[4]) && __builtin_constant_p(v[5]) && + __builtin_constant_p(v[6]) && __builtin_constant_p(v[7]); +} + +template +static HWY_INLINE bool IsConstantRawAltivecVect( + hwy::SizeTag<4> /* lane_size_tag */, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]) && + __builtin_constant_p(v[2]) && __builtin_constant_p(v[3]); +} + +template +static HWY_INLINE bool IsConstantRawAltivecVect( + hwy::SizeTag<8> /* lane_size_tag */, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]); +} + +template +static HWY_INLINE bool IsConstantRawAltivecVect(RawV v) { + return IsConstantRawAltivecVect(hwy::SizeTag(), v); +} + +} // namespace detail + +// ------------------------------ TernaryLogic +#if HWY_PPC_HAVE_10 +namespace detail { + +// NOTE: the kTernLogOp bits of the PPC10 TernaryLogic operation are in reverse +// order of the kTernLogOp bits of AVX3 +// _mm_ternarylogic_epi64(a, b, c, kTernLogOp) +template +HWY_INLINE V TernaryLogic(V a, V b, V c) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const auto a_raw = BitCast(du, a).raw; + const auto b_raw = BitCast(du, b).raw; + const auto c_raw = BitCast(du, c).raw; + +#if HWY_COMPILER_GCC_ACTUAL + // Use inline assembly on GCC to work around GCC compiler bug + typename detail::Raw128>::type raw_ternlog_result; + __asm__("xxeval %x0,%x1,%x2,%x3,%4" + : "=wa"(raw_ternlog_result) + : "wa"(a_raw), "wa"(b_raw), "wa"(c_raw), + "n"(static_cast(kTernLogOp)) + :); +#else + const auto raw_ternlog_result = + vec_ternarylogic(a_raw, b_raw, c_raw, kTernLogOp); +#endif + + return BitCast(d, VU{raw_ternlog_result}); +} + +} // namespace detail + +// ------------------------------ Xor3 + +#ifdef HWY_NATIVE_XOR3 +#undef HWY_NATIVE_XOR3 +#else +#define HWY_NATIVE_XOR3 +#endif + +template +HWY_API Vec128 Xor3(Vec128 x1, Vec128 x2, Vec128 x3) { +#if defined(__OPTIMIZE__) + if (static_cast(detail::IsConstantRawAltivecVect(x1.raw)) + + static_cast(detail::IsConstantRawAltivecVect(x2.raw)) + + static_cast(detail::IsConstantRawAltivecVect(x3.raw)) >= + 2) { + return Xor(x1, Xor(x2, x3)); + } else // NOLINT +#endif + { + return detail::TernaryLogic<0x69>(x1, x2, x3); + } +} + +// ------------------------------ XorAndNot + +#ifdef HWY_NATIVE_BCAX +#undef HWY_NATIVE_BCAX +#else +#define HWY_NATIVE_BCAX +#endif + +template +HWY_API Vec128 XorAndNot(Vec128 x, Vec128 a1, + Vec128 a2) { +#if defined(__OPTIMIZE__) + if (static_cast(detail::IsConstantRawAltivecVect(x.raw)) + + static_cast(detail::IsConstantRawAltivecVect(a1.raw)) + + static_cast(detail::IsConstantRawAltivecVect(a2.raw)) >= + 2) { + return Xor(x, AndNot(a1, a2)); + } else // NOLINT +#endif + { + return detail::TernaryLogic<0x4B>(x, a1, a2); + } +} + +#endif // HWY_PPC_HAVE_10 + +// ------------------------------ Or3 +template +HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { +#if HWY_PPC_HAVE_10 +#if defined(__OPTIMIZE__) + if (static_cast(detail::IsConstantRawAltivecVect(o1.raw)) + + static_cast(detail::IsConstantRawAltivecVect(o2.raw)) + + static_cast(detail::IsConstantRawAltivecVect(o3.raw)) >= + 2) { + return Or(o1, Or(o2, o3)); + } else // NOLINT +#endif + { + return detail::TernaryLogic<0x7F>(o1, o2, o3); + } +#else + return Or(o1, Or(o2, o3)); +#endif +} + +// ------------------------------ OrAnd +template +HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { +#if HWY_PPC_HAVE_10 +#if defined(__OPTIMIZE__) + if (detail::IsConstantRawAltivecVect(a1.raw) && + detail::IsConstantRawAltivecVect(a2.raw)) { + return Or(o, And(a1, a2)); + } else // NOLINT +#endif + { + return detail::TernaryLogic<0x1F>(o, a1, a2); + } +#else + return Or(o, And(a1, a2)); +#endif +} + +// ------------------------------ IfVecThenElse +template +HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, + Vec128 no) { + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, VFromD{vec_sel(BitCast(du, no).raw, BitCast(du, yes).raw, + BitCast(du, mask).raw)}); +} + +// ------------------------------ BitwiseIfThenElse + +#ifdef HWY_NATIVE_BITWISE_IF_THEN_ELSE +#undef HWY_NATIVE_BITWISE_IF_THEN_ELSE +#else +#define HWY_NATIVE_BITWISE_IF_THEN_ELSE +#endif + +template +HWY_API V BitwiseIfThenElse(V mask, V yes, V no) { + return IfVecThenElse(mask, yes, no); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec128 operator&(Vec128 a, Vec128 b) { + return And(a, b); +} + +template +HWY_API Vec128 operator|(Vec128 a, Vec128 b) { + return Or(a, b); +} + +template +HWY_API Vec128 operator^(Vec128 a, Vec128 b) { + return Xor(a, b); +} + +// ------------------------------ PopulationCount + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +template +HWY_API Vec128 PopulationCount(Vec128 v) { + return Vec128{vec_popcnt(v.raw)}; +} + +// ================================================== SIGN + +// ------------------------------ Neg + +template +HWY_API Vec128 Neg(Vec128 v) { + // If T is an signed integer type, use Zero(d) - v instead of vec_neg to + // avoid undefined behavior in the case where v[i] == LimitsMin() + const DFromV d; + return Zero(d) - v; +} + +template +HWY_API Vec128 Neg(Vec128 v) { +#if HWY_S390X_HAVE_Z14 + return Xor(v, SignBit(DFromV())); +#else + return Vec128{vec_neg(v.raw)}; +#endif +} + +template +HWY_API Vec128 Neg(const Vec128 v) { + return Xor(v, SignBit(DFromV())); +} + +// ------------------------------ Abs + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +template +HWY_API Vec128 Abs(Vec128 v) { + // If T is a signed integer type, use Max(v, Neg(v)) instead of vec_abs to + // avoid undefined behavior in the case where v[i] == LimitsMin(). + return Max(v, Neg(v)); +} + +template +HWY_API Vec128 Abs(Vec128 v) { + return Vec128{vec_abs(v.raw)}; +} + +// ------------------------------ CopySign + +#if HWY_S390X_HAVE_Z14 +template +HWY_API V CopySign(const V magn, const V sign) { + static_assert(IsFloat>(), "Only makes sense for floating-point"); + + const DFromV d; + const auto msb = SignBit(d); + + // Truth table for msb, magn, sign | bitwise msb ? sign : mag + // 0 0 0 | 0 + // 0 0 1 | 0 + // 0 1 0 | 1 + // 0 1 1 | 1 + // 1 0 0 | 0 + // 1 0 1 | 1 + // 1 1 0 | 0 + // 1 1 1 | 1 + return BitwiseIfThenElse(msb, sign, magn); +} +#else // VSX +template +HWY_API Vec128 CopySign(Vec128 magn, + Vec128 sign) { + // Work around compiler bugs that are there with vec_cpsgn on older versions + // of GCC/Clang +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1200 + return Vec128{__builtin_vec_copysign(magn.raw, sign.raw)}; +#elif HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1200 && \ + HWY_HAS_BUILTIN(__builtin_vsx_xvcpsgnsp) + return Vec128{__builtin_vsx_xvcpsgnsp(magn.raw, sign.raw)}; +#else + return Vec128{vec_cpsgn(sign.raw, magn.raw)}; +#endif +} + +template +HWY_API Vec128 CopySign(Vec128 magn, + Vec128 sign) { + // Work around compiler bugs that are there with vec_cpsgn on older versions + // of GCC/Clang +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1200 + return Vec128{__builtin_vec_copysign(magn.raw, sign.raw)}; +#elif HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1200 && \ + HWY_HAS_BUILTIN(__builtin_vsx_xvcpsgndp) + return Vec128{__builtin_vsx_xvcpsgndp(magn.raw, sign.raw)}; +#else + return Vec128{vec_cpsgn(sign.raw, magn.raw)}; +#endif +} +#endif // HWY_S390X_HAVE_Z14 + +template +HWY_API Vec128 CopySignToAbs(Vec128 abs, Vec128 sign) { + // PPC8 can also handle abs < 0, so no extra action needed. + static_assert(IsFloat(), "Only makes sense for floating-point"); + return CopySign(abs, sign); +} + +// ================================================== MEMORY (1) + +// Note: type punning is safe because the types are tagged with may_alias. +// (https://godbolt.org/z/fqrWjfjsP) + +// ------------------------------ Load + +template > +HWY_API Vec128 Load(D /* tag */, const T* HWY_RESTRICT aligned) { +// Suppress the ignoring attributes warning that is generated by +// HWY_RCAST_ALIGNED(const LoadRaw*, aligned) with GCC +#if HWY_COMPILER_GCC + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4649, ignored "-Wignored-attributes") +#endif + + using LoadRaw = typename detail::Raw128::AlignedRawVec; + const LoadRaw* HWY_RESTRICT p = HWY_RCAST_ALIGNED(const LoadRaw*, aligned); + using ResultRaw = typename detail::Raw128::type; + return Vec128{reinterpret_cast(*p)}; + +#if HWY_COMPILER_GCC + HWY_DIAGNOSTICS(pop) +#endif +} + +// Any <= 64 bit +template > +HWY_API VFromD Load(D d, const T* HWY_RESTRICT p) { + using BitsT = UnsignedFromSize; + + BitsT bits; + const Repartition d_bits; + CopyBytes(p, &bits); + return BitCast(d, Set(d_bits, bits)); +} + +// ================================================== MASK + +// ------------------------------ Mask + +// Mask and Vec are both backed by vector types (true = FF..FF). +template +HWY_API Mask128 MaskFromVec(Vec128 v) { + using Raw = typename detail::Raw128::RawBoolVec; + return Mask128{reinterpret_cast(v.raw)}; +} + +template +using MFromD = decltype(MaskFromVec(VFromD())); + +template +HWY_API Vec128 VecFromMask(Mask128 v) { + return Vec128{ + reinterpret_cast::type>(v.raw)}; +} + +template +HWY_API VFromD VecFromMask(D /* tag */, MFromD v) { + return VFromD{ + reinterpret_cast>::type>(v.raw)}; +} + +// mask ? yes : no +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, VFromD{vec_sel( + BitCast(du, no).raw, BitCast(du, yes).raw, mask.raw)}); +} + +// mask ? yes : 0 +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return yes & VecFromMask(DFromV(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return AndNot(VecFromMask(DFromV(), mask), no); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(Mask128 m) { + return Mask128{vec_nor(m.raw, m.raw)}; +} + +template +HWY_API Mask128 And(Mask128 a, Mask128 b) { +#if HWY_S390X_HAVE_Z14 + return Mask128{a.raw & b.raw}; +#else + return Mask128{vec_and(a.raw, b.raw)}; +#endif +} + +template +HWY_API Mask128 AndNot(Mask128 a, Mask128 b) { + return Mask128{vec_andc(b.raw, a.raw)}; +} + +template +HWY_API Mask128 Or(Mask128 a, Mask128 b) { +#if HWY_S390X_HAVE_Z14 + return Mask128{a.raw | b.raw}; +#else + return Mask128{vec_or(a.raw, b.raw)}; +#endif +} + +template +HWY_API Mask128 Xor(Mask128 a, Mask128 b) { +#if HWY_S390X_HAVE_Z14 + return Mask128{a.raw ^ b.raw}; +#else + return Mask128{vec_xor(a.raw, b.raw)}; +#endif +} + +template +HWY_API Mask128 ExclusiveNeither(Mask128 a, Mask128 b) { + return Mask128{vec_nor(a.raw, b.raw)}; +} + +// ------------------------------ ShiftLeftSame + +template +HWY_API Vec128 ShiftLeftSame(Vec128 v, const int bits) { + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; + +#if HWY_S390X_HAVE_Z14 + return BitCast(d, + VFromD{BitCast(du, v).raw + << Set(du, static_cast(bits)).raw}); +#else + // Do an unsigned vec_sl operation to avoid undefined behavior + return BitCast( + d, VFromD{ + vec_sl(BitCast(du, v).raw, Set(du, static_cast(bits)).raw)}); +#endif +} + +// ------------------------------ ShiftRightSame + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { + using TU = typename detail::Raw128>::RawT; +#if HWY_S390X_HAVE_Z14 + return Vec128{v.raw >> vec_splats(static_cast(bits))}; +#else + return Vec128{vec_sr(v.raw, vec_splats(static_cast(bits)))}; +#endif +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { +#if HWY_S390X_HAVE_Z14 + using TI = typename detail::Raw128::RawT; + return Vec128{v.raw >> vec_splats(static_cast(bits))}; +#else + using TU = typename detail::Raw128>::RawT; + return Vec128{vec_sra(v.raw, vec_splats(static_cast(bits)))}; +#endif +} + +// ------------------------------ ShiftLeft + +template +HWY_API Vec128 ShiftLeft(Vec128 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + return ShiftLeftSame(v, kBits); +} + +// ------------------------------ ShiftRight + +template +HWY_API Vec128 ShiftRight(Vec128 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + return ShiftRightSame(v, kBits); +} + +// ------------------------------ BroadcastSignBit + +template +HWY_API Vec128 BroadcastSignBit(Vec128 v) { + return ShiftRightSame(v, static_cast(sizeof(T) * 8 - 1)); +} + +// ================================================== SWIZZLE (1) + +// ------------------------------ TableLookupBytes +template +HWY_API Vec128 TableLookupBytes(Vec128 bytes, + Vec128 from) { + const Repartition> du8_from; + return Vec128{reinterpret_cast::type>( + vec_perm(bytes.raw, bytes.raw, BitCast(du8_from, from).raw))}; +} + +// ------------------------------ TableLookupBytesOr0 +// For all vector widths; Altivec/VSX needs zero out +template +HWY_API VI TableLookupBytesOr0(const V bytes, const VI from) { + const DFromV di; + Repartition di8; + const VI zeroOutMask = BitCast(di, BroadcastSignBit(BitCast(di8, from))); + return AndNot(zeroOutMask, TableLookupBytes(bytes, from)); +} + +// ------------------------------ Reverse +#if HWY_S390X_HAVE_Z14 && HWY_COMPILER_GCC_ACTUAL && \ + HWY_COMPILER_GCC_ACTUAL < 900 +// Workaround for missing vec_reve on Z14 with GCC 8 or earlier +template , HWY_IF_LANES_GT_D(D, 1), + HWY_IF_T_SIZE_D(D, 1)> +HWY_API Vec128 Reverse(D d, Vec128 v) { + const Repartition du8; + return TableLookupBytes( + v, BitCast(d, Dup128VecFromValues(du8, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, + 5, 4, 3, 2, 1, 0))); +} + +template , HWY_IF_LANES_GT_D(D, 1), + HWY_IF_T_SIZE_D(D, 2)> +HWY_API Vec128 Reverse(D d, Vec128 v) { + const Repartition du8; + return TableLookupBytes( + v, BitCast(d, Dup128VecFromValues(du8, 14, 15, 12, 13, 10, 11, 8, 9, 6, 7, + 4, 5, 2, 3, 0, 1))); +} + +template , HWY_IF_LANES_GT_D(D, 1), + HWY_IF_T_SIZE_D(D, 4)> +HWY_API Vec128 Reverse(D d, Vec128 v) { + const Repartition du8; + return TableLookupBytes( + v, BitCast(d, Dup128VecFromValues(du8, 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, + 6, 7, 0, 1, 2, 3))); +} + +template , HWY_IF_LANES_GT_D(D, 1), + HWY_IF_T_SIZE_D(D, 8)> +HWY_API Vec128 Reverse(D /* tag */, Vec128 v) { + return Vec128{vec_sld(v.raw, v.raw, 8)}; +} +#else +template , HWY_IF_LANES_GT_D(D, 1)> +HWY_API Vec128 Reverse(D /* tag */, Vec128 v) { + return Vec128{vec_reve(v.raw)}; +} +#endif + +// ------------------------------ Shuffles (Reverse) + +// Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec128 Shuffle2301(Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + const __vector unsigned char kShuffle = {4, 5, 6, 7, 0, 1, 2, 3, + 12, 13, 14, 15, 8, 9, 10, 11}; + return Vec128{vec_perm(v.raw, v.raw, kShuffle)}; +} + +// These are used by generic_ops-inl to implement LoadInterleaved3. As with +// Intel's shuffle* intrinsics and InterleaveLower, the lower half of the output +// comes from the first argument. +namespace detail { + +template +HWY_API Vec32 ShuffleTwo2301(Vec32 a, Vec32 b) { + const __vector unsigned char kShuffle16 = {1, 0, 19, 18}; + return Vec32{vec_perm(a.raw, b.raw, kShuffle16)}; +} +template +HWY_API Vec64 ShuffleTwo2301(Vec64 a, Vec64 b) { + const __vector unsigned char kShuffle = {2, 3, 0, 1, 22, 23, 20, 21}; + return Vec64{vec_perm(a.raw, b.raw, kShuffle)}; +} +template +HWY_API Vec128 ShuffleTwo2301(Vec128 a, Vec128 b) { + const __vector unsigned char kShuffle = {4, 5, 6, 7, 0, 1, 2, 3, + 28, 29, 30, 31, 24, 25, 26, 27}; + return Vec128{vec_perm(a.raw, b.raw, kShuffle)}; +} + +template +HWY_API Vec32 ShuffleTwo1230(Vec32 a, Vec32 b) { + const __vector unsigned char kShuffle = {0, 3, 18, 17}; + return Vec32{vec_perm(a.raw, b.raw, kShuffle)}; +} +template +HWY_API Vec64 ShuffleTwo1230(Vec64 a, Vec64 b) { + const __vector unsigned char kShuffle = {0, 1, 6, 7, 20, 21, 18, 19}; + return Vec64{vec_perm(a.raw, b.raw, kShuffle)}; +} +template +HWY_API Vec128 ShuffleTwo1230(Vec128 a, Vec128 b) { + const __vector unsigned char kShuffle = {0, 1, 2, 3, 12, 13, 14, 15, + 24, 25, 26, 27, 20, 21, 22, 23}; + return Vec128{vec_perm(a.raw, b.raw, kShuffle)}; +} + +template +HWY_API Vec32 ShuffleTwo3012(Vec32 a, Vec32 b) { + const __vector unsigned char kShuffle = {2, 1, 16, 19}; + return Vec32{vec_perm(a.raw, b.raw, kShuffle)}; +} +template +HWY_API Vec64 ShuffleTwo3012(Vec64 a, Vec64 b) { + const __vector unsigned char kShuffle = {4, 5, 2, 3, 16, 17, 22, 23}; + return Vec64{vec_perm(a.raw, b.raw, kShuffle)}; +} +template +HWY_API Vec128 ShuffleTwo3012(Vec128 a, Vec128 b) { + const __vector unsigned char kShuffle = {8, 9, 10, 11, 4, 5, 6, 7, + 16, 17, 18, 19, 28, 29, 30, 31}; + return Vec128{vec_perm(a.raw, b.raw, kShuffle)}; +} + +} // namespace detail + +// Swap 64-bit halves +template +HWY_API Vec128 Shuffle1032(Vec128 v) { + const Full128 d; + const Full128 du64; + return BitCast(d, Reverse(du64, BitCast(du64, v))); +} +template +HWY_API Vec128 Shuffle01(Vec128 v) { + return Reverse(Full128(), v); +} + +// Rotate right 32 bits +template +HWY_API Vec128 Shuffle0321(Vec128 v) { +#if HWY_IS_LITTLE_ENDIAN + return Vec128{vec_sld(v.raw, v.raw, 12)}; +#else + return Vec128{vec_sld(v.raw, v.raw, 4)}; +#endif +} +// Rotate left 32 bits +template +HWY_API Vec128 Shuffle2103(Vec128 v) { +#if HWY_IS_LITTLE_ENDIAN + return Vec128{vec_sld(v.raw, v.raw, 4)}; +#else + return Vec128{vec_sld(v.raw, v.raw, 12)}; +#endif +} + +template +HWY_API Vec128 Shuffle0123(Vec128 v) { + return Reverse(Full128(), v); +} + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API MFromD RebindMask(DTo /*dto*/, Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TFromD), "Must have same size"); + return MFromD{m.raw}; +} + +template +HWY_API Mask128 TestBit(Vec128 v, Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Equality + +template +HWY_API Mask128 operator==(Vec128 a, Vec128 b) { + return Mask128{vec_cmpeq(a.raw, b.raw)}; +} + +// ------------------------------ Inequality + +// This cannot have T as a template argument, otherwise it is not more +// specialized than rewritten operator== in C++20, leading to compile +// errors: https://gcc.godbolt.org/z/xsrPhPvPT. +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { +#if HWY_PPC_HAVE_9 + return Mask128{vec_cmpne(a.raw, b.raw)}; +#else + return Not(a == b); +#endif +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { +#if HWY_PPC_HAVE_9 + return Mask128{vec_cmpne(a.raw, b.raw)}; +#else + return Not(a == b); +#endif +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { +#if HWY_PPC_HAVE_9 + return Mask128{vec_cmpne(a.raw, b.raw)}; +#else + return Not(a == b); +#endif +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { +#if HWY_PPC_HAVE_9 + return Mask128{vec_cmpne(a.raw, b.raw)}; +#else + return Not(a == b); +#endif +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { +#if HWY_PPC_HAVE_9 + return Mask128{vec_cmpne(a.raw, b.raw)}; +#else + return Not(a == b); +#endif +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { +#if HWY_PPC_HAVE_9 + return Mask128{vec_cmpne(a.raw, b.raw)}; +#else + return Not(a == b); +#endif +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} + +template +HWY_API Mask128 operator!=(Vec128 a, Vec128 b) { + return Not(a == b); +} + +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} + +// ------------------------------ Strict inequality + +template +HWY_INLINE Mask128 operator>(Vec128 a, Vec128 b) { + return Mask128{vec_cmpgt(a.raw, b.raw)}; +} + +// ------------------------------ Weak inequality + +template +HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { + return Mask128{vec_cmpge(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { + return Not(b > a); +} + +// ------------------------------ Reversed comparisons + +template +HWY_API Mask128 operator<(Vec128 a, Vec128 b) { + return b > a; +} + +template +HWY_API Mask128 operator<=(Vec128 a, Vec128 b) { + return b >= a; +} + +// ================================================== MEMORY (2) + +// ------------------------------ Load +template > +HWY_API Vec128 LoadU(D /* tag */, const T* HWY_RESTRICT p) { + using LoadRaw = typename detail::Raw128::UnalignedRawVec; + const LoadRaw* HWY_RESTRICT praw = reinterpret_cast(p); + using ResultRaw = typename detail::Raw128::type; + return Vec128{reinterpret_cast(*praw)}; +} + +// For < 128 bit, LoadU == Load. +template > +HWY_API VFromD LoadU(D d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template > +HWY_API VFromD LoadDup128(D d, const T* HWY_RESTRICT p) { + return LoadU(d, p); +} + +#if (HWY_PPC_HAVE_9 && HWY_ARCH_PPC_64) || HWY_S390X_HAVE_Z14 +#ifdef HWY_NATIVE_LOAD_N +#undef HWY_NATIVE_LOAD_N +#else +#define HWY_NATIVE_LOAD_N +#endif + +template > +HWY_API VFromD LoadN(D d, const T* HWY_RESTRICT p, + size_t max_lanes_to_load) { +#if HWY_COMPILER_GCC && !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(max_lanes_to_load) && max_lanes_to_load == 0) { + return Zero(d); + } + + if (__builtin_constant_p(max_lanes_to_load >= HWY_MAX_LANES_D(D)) && + max_lanes_to_load >= HWY_MAX_LANES_D(D)) { + return LoadU(d, p); + } +#endif + + const size_t num_of_bytes_to_load = + HWY_MIN(max_lanes_to_load, HWY_MAX_LANES_D(D)) * sizeof(TFromD); + const Repartition du8; +#if HWY_S390X_HAVE_Z14 + return (num_of_bytes_to_load > 0) + ? BitCast(d, VFromD{vec_load_len( + const_cast( + reinterpret_cast(p)), + static_cast(num_of_bytes_to_load - 1))}) + : Zero(d); +#else + return BitCast( + d, + VFromD{vec_xl_len( + const_cast(reinterpret_cast(p)), + num_of_bytes_to_load)}); +#endif +} + +template > +HWY_API VFromD LoadNOr(VFromD no, D d, const T* HWY_RESTRICT p, + size_t max_lanes_to_load) { +#if HWY_COMPILER_GCC && !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(max_lanes_to_load) && max_lanes_to_load == 0) { + return no; + } + + if (__builtin_constant_p(max_lanes_to_load >= HWY_MAX_LANES_D(D)) && + max_lanes_to_load >= HWY_MAX_LANES_D(D)) { + return LoadU(d, p); + } +#endif + + return IfThenElse(FirstN(d, max_lanes_to_load), + LoadN(d, p, max_lanes_to_load), no); +} + +#endif // HWY_PPC_HAVE_9 || HWY_S390X_HAVE_Z14 + +// Returns a vector with lane i=[0, N) set to "first" + i. +namespace detail { + +template +HWY_INLINE VFromD Iota0(D d) { + constexpr __vector unsigned char kU8Iota0 = {0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15}; + return BitCast(d, VFromD>{kU8Iota0}); +} + +template +HWY_INLINE VFromD Iota0(D d) { + constexpr __vector unsigned short kU16Iota0 = {0, 1, 2, 3, 4, 5, 6, 7}; + return BitCast(d, VFromD>{kU16Iota0}); +} + +template +HWY_INLINE VFromD Iota0(D d) { + constexpr __vector unsigned int kU32Iota0 = {0, 1, 2, 3}; + return BitCast(d, VFromD>{kU32Iota0}); +} + +template +HWY_INLINE VFromD Iota0(D d) { + constexpr __vector unsigned long long kU64Iota0 = {0, 1}; + return BitCast(d, VFromD>{kU64Iota0}); +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + constexpr __vector float kF32Iota0 = {0.0f, 1.0f, 2.0f, 3.0f}; + return VFromD{kF32Iota0}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + constexpr __vector double kF64Iota0 = {0.0, 1.0}; + return VFromD{kF64Iota0}; +} + +} // namespace detail + +template +HWY_API VFromD Iota(D d, const T2 first) { + return detail::Iota0(d) + Set(d, static_cast>(first)); +} + +// ------------------------------ FirstN (Iota, Lt) + +template +HWY_API MFromD FirstN(D d, size_t num) { + const RebindToUnsigned du; + using TU = TFromD; + return RebindMask(d, Iota(du, 0) < Set(du, static_cast(num))); +} + +// ------------------------------ MaskedLoad +template > +HWY_API VFromD MaskedLoad(MFromD m, D d, const T* HWY_RESTRICT p) { + return IfThenElseZero(m, LoadU(d, p)); +} + +// ------------------------------ MaskedLoadOr +template > +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, + const T* HWY_RESTRICT p) { + return IfThenElse(m, LoadU(d, p), v); +} + +// ------------------------------ Store + +template > +HWY_API void Store(Vec128 v, D /* tag */, T* HWY_RESTRICT aligned) { +// Suppress the ignoring attributes warning that is generated by +// HWY_RCAST_ALIGNED(StoreRaw*, aligned) with GCC +#if HWY_COMPILER_GCC + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4649, ignored "-Wignored-attributes") +#endif + + using StoreRaw = typename detail::Raw128::AlignedRawVec; + *HWY_RCAST_ALIGNED(StoreRaw*, aligned) = reinterpret_cast(v.raw); + +#if HWY_COMPILER_GCC + HWY_DIAGNOSTICS(pop) +#endif +} + +template > +HWY_API void StoreU(Vec128 v, D /* tag */, T* HWY_RESTRICT p) { + using StoreRaw = typename detail::Raw128::UnalignedRawVec; + *reinterpret_cast(p) = reinterpret_cast(v.raw); +} + +template > +HWY_API void Store(VFromD v, D d, T* HWY_RESTRICT p) { + using BitsT = UnsignedFromSize; + + const Repartition d_bits; + const BitsT bits = GetLane(BitCast(d_bits, v)); + CopyBytes(&bits, p); +} + +// For < 128 bit, StoreU == Store. +template > +HWY_API void StoreU(VFromD v, D d, T* HWY_RESTRICT p) { + Store(v, d, p); +} + +#if (HWY_PPC_HAVE_9 && HWY_ARCH_PPC_64) || HWY_S390X_HAVE_Z14 + +#ifdef HWY_NATIVE_STORE_N +#undef HWY_NATIVE_STORE_N +#else +#define HWY_NATIVE_STORE_N +#endif + +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { +#if HWY_COMPILER_GCC && !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(max_lanes_to_store) && max_lanes_to_store == 0) { + return; + } + + if (__builtin_constant_p(max_lanes_to_store >= HWY_MAX_LANES_D(D)) && + max_lanes_to_store >= HWY_MAX_LANES_D(D)) { + StoreU(v, d, p); + return; + } +#endif + + const size_t num_of_bytes_to_store = + HWY_MIN(max_lanes_to_store, HWY_MAX_LANES_D(D)) * sizeof(TFromD); + const Repartition du8; +#if HWY_S390X_HAVE_Z14 + if (num_of_bytes_to_store > 0) { + vec_store_len(BitCast(du8, v).raw, reinterpret_cast(p), + static_cast(num_of_bytes_to_store - 1)); + } +#else + vec_xst_len(BitCast(du8, v).raw, reinterpret_cast(p), + num_of_bytes_to_store); +#endif +} +#endif + +// ------------------------------ BlendedStore + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + const VFromD old = LoadU(d, p); + StoreU(IfThenElse(RebindMask(d, m), v, old), d, p); +} + +// ================================================== ARITHMETIC + +namespace detail { +// If TFromD is an integer type, detail::RebindToUnsignedIfNotFloat +// rebinds D to MakeUnsigned>. + +// Otherwise, if TFromD is a floating-point type (including F16 and BF16), +// detail::RebindToUnsignedIfNotFloat is the same as D. +template +using RebindToUnsignedIfNotFloat = + hwy::If<(!hwy::IsFloat>() && !hwy::IsSpecialFloat>()), + RebindToUnsigned, D>; +} // namespace detail + +// ------------------------------ Addition + +template +HWY_API Vec128 operator+(Vec128 a, Vec128 b) { + const DFromV d; + const detail::RebindToUnsignedIfNotFloat d_arith; + + // If T is an integer type, do an unsigned vec_add to avoid undefined behavior +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VFromD{BitCast(d_arith, a).raw + + BitCast(d_arith, b).raw}); +#else + return BitCast(d, VFromD{vec_add( + BitCast(d_arith, a).raw, BitCast(d_arith, b).raw)}); +#endif +} + +// ------------------------------ Subtraction + +template +HWY_API Vec128 operator-(Vec128 a, Vec128 b) { + const DFromV d; + const detail::RebindToUnsignedIfNotFloat d_arith; + + // If T is an integer type, do an unsigned vec_sub to avoid undefined behavior +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VFromD{BitCast(d_arith, a).raw - + BitCast(d_arith, b).raw}); +#else + return BitCast(d, VFromD{vec_sub( + BitCast(d_arith, a).raw, BitCast(d_arith, b).raw)}); +#endif +} + +// ------------------------------ SumsOf8 +template )> +HWY_API VFromD>> SumsOf8(V v) { + return SumsOf2(SumsOf4(v)); +} + +template )> +HWY_API VFromD>> SumsOf8(V v) { +#if HWY_S390X_HAVE_Z14 + const DFromV di8; + const RebindToUnsigned du8; + const RepartitionToWideX3 di64; + + return BitCast(di64, SumsOf8(BitCast(du8, Xor(v, SignBit(di8))))) + + Set(di64, int64_t{-1024}); +#else + return SumsOf2(SumsOf4(v)); +#endif +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +#if HWY_S390X_HAVE_Z14 +// Z14/Z15/Z16 does not have I8/U8/I16/U16 SaturatedAdd instructions unlike most +// other integer SIMD instruction sets + +template +HWY_API Vec128 SaturatedAdd(Vec128 a, Vec128 b) { + return Add(a, Min(b, Not(a))); +} + +template +HWY_API Vec128 SaturatedAdd(Vec128 a, Vec128 b) { + const DFromV d; + const auto sum = Add(a, b); + const auto overflow_mask = AndNot(Xor(a, b), Xor(a, sum)); + const auto overflow_result = Xor(BroadcastSignBit(a), Set(d, LimitsMax())); + return IfNegativeThenElse(overflow_mask, overflow_result, sum); +} + +#else // VSX + +#ifdef HWY_NATIVE_I32_SATURATED_ADDSUB +#undef HWY_NATIVE_I32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I32_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_U32_SATURATED_ADDSUB +#undef HWY_NATIVE_U32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_U32_SATURATED_ADDSUB +#endif + +template +HWY_API Vec128 SaturatedAdd(Vec128 a, Vec128 b) { + return Vec128{vec_adds(a.raw, b.raw)}; +} +#endif // HWY_S390X_HAVE_Z14 + +#if HWY_PPC_HAVE_10 + +#ifdef HWY_NATIVE_I64_SATURATED_ADDSUB +#undef HWY_NATIVE_I64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I64_SATURATED_ADDSUB +#endif + +template )> +HWY_API V SaturatedAdd(V a, V b) { + const DFromV d; + const auto sum = Add(a, b); + const auto overflow_mask = + BroadcastSignBit(detail::TernaryLogic<0x42>(a, b, sum)); + const auto overflow_result = + Xor(BroadcastSignBit(a), Set(d, LimitsMax())); + return IfNegativeThenElse(overflow_mask, overflow_result, sum); +} + +#endif // HWY_PPC_HAVE_10 + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +#if HWY_S390X_HAVE_Z14 +// Z14/Z15/Z16 does not have I8/U8/I16/U16 SaturatedSub instructions unlike most +// other integer SIMD instruction sets + +template +HWY_API Vec128 SaturatedSub(Vec128 a, Vec128 b) { + return Sub(a, Min(a, b)); +} + +template +HWY_API Vec128 SaturatedSub(Vec128 a, Vec128 b) { + const DFromV d; + const auto diff = Sub(a, b); + const auto overflow_mask = And(Xor(a, b), Xor(a, diff)); + const auto overflow_result = Xor(BroadcastSignBit(a), Set(d, LimitsMax())); + return IfNegativeThenElse(overflow_mask, overflow_result, diff); +} + +#else // VSX + +template +HWY_API Vec128 SaturatedSub(Vec128 a, Vec128 b) { + return Vec128{vec_subs(a.raw, b.raw)}; +} +#endif // HWY_S390X_HAVE_Z14 + +#if HWY_PPC_HAVE_10 + +template )> +HWY_API V SaturatedSub(V a, V b) { + const DFromV d; + const auto diff = Sub(a, b); + const auto overflow_mask = + BroadcastSignBit(detail::TernaryLogic<0x18>(a, b, diff)); + const auto overflow_result = + Xor(BroadcastSignBit(a), Set(d, LimitsMax())); + return IfNegativeThenElse(overflow_mask, overflow_result, diff); +} + +#endif // HWY_PPC_HAVE_10 + +// ------------------------------ AverageRound + +// Returns (a + b + 1) / 2 + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 +#undef HWY_NATIVE_AVERAGE_ROUND_UI32 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI32 +#endif + +#if HWY_S390X_HAVE_Z14 +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64 +#undef HWY_NATIVE_AVERAGE_ROUND_UI64 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI64 +#endif + +#define HWY_PPC_IF_AVERAGE_ROUND_T(T) void* = nullptr +#else // !HWY_S390X_HAVE_Z14 +#define HWY_PPC_IF_AVERAGE_ROUND_T(T) \ + HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2) | (1 << 4)) +#endif // HWY_S390X_HAVE_Z14 + +template +HWY_API Vec128 AverageRound(Vec128 a, Vec128 b) { + return Vec128{vec_avg(a.raw, b.raw)}; +} + +#undef HWY_PPC_IF_AVERAGE_ROUND_T + +// ------------------------------ Multiplication + +// Per-target flags to prevent generic_ops-inl.h defining 8/64-bit operator*. +#ifdef HWY_NATIVE_MUL_8 +#undef HWY_NATIVE_MUL_8 +#else +#define HWY_NATIVE_MUL_8 +#endif +#ifdef HWY_NATIVE_MUL_64 +#undef HWY_NATIVE_MUL_64 +#else +#define HWY_NATIVE_MUL_64 +#endif + +template +HWY_API Vec128 operator*(Vec128 a, Vec128 b) { + const DFromV d; + const detail::RebindToUnsignedIfNotFloat d_arith; + + // If T is an integer type, do an unsigned vec_mul to avoid undefined behavior +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VFromD{BitCast(d_arith, a).raw * + BitCast(d_arith, b).raw}); +#else + return BitCast(d, VFromD{vec_mul( + BitCast(d_arith, a).raw, BitCast(d_arith, b).raw)}); +#endif +} + +// Returns the upper sizeof(T)*8 bits of a * b in each lane. + +#if HWY_S390X_HAVE_Z14 +#define HWY_PPC_IF_MULHIGH_USING_VEC_MULH(T) \ + HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2) | (1 << 4)) +#define HWY_PPC_IF_MULHIGH_8_16_32_NOT_USING_VEC_MULH(T) \ + hwy::EnableIf()>* = nullptr +#elif HWY_PPC_HAVE_10 +#define HWY_PPC_IF_MULHIGH_USING_VEC_MULH(T) \ + HWY_IF_T_SIZE_ONE_OF(T, (1 << 4) | (1 << 8)) +#define HWY_PPC_IF_MULHIGH_8_16_32_NOT_USING_VEC_MULH(T) \ + HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2)) +#else +#define HWY_PPC_IF_MULHIGH_USING_VEC_MULH(T) \ + hwy::EnableIf()>* = nullptr +#define HWY_PPC_IF_MULHIGH_8_16_32_NOT_USING_VEC_MULH(T) \ + HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2) | (1 << 4)) +#endif + +#if HWY_S390X_HAVE_Z14 || HWY_PPC_HAVE_10 +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + return Vec128{vec_mulh(a.raw, b.raw)}; +} +#endif + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + const auto p_even = MulEven(a, b); + +#if HWY_IS_LITTLE_ENDIAN + const auto p_even_full = ResizeBitCast(Full128(), p_even); + return Vec128{ + vec_sld(p_even_full.raw, p_even_full.raw, 16 - sizeof(T))}; +#else + const DFromV d; + return ResizeBitCast(d, p_even); +#endif +} + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + const DFromV d; + + const auto p_even = BitCast(d, MulEven(a, b)); + const auto p_odd = BitCast(d, MulOdd(a, b)); + +#if HWY_IS_LITTLE_ENDIAN + return InterleaveOdd(d, p_even, p_odd); +#else + return InterleaveEven(d, p_even, p_odd); +#endif +} + +#if !HWY_PPC_HAVE_10 +template +HWY_API Vec64 MulHigh(Vec64 a, Vec64 b) { + T p_hi; + Mul128(GetLane(a), GetLane(b), &p_hi); + return Set(Full64(), p_hi); +} + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + const DFromV d; + const Half dh; + return Combine(d, MulHigh(UpperHalf(dh, a), UpperHalf(dh, b)), + MulHigh(LowerHalf(dh, a), LowerHalf(dh, b))); +} +#endif // !HWY_PPC_HAVE_10 + +#undef HWY_PPC_IF_MULHIGH_USING_VEC_MULH +#undef HWY_PPC_IF_MULHIGH_8_16_32_NOT_USING_VEC_MULH + +// Multiplies even lanes (0, 2, ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +template +HWY_API Vec128, (N + 1) / 2> MulEven(Vec128 a, + Vec128 b) { + return Vec128, (N + 1) / 2>{vec_mule(a.raw, b.raw)}; +} + +// Multiplies odd lanes (1, 3, ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +template +HWY_API Vec128, (N + 1) / 2> MulOdd(Vec128 a, + Vec128 b) { + return Vec128, (N + 1) / 2>{vec_mulo(a.raw, b.raw)}; +} + +// ------------------------------ Rol/Ror + +#ifdef HWY_NATIVE_ROL_ROR_8 +#undef HWY_NATIVE_ROL_ROR_8 +#else +#define HWY_NATIVE_ROL_ROR_8 +#endif + +#ifdef HWY_NATIVE_ROL_ROR_16 +#undef HWY_NATIVE_ROL_ROR_16 +#else +#define HWY_NATIVE_ROL_ROR_16 +#endif + +#ifdef HWY_NATIVE_ROL_ROR_32_64 +#undef HWY_NATIVE_ROL_ROR_32_64 +#else +#define HWY_NATIVE_ROL_ROR_32_64 +#endif + +template +HWY_API Vec128 Rol(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast( + d, VFromD{vec_rl(BitCast(du, a).raw, BitCast(du, b).raw)}); +} + +template +HWY_API Vec128 Ror(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToSigned di; + return Rol(a, BitCast(d, Neg(BitCast(di, b)))); +} + +// ------------------------------ RotateRight +template +HWY_API Vec128 RotateRight(const Vec128 v) { + const DFromV d; + constexpr size_t kSizeInBits = sizeof(T) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + + return (kBits == 0) + ? v + : Rol(v, Set(d, static_cast(static_cast(kSizeInBits) - + kBits))); +} + +// ------------------------------ RotateLeftSame/RotateRightSame +#ifdef HWY_NATIVE_ROL_ROR_SAME_8 +#undef HWY_NATIVE_ROL_ROR_SAME_8 +#else +#define HWY_NATIVE_ROL_ROR_SAME_8 +#endif + +#ifdef HWY_NATIVE_ROL_ROR_SAME_16 +#undef HWY_NATIVE_ROL_ROR_SAME_16 +#else +#define HWY_NATIVE_ROL_ROR_SAME_16 +#endif + +#ifdef HWY_NATIVE_ROL_ROR_SAME_32_64 +#undef HWY_NATIVE_ROL_ROR_SAME_32_64 +#else +#define HWY_NATIVE_ROL_ROR_SAME_32_64 +#endif + +template +HWY_API Vec128 RotateLeftSame(Vec128 v, int bits) { + const DFromV d; + return Rol(v, Set(d, static_cast(static_cast(bits)))); +} + +template +HWY_API Vec128 RotateRightSame(Vec128 v, int bits) { + const DFromV d; + return Rol(v, Set(d, static_cast(0u - static_cast(bits)))); +} + +// ------------------------------ IfNegativeThenElse + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + + const DFromV d; +#if HWY_PPC_HAVE_10 + const RebindToUnsigned du; + return BitCast( + d, VFromD{vec_blendv( + BitCast(du, no).raw, BitCast(du, yes).raw, BitCast(du, v).raw)}); +#else + const RebindToSigned di; + return IfVecThenElse(BitCast(d, BroadcastSignBit(BitCast(di, v))), yes, no); +#endif +} + +#if HWY_PPC_HAVE_10 +#ifdef HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#undef HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#else +#define HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#endif + +#ifdef HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#undef HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#else +#define HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#endif + +template +HWY_API V IfNegativeThenElseZero(V v, V yes) { + const DFromV d; + return IfNegativeThenElse(v, yes, Zero(d)); +} + +template +HWY_API V IfNegativeThenZeroElse(V v, V no) { + const DFromV d; + return IfNegativeThenElse(v, Zero(d), no); +} +#endif + +// generic_ops takes care of integer T. +template +HWY_API Vec128 AbsDiff(Vec128 a, Vec128 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +// Returns mul * x + add +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return Vec128{vec_madd(mul.raw, x.raw, add.raw)}; +} + +// Returns add - mul * x +template +HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + // NOTE: the vec_nmsub operation below computes -(mul * x - add), + // which is equivalent to add - mul * x in the round-to-nearest + // and round-towards-zero rounding modes + return Vec128{vec_nmsub(mul.raw, x.raw, add.raw)}; +} + +// Returns mul * x - sub +template +HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, + Vec128 sub) { + return Vec128{vec_msub(mul.raw, x.raw, sub.raw)}; +} + +// Returns -mul * x - sub +template +HWY_API Vec128 NegMulSub(Vec128 mul, Vec128 x, + Vec128 sub) { + // NOTE: The vec_nmadd operation below computes -(mul * x + sub), + // which is equivalent to -mul * x - sub in the round-to-nearest + // and round-towards-zero rounding modes + return Vec128{vec_nmadd(mul.raw, x.raw, sub.raw)}; +} + +// ------------------------------ Floating-point div +// Approximate reciprocal + +#ifdef HWY_NATIVE_F64_APPROX_RECIP +#undef HWY_NATIVE_F64_APPROX_RECIP +#else +#define HWY_NATIVE_F64_APPROX_RECIP +#endif + +template +HWY_API Vec128 operator/(Vec128 a, Vec128 b) { +#if HWY_S390X_HAVE_Z14 + return Vec128{a.raw / b.raw}; +#else + return Vec128{vec_div(a.raw, b.raw)}; +#endif +} + +template +HWY_API Vec128 ApproximateReciprocal(Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const DFromV d; + return Set(d, T(1.0)) / v; +#else + return Vec128{vec_re(v.raw)}; +#endif +} + +// ------------------------------ Floating-point square root + +#if HWY_S390X_HAVE_Z14 +// Approximate reciprocal square root +template +HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + + const auto half = v * Set(d, 0.5f); + // Initial guess based on log2(f) + const auto guess = BitCast( + d, Set(du, uint32_t{0x5F3759DFu}) - ShiftRight<1>(BitCast(du, v))); + // One Newton-Raphson iteration + return guess * NegMulAdd(half * guess, guess, Set(d, 1.5f)); +} +#else // VSX + +#ifdef HWY_NATIVE_F64_APPROX_RSQRT +#undef HWY_NATIVE_F64_APPROX_RSQRT +#else +#define HWY_NATIVE_F64_APPROX_RSQRT +#endif + +// Approximate reciprocal square root +template +HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { + return Vec128{vec_rsqrte(v.raw)}; +} +#endif // HWY_S390X_HAVE_Z14 + +// Full precision square root +template +HWY_API Vec128 Sqrt(Vec128 v) { + return Vec128{vec_sqrt(v.raw)}; +} + +// ------------------------------ GetBiasedExponent + +#if HWY_PPC_HAVE_9 + +#ifdef HWY_NATIVE_GET_BIASED_EXPONENT +#undef HWY_NATIVE_GET_BIASED_EXPONENT +#else +#define HWY_NATIVE_GET_BIASED_EXPONENT +#endif + +template +HWY_API VFromD>> GetBiasedExponent(V v) { + return VFromD>>{vec_extract_exp(v.raw)}; +} + +#endif // HWY_PPC_HAVE_9 + +// ------------------------------ Min (Gt, IfThenElse) + +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{vec_min(a.raw, b.raw)}; +} + +// ------------------------------ Max (Gt, IfThenElse) + +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{vec_max(a.raw, b.raw)}; +} + +// ------------------------------- Integer AbsDiff for PPC9/PPC10 + +#if HWY_PPC_HAVE_9 +#ifdef HWY_NATIVE_INTEGER_ABS_DIFF +#undef HWY_NATIVE_INTEGER_ABS_DIFF +#else +#define HWY_NATIVE_INTEGER_ABS_DIFF +#endif + +template +HWY_API V AbsDiff(const V a, const V b) { + return V{vec_absd(a.raw, b.raw)}; +} + +template )> +HWY_API V AbsDiff(const V a, const V b) { + return Sub(Max(a, b), Min(a, b)); +} + +template +HWY_API V AbsDiff(const V a, const V b) { + return Sub(Max(a, b), Min(a, b)); +} + +#endif // HWY_PPC_HAVE_9 + +// ------------------------------ Integer Div for PPC10 +#if HWY_PPC_HAVE_10 +#ifdef HWY_NATIVE_INT_DIV +#undef HWY_NATIVE_INT_DIV +#else +#define HWY_NATIVE_INT_DIV +#endif + +template +HWY_API Vec128 operator/(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_div for I32 Div on PPC10 to avoid + // undefined behavior if b[i] == 0 or + // (a[i] == LimitsMin() && b[i] == -1) + + // Clang will also optimize out I32 vec_div on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector signed int raw_result; + __asm__("vdivsw %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator/(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_div for U32 Div on PPC10 to avoid + // undefined behavior if b[i] == 0 + + // Clang will also optimize out U32 vec_div on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector unsigned int raw_result; + __asm__("vdivuw %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator/(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_div for I64 Div on PPC10 to avoid + // undefined behavior if b[i] == 0 or + // (a[i] == LimitsMin() && b[i] == -1) + + // Clang will also optimize out I64 vec_div on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector signed long long raw_result; + __asm__("vdivsd %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator/(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_div for U64 Div on PPC10 to avoid + // undefined behavior if b[i] == 0 + + // Clang will also optimize out U64 vec_div on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector unsigned long long raw_result; + __asm__("vdivud %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator/(Vec128 a, Vec128 b) { + const DFromV d; + const RepartitionToWide dw; + return OrderedDemote2To(d, PromoteLowerTo(dw, a) / PromoteLowerTo(dw, b), + PromoteUpperTo(dw, a) / PromoteUpperTo(dw, b)); +} + +template +HWY_API Vec128 operator/(Vec128 a, Vec128 b) { + const DFromV d; + const Rebind, decltype(d)> dw; + return DemoteTo(d, PromoteTo(dw, a) / PromoteTo(dw, b)); +} + +template +HWY_API Vec128 operator%(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_mod for I32 Mod on PPC10 to avoid + // undefined behavior if b[i] == 0 or + // (a[i] == LimitsMin() && b[i] == -1) + + // Clang will also optimize out I32 vec_mod on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector signed int raw_result; + __asm__("vmodsw %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator%(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_mod for U32 Mod on PPC10 to avoid + // undefined behavior if b[i] == 0 + + // Clang will also optimize out U32 vec_mod on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector unsigned int raw_result; + __asm__("vmoduw %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator%(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_mod for I64 Mod on PPC10 to avoid + // undefined behavior if b[i] == 0 or + // (a[i] == LimitsMin() && b[i] == -1) + + // Clang will also optimize out I64 vec_mod on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector signed long long raw_result; + __asm__("vmodsd %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator%(Vec128 a, + Vec128 b) { + // Inline assembly is used instead of vec_mod for U64 Mod on PPC10 to avoid + // undefined behavior if b[i] == 0 + + // Clang will also optimize out U64 vec_mod on PPC10 if optimizations are + // enabled and any of the lanes of b are known to be zero (even in the unused + // lanes of a partial vector) + __vector unsigned long long raw_result; + __asm__("vmodud %0,%1,%2" : "=v"(raw_result) : "v"(a.raw), "v"(b.raw)); + return Vec128{raw_result}; +} + +template +HWY_API Vec128 operator%(Vec128 a, Vec128 b) { + const DFromV d; + const RepartitionToWide dw; + return OrderedDemote2To(d, PromoteLowerTo(dw, a) % PromoteLowerTo(dw, b), + PromoteUpperTo(dw, a) % PromoteUpperTo(dw, b)); +} + +template +HWY_API Vec128 operator%(Vec128 a, Vec128 b) { + const DFromV d; + const Rebind, decltype(d)> dw; + return DemoteTo(d, PromoteTo(dw, a) % PromoteTo(dw, b)); +} +#endif + +// ================================================== MEMORY (3) + +// ------------------------------ Non-temporal stores + +template +HWY_API void Stream(VFromD v, D d, TFromD* HWY_RESTRICT aligned) { + __builtin_prefetch(aligned, 1, 0); + Store(v, d, aligned); +} + +// ------------------------------ Scatter in generic_ops-inl.h +// ------------------------------ Gather in generic_ops-inl.h + +// ================================================== SWIZZLE (2) + +// ------------------------------ LowerHalf + +// Returns upper/lower half of a vector. +template +HWY_API VFromD LowerHalf(D /* tag */, VFromD> v) { + return VFromD{v.raw}; +} +template +HWY_API Vec128 LowerHalf(Vec128 v) { + return Vec128{v.raw}; +} + +// ------------------------------ ShiftLeftBytes + +// NOTE: The ShiftLeftBytes operation moves the elements of v to the right +// by kBytes bytes and zeroes out the first kBytes bytes of v on both +// little-endian and big-endian PPC targets +// (same behavior as the HWY_EMU128 ShiftLeftBytes operation on both +// little-endian and big-endian targets) + +template +HWY_API VFromD ShiftLeftBytes(D d, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + if (kBytes == 0) return v; + const auto zeros = Zero(d); +#if HWY_IS_LITTLE_ENDIAN + return VFromD{vec_sld(v.raw, zeros.raw, kBytes)}; +#else + return VFromD{vec_sld(zeros.raw, v.raw, (-kBytes) & 15)}; +#endif +} + +template +HWY_API Vec128 ShiftLeftBytes(Vec128 v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftLeftLanes + +// NOTE: The ShiftLeftLanes operation moves the elements of v to the right +// by kLanes lanes and zeroes out the first kLanes lanes of v on both +// little-endian and big-endian PPC targets +// (same behavior as the HWY_EMU128 ShiftLeftLanes operation on both +// little-endian and big-endian targets) + +template > +HWY_API VFromD ShiftLeftLanes(D d, VFromD v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec128 ShiftLeftLanes(Vec128 v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes + +// NOTE: The ShiftRightBytes operation moves the elements of v to the left +// by kBytes bytes and zeroes out the last kBytes bytes of v on both +// little-endian and big-endian PPC targets +// (same behavior as the HWY_EMU128 ShiftRightBytes operation on both +// little-endian and big-endian targets) + +template +HWY_API VFromD ShiftRightBytes(D d, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + if (kBytes == 0) return v; + + // For partial vectors, clear upper lanes so we shift in zeros. + if (d.MaxBytes() != 16) { + const Full128> dfull; + VFromD vfull{v.raw}; + v = VFromD{IfThenElseZero(FirstN(dfull, MaxLanes(d)), vfull).raw}; + } + + const auto zeros = Zero(d); +#if HWY_IS_LITTLE_ENDIAN + return VFromD{vec_sld(zeros.raw, v.raw, (-kBytes) & 15)}; +#else + return VFromD{vec_sld(v.raw, zeros.raw, kBytes)}; +#endif +} + +// ------------------------------ ShiftRightLanes + +// NOTE: The ShiftRightLanes operation moves the elements of v to the left +// by kLanes lanes and zeroes out the last kLanes lanes of v on both +// little-endian and big-endian PPC targets +// (same behavior as the HWY_EMU128 ShiftRightLanes operation on both +// little-endian and big-endian targets) + +template +HWY_API VFromD ShiftRightLanes(D d, VFromD v) { + const Repartition d8; + constexpr size_t kBytes = kLanes * sizeof(TFromD); + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +template +HWY_API VFromD UpperHalf(D d, VFromD> v) { + return LowerHalf(d, ShiftRightBytes(Twice(), v)); +} + +// ------------------------------ ExtractLane +template +HWY_API T ExtractLane(Vec128 v, size_t i) { + return static_cast(v.raw[i]); +} + +// ------------------------------ InsertLane +template +HWY_API Vec128 InsertLane(Vec128 v, size_t i, T t) { +#if HWY_IS_LITTLE_ENDIAN + typename detail::Raw128::type raw_result = v.raw; + raw_result[i] = BitCastScalar::RawT>(t); + return Vec128{raw_result}; +#else + // On ppc64be without this, mul_test fails, but swizzle_test passes. + DFromV d; + alignas(16) T lanes[16 / sizeof(T)]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +#endif +} + +// ------------------------------ CombineShiftRightBytes + +// NOTE: The CombineShiftRightBytes operation below moves the elements of lo to +// the left by kBytes bytes and moves the elements of hi right by (d.MaxBytes() +// - kBytes) bytes on both little-endian and big-endian PPC targets. + +template > +HWY_API Vec128 CombineShiftRightBytes(D /*d*/, Vec128 hi, Vec128 lo) { + constexpr size_t kSize = 16; + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); +#if HWY_IS_LITTLE_ENDIAN + return Vec128{vec_sld(hi.raw, lo.raw, (-kBytes) & 15)}; +#else + return Vec128{vec_sld(lo.raw, hi.raw, kBytes)}; +#endif +} + +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + constexpr size_t kSize = d.MaxBytes(); + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); + const Repartition d8; + using V8 = Vec128; + const DFromV dfull8; + const Repartition, decltype(dfull8)> dfull; + const V8 hi8{BitCast(d8, hi).raw}; + // Move into most-significant bytes + const V8 lo8 = ShiftLeftBytes<16 - kSize>(V8{BitCast(d8, lo).raw}); + const V8 r = CombineShiftRightBytes<16 - kSize + kBytes>(dfull8, hi8, lo8); + return VFromD{BitCast(dfull, r).raw}; +} + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec128 Broadcast(Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{vec_splat(v.raw, kLane)}; +} + +// ------------------------------ TableLookupLanes (Shuffle01) + +// Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. +template +struct Indices128 { + __vector unsigned char raw; +}; + +namespace detail { + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; + return Iota(d8, 0); +} + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ + constexpr __vector unsigned char kBroadcastLaneBytes = { + 0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14}; +#else + constexpr __vector unsigned char kBroadcastLaneBytes = { + 1, 1, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}; +#endif + return VFromD{kBroadcastLaneBytes}; +} + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ + constexpr __vector unsigned char kBroadcastLaneBytes = { + 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12}; +#else + constexpr __vector unsigned char kBroadcastLaneBytes = { + 3, 3, 3, 3, 7, 7, 7, 7, 11, 11, 11, 11, 15, 15, 15, 15}; +#endif + return VFromD{kBroadcastLaneBytes}; +} + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ + constexpr __vector unsigned char kBroadcastLaneBytes = { + 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 8, 8}; +#else + constexpr __vector unsigned char kBroadcastLaneBytes = { + 7, 7, 7, 7, 7, 7, 7, 7, 15, 15, 15, 15, 15, 15, 15, 15}; +#endif + return VFromD{kBroadcastLaneBytes}; +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + return Zero(d8); +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + constexpr __vector unsigned char kByteOffsets = {0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1}; + return VFromD{kByteOffsets}; +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + constexpr __vector unsigned char kByteOffsets = {0, 1, 2, 3, 0, 1, 2, 3, + 0, 1, 2, 3, 0, 1, 2, 3}; + return VFromD{kByteOffsets}; +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + constexpr __vector unsigned char kByteOffsets = {0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7}; + return VFromD{kByteOffsets}; +} + +} // namespace detail + +template +HWY_API Indices128, MaxLanes(D())> IndicesFromVec( + D d, Vec128 vec) { + using T = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const RebindToUnsigned du; + using TU = TFromD; + HWY_DASSERT(AllTrue( + du, Lt(BitCast(du, vec), Set(du, static_cast(MaxLanes(d) * 2))))); +#endif + + const Repartition d8; + return Indices128, MaxLanes(D())>{BitCast(d8, vec).raw}; +} + +template +HWY_API Indices128, MaxLanes(D())> IndicesFromVec( + D d, Vec128 vec) { + using T = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const RebindToUnsigned du; + using TU = TFromD; + HWY_DASSERT(AllTrue( + du, Lt(BitCast(du, vec), Set(du, static_cast(MaxLanes(d) * 2))))); +#endif + + const Repartition d8; + using V8 = VFromD; + + // Broadcast each lane index to all bytes of T and shift to bytes + const V8 lane_indices = TableLookupBytes( + BitCast(d8, vec), detail::IndicesFromVecBroadcastLaneBytes(d)); + constexpr int kIndexShiftAmt = static_cast(FloorLog2(sizeof(T))); + const V8 byte_indices = ShiftLeft(lane_indices); + const V8 sum = Add(byte_indices, detail::IndicesFromVecByteOffsets(d)); + return Indices128, MaxLanes(D())>{sum.raw}; +} + +template +HWY_API Indices128, HWY_MAX_LANES_D(D)> SetTableIndices( + D d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + const DFromV d; + const Repartition d8; + return BitCast(d, TableLookupBytes(v, VFromD{idx.raw})); +} + +// Single lane: no change +template +HWY_API Vec128 TableLookupLanes(Vec128 v, + Indices128 /* idx */) { + return v; +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { + const DFromV d; + const Twice dt; + const Repartition dt_u8; +// TableLookupLanes currently requires table and index vectors to be the same +// size, though a half-length index vector would be sufficient here. +#if HWY_IS_MSAN + const Vec128 idx_vec{idx.raw}; + const Indices128 idx2{Combine(dt, idx_vec, idx_vec).raw}; +#else + // We only keep LowerHalf of the result, which is valid in idx. + const Indices128 idx2{idx.raw}; +#endif + return LowerHalf( + d, TableLookupBytes(Combine(dt, b, a), + BitCast(dt, VFromD{idx2.raw}))); +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { + return Vec128{vec_perm(a.raw, b.raw, idx.raw)}; +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template +HWY_API VFromD ReverseBlocks(D /* tag */, VFromD v) { + return v; +} + +// ------------------------------ Reverse (Shuffle0123, Shuffle2301) + +// Single lane: no change +template , HWY_IF_LANES_D(D, 1)> +HWY_API Vec128 Reverse(D /* tag */, Vec128 v) { + return v; +} + +// 32-bit x2: shuffle +template , HWY_IF_T_SIZE(T, 4)> +HWY_API Vec64 Reverse(D /* tag */, Vec64 v) { + return Vec64{Shuffle2301(Vec128{v.raw}).raw}; +} + +// 16-bit x4: shuffle +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec64 Reverse(D /* tag */, Vec64 v) { + const __vector unsigned char kShuffle = {6, 7, 4, 5, 2, 3, 0, 1, + 14, 15, 12, 13, 10, 11, 8, 9}; + return Vec64{vec_perm(v.raw, v.raw, kShuffle)}; +} + +// 16-bit x2: rotate bytes +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec32 Reverse(D d, Vec32 v) { + const RepartitionToWide> du32; + return BitCast(d, RotateRight<16>(Reverse(du32, BitCast(du32, v)))); +} + +// ------------------------------- ReverseLaneBytes + +#if (HWY_PPC_HAVE_9 || HWY_S390X_HAVE_Z14) && \ + ((!HWY_S390X_HAVE_Z14 && HWY_COMPILER_GCC_ACTUAL >= 710) || \ + (HWY_S390X_HAVE_Z14 && HWY_COMPILER_GCC_ACTUAL >= 900) || \ + HWY_COMPILER_CLANG >= 400) + +// Per-target flag to prevent generic_ops-inl.h defining 8-bit ReverseLaneBytes. +#ifdef HWY_NATIVE_REVERSE_LANE_BYTES +#undef HWY_NATIVE_REVERSE_LANE_BYTES +#else +#define HWY_NATIVE_REVERSE_LANE_BYTES +#endif + +template +HWY_API V ReverseLaneBytes(V v) { + return V{vec_revb(v.raw)}; +} + +// Per-target flag to prevent generic_ops-inl.h defining 8-bit Reverse2/4/8. +#ifdef HWY_NATIVE_REVERSE2_8 +#undef HWY_NATIVE_REVERSE2_8 +#else +#define HWY_NATIVE_REVERSE2_8 +#endif + +template , HWY_IF_T_SIZE(T, 1)> +HWY_API VFromD Reverse2(D d, VFromD v) { + const Repartition du16; + return BitCast(d, ReverseLaneBytes(BitCast(du16, v))); +} + +template , HWY_IF_T_SIZE(T, 1)> +HWY_API VFromD Reverse4(D d, VFromD v) { + const Repartition du32; + return BitCast(d, ReverseLaneBytes(BitCast(du32, v))); +} + +template , HWY_IF_T_SIZE(T, 1)> +HWY_API VFromD Reverse8(D d, VFromD v) { + const Repartition du64; + return BitCast(d, ReverseLaneBytes(BitCast(du64, v))); +} + +#endif // HWY_PPC_HAVE_9 || HWY_S390X_HAVE_Z14 + +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec16 Reverse(D d, Vec16 v) { + return Reverse2(d, v); +} + +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec32 Reverse(D d, Vec32 v) { + return Reverse4(d, v); +} + +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec64 Reverse(D d, Vec64 v) { + return Reverse8(d, v); +} + +// ------------------------------ Reverse2 + +// Single lane: no change +template , HWY_IF_LANES_D(D, 1)> +HWY_API Vec128 Reverse2(D /* tag */, Vec128 v) { + return v; +} + +template , HWY_IF_T_SIZE(T, 2)> +HWY_API VFromD Reverse2(D d, VFromD v) { + const Repartition du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +} + +template , HWY_IF_T_SIZE(T, 4)> +HWY_API VFromD Reverse2(D d, VFromD v) { + const Repartition du64; + return BitCast(d, RotateRight<32>(BitCast(du64, v))); +} + +template , HWY_IF_T_SIZE(T, 8)> +HWY_API VFromD Reverse2(D /* tag */, VFromD v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template +HWY_API VFromD Reverse4(D /*d*/, VFromD v) { + const __vector unsigned char kShuffle = {6, 7, 4, 5, 2, 3, 0, 1, + 14, 15, 12, 13, 10, 11, 8, 9}; + return VFromD{vec_perm(v.raw, v.raw, kShuffle)}; +} + +template +HWY_API VFromD Reverse4(D d, VFromD v) { + return Reverse(d, v); +} + +template +HWY_API VFromD Reverse4(D /* tag */, VFromD /* v */) { + HWY_ASSERT(0); // don't have 4 u64 lanes +} + +// ------------------------------ Reverse8 + +template +HWY_API VFromD Reverse8(D d, VFromD v) { + return Reverse(d, v); +} + +template +HWY_API VFromD Reverse8(D /* tag */, VFromD /* v */) { + HWY_ASSERT(0); // don't have 8 lanes if larger than 16-bit +} + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + return Vec128{vec_mergeh(a.raw, b.raw)}; +} + +// Additional overload for the optional tag +template +HWY_API VFromD InterleaveLower(D /* tag */, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveUpper (UpperHalf) + +// Full +template > +HWY_API Vec128 InterleaveUpper(D /* tag */, Vec128 a, Vec128 b) { + return Vec128{vec_mergel(a.raw, b.raw)}; +} + +// Partial +template +HWY_API VFromD InterleaveUpper(D d, VFromD a, VFromD b) { + const Half d2; + return InterleaveLower(d, VFromD{UpperHalf(d2, a).raw}, + VFromD{UpperHalf(d2, b).raw}); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ------------------------------ Per4LaneBlkShufDupSet4xU32 + +// Used by hwy/ops/generic_ops-inl.h to implement Per4LaneBlockShuffle +namespace detail { + +#ifdef HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#undef HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#else +#define HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#endif + +template +HWY_INLINE VFromD Per4LaneBlkShufDupSet4xU32(D d, const uint32_t x3, + const uint32_t x2, + const uint32_t x1, + const uint32_t x0) { + const __vector unsigned int raw = {x0, x1, x2, x3}; + return ResizeBitCast(d, Vec128{raw}); +} + +} // namespace detail + +// ------------------------------ SlideUpLanes + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { + const Repartition du8; + using VU8 = VFromD; + const auto v_shift_amt = + BitCast(Full128(), + Set(Full128(), + static_cast(amt * sizeof(TFromD) * 8))); + +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VU8{vec_srb(BitCast(du8, v).raw, v_shift_amt.raw)}); +#else // VSX +#if HWY_IS_LITTLE_ENDIAN + return BitCast(d, VU8{vec_slo(BitCast(du8, v).raw, v_shift_amt.raw)}); +#else + return BitCast(d, VU8{vec_sro(BitCast(du8, v).raw, v_shift_amt.raw)}); +#endif // HWY_IS_LITTLE_ENDIAN +#endif // HWY_S390X_HAVE_Z14 +} + +// ------------------------------ SlideDownLanes + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { + using TU = UnsignedFromSize; + const Repartition du; + const auto v_shift_amt = + Set(du, static_cast(amt * sizeof(TFromD) * 8)); + +#if HWY_IS_LITTLE_ENDIAN + return BitCast(d, BitCast(du, v) >> v_shift_amt); +#else + return BitCast(d, BitCast(du, v) << v_shift_amt); +#endif +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { + const Repartition du8; + using VU8 = VFromD; + const auto v_shift_amt = + BitCast(Full128(), + Set(Full128(), + static_cast(amt * sizeof(TFromD) * 8))); + +#if HWY_S390X_HAVE_Z14 + return BitCast(d, VU8{vec_slb(BitCast(du8, v).raw, v_shift_amt.raw)}); +#else // VSX +#if HWY_IS_LITTLE_ENDIAN + return BitCast(d, VU8{vec_sro(BitCast(du8, v).raw, v_shift_amt.raw)}); +#else + return BitCast(d, VU8{vec_slo(BitCast(du8, v).raw, v_shift_amt.raw)}); +#endif // HWY_IS_LITTLE_ENDIAN +#endif // HWY_S390X_HAVE_Z14 +} + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) + +// N = N/2 + N/2 (upper half undefined) +template >> +HWY_API VFromD Combine(D d, VH hi_half, VH lo_half) { + const Half dh; + // Treat half-width input as one lane, and expand to two lanes. + using VU = Vec128, 2>; + using Raw = typename detail::Raw128>::type; + const VU lo{reinterpret_cast(lo_half.raw)}; + const VU hi{reinterpret_cast(hi_half.raw)}; + return BitCast(d, InterleaveLower(lo, hi)); +} + +// ------------------------------ ZeroExtendVector (Combine, IfThenElseZero) + +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { + const Half dh; + return IfThenElseZero(FirstN(d, MaxLanes(dh)), VFromD{lo.raw}); +} + +// ------------------------------ Concat full (InterleaveLower) + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template > +HWY_API Vec128 ConcatLowerLower(D d, Vec128 hi, Vec128 lo) { + const Repartition d64; + return BitCast(d, InterleaveLower(BitCast(d64, lo), BitCast(d64, hi))); +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template > +HWY_API Vec128 ConcatUpperUpper(D d, Vec128 hi, Vec128 lo) { + const Repartition d64; + return BitCast(d, InterleaveUpper(d64, BitCast(d64, lo), BitCast(d64, hi))); +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves) +template > +HWY_API Vec128 ConcatLowerUpper(D d, Vec128 hi, Vec128 lo) { + return CombineShiftRightBytes<8>(d, hi, lo); +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template > +HWY_API Vec128 ConcatUpperLower(D /*d*/, Vec128 hi, Vec128 lo) { + const __vector unsigned char kShuffle = {0, 1, 2, 3, 4, 5, 6, 7, + 24, 25, 26, 27, 28, 29, 30, 31}; + return Vec128{vec_perm(lo.raw, hi.raw, kShuffle)}; +} + +// ------------------------------ Concat partial (Combine, LowerHalf) + +template +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), LowerHalf(d2, lo)); +} + +template +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, UpperHalf(d2, hi), UpperHalf(d2, lo)); +} + +template +HWY_API VFromD ConcatLowerUpper(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), UpperHalf(d2, lo)); +} + +template +HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, UpperHalf(d2, hi), LowerHalf(d2, lo)); +} + +// ------------------------------ TruncateTo + +template = sizeof(TFromD) * 2)>* = nullptr, + HWY_IF_LANES_D(D, 1)> +HWY_API VFromD TruncateTo(D /* tag */, Vec128 v) { + using Raw = typename detail::Raw128>::type; +#if HWY_IS_LITTLE_ENDIAN + return VFromD{reinterpret_cast(v.raw)}; +#else + return VFromD{reinterpret_cast( + vec_sld(v.raw, v.raw, sizeof(FromT) - sizeof(TFromD)))}; +#endif +} + +namespace detail { + +template ) * 2), HWY_IF_LANES_GT_D(D, 1)> +HWY_API VFromD Truncate2To( + D /* tag */, Vec128().MaxLanes()> lo, + Vec128().MaxLanes()> hi) { + return VFromD{vec_pack(lo.raw, hi.raw)}; +} + +} // namespace detail + +template ) * 2), HWY_IF_LANES_GT_D(D, 1)> +HWY_API VFromD TruncateTo(D /* d */, + Vec128().MaxLanes()> v) { + return VFromD{vec_pack(v.raw, v.raw)}; +} + +template = sizeof(TFromD) * 4)>* = nullptr, + HWY_IF_LANES_GT_D(D, 1)> +HWY_API VFromD TruncateTo(D d, + Vec128().MaxLanes()> v) { + const Rebind, decltype(d)> d2; + return TruncateTo(d, TruncateTo(d2, v)); +} + +// ------------------------------ ConcatOdd (TruncateTo) + +// 8-bit full +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec128 ConcatOdd(D d, Vec128 hi, Vec128 lo) { + const Repartition dw; + const RebindToUnsigned du; +#if HWY_IS_LITTLE_ENDIAN + // Right-shift 8 bits per u16 so we can pack. + const Vec128 uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec128 uL = ShiftRight<8>(BitCast(dw, lo)); +#else + const Vec128 uH = BitCast(dw, hi); + const Vec128 uL = BitCast(dw, lo); +#endif + return BitCast(d, detail::Truncate2To(du, uL, uH)); +} + +// 8-bit x8 +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec64 ConcatOdd(D /*d*/, Vec64 hi, Vec64 lo) { + // Don't care about upper half, no need to zero. + const __vector unsigned char kCompactOddU8 = {1, 3, 5, 7, 17, 19, 21, 23}; + return Vec64{vec_perm(lo.raw, hi.raw, kCompactOddU8)}; +} + +// 8-bit x4 +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec32 ConcatOdd(D /*d*/, Vec32 hi, Vec32 lo) { + // Don't care about upper half, no need to zero. + const __vector unsigned char kCompactOddU8 = {1, 3, 17, 19}; + return Vec32{vec_perm(lo.raw, hi.raw, kCompactOddU8)}; +} + +// 16-bit full +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec128 ConcatOdd(D d, Vec128 hi, Vec128 lo) { + const Repartition dw; + const RebindToUnsigned du; +#if HWY_IS_LITTLE_ENDIAN + const Vec128 uH = ShiftRight<16>(BitCast(dw, hi)); + const Vec128 uL = ShiftRight<16>(BitCast(dw, lo)); +#else + const Vec128 uH = BitCast(dw, hi); + const Vec128 uL = BitCast(dw, lo); +#endif + return BitCast(d, detail::Truncate2To(du, uL, uH)); +} + +// 16-bit x4 +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec64 ConcatOdd(D /*d*/, Vec64 hi, Vec64 lo) { + // Don't care about upper half, no need to zero. + const __vector unsigned char kCompactOddU16 = {2, 3, 6, 7, 18, 19, 22, 23}; + return Vec64{vec_perm(lo.raw, hi.raw, kCompactOddU16)}; +} + +// 32-bit full +template , HWY_IF_T_SIZE(T, 4)> +HWY_API Vec128 ConcatOdd(D d, Vec128 hi, Vec128 lo) { +#if HWY_IS_LITTLE_ENDIAN + (void)d; + const __vector unsigned char kShuffle = {4, 5, 6, 7, 12, 13, 14, 15, + 20, 21, 22, 23, 28, 29, 30, 31}; + return Vec128{vec_perm(lo.raw, hi.raw, kShuffle)}; +#else + const RebindToUnsigned du; + const Repartition dw; + return BitCast(d, detail::Truncate2To(du, BitCast(dw, lo), BitCast(dw, hi))); +#endif +} + +// Any type x2 +template , HWY_IF_LANES_D(D, 2)> +HWY_API Vec128 ConcatOdd(D d, Vec128 hi, Vec128 lo) { + return InterleaveUpper(d, lo, hi); +} + +// ------------------------------ ConcatEven (TruncateTo) + +// 8-bit full +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec128 ConcatEven(D d, Vec128 hi, Vec128 lo) { + const Repartition dw; + const RebindToUnsigned du; +#if HWY_IS_LITTLE_ENDIAN + const Vec128 uH = BitCast(dw, hi); + const Vec128 uL = BitCast(dw, lo); +#else + // Right-shift 8 bits per u16 so we can pack. + const Vec128 uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec128 uL = ShiftRight<8>(BitCast(dw, lo)); +#endif + return BitCast(d, detail::Truncate2To(du, uL, uH)); +} + +// 8-bit x8 +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec64 ConcatEven(D /*d*/, Vec64 hi, Vec64 lo) { + // Don't care about upper half, no need to zero. + const __vector unsigned char kCompactEvenU8 = {0, 2, 4, 6, 16, 18, 20, 22}; + return Vec64{vec_perm(lo.raw, hi.raw, kCompactEvenU8)}; +} + +// 8-bit x4 +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec32 ConcatEven(D /*d*/, Vec32 hi, Vec32 lo) { + // Don't care about upper half, no need to zero. + const __vector unsigned char kCompactEvenU8 = {0, 2, 16, 18}; + return Vec32{vec_perm(lo.raw, hi.raw, kCompactEvenU8)}; +} + +// 16-bit full +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec128 ConcatEven(D d, Vec128 hi, Vec128 lo) { + // Isolate lower 16 bits per u32 so we can pack. + const Repartition dw; + const RebindToUnsigned du; +#if HWY_IS_LITTLE_ENDIAN + const Vec128 uH = BitCast(dw, hi); + const Vec128 uL = BitCast(dw, lo); +#else + const Vec128 uH = ShiftRight<16>(BitCast(dw, hi)); + const Vec128 uL = ShiftRight<16>(BitCast(dw, lo)); +#endif + return BitCast(d, detail::Truncate2To(du, uL, uH)); +} + +// 16-bit x4 +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec64 ConcatEven(D /*d*/, Vec64 hi, Vec64 lo) { + // Don't care about upper half, no need to zero. + const __vector unsigned char kCompactEvenU16 = {0, 1, 4, 5, 16, 17, 20, 21}; + return Vec64{vec_perm(lo.raw, hi.raw, kCompactEvenU16)}; +} + +// 32-bit full +template , HWY_IF_T_SIZE(T, 4)> +HWY_API Vec128 ConcatEven(D d, Vec128 hi, Vec128 lo) { +#if HWY_IS_LITTLE_ENDIAN + const Repartition dw; + const RebindToUnsigned du; + return BitCast(d, detail::Truncate2To(du, BitCast(dw, lo), BitCast(dw, hi))); +#else + (void)d; + constexpr __vector unsigned char kShuffle = {0, 1, 2, 3, 8, 9, 10, 11, + 16, 17, 18, 19, 24, 25, 26, 27}; + return Vec128{vec_perm(lo.raw, hi.raw, kShuffle)}; +#endif +} + +// Any T x2 +template , HWY_IF_LANES_D(D, 2)> +HWY_API Vec128 ConcatEven(D d, Vec128 hi, Vec128 lo) { + return InterleaveLower(d, lo, hi); +} + +// ------------------------------ OrderedTruncate2To (ConcatEven, ConcatOdd) +#ifdef HWY_NATIVE_ORDERED_TRUNCATE_2_TO +#undef HWY_NATIVE_ORDERED_TRUNCATE_2_TO +#else +#define HWY_NATIVE_ORDERED_TRUNCATE_2_TO +#endif + +template ) * 2), + HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedTruncate2To(D d, V a, V b) { +#if HWY_IS_LITTLE_ENDIAN + return ConcatEven(d, BitCast(d, b), BitCast(d, a)); +#else + return ConcatOdd(d, BitCast(d, b), BitCast(d, a)); +#endif +} + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec128 DupEven(Vec128 v) { + return v; +} + +template +HWY_API Vec128 DupEven(Vec128 v) { + return InterleaveLower(DFromV(), v, v); +} + +template +HWY_API Vec128 DupEven(Vec128 v) { + const DFromV d; + const Repartition du8; + constexpr __vector unsigned char kShuffle = {0, 0, 2, 2, 4, 4, 6, 6, + 8, 8, 10, 10, 12, 12, 14, 14}; + return TableLookupBytes(v, BitCast(d, VFromD{kShuffle})); +} + +template +HWY_API Vec128 DupEven(Vec128 v) { + const DFromV d; + const Repartition du8; + constexpr __vector unsigned char kShuffle = {0, 1, 0, 1, 4, 5, 4, 5, + 8, 9, 8, 9, 12, 13, 12, 13}; + return TableLookupBytes(v, BitCast(d, VFromD{kShuffle})); +} + +template +HWY_API Vec128 DupEven(Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const DFromV d; + const Repartition du8; + return TableLookupBytes( + v, BitCast(d, Dup128VecFromValues(du8, 0, 1, 2, 3, 0, 1, 2, 3, 8, 9, 10, + 11, 8, 9, 10, 11))); +#else + return Vec128{vec_mergee(v.raw, v.raw)}; +#endif +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec128 DupOdd(Vec128 v) { + const DFromV d; + const Repartition du8; + constexpr __vector unsigned char kShuffle = {1, 1, 3, 3, 5, 5, 7, 7, + 9, 9, 11, 11, 13, 13, 15, 15}; + return TableLookupBytes(v, BitCast(d, VFromD{kShuffle})); +} + +template +HWY_API Vec128 DupOdd(Vec128 v) { + const DFromV d; + const Repartition du8; + constexpr __vector unsigned char kShuffle = {2, 3, 2, 3, 6, 7, 6, 7, + 10, 11, 10, 11, 14, 15, 14, 15}; + return TableLookupBytes(v, BitCast(d, VFromD{kShuffle})); +} + +template +HWY_API Vec128 DupOdd(Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const DFromV d; + const Repartition du8; + return TableLookupBytes( + v, BitCast(d, Dup128VecFromValues(du8, 4, 5, 6, 7, 4, 5, 6, 7, 12, 13, 14, + 15, 12, 13, 14, 15))); +#else + return Vec128{vec_mergeo(v.raw, v.raw)}; +#endif +} + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return InterleaveUpper(DFromV(), v, v); +} + +// ------------------------------ OddEven (IfThenElse) + +template +HWY_INLINE Vec128 OddEven(Vec128 a, Vec128 b) { + const DFromV d; + const __vector unsigned char mask = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, + 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; + return IfVecThenElse(BitCast(d, Vec128{mask}), b, a); +} + +template +HWY_INLINE Vec128 OddEven(Vec128 a, Vec128 b) { + const DFromV d; + const __vector unsigned char mask = {0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0, + 0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0}; + return IfVecThenElse(BitCast(d, Vec128{mask}), b, a); +} + +template +HWY_INLINE Vec128 OddEven(Vec128 a, Vec128 b) { + const DFromV d; + const __vector unsigned char mask = {0xFF, 0xFF, 0xFF, 0xFF, 0, 0, 0, 0, + 0xFF, 0xFF, 0xFF, 0xFF, 0, 0, 0, 0}; + return IfVecThenElse(BitCast(d, Vec128{mask}), b, a); +} + +template +HWY_INLINE Vec128 OddEven(Vec128 a, Vec128 b) { + // Same as ConcatUpperLower for full vectors; do not call that because this + // is more efficient for 64x1 vectors. + const DFromV d; + const __vector unsigned char mask = { + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0, 0, 0, 0, 0, 0, 0, 0}; + return IfVecThenElse(BitCast(d, Vec128{mask}), b, a); +} + +// ------------------------------ InterleaveEven + +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const Full128> d_full; + const Indices128> idx{ + Dup128VecFromValues(Full128(), 0, 16, 2, 18, 4, 20, 6, 22, 8, 24, + 10, 26, 12, 28, 14, 30) + .raw}; + return ResizeBitCast(d, TwoTablesLookupLanes(ResizeBitCast(d_full, a), + ResizeBitCast(d_full, b), idx)); +} + +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const Full128> d_full; + const Indices128> idx{Dup128VecFromValues(Full128(), 0, 1, + 16, 17, 4, 5, 20, 21, 8, + 9, 24, 25, 12, 13, 28, 29) + .raw}; + return ResizeBitCast(d, TwoTablesLookupLanes(ResizeBitCast(d_full, a), + ResizeBitCast(d_full, b), idx)); +} + +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { +#if HWY_S390X_HAVE_Z14 + const Full128> d_full; + const Indices128> idx{Dup128VecFromValues(Full128(), 0, 1, + 2, 3, 16, 17, 18, 19, 8, + 9, 10, 11, 24, 25, 26, 27) + .raw}; + return ResizeBitCast(d, TwoTablesLookupLanes(ResizeBitCast(d_full, a), + ResizeBitCast(d_full, b), idx)); +#else + (void)d; + return VFromD{vec_mergee(a.raw, b.raw)}; +#endif +} + +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveOdd + +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const Full128> d_full; + const Indices128> idx{ + Dup128VecFromValues(Full128(), 1, 17, 3, 19, 5, 21, 7, 23, 9, 25, + 11, 27, 13, 29, 15, 31) + .raw}; + return ResizeBitCast(d, TwoTablesLookupLanes(ResizeBitCast(d_full, a), + ResizeBitCast(d_full, b), idx)); +} + +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const Full128> d_full; + const Indices128> idx{ + Dup128VecFromValues(Full128(), 2, 3, 18, 19, 6, 7, 22, 23, 10, + 11, 26, 27, 14, 15, 30, 31) + .raw}; + return ResizeBitCast(d, TwoTablesLookupLanes(ResizeBitCast(d_full, a), + ResizeBitCast(d_full, b), idx)); +} + +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { +#if HWY_S390X_HAVE_Z14 + const Full128> d_full; + const Indices128> idx{ + Dup128VecFromValues(Full128(), 4, 5, 6, 7, 20, 21, 22, 23, 12, + 13, 14, 15, 28, 29, 30, 31) + .raw}; + return ResizeBitCast(d, TwoTablesLookupLanes(ResizeBitCast(d_full, a), + ResizeBitCast(d_full, b), idx)); +#else + (void)d; + return VFromD{vec_mergeo(a.raw, b.raw)}; +#endif +} + +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + return InterleaveUpper(d, a, b); +} + +// ------------------------------ OddEvenBlocks +template +HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks +template +HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { + return v; +} + +// ------------------------------ InterleaveEvenBlocks +template > +HWY_API V InterleaveEvenBlocks(D, V a, V /*b*/) { + return a; +} +// ------------------------------ InterleaveOddBlocks +template > +HWY_API V InterleaveOddBlocks(D, V a, V /*b*/) { + return a; +} + +// ------------------------------ InterleaveLowerBlocks +template > +HWY_API V InterleaveLowerBlocks(D, V a, V /*b*/) { + return a; +} +// ------------------------------ InterleaveUpperBlocks +template > +HWY_API V InterleaveUpperBlocks(D, V a, V /*b*/) { + return a; +} + +// ------------------------------ MulFixedPoint15 (OddEven) + +#if HWY_S390X_HAVE_Z14 +HWY_API Vec16 MulFixedPoint15(Vec16 a, Vec16 b) { + const DFromV di16; + const RepartitionToWide di32; + + const auto round_up_incr = Set(di32, 0x4000); + const auto i32_product = MulEven(a, b) + round_up_incr; + + return ResizeBitCast(di16, ShiftLeft<1>(i32_product)); +} +template +HWY_API Vec128 MulFixedPoint15(Vec128 a, + Vec128 b) { + const DFromV di16; + const RepartitionToWide di32; + + const auto round_up_incr = Set(di32, 0x4000); + const auto even_product = MulEven(a, b) + round_up_incr; + const auto odd_product = MulOdd(a, b) + round_up_incr; + + return OddEven(BitCast(di16, ShiftRight<15>(odd_product)), + BitCast(di16, ShiftLeft<1>(even_product))); +} +#else +template +HWY_API Vec128 MulFixedPoint15(Vec128 a, + Vec128 b) { + const Vec128 zero = Zero(Full128()); + return Vec128{vec_mradds(a.raw, b.raw, zero.raw)}; +} +#endif + +// ------------------------------ Shl + +namespace detail { +template +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_S390X_HAVE_Z14 + return Vec128{v.raw << bits.raw}; +#else + return Vec128{vec_sl(v.raw, bits.raw)}; +#endif +} + +// Signed left shift is the same as unsigned. +template +HWY_API Vec128 Shl(hwy::SignedTag /*tag*/, Vec128 v, + Vec128 bits) { + const DFromV di; + const RebindToUnsigned du; + return BitCast(di, + Shl(hwy::UnsignedTag(), BitCast(du, v), BitCast(du, bits))); +} + +} // namespace detail + +template +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return detail::Shl(hwy::TypeTag(), v, bits); +} + +// ------------------------------ Shr + +namespace detail { +template +HWY_API Vec128 Shr(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_S390X_HAVE_Z14 + return Vec128{v.raw >> bits.raw}; +#else + return Vec128{vec_sr(v.raw, bits.raw)}; +#endif +} + +template +HWY_API Vec128 Shr(hwy::SignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_S390X_HAVE_Z14 + return Vec128{v.raw >> bits.raw}; +#else + const DFromV di; + const RebindToUnsigned du; + return Vec128{vec_sra(v.raw, BitCast(du, bits).raw)}; +#endif +} + +} // namespace detail + +template +HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { + return detail::Shr(hwy::TypeTag(), v, bits); +} + +// ------------------------------ MulEven/Odd 64x64 (UpperHalf) + +template +HWY_INLINE Vec128 MulEven(Vec128 a, Vec128 b) { +#if HWY_PPC_HAVE_10 && defined(__SIZEOF_INT128__) + using V64 = typename detail::Raw128::type; + const V64 mul128_result = reinterpret_cast(vec_mule(a.raw, b.raw)); +#if HWY_IS_LITTLE_ENDIAN + return Vec128{mul128_result}; +#else + // Need to swap the two halves of mul128_result on big-endian targets as + // the upper 64 bits of the product are in lane 0 of mul128_result and + // the lower 64 bits of the product are in lane 1 of mul128_result + return Vec128{vec_sld(mul128_result, mul128_result, 8)}; +#endif +#else + alignas(16) T mul[2]; + mul[0] = Mul128(GetLane(a), GetLane(b), &mul[1]); + return Load(Full128(), mul); +#endif +} + +template +HWY_INLINE Vec128 MulOdd(Vec128 a, Vec128 b) { +#if HWY_PPC_HAVE_10 && defined(__SIZEOF_INT128__) + using V64 = typename detail::Raw128::type; + const V64 mul128_result = reinterpret_cast(vec_mulo(a.raw, b.raw)); +#if HWY_IS_LITTLE_ENDIAN + return Vec128{mul128_result}; +#else + // Need to swap the two halves of mul128_result on big-endian targets as + // the upper 64 bits of the product are in lane 0 of mul128_result and + // the lower 64 bits of the product are in lane 1 of mul128_result + return Vec128{vec_sld(mul128_result, mul128_result, 8)}; +#endif +#else + alignas(16) T mul[2]; + const Full64 d2; + mul[0] = + Mul128(GetLane(UpperHalf(d2, a)), GetLane(UpperHalf(d2, b)), &mul[1]); + return Load(Full128(), mul); +#endif +} + +// ------------------------------ PromoteEvenTo/PromoteOddTo +#include "hwy/ops/inside-inl.h" + +// ------------------------------ WidenMulPairwiseAdd + +template >> +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), + Mul(PromoteOddTo(df, a), PromoteOddTo(df, b))); +} + +// Even if N=1, the input is always at least 2 lanes, hence vec_msum is safe. +template >> +HWY_API VFromD WidenMulPairwiseAdd(D32 d32, V16 a, V16 b) { +#if HWY_S390X_HAVE_Z14 + (void)d32; + return MulEven(a, b) + MulOdd(a, b); +#else + return VFromD{vec_msum(a.raw, b.raw, Zero(d32).raw)}; +#endif +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +// Even if N=1, the input is always at least 2 lanes, hence vec_msum is safe. +template >> +HWY_API VFromD ReorderWidenMulAccumulate(D32 /*d32*/, V16 a, V16 b, + VFromD sum0, + VFromD& /*sum1*/) { +#if HWY_S390X_HAVE_Z14 + return MulEven(a, b) + MulOdd(a, b) + sum0; +#else + return VFromD{vec_msum(a.raw, b.raw, sum0.raw)}; +#endif +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW) { + return sum0; // invariant already holds +} + +// ------------------------------ SatWidenMulPairwiseAccumulate +#if !HWY_S390X_HAVE_Z14 + +#ifdef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#undef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#else +#define HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#endif + +template +HWY_API VFromD SatWidenMulPairwiseAccumulate( + DI32 /* tag */, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD{vec_msums(a.raw, b.raw, sum.raw)}; +} + +#endif // !HWY_S390X_HAVE_Z14 + +// ------------------------------ SumOfMulQuadAccumulate +#if !HWY_S390X_HAVE_Z14 + +#ifdef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#endif +template +HWY_API VFromD SumOfMulQuadAccumulate( + DU32 /*du32*/, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD{vec_msum(a.raw, b.raw, sum.raw)}; +} + +#ifdef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DI32 /*di32*/, VFromD> a_u, + VFromD> b_i, VFromD sum) { + return VFromD{vec_msum(b_i.raw, a_u.raw, sum.raw)}; +} + +#ifdef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#endif +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 di32, + VFromD> a, + VFromD> b, + VFromD sum) { + const Repartition du8; + + const auto result_sum_0 = + SumOfMulQuadAccumulate(di32, BitCast(du8, a), b, sum); + const auto result_sum_1 = ShiftLeft<8>(SumsOf4(And(b, BroadcastSignBit(a)))); + return result_sum_0 - result_sum_1; +} + +#endif // !HWY_S390X_HAVE_Z14 + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned to signed/unsigned: zero-extend. +template +HWY_API VFromD PromoteTo(D /* d */, + Vec128().MaxLanes()> v) { + // First pretend the input has twice the lanes - the upper half will be + // ignored by ZipLower. + const Rebind> d2; + const VFromD twice{v.raw}; + // Then cast to narrow as expected by ZipLower, in case the sign of FromT + // differs from that of D. + const RepartitionToNarrow dn; + +#if HWY_IS_LITTLE_ENDIAN + return ZipLower(BitCast(dn, twice), Zero(dn)); +#else + return ZipLower(Zero(dn), BitCast(dn, twice)); +#endif +} + +// Signed: replicate sign bit. +template +HWY_API VFromD PromoteTo(D /* d */, + Vec128().MaxLanes()> v) { + using Raw = typename detail::Raw128>::type; + return VFromD{reinterpret_cast(vec_unpackh(v.raw))}; +} + +// 8-bit to 32-bit: First, promote to 16-bit, and then convert to 32-bit. +template +HWY_API VFromD PromoteTo(D d32, + Vec128().MaxLanes()> v) { + const DFromV d8; + const Rebind, decltype(d8)> d16; + return PromoteTo(d32, PromoteTo(d16, v)); +} + +// 8-bit or 16-bit to 64-bit: First, promote to MakeWide, and then +// convert to 64-bit. +template +HWY_API VFromD PromoteTo(D d64, + Vec128().MaxLanes()> v) { + const Rebind, decltype(d64)> dw; + return PromoteTo(d64, PromoteTo(dw, v)); +} + +#if HWY_PPC_HAVE_9 + +// Per-target flag to prevent generic_ops-inl.h from defining f16 conversions. +#ifdef HWY_NATIVE_F16C +#undef HWY_NATIVE_F16C +#else +#define HWY_NATIVE_F16C +#endif + +template +HWY_INLINE VFromD PromoteTo(D /*tag*/, VFromD> v) { + return VFromD{vec_extract_fp32_from_shorth(v.raw)}; +} + +#endif // HWY_PPC_HAVE_9 + +template +HWY_API VFromD PromoteTo(D df32, VFromD> v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + const __vector float raw_v = InterleaveLower(v, v).raw; +#if HWY_IS_LITTLE_ENDIAN + return VFromD{vec_doubleo(raw_v)}; +#elif HWY_S390X_HAVE_Z14 && HWY_COMPILER_GCC_ACTUAL && \ + HWY_COMPILER_GCC_ACTUAL < 1000 + // Workaround for compiler errors with GCC 9 or earlier on Z14 + return VFromD{__builtin_s390_vflls(raw_v)}; +#else + return VFromD{vec_doublee(raw_v)}; +#endif +} + +template +HWY_API VFromD PromoteTo(D df64, VFromD> v) { +#if HWY_S390X_HAVE_Z14 + const RebindToSigned di64; + return ConvertTo(df64, PromoteTo(di64, v)); +#else // VSX + (void)df64; + const __vector signed int raw_v = InterleaveLower(v, v).raw; +#if HWY_IS_LITTLE_ENDIAN + return VFromD{vec_doubleo(raw_v)}; +#else + return VFromD{vec_doublee(raw_v)}; +#endif +#endif // HWY_S390X_HAVE_Z14 +} + +template +HWY_API VFromD PromoteTo(D df64, VFromD> v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du64; + return ConvertTo(df64, PromoteTo(du64, v)); +#else // VSX + (void)df64; + const __vector unsigned int raw_v = InterleaveLower(v, v).raw; +#if HWY_IS_LITTLE_ENDIAN + return VFromD{vec_doubleo(raw_v)}; +#else + return VFromD{vec_doublee(raw_v)}; +#endif +#endif // HWY_S390X_HAVE_Z14 +} + +#if !HWY_S390X_HAVE_Z14 +namespace detail { + +template +static HWY_INLINE V VsxF2INormalizeSrcVals(V v) { +#if !defined(HWY_DISABLE_PPC_VSX_QEMU_F2I_WORKAROUND) + // Workaround for QEMU 7/8 VSX float to int conversion bug + return IfThenElseZero(v == v, v); +#else + return v; +#endif +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED VFromD>> +VsxXvcvspsxds(VF32 vf32) { + using VI64 = VFromD>>; +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1500) || \ + HWY_HAS_BUILTIN(__builtin_vsx_xvcvspsxds) + // Use __builtin_vsx_xvcvspsxds if it is available (which is the case with + // GCC 4.8 through GCC 14 or Clang 13 or later on PPC8/PPC9/PPC10) + return VI64{__builtin_vsx_xvcvspsxds(vf32.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL >= 1500 && HWY_IS_LITTLE_ENDIAN + // On little-endian PPC8/PPC9/PPC10 with GCC 15 or later, use the F32->I64 + // vec_signedo intrinsic as the __builtin_vsx_xvcvspsxds intrinsic has been + // removed from GCC in GCC 15 + return VI64{vec_signedo(vf32.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL >= 1500 && HWY_IS_BIG_ENDIAN + // On big-endian PPC8/PPC9/PPC10 with GCC 15 or later, use the F32->I64 + // vec_signede intrinsic as the __builtin_vsx_xvcvspsxds intrinsic has been + // removed from GCC in GCC 15 + return VI64{vec_signede(vf32.raw)}; +#else + // Inline assembly fallback for older versions of Clang that do not have the + // __builtin_vsx_xvcvspsxds intrinsic + __vector signed long long raw_result; + __asm__("xvcvspsxds %x0, %x1" : "=wa"(raw_result) : "wa"(vf32.raw) :); + return VI64{raw_result}; +#endif +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED VFromD>> +VsxXvcvspuxds(VF32 vf32) { + using VU64 = VFromD>>; +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1500) || \ + HWY_HAS_BUILTIN(__builtin_vsx_xvcvspuxds) + // Use __builtin_vsx_xvcvspuxds if it is available (which is the case with + // GCC 4.8 through GCC 14 or Clang 13 or later on PPC8/PPC9/PPC10) + return VU64{reinterpret_cast<__vector unsigned long long>( + __builtin_vsx_xvcvspuxds(vf32.raw))}; +#elif HWY_COMPILER_GCC_ACTUAL >= 1500 && HWY_IS_LITTLE_ENDIAN + // On little-endian PPC8/PPC9/PPC10 with GCC 15 or later, use the F32->U64 + // vec_unsignedo intrinsic as the __builtin_vsx_xvcvspuxds intrinsic has been + // removed from GCC in GCC 15 + return VU64{vec_unsignedo(vf32.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL >= 1500 && HWY_IS_BIG_ENDIAN + // On big-endian PPC8/PPC9/PPC10 with GCC 15 or later, use the F32->U64 + // vec_unsignedo intrinsic as the __builtin_vsx_xvcvspuxds intrinsic has been + // removed from GCC in GCC 15 + return VU64{vec_unsignede(vf32.raw)}; +#else + // Inline assembly fallback for older versions of Clang that do not have the + // __builtin_vsx_xvcvspuxds intrinsic + __vector unsigned long long raw_result; + __asm__("xvcvspuxds %x0, %x1" : "=wa"(raw_result) : "wa"(vf32.raw) :); + return VU64{raw_result}; +#endif +} + +} // namespace detail +#endif // !HWY_S390X_HAVE_Z14 + +template +HWY_API VFromD PromoteTo(D di64, VFromD> v) { +#if !HWY_S390X_HAVE_Z14 + const Repartition dt_f32; + const auto vt_f32 = ResizeBitCast(dt_f32, v); + return detail::VsxXvcvspsxds( + detail::VsxF2INormalizeSrcVals(InterleaveLower(vt_f32, vt_f32))); +#else + const RebindToFloat df64; + return ConvertTo(di64, PromoteTo(df64, v)); +#endif +} + +template +HWY_API VFromD PromoteTo(D du64, VFromD> v) { +#if !HWY_S390X_HAVE_Z14 + const Repartition dt_f32; + const auto vt_f32 = ResizeBitCast(dt_f32, v); + return detail::VsxXvcvspuxds( + detail::VsxF2INormalizeSrcVals(InterleaveLower(vt_f32, vt_f32))); +#else + const RebindToFloat df64; + return ConvertTo(du64, PromoteTo(df64, v)); +#endif +} + +// ------------------------------ PromoteUpperTo + +#ifdef HWY_NATIVE_PROMOTE_UPPER_TO +#undef HWY_NATIVE_PROMOTE_UPPER_TO +#else +#define HWY_NATIVE_PROMOTE_UPPER_TO +#endif + +// Unsigned to signed/unsigned: zero-extend. +template +HWY_API VFromD PromoteUpperTo(D d, Vec128 v) { + const RebindToUnsigned du; + const RepartitionToNarrow dn; + +#if HWY_IS_LITTLE_ENDIAN + return BitCast(d, ZipUpper(du, v, Zero(dn))); +#else + return BitCast(d, ZipUpper(du, Zero(dn), v)); +#endif +} + +// Signed: replicate sign bit. +template +HWY_API VFromD PromoteUpperTo(D /* d */, Vec128 v) { + using Raw = typename detail::Raw128>::type; + return VFromD{reinterpret_cast(vec_unpackl(v.raw))}; +} + +// F16 to F32 +template +HWY_API VFromD PromoteUpperTo(D df32, Vec128 v) { +#if HWY_PPC_HAVE_9 + (void)df32; + return VFromD{vec_extract_fp32_from_shortl(v.raw)}; +#else + const Rebind dh; + return PromoteTo(df32, UpperHalf(dh, v)); +#endif +} + +// BF16 to F32 +template +HWY_API VFromD PromoteUpperTo(D df32, Vec128 v) { + const Repartition du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteUpperTo(di32, BitCast(du16, v)))); +} + +template +HWY_API VFromD PromoteUpperTo(D /*tag*/, Vec128 v) { + const __vector float raw_v = InterleaveUpper(Full128(), v, v).raw; +#if HWY_IS_LITTLE_ENDIAN + return VFromD{vec_doubleo(raw_v)}; +#elif HWY_S390X_HAVE_Z14 && HWY_COMPILER_GCC_ACTUAL && \ + HWY_COMPILER_GCC_ACTUAL < 1000 + // Workaround for compiler error with GCC 9 or earlier on Z14 + return VFromD{__builtin_s390_vflls(raw_v)}; +#else + return VFromD{vec_doublee(raw_v)}; +#endif +} + +template +HWY_API VFromD PromoteUpperTo(D df64, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToSigned di64; + return ConvertTo(df64, PromoteUpperTo(di64, v)); +#else // VSX + (void)df64; + const __vector signed int raw_v = + InterleaveUpper(Full128(), v, v).raw; +#if HWY_IS_LITTLE_ENDIAN + return VFromD{vec_doubleo(raw_v)}; +#else + return VFromD{vec_doublee(raw_v)}; +#endif +#endif // HWY_S390X_HAVE_Z14 +} + +template +HWY_API VFromD PromoteUpperTo(D df64, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du64; + return ConvertTo(df64, PromoteUpperTo(du64, v)); +#else // VSX + (void)df64; + const __vector unsigned int raw_v = + InterleaveUpper(Full128(), v, v).raw; +#if HWY_IS_LITTLE_ENDIAN + return VFromD{vec_doubleo(raw_v)}; +#else + return VFromD{vec_doublee(raw_v)}; +#endif +#endif // HWY_S390X_HAVE_Z14 +} + +template +HWY_API VFromD PromoteUpperTo(D di64, Vec128 v) { +#if !HWY_S390X_HAVE_Z14 + (void)di64; + return detail::VsxXvcvspsxds( + detail::VsxF2INormalizeSrcVals(InterleaveUpper(Full128(), v, v))); +#else + const RebindToFloat df64; + return ConvertTo(di64, PromoteUpperTo(df64, v)); +#endif +} + +template +HWY_API VFromD PromoteUpperTo(D du64, Vec128 v) { +#if !HWY_S390X_HAVE_Z14 + (void)du64; + return detail::VsxXvcvspuxds( + detail::VsxF2INormalizeSrcVals(InterleaveUpper(Full128(), v, v))); +#else + const RebindToFloat df64; + return ConvertTo(du64, PromoteUpperTo(df64, v)); +#endif +} + +// Generic version for <=64 bit input/output +template +HWY_API VFromD PromoteUpperTo(D d, V v) { + const Rebind, decltype(d)> dh; + return PromoteTo(d, UpperHalf(dh, v)); +} + +// ------------------------------ PromoteEvenTo/PromoteOddTo + +namespace detail { + +// Signed to Signed PromoteEvenTo/PromoteOddTo for PPC9/PPC10 +#if HWY_PPC_HAVE_9 && \ + (HWY_COMPILER_GCC_ACTUAL >= 1200 || HWY_COMPILER_CLANG >= 1200) + +#if HWY_IS_LITTLE_ENDIAN +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<4> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D /*d_to*/, + V v) { + return VFromD{vec_signexti(v.raw)}; +} +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D /*d_to*/, + V v) { + return VFromD{vec_signextll(v.raw)}; +} +#else +template +HWY_INLINE VFromD PromoteOddTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<4> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D /*d_to*/, + V v) { + return VFromD{vec_signexti(v.raw)}; +} +template +HWY_INLINE VFromD PromoteOddTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D /*d_to*/, + V v) { + return VFromD{vec_signextll(v.raw)}; +} +#endif // HWY_IS_LITTLE_ENDIAN + +#endif // HWY_PPC_HAVE_9 + +// I32/U32/F32->F64 PromoteEvenTo +#if HWY_S390X_HAVE_Z14 +template +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D /*d_to*/, + V v) { + return VFromD{vec_doublee(v.raw)}; +} +template )> +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, D d_to, V v) { + const Rebind>, decltype(d_to)> dw; + return ConvertTo(d_to, PromoteEvenTo(dw, v)); +} +#else // VSX +template +HWY_INLINE VFromD PromoteEvenTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, D /*d_to*/, + V v) { + return VFromD{vec_doublee(v.raw)}; +} +#endif // HWY_S390X_HAVE_Z14 + +// F32->I64 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + V v) { +#if !HWY_S390X_HAVE_Z14 + (void)d_to; + const auto normalized_v = detail::VsxF2INormalizeSrcVals(v); +#if HWY_IS_LITTLE_ENDIAN + // VsxXvcvspsxds expects the source values to be in the odd lanes on + // little-endian PPC, and the Shuffle2103 operation below will shift the even + // lanes of normalized_v into the odd lanes. + return VsxXvcvspsxds(Shuffle2103(normalized_v)); +#else + // VsxXvcvspsxds expects the source values to be in the even lanes on + // big-endian PPC. + return VsxXvcvspsxds(normalized_v); +#endif +#else + const RebindToFloat df64; + return ConvertTo(d_to, PromoteEvenTo(hwy::FloatTag(), hwy::SizeTag<8>(), + hwy::FloatTag(), df64, v)); +#endif +} + +// F32->U64 PromoteEvenTo +template +HWY_INLINE VFromD PromoteEvenTo(hwy::UnsignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + V v) { +#if !HWY_S390X_HAVE_Z14 + (void)d_to; + const auto normalized_v = detail::VsxF2INormalizeSrcVals(v); +#if HWY_IS_LITTLE_ENDIAN + // VsxXvcvspuxds expects the source values to be in the odd lanes + // on little-endian PPC, and the Shuffle2103 operation below will shift the + // even lanes of normalized_v into the odd lanes. + return VsxXvcvspuxds(Shuffle2103(normalized_v)); +#else + // VsxXvcvspuxds expects the source values to be in the even lanes + // on big-endian PPC. + return VsxXvcvspuxds(normalized_v); +#endif +#else + const RebindToFloat df64; + return ConvertTo(d_to, PromoteEvenTo(hwy::FloatTag(), hwy::SizeTag<8>(), + hwy::FloatTag(), df64, v)); +#endif +} + +// I32/U32/F32->F64 PromoteOddTo +#if HWY_S390X_HAVE_Z14 +template +HWY_INLINE VFromD PromoteOddTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + V v) { + return PromoteEvenTo(hwy::FloatTag(), hwy::SizeTag<8>(), hwy::FloatTag(), + d_to, V{vec_sld(v.raw, v.raw, 4)}); +} +template )> +HWY_INLINE VFromD PromoteOddTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, D d_to, V v) { + const Rebind>, decltype(d_to)> dw; + return ConvertTo(d_to, PromoteOddTo(dw, v)); +} +#else +template +HWY_INLINE VFromD PromoteOddTo(hwy::FloatTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + FromTypeTag /*from_type_tag*/, D /*d_to*/, + V v) { + return VFromD{vec_doubleo(v.raw)}; +} +#endif + +// F32->I64 PromoteOddTo +template +HWY_INLINE VFromD PromoteOddTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + V v) { +#if !HWY_S390X_HAVE_Z14 + (void)d_to; + const auto normalized_v = detail::VsxF2INormalizeSrcVals(v); +#if HWY_IS_LITTLE_ENDIAN + // VsxXvcvspsxds expects the source values to be in the odd lanes + // on little-endian PPC + return VsxXvcvspsxds(normalized_v); +#else + // VsxXvcvspsxds expects the source values to be in the even lanes + // on big-endian PPC, and the Shuffle0321 operation below will shift the odd + // lanes of normalized_v into the even lanes. + return VsxXvcvspsxds(Shuffle0321(normalized_v)); +#endif +#else + const RebindToFloat df64; + return ConvertTo(d_to, PromoteOddTo(hwy::FloatTag(), hwy::SizeTag<8>(), + hwy::FloatTag(), df64, v)); +#endif +} + +// F32->U64 PromoteOddTo +template +HWY_INLINE VFromD PromoteOddTo(hwy::UnsignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::FloatTag /*from_type_tag*/, D d_to, + V v) { +#if !HWY_S390X_HAVE_Z14 + (void)d_to; + const auto normalized_v = detail::VsxF2INormalizeSrcVals(v); +#if HWY_IS_LITTLE_ENDIAN + // VsxXvcvspuxds expects the source values to be in the odd lanes + // on little-endian PPC + return VsxXvcvspuxds(normalized_v); +#else + // VsxXvcvspuxds expects the source values to be in the even lanes + // on big-endian PPC, and the Shuffle0321 operation below will shift the odd + // lanes of normalized_v into the even lanes. + return VsxXvcvspuxds(Shuffle0321(normalized_v)); +#endif +#else + const RebindToFloat df64; + return ConvertTo(d_to, PromoteOddTo(hwy::FloatTag(), hwy::SizeTag<8>(), + hwy::FloatTag(), df64, v)); +#endif +} + +} // namespace detail + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +template ) * 2)> +HWY_API VFromD DemoteTo(D /* tag */, + Vec128().MaxLanes()> v) { + return VFromD{vec_packsu(v.raw, v.raw)}; +} + +template ) * 2)> +HWY_API VFromD DemoteTo(D /* tag */, + Vec128().MaxLanes()> v) { + return VFromD{vec_packs(v.raw, v.raw)}; +} + +template ) * 2)> +HWY_API VFromD DemoteTo(D /* tag */, + Vec128().MaxLanes()> v) { + return VFromD{vec_packs(v.raw, v.raw)}; +} + +template = sizeof(TFromD) * 4)>* = nullptr> +HWY_API VFromD DemoteTo(D d, + Vec128().MaxLanes()> v) { + const Rebind, D> d2; + return DemoteTo(d, DemoteTo(d2, v)); +} + +template = sizeof(TFromD) * 4)>* = nullptr> +HWY_API VFromD DemoteTo(D d, + Vec128().MaxLanes()> v) { + const Rebind, D> d2; + return DemoteTo(d, DemoteTo(d2, v)); +} + +template = sizeof(TFromD) * 4)>* = nullptr> +HWY_API VFromD DemoteTo(D d, + Vec128().MaxLanes()> v) { + const Rebind>, D> d2; + return DemoteTo(d, DemoteTo(d2, v)); +} + +#if HWY_PPC_HAVE_9 && \ + (HWY_COMPILER_GCC_ACTUAL || HWY_HAS_BUILTIN(__builtin_vsx_xvcvsphp)) + +// We already toggled HWY_NATIVE_F16C above. + +template +HWY_API VFromD DemoteTo(D df16, VFromD> v) { +// Avoid vec_pack_to_short_fp32 on Clang because its implementation is buggy. +#if HWY_COMPILER_GCC_ACTUAL + (void)df16; + return VFromD{vec_pack_to_short_fp32(v.raw, v.raw)}; +#elif HWY_HAS_BUILTIN(__builtin_vsx_xvcvsphp) + // Work around bug in the clang implementation of vec_pack_to_short_fp32 + // by using the __builtin_vsx_xvcvsphp builtin on PPC9/PPC10 targets + // if the __builtin_vsx_xvcvsphp intrinsic is available + const RebindToUnsigned du16; + const Rebind du; + const VFromD bits16{ + reinterpret_cast<__vector unsigned int>(__builtin_vsx_xvcvsphp(v.raw))}; + return BitCast(df16, TruncateTo(du16, bits16)); +#else +#error "Only define the function if we have a native implementation" +#endif +} + +#endif // HWY_PPC_HAVE_9 + +#if HWY_PPC_HAVE_9 + +#ifdef HWY_NATIVE_DEMOTE_F64_TO_F16 +#undef HWY_NATIVE_DEMOTE_F64_TO_F16 +#else +#define HWY_NATIVE_DEMOTE_F64_TO_F16 +#endif + +namespace detail { + +// On big-endian PPC9, VsxXscvdphp converts vf64[0] to a F16, returned as an U64 +// vector with the resulting F16 bits in the lower 16 bits of U64 lane 0 + +// On little-endian PPC9, VsxXscvdphp converts vf64[1] to a F16, returned as +// an U64 vector with the resulting F16 bits in the lower 16 bits of U64 lane 1 +static HWY_INLINE Vec128 VsxXscvdphp(Vec128 vf64) { + // Inline assembly is needed for the PPC9 xscvdphp instruction as there is + // currently no intrinsic available for the PPC9 xscvdphp instruction + __vector unsigned long long raw_result; + __asm__("xscvdphp %x0, %x1" : "=wa"(raw_result) : "wa"(vf64.raw)); + return Vec128{raw_result}; +} + +} // namespace detail + +template +HWY_API VFromD DemoteTo(D df16, VFromD> v) { + const RebindToUnsigned du16; + const Rebind du64; + + const Full128 df64_full; +#if HWY_IS_LITTLE_ENDIAN + const auto bits16_as_u64 = + UpperHalf(du64, detail::VsxXscvdphp(Combine(df64_full, v, v))); +#else + const auto bits16_as_u64 = + LowerHalf(du64, detail::VsxXscvdphp(ResizeBitCast(df64_full, v))); +#endif + + return BitCast(df16, TruncateTo(du16, bits16_as_u64)); +} + +template +HWY_API VFromD DemoteTo(D df16, VFromD> v) { + const RebindToUnsigned du16; + const Rebind du64; + const Rebind df64; + +#if HWY_IS_LITTLE_ENDIAN + const auto bits64_as_u64_0 = detail::VsxXscvdphp(InterleaveLower(df64, v, v)); + const auto bits64_as_u64_1 = detail::VsxXscvdphp(v); + const auto bits64_as_u64 = + InterleaveUpper(du64, bits64_as_u64_0, bits64_as_u64_1); +#else + const auto bits64_as_u64_0 = detail::VsxXscvdphp(v); + const auto bits64_as_u64_1 = detail::VsxXscvdphp(InterleaveUpper(df64, v, v)); + const auto bits64_as_u64 = + InterleaveLower(du64, bits64_as_u64_0, bits64_as_u64_1); +#endif + + return BitCast(df16, TruncateTo(du16, bits64_as_u64)); +} + +#elif HWY_S390X_HAVE_Z14 + +#ifdef HWY_NATIVE_DEMOTE_F64_TO_F16 +#undef HWY_NATIVE_DEMOTE_F64_TO_F16 +#else +#define HWY_NATIVE_DEMOTE_F64_TO_F16 +#endif + +namespace detail { + +template +static HWY_INLINE VFromD DemoteToF32WithRoundToOdd( + DF32 df32, VFromD> v) { + const Twice dt_f32; + + __vector float raw_f32_in_even; + __asm__("vledb %0,%1,0,3" : "=v"(raw_f32_in_even) : "v"(v.raw)); + + const VFromD f32_in_even{raw_f32_in_even}; + return LowerHalf(df32, ConcatEven(dt_f32, f32_in_even, f32_in_even)); +} + +} // namespace detail + +template +HWY_API VFromD DemoteTo(D df16, VFromD> v) { + const Rebind df32; + return DemoteTo(df16, detail::DemoteToF32WithRoundToOdd(df32, v)); +} + +#endif // HWY_PPC_HAVE_9 + +#if HWY_PPC_HAVE_10 && HWY_HAS_BUILTIN(__builtin_vsx_xvcvspbf16) + +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + +namespace detail { + +// VsxXvcvspbf16 converts a F32 vector to a BF16 vector, bitcasted to an U32 +// vector with the resulting BF16 bits in the lower 16 bits of each U32 lane +template +static HWY_INLINE VFromD> VsxXvcvspbf16( + D dbf16, VFromD> v) { + const Rebind du32; + const Repartition du32_as_du8; + + using VU32 = __vector unsigned int; + + // Even though the __builtin_vsx_xvcvspbf16 builtin performs a F32 to BF16 + // conversion, the __builtin_vsx_xvcvspbf16 intrinsic expects a + // __vector unsigned char argument (at least as of GCC 13 and Clang 17) + return VFromD>{reinterpret_cast( + __builtin_vsx_xvcvspbf16(BitCast(du32_as_du8, v).raw))}; +} + +} // namespace detail + +template +HWY_API VFromD DemoteTo(D dbf16, VFromD> v) { + const RebindToUnsigned du16; + return BitCast(dbf16, TruncateTo(du16, detail::VsxXvcvspbf16(dbf16, v))); +} + +#endif // HWY_PPC_HAVE_10 && HWY_HAS_BUILTIN(__builtin_vsx_xvcvspbf16) + +// Specializations for partial vectors because vec_packs sets lanes above 2*N. +template ) * 2)> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +template ) * 2)> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const Twice dn_full; + const Repartition du32_full; + + const VFromD v_full{vec_packs(a.raw, b.raw)}; + const auto vu32_full = BitCast(du32_full, v_full); + return LowerHalf( + BitCast(dn_full, ConcatEven(du32_full, vu32_full, vu32_full))); +} +template ) * 2)> +HWY_API VFromD ReorderDemote2To(DN /*dn*/, V a, V b) { + return VFromD{vec_packs(a.raw, b.raw)}; +} + +template ) * 2)> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +template ) * 2)> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const Twice dn_full; + const Repartition du32_full; + + const VFromD v_full{vec_packsu(a.raw, b.raw)}; + const auto vu32_full = BitCast(du32_full, v_full); + return LowerHalf( + BitCast(dn_full, ConcatEven(du32_full, vu32_full, vu32_full))); +} +template ) * 2)> +HWY_API VFromD ReorderDemote2To(DN /*dn*/, V a, V b) { + return VFromD{vec_packsu(a.raw, b.raw)}; +} + +template ) * 2)> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +template ) * 2)> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const Twice dn_full; + const Repartition du32_full; + + const VFromD v_full{vec_packs(a.raw, b.raw)}; + const auto vu32_full = BitCast(du32_full, v_full); + return LowerHalf( + BitCast(dn_full, ConcatEven(du32_full, vu32_full, vu32_full))); +} +template ) * 2)> +HWY_API VFromD ReorderDemote2To(DN /*dn*/, V a, V b) { + return VFromD{vec_packs(a.raw, b.raw)}; +} + +#if HWY_PPC_HAVE_10 && HWY_HAS_BUILTIN(__builtin_vsx_xvcvspbf16) +template ), + HWY_IF_LANES_D(D, HWY_MAX_LANES_V(V) * 2)> +HWY_API VFromD ReorderDemote2To(D dbf16, V a, V b) { + const RebindToUnsigned du16; + const Half dh_bf16; + return BitCast(dbf16, + OrderedTruncate2To(du16, detail::VsxXvcvspbf16(dh_bf16, a), + detail::VsxXvcvspbf16(dh_bf16, b))); +} +#endif + +template ), class V, + HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedDemote2To(D d, V a, V b) { + return ReorderDemote2To(d, a, b); +} + +#if HWY_PPC_HAVE_10 && HWY_HAS_BUILTIN(__builtin_vsx_xvcvspbf16) +template ), + HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedDemote2To(D d, V a, V b) { + return ReorderDemote2To(d, a, b); +} +#endif + +template +HWY_API Vec32 DemoteTo(D /* tag */, Vec64 v) { +#if HWY_S390X_HAVE_Z14 && HWY_COMPILER_GCC_ACTUAL && \ + HWY_COMPILER_GCC_ACTUAL < 1000 + // Workaround for compiler error with GCC 9 or earlier on Z14 + return Vec32{__builtin_s390_vflrd(v.raw, 0, 0)}; +#else + return Vec32{vec_floate(v.raw)}; +#endif +} + +template +HWY_API Vec64 DemoteTo(D d, Vec128 v) { +#if HWY_S390X_HAVE_Z14 && HWY_COMPILER_GCC_ACTUAL && \ + HWY_COMPILER_GCC_ACTUAL < 1000 + // Workaround for compiler error with GCC 9 or earlier on Z14 + const Vec128 f64_to_f32{__builtin_s390_vflrd(v.raw, 0, 0)}; +#elif HWY_S390X_HAVE_Z14 || HWY_IS_LITTLE_ENDIAN + const Vec128 f64_to_f32{vec_floate(v.raw)}; +#else + const Vec128 f64_to_f32{vec_floato(v.raw)}; +#endif + +#if HWY_S390X_HAVE_Z14 + const Twice dt; + return LowerHalf(d, ConcatEven(dt, f64_to_f32, f64_to_f32)); +#else + const RebindToUnsigned du; + const Rebind du64; + return Vec64{ + BitCast(d, TruncateTo(du, BitCast(du64, f64_to_f32))).raw}; +#endif +} + +template +HWY_API Vec32 DemoteTo(D di32, Vec64 v) { +#if HWY_S390X_HAVE_Z14 + const Rebind di64; + return DemoteTo(di32, ConvertTo(di64, v)); +#else + (void)di32; + return Vec32{vec_signede(detail::VsxF2INormalizeSrcVals(v).raw)}; +#endif +} + +template +HWY_API Vec64 DemoteTo(D di32, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const Rebind di64; + return DemoteTo(di32, ConvertTo(di64, v)); +#else + (void)di32; + +#if HWY_IS_LITTLE_ENDIAN + const Vec128 f64_to_i32{ + vec_signede(detail::VsxF2INormalizeSrcVals(v).raw)}; +#else + const Vec128 f64_to_i32{ + vec_signedo(detail::VsxF2INormalizeSrcVals(v).raw)}; +#endif + + const Rebind di64; + const Vec128 vi64 = BitCast(di64, f64_to_i32); + return Vec64{vec_pack(vi64.raw, vi64.raw)}; +#endif +} + +template +HWY_API Vec32 DemoteTo(D du32, Vec64 v) { +#if HWY_S390X_HAVE_Z14 + const Rebind du64; + return DemoteTo(du32, ConvertTo(du64, v)); +#else + (void)du32; + return Vec32{vec_unsignede(detail::VsxF2INormalizeSrcVals(v).raw)}; +#endif +} + +template +HWY_API Vec64 DemoteTo(D du32, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const Rebind du64; + return DemoteTo(du32, ConvertTo(du64, v)); +#else + (void)du32; +#if HWY_IS_LITTLE_ENDIAN + const Vec128 f64_to_u32{ + vec_unsignede(detail::VsxF2INormalizeSrcVals(v).raw)}; +#else + const Vec128 f64_to_u32{ + vec_unsignedo(detail::VsxF2INormalizeSrcVals(v).raw)}; +#endif + + const Rebind du64; + const Vec128 vu64 = BitCast(du64, f64_to_u32); + return Vec64{vec_pack(vu64.raw, vu64.raw)}; +#endif +} + +#if HWY_S390X_HAVE_Z14 +namespace detail { + +template )> +HWY_INLINE VFromD>> ConvToF64WithRoundToOdd(V v) { + __vector double raw_result; + // Use inline assembly to do a round-to-odd I64->F64 conversion on Z14 + __asm__("vcdgb %0,%1,0,3" : "=v"(raw_result) : "v"(v.raw)); + return VFromD>>{raw_result}; +} + +template )> +HWY_INLINE VFromD>> ConvToF64WithRoundToOdd(V v) { + __vector double raw_result; + // Use inline assembly to do a round-to-odd U64->F64 conversion on Z14 + __asm__("vcdlgb %0,%1,0,3" : "=v"(raw_result) : "v"(v.raw)); + return VFromD>>{raw_result}; +} + +} // namespace detail +#endif // HWY_S390X_HAVE_Z14 + +template +HWY_API Vec32 DemoteTo(D df32, Vec64 v) { +#if HWY_S390X_HAVE_Z14 + return DemoteTo(df32, detail::ConvToF64WithRoundToOdd(v)); +#else // VSX + (void)df32; + return Vec32{vec_floate(v.raw)}; +#endif +} + +template +HWY_API Vec64 DemoteTo(D df32, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + return DemoteTo(df32, detail::ConvToF64WithRoundToOdd(v)); +#else // VSX +#if HWY_IS_LITTLE_ENDIAN + const Vec128 i64_to_f32{vec_floate(v.raw)}; +#else + const Vec128 i64_to_f32{vec_floato(v.raw)}; +#endif + + const RebindToUnsigned du32; + const Rebind du64; + return Vec64{ + BitCast(df32, TruncateTo(du32, BitCast(du64, i64_to_f32))).raw}; +#endif +} + +template +HWY_API Vec32 DemoteTo(D df32, Vec64 v) { +#if HWY_S390X_HAVE_Z14 + return DemoteTo(df32, detail::ConvToF64WithRoundToOdd(v)); +#else // VSX + (void)df32; + return Vec32{vec_floate(v.raw)}; +#endif +} + +template +HWY_API Vec64 DemoteTo(D df32, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + return DemoteTo(df32, detail::ConvToF64WithRoundToOdd(v)); +#else // VSX +#if HWY_IS_LITTLE_ENDIAN + const Vec128 u64_to_f32{vec_floate(v.raw)}; +#else + const Vec128 u64_to_f32{vec_floato(v.raw)}; +#endif + + const RebindToUnsigned du; + const Rebind du64; + return Vec64{ + BitCast(df32, TruncateTo(du, BitCast(du64, u64_to_f32))).raw}; +#endif +} + +// For already range-limited input [0, 255]. +template +HWY_API Vec128 U8FromU32(Vec128 v) { + const Rebind> du16; + const Rebind du8; + return TruncateTo(du8, TruncateTo(du16, v)); +} +// ------------------------------ Integer <=> fp (ShiftRight, OddEven) + +// Note: altivec.h vec_ct* currently contain C casts which triggers +// -Wdeprecate-lax-vec-conv-all warnings, so disable them. + +#if HWY_S390X_HAVE_Z14 && !HWY_S390X_HAVE_Z15 +template +HWY_API VFromD ConvertTo(D df32, + Vec128().MaxLanes()> v) { + const Rebind df64; + return DemoteTo(df32, PromoteTo(df64, v)); +} +template +HWY_API VFromD ConvertTo(D df32, Vec128 v) { + const RepartitionToWide df64; + +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1000 + // Workaround for compiler error with GCC 9 or earlier on Z14 + const VFromD vf32_lo{ + __builtin_s390_vflrd(PromoteLowerTo(df64, v).raw, 0, 0)}; + const VFromD vf32_hi{ + __builtin_s390_vflrd(PromoteUpperTo(df64, v).raw, 0, 0)}; +#else + const VFromD vf32_lo{vec_floate(PromoteLowerTo(df64, v).raw)}; + const VFromD vf32_hi{vec_floate(PromoteUpperTo(df64, v).raw)}; +#endif + return ConcatEven(df32, vf32_hi, vf32_lo); +} +#else // Z15 or PPC +template +HWY_API VFromD ConvertTo(D /* tag */, + Vec128().MaxLanes()> v) { + HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_CLANG + HWY_DIAGNOSTICS_OFF(disable : 5219, ignored "-Wdeprecate-lax-vec-conv-all") +#endif +#if HWY_S390X_HAVE_Z15 + return VFromD{vec_float(v.raw)}; +#else + return VFromD{vec_ctf(v.raw, 0)}; +#endif + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_TARGET == HWY_Z14 + +template +HWY_API VFromD ConvertTo(D /* tag */, + Vec128().MaxLanes()> v) { + return VFromD{vec_double(v.raw)}; +} + +// Truncates (rounds toward zero). +#if HWY_S390X_HAVE_Z14 && !HWY_S390X_HAVE_Z15 +template +HWY_API VFromD ConvertTo(D di32, + Vec128().MaxLanes()> v) { + const Rebind di64; + return DemoteTo(di32, PromoteTo(di64, v)); +} +template +HWY_API VFromD ConvertTo(D di32, + Vec128().MaxLanes()> v) { + const RepartitionToWide di64; + return OrderedDemote2To(di32, PromoteLowerTo(di64, v), + PromoteUpperTo(di64, v)); +} +#else // Z15 or PPC +template +HWY_API VFromD ConvertTo(D /* tag */, + Vec128().MaxLanes()> v) { +#if defined(__OPTIMIZE__) + if (detail::IsConstantRawAltivecVect(v.raw)) { + constexpr int32_t kMinI32 = LimitsMin(); + constexpr int32_t kMaxI32 = LimitsMax(); + return Dup128VecFromValues( + D(), + (v.raw[0] >= -2147483648.0f) + ? ((v.raw[0] < 2147483648.0f) ? static_cast(v.raw[0]) + : kMaxI32) + : ((v.raw[0] < 0) ? kMinI32 : 0), + (v.raw[1] >= -2147483648.0f) + ? ((v.raw[1] < 2147483648.0f) ? static_cast(v.raw[1]) + : kMaxI32) + : ((v.raw[1] < 0) ? kMinI32 : 0), + (v.raw[2] >= -2147483648.0f) + ? ((v.raw[2] < 2147483648.0f) ? static_cast(v.raw[2]) + : kMaxI32) + : ((v.raw[2] < 0) ? kMinI32 : 0), + (v.raw[3] >= -2147483648.0f) + ? ((v.raw[3] < 2147483648.0f) ? static_cast(v.raw[3]) + : kMaxI32) + : ((v.raw[3] < 0) ? kMinI32 : 0)); + } +#endif + +#if HWY_S390X_HAVE_Z15 + // Use inline assembly on Z15 to avoid undefined behavior if v[i] is not in + // the range of an int32_t + __vector signed int raw_result; + __asm__("vcfeb %0,%1,0,5" : "=v"(raw_result) : "v"(v.raw)); + return VFromD{raw_result}; +#else + HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_CLANG + HWY_DIAGNOSTICS_OFF(disable : 5219, ignored "-Wdeprecate-lax-vec-conv-all") +#endif + return VFromD{vec_cts(v.raw, 0)}; + HWY_DIAGNOSTICS(pop) +#endif // HWY_S390X_HAVE_Z15 +} +#endif // HWY_S390X_HAVE_Z14 && !HWY_S390X_HAVE_Z15 + +template +HWY_API VFromD ConvertTo(D /* tag */, + Vec128().MaxLanes()> v) { +#if defined(__OPTIMIZE__) && (!HWY_COMPILER_CLANG || !HWY_S390X_HAVE_Z14) + if (detail::IsConstantRawAltivecVect(v.raw)) { + constexpr int64_t kMinI64 = LimitsMin(); + constexpr int64_t kMaxI64 = LimitsMax(); + return Dup128VecFromValues(D(), + (v.raw[0] >= -9223372036854775808.0) + ? ((v.raw[0] < 9223372036854775808.0) + ? static_cast(v.raw[0]) + : kMaxI64) + : ((v.raw[0] < 0) ? kMinI64 : 0LL), + (v.raw[1] >= -9223372036854775808.0) + ? ((v.raw[1] < 9223372036854775808.0) + ? static_cast(v.raw[1]) + : kMaxI64) + : ((v.raw[1] < 0) ? kMinI64 : 0LL)); + } +#endif + + // Use inline assembly to avoid undefined behavior if v[i] is not within the + // range of an int64_t + __vector signed long long raw_result; +#if HWY_S390X_HAVE_Z14 + __asm__("vcgdb %0,%1,0,5" : "=v"(raw_result) : "v"(v.raw)); +#else + __asm__("xvcvdpsxds %x0,%x1" + : "=wa"(raw_result) + : "wa"(detail::VsxF2INormalizeSrcVals(v).raw)); +#endif + return VFromD{raw_result}; +} + +#if HWY_S390X_HAVE_Z14 && !HWY_S390X_HAVE_Z15 +template +HWY_API VFromD ConvertTo(D du32, + Vec128().MaxLanes()> v) { + const Rebind du64; + return DemoteTo(du32, PromoteTo(du64, v)); +} +template +HWY_API VFromD ConvertTo(D du32, + Vec128().MaxLanes()> v) { + const RepartitionToWide du64; + return OrderedDemote2To(du32, PromoteLowerTo(du64, v), + PromoteUpperTo(du64, v)); +} +#else // Z15 or VSX +template +HWY_API VFromD ConvertTo(D /* tag */, + Vec128().MaxLanes()> v) { +#if defined(__OPTIMIZE__) + if (detail::IsConstantRawAltivecVect(v.raw)) { + constexpr uint32_t kMaxU32 = LimitsMax(); + return Dup128VecFromValues( + D(), + (v.raw[0] >= 0.0f) + ? ((v.raw[0] < 4294967296.0f) ? static_cast(v.raw[0]) + : kMaxU32) + : 0, + (v.raw[1] >= 0.0f) + ? ((v.raw[1] < 4294967296.0f) ? static_cast(v.raw[1]) + : kMaxU32) + : 0, + (v.raw[2] >= 0.0f) + ? ((v.raw[2] < 4294967296.0f) ? static_cast(v.raw[2]) + : kMaxU32) + : 0, + (v.raw[3] >= 0.0f) + ? ((v.raw[3] < 4294967296.0f) ? static_cast(v.raw[3]) + : kMaxU32) + : 0); + } +#endif + +#if HWY_S390X_HAVE_Z15 + // Use inline assembly on Z15 to avoid undefined behavior if v[i] is not in + // the range of an uint32_t + __vector unsigned int raw_result; + __asm__("vclfeb %0,%1,0,5" : "=v"(raw_result) : "v"(v.raw)); + return VFromD{raw_result}; +#else // VSX + HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_CLANG + HWY_DIAGNOSTICS_OFF(disable : 5219, ignored "-Wdeprecate-lax-vec-conv-all") +#endif + VFromD result{vec_ctu(v.raw, 0)}; + HWY_DIAGNOSTICS(pop) + return result; +#endif // HWY_S390X_HAVE_Z15 +} +#endif // HWY_S390X_HAVE_Z14 && !HWY_S390X_HAVE_Z15 + +template +HWY_API VFromD ConvertTo(D /* tag */, + Vec128().MaxLanes()> v) { + HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_CLANG + HWY_DIAGNOSTICS_OFF(disable : 5219, ignored "-Wdeprecate-lax-vec-conv-all") +#endif + +#if defined(__OPTIMIZE__) && (!HWY_COMPILER_CLANG || !HWY_S390X_HAVE_Z14) + if (detail::IsConstantRawAltivecVect(v.raw)) { + constexpr uint64_t kMaxU64 = LimitsMax(); + return Dup128VecFromValues( + D(), + (v.raw[0] >= 0.0) ? ((v.raw[0] < 18446744073709551616.0) + ? static_cast(v.raw[0]) + : kMaxU64) + : 0, + (v.raw[1] >= 0.0) ? ((v.raw[1] < 18446744073709551616.0) + ? static_cast(v.raw[1]) + : kMaxU64) + : 0); + } +#endif + + // Use inline assembly to avoid undefined behavior if v[i] is not within the + // range of an uint64_t + __vector unsigned long long raw_result; +#if HWY_S390X_HAVE_Z14 + __asm__("vclgdb %0,%1,0,5" : "=v"(raw_result) : "v"(v.raw)); +#else // VSX + __asm__("xvcvdpuxds %x0,%x1" + : "=wa"(raw_result) + : "wa"(detail::VsxF2INormalizeSrcVals(v).raw)); +#endif + return VFromD{raw_result}; +} + +// ------------------------------ Floating-point rounding (ConvertTo) + +// Toward nearest integer, ties to even +template +HWY_API Vec128 Round(Vec128 v) { + return Vec128{vec_round(v.raw)}; +} + +template +HWY_API Vec128 Round(Vec128 v) { +#if HWY_S390X_HAVE_Z14 + return Vec128{vec_round(v.raw)}; +#else + return Vec128{vec_rint(v.raw)}; +#endif +} + +template +HWY_API Vec128, N> NearestInt(Vec128 v) { + const DFromV d; + const RebindToSigned di; + return ConvertTo(di, Round(v)); +} + +template +HWY_API VFromD DemoteToNearestInt(DI32 di32, + VFromD> v) { + return DemoteTo(di32, Round(v)); +} + +// Toward zero, aka truncate +template +HWY_API Vec128 Trunc(Vec128 v) { + return Vec128{vec_trunc(v.raw)}; +} + +// Toward +infinity, aka ceiling +template +HWY_API Vec128 Ceil(Vec128 v) { + return Vec128{vec_ceil(v.raw)}; +} + +// Toward -infinity, aka floor +template +HWY_API Vec128 Floor(Vec128 v) { + return Vec128{vec_floor(v.raw)}; +} + +// ------------------------------ Floating-point classification + +template +HWY_API Mask128 IsNaN(Vec128 v) { + static_assert(IsFloat(), "Only for float"); + return v != v; +} + +template +HWY_API Mask128 IsInf(Vec128 v) { + static_assert(IsFloat(), "Only for float"); + using TU = MakeUnsigned; + const DFromV d; + const RebindToUnsigned du; + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask( + d, + Eq(Add(vu, vu), Set(du, static_cast(hwy::MaxExponentTimes2())))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask128 IsFinite(Vec128 v) { + static_assert(IsFloat(), "Only for float"); + using TU = MakeUnsigned; + const DFromV d; + const RebindToUnsigned du; + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent(hwy::MaxExponentTimes2())))); +} + +// ================================================== CRYPTO + +#if !HWY_S390X_HAVE_Z14 && !defined(HWY_DISABLE_PPC8_CRYPTO) + +// Per-target flag to prevent generic_ops-inl.h from defining AESRound. +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +namespace detail { +#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1600 +using CipherTag = Full128; +#else +using CipherTag = Full128; +#endif // !HWY_COMPILER_CLANG +using CipherVec = VFromD; +} // namespace detail + +HWY_API Vec128 AESRound(Vec128 state, + Vec128 round_key) { + const detail::CipherTag dc; + const Full128 du8; +#if HWY_IS_LITTLE_ENDIAN + return Reverse(du8, + BitCast(du8, detail::CipherVec{vec_cipher_be( + BitCast(dc, Reverse(du8, state)).raw, + BitCast(dc, Reverse(du8, round_key)).raw)})); +#else + return BitCast(du8, detail::CipherVec{vec_cipher_be( + BitCast(dc, state).raw, BitCast(dc, round_key).raw)}); +#endif +} + +HWY_API Vec128 AESLastRound(Vec128 state, + Vec128 round_key) { + const detail::CipherTag dc; + const Full128 du8; +#if HWY_IS_LITTLE_ENDIAN + return Reverse(du8, + BitCast(du8, detail::CipherVec{vec_cipherlast_be( + BitCast(dc, Reverse(du8, state)).raw, + BitCast(dc, Reverse(du8, round_key)).raw)})); +#else + return BitCast(du8, detail::CipherVec{vec_cipherlast_be( + BitCast(dc, state).raw, BitCast(dc, round_key).raw)}); +#endif +} + +HWY_API Vec128 AESRoundInv(Vec128 state, + Vec128 round_key) { + const detail::CipherTag dc; + const Full128 du8; +#if HWY_IS_LITTLE_ENDIAN + return Xor(Reverse(du8, BitCast(du8, detail::CipherVec{vec_ncipher_be( + BitCast(dc, Reverse(du8, state)).raw, + Zero(dc).raw)})), + round_key); +#else + return Xor(BitCast(du8, detail::CipherVec{vec_ncipher_be( + BitCast(dc, state).raw, Zero(dc).raw)}), + round_key); +#endif +} + +HWY_API Vec128 AESLastRoundInv(Vec128 state, + Vec128 round_key) { + const detail::CipherTag dc; + const Full128 du8; +#if HWY_IS_LITTLE_ENDIAN + return Reverse(du8, + BitCast(du8, detail::CipherVec{vec_ncipherlast_be( + BitCast(dc, Reverse(du8, state)).raw, + BitCast(dc, Reverse(du8, round_key)).raw)})); +#else + return BitCast(du8, detail::CipherVec{vec_ncipherlast_be( + BitCast(dc, state).raw, BitCast(dc, round_key).raw)}); +#endif +} + +HWY_API Vec128 AESInvMixColumns(Vec128 state) { + const Full128 du8; + const auto zero = Zero(du8); + + // PPC8/PPC9/PPC10 does not have a single instruction for the AES + // InvMixColumns operation like ARM Crypto, SVE2 Crypto, or AES-NI do. + + // The AESInvMixColumns operation can be carried out on PPC8/PPC9/PPC10 + // by doing an AESLastRound operation with a zero round_key followed by an + // AESRoundInv operation with a zero round_key. + return AESRoundInv(AESLastRound(state, zero), zero); +} + +template +HWY_API Vec128 AESKeyGenAssist(Vec128 v) { + constexpr __vector unsigned char kRconXorMask = {0, 0, 0, 0, kRcon, 0, 0, 0, + 0, 0, 0, 0, kRcon, 0, 0, 0}; + constexpr __vector unsigned char kRotWordShuffle = { + 4, 5, 6, 7, 5, 6, 7, 4, 12, 13, 14, 15, 13, 14, 15, 12}; + const detail::CipherTag dc; + const Full128 du8; + const auto sub_word_result = + BitCast(du8, detail::CipherVec{vec_sbox_be(BitCast(dc, v).raw)}); + const auto rot_word_result = + TableLookupBytes(sub_word_result, Vec128{kRotWordShuffle}); + return Xor(rot_word_result, Vec128{kRconXorMask}); +} + +template +HWY_API Vec128 CLMulLower(Vec128 a, + Vec128 b) { + // NOTE: Lane 1 of both a and b need to be zeroed out for the + // vec_pmsum_be operation below as the vec_pmsum_be operation + // does a carryless multiplication of each 64-bit half and then + // adds the two halves using an bitwise XOR operation. + + const DFromV d; + const auto zero = Zero(d); + + using VU64 = __vector unsigned long long; + const VU64 pmsum_result = reinterpret_cast( + vec_pmsum_be(InterleaveLower(a, zero).raw, InterleaveLower(b, zero).raw)); + +#if HWY_IS_LITTLE_ENDIAN + return Vec128{pmsum_result}; +#else + // Need to swap the two halves of pmsum_result on big-endian targets as + // the upper 64 bits of the carryless multiplication result are in lane 0 of + // pmsum_result and the lower 64 bits of the carryless multiplication result + // are in lane 1 of mul128_result + return Vec128{vec_sld(pmsum_result, pmsum_result, 8)}; +#endif +} + +template +HWY_API Vec128 CLMulUpper(Vec128 a, + Vec128 b) { + // NOTE: Lane 0 of both a and b need to be zeroed out for the + // vec_pmsum_be operation below as the vec_pmsum_be operation + // does a carryless multiplication of each 64-bit half and then + // adds the two halves using an bitwise XOR operation. + + const DFromV d; + const auto zero = Zero(d); + + using VU64 = __vector unsigned long long; + const VU64 pmsum_result = reinterpret_cast( + vec_pmsum_be(vec_mergel(zero.raw, a.raw), vec_mergel(zero.raw, b.raw))); + +#if HWY_IS_LITTLE_ENDIAN + return Vec128{pmsum_result}; +#else + // Need to swap the two halves of pmsum_result on big-endian targets as + // the upper 64 bits of the carryless multiplication result are in lane 0 of + // pmsum_result and the lower 64 bits of the carryless multiplication result + // are in lane 1 of mul128_result + return Vec128{vec_sld(pmsum_result, pmsum_result, 8)}; +#endif +} + +#endif // !defined(HWY_DISABLE_PPC8_CRYPTO) + +// ================================================== MISC + +// ------------------------------ LoadMaskBits (TestBit) + +namespace detail { + +template +HWY_INLINE MFromD LoadMaskBits128(D /*d*/, uint64_t mask_bits) { +#if HWY_PPC_HAVE_10 + const Vec128 mask_vec{vec_genbm(mask_bits)}; + +#if HWY_IS_LITTLE_ENDIAN + return MFromD{MaskFromVec(mask_vec).raw}; +#else + return MFromD{MaskFromVec(Reverse(Full128(), mask_vec)).raw}; +#endif // HWY_IS_LITTLE_ENDIAN + +#else // PPC9 or earlier + const Full128 du8; + const Full128 du16; + const Vec128 vbits = + BitCast(du8, Set(du16, static_cast(mask_bits))); + + // Replicate bytes 8x such that each byte contains the bit that governs it. +#if HWY_IS_LITTLE_ENDIAN + const __vector unsigned char kRep8 = {0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1}; +#else + const __vector unsigned char kRep8 = {1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 0, 0}; +#endif // HWY_IS_LITTLE_ENDIAN + + const Vec128 rep8{vec_perm(vbits.raw, vbits.raw, kRep8)}; + const __vector unsigned char kBit = {1, 2, 4, 8, 16, 32, 64, 128, + 1, 2, 4, 8, 16, 32, 64, 128}; + return MFromD{TestBit(rep8, Vec128{kBit}).raw}; +#endif // HWY_PPC_HAVE_10 +} + +template +HWY_INLINE MFromD LoadMaskBits128(D /*d*/, uint64_t mask_bits) { +#if HWY_PPC_HAVE_10 + const Vec128 mask_vec{vec_genhm(mask_bits)}; + +#if HWY_IS_LITTLE_ENDIAN + return MFromD{MaskFromVec(mask_vec).raw}; +#else + return MFromD{MaskFromVec(Reverse(Full128(), mask_vec)).raw}; +#endif // HWY_IS_LITTLE_ENDIAN + +#else // PPC9 or earlier + const __vector unsigned short kBit = {1, 2, 4, 8, 16, 32, 64, 128}; + const auto vmask_bits = + Set(Full128(), static_cast(mask_bits)); + return MFromD{TestBit(vmask_bits, Vec128{kBit}).raw}; +#endif // HWY_PPC_HAVE_10 +} + +template +HWY_INLINE MFromD LoadMaskBits128(D /*d*/, uint64_t mask_bits) { +#if HWY_PPC_HAVE_10 + const Vec128 mask_vec{vec_genwm(mask_bits)}; + +#if HWY_IS_LITTLE_ENDIAN + return MFromD{MaskFromVec(mask_vec).raw}; +#else + return MFromD{MaskFromVec(Reverse(Full128(), mask_vec)).raw}; +#endif // HWY_IS_LITTLE_ENDIAN + +#else // PPC9 or earlier + const __vector unsigned int kBit = {1, 2, 4, 8}; + const auto vmask_bits = + Set(Full128(), static_cast(mask_bits)); + return MFromD{TestBit(vmask_bits, Vec128{kBit}).raw}; +#endif // HWY_PPC_HAVE_10 +} + +template +HWY_INLINE MFromD LoadMaskBits128(D /*d*/, uint64_t mask_bits) { +#if HWY_PPC_HAVE_10 + const Vec128 mask_vec{vec_gendm(mask_bits)}; + +#if HWY_IS_LITTLE_ENDIAN + return MFromD{MaskFromVec(mask_vec).raw}; +#else + return MFromD{MaskFromVec(Reverse(Full128(), mask_vec)).raw}; +#endif // HWY_IS_LITTLE_ENDIAN + +#else // PPC9 or earlier + const __vector unsigned long long kBit = {1, 2}; + const auto vmask_bits = + Set(Full128(), static_cast(mask_bits)); + return MFromD{TestBit(vmask_bits, Vec128{kBit}).raw}; +#endif // HWY_PPC_HAVE_10 +} + +} // namespace detail + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + // If there are 8 or fewer lanes, simply convert bits[0] to a uint64_t + uint64_t mask_bits = bits[0]; + + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= (1u << kN) - 1; + + return detail::LoadMaskBits128(d, mask_bits); +} + +template +HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + // First, copy the mask bits to a uint16_t as there as there are at most + // 16 lanes in a vector. + + // Copying the mask bits to a uint16_t first will also ensure that the + // mask bits are loaded into the lower 16 bits on big-endian PPC targets. + uint16_t u16_mask_bits; + CopyBytes(bits, &u16_mask_bits); + +#if HWY_IS_LITTLE_ENDIAN + return detail::LoadMaskBits128(d, u16_mask_bits); +#else + // On big-endian targets, u16_mask_bits need to be byte swapped as bits + // contains the mask bits in little-endian byte order + + // GCC/Clang will optimize the load of u16_mask_bits and byte swap to a + // single lhbrx instruction on big-endian PPC targets when optimizations + // are enabled. +#if HWY_HAS_BUILTIN(__builtin_bswap16) + return detail::LoadMaskBits128(d, __builtin_bswap16(u16_mask_bits)); +#else + return detail::LoadMaskBits128( + d, static_cast((u16_mask_bits << 8) | (u16_mask_bits >> 8))); +#endif +#endif +} + +template +struct CompressIsPartition { + // generic_ops-inl does not guarantee IsPartition for 8-bit. + enum { value = (sizeof(T) != 1) }; +}; + +// ------------------------------ Dup128MaskFromMaskBits + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= (1u << kN) - 1; + return detail::LoadMaskBits128(d, mask_bits); +} + +// ------------------------------ StoreMaskBits + +namespace detail { + +// Returns the lowest N of the mask bits. +template +constexpr uint64_t OnlyActive(D d, uint64_t mask_bits) { + return (d.MaxBytes() == 16) ? mask_bits + : mask_bits & ((1ull << d.MaxLanes()) - 1); +} + +#if !HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN +// fallback for missing vec_extractm +template +HWY_INLINE uint64_t ExtractSignBits(Vec128 sign_bits, + __vector unsigned char bit_shuffle) { + // clang POWER8 and 9 targets appear to differ in their return type of + // vec_vbpermq: unsigned or signed, so cast to avoid a warning. + using VU64 = detail::Raw128::type; +#if HWY_S390X_HAVE_Z14 + +#if HWY_COMPILER_GCC_ACTUAL >= 1500 || HWY_COMPILER_CLANG >= 2100 + // GCC 15 and Clang 20 have added the vec_bperm intrinsic + + // Need to use vec_bperm instead of vec_bperm_u128 with GCC 15 and later to + // avoid compiler warning + using VU128 = __vector unsigned __int128; + const Vec128 extracted{reinterpret_cast( + vec_bperm(reinterpret_cast(sign_bits.raw), bit_shuffle))}; +#else // !(HWY_COMPILER_GCC_ACTUAL >= 1500 || HWY_COMPILER_CLANG >= 2100) + const Vec128 extracted{ + reinterpret_cast(vec_bperm_u128(sign_bits.raw, bit_shuffle))}; +#endif // HWY_COMPILER_GCC_ACTUAL >= 1500 || HWY_COMPILER_CLANG >= 2100 + +#else // !HWY_S390X_HAVE_Z14 + const Vec128 extracted{ + reinterpret_cast(vec_vbpermq(sign_bits.raw, bit_shuffle))}; +#endif + return extracted.raw[HWY_IS_LITTLE_ENDIAN]; +} + +#endif // !HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN + +} // namespace detail + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const Repartition du8; + const VFromD sign_bits = BitCast(du8, VecFromMask(d, mask)); + +#if HWY_PPC_HAVE_10 && HWY_IS_LITTLE_ENDIAN + return detail::OnlyActive(d, + static_cast(vec_extractm(sign_bits.raw))); +#else // Z14, Z15, PPC8, PPC9, or big-endian PPC10 + const __vector unsigned char kBitShuffle = {120, 112, 104, 96, 88, 80, 72, 64, + 56, 48, 40, 32, 24, 16, 8, 0}; + return detail::OnlyActive(d, detail::ExtractSignBits(sign_bits, kBitShuffle)); +#endif // HWY_PPC_HAVE_10 && HWY_IS_LITTLE_ENDIAN +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const RebindToUnsigned du; + + const Repartition du8; + const VFromD sign_bits = BitCast(du8, VecFromMask(d, mask)); + +#if HWY_PPC_HAVE_10 && HWY_IS_LITTLE_ENDIAN + return detail::OnlyActive( + d, static_cast(vec_extractm(BitCast(du, sign_bits).raw))); +#else // Z14, Z15, PPC8, PPC9, or big-endian PPC10 + (void)du; +#if HWY_IS_LITTLE_ENDIAN + const __vector unsigned char kBitShuffle = { + 112, 96, 80, 64, 48, 32, 16, 0, 128, 128, 128, 128, 128, 128, 128, 128}; +#else + const __vector unsigned char kBitShuffle = { + 128, 128, 128, 128, 128, 128, 128, 128, 112, 96, 80, 64, 48, 32, 16, 0}; +#endif + return detail::OnlyActive(d, detail::ExtractSignBits(sign_bits, kBitShuffle)); +#endif // HWY_PPC_HAVE_10 +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const RebindToUnsigned du; + + const Repartition du8; + const VFromD sign_bits = BitCast(du8, VecFromMask(d, mask)); + +#if HWY_PPC_HAVE_10 && HWY_IS_LITTLE_ENDIAN + return detail::OnlyActive( + d, static_cast(vec_extractm(BitCast(du, sign_bits).raw))); +#else // Z14, Z15, PPC8, PPC9, or big-endian PPC10 + (void)du; +#if HWY_IS_LITTLE_ENDIAN + const __vector unsigned char kBitShuffle = {96, 64, 32, 0, 128, 128, + 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128}; +#else + const __vector unsigned char kBitShuffle = {128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, + 96, 64, 32, 0}; +#endif + return detail::OnlyActive(d, detail::ExtractSignBits(sign_bits, kBitShuffle)); +#endif // HWY_PPC_HAVE_10 +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const RebindToUnsigned du; + + const Repartition du8; + const VFromD sign_bits = BitCast(du8, VecFromMask(d, mask)); + +#if HWY_PPC_HAVE_10 && HWY_IS_LITTLE_ENDIAN + return detail::OnlyActive( + d, static_cast(vec_extractm(BitCast(du, sign_bits).raw))); +#else // Z14, Z15, PPC8, PPC9, or big-endian PPC10 + (void)du; +#if HWY_IS_LITTLE_ENDIAN + const __vector unsigned char kBitShuffle = {64, 0, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128}; +#else + const __vector unsigned char kBitShuffle = {128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, + 128, 128, 64, 0}; +#endif + return detail::OnlyActive(d, detail::ExtractSignBits(sign_bits, kBitShuffle)); +#endif // HWY_PPC_HAVE_10 +} + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(D d, MFromD mask, uint8_t* bits) { + // For vectors with 8 or fewer lanes, simply cast the result of BitsFromMask + // to an uint8_t and store the result in bits[0]. + bits[0] = static_cast(BitsFromMask(d, mask)); + return sizeof(uint8_t); +} + +template +HWY_API size_t StoreMaskBits(D d, MFromD mask, uint8_t* bits) { + const auto mask_bits = BitsFromMask(d, mask); + + // First convert mask_bits to a uint16_t as we only want to store + // the lower 16 bits of mask_bits as there are 16 lanes in mask. + + // Converting mask_bits to a uint16_t first will also ensure that + // the lower 16 bits of mask_bits are stored instead of the upper 16 bits + // of mask_bits on big-endian PPC targets. +#if HWY_IS_LITTLE_ENDIAN + const uint16_t u16_mask_bits = static_cast(mask_bits); +#else + // On big-endian targets, the bytes of mask_bits need to be swapped + // as StoreMaskBits expects the mask bits to be stored in little-endian + // byte order. + + // GCC will also optimize the byte swap and CopyBytes operations below + // to a single sthbrx instruction when optimizations are enabled on + // big-endian PPC targets +#if HWY_HAS_BUILTIN(__builtin_bswap16) + const uint16_t u16_mask_bits = + __builtin_bswap16(static_cast(mask_bits)); +#else + const uint16_t u16_mask_bits = static_cast( + (mask_bits << 8) | (static_cast(mask_bits) >> 8)); +#endif +#endif + + CopyBytes(&u16_mask_bits, bits); + return sizeof(uint16_t); +} + +// ------------------------------ Mask testing + +template +HWY_API bool AllFalse(D d, MFromD mask) { + const RebindToUnsigned du; + return static_cast( + vec_all_eq(VecFromMask(du, RebindMask(du, mask)).raw, Zero(du).raw)); +} + +template +HWY_API bool AllTrue(D d, MFromD mask) { + const RebindToUnsigned du; + using TU = TFromD; + return static_cast(vec_all_eq(VecFromMask(du, RebindMask(du, mask)).raw, + Set(du, hwy::LimitsMax()).raw)); +} + +template +HWY_API bool AllFalse(D d, MFromD mask) { + const Full128> d_full; + constexpr size_t kN = MaxLanes(d); + return AllFalse(d_full, + And(MFromD{mask.raw}, FirstN(d_full, kN))); +} + +template +HWY_API bool AllTrue(D d, MFromD mask) { + const Full128> d_full; + constexpr size_t kN = MaxLanes(d); + return AllTrue( + d_full, Or(MFromD{mask.raw}, Not(FirstN(d_full, kN)))); +} + +template +HWY_API size_t CountTrue(D d, MFromD mask) { + return PopCount(BitsFromMask(d, mask)); +} + +#if HWY_PPC_HAVE_9 && (!HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN) +namespace detail { + +template +static HWY_INLINE size_t VsxCntlzLsbb(V v) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1200 && \ + HWY_IS_LITTLE_ENDIAN + // Use inline assembly to work around bug in GCC 11 and earlier on + // little-endian PPC9 + int idx; + __asm__("vctzlsbb %0,%1" : "=r"(idx) : "v"(v.raw)); + return static_cast(idx); +#else + return static_cast(vec_cntlz_lsbb(v.raw)); +#endif +} + +template +static HWY_INLINE size_t VsxCnttzLsbb(V v) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1200 && \ + HWY_IS_LITTLE_ENDIAN + // Use inline assembly to work around bug in GCC 11 and earlier on + // little-endian PPC9 + int idx; + __asm__("vclzlsbb %0,%1" : "=r"(idx) : "v"(v.raw)); + return static_cast(idx); +#else + return static_cast(vec_cnttz_lsbb(v.raw)); +#endif +} + +} // namespace detail +#endif + +template > +HWY_API size_t FindKnownFirstTrue(D d, MFromD mask) { +// For little-endian PPC10, BitsFromMask is already efficient. +#if HWY_PPC_HAVE_9 && (!HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN) + if (detail::IsFull(d)) { + const Repartition d8; + const auto bytes = BitCast(d8, VecFromMask(d, mask)); + return detail::VsxCntlzLsbb(bytes) / sizeof(T); + } +#endif // HWY_PPC_HAVE_9 && (!HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN) + return Num0BitsBelowLS1Bit_Nonzero64(BitsFromMask(d, mask)); +} + +template > +HWY_API intptr_t FindFirstTrue(D d, MFromD mask) { +// For little-endian PPC10, BitsFromMask is already efficient. +#if HWY_PPC_HAVE_9 && (!HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN) + constexpr size_t kN = 16 / sizeof(T); + if (detail::IsFull(d)) { + const Repartition d8; + const auto bytes = BitCast(d8, VecFromMask(d, mask)); + const size_t idx = detail::VsxCntlzLsbb(bytes) / sizeof(T); + return idx == kN ? -1 : static_cast(idx); + } +#endif // HWY_PPC_HAVE_9 && (!HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN) + const uint64_t mask_bits = BitsFromMask(d, mask); + return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero64(mask_bits)) : -1; +} + +template > +HWY_API size_t FindKnownLastTrue(D d, MFromD mask) { +// For little-endian PPC10, BitsFromMask is already efficient. +#if HWY_PPC_HAVE_9 && (!HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN) + if (detail::IsFull(d)) { + const Repartition d8; + const auto bytes = BitCast(d8, VecFromMask(d, mask)); + const size_t idx = detail::VsxCnttzLsbb(bytes) / sizeof(T); + return 16 / sizeof(T) - 1 - idx; + } +#endif // HWY_PPC_HAVE_9 && (!HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN) + return 63 - Num0BitsAboveMS1Bit_Nonzero64(BitsFromMask(d, mask)); +} + +template > +HWY_API intptr_t FindLastTrue(D d, MFromD mask) { +// For little-endian PPC10, BitsFromMask is already efficient. +#if HWY_PPC_HAVE_9 && (!HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN) + constexpr size_t kN = 16 / sizeof(T); + if (detail::IsFull(d)) { + const Repartition d8; + const auto bytes = BitCast(d8, VecFromMask(d, mask)); + const size_t idx = detail::VsxCnttzLsbb(bytes) / sizeof(T); + return idx == kN ? -1 : static_cast(kN - 1 - idx); + } +#endif // HWY_PPC_HAVE_9 && (!HWY_PPC_HAVE_10 || HWY_IS_BIG_ENDIAN) + const uint64_t mask_bits = BitsFromMask(d, mask); + return mask_bits ? intptr_t(63 - Num0BitsAboveMS1Bit_Nonzero64(mask_bits)) + : -1; +} + +// ------------------------------ Compress, CompressBits + +namespace detail { + +#if HWY_PPC_HAVE_10 +template +HWY_INLINE VFromD CompressOrExpandIndicesFromMask(D d, MFromD mask) { + constexpr unsigned kGenPcvmMode = + (kIsCompress ? 1u : 0u) | (HWY_IS_LITTLE_ENDIAN ? 2u : 0u); + + // Inline assembly is used instead of the vec_genpcvm intrinsic to work around + // compiler bugs on little-endian PPC10 + typename detail::Raw128>::type idx; + __asm__("xxgenpcvbm %x0, %1, %2" + : "=wa"(idx) + : "v"(mask.raw), "i"(kGenPcvmMode)); + return VFromD{idx}; +} +template +HWY_INLINE VFromD CompressOrExpandIndicesFromMask(D d, MFromD mask) { + constexpr unsigned kGenPcvmMode = + (kIsCompress ? 1u : 0u) | (HWY_IS_LITTLE_ENDIAN ? 2u : 0u); + + // Inline assembly is used instead of the vec_genpcvm intrinsic to work around + // compiler bugs on little-endian PPC10 + typename detail::Raw128>::type idx; + __asm__("xxgenpcvhm %x0, %1, %2" + : "=wa"(idx) + : "v"(mask.raw), "i"(kGenPcvmMode)); + return VFromD{idx}; +} +template +HWY_INLINE VFromD CompressOrExpandIndicesFromMask(D d, MFromD mask) { + constexpr unsigned kGenPcvmMode = + (kIsCompress ? 1u : 0u) | (HWY_IS_LITTLE_ENDIAN ? 2u : 0u); + + // Inline assembly is used instead of the vec_genpcvm intrinsic to work around + // compiler bugs on little-endian PPC10 + typename detail::Raw128>::type idx; + __asm__("xxgenpcvwm %x0, %1, %2" + : "=wa"(idx) + : "v"(mask.raw), "i"(kGenPcvmMode)); + return VFromD{idx}; +} +#endif + +// Also works for N < 8 because the first 16 4-tuples only reference bytes 0-6. +template +HWY_INLINE VFromD IndicesFromBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Rebind d8; + const Twice d8t; + const RebindToUnsigned du; + + // To reduce cache footprint, store lane indices and convert to byte indices + // (2*lane + 0..1), with the doubling baked into the table. It's not clear + // that the additional cost of unpacking nibbles is worthwhile. + alignas(16) static constexpr uint8_t table[2048] = { + // PrintCompress16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 2, 0, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 4, 0, 2, 6, 8, 10, 12, 14, /**/ 0, 4, 2, 6, 8, 10, 12, 14, // + 2, 4, 0, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 6, 0, 2, 4, 8, 10, 12, 14, /**/ 0, 6, 2, 4, 8, 10, 12, 14, // + 2, 6, 0, 4, 8, 10, 12, 14, /**/ 0, 2, 6, 4, 8, 10, 12, 14, // + 4, 6, 0, 2, 8, 10, 12, 14, /**/ 0, 4, 6, 2, 8, 10, 12, 14, // + 2, 4, 6, 0, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 8, 0, 2, 4, 6, 10, 12, 14, /**/ 0, 8, 2, 4, 6, 10, 12, 14, // + 2, 8, 0, 4, 6, 10, 12, 14, /**/ 0, 2, 8, 4, 6, 10, 12, 14, // + 4, 8, 0, 2, 6, 10, 12, 14, /**/ 0, 4, 8, 2, 6, 10, 12, 14, // + 2, 4, 8, 0, 6, 10, 12, 14, /**/ 0, 2, 4, 8, 6, 10, 12, 14, // + 6, 8, 0, 2, 4, 10, 12, 14, /**/ 0, 6, 8, 2, 4, 10, 12, 14, // + 2, 6, 8, 0, 4, 10, 12, 14, /**/ 0, 2, 6, 8, 4, 10, 12, 14, // + 4, 6, 8, 0, 2, 10, 12, 14, /**/ 0, 4, 6, 8, 2, 10, 12, 14, // + 2, 4, 6, 8, 0, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 10, 0, 2, 4, 6, 8, 12, 14, /**/ 0, 10, 2, 4, 6, 8, 12, 14, // + 2, 10, 0, 4, 6, 8, 12, 14, /**/ 0, 2, 10, 4, 6, 8, 12, 14, // + 4, 10, 0, 2, 6, 8, 12, 14, /**/ 0, 4, 10, 2, 6, 8, 12, 14, // + 2, 4, 10, 0, 6, 8, 12, 14, /**/ 0, 2, 4, 10, 6, 8, 12, 14, // + 6, 10, 0, 2, 4, 8, 12, 14, /**/ 0, 6, 10, 2, 4, 8, 12, 14, // + 2, 6, 10, 0, 4, 8, 12, 14, /**/ 0, 2, 6, 10, 4, 8, 12, 14, // + 4, 6, 10, 0, 2, 8, 12, 14, /**/ 0, 4, 6, 10, 2, 8, 12, 14, // + 2, 4, 6, 10, 0, 8, 12, 14, /**/ 0, 2, 4, 6, 10, 8, 12, 14, // + 8, 10, 0, 2, 4, 6, 12, 14, /**/ 0, 8, 10, 2, 4, 6, 12, 14, // + 2, 8, 10, 0, 4, 6, 12, 14, /**/ 0, 2, 8, 10, 4, 6, 12, 14, // + 4, 8, 10, 0, 2, 6, 12, 14, /**/ 0, 4, 8, 10, 2, 6, 12, 14, // + 2, 4, 8, 10, 0, 6, 12, 14, /**/ 0, 2, 4, 8, 10, 6, 12, 14, // + 6, 8, 10, 0, 2, 4, 12, 14, /**/ 0, 6, 8, 10, 2, 4, 12, 14, // + 2, 6, 8, 10, 0, 4, 12, 14, /**/ 0, 2, 6, 8, 10, 4, 12, 14, // + 4, 6, 8, 10, 0, 2, 12, 14, /**/ 0, 4, 6, 8, 10, 2, 12, 14, // + 2, 4, 6, 8, 10, 0, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 12, 0, 2, 4, 6, 8, 10, 14, /**/ 0, 12, 2, 4, 6, 8, 10, 14, // + 2, 12, 0, 4, 6, 8, 10, 14, /**/ 0, 2, 12, 4, 6, 8, 10, 14, // + 4, 12, 0, 2, 6, 8, 10, 14, /**/ 0, 4, 12, 2, 6, 8, 10, 14, // + 2, 4, 12, 0, 6, 8, 10, 14, /**/ 0, 2, 4, 12, 6, 8, 10, 14, // + 6, 12, 0, 2, 4, 8, 10, 14, /**/ 0, 6, 12, 2, 4, 8, 10, 14, // + 2, 6, 12, 0, 4, 8, 10, 14, /**/ 0, 2, 6, 12, 4, 8, 10, 14, // + 4, 6, 12, 0, 2, 8, 10, 14, /**/ 0, 4, 6, 12, 2, 8, 10, 14, // + 2, 4, 6, 12, 0, 8, 10, 14, /**/ 0, 2, 4, 6, 12, 8, 10, 14, // + 8, 12, 0, 2, 4, 6, 10, 14, /**/ 0, 8, 12, 2, 4, 6, 10, 14, // + 2, 8, 12, 0, 4, 6, 10, 14, /**/ 0, 2, 8, 12, 4, 6, 10, 14, // + 4, 8, 12, 0, 2, 6, 10, 14, /**/ 0, 4, 8, 12, 2, 6, 10, 14, // + 2, 4, 8, 12, 0, 6, 10, 14, /**/ 0, 2, 4, 8, 12, 6, 10, 14, // + 6, 8, 12, 0, 2, 4, 10, 14, /**/ 0, 6, 8, 12, 2, 4, 10, 14, // + 2, 6, 8, 12, 0, 4, 10, 14, /**/ 0, 2, 6, 8, 12, 4, 10, 14, // + 4, 6, 8, 12, 0, 2, 10, 14, /**/ 0, 4, 6, 8, 12, 2, 10, 14, // + 2, 4, 6, 8, 12, 0, 10, 14, /**/ 0, 2, 4, 6, 8, 12, 10, 14, // + 10, 12, 0, 2, 4, 6, 8, 14, /**/ 0, 10, 12, 2, 4, 6, 8, 14, // + 2, 10, 12, 0, 4, 6, 8, 14, /**/ 0, 2, 10, 12, 4, 6, 8, 14, // + 4, 10, 12, 0, 2, 6, 8, 14, /**/ 0, 4, 10, 12, 2, 6, 8, 14, // + 2, 4, 10, 12, 0, 6, 8, 14, /**/ 0, 2, 4, 10, 12, 6, 8, 14, // + 6, 10, 12, 0, 2, 4, 8, 14, /**/ 0, 6, 10, 12, 2, 4, 8, 14, // + 2, 6, 10, 12, 0, 4, 8, 14, /**/ 0, 2, 6, 10, 12, 4, 8, 14, // + 4, 6, 10, 12, 0, 2, 8, 14, /**/ 0, 4, 6, 10, 12, 2, 8, 14, // + 2, 4, 6, 10, 12, 0, 8, 14, /**/ 0, 2, 4, 6, 10, 12, 8, 14, // + 8, 10, 12, 0, 2, 4, 6, 14, /**/ 0, 8, 10, 12, 2, 4, 6, 14, // + 2, 8, 10, 12, 0, 4, 6, 14, /**/ 0, 2, 8, 10, 12, 4, 6, 14, // + 4, 8, 10, 12, 0, 2, 6, 14, /**/ 0, 4, 8, 10, 12, 2, 6, 14, // + 2, 4, 8, 10, 12, 0, 6, 14, /**/ 0, 2, 4, 8, 10, 12, 6, 14, // + 6, 8, 10, 12, 0, 2, 4, 14, /**/ 0, 6, 8, 10, 12, 2, 4, 14, // + 2, 6, 8, 10, 12, 0, 4, 14, /**/ 0, 2, 6, 8, 10, 12, 4, 14, // + 4, 6, 8, 10, 12, 0, 2, 14, /**/ 0, 4, 6, 8, 10, 12, 2, 14, // + 2, 4, 6, 8, 10, 12, 0, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 14, 0, 2, 4, 6, 8, 10, 12, /**/ 0, 14, 2, 4, 6, 8, 10, 12, // + 2, 14, 0, 4, 6, 8, 10, 12, /**/ 0, 2, 14, 4, 6, 8, 10, 12, // + 4, 14, 0, 2, 6, 8, 10, 12, /**/ 0, 4, 14, 2, 6, 8, 10, 12, // + 2, 4, 14, 0, 6, 8, 10, 12, /**/ 0, 2, 4, 14, 6, 8, 10, 12, // + 6, 14, 0, 2, 4, 8, 10, 12, /**/ 0, 6, 14, 2, 4, 8, 10, 12, // + 2, 6, 14, 0, 4, 8, 10, 12, /**/ 0, 2, 6, 14, 4, 8, 10, 12, // + 4, 6, 14, 0, 2, 8, 10, 12, /**/ 0, 4, 6, 14, 2, 8, 10, 12, // + 2, 4, 6, 14, 0, 8, 10, 12, /**/ 0, 2, 4, 6, 14, 8, 10, 12, // + 8, 14, 0, 2, 4, 6, 10, 12, /**/ 0, 8, 14, 2, 4, 6, 10, 12, // + 2, 8, 14, 0, 4, 6, 10, 12, /**/ 0, 2, 8, 14, 4, 6, 10, 12, // + 4, 8, 14, 0, 2, 6, 10, 12, /**/ 0, 4, 8, 14, 2, 6, 10, 12, // + 2, 4, 8, 14, 0, 6, 10, 12, /**/ 0, 2, 4, 8, 14, 6, 10, 12, // + 6, 8, 14, 0, 2, 4, 10, 12, /**/ 0, 6, 8, 14, 2, 4, 10, 12, // + 2, 6, 8, 14, 0, 4, 10, 12, /**/ 0, 2, 6, 8, 14, 4, 10, 12, // + 4, 6, 8, 14, 0, 2, 10, 12, /**/ 0, 4, 6, 8, 14, 2, 10, 12, // + 2, 4, 6, 8, 14, 0, 10, 12, /**/ 0, 2, 4, 6, 8, 14, 10, 12, // + 10, 14, 0, 2, 4, 6, 8, 12, /**/ 0, 10, 14, 2, 4, 6, 8, 12, // + 2, 10, 14, 0, 4, 6, 8, 12, /**/ 0, 2, 10, 14, 4, 6, 8, 12, // + 4, 10, 14, 0, 2, 6, 8, 12, /**/ 0, 4, 10, 14, 2, 6, 8, 12, // + 2, 4, 10, 14, 0, 6, 8, 12, /**/ 0, 2, 4, 10, 14, 6, 8, 12, // + 6, 10, 14, 0, 2, 4, 8, 12, /**/ 0, 6, 10, 14, 2, 4, 8, 12, // + 2, 6, 10, 14, 0, 4, 8, 12, /**/ 0, 2, 6, 10, 14, 4, 8, 12, // + 4, 6, 10, 14, 0, 2, 8, 12, /**/ 0, 4, 6, 10, 14, 2, 8, 12, // + 2, 4, 6, 10, 14, 0, 8, 12, /**/ 0, 2, 4, 6, 10, 14, 8, 12, // + 8, 10, 14, 0, 2, 4, 6, 12, /**/ 0, 8, 10, 14, 2, 4, 6, 12, // + 2, 8, 10, 14, 0, 4, 6, 12, /**/ 0, 2, 8, 10, 14, 4, 6, 12, // + 4, 8, 10, 14, 0, 2, 6, 12, /**/ 0, 4, 8, 10, 14, 2, 6, 12, // + 2, 4, 8, 10, 14, 0, 6, 12, /**/ 0, 2, 4, 8, 10, 14, 6, 12, // + 6, 8, 10, 14, 0, 2, 4, 12, /**/ 0, 6, 8, 10, 14, 2, 4, 12, // + 2, 6, 8, 10, 14, 0, 4, 12, /**/ 0, 2, 6, 8, 10, 14, 4, 12, // + 4, 6, 8, 10, 14, 0, 2, 12, /**/ 0, 4, 6, 8, 10, 14, 2, 12, // + 2, 4, 6, 8, 10, 14, 0, 12, /**/ 0, 2, 4, 6, 8, 10, 14, 12, // + 12, 14, 0, 2, 4, 6, 8, 10, /**/ 0, 12, 14, 2, 4, 6, 8, 10, // + 2, 12, 14, 0, 4, 6, 8, 10, /**/ 0, 2, 12, 14, 4, 6, 8, 10, // + 4, 12, 14, 0, 2, 6, 8, 10, /**/ 0, 4, 12, 14, 2, 6, 8, 10, // + 2, 4, 12, 14, 0, 6, 8, 10, /**/ 0, 2, 4, 12, 14, 6, 8, 10, // + 6, 12, 14, 0, 2, 4, 8, 10, /**/ 0, 6, 12, 14, 2, 4, 8, 10, // + 2, 6, 12, 14, 0, 4, 8, 10, /**/ 0, 2, 6, 12, 14, 4, 8, 10, // + 4, 6, 12, 14, 0, 2, 8, 10, /**/ 0, 4, 6, 12, 14, 2, 8, 10, // + 2, 4, 6, 12, 14, 0, 8, 10, /**/ 0, 2, 4, 6, 12, 14, 8, 10, // + 8, 12, 14, 0, 2, 4, 6, 10, /**/ 0, 8, 12, 14, 2, 4, 6, 10, // + 2, 8, 12, 14, 0, 4, 6, 10, /**/ 0, 2, 8, 12, 14, 4, 6, 10, // + 4, 8, 12, 14, 0, 2, 6, 10, /**/ 0, 4, 8, 12, 14, 2, 6, 10, // + 2, 4, 8, 12, 14, 0, 6, 10, /**/ 0, 2, 4, 8, 12, 14, 6, 10, // + 6, 8, 12, 14, 0, 2, 4, 10, /**/ 0, 6, 8, 12, 14, 2, 4, 10, // + 2, 6, 8, 12, 14, 0, 4, 10, /**/ 0, 2, 6, 8, 12, 14, 4, 10, // + 4, 6, 8, 12, 14, 0, 2, 10, /**/ 0, 4, 6, 8, 12, 14, 2, 10, // + 2, 4, 6, 8, 12, 14, 0, 10, /**/ 0, 2, 4, 6, 8, 12, 14, 10, // + 10, 12, 14, 0, 2, 4, 6, 8, /**/ 0, 10, 12, 14, 2, 4, 6, 8, // + 2, 10, 12, 14, 0, 4, 6, 8, /**/ 0, 2, 10, 12, 14, 4, 6, 8, // + 4, 10, 12, 14, 0, 2, 6, 8, /**/ 0, 4, 10, 12, 14, 2, 6, 8, // + 2, 4, 10, 12, 14, 0, 6, 8, /**/ 0, 2, 4, 10, 12, 14, 6, 8, // + 6, 10, 12, 14, 0, 2, 4, 8, /**/ 0, 6, 10, 12, 14, 2, 4, 8, // + 2, 6, 10, 12, 14, 0, 4, 8, /**/ 0, 2, 6, 10, 12, 14, 4, 8, // + 4, 6, 10, 12, 14, 0, 2, 8, /**/ 0, 4, 6, 10, 12, 14, 2, 8, // + 2, 4, 6, 10, 12, 14, 0, 8, /**/ 0, 2, 4, 6, 10, 12, 14, 8, // + 8, 10, 12, 14, 0, 2, 4, 6, /**/ 0, 8, 10, 12, 14, 2, 4, 6, // + 2, 8, 10, 12, 14, 0, 4, 6, /**/ 0, 2, 8, 10, 12, 14, 4, 6, // + 4, 8, 10, 12, 14, 0, 2, 6, /**/ 0, 4, 8, 10, 12, 14, 2, 6, // + 2, 4, 8, 10, 12, 14, 0, 6, /**/ 0, 2, 4, 8, 10, 12, 14, 6, // + 6, 8, 10, 12, 14, 0, 2, 4, /**/ 0, 6, 8, 10, 12, 14, 2, 4, // + 2, 6, 8, 10, 12, 14, 0, 4, /**/ 0, 2, 6, 8, 10, 12, 14, 4, // + 4, 6, 8, 10, 12, 14, 0, 2, /**/ 0, 4, 6, 8, 10, 12, 14, 2, // + 2, 4, 6, 8, 10, 12, 14, 0, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const VFromD byte_idx{Load(d8, table + mask_bits * 8).raw}; + const VFromD pairs = ZipLower(byte_idx, byte_idx); + constexpr uint16_t kPairIndexIncrement = + HWY_IS_LITTLE_ENDIAN ? 0x0100 : 0x0001; + + return BitCast(d, pairs + Set(du, kPairIndexIncrement)); +} + +template +HWY_INLINE VFromD IndicesFromNotBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Rebind d8; + const Twice d8t; + const RebindToUnsigned du; + + // To reduce cache footprint, store lane indices and convert to byte indices + // (2*lane + 0..1), with the doubling baked into the table. It's not clear + // that the additional cost of unpacking nibbles is worthwhile. + alignas(16) static constexpr uint8_t table[2048] = { + // PrintCompressNot16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 14, 0, // + 0, 4, 6, 8, 10, 12, 14, 2, /**/ 4, 6, 8, 10, 12, 14, 0, 2, // + 0, 2, 6, 8, 10, 12, 14, 4, /**/ 2, 6, 8, 10, 12, 14, 0, 4, // + 0, 6, 8, 10, 12, 14, 2, 4, /**/ 6, 8, 10, 12, 14, 0, 2, 4, // + 0, 2, 4, 8, 10, 12, 14, 6, /**/ 2, 4, 8, 10, 12, 14, 0, 6, // + 0, 4, 8, 10, 12, 14, 2, 6, /**/ 4, 8, 10, 12, 14, 0, 2, 6, // + 0, 2, 8, 10, 12, 14, 4, 6, /**/ 2, 8, 10, 12, 14, 0, 4, 6, // + 0, 8, 10, 12, 14, 2, 4, 6, /**/ 8, 10, 12, 14, 0, 2, 4, 6, // + 0, 2, 4, 6, 10, 12, 14, 8, /**/ 2, 4, 6, 10, 12, 14, 0, 8, // + 0, 4, 6, 10, 12, 14, 2, 8, /**/ 4, 6, 10, 12, 14, 0, 2, 8, // + 0, 2, 6, 10, 12, 14, 4, 8, /**/ 2, 6, 10, 12, 14, 0, 4, 8, // + 0, 6, 10, 12, 14, 2, 4, 8, /**/ 6, 10, 12, 14, 0, 2, 4, 8, // + 0, 2, 4, 10, 12, 14, 6, 8, /**/ 2, 4, 10, 12, 14, 0, 6, 8, // + 0, 4, 10, 12, 14, 2, 6, 8, /**/ 4, 10, 12, 14, 0, 2, 6, 8, // + 0, 2, 10, 12, 14, 4, 6, 8, /**/ 2, 10, 12, 14, 0, 4, 6, 8, // + 0, 10, 12, 14, 2, 4, 6, 8, /**/ 10, 12, 14, 0, 2, 4, 6, 8, // + 0, 2, 4, 6, 8, 12, 14, 10, /**/ 2, 4, 6, 8, 12, 14, 0, 10, // + 0, 4, 6, 8, 12, 14, 2, 10, /**/ 4, 6, 8, 12, 14, 0, 2, 10, // + 0, 2, 6, 8, 12, 14, 4, 10, /**/ 2, 6, 8, 12, 14, 0, 4, 10, // + 0, 6, 8, 12, 14, 2, 4, 10, /**/ 6, 8, 12, 14, 0, 2, 4, 10, // + 0, 2, 4, 8, 12, 14, 6, 10, /**/ 2, 4, 8, 12, 14, 0, 6, 10, // + 0, 4, 8, 12, 14, 2, 6, 10, /**/ 4, 8, 12, 14, 0, 2, 6, 10, // + 0, 2, 8, 12, 14, 4, 6, 10, /**/ 2, 8, 12, 14, 0, 4, 6, 10, // + 0, 8, 12, 14, 2, 4, 6, 10, /**/ 8, 12, 14, 0, 2, 4, 6, 10, // + 0, 2, 4, 6, 12, 14, 8, 10, /**/ 2, 4, 6, 12, 14, 0, 8, 10, // + 0, 4, 6, 12, 14, 2, 8, 10, /**/ 4, 6, 12, 14, 0, 2, 8, 10, // + 0, 2, 6, 12, 14, 4, 8, 10, /**/ 2, 6, 12, 14, 0, 4, 8, 10, // + 0, 6, 12, 14, 2, 4, 8, 10, /**/ 6, 12, 14, 0, 2, 4, 8, 10, // + 0, 2, 4, 12, 14, 6, 8, 10, /**/ 2, 4, 12, 14, 0, 6, 8, 10, // + 0, 4, 12, 14, 2, 6, 8, 10, /**/ 4, 12, 14, 0, 2, 6, 8, 10, // + 0, 2, 12, 14, 4, 6, 8, 10, /**/ 2, 12, 14, 0, 4, 6, 8, 10, // + 0, 12, 14, 2, 4, 6, 8, 10, /**/ 12, 14, 0, 2, 4, 6, 8, 10, // + 0, 2, 4, 6, 8, 10, 14, 12, /**/ 2, 4, 6, 8, 10, 14, 0, 12, // + 0, 4, 6, 8, 10, 14, 2, 12, /**/ 4, 6, 8, 10, 14, 0, 2, 12, // + 0, 2, 6, 8, 10, 14, 4, 12, /**/ 2, 6, 8, 10, 14, 0, 4, 12, // + 0, 6, 8, 10, 14, 2, 4, 12, /**/ 6, 8, 10, 14, 0, 2, 4, 12, // + 0, 2, 4, 8, 10, 14, 6, 12, /**/ 2, 4, 8, 10, 14, 0, 6, 12, // + 0, 4, 8, 10, 14, 2, 6, 12, /**/ 4, 8, 10, 14, 0, 2, 6, 12, // + 0, 2, 8, 10, 14, 4, 6, 12, /**/ 2, 8, 10, 14, 0, 4, 6, 12, // + 0, 8, 10, 14, 2, 4, 6, 12, /**/ 8, 10, 14, 0, 2, 4, 6, 12, // + 0, 2, 4, 6, 10, 14, 8, 12, /**/ 2, 4, 6, 10, 14, 0, 8, 12, // + 0, 4, 6, 10, 14, 2, 8, 12, /**/ 4, 6, 10, 14, 0, 2, 8, 12, // + 0, 2, 6, 10, 14, 4, 8, 12, /**/ 2, 6, 10, 14, 0, 4, 8, 12, // + 0, 6, 10, 14, 2, 4, 8, 12, /**/ 6, 10, 14, 0, 2, 4, 8, 12, // + 0, 2, 4, 10, 14, 6, 8, 12, /**/ 2, 4, 10, 14, 0, 6, 8, 12, // + 0, 4, 10, 14, 2, 6, 8, 12, /**/ 4, 10, 14, 0, 2, 6, 8, 12, // + 0, 2, 10, 14, 4, 6, 8, 12, /**/ 2, 10, 14, 0, 4, 6, 8, 12, // + 0, 10, 14, 2, 4, 6, 8, 12, /**/ 10, 14, 0, 2, 4, 6, 8, 12, // + 0, 2, 4, 6, 8, 14, 10, 12, /**/ 2, 4, 6, 8, 14, 0, 10, 12, // + 0, 4, 6, 8, 14, 2, 10, 12, /**/ 4, 6, 8, 14, 0, 2, 10, 12, // + 0, 2, 6, 8, 14, 4, 10, 12, /**/ 2, 6, 8, 14, 0, 4, 10, 12, // + 0, 6, 8, 14, 2, 4, 10, 12, /**/ 6, 8, 14, 0, 2, 4, 10, 12, // + 0, 2, 4, 8, 14, 6, 10, 12, /**/ 2, 4, 8, 14, 0, 6, 10, 12, // + 0, 4, 8, 14, 2, 6, 10, 12, /**/ 4, 8, 14, 0, 2, 6, 10, 12, // + 0, 2, 8, 14, 4, 6, 10, 12, /**/ 2, 8, 14, 0, 4, 6, 10, 12, // + 0, 8, 14, 2, 4, 6, 10, 12, /**/ 8, 14, 0, 2, 4, 6, 10, 12, // + 0, 2, 4, 6, 14, 8, 10, 12, /**/ 2, 4, 6, 14, 0, 8, 10, 12, // + 0, 4, 6, 14, 2, 8, 10, 12, /**/ 4, 6, 14, 0, 2, 8, 10, 12, // + 0, 2, 6, 14, 4, 8, 10, 12, /**/ 2, 6, 14, 0, 4, 8, 10, 12, // + 0, 6, 14, 2, 4, 8, 10, 12, /**/ 6, 14, 0, 2, 4, 8, 10, 12, // + 0, 2, 4, 14, 6, 8, 10, 12, /**/ 2, 4, 14, 0, 6, 8, 10, 12, // + 0, 4, 14, 2, 6, 8, 10, 12, /**/ 4, 14, 0, 2, 6, 8, 10, 12, // + 0, 2, 14, 4, 6, 8, 10, 12, /**/ 2, 14, 0, 4, 6, 8, 10, 12, // + 0, 14, 2, 4, 6, 8, 10, 12, /**/ 14, 0, 2, 4, 6, 8, 10, 12, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 0, 14, // + 0, 4, 6, 8, 10, 12, 2, 14, /**/ 4, 6, 8, 10, 12, 0, 2, 14, // + 0, 2, 6, 8, 10, 12, 4, 14, /**/ 2, 6, 8, 10, 12, 0, 4, 14, // + 0, 6, 8, 10, 12, 2, 4, 14, /**/ 6, 8, 10, 12, 0, 2, 4, 14, // + 0, 2, 4, 8, 10, 12, 6, 14, /**/ 2, 4, 8, 10, 12, 0, 6, 14, // + 0, 4, 8, 10, 12, 2, 6, 14, /**/ 4, 8, 10, 12, 0, 2, 6, 14, // + 0, 2, 8, 10, 12, 4, 6, 14, /**/ 2, 8, 10, 12, 0, 4, 6, 14, // + 0, 8, 10, 12, 2, 4, 6, 14, /**/ 8, 10, 12, 0, 2, 4, 6, 14, // + 0, 2, 4, 6, 10, 12, 8, 14, /**/ 2, 4, 6, 10, 12, 0, 8, 14, // + 0, 4, 6, 10, 12, 2, 8, 14, /**/ 4, 6, 10, 12, 0, 2, 8, 14, // + 0, 2, 6, 10, 12, 4, 8, 14, /**/ 2, 6, 10, 12, 0, 4, 8, 14, // + 0, 6, 10, 12, 2, 4, 8, 14, /**/ 6, 10, 12, 0, 2, 4, 8, 14, // + 0, 2, 4, 10, 12, 6, 8, 14, /**/ 2, 4, 10, 12, 0, 6, 8, 14, // + 0, 4, 10, 12, 2, 6, 8, 14, /**/ 4, 10, 12, 0, 2, 6, 8, 14, // + 0, 2, 10, 12, 4, 6, 8, 14, /**/ 2, 10, 12, 0, 4, 6, 8, 14, // + 0, 10, 12, 2, 4, 6, 8, 14, /**/ 10, 12, 0, 2, 4, 6, 8, 14, // + 0, 2, 4, 6, 8, 12, 10, 14, /**/ 2, 4, 6, 8, 12, 0, 10, 14, // + 0, 4, 6, 8, 12, 2, 10, 14, /**/ 4, 6, 8, 12, 0, 2, 10, 14, // + 0, 2, 6, 8, 12, 4, 10, 14, /**/ 2, 6, 8, 12, 0, 4, 10, 14, // + 0, 6, 8, 12, 2, 4, 10, 14, /**/ 6, 8, 12, 0, 2, 4, 10, 14, // + 0, 2, 4, 8, 12, 6, 10, 14, /**/ 2, 4, 8, 12, 0, 6, 10, 14, // + 0, 4, 8, 12, 2, 6, 10, 14, /**/ 4, 8, 12, 0, 2, 6, 10, 14, // + 0, 2, 8, 12, 4, 6, 10, 14, /**/ 2, 8, 12, 0, 4, 6, 10, 14, // + 0, 8, 12, 2, 4, 6, 10, 14, /**/ 8, 12, 0, 2, 4, 6, 10, 14, // + 0, 2, 4, 6, 12, 8, 10, 14, /**/ 2, 4, 6, 12, 0, 8, 10, 14, // + 0, 4, 6, 12, 2, 8, 10, 14, /**/ 4, 6, 12, 0, 2, 8, 10, 14, // + 0, 2, 6, 12, 4, 8, 10, 14, /**/ 2, 6, 12, 0, 4, 8, 10, 14, // + 0, 6, 12, 2, 4, 8, 10, 14, /**/ 6, 12, 0, 2, 4, 8, 10, 14, // + 0, 2, 4, 12, 6, 8, 10, 14, /**/ 2, 4, 12, 0, 6, 8, 10, 14, // + 0, 4, 12, 2, 6, 8, 10, 14, /**/ 4, 12, 0, 2, 6, 8, 10, 14, // + 0, 2, 12, 4, 6, 8, 10, 14, /**/ 2, 12, 0, 4, 6, 8, 10, 14, // + 0, 12, 2, 4, 6, 8, 10, 14, /**/ 12, 0, 2, 4, 6, 8, 10, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 0, 12, 14, // + 0, 4, 6, 8, 10, 2, 12, 14, /**/ 4, 6, 8, 10, 0, 2, 12, 14, // + 0, 2, 6, 8, 10, 4, 12, 14, /**/ 2, 6, 8, 10, 0, 4, 12, 14, // + 0, 6, 8, 10, 2, 4, 12, 14, /**/ 6, 8, 10, 0, 2, 4, 12, 14, // + 0, 2, 4, 8, 10, 6, 12, 14, /**/ 2, 4, 8, 10, 0, 6, 12, 14, // + 0, 4, 8, 10, 2, 6, 12, 14, /**/ 4, 8, 10, 0, 2, 6, 12, 14, // + 0, 2, 8, 10, 4, 6, 12, 14, /**/ 2, 8, 10, 0, 4, 6, 12, 14, // + 0, 8, 10, 2, 4, 6, 12, 14, /**/ 8, 10, 0, 2, 4, 6, 12, 14, // + 0, 2, 4, 6, 10, 8, 12, 14, /**/ 2, 4, 6, 10, 0, 8, 12, 14, // + 0, 4, 6, 10, 2, 8, 12, 14, /**/ 4, 6, 10, 0, 2, 8, 12, 14, // + 0, 2, 6, 10, 4, 8, 12, 14, /**/ 2, 6, 10, 0, 4, 8, 12, 14, // + 0, 6, 10, 2, 4, 8, 12, 14, /**/ 6, 10, 0, 2, 4, 8, 12, 14, // + 0, 2, 4, 10, 6, 8, 12, 14, /**/ 2, 4, 10, 0, 6, 8, 12, 14, // + 0, 4, 10, 2, 6, 8, 12, 14, /**/ 4, 10, 0, 2, 6, 8, 12, 14, // + 0, 2, 10, 4, 6, 8, 12, 14, /**/ 2, 10, 0, 4, 6, 8, 12, 14, // + 0, 10, 2, 4, 6, 8, 12, 14, /**/ 10, 0, 2, 4, 6, 8, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 0, 10, 12, 14, // + 0, 4, 6, 8, 2, 10, 12, 14, /**/ 4, 6, 8, 0, 2, 10, 12, 14, // + 0, 2, 6, 8, 4, 10, 12, 14, /**/ 2, 6, 8, 0, 4, 10, 12, 14, // + 0, 6, 8, 2, 4, 10, 12, 14, /**/ 6, 8, 0, 2, 4, 10, 12, 14, // + 0, 2, 4, 8, 6, 10, 12, 14, /**/ 2, 4, 8, 0, 6, 10, 12, 14, // + 0, 4, 8, 2, 6, 10, 12, 14, /**/ 4, 8, 0, 2, 6, 10, 12, 14, // + 0, 2, 8, 4, 6, 10, 12, 14, /**/ 2, 8, 0, 4, 6, 10, 12, 14, // + 0, 8, 2, 4, 6, 10, 12, 14, /**/ 8, 0, 2, 4, 6, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 0, 8, 10, 12, 14, // + 0, 4, 6, 2, 8, 10, 12, 14, /**/ 4, 6, 0, 2, 8, 10, 12, 14, // + 0, 2, 6, 4, 8, 10, 12, 14, /**/ 2, 6, 0, 4, 8, 10, 12, 14, // + 0, 6, 2, 4, 8, 10, 12, 14, /**/ 6, 0, 2, 4, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 0, 6, 8, 10, 12, 14, // + 0, 4, 2, 6, 8, 10, 12, 14, /**/ 4, 0, 2, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 0, 4, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const VFromD byte_idx{Load(d8, table + mask_bits * 8).raw}; + const VFromD pairs = ZipLower(byte_idx, byte_idx); + constexpr uint16_t kPairIndexIncrement = + HWY_IS_LITTLE_ENDIAN ? 0x0100 : 0x0001; + + return BitCast(d, pairs + Set(du, kPairIndexIncrement)); +} + +template +HWY_INLINE VFromD IndicesFromBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[256] = { + // PrintCompress32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, // + 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, // + 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE VFromD IndicesFromNotBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[256] = { + // PrintCompressNot32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, + 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE VFromD IndicesFromBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[64] = { + // PrintCompress64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE VFromD IndicesFromNotBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[64] = { + // PrintCompressNot64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_API Vec128 CompressBits(Vec128 v, uint64_t mask_bits) { + const DFromV d; + const RebindToUnsigned du; + + HWY_DASSERT(mask_bits < (1ull << N)); + const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits)); + return BitCast(d, TableLookupBytes(BitCast(du, v), indices)); +} + +template +HWY_API Vec128 CompressNotBits(Vec128 v, uint64_t mask_bits) { + const DFromV d; + const RebindToUnsigned du; + + HWY_DASSERT(mask_bits < (1ull << N)); + const auto indices = BitCast(du, detail::IndicesFromNotBits128(d, mask_bits)); + return BitCast(d, TableLookupBytes(BitCast(du, v), indices)); +} + +} // namespace detail + +// Single lane: no-op +template +HWY_API Vec128 Compress(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + // If mask[1] = 1 and mask[0] = 0, then swap both halves, else keep. + const Full128 d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskL, maskH); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +#if HWY_PPC_HAVE_10 +#ifdef HWY_NATIVE_COMPRESS8 +#undef HWY_NATIVE_COMPRESS8 +#else +#define HWY_NATIVE_COMPRESS8 +#endif + +// General case, 1 byte +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + const DFromV d; + return TableLookupBytes( + v, detail::CompressOrExpandIndicesFromMask(d, mask)); +} +#endif + +// General case, 2 or 4 bytes +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + const DFromV d; + return detail::CompressBits(v, BitsFromMask(d, mask)); +} + +// ------------------------------ CompressNot + +// Single lane: no-op +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // If mask[1] = 0 and mask[0] = 1, then swap both halves, else keep. + const Full128 d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskH, maskL); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +#if HWY_PPC_HAVE_10 +// General case, 1 byte +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + const DFromV d; + return TableLookupBytes( + v, detail::CompressOrExpandIndicesFromMask(d, Not(mask))); +} +#endif + +// General case, 2 or 4 bytes +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + const DFromV d; + // For partial vectors, we cannot pull the Not() into the table because + // BitsFromMask clears the upper bits. + if (N < 16 / sizeof(T)) { + return detail::CompressBits(v, BitsFromMask(d, Not(mask))); + } + return detail::CompressNotBits(v, BitsFromMask(d, mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +#if HWY_PPC_HAVE_10 +template +HWY_API Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + const DFromV d; + return Compress(v, LoadMaskBits(d, bits)); +} +#endif + +template +HWY_API Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + // As there are at most 8 lanes in v if sizeof(TFromD) > 1, simply + // convert bits[0] to a uint64_t + uint64_t mask_bits = bits[0]; + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::CompressBits(v, mask_bits); +} + +// ------------------------------ CompressStore, CompressBitsStore + +#if HWY_PPC_HAVE_10 +template +HWY_API size_t CompressStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const size_t count = CountTrue(d, m); + const auto indices = detail::CompressOrExpandIndicesFromMask(d, m); + const auto compressed = TableLookupBytes(v, indices); + StoreU(compressed, d, unaligned); + return count; +} +#endif + +template +HWY_API size_t CompressStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + const uint64_t mask_bits = BitsFromMask(d, m); + HWY_DASSERT(mask_bits < (1ull << MaxLanes(d))); + const size_t count = PopCount(mask_bits); + + const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + StoreU(compressed, d, unaligned); + return count; +} + +#if HWY_PPC_HAVE_10 +template +HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const size_t count = CountTrue(d, m); + const auto indices = detail::CompressOrExpandIndicesFromMask(d, m); + const auto compressed = TableLookupBytes(v, indices); + StoreN(compressed, d, unaligned, count); + return count; +} +#endif + +template +HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + const uint64_t mask_bits = BitsFromMask(d, m); + HWY_DASSERT(mask_bits < (1ull << MaxLanes(d))); + const size_t count = PopCount(mask_bits); + + const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); +#if (HWY_PPC_HAVE_9 && HWY_ARCH_PPC_64) || HWY_S390X_HAVE_Z14 + StoreN(compressed, d, unaligned, count); +#else + BlendedStore(compressed, FirstN(d, count), d, unaligned); +#endif + return count; +} + +#if HWY_PPC_HAVE_10 +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} +#endif + +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + // As there are at most 8 lanes in v if sizeof(TFromD) > 1, simply + // convert bits[0] to a uint64_t + uint64_t mask_bits = bits[0]; + constexpr size_t kN = MaxLanes(d); + if (kN < 8) { + mask_bits &= (1ull << kN) - 1; + } + const size_t count = PopCount(mask_bits); + + const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + StoreU(compressed, d, unaligned); + + return count; +} + +// ------------------------------ Expand +#if HWY_PPC_HAVE_10 +#ifdef HWY_NATIVE_EXPAND +#undef HWY_NATIVE_EXPAND +#else +#define HWY_NATIVE_EXPAND +#endif + +template +HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { + const DFromV d; + const auto idx = detail::CompressOrExpandIndicesFromMask(d, mask); + return IfThenElseZero(mask, TableLookupBytes(v, idx)); +} + +template +HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { + // Same as Compress, just zero out the mask=false lanes. + return IfThenElseZero(mask, Compress(v, mask)); +} + +// For single-element vectors, this is at least as fast as native. +template +HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { + return IfThenElseZero(mask, v); +} + +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { + return Expand(LoadU(d, unaligned), mask); +} +#endif // HWY_PPC_HAVE_10 + +// ------------------------------ StoreInterleaved2/3/4 + +// HWY_NATIVE_LOAD_STORE_INTERLEAVED not set, hence defined in +// generic_ops-inl.h. + +// ------------------------------ Additional mask logical operations +namespace detail { + +#if HWY_IS_LITTLE_ENDIAN +template +HWY_INLINE V Per64BitBlkRevLanesOnBe(V v) { + return v; +} +template +HWY_INLINE V Per128BitBlkRevLanesOnBe(V v) { + return v; +} +#else +template +HWY_INLINE V Per64BitBlkRevLanesOnBe(V v) { + const DFromV d; + return Reverse8(d, v); +} +template +HWY_INLINE V Per64BitBlkRevLanesOnBe(V v) { + const DFromV d; + return Reverse4(d, v); +} +template +HWY_INLINE V Per64BitBlkRevLanesOnBe(V v) { + const DFromV d; + return Reverse2(d, v); +} +template +HWY_INLINE V Per64BitBlkRevLanesOnBe(V v) { + return v; +} +template +HWY_INLINE V Per128BitBlkRevLanesOnBe(V v) { + const DFromV d; + return Reverse(d, v); +} +#endif + +template +HWY_INLINE V I128Subtract(V a, V b) { +#if HWY_S390X_HAVE_Z14 +#if HWY_COMPILER_CLANG || HWY_COMPILER_GCC_ACTUAL >= 1500 + // Workaround for bug in vec_sub_u128 in Clang vecintrin.h + + // The vec_sub_u128 intrinsic is also now deprecated in GCC 15 and later. + // The built-in U128x1 vector subtraction operator should be used instead of + // vec_sub_u128 with GCC 15 and later to avoid compiler warnings. + + typedef __uint128_t VU128 __attribute__((__vector_size__(16))); + const V diff_i128{reinterpret_cast>::type>( + reinterpret_cast(a.raw) - reinterpret_cast(b.raw))}; +#else // !HWY_COMPILER_CLANG + const V diff_i128{reinterpret_cast>::type>( + vec_sub_u128(reinterpret_cast<__vector unsigned char>(a.raw), + reinterpret_cast<__vector unsigned char>(b.raw)))}; +#endif // HWY_COMPILER_CLANG +#elif defined(__SIZEOF_INT128__) + using VU128 = __vector unsigned __int128; + const V diff_i128{reinterpret_cast>::type>( + vec_sub(reinterpret_cast(a.raw), reinterpret_cast(b.raw)))}; +#else + const DFromV d; + const Repartition du64; + + const auto u64_a = BitCast(du64, a); + const auto u64_b = BitCast(du64, b); + + const auto diff_u64 = u64_a - u64_b; + const auto borrow_u64 = VecFromMask(du64, u64_a < u64_b); + +#if HWY_IS_LITTLE_ENDIAN + const auto borrow_u64_shifted = ShiftLeftBytes<8>(du64, borrow_u64); +#else + const auto borrow_u64_shifted = ShiftRightBytes<8>(du64, borrow_u64); +#endif + + const auto diff_i128 = BitCast(d, diff_u64 + borrow_u64_shifted); +#endif + + return diff_i128; +} + +} // namespace detail + +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + return mask; +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const FixedTag d; + const auto vmask = VecFromMask(d, mask); + return MaskFromVec(Or(vmask, InterleaveLower(vmask, vmask))); +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const Simd d; + const Full64 d_full64; + + const auto vmask = VecFromMask(d, mask); + const auto vmask_le64 = + BitCast(Full64(), + detail::Per64BitBlkRevLanesOnBe(ResizeBitCast(d_full64, vmask))); + const auto neg_vmask_le64 = Neg(vmask_le64); + const auto neg_vmask = ResizeBitCast( + d, detail::Per64BitBlkRevLanesOnBe(BitCast(d_full64, neg_vmask_le64))); + + return MaskFromVec(Or(vmask, neg_vmask)); +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const Full128 d; + auto vmask = VecFromMask(d, mask); + + const auto vmask_le128 = detail::Per128BitBlkRevLanesOnBe(vmask); + const auto neg_vmask_le128 = detail::I128Subtract(Zero(d), vmask_le128); + const auto neg_vmask = detail::Per128BitBlkRevLanesOnBe(neg_vmask_le128); + + return MaskFromVec(BitCast(d, Or(vmask, neg_vmask))); +} + +template +HWY_API Mask128 SetBeforeFirst(Mask128 mask) { + return Not(SetAtOrAfterFirst(mask)); +} + +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + return mask; +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const FixedTag d; + const RebindToSigned di; + + const auto vmask = BitCast(di, VecFromMask(d, mask)); + const auto zero = Zero(di); + const auto vmask2 = VecFromMask(di, InterleaveLower(zero, vmask) == zero); + return MaskFromVec(BitCast(d, And(vmask, vmask2))); +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const Simd d; + const Full64 d_full64; + const RebindToSigned di; + + const auto vmask = VecFromMask(d, mask); + const auto vmask_le64 = + BitCast(Full64(), + detail::Per64BitBlkRevLanesOnBe(ResizeBitCast(d_full64, vmask))); + const auto neg_vmask_le64 = Neg(vmask_le64); + const auto neg_vmask = ResizeBitCast( + d, detail::Per64BitBlkRevLanesOnBe(BitCast(d_full64, neg_vmask_le64))); + + const auto first_vmask = BitCast(di, And(vmask, neg_vmask)); + return MaskFromVec(BitCast(d, Or(first_vmask, Neg(first_vmask)))); +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const Full128 d; + const RebindToSigned di; + + const auto vmask = VecFromMask(d, mask); + const auto vmask_le128 = detail::Per128BitBlkRevLanesOnBe(vmask); + const auto neg_vmask_le128 = detail::I128Subtract(Zero(d), vmask_le128); + const auto neg_vmask = detail::Per128BitBlkRevLanesOnBe(neg_vmask_le128); + + return MaskFromVec(BitCast(d, Neg(BitCast(di, And(vmask, neg_vmask))))); +} + +template +HWY_API Mask128 SetAtOrBeforeFirst(Mask128 /*mask*/) { + const FixedTag d; + const RebindToSigned di; + using TI = MakeSigned; + + return RebindMask(d, MaskFromVec(Set(di, TI(-1)))); +} +template +HWY_API Mask128 SetAtOrBeforeFirst(Mask128 mask) { + const Simd d; + return SetBeforeFirst(MaskFromVec(ShiftLeftLanes<1>(VecFromMask(d, mask)))); +} + +// ------------------------------ SumsOf2 and SumsOf4 +namespace detail { + +#if !HWY_S390X_HAVE_Z14 +// Casts nominally int32_t result to D. +template +HWY_INLINE VFromD AltivecVsum4sbs(D d, __vector signed char a, + __vector signed int b) { + const Repartition di32; +#ifdef __OPTIMIZE__ + if (IsConstantRawAltivecVect(a) && IsConstantRawAltivecVect(b)) { + const int64_t sum0 = + static_cast(a[0]) + static_cast(a[1]) + + static_cast(a[2]) + static_cast(a[3]) + + static_cast(b[0]); + const int64_t sum1 = + static_cast(a[4]) + static_cast(a[5]) + + static_cast(a[6]) + static_cast(a[7]) + + static_cast(b[1]); + const int64_t sum2 = + static_cast(a[8]) + static_cast(a[9]) + + static_cast(a[10]) + static_cast(a[11]) + + static_cast(b[2]); + const int64_t sum3 = + static_cast(a[12]) + static_cast(a[13]) + + static_cast(a[14]) + static_cast(a[15]) + + static_cast(b[3]); + const int32_t sign0 = static_cast(sum0 >> 63); + const int32_t sign1 = static_cast(sum1 >> 63); + const int32_t sign2 = static_cast(sum2 >> 63); + const int32_t sign3 = static_cast(sum3 >> 63); + using Raw = typename detail::Raw128::type; + return BitCast( + d, + VFromD{Raw{ + (sign0 == (sum0 >> 31)) ? static_cast(sum0) + : static_cast(sign0 ^ 0x7FFFFFFF), + (sign1 == (sum1 >> 31)) ? static_cast(sum1) + : static_cast(sign1 ^ 0x7FFFFFFF), + (sign2 == (sum2 >> 31)) ? static_cast(sum2) + : static_cast(sign2 ^ 0x7FFFFFFF), + (sign3 == (sum3 >> 31)) + ? static_cast(sum3) + : static_cast(sign3 ^ 0x7FFFFFFF)}}); + } else // NOLINT +#endif + { + return BitCast(d, VFromD{vec_vsum4sbs(a, b)}); + } +} + +// Casts nominally uint32_t result to D. +template +HWY_INLINE VFromD AltivecVsum4ubs(D d, __vector unsigned char a, + __vector unsigned int b) { + const Repartition du32; +#ifdef __OPTIMIZE__ + if (IsConstantRawAltivecVect(a) && IsConstantRawAltivecVect(b)) { + const uint64_t sum0 = + static_cast(a[0]) + static_cast(a[1]) + + static_cast(a[2]) + static_cast(a[3]) + + static_cast(b[0]); + const uint64_t sum1 = + static_cast(a[4]) + static_cast(a[5]) + + static_cast(a[6]) + static_cast(a[7]) + + static_cast(b[1]); + const uint64_t sum2 = + static_cast(a[8]) + static_cast(a[9]) + + static_cast(a[10]) + static_cast(a[11]) + + static_cast(b[2]); + const uint64_t sum3 = + static_cast(a[12]) + static_cast(a[13]) + + static_cast(a[14]) + static_cast(a[15]) + + static_cast(b[3]); + return BitCast( + d, + VFromD{(__vector unsigned int){ + static_cast(sum0 <= 0xFFFFFFFFu ? sum0 : 0xFFFFFFFFu), + static_cast(sum1 <= 0xFFFFFFFFu ? sum1 : 0xFFFFFFFFu), + static_cast(sum2 <= 0xFFFFFFFFu ? sum2 : 0xFFFFFFFFu), + static_cast(sum3 <= 0xFFFFFFFFu ? sum3 + : 0xFFFFFFFFu)}}); + } else // NOLINT +#endif + { + return BitCast(d, VFromD{vec_vsum4ubs(a, b)}); + } +} + +// Casts nominally int32_t result to D. +template +HWY_INLINE VFromD AltivecVsum2sws(D d, __vector signed int a, + __vector signed int b) { + const Repartition di32; +#ifdef __OPTIMIZE__ + const Repartition du64; + constexpr int kDestLaneOffset = HWY_IS_BIG_ENDIAN; + if (IsConstantRawAltivecVect(a) && __builtin_constant_p(b[kDestLaneOffset]) && + __builtin_constant_p(b[kDestLaneOffset + 2])) { + const int64_t sum0 = static_cast(a[0]) + + static_cast(a[1]) + + static_cast(b[kDestLaneOffset]); + const int64_t sum1 = static_cast(a[2]) + + static_cast(a[3]) + + static_cast(b[kDestLaneOffset + 2]); + const int32_t sign0 = static_cast(sum0 >> 63); + const int32_t sign1 = static_cast(sum1 >> 63); + return BitCast(d, VFromD{(__vector unsigned long long){ + (sign0 == (sum0 >> 31)) + ? static_cast(sum0) + : static_cast(sign0 ^ 0x7FFFFFFF), + (sign1 == (sum1 >> 31)) + ? static_cast(sum1) + : static_cast(sign1 ^ 0x7FFFFFFF)}}); + } else // NOLINT +#endif + { + __vector signed int sum; + + // Inline assembly is used for vsum2sws to avoid unnecessary shuffling + // on little-endian PowerPC targets as the result of the vsum2sws + // instruction will already be in the correct lanes on little-endian + // PowerPC targets. + __asm__("vsum2sws %0,%1,%2" : "=v"(sum) : "v"(a), "v"(b)); + + return BitCast(d, VFromD{sum}); + } +} + +// Casts nominally int32_t result to D. +template +HWY_INLINE VFromD AltivecVsum4shs(D d, __vector signed short a, + __vector signed int b) { + const Repartition di32; +#ifdef __OPTIMIZE__ + if (IsConstantRawAltivecVect(a) && IsConstantRawAltivecVect(b)) { + const int64_t sum0 = static_cast(a[0]) + + static_cast(a[1]) + + static_cast(b[0]); + const int64_t sum1 = static_cast(a[2]) + + static_cast(a[3]) + + static_cast(b[1]); + const int64_t sum2 = static_cast(a[4]) + + static_cast(a[5]) + + static_cast(b[2]); + const int64_t sum3 = static_cast(a[6]) + + static_cast(a[7]) + + static_cast(b[3]); + const int32_t sign0 = static_cast(sum0 >> 63); + const int32_t sign1 = static_cast(sum1 >> 63); + const int32_t sign2 = static_cast(sum2 >> 63); + const int32_t sign3 = static_cast(sum3 >> 63); + using Raw = typename detail::Raw128::type; + return BitCast( + d, + VFromD{Raw{ + (sign0 == (sum0 >> 31)) ? static_cast(sum0) + : static_cast(sign0 ^ 0x7FFFFFFF), + (sign1 == (sum1 >> 31)) ? static_cast(sum1) + : static_cast(sign1 ^ 0x7FFFFFFF), + (sign2 == (sum2 >> 31)) ? static_cast(sum2) + : static_cast(sign2 ^ 0x7FFFFFFF), + (sign3 == (sum3 >> 31)) + ? static_cast(sum3) + : static_cast(sign3 ^ 0x7FFFFFFF)}}); + } else // NOLINT +#endif + { + return BitCast(d, VFromD{vec_vsum4shs(a, b)}); + } +} + +// Casts nominally int32_t result to D. +template +HWY_INLINE VFromD AltivecVsumsws(D d, __vector signed int a, + __vector signed int b) { + const Repartition di32; +#ifdef __OPTIMIZE__ + constexpr int kDestLaneOffset = HWY_IS_LITTLE_ENDIAN ? 0 : 3; + if (IsConstantRawAltivecVect(a) && __builtin_constant_p(b[kDestLaneOffset])) { + const int64_t sum = + static_cast(a[0]) + static_cast(a[1]) + + static_cast(a[2]) + static_cast(a[3]) + + static_cast(b[kDestLaneOffset]); + const int32_t sign = static_cast(sum >> 63); +#if HWY_IS_LITTLE_ENDIAN + return BitCast( + d, VFromD{(__vector signed int){ + (sign == (sum >> 31)) ? static_cast(sum) + : static_cast(sign ^ 0x7FFFFFFF), + 0, 0, 0}}); +#else + return BitCast(d, VFromD{(__vector signed int){ + 0, 0, 0, + (sign == (sum >> 31)) + ? static_cast(sum) + : static_cast(sign ^ 0x7FFFFFFF)}}); +#endif + } else // NOLINT +#endif + { + __vector signed int sum; + + // Inline assembly is used for vsumsws to avoid unnecessary shuffling + // on little-endian PowerPC targets as the result of the vsumsws + // instruction will already be in the correct lanes on little-endian + // PowerPC targets. + __asm__("vsumsws %0,%1,%2" : "=v"(sum) : "v"(a), "v"(b)); + + return BitCast(d, VFromD{sum}); + } +} + +template +HWY_INLINE Vec128 AltivecU16SumsOf2(Vec128 v) { + const RebindToSigned> di16; + const RepartitionToWide di32; + return AltivecVsum4shs(di32, Xor(BitCast(di16, v), Set(di16, -32768)).raw, + Set(di32, 65536).raw); +} +#endif // !HWY_S390X_HAVE_Z14 + +// U16->U32 SumsOf2 +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag /*type_tag*/, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWide dw; + +#if HWY_S390X_HAVE_Z14 + return VFromD{vec_sum4(v.raw, Zero(d).raw)}; +#else + return BitCast(dw, AltivecU16SumsOf2(v)); +#endif +} + +// I16->I32 SumsOf2 +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag /*type_tag*/, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWide dw; + +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du; + return BitCast(dw, SumsOf2(hwy::UnsignedTag(), hwy::SizeTag<2>(), + BitCast(du, Xor(v, SignBit(d))))) + + Set(dw, int32_t{-65536}); +#else + return AltivecVsum4shs(dw, v.raw, Zero(dw).raw); +#endif +} + +#if HWY_S390X_HAVE_Z14 +// U32->U64 SumsOf2 +template +HWY_INLINE VFromD>> SumsOf2( + hwy::UnsignedTag /*type_tag*/, hwy::SizeTag<4> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWide dw; + return VFromD{vec_sum2(v.raw, Zero(d).raw)}; +} + +// I32->I64 SumsOf2 +template +HWY_INLINE VFromD>> SumsOf2( + hwy::SignedTag /*type_tag*/, hwy::SizeTag<4> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWide dw; + const RebindToUnsigned du; + + return BitCast(dw, SumsOf2(hwy::UnsignedTag(), hwy::SizeTag<4>(), + BitCast(du, Xor(v, SignBit(d))))) + + Set(dw, int64_t{-4294967296LL}); +} +#endif + +// U8->U32 SumsOf4 +template +HWY_INLINE VFromD>> SumsOf4( + hwy::UnsignedTag /*type_tag*/, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWideX2 dw2; + +#if HWY_S390X_HAVE_Z14 + return VFromD{vec_sum4(v.raw, Zero(d).raw)}; +#else + return AltivecVsum4ubs(dw2, v.raw, Zero(dw2).raw); +#endif +} + +// I8->I32 SumsOf4 +template +HWY_INLINE VFromD>> SumsOf4( + hwy::SignedTag /*type_tag*/, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWideX2 dw2; + +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du; + return BitCast(dw2, SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<1>(), + BitCast(du, Xor(v, SignBit(d))))) + + Set(dw2, int32_t{-512}); +#else + return AltivecVsum4sbs(dw2, v.raw, Zero(dw2).raw); +#endif +} + +// U16->U64 SumsOf4 +template +HWY_INLINE VFromD>> SumsOf4( + hwy::UnsignedTag /*type_tag*/, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWide dw; + const RepartitionToWide dw2; + +#if HWY_S390X_HAVE_Z14 + return VFromD{vec_sum2(v.raw, Zero(d).raw)}; +#else + const RebindToSigned dw_i; + return AltivecVsum2sws(dw2, BitCast(dw_i, SumsOf2(v)).raw, Zero(dw_i).raw); +#endif +} + +// I16->I64 SumsOf4 +template +HWY_INLINE VFromD>> SumsOf4( + hwy::SignedTag /*type_tag*/, hwy::SizeTag<2> /*lane_size_tag*/, V v) { + const DFromV d; + const RepartitionToWide dw; + const RepartitionToWide dw2; + +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du; + return BitCast(dw2, SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<2>(), + BitCast(du, Xor(v, SignBit(d))))) + + Set(dw2, int64_t{-131072}); +#else // VSX + const auto sums_of_4_in_lo32 = + AltivecVsum2sws(dw, SumsOf2(v).raw, Zero(dw).raw); + +#if HWY_IS_LITTLE_ENDIAN + return PromoteEvenTo(dw2, sums_of_4_in_lo32); +#else + return PromoteOddTo(dw2, sums_of_4_in_lo32); +#endif // HWY_IS_LITTLE_ENDIAN +#endif // HWY_S390X_HAVE_Z14 +} + +} // namespace detail + +// ------------------------------ SumOfLanes + +// We define SumOfLanes for 8/16-bit types (and I32/U32/I64/U64 on Z14/Z15/Z16); +// enable generic for the rest. +#undef HWY_IF_SUM_OF_LANES_D +#if HWY_S390X_HAVE_Z14 +#define HWY_IF_SUM_OF_LANES_D(D) HWY_IF_LANES_GT_D(D, 1), HWY_IF_FLOAT3264_D(D) +#else +#define HWY_IF_SUM_OF_LANES_D(D) \ + HWY_IF_LANES_GT_D(D, 1), HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 4) | (1 << 8)) +#endif + +#if HWY_S390X_HAVE_Z14 +namespace detail { + +#if HWY_COMPILER_CLANG && HWY_HAS_BUILTIN(__builtin_s390_vsumqf) && \ + HWY_HAS_BUILTIN(__builtin_s390_vsumqg) +// Workaround for bug in vec_sum_u128 in Clang vecintrin.h +template +HWY_INLINE Vec128 SumOfU32OrU64LanesAsU128(Vec128 v) { + typedef __uint128_t VU128 __attribute__((__vector_size__(16))); + const DFromV d; + const RebindToUnsigned du; + const VU128 sum = {__builtin_s390_vsumqf(BitCast(du, v).raw, Zero(du).raw)}; + return Vec128{reinterpret_cast::type>(sum)}; +} +template +HWY_INLINE Vec128 SumOfU32OrU64LanesAsU128(Vec128 v) { + typedef __uint128_t VU128 __attribute__((__vector_size__(16))); + const DFromV d; + const RebindToUnsigned du; + const VU128 sum = {__builtin_s390_vsumqg(BitCast(du, v).raw, Zero(du).raw)}; + return Vec128{reinterpret_cast::type>(sum)}; +} +#else +template +HWY_INLINE Vec128 SumOfU32OrU64LanesAsU128(Vec128 v) { + const DFromV d; + const RebindToUnsigned du; +#if HWY_COMPILER_GCC_ACTUAL >= 1500 || HWY_COMPILER_CLANG >= 2100 + // GCC 15 and Clang 20 have new vec_sum intrinsics that replaced the + // vec_sum_u128 intrinsic + + // vec_sum needs to be used instead of vec_sum_u128 with GCC 15 or later to + // avoid compiler warnings + return Vec128{reinterpret_cast::type>( + vec_sum(BitCast(du, v).raw, Zero(du).raw))}; +#else + return BitCast( + d, Vec128{vec_sum_u128(BitCast(du, v).raw, Zero(du).raw)}); +#endif +} +#endif + +} // namespace detail + +template +HWY_API VFromD SumOfLanes(D /*d64*/, VFromD v) { + return Broadcast<1>(detail::SumOfU32OrU64LanesAsU128(v)); +} +#endif + +template +HWY_API Vec32 SumOfLanes(D du16, Vec32 v) { + constexpr int kSumLaneIdx = HWY_IS_BIG_ENDIAN; + return Broadcast( + BitCast(du16, detail::SumsOf2(hwy::UnsignedTag(), hwy::SizeTag<2>(), v))); +} + +template +HWY_API Vec64 SumOfLanes(D du16, Vec64 v) { + constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 3; + return Broadcast( + BitCast(du16, detail::SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<2>(), v))); +} + +template +HWY_API Vec128 SumOfLanes(D du16, Vec128 v) { + constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 7; +#if HWY_S390X_HAVE_Z14 + return Broadcast( + BitCast(du16, detail::SumOfU32OrU64LanesAsU128(detail::SumsOf4( + hwy::UnsignedTag(), hwy::SizeTag<2>(), v)))); +#else // VSX + const auto zero = Zero(Full128()); + return Broadcast( + detail::AltivecVsumsws(du16, detail::AltivecU16SumsOf2(v).raw, zero.raw)); +#endif +} + +template +HWY_API Vec32 SumOfLanes(D di16, Vec32 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du16; + return BitCast(di16, SumOfLanes(du16, BitCast(du16, v))); +#else + constexpr int kSumLaneIdx = HWY_IS_BIG_ENDIAN; + return Broadcast( + BitCast(di16, detail::SumsOf2(hwy::SignedTag(), hwy::SizeTag<2>(), v))); +#endif +} + +template +HWY_API Vec64 SumOfLanes(D di16, Vec64 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du16; + return BitCast(di16, SumOfLanes(du16, BitCast(du16, v))); +#else + constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 3; + return Broadcast( + BitCast(di16, detail::SumsOf4(hwy::SignedTag(), hwy::SizeTag<2>(), v))); +#endif +} + +template +HWY_API Vec128 SumOfLanes(D di16, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du16; + return BitCast(di16, SumOfLanes(du16, BitCast(du16, v))); +#else + constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 7; + const Full128 di32; + const auto zero = Zero(di32); + return Broadcast(detail::AltivecVsumsws( + di16, detail::AltivecVsum4shs(di32, v.raw, zero.raw).raw, zero.raw)); +#endif +} + +template +HWY_API Vec32 SumOfLanes(D du8, Vec32 v) { + constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 3; + return Broadcast( + BitCast(du8, detail::SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<1>(), v))); +} + +template +HWY_API Vec16 SumOfLanes(D du8, Vec16 v) { + const Twice dt_u8; + return LowerHalf(du8, SumOfLanes(dt_u8, Combine(dt_u8, Zero(du8), v))); +} + +template +HWY_API Vec64 SumOfLanes(D du8, Vec64 v) { + constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 7; + return Broadcast(BitCast(du8, SumsOf8(v))); +} + +template +HWY_API Vec128 SumOfLanes(D du8, Vec128 v) { + constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 15; + +#if HWY_S390X_HAVE_Z14 + return Broadcast( + BitCast(du8, detail::SumOfU32OrU64LanesAsU128(detail::SumsOf4( + hwy::UnsignedTag(), hwy::SizeTag<1>(), v)))); +#else + const Full128 du32; + const RebindToSigned di32; + const Vec128 zero = Zero(du32); + return Broadcast(detail::AltivecVsumsws( + du8, detail::AltivecVsum4ubs(di32, v.raw, zero.raw).raw, + BitCast(di32, zero).raw)); +#endif +} + +template +HWY_API Vec32 SumOfLanes(D di8, Vec32 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du8; + return BitCast(di8, SumOfLanes(du8, BitCast(du8, v))); +#else + constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 3; + return Broadcast( + BitCast(di8, detail::SumsOf4(hwy::SignedTag(), hwy::SizeTag<1>(), v))); +#endif +} + +template +HWY_API Vec16 SumOfLanes(D di8, Vec16 v) { + const Twice dt_i8; + return LowerHalf(di8, SumOfLanes(dt_i8, Combine(dt_i8, Zero(di8), v))); +} + +template +HWY_API Vec64 SumOfLanes(D di8, Vec64 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du8; + return BitCast(di8, SumOfLanes(du8, BitCast(du8, v))); +#else + constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 7; + return Broadcast(BitCast(di8, SumsOf8(v))); +#endif +} + +template +HWY_API Vec128 SumOfLanes(D di8, Vec128 v) { +#if HWY_S390X_HAVE_Z14 + const RebindToUnsigned du8; + return BitCast(di8, SumOfLanes(du8, BitCast(du8, v))); +#else + constexpr int kSumLaneIdx = HWY_IS_LITTLE_ENDIAN ? 0 : 15; + const Full128 di32; + const Vec128 zero = Zero(di32); + return Broadcast(detail::AltivecVsumsws( + di8, detail::AltivecVsum4sbs(di32, v.raw, zero.raw).raw, zero.raw)); +#endif +} + +#if HWY_S390X_HAVE_Z14 +template +HWY_API VFromD SumOfLanes(D d32, VFromD v) { + const RebindToUnsigned du32; + return Broadcast<1>( + BitCast(d32, detail::SumsOf2(hwy::UnsignedTag(), hwy::SizeTag<4>(), + BitCast(du32, v)))); +} + +template +HWY_API VFromD SumOfLanes(D /*d32*/, VFromD v) { + return Broadcast<3>(detail::SumOfU32OrU64LanesAsU128(v)); +} +#endif + +// generic_ops defines MinOfLanes and MaxOfLanes. + +// ------------------------------ ReduceSum for N=4 I8/U8 + +// GetLane(SumsOf4(v)) is more efficient on PPC/Z14 than the default N=4 +// I8/U8 ReduceSum implementation in generic_ops-inl.h +#ifdef HWY_NATIVE_REDUCE_SUM_4_UI8 +#undef HWY_NATIVE_REDUCE_SUM_4_UI8 +#else +#define HWY_NATIVE_REDUCE_SUM_4_UI8 +#endif + +template +HWY_API TFromD ReduceSum(D /*d*/, VFromD v) { + return static_cast>(GetLane(SumsOf4(v))); +} + +// ------------------------------ BitShuffle + +#ifdef HWY_NATIVE_BITSHUFFLE +#undef HWY_NATIVE_BITSHUFFLE +#else +#define HWY_NATIVE_BITSHUFFLE +#endif + +template ), HWY_IF_UI8(TFromV), + HWY_IF_V_SIZE_V(VI, HWY_MAX_LANES_V(V) * 8)> +HWY_API V BitShuffle(V v, VI idx) { + const DFromV d64; + const RebindToUnsigned du64; + const Repartition du8; + + const Full128> d_full_u64; + const Full128> d_full_u8; + + using RawVU64 = __vector unsigned long long; + +#if HWY_PPC_HAVE_9 + +#if HWY_IS_LITTLE_ENDIAN + (void)d_full_u64; + auto bit_idx = ResizeBitCast(d_full_u8, idx); +#else + auto bit_idx = + BitCast(d_full_u8, ReverseLaneBytes(ResizeBitCast(d_full_u64, idx))); +#endif + + bit_idx = Xor(bit_idx, Set(d_full_u8, uint8_t{0x3F})); + + return BitCast(d64, VFromD{reinterpret_cast( + vec_bperm(BitCast(du64, v).raw, bit_idx.raw))}); +#else // !HWY_PPC_HAVE_9 + +#if HWY_IS_LITTLE_ENDIAN + const auto bit_idx_xor_mask = BitCast( + d_full_u8, Dup128VecFromValues(d_full_u64, uint64_t{0x7F7F7F7F7F7F7F7Fu}, + uint64_t{0x3F3F3F3F3F3F3F3Fu})); + const auto bit_idx = Xor(ResizeBitCast(d_full_u8, idx), bit_idx_xor_mask); + constexpr int kBitShufResultByteShrAmt = 8; +#else + const auto bit_idx_xor_mask = BitCast( + d_full_u8, Dup128VecFromValues(d_full_u64, uint64_t{0x3F3F3F3F3F3F3F3Fu}, + uint64_t{0x7F7F7F7F7F7F7F7Fu})); + const auto bit_idx = + Xor(BitCast(d_full_u8, ReverseLaneBytes(ResizeBitCast(d_full_u64, idx))), + bit_idx_xor_mask); + constexpr int kBitShufResultByteShrAmt = 6; +#endif + +#if HWY_S390X_HAVE_Z14 + +#if HWY_COMPILER_GCC_ACTUAL >= 1500 || HWY_COMPILER_CLANG >= 2100 + // GCC 15 and Clang 20 have added the vec_bperm intrinsic + + // Need to use vec_bperm instead of vec_bperm_u128 with GCC 15 and later to + // avoid compiler warning + using RawVU128 = __vector unsigned __int128; + + const VFromD bit_shuf_result{reinterpret_cast( + vec_bperm(reinterpret_cast(v.raw), bit_idx.raw))}; +#else + const VFromD bit_shuf_result{reinterpret_cast( + vec_bperm_u128(BitCast(du8, v).raw, bit_idx.raw))}; +#endif // !(HWY_COMPILER_GCC_ACTUAL >= 1500 || HWY_COMPILER_CLANG >= 2100) + +#elif defined(__SIZEOF_INT128__) + using RawVU128 = __vector unsigned __int128; + const VFromD bit_shuf_result{reinterpret_cast( + vec_vbpermq(reinterpret_cast(v.raw), bit_idx.raw))}; +#else + using RawVU128 = __vector unsigned char; + const VFromD bit_shuf_result{reinterpret_cast( + vec_vbpermq(reinterpret_cast(v.raw), bit_idx.raw))}; +#endif + + return ResizeBitCast( + d64, PromoteTo(d_full_u64, + ResizeBitCast( + Rebind(), + CombineShiftRightBytes( + d_full_u64, bit_shuf_result, bit_shuf_result)))); +#endif // HWY_PPC_HAVE_9 +} + +// ------------------------------ Lt128 + +namespace detail { + +// Returns vector-mask for Lt128. +template > +HWY_INLINE V Lt128Vec(D d, V a, V b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); +#if HWY_PPC_HAVE_10 && defined(__SIZEOF_INT128__) + (void)d; + using VU64 = __vector unsigned long long; + using VU128 = __vector unsigned __int128; +#if HWY_IS_LITTLE_ENDIAN + const VU128 a_u128 = reinterpret_cast(a.raw); + const VU128 b_u128 = reinterpret_cast(b.raw); +#else + // NOTE: Need to swap the halves of both a and b on big-endian targets + // as the upper 64 bits of a and b are in lane 1 and the lower 64 bits + // of a and b are in lane 0 whereas the vec_cmplt operation below expects + // the upper 64 bits in lane 0 and the lower 64 bits in lane 1 on + // big-endian PPC targets. + const VU128 a_u128 = reinterpret_cast(vec_sld(a.raw, a.raw, 8)); + const VU128 b_u128 = reinterpret_cast(vec_sld(b.raw, b.raw, 8)); +#endif + return V{reinterpret_cast(vec_cmplt(a_u128, b_u128))}; +#else // !HWY_PPC_HAVE_10 + // Truth table of Eq and Lt for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const auto eqHL = Eq(a, b); + const V ltHL = VecFromMask(d, Lt(a, b)); + const V ltLX = ShiftLeftLanes<1>(ltHL); + const V vecHx = IfThenElse(eqHL, ltLX, ltHL); + return InterleaveUpper(d, vecHx, vecHx); +#endif +} + +// Returns vector-mask for Eq128. +template > +HWY_INLINE V Eq128Vec(D d, V a, V b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); +#if HWY_PPC_HAVE_10 && defined(__SIZEOF_INT128__) + (void)d; + using VU64 = __vector unsigned long long; + using VU128 = __vector unsigned __int128; + return V{reinterpret_cast(vec_cmpeq(reinterpret_cast(a.raw), + reinterpret_cast(b.raw)))}; +#else + const auto eqHL = VecFromMask(d, Eq(a, b)); + const auto eqLH = Reverse2(d, eqHL); + return And(eqHL, eqLH); +#endif +} + +template > +HWY_INLINE V Ne128Vec(D d, V a, V b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); +#if HWY_PPC_HAVE_10 && defined(__SIZEOF_INT128__) + (void)d; + using VU64 = __vector unsigned long long; + using VU128 = __vector unsigned __int128; + return V{reinterpret_cast(vec_cmpne(reinterpret_cast(a.raw), + reinterpret_cast(b.raw)))}; +#else + const auto neHL = VecFromMask(d, Ne(a, b)); + const auto neLH = Reverse2(d, neHL); + return Or(neHL, neLH); +#endif +} + +template > +HWY_INLINE V Lt128UpperVec(D d, V a, V b) { + const V ltHL = VecFromMask(d, Lt(a, b)); + return InterleaveUpper(d, ltHL, ltHL); +} + +template > +HWY_INLINE V Eq128UpperVec(D d, V a, V b) { + const V eqHL = VecFromMask(d, Eq(a, b)); + return InterleaveUpper(d, eqHL, eqHL); +} + +template > +HWY_INLINE V Ne128UpperVec(D d, V a, V b) { + const V neHL = VecFromMask(d, Ne(a, b)); + return InterleaveUpper(d, neHL, neHL); +} + +} // namespace detail + +template > +HWY_API MFromD Lt128(D d, V a, V b) { + return MaskFromVec(detail::Lt128Vec(d, a, b)); +} + +template > +HWY_API MFromD Eq128(D d, V a, V b) { + return MaskFromVec(detail::Eq128Vec(d, a, b)); +} + +template > +HWY_API MFromD Ne128(D d, V a, V b) { + return MaskFromVec(detail::Ne128Vec(d, a, b)); +} + +template > +HWY_API MFromD Lt128Upper(D d, V a, V b) { + return MaskFromVec(detail::Lt128UpperVec(d, a, b)); +} + +template > +HWY_API MFromD Eq128Upper(D d, V a, V b) { + return MaskFromVec(detail::Eq128UpperVec(d, a, b)); +} + +template > +HWY_API MFromD Ne128Upper(D d, V a, V b) { + return MaskFromVec(detail::Ne128UpperVec(d, a, b)); +} + +// ------------------------------ Min128, Max128 (Lt128) + +// Avoids the extra MaskFromVec in Lt128. +template > +HWY_API V Min128(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128Vec(d, a, b), a, b); +} + +template > +HWY_API V Max128(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128Vec(d, b, a), a, b); +} + +template > +HWY_API V Min128Upper(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128UpperVec(d, a, b), a, b); +} + +template > +HWY_API V Max128Upper(D d, const V a, const V b) { + return IfVecThenElse(detail::Lt128UpperVec(d, b, a), a, b); +} + +// -------------------- LeadingZeroCount, TrailingZeroCount, HighestSetBitIndex + +#ifdef HWY_NATIVE_LEADING_ZERO_COUNT +#undef HWY_NATIVE_LEADING_ZERO_COUNT +#else +#define HWY_NATIVE_LEADING_ZERO_COUNT +#endif + +template +HWY_API V LeadingZeroCount(V v) { +#if HWY_S390X_HAVE_Z14 + const DFromV d; + const RebindToUnsigned du; + +#if HWY_COMPILER_GCC_ACTUAL && defined(__OPTIMIZE__) + // Work around for GCC compiler bug in vec_cnttz on Z14/Z15 if v[i] is a + // constant + __asm__("" : "+v"(v.raw)); +#endif + + return BitCast(d, VFromD{vec_cntlz(BitCast(du, v).raw)}); +#else + return V{vec_cntlz(v.raw)}; +#endif +} + +template +HWY_API V HighestSetBitIndex(V v) { + const DFromV d; + using T = TFromD; + return BitCast(d, Set(d, T{sizeof(T) * 8 - 1}) - LeadingZeroCount(v)); +} + +#if HWY_PPC_HAVE_9 || HWY_S390X_HAVE_Z14 +template +HWY_API V TrailingZeroCount(V v) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 + return V{vec_vctz(v.raw)}; +#else +#if HWY_S390X_HAVE_Z14 + const DFromV d; + const RebindToUnsigned du; + +#if HWY_COMPILER_GCC_ACTUAL && defined(__OPTIMIZE__) + // Work around for GCC compiler bug in vec_cnttz on Z14/Z15 if v[i] is a + // constant + __asm__("" : "+v"(v.raw)); +#endif + + return BitCast(d, VFromD{vec_cnttz(BitCast(du, v).raw)}); +#else + return V{vec_cnttz(v.raw)}; +#endif // HWY_S390X_HAVE_Z14 +#endif // HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 700 +} +#else +template +HWY_API V TrailingZeroCount(V v) { + const DFromV d; + const RebindToSigned di; + using TI = TFromD; + + const auto vi = BitCast(di, v); + const auto lowest_bit = And(vi, Neg(vi)); + constexpr TI kNumOfBitsInT{sizeof(TI) * 8}; + const auto bit_idx = HighestSetBitIndex(lowest_bit); + return BitCast(d, IfThenElse(MaskFromVec(BroadcastSignBit(bit_idx)), + Set(di, kNumOfBitsInT), bit_idx)); +} +#endif + +#undef HWY_PPC_HAVE_9 +#undef HWY_PPC_HAVE_10 +#undef HWY_S390X_HAVE_Z14 +#undef HWY_S390X_HAVE_Z15 + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/lib/highway/hwy/ops/rvv-inl.h b/lib/highway/hwy/ops/rvv-inl.h new file mode 100644 index 00000000000..2f530d3bae0 --- /dev/null +++ b/lib/highway/hwy/ops/rvv-inl.h @@ -0,0 +1,6694 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RISC-V V vectors (length not known at compile time). +// External include guard in highway.h - see comment there. + +#pragma push_macro("__riscv_v_elen") + +// Workaround that ensures that all of the __riscv_vsetvl_* and +// __riscv_vsetvlmax_* macros in riscv_vector.h are defined when compiling with +// Clang 20 with dynamic dispatch and a baseline target of SCALAR or EMU128 +#if HWY_COMPILER_CLANG >= 2000 && HWY_COMPILER_CLANG < 2100 && \ + (!defined(__riscv_v_elen) || __riscv_v_elen < 64) +#undef __riscv_v_elen +#define __riscv_v_elen 64 +#endif + +#include + +#pragma pop_macro("__riscv_v_elen") + +#include "hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Support for vfloat16m*_t and PromoteTo/DemoteTo. +#ifdef __riscv_zvfhmin +#define HWY_RVV_HAVE_F16C 1 +#else +#define HWY_RVV_HAVE_F16C 0 +#endif + +template +struct DFromV_t {}; // specialized in macros +template +using DFromV = typename DFromV_t>::type; + +template +using TFromV = TFromD>; + +template +constexpr size_t MLenFromD(Simd /* tag */) { + // Returns divisor = type bits / LMUL. Folding *8 into the ScaleByPower + // argument enables fractional LMUL < 1. Limit to 64 because that is the + // largest value for which vbool##_t are defined. + return HWY_MIN(64, sizeof(T) * 8 * 8 / detail::ScaleByPower(8, kPow2)); +} + +namespace detail { + +template +class AdjustSimdTagToMinVecPow2_t {}; + +template +class AdjustSimdTagToMinVecPow2_t> { + private: + using D = Simd; + static constexpr int kMinVecPow2 = + -3 + static_cast(FloorLog2(sizeof(T))); + static constexpr size_t kNumMaxLanes = HWY_MAX_LANES_D(D); + static constexpr int kNewPow2 = HWY_MAX(kPow2, kMinVecPow2); + static constexpr size_t kNewN = D::template NewN(); + + public: + using type = Simd; +}; + +template +using AdjustSimdTagToMinVecPow2 = + typename AdjustSimdTagToMinVecPow2_t>::type; + +} // namespace detail + +// ================================================== MACROS + +// Generate specializations and function definitions using X macros. Although +// harder to read and debug, writing everything manually is too bulky. + +namespace detail { // for code folding + +// For all mask sizes MLEN: (1/Nth of a register, one bit per lane) +// The first three arguments are arbitrary SEW, LMUL, SHIFT such that +// SEW >> SHIFT = MLEN. +#define HWY_RVV_FOREACH_B(X_MACRO, NAME, OP) \ + X_MACRO(64, 0, 64, NAME, OP) \ + X_MACRO(32, 0, 32, NAME, OP) \ + X_MACRO(16, 0, 16, NAME, OP) \ + X_MACRO(8, 0, 8, NAME, OP) \ + X_MACRO(8, 1, 4, NAME, OP) \ + X_MACRO(8, 2, 2, NAME, OP) \ + X_MACRO(8, 3, 1, NAME, OP) + +// For given SEW, iterate over one of LMULS: _TRUNC, _EXT, _ALL. This allows +// reusing type lists such as HWY_RVV_FOREACH_U for _ALL (the usual case) or +// _EXT (for Combine). To achieve this, we HWY_CONCAT with the LMULS suffix. +// +// Precompute SEW/LMUL => MLEN to allow token-pasting the result. For the same +// reason, also pass the double-width and half SEW and LMUL (suffixed D and H, +// respectively). "__" means there is no corresponding LMUL (e.g. LMULD for m8). +// Args: BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, MLEN, NAME, OP + +// LMULS = _TRUNC: truncatable (not the smallest LMUL) +#define HWY_RVV_FOREACH_08_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP) + +#define HWY_RVV_FOREACH_16_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_32_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_64_TRUNC(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) + +#define HWY_RVV_FOREACH_08_GET_SET(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP) + +#define HWY_RVV_FOREACH_16_GET_SET(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_32_GET_SET(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_64_GET_SET(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) + +// LMULS = _DEMOTE: can demote from SEW*LMUL to SEWH*LMULH. +#define HWY_RVV_FOREACH_08_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP) + +#define HWY_RVV_FOREACH_16_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf4, mf2, mf8, -2, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_32_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, mf2, m1, mf4, -1, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_64_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m1, m2, mf2, 0, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) + +// LMULS = _LE2: <= 2 +#define HWY_RVV_FOREACH_08_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf8, mf4, __, -3, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m2, m4, m1, 1, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_16_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf4, mf2, mf8, -2, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m2, m4, m1, 1, /*MLEN=*/8, NAME, OP) + +#define HWY_RVV_FOREACH_32_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, mf2, m1, mf4, -1, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m2, m4, m1, 1, /*MLEN=*/16, NAME, OP) + +#define HWY_RVV_FOREACH_64_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m1, m2, mf2, 0, /*MLEN=*/64, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m2, m4, m1, 1, /*MLEN=*/32, NAME, OP) + +// LMULS = _EXT: not the largest LMUL +#define HWY_RVV_FOREACH_08_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m4, m8, m2, 2, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_16_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m4, m8, m2, 2, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_32_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m4, m8, m2, 2, /*MLEN=*/8, NAME, OP) + +#define HWY_RVV_FOREACH_64_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m4, m8, m2, 2, /*MLEN=*/16, NAME, OP) + +// LMULS = _ALL (2^MinPow2() <= LMUL <= 8) +#define HWY_RVV_FOREACH_08_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m8, __, m4, 3, /*MLEN=*/1, NAME, OP) + +#define HWY_RVV_FOREACH_16_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m8, __, m4, 3, /*MLEN=*/2, NAME, OP) + +#define HWY_RVV_FOREACH_32_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m8, __, m4, 3, /*MLEN=*/4, NAME, OP) + +#define HWY_RVV_FOREACH_64_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m8, __, m4, 3, /*MLEN=*/8, NAME, OP) + +// 'Virtual' LMUL. This upholds the Highway guarantee that vectors are at least +// 128 bit and LowerHalf is defined whenever there are at least 2 lanes, even +// though RISC-V LMUL must be at least SEW/64 (notice that this rules out +// LMUL=1/2 for SEW=64). To bridge the gap, we add overloads for kPow2 equal to +// one less than should be supported, with all other parameters (vector type +// etc.) unchanged. For D with the lowest kPow2 ('virtual LMUL'), Lanes() +// returns half of what it usually would. +// +// Notice that we can only add overloads whenever there is a D argument: those +// are unique with respect to non-virtual-LMUL overloads because their kPow2 +// template argument differs. Otherwise, there is no actual vuint64mf2_t, and +// defining another overload with the same LMUL would be an error. Thus we have +// a separate _VIRT category for HWY_RVV_FOREACH*, and the common case is +// _ALL_VIRT (meaning the regular LMUL plus the VIRT overloads), used in most +// functions that take a D. + +#define HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf4, mf2, mf8, -3, /*MLEN=*/64, NAME, OP) + +#define HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, mf2, m1, mf4, -2, /*MLEN=*/64, NAME, OP) + +#define HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m1, m2, mf2, -1, /*MLEN=*/64, NAME, OP) + +// ALL + VIRT +#define HWY_RVV_FOREACH_08_ALL_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_ALL_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_32_ALL_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_64_ALL_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_ALL(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// LE2 + VIRT +#define HWY_RVV_FOREACH_08_LE2_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_LE2_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_32_LE2_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_64_LE2_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_LE2(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// GET/SET + VIRT +#define HWY_RVV_FOREACH_08_GET_SET_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf4, mf2, mf8, -2, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf2, m1, mf4, -1, /*MLEN=*/16, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, m1, m2, mf2, 0, /*MLEN=*/8, NAME, OP) + +#define HWY_RVV_FOREACH_16_GET_SET_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf2, m1, mf4, -1, /*MLEN=*/32, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, m1, m2, mf2, 0, /*MLEN=*/16, NAME, OP) + +#define HWY_RVV_FOREACH_32_GET_SET_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, m1, m2, mf2, 0, /*MLEN=*/32, NAME, OP) + +#define HWY_RVV_FOREACH_64_GET_SET_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// For the smallest LMUL for each SEW, similar to the LowerHalf operator, we +// provide the Get and Set operator that returns the same vector type. +#define HWY_RVV_FOREACH_08_GET_SET_SMALLEST(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 8, 16, __, mf8, mf4, __, -3, /*MLEN=*/64, NAME, OP) + +#define HWY_RVV_FOREACH_16_GET_SET_SMALLEST(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 16, 32, 8, mf4, mf2, mf8, -2, /*MLEN=*/64, NAME, OP) + +#define HWY_RVV_FOREACH_32_GET_SET_SMALLEST(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 32, 64, 16, mf2, m1, mf4, -1, /*MLEN=*/64, NAME, OP) + +#define HWY_RVV_FOREACH_64_GET_SET_SMALLEST(X_MACRO, BASE, CHAR, NAME, OP) \ + X_MACRO(BASE, CHAR, 64, __, 32, m1, m2, mf2, 0, /*MLEN=*/64, NAME, OP) + +// EXT + VIRT +#define HWY_RVV_FOREACH_08_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_32_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_64_EXT_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_EXT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// DEMOTE + VIRT +#define HWY_RVV_FOREACH_08_DEMOTE_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_08_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_16_DEMOTE_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_16_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_32_DEMOTE_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_32_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +#define HWY_RVV_FOREACH_64_DEMOTE_VIRT(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_DEMOTE(X_MACRO, BASE, CHAR, NAME, OP) \ + HWY_RVV_FOREACH_64_VIRT(X_MACRO, BASE, CHAR, NAME, OP) + +// SEW for unsigned: +#define HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_08, LMULS)(X_MACRO, uint, u, NAME, OP) +#define HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_16, LMULS)(X_MACRO, uint, u, NAME, OP) +#define HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_32, LMULS)(X_MACRO, uint, u, NAME, OP) +#define HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_64, LMULS)(X_MACRO, uint, u, NAME, OP) + +// SEW for signed: +#define HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_08, LMULS)(X_MACRO, int, i, NAME, OP) +#define HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_16, LMULS)(X_MACRO, int, i, NAME, OP) +#define HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_32, LMULS)(X_MACRO, int, i, NAME, OP) +#define HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_64, LMULS)(X_MACRO, int, i, NAME, OP) + +// SEW for float: + +// Used for conversion instructions if HWY_RVV_HAVE_F16C. +#define HWY_RVV_FOREACH_F16_UNCONDITIONAL(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_16, LMULS)(X_MACRO, float, f, NAME, OP) + +#if HWY_HAVE_FLOAT16 +// Full support for f16 in all ops +#define HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F16_UNCONDITIONAL(X_MACRO, NAME, OP, LMULS) +// Only BF16 is emulated. +#define HWY_RVV_IF_EMULATED_D(D) HWY_IF_BF16_D(D) +#define HWY_GENERIC_IF_EMULATED_D(D) HWY_IF_BF16_D(D) +#define HWY_RVV_IF_NOT_EMULATED_D(D) HWY_IF_NOT_BF16_D(D) +#else +#define HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP, LMULS) +#define HWY_RVV_IF_EMULATED_D(D) HWY_IF_SPECIAL_FLOAT_D(D) +#define HWY_GENERIC_IF_EMULATED_D(D) HWY_IF_SPECIAL_FLOAT_D(D) +#define HWY_RVV_IF_NOT_EMULATED_D(D) HWY_IF_NOT_SPECIAL_FLOAT_D(D) +#endif +#define HWY_RVV_FOREACH_F32(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_32, LMULS)(X_MACRO, float, f, NAME, OP) +#define HWY_RVV_FOREACH_F64(X_MACRO, NAME, OP, LMULS) \ + HWY_CONCAT(HWY_RVV_FOREACH_64, LMULS)(X_MACRO, float, f, NAME, OP) + +// Commonly used type/SEW groups: +#define HWY_RVV_FOREACH_UI08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI64(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI3264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_UI32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_UI64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_U163264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_I163264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I64(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_UI163264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U163264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I163264(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_F3264(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F32(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F64(X_MACRO, NAME, OP, LMULS) + +// For all combinations of SEW: +#define HWY_RVV_FOREACH_U(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U163264(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_I(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I08(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I163264(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH_F(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F16(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F3264(X_MACRO, NAME, OP, LMULS) + +// Commonly used type categories: +#define HWY_RVV_FOREACH_UI(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_U(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_I(X_MACRO, NAME, OP, LMULS) + +#define HWY_RVV_FOREACH(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_UI(X_MACRO, NAME, OP, LMULS) \ + HWY_RVV_FOREACH_F(X_MACRO, NAME, OP, LMULS) + +// Assemble types for use in x-macros +#define HWY_RVV_T(BASE, SEW) BASE##SEW##_t +#define HWY_RVV_D(BASE, SEW, N, SHIFT) Simd +#define HWY_RVV_V(BASE, SEW, LMUL) v##BASE##SEW##LMUL##_t +#define HWY_RVV_TUP(BASE, SEW, LMUL, TUP) v##BASE##SEW##LMUL##x##TUP##_t +#define HWY_RVV_M(MLEN) vbool##MLEN##_t + +} // namespace detail + +// Until we have full intrinsic support for fractional LMUL, mixed-precision +// code can use LMUL 1..8 (adequate unless they need many registers). +#define HWY_SPECIALIZE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template <> \ + struct DFromV_t { \ + using Lane = HWY_RVV_T(BASE, SEW); \ + using type = ScalableTag; \ + }; + +HWY_RVV_FOREACH(HWY_SPECIALIZE, _, _, _ALL) +#undef HWY_SPECIALIZE + +// ------------------------------ Lanes + +// WARNING: we want to query VLMAX/sizeof(T), but this may actually change VL! + +#if HWY_COMPILER_GCC && !HWY_IS_DEBUG_BUILD +// HWY_RVV_CAPPED_LANES_SPECIAL_CASES provides some additional optimizations +// to CappedLanes in non-debug builds +#define HWY_RVV_CAPPED_LANES_SPECIAL_CASES(BASE, SEW, LMUL) \ + if (__builtin_constant_p(cap >= kMaxLanes) && (cap >= kMaxLanes)) { \ + /* If cap is known to be greater than or equal to MaxLanes(d), */ \ + /* HWY_MIN(cap, Lanes(d)) will be equal to Lanes(d) */ \ + return Lanes(d); \ + } \ + \ + if ((__builtin_constant_p((cap & (cap - 1)) == 0) && \ + ((cap & (cap - 1)) == 0)) || \ + (__builtin_constant_p(cap <= HWY_MAX(kMinLanesPerFullVec, 4)) && \ + (cap <= HWY_MAX(kMinLanesPerFullVec, 4)))) { \ + /* If cap is known to be a power of 2, then */ \ + /* vsetvl(HWY_MIN(cap, kMaxLanes)) is guaranteed to return the same */ \ + /* result as HWY_MIN(cap, Lanes(d)) as kMaxLanes is a power of 2 and */ \ + /* as (cap > VLMAX && cap < 2 * VLMAX) can only be true if cap is not a */ \ + /* power of 2 since VLMAX is always a power of 2 */ \ + \ + /* If cap is known to be less than or equal to 4, then */ \ + /* vsetvl(HWY_MIN(cap, kMaxLanes)) is guaranteed to return the same */ \ + /* result as HWY_MIN(cap, Lanes(d)) as HWY_MIN(cap, kMaxLanes) <= 4 is */ \ + /* true if cap <= 4 and as vsetvl(HWY_MIN(cap, kMaxLanes)) is */ \ + /* guaranteed to return the same result as HWY_MIN(cap, Lanes(d)) */ \ + /* if HWY_MIN(cap, kMaxLanes) <= 4 is true */ \ + \ + /* If cap is known to be less than or equal to kMinLanesPerFullVec, */ \ + /* then vsetvl(HWY_MIN(cap, kMaxLanes)) is guaranteed to return the */ \ + /* same result as HWY_MIN(cap, Lanes(d)) as */ \ + /* HWY_MIN(cap, kMaxLanes) <= kMinLanesPerFullVec is true if */ \ + /* cap <= kMinLanesPerFullVec is true */ \ + \ + /* If cap <= HWY_MAX(kMinLanesPerFullVec, 4) is true, then either */ \ + /* cap <= 4 or cap <= kMinLanesPerFullVec must be true */ \ + \ + /* If cap <= HWY_MAX(kMinLanesPerFullVec, 4) is known to be true, */ \ + /* then vsetvl(HWY_MIN(cap, kMaxLanes)) is guaranteed to return the */ \ + /* same result as HWY_MIN(cap, Lanes(d)) */ \ + \ + /* If no cap, avoid the HWY_MIN. */ \ + return detail::IsFull(d) \ + ? __riscv_vsetvl_e##SEW##LMUL(cap) \ + : __riscv_vsetvl_e##SEW##LMUL(HWY_MIN(cap, kMaxLanes)); \ + } +#else +#define HWY_RVV_CAPPED_LANES_SPECIAL_CASES(BASE, SEW, LMUL) +#endif + +#define HWY_RVV_LANES(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API size_t NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d) { \ + constexpr size_t kFull = HWY_LANES(HWY_RVV_T(BASE, SEW)); \ + constexpr size_t kCap = MaxLanes(d); \ + /* If no cap, avoid generating a constant by using VLMAX. */ \ + return N == kFull ? __riscv_vsetvlmax_e##SEW##LMUL() \ + : __riscv_vsetvl_e##SEW##LMUL(kCap); \ + } \ + template \ + HWY_API size_t Capped##NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, size_t cap) { \ + /* NOTE: Section 6.3 of the RVV specification, which can be found at */ \ + /* https://github.com/riscv/riscv-v-spec/blob/master/v-spec.adoc, */ \ + /* allows vsetvl to return a result less than Lanes(d) but greater than */ \ + /* or equal to ((cap + 1) / 2) if */ \ + /* (Lanes(d) > 2 && cap > HWY_MAX(Lanes(d), 4) && cap < (2 * Lanes(d))) */ \ + /* is true */ \ + \ + /* VLMAX is the number of lanes in a vector of type */ \ + /* VFromD, which is returned by */ \ + /* Lanes(DFromV>()) */ \ + \ + /* VLMAX is guaranteed to be a power of 2 under Section 2 of the RVV */ \ + /* specification */ \ + \ + /* The VLMAX of a vector of type VFromD is at least 2 as */ \ + /* the HWY_RVV target requires support for the RVV Zvl128b extension, */ \ + /* which guarantees that vectors with LMUL=1 are at least 16 bytes */ \ + \ + /* If VLMAX == 2 is true, then vsetvl(cap) is equal to HWY_MIN(cap, 2) */ \ + /* as cap == 3 is the only value such that */ \ + /* (cap > VLMAX && cap < 2 * VLMAX) if VLMAX == 2 and as */ \ + /* ((3 + 1) / 2) is equal to 2 */ \ + \ + /* If cap <= 4 is true, then vsetvl(cap) must be equal to */ \ + /* HWY_MIN(cap, VLMAX) as cap <= VLMAX is true if VLMAX >= 4 is true */ \ + /* and as vsetvl(cap) is guaranteed to be equal to HWY_MIN(cap, VLMAX) */ \ + /* if VLMAX == 2 */ \ + \ + /* We want CappedLanes(d, cap) to return Lanes(d) if cap > Lanes(d) as */ \ + /* LoadN(d, p, cap) expects to load exactly HWY_MIN(cap, Lanes(d)) */ \ + /* lanes and StoreN(v, d, p, cap) expects to store exactly */ \ + /* HWY_MIN(cap, Lanes(d)) lanes, even in the case where vsetvl returns */ \ + /* a result that is less than HWY_MIN(cap, Lanes(d)) */ \ + \ + /* kMinLanesPerFullVec is the minimum value of VLMAX for a vector of */ \ + /* type VFromD */ \ + constexpr size_t kMinLanesPerFullVec = \ + detail::ScaleByPower(16 / (SEW / 8), SHIFT); \ + /* kMaxLanes is the maximum number of lanes returned by Lanes(d) */ \ + constexpr size_t kMaxLanes = MaxLanes(d); \ + \ + HWY_RVV_CAPPED_LANES_SPECIAL_CASES(BASE, SEW, LMUL) \ + \ + if (kMaxLanes <= HWY_MAX(kMinLanesPerFullVec, 4)) { \ + /* If kMaxLanes <= kMinLanesPerFullVec is true, then */ \ + /* vsetvl(HWY_MIN(cap, kMaxLanes)) is guaranteed to return */ \ + /* HWY_MIN(cap, Lanes(d)) as */ \ + /* HWY_MIN(cap, kMaxLanes) <= kMaxLanes <= VLMAX is true if */ \ + /* kMaxLanes <= kMinLanesPerFullVec is true */ \ + \ + /* If kMaxLanes <= 4 is true, then vsetvl(HWY_MIN(cap, kMaxLanes)) is */ \ + /* guaranteed to return the same result as HWY_MIN(cap, Lanes(d)) as */ \ + /* HWY_MIN(cap, kMaxLanes) <= 4 is true if kMaxLanes <= 4 is true */ \ + \ + /* If kMaxLanes <= HWY_MAX(kMinLanesPerFullVec, 4) is true, then */ \ + /* either kMaxLanes <= 4 or kMaxLanes <= kMinLanesPerFullVec must be */ \ + /* true */ \ + \ + return __riscv_vsetvl_e##SEW##LMUL(HWY_MIN(cap, kMaxLanes)); \ + } else { \ + /* If kMaxLanes > HWY_MAX(kMinLanesPerFullVec, 4) is true, need to */ \ + /* obtain the actual number of lanes using Lanes(d) and clamp cap to */ \ + /* the result of Lanes(d) */ \ + const size_t actual = Lanes(d); \ + return HWY_MIN(actual, cap); \ + } \ + } + +#define HWY_RVV_LANES_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API size_t NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d) { \ + constexpr size_t kCap = MaxLanes(d); \ + /* In case of virtual LMUL (intrinsics do not provide "uint16mf8_t") */ \ + /* vsetvl may or may not be correct, so do it ourselves. */ \ + const size_t actual = \ + detail::ScaleByPower(__riscv_vlenb() / (SEW / 8), SHIFT); \ + return HWY_MIN(actual, kCap); \ + } \ + template \ + HWY_API size_t Capped##NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, size_t cap) { \ + /* In case of virtual LMUL (intrinsics do not provide "uint16mf8_t") */ \ + /* vsetvl may or may not be correct, so do it ourselves. */ \ + const size_t actual = \ + detail::ScaleByPower(__riscv_vlenb() / (SEW / 8), SHIFT); \ + /* If no cap, avoid an extra HWY_MIN. */ \ + return detail::IsFull(d) ? HWY_MIN(actual, cap) \ + : HWY_MIN(HWY_MIN(actual, cap), MaxLanes(d)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_LANES, Lanes, setvlmax_e, _ALL) +HWY_RVV_FOREACH(HWY_RVV_LANES_VIRT, Lanes, lenb, _VIRT) +#undef HWY_RVV_LANES +#undef HWY_RVV_LANES_VIRT +#undef HWY_RVV_CAPPED_LANES_SPECIAL_CASES + +template +HWY_API size_t Lanes(D /* tag*/) { + return Lanes(RebindToUnsigned()); +} + +template +HWY_API size_t CappedLanes(D /* tag*/, size_t cap) { + return CappedLanes(RebindToUnsigned(), cap); +} + +// ------------------------------ Common x-macros + +// Last argument to most intrinsics. Use when the op has no d arg of its own, +// which means there is no user-specified cap. +#define HWY_RVV_AVL(SEW, SHIFT) \ + Lanes(ScalableTag()) + +// vector = f(vector), e.g. Not +#define HWY_RVV_RETV_ARGV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL(v, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// vector = f(vector, scalar), e.g. detail::AddS +#define HWY_RVV_RETV_ARGVS(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_T(BASE, SEW) b) { \ + return __riscv_v##OP##_##CHAR##SEW##LMUL(a, b, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// vector = f(vector, vector), e.g. Add +#define HWY_RVV_RETV_ARGVV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL(a, b, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// vector = f(vector, mask, vector, vector), e.g. MaskedAddOr +#define HWY_RVV_RETV_ARGMVV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) no, HWY_RVV_M(MLEN) m, \ + HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL##_mu(m, no, a, b, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// mask = f(mask) +#define HWY_RVV_RETM_ARGM(SEW, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) NAME(HWY_RVV_M(MLEN) m) { \ + return __riscv_vm##OP##_m_b##MLEN(m, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// ================================================== INIT + +// ------------------------------ Set + +#define HWY_RVV_SET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_T(BASE, SEW) arg) { \ + return __riscv_v##OP##_##CHAR##SEW##LMUL(arg, Lanes(d)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_SET, Set, mv_v_x, _ALL_VIRT) +HWY_RVV_FOREACH_F(HWY_RVV_SET, Set, fmv_v_f, _ALL_VIRT) +#undef HWY_RVV_SET + +// Treat bfloat16_t as int16_t (using the previously defined Set overloads); +// required for Zero and VFromD. +template +decltype(Set(RebindToSigned(), 0)) Set(D d, hwy::bfloat16_t arg) { + return Set(RebindToSigned(), BitCastScalar(arg)); +} +#if !HWY_HAVE_FLOAT16 // Otherwise already defined above. +// WARNING: returns a different type than emulated bfloat16_t so that we can +// implement PromoteTo overloads for both bfloat16_t and float16_t, and also +// provide a Neg(hwy::float16_t) overload that coexists with Neg(int16_t). +template +decltype(Set(RebindToUnsigned(), 0)) Set(D d, hwy::float16_t arg) { + return Set(RebindToUnsigned(), BitCastScalar(arg)); +} +#endif + +template +using VFromD = decltype(Set(D(), TFromD())); + +// ------------------------------ Zero + +template +HWY_API VFromD Zero(D d) { + // Cast to support bfloat16_t. + const RebindToUnsigned du; + return BitCast(d, Set(du, 0)); +} + +// ------------------------------ Undefined + +// RVV vundefined is 'poisoned' such that even XORing a _variable_ initialized +// by it gives unpredictable results. It should only be used for maskoff, so +// keep it internal. For the Highway op, just use Zero (single instruction). +namespace detail { +#define HWY_RVV_UNDEFINED(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) /* tag */) { \ + return __riscv_v##OP##_##CHAR##SEW##LMUL(); /* no AVL */ \ + } + +HWY_RVV_FOREACH(HWY_RVV_UNDEFINED, Undefined, undefined, _ALL) +#undef HWY_RVV_UNDEFINED +} // namespace detail + +template +HWY_API VFromD Undefined(D d) { + return Zero(d); +} + +// ------------------------------ BitCast + +namespace detail { + +// Halves LMUL. (Use LMUL arg for the source so we can use _TRUNC.) +#define HWY_RVV_TRUNC(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMULH) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMULH( \ + v); /* no AVL */ \ + } +HWY_RVV_FOREACH(HWY_RVV_TRUNC, Trunc, lmul_trunc, _TRUNC) +#undef HWY_RVV_TRUNC + +// Doubles LMUL to `d2` (the arg is only necessary for _VIRT). +#define HWY_RVV_EXT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMULD) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT + 1) /* d2 */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMULD( \ + v); /* no AVL */ \ + } +HWY_RVV_FOREACH(HWY_RVV_EXT, Ext, lmul_ext, _EXT) +#undef HWY_RVV_EXT + +// For virtual LMUL e.g. 'uint32mf4_t', the return type should be mf2, which is +// the same as the actual input type. +#define HWY_RVV_EXT_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT + 1) /* d2 */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return v; \ + } +HWY_RVV_FOREACH(HWY_RVV_EXT_VIRT, Ext, lmul_ext, _VIRT) +#undef HWY_RVV_EXT_VIRT + +template +VFromD Ext(D d, VFromD> v) { + const RebindToUnsigned du; + const Half duh; + return BitCast(d, Ext(du, BitCast(duh, v))); +} + +// For BitCastToByte, the D arg is only to prevent duplicate definitions caused +// by _ALL_VIRT. + +// There is no reinterpret from u8 <-> u8, so just return. +#define HWY_RVV_CAST_U8(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMUL##_t BitCastToByte(Simd /* d */, \ + vuint8##LMUL##_t v) { \ + return v; \ + } \ + template \ + HWY_API vuint8##LMUL##_t BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return v; \ + } + +// For i8, need a single reinterpret (HWY_RVV_CAST_IF does two). +#define HWY_RVV_CAST_I8(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMUL##_t BitCastToByte(Simd /* d */, \ + vint8##LMUL##_t v) { \ + return __riscv_vreinterpret_v_i8##LMUL##_u8##LMUL(v); \ + } \ + template \ + HWY_API vint8##LMUL##_t BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return __riscv_vreinterpret_v_u8##LMUL##_i8##LMUL(v); \ + } + +// Separate u/i because clang only provides signed <-> unsigned reinterpret for +// the same SEW. +#define HWY_RVV_CAST_U(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMUL##_t BitCastToByte(Simd /* d */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL##_u8##LMUL(v); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return __riscv_v##OP##_v_u8##LMUL##_##CHAR##SEW##LMUL(v); \ + } + +// Signed/Float: first cast to/from unsigned +#define HWY_RVV_CAST_IF(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMUL##_t BitCastToByte(Simd /* d */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_v_u##SEW##LMUL##_u8##LMUL( \ + __riscv_v##OP##_v_##CHAR##SEW##LMUL##_u##SEW##LMUL(v)); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMUL##_t v) { \ + return __riscv_v##OP##_v_u##SEW##LMUL##_##CHAR##SEW##LMUL( \ + __riscv_v##OP##_v_u8##LMUL##_u##SEW##LMUL(v)); \ + } + +// Additional versions for virtual LMUL using LMULH for byte vectors. +#define HWY_RVV_CAST_VIRT_U(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMULH##_t BitCastToByte(Simd /* d */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return detail::Trunc(__riscv_v##OP##_v_##CHAR##SEW##LMUL##_u8##LMUL(v)); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMULH##_t v) { \ + HWY_RVV_D(uint, 8, N, SHIFT + 1) d2; \ + const vuint8##LMUL##_t v2 = detail::Ext(d2, v); \ + return __riscv_v##OP##_v_u8##LMUL##_##CHAR##SEW##LMUL(v2); \ + } + +// Signed/Float: first cast to/from unsigned +#define HWY_RVV_CAST_VIRT_IF(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API vuint8##LMULH##_t BitCastToByte(Simd /* d */, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return detail::Trunc(__riscv_v##OP##_v_u##SEW##LMUL##_u8##LMUL( \ + __riscv_v##OP##_v_##CHAR##SEW##LMUL##_u##SEW##LMUL(v))); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) BitCastFromByte( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, vuint8##LMULH##_t v) { \ + HWY_RVV_D(uint, 8, N, SHIFT + 1) d2; \ + const vuint8##LMUL##_t v2 = detail::Ext(d2, v); \ + return __riscv_v##OP##_v_u##SEW##LMUL##_##CHAR##SEW##LMUL( \ + __riscv_v##OP##_v_u8##LMUL##_u##SEW##LMUL(v2)); \ + } + +HWY_RVV_FOREACH_U08(HWY_RVV_CAST_U8, _, reinterpret, _ALL) +HWY_RVV_FOREACH_I08(HWY_RVV_CAST_I8, _, reinterpret, _ALL) +HWY_RVV_FOREACH_U163264(HWY_RVV_CAST_U, _, reinterpret, _ALL) +HWY_RVV_FOREACH_I163264(HWY_RVV_CAST_IF, _, reinterpret, _ALL) +HWY_RVV_FOREACH_U163264(HWY_RVV_CAST_VIRT_U, _, reinterpret, _VIRT) +HWY_RVV_FOREACH_I163264(HWY_RVV_CAST_VIRT_IF, _, reinterpret, _VIRT) +HWY_RVV_FOREACH_F(HWY_RVV_CAST_IF, _, reinterpret, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_CAST_VIRT_IF, _, reinterpret, _VIRT) +#if HWY_HAVE_FLOAT16 // HWY_RVV_FOREACH_F already covered float16_ +#elif HWY_RVV_HAVE_F16C // zvfhmin provides reinterpret* intrinsics: +HWY_RVV_FOREACH_F16_UNCONDITIONAL(HWY_RVV_CAST_IF, _, reinterpret, _ALL) +HWY_RVV_FOREACH_F16_UNCONDITIONAL(HWY_RVV_CAST_VIRT_IF, _, reinterpret, _VIRT) +#else +template +HWY_INLINE VFromD> BitCastFromByte( + D /* d */, VFromD> v) { + return BitCastFromByte(RebindToUnsigned(), v); +} +#endif + +#undef HWY_RVV_CAST_U8 +#undef HWY_RVV_CAST_I8 +#undef HWY_RVV_CAST_U +#undef HWY_RVV_CAST_IF +#undef HWY_RVV_CAST_VIRT_U +#undef HWY_RVV_CAST_VIRT_IF + +template +HWY_INLINE VFromD> BitCastFromByte( + D d, VFromD> v) { + return BitCastFromByte(RebindToSigned(), v); +} + +} // namespace detail + +template +HWY_API VFromD BitCast(D d, FromV v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(d, v)); +} + +// ------------------------------ Iota + +namespace detail { + +#define HWY_RVV_IOTA(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d) { \ + return __riscv_v##OP##_##CHAR##SEW##LMUL(Lanes(d)); \ + } + +// For i8 lanes, this may well wrap around. Unsigned only is less error-prone. +HWY_RVV_FOREACH_U(HWY_RVV_IOTA, Iota0, id_v, _ALL_VIRT) +#undef HWY_RVV_IOTA + +// Used by Expand. +#define HWY_RVV_MASKED_IOTA(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_M(MLEN) mask) { \ + return __riscv_v##OP##_##CHAR##SEW##LMUL(mask, Lanes(d)); \ + } + +HWY_RVV_FOREACH_U(HWY_RVV_MASKED_IOTA, MaskedIota, iota_m, _ALL_VIRT) +#undef HWY_RVV_MASKED_IOTA + +} // namespace detail + +// ================================================== LOGICAL + +// ------------------------------ Not + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGV, Not, not, _ALL) + +template +HWY_API V Not(const V v) { + using DF = DFromV; + using DU = RebindToUnsigned; + return BitCast(DF(), Not(BitCast(DU(), v))); +} + +// ------------------------------ And + +// Non-vector version (ideally immediate) for use with Iota0 +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, AndS, and_vx, _ALL) +} // namespace detail + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, And, and, _ALL) + +template +HWY_API V And(const V a, const V b) { + using DF = DFromV; + using DU = RebindToUnsigned; + return BitCast(DF(), And(BitCast(DU(), a), BitCast(DU(), b))); +} + +// ------------------------------ Or + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Or, or, _ALL) + +template +HWY_API V Or(const V a, const V b) { + using DF = DFromV; + using DU = RebindToUnsigned; + return BitCast(DF(), Or(BitCast(DU(), a), BitCast(DU(), b))); +} + +// ------------------------------ Xor + +// Non-vector version (ideally immediate) for use with Iota0 +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, XorS, xor_vx, _ALL) +} // namespace detail + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Xor, xor, _ALL) + +template +HWY_API V Xor(const V a, const V b) { + using DF = DFromV; + using DU = RebindToUnsigned; + return BitCast(DF(), Xor(BitCast(DU(), a), BitCast(DU(), b))); +} + +// ------------------------------ AndNot +template +HWY_API V AndNot(const V not_a, const V b) { + return And(Not(not_a), b); +} + +// ------------------------------ Or3 +template +HWY_API V Or3(V o1, V o2, V o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd +template +HWY_API V OrAnd(const V o, const V a1, const V a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ CopySign + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, CopySign, fsgnj, _ALL) + +template +HWY_API V CopySignToAbs(const V abs, const V sign) { + // RVV can also handle abs < 0, so no extra action needed. + return CopySign(abs, sign); +} + +// ================================================== ARITHMETIC + +// Per-target flags to prevent generic_ops-inl.h defining Add etc. +#ifdef HWY_NATIVE_OPERATOR_REPLACEMENTS +#undef HWY_NATIVE_OPERATOR_REPLACEMENTS +#else +#define HWY_NATIVE_OPERATOR_REPLACEMENTS +#endif + +// ------------------------------ Add + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, AddS, add_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, AddS, fadd_vf, _ALL) +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, ReverseSubS, rsub_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, ReverseSubS, frsub_vf, _ALL) +} // namespace detail + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Add, add, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Add, fadd, _ALL) + +// ------------------------------ Sub +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVS, SubS, sub_vx, _ALL) +} // namespace detail + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Sub, sub, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Sub, fsub, _ALL) + +// ------------------------------ Neg (ReverseSubS, Xor) + +template +HWY_API V Neg(const V v) { + return detail::ReverseSubS(v, 0); +} + +// vector = f(vector), but argument is repeated +#define HWY_RVV_RETV_ARGV2(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL(v, v, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV2, Neg, fsgnjn, _ALL) + +#if !HWY_HAVE_FLOAT16 + +template )> // hwy::float16_t +HWY_API V Neg(V v) { + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; + return BitCast(d, Xor(BitCast(du, v), Set(du, SignMask()))); +} + +#endif // !HWY_HAVE_FLOAT16 + +// ------------------------------ SaturatedAdd + +#ifdef HWY_NATIVE_I32_SATURATED_ADDSUB +#undef HWY_NATIVE_I32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I32_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_U32_SATURATED_ADDSUB +#undef HWY_NATIVE_U32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_U32_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_I64_SATURATED_ADDSUB +#undef HWY_NATIVE_I64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I64_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_U64_SATURATED_ADDSUB +#undef HWY_NATIVE_U64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_U64_SATURATED_ADDSUB +#endif + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, SaturatedAdd, saddu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, SaturatedAdd, sadd, _ALL) + +// ------------------------------ SaturatedSub + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, SaturatedSub, ssubu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, SaturatedSub, ssub, _ALL) + +// ------------------------------ AverageRound + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 +#undef HWY_NATIVE_AVERAGE_ROUND_UI32 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI32 +#endif + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64 +#undef HWY_NATIVE_AVERAGE_ROUND_UI64 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI64 +#endif + +// Define this to opt-out of the default behavior, which is AVOID on certain +// compiler versions. You can define only this to use VXRM, or define both this +// and HWY_RVV_AVOID_VXRM to always avoid VXRM. +#ifndef HWY_RVV_CHOOSE_VXRM + +// Assume that GCC-13 defaults to 'avoid VXRM'. Tested with GCC 13.1.0. +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400 +#define HWY_RVV_AVOID_VXRM +// Clang 16 with __riscv_v_intrinsic == 11000 may either require VXRM or avoid. +// Assume that Clang 16 and earlier avoid VXRM. +#elif HWY_COMPILER_CLANG && \ + (HWY_COMPILER_CLANG < 1700 || __riscv_v_intrinsic < 11000) +#define HWY_RVV_AVOID_VXRM +#endif + +#endif // HWY_RVV_CHOOSE_VXRM + +// Adding __RISCV_VXRM_* was a backwards-incompatible change and it is not clear +// how to detect whether it is supported or required. #ifdef __RISCV_VXRM_RDN +// does not work because it seems to be a compiler built-in, but neither does +// __has_builtin(__RISCV_VXRM_RDN). The intrinsics version was also not updated, +// so we require a macro to opt out of the new intrinsics. +#ifdef HWY_RVV_AVOID_VXRM +#define HWY_RVV_INSERT_VXRM(vxrm, avl) avl +#define __RISCV_VXRM_RNU +#define __RISCV_VXRM_RDN +#else // default: use new vxrm arguments +#define HWY_RVV_INSERT_VXRM(vxrm, avl) vxrm, avl +#endif + +// Extra rounding mode = up argument. +#define HWY_RVV_RETV_AVERAGE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL( \ + a, b, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RNU, HWY_RVV_AVL(SEW, SHIFT))); \ + } + +HWY_RVV_FOREACH_I(HWY_RVV_RETV_AVERAGE, AverageRound, aadd, _ALL) +HWY_RVV_FOREACH_U(HWY_RVV_RETV_AVERAGE, AverageRound, aaddu, _ALL) + +#undef HWY_RVV_RETV_AVERAGE + +// ------------------------------ ShiftLeft[Same] + +// Intrinsics do not define .vi forms, so use .vx instead. +#define HWY_RVV_SHIFT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_vx_##CHAR##SEW##LMUL(v, kBits, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME##Same(HWY_RVV_V(BASE, SEW, LMUL) v, int bits) { \ + return __riscv_v##OP##_vx_##CHAR##SEW##LMUL(v, static_cast(bits), \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_SHIFT, ShiftLeft, sll, _ALL) + +// ------------------------------ ShiftRight[Same] + +HWY_RVV_FOREACH_U(HWY_RVV_SHIFT, ShiftRight, srl, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_SHIFT, ShiftRight, sra, _ALL) + +#undef HWY_RVV_SHIFT + +// ------------------------------ RoundingShiftRight[Same] + +#ifdef HWY_NATIVE_ROUNDING_SHR +#undef HWY_NATIVE_ROUNDING_SHR +#else +#define HWY_NATIVE_ROUNDING_SHR +#endif + +// Intrinsics do not define .vi forms, so use .vx instead. +#define HWY_RVV_ROUNDING_SHR(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_vx_##CHAR##SEW##LMUL( \ + v, kBits, \ + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RNU, HWY_RVV_AVL(SEW, SHIFT))); \ + } \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME##Same(HWY_RVV_V(BASE, SEW, LMUL) v, int bits) { \ + return __riscv_v##OP##_vx_##CHAR##SEW##LMUL( \ + v, static_cast(bits), \ + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RNU, HWY_RVV_AVL(SEW, SHIFT))); \ + } + +HWY_RVV_FOREACH_U(HWY_RVV_ROUNDING_SHR, RoundingShiftRight, ssrl, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_ROUNDING_SHR, RoundingShiftRight, ssra, _ALL) + +#undef HWY_RVV_ROUNDING_SHR + +// ------------------------------ SumsOf8 (ShiftRight, Add) +template )> +HWY_API VFromD>> SumsOf8(const VU8 v) { + const DFromV du8; + const RepartitionToWide du16; + const RepartitionToWide du32; + const RepartitionToWide du64; + using VU16 = VFromD; + + const VU16 vFDB97531 = ShiftRight<8>(BitCast(du16, v)); + const VU16 vECA86420 = detail::AndS(BitCast(du16, v), 0xFF); + const VU16 sFE_DC_BA_98_76_54_32_10 = Add(vFDB97531, vECA86420); + + const VU16 szz_FE_zz_BA_zz_76_zz_32 = + BitCast(du16, ShiftRight<16>(BitCast(du32, sFE_DC_BA_98_76_54_32_10))); + const VU16 sxx_FC_xx_B8_xx_74_xx_30 = + Add(sFE_DC_BA_98_76_54_32_10, szz_FE_zz_BA_zz_76_zz_32); + const VU16 szz_zz_xx_FC_zz_zz_xx_74 = + BitCast(du16, ShiftRight<32>(BitCast(du64, sxx_FC_xx_B8_xx_74_xx_30))); + const VU16 sxx_xx_xx_F8_xx_xx_xx_70 = + Add(sxx_FC_xx_B8_xx_74_xx_30, szz_zz_xx_FC_zz_zz_xx_74); + return detail::AndS(BitCast(du64, sxx_xx_xx_F8_xx_xx_xx_70), 0xFFFFull); +} + +template )> +HWY_API VFromD>> SumsOf8(const VI8 v) { + const DFromV di8; + const RepartitionToWide di16; + const RepartitionToWide di32; + const RepartitionToWide di64; + const RebindToUnsigned du32; + const RebindToUnsigned du64; + using VI16 = VFromD; + + const VI16 vFDB97531 = ShiftRight<8>(BitCast(di16, v)); + const VI16 vECA86420 = ShiftRight<8>(ShiftLeft<8>(BitCast(di16, v))); + const VI16 sFE_DC_BA_98_76_54_32_10 = Add(vFDB97531, vECA86420); + + const VI16 sDC_zz_98_zz_54_zz_10_zz = + BitCast(di16, ShiftLeft<16>(BitCast(du32, sFE_DC_BA_98_76_54_32_10))); + const VI16 sFC_xx_B8_xx_74_xx_30_xx = + Add(sFE_DC_BA_98_76_54_32_10, sDC_zz_98_zz_54_zz_10_zz); + const VI16 sB8_xx_zz_zz_30_xx_zz_zz = + BitCast(di16, ShiftLeft<32>(BitCast(du64, sFC_xx_B8_xx_74_xx_30_xx))); + const VI16 sF8_xx_xx_xx_70_xx_xx_xx = + Add(sFC_xx_B8_xx_74_xx_30_xx, sB8_xx_zz_zz_30_xx_zz_zz); + return ShiftRight<48>(BitCast(di64, sF8_xx_xx_xx_70_xx_xx_xx)); +} + +// ------------------------------ RotateRight +template +HWY_API V RotateRight(const V v) { + const DFromV d; + const RebindToUnsigned du; + + constexpr size_t kSizeInBits = sizeof(TFromV) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + + return Or(BitCast(d, ShiftRight(BitCast(du, v))), + ShiftLeft(v)); +} + +// ------------------------------ Shl +#define HWY_RVV_SHIFT_VV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, LMUL) bits) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL(v, bits, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_U(HWY_RVV_SHIFT_VV, Shl, sll, _ALL) + +#define HWY_RVV_SHIFT_II(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, LMUL) bits) { \ + const HWY_RVV_D(uint, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), SHIFT) du; \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL(v, BitCast(du, bits), \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_I(HWY_RVV_SHIFT_II, Shl, sll, _ALL) + +// ------------------------------ Shr + +HWY_RVV_FOREACH_U(HWY_RVV_SHIFT_VV, Shr, srl, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_SHIFT_II, Shr, sra, _ALL) + +#undef HWY_RVV_SHIFT_II +#undef HWY_RVV_SHIFT_VV + +// ------------------------------ RoundingShr +#define HWY_RVV_ROUNDING_SHR_VV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, LMUL) bits) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL( \ + v, bits, \ + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RNU, HWY_RVV_AVL(SEW, SHIFT))); \ + } + +HWY_RVV_FOREACH_U(HWY_RVV_ROUNDING_SHR_VV, RoundingShr, ssrl, _ALL) + +#define HWY_RVV_ROUNDING_SHR_II(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(BASE, SEW, LMUL) bits) { \ + const HWY_RVV_D(uint, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), SHIFT) du; \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL( \ + v, BitCast(du, bits), \ + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RNU, HWY_RVV_AVL(SEW, SHIFT))); \ + } + +HWY_RVV_FOREACH_I(HWY_RVV_ROUNDING_SHR_II, RoundingShr, ssra, _ALL) + +#undef HWY_RVV_ROUNDING_SHR_VV +#undef HWY_RVV_ROUNDING_SHR_II + +// ------------------------------ Min + +namespace detail { + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVS, MinS, minu_vx, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVS, MinS, min_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, MinS, fmin_vf, _ALL) + +} // namespace detail + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, Min, minu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, Min, min, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Min, fmin, _ALL) + +// ------------------------------ Max + +namespace detail { + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVS, MaxS, maxu_vx, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVS, MaxS, max_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVS, MaxS, fmax_vf, _ALL) + +} // namespace detail + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, Max, maxu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, Max, max, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Max, fmax, _ALL) + +// ------------------------------ Mul + +// Per-target flags to prevent generic_ops-inl.h defining 8/64-bit operator*. +#ifdef HWY_NATIVE_MUL_8 +#undef HWY_NATIVE_MUL_8 +#else +#define HWY_NATIVE_MUL_8 +#endif +#ifdef HWY_NATIVE_MUL_64 +#undef HWY_NATIVE_MUL_64 +#else +#define HWY_NATIVE_MUL_64 +#endif + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGVV, Mul, mul, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Mul, fmul, _ALL) + +// ------------------------------ MulHigh + +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, MulHigh, mulh, _ALL) +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, MulHigh, mulhu, _ALL) + +// ------------------------------ MulFixedPoint15 + +// Extra rounding mode = up argument. +#define HWY_RVV_MUL15(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL( \ + a, b, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RNU, HWY_RVV_AVL(SEW, SHIFT))); \ + } + +HWY_RVV_FOREACH_I16(HWY_RVV_MUL15, MulFixedPoint15, smul, _ALL) + +#undef HWY_RVV_MUL15 + +// ------------------------------ Div +#ifdef HWY_NATIVE_INT_DIV +#undef HWY_NATIVE_INT_DIV +#else +#define HWY_NATIVE_INT_DIV +#endif + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, Div, divu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, Div, div, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGVV, Div, fdiv, _ALL) + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGVV, Mod, remu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, Mod, rem, _ALL) + +// ------------------------------ MaskedAddOr etc. + +#ifdef HWY_NATIVE_MASKED_ARITH +#undef HWY_NATIVE_MASKED_ARITH +#else +#define HWY_NATIVE_MASKED_ARITH +#endif + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGMVV, MaskedMinOr, minu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGMVV, MaskedMinOr, min, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGMVV, MaskedMinOr, fmin, _ALL) + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGMVV, MaskedMaxOr, maxu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGMVV, MaskedMaxOr, max, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGMVV, MaskedMaxOr, fmax, _ALL) + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGMVV, MaskedAddOr, add, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGMVV, MaskedAddOr, fadd, _ALL) + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGMVV, MaskedSubOr, sub, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGMVV, MaskedSubOr, fsub, _ALL) + +HWY_RVV_FOREACH_UI(HWY_RVV_RETV_ARGMVV, MaskedMulOr, mul, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGMVV, MaskedMulOr, fmul, _ALL) + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGMVV, MaskedDivOr, divu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGMVV, MaskedDivOr, div, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGMVV, MaskedDivOr, fdiv, _ALL) + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGMVV, MaskedModOr, remu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGMVV, MaskedModOr, rem, _ALL) + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGMVV, MaskedSatAddOr, saddu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGMVV, MaskedSatAddOr, sadd, _ALL) + +HWY_RVV_FOREACH_U(HWY_RVV_RETV_ARGMVV, MaskedSatSubOr, ssubu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGMVV, MaskedSatSubOr, ssub, _ALL) + +// ------------------------------ ApproximateReciprocal +#ifdef HWY_NATIVE_F64_APPROX_RECIP +#undef HWY_NATIVE_F64_APPROX_RECIP +#else +#define HWY_NATIVE_F64_APPROX_RECIP +#endif + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV, ApproximateReciprocal, frec7, _ALL) + +// ------------------------------ Sqrt +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV, Sqrt, fsqrt, _ALL) + +// ------------------------------ ApproximateReciprocalSqrt +#ifdef HWY_NATIVE_F64_APPROX_RSQRT +#undef HWY_NATIVE_F64_APPROX_RSQRT +#else +#define HWY_NATIVE_F64_APPROX_RSQRT +#endif + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV, ApproximateReciprocalSqrt, frsqrt7, _ALL) + +// ------------------------------ MulAdd + +// Per-target flag to prevent generic_ops-inl.h from defining int MulAdd. +#ifdef HWY_NATIVE_INT_FMA +#undef HWY_NATIVE_INT_FMA +#else +#define HWY_NATIVE_INT_FMA +#endif + +// Note: op is still named vv, not vvv. +#define HWY_RVV_FMA(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) mul, HWY_RVV_V(BASE, SEW, LMUL) x, \ + HWY_RVV_V(BASE, SEW, LMUL) add) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL(add, mul, x, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_FMA, MulAdd, macc, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_FMA, MulAdd, fmacc, _ALL) + +// ------------------------------ NegMulAdd +HWY_RVV_FOREACH_UI(HWY_RVV_FMA, NegMulAdd, nmsac, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_FMA, NegMulAdd, fnmsac, _ALL) + +// ------------------------------ MulSub +HWY_RVV_FOREACH_F(HWY_RVV_FMA, MulSub, fmsac, _ALL) + +// ------------------------------ NegMulSub +HWY_RVV_FOREACH_F(HWY_RVV_FMA, NegMulSub, fnmacc, _ALL) + +#undef HWY_RVV_FMA + +// ================================================== COMPARE + +// ------------------------------ MClear + +// mask = f() +#define HWY_RVV_RETM(SEW, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) NAME##MLEN() { \ + return __riscv_vm##OP##_m_b##MLEN(HWY_RVV_AVL(SEW, SHIFT)); \ + } + +namespace detail { +HWY_RVV_FOREACH_B(HWY_RVV_RETM, MClear, clr) // with ##MLEN suffix +} // namespace detail + +#undef HWY_RVV_RETM + +// Comparisons set a mask bit to 1 if the condition is true, else 0. The XX in +// vboolXX_t is a power of two divisor for vector bits. SEW=8 / LMUL=1 = 1/8th +// of all bits; SEW=8 / LMUL=4 = half of all bits. + +// mask = f(vector, vector) +#define HWY_RVV_RETM_ARGVV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL##_b##MLEN( \ + a, b, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// mask = f(mask, vector, vector) +#define HWY_RVV_RETM_ARGMVV(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) \ + NAME(HWY_RVV_M(MLEN) m, HWY_RVV_V(BASE, SEW, LMUL) a, \ + HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL##_b##MLEN##_mu( \ + m, detail::MClear##MLEN(), a, b, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// mask = f(vector, scalar) +#define HWY_RVV_RETM_ARGVS(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_T(BASE, SEW) b) { \ + return __riscv_v##OP##_##CHAR##SEW##LMUL##_b##MLEN( \ + a, b, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +#ifdef HWY_NATIVE_MASKED_COMP +#undef HWY_NATIVE_MASKED_COMP +#else +#define HWY_NATIVE_MASKED_COMP +#endif + +// ------------------------------ Eq +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVV, Eq, mseq, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Eq, mfeq, _ALL) +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGMVV, MaskedEq, mseq, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGMVV, MaskedEq, mfeq, _ALL) + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVS, EqS, mseq_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, EqS, mfeq_vf, _ALL) +} // namespace detail + +// ------------------------------ Ne +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVV, Ne, msne, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Ne, mfne, _ALL) +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGMVV, MaskedNe, msne, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGMVV, MaskedNe, mfne, _ALL) + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_RETM_ARGVS, NeS, msne_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, NeS, mfne_vf, _ALL) +} // namespace detail + +// ------------------------------ Lt +HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGVV, Lt, msltu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGVV, Lt, mslt, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Lt, mflt, _ALL) +HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGMVV, MaskedLt, msltu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGMVV, MaskedLt, mslt, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGMVV, MaskedLt, mflt, _ALL) + +namespace detail { +HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGVS, LtS, mslt_vx, _ALL) +HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGVS, LtS, msltu_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, LtS, mflt_vf, _ALL) +} // namespace detail + +// ------------------------------ Le +HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGVV, Le, msleu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGVV, Le, msle, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVV, Le, mfle, _ALL) +HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGMVV, MaskedLe, msleu, _ALL) +HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGMVV, MaskedLe, msle, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGMVV, MaskedLe, mfle, _ALL) + +template +using MFromD = decltype(Eq(Zero(D()), Zero(D()))); + +template > +HWY_API MFromD MaskedIsNaN(const M m, const V v) { + return MaskedNe(m, v, v); +} + +#undef HWY_RVV_RETM_ARGMVV +#undef HWY_RVV_RETM_ARGVV + +// ------------------------------ Gt/Ge (Lt, Le) + +namespace detail { +HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGVS, GtS, msgt_vx, _ALL) +HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGVS, GtS, msgtu_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, GtS, mfgt_vf, _ALL) +} // namespace detail + +#undef HWY_RVV_RETM_ARGVS + +// Swap args to reverse comparisons: +template +HWY_API auto Gt(const V a, const V b) -> decltype(Lt(a, b)) { + return Lt(b, a); +} + +template +HWY_API auto Ge(const V a, const V b) -> decltype(Le(a, b)) { + return Le(b, a); +} + +template > +HWY_API MFromD MaskedGt(M m, V a, V b) { + return MaskedLt(m, b, a); +} + +template > +HWY_API MFromD MaskedGe(M m, V a, V b) { + return MaskedLe(m, b, a); +} + +// ------------------------------ TestBit +template +HWY_API auto TestBit(const V a, const V bit) -> decltype(Eq(a, bit)) { + return detail::NeS(And(a, bit), 0); +} + +// ------------------------------ Not +// NOLINTNEXTLINE +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGM, Not, not ) + +// ------------------------------ And + +// mask = f(mask_a, mask_b) (note arg2,arg1 order!) +#define HWY_RVV_RETM_ARGMM(SEW, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) NAME(HWY_RVV_M(MLEN) a, HWY_RVV_M(MLEN) b) { \ + return __riscv_vm##OP##_mm_b##MLEN(b, a, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, And, and) + +// ------------------------------ AndNot +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, AndNot, andn) + +// ------------------------------ Or +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, Or, or) + +// ------------------------------ Xor +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, Xor, xor) + +// ------------------------------ ExclusiveNeither +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGMM, ExclusiveNeither, xnor) + +#undef HWY_RVV_RETM_ARGMM + +// ------------------------------ IfThenElse + +#define HWY_RVV_IF_THEN_ELSE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_M(MLEN) m, HWY_RVV_V(BASE, SEW, LMUL) yes, \ + HWY_RVV_V(BASE, SEW, LMUL) no) { \ + return __riscv_v##OP##_vvm_##CHAR##SEW##LMUL(no, yes, m, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_IF_THEN_ELSE, IfThenElse, merge, _ALL) + +#undef HWY_RVV_IF_THEN_ELSE + +// ------------------------------ IfThenElseZero +template +HWY_API V IfThenElseZero(const M mask, const V yes) { + return IfThenElse(mask, yes, Zero(DFromV())); +} + +// ------------------------------ IfThenZeroElse + +#define HWY_RVV_IF_THEN_ZERO_ELSE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_M(MLEN) m, HWY_RVV_V(BASE, SEW, LMUL) no) { \ + return __riscv_v##OP##_##CHAR##SEW##LMUL(no, 0, m, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_IF_THEN_ZERO_ELSE, IfThenZeroElse, merge_vxm, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_IF_THEN_ZERO_ELSE, IfThenZeroElse, fmerge_vfm, _ALL) + +#undef HWY_RVV_IF_THEN_ZERO_ELSE + +// ------------------------------ MaskFromVec +template +HWY_API MFromD> MaskFromVec(const V v) { + return detail::NeS(v, 0); +} + +// ------------------------------ IsNegative (MFromD) +#ifdef HWY_NATIVE_IS_NEGATIVE +#undef HWY_NATIVE_IS_NEGATIVE +#else +#define HWY_NATIVE_IS_NEGATIVE +#endif + +// Generic for all vector lengths +template +HWY_API MFromD> IsNegative(V v) { + const DFromV d; + const RebindToSigned di; + using TI = TFromD; + + return detail::LtS(BitCast(di, v), static_cast(0)); +} + +// ------------------------------ MaskFalse + +// For mask ops including vmclr, elements past VL are tail-agnostic and cannot +// be relied upon, so define a variant of the generic_ops-inl implementation of +// MaskFalse that ensures all bits are zero as required by mask_test. +#ifdef HWY_NATIVE_MASK_FALSE +#undef HWY_NATIVE_MASK_FALSE +#else +#define HWY_NATIVE_MASK_FALSE +#endif + +template +HWY_API MFromD MaskFalse(D d) { + const DFromV> d_full; + return MaskFromVec(Zero(d_full)); +} + +// ------------------------------ RebindMask +template +HWY_API MFromD RebindMask(const D /*d*/, const MFrom mask) { + // No need to check lane size/LMUL are the same: if not, casting MFrom to + // MFromD would fail. + return mask; +} + +// ------------------------------ VecFromMask + +// Returns mask ? ~0 : 0. No longer use sub.vx(Zero(), 1, mask) because per the +// default mask-agnostic policy, the result of inactive lanes may also be ~0. +#define HWY_RVV_VEC_FROM_MASK(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_M(MLEN) m) { \ + /* MaskFalse requires we set all lanes for capped d and virtual LMUL. */ \ + const DFromV> d_full; \ + const RebindToSigned di; \ + using TI = TFromD; \ + return BitCast(d_full, __riscv_v##OP##_i##SEW##LMUL(Zero(di), TI{-1}, m, \ + Lanes(d_full))); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_VEC_FROM_MASK, VecFromMask, merge_vxm, _ALL_VIRT) + +#undef HWY_RVV_VEC_FROM_MASK + +template +HWY_API VFromD VecFromMask(const D d, MFromD mask) { + return BitCast(d, VecFromMask(RebindToUnsigned(), mask)); +} + +// ------------------------------ IfVecThenElse (MaskFromVec) +template +HWY_API V IfVecThenElse(const V mask, const V yes, const V no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ BroadcastSignBit +template +HWY_API V BroadcastSignBit(const V v) { + return ShiftRight) * 8 - 1>(v); +} + +// ------------------------------ IfNegativeThenElse (BroadcastSignBit) +template +HWY_API V IfNegativeThenElse(V v, V yes, V no) { + static_assert(IsSigned>(), "Only works for signed/float"); + return IfThenElse(IsNegative(v), yes, no); +} + +// ------------------------------ FindFirstTrue + +#define HWY_RVV_FIND_FIRST_TRUE(SEW, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API intptr_t FindFirstTrue(D d, HWY_RVV_M(MLEN) m) { \ + static_assert(MLenFromD(d) == MLEN, "Type mismatch"); \ + return __riscv_vfirst_m_b##MLEN(m, Lanes(d)); \ + } \ + template \ + HWY_API size_t FindKnownFirstTrue(D d, HWY_RVV_M(MLEN) m) { \ + static_assert(MLenFromD(d) == MLEN, "Type mismatch"); \ + return static_cast(__riscv_vfirst_m_b##MLEN(m, Lanes(d))); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_FIND_FIRST_TRUE, , _) +#undef HWY_RVV_FIND_FIRST_TRUE + +// ------------------------------ AllFalse +template +HWY_API bool AllFalse(D d, MFromD m) { + return FindFirstTrue(d, m) < 0; +} + +// ------------------------------ AllTrue + +#define HWY_RVV_ALL_TRUE(SEW, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API bool AllTrue(D d, HWY_RVV_M(MLEN) m) { \ + static_assert(MLenFromD(d) == MLEN, "Type mismatch"); \ + return AllFalse(d, __riscv_vmnot_m_b##MLEN(m, Lanes(d))); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_ALL_TRUE, _, _) +#undef HWY_RVV_ALL_TRUE + +// ------------------------------ CountTrue + +#define HWY_RVV_COUNT_TRUE(SEW, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API size_t CountTrue(D d, HWY_RVV_M(MLEN) m) { \ + static_assert(MLenFromD(d) == MLEN, "Type mismatch"); \ + return __riscv_vcpop_m_b##MLEN(m, Lanes(d)); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_COUNT_TRUE, _, _) +#undef HWY_RVV_COUNT_TRUE + +// ------------------------------ PromoteMaskTo + +#ifdef HWY_NATIVE_PROMOTE_MASK_TO +#undef HWY_NATIVE_PROMOTE_MASK_TO +#else +#define HWY_NATIVE_PROMOTE_MASK_TO +#endif + +template )), + hwy::EnableIf, MFromD>()>* = nullptr> +HWY_API MFromD PromoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, + MFromD m) { + return m; +} + +// ------------------------------ DemoteMaskTo + +#ifdef HWY_NATIVE_DEMOTE_MASK_TO +#undef HWY_NATIVE_DEMOTE_MASK_TO +#else +#define HWY_NATIVE_DEMOTE_MASK_TO +#endif + +template ) - 1), + hwy::EnableIf, MFromD>()>* = nullptr> +HWY_API MFromD DemoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, + MFromD m) { + return m; +} + +// ================================================== MEMORY + +// ------------------------------ Load + +#define HWY_RVV_LOAD(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL( \ + detail::NativeLanePointer(p), Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_LOAD, Load, le, _ALL_VIRT) +#undef HWY_RVV_LOAD + +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, Load(du, detail::U16LanePointer(p))); +} + +// ------------------------------ LoadU +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + // RVV only requires element alignment, not vector alignment. + return Load(d, p); +} + +// ------------------------------ MaskedLoad + +#define HWY_RVV_MASKED_LOAD(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_M(MLEN) m, HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL##_mu( \ + m, Zero(d), detail::NativeLanePointer(p), Lanes(d)); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME##Or(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_M(MLEN) m, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL##_mu( \ + m, v, detail::NativeLanePointer(p), Lanes(d)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_MASKED_LOAD, MaskedLoad, le, _ALL_VIRT) +#undef HWY_RVV_MASKED_LOAD + +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, + MaskedLoad(RebindMask(du, m), du, detail::U16LanePointer(p))); +} + +template +HWY_API VFromD MaskedLoadOr(VFromD no, MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, MaskedLoadOr(BitCast(du, no), RebindMask(du, m), du, + detail::U16LanePointer(p))); +} + +// ------------------------------ LoadN + +// Native with avl is faster than the generic_ops using FirstN. +#ifdef HWY_NATIVE_LOAD_N +#undef HWY_NATIVE_LOAD_N +#else +#define HWY_NATIVE_LOAD_N +#endif + +#define HWY_RVV_LOADN(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p, size_t num_lanes) { \ + /* Use a tail-undisturbed load in LoadN as the tail-undisturbed load */ \ + /* operation below will leave any lanes past the first */ \ + /* (lowest-indexed) HWY_MIN(num_lanes, Lanes(d)) lanes unchanged */ \ + return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL##_tu( \ + Zero(d), detail::NativeLanePointer(p), CappedLanes(d, num_lanes)); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME##Or( \ + HWY_RVV_V(BASE, SEW, LMUL) no, HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p, size_t num_lanes) { \ + /* Use a tail-undisturbed load in LoadNOr as the tail-undisturbed load */ \ + /* operation below will set any lanes past the first */ \ + /* (lowest-indexed) HWY_MIN(num_lanes, Lanes(d)) lanes to the */ \ + /* corresponding lanes in no */ \ + return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL##_tu( \ + no, detail::NativeLanePointer(p), CappedLanes(d, num_lanes)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_LOADN, LoadN, le, _ALL_VIRT) +#undef HWY_RVV_LOADN + +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const RebindToUnsigned du; + return BitCast(d, LoadN(du, detail::U16LanePointer(p), num_lanes)); +} +template +HWY_API VFromD LoadNOr(VFromD v, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const RebindToUnsigned du; + return BitCast( + d, LoadNOr(BitCast(du, v), du, detail::U16LanePointer(p), num_lanes)); +} + +// ------------------------------ Store + +#define HWY_RVV_STORE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL( \ + detail::NativeLanePointer(p), v, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_STORE, Store, se, _ALL_VIRT) +#undef HWY_RVV_STORE + +template +HWY_API void Store(VFromD v, D d, TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + Store(BitCast(du, v), du, detail::U16LanePointer(p)); +} + +// ------------------------------ BlendedStore + +#define HWY_RVV_BLENDED_STORE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_M(MLEN) m, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL##_m( \ + m, detail::NativeLanePointer(p), v, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_BLENDED_STORE, BlendedStore, se, _ALL_VIRT) +#undef HWY_RVV_BLENDED_STORE + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + BlendedStore(BitCast(du, v), RebindMask(du, m), du, + detail::U16LanePointer(p)); +} + +// ------------------------------ StoreN + +namespace detail { + +#define HWY_RVV_STOREN(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME(size_t count, HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /* d */, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p) { \ + return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL( \ + detail::NativeLanePointer(p), v, count); \ + } +HWY_RVV_FOREACH(HWY_RVV_STOREN, StoreN, se, _ALL_VIRT) +#undef HWY_RVV_STOREN + +template +HWY_API void StoreN(size_t count, VFromD v, D d, TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + StoreN(count, BitCast(du, v), du, detail::U16LanePointer(p)); +} + +} // namespace detail + +#ifdef HWY_NATIVE_STORE_N +#undef HWY_NATIVE_STORE_N +#else +#define HWY_NATIVE_STORE_N +#endif + +template +HWY_API void StoreN(VFromD v, D d, TFromD* HWY_RESTRICT p, + size_t max_lanes_to_store) { + // NOTE: Need to clamp max_lanes_to_store to Lanes(d), even if + // MaxLanes(d) >= MaxLanes(DFromV>()) is true, as it is possible for + // detail::StoreN(max_lanes_to_store, v, d, p) to store fewer than + // Lanes(DFromV>()) lanes to p if + // max_lanes_to_store > Lanes(DFromV>()) and + // max_lanes_to_store < 2 * Lanes(DFromV>()) are both true. + + // Also need to make sure that no more than Lanes(d) lanes are stored to p + // if Lanes(d) < Lanes(DFromV>()) is true, which is possible if + // MaxLanes(d) < MaxLanes(DFromV>()) or + // d.Pow2() < DFromV>().Pow2() is true. + detail::StoreN(CappedLanes(d, max_lanes_to_store), v, d, p); +} + +// ------------------------------ StoreU +template +HWY_API void StoreU(const V v, D d, TFromD* HWY_RESTRICT p) { + // RVV only requires element alignment, not vector alignment. + Store(v, d, p); +} + +// ------------------------------ Stream +template +HWY_API void Stream(const V v, D d, T* HWY_RESTRICT aligned) { + Store(v, d, aligned); +} + +// ------------------------------ ScatterOffset + +#ifdef HWY_NATIVE_SCATTER +#undef HWY_NATIVE_SCATTER +#else +#define HWY_NATIVE_SCATTER +#endif + +#define HWY_RVV_SCATTER(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ + HWY_RVV_V(int, SEW, LMUL) offset) { \ + const RebindToUnsigned du; \ + return __riscv_v##OP##ei##SEW##_v_##CHAR##SEW##LMUL( \ + detail::NativeLanePointer(base), BitCast(du, offset), v, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_SCATTER, ScatterOffset, sux, _ALL_VIRT) +#undef HWY_RVV_SCATTER + +// ------------------------------ ScatterIndex +template +HWY_API void ScatterIndex(VFromD v, D d, TFromD* HWY_RESTRICT base, + VFromD> indices) { + constexpr size_t kBits = CeilLog2(sizeof(TFromD)); + return ScatterOffset(v, d, base, ShiftLeft(indices)); +} + +// ------------------------------ MaskedScatterIndex + +#define HWY_RVV_MASKED_SCATTER(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_M(MLEN) m, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ + HWY_RVV_V(int, SEW, LMUL) indices) { \ + const RebindToUnsigned du; \ + constexpr size_t kBits = CeilLog2(sizeof(TFromD)); \ + return __riscv_v##OP##ei##SEW##_v_##CHAR##SEW##LMUL##_m( \ + m, detail::NativeLanePointer(base), \ + ShiftLeft(BitCast(du, indices)), v, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_MASKED_SCATTER, MaskedScatterIndex, sux, _ALL_VIRT) +#undef HWY_RVV_MASKED_SCATTER + +// ------------------------------ GatherOffset + +#ifdef HWY_NATIVE_GATHER +#undef HWY_NATIVE_GATHER +#else +#define HWY_NATIVE_GATHER +#endif + +#define HWY_RVV_GATHER(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ + HWY_RVV_V(int, SEW, LMUL) offset) { \ + const RebindToUnsigned du; \ + return __riscv_v##OP##ei##SEW##_v_##CHAR##SEW##LMUL( \ + detail::NativeLanePointer(base), BitCast(du, offset), Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_GATHER, GatherOffset, lux, _ALL_VIRT) +#undef HWY_RVV_GATHER + +// ------------------------------ GatherIndex + +template +HWY_API VFromD GatherIndex(D d, const TFromD* HWY_RESTRICT base, + const VFromD> index) { + constexpr size_t kBits = CeilLog2(sizeof(TFromD)); + return GatherOffset(d, base, ShiftLeft(index)); +} + +// ------------------------------ MaskedGatherIndexOr + +#define HWY_RVV_MASKED_GATHER(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) no, HWY_RVV_M(MLEN) m, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT base, \ + HWY_RVV_V(int, SEW, LMUL) indices) { \ + const RebindToUnsigned du; \ + const RebindToSigned di; \ + (void)di; /* for HWY_DASSERT */ \ + constexpr size_t kBits = CeilLog2(SEW / 8); \ + HWY_DASSERT(AllFalse(di, Lt(indices, Zero(di)))); \ + return __riscv_v##OP##ei##SEW##_v_##CHAR##SEW##LMUL##_mu( \ + m, no, detail::NativeLanePointer(base), \ + ShiftLeft(BitCast(du, indices)), Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_MASKED_GATHER, MaskedGatherIndexOr, lux, _ALL_VIRT) +#undef HWY_RVV_MASKED_GATHER + +template +HWY_API VFromD MaskedGatherIndex(MFromD m, D d, const TFromD* base, + VFromD> indices) { + return MaskedGatherIndexOr(Zero(d), m, d, base, indices); +} + +// ================================================== CONVERT + +// ------------------------------ PromoteTo + +// SEW is for the input. +#define HWY_RVV_PROMOTE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWD, LMULD) NAME( \ + HWY_RVV_D(BASE, SEWD, N, SHIFT + 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##CHAR##SEWD##LMULD(v, Lanes(d)); \ + } + +HWY_RVV_FOREACH_U08(HWY_RVV_PROMOTE, PromoteTo, zext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_U16(HWY_RVV_PROMOTE, PromoteTo, zext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_U32(HWY_RVV_PROMOTE, PromoteTo, zext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_I08(HWY_RVV_PROMOTE, PromoteTo, sext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_I16(HWY_RVV_PROMOTE, PromoteTo, sext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_I32(HWY_RVV_PROMOTE, PromoteTo, sext_vf2_, _EXT_VIRT) +HWY_RVV_FOREACH_F32(HWY_RVV_PROMOTE, PromoteTo, fwcvt_f_f_v_, _EXT_VIRT) + +#if HWY_HAVE_FLOAT16 || HWY_RVV_HAVE_F16C + +HWY_RVV_FOREACH_F16_UNCONDITIONAL(HWY_RVV_PROMOTE, PromoteTo, fwcvt_f_f_v_, + _EXT_VIRT) + +// Per-target flag to prevent generic_ops-inl.h from defining f16 conversions. +#ifdef HWY_NATIVE_F16C +#undef HWY_NATIVE_F16C +#else +#define HWY_NATIVE_F16C +#endif +#endif // HWY_HAVE_FLOAT16 || HWY_RVV_HAVE_F16C + +#undef HWY_RVV_PROMOTE + +// The above X-macro cannot handle 4x promotion nor type switching. +// TODO(janwas): use BASE2 arg to allow the latter. +#define HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, LMUL, LMUL_IN, \ + SHIFT, ADD) \ + template \ + HWY_API HWY_RVV_V(BASE, BITS, LMUL) \ + PromoteTo(HWY_RVV_D(BASE, BITS, N, SHIFT + ADD) d, \ + HWY_RVV_V(BASE_IN, BITS_IN, LMUL_IN) v) { \ + return __riscv_v##OP##CHAR##BITS##LMUL(v, Lanes(d)); \ + } + +#define HWY_RVV_PROMOTE_X2(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m1, mf2, -2, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m1, mf2, -1, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m2, m1, 0, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m4, m2, 1, 1) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m8, m4, 2, 1) + +#define HWY_RVV_PROMOTE_X4(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m1, mf4, -2, 2) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m2, mf2, -1, 2) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m4, m1, 0, 2) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m8, m2, 1, 2) + +#define HWY_RVV_PROMOTE_X4_FROM_U8(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, mf2, mf8, -3, 2) \ + HWY_RVV_PROMOTE_X4(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN) + +#define HWY_RVV_PROMOTE_X8(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m1, mf8, -3, 3) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m2, mf4, -2, 3) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m4, mf2, -1, 3) \ + HWY_RVV_PROMOTE(OP, BASE, CHAR, BITS, BASE_IN, BITS_IN, m8, m1, 0, 3) + +HWY_RVV_PROMOTE_X8(zext_vf8_, uint, u, 64, uint, 8) +HWY_RVV_PROMOTE_X8(sext_vf8_, int, i, 64, int, 8) + +HWY_RVV_PROMOTE_X4_FROM_U8(zext_vf4_, uint, u, 32, uint, 8) +HWY_RVV_PROMOTE_X4_FROM_U8(sext_vf4_, int, i, 32, int, 8) +HWY_RVV_PROMOTE_X4(zext_vf4_, uint, u, 64, uint, 16) +HWY_RVV_PROMOTE_X4(sext_vf4_, int, i, 64, int, 16) + +// i32 to f64 +HWY_RVV_PROMOTE_X2(fwcvt_f_x_v_, float, f, 64, int, 32) + +// u32 to f64 +HWY_RVV_PROMOTE_X2(fwcvt_f_xu_v_, float, f, 64, uint, 32) + +// f32 to i64 +HWY_RVV_PROMOTE_X2(fwcvt_rtz_x_f_v_, int, i, 64, float, 32) + +// f32 to u64 +HWY_RVV_PROMOTE_X2(fwcvt_rtz_xu_f_v_, uint, u, 64, float, 32) + +#undef HWY_RVV_PROMOTE_X8 +#undef HWY_RVV_PROMOTE_X4_FROM_U8 +#undef HWY_RVV_PROMOTE_X4 +#undef HWY_RVV_PROMOTE_X2 +#undef HWY_RVV_PROMOTE + +// I16->I64 or U16->U64 PromoteTo with virtual LMUL +template +HWY_API auto PromoteTo(Simd d, + VFromD> v) + -> VFromD { + return PromoteTo(ScalableTag(), v); +} + +template +HWY_API auto PromoteTo(Simd d, + VFromD> v) + -> VFromD { + return PromoteTo(ScalableTag(), v); +} + +// Unsigned to signed: cast for unsigned promote. +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { + return BitCast(d, PromoteTo(RebindToUnsigned(), v)); +} + +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { + return BitCast(d, PromoteTo(RebindToUnsigned(), v)); +} + +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { + return BitCast(d, PromoteTo(RebindToUnsigned(), v)); +} + +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { + return BitCast(d, PromoteTo(RebindToUnsigned(), v)); +} + +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { + return BitCast(d, PromoteTo(RebindToUnsigned(), v)); +} + +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { + return BitCast(d, PromoteTo(RebindToUnsigned(), v)); +} + +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { + const RebindToSigned di32; + const Rebind du16; + return BitCast(d, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +// ------------------------------ DemoteTo U + +HWY_INLINE_VAR constexpr size_t kClipShift = 0; + +// SEW is for the source so we can use _DEMOTE_VIRT. +#define HWY_RVV_DEMOTE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME( \ + HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##CHAR##SEWH##LMULH( \ + v, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); \ + } + +// Unsigned -> unsigned +HWY_RVV_FOREACH_U16(HWY_RVV_DEMOTE, DemoteTo, nclipu_wx_, _DEMOTE_VIRT) +HWY_RVV_FOREACH_U32(HWY_RVV_DEMOTE, DemoteTo, nclipu_wx_, _DEMOTE_VIRT) +HWY_RVV_FOREACH_U64(HWY_RVV_DEMOTE, DemoteTo, nclipu_wx_, _DEMOTE_VIRT) + +// SEW is for the source so we can use _DEMOTE_VIRT. +#define HWY_RVV_DEMOTE_I_TO_U(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(uint, SEWH, LMULH) NAME( \ + HWY_RVV_D(uint, SEWH, N, SHIFT - 1) dn, HWY_RVV_V(int, SEW, LMUL) v) { \ + const HWY_RVV_D(uint, SEW, N, SHIFT) du; \ + /* First clamp negative numbers to zero to match x86 packus. */ \ + return DemoteTo(dn, BitCast(du, detail::MaxS(v, 0))); \ + } +HWY_RVV_FOREACH_I64(HWY_RVV_DEMOTE_I_TO_U, DemoteTo, _, _DEMOTE_VIRT) +HWY_RVV_FOREACH_I32(HWY_RVV_DEMOTE_I_TO_U, DemoteTo, _, _DEMOTE_VIRT) +HWY_RVV_FOREACH_I16(HWY_RVV_DEMOTE_I_TO_U, DemoteTo, _, _DEMOTE_VIRT) +#undef HWY_RVV_DEMOTE_I_TO_U + +template +HWY_API vuint8mf8_t DemoteTo(Simd d, const vint32mf2_t v) { + return __riscv_vnclipu_wx_u8mf8( + DemoteTo(Simd(), v), kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); +} +template +HWY_API vuint8mf4_t DemoteTo(Simd d, const vint32m1_t v) { + return __riscv_vnclipu_wx_u8mf4( + DemoteTo(Simd(), v), kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); +} +template +HWY_API vuint8mf2_t DemoteTo(Simd d, const vint32m2_t v) { + return __riscv_vnclipu_wx_u8mf2( + DemoteTo(Simd(), v), kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); +} +template +HWY_API vuint8m1_t DemoteTo(Simd d, const vint32m4_t v) { + return __riscv_vnclipu_wx_u8m1( + DemoteTo(Simd(), v), kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); +} +template +HWY_API vuint8m2_t DemoteTo(Simd d, const vint32m8_t v) { + return __riscv_vnclipu_wx_u8m2( + DemoteTo(Simd(), v), kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); +} + +template +HWY_API vuint8mf8_t DemoteTo(Simd d, const vuint32mf2_t v) { + return __riscv_vnclipu_wx_u8mf8( + DemoteTo(Simd(), v), kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); +} +template +HWY_API vuint8mf4_t DemoteTo(Simd d, const vuint32m1_t v) { + return __riscv_vnclipu_wx_u8mf4( + DemoteTo(Simd(), v), kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); +} +template +HWY_API vuint8mf2_t DemoteTo(Simd d, const vuint32m2_t v) { + return __riscv_vnclipu_wx_u8mf2( + DemoteTo(Simd(), v), kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); +} +template +HWY_API vuint8m1_t DemoteTo(Simd d, const vuint32m4_t v) { + return __riscv_vnclipu_wx_u8m1( + DemoteTo(Simd(), v), kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); +} +template +HWY_API vuint8m2_t DemoteTo(Simd d, const vuint32m8_t v) { + return __riscv_vnclipu_wx_u8m2( + DemoteTo(Simd(), v), kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); +} + +template +HWY_API VFromD DemoteTo(D d, VFromD> v) { + return DemoteTo(d, DemoteTo(Rebind(), v)); +} + +template +HWY_API VFromD DemoteTo(D d, VFromD> v) { + return DemoteTo(d, DemoteTo(Rebind(), v)); +} + +template +HWY_API VFromD DemoteTo(D d, VFromD> v) { + return DemoteTo(d, DemoteTo(Rebind(), v)); +} + +template +HWY_API VFromD DemoteTo(D d, VFromD> v) { + return DemoteTo(d, DemoteTo(Rebind(), v)); +} + +HWY_API vuint8mf8_t U8FromU32(const vuint32mf2_t v) { + const size_t avl = Lanes(ScalableTag()); + return __riscv_vnclipu_wx_u8mf8( + __riscv_vnclipu_wx_u16mf4(v, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)), + kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} +HWY_API vuint8mf4_t U8FromU32(const vuint32m1_t v) { + const size_t avl = Lanes(ScalableTag()); + return __riscv_vnclipu_wx_u8mf4( + __riscv_vnclipu_wx_u16mf2(v, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)), + kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} +HWY_API vuint8mf2_t U8FromU32(const vuint32m2_t v) { + const size_t avl = Lanes(ScalableTag()); + return __riscv_vnclipu_wx_u8mf2( + __riscv_vnclipu_wx_u16m1(v, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)), + kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} +HWY_API vuint8m1_t U8FromU32(const vuint32m4_t v) { + const size_t avl = Lanes(ScalableTag()); + return __riscv_vnclipu_wx_u8m1( + __riscv_vnclipu_wx_u16m2(v, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)), + kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} +HWY_API vuint8m2_t U8FromU32(const vuint32m8_t v) { + const size_t avl = Lanes(ScalableTag()); + return __riscv_vnclipu_wx_u8m2( + __riscv_vnclipu_wx_u16m4(v, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)), + kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +// ------------------------------ Truncations + +template +HWY_API vuint8mf8_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m1_t v1 = __riscv_vand(v, 0xFF, avl); + const vuint32mf2_t v2 = __riscv_vnclipu_wx_u32mf2( + v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + const vuint16mf4_t v3 = __riscv_vnclipu_wx_u16mf4( + v2, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u8mf8(v3, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m2_t v1 = __riscv_vand(v, 0xFF, avl); + const vuint32m1_t v2 = __riscv_vnclipu_wx_u32m1( + v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + const vuint16mf2_t v3 = __riscv_vnclipu_wx_u16mf2( + v2, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u8mf4(v3, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m4_t v1 = __riscv_vand(v, 0xFF, avl); + const vuint32m2_t v2 = __riscv_vnclipu_wx_u32m2( + v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + const vuint16m1_t v3 = __riscv_vnclipu_wx_u16m1( + v2, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u8mf2(v3, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m8_t v1 = __riscv_vand(v, 0xFF, avl); + const vuint32m4_t v2 = __riscv_vnclipu_wx_u32m4( + v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + const vuint16m2_t v3 = __riscv_vnclipu_wx_u16m2( + v2, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u8m1(v3, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint16mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m1_t v1 = __riscv_vand(v, 0xFFFF, avl); + const vuint32mf2_t v2 = __riscv_vnclipu_wx_u32mf2( + v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u16mf4(v2, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint16mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m1_t v1 = __riscv_vand(v, 0xFFFF, avl); + const vuint32mf2_t v2 = __riscv_vnclipu_wx_u32mf2( + v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u16mf4(v2, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint16mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m2_t v1 = __riscv_vand(v, 0xFFFF, avl); + const vuint32m1_t v2 = __riscv_vnclipu_wx_u32m1( + v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u16mf2(v2, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint16m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m4_t v1 = __riscv_vand(v, 0xFFFF, avl); + const vuint32m2_t v2 = __riscv_vnclipu_wx_u32m2( + v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u16m1(v2, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint16m2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m8_t v1 = __riscv_vand(v, 0xFFFF, avl); + const vuint32m4_t v2 = __riscv_vnclipu_wx_u32m4( + v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u16m2(v2, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint32mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m1_t v1 = __riscv_vand(v, 0xFFFFFFFFu, avl); + return __riscv_vnclipu_wx_u32mf2(v1, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint32mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m1_t v1 = __riscv_vand(v, 0xFFFFFFFFu, avl); + return __riscv_vnclipu_wx_u32mf2(v1, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint32m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m2_t v1 = __riscv_vand(v, 0xFFFFFFFFu, avl); + return __riscv_vnclipu_wx_u32m1(v1, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint32m2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m4_t v1 = __riscv_vand(v, 0xFFFFFFFFu, avl); + return __riscv_vnclipu_wx_u32m2(v1, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint32m4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint64m8_t v1 = __riscv_vand(v, 0xFFFFFFFFu, avl); + return __riscv_vnclipu_wx_u32m4(v1, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8mf8_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32mf2_t v1 = __riscv_vand(v, 0xFF, avl); + const vuint16mf4_t v2 = __riscv_vnclipu_wx_u16mf4( + v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u8mf8(v2, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m1_t v1 = __riscv_vand(v, 0xFF, avl); + const vuint16mf2_t v2 = __riscv_vnclipu_wx_u16mf2( + v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u8mf4(v2, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m2_t v1 = __riscv_vand(v, 0xFF, avl); + const vuint16m1_t v2 = __riscv_vnclipu_wx_u16m1( + v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u8mf2(v2, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m4_t v1 = __riscv_vand(v, 0xFF, avl); + const vuint16m2_t v2 = __riscv_vnclipu_wx_u16m2( + v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u8m1(v2, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8m2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m8_t v1 = __riscv_vand(v, 0xFF, avl); + const vuint16m4_t v2 = __riscv_vnclipu_wx_u16m4( + v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); + return __riscv_vnclipu_wx_u8m2(v2, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint16mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32mf2_t v1 = __riscv_vand(v, 0xFFFF, avl); + return __riscv_vnclipu_wx_u16mf4(v1, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint16mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32mf2_t v1 = __riscv_vand(v, 0xFFFF, avl); + return __riscv_vnclipu_wx_u16mf4(v1, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint16mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m1_t v1 = __riscv_vand(v, 0xFFFF, avl); + return __riscv_vnclipu_wx_u16mf2(v1, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint16m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m2_t v1 = __riscv_vand(v, 0xFFFF, avl); + return __riscv_vnclipu_wx_u16m1(v1, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint16m2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m4_t v1 = __riscv_vand(v, 0xFFFF, avl); + return __riscv_vnclipu_wx_u16m2(v1, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint16m4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint32m8_t v1 = __riscv_vand(v, 0xFFFF, avl); + return __riscv_vnclipu_wx_u16m4(v1, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8mf8_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16mf4_t v1 = __riscv_vand(v, 0xFF, avl); + return __riscv_vnclipu_wx_u8mf8(v1, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8mf4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16mf2_t v1 = __riscv_vand(v, 0xFF, avl); + return __riscv_vnclipu_wx_u8mf4(v1, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8mf2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16m1_t v1 = __riscv_vand(v, 0xFF, avl); + return __riscv_vnclipu_wx_u8mf2(v1, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8m1_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16m2_t v1 = __riscv_vand(v, 0xFF, avl); + return __riscv_vnclipu_wx_u8m1(v1, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8m2_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16m4_t v1 = __riscv_vand(v, 0xFF, avl); + return __riscv_vnclipu_wx_u8m2(v1, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +template +HWY_API vuint8m4_t TruncateTo(Simd d, + const VFromD> v) { + const size_t avl = Lanes(d); + const vuint16m8_t v1 = __riscv_vand(v, 0xFF, avl); + return __riscv_vnclipu_wx_u8m4(v1, kClipShift, + HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)); +} + +// ------------------------------ DemoteTo I + +HWY_RVV_FOREACH_I16(HWY_RVV_DEMOTE, DemoteTo, nclip_wx_, _DEMOTE_VIRT) +HWY_RVV_FOREACH_I32(HWY_RVV_DEMOTE, DemoteTo, nclip_wx_, _DEMOTE_VIRT) +HWY_RVV_FOREACH_I64(HWY_RVV_DEMOTE, DemoteTo, nclip_wx_, _DEMOTE_VIRT) + +template +HWY_API vint8mf8_t DemoteTo(Simd d, const vint32mf2_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); +} +template +HWY_API vint8mf4_t DemoteTo(Simd d, const vint32m1_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); +} +template +HWY_API vint8mf2_t DemoteTo(Simd d, const vint32m2_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); +} +template +HWY_API vint8m1_t DemoteTo(Simd d, const vint32m4_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); +} +template +HWY_API vint8m2_t DemoteTo(Simd d, const vint32m8_t v) { + return DemoteTo(d, DemoteTo(Simd(), v)); +} + +template +HWY_API VFromD DemoteTo(D d, VFromD> v) { + return DemoteTo(d, DemoteTo(Rebind(), v)); +} + +template +HWY_API VFromD DemoteTo(D d, VFromD> v) { + return DemoteTo(d, DemoteTo(Rebind(), v)); +} + +#undef HWY_RVV_DEMOTE + +// ------------------------------ DemoteTo F + +// SEW is for the source so we can use _DEMOTE_VIRT. +#define HWY_RVV_DEMOTE_F(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME( \ + HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##SEWH##LMULH(v, Lanes(d)); \ + } + +#if HWY_HAVE_FLOAT16 || HWY_RVV_HAVE_F16C +HWY_RVV_FOREACH_F32(HWY_RVV_DEMOTE_F, DemoteTo, fncvt_f_f_w_f, _DEMOTE_VIRT) +#endif +HWY_RVV_FOREACH_F64(HWY_RVV_DEMOTE_F, DemoteTo, fncvt_f_f_w_f, _DEMOTE_VIRT) + +namespace detail { +HWY_RVV_FOREACH_F64(HWY_RVV_DEMOTE_F, DemoteToF32WithRoundToOdd, + fncvt_rod_f_f_w_f, _DEMOTE_VIRT) +} // namespace detail + +#undef HWY_RVV_DEMOTE_F + +// TODO(janwas): add BASE2 arg to allow generating this via DEMOTE_F. +template +HWY_API vint32mf2_t DemoteTo(Simd d, const vfloat64m1_t v) { + return __riscv_vfncvt_rtz_x_f_w_i32mf2(v, Lanes(d)); +} +template +HWY_API vint32mf2_t DemoteTo(Simd d, const vfloat64m1_t v) { + return __riscv_vfncvt_rtz_x_f_w_i32mf2(v, Lanes(d)); +} +template +HWY_API vint32m1_t DemoteTo(Simd d, const vfloat64m2_t v) { + return __riscv_vfncvt_rtz_x_f_w_i32m1(v, Lanes(d)); +} +template +HWY_API vint32m2_t DemoteTo(Simd d, const vfloat64m4_t v) { + return __riscv_vfncvt_rtz_x_f_w_i32m2(v, Lanes(d)); +} +template +HWY_API vint32m4_t DemoteTo(Simd d, const vfloat64m8_t v) { + return __riscv_vfncvt_rtz_x_f_w_i32m4(v, Lanes(d)); +} + +template +HWY_API vuint32mf2_t DemoteTo(Simd d, const vfloat64m1_t v) { + return __riscv_vfncvt_rtz_xu_f_w_u32mf2(v, Lanes(d)); +} +template +HWY_API vuint32mf2_t DemoteTo(Simd d, const vfloat64m1_t v) { + return __riscv_vfncvt_rtz_xu_f_w_u32mf2(v, Lanes(d)); +} +template +HWY_API vuint32m1_t DemoteTo(Simd d, const vfloat64m2_t v) { + return __riscv_vfncvt_rtz_xu_f_w_u32m1(v, Lanes(d)); +} +template +HWY_API vuint32m2_t DemoteTo(Simd d, const vfloat64m4_t v) { + return __riscv_vfncvt_rtz_xu_f_w_u32m2(v, Lanes(d)); +} +template +HWY_API vuint32m4_t DemoteTo(Simd d, const vfloat64m8_t v) { + return __riscv_vfncvt_rtz_xu_f_w_u32m4(v, Lanes(d)); +} + +template +HWY_API vfloat32mf2_t DemoteTo(Simd d, const vint64m1_t v) { + return __riscv_vfncvt_f_x_w_f32mf2(v, Lanes(d)); +} +template +HWY_API vfloat32mf2_t DemoteTo(Simd d, const vint64m1_t v) { + return __riscv_vfncvt_f_x_w_f32mf2(v, Lanes(d)); +} +template +HWY_API vfloat32m1_t DemoteTo(Simd d, const vint64m2_t v) { + return __riscv_vfncvt_f_x_w_f32m1(v, Lanes(d)); +} +template +HWY_API vfloat32m2_t DemoteTo(Simd d, const vint64m4_t v) { + return __riscv_vfncvt_f_x_w_f32m2(v, Lanes(d)); +} +template +HWY_API vfloat32m4_t DemoteTo(Simd d, const vint64m8_t v) { + return __riscv_vfncvt_f_x_w_f32m4(v, Lanes(d)); +} + +template +HWY_API vfloat32mf2_t DemoteTo(Simd d, const vuint64m1_t v) { + return __riscv_vfncvt_f_xu_w_f32mf2(v, Lanes(d)); +} +template +HWY_API vfloat32mf2_t DemoteTo(Simd d, const vuint64m1_t v) { + return __riscv_vfncvt_f_xu_w_f32mf2(v, Lanes(d)); +} +template +HWY_API vfloat32m1_t DemoteTo(Simd d, const vuint64m2_t v) { + return __riscv_vfncvt_f_xu_w_f32m1(v, Lanes(d)); +} +template +HWY_API vfloat32m2_t DemoteTo(Simd d, const vuint64m4_t v) { + return __riscv_vfncvt_f_xu_w_f32m2(v, Lanes(d)); +} +template +HWY_API vfloat32m4_t DemoteTo(Simd d, const vuint64m8_t v) { + return __riscv_vfncvt_f_xu_w_f32m4(v, Lanes(d)); +} + +// Narrows f32 bits to bf16 using round to even. +// SEW is for the source so we can use _DEMOTE_VIRT. +#ifdef HWY_RVV_AVOID_VXRM +#define HWY_RVV_DEMOTE_16_NEAREST_EVEN(BASE, CHAR, SEW, SEWD, SEWH, LMUL, \ + LMULD, LMULH, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME( \ + HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + const auto round = \ + detail::AddS(detail::AndS(ShiftRight<16>(v), 1u), 0x7FFFu); \ + v = Add(v, round); \ + /* The default rounding mode appears to be RNU=0, which adds the LSB. */ \ + /* Prevent further rounding by clearing the bits we want to truncate. */ \ + v = detail::AndS(v, 0xFFFF0000u); \ + return __riscv_v##OP##CHAR##SEWH##LMULH(v, 16, Lanes(d)); \ + } + +#else +#define HWY_RVV_DEMOTE_16_NEAREST_EVEN(BASE, CHAR, SEW, SEWD, SEWH, LMUL, \ + LMULD, LMULH, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME( \ + HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##CHAR##SEWH##LMULH( \ + v, 16, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RNE, Lanes(d))); \ + } +#endif // HWY_RVV_AVOID_VXRM +namespace detail { +HWY_RVV_FOREACH_U32(HWY_RVV_DEMOTE_16_NEAREST_EVEN, DemoteTo16NearestEven, + nclipu_wx_, _DEMOTE_VIRT) +} +#undef HWY_RVV_DEMOTE_16_NEAREST_EVEN + +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + +template +HWY_API VFromD DemoteTo(DBF16 d, VFromD> v) { + const DFromV df; + const RebindToUnsigned du32; + const RebindToUnsigned du16; + // Consider an f32 mantissa with the upper 7 bits set, followed by a 1-bit + // and at least one other bit set. This will round to 0 and increment the + // exponent. If the exponent was already 0xFF (NaN), then the result is -inf; + // there no wraparound because nclipu saturates. Note that in this case, the + // input cannot have been inf because its mantissa bits are zero. To avoid + // converting NaN to inf, we canonicalize the NaN to prevent the rounding. + const decltype(v) canonicalized = + IfThenElse(Eq(v, v), v, BitCast(df, Set(du32, 0x7F800000))); + return BitCast( + d, detail::DemoteTo16NearestEven(du16, BitCast(du32, canonicalized))); +} + +#ifdef HWY_NATIVE_DEMOTE_F64_TO_F16 +#undef HWY_NATIVE_DEMOTE_F64_TO_F16 +#else +#define HWY_NATIVE_DEMOTE_F64_TO_F16 +#endif + +template +HWY_API VFromD DemoteTo(D df16, VFromD> v) { + const Rebind df32; + return DemoteTo(df16, detail::DemoteToF32WithRoundToOdd(df32, v)); +} + +// ------------------------------ ConvertTo F + +#define HWY_RVV_CONVERT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) ConvertTo( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(int, SEW, LMUL) v) { \ + return __riscv_vfcvt_f_x_v_f##SEW##LMUL(v, Lanes(d)); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) ConvertTo( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(uint, SEW, LMUL) v) { \ + return __riscv_vfcvt_f_xu_v_f##SEW##LMUL(v, Lanes(d)); \ + } \ + /* Truncates (rounds toward zero). */ \ + template \ + HWY_API HWY_RVV_V(int, SEW, LMUL) ConvertTo(HWY_RVV_D(int, SEW, N, SHIFT) d, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_vfcvt_rtz_x_f_v_i##SEW##LMUL(v, Lanes(d)); \ + } \ + template \ + HWY_API HWY_RVV_V(uint, SEW, LMUL) ConvertTo( \ + HWY_RVV_D(uint, SEW, N, SHIFT) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_vfcvt_rtz_xu_f_v_u##SEW##LMUL(v, Lanes(d)); \ + } + +HWY_RVV_FOREACH_F(HWY_RVV_CONVERT, _, _, _ALL_VIRT) +#undef HWY_RVV_CONVERT + +// Uses default rounding mode. Must be separate because there is no D arg. +#define HWY_RVV_NEAREST(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(int, SEW, LMUL) NearestInt(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_vfcvt_x_f_v_i##SEW##LMUL(v, HWY_RVV_AVL(SEW, SHIFT)); \ + } +HWY_RVV_FOREACH_F(HWY_RVV_NEAREST, _, _, _ALL) +#undef HWY_RVV_NEAREST + +template +HWY_API vint32mf2_t DemoteToNearestInt(Simd d, + const vfloat64m1_t v) { + return __riscv_vfncvt_x_f_w_i32mf2(v, Lanes(d)); +} +template +HWY_API vint32mf2_t DemoteToNearestInt(Simd d, + const vfloat64m1_t v) { + return __riscv_vfncvt_x_f_w_i32mf2(v, Lanes(d)); +} +template +HWY_API vint32m1_t DemoteToNearestInt(Simd d, + const vfloat64m2_t v) { + return __riscv_vfncvt_x_f_w_i32m1(v, Lanes(d)); +} +template +HWY_API vint32m2_t DemoteToNearestInt(Simd d, + const vfloat64m4_t v) { + return __riscv_vfncvt_x_f_w_i32m2(v, Lanes(d)); +} +template +HWY_API vint32m4_t DemoteToNearestInt(Simd d, + const vfloat64m8_t v) { + return __riscv_vfncvt_x_f_w_i32m4(v, Lanes(d)); +} + +// ================================================== COMBINE + +namespace detail { + +// For x86-compatible behaviour mandated by Highway API: TableLookupBytes +// offsets are implicitly relative to the start of their 128-bit block. +template +HWY_LANES_CONSTEXPR HWY_INLINE size_t LanesPerBlock(Simd d) { + // kMinVecBytes is the minimum size of VFromD in bytes + constexpr size_t kMinVecBytes = + ScaleByPower(16, HWY_MAX(HWY_MIN(kPow2, 3), -3)); + // kMinVecLanes is the minimum number of lanes in VFromD + constexpr size_t kMinVecLanes = (kMinVecBytes + sizeof(T) - 1) / sizeof(T); + // kMaxLpb is the maximum number of lanes per block + constexpr size_t kMaxLpb = HWY_MIN(16 / sizeof(T), MaxLanes(d)); + + // If kMaxLpb <= kMinVecLanes is true, then kMaxLpb <= Lanes(d) is true + if (kMaxLpb <= kMinVecLanes) return kMaxLpb; + + // Fractional LMUL: Lanes(d) may be smaller than kMaxLpb, so honor that. + const size_t lanes_per_vec = Lanes(d); + return HWY_MIN(lanes_per_vec, kMaxLpb); +} + +template +HWY_INLINE V OffsetsOf128BitBlocks(const D d, const V iota0) { + using T = MakeUnsigned>; + return AndS(iota0, static_cast(~(LanesPerBlock(d) - 1))); +} + +template +HWY_INLINE MFromD FirstNPerBlock(D /* tag */) { + const RebindToUnsigned du; + const RebindToSigned di; + using TU = TFromD; + const auto idx_mod = AndS(Iota0(du), static_cast(LanesPerBlock(du) - 1)); + return LtS(BitCast(di, idx_mod), static_cast>(kLanes)); +} + +#define HWY_RVV_SLIDE_UP(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) dst, HWY_RVV_V(BASE, SEW, LMUL) src, \ + size_t lanes) { \ + return __riscv_v##OP##_vx_##CHAR##SEW##LMUL(dst, src, lanes, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +#define HWY_RVV_SLIDE_DOWN(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) src, size_t lanes) { \ + return __riscv_v##OP##_vx_##CHAR##SEW##LMUL(src, lanes, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_SLIDE_UP, SlideUp, slideup, _ALL) +HWY_RVV_FOREACH(HWY_RVV_SLIDE_DOWN, SlideDown, slidedown, _ALL) + +#undef HWY_RVV_SLIDE_UP +#undef HWY_RVV_SLIDE_DOWN + +#define HWY_RVV_GET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMULH) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMULH( \ + v, kIndex); /* no AVL */ \ + } +#define HWY_RVV_GET_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMULH) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \ + HWY_IF_CONSTEXPR(kIndex == 0) { return Trunc(v); } \ + HWY_IF_CONSTEXPR(kIndex != 0) { \ + return Trunc(SlideDown( \ + v, Lanes(HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), \ + SHIFT - 1){}))); \ + } \ + } +#define HWY_RVV_GET_SMALLEST(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \ + HWY_IF_CONSTEXPR(kIndex == 0) { return v; } \ + HWY_IF_CONSTEXPR(kIndex != 0) { \ + return SlideDown( \ + v, Lanes(HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), \ + SHIFT){}) / \ + 2); \ + } \ + } +HWY_RVV_FOREACH(HWY_RVV_GET, Get, get, _GET_SET) +HWY_RVV_FOREACH(HWY_RVV_GET_VIRT, Get, get, _GET_SET_VIRT) +HWY_RVV_FOREACH(HWY_RVV_GET_SMALLEST, Get, get, _GET_SET_SMALLEST) +#undef HWY_RVV_GET +#undef HWY_RVV_GET_VIRT +#undef HWY_RVV_GET_SMALLEST + +template +static HWY_INLINE HWY_MAYBE_UNUSED VFromD>> +Get(D d, VFromD v) { + static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); + HWY_IF_CONSTEXPR(kIndex == 0 || detail::IsFull(d)) { return Get(v); } + HWY_IF_CONSTEXPR(kIndex != 0 && !detail::IsFull(d)) { + const AdjustSimdTagToMinVecPow2> dh; + const size_t slide_down_amt = + (dh.Pow2() < DFromV().Pow2()) ? Lanes(dh) : (Lanes(d) / 2); + return ResizeBitCast(dh, SlideDown(v, slide_down_amt)); + } +} + +#define HWY_RVV_PARTIAL_VEC_SET_HALF(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMULH) v, \ + size_t half_N) { \ + static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \ + const DFromV d; \ + HWY_IF_CONSTEXPR(kIndex == 0) { \ + return __riscv_v##OP##_v_v_##CHAR##SEW##LMUL##_tu(dest, Ext(d, v), \ + half_N); \ + } \ + HWY_IF_CONSTEXPR(kIndex != 0) { return SlideUp(dest, Ext(d, v), half_N); } \ + } +#define HWY_RVV_PARTIAL_VEC_SET_HALF_SMALLEST( \ + BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMUL) v, \ + size_t half_N) { \ + static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \ + HWY_IF_CONSTEXPR(kIndex == 0) { \ + return __riscv_v##OP##_v_v_##CHAR##SEW##LMUL##_tu(dest, v, half_N); \ + } \ + HWY_IF_CONSTEXPR(kIndex != 0) { return SlideUp(dest, v, half_N); } \ + } +HWY_RVV_FOREACH(HWY_RVV_PARTIAL_VEC_SET_HALF, PartialVecSetHalf, mv, _GET_SET) +HWY_RVV_FOREACH(HWY_RVV_PARTIAL_VEC_SET_HALF, PartialVecSetHalf, mv, + _GET_SET_VIRT) +HWY_RVV_FOREACH(HWY_RVV_PARTIAL_VEC_SET_HALF_SMALLEST, PartialVecSetHalf, mv, + _GET_SET_SMALLEST) +#undef HWY_RVV_PARTIAL_VEC_SET_HALF +#undef HWY_RVV_PARTIAL_VEC_SET_HALF_SMALLEST + +#define HWY_RVV_SET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(BASE, SEW, LMUL) dest, \ + HWY_RVV_V(BASE, SEW, LMULH) v) { \ + HWY_IF_CONSTEXPR(detail::IsFull(d)) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMULH##_##CHAR##SEW##LMUL( \ + dest, kIndex, v); /* no AVL */ \ + } \ + HWY_IF_CONSTEXPR(!detail::IsFull(d)) { \ + const Half dh; \ + return PartialVecSetHalf(dest, v, Lanes(dh)); \ + } \ + } +#define HWY_RVV_SET_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(BASE, SEW, LMUL) dest, \ + HWY_RVV_V(BASE, SEW, LMULH) v) { \ + const Half dh; \ + return PartialVecSetHalf(dest, v, Lanes(dh)); \ + } +#define HWY_RVV_SET_SMALLEST(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(BASE, SEW, LMUL) dest, \ + HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return PartialVecSetHalf(dest, v, Lanes(d) / 2); \ + } +#define HWY_RVV_SET_SMALLEST_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT - 1) d, \ + HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return PartialVecSetHalf(dest, v, Lanes(d) / 2); \ + } +HWY_RVV_FOREACH(HWY_RVV_SET, Set, set, _GET_SET) +HWY_RVV_FOREACH(HWY_RVV_SET_VIRT, Set, set, _GET_SET_VIRT) +HWY_RVV_FOREACH(HWY_RVV_SET_SMALLEST, Set, set, _GET_SET_SMALLEST) +HWY_RVV_FOREACH_UI163264(HWY_RVV_SET_SMALLEST_VIRT, Set, set, _GET_SET_SMALLEST) +HWY_RVV_FOREACH_F(HWY_RVV_SET_SMALLEST_VIRT, Set, set, _GET_SET_SMALLEST) +#undef HWY_RVV_SET +#undef HWY_RVV_SET_VIRT +#undef HWY_RVV_SET_SMALLEST +#undef HWY_RVV_SET_SMALLEST_VIRT + +template +static HWY_INLINE HWY_MAYBE_UNUSED VFromD Set( + D d, VFromD dest, VFromD>> v) { + const RebindToUnsigned du; + return BitCast( + d, Set(du, BitCast(du, dest), + BitCast(RebindToUnsigned>(), v))); +} + +} // namespace detail + +// ------------------------------ SlideUpLanes +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { + return detail::SlideUp(Zero(d), v, amt); +} + +// ------------------------------ SlideDownLanes +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { + v = detail::SlideDown(v, amt); + // Zero out upper lanes if v is a partial vector + if (MaxLanes(d) < MaxLanes(DFromV())) { + v = detail::SlideUp(v, Zero(d), Lanes(d) - amt); + } + return v; +} + +// ------------------------------ ConcatUpperLower +template +HWY_API V ConcatUpperLower(D d, const V hi, const V lo) { + const auto lo_lower = detail::Get<0>(d, lo); + return detail::Set<0>(d, hi, lo_lower); +} + +// ------------------------------ ConcatLowerLower +template +HWY_API V ConcatLowerLower(D d, const V hi, const V lo) { + const auto hi_lower = detail::Get<0>(d, hi); + return detail::Set<1>(d, lo, hi_lower); +} + +// ------------------------------ ConcatUpperUpper +template +HWY_API V ConcatUpperUpper(D d, const V hi, const V lo) { + const auto lo_upper = detail::Get<1>(d, lo); + return detail::Set<0>(d, hi, lo_upper); +} + +// ------------------------------ ConcatLowerUpper +template +HWY_API V ConcatLowerUpper(D d, const V hi, const V lo) { + const auto lo_upper = detail::Get<1>(d, lo); + const auto hi_lower = detail::Get<0>(d, hi); + return detail::Set<1>(d, ResizeBitCast(d, lo_upper), hi_lower); +} + +// ------------------------------ Combine +template +HWY_API VFromD Combine(D2 d2, const V hi, const V lo) { + return detail::Set<1>(d2, ResizeBitCast(d2, lo), hi); +} + +// ------------------------------ ZeroExtendVector +template +HWY_API VFromD ZeroExtendVector(D2 d2, const V lo) { + return Combine(d2, Xor(lo, lo), lo); +} + +// ------------------------------ Lower/UpperHalf + +namespace detail { + +// RVV may only support LMUL >= SEW/64; returns whether that holds for D. Note +// that SEW = sizeof(T)*8 and LMUL = 1 << d.Pow2(). Add 3 to Pow2 to avoid +// negative shift counts. +template +constexpr bool IsSupportedLMUL(D d) { + return (size_t{1} << (d.Pow2() + 3)) >= sizeof(TFromD); +} + +} // namespace detail + +// If IsSupportedLMUL, just 'truncate' i.e. halve LMUL. +template * = nullptr> +HWY_API VFromD LowerHalf(const DH /* tag */, const VFromD> v) { + return detail::Trunc(v); +} + +// Otherwise, there is no corresponding intrinsic type (e.g. vuint64mf2_t), and +// the hardware may set "vill" if we attempt such an LMUL. However, the V +// extension on application processors requires Zvl128b, i.e. VLEN >= 128, so it +// still makes sense to have half of an SEW=64 vector. We instead just return +// the vector, and rely on the kPow2 in DH to halve the return value of Lanes(). +template * = nullptr> +HWY_API V LowerHalf(const DH /* tag */, const V v) { + return v; +} + +// Same, but without D arg +template +HWY_API VFromD>> LowerHalf(const V v) { + return LowerHalf(Half>(), v); +} + +template +HWY_API VFromD UpperHalf(const DH /*d2*/, const VFromD> v) { + const Twice d; + return detail::Get<1>(d, v); +} + +// ================================================== SWIZZLE + +namespace detail { +// Special instruction for 1 lane is presumably faster? +#define HWY_RVV_SLIDE1(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_##CHAR##SEW##LMUL(v, 0, HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_SLIDE1, Slide1Up, slide1up_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_SLIDE1, Slide1Up, fslide1up_vf, _ALL) +HWY_RVV_FOREACH_UI(HWY_RVV_SLIDE1, Slide1Down, slide1down_vx, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_SLIDE1, Slide1Down, fslide1down_vf, _ALL) +#undef HWY_RVV_SLIDE1 +} // namespace detail + +// ------------------------------ Slide1Up and Slide1Down +#ifdef HWY_NATIVE_SLIDE1_UP_DOWN +#undef HWY_NATIVE_SLIDE1_UP_DOWN +#else +#define HWY_NATIVE_SLIDE1_UP_DOWN +#endif + +template +HWY_API VFromD Slide1Up(D /*d*/, VFromD v) { + return detail::Slide1Up(v); +} + +template +HWY_API VFromD Slide1Down(D d, VFromD v) { + v = detail::Slide1Down(v); + // Zero out upper lanes if v is a partial vector + if (MaxLanes(d) < MaxLanes(DFromV())) { + v = detail::SlideUp(v, Zero(d), Lanes(d) - 1); + } + return v; +} + +// ------------------------------ GetLane + +#define HWY_RVV_GET_LANE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_T(BASE, SEW) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_s_##CHAR##SEW##LMUL##_##CHAR##SEW(v); /* no AVL */ \ + } + +HWY_RVV_FOREACH_UI(HWY_RVV_GET_LANE, GetLane, mv_x, _ALL) +HWY_RVV_FOREACH_F(HWY_RVV_GET_LANE, GetLane, fmv_f, _ALL) +#undef HWY_RVV_GET_LANE + +// ------------------------------ ExtractLane +template +HWY_API TFromV ExtractLane(const V v, size_t i) { + return GetLane(detail::SlideDown(v, i)); +} + +// ------------------------------ Additional mask logical operations + +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGM, SetOnlyFirst, sof) +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGM, SetBeforeFirst, sbf) +HWY_RVV_FOREACH_B(HWY_RVV_RETM_ARGM, SetAtOrBeforeFirst, sif) + +#define HWY_RVV_SET_AT_OR_AFTER_FIRST(SEW, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_M(MLEN) SetAtOrAfterFirst(HWY_RVV_M(MLEN) m) { \ + return Not(SetBeforeFirst(m)); \ + } + +HWY_RVV_FOREACH_B(HWY_RVV_SET_AT_OR_AFTER_FIRST, _, _) +#undef HWY_RVV_SET_AT_OR_AFTER_FIRST + +// ------------------------------ InsertLane + +// T template arg because TFromV might not match the hwy::float16_t argument. +template +HWY_API V InsertLane(const V v, size_t i, T t) { + const Rebind> d; + const RebindToUnsigned du; // Iota0 is unsigned only + using TU = TFromD; + const auto is_i = detail::EqS(detail::Iota0(du), static_cast(i)); + return IfThenElse(RebindMask(d, is_i), Set(d, t), v); +} + +// For 8-bit lanes, Iota0 might overflow. +template +HWY_API V InsertLane(const V v, size_t i, T t) { + const Rebind> d; + const auto zero = Zero(d); + const auto one = Set(d, 1); + const auto ge_i = Eq(detail::SlideUp(zero, one, i), one); + const auto is_i = SetOnlyFirst(ge_i); + return IfThenElse(RebindMask(d, is_i), Set(d, t), v); +} + +// ------------------------------ OddEven + +namespace detail { + +// Faster version using a wide constant instead of Iota0 + AndS. +template +HWY_INLINE MFromD IsEven(D d) { + const RebindToUnsigned du; + const RepartitionToWide duw; + return RebindMask(d, detail::NeS(BitCast(du, Set(duw, 1)), 0u)); +} + +template +HWY_INLINE MFromD IsEven(D d) { + const RebindToUnsigned du; // Iota0 is unsigned only + return detail::EqS(detail::AndS(detail::Iota0(du), 1), 0); +} + +// Also provide the negated form because there is no native CompressNot. +template +HWY_INLINE MFromD IsOdd(D d) { + const RebindToUnsigned du; + const RepartitionToWide duw; + return RebindMask(d, detail::EqS(BitCast(du, Set(duw, 1)), 0u)); +} + +template +HWY_INLINE MFromD IsOdd(D d) { + const RebindToUnsigned du; // Iota0 is unsigned only + return detail::NeS(detail::AndS(detail::Iota0(du), 1), 0); +} + +} // namespace detail + +template +HWY_API V OddEven(const V a, const V b) { + return IfThenElse(detail::IsEven(DFromV()), b, a); +} + +// ------------------------------ DupEven (OddEven) +template +HWY_API V DupEven(const V v) { + const V up = detail::Slide1Up(v); + return OddEven(up, v); +} + +// ------------------------------ DupOdd (OddEven) +template +HWY_API V DupOdd(const V v) { + const V down = detail::Slide1Down(v); + return OddEven(v, down); +} + +// ------------------------------ InterleaveEven (OddEven) +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return OddEven(detail::Slide1Up(b), a); +} + +// ------------------------------ InterleaveOdd (OddEven) +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return OddEven(b, detail::Slide1Down(a)); +} + +// ------------------------------ OddEvenBlocks +template +HWY_API V OddEvenBlocks(const V a, const V b) { + const RebindToUnsigned> du; // Iota0 is unsigned only + constexpr size_t kShift = CeilLog2(16 / sizeof(TFromV)); + const auto idx_block = ShiftRight(detail::Iota0(du)); + const auto is_even = detail::EqS(detail::AndS(idx_block, 1), 0); + return IfThenElse(is_even, b, a); +} + +// ------------------------------ SwapAdjacentBlocks +template +HWY_API V SwapAdjacentBlocks(const V v) { + const DFromV d; + const size_t lpb = detail::LanesPerBlock(d); + const V down = detail::SlideDown(v, lpb); + const V up = detail::SlideUp(v, v, lpb); + return OddEvenBlocks(up, down); +} + +// ------------------------------ InterleaveEvenBlocks +// (SlideUpLanes, OddEvenBlocks) + +template > +HWY_API V InterleaveEvenBlocks(D d, V a, V b) { + HWY_LANES_CONSTEXPR size_t lpb = detail::LanesPerBlock(d); + return OddEvenBlocks(SlideUpLanes(d, b, lpb), a); +} + +// ------------------------------ InterleaveOddBlocks +// (SlideDownLanes, OddEvenBlocks) + +template > +HWY_API V InterleaveOddBlocks(D d, V a, V b) { + HWY_LANES_CONSTEXPR size_t lpb = detail::LanesPerBlock(d); + return OddEvenBlocks(b, SlideDownLanes(d, a, lpb)); +} + +// ------------------------------ TableLookupLanes + +template +HWY_API VFromD> IndicesFromVec(D d, VI vec) { + static_assert(sizeof(TFromD) == sizeof(TFromV), "Index != lane"); + const RebindToUnsigned du; // instead of : avoids unused d. + const auto indices = BitCast(du, vec); +#if HWY_IS_DEBUG_BUILD + using TU = TFromD; + const size_t twice_num_of_lanes = Lanes(d) * 2; + HWY_DASSERT(AllTrue( + du, Eq(indices, + detail::AndS(indices, static_cast(twice_num_of_lanes - 1))))); +#endif + return indices; +} + +template +HWY_API VFromD> SetTableIndices(D d, const TI* idx) { + static_assert(sizeof(TFromD) == sizeof(TI), "Index size must match lane"); + return IndicesFromVec(d, LoadU(Rebind(), idx)); +} + +#define HWY_RVV_TABLE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(uint, SEW, LMUL) idx) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL(v, idx, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +// TableLookupLanes is supported for all types, but beware that indices are +// likely to wrap around for 8-bit lanes. When using TableLookupLanes inside +// this file, ensure that it is safe or use TableLookupLanes16 instead. +HWY_RVV_FOREACH(HWY_RVV_TABLE, TableLookupLanes, rgather, _ALL) +#undef HWY_RVV_TABLE + +namespace detail { + +#define HWY_RVV_TABLE16(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(uint, SEWD, LMULD) idx) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL(v, idx, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI08(HWY_RVV_TABLE16, TableLookupLanes16, rgatherei16, _EXT) +#undef HWY_RVV_TABLE16 + +// Used by Expand. +#define HWY_RVV_MASKED_TABLE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_M(MLEN) mask, HWY_RVV_V(BASE, SEW, LMUL) maskedoff, \ + HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(uint, SEW, LMUL) idx) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL##_mu(mask, maskedoff, v, idx, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_MASKED_TABLE, MaskedTableLookupLanes, rgather, _ALL) +#undef HWY_RVV_MASKED_TABLE + +#define HWY_RVV_MASKED_TABLE16(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_M(MLEN) mask, HWY_RVV_V(BASE, SEW, LMUL) maskedoff, \ + HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_V(uint, SEWD, LMULD) idx) { \ + return __riscv_v##OP##_vv_##CHAR##SEW##LMUL##_mu(mask, maskedoff, v, idx, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH_UI08(HWY_RVV_MASKED_TABLE16, MaskedTableLookupLanes16, + rgatherei16, _EXT) +#undef HWY_RVV_MASKED_TABLE16 + +} // namespace detail + +// ------------------------------ InterleaveLowerBlocks +// (ConcatLowerLower, OddEvenBlocks, TableLookupLanes) + +namespace detail { + +// Given a concatenated vector of (either upper or lower) half of `b` and `a`, +// permutes the vector to return interleaved blocks from both. +template , HWY_IF_NOT_T_SIZE_D(D, 1)> +HWY_INLINE V InterleaveBlocks(D d, const V ba) { + const RebindToUnsigned du; + using VU = VFromD; + using TU = TFromD; + HWY_LANES_CONSTEXPR size_t lpb = detail::LanesPerBlock(d); + const VU iota = detail::Iota0(du); + // Divide the block index by 2, without affecting the within-block lane. + const VU idx_blocks = + detail::AndS(ShiftRight<1>(iota), static_cast(~(lpb - 1))); + VU idx = Or(idx_blocks, detail::AndS(iota, static_cast(lpb - 1))); + // Odd blocks from from `b`, i.e. the upper half of `ba`. + idx = OddEvenBlocks(Add(idx, Set(du, static_cast(Lanes(d) / 2))), idx); + + return TableLookupLanes(ba, IndicesFromVec(d, idx)); +} + +// As above, but uses 16-bit indices because 8-bit lanes might overflow. +// Identical except for for `du` and the `TableLookupLanes16` call. +template , HWY_IF_T_SIZE_D(D, 1)> +HWY_INLINE V InterleaveBlocks(D d, const V ba) { + const Rebind du; // Increases LMUL + using VU = VFromD; + using TU = TFromD; + HWY_LANES_CONSTEXPR size_t lpb = detail::LanesPerBlock(d); + const VU iota = detail::Iota0(du); + // Divide the block index by 2, without affecting the within-block lane. + const VU idx_blocks = + detail::AndS(ShiftRight<1>(iota), static_cast(~(lpb - 1))); + VU idx = Or(idx_blocks, detail::AndS(iota, static_cast(lpb - 1))); + // Odd blocks from from `b`, i.e. the upper half of `ba`. + idx = OddEvenBlocks(Add(idx, Set(du, static_cast(Lanes(d) / 2))), idx); + + return detail::TableLookupLanes16(ba, idx); +} + +} // namespace detail + +template , HWY_IF_POW2_GT_D(DFromV, -3)> +HWY_INLINE V InterleaveLowerBlocks(D d, const V a, const V b) { + // Blocks are independent of type, hence cast to a type where indices will not + // overflow. + const Repartition du; + return BitCast( + d, detail::InterleaveBlocks(du, BitCast(du, ConcatLowerLower(d, b, a)))); +} + +template , HWY_IF_POW2_LE_D(DFromV, -3)> +HWY_INLINE V InterleaveLowerBlocks(D d, const V a, const V b) { + // Might be bytes; if so, the second overload will call TableLookupLanes16. + return detail::InterleaveBlocks(d, ConcatLowerLower(d, b, a)); +} + +// ------------------------------ InterleaveUpperBlocks + +template , HWY_IF_POW2_GT_D(DFromV, -3)> +HWY_INLINE V InterleaveUpperBlocks(D d, const V a, const V b) { + // Blocks are independent of type, hence cast to a type where indices will not + // overflow. + const Repartition du; + return BitCast( + d, detail::InterleaveBlocks(du, BitCast(du, ConcatUpperUpper(d, b, a)))); +} + +template , HWY_IF_POW2_LE_D(DFromV, -3)> +HWY_INLINE V InterleaveUpperBlocks(D d, const V a, const V b) { + // Might be bytes; if so, the second overload will call TableLookupLanes16. + return detail::InterleaveBlocks(d, ConcatUpperUpper(d, b, a)); +} + +// ------------------------------ Reverse (TableLookupLanes) +template +HWY_API VFromD Reverse(D d, VFromD v) { + const Rebind du16; + const size_t N = Lanes(d); + const auto idx = + detail::ReverseSubS(detail::Iota0(du16), static_cast(N - 1)); + return detail::TableLookupLanes16(v, idx); +} + +template +HWY_API VFromD Reverse(D d, VFromD v) { + const Half dh; + const Rebind du16; + const size_t half_n = Lanes(dh); + const auto idx = detail::ReverseSubS(detail::Iota0(du16), + static_cast(half_n - 1)); + const auto reversed_lo = detail::TableLookupLanes16(LowerHalf(dh, v), idx); + const auto reversed_hi = detail::TableLookupLanes16(UpperHalf(dh, v), idx); + return Combine(d, reversed_lo, reversed_hi); +} + +template +HWY_API VFromD Reverse(D /* tag */, VFromD v) { + const RebindToUnsigned du; + using TU = TFromD; + const size_t N = Lanes(du); + const auto idx = + detail::ReverseSubS(detail::Iota0(du), static_cast(N - 1)); + return TableLookupLanes(v, idx); +} + +// ------------------------------ ResizeBitCast + +// Extends or truncates a vector to match the given d. +namespace detail { + +template +HWY_INLINE VFromD ChangeLMUL(D /* d */, VFromD v) { + return v; +} + +// Sanity check: when calling ChangeLMUL, the caller (ResizeBitCast) already +// BitCast to the same lane type. Note that V may use the native lane type for +// f16, so convert D to that before checking. +#define HWY_RVV_IF_SAME_T_DV(D, V) \ + hwy::EnableIf>, TFromV>()>* = nullptr + +// LMUL of VFromD < LMUL of V: need to truncate v +template >, DFromV().Pow2() - 1)> +HWY_INLINE VFromD ChangeLMUL(D d, V v) { + const DFromV d_from; + const Half dh_from; + static_assert( + DFromV>().Pow2() < DFromV().Pow2(), + "The LMUL of VFromD must be less than the LMUL of V"); + static_assert( + DFromV>().Pow2() <= DFromV>().Pow2(), + "The LMUL of VFromD must be less than or equal to the LMUL of " + "VFromD"); + return ChangeLMUL(d, Trunc(v)); +} + +// LMUL of VFromD > LMUL of V: need to extend v +template >, DFromV().Pow2())> +HWY_INLINE VFromD ChangeLMUL(D d, V v) { + const DFromV d_from; + const Twice dt_from; + static_assert(DFromV>().Pow2() > DFromV().Pow2(), + "The LMUL of VFromD must be greater than " + "the LMUL of V"); + static_assert( + DFromV>().Pow2() >= DFromV>().Pow2(), + "The LMUL of VFromD must be greater than or equal to the LMUL of " + "VFromD"); + return ChangeLMUL(d, Ext(dt_from, v)); +} + +#undef HWY_RVV_IF_SAME_T_DV + +} // namespace detail + +template +HWY_API VFromD ResizeBitCast(DTo /*dto*/, VFrom v) { + const DFromV d_from; + const Repartition du8_from; + const DFromV> d_to; + const Repartition du8_to; + return BitCast(d_to, detail::ChangeLMUL(du8_to, BitCast(du8_from, v))); +} + +// ------------------------------ Reverse2 (RotateRight, OddEven) + +// Per-target flags to prevent generic_ops-inl.h defining 8-bit Reverse2/4/8. +#ifdef HWY_NATIVE_REVERSE2_8 +#undef HWY_NATIVE_REVERSE2_8 +#else +#define HWY_NATIVE_REVERSE2_8 +#endif + +// Shifting and adding requires fewer instructions than blending, but casting to +// u32 only works for LMUL in [1/2, 8]. + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const detail::AdjustSimdTagToMinVecPow2> du16; + return ResizeBitCast(d, RotateRight<8>(ResizeBitCast(du16, v))); +} + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const detail::AdjustSimdTagToMinVecPow2> du32; + return ResizeBitCast(d, RotateRight<16>(ResizeBitCast(du32, v))); +} + +// Shifting and adding requires fewer instructions than blending, but casting to +// u64 does not work for LMUL < 1. +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const detail::AdjustSimdTagToMinVecPow2> du64; + return ResizeBitCast(d, RotateRight<32>(ResizeBitCast(du64, v))); +} + +template , HWY_IF_T_SIZE_D(D, 8)> +HWY_API V Reverse2(D /* tag */, const V v) { + const V up = detail::Slide1Up(v); + const V down = detail::Slide1Down(v); + return OddEven(up, down); +} + +// ------------------------------ Reverse4 (TableLookupLanes) + +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + const detail::AdjustSimdTagToMinVecPow2> du16; + return ResizeBitCast(d, Reverse2(du16, ResizeBitCast(du16, Reverse2(d, v)))); +} + +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + const RebindToUnsigned du; + const auto idx = detail::XorS(detail::Iota0(du), 3); + return BitCast(d, TableLookupLanes(BitCast(du, v), idx)); +} + +// ------------------------------ Reverse8 (TableLookupLanes) + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + const detail::AdjustSimdTagToMinVecPow2> du32; + return ResizeBitCast(d, Reverse2(du32, ResizeBitCast(du32, Reverse4(d, v)))); +} + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + const RebindToUnsigned du; + const auto idx = detail::XorS(detail::Iota0(du), 7); + return BitCast(d, TableLookupLanes(BitCast(du, v), idx)); +} + +// ------------------------------ ReverseBlocks (Reverse, Shuffle01) +template > +HWY_API V ReverseBlocks(D d, V v) { + const detail::AdjustSimdTagToMinVecPow2> du64; + const size_t N = Lanes(du64); + const auto rev = + detail::ReverseSubS(detail::Iota0(du64), static_cast(N - 1)); + // Swap lo/hi u64 within each block + const auto idx = detail::XorS(rev, 1); + return ResizeBitCast(d, TableLookupLanes(ResizeBitCast(du64, v), idx)); +} + +// ------------------------------ Compress + +// RVV supports all lane types natively. +#ifdef HWY_NATIVE_COMPRESS8 +#undef HWY_NATIVE_COMPRESS8 +#else +#define HWY_NATIVE_COMPRESS8 +#endif + +template +struct CompressIsPartition { + enum { value = 0 }; +}; + +#define HWY_RVV_COMPRESS(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, HWY_RVV_M(MLEN) mask) { \ + return __riscv_v##OP##_vm_##CHAR##SEW##LMUL(v, mask, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_COMPRESS, Compress, compress, _ALL) +#undef HWY_RVV_COMPRESS + +// ------------------------------ Expand + +#ifdef HWY_NATIVE_EXPAND +#undef HWY_NATIVE_EXPAND +#else +#define HWY_NATIVE_EXPAND +#endif + +// >= 2-byte lanes: idx lanes will not overflow. +template +HWY_API V Expand(V v, const M mask) { + const DFromV d; + const RebindToUnsigned du; + const auto idx = detail::MaskedIota(du, RebindMask(du, mask)); + const V zero = Zero(d); + return detail::MaskedTableLookupLanes(mask, zero, v, idx); +} + +// 1-byte lanes, LMUL < 8: promote idx to u16. +template , + HWY_IF_POW2_LE_D(D, 2)> +HWY_API V Expand(V v, const M mask) { + const D d; + const Rebind du16; + const auto idx = detail::MaskedIota(du16, RebindMask(du16, mask)); + const V zero = Zero(d); + return detail::MaskedTableLookupLanes16(mask, zero, v, idx); +} + +// 1-byte lanes, max LMUL: unroll 2x. +template , + HWY_IF_POW2_GT_D(DFromV, 2)> +HWY_API V Expand(V v, const M mask) { + const D d; + const Half dh; + const auto v0 = LowerHalf(dh, v); + // TODO(janwas): skip vec<->mask if we can cast masks. + const V vmask = VecFromMask(d, mask); + const auto m0 = MaskFromVec(LowerHalf(dh, vmask)); + + // Cannot just use UpperHalf, must shift by the number of inputs consumed. + const size_t count = CountTrue(dh, m0); + const auto v1 = detail::Trunc(detail::SlideDown(v, count)); + const auto m1 = MaskFromVec(UpperHalf(dh, vmask)); + return Combine(d, Expand(v1, m1), Expand(v0, m0)); +} + +// ------------------------------ LoadExpand +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { + return Expand(LoadU(d, unaligned), mask); +} + +// ------------------------------ CompressNot +template +HWY_API V CompressNot(V v, const M mask) { + return Compress(v, Not(mask)); +} + +// ------------------------------ CompressBlocksNot +template +HWY_API V CompressBlocksNot(V v, const M mask) { + return CompressNot(v, mask); +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(const V v, const M mask, const D d, + TFromD* HWY_RESTRICT unaligned) { + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(const V v, const M mask, const D d, + TFromD* HWY_RESTRICT unaligned) { + const size_t count = CountTrue(d, mask); + StoreN(Compress(v, mask), d, unaligned, count); + return count; +} + +// ================================================== COMPARE (2) + +// ------------------------------ FindLastTrue + +template +HWY_API intptr_t FindLastTrue(D d, MFromD m) { + const RebindToSigned di; + const intptr_t fft_rev_idx = + FindFirstTrue(d, MaskFromVec(Reverse(di, VecFromMask(di, m)))); + return (fft_rev_idx >= 0) + ? (static_cast(Lanes(d) - 1) - fft_rev_idx) + : intptr_t{-1}; +} + +template +HWY_API size_t FindKnownLastTrue(D d, MFromD m) { + const RebindToSigned di; + const size_t fft_rev_idx = + FindKnownFirstTrue(d, MaskFromVec(Reverse(di, VecFromMask(di, m)))); + return Lanes(d) - 1 - fft_rev_idx; +} + +// ------------------------------ ConcatOdd (Compress) + +namespace detail { + +#define HWY_RVV_NARROW(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) NAME(HWY_RVV_V(BASE, SEWD, LMULD) v) { \ + return __riscv_v##OP##_wx_##CHAR##SEW##LMUL(v, kShift, \ + HWY_RVV_AVL(SEWD, SHIFT + 1)); \ + } + +HWY_RVV_FOREACH_U08(HWY_RVV_NARROW, Narrow, nsrl, _EXT) +HWY_RVV_FOREACH_U16(HWY_RVV_NARROW, Narrow, nsrl, _EXT) +HWY_RVV_FOREACH_U32(HWY_RVV_NARROW, Narrow, nsrl, _EXT) +#undef HWY_RVV_NARROW + +} // namespace detail + +// Casting to wider and narrowing is the fastest for < 64-bit lanes. +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + constexpr size_t kBits = sizeof(TFromD) * 8; + const Twice dt; + const RepartitionToWide> dtuw; + const VFromD hl = BitCast(dtuw, Combine(dt, hi, lo)); + return BitCast(d, detail::Narrow(hl)); +} + +// 64-bit: Combine+Compress. +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const Twice dt; + const VFromD hl = Combine(dt, hi, lo); + return LowerHalf(d, Compress(hl, detail::IsOdd(dt))); +} + +// Any type, max LMUL: Compress both, then Combine. +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const Half dh; + const MFromD is_odd = detail::IsOdd(d); + const VFromD hi_odd = Compress(hi, is_odd); + const VFromD lo_odd = Compress(lo, is_odd); + return Combine(d, LowerHalf(dh, hi_odd), LowerHalf(dh, lo_odd)); +} + +// ------------------------------ ConcatEven (Compress) + +// Casting to wider and narrowing is the fastest for < 64-bit lanes. +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const Twice dt; + const RepartitionToWide> dtuw; + const VFromD hl = BitCast(dtuw, Combine(dt, hi, lo)); + return BitCast(d, detail::Narrow<0>(hl)); +} + +// 64-bit: Combine+Compress. +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const Twice dt; + const VFromD hl = Combine(dt, hi, lo); + return LowerHalf(d, Compress(hl, detail::IsEven(dt))); +} + +// Any type, max LMUL: Compress both, then Combine. +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const Half dh; + const MFromD is_even = detail::IsEven(d); + const VFromD hi_even = Compress(hi, is_even); + const VFromD lo_even = Compress(lo, is_even); + return Combine(d, LowerHalf(dh, hi_even), LowerHalf(dh, lo_even)); +} + +// ------------------------------ PromoteEvenTo/PromoteOddTo +#include "hwy/ops/inside-inl.h" + +// ================================================== BLOCKWISE + +// ------------------------------ CombineShiftRightBytes +template > +HWY_API V CombineShiftRightBytes(const D d, const V hi, V lo) { + const Repartition d8; + const auto hi8 = BitCast(d8, hi); + const auto lo8 = BitCast(d8, lo); + const auto hi_up = detail::SlideUp(hi8, hi8, 16 - kBytes); + const auto lo_down = detail::SlideDown(lo8, kBytes); + const auto is_lo = detail::FirstNPerBlock<16 - kBytes>(d8); + return BitCast(d, IfThenElse(is_lo, lo_down, hi_up)); +} + +// ------------------------------ CombineShiftRightLanes +template > +HWY_API V CombineShiftRightLanes(const D d, const V hi, V lo) { + constexpr size_t kLanesUp = 16 / sizeof(TFromV) - kLanes; + const auto hi_up = detail::SlideUp(hi, hi, kLanesUp); + const auto lo_down = detail::SlideDown(lo, kLanes); + const auto is_lo = detail::FirstNPerBlock(d); + return IfThenElse(is_lo, lo_down, hi_up); +} + +// ------------------------------ Shuffle2301 (ShiftLeft) +template +HWY_API V Shuffle2301(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + const Repartition du64; + const auto v64 = BitCast(du64, v); + return BitCast(d, Or(ShiftRight<32>(v64), ShiftLeft<32>(v64))); +} + +// ------------------------------ Shuffle2103 +template +HWY_API V Shuffle2103(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + return CombineShiftRightLanes<3>(d, v, v); +} + +// ------------------------------ Shuffle0321 +template +HWY_API V Shuffle0321(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + return CombineShiftRightLanes<1>(d, v, v); +} + +// ------------------------------ Shuffle1032 +template +HWY_API V Shuffle1032(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 4, "Defined for 32-bit types"); + return CombineShiftRightLanes<2>(d, v, v); +} + +// ------------------------------ Shuffle01 +template +HWY_API V Shuffle01(const V v) { + const DFromV d; + static_assert(sizeof(TFromD) == 8, "Defined for 64-bit types"); + return CombineShiftRightLanes<1>(d, v, v); +} + +// ------------------------------ Shuffle0123 +template +HWY_API V Shuffle0123(const V v) { + return Shuffle2301(Shuffle1032(v)); +} + +// ------------------------------ TableLookupBytes + +template +HWY_API VI TableLookupBytes(const VT vt, const VI vi) { + const DFromV dt; // T=table, I=index. + const DFromV di; + const Repartition dt8; + const Repartition di8; + // Required for producing half-vectors with table lookups from a full vector. + // If we instead run at the LMUL of the index vector, lookups into the table + // would be truncated. Thus we run at the larger of the two LMULs and truncate + // the result vector to the original index LMUL. + constexpr int kPow2T = dt8.Pow2(); + constexpr int kPow2I = di8.Pow2(); + const Simd dm8; // m=max + const auto vmt = detail::ChangeLMUL(dm8, BitCast(dt8, vt)); + const auto vmi = detail::ChangeLMUL(dm8, BitCast(di8, vi)); + auto offsets = detail::OffsetsOf128BitBlocks(dm8, detail::Iota0(dm8)); + // If the table is shorter, wrap around offsets so they do not reference + // undefined lanes in the newly extended vmt. + if (kPow2T < kPow2I) { + offsets = detail::AndS(offsets, static_cast(Lanes(dt8) - 1)); + } + const auto out = TableLookupLanes(vmt, Add(vmi, offsets)); + return BitCast(di, detail::ChangeLMUL(di8, out)); +} + +template +HWY_API VI TableLookupBytesOr0(const VT vt, const VI idx) { + const DFromV di; + const Repartition di8; + const auto idx8 = BitCast(di8, idx); + const auto lookup = TableLookupBytes(vt, idx8); + return BitCast(di, IfThenZeroElse(detail::LtS(idx8, 0), lookup)); +} + +// ------------------------------ TwoTablesLookupLanes + +// WARNING: 8-bit lanes may lead to unexpected results because idx is the same +// size and may overflow. +template +HWY_API VFromD TwoTablesLookupLanes(D d, VFromD a, VFromD b, + VFromD> idx) { + const Twice dt; + const RebindToUnsigned dt_u; + const auto combined_tbl = Combine(dt, b, a); + const auto combined_idx = Combine(dt_u, idx, idx); + return LowerHalf(d, TableLookupLanes(combined_tbl, combined_idx)); +} + +template +HWY_API VFromD TwoTablesLookupLanes(D d, VFromD a, VFromD b, + VFromD> idx) { + const RebindToUnsigned du; + using TU = TFromD; + + const size_t num_of_lanes = Lanes(d); + const auto idx_mod = detail::AndS(idx, static_cast(num_of_lanes - 1)); + const auto sel_a_mask = Ne(idx, idx_mod); // FALSE if a + + const auto a_lookup_result = TableLookupLanes(a, idx_mod); + return detail::MaskedTableLookupLanes(sel_a_mask, a_lookup_result, b, + idx_mod); +} + +template +HWY_API V TwoTablesLookupLanes(V a, V b, + VFromD>> idx) { + const DFromV d; + return TwoTablesLookupLanes(d, a, b, idx); +} + +// ------------------------------ Broadcast + +// 8-bit requires 16-bit tables. +template , HWY_IF_T_SIZE_D(D, 1), + HWY_IF_POW2_LE_D(D, 2)> +HWY_API V Broadcast(const V v) { + const D d; + HWY_DASSERT(0 <= kLane && kLane < detail::LanesPerBlock(d)); + + const Rebind du16; + VFromD idx = + detail::OffsetsOf128BitBlocks(d, detail::Iota0(du16)); + if (kLane != 0) { + idx = detail::AddS(idx, kLane); + } + return detail::TableLookupLanes16(v, idx); +} + +// 8-bit and max LMUL: split into halves. +template , HWY_IF_T_SIZE_D(D, 1), + HWY_IF_POW2_GT_D(D, 2)> +HWY_API V Broadcast(const V v) { + const D d; + HWY_DASSERT(0 <= kLane && kLane < detail::LanesPerBlock(d)); + + const Half dh; + using VH = VFromD; + const Rebind du16; + VFromD idx = + detail::OffsetsOf128BitBlocks(d, detail::Iota0(du16)); + if (kLane != 0) { + idx = detail::AddS(idx, kLane); + } + const VH lo = detail::TableLookupLanes16(LowerHalf(dh, v), idx); + const VH hi = detail::TableLookupLanes16(UpperHalf(dh, v), idx); + return Combine(d, hi, lo); +} + +template , + HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 2) | (1 << 4) | (1 << 8))> +HWY_API V Broadcast(const V v) { + const D d; + HWY_DASSERT(0 <= kLane && kLane < detail::LanesPerBlock(d)); + + const RebindToUnsigned du; + auto idx = detail::OffsetsOf128BitBlocks(d, detail::Iota0(du)); + if (kLane != 0) { + idx = detail::AddS(idx, kLane); + } + return TableLookupLanes(v, idx); +} + +// ------------------------------ BroadcastLane +#ifdef HWY_NATIVE_BROADCASTLANE +#undef HWY_NATIVE_BROADCASTLANE +#else +#define HWY_NATIVE_BROADCASTLANE +#endif + +namespace detail { + +#define HWY_RVV_BROADCAST_LANE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \ + LMULH, SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_V(BASE, SEW, LMUL) v, size_t idx) { \ + return __riscv_v##OP##_vx_##CHAR##SEW##LMUL(v, idx, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } + +HWY_RVV_FOREACH(HWY_RVV_BROADCAST_LANE, BroadcastLane, rgather, _ALL) +#undef HWY_RVV_BROADCAST_LANE + +} // namespace detail + +template +HWY_API V BroadcastLane(V v) { + static_assert(0 <= kLane && kLane < HWY_MAX_LANES_V(V), "Invalid lane"); + return detail::BroadcastLane(v, static_cast(kLane)); +} + +// ------------------------------ InsertBlock +#ifdef HWY_NATIVE_BLK_INSERT_EXTRACT +#undef HWY_NATIVE_BLK_INSERT_EXTRACT +#else +#define HWY_NATIVE_BLK_INSERT_EXTRACT +#endif + +template +HWY_API V InsertBlock(V v, VFromD>> blk_to_insert) { + const DFromV d; + using TU = If<(sizeof(TFromV) == 1 && DFromV().Pow2() >= -2), uint16_t, + MakeUnsigned>>; + using TIdx = If; + + const Repartition du; + const Rebind d_idx; + static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(), + "Invalid block index"); + constexpr size_t kMaxLanesPerBlock = 16 / sizeof(TU); + + constexpr size_t kBlkByteOffset = + static_cast(kBlockIdx) * kMaxLanesPerBlock; + const auto vu = BitCast(du, v); + const auto vblk = ResizeBitCast(du, blk_to_insert); + const auto vblk_shifted = detail::SlideUp(vblk, vblk, kBlkByteOffset); + const auto insert_mask = RebindMask( + du, detail::LtS(detail::SubS(detail::Iota0(d_idx), + static_cast(kBlkByteOffset)), + static_cast(kMaxLanesPerBlock))); + + return BitCast(d, IfThenElse(insert_mask, vblk_shifted, vu)); +} + +// ------------------------------ BroadcastBlock +template , -3)> +HWY_API V BroadcastBlock(V v) { + const DFromV d; + const Repartition du8; + const Rebind du16; + + static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(), + "Invalid block index"); + + const auto idx = detail::AddS(detail::AndS(detail::Iota0(du16), uint16_t{15}), + static_cast(kBlockIdx * 16)); + return BitCast(d, detail::TableLookupLanes16(BitCast(du8, v), idx)); +} + +namespace detail { + +// Called for at least 16-bit lanes to ensure indices do not overflow. +template +HWY_API V BroadcastBlock(V v) { + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; + + static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(), + "Invalid block index"); + constexpr size_t kMaxLanesPerBlock = 16 / sizeof(TU); + + const auto idx = detail::AddS( + detail::AndS(detail::Iota0(du), static_cast(kMaxLanesPerBlock - 1)), + static_cast(static_cast(kBlockIdx) * kMaxLanesPerBlock)); + return BitCast(d, TableLookupLanes(BitCast(du, v), idx)); +} + +} // namespace detail + +template , -3)> +HWY_API V BroadcastBlock(V v) { + // Because we are broadcasting 128-bit blocks, the type does not matter. + // We can cast to uint16_t to ensure indices do not overflow. + const DFromV d; + const Repartition du16; + return BitCast(d, detail::BroadcastBlock(BitCast(du16, v))); +} + +// ------------------------------ ExtractBlock +template +HWY_API VFromD>> ExtractBlock(V v) { + const DFromV d; + const BlockDFromD d_block; + + static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(), + "Invalid block index"); + constexpr size_t kMaxLanesPerBlock = 16 / sizeof(TFromD); + constexpr size_t kBlkByteOffset = + static_cast(kBlockIdx) * kMaxLanesPerBlock; + + return ResizeBitCast(d_block, detail::SlideDown(v, kBlkByteOffset)); +} + +// ------------------------------ ShiftLeftLanes + +template > +HWY_API V ShiftLeftLanes(const D d, const V v) { + const RebindToSigned di; + const RebindToUnsigned du; + using TI = TFromD; + const auto shifted = detail::SlideUp(v, v, kLanes); + // Match x86 semantics by zeroing lower lanes in 128-bit blocks + const auto idx_mod = + detail::AndS(BitCast(di, detail::Iota0(du)), + static_cast(detail::LanesPerBlock(di) - 1)); + const auto clear = detail::LtS(idx_mod, static_cast(kLanes)); + return IfThenZeroElse(clear, shifted); +} + +template +HWY_API V ShiftLeftLanes(const V v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftLeftBytes + +template +HWY_API VFromD ShiftLeftBytes(D d, const VFromD v) { + const Repartition d8; + return BitCast(d, ShiftLeftLanes(BitCast(d8, v))); +} + +template +HWY_API V ShiftLeftBytes(const V v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftRightLanes +template >> +HWY_API V ShiftRightLanes(const Simd d, V v) { + const RebindToSigned di; + const RebindToUnsigned du; + using TI = TFromD; + // For partial vectors, clear upper lanes so we shift in zeros. + if (N <= 16 / sizeof(T)) { + v = detail::SlideUp(v, Zero(d), N); + } + + const auto shifted = detail::SlideDown(v, kLanes); + // Match x86 semantics by zeroing upper lanes in 128-bit blocks + HWY_LANES_CONSTEXPR size_t lpb = detail::LanesPerBlock(di); + const auto idx_mod = + detail::AndS(BitCast(di, detail::Iota0(du)), static_cast(lpb - 1)); + const auto keep = detail::LtS(idx_mod, static_cast(lpb - kLanes)); + return IfThenElseZero(keep, shifted); +} + +// ------------------------------ ShiftRightBytes +template > +HWY_API V ShiftRightBytes(const D d, const V v) { + const Repartition d8; + return BitCast(d, ShiftRightLanes(d8, BitCast(d8, v))); +} + +// ------------------------------ InterleaveWholeLower +#ifdef HWY_TOGGLE_INTERLEAVE_WHOLE +#undef HWY_TOGGLE_INTERLEAVE_WHOLE +#else +#define HWY_TOGGLE_INTERLEAVE_WHOLE +#endif + +namespace detail { +// Returns double-length vector with interleaved lanes. +template +HWY_API VFromD InterleaveWhole(D d, VFromD> a, VFromD> b) { + const RebindToUnsigned du; + using TW = MakeWide>; + const Rebind> dw; + const Half duh; // cast inputs to unsigned so we zero-extend + + const VFromD aw = PromoteTo(dw, BitCast(duh, a)); + const VFromD bw = PromoteTo(dw, BitCast(duh, b)); + return BitCast(d, Or(aw, BitCast(dw, detail::Slide1Up(BitCast(du, bw))))); +} +// 64-bit: cannot PromoteTo, but can Ext. +template +HWY_API VFromD InterleaveWhole(D d, VFromD> a, VFromD> b) { + const RebindToUnsigned du; + const auto idx = ShiftRight<1>(detail::Iota0(du)); + return OddEven(TableLookupLanes(detail::Ext(d, b), idx), + TableLookupLanes(detail::Ext(d, a), idx)); +} +template +HWY_API VFromD InterleaveWhole(D d, VFromD> a, VFromD> b) { + const Half dh; + const Half dq; + const VFromD i0 = + InterleaveWhole(dh, LowerHalf(dq, a), LowerHalf(dq, b)); + const VFromD i1 = + InterleaveWhole(dh, UpperHalf(dq, a), UpperHalf(dq, b)); + return Combine(d, i1, i0); +} + +} // namespace detail + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + const detail::AdjustSimdTagToMinVecPow2> dw; + const RepartitionToNarrow du_src; + + const VFromD aw = + ResizeBitCast(d, PromoteLowerTo(dw, ResizeBitCast(du_src, a))); + const VFromD bw = + ResizeBitCast(d, PromoteLowerTo(dw, ResizeBitCast(du_src, b))); + return Or(aw, detail::Slide1Up(bw)); +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + const auto idx = ShiftRight<1>(detail::Iota0(du)); + return OddEven(TableLookupLanes(b, idx), TableLookupLanes(a, idx)); +} + +// ------------------------------ InterleaveWholeUpper + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + // Use Lanes(d) / 2 instead of Lanes(Half()) as Lanes(Half()) can only + // be called if (d.Pow2() >= -2 && d.Pow2() == DFromV>().Pow2()) is + // true and and as the results of InterleaveWholeUpper are + // implementation-defined if Lanes(d) is less than 2. + const size_t half_N = Lanes(d) / 2; + return InterleaveWholeLower(d, detail::SlideDown(a, half_N), + detail::SlideDown(b, half_N)); +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + // Use Lanes(d) / 2 instead of Lanes(Half()) as Lanes(Half()) can only + // be called if (d.Pow2() >= -2 && d.Pow2() == DFromV>().Pow2()) is + // true and as the results of InterleaveWholeUpper are implementation-defined + // if Lanes(d) is less than 2. + const size_t half_N = Lanes(d) / 2; + const RebindToUnsigned du; + const auto idx = detail::AddS(ShiftRight<1>(detail::Iota0(du)), + static_cast(half_N)); + return OddEven(TableLookupLanes(b, idx), TableLookupLanes(a, idx)); +} + +// ------------------------------ InterleaveLower (InterleaveWholeLower) + +namespace detail { + +// Definitely at least 128 bit: match x86 semantics (independent blocks). Using +// InterleaveWhole and 64-bit Compress avoids 8-bit overflow. +template +HWY_INLINE V InterleaveLowerImpl(D d, const V a, const V b) { + static_assert(IsSame, TFromV>(), "D/V mismatch"); + const Twice dt; + const RebindToUnsigned dt_u; + const VFromD interleaved = detail::InterleaveWhole(dt, a, b); + // Keep only even 128-bit blocks. This is faster than u64 ConcatEven + // because we only have a single vector. + constexpr size_t kShift = CeilLog2(16 / sizeof(TFromD)); + const VFromD idx_block = + ShiftRight(detail::Iota0(dt_u)); + const MFromD is_even = + detail::EqS(detail::AndS(idx_block, 1), 0); + return BitCast(d, LowerHalf(Compress(BitCast(dt_u, interleaved), is_even))); +} +template +HWY_INLINE V InterleaveLowerImpl(D d, const V a, const V b) { + const Half dh; + const VFromD i0 = + InterleaveLowerImpl(dh, LowerHalf(dh, a), LowerHalf(dh, b)); + const VFromD i1 = + InterleaveLowerImpl(dh, UpperHalf(dh, a), UpperHalf(dh, b)); + return Combine(d, i1, i0); +} + +// As above, for the upper half of blocks. +template +HWY_INLINE V InterleaveUpperImpl(D d, const V a, const V b) { + static_assert(IsSame, TFromV>(), "D/V mismatch"); + const Twice dt; + const RebindToUnsigned dt_u; + const VFromD interleaved = detail::InterleaveWhole(dt, a, b); + // Keep only odd 128-bit blocks. This is faster than u64 ConcatEven + // because we only have a single vector. + constexpr size_t kShift = CeilLog2(16 / sizeof(TFromD)); + const VFromD idx_block = + ShiftRight(detail::Iota0(dt_u)); + const MFromD is_odd = + detail::EqS(detail::AndS(idx_block, 1), 1); + return BitCast(d, LowerHalf(Compress(BitCast(dt_u, interleaved), is_odd))); +} +template +HWY_INLINE V InterleaveUpperImpl(D d, const V a, const V b) { + const Half dh; + const VFromD i0 = + InterleaveUpperImpl(dh, LowerHalf(dh, a), LowerHalf(dh, b)); + const VFromD i1 = + InterleaveUpperImpl(dh, UpperHalf(dh, a), UpperHalf(dh, b)); + return Combine(d, i1, i0); +} + +// RVV vectors are at least 128 bit when there is no fractional LMUL nor cap. +// Used by functions with per-block behavior such as InterleaveLower. +template +constexpr bool IsGE128(Simd /* d */) { + return N * sizeof(T) >= 16 && kPow2 >= 0; +} + +// Definitely less than 128-bit only if there is a small cap; fractional LMUL +// might not be enough if vectors are large. +template +constexpr bool IsLT128(Simd /* d */) { + return N * sizeof(T) < 16; +} + +} // namespace detail + +#define HWY_RVV_IF_GE128_D(D) hwy::EnableIf* = nullptr +#define HWY_RVV_IF_LT128_D(D) hwy::EnableIf* = nullptr +#define HWY_RVV_IF_CAN128_D(D) \ + hwy::EnableIf* = nullptr + +template +HWY_API V InterleaveLower(D d, const V a, const V b) { + return detail::InterleaveLowerImpl(d, a, b); +} + +// Single block: interleave without extra Compress. +template +HWY_API V InterleaveLower(D d, const V a, const V b) { + static_assert(IsSame, TFromV>(), "D/V mismatch"); + return InterleaveWholeLower(d, a, b); +} + +// Could be either; branch at runtime. +template +HWY_API V InterleaveLower(D d, const V a, const V b) { + if (Lanes(d) * sizeof(TFromD) <= 16) { + return InterleaveWholeLower(d, a, b); + } + // Fractional LMUL: use LMUL=1 to ensure we can cast to u64. + const ScalableTag, HWY_MAX(d.Pow2(), 0)> d1; + return ResizeBitCast(d, detail::InterleaveLowerImpl(d1, ResizeBitCast(d1, a), + ResizeBitCast(d1, b))); +} + +template +HWY_API V InterleaveLower(const V a, const V b) { + return InterleaveLower(DFromV(), a, b); +} + +// ------------------------------ InterleaveUpper (Compress) + +template +HWY_API V InterleaveUpper(D d, const V a, const V b) { + return detail::InterleaveUpperImpl(d, a, b); +} + +// Single block: interleave without extra Compress. +template +HWY_API V InterleaveUpper(D d, const V a, const V b) { + static_assert(IsSame, TFromV>(), "D/V mismatch"); + return InterleaveWholeUpper(d, a, b); +} + +// Could be either; branch at runtime. +template +HWY_API V InterleaveUpper(D d, const V a, const V b) { + if (Lanes(d) * sizeof(TFromD) <= 16) { + return InterleaveWholeUpper(d, a, b); + } + // Fractional LMUL: use LMUL=1 to ensure we can cast to u64. + const ScalableTag, HWY_MAX(d.Pow2(), 0)> d1; + return ResizeBitCast(d, detail::InterleaveUpperImpl(d1, ResizeBitCast(d1, a), + ResizeBitCast(d1, b))); +} + +// ------------------------------ ZipLower + +template >> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + const RepartitionToNarrow dn; + static_assert(IsSame, TFromV>(), "D/V mismatch"); + return BitCast(dw, InterleaveLower(dn, a, b)); +} + +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} + +// ------------------------------ ZipUpper +template +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + const RepartitionToNarrow dn; + static_assert(IsSame, TFromV>(), "D/V mismatch"); + return BitCast(dw, InterleaveUpper(dn, a, b)); +} + +// ================================================== REDUCE + +// We have ReduceSum, generic_ops-inl.h defines SumOfLanes via Set. +#ifdef HWY_NATIVE_REDUCE_SCALAR +#undef HWY_NATIVE_REDUCE_SCALAR +#else +#define HWY_NATIVE_REDUCE_SCALAR +#endif + +// scalar = f(vector, zero_m1) +#define HWY_RVV_REDUCE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_T(BASE, SEW) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_V(BASE, SEW, m1) v0) { \ + return GetLane(__riscv_v##OP##_vs_##CHAR##SEW##LMUL##_##CHAR##SEW##m1( \ + v, v0, Lanes(d))); \ + } + +// detail::RedSum, detail::RedMin, and detail::RedMax is more efficient +// for N=4 I8/U8 reductions on RVV than the default implementations of the +// the N=4 I8/U8 ReduceSum/ReduceMin/ReduceMax operations in generic_ops-inl.h +#undef HWY_IF_REDUCE_D +#define HWY_IF_REDUCE_D(D) hwy::EnableIf* = nullptr + +#ifdef HWY_NATIVE_REDUCE_SUM_4_UI8 +#undef HWY_NATIVE_REDUCE_SUM_4_UI8 +#else +#define HWY_NATIVE_REDUCE_SUM_4_UI8 +#endif + +#ifdef HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#undef HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#else +#define HWY_NATIVE_REDUCE_MINMAX_4_UI8 +#endif + +// ------------------------------ ReduceSum + +namespace detail { +HWY_RVV_FOREACH_UI(HWY_RVV_REDUCE, RedSum, redsum, _ALL_VIRT) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedSum, fredusum, _ALL_VIRT) +} // namespace detail + +template +HWY_API TFromD ReduceSum(D d, const VFromD v) { + const auto v0 = Zero(ScalableTag>()); // always m1 + return detail::RedSum(d, v, v0); +} + +// ------------------------------ ReduceMin +namespace detail { +HWY_RVV_FOREACH_U(HWY_RVV_REDUCE, RedMin, redminu, _ALL_VIRT) +HWY_RVV_FOREACH_I(HWY_RVV_REDUCE, RedMin, redmin, _ALL_VIRT) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedMin, fredmin, _ALL_VIRT) +} // namespace detail + +template , HWY_IF_REDUCE_D(D)> +HWY_API T ReduceMin(D d, const VFromD v) { + const ScalableTag d1; // always m1 + return detail::RedMin(d, v, Set(d1, HighestValue())); +} + +// ------------------------------ ReduceMax +namespace detail { +HWY_RVV_FOREACH_U(HWY_RVV_REDUCE, RedMax, redmaxu, _ALL_VIRT) +HWY_RVV_FOREACH_I(HWY_RVV_REDUCE, RedMax, redmax, _ALL_VIRT) +HWY_RVV_FOREACH_F(HWY_RVV_REDUCE, RedMax, fredmax, _ALL_VIRT) +} // namespace detail + +template , HWY_IF_REDUCE_D(D)> +HWY_API T ReduceMax(D d, const VFromD v) { + const ScalableTag d1; // always m1 + return detail::RedMax(d, v, Set(d1, LowestValue())); +} + +#undef HWY_RVV_REDUCE + +// TODO: add MaskedReduceSum/Min/Max + +// ------------------------------ SumOfLanes + +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + return Set(d, ReduceSum(d, v)); +} +template +HWY_API VFromD MinOfLanes(D d, VFromD v) { + return Set(d, ReduceMin(d, v)); +} +template +HWY_API VFromD MaxOfLanes(D d, VFromD v) { + return Set(d, ReduceMax(d, v)); +} + +// ================================================== Ops with dependencies + +// ------------------------------ LoadInterleaved2 + +// Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +// Requires Clang 16+, GCC 14+; otherwise emulated in generic_ops-inl.h. +#if HWY_HAVE_TUPLE + +#define HWY_RVV_GET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME##2(HWY_RVV_TUP(BASE, SEW, LMUL, 2) tup) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL##x2_##CHAR##SEW##LMUL(tup, \ + kIndex); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME##3(HWY_RVV_TUP(BASE, SEW, LMUL, 3) tup) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL##x3_##CHAR##SEW##LMUL(tup, \ + kIndex); \ + } \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME##4(HWY_RVV_TUP(BASE, SEW, LMUL, 4) tup) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL##x4_##CHAR##SEW##LMUL(tup, \ + kIndex); \ + } + +HWY_RVV_FOREACH(HWY_RVV_GET, Get, get, _LE2) +#undef HWY_RVV_GET + +#define HWY_RVV_SET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_TUP(BASE, SEW, LMUL, 2) NAME##2( \ + HWY_RVV_TUP(BASE, SEW, LMUL, 2) tup, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMUL##x2( \ + tup, kIndex, v); \ + } \ + template \ + HWY_API HWY_RVV_TUP(BASE, SEW, LMUL, 3) NAME##3( \ + HWY_RVV_TUP(BASE, SEW, LMUL, 3) tup, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMUL##x3( \ + tup, kIndex, v); \ + } \ + template \ + HWY_API HWY_RVV_TUP(BASE, SEW, LMUL, 4) NAME##4( \ + HWY_RVV_TUP(BASE, SEW, LMUL, 4) tup, HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMUL##x4( \ + tup, kIndex, v); \ + } + +HWY_RVV_FOREACH(HWY_RVV_SET, Set, set, _LE2) +#undef HWY_RVV_SET + +// RVV does not provide vcreate, so implement using Set. +#define HWY_RVV_CREATE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_TUP(BASE, SEW, LMUL, 2) \ + NAME##2(HWY_RVV_D(BASE, SEW, N, SHIFT) /*d*/, \ + HWY_RVV_V(BASE, SEW, LMUL) v0, HWY_RVV_V(BASE, SEW, LMUL) v1) { \ + HWY_RVV_TUP(BASE, SEW, LMUL, 2) tup{}; \ + tup = Set2<0>(tup, v0); \ + tup = Set2<1>(tup, v1); \ + return tup; \ + } \ + template \ + HWY_API HWY_RVV_TUP(BASE, SEW, LMUL, 3) NAME##3( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) /*d*/, HWY_RVV_V(BASE, SEW, LMUL) v0, \ + HWY_RVV_V(BASE, SEW, LMUL) v1, HWY_RVV_V(BASE, SEW, LMUL) v2) { \ + HWY_RVV_TUP(BASE, SEW, LMUL, 3) tup{}; \ + tup = Set3<0>(tup, v0); \ + tup = Set3<1>(tup, v1); \ + tup = Set3<2>(tup, v2); \ + return tup; \ + } \ + template \ + HWY_API HWY_RVV_TUP(BASE, SEW, LMUL, 4) \ + NAME##4(HWY_RVV_D(BASE, SEW, N, SHIFT) /*d*/, \ + HWY_RVV_V(BASE, SEW, LMUL) v0, HWY_RVV_V(BASE, SEW, LMUL) v1, \ + HWY_RVV_V(BASE, SEW, LMUL) v2, HWY_RVV_V(BASE, SEW, LMUL) v3) { \ + HWY_RVV_TUP(BASE, SEW, LMUL, 4) tup{}; \ + tup = Set4<0>(tup, v0); \ + tup = Set4<1>(tup, v1); \ + tup = Set4<2>(tup, v2); \ + tup = Set4<3>(tup, v3); \ + return tup; \ + } + +HWY_RVV_FOREACH(HWY_RVV_CREATE, Create, xx, _LE2_VIRT) +#undef HWY_RVV_CREATE + +template +using Vec2 = decltype(Create2(D(), Zero(D()), Zero(D()))); +template +using Vec3 = decltype(Create3(D(), Zero(D()), Zero(D()), Zero(D()))); +template +using Vec4 = decltype(Create4(D(), Zero(D()), Zero(D()), Zero(D()), Zero(D()))); + +#define HWY_RVV_LOAD2(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned, \ + HWY_RVV_V(BASE, SEW, LMUL) & v0, \ + HWY_RVV_V(BASE, SEW, LMUL) & v1) { \ + const HWY_RVV_TUP(BASE, SEW, LMUL, 2) tup = \ + __riscv_v##OP##e##SEW##_v_##CHAR##SEW##LMUL##x2(unaligned, Lanes(d)); \ + v0 = Get2<0>(tup); \ + v1 = Get2<1>(tup); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_LOAD2, LoadInterleaved2, lseg2, _LE2_VIRT) +#undef HWY_RVV_LOAD2 + +// ------------------------------ LoadInterleaved3 + +#define HWY_RVV_LOAD3(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned, \ + HWY_RVV_V(BASE, SEW, LMUL) & v0, \ + HWY_RVV_V(BASE, SEW, LMUL) & v1, \ + HWY_RVV_V(BASE, SEW, LMUL) & v2) { \ + const HWY_RVV_TUP(BASE, SEW, LMUL, 3) tup = \ + __riscv_v##OP##e##SEW##_v_##CHAR##SEW##LMUL##x3(unaligned, Lanes(d)); \ + v0 = Get3<0>(tup); \ + v1 = Get3<1>(tup); \ + v2 = Get3<2>(tup); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_LOAD3, LoadInterleaved3, lseg3, _LE2_VIRT) +#undef HWY_RVV_LOAD3 + +// ------------------------------ LoadInterleaved4 + +#define HWY_RVV_LOAD4(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME( \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned, \ + HWY_RVV_V(BASE, SEW, LMUL) & v0, HWY_RVV_V(BASE, SEW, LMUL) & v1, \ + HWY_RVV_V(BASE, SEW, LMUL) & v2, HWY_RVV_V(BASE, SEW, LMUL) & v3) { \ + const HWY_RVV_TUP(BASE, SEW, LMUL, 4) tup = \ + __riscv_v##OP##e##SEW##_v_##CHAR##SEW##LMUL##x4(unaligned, Lanes(d)); \ + v0 = Get4<0>(tup); \ + v1 = Get4<1>(tup); \ + v2 = Get4<2>(tup); \ + v3 = Get4<3>(tup); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_LOAD4, LoadInterleaved4, lseg4, _LE2_VIRT) +#undef HWY_RVV_LOAD4 + +// ------------------------------ StoreInterleaved2 + +#define HWY_RVV_STORE2(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v0, \ + HWY_RVV_V(BASE, SEW, LMUL) v1, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned) { \ + const HWY_RVV_TUP(BASE, SEW, LMUL, 2) tup = Create2(d, v0, v1); \ + __riscv_v##OP##e##SEW##_v_##CHAR##SEW##LMUL##x2(unaligned, tup, Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_STORE2, StoreInterleaved2, sseg2, _LE2_VIRT) +#undef HWY_RVV_STORE2 + +// ------------------------------ StoreInterleaved3 + +#define HWY_RVV_STORE3(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME( \ + HWY_RVV_V(BASE, SEW, LMUL) v0, HWY_RVV_V(BASE, SEW, LMUL) v1, \ + HWY_RVV_V(BASE, SEW, LMUL) v2, HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned) { \ + const HWY_RVV_TUP(BASE, SEW, LMUL, 3) tup = Create3(d, v0, v1, v2); \ + __riscv_v##OP##e##SEW##_v_##CHAR##SEW##LMUL##x3(unaligned, tup, Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_STORE3, StoreInterleaved3, sseg3, _LE2_VIRT) +#undef HWY_RVV_STORE3 + +// ------------------------------ StoreInterleaved4 + +#define HWY_RVV_STORE4(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \ + MLEN, NAME, OP) \ + template \ + HWY_API void NAME( \ + HWY_RVV_V(BASE, SEW, LMUL) v0, HWY_RVV_V(BASE, SEW, LMUL) v1, \ + HWY_RVV_V(BASE, SEW, LMUL) v2, HWY_RVV_V(BASE, SEW, LMUL) v3, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned) { \ + const HWY_RVV_TUP(BASE, SEW, LMUL, 4) tup = Create4(d, v0, v1, v2, v3); \ + __riscv_v##OP##e##SEW##_v_##CHAR##SEW##LMUL##x4(unaligned, tup, Lanes(d)); \ + } +// Segments are limited to 8 registers, so we can only go up to LMUL=2. +HWY_RVV_FOREACH(HWY_RVV_STORE4, StoreInterleaved4, sseg4, _LE2_VIRT) +#undef HWY_RVV_STORE4 + +#else // !HWY_HAVE_TUPLE + +template , HWY_RVV_IF_NOT_EMULATED_D(D)> +HWY_API void LoadInterleaved2(D d, const T* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1) { + const VFromD A = LoadU(d, unaligned); // v1[1] v0[1] v1[0] v0[0] + const VFromD B = LoadU(d, unaligned + Lanes(d)); + v0 = ConcatEven(d, B, A); + v1 = ConcatOdd(d, B, A); +} + +namespace detail { +#define HWY_RVV_LOAD_STRIDED(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEW, LMUL) \ + NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + const HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p, size_t stride) { \ + return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL( \ + p, static_cast(stride), Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_LOAD_STRIDED, LoadStrided, lse, _ALL_VIRT) +#undef HWY_RVV_LOAD_STRIDED +} // namespace detail + +template , HWY_RVV_IF_NOT_EMULATED_D(D)> +HWY_API void LoadInterleaved3(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2) { + // Offsets are bytes, and this is not documented. + v0 = detail::LoadStrided(d, unaligned + 0, 3 * sizeof(T)); + v1 = detail::LoadStrided(d, unaligned + 1, 3 * sizeof(T)); + v2 = detail::LoadStrided(d, unaligned + 2, 3 * sizeof(T)); +} + +template , HWY_RVV_IF_NOT_EMULATED_D(D)> +HWY_API void LoadInterleaved4(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& v0, VFromD& v1, VFromD& v2, + VFromD& v3) { + // Offsets are bytes, and this is not documented. + v0 = detail::LoadStrided(d, unaligned + 0, 4 * sizeof(T)); + v1 = detail::LoadStrided(d, unaligned + 1, 4 * sizeof(T)); + v2 = detail::LoadStrided(d, unaligned + 2, 4 * sizeof(T)); + v3 = detail::LoadStrided(d, unaligned + 3, 4 * sizeof(T)); +} + +// Not 64-bit / max LMUL: interleave via promote, slide, OddEven. +template , HWY_IF_NOT_T_SIZE_D(D, 8), + HWY_IF_POW2_LE_D(D, 2), HWY_RVV_IF_NOT_EMULATED_D(D)> +HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, + T* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + const Twice> duw; + const Twice dt; + // Interleave with zero by promoting to wider (unsigned) type. + const VFromD w0 = BitCast(dt, PromoteTo(duw, BitCast(du, v0))); + const VFromD w1 = BitCast(dt, PromoteTo(duw, BitCast(du, v1))); + // OR second vector into the zero-valued lanes (faster than OddEven). + StoreU(Or(w0, detail::Slide1Up(w1)), dt, unaligned); +} + +// Can promote, max LMUL: two half-length +template , HWY_IF_NOT_T_SIZE_D(D, 8), + HWY_IF_POW2_GT_D(D, 2), HWY_RVV_IF_NOT_EMULATED_D(D)> +HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, + T* HWY_RESTRICT unaligned) { + const Half dh; + StoreInterleaved2(LowerHalf(dh, v0), LowerHalf(dh, v1), d, unaligned); + StoreInterleaved2(UpperHalf(dh, v0), UpperHalf(dh, v1), d, + unaligned + Lanes(d)); +} + +namespace detail { +#define HWY_RVV_STORE_STRIDED(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v, \ + HWY_RVV_D(BASE, SEW, N, SHIFT) d, \ + HWY_RVV_T(BASE, SEW) * HWY_RESTRICT p, size_t stride) { \ + return __riscv_v##OP##SEW##_v_##CHAR##SEW##LMUL( \ + p, static_cast(stride), v, Lanes(d)); \ + } +HWY_RVV_FOREACH(HWY_RVV_STORE_STRIDED, StoreStrided, sse, _ALL_VIRT) +#undef HWY_RVV_STORE_STRIDED +} // namespace detail + +// 64-bit: strided +template , HWY_IF_T_SIZE_D(D, 8), + HWY_RVV_IF_NOT_EMULATED_D(D)> +HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, + T* HWY_RESTRICT unaligned) { + // Offsets are bytes, and this is not documented. + detail::StoreStrided(v0, d, unaligned + 0, 2 * sizeof(T)); + detail::StoreStrided(v1, d, unaligned + 1, 2 * sizeof(T)); +} + +template , HWY_RVV_IF_NOT_EMULATED_D(D)> +HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, + T* HWY_RESTRICT unaligned) { + // Offsets are bytes, and this is not documented. + detail::StoreStrided(v0, d, unaligned + 0, 3 * sizeof(T)); + detail::StoreStrided(v1, d, unaligned + 1, 3 * sizeof(T)); + detail::StoreStrided(v2, d, unaligned + 2, 3 * sizeof(T)); +} + +template , HWY_RVV_IF_NOT_EMULATED_D(D)> +HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, + VFromD v3, D d, T* HWY_RESTRICT unaligned) { + // Offsets are bytes, and this is not documented. + detail::StoreStrided(v0, d, unaligned + 0, 4 * sizeof(T)); + detail::StoreStrided(v1, d, unaligned + 1, 4 * sizeof(T)); + detail::StoreStrided(v2, d, unaligned + 2, 4 * sizeof(T)); + detail::StoreStrided(v3, d, unaligned + 3, 4 * sizeof(T)); +} + +#endif // HWY_HAVE_TUPLE + +// Rely on generic Load/StoreInterleaved[234] for any emulated types. +// Requires HWY_GENERIC_IF_EMULATED_D mirrors HWY_RVV_IF_EMULATED_D. + +// ------------------------------ Dup128VecFromValues (ResizeBitCast) + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD /*t1*/) { + return Set(d, t0); +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1) { + const auto even_lanes = Set(d, t0); +#if HWY_COMPILER_GCC && !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(BitCastScalar(t0) == + BitCastScalar(t1)) && + (BitCastScalar(t0) == BitCastScalar(t1))) { + return even_lanes; + } +#endif + + const auto odd_lanes = Set(d, t1); + return OddEven(odd_lanes, even_lanes); +} + +namespace detail { + +#pragma pack(push, 1) + +template +struct alignas(8) Vec64ValsWrapper { + static_assert(sizeof(T) >= 1, "sizeof(T) >= 1 must be true"); + static_assert(sizeof(T) <= 8, "sizeof(T) <= 8 must be true"); + T vals[8 / sizeof(T)]; +}; + +#pragma pack(pop) + +} // namespace detail + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + const detail::AdjustSimdTagToMinVecPow2> du64; + return ResizeBitCast( + d, Dup128VecFromValues( + du64, + BitCastScalar(detail::Vec64ValsWrapper>{ + {t0, t1, t2, t3, t4, t5, t6, t7}}), + BitCastScalar(detail::Vec64ValsWrapper>{ + {t8, t9, t10, t11, t12, t13, t14, t15}}))); +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const detail::AdjustSimdTagToMinVecPow2> du64; + return ResizeBitCast( + d, Dup128VecFromValues( + du64, + BitCastScalar( + detail::Vec64ValsWrapper>{{t0, t1, t2, t3}}), + BitCastScalar( + detail::Vec64ValsWrapper>{{t4, t5, t6, t7}}))); +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + const detail::AdjustSimdTagToMinVecPow2> du64; + return ResizeBitCast( + d, + Dup128VecFromValues(du64, + BitCastScalar( + detail::Vec64ValsWrapper>{{t0, t1}}), + BitCastScalar( + detail::Vec64ValsWrapper>{{t2, t3}}))); +} + +// ------------------------------ LoadDup128 + +template +HWY_API VFromD LoadDup128(D d, const TFromD* const HWY_RESTRICT p) { + const RebindToUnsigned du; + + // Make sure that no more than 16 bytes are loaded from p + constexpr int kLoadPow2 = d.Pow2(); + constexpr size_t kMaxLanesToLoad = + HWY_MIN(HWY_MAX_LANES_D(D), 16 / sizeof(TFromD)); + constexpr size_t kLoadN = D::template NewN(); + const Simd, kLoadN, kLoadPow2> d_load; + static_assert(d_load.MaxBytes() <= 16, + "d_load.MaxBytes() <= 16 must be true"); + static_assert((d.MaxBytes() < 16) || (d_load.MaxBytes() == 16), + "d_load.MaxBytes() == 16 must be true if d.MaxBytes() >= 16 is " + "true"); + static_assert((d.MaxBytes() >= 16) || (d_load.MaxBytes() == d.MaxBytes()), + "d_load.MaxBytes() == d.MaxBytes() must be true if " + "d.MaxBytes() < 16 is true"); + + const VFromD loaded = Load(d_load, p); + if (d.MaxBytes() <= 16) return loaded; + + // idx must be unsigned for TableLookupLanes. + using TU = TFromD; + HWY_LANES_CONSTEXPR TU mask = static_cast(detail::LanesPerBlock(d) - 1); + // Broadcast the first block. + const VFromD> idx = detail::AndS(detail::Iota0(du), mask); + // Safe even for 8-bit lanes because indices never exceed 15. + return TableLookupLanes(loaded, idx); +} + +// ------------------------------ LoadMaskBits + +// Support all combinations of T and SHIFT(LMUL) without explicit overloads for +// each. First overload for MLEN=1..64. +namespace detail { + +// Maps D to MLEN (wrapped in SizeTag), such that #mask_bits = VLEN/MLEN. MLEN +// increases with lane size and decreases for increasing LMUL. Cap at 64, the +// largest supported by HWY_RVV_FOREACH_B (and intrinsics), for virtual LMUL +// e.g. vuint16mf8_t: (8*2 << 3) == 128. +template +using MaskTag = hwy::SizeTag), -D().Pow2()))>; + +#define HWY_RVV_LOAD_MASK_BITS(SEW, SHIFT, MLEN, NAME, OP) \ + HWY_INLINE HWY_RVV_M(MLEN) \ + NAME(hwy::SizeTag /* tag */, const uint8_t* bits, size_t N) { \ + return __riscv_v##OP##_v_b##MLEN(bits, N); \ + } +HWY_RVV_FOREACH_B(HWY_RVV_LOAD_MASK_BITS, LoadMaskBits, lm) +#undef HWY_RVV_LOAD_MASK_BITS +} // namespace detail + +template > +HWY_API auto LoadMaskBits(D d, const uint8_t* bits) + -> decltype(detail::LoadMaskBits(MT(), bits, Lanes(d))) { + return detail::LoadMaskBits(MT(), bits, Lanes(d)); +} + +// ------------------------------ StoreMaskBits +#define HWY_RVV_STORE_MASK_BITS(SEW, SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API size_t NAME(D d, HWY_RVV_M(MLEN) m, uint8_t* bits) { \ + const size_t N = Lanes(d); \ + __riscv_v##OP##_v_b##MLEN(bits, m, N); \ + /* Non-full byte, need to clear the undefined upper bits. */ \ + /* Use MaxLanes and sizeof(T) to move some checks to compile-time. */ \ + constexpr bool kLessThan8 = \ + detail::ScaleByPower(16 / sizeof(TFromD), d.Pow2()) < 8; \ + if (MaxLanes(d) < 8 || (kLessThan8 && N < 8)) { \ + const int mask = (1 << N) - 1; \ + bits[0] = static_cast(bits[0] & mask); \ + } \ + return (N + 7) / 8; \ + } +HWY_RVV_FOREACH_B(HWY_RVV_STORE_MASK_BITS, StoreMaskBits, sm) +#undef HWY_RVV_STORE_MASK_BITS + +// ------------------------------ CompressBits, CompressBitsStore (LoadMaskBits) + +template +HWY_INLINE V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) { + return Compress(v, LoadMaskBits(DFromV(), bits)); +} + +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} + +// ------------------------------ FirstN (Iota0, Lt, RebindMask, SlideUp) + +// NOTE: do not use this as a building block within rvv-inl - it is likely more +// efficient to use avl or detail::SlideUp. + +// Disallow for 8-bit because Iota is likely to overflow. +template +HWY_API MFromD FirstN(const D d, const size_t n) { + const RebindToUnsigned du; + using TU = TFromD; + return RebindMask(d, detail::LtS(detail::Iota0(du), static_cast(n))); +} + +template +HWY_API MFromD FirstN(const D d, const size_t n) { + const auto zero = Zero(d); + const auto one = Set(d, 1); + return Eq(detail::SlideUp(one, zero, n), one); +} + +// ------------------------------ LowerHalfOfMask/UpperHalfOfMask + +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +// Target-specific implementations of LowerHalfOfMask, UpperHalfOfMask, +// CombineMasks, OrderedDemote2MasksTo, and Dup128MaskFromMaskBits are possible +// on RVV if the __riscv_vreinterpret_v_b*_u8m1 and +// __riscv_vreinterpret_v_u8m1_b* intrinsics are available. + +// The __riscv_vreinterpret_v_b*_u8m1 and __riscv_vreinterpret_v_u8m1_b* +// intrinsics available with Clang 17 and later and GCC 14 and later. + +namespace detail { + +HWY_INLINE vuint8m1_t MaskToU8MaskBitsVec(vbool1_t m) { + return __riscv_vreinterpret_v_b1_u8m1(m); +} + +HWY_INLINE vuint8m1_t MaskToU8MaskBitsVec(vbool2_t m) { + return __riscv_vreinterpret_v_b2_u8m1(m); +} + +HWY_INLINE vuint8m1_t MaskToU8MaskBitsVec(vbool4_t m) { + return __riscv_vreinterpret_v_b4_u8m1(m); +} + +HWY_INLINE vuint8m1_t MaskToU8MaskBitsVec(vbool8_t m) { + return __riscv_vreinterpret_v_b8_u8m1(m); +} + +HWY_INLINE vuint8m1_t MaskToU8MaskBitsVec(vbool16_t m) { + return __riscv_vreinterpret_v_b16_u8m1(m); +} + +HWY_INLINE vuint8m1_t MaskToU8MaskBitsVec(vbool32_t m) { + return __riscv_vreinterpret_v_b32_u8m1(m); +} + +HWY_INLINE vuint8m1_t MaskToU8MaskBitsVec(vbool64_t m) { + return __riscv_vreinterpret_v_b64_u8m1(m); +} + +template , vbool1_t>()>* = nullptr> +HWY_INLINE MFromD U8MaskBitsVecToMask(D /*d*/, vuint8m1_t v) { + return __riscv_vreinterpret_v_u8m1_b1(v); +} + +template , vbool2_t>()>* = nullptr> +HWY_INLINE MFromD U8MaskBitsVecToMask(D /*d*/, vuint8m1_t v) { + return __riscv_vreinterpret_v_u8m1_b2(v); +} + +template , vbool4_t>()>* = nullptr> +HWY_INLINE MFromD U8MaskBitsVecToMask(D /*d*/, vuint8m1_t v) { + return __riscv_vreinterpret_v_u8m1_b4(v); +} + +template , vbool8_t>()>* = nullptr> +HWY_INLINE MFromD U8MaskBitsVecToMask(D /*d*/, vuint8m1_t v) { + return __riscv_vreinterpret_v_u8m1_b8(v); +} + +template , vbool16_t>()>* = nullptr> +HWY_INLINE MFromD U8MaskBitsVecToMask(D /*d*/, vuint8m1_t v) { + return __riscv_vreinterpret_v_u8m1_b16(v); +} + +template , vbool32_t>()>* = nullptr> +HWY_INLINE MFromD U8MaskBitsVecToMask(D /*d*/, vuint8m1_t v) { + return __riscv_vreinterpret_v_u8m1_b32(v); +} + +template , vbool64_t>()>* = nullptr> +HWY_INLINE MFromD U8MaskBitsVecToMask(D /*d*/, vuint8m1_t v) { + return __riscv_vreinterpret_v_u8m1_b64(v); +} + +} // namespace detail + +#ifdef HWY_NATIVE_LOWER_HALF_OF_MASK +#undef HWY_NATIVE_LOWER_HALF_OF_MASK +#else +#define HWY_NATIVE_LOWER_HALF_OF_MASK +#endif + +template +HWY_API MFromD LowerHalfOfMask(D d, MFromD> m) { + return detail::U8MaskBitsVecToMask(d, detail::MaskToU8MaskBitsVec(m)); +} + +#ifdef HWY_NATIVE_UPPER_HALF_OF_MASK +#undef HWY_NATIVE_UPPER_HALF_OF_MASK +#else +#define HWY_NATIVE_UPPER_HALF_OF_MASK +#endif + +template +HWY_API MFromD UpperHalfOfMask(D d, MFromD> m) { + const size_t N = Lanes(d); + + vuint8m1_t mask_bits = detail::MaskToU8MaskBitsVec(m); + mask_bits = ShiftRightSame(mask_bits, static_cast(N & 7)); + if (HWY_MAX_LANES_D(D) >= 8) { + mask_bits = SlideDownLanes(ScalableTag(), mask_bits, N / 8); + } + + return detail::U8MaskBitsVecToMask(d, mask_bits); +} + +// ------------------------------ CombineMasks + +#ifdef HWY_NATIVE_COMBINE_MASKS +#undef HWY_NATIVE_COMBINE_MASKS +#else +#define HWY_NATIVE_COMBINE_MASKS +#endif + +template +HWY_API MFromD CombineMasks(D d, MFromD> hi, MFromD> lo) { + const Half dh; + const size_t half_N = Lanes(dh); + + const auto ext_lo_mask = + And(detail::U8MaskBitsVecToMask(d, detail::MaskToU8MaskBitsVec(lo)), + FirstN(d, half_N)); + vuint8m1_t hi_mask_bits = detail::MaskToU8MaskBitsVec(hi); + hi_mask_bits = ShiftLeftSame(hi_mask_bits, static_cast(half_N & 7)); + if (HWY_MAX_LANES_D(D) >= 8) { + hi_mask_bits = + SlideUpLanes(ScalableTag(), hi_mask_bits, half_N / 8); + } + + return Or(ext_lo_mask, detail::U8MaskBitsVecToMask(d, hi_mask_bits)); +} + +// ------------------------------ OrderedDemote2MasksTo + +#ifdef HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#undef HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#else +#define HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#endif + +template ) / 2), + class DTo_2 = Repartition, DFrom>, + hwy::EnableIf, MFromD>()>* = nullptr> +HWY_API MFromD OrderedDemote2MasksTo(DTo d_to, DFrom /*d_from*/, + MFromD a, MFromD b) { + return CombineMasks(d_to, b, a); +} + +#endif // HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +// ------------------------------ Dup128MaskFromMaskBits + +namespace detail { +// Even though this is only used after checking if (kN < X), this helper +// function prevents "shift count exceeded" errors. +template +constexpr unsigned MaxMaskBits() { + return (1u << kN) - 1; +} +template +constexpr unsigned MaxMaskBits() { + return ~0u; +} + +template +constexpr int SufficientPow2ForMask() { + return HWY_MAX( + D().Pow2() - 3 - static_cast(FloorLog2(sizeof(TFromD))), -3); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED M RvvVmmv(M mask) { + // The below And operation is equivalent to the RVV vmmv instruction and + // ensures that mask is not in the same register as a vector operand when used + // in RVV instructions that take both a vector operand and a mask operand. + return And(mask, mask); +} + +} // namespace detail + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= detail::MaxMaskBits(); + +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + const ScalableTag()> du8; + return detail::RvvVmmv(detail::U8MaskBitsVecToMask( + d, detail::ChangeLMUL(ScalableTag(), + Set(du8, static_cast(mask_bits))))); +#else + const RebindToUnsigned du8; + const detail::AdjustSimdTagToMinVecPow2> + du64; + + const auto bytes = ResizeBitCast( + du8, detail::AndS( + ResizeBitCast(du64, Set(du8, static_cast(mask_bits))), + uint64_t{0x8040201008040201u})); + return detail::NeS(bytes, uint8_t{0}); +#endif +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + const ScalableTag()> du8; + const ScalableTag()> du16; + // There are exactly 16 mask bits for 128 vector bits of 8-bit lanes. + return detail::RvvVmmv(detail::U8MaskBitsVecToMask( + d, detail::ChangeLMUL( + ScalableTag(), + BitCast(du8, Set(du16, static_cast(mask_bits)))))); +#else + // Slow fallback for completeness; the above bits to mask cast is preferred. + const RebindToUnsigned du8; + const Repartition du16; + const detail::AdjustSimdTagToMinVecPow2> + du64; + + // Replicate the lower 16 bits of mask_bits to each u16 lane of a u16 vector, + // and then bitcast the replicated mask_bits to a u8 vector + const auto bytes = BitCast(du8, Set(du16, static_cast(mask_bits))); + // Replicate bytes 8x such that each byte contains the bit that governs it. + const auto rep8 = TableLookupLanes(bytes, ShiftRight<3>(detail::Iota0(du8))); + + const auto masked_out_rep8 = ResizeBitCast( + du8, + detail::AndS(ResizeBitCast(du64, rep8), uint64_t{0x8040201008040201u})); + return detail::NeS(masked_out_rep8, uint8_t{0}); +#endif +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= detail::MaxMaskBits(); + +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + const ScalableTag()> du8; + // There are exactly 8 mask bits for 128 vector bits of 16-bit lanes. + return detail::RvvVmmv(detail::U8MaskBitsVecToMask( + d, detail::ChangeLMUL(ScalableTag(), + Set(du8, static_cast(mask_bits))))); +#else + // Slow fallback for completeness; the above bits to mask cast is preferred. + const RebindToUnsigned du; + const VFromD bits = + Shl(Set(du, uint16_t{1}), detail::AndS(detail::Iota0(du), 7)); + return TestBit(Set(du, static_cast(mask_bits)), bits); +#endif +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 4) mask_bits &= detail::MaxMaskBits(); + +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + const ScalableTag()> du8; + return detail::RvvVmmv(detail::U8MaskBitsVecToMask( + d, detail::ChangeLMUL(ScalableTag(), + Set(du8, static_cast(mask_bits * 0x11))))); +#else + // Slow fallback for completeness; the above bits to mask cast is preferred. + const RebindToUnsigned du; + const VFromD bits = Dup128VecFromValues(du, 1, 2, 4, 8); + return TestBit(Set(du, static_cast(mask_bits)), bits); +#endif +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 2) mask_bits &= detail::MaxMaskBits(); + +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + const ScalableTag()> du8; + return detail::RvvVmmv(detail::U8MaskBitsVecToMask( + d, detail::ChangeLMUL(ScalableTag(), + Set(du8, static_cast(mask_bits * 0x55))))); +#else + // Slow fallback for completeness; the above bits to mask cast is preferred. + const RebindToUnsigned du; + const VFromD bits = Dup128VecFromValues(du, 1, 2); + return TestBit(Set(du, static_cast(mask_bits)), bits); +#endif +} + +// ------------------------------ SetMask + +#ifdef HWY_NATIVE_SET_MASK +#undef HWY_NATIVE_SET_MASK +#else +#define HWY_NATIVE_SET_MASK +#endif + +template +HWY_API MFromD SetMask(D d, bool val) { + const uint8_t u8_mask_val = static_cast(-static_cast(val)); +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + const ScalableTag()> du8; + return detail::RvvVmmv(detail::U8MaskBitsVecToMask( + d, detail::ChangeLMUL(ScalableTag(), Set(du8, u8_mask_val)))); +#else + const Rebind>> du8; + return MaskFromVec(Set(du8, u8_mask_val)); +#endif +} + +// ------------------------------ Abs (Max, Neg) + +template +HWY_API V Abs(const V v) { + return Max(v, Neg(v)); +} + +HWY_RVV_FOREACH_F(HWY_RVV_RETV_ARGV2, Abs, fsgnjx, _ALL) + +#undef HWY_RVV_RETV_ARGV2 + +// ------------------------------ AbsDiff (Abs, Sub) +template +HWY_API V AbsDiff(const V a, const V b) { + return Abs(Sub(a, b)); +} + +// ------------------------------ Round (NearestInt, ConvertTo, CopySign) + +// IEEE-754 roundToIntegralTiesToEven returns floating-point, but we do not have +// a dedicated instruction for that. Rounding to integer and converting back to +// float is correct except when the input magnitude is large, in which case the +// input was already an integer (because mantissa >> exponent is zero). + +namespace detail { +enum RoundingModes { kNear, kTrunc, kDown, kUp }; + +template +HWY_INLINE auto UseInt(const V v) -> decltype(MaskFromVec(v)) { + return detail::LtS(Abs(v), MantissaEnd>()); +} + +} // namespace detail + +template +HWY_API V Round(const V v) { + const DFromV df; + + const auto integer = NearestInt(v); // round using current mode + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +// ------------------------------ Trunc (ConvertTo) +template +HWY_API V Trunc(const V v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +// ------------------------------ Ceil +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL >= 1400) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 1700) +namespace detail { +#define HWY_RVV_CEIL_INT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(int, SEW, LMUL) CeilInt(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_vfcvt_x_f_v_i##SEW##LMUL##_rm(v, __RISCV_FRM_RUP, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } +HWY_RVV_FOREACH_F(HWY_RVV_CEIL_INT, _, _, _ALL) +#undef HWY_RVV_CEIL_INT + +} // namespace detail + +template +HWY_API V Ceil(const V v) { + const DFromV df; + + const auto integer = detail::CeilInt(v); + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +#else // GCC 13 or earlier or Clang 16 or earlier + +template +HWY_API V Ceil(const V v) { + const DFromV df; + const RebindToSigned di; + + using T = TFromD; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a positive non-integer ends up smaller; if so, add 1. + const auto pos1 = + IfThenElseZero(Lt(int_f, v), Set(df, ConvertScalarTo(1.0))); + + return IfThenElse(detail::UseInt(v), Add(int_f, pos1), v); +} + +#endif // (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL >= 1400) || + // (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 1700) + +// ------------------------------ Floor +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL >= 1400) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 1700) +namespace detail { +#define HWY_RVV_FLOOR_INT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + HWY_API HWY_RVV_V(int, SEW, LMUL) FloorInt(HWY_RVV_V(BASE, SEW, LMUL) v) { \ + return __riscv_vfcvt_x_f_v_i##SEW##LMUL##_rm(v, __RISCV_FRM_RDN, \ + HWY_RVV_AVL(SEW, SHIFT)); \ + } +HWY_RVV_FOREACH_F(HWY_RVV_FLOOR_INT, _, _, _ALL) +#undef HWY_RVV_FLOOR_INT + +} // namespace detail + +template +HWY_API V Floor(const V v) { + const DFromV df; + + const auto integer = detail::FloorInt(v); + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +#else // GCC 13 or earlier or Clang 16 or earlier + +template +HWY_API V Floor(const V v) { + const DFromV df; + const RebindToSigned di; + + using T = TFromD; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a negative non-integer ends up larger; if so, subtract 1. + const auto neg1 = + IfThenElseZero(Gt(int_f, v), Set(df, ConvertScalarTo(-1.0))); + + return IfThenElse(detail::UseInt(v), Add(int_f, neg1), v); +} + +#endif // (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL >= 1400) || + // (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 1700) + +// ------------------------------ Floating-point classification (Ne) + +// vfclass does not help because it would require 3 instructions (to AND and +// then compare the bits), whereas these are just 1-3 integer instructions. + +template +HWY_API MFromD> IsNaN(const V v) { + return Ne(v, v); +} + +// Per-target flag to prevent generic_ops-inl.h from defining IsInf / IsFinite. +// We use a fused Set/comparison for IsFinite. +#ifdef HWY_NATIVE_ISINF +#undef HWY_NATIVE_ISINF +#else +#define HWY_NATIVE_ISINF +#endif + +template > +HWY_API MFromD IsInf(const V v) { + const D d; + const RebindToSigned di; + using T = TFromD; + const VFromD vi = BitCast(di, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, detail::EqS(Add(vi, vi), hwy::MaxExponentTimes2())); +} + +// Returns whether normal/subnormal/zero. +template > +HWY_API MFromD IsFinite(const V v) { + const D d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + using T = TFromD; + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD exp = + BitCast(di, ShiftRight() + 1>(Add(vu, vu))); + return RebindMask(d, detail::LtS(exp, hwy::MaxExponentField())); +} + +// ------------------------------ Iota (ConvertTo) + +template +HWY_API VFromD Iota(const D d, T2 first) { + return detail::AddS(detail::Iota0(d), static_cast>(first)); +} + +template +HWY_API VFromD Iota(const D d, T2 first) { + const RebindToUnsigned du; + return detail::AddS(BitCast(d, detail::Iota0(du)), + static_cast>(first)); +} + +template +HWY_API VFromD Iota(const D d, T2 first) { + const RebindToUnsigned du; + const RebindToSigned di; + return detail::AddS(ConvertTo(d, BitCast(di, detail::Iota0(du))), + ConvertScalarTo>(first)); +} + +// ------------------------------ BitShuffle (PromoteTo, Rol, SumsOf8) + +// Native implementation required to avoid 8-bit wraparound on long vectors. +#ifdef HWY_NATIVE_BITSHUFFLE +#undef HWY_NATIVE_BITSHUFFLE +#else +#define HWY_NATIVE_BITSHUFFLE +#endif + +// Cannot handle LMUL=8 because we promote indices. +template ), class D64 = DFromV, + HWY_IF_UI64_D(D64), HWY_IF_POW2_LE_D(D64, 2)> +HWY_API V64 BitShuffle(V64 values, VI idx) { + const RebindToUnsigned du64; + const Repartition du8; + const Rebind du16; + using VU8 = VFromD; + using VU16 = VFromD; + // For each 16-bit (to avoid wraparound for long vectors) index of an output + // byte: offset of the u64 lane to which it belongs. + const VU16 byte_offsets = + detail::AndS(detail::Iota0(du16), static_cast(~7u)); + // idx is for a bit; shifting makes that bytes. Promote so we can add + // byte_offsets, then we have the u8 lane index within the whole vector. + const VU16 idx16 = + Add(byte_offsets, PromoteTo(du16, ShiftRight<3>(BitCast(du8, idx)))); + const VU8 bytes = detail::TableLookupLanes16(BitCast(du8, values), idx16); + + // We want to shift right by idx & 7 to extract the desired bit in `bytes`, + // and left by iota & 7 to put it in the correct output bit. To correctly + // handle shift counts from -7 to 7, we rotate (unfortunately not natively + // supported on RVV). + const VU8 rotate_left_bits = Sub(detail::Iota0(du8), BitCast(du8, idx)); + const VU8 extracted_bits_mask = + BitCast(du8, Set(du64, static_cast(0x8040201008040201u))); + const VU8 extracted_bits = + And(Rol(bytes, rotate_left_bits), extracted_bits_mask); + // Combine bit-sliced (one bit per byte) into one 64-bit sum. + return BitCast(D64(), SumsOf8(extracted_bits)); +} + +template ), class D64 = DFromV, + HWY_IF_UI64_D(D64), HWY_IF_POW2_GT_D(D64, 2)> +HWY_API V64 BitShuffle(V64 values, VI idx) { + const Half dh; + const Half> dih; + using V64H = VFromD; + const V64H r0 = BitShuffle(LowerHalf(dh, values), LowerHalf(dih, idx)); + const V64H r1 = BitShuffle(UpperHalf(dh, values), UpperHalf(dih, idx)); + return Combine(D64(), r1, r0); +} + +// ------------------------------ MulEven/Odd (Mul, OddEven) + +template , class DW = RepartitionToWide> +HWY_API VFromD MulEven(const V a, const V b) { + constexpr int maskVal = sizeof(TFromD) == 4 ? 5 + : sizeof(TFromD) == 2 ? 0x55 + : 0x5555; + const auto mask = Dup128MaskFromMaskBits(D(), maskVal); + const auto hi = Slide1Up(D(), MulHigh(a, b)); + const auto res = MaskedMulOr(hi, mask, a, b); + return BitCast(DW(), res); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD MulOdd(const V a, const V b) { + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return BitCast(DW(), OddEven(hi, detail::Slide1Down(lo))); +} + +// There is no 64x64 vwmul. +template +HWY_INLINE V MulEven(const V a, const V b) { + const auto mask = Dup128MaskFromMaskBits(DFromV(), 1); + const auto hi = Slide1Up(DFromV(), MulHigh(a, b)); + return MaskedMulOr(hi, mask, a, b); +} + +template +HWY_INLINE V MulOdd(const V a, const V b) { + const auto lo = Mul(a, b); + const auto hi = MulHigh(a, b); + return OddEven(hi, detail::Slide1Down(lo)); +} + +// ------------------------------ ReorderDemote2To (OddEven, Combine) + +template +HWY_API VFromD ReorderDemote2To(D dbf16, VFromD> a, + VFromD> b) { + const RebindToUnsigned du16; + const Half du16_half; + const RebindToUnsigned> du32; + const VFromD a_in_even = PromoteTo( + du32, detail::DemoteTo16NearestEven(du16_half, BitCast(du32, a))); + const VFromD b_in_even = PromoteTo( + du32, detail::DemoteTo16NearestEven(du16_half, BitCast(du32, b))); + // Equivalent to InterleaveEven, but because the upper 16 bits are zero, we + // can OR instead of OddEven. + const VFromD a_in_odd = + detail::Slide1Up(BitCast(du16, a_in_even)); + return BitCast(dbf16, Or(a_in_odd, BitCast(du16, b_in_even))); +} + +// If LMUL is not the max, Combine first to avoid another DemoteTo. +template ), + HWY_IF_POW2_LE_D(DN, 2), class V, HWY_IF_SIGNED_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + class V2 = VFromD, DN>>, + hwy::EnableIf().Pow2() == DFromV().Pow2()>* = nullptr> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const Rebind, DN> dt; + const VFromD ab = Combine(dt, b, a); + return DemoteTo(dn, ab); +} + +template ) * 2), + class V2 = VFromD, DN>>, + hwy::EnableIf().Pow2() == DFromV().Pow2()>* = nullptr> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const Rebind, DN> dt; + const VFromD ab = Combine(dt, b, a); + return DemoteTo(dn, ab); +} + +// Max LMUL: must DemoteTo first, then Combine. +template ), + HWY_IF_POW2_GT_D(DN, 2), class V, HWY_IF_SIGNED_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + class V2 = VFromD, DN>>, + hwy::EnableIf().Pow2() == DFromV().Pow2()>* = nullptr> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const Half dnh; + const VFromD demoted_a = DemoteTo(dnh, a); + const VFromD demoted_b = DemoteTo(dnh, b); + return Combine(dn, demoted_b, demoted_a); +} + +template ) * 2), + class V2 = VFromD, DN>>, + hwy::EnableIf().Pow2() == DFromV().Pow2()>* = nullptr> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const Half dnh; + const VFromD demoted_a = DemoteTo(dnh, a); + const VFromD demoted_b = DemoteTo(dnh, b); + return Combine(dn, demoted_b, demoted_a); +} + +// If LMUL is not the max, Combine first to avoid another DemoteTo. +template ), + class V2 = VFromD, DN>>, + hwy::EnableIf().Pow2() == DFromV().Pow2()>* = nullptr> +HWY_API VFromD OrderedDemote2To(DN dn, V a, V b) { + const Rebind, DN> dt; + const VFromD ab = Combine(dt, b, a); + return DemoteTo(dn, ab); +} + +// Max LMUL: must DemoteTo first, then Combine. +template ), + class V2 = VFromD, DN>>, + hwy::EnableIf().Pow2() == DFromV().Pow2()>* = nullptr> +HWY_API VFromD OrderedDemote2To(DN dn, V a, V b) { + const Half dnh; + const RebindToUnsigned dn_u; + const RebindToUnsigned dnh_u; + const auto demoted_a = BitCast(dnh_u, DemoteTo(dnh, a)); + const auto demoted_b = BitCast(dnh_u, DemoteTo(dnh, b)); + return BitCast(dn, Combine(dn_u, demoted_b, demoted_a)); +} + +template ), class V, + HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + class V2 = VFromD, DN>>, + hwy::EnableIf().Pow2() == DFromV().Pow2()>* = nullptr> +HWY_API VFromD OrderedDemote2To(DN dn, V a, V b) { + return ReorderDemote2To(dn, a, b); +} + +// ------------------------------ WidenMulPairwiseAdd + +template >> +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + const VFromD ae = PromoteEvenTo(df, a); + const VFromD be = PromoteEvenTo(df, b); + const VFromD ao = PromoteOddTo(df, a); + const VFromD bo = PromoteOddTo(df, b); + return MulAdd(ae, be, Mul(ao, bo)); +} + +template >> +HWY_API VFromD WidenMulPairwiseAdd(D d32, V16 a, V16 b) { + return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), + Mul(PromoteOddTo(d32, a), PromoteOddTo(d32, b))); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +namespace detail { + +#define HWY_RVV_WIDEN_MACC(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \ + SHIFT, MLEN, NAME, OP) \ + template \ + HWY_API HWY_RVV_V(BASE, SEWD, LMULD) NAME( \ + HWY_RVV_D(BASE, SEWD, N, SHIFT + 1) d, HWY_RVV_V(BASE, SEWD, LMULD) sum, \ + HWY_RVV_V(BASE, SEW, LMUL) a, HWY_RVV_V(BASE, SEW, LMUL) b) { \ + return __riscv_v##OP##CHAR##SEWD##LMULD(sum, a, b, Lanes(d)); \ + } + +HWY_RVV_FOREACH_I16(HWY_RVV_WIDEN_MACC, WidenMulAcc, wmacc_vv_, _EXT_VIRT) +HWY_RVV_FOREACH_U16(HWY_RVV_WIDEN_MACC, WidenMulAcc, wmaccu_vv_, _EXT_VIRT) +#undef HWY_RVV_WIDEN_MACC + +// If LMUL is not the max, we can WidenMul first (3 instructions). +template , + class D16 = RepartitionToNarrow> +HWY_API VFromD ReorderWidenMulAccumulateI16(D32 d32, VFromD a, + VFromD b, const V32 sum0, + V32& sum1) { + const Twice d32t; + using V32T = VFromD; + V32T sum = Combine(d32t, sum1, sum0); + sum = detail::WidenMulAcc(d32t, sum, a, b); + sum1 = UpperHalf(d32, sum); + return LowerHalf(d32, sum); +} + +// Max LMUL: must LowerHalf first (4 instructions). +template , + class D16 = RepartitionToNarrow> +HWY_API VFromD ReorderWidenMulAccumulateI16(D32 d32, VFromD a, + VFromD b, const V32 sum0, + V32& sum1) { + const Half d16h; + using V16H = VFromD; + const V16H a0 = LowerHalf(d16h, a); + const V16H a1 = UpperHalf(d16h, a); + const V16H b0 = LowerHalf(d16h, b); + const V16H b1 = UpperHalf(d16h, b); + sum1 = detail::WidenMulAcc(d32, sum1, a1, b1); + return detail::WidenMulAcc(d32, sum0, a0, b0); +} + +// If LMUL is not the max, we can WidenMul first (3 instructions). +template , + class D16 = RepartitionToNarrow> +HWY_API VFromD ReorderWidenMulAccumulateU16(D32 d32, VFromD a, + VFromD b, const V32 sum0, + V32& sum1) { + const Twice d32t; + using V32T = VFromD; + V32T sum = Combine(d32t, sum1, sum0); + sum = detail::WidenMulAcc(d32t, sum, a, b); + sum1 = UpperHalf(d32, sum); + return LowerHalf(d32, sum); +} + +// Max LMUL: must LowerHalf first (4 instructions). +template , + class D16 = RepartitionToNarrow> +HWY_API VFromD ReorderWidenMulAccumulateU16(D32 d32, VFromD a, + VFromD b, const V32 sum0, + V32& sum1) { + const Half d16h; + using V16H = VFromD; + const V16H a0 = LowerHalf(d16h, a); + const V16H a1 = UpperHalf(d16h, a); + const V16H b0 = LowerHalf(d16h, b); + const V16H b1 = UpperHalf(d16h, b); + sum1 = detail::WidenMulAcc(d32, sum1, a1, b1); + return detail::WidenMulAcc(d32, sum0, a0, b0); +} + +} // namespace detail + +template +HWY_API VW ReorderWidenMulAccumulate(D d32, VN a, VN b, const VW sum0, + VW& sum1) { + return detail::ReorderWidenMulAccumulateI16(d32, a, b, sum0, sum1); +} + +template +HWY_API VW ReorderWidenMulAccumulate(D d32, VN a, VN b, const VW sum0, + VW& sum1) { + return detail::ReorderWidenMulAccumulateU16(d32, a, b, sum0, sum1); +} + +// ------------------------------ RearrangeToOddPlusEven + +template // vint32_t* +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { + // vwmacc doubles LMUL, so we require a pairwise sum here. This op is + // expected to be less frequent than ReorderWidenMulAccumulate, hence it's + // preferable to do the extra work here rather than do manual odd/even + // extraction there. + const DFromV di32; + const RebindToUnsigned du32; + const Twice di32x2; + const RepartitionToWide di64x2; + const RebindToUnsigned du64x2; + const auto combined = BitCast(di64x2, Combine(di32x2, sum1, sum0)); + // Isolate odd/even int32 in int64 lanes. + const auto even = ShiftRight<32>(ShiftLeft<32>(combined)); // sign extend + const auto odd = ShiftRight<32>(combined); + return BitCast(di32, TruncateTo(du32, BitCast(du64x2, Add(even, odd)))); +} + +// For max LMUL, we cannot Combine again and instead manually unroll. +HWY_API vint32m8_t RearrangeToOddPlusEven(vint32m8_t sum0, vint32m8_t sum1) { + const DFromV d; + const Half dh; + const vint32m4_t lo = + RearrangeToOddPlusEven(LowerHalf(sum0), UpperHalf(dh, sum0)); + const vint32m4_t hi = + RearrangeToOddPlusEven(LowerHalf(sum1), UpperHalf(dh, sum1)); + return Combine(d, hi, lo); +} + +template // vuint32_t* +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) { + // vwmacc doubles LMUL, so we require a pairwise sum here. This op is + // expected to be less frequent than ReorderWidenMulAccumulate, hence it's + // preferable to do the extra work here rather than do manual odd/even + // extraction there. + const DFromV du32; + const Twice du32x2; + const RepartitionToWide du64x2; + const auto combined = BitCast(du64x2, Combine(du32x2, sum1, sum0)); + // Isolate odd/even int32 in int64 lanes. + const auto even = detail::AndS(combined, uint64_t{0xFFFFFFFFu}); + const auto odd = ShiftRight<32>(combined); + return TruncateTo(du32, Add(even, odd)); +} + +// For max LMUL, we cannot Combine again and instead manually unroll. +HWY_API vuint32m8_t RearrangeToOddPlusEven(vuint32m8_t sum0, vuint32m8_t sum1) { + const DFromV d; + const Half dh; + const vuint32m4_t lo = + RearrangeToOddPlusEven(LowerHalf(sum0), UpperHalf(dh, sum0)); + const vuint32m4_t hi = + RearrangeToOddPlusEven(LowerHalf(sum1), UpperHalf(dh, sum1)); + return Combine(d, hi, lo); +} + +// ------------------------------ Lt128 +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +template +HWY_INLINE MFromD Lt128(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + // The subsequent computations are performed using e8mf8 (8-bit elements with + // a fractional LMUL of 1/8) for the following reasons: + // 1. It is correct for the possible input vector types e64m<1,2,4,8>. This is + // because the resulting mask can occupy at most 1/8 of a full vector when + // using e64m8. + // 2. It can be more efficient than using a full vector or a vector group. + // + // The algorithm computes the result as follows: + // 1. Compute cH | (=H & cL) in the high bits, where cH and cL represent the + // comparison results for the high and low 64-bit elements, respectively. + // 2. Shift the result right by 1 to duplicate the comparison results for the + // low bits. + // 3. Obtain the final result by performing a bitwise OR on the high and low + // bits. + auto du8mf8 = ScalableTag{}; + const vuint8mf8_t ltHL0 = + detail::ChangeLMUL(du8mf8, detail::MaskToU8MaskBitsVec(Lt(a, b))); + const vuint8mf8_t eqHL0 = + detail::ChangeLMUL(du8mf8, detail::MaskToU8MaskBitsVec(Eq(a, b))); + const vuint8mf8_t ltLx0 = Add(ltHL0, ltHL0); + const vuint8mf8_t resultHx = detail::AndS(OrAnd(ltHL0, ltLx0, eqHL0), 0xaa); + const vuint8mf8_t resultxL = ShiftRight<1>(resultHx); + const vuint8mf8_t result = Or(resultHx, resultxL); + auto du8m1 = ScalableTag{}; + return detail::U8MaskBitsVecToMask(d, detail::ChangeLMUL(du8m1, result)); +} + +#else + +template +HWY_INLINE MFromD Lt128(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + // Truth table of Eq and Compare for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + // Shift leftward so L can influence H. + const VFromD ltLx = detail::Slide1Up(ltHL); + const VFromD vecHx = OrAnd(ltHL, eqHL, ltLx); + // Replicate H to its neighbor. + return MaskFromVec(OddEven(vecHx, detail::Slide1Down(vecHx))); +} + +#endif // HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +// ------------------------------ Lt128Upper +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +template +HWY_INLINE MFromD Lt128Upper(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + auto du8mf8 = ScalableTag{}; + const vuint8mf8_t ltHL = + detail::ChangeLMUL(du8mf8, detail::MaskToU8MaskBitsVec(Lt(a, b))); + const vuint8mf8_t ltHx = detail::AndS(ltHL, 0xaa); + const vuint8mf8_t ltxL = ShiftRight<1>(ltHx); + auto du8m1 = ScalableTag{}; + return detail::U8MaskBitsVecToMask(d, + detail::ChangeLMUL(du8m1, Or(ltHx, ltxL))); +} + +#else + +template +HWY_INLINE MFromD Lt128Upper(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + const VFromD down = detail::Slide1Down(ltHL); + // b(267743505): Clang compiler bug, workaround is DoNotOptimize + asm volatile("" : : "r,m"(GetLane(down)) : "memory"); + // Replicate H to its neighbor. + return MaskFromVec(OddEven(ltHL, down)); +} + +#endif // HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +// ------------------------------ Eq128 +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +template +HWY_INLINE MFromD Eq128(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + auto du8mf8 = ScalableTag{}; + const vuint8mf8_t eqHL = + detail::ChangeLMUL(du8mf8, detail::MaskToU8MaskBitsVec(Eq(a, b))); + const vuint8mf8_t eqxH = ShiftRight<1>(eqHL); + const vuint8mf8_t result0L = detail::AndS(And(eqHL, eqxH), 0x55); + const vuint8mf8_t resultH0 = Add(result0L, result0L); + auto du8m1 = ScalableTag{}; + return detail::U8MaskBitsVecToMask( + d, detail::ChangeLMUL(du8m1, Or(result0L, resultH0))); +} + +#else + +template +HWY_INLINE MFromD Eq128(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + const VFromD eqLH = Reverse2(d, eqHL); + const VFromD eq = And(eqHL, eqLH); + // b(267743505): Clang compiler bug, workaround is DoNotOptimize + asm volatile("" : : "r,m"(GetLane(eq)) : "memory"); + return MaskFromVec(eq); +} + +#endif + +// ------------------------------ Eq128Upper +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +template +HWY_INLINE MFromD Eq128Upper(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + auto du8mf8 = ScalableTag{}; + const vuint8mf8_t eqHL = + detail::ChangeLMUL(du8mf8, detail::MaskToU8MaskBitsVec(Eq(a, b))); + const vuint8mf8_t eqHx = detail::AndS(eqHL, 0xaa); + const vuint8mf8_t eqxL = ShiftRight<1>(eqHx); + auto du8m1 = ScalableTag{}; + return detail::U8MaskBitsVecToMask(d, + detail::ChangeLMUL(du8m1, Or(eqHx, eqxL))); +} + +#else + +template +HWY_INLINE MFromD Eq128Upper(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + // Replicate H to its neighbor. + return MaskFromVec(OddEven(eqHL, detail::Slide1Down(eqHL))); +} + +#endif + +// ------------------------------ Ne128 +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +template +HWY_INLINE MFromD Ne128(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + auto du8mf8 = ScalableTag{}; + const vuint8mf8_t neHL = + detail::ChangeLMUL(du8mf8, detail::MaskToU8MaskBitsVec(Ne(a, b))); + const vuint8mf8_t nexH = ShiftRight<1>(neHL); + const vuint8mf8_t result0L = detail::AndS(Or(neHL, nexH), 0x55); + const vuint8mf8_t resultH0 = Add(result0L, result0L); + auto du8m1 = ScalableTag{}; + return detail::U8MaskBitsVecToMask( + d, detail::ChangeLMUL(du8m1, Or(result0L, resultH0))); +} + +#else + +template +HWY_INLINE MFromD Ne128(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + const VFromD neHL = VecFromMask(d, Ne(a, b)); + const VFromD neLH = Reverse2(d, neHL); + // b(267743505): Clang compiler bug, workaround is DoNotOptimize + asm volatile("" : : "r,m"(GetLane(neLH)) : "memory"); + return MaskFromVec(Or(neHL, neLH)); +} + +#endif + +// ------------------------------ Ne128Upper +#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400 + +template +HWY_INLINE MFromD Ne128Upper(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + auto du8mf8 = ScalableTag{}; + const vuint8mf8_t neHL = + detail::ChangeLMUL(du8mf8, detail::MaskToU8MaskBitsVec(Ne(a, b))); + const vuint8mf8_t neHx = detail::AndS(neHL, 0xaa); + const vuint8mf8_t nexL = ShiftRight<1>(neHx); + auto du8m1 = ScalableTag{}; + return detail::U8MaskBitsVecToMask(d, + detail::ChangeLMUL(du8m1, Or(neHx, nexL))); +} + +#else + +template +HWY_INLINE MFromD Ne128Upper(D d, const VFromD a, const VFromD b) { + static_assert(IsSame, uint64_t>(), "D must be u64"); + const VFromD neHL = VecFromMask(d, Ne(a, b)); + const VFromD down = detail::Slide1Down(neHL); + // b(267743505): Clang compiler bug, workaround is DoNotOptimize + asm volatile("" : : "r,m"(GetLane(down)) : "memory"); + // Replicate H to its neighbor. + return MaskFromVec(OddEven(neHL, down)); +} + +#endif + +// ------------------------------ Min128, Max128 (Lt128) + +template +HWY_INLINE VFromD Min128(D /* tag */, const VFromD a, const VFromD b) { + const VFromD aXH = detail::Slide1Down(a); + const VFromD bXH = detail::Slide1Down(b); + const VFromD minHL = Min(a, b); + const MFromD ltXH = Lt(aXH, bXH); + const MFromD eqXH = Eq(aXH, bXH); + // If the upper lane is the decider, take lo from the same reg. + const VFromD lo = IfThenElse(ltXH, a, b); + // The upper lane is just minHL; if they are equal, we also need to use the + // actual min of the lower lanes. + return OddEven(minHL, IfThenElse(eqXH, minHL, lo)); +} + +template +HWY_INLINE VFromD Max128(D /* tag */, const VFromD a, const VFromD b) { + const VFromD aXH = detail::Slide1Down(a); + const VFromD bXH = detail::Slide1Down(b); + const VFromD maxHL = Max(a, b); + const MFromD ltXH = Lt(aXH, bXH); + const MFromD eqXH = Eq(aXH, bXH); + // If the upper lane is the decider, take lo from the same reg. + const VFromD lo = IfThenElse(ltXH, b, a); + // The upper lane is just maxHL; if they are equal, we also need to use the + // actual min of the lower lanes. + return OddEven(maxHL, IfThenElse(eqXH, maxHL, lo)); +} + +template +HWY_INLINE VFromD Min128Upper(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template +HWY_INLINE VFromD Max128Upper(D d, VFromD a, VFromD b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +// ================================================== END MACROS +#undef HWY_RVV_AVL +#undef HWY_RVV_D +#undef HWY_RVV_FOREACH +#undef HWY_RVV_FOREACH_08_ALL +#undef HWY_RVV_FOREACH_08_ALL_VIRT +#undef HWY_RVV_FOREACH_08_DEMOTE +#undef HWY_RVV_FOREACH_08_DEMOTE_VIRT +#undef HWY_RVV_FOREACH_08_EXT +#undef HWY_RVV_FOREACH_08_EXT_VIRT +#undef HWY_RVV_FOREACH_08_TRUNC +#undef HWY_RVV_FOREACH_08_VIRT +#undef HWY_RVV_FOREACH_16_ALL +#undef HWY_RVV_FOREACH_16_ALL_VIRT +#undef HWY_RVV_FOREACH_16_DEMOTE +#undef HWY_RVV_FOREACH_16_DEMOTE_VIRT +#undef HWY_RVV_FOREACH_16_EXT +#undef HWY_RVV_FOREACH_16_EXT_VIRT +#undef HWY_RVV_FOREACH_16_TRUNC +#undef HWY_RVV_FOREACH_16_VIRT +#undef HWY_RVV_FOREACH_32_ALL +#undef HWY_RVV_FOREACH_32_ALL_VIRT +#undef HWY_RVV_FOREACH_32_DEMOTE +#undef HWY_RVV_FOREACH_32_DEMOTE_VIRT +#undef HWY_RVV_FOREACH_32_EXT +#undef HWY_RVV_FOREACH_32_EXT_VIRT +#undef HWY_RVV_FOREACH_32_TRUNC +#undef HWY_RVV_FOREACH_32_VIRT +#undef HWY_RVV_FOREACH_64_ALL +#undef HWY_RVV_FOREACH_64_ALL_VIRT +#undef HWY_RVV_FOREACH_64_DEMOTE +#undef HWY_RVV_FOREACH_64_DEMOTE_VIRT +#undef HWY_RVV_FOREACH_64_EXT +#undef HWY_RVV_FOREACH_64_EXT_VIRT +#undef HWY_RVV_FOREACH_64_TRUNC +#undef HWY_RVV_FOREACH_64_VIRT +#undef HWY_RVV_FOREACH_B +#undef HWY_RVV_FOREACH_F +#undef HWY_RVV_FOREACH_F16 +#undef HWY_RVV_FOREACH_F32 +#undef HWY_RVV_FOREACH_F3264 +#undef HWY_RVV_FOREACH_F64 +#undef HWY_RVV_FOREACH_I +#undef HWY_RVV_FOREACH_I08 +#undef HWY_RVV_FOREACH_I16 +#undef HWY_RVV_FOREACH_I163264 +#undef HWY_RVV_FOREACH_I32 +#undef HWY_RVV_FOREACH_I64 +#undef HWY_RVV_FOREACH_U +#undef HWY_RVV_FOREACH_U08 +#undef HWY_RVV_FOREACH_U16 +#undef HWY_RVV_FOREACH_U163264 +#undef HWY_RVV_FOREACH_U32 +#undef HWY_RVV_FOREACH_U64 +#undef HWY_RVV_FOREACH_UI +#undef HWY_RVV_FOREACH_UI08 +#undef HWY_RVV_FOREACH_UI16 +#undef HWY_RVV_FOREACH_UI163264 +#undef HWY_RVV_FOREACH_UI32 +#undef HWY_RVV_FOREACH_UI3264 +#undef HWY_RVV_FOREACH_UI64 +#undef HWY_RVV_IF_EMULATED_D +#undef HWY_RVV_IF_CAN128_D +#undef HWY_RVV_IF_GE128_D +#undef HWY_RVV_IF_LT128_D +#undef HWY_RVV_INSERT_VXRM +#undef HWY_RVV_M +#undef HWY_RVV_RETM_ARGM +#undef HWY_RVV_RETV_ARGMVV +#undef HWY_RVV_RETV_ARGV +#undef HWY_RVV_RETV_ARGVS +#undef HWY_RVV_RETV_ARGVV +#undef HWY_RVV_T +#undef HWY_RVV_V +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/lib/highway/hwy/ops/scalar-inl.h b/lib/highway/hwy/ops/scalar-inl.h new file mode 100644 index 00000000000..12ce6e42054 --- /dev/null +++ b/lib/highway/hwy/ops/scalar-inl.h @@ -0,0 +1,2185 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Single-element vectors and operations. +// External include guard in highway.h - see comment there. + +#include +#ifndef HWY_NO_LIBCXX +#include // sqrtf +#endif + +#include "hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Single instruction, single data. +template +using Sisd = Simd; + +// (Wrapper class required for overloading comparison operators.) +template +struct Vec1 { + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = 1; // only for DFromV + + HWY_INLINE Vec1() = default; + Vec1(const Vec1&) = default; + Vec1& operator=(const Vec1&) = default; + HWY_INLINE explicit Vec1(const T t) : raw(t) {} + + HWY_INLINE Vec1& operator*=(const Vec1 other) { + return *this = (*this * other); + } + HWY_INLINE Vec1& operator/=(const Vec1 other) { + return *this = (*this / other); + } + HWY_INLINE Vec1& operator+=(const Vec1 other) { + return *this = (*this + other); + } + HWY_INLINE Vec1& operator-=(const Vec1 other) { + return *this = (*this - other); + } + HWY_INLINE Vec1& operator%=(const Vec1 other) { + return *this = (*this % other); + } + HWY_INLINE Vec1& operator&=(const Vec1 other) { + return *this = (*this & other); + } + HWY_INLINE Vec1& operator|=(const Vec1 other) { + return *this = (*this | other); + } + HWY_INLINE Vec1& operator^=(const Vec1 other) { + return *this = (*this ^ other); + } + + T raw; +}; + +// 0 or FF..FF, same size as Vec1. +template +struct Mask1 { + using Raw = hwy::MakeUnsigned; + + using PrivateT = T; // only for DFromM + static constexpr size_t kPrivateN = 1; // only for DFromM + + static HWY_INLINE Mask1 FromBool(bool b) { + Mask1 mask; + mask.bits = b ? static_cast(~Raw{0}) : 0; + return mask; + } + + Raw bits; +}; + +template +using DFromV = Simd; + +template +using DFromM = Simd; + +template +using TFromV = typename V::PrivateT; + +// ------------------------------ BitCast + +template , typename TFrom> +HWY_API Vec1 BitCast(DTo /* tag */, Vec1 v) { + static_assert(sizeof(TTo) <= sizeof(TFrom), "Promoting is undefined"); + TTo to; + CopyBytes(&v.raw, &to); // not same size - ok to shrink + return Vec1(to); +} + +// ------------------------------ Zero + +template > +HWY_API Vec1 Zero(D /* tag */) { + return Vec1(ConvertScalarTo(0)); +} + +template +using VFromD = decltype(Zero(D())); + +// ------------------------------ Set +template , typename T2> +HWY_API Vec1 Set(D /* tag */, const T2 t) { + return Vec1(static_cast(t)); +} + +// ------------------------------ Undefined +template > +HWY_API Vec1 Undefined(D d) { + return Zero(d); +} + +// ------------------------------ Iota +template , typename T2> +HWY_API Vec1 Iota(const D /* tag */, const T2 first) { + return Vec1(static_cast(first)); +} + +// ------------------------------ ResizeBitCast + +template +HWY_API VFromD ResizeBitCast(D /* tag */, FromV v) { + using TFrom = TFromV; + using TTo = TFromD; + constexpr size_t kCopyLen = HWY_MIN(sizeof(TFrom), sizeof(TTo)); + TTo to{}; + CopyBytes(&v.raw, &to); + return VFromD(to); +} + +namespace detail { + +// ResizeBitCast on the HWY_SCALAR target has zero-extending semantics if +// sizeof(TFromD) is greater than sizeof(TFromV) +template +HWY_INLINE VFromD ZeroExtendResizeBitCast(FromSizeTag /* from_size_tag */, + ToSizeTag /* to_size_tag */, + DTo d_to, DFrom /*d_from*/, + VFromD v) { + return ResizeBitCast(d_to, v); +} + +} // namespace detail + +// ------------------------------ Dup128VecFromValues + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD /*t1*/, + TFromD /*t2*/, TFromD /*t3*/, + TFromD /*t4*/, TFromD /*t5*/, + TFromD /*t6*/, TFromD /*t7*/, + TFromD /*t8*/, TFromD /*t9*/, + TFromD /*t10*/, TFromD /*t11*/, + TFromD /*t12*/, TFromD /*t13*/, + TFromD /*t14*/, TFromD /*t15*/) { + return VFromD(t0); +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD /*t1*/, + TFromD /*t2*/, TFromD /*t3*/, + TFromD /*t4*/, TFromD /*t5*/, + TFromD /*t6*/, TFromD /*t7*/) { + return VFromD(t0); +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD /*t1*/, + TFromD /*t2*/, TFromD /*t3*/) { + return VFromD(t0); +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD /*t1*/) { + return VFromD(t0); +} + +// ================================================== LOGICAL + +// ------------------------------ Not + +template +HWY_API Vec1 Not(const Vec1 v) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(static_cast(~BitCast(du, v).raw))); +} + +// ------------------------------ And + +template +HWY_API Vec1 And(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(BitCast(du, a).raw & BitCast(du, b).raw)); +} +template +HWY_API Vec1 operator&(const Vec1 a, const Vec1 b) { + return And(a, b); +} + +// ------------------------------ AndNot + +template +HWY_API Vec1 AndNot(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(static_cast(~BitCast(du, a).raw & + BitCast(du, b).raw))); +} + +// ------------------------------ Or + +template +HWY_API Vec1 Or(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(BitCast(du, a).raw | BitCast(du, b).raw)); +} +template +HWY_API Vec1 operator|(const Vec1 a, const Vec1 b) { + return Or(a, b); +} + +// ------------------------------ Xor + +template +HWY_API Vec1 Xor(const Vec1 a, const Vec1 b) { + using TU = MakeUnsigned; + const Sisd du; + return BitCast(Sisd(), Vec1(BitCast(du, a).raw ^ BitCast(du, b).raw)); +} +template +HWY_API Vec1 operator^(const Vec1 a, const Vec1 b) { + return Xor(a, b); +} + +// ------------------------------ Or3 + +template +HWY_API Vec1 Or3(Vec1 o1, Vec1 o2, Vec1 o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd + +template +HWY_API Vec1 OrAnd(const Vec1 o, const Vec1 a1, const Vec1 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ Mask + +template , typename TFrom> +HWY_API Mask1 RebindMask(DTo /*tag*/, Mask1 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return Mask1{m.bits}; +} + +// v must be 0 or FF..FF. +template +HWY_API Mask1 MaskFromVec(const Vec1 v) { + Mask1 mask; + CopySameSize(&v, &mask); + return mask; +} + +template +using MFromD = decltype(MaskFromVec(VFromD())); + +template > +Vec1 VecFromMask(D /* tag */, const Mask1 mask) { + Vec1 v; + CopySameSize(&mask, &v); + return v; +} + +template +uint64_t BitsFromMask(D, MFromD mask) { + return mask.bits ? 1 : 0; +} + +template > +HWY_API Mask1 FirstN(D /*tag*/, size_t n) { + return Mask1::FromBool(n != 0); +} + +#ifdef HWY_NATIVE_SET_MASK +#undef HWY_NATIVE_SET_MASK +#else +#define HWY_NATIVE_SET_MASK +#endif + +template +HWY_API MFromD SetMask(D /*d*/, bool val) { + return MFromD::FromBool(val); +} + +// ------------------------------ IfVecThenElse +template +HWY_API Vec1 IfVecThenElse(Vec1 mask, Vec1 yes, Vec1 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ CopySign +template +HWY_API Vec1 CopySign(const Vec1 magn, const Vec1 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const DFromV d; + return BitwiseIfThenElse(SignBit(d), sign, magn); +} + +// ------------------------------ CopySignToAbs +template +HWY_API Vec1 CopySignToAbs(const Vec1 abs, const Vec1 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const Sisd d; + return OrAnd(abs, SignBit(d), sign); +} + +// ------------------------------ BroadcastSignBit +template +HWY_API Vec1 BroadcastSignBit(const Vec1 v) { + return Vec1(ScalarShr(v.raw, sizeof(T) * 8 - 1)); +} + +// ------------------------------ PopulationCount + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +template +HWY_API Vec1 PopulationCount(Vec1 v) { + return Vec1(static_cast(PopCount(v.raw))); +} + +// ------------------------------ IfThenElse + +// Returns mask ? yes : no. +template +HWY_API Vec1 IfThenElse(const Mask1 mask, const Vec1 yes, + const Vec1 no) { + return mask.bits ? yes : no; +} + +template +HWY_API Vec1 IfThenElseZero(const Mask1 mask, const Vec1 yes) { + return mask.bits ? yes : Vec1(ConvertScalarTo(0)); +} + +template +HWY_API Vec1 IfThenZeroElse(const Mask1 mask, const Vec1 no) { + return mask.bits ? Vec1(ConvertScalarTo(0)) : no; +} + +template +HWY_API Vec1 IfNegativeThenElse(Vec1 v, Vec1 yes, Vec1 no) { + const DFromV d; + const RebindToSigned di; + const auto vi = BitCast(di, v); + + return vi.raw < 0 ? yes : no; +} + +// ------------------------------ Mask logical + +template +HWY_API Mask1 Not(const Mask1 m) { + return MaskFromVec(Not(VecFromMask(Sisd(), m))); +} + +template +HWY_API Mask1 And(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 AndNot(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 Or(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 Xor(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask1 ExclusiveNeither(const Mask1 a, Mask1 b) { + const Sisd d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +template +HWY_API Mask1 SetAtOrAfterFirst(Mask1 mask) { + return mask; +} + +template +HWY_API Mask1 SetBeforeFirst(Mask1 mask) { + return Not(mask); +} + +template +HWY_API Mask1 SetOnlyFirst(Mask1 mask) { + return mask; +} + +template +HWY_API Mask1 SetAtOrBeforeFirst(Mask1 /*mask*/) { + return Mask1::FromBool(true); +} + +// ------------------------------ LowerHalfOfMask + +#ifdef HWY_NATIVE_LOWER_HALF_OF_MASK +#undef HWY_NATIVE_LOWER_HALF_OF_MASK +#else +#define HWY_NATIVE_LOWER_HALF_OF_MASK +#endif + +template +HWY_API MFromD LowerHalfOfMask(D /*d*/, MFromD m) { + return m; +} + +// ================================================== SHIFTS + +// ------------------------------ ShiftLeft/ShiftRight (BroadcastSignBit) + +template +HWY_API Vec1 ShiftLeft(const Vec1 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + return Vec1( + static_cast(static_cast>(v.raw) << kBits)); +} + +template +HWY_API Vec1 ShiftRight(const Vec1 v) { + static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); + return Vec1(ScalarShr(v.raw, kBits)); +} + +// ------------------------------ RotateRight (ShiftRight) +template +HWY_API Vec1 RotateRight(const Vec1 v) { + const DFromV d; + const RebindToUnsigned du; + + constexpr size_t kSizeInBits = sizeof(T) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + + return Or(BitCast(d, ShiftRight(BitCast(du, v))), + ShiftLeft(v)); +} + +// ------------------------------ ShiftLeftSame (BroadcastSignBit) + +template +HWY_API Vec1 ShiftLeftSame(const Vec1 v, int bits) { + return Vec1( + static_cast(static_cast>(v.raw) << bits)); +} + +template +HWY_API Vec1 ShiftRightSame(const Vec1 v, int bits) { + return Vec1(ScalarShr(v.raw, bits)); +} + +// ------------------------------ Shl + +// Single-lane => same as ShiftLeftSame except for the argument type. +template +HWY_API Vec1 operator<<(const Vec1 v, const Vec1 bits) { + return ShiftLeftSame(v, static_cast(bits.raw)); +} + +template +HWY_API Vec1 operator>>(const Vec1 v, const Vec1 bits) { + return ShiftRightSame(v, static_cast(bits.raw)); +} + +// ================================================== ARITHMETIC + +template +HWY_API Vec1 operator+(Vec1 a, Vec1 b) { + const uint64_t a64 = static_cast(a.raw); + const uint64_t b64 = static_cast(b.raw); + return Vec1(static_cast((a64 + b64) & static_cast(~T(0)))); +} +HWY_API Vec1 operator+(const Vec1 a, const Vec1 b) { + return Vec1(a.raw + b.raw); +} +HWY_API Vec1 operator+(const Vec1 a, const Vec1 b) { + return Vec1(a.raw + b.raw); +} + +template +HWY_API Vec1 operator-(Vec1 a, Vec1 b) { + const uint64_t a64 = static_cast(a.raw); + const uint64_t b64 = static_cast(b.raw); + return Vec1(static_cast((a64 - b64) & static_cast(~T(0)))); +} +HWY_API Vec1 operator-(const Vec1 a, const Vec1 b) { + return Vec1(a.raw - b.raw); +} +HWY_API Vec1 operator-(const Vec1 a, const Vec1 b) { + return Vec1(a.raw - b.raw); +} + +// ------------------------------ SumsOf8 + +HWY_API Vec1 SumsOf8(const Vec1 v) { + return Vec1(v.raw); +} +HWY_API Vec1 SumsOf8(const Vec1 v) { + return Vec1(v.raw); +} + +// ------------------------------ SumsOf2 + +template +HWY_API Vec1> SumsOf2(const Vec1 v) { + const DFromV d; + const Rebind, decltype(d)> dw; + return PromoteTo(dw, v); +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +HWY_API Vec1 SaturatedAdd(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(0, a.raw + b.raw), 255))); +} +HWY_API Vec1 SaturatedAdd(const Vec1 a, + const Vec1 b) { + return Vec1(static_cast( + HWY_MIN(HWY_MAX(0, static_cast(a.raw) + b.raw), 65535))); +} + +// Signed +HWY_API Vec1 SaturatedAdd(const Vec1 a, const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(-128, a.raw + b.raw), 127))); +} +HWY_API Vec1 SaturatedAdd(const Vec1 a, + const Vec1 b) { + return Vec1(static_cast( + HWY_MIN(HWY_MAX(-32768, static_cast(a.raw) + b.raw), 32767))); +} + +// ------------------------------ Saturating subtraction + +// Returns a - b clamped to the destination range. + +// Unsigned +HWY_API Vec1 SaturatedSub(const Vec1 a, + const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(0, a.raw - b.raw), 255))); +} +HWY_API Vec1 SaturatedSub(const Vec1 a, + const Vec1 b) { + return Vec1(static_cast( + HWY_MIN(HWY_MAX(0, static_cast(a.raw) - b.raw), 65535))); +} + +// Signed +HWY_API Vec1 SaturatedSub(const Vec1 a, const Vec1 b) { + return Vec1( + static_cast(HWY_MIN(HWY_MAX(-128, a.raw - b.raw), 127))); +} +HWY_API Vec1 SaturatedSub(const Vec1 a, + const Vec1 b) { + return Vec1(static_cast( + HWY_MIN(HWY_MAX(-32768, static_cast(a.raw) - b.raw), 32767))); +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32 +#undef HWY_NATIVE_AVERAGE_ROUND_UI32 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI32 +#endif + +#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64 +#undef HWY_NATIVE_AVERAGE_ROUND_UI64 +#else +#define HWY_NATIVE_AVERAGE_ROUND_UI64 +#endif + +template +HWY_API Vec1 AverageRound(const Vec1 a, const Vec1 b) { + const T a_val = a.raw; + const T b_val = b.raw; + return Vec1(static_cast((a_val | b_val) - ScalarShr(a_val ^ b_val, 1))); +} + +// ------------------------------ Absolute value + +template +HWY_API Vec1 Abs(const Vec1 a) { + return Vec1(ScalarAbs(a.raw)); +} + +// ------------------------------ Min/Max + +// may be unavailable, so implement our own. + +template +HWY_API Vec1 Min(const Vec1 a, const Vec1 b) { + return Vec1(HWY_MIN(a.raw, b.raw)); +} + +template +HWY_API Vec1 Min(const Vec1 a, const Vec1 b) { + if (ScalarIsNaN(a.raw)) return b; + if (ScalarIsNaN(b.raw)) return a; + return Vec1(HWY_MIN(a.raw, b.raw)); +} + +template +HWY_API Vec1 Max(const Vec1 a, const Vec1 b) { + return Vec1(HWY_MAX(a.raw, b.raw)); +} + +template +HWY_API Vec1 Max(const Vec1 a, const Vec1 b) { + if (ScalarIsNaN(a.raw)) return b; + if (ScalarIsNaN(b.raw)) return a; + return Vec1(HWY_MAX(a.raw, b.raw)); +} + +// ------------------------------ Floating-point negate + +template +HWY_API Vec1 Neg(const Vec1 v) { + return Xor(v, SignBit(Sisd())); +} + +template +HWY_API Vec1 Neg(const Vec1 v) { + return Zero(Sisd()) - v; +} + +// ------------------------------ mul/div + +// Per-target flags to prevent generic_ops-inl.h defining 8/64-bit operator*. +#ifdef HWY_NATIVE_MUL_8 +#undef HWY_NATIVE_MUL_8 +#else +#define HWY_NATIVE_MUL_8 +#endif +#ifdef HWY_NATIVE_MUL_64 +#undef HWY_NATIVE_MUL_64 +#else +#define HWY_NATIVE_MUL_64 +#endif + +template +HWY_API Vec1 operator*(const Vec1 a, const Vec1 b) { + return Vec1(static_cast(double{a.raw} * b.raw)); +} + +template +HWY_API Vec1 operator*(const Vec1 a, const Vec1 b) { + return Vec1(static_cast(static_cast(a.raw) * + static_cast(b.raw))); +} + +template +HWY_API Vec1 operator/(const Vec1 a, const Vec1 b) { + return Vec1(a.raw / b.raw); +} + +// Returns the upper sizeof(T)*8 bits of a * b in each lane. +template +HWY_API Vec1 MulHigh(const Vec1 a, const Vec1 b) { + using TW = MakeWide; + return Vec1(static_cast( + (static_cast(a.raw) * static_cast(b.raw)) >> (sizeof(T) * 8))); +} +template +HWY_API Vec1 MulHigh(const Vec1 a, const Vec1 b) { + T hi; + Mul128(a.raw, b.raw, &hi); + return Vec1(hi); +} + +HWY_API Vec1 MulFixedPoint15(Vec1 a, Vec1 b) { + return Vec1(static_cast((a.raw * b.raw + 16384) >> 15)); +} + +// Multiplies even lanes (0, 2 ..) and returns the double-wide result. +template +HWY_API Vec1> MulEven(const Vec1 a, const Vec1 b) { + using TW = MakeWide; + const TW a_wide = a.raw; + return Vec1(static_cast(a_wide * b.raw)); +} + +template +HWY_API Vec1> MulOdd(const Vec1, const Vec1) { + static_assert(sizeof(T) == 0, "There are no odd lanes"); +} + +// Approximate reciprocal +HWY_API Vec1 ApproximateReciprocal(const Vec1 v) { + // Zero inputs are allowed, but callers are responsible for replacing the + // return value with something else (typically using IfThenElse). This check + // avoids a ubsan error. The return value is arbitrary. + if (v.raw == 0.0f) return Vec1(0.0f); + return Vec1(1.0f / v.raw); +} + +// generic_ops takes care of integer T. +template +HWY_API Vec1 AbsDiff(const Vec1 a, const Vec1 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +template +HWY_API Vec1 MulAdd(const Vec1 mul, const Vec1 x, const Vec1 add) { + return mul * x + add; +} + +template +HWY_API Vec1 NegMulAdd(const Vec1 mul, const Vec1 x, + const Vec1 add) { + return add - mul * x; +} + +template +HWY_API Vec1 MulSub(const Vec1 mul, const Vec1 x, const Vec1 sub) { + return mul * x - sub; +} + +template +HWY_API Vec1 NegMulSub(const Vec1 mul, const Vec1 x, + const Vec1 sub) { + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +// Approximate reciprocal square root +HWY_API Vec1 ApproximateReciprocalSqrt(const Vec1 v) { + float f = v.raw; + const float half = f * 0.5f; + uint32_t bits; + CopySameSize(&f, &bits); + // Initial guess based on log2(f) + bits = 0x5F3759DF - (bits >> 1); + CopySameSize(&bits, &f); + // One Newton-Raphson iteration + return Vec1(f * (1.5f - (half * f * f))); +} + +// Square root +HWY_API Vec1 Sqrt(Vec1 v) { +#if defined(HWY_NO_LIBCXX) +#if HWY_COMPILER_GCC_ACTUAL + return Vec1(__builtin_sqrt(v.raw)); +#else + uint32_t bits; + CopyBytes(&v, &bits); + // Coarse approximation, letting the exponent LSB leak into the mantissa + bits = (1 << 29) + (bits >> 1) - (1 << 22); + CopyBytes(&bits, &v); + return v; +#endif // !HWY_COMPILER_GCC_ACTUAL +#else + return Vec1(sqrtf(v.raw)); +#endif // !HWY_NO_LIBCXX +} +HWY_API Vec1 Sqrt(Vec1 v) { +#if defined(HWY_NO_LIBCXX) +#if HWY_COMPILER_GCC_ACTUAL + return Vec1(__builtin_sqrt(v.raw)); +#else + uint64_t bits; + CopyBytes(&v, &bits); + // Coarse approximation, letting the exponent LSB leak into the mantissa + bits = (1ULL << 61) + (bits >> 1) - (1ULL << 51); + CopyBytes(&bits, &v); + return v; +#endif // !HWY_COMPILER_GCC_ACTUAL +#else + return Vec1(sqrt(v.raw)); +#endif // HWY_NO_LIBCXX +} + +// ------------------------------ Floating-point rounding + +template +HWY_API Vec1 Round(const Vec1 v) { + using TI = MakeSigned; + if (!(Abs(v).raw < MantissaEnd())) { // Huge or NaN + return v; + } + const T k0 = ConvertScalarTo(0); + const T bias = ConvertScalarTo(v.raw < k0 ? -0.5 : 0.5); + const TI rounded = ConvertScalarTo(v.raw + bias); + if (rounded == 0) return CopySignToAbs(Vec1(k0), v); + TI offset = 0; + // Round to even + if ((rounded & 1) && ScalarAbs(ConvertScalarTo(rounded) - v.raw) == + ConvertScalarTo(0.5)) { + offset = v.raw < k0 ? -1 : 1; + } + return Vec1(ConvertScalarTo(rounded - offset)); +} + +// Round-to-nearest even. +template +HWY_API Vec1> NearestInt(const Vec1 v) { + using TI = MakeSigned; + + const T abs = Abs(v).raw; + const bool is_sign = ScalarSignBit(v.raw); + + if (!(abs < MantissaEnd())) { // Huge or NaN + // Check if too large to cast or NaN + if (!(abs <= ConvertScalarTo(LimitsMax()))) { + return Vec1(is_sign ? LimitsMin() : LimitsMax()); + } + return Vec1(ConvertScalarTo(v.raw)); + } + const T bias = + ConvertScalarTo(v.raw < ConvertScalarTo(0.0) ? -0.5 : 0.5); + const TI rounded = ConvertScalarTo(v.raw + bias); + if (rounded == 0) return Vec1(0); + TI offset = 0; + // Round to even + if ((rounded & 1) && ScalarAbs(ConvertScalarTo(rounded) - v.raw) == + ConvertScalarTo(0.5)) { + offset = is_sign ? -1 : 1; + } + return Vec1(rounded - offset); +} + +// Round-to-nearest even. +template +HWY_API VFromD DemoteToNearestInt(DI32 /*di32*/, const Vec1 v) { + using T = double; + using TI = int32_t; + + const T abs = Abs(v).raw; + const bool is_sign = ScalarSignBit(v.raw); + + // Check if too large to cast or NaN + if (!(abs <= ConvertScalarTo(LimitsMax()))) { + return Vec1(is_sign ? LimitsMin() : LimitsMax()); + } + + const T bias = + ConvertScalarTo(v.raw < ConvertScalarTo(0.0) ? -0.5 : 0.5); + const TI rounded = ConvertScalarTo(v.raw + bias); + if (rounded == 0) return Vec1(0); + TI offset = 0; + // Round to even + if ((rounded & 1) && ScalarAbs(ConvertScalarTo(rounded) - v.raw) == + ConvertScalarTo(0.5)) { + offset = is_sign ? -1 : 1; + } + return Vec1(rounded - offset); +} + +template +HWY_API Vec1 Trunc(const Vec1 v) { + using TI = MakeSigned; + if (!(Abs(v).raw <= MantissaEnd())) { // Huge or NaN + return v; + } + const TI truncated = ConvertScalarTo(v.raw); + if (truncated == 0) return CopySignToAbs(Vec1(0), v); + return Vec1(ConvertScalarTo(truncated)); +} + +template +V Ceiling(const V v) { + const Bits kExponentMask = (1ull << kExponentBits) - 1; + const Bits kMantissaMask = (1ull << kMantissaBits) - 1; + const Bits kBias = kExponentMask / 2; + + Float f = v.raw; + const bool positive = f > Float(0.0); + + Bits bits; + CopySameSize(&v, &bits); + + const int exponent = + static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) return v; + // |v| <= 1 => 0 or 1. + if (exponent < 0) return positive ? V(1) : V(-0.0); + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) return v; + + // Clear fractional bits and round up + if (positive) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + CopySameSize(&bits, &f); + return V(f); +} + +template +V Floor(const V v) { + const Bits kExponentMask = (1ull << kExponentBits) - 1; + const Bits kMantissaMask = (1ull << kMantissaBits) - 1; + const Bits kBias = kExponentMask / 2; + + Float f = v.raw; + const bool negative = f < Float(0.0); + + Bits bits; + CopySameSize(&v, &bits); + + const int exponent = + static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); + // Already an integer. + if (exponent >= kMantissaBits) return v; + // |v| <= 1 => -1 or 0. + if (exponent < 0) return V(negative ? Float(-1.0) : Float(0.0)); + + const Bits mantissa_mask = kMantissaMask >> exponent; + // Already an integer + if ((bits & mantissa_mask) == 0) return v; + + // Clear fractional bits and round down + if (negative) bits += (kMantissaMask + 1) >> exponent; + bits &= ~mantissa_mask; + + CopySameSize(&bits, &f); + return V(f); +} + +// Toward +infinity, aka ceiling +HWY_API Vec1 Ceil(const Vec1 v) { + return Ceiling(v); +} +HWY_API Vec1 Ceil(const Vec1 v) { + return Ceiling(v); +} + +// Toward -infinity, aka floor +HWY_API Vec1 Floor(const Vec1 v) { + return Floor(v); +} +HWY_API Vec1 Floor(const Vec1 v) { + return Floor(v); +} + +// ================================================== COMPARE + +template +HWY_API Mask1 operator==(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw == b.raw); +} + +template +HWY_API Mask1 operator!=(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw != b.raw); +} + +template +HWY_API Mask1 TestBit(const Vec1 v, const Vec1 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +template +HWY_API Mask1 operator<(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw < b.raw); +} +template +HWY_API Mask1 operator>(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw > b.raw); +} + +template +HWY_API Mask1 operator<=(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw <= b.raw); +} +template +HWY_API Mask1 operator>=(const Vec1 a, const Vec1 b) { + return Mask1::FromBool(a.raw >= b.raw); +} + +// ------------------------------ Floating-point classification (==) + +template +HWY_API Mask1 IsNaN(const Vec1 v) { + // std::isnan returns false for 0x7F..FF in clang AVX3 builds, so DIY. + return Mask1::FromBool(ScalarIsNaN(v.raw)); +} + +// Per-target flag to prevent generic_ops-inl.h from defining IsInf / IsFinite. +#ifdef HWY_NATIVE_ISINF +#undef HWY_NATIVE_ISINF +#else +#define HWY_NATIVE_ISINF +#endif + +HWY_API Mask1 IsInf(const Vec1 v) { + const Sisd d; + const RebindToUnsigned du; + const Vec1 vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, (vu + vu) == Set(du, 0xFF000000u)); +} +HWY_API Mask1 IsInf(const Vec1 v) { + const Sisd d; + const RebindToUnsigned du; + const Vec1 vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, (vu + vu) == Set(du, 0xFFE0000000000000ull)); +} + +HWY_API Mask1 IsFinite(const Vec1 v) { + const Vec1 vu = BitCast(Sisd(), v); + // Shift left to clear the sign bit, check whether exponent != max value. + return Mask1::FromBool((vu.raw << 1) < 0xFF000000u); +} +HWY_API Mask1 IsFinite(const Vec1 v) { + const Vec1 vu = BitCast(Sisd(), v); + // Shift left to clear the sign bit, check whether exponent != max value. + return Mask1::FromBool((vu.raw << 1) < 0xFFE0000000000000ull); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template > +HWY_API Vec1 Load(D /* tag */, const T* HWY_RESTRICT aligned) { + T t; + CopySameSize(aligned, &t); + return Vec1(t); +} + +template > +HWY_API Vec1 MaskedLoad(Mask1 m, D d, const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +template > +HWY_API Vec1 MaskedLoadOr(Vec1 v, Mask1 m, D d, + const T* HWY_RESTRICT aligned) { + return IfThenElse(m, Load(d, aligned), v); +} + +template > +HWY_API Vec1 LoadU(D d, const T* HWY_RESTRICT p) { + return Load(d, p); +} + +// In some use cases, "load single lane" is sufficient; otherwise avoid this. +template > +HWY_API Vec1 LoadDup128(D d, const T* HWY_RESTRICT aligned) { + return Load(d, aligned); +} + +#ifdef HWY_NATIVE_LOAD_N +#undef HWY_NATIVE_LOAD_N +#else +#define HWY_NATIVE_LOAD_N +#endif + +template > +HWY_API VFromD LoadN(D d, const T* HWY_RESTRICT p, + size_t max_lanes_to_load) { + return (max_lanes_to_load > 0) ? Load(d, p) : Zero(d); +} + +template > +HWY_API VFromD LoadNOr(VFromD no, D d, const T* HWY_RESTRICT p, + size_t max_lanes_to_load) { + return (max_lanes_to_load > 0) ? Load(d, p) : no; +} + +// ------------------------------ Store + +template > +HWY_API void Store(const Vec1 v, D /* tag */, T* HWY_RESTRICT aligned) { + CopySameSize(&v.raw, aligned); +} + +template > +HWY_API void StoreU(const Vec1 v, D d, T* HWY_RESTRICT p) { + return Store(v, d, p); +} + +template > +HWY_API void BlendedStore(const Vec1 v, Mask1 m, D d, T* HWY_RESTRICT p) { + if (!m.bits) return; + StoreU(v, d, p); +} + +#ifdef HWY_NATIVE_STORE_N +#undef HWY_NATIVE_STORE_N +#else +#define HWY_NATIVE_STORE_N +#endif + +template > +HWY_API void StoreN(VFromD v, D d, T* HWY_RESTRICT p, + size_t max_lanes_to_store) { + if (max_lanes_to_store > 0) { + Store(v, d, p); + } +} + +// ------------------------------ Tuples +#include "hwy/ops/inside-inl.h" + +// ------------------------------ LoadInterleaved2/3/4 + +// Per-target flag to prevent generic_ops-inl.h from defining StoreInterleaved2. +#ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#undef HWY_NATIVE_LOAD_STORE_INTERLEAVED +#else +#define HWY_NATIVE_LOAD_STORE_INTERLEAVED +#endif + +template > +HWY_API void LoadInterleaved2(D d, const T* HWY_RESTRICT unaligned, Vec1& v0, + Vec1& v1) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); +} + +template > +HWY_API void LoadInterleaved3(D d, const T* HWY_RESTRICT unaligned, Vec1& v0, + Vec1& v1, Vec1& v2) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); +} + +template > +HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, Vec1& v0, + Vec1& v1, Vec1& v2, Vec1& v3) { + v0 = LoadU(d, unaligned + 0); + v1 = LoadU(d, unaligned + 1); + v2 = LoadU(d, unaligned + 2); + v3 = LoadU(d, unaligned + 3); +} + +// ------------------------------ StoreInterleaved2/3/4 + +template > +HWY_API void StoreInterleaved2(const Vec1 v0, const Vec1 v1, D d, + T* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); +} + +template > +HWY_API void StoreInterleaved3(const Vec1 v0, const Vec1 v1, + const Vec1 v2, D d, + T* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); + StoreU(v2, d, unaligned + 2); +} + +template > +HWY_API void StoreInterleaved4(const Vec1 v0, const Vec1 v1, + const Vec1 v2, const Vec1 v3, D d, + T* HWY_RESTRICT unaligned) { + StoreU(v0, d, unaligned + 0); + StoreU(v1, d, unaligned + 1); + StoreU(v2, d, unaligned + 2); + StoreU(v3, d, unaligned + 3); +} + +// ------------------------------ Stream + +template > +HWY_API void Stream(const Vec1 v, D d, T* HWY_RESTRICT aligned) { + return Store(v, d, aligned); +} + +// ------------------------------ Scatter + +#ifdef HWY_NATIVE_SCATTER +#undef HWY_NATIVE_SCATTER +#else +#define HWY_NATIVE_SCATTER +#endif + +template , typename TI> +HWY_API void ScatterOffset(Vec1 v, D d, T* base, Vec1 offset) { + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + const intptr_t addr = + reinterpret_cast(base) + static_cast(offset.raw); + Store(v, d, reinterpret_cast(addr)); +} + +template , typename TI> +HWY_API void ScatterIndex(Vec1 v, D d, T* HWY_RESTRICT base, + Vec1 index) { + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + Store(v, d, base + index.raw); +} + +template , typename TI> +HWY_API void MaskedScatterIndex(Vec1 v, Mask1 m, D d, + T* HWY_RESTRICT base, Vec1 index) { + static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match"); + if (m.bits) Store(v, d, base + index.raw); +} + +// ------------------------------ Gather + +#ifdef HWY_NATIVE_GATHER +#undef HWY_NATIVE_GATHER +#else +#define HWY_NATIVE_GATHER +#endif + +template > +HWY_API Vec1 GatherOffset(D d, const T* base, Vec1> offset) { + HWY_DASSERT(offset.raw >= 0); + const intptr_t addr = + reinterpret_cast(base) + static_cast(offset.raw); + return Load(d, reinterpret_cast(addr)); +} + +template > +HWY_API Vec1 GatherIndex(D d, const T* HWY_RESTRICT base, + Vec1> index) { + HWY_DASSERT(index.raw >= 0); + return Load(d, base + index.raw); +} + +template > +HWY_API Vec1 MaskedGatherIndex(Mask1 m, D d, const T* HWY_RESTRICT base, + Vec1> index) { + HWY_DASSERT(index.raw >= 0); + return MaskedLoad(m, d, base + index.raw); +} + +template > +HWY_API Vec1 MaskedGatherIndexOr(Vec1 no, Mask1 m, D d, + const T* HWY_RESTRICT base, + Vec1> index) { + HWY_DASSERT(index.raw >= 0); + return MaskedLoadOr(no, m, d, base + index.raw); +} + +// ================================================== CONVERT + +// ConvertTo and DemoteTo with floating-point input and integer output truncate +// (rounding toward zero). + +namespace detail { + +template +HWY_INLINE ToT CastValueForF2IConv(FromT val) { + // Prevent ubsan errors when converting float to narrower integer + + using FromTU = MakeUnsigned; + using ToTU = MakeUnsigned; + + constexpr unsigned kMaxExpField = + static_cast(MaxExponentField()); + constexpr unsigned kExpBias = kMaxExpField >> 1; + constexpr unsigned kMinOutOfRangeExpField = static_cast(HWY_MIN( + kExpBias + sizeof(ToT) * 8 - static_cast(IsSigned()), + kMaxExpField)); + + // If ToT is signed, compare only the exponent bits of val against + // kMinOutOfRangeExpField. + // + // Otherwise, if ToT is unsigned, compare the sign bit plus exponent bits of + // val against kMinOutOfRangeExpField as a negative value is outside of the + // range of an unsigned integer type. + const FromT val_to_compare = + static_cast(IsSigned() ? ScalarAbs(val) : val); + + // val is within the range of ToT if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is less + // than kMinOutOfRangeExpField + // + // Otherwise, val is either outside of the range of ToT or equal to + // LimitsMin() if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is greater + // than or equal to kMinOutOfRangeExpField. + + return (static_cast(BitCastScalar(val_to_compare) >> + MantissaBits()) < kMinOutOfRangeExpField) + ? static_cast(val) + : static_cast(static_cast(LimitsMax()) + + static_cast(ScalarSignBit(val))); +} + +template +HWY_INLINE ToT CastValueForPromoteTo(ToTypeTag /* to_type_tag */, FromT val) { + return ConvertScalarTo(val); +} + +template +HWY_INLINE ToT CastValueForPromoteTo(hwy::SignedTag /*to_type_tag*/, + float val) { + return CastValueForF2IConv(val); +} + +template +HWY_INLINE ToT CastValueForPromoteTo(hwy::UnsignedTag /*to_type_tag*/, + float val) { + return CastValueForF2IConv(val); +} + +// If val is within the range of ToT, CastValueForInRangeF2IConv(val) +// returns static_cast(val) +// +// Otherwise, CastValueForInRangeF2IConv(val) returns an +// implementation-defined result if val is not within the range of ToT. +template +HWY_INLINE ToT CastValueForInRangeF2IConv(FromT val) { + // Prevent ubsan errors when converting float to narrower integer + + using FromTU = MakeUnsigned; + + constexpr unsigned kMaxExpField = + static_cast(MaxExponentField()); + constexpr unsigned kExpBias = kMaxExpField >> 1; + constexpr unsigned kMinOutOfRangeExpField = static_cast(HWY_MIN( + kExpBias + sizeof(ToT) * 8 - static_cast(IsSigned()), + kMaxExpField)); + + // If ToT is signed, compare only the exponent bits of val against + // kMinOutOfRangeExpField. + // + // Otherwise, if ToT is unsigned, compare the sign bit plus exponent bits of + // val against kMinOutOfRangeExpField as a negative value is outside of the + // range of an unsigned integer type. + const FromT val_to_compare = + static_cast(IsSigned() ? ScalarAbs(val) : val); + + // val is within the range of ToT if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is less + // than kMinOutOfRangeExpField + // + // Otherwise, val is either outside of the range of ToT or equal to + // LimitsMin() if + // (BitCastScalar(val_to_compare) >> MantissaBits()) is greater + // than or equal to kMinOutOfRangeExpField. + + return (static_cast(BitCastScalar(val_to_compare) >> + MantissaBits()) < kMinOutOfRangeExpField) + ? static_cast(val) + : static_cast(LimitsMin()); +} + +} // namespace detail + +#ifdef HWY_NATIVE_PROMOTE_F16_TO_F64 +#undef HWY_NATIVE_PROMOTE_F16_TO_F64 +#else +#define HWY_NATIVE_PROMOTE_F16_TO_F64 +#endif + +template , typename TFrom> +HWY_API Vec1 PromoteTo(DTo /* tag */, Vec1 from) { + static_assert(sizeof(TTo) > sizeof(TFrom), "Not promoting"); + // For bits Y > X, floatX->floatY and intX->intY are always representable. + return Vec1( + detail::CastValueForPromoteTo(hwy::TypeTag(), from.raw)); +} + +#ifdef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#endif + +template +HWY_API VFromD PromoteInRangeTo(DTo /* tag */, Vec1 from) { + using TTo = TFromD; + return Vec1(detail::CastValueForInRangeF2IConv(from.raw)); +} + +// MSVC 19.10 cannot deduce the argument type if HWY_IF_FLOAT(TFrom) is here, +// so we overload for TFrom=double and TTo={float,int32_t}. +template +HWY_API Vec1 DemoteTo(D /* tag */, Vec1 from) { + // Prevent ubsan errors when converting float to narrower integer/float + if (IsInf(from).bits || + Abs(from).raw > static_cast(HighestValue())) { + return Vec1(ScalarSignBit(from.raw) ? LowestValue() + : HighestValue()); + } + return Vec1(static_cast(from.raw)); +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec1 from) { + // Prevent ubsan errors when converting int32_t to narrower integer/int32_t + return Vec1>(detail::CastValueForF2IConv>(from.raw)); +} + +template , typename TFrom, + HWY_IF_SIGNED(TFrom), HWY_IF_NOT_FLOAT_NOR_SPECIAL(TFromD)> +HWY_API Vec1 DemoteTo(DTo /* tag */, Vec1 from) { + static_assert(!IsFloat(), "TFrom=double are handled above"); + static_assert(sizeof(TTo) < sizeof(TFrom), "Not demoting"); + + // Int to int: choose closest value in TTo to `from` (avoids UB) + from.raw = HWY_MIN(HWY_MAX(LimitsMin(), from.raw), LimitsMax()); + return Vec1(static_cast(from.raw)); +} + +// Disable the default unsigned to signed DemoteTo implementation in +// generic_ops-inl.h on SCALAR as the SCALAR target has a target-specific +// implementation of the unsigned to signed DemoteTo op and as ReorderDemote2To +// is not supported on the SCALAR target + +// NOTE: hwy::EnableIf()>* = nullptr is used instead of +// hwy::EnableIf* = nullptr to avoid compiler errors since +// !hwy::IsSame() is always false and as !hwy::IsSame() will cause +// SFINAE to occur instead of a hard error due to a dependency on the V template +// argument +#undef HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V +#define HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V(V) \ + hwy::EnableIf()>* = nullptr + +template , typename TFrom, + HWY_IF_UNSIGNED(TFrom), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(DTo)> +HWY_API Vec1 DemoteTo(DTo /* tag */, Vec1 from) { + static_assert(!IsFloat(), "TFrom=double are handled above"); + static_assert(sizeof(TTo) < sizeof(TFrom), "Not demoting"); + + const auto max = static_cast>(LimitsMax()); + + // Int to int: choose closest value in TTo to `from` (avoids UB) + return Vec1(static_cast(HWY_MIN(from.raw, max))); +} + +template , typename TFrom, + HWY_IF_UI64(TFrom), HWY_IF_F32_D(DTo)> +HWY_API Vec1 DemoteTo(DTo /* tag */, Vec1 from) { + // int64_t/uint64_t to float: simply cast to TTo + return Vec1(static_cast(from.raw)); +} + +#ifdef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#endif + +template +HWY_API VFromD DemoteInRangeTo(D32 /*d32*/, + VFromD> v) { + using TTo = TFromD; + return Vec1(detail::CastValueForInRangeF2IConv(v.raw)); +} + +// Per-target flag to prevent generic_ops-inl.h from defining f16 conversions; +// use this scalar version to verify the vector implementation. +#ifdef HWY_NATIVE_F16C +#undef HWY_NATIVE_F16C +#else +#define HWY_NATIVE_F16C +#endif + +template +HWY_API Vec1 PromoteTo(D /* tag */, const Vec1 v) { + return Vec1(F32FromF16(v.raw)); +} + +template +HWY_API Vec1 PromoteTo(D d, const Vec1 v) { + return Set(d, F32FromBF16(v.raw)); +} + +template +HWY_API VFromD PromoteEvenTo(DTo d_to, Vec1 v) { + return PromoteTo(d_to, v); +} + +template +HWY_API Vec1 DemoteTo(D /* tag */, const Vec1 v) { + return Vec1(F16FromF32(v.raw)); +} + +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + +template +HWY_API Vec1 DemoteTo(D d, const Vec1 v) { + return Set(d, BF16FromF32(v.raw)); +} + +template , typename TFrom, + HWY_IF_FLOAT(TFrom)> +HWY_API Vec1 ConvertTo(DTo /* tag */, Vec1 from) { + static_assert(sizeof(TTo) == sizeof(TFrom), "Should have same size"); + // float## -> int##: return closest representable value. + return Vec1(detail::CastValueForF2IConv(from.raw)); +} + +template , typename TFrom, + HWY_IF_NOT_FLOAT(TFrom)> +HWY_API Vec1 ConvertTo(DTo /* tag */, Vec1 from) { + static_assert(sizeof(TTo) == sizeof(TFrom), "Should have same size"); + // int## -> float##: no check needed + return Vec1(static_cast(from.raw)); +} + +#ifdef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#undef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#else +#define HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#endif + +template +HWY_API VFromD ConvertInRangeTo(DI /*di*/, VFromD> v) { + using TTo = TFromD; + return VFromD(detail::CastValueForInRangeF2IConv(v.raw)); +} + +HWY_API Vec1 U8FromU32(const Vec1 v) { + return DemoteTo(Sisd(), v); +} + +// ------------------------------ TruncateTo + +template +HWY_API Vec1 TruncateTo(D /* tag */, Vec1 v) { + return Vec1{static_cast(v.raw & 0xFF)}; +} + +template +HWY_API Vec1 TruncateTo(D /* tag */, Vec1 v) { + return Vec1{static_cast(v.raw & 0xFFFF)}; +} + +template +HWY_API Vec1 TruncateTo(D /* tag */, Vec1 v) { + return Vec1{static_cast(v.raw & 0xFFFFFFFFu)}; +} + +template +HWY_API Vec1 TruncateTo(D /* tag */, Vec1 v) { + return Vec1{static_cast(v.raw & 0xFF)}; +} + +template +HWY_API Vec1 TruncateTo(D /* tag */, Vec1 v) { + return Vec1{static_cast(v.raw & 0xFFFF)}; +} + +template +HWY_API Vec1 TruncateTo(D /* tag */, Vec1 v) { + return Vec1{static_cast(v.raw & 0xFF)}; +} + +// ================================================== COMBINE +// UpperHalf, ZeroExtendVector, Combine, Concat* are unsupported. + +template +HWY_API Vec1 LowerHalf(Vec1 v) { + return v; +} + +template > +HWY_API Vec1 LowerHalf(D /* tag */, Vec1 v) { + return v; +} + +// ================================================== SWIZZLE + +template +HWY_API T GetLane(const Vec1 v) { + return v.raw; +} + +template +HWY_API T ExtractLane(const Vec1 v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return v.raw; +} + +template +HWY_API Vec1 InsertLane(Vec1 v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + v.raw = t; + return v; +} + +template +HWY_API Vec1 DupEven(Vec1 v) { + return v; +} +// DupOdd is unsupported. + +template +HWY_API Vec1 OddEven(Vec1 /* odd */, Vec1 even) { + return even; +} + +template +HWY_API Vec1 OddEvenBlocks(Vec1 /* odd */, Vec1 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks +template +HWY_API Vec1 SwapAdjacentBlocks(Vec1 v) { + return v; +} + +// ------------------------------ InterleaveEvenBlocks +template > +HWY_API V InterleaveEvenBlocks(D, V a, V /*b*/) { + return a; +} +// ------------------------------ InterleaveOddBlocks +template > +HWY_API V InterleaveOddBlocks(D, V a, V /*b*/) { + return a; +} + +// ------------------------------ InterleaveLowerBlocks +template > +HWY_API V InterleaveLowerBlocks(D, V a, V /*b*/) { + return a; +} +// ------------------------------ InterleaveUpperBlocks +template > +HWY_API V InterleaveUpperBlocks(D, V a, V /*b*/) { + return a; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices1 { + MakeSigned raw; +}; + +template , typename TI> +HWY_API Indices1 IndicesFromVec(D, Vec1 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane size"); + HWY_DASSERT(vec.raw <= 1); + return Indices1{static_cast>(vec.raw)}; +} + +template , typename TI> +HWY_API Indices1 SetTableIndices(D d, const TI* idx) { + return IndicesFromVec(d, LoadU(Sisd(), idx)); +} + +template +HWY_API Vec1 TableLookupLanes(const Vec1 v, const Indices1 /* idx */) { + return v; +} + +template +HWY_API Vec1 TwoTablesLookupLanes(const Vec1 a, const Vec1 b, + const Indices1 idx) { + return (idx.raw == 0) ? a : b; +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template > +HWY_API Vec1 ReverseBlocks(D /* tag */, const Vec1 v) { + return v; +} + +// ------------------------------ Reverse + +template > +HWY_API Vec1 Reverse(D /* tag */, const Vec1 v) { + return v; +} + +// Per-target flag to prevent generic_ops-inl.h defining 8-bit Reverse2/4/8. +#ifdef HWY_NATIVE_REVERSE2_8 +#undef HWY_NATIVE_REVERSE2_8 +#else +#define HWY_NATIVE_REVERSE2_8 +#endif + +// Must not be called: +template > +HWY_API Vec1 Reverse2(D /* tag */, const Vec1 v) { + return v; +} + +template > +HWY_API Vec1 Reverse4(D /* tag */, const Vec1 v) { + return v; +} + +template > +HWY_API Vec1 Reverse8(D /* tag */, const Vec1 v) { + return v; +} + +// ------------------------------ ReverseLaneBytes + +#ifdef HWY_NATIVE_REVERSE_LANE_BYTES +#undef HWY_NATIVE_REVERSE_LANE_BYTES +#else +#define HWY_NATIVE_REVERSE_LANE_BYTES +#endif + +HWY_API Vec1 ReverseLaneBytes(Vec1 v) { + const uint32_t val{v.raw}; + return Vec1( + static_cast(((val << 8) & 0xFF00u) | ((val >> 8) & 0x00FFu))); +} + +HWY_API Vec1 ReverseLaneBytes(Vec1 v) { + const uint32_t val = v.raw; + return Vec1(static_cast( + ((val << 24) & 0xFF000000u) | ((val << 8) & 0x00FF0000u) | + ((val >> 8) & 0x0000FF00u) | ((val >> 24) & 0x000000FFu))); +} + +HWY_API Vec1 ReverseLaneBytes(Vec1 v) { + const uint64_t val = v.raw; + return Vec1(static_cast( + ((val << 56) & 0xFF00000000000000u) | + ((val << 40) & 0x00FF000000000000u) | + ((val << 24) & 0x0000FF0000000000u) | ((val << 8) & 0x000000FF00000000u) | + ((val >> 8) & 0x00000000FF000000u) | ((val >> 24) & 0x0000000000FF0000u) | + ((val >> 40) & 0x000000000000FF00u) | + ((val >> 56) & 0x00000000000000FFu))); +} + +template +HWY_API V ReverseLaneBytes(V v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, ReverseLaneBytes(BitCast(du, v))); +} + +// ------------------------------ ReverseBits +#ifdef HWY_NATIVE_REVERSE_BITS_UI8 +#undef HWY_NATIVE_REVERSE_BITS_UI8 +#else +#define HWY_NATIVE_REVERSE_BITS_UI8 +#endif + +#ifdef HWY_NATIVE_REVERSE_BITS_UI16_32_64 +#undef HWY_NATIVE_REVERSE_BITS_UI16_32_64 +#else +#define HWY_NATIVE_REVERSE_BITS_UI16_32_64 +#endif + +namespace detail { + +template +HWY_INLINE T ReverseBitsOfEachByte(T val) { + using TU = MakeUnsigned; + constexpr TU kMaxUnsignedVal{LimitsMax()}; + constexpr TU kShrMask1 = + static_cast(0x5555555555555555u & kMaxUnsignedVal); + constexpr TU kShrMask2 = + static_cast(0x3333333333333333u & kMaxUnsignedVal); + constexpr TU kShrMask3 = + static_cast(0x0F0F0F0F0F0F0F0Fu & kMaxUnsignedVal); + + constexpr TU kShlMask1 = static_cast(~kShrMask1); + constexpr TU kShlMask2 = static_cast(~kShrMask2); + constexpr TU kShlMask3 = static_cast(~kShrMask3); + + TU result = static_cast(val); + result = static_cast(((result << 1) & kShlMask1) | + ((result >> 1) & kShrMask1)); + result = static_cast(((result << 2) & kShlMask2) | + ((result >> 2) & kShrMask2)); + result = static_cast(((result << 4) & kShlMask3) | + ((result >> 4) & kShrMask3)); + return static_cast(result); +} + +} // namespace detail + +template +HWY_API V ReverseBits(V v) { + return V(detail::ReverseBitsOfEachByte(v.raw)); +} + +template +HWY_API V ReverseBits(V v) { + return ReverseLaneBytes(V(detail::ReverseBitsOfEachByte(v.raw))); +} + +template +HWY_API V ReverseBits(V v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, ReverseBits(BitCast(du, v))); +} + +// ------------------------------ SlideUpLanes + +template +HWY_API VFromD SlideUpLanes(D /*d*/, VFromD v, size_t /*amt*/) { + return v; +} + +// ------------------------------ SlideDownLanes + +template +HWY_API VFromD SlideDownLanes(D /*d*/, VFromD v, size_t /*amt*/) { + return v; +} + +// ================================================== BLOCKWISE +// Shift*Bytes, CombineShiftRightBytes, Interleave*, Shuffle* are unsupported. + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec1 Broadcast(const Vec1 v) { + static_assert(kLane == 0, "Scalar only has one lane"); + return v; +} + +// ------------------------------ TableLookupBytes, TableLookupBytesOr0 + +template +HWY_API Vec1 TableLookupBytes(const Vec1 in, const Vec1 indices) { + uint8_t in_bytes[sizeof(T)]; + uint8_t idx_bytes[sizeof(T)]; + uint8_t out_bytes[sizeof(T)]; + CopyBytes(&in, &in_bytes); // copy to bytes + CopyBytes(&indices, &idx_bytes); + for (size_t i = 0; i < sizeof(T); ++i) { + out_bytes[i] = in_bytes[idx_bytes[i]]; + } + TI out; + CopyBytes(&out_bytes, &out); + return Vec1{out}; +} + +template +HWY_API Vec1 TableLookupBytesOr0(const Vec1 in, const Vec1 indices) { + uint8_t in_bytes[sizeof(T)]; + uint8_t idx_bytes[sizeof(T)]; + uint8_t out_bytes[sizeof(T)]; + CopyBytes(&in, &in_bytes); // copy to bytes + CopyBytes(&indices, &idx_bytes); + for (size_t i = 0; i < sizeof(T); ++i) { + out_bytes[i] = idx_bytes[i] & 0x80 ? 0 : in_bytes[idx_bytes[i]]; + } + TI out; + CopyBytes(&out_bytes, &out); + return Vec1{out}; +} + +// ------------------------------ ZipLower + +HWY_API Vec1 ZipLower(Vec1 a, Vec1 b) { + return Vec1(static_cast((uint32_t{b.raw} << 8) + a.raw)); +} +HWY_API Vec1 ZipLower(Vec1 a, Vec1 b) { + return Vec1((uint32_t{b.raw} << 16) + a.raw); +} +HWY_API Vec1 ZipLower(Vec1 a, Vec1 b) { + return Vec1((uint64_t{b.raw} << 32) + a.raw); +} +HWY_API Vec1 ZipLower(Vec1 a, Vec1 b) { + return Vec1(static_cast((int32_t{b.raw} << 8) + a.raw)); +} +HWY_API Vec1 ZipLower(Vec1 a, Vec1 b) { + return Vec1((int32_t{b.raw} << 16) + a.raw); +} +HWY_API Vec1 ZipLower(Vec1 a, Vec1 b) { + return Vec1((int64_t{b.raw} << 32) + a.raw); +} + +template , typename TN = MakeNarrow> +HWY_API Vec1 ZipLower(DW /* tag */, Vec1 a, Vec1 b) { + return Vec1(static_cast((TW{b.raw} << (sizeof(TN) * 8)) + a.raw)); +} + +// ================================================== MASK + +template > +HWY_API bool AllFalse(D /* tag */, const Mask1 mask) { + return mask.bits == 0; +} + +template > +HWY_API bool AllTrue(D /* tag */, const Mask1 mask) { + return mask.bits != 0; +} + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template > +HWY_API Mask1 LoadMaskBits(D /* tag */, const uint8_t* HWY_RESTRICT bits) { + return Mask1::FromBool((bits[0] & 1) != 0); +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D /*d*/, unsigned mask_bits) { + return MFromD::FromBool((mask_bits & 1) != 0); +} + +// `p` points to at least 8 writable bytes. +template > +HWY_API size_t StoreMaskBits(D d, const Mask1 mask, uint8_t* bits) { + *bits = AllTrue(d, mask); + return 1; +} + +template > +HWY_API size_t CountTrue(D /* tag */, const Mask1 mask) { + return mask.bits == 0 ? 0 : 1; +} + +template > +HWY_API intptr_t FindFirstTrue(D /* tag */, const Mask1 mask) { + return mask.bits == 0 ? -1 : 0; +} + +template > +HWY_API size_t FindKnownFirstTrue(D /* tag */, const Mask1 /* m */) { + return 0; // There is only one lane and we know it is true. +} + +template > +HWY_API intptr_t FindLastTrue(D /* tag */, const Mask1 mask) { + return mask.bits == 0 ? -1 : 0; +} + +template > +HWY_API size_t FindKnownLastTrue(D /* tag */, const Mask1 /* m */) { + return 0; // There is only one lane and we know it is true. +} + +// ------------------------------ Compress, CompressBits + +template +struct CompressIsPartition { + enum { value = 1 }; +}; + +template +HWY_API Vec1 Compress(Vec1 v, const Mask1 /* mask */) { + // A single lane is already partitioned by definition. + return v; +} + +template +HWY_API Vec1 CompressNot(Vec1 v, const Mask1 /* mask */) { + // A single lane is already partitioned by definition. + return v; +} + +// ------------------------------ CompressStore +template > +HWY_API size_t CompressStore(Vec1 v, const Mask1 mask, D d, + T* HWY_RESTRICT unaligned) { + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ CompressBlendedStore +template > +HWY_API size_t CompressBlendedStore(Vec1 v, const Mask1 mask, D d, + T* HWY_RESTRICT unaligned) { + if (!mask.bits) return 0; + StoreU(v, d, unaligned); + return 1; +} + +// ------------------------------ CompressBits +template +HWY_API Vec1 CompressBits(Vec1 v, const uint8_t* HWY_RESTRICT /*bits*/) { + return v; +} + +// ------------------------------ CompressBitsStore +template > +HWY_API size_t CompressBitsStore(Vec1 v, const uint8_t* HWY_RESTRICT bits, + D d, T* HWY_RESTRICT unaligned) { + const Mask1 mask = LoadMaskBits(d, bits); + StoreU(Compress(v, mask), d, unaligned); + return CountTrue(d, mask); +} + +// ------------------------------ Expand + +// generic_ops-inl.h requires Vec64/128, so implement [Load]Expand here. +#ifdef HWY_NATIVE_EXPAND +#undef HWY_NATIVE_EXPAND +#else +#define HWY_NATIVE_EXPAND +#endif + +template +HWY_API Vec1 Expand(Vec1 v, const Mask1 mask) { + return IfThenElseZero(mask, v); +} + +// ------------------------------ LoadExpand +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { + return MaskedLoad(mask, d, unaligned); +} + +// ------------------------------ WidenMulPairwiseAdd + +template +HWY_API Vec1 WidenMulPairwiseAdd(D32 /* tag */, Vec1 a, + Vec1 b) { + return Vec1(F32FromBF16(a.raw)) * Vec1(F32FromBF16(b.raw)); +} + +template +HWY_API Vec1 WidenMulPairwiseAdd(D32 /* tag */, Vec1 a, + Vec1 b) { + return Vec1(a.raw * b.raw); +} + +// ------------------------------ SatWidenMulAccumFixedPoint +#ifdef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#undef HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#else +#define HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT +#endif + +template +HWY_API VFromD SatWidenMulAccumFixedPoint(DI32 di32, + VFromD> a, + VFromD> b, + VFromD sum) { + // Multiplying static_cast(a.raw) by static_cast(b.raw) + // followed by an addition of the product is okay as + // (a.raw * b.raw * 2) is between -2147418112 and 2147483648 and as + // a.raw * b.raw * 2 can only overflow an int32_t if both a.raw and b.raw are + // equal to -32768. + + const VFromD product(static_cast(a.raw) * + static_cast(b.raw)); + const VFromD product2 = Add(product, product); + + const auto mul_overflow = + VecFromMask(di32, Eq(product2, Set(di32, LimitsMin()))); + + return SaturatedAdd(Sub(sum, And(BroadcastSignBit(sum), mul_overflow)), + Add(product2, mul_overflow)); +} + +// ------------------------------ SatWidenMulPairwiseAdd + +#ifdef HWY_NATIVE_U8_I8_SATWIDENMULPAIRWISEADD +#undef HWY_NATIVE_U8_I8_SATWIDENMULPAIRWISEADD +#else +#define HWY_NATIVE_U8_I8_SATWIDENMULPAIRWISEADD +#endif + +template +HWY_API Vec1 SatWidenMulPairwiseAdd(DI16 /* tag */, Vec1 a, + Vec1 b) { + // Saturation of a.raw * b.raw is not needed on the HWY_SCALAR target as the + // input vectors only have 1 lane on the HWY_SCALAR target and as + // a.raw * b.raw is between -32640 and 32385, which is already within the + // range of an int16_t. + + // On other targets, a saturated addition of a[0]*b[0] + a[1]*b[1] is needed + // as it is possible for the addition of a[0]*b[0] + a[1]*b[1] to overflow if + // a[0], a[1], b[0], and b[1] are all non-zero and b[0] and b[1] both have the + // same sign. + + return Vec1(static_cast(a.raw) * + static_cast(b.raw)); +} + +// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) + +#ifdef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#undef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#else +#define HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#endif + +template +HWY_API Vec1 ReorderWidenMulAccumulate(D32 /* tag */, Vec1 a, + Vec1 b, + const Vec1 sum0, + Vec1& /* sum1 */) { + return MulAdd(Vec1(F32FromBF16(a.raw)), + Vec1(F32FromBF16(b.raw)), sum0); +} + +template +HWY_API Vec1 ReorderWidenMulAccumulate(D32 /* tag */, Vec1 a, + Vec1 b, + const Vec1 sum0, + Vec1& /* sum1 */) { + return Vec1(a.raw * b.raw + sum0.raw); +} + +template +HWY_API Vec1 ReorderWidenMulAccumulate(DU32 /* tag */, + Vec1 a, + Vec1 b, + const Vec1 sum0, + Vec1& /* sum1 */) { + return Vec1(static_cast(a.raw) * b.raw + sum0.raw); +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API Vec1 RearrangeToOddPlusEven(Vec1 sum0, Vec1 /* sum1 */) { + return sum0; // invariant already holds +} + +// ================================================== REDUCTIONS + +// Nothing native, generic_ops-inl defines SumOfLanes and ReduceSum. + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/lib/highway/hwy/ops/set_macros-inl.h b/lib/highway/hwy/ops/set_macros-inl.h new file mode 100644 index 00000000000..4fe0e9d35f6 --- /dev/null +++ b/lib/highway/hwy/ops/set_macros-inl.h @@ -0,0 +1,898 @@ +// Copyright 2020 Google LLC +// Copyright 2024-2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// SPDX-License-Identifier: BSD-3-Clause +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Sets macros based on HWY_TARGET. + +// This include guard is toggled by foreach_target, so avoid the usual _H_ +// suffix to prevent copybara from renaming it. +#if defined(HWY_SET_MACROS_PER_TARGET) == defined(HWY_TARGET_TOGGLE) +#ifdef HWY_SET_MACROS_PER_TARGET +#undef HWY_SET_MACROS_PER_TARGET +#else +#define HWY_SET_MACROS_PER_TARGET +#endif + +#endif // HWY_SET_MACROS_PER_TARGET + +#include "hwy/detect_compiler_arch.h" // IWYU: export +#include "hwy/detect_targets.h" // IWYU: export + +#undef HWY_NAMESPACE +#undef HWY_ALIGN +#undef HWY_MAX_BYTES +#undef HWY_MIN_BYTES + +#undef HWY_HAVE_SCALABLE +#undef HWY_HAVE_TUPLE +#undef HWY_REGISTERS +#undef HWY_HAVE_INTEGER64 +#undef HWY_HAVE_FLOAT16 +#undef HWY_HAVE_FLOAT64 +#undef HWY_MEM_OPS_MIGHT_FAULT +#undef HWY_NATIVE_FMA +#undef HWY_NATIVE_DOT_BF16 +#undef HWY_NATIVE_MASK +#undef HWY_NATIVE_INTERLEAVE_WHOLE + +#ifndef HWY_CAP_GE256 +#define HWY_CAP_GE256 (HWY_MIN_BYTES >= 32) +#endif +#ifndef HWY_CAP_GE512 +#define HWY_CAP_GE512 (HWY_MIN_BYTES >= 64) +#endif + +// Almost all targets (except RVV and SCALAR) use this definition. +#undef HWY_LANES +#define HWY_LANES(T) (HWY_MAX_BYTES / sizeof(T)) + +// If 1, both __bf16 and a limited set of *_bf16 SVE intrinsics are available: +// create/get/set/dup, ld/st, sel, rev, trn, uzp, zip. +// Consulted below, hence define here rather than in arm_sve-inl.h. +#if HWY_ARM_HAVE_SCALAR_BF16_TYPE && defined(__ARM_FEATURE_SVE_BF16) +#define HWY_SVE_HAVE_BF16_FEATURE 1 +#else +#define HWY_SVE_HAVE_BF16_FEATURE 0 +#endif + +#undef HWY_TARGET_IS_SVE +#if HWY_TARGET & HWY_ALL_SVE +#define HWY_TARGET_IS_SVE 1 +#else +#define HWY_TARGET_IS_SVE 0 +#endif + +#undef HWY_TARGET_IS_NEON +#if HWY_TARGET & HWY_ALL_NEON +#define HWY_TARGET_IS_NEON 1 +#else +#define HWY_TARGET_IS_NEON 0 +#endif + +#undef HWY_TARGET_IS_PPC +#if HWY_TARGET & HWY_ALL_PPC +#define HWY_TARGET_IS_PPC 1 +#else +#define HWY_TARGET_IS_PPC 0 +#endif + +#undef HWY_TARGET_IS_AVX10_2 +#if HWY_TARGET == HWY_AVX10_2 +#define HWY_TARGET_IS_AVX10_2 1 +#else +#define HWY_TARGET_IS_AVX10_2 0 +#endif + +// Supported on all targets except RVV (requires GCC 14 or upcoming Clang) +#if HWY_TARGET == HWY_RVV && \ + ((HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1700)) +#define HWY_HAVE_TUPLE 0 +#else +#define HWY_HAVE_TUPLE 1 +#endif + +// Target-specific number of architectural vector registers available. +#if !HWY_ARCH_X86 || (HWY_TARGET <= HWY_AVX3) +#define HWY_REGISTERS 32 +#else +#define HWY_REGISTERS 16 +#endif + +// For internal use (clamping/validating N for Simd<>) +#undef HWY_MAX_N +#if HWY_TARGET == HWY_SCALAR +#define HWY_MAX_N 1 +#else +#define HWY_MAX_N 65536 +#endif + +// For internal use (clamping kPow2 for Simd<>) +#undef HWY_MAX_POW2 +// For HWY_TARGET == HWY_RVV, LMUL <= 8. Even on other targets, we want to +// support say Rebind> d; whose kPow2 is also 3. +// However, those other targets do not actually support multiple vectors, and +// thus Lanes(d) must not exceed Lanes(ScalableTag()). +#define HWY_MAX_POW2 3 + +// User-visible. Loose lower bound that guarantees HWY_MAX_BYTES >> +// (-HWY_MIN_POW2) <= 1. Useful for terminating compile-time recursions. +#undef HWY_MIN_POW2 +#if HWY_TARGET == HWY_RVV +#define HWY_MIN_POW2 -16 +#else +// Tighter bound for other targets, whose vectors are smaller, to potentially +// save compile time. +#define HWY_MIN_POW2 -8 +#endif // HWY_TARGET == HWY_RVV + +#undef HWY_TARGET_STR + +#if defined(HWY_DISABLE_PCLMUL_AES) +#define HWY_TARGET_STR_PCLMUL_AES "" +#else +#define HWY_TARGET_STR_PCLMUL_AES ",pclmul,aes" +#endif + +#if defined(HWY_DISABLE_BMI2_FMA) +#define HWY_TARGET_STR_BMI2_FMA "" +#else +#define HWY_TARGET_STR_BMI2_FMA ",bmi,bmi2,fma" +#endif + +#if defined(HWY_DISABLE_F16C) +#define HWY_TARGET_STR_F16C "" +#else +#define HWY_TARGET_STR_F16C ",f16c" +#endif + +#define HWY_TARGET_STR_SSE2 "sse2" + +#define HWY_TARGET_STR_SSSE3 "sse2,ssse3" + +#define HWY_TARGET_STR_SSE4 \ + HWY_TARGET_STR_SSSE3 ",sse4.1,sse4.2" HWY_TARGET_STR_PCLMUL_AES +// Include previous targets, which are the half-vectors of the next target. +#define HWY_TARGET_STR_AVX2 \ + HWY_TARGET_STR_SSE4 ",avx,avx2" HWY_TARGET_STR_BMI2_FMA HWY_TARGET_STR_F16C + +#ifndef HWY_HAVE_EVEX512 // allow override +// evex512 has been removed from clang 22, see +// https://github.com/llvm/llvm-project/pull/157034 +#if (1400 <= HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1600) || \ + (1800 <= HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 2200) +#define HWY_HAVE_EVEX512 1 +#else +#define HWY_HAVE_EVEX512 0 +#endif +#endif + +#if (HWY_HAVE_EVEX512 == 1) +#define HWY_TARGET_STR_AVX3_VL512 ",evex512" +#else +#define HWY_TARGET_STR_AVX3_VL512 +#endif + +#define HWY_TARGET_STR_AVX3 \ + HWY_TARGET_STR_AVX2 \ + ",avx512f,avx512cd,avx512vl,avx512dq,avx512bw" HWY_TARGET_STR_AVX3_VL512 + +#define HWY_TARGET_STR_AVX3_DL \ + HWY_TARGET_STR_AVX3 \ + ",vpclmulqdq,avx512vbmi,avx512vbmi2,vaes,avx512vnni,avx512bitalg," \ + "avx512vpopcntdq,gfni" + +// Opt-out for compilers that do not properly support avx512bf16. +#ifndef HWY_AVX3_ENABLE_AVX512BF16 // allow override +// Default is to disable if the DISABLE macro is defined, or if old compiler. +// clang-cl 21.1.4 reportedly works; feel free to define this to 1 there. +#if defined(HWY_AVX3_DISABLE_AVX512BF16) || \ + (HWY_COMPILER_CLANGCL || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1000) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 900)) +#define HWY_AVX3_ENABLE_AVX512BF16 0 +#else +#define HWY_AVX3_ENABLE_AVX512BF16 1 +#endif +#endif // HWY_AVX3_ENABLE_AVX512BF16 + +#if HWY_AVX3_ENABLE_AVX512BF16 +#define HWY_TARGET_STR_AVX3_ZEN4 HWY_TARGET_STR_AVX3_DL ",avx512bf16" +#else +#define HWY_TARGET_STR_AVX3_ZEN4 HWY_TARGET_STR_AVX3_DL +#endif + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 || HWY_COMPILER_CLANG >= 1400 +#define HWY_TARGET_STR_AVX3_SPR HWY_TARGET_STR_AVX3_ZEN4 ",avx512fp16" +#else +#define HWY_TARGET_STR_AVX3_SPR HWY_TARGET_STR_AVX3_ZEN4 +#endif + +// Support for avx10.2-512 was removed between clang 22 and 23 without a +// feature test macro. +#if HWY_COMPILER_CLANG >= 2200 && HWY_HAVE_EVEX512 +#define HWY_TARGET_STR_AVX10_2 HWY_TARGET_STR_AVX3_SPR ",avx10.2-512" +// Recent compilers drop the -512 suffix because 512 bits are always available. +#elif HWY_COMPILER_GCC_ACTUAL >= 1500 || HWY_COMPILER_CLANG >= 2200 +#define HWY_TARGET_STR_AVX10_2 HWY_TARGET_STR_AVX3_SPR ",avx10.2" +#else +#define HWY_TARGET_STR_AVX10_2 HWY_TARGET_STR_AVX3_SPR +#endif + +#if defined(HWY_DISABLE_PPC8_CRYPTO) +#define HWY_TARGET_STR_PPC8_CRYPTO "" +#else +#define HWY_TARGET_STR_PPC8_CRYPTO ",crypto" +#endif + +#define HWY_TARGET_STR_PPC8 \ + "altivec,vsx,power8-vector" HWY_TARGET_STR_PPC8_CRYPTO +#define HWY_TARGET_STR_PPC9 HWY_TARGET_STR_PPC8 ",power9-vector" + +#if HWY_COMPILER_CLANG +#define HWY_TARGET_STR_PPC10 HWY_TARGET_STR_PPC9 ",power10-vector" +#else +// See #1707 and https://gcc.gnu.org/bugzilla/show_bug.cgi?id=102059#c35. +// When the baseline is PPC 8 or 9, inlining functions such as PreventElision +// into PPC10 code fails because PPC10 defaults to no-htm and is thus worse than +// the baseline, which has htm. We cannot have pragma target on functions +// outside HWY_NAMESPACE such as those in base.h. It would be possible for users +// to set -mno-htm globally, but we can also work around this at the library +// level by claiming that PPC10 still has HTM, thus avoiding the mismatch. This +// seems to be safe because HTM uses builtins rather than modifying codegen, see +// https://gcc.gnu.org/legacy-ml/gcc-patches/2013-07/msg00167.html. +#define HWY_TARGET_STR_PPC10 HWY_TARGET_STR_PPC9 ",cpu=power10,htm" +#endif + +#define HWY_TARGET_STR_Z14 "arch=z14" +#define HWY_TARGET_STR_Z15 "arch=z15" + +// Before include guard so we redefine HWY_TARGET_STR on each include, +// governed by the current HWY_TARGET. + +//----------------------------------------------------------------------------- +// SSE2 +#if HWY_TARGET == HWY_SSE2 + +#define HWY_NAMESPACE N_SSE2 +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_MIN_BYTES 16 + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_NATIVE_MASK 0 // a few actually are + +#define HWY_TARGET_STR HWY_TARGET_STR_SSE2 +//----------------------------------------------------------------------------- +// SSSE3 +#elif HWY_TARGET == HWY_SSSE3 + +#define HWY_NAMESPACE N_SSSE3 +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_MIN_BYTES 16 + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_NATIVE_MASK 0 // a few actually are + +#define HWY_TARGET_STR HWY_TARGET_STR_SSSE3 + +//----------------------------------------------------------------------------- +// SSE4 +#elif HWY_TARGET == HWY_SSE4 + +#define HWY_NAMESPACE N_SSE4 +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_MIN_BYTES 16 + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_NATIVE_MASK 0 // a few actually are + +#define HWY_TARGET_STR HWY_TARGET_STR_SSE4 + +//----------------------------------------------------------------------------- +// AVX2 +#elif HWY_TARGET == HWY_AVX2 + +#define HWY_NAMESPACE N_AVX2 +#define HWY_ALIGN alignas(32) +#define HWY_MAX_BYTES 32 +#define HWY_MIN_BYTES 32 + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 + +#ifdef HWY_DISABLE_BMI2_FMA +#define HWY_NATIVE_FMA 0 +#else +#define HWY_NATIVE_FMA 1 +#endif +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_NATIVE_MASK 0 // a few actually are + +#define HWY_TARGET_STR HWY_TARGET_STR_AVX2 + +//----------------------------------------------------------------------------- +// AVX3[_DL/ZEN4/SPR]/AVX10 +#elif HWY_TARGET <= HWY_AVX3 + +#define HWY_ALIGN alignas(64) +#define HWY_MAX_BYTES 64 +#define HWY_MIN_BYTES 64 + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#if HWY_TARGET <= HWY_AVX3_SPR && \ + (HWY_COMPILER_GCC_ACTUAL || HWY_COMPILER_CLANG >= 2200) && \ + HWY_HAVE_SCALAR_F16_TYPE +#define HWY_HAVE_FLOAT16 1 +#else +#define HWY_HAVE_FLOAT16 0 +#endif +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 0 +#define HWY_NATIVE_FMA 1 +#if (HWY_TARGET <= HWY_AVX3_ZEN4) && HWY_AVX3_ENABLE_AVX512BF16 +#define HWY_NATIVE_DOT_BF16 1 +#else +#define HWY_NATIVE_DOT_BF16 0 +#endif +#define HWY_NATIVE_MASK 1 + +#if HWY_TARGET == HWY_AVX3 + +#define HWY_NAMESPACE N_AVX3 +#define HWY_TARGET_STR HWY_TARGET_STR_AVX3 + +#elif HWY_TARGET == HWY_AVX3_DL + +#define HWY_NAMESPACE N_AVX3_DL +#define HWY_TARGET_STR HWY_TARGET_STR_AVX3_DL + +#elif HWY_TARGET == HWY_AVX3_ZEN4 + +#define HWY_NAMESPACE N_AVX3_ZEN4 +#define HWY_TARGET_STR HWY_TARGET_STR_AVX3_ZEN4 + +#elif HWY_TARGET == HWY_AVX3_SPR + +#define HWY_NAMESPACE N_AVX3_SPR +#define HWY_TARGET_STR HWY_TARGET_STR_AVX3_SPR + +#elif HWY_TARGET == HWY_AVX10_2 + +#define HWY_NAMESPACE N_AVX10_2 +#define HWY_TARGET_STR HWY_TARGET_STR_AVX10_2 + +#else +#error "Logic error" +#endif // HWY_TARGET + +//----------------------------------------------------------------------------- +// PPC8, PPC9, PPC10 +#elif HWY_TARGET_IS_PPC + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_MIN_BYTES 16 + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 1 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_NATIVE_MASK 0 + +#if HWY_TARGET == HWY_PPC8 + +#define HWY_NAMESPACE N_PPC8 +#define HWY_TARGET_STR HWY_TARGET_STR_PPC8 + +#elif HWY_TARGET == HWY_PPC9 + +#define HWY_NAMESPACE N_PPC9 +#define HWY_TARGET_STR HWY_TARGET_STR_PPC9 + +#elif HWY_TARGET == HWY_PPC10 + +#define HWY_NAMESPACE N_PPC10 +#define HWY_TARGET_STR HWY_TARGET_STR_PPC10 + +#else +#error "Logic error" +#endif // HWY_TARGET + +//----------------------------------------------------------------------------- +// Z14, Z15 +#elif HWY_TARGET == HWY_Z14 || HWY_TARGET == HWY_Z15 + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_MIN_BYTES 16 + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 1 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_NATIVE_MASK 0 + +#if HWY_TARGET == HWY_Z14 + +#define HWY_NAMESPACE N_Z14 +#define HWY_TARGET_STR HWY_TARGET_STR_Z14 + +#elif HWY_TARGET == HWY_Z15 + +#define HWY_NAMESPACE N_Z15 +#define HWY_TARGET_STR HWY_TARGET_STR_Z15 + +#else +#error "Logic error" +#endif // HWY_TARGET == HWY_Z15 + +//----------------------------------------------------------------------------- +// NEON +#elif HWY_TARGET_IS_NEON + +// Clang 17 crashes with bf16, see github.com/llvm/llvm-project/issues/64179. +#undef HWY_NEON_HAVE_BFLOAT16 +#if HWY_HAVE_SCALAR_BF16_TYPE && \ + ((HWY_TARGET == HWY_NEON_BF16 && \ + (!HWY_COMPILER_CLANG || HWY_COMPILER_CLANG >= 1800)) || \ + defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC)) +#define HWY_NEON_HAVE_BFLOAT16 1 +#else +#define HWY_NEON_HAVE_BFLOAT16 0 +#endif + +// HWY_NEON_HAVE_F32_TO_BF16C is defined if NEON vcvt_bf16_f32 and +// vbfdot_f32 are available, even if the __bf16 type is disabled due to +// GCC/Clang bugs. +#undef HWY_NEON_HAVE_F32_TO_BF16C +#if HWY_NEON_HAVE_BFLOAT16 || HWY_TARGET == HWY_NEON_BF16 || \ + (defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) && \ + (HWY_COMPILER_GCC_ACTUAL >= 1000 || HWY_COMPILER_CLANG >= 1100)) +#define HWY_NEON_HAVE_F32_TO_BF16C 1 +#else +#define HWY_NEON_HAVE_F32_TO_BF16C 0 +#endif + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_MIN_BYTES 16 + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || HWY_TARGET == HWY_NEON_BF16 +#define HWY_HAVE_FLOAT16 1 +#else +#define HWY_HAVE_FLOAT16 0 +#endif + +#if HWY_ARCH_ARM_A64 +#define HWY_HAVE_FLOAT64 1 +#else +#define HWY_HAVE_FLOAT64 0 +#endif + +#define HWY_MEM_OPS_MIGHT_FAULT 1 + +#if defined(__ARM_FEATURE_FMA) || defined(__ARM_VFPV4__) || HWY_ARCH_ARM_A64 +#define HWY_NATIVE_FMA 1 +#else +#define HWY_NATIVE_FMA 0 +#endif + +#if HWY_NEON_HAVE_F32_TO_BF16C +#define HWY_NATIVE_DOT_BF16 1 +#else +#define HWY_NATIVE_DOT_BF16 0 +#endif + +#define HWY_NATIVE_MASK 0 + +#if HWY_TARGET == HWY_NEON_WITHOUT_AES +#define HWY_NAMESPACE N_NEON_WITHOUT_AES +#elif HWY_TARGET == HWY_NEON +#define HWY_NAMESPACE N_NEON +#elif HWY_TARGET == HWY_NEON_BF16 +#define HWY_NAMESPACE N_NEON_BF16 +#else +#error "Logic error, missing case" +#endif // HWY_TARGET + +// Can use pragmas instead of -march compiler flag +#if HWY_HAVE_RUNTIME_DISPATCH +#if HWY_ARCH_ARM_V7 + +// The __attribute__((target(+neon-vfpv4)) was introduced in gcc >= 8. +#if HWY_COMPILER_GCC_ACTUAL >= 800 +#define HWY_TARGET_STR "+neon-vfpv4" +#else // GCC < 7 +// Do not define HWY_TARGET_STR (no pragma). +#endif // HWY_COMPILER_GCC_ACTUAL + +#else // !HWY_ARCH_ARM_V7 + +#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1300) || \ + (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1300) +// GCC 12 or earlier and Clang 12 or earlier require +crypto be added to the +// target string to enable AArch64 AES intrinsics +#define HWY_TARGET_STR_NEON "+crypto" +#else +#define HWY_TARGET_STR_NEON "+aes" +#endif + +// Clang >= 16 requires +fullfp16 instead of fp16, but Apple Clang 15 = 1600 +// fails to parse unless the string starts with armv8, whereas 1700 refuses it. +#if HWY_COMPILER_CLANG >= 1700 +#define HWY_TARGET_STR_FP16 "+fullfp16" +#elif HWY_COMPILER_CLANG >= 1600 && defined(__apple_build_version__) +#define HWY_TARGET_STR_FP16 "armv8.4-a+fullfp16" +#else +#define HWY_TARGET_STR_FP16 "+fp16" +#endif + +#if HWY_OS_APPLE +// Enable i8mm for the NEON_BF16 target if compiling for macOS, iOS, or iPadOS +// as all Apple Silicon CPU's that support BF16 have support for I8MM. +#define HWY_TARGET_STR_NEON_BF16_EXTRA "+i8mm" +#else +#define HWY_TARGET_STR_NEON_BF16_EXTRA "" +#endif + +#if HWY_TARGET == HWY_NEON_WITHOUT_AES +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400 +// Prevents inadvertent use of SVE by GCC 13.4 and earlier, see #2689. +#define HWY_TARGET_STR "+nosve" +#else +// Do not define HWY_TARGET_STR (no pragma). +#endif // HWY_COMPILER_GCC_ACTUAL +#elif HWY_TARGET == HWY_NEON +#define HWY_TARGET_STR HWY_TARGET_STR_NEON +#elif HWY_TARGET == HWY_NEON_BF16 +#define HWY_TARGET_STR \ + HWY_TARGET_STR_FP16 \ + "+bf16+dotprod" HWY_TARGET_STR_NEON_BF16_EXTRA HWY_TARGET_STR_NEON +#else +#error "Logic error, missing case" +#endif // HWY_TARGET + +#endif // !HWY_ARCH_ARM_V7 +#else // !HWY_HAVE_RUNTIME_DISPATCH +// HWY_TARGET_STR remains undefined +#endif + +//----------------------------------------------------------------------------- +// SVE[2] +#elif HWY_TARGET_IS_SVE + +// SVE only requires lane alignment, not natural alignment of the entire vector. +#define HWY_ALIGN alignas(8) + +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 0 +#define HWY_NATIVE_FMA 1 +#if HWY_SVE_HAVE_BF16_FEATURE || HWY_TARGET == HWY_SVE2_128 +#define HWY_NATIVE_DOT_BF16 1 +#else +#define HWY_NATIVE_DOT_BF16 0 +#endif +#define HWY_NATIVE_MASK 1 + +#if HWY_TARGET == HWY_SVE2 +#define HWY_NAMESPACE N_SVE2 +#define HWY_MAX_BYTES 256 +#define HWY_MIN_BYTES 16 +#define HWY_HAVE_SCALABLE 1 +#elif HWY_TARGET == HWY_SVE_256 +#define HWY_NAMESPACE N_SVE_256 +#define HWY_MAX_BYTES 32 +#define HWY_MIN_BYTES 32 +#define HWY_HAVE_SCALABLE 0 +#elif HWY_TARGET == HWY_SVE2_128 +#define HWY_NAMESPACE N_SVE2_128 +#define HWY_MAX_BYTES 16 +#define HWY_MIN_BYTES 16 +#define HWY_HAVE_SCALABLE 0 +#else +#define HWY_NAMESPACE N_SVE +#define HWY_MAX_BYTES 256 +#define HWY_MIN_BYTES 16 +#define HWY_HAVE_SCALABLE 1 +#endif + +// Note: -march strings are delimited by + and GCC actually requires + before +// each pragma target, which are also comma-separated. + +#undef HWY_TARGET_STR_SVE2_AES +// Static dispatch with -march=armv8-a+sve2+aes, or no baseline, hence dynamic +// dispatch, which checks for AES support at runtime. +#if defined(__ARM_FEATURE_SVE2_AES) || (HWY_BASELINE_SVE2 == 0) +#define HWY_TARGET_STR_SVE2_AES ",+sve2-aes" +#else // SVE2 without AES +#define HWY_TARGET_STR_SVE2_AES "" +#endif + +#undef HWY_TARGET_STR_SVE2_128 +// SVE2_128 implies/requires I8MM and BF16, see #2973. +#if HWY_TARGET == HWY_SVE2_128 +#define HWY_TARGET_STR_SVE2_128 ",+i8mm,+bf16" +#else +#define HWY_TARGET_STR_SVE2_128 "" +#endif + +// Can use pragmas instead of -march compiler flag +#if HWY_HAVE_RUNTIME_DISPATCH +#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128 +#define HWY_TARGET_STR \ + "+sve,+sve2" HWY_TARGET_STR_SVE2_AES HWY_TARGET_STR_SVE2_128 +#else // not SVE2 target +#define HWY_TARGET_STR "+sve" +#endif +#else // !HWY_HAVE_RUNTIME_DISPATCH +// HWY_TARGET_STR remains undefined +#endif + +//----------------------------------------------------------------------------- +// WASM +#elif HWY_TARGET == HWY_WASM + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_MIN_BYTES 16 + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_NATIVE_MASK 0 + +#define HWY_NAMESPACE N_WASM + +#define HWY_TARGET_STR "simd128" + +//----------------------------------------------------------------------------- +// WASM_EMU256 +#elif HWY_TARGET == HWY_WASM_EMU256 + +#define HWY_ALIGN alignas(32) +#define HWY_MAX_BYTES 32 +#define HWY_MIN_BYTES 32 + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_NATIVE_MASK 0 + +#define HWY_NAMESPACE N_WASM_EMU256 + +#define HWY_TARGET_STR "simd128" + +//----------------------------------------------------------------------------- +// RVV +#elif HWY_TARGET == HWY_RVV + +// RVV only requires lane alignment, not natural alignment of the entire vector, +// and the compiler already aligns builtin types, so nothing to do here. +#define HWY_ALIGN + +// The spec requires VLEN <= 2^16 bits, so the limit is 2^16 bytes (LMUL=8). +#define HWY_MAX_BYTES 65536 +#define HWY_MIN_BYTES 16 + +// = HWY_MAX_BYTES divided by max LMUL=8 because MaxLanes includes the actual +// LMUL. This is the tightest possible upper bound. +#undef HWY_LANES +#define HWY_LANES(T) (8192 / sizeof(T)) + +#define HWY_HAVE_SCALABLE 1 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 0 +#define HWY_NATIVE_FMA 1 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_NATIVE_MASK 1 + +#if HWY_RVV_HAVE_F16_VEC +#define HWY_HAVE_FLOAT16 1 +#else +#define HWY_HAVE_FLOAT16 0 +#endif + +#define HWY_NAMESPACE N_RVV + +#if HWY_COMPILER_CLANG >= 1900 +// https://github.com/riscv/riscv-v-spec/blob/master/v-spec.adoc#181-zvl-minimum-vector-length-standard-extensions +#define HWY_TARGET_STR "arch=+v" +#else +// HWY_TARGET_STR remains undefined so HWY_ATTR is a no-op. +#endif + +//----------------------------------------------------------------------------- +// LSX/LASX +#elif HWY_TARGET == HWY_LSX || HWY_TARGET == HWY_LASX + +#if HWY_TARGET == HWY_LSX +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_MIN_BYTES 16 +#ifndef __loongarch_sx +#define HWY_TARGET_STR "lsx" +#endif +#else +#define HWY_ALIGN alignas(32) +#define HWY_MAX_BYTES 32 +#define HWY_MIN_BYTES 32 +#ifndef __loongarch_asx +#define HWY_TARGET_STR "lsx,lasx" +#endif +#endif + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 1 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_NATIVE_MASK 0 + +#if HWY_TARGET == HWY_LSX +#define HWY_NAMESPACE N_LSX +#else +#define HWY_NAMESPACE N_LASX +#endif + +// HWY_TARGET_STR remains undefined so HWY_ATTR is a no-op. + +//----------------------------------------------------------------------------- +// EMU128 +#elif HWY_TARGET == HWY_EMU128 + +#define HWY_ALIGN alignas(16) +#define HWY_MAX_BYTES 16 +#define HWY_MIN_BYTES 16 + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#define HWY_NATIVE_FMA 0 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_NATIVE_MASK 0 + +#define HWY_NAMESPACE N_EMU128 + +// HWY_TARGET_STR remains undefined so HWY_ATTR is a no-op. + +//----------------------------------------------------------------------------- +// SCALAR +#elif HWY_TARGET == HWY_SCALAR + +#define HWY_ALIGN +#define HWY_MAX_BYTES 8 +#define HWY_MIN_BYTES 8 +#undef HWY_LANES +#define HWY_LANES(T) 1 + +#define HWY_HAVE_SCALABLE 0 +#define HWY_HAVE_INTEGER64 1 +#define HWY_HAVE_FLOAT16 0 +#define HWY_HAVE_FLOAT64 1 +#define HWY_MEM_OPS_MIGHT_FAULT 0 +#define HWY_NATIVE_FMA 0 +#define HWY_NATIVE_DOT_BF16 0 +#define HWY_NATIVE_MASK 0 + +#define HWY_NAMESPACE N_SCALAR + +// HWY_TARGET_STR remains undefined so HWY_ATTR is a no-op. + +#else +#pragma message("HWY_TARGET does not match any known target") +#endif // HWY_TARGET + +//----------------------------------------------------------------------------- + +// Sanity check: if we have f16 vector support, then base.h should also be +// using a built-in type for f16 scalars. +#if HWY_HAVE_FLOAT16 && !HWY_HAVE_SCALAR_F16_TYPE +#error "Logic error: f16 vectors but no scalars" +#endif + +// Override this to 1 in asan/msan builds, which will still fault. +#if HWY_IS_ASAN || HWY_IS_MSAN +#undef HWY_MEM_OPS_MIGHT_FAULT +#define HWY_MEM_OPS_MIGHT_FAULT 1 +#endif + +// Clang <9 requires this be invoked at file scope, before any namespace. +#undef HWY_BEFORE_NAMESPACE +#if defined(HWY_TARGET_STR) && !defined(HWY_DISABLE_ATTR) +#define HWY_BEFORE_NAMESPACE() \ + HWY_PUSH_ATTRIBUTES(HWY_TARGET_STR) \ + static_assert(true, "For requiring trailing semicolon") +#else +// avoids compiler warning if no HWY_TARGET_STR +#define HWY_BEFORE_NAMESPACE() \ + static_assert(true, "For requiring trailing semicolon") +#endif + +// Clang <9 requires any namespaces be closed before this macro. +#undef HWY_AFTER_NAMESPACE +#if defined(HWY_TARGET_STR) && !defined(HWY_DISABLE_ATTR) +#define HWY_AFTER_NAMESPACE() \ + HWY_POP_ATTRIBUTES \ + static_assert(true, "For requiring trailing semicolon") +#else +// avoids compiler warning if no HWY_TARGET_STR +#define HWY_AFTER_NAMESPACE() \ + static_assert(true, "For requiring trailing semicolon") +#endif + +#undef HWY_ATTR +#if defined(HWY_TARGET_STR) && HWY_HAS_ATTRIBUTE(target) && \ + !defined(HWY_DISABLE_ATTR) +#define HWY_ATTR __attribute__((target(HWY_TARGET_STR))) +#else +#define HWY_ATTR +#endif + +#if (HWY_MAX_BYTES <= 16) || HWY_TARGET_IS_SVE || (HWY_TARGET == HWY_RVV) || \ + (HWY_TARGET == HWY_WASM_EMU256) +#define HWY_NATIVE_INTERLEAVE_WHOLE 1 +#else +#define HWY_NATIVE_INTERLEAVE_WHOLE 0 +#endif diff --git a/lib/highway/hwy/ops/shared-inl.h b/lib/highway/hwy/ops/shared-inl.h new file mode 100644 index 00000000000..5a11f92490c --- /dev/null +++ b/lib/highway/hwy/ops/shared-inl.h @@ -0,0 +1,721 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Per-target definitions shared by ops/*.h and user code. + +// IWYU pragma: begin_exports +// Export does not seem to be recursive, so re-export these (also in base.h) +#include + +#include "hwy/base.h" +// "IWYU pragma: keep" does not work for this include, so hide it from the IDE. +#if !HWY_IDE +#include +#endif + +#include "hwy/detect_compiler_arch.h" +#include "hwy/detect_targets.h" + +// Separate header because foreach_target.h re-enables its include guard. +#include "hwy/ops/set_macros-inl.h" + +// IWYU pragma: end_exports + +#if HWY_IS_MSAN +#include +#endif + +// We are covered by the highway.h include guard, but generic_ops-inl.h +// includes this again #if HWY_IDE. +// clang-format off +#if defined(HIGHWAY_HWY_OPS_SHARED_TOGGLE) == defined(HWY_TARGET_TOGGLE) // NOLINT +// clang-format on +#ifdef HIGHWAY_HWY_OPS_SHARED_TOGGLE +#undef HIGHWAY_HWY_OPS_SHARED_TOGGLE +#else +#define HIGHWAY_HWY_OPS_SHARED_TOGGLE +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// NOTE: GCC generates incorrect code for vector arguments to non-inlined +// functions in two situations: +// - on Windows and GCC 10.3, passing by value crashes due to unaligned loads: +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=54412. +// - on aarch64 and GCC 9.3.0 or 11.2.1, passing by value causes many (but not +// all) tests to fail. +// +// We therefore pass by const& only on GCC and (Windows or aarch64). This alias +// must be used for all vector/mask parameters of functions marked HWY_NOINLINE, +// and possibly also other functions that are not inlined. +// +// Even better is to avoid passing vector arguments to non-inlined functions, +// because the SVE and RISC-V ABIs are still works in progress and may lead to +// incorrect codegen. +#if HWY_COMPILER_GCC_ACTUAL && (HWY_OS_WIN || HWY_ARCH_ARM_A64) +template +using VecArg = const V&; +#else +template +using VecArg = V; +#endif + +namespace detail { + +template +struct NativeLaneTypeT { + using type = T; +}; +template <> +struct NativeLaneTypeT { +#if HWY_HAVE_SCALAR_F16_TYPE + using type = hwy::float16_t::Native; +#else + using type = uint16_t; +#endif +}; +template <> +struct NativeLaneTypeT { +#if HWY_HAVE_SCALAR_BF16_TYPE + using type = hwy::bfloat16_t::Native; +#else + using type = uint16_t; +#endif +}; + +// The type expected by intrinsics for the given Highway lane type T. This +// usually matches T, but differs for our wrapper types [b]float16_t. Use this +// only when defining intrinsic wrappers, and NOT for casting, which is UB. +template +using NativeLaneType = typename NativeLaneTypeT::type; + +// Returns the same pointer after changing type to NativeLaneType. Use this only +// for wrapper functions that call intrinsics (e.g. load/store) where some of +// the overloads expect _Float16* or __bf16* arguments. For non-special floats, +// this returns the same pointer and type. +// +// This makes use of the fact that a wrapper struct is pointer-interconvertible +// with its first member (a union), thus also with the union members. Do NOT +// call both this and U16LanePointer on the same object - they access different +// union members, and this is not guaranteed to be safe. +template +HWY_INLINE T* NativeLanePointer(T* p) { + return p; +} +template >, + HWY_IF_F16(T)> +HWY_INLINE constexpr If(), const NT*, NT*> NativeLanePointer(T* p) { +#if HWY_HAVE_SCALAR_F16_TYPE + return &p->native; +#else + return &p->bits; +#endif +} +template >, + HWY_IF_BF16(T)> +HWY_INLINE constexpr If(), const NT*, NT*> NativeLanePointer(T* p) { +#if HWY_HAVE_SCALAR_BF16_TYPE + return &p->native; +#else + return &p->bits; +#endif +} + +// Returns a pointer to the u16 member of our [b]float16_t wrapper structs. +// Use this in Highway targets that lack __bf16 intrinsics; for storing to +// memory, we BitCast vectors to u16 and write to the pointer returned here. +// Do NOT call both this and U16LanePointer on the same object - they access +// different union members, and this is not guaranteed to be safe. +template +HWY_INLINE If(), const uint16_t*, uint16_t*> U16LanePointer(T* p) { + return &p->bits; +} + +// Returns N * 2^pow2. N is the number of lanes in a full vector and pow2 the +// desired fraction or multiple of it, see Simd<>. `pow2` is most often in +// [-3, 3] but can also be lower for user-specified fractions. +constexpr size_t ScaleByPower(size_t N, int pow2) { + return pow2 >= 0 ? (N << pow2) : (N >> (-pow2)); +} + +template +HWY_INLINE void MaybePoison(T* HWY_RESTRICT unaligned, size_t count) { +#if HWY_IS_MSAN + __msan_poison(unaligned, count * sizeof(T)); +#else + (void)unaligned; + (void)count; +#endif +} + +// This can be useful for working around MSAN limitations. For example, prior +// to Clang 16, it did not understand AVX-512 CompressStore. +template +HWY_INLINE void MaybeUnpoison(T* HWY_RESTRICT unaligned, size_t count) { +#if HWY_IS_MSAN + __msan_unpoison(unaligned, count * sizeof(T)); +#else + (void)unaligned; + (void)count; +#endif +} + +} // namespace detail + +// Highway operations are implemented as overloaded functions selected using a +// zero-sized tag type D := Simd. T denotes the lane type. +// +// N defines how many lanes are in a 'full' vector, typically equal to +// HWY_LANES(T) (which is the actual count on targets with vectors of known +// size, and an upper bound in case of scalable vectors), otherwise a +// user-specified limit at most that large. +// +// 2^kPow2 is a _subsequently_ applied scaling factor that indicates the +// desired fraction of a 'full' vector: 0 means full, -1 means half; 1,2,3 +// means two/four/eight full vectors ganged together. The largest supported +// kPow2 is `HWY_MAX_POW2` and the aliases below take care of clamping +// user-specified values to that. Note that `Simd` and `Simd` +// have the same `MaxLanes` and `Lanes`. +// +// We can theoretically keep halving Lanes(), but recursive instantiations of +// kPow2 - 1 will eventually fail e.g. because -64 is not a valid shift count. +// Users must terminate such compile-time recursions at or above HWY_MIN_POW2. +// +// WARNING: do not use N directly because it may be a special representation of +// a fractional MaxLanes. This arises when we Rebind Simd to +// Simd. RVV requires that the last argument (kPow2) be two, +// but we want MaxLanes to be the same in both cases. Hence ?? is a +// fixed-point encoding of 1/4. +// +// Instead of referring to Simd<> directly, users create D via aliases: +// - ScalableTag for a full vector; +// - ScalableTag() for a fraction/group, where `kPow2` is +// interpreted as `HWY_MIN(kPow2, HWY_MAX_POW2)`; +// - CappedTag for a vector with up to kLimit lanes; or +// - FixedTag for a vector with exactly kNumLanes lanes. +// +// Instead of N, use Lanes(D()) for the actual number of lanes at runtime and +// D().MaxLanes() for a constexpr upper bound. Both are powers of two. +template +struct Simd { + constexpr Simd() = default; + using T = Lane; + + private: + static_assert(sizeof(Lane) <= 8, "Lanes are up to 64-bit"); + static_assert(IsSame>(), + "Lane must not be a reference type, const-qualified type, or " + "volatile-qualified type"); + static_assert(IsIntegerLaneType() || IsFloat() || + IsSpecialFloat(), + "IsIntegerLaneType(), IsFloat(), or IsSpecialFloat() " + "must be true"); + // 20 bits are sufficient for any HWY_MAX_BYTES. This is the 'normal' value of + // N when kFrac == 0, otherwise it is one (see FracN). + static constexpr size_t kWhole = N & 0xFFFFF; + // Fractional part is in the bits above kWhole. + static constexpr int kFrac = static_cast(N >> 20); + // Can be 8x larger because kPow2 may be as low as -3 (Rebind of a larger + // type to u8 results in fractions). + static_assert(kWhole <= 8 * HWY_MAX_N && kFrac <= 3, "Out of range"); + static_assert(kFrac == 0 || kWhole == 1, "If frac, whole must be 1"); + static_assert((kWhole & (kWhole - 1)) == 0 && kWhole != 0, "Not 2^x"); + // Important to check this here because kPow2 <= -64 causes confusing + // compile errors (invalid shift count). + static_assert(kPow2 >= HWY_MIN_POW2, "Forgot kPow2 recursion terminator?"); + // However, do NOT verify kPow2 <= HWY_MAX_POW2 - users should be able to + // Rebind> in order to discover that its + // kPow2 is out of bounds. + + public: + // Upper bound on the number of lanes (tight if !HWY_HAVE_SCALABLE). In the + // common case, N == kWhole, but if kFrac is nonzero, we deduct it from kPow2. + // E.g. Rebind> is Simd. + // The resulting number of lanes is still 1 because this N represents 1/4 + // (the ratio of the sizes). Note that RVV requires kPow2 to be the ratio of + // the sizes so that the correct LMUL overloads are chosen, even if N is + // small enough that it would fit in an LMUL=1 vector. + // + // Cannot be an enum because GCC warns when using enums and non-enums in the + // same expression. Cannot be a static constexpr function (MSVC limitation). + // Rounded up to one so this is a valid array length. + // + // Do not use this directly - only 'public' so it is visible from the accessor + // macro required by MSVC. + static constexpr size_t kPrivateLanes = + HWY_MAX(size_t{1}, detail::ScaleByPower(kWhole, kPow2 - kFrac)); + // Do not use this directly - only 'public' so it is visible from the accessor + // macro required by MSVC. + static constexpr int kPrivatePow2 = kPow2; + + constexpr size_t MaxLanes() const { return kPrivateLanes; } + constexpr size_t MaxBytes() const { return kPrivateLanes * sizeof(Lane); } + constexpr size_t MaxBlocks() const { return (MaxBytes() + 15) / 16; } + // For SFINAE (HWY_IF_POW2_GT_D). + constexpr int Pow2() const { return kPow2; } + + // ------------------------------ Changing lane type or count + // Do not use any of these directly. Anything used from member typedefs cannot + // be made private, but functions only used within other functions can. + + // Returns number of NewT lanes that fit within MaxBytes(). + template + static constexpr size_t RepartitionLanes() { + // Round up to correctly handle larger NewT. + return (kPrivateLanes * sizeof(T) + sizeof(NewT) - 1) / sizeof(NewT); + } + + // Returns the new kPow2 required for lanes of type NewT. + template + static constexpr int RebindPow2() { + return kPow2 + + ((sizeof(NewT) >= sizeof(T)) + ? static_cast(CeilLog2(sizeof(NewT) / sizeof(T))) + : -static_cast(CeilLog2(sizeof(T) / sizeof(NewT)))); + } + + private: + // Returns 0 or whole NewN such that kNewMaxLanes = NewN * 2^kNewPow2. + template + static constexpr size_t WholeN() { + return detail::ScaleByPower(kNewMaxLanes, -kNewPow2); + } + + // Returns fractional NewN such that kNewMaxLanes = NewN * 2^kNewPow2. + template + static constexpr size_t FracN() { + // Only reached if kNewPow2 > CeilLog2(kNewMaxLanes) >= 0 (else WholeN + // would not have been zero), but clamp to zero to avoid warnings. kFrac is + // the difference, stored in the upper bits of N, and we also set kWhole = + // 1 so that the new kPrivateLanes = kNewMaxLanes. + static_assert(HWY_MAX_N <= (size_t{1} << 20), "Change bit shift"); + return static_cast( + 1 + (HWY_MAX(0, kNewPow2 - static_cast(CeilLog2(kNewMaxLanes))) + << 20)); + } + + public: + // Returns (whole or fractional) NewN, see above. + template + static constexpr size_t NewN() { + // We require a fraction if inverting kNewPow2 results in 0. + return WholeN() == 0 + ? FracN() + : WholeN(); + } + + // PromoteTo/DemoteTo() with another lane type, but same number of lanes. + template + using Rebind = + Simd(), kPrivateLanes>(), RebindPow2()>; + + // Change lane type while keeping the same vector size, e.g. for MulEven. + template + using Repartition = + Simd()>(), kPow2>; + + // Half the lanes while keeping the same lane type, e.g. for LowerHalf. + using Half = Simd; + + // Twice the lanes while keeping the same lane type, e.g. for Combine. + using Twice = Simd; +}; + +namespace detail { + +template +constexpr bool IsFull(Simd /* d */) { + return N == HWY_LANES(T) && kPow2 == 0; +} + +// Struct wrappers enable validation of arguments via static_assert. +template +struct ClampNAndPow2 { + using type = Simd; +}; + +template +struct ScalableTagChecker { + using type = typename ClampNAndPow2::type; +}; + +template +struct CappedTagChecker { + static_assert(kLimit != 0, "Does not make sense to have zero lanes"); + // Safely handle non-power-of-two inputs by rounding down, which is allowed by + // CappedTag. Otherwise, Simd would static_assert. + static constexpr size_t kLimitPow2 = size_t{1} << hwy::FloorLog2(kLimit); + static constexpr size_t N = HWY_MIN(kLimitPow2, HWY_LANES(T)); + using type = typename ClampNAndPow2::type; +}; + +template +struct FixedTagChecker { + static_assert(kNumLanes != 0, "Does not make sense to have zero lanes"); + static_assert(kNumLanes <= HWY_LANES(T), "Too many lanes"); + using type = Simd; +}; + +} // namespace detail + +// ------------------------------ Aliases for Simd<> + +// Tag describing a full vector (kPow2 == 0: the most common usage, e.g. 1D +// loops where the application does not care about the vector size) or a +// fraction/multiple of one. Fractions (kPow2 < 0) are useful for arguments or +// return values of type promotion and demotion. User-specified kPow2 is +// interpreted as `HWY_MIN(kPow2, HWY_MAX_POW2)`. +template +using ScalableTag = typename detail::ScalableTagChecker::type; + +// Tag describing a vector with *up to* kLimit active lanes, even on targets +// with scalable vectors and HWY_SCALAR. The runtime lane count `Lanes(tag)` may +// be less than kLimit, and is 1 on HWY_SCALAR. This alias is typically used for +// 1D loops with a relatively low application-defined upper bound, e.g. for 8x8 +// DCTs. However, it is better if data structures are designed to be +// vector-length-agnostic (e.g. a hybrid SoA where there are chunks of `M >= +// MaxLanes(d)` DC components followed by M AC1, .., and M AC63; this would +// enable vector-length-agnostic loops using ScalableTag). User-specified kPow2 +// is interpreted as `HWY_MIN(kPow2, HWY_MAX_POW2)`. +template +using CappedTag = typename detail::CappedTagChecker::type; + +#if !HWY_HAVE_SCALABLE +// If the vector size is known, and the app knows it does not want more than +// kLimit lanes, then capping can be beneficial. For example, AVX-512 has lower +// IPC and potentially higher costs for unaligned load/store vs. 256-bit AVX2. +template +using CappedTagIfFixed = CappedTag; +#else // HWY_HAVE_SCALABLE +// .. whereas on RVV/SVE, the cost of clamping Lanes() may exceed the benefit. +template +using CappedTagIfFixed = ScalableTag; +#endif + +// Alias for a tag describing a vector with *exactly* kNumLanes active lanes, +// even on targets with scalable vectors. Requires `kNumLanes` to be a power of +// two not exceeding `HWY_LANES(T)`. +// +// NOTE: if the application does not need to support HWY_SCALAR (+), use this +// instead of CappedTag to emphasize that there will be exactly kNumLanes lanes. +// This is useful for data structures that rely on exactly 128-bit SIMD, but +// these are discouraged because they cannot benefit from wider vectors. +// Instead, applications would ideally define a larger problem size and loop +// over it with the (unknown size) vectors from ScalableTag. +// +// + e.g. if the baseline is known to support SIMD, or the application requires +// ops such as TableLookupBytes not supported by HWY_SCALAR. +template +using FixedTag = typename detail::FixedTagChecker::type; + +// Convenience form for fixed sizes. +template +using Full16 = Simd; + +template +using Full32 = Simd; + +template +using Full64 = Simd; + +template +using Full128 = Simd; + +// ------------------------------ Accessors for Simd<> + +// Lane type. +template +using TFromD = typename D::T; + +// Upper bound on the number of lanes, typically used for SFINAE conditions and +// to allocate storage for targets with known vector sizes. Note: this may be a +// loose bound, instead use Lanes() as the actual size for AllocateAligned. +// MSVC workaround: use static constant directly instead of a function. +#define HWY_MAX_LANES_D(D) D::kPrivateLanes + +// Same as D().Pow2(), but this is too complex for SFINAE with MSVC, so we use a +// static constant directly. +#define HWY_POW2_D(D) D::kPrivatePow2 + +// Non-macro form of HWY_MAX_LANES_D in case that is preferable. WARNING: the +// macro form may be required for MSVC, which has limitations on deducing +// arguments. +template +HWY_INLINE HWY_MAYBE_UNUSED constexpr size_t MaxLanes(D) { + return HWY_MAX_LANES_D(D); +} + +#undef HWY_HAVE_CONSTEXPR_LANES +#undef HWY_LANES_CONSTEXPR + +#if HWY_HAVE_SCALABLE +#define HWY_HAVE_CONSTEXPR_LANES 0 +#define HWY_LANES_CONSTEXPR +#else + +// We want Lanes() to be constexpr where possible, so that compilers are able to +// precompute offsets. However, user code must not depend on the constexpr, +// because that will fail for RISC-V V and Arm SVE. To achieve both, we mark it +// as non-constexpr in debug builds, but not sanitizers, because we typically +// want them to see the same code. +#if HWY_IS_DEBUG_BUILD && !HWY_IS_SANITIZER +#define HWY_HAVE_CONSTEXPR_LANES 0 +#define HWY_LANES_CONSTEXPR +#else +#define HWY_HAVE_CONSTEXPR_LANES 1 +#define HWY_LANES_CONSTEXPR constexpr +#endif + +// Returns actual vector length, used when advancing loop counters. The +// non-constexpr implementations are defined in their target's header. For a +// guaranteed-constexpr upper bound, use `MaxLanes(d)`. +template +HWY_INLINE HWY_MAYBE_UNUSED HWY_LANES_CONSTEXPR size_t Lanes(D) { + return HWY_MAX_LANES_D(D); +} + +#endif // !HWY_HAVE_SCALABLE + +// Tag for the same number of lanes as D, but with the LaneType T. +template +using Rebind = typename D::template Rebind; + +template +using RebindToSigned = Rebind>, D>; +template +using RebindToUnsigned = Rebind>, D>; +template +using RebindToFloat = Rebind>, D>; + +// Tag for the same total size as D, but with the LaneType T. +template +using Repartition = typename D::template Repartition; + +template +using RepartitionToWide = Repartition>, D>; +template +using RepartitionToNarrow = Repartition>, D>; + +// Shorthand for applying RepartitionToWide twice (for 8/16-bit types). +template +using RepartitionToWideX2 = RepartitionToWide>; +// Shorthand for applying RepartitionToWide three times (for 8-bit types). +template +using RepartitionToWideX3 = RepartitionToWide>; + +// Tag for the same lane type as D, but half the lanes. +template +using Half = typename D::Half; + +// Tag for the same lane type as D, but twice the lanes. +template +using Twice = typename D::Twice; + +// Tag for a 16-byte block with the same lane type as D +#if HWY_HAVE_SCALABLE +namespace detail { + +template +class BlockDFromD_t {}; + +template +class BlockDFromD_t> { + using D = Simd; + static constexpr int kNewPow2 = HWY_MIN(kPow2, 0); + static constexpr size_t kMaxLpb = HWY_MIN(16 / sizeof(T), HWY_MAX_LANES_D(D)); + static constexpr size_t kNewN = D::template NewN(); + + public: + using type = Simd; +}; + +} // namespace detail + +template +using BlockDFromD = typename detail::BlockDFromD_t>::type; +#else +template +using BlockDFromD = + Simd, HWY_MIN(16 / sizeof(TFromD), HWY_MAX_LANES_D(D)), 0>; +#endif + +// Returns whether `ptr` is a multiple of `Lanes(d)` elements. +template +HWY_API bool IsAligned(D d, T* ptr) { + const size_t N = Lanes(d); + return reinterpret_cast(ptr) % (N * sizeof(T)) == 0; +} + +// ------------------------------ Choosing overloads (SFINAE) + +// Same as base.h macros but with a Simd argument instead of T. +#define HWY_IF_UNSIGNED_D(D) HWY_IF_UNSIGNED(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_NOT_UNSIGNED_D(D) \ + HWY_IF_NOT_UNSIGNED(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_SIGNED_D(D) HWY_IF_SIGNED(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_FLOAT_D(D) HWY_IF_FLOAT(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_NOT_FLOAT_D(D) HWY_IF_NOT_FLOAT(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_FLOAT3264_D(D) HWY_IF_FLOAT3264(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_NOT_FLOAT3264_D(D) \ + HWY_IF_NOT_FLOAT3264(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_SPECIAL_FLOAT_D(D) \ + HWY_IF_SPECIAL_FLOAT(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_NOT_SPECIAL_FLOAT_D(D) \ + HWY_IF_NOT_SPECIAL_FLOAT(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_FLOAT_OR_SPECIAL_D(D) \ + HWY_IF_FLOAT_OR_SPECIAL(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D) \ + HWY_IF_NOT_FLOAT_NOR_SPECIAL(hwy::HWY_NAMESPACE::TFromD) + +#define HWY_IF_T_SIZE_D(D, bytes) \ + HWY_IF_T_SIZE(hwy::HWY_NAMESPACE::TFromD, bytes) +#define HWY_IF_NOT_T_SIZE_D(D, bytes) \ + HWY_IF_NOT_T_SIZE(hwy::HWY_NAMESPACE::TFromD, bytes) +#define HWY_IF_T_SIZE_ONE_OF_D(D, bit_array) \ + HWY_IF_T_SIZE_ONE_OF(hwy::HWY_NAMESPACE::TFromD, bit_array) +#define HWY_IF_T_SIZE_LE_D(D, bytes) \ + HWY_IF_T_SIZE_LE(hwy::HWY_NAMESPACE::TFromD, bytes) +#define HWY_IF_T_SIZE_GT_D(D, bytes) \ + HWY_IF_T_SIZE_GT(hwy::HWY_NAMESPACE::TFromD, bytes) + +#define HWY_IF_LANES_D(D, lanes) HWY_IF_LANES(HWY_MAX_LANES_D(D), lanes) +#define HWY_IF_LANES_LE_D(D, lanes) HWY_IF_LANES_LE(HWY_MAX_LANES_D(D), lanes) +#define HWY_IF_LANES_GT_D(D, lanes) HWY_IF_LANES_GT(HWY_MAX_LANES_D(D), lanes) +#define HWY_IF_LANES_PER_BLOCK_D(D, lanes) \ + HWY_IF_LANES_PER_BLOCK(hwy::HWY_NAMESPACE::TFromD, HWY_MAX_LANES_D(D), \ + lanes) + +#if HWY_COMPILER_MSVC +#define HWY_IF_POW2_LE_D(D, pow2) \ + hwy::EnableIf* = nullptr +#define HWY_IF_POW2_GT_D(D, pow2) \ + hwy::EnableIf<(HWY_POW2_D(D) > pow2)>* = nullptr +#else +#define HWY_IF_POW2_LE_D(D, pow2) hwy::EnableIf* = nullptr +#define HWY_IF_POW2_GT_D(D, pow2) hwy::EnableIf<(D().Pow2() > pow2)>* = nullptr +#endif // HWY_COMPILER_MSVC + +#define HWY_IF_U8_D(D) HWY_IF_U8(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_U16_D(D) HWY_IF_U16(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_U32_D(D) HWY_IF_U32(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_U64_D(D) HWY_IF_U64(hwy::HWY_NAMESPACE::TFromD) + +#define HWY_IF_I8_D(D) HWY_IF_I8(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_I16_D(D) HWY_IF_I16(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_I32_D(D) HWY_IF_I32(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_I64_D(D) HWY_IF_I64(hwy::HWY_NAMESPACE::TFromD) + +// Use instead of HWY_IF_T_SIZE_D to avoid ambiguity with float16_t/float/double +// overloads. +#define HWY_IF_UI8_D(D) HWY_IF_UI8(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_UI16_D(D) HWY_IF_UI16(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_UI32_D(D) HWY_IF_UI32(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_UI64_D(D) HWY_IF_UI64(hwy::HWY_NAMESPACE::TFromD) + +#define HWY_IF_BF16_D(D) HWY_IF_BF16(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_NOT_BF16_D(D) HWY_IF_NOT_BF16(hwy::HWY_NAMESPACE::TFromD) + +#define HWY_IF_F16_D(D) HWY_IF_F16(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_NOT_F16_D(D) HWY_IF_NOT_F16(hwy::HWY_NAMESPACE::TFromD) + +#define HWY_IF_F32_D(D) HWY_IF_F32(hwy::HWY_NAMESPACE::TFromD) +#define HWY_IF_F64_D(D) HWY_IF_F64(hwy::HWY_NAMESPACE::TFromD) + +#define HWY_V_SIZE_D(D) \ + (HWY_MAX_LANES_D(D) * sizeof(hwy::HWY_NAMESPACE::TFromD)) +#define HWY_IF_V_SIZE_D(D, bytes) \ + HWY_IF_V_SIZE(hwy::HWY_NAMESPACE::TFromD, HWY_MAX_LANES_D(D), bytes) +#define HWY_IF_V_SIZE_LE_D(D, bytes) \ + HWY_IF_V_SIZE_LE(hwy::HWY_NAMESPACE::TFromD, HWY_MAX_LANES_D(D), bytes) +#define HWY_IF_V_SIZE_GT_D(D, bytes) \ + HWY_IF_V_SIZE_GT(hwy::HWY_NAMESPACE::TFromD, HWY_MAX_LANES_D(D), bytes) + +// Same, but with a vector argument. ops/*-inl.h define their own TFromV. +#define HWY_IF_UNSIGNED_V(V) HWY_IF_UNSIGNED(hwy::HWY_NAMESPACE::TFromV) +#define HWY_IF_NOT_UNSIGNED_V(V) \ + HWY_IF_NOT_UNSIGNED(hwy::HWY_NAMESPACE::TFromV) +#define HWY_IF_SIGNED_V(V) HWY_IF_SIGNED(hwy::HWY_NAMESPACE::TFromV) +#define HWY_IF_FLOAT_V(V) HWY_IF_FLOAT(hwy::HWY_NAMESPACE::TFromV) +#define HWY_IF_NOT_FLOAT_V(V) HWY_IF_NOT_FLOAT(hwy::HWY_NAMESPACE::TFromV) +#define HWY_IF_FLOAT3264_V(V) HWY_IF_FLOAT3264(hwy::HWY_NAMESPACE::TFromV) +#define HWY_IF_SPECIAL_FLOAT_V(V) \ + HWY_IF_SPECIAL_FLOAT(hwy::HWY_NAMESPACE::TFromV) +#define HWY_IF_FLOAT_OR_SPECIAL_V(V) \ + HWY_IF_FLOAT_OR_SPECIAL(hwy::HWY_NAMESPACE::TFromV) +#define HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V) \ + HWY_IF_NOT_FLOAT_NOR_SPECIAL(hwy::HWY_NAMESPACE::TFromV) + +#define HWY_IF_T_SIZE_V(V, bytes) \ + HWY_IF_T_SIZE(hwy::HWY_NAMESPACE::TFromV, bytes) +#define HWY_IF_NOT_T_SIZE_V(V, bytes) \ + HWY_IF_NOT_T_SIZE(hwy::HWY_NAMESPACE::TFromV, bytes) +#define HWY_IF_T_SIZE_ONE_OF_V(V, bit_array) \ + HWY_IF_T_SIZE_ONE_OF(hwy::HWY_NAMESPACE::TFromV, bit_array) + +#define HWY_MAX_LANES_V(V) HWY_MAX_LANES_D(hwy::HWY_NAMESPACE::DFromV) +#define HWY_IF_V_SIZE_V(V, bytes) \ + HWY_IF_V_SIZE(hwy::HWY_NAMESPACE::TFromV, HWY_MAX_LANES_V(V), bytes) +#define HWY_IF_V_SIZE_LE_V(V, bytes) \ + HWY_IF_V_SIZE_LE(hwy::HWY_NAMESPACE::TFromV, HWY_MAX_LANES_V(V), bytes) +#define HWY_IF_V_SIZE_GT_V(V, bytes) \ + HWY_IF_V_SIZE_GT(hwy::HWY_NAMESPACE::TFromV, HWY_MAX_LANES_V(V), bytes) + +// Use in implementations of ReduceSum etc. to avoid conflicts with the N=1 and +// N=4 8-bit specializations in generic_ops-inl. +#undef HWY_IF_REDUCE_D +#define HWY_IF_REDUCE_D(D) \ + hwy::EnableIf) != 1)>* = nullptr + +#undef HWY_IF_SUM_OF_LANES_D +#define HWY_IF_SUM_OF_LANES_D(D) HWY_IF_LANES_GT_D(D, 1) + +#undef HWY_IF_MINMAX_OF_LANES_D +#define HWY_IF_MINMAX_OF_LANES_D(D) HWY_IF_LANES_GT_D(D, 1) + +#undef HWY_IF_ADDSUB_V +#define HWY_IF_ADDSUB_V(V) HWY_IF_LANES_GT_D(hwy::HWY_NAMESPACE::DFromV, 1) + +#undef HWY_IF_MULADDSUB_V +#define HWY_IF_MULADDSUB_V(V) \ + HWY_IF_LANES_GT_D(hwy::HWY_NAMESPACE::DFromV, 1) + +#undef HWY_IF_PAIRWISE_ADD_128_D +#define HWY_IF_PAIRWISE_ADD_128_D(D) HWY_IF_V_SIZE_GT_D(D, 8) + +#undef HWY_IF_PAIRWISE_SUB_128_D +#define HWY_IF_PAIRWISE_SUB_128_D(D) HWY_IF_V_SIZE_GT_D(D, 8) + +// HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V is used to disable the default +// implementation of unsigned to signed DemoteTo/ReorderDemote2To in +// generic_ops-inl.h for at least some of the unsigned to signed demotions on +// SCALAR/EMU128/SSE2/SSSE3/SSE4/AVX2/SVE/SVE2/LSX/LASX + +#undef HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V +#define HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V(V) void* = nullptr + +// Old names (deprecated) +#define HWY_IF_LANE_SIZE_D(D, bytes) HWY_IF_T_SIZE_D(D, bytes) +#define HWY_IF_NOT_LANE_SIZE_D(D, bytes) HWY_IF_NOT_T_SIZE_D(D, bytes) + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // HIGHWAY_HWY_OPS_SHARED_TOGGLE diff --git a/lib/highway/hwy/ops/wasm_128-inl.h b/lib/highway/hwy/ops/wasm_128-inl.h new file mode 100644 index 00000000000..65b0e174efe --- /dev/null +++ b/lib/highway/hwy/ops/wasm_128-inl.h @@ -0,0 +1,5992 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 128-bit WASM vectors and operations. +// External include guard in highway.h - see comment there. + +#include + +#include "hwy/base.h" +#include "hwy/ops/shared-inl.h" + +#ifdef HWY_WASM_OLD_NAMES +#define wasm_i8x16_shuffle wasm_v8x16_shuffle +#define wasm_i16x8_shuffle wasm_v16x8_shuffle +#define wasm_i32x4_shuffle wasm_v32x4_shuffle +#define wasm_i64x2_shuffle wasm_v64x2_shuffle +#define wasm_u16x8_extend_low_u8x16 wasm_i16x8_widen_low_u8x16 +#define wasm_u32x4_extend_low_u16x8 wasm_i32x4_widen_low_u16x8 +#define wasm_i32x4_extend_low_i16x8 wasm_i32x4_widen_low_i16x8 +#define wasm_i16x8_extend_low_i8x16 wasm_i16x8_widen_low_i8x16 +#define wasm_u32x4_extend_high_u16x8 wasm_i32x4_widen_high_u16x8 +#define wasm_i32x4_extend_high_i16x8 wasm_i32x4_widen_high_i16x8 +#define wasm_i32x4_trunc_sat_f32x4 wasm_i32x4_trunc_saturate_f32x4 +#define wasm_i62x2_trunc_sat_f64x2 wasm_i64x2_trunc_saturate_f64x2 +#define wasm_u8x16_add_sat wasm_u8x16_add_saturate +#define wasm_u8x16_sub_sat wasm_u8x16_sub_saturate +#define wasm_u16x8_add_sat wasm_u16x8_add_saturate +#define wasm_u16x8_sub_sat wasm_u16x8_sub_saturate +#define wasm_i8x16_add_sat wasm_i8x16_add_saturate +#define wasm_i8x16_sub_sat wasm_i8x16_sub_saturate +#define wasm_i16x8_add_sat wasm_i16x8_add_saturate +#define wasm_i16x8_sub_sat wasm_i16x8_sub_saturate +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +#if HWY_TARGET == HWY_WASM_EMU256 +template +using Full256 = Simd; +#endif + +namespace detail { + +template +struct Raw128 { + using type = __v128_u; +}; +template <> +struct Raw128 { + using type = __f32x4; +}; +template <> +struct Raw128 { + using type = __f64x2; +}; + +} // namespace detail + +template +class Vec128 { + using Raw = typename detail::Raw128::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = N; // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator%=(const Vec128 other) { + return *this = (*this % other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +template +using Vec64 = Vec128; + +template +using Vec32 = Vec128; + +template +using Vec16 = Vec128; + +// FF..FF or 0. +template +struct Mask128 { + using PrivateT = T; // only for DFromM + static constexpr size_t kPrivateN = N; // only for DFromM + + typename detail::Raw128::type raw; +}; + +template +using DFromV = Simd; + +template +using DFromM = Simd; + +template +using TFromV = typename V::PrivateT; + +// ------------------------------ Zero + +// Use HWY_MAX_LANES_D here because VFromD is defined in terms of Zero. +template +HWY_API Vec128, HWY_MAX_LANES_D(D)> Zero(D /* tag */) { + return Vec128, HWY_MAX_LANES_D(D)>{wasm_i32x4_splat(0)}; +} +template +HWY_API Vec128, HWY_MAX_LANES_D(D)> Zero(D /* tag */) { + return Vec128, HWY_MAX_LANES_D(D)>{wasm_f32x4_splat(0.0f)}; +} +template +HWY_API Vec128, HWY_MAX_LANES_D(D)> Zero(D /* tag */) { + return Vec128, HWY_MAX_LANES_D(D)>{wasm_f64x2_splat(0.0)}; +} + +template +using VFromD = decltype(Zero(D())); + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __v128_u BitCastToInteger(__v128_u v) { return v; } +HWY_INLINE __v128_u BitCastToInteger(__f32x4 v) { + return static_cast<__v128_u>(v); +} +HWY_INLINE __v128_u BitCastToInteger(__f64x2 v) { + return static_cast<__v128_u>(v); +} + +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return Vec128{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger128 { + HWY_INLINE __v128_u operator()(__v128_u v) { return v; } +}; +template <> +struct BitCastFromInteger128 { + HWY_INLINE __f32x4 operator()(__v128_u v) { return static_cast<__f32x4>(v); } +}; +template <> +struct BitCastFromInteger128 { + HWY_INLINE __f64x2 operator()(__v128_u v) { return static_cast<__f64x2>(v); } +}; + +template +HWY_INLINE VFromD BitCastFromByte(D d, Vec128 v) { + return VFromD{BitCastFromInteger128>()(v.raw)}; +} + +} // namespace detail + +template +HWY_API VFromD BitCast(D d, + Vec128().MaxLanes()> v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ ResizeBitCast + +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + const Repartition du8_to; + return BitCast(d, VFromD{detail::BitCastToInteger(v.raw)}); +} + +// ------------------------------ Set + +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{wasm_i8x16_splat(static_cast(t))}; +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{wasm_i16x8_splat(static_cast(t))}; +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{wasm_i32x4_splat(static_cast(t))}; +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{wasm_i64x2_splat(static_cast(t))}; +} + +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{wasm_i16x8_splat(BitCastScalar(t))}; +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{wasm_f32x4_splat(t)}; +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{wasm_f64x2_splat(t)}; +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// For all vector sizes. +template +HWY_API VFromD Undefined(D d) { + return Zero(d); +} + +HWY_DIAGNOSTICS(pop) + +// For all vector sizes. +template , typename T2> +HWY_API VFromD Iota(D d, const T2 first) { + HWY_ALIGN T lanes[MaxLanes(d)]; + for (size_t i = 0; i < MaxLanes(d); ++i) { + lanes[i] = AddWithWraparound(static_cast(first), i); + } + return Load(d, lanes); +} + +// ------------------------------ Dup128VecFromValues +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + return VFromD{wasm_i8x16_make(t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, + t11, t12, t13, t14, t15)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + return VFromD{wasm_u8x16_make(t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, + t11, t12, t13, t14, t15)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return VFromD{wasm_i16x8_make(t0, t1, t2, t3, t4, t5, t6, t7)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return VFromD{wasm_u16x8_make(t0, t1, t2, t3, t4, t5, t6, t7)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const RebindToSigned di; + return BitCast(d, + Dup128VecFromValues( + di, BitCastScalar(t0), BitCastScalar(t1), + BitCastScalar(t2), BitCastScalar(t3), + BitCastScalar(t4), BitCastScalar(t5), + BitCastScalar(t6), BitCastScalar(t7))); +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{wasm_i32x4_make(t0, t1, t2, t3)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{wasm_u32x4_make(t0, t1, t2, t3)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{wasm_f32x4_make(t0, t1, t2, t3)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{wasm_i64x2_make(t0, t1)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{wasm_u64x2_make(t0, t1)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{wasm_f64x2_make(t0, t1)}; +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_add(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_add(a.raw, b.raw)}; +} + +// Float +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f32x4_add(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f64x2_add(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(Vec128 a, + Vec128 b) { + return Vec128{wasm_i16x8_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_sub(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_sub(a.raw, b.raw)}; +} + +// Float +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f32x4_sub(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f64x2_sub(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u8x16_add_sat(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u16x8_add_sat(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_add_sat(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_add_sat(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u8x16_sub_sat(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u16x8_sub_sat(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i8x16_sub_sat(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_sub_sat(a.raw, b.raw)}; +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +// Unsigned +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u8x16_avgr(a.raw, b.raw)}; +} +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_u16x8_avgr(a.raw, b.raw)}; +} + +template +HWY_API V AverageRound(V a, V b) { + const DFromV d; + const RebindToUnsigned du; + const V sign_bit = SignBit(d); + return Xor(BitCast(d, AverageRound(BitCast(du, Xor(a, sign_bit)), + BitCast(du, Xor(b, sign_bit)))), + sign_bit); +} + +// ------------------------------ Absolute value + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_i8x16_abs(v.raw)}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_i16x8_abs(v.raw)}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_i32x4_abs(v.raw)}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_i64x2_abs(v.raw)}; +} + +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_f32x4_abs(v.raw)}; +} +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{wasm_f64x2_abs(v.raw)}; +} + +// ------------------------------ Shift lanes by constant #bits + +// Unsigned +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i16x8_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_u16x8_shr(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i32x4_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i64x2_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_u32x4_shr(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_u64x2_shr(v.raw, kBits)}; +} + +// Signed +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i16x8_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_i16x8_shr(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i32x4_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{wasm_i64x2_shl(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_i32x4_shr(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{wasm_i64x2_shr(v.raw, kBits)}; +} + +// 8-bit +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ShiftLeft(Vec128>{v.raw}).raw}; + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRight(Vec128{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ RotateRight (ShiftRight, Or) +template +HWY_API Vec128 RotateRight(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + + constexpr size_t kSizeInBits = sizeof(T) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + + if (kBits == 0) return v; + return Or(BitCast(d, ShiftRight(BitCast(du, v))), + ShiftLeft(v)); +} + +// ------------------------------ Shift lanes by same variable #bits + +// After https://reviews.llvm.org/D108415 shift argument became unsigned. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +// Unsigned +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i16x8_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_u16x8_shr(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i32x4_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_u32x4_shr(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i64x2_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_u64x2_shr(v.raw, bits)}; +} + +// Signed +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i16x8_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i16x8_shr(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i32x4_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i32x4_shr(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i64x2_shl(v.raw, bits)}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { + return Vec128{wasm_i64x2_shr(v.raw, bits)}; +} + +// 8-bit +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, const int bits) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftLeftSame(Vec128>{v.raw}, bits).raw}; + return shifted & Set(d8, static_cast((0xFF << bits) & 0xFF)); +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, + const int bits) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRightSame(Vec128{v.raw}, bits).raw}; + return shifted & Set(d8, 0xFF >> bits); +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> bits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ignore Wsign-conversion +HWY_DIAGNOSTICS(pop) + +// ------------------------------ Minimum + +// Unsigned +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_u8x16_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_u16x8_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_u32x4_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + // Avoid wasm_u64x2_extract_lane - not all implementations have it yet. + const uint64_t a0 = static_cast(wasm_i64x2_extract_lane(a.raw, 0)); + const uint64_t b0 = static_cast(wasm_i64x2_extract_lane(b.raw, 0)); + const uint64_t a1 = static_cast(wasm_i64x2_extract_lane(a.raw, 1)); + const uint64_t b1 = static_cast(wasm_i64x2_extract_lane(b.raw, 1)); + alignas(16) uint64_t min[2] = {HWY_MIN(a0, b0), HWY_MIN(a1, b1)}; + return Vec128{wasm_v128_load(min)}; +} + +// Signed +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_i8x16_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_i16x8_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{wasm_i32x4_min(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + alignas(16) int64_t min[4]; + min[0] = HWY_MIN(wasm_i64x2_extract_lane(a.raw, 0), + wasm_i64x2_extract_lane(b.raw, 0)); + min[1] = HWY_MIN(wasm_i64x2_extract_lane(a.raw, 1), + wasm_i64x2_extract_lane(b.raw, 1)); + return Vec128{wasm_v128_load(min)}; +} + +// Float +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + // Equivalent to a < b ? a : b (taking into account our swapped arg order, + // so that Min(NaN, x) is x to match x86). + return Vec128{wasm_f32x4_pmin(b.raw, a.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + // Equivalent to a < b ? a : b (taking into account our swapped arg order, + // so that Min(NaN, x) is x to match x86). + return Vec128{wasm_f64x2_pmin(b.raw, a.raw)}; +} + +// ------------------------------ Maximum + +// Unsigned +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_u8x16_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_u16x8_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_u32x4_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + // Avoid wasm_u64x2_extract_lane - not all implementations have it yet. + const uint64_t a0 = static_cast(wasm_i64x2_extract_lane(a.raw, 0)); + const uint64_t b0 = static_cast(wasm_i64x2_extract_lane(b.raw, 0)); + const uint64_t a1 = static_cast(wasm_i64x2_extract_lane(a.raw, 1)); + const uint64_t b1 = static_cast(wasm_i64x2_extract_lane(b.raw, 1)); + alignas(16) uint64_t max[2] = {HWY_MAX(a0, b0), HWY_MAX(a1, b1)}; + return Vec128{wasm_v128_load(max)}; +} + +// Signed +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_i8x16_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_i16x8_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{wasm_i32x4_max(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + alignas(16) int64_t max[2]; + max[0] = HWY_MAX(wasm_i64x2_extract_lane(a.raw, 0), + wasm_i64x2_extract_lane(b.raw, 0)); + max[1] = HWY_MAX(wasm_i64x2_extract_lane(a.raw, 1), + wasm_i64x2_extract_lane(b.raw, 1)); + return Vec128{wasm_v128_load(max)}; +} + +// Float +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + // Equivalent to b < a ? a : b (taking into account our swapped arg order, + // so that Max(NaN, x) is x to match x86). + return Vec128{wasm_f32x4_pmax(b.raw, a.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + // Equivalent to b < a ? a : b (taking into account our swapped arg order, + // so that Max(NaN, x) is x to match x86). + return Vec128{wasm_f64x2_pmax(b.raw, a.raw)}; +} + +// ------------------------------ MinNumber and MaxNumber + +#ifdef HWY_NATIVE_FLOAT_MIN_MAX_NUMBER +#undef HWY_NATIVE_FLOAT_MIN_MAX_NUMBER +#else +#define HWY_NATIVE_FLOAT_MIN_MAX_NUMBER +#endif + +template +HWY_API V MinNumber(V a, V b) { + return Min(a, IfThenElse(IsNaN(b), a, b)); +} + +template +HWY_API V MaxNumber(V a, V b) { + return Max(a, IfThenElse(IsNaN(b), a, b)); +} + +// ------------------------------ Integer multiplication + +// Unsigned +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_mul(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_mul(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i16x8_mul(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_mul(a.raw, b.raw)}; +} + +// Returns the upper sizeof(T)*8 bits of a * b in each lane. +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + const auto l = wasm_u16x8_extmul_low_u8x16(a.raw, b.raw); + const auto h = wasm_u16x8_extmul_high_u8x16(a.raw, b.raw); + // TODO(eustas): shift-right + narrow? + return Vec128{wasm_i8x16_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15, + 17, 19, 21, 23, 25, 27, 29, 31)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + const auto l = wasm_i16x8_extmul_low_i8x16(a.raw, b.raw); + const auto h = wasm_i16x8_extmul_high_i8x16(a.raw, b.raw); + // TODO(eustas): shift-right + narrow? + return Vec128{wasm_i8x16_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15, + 17, 19, 21, 23, 25, 27, 29, 31)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + const auto l = wasm_u32x4_extmul_low_u16x8(a.raw, b.raw); + const auto h = wasm_u32x4_extmul_high_u16x8(a.raw, b.raw); + // TODO(eustas): shift-right + narrow? + return Vec128{ + wasm_i16x8_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + const auto l = wasm_i32x4_extmul_low_i16x8(a.raw, b.raw); + const auto h = wasm_i32x4_extmul_high_i16x8(a.raw, b.raw); + // TODO(eustas): shift-right + narrow? + return Vec128{ + wasm_i16x8_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + const auto l = wasm_u64x2_extmul_low_u32x4(a.raw, b.raw); + const auto h = wasm_u64x2_extmul_high_u32x4(a.raw, b.raw); + // TODO(eustas): shift-right + narrow? + return Vec128{wasm_i32x4_shuffle(l, h, 1, 3, 5, 7)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + const auto l = wasm_i64x2_extmul_low_i32x4(a.raw, b.raw); + const auto h = wasm_i64x2_extmul_high_i32x4(a.raw, b.raw); + // TODO(eustas): shift-right + narrow? + return Vec128{wasm_i32x4_shuffle(l, h, 1, 3, 5, 7)}; +} + +template +HWY_API Vec128 MulFixedPoint15(Vec128 a, + Vec128 b) { + return Vec128{wasm_i16x8_q15mulr_sat(a.raw, b.raw)}; +} + +// Multiplies even lanes (0, 2 ..) and returns the double-width result. +template +HWY_API Vec128, (N + 1) / 2> MulEven(const Vec128 a, + const Vec128 b) { + const DFromV d; + const RepartitionToWide dw; + constexpr int kSrcBits = sizeof(T) * 8; + + const auto ae = + ShiftRight(ShiftLeft(ResizeBitCast(dw, a))); + const auto be = + ShiftRight(ShiftLeft(ResizeBitCast(dw, b))); + return ae * be; +} +template +HWY_API Vec128, (N + 1) / 2> MulEven(const Vec128 a, + const Vec128 b) { + const DFromV d; + const RepartitionToWide dw; + const auto kEvenMask = Set(dw, LimitsMax()); + + const auto ae = And(ResizeBitCast(dw, a), kEvenMask); + const auto be = And(ResizeBitCast(dw, b), kEvenMask); + return ae * be; +} +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + const DFromV d; + const RepartitionToWide dw; + const auto ae = ShiftRight<32>(ShiftLeft<32>(ResizeBitCast(dw, a))).raw; + const auto be = ShiftRight<32>(ShiftLeft<32>(ResizeBitCast(dw, b))).raw; + return Vec128{wasm_i64x2_mul(ae, be)}; +} +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + const auto kEvenMask = wasm_i32x4_make(-1, 0, -1, 0); + const auto ae = wasm_v128_and(a.raw, kEvenMask); + const auto be = wasm_v128_and(b.raw, kEvenMask); + return Vec128{wasm_i64x2_mul(ae, be)}; +} + +// Multiplies odd lanes (1, 3 ..) and returns the double-width result. +template +HWY_API Vec128, (N + 1) / 2> MulOdd(const Vec128 a, + const Vec128 b) { + const DFromV d; + const RepartitionToWide dw; + constexpr int kSrcBits = sizeof(T) * 8; + + const auto ao = ShiftRight(BitCast(dw, a)); + const auto bo = ShiftRight(BitCast(dw, b)); + return ao * bo; +} +template +HWY_API Vec128, (N + 1) / 2> MulOdd(const Vec128 a, + const Vec128 b) { + const DFromV d; + const RepartitionToWide dw; + + const auto ao = ShiftRight<32>(BitCast(dw, a)); + const auto bo = ShiftRight<32>(BitCast(dw, b)); + return Vec128, (N + 1) / 2>{wasm_i64x2_mul(ao.raw, bo.raw)}; +} + +// ------------------------------ Negate + +template +HWY_API Vec128 Neg(const Vec128 v) { + return Xor(v, SignBit(DFromV())); +} + +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{wasm_i8x16_neg(v.raw)}; +} +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{wasm_i16x8_neg(v.raw)}; +} +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{wasm_i32x4_neg(v.raw)}; +} +template +HWY_API Vec128 Neg(const Vec128 v) { + return Vec128{wasm_i64x2_neg(v.raw)}; +} + +// ------------------------------ Floating-point mul / div + +template +HWY_API Vec128 operator*(Vec128 a, Vec128 b) { + return Vec128{wasm_f32x4_mul(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(Vec128 a, Vec128 b) { + return Vec128{wasm_f64x2_mul(a.raw, b.raw)}; +} + +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f32x4_div(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_f64x2_div(a.raw, b.raw)}; +} + +template )> +HWY_API V ApproximateReciprocal(const V v) { + return Set(DFromV(), 1.0f) / v; +} + +// Integer overload defined in generic_ops-inl.h. +template +HWY_API Vec128 AbsDiff(const Vec128 a, const Vec128 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point multiply-add variants + +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return mul * x + add; +} + +template +HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, + Vec128 add) { + return add - mul * x; +} + +template +HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, + Vec128 sub) { + return mul * x - sub; +} + +template +HWY_API Vec128 NegMulSub(Vec128 mul, Vec128 x, + Vec128 sub) { + return Neg(mul) * x - sub; +} + +// ------------------------------ Floating-point square root + +// Full precision square root +template +HWY_API Vec128 Sqrt(const Vec128 v) { + return Vec128{wasm_f32x4_sqrt(v.raw)}; +} +template +HWY_API Vec128 Sqrt(const Vec128 v) { + return Vec128{wasm_f64x2_sqrt(v.raw)}; +} + +// Approximate reciprocal square root +template )> +HWY_API V ApproximateReciprocalSqrt(V v) { + // TODO(eustas): find cheaper a way to calculate this. + return Set(DFromV(), static_cast>(1.0)) / Sqrt(v); +} + +// ------------------------------ Floating-point rounding + +// Toward nearest integer, ties to even +template +HWY_API Vec128 Round(const Vec128 v) { + return Vec128{wasm_f32x4_nearest(v.raw)}; +} +template +HWY_API Vec128 Round(const Vec128 v) { + return Vec128{wasm_f64x2_nearest(v.raw)}; +} + +// Toward zero, aka truncate +template +HWY_API Vec128 Trunc(const Vec128 v) { + return Vec128{wasm_f32x4_trunc(v.raw)}; +} +template +HWY_API Vec128 Trunc(const Vec128 v) { + return Vec128{wasm_f64x2_trunc(v.raw)}; +} + +// Toward +infinity, aka ceiling +template +HWY_API Vec128 Ceil(const Vec128 v) { + return Vec128{wasm_f32x4_ceil(v.raw)}; +} +template +HWY_API Vec128 Ceil(const Vec128 v) { + return Vec128{wasm_f64x2_ceil(v.raw)}; +} + +// Toward -infinity, aka floor +template +HWY_API Vec128 Floor(const Vec128 v) { + return Vec128{wasm_f32x4_floor(v.raw)}; +} +template +HWY_API Vec128 Floor(const Vec128 v) { + return Vec128{wasm_f64x2_floor(v.raw)}; +} + +// ------------------------------ Floating-point classification +template +HWY_API Mask128 IsNaN(const Vec128 v) { + return v != v; +} + +template +HWY_API Mask128 IsInf(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vu, vu), Set(du, hwy::MaxExponentTimes2()))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask128 IsFinite(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD exp = + BitCast(di, ShiftRight() + 1>(Add(vu, vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); +} + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + return Mask128{v.raw}; +} + +template +using MFromD = decltype(MaskFromVec(VFromD())); + +template +HWY_API MFromD RebindMask(DTo /* tag */, Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TFromD), "Must have same size"); + return MFromD{m.raw}; +} + +template +HWY_API Mask128 TestBit(Vec128 v, Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Equality + +// Unsigned +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i16x8_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_eq(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{wasm_i16x8_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_eq(a.raw, b.raw)}; +} + +// Float +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f32x4_eq(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f64x2_eq(a.raw, b.raw)}; +} + +// ------------------------------ Inequality + +// Unsigned +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i16x8_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_ne(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i16x8_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_ne(a.raw, b.raw)}; +} + +// Float +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f32x4_ne(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f64x2_ne(a.raw, b.raw)}; +} + +// ------------------------------ Strict inequality + +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i16x8_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_gt(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_u8x16_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_u16x8_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_u32x4_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + const DFromV d; + const Repartition d32; + const auto a32 = BitCast(d32, a); + const auto b32 = BitCast(d32, b); + // If the upper halves are not equal, this is the answer. + const auto m_gt = a32 > b32; + + // Otherwise, the lower half decides. + const auto m_eq = a32 == b32; + const auto lo_in_hi = wasm_i32x4_shuffle(m_gt.raw, m_gt.raw, 0, 0, 2, 2); + const auto lo_gt = And(m_eq, MaskFromVec(VFromD{lo_in_hi})); + + const auto gt = Or(lo_gt, m_gt); + // Copy result in upper 32 bits to lower 32 bits. + return Mask128{wasm_i32x4_shuffle(gt.raw, gt.raw, 1, 1, 3, 3)}; +} + +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f32x4_gt(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f64x2_gt(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator<(const Vec128 a, const Vec128 b) { + return operator>(b, a); +} + +// ------------------------------ Weak inequality + +// Float >= +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f32x4_ge(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_f64x2_ge(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i8x16_ge(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i16x8_ge(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i32x4_ge(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_i64x2_ge(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_u8x16_ge(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_u16x8_ge(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Mask128{wasm_u32x4_ge(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(const Vec128 a, + const Vec128 b) { + return Not(b > a); +} + +template +HWY_API Mask128 operator<=(const Vec128 a, const Vec128 b) { + return operator>=(b, a); +} + +// ------------------------------ FirstN (Iota, Lt) + +template +HWY_API MFromD FirstN(D d, size_t num) { + const RebindToSigned di; // Signed comparisons may be cheaper. + using TI = TFromD; + return RebindMask(d, Iota(di, 0) < Set(di, static_cast(num))); +} + +// ================================================== LOGICAL + +// ------------------------------ Not + +template +HWY_API Vec128 Not(Vec128 v) { + return Vec128{wasm_v128_not(v.raw)}; +} + +// ------------------------------ And + +template +HWY_API Vec128 And(Vec128 a, Vec128 b) { + return Vec128{wasm_v128_and(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec128 AndNot(Vec128 not_mask, Vec128 mask) { + return Vec128{wasm_v128_andnot(mask.raw, not_mask.raw)}; +} + +// ------------------------------ Or + +template +HWY_API Vec128 Or(Vec128 a, Vec128 b) { + return Vec128{wasm_v128_or(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template +HWY_API Vec128 Xor(Vec128 a, Vec128 b) { + return Vec128{wasm_v128_xor(a.raw, b.raw)}; +} + +// ------------------------------ Or3 + +template +HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { + return Or(o1, Or(o2, o3)); +} + +// ------------------------------ OrAnd + +template +HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { + return Or(o, And(a1, a2)); +} + +// ------------------------------ IfVecThenElse + +template +HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, + Vec128 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec128 operator&(const Vec128 a, const Vec128 b) { + return And(a, b); +} + +template +HWY_API Vec128 operator|(const Vec128 a, const Vec128 b) { + return Or(a, b); +} + +template +HWY_API Vec128 operator^(const Vec128 a, const Vec128 b) { + return Xor(a, b); +} + +// ------------------------------ CopySign +template +HWY_API Vec128 CopySign(const Vec128 magn, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const DFromV d; + return BitwiseIfThenElse(SignBit(d), sign, magn); +} + +// ------------------------------ CopySignToAbs +template +HWY_API Vec128 CopySignToAbs(const Vec128 abs, + const Vec128 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const DFromV d; + return OrAnd(abs, SignBit(d), sign); +} + +// ------------------------------ BroadcastSignBit (compare) + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight(v); +} +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + const DFromV d; + return VecFromMask(d, v < Zero(d)); +} + +// ------------------------------ Mask + +template +HWY_API VFromD VecFromMask(D /* tag */, MFromD v) { + return VFromD{v.raw}; +} + +// mask ? yes : no +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{wasm_v128_bitselect(yes.raw, no.raw, mask.raw)}; +} + +// mask ? yes : 0 +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return yes & VecFromMask(DFromV(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return AndNot(VecFromMask(DFromV(), mask), no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const DFromV d; + const RebindToSigned di; + + v = BitCast(d, BroadcastSignBit(BitCast(di, v))); + return IfThenElse(MaskFromVec(v), yes, no); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(const Mask128 m) { + const DFromM d; + return MaskFromVec(Not(VecFromMask(d, m))); +} + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + const DFromM d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + const DFromM d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + const DFromM d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + const DFromM d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + const DFromM d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ------------------------------ Shl (BroadcastSignBit, IfThenElse) + +// The x86 multiply-by-Pow2() trick will not work because WASM saturates +// float->int correctly to 2^31-1 (not 2^31). Because WASM's shifts take a +// scalar count operand, per-lane shift instructions would require extract_lane +// for each lane, and hoping that shuffle is correctly mapped to a native +// instruction. Using non-vector shifts would incur a store-load forwarding +// stall when loading the result vector. We instead test bits of the shift +// count to "predicate" a shift of the entire vector by a constant. + +template +HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { + const DFromV d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<5>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftLeft<1>(v), v); +} + +template +HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { + const DFromV d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<12>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftLeft<1>(v), v); +} + +template +HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { + const DFromV d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<27>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<16>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftLeft<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftLeft<1>(v), v); +} + +template +HWY_API Vec128 operator<<(Vec128 v, const Vec128 bits) { + const DFromV d; + const RebindToUnsigned du; + using TU = MakeUnsigned; + alignas(16) TU lanes[2] = {}; + alignas(16) TU bits_lanes[2] = {}; + Store(BitCast(du, v), du, lanes); + Store(BitCast(du, bits), du, bits_lanes); + lanes[0] <<= (bits_lanes[0] & 63); + lanes[1] <<= (bits_lanes[1] & 63); + return BitCast(d, Load(du, lanes)); +} + +// ------------------------------ Shr (BroadcastSignBit, IfThenElse) + +template +HWY_API Vec128 operator>>(Vec128 v, const Vec128 bits) { + const DFromV d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<5>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftRight<1>(v), v); +} + +template +HWY_API Vec128 operator>>(Vec128 v, const Vec128 bits) { + const DFromV d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<12>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftRight<1>(v), v); +} + +template +HWY_API Vec128 operator>>(Vec128 v, const Vec128 bits) { + const DFromV d; + Mask128 mask; + // Need a signed type for BroadcastSignBit. + auto test = BitCast(RebindToSigned(), bits); + // Move the highest valid bit of the shift count into the sign bit. + test = ShiftLeft<27>(test); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<16>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<8>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<4>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + test = ShiftLeft<1>(test); // next bit (descending order) + v = IfThenElse(mask, ShiftRight<2>(v), v); + + mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); + return IfThenElse(mask, ShiftRight<1>(v), v); +} + +template +HWY_API Vec128 operator>>(Vec128 v, const Vec128 bits) { + const DFromV d; + alignas(16) T lanes[2] = {}; + alignas(16) T bits_lanes[2] = {}; + Store(v, d, lanes); + Store(bits, d, bits_lanes); + lanes[0] >>= (bits_lanes[0] & 63); + lanes[1] >>= (bits_lanes[1] & 63); + return Load(d, lanes); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template > +HWY_API Vec128 Load(D /* tag */, const T* HWY_RESTRICT aligned) { + return Vec128{wasm_v128_load(aligned)}; +} + +// Partial +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT p) { + VFromD v; + CopyBytes(p, &v); + return v; +} + +// LoadU == Load. +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + return Load(d, p); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template +HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { + return Load(d, p); +} + +template > +HWY_API VFromD MaskedLoad(MFromD m, D d, const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +template > +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, + const T* HWY_RESTRICT aligned) { + return IfThenElse(m, Load(d, aligned), v); +} + +// ------------------------------ Store + +namespace detail { + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + return static_cast(wasm_i8x16_extract_lane(v.raw, kLane)); +} +template +HWY_INLINE T ExtractLane(const Vec128 v) { + const int16_t lane = wasm_i16x8_extract_lane(v.raw, kLane); + return static_cast(lane); +} +template +HWY_INLINE T ExtractLane(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + + const uint16_t bits = ExtractLane(BitCast(du, v)); + return BitCastScalar(bits); +} +template +HWY_INLINE T ExtractLane(const Vec128 v) { + return static_cast(wasm_i32x4_extract_lane(v.raw, kLane)); +} +template +HWY_INLINE T ExtractLane(const Vec128 v) { + return static_cast(wasm_i64x2_extract_lane(v.raw, kLane)); +} + +template +HWY_INLINE float ExtractLane(const Vec128 v) { + return wasm_f32x4_extract_lane(v.raw, kLane); +} +template +HWY_INLINE double ExtractLane(const Vec128 v) { + return wasm_f64x2_extract_lane(v.raw, kLane); +} + +} // namespace detail + +template +HWY_API void Store(VFromD v, D /* tag */, TFromD* HWY_RESTRICT aligned) { + wasm_v128_store(aligned, v.raw); +} + +// Partial +template +HWY_API void Store(VFromD v, D d, TFromD* HWY_RESTRICT p) { + CopyBytes(&v, p); +} + +template +HWY_API void Store(VFromD v, D /* tag */, TFromD* HWY_RESTRICT p) { + *p = detail::ExtractLane<0>(v); +} + +// StoreU == Store. +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + Store(v, d, p); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + StoreU(IfThenElse(m, v, LoadU(d, p)), d, p); +} + +// ------------------------------ Non-temporal stores + +// Same as aligned stores on non-x86. + +template +HWY_API void Stream(VFromD v, D /* tag */, TFromD* HWY_RESTRICT aligned) { + wasm_v128_store(aligned, v.raw); +} + +// ------------------------------ Scatter in generic_ops-inl.h +// ------------------------------ Gather in generic_ops-inl.h + +// ================================================== SWIZZLE + +// ------------------------------ ExtractLane + +// One overload per vector length just in case *_extract_lane raise compile +// errors if their argument is out of bounds (even if that would never be +// reached at runtime). +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return detail::ExtractLane<0>(v); +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + } + } +#endif + alignas(16) T lanes[2]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + } + } +#endif + alignas(16) T lanes[4]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + } + } +#endif + alignas(16) T lanes[8]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + case 8: + return detail::ExtractLane<8>(v); + case 9: + return detail::ExtractLane<9>(v); + case 10: + return detail::ExtractLane<10>(v); + case 11: + return detail::ExtractLane<11>(v); + case 12: + return detail::ExtractLane<12>(v); + case 13: + return detail::ExtractLane<13>(v); + case 14: + return detail::ExtractLane<14>(v); + case 15: + return detail::ExtractLane<15>(v); + } + } +#endif + alignas(16) T lanes[16]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +// ------------------------------ GetLane +template +HWY_API T GetLane(const Vec128 v) { + return detail::ExtractLane<0>(v); +} + +// ------------------------------ InsertLane + +namespace detail { + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{ + wasm_i8x16_replace_lane(v.raw, kLane, static_cast(t))}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{ + wasm_i16x8_replace_lane(v.raw, kLane, BitCastScalar(t))}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{ + wasm_i32x4_replace_lane(v.raw, kLane, static_cast(t))}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{ + wasm_i64x2_replace_lane(v.raw, kLane, static_cast(t))}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, float t) { + static_assert(kLane < N, "Lane index out of bounds"); + return Vec128{wasm_f32x4_replace_lane(v.raw, kLane, t)}; +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, double t) { + static_assert(kLane < 2, "Lane index out of bounds"); + return Vec128{wasm_f64x2_replace_lane(v.raw, kLane, t)}; +} + +} // namespace detail + +// Requires one overload per vector length because InsertLane<3> may be a +// compile error if it calls wasm_f64x2_replace_lane. + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + return Set(DFromV(), t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[2]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[4]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[8]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + case 8: + return detail::InsertLane<8>(v, t); + case 9: + return detail::InsertLane<9>(v, t); + case 10: + return detail::InsertLane<10>(v, t); + case 11: + return detail::InsertLane<11>(v, t); + case 12: + return detail::InsertLane<12>(v, t); + case 13: + return detail::InsertLane<13>(v, t); + case 14: + return detail::InsertLane<14>(v, t); + case 15: + return detail::InsertLane<15>(v, t); + } + } +#endif + const DFromV d; + alignas(16) T lanes[16]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ------------------------------ LowerHalf + +template +HWY_API VFromD LowerHalf(D /* tag */, VFromD> v) { + return VFromD{v.raw}; +} +template +HWY_API Vec128 LowerHalf(Vec128 v) { + return Vec128{v.raw}; +} + +// ------------------------------ ShiftLeftBytes + +// 0x01..0F, kBytes = 1 => 0x02..0F00 +template +HWY_API VFromD ShiftLeftBytes(D /* tag */, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + const __i8x16 zero = wasm_i8x16_splat(0); + switch (kBytes) { + case 0: + return v; + + case 1: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14)}; + + case 2: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13)}; + + case 3: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 0, 1, 2, 3, + 4, 5, 6, 7, 8, 9, 10, 11, 12)}; + + case 4: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 0, 1, 2, + 3, 4, 5, 6, 7, 8, 9, 10, 11)}; + + case 5: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 0, 1, + 2, 3, 4, 5, 6, 7, 8, 9, 10)}; + + case 6: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9)}; + + case 7: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 0, 1, 2, 3, 4, 5, 6, 7, 8)}; + + case 8: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 0, 1, 2, 3, 4, 5, 6, 7)}; + + case 9: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 0, 1, 2, 3, 4, 5, 6)}; + + case 10: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 0, 1, 2, 3, 4, 5)}; + + case 11: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 0, 1, 2, 3, 4)}; + + case 12: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 0, 1, 2, 3)}; + + case 13: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 0, 1, 2)}; + + case 14: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 0, + 1)}; + + case 15: + return VFromD{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, + 0)}; + } + return VFromD{zero}; +} + +template +HWY_API Vec128 ShiftLeftBytes(Vec128 v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftLeftLanes + +template +HWY_API VFromD ShiftLeftLanes(D d, const VFromD v) { + const Repartition d8; + constexpr size_t kBytes = kLanes * sizeof(TFromD); + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec128 ShiftLeftLanes(const Vec128 v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes +namespace detail { + +// Helper function allows zeroing invalid lanes in caller. +template +HWY_API __i8x16 ShrBytes(const Vec128 v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + const __i8x16 zero = wasm_i8x16_splat(0); + + switch (kBytes) { + case 0: + return v.raw; + + case 1: + return wasm_i8x16_shuffle(v.raw, zero, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16); + + case 2: + return wasm_i8x16_shuffle(v.raw, zero, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 16); + + case 3: + return wasm_i8x16_shuffle(v.raw, zero, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 16, 16); + + case 4: + return wasm_i8x16_shuffle(v.raw, zero, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 16, 16, 16); + + case 5: + return wasm_i8x16_shuffle(v.raw, zero, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 16, 16, 16, 16); + + case 6: + return wasm_i8x16_shuffle(v.raw, zero, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 16, 16, 16, 16, 16); + + case 7: + return wasm_i8x16_shuffle(v.raw, zero, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 16, 16, 16, 16, 16, 16); + + case 8: + return wasm_i8x16_shuffle(v.raw, zero, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 9: + return wasm_i8x16_shuffle(v.raw, zero, 9, 10, 11, 12, 13, 14, 15, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 10: + return wasm_i8x16_shuffle(v.raw, zero, 10, 11, 12, 13, 14, 15, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 11: + return wasm_i8x16_shuffle(v.raw, zero, 11, 12, 13, 14, 15, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 12: + return wasm_i8x16_shuffle(v.raw, zero, 12, 13, 14, 15, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 13: + return wasm_i8x16_shuffle(v.raw, zero, 13, 14, 15, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 14: + return wasm_i8x16_shuffle(v.raw, zero, 14, 15, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + + case 15: + return wasm_i8x16_shuffle(v.raw, zero, 15, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16); + case 16: + return zero; + } +} + +} // namespace detail + +// 0x01..0F, kBytes = 1 => 0x0001..0E +template +HWY_API VFromD ShiftRightBytes(D d, VFromD v) { + // For partial vectors, clear upper lanes so we shift in zeros. + if (d.MaxBytes() != 16) { + const Full128> dfull; + const VFromD vfull{v.raw}; + v = VFromD{IfThenElseZero(FirstN(dfull, MaxLanes(d)), vfull).raw}; + } + return VFromD{detail::ShrBytes(v)}; +} + +// ------------------------------ ShiftRightLanes +template +HWY_API VFromD ShiftRightLanes(D d, const VFromD v) { + const Repartition d8; + constexpr size_t kBytes = kLanes * sizeof(TFromD); + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +template > +HWY_API Vec64 UpperHalf(D /* tag */, const Vec128 v) { + return Vec64{wasm_i32x4_shuffle(v.raw, v.raw, 2, 3, 2, 3)}; +} + +// Partial +template +HWY_API VFromD UpperHalf(D d, VFromD> v) { + return LowerHalf(d, ShiftRightBytes(Twice(), v)); +} + +// ------------------------------ CombineShiftRightBytes + +template > +HWY_API Vec128 CombineShiftRightBytes(D /* tag */, Vec128 hi, + Vec128 lo) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + switch (kBytes) { + case 0: + return lo; + + case 1: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 16)}; + + case 2: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, 17)}; + + case 3: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18)}; + + case 4: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19)}; + + case 5: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20)}; + + case 6: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21)}; + + case 7: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22)}; + + case 8: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23)}; + + case 9: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24)}; + + case 10: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25)}; + + case 11: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26)}; + + case 12: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, + 26, 27)}; + + case 13: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, + 27, 28)}; + + case 14: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, + 28, 29)}; + + case 15: + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30)}; + } + return hi; +} + +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + constexpr size_t kSize = d.MaxBytes(); + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); + const Repartition d8; + using V8 = Vec128; + const DFromV dfull8; + const Repartition, decltype(dfull8)> dfull; + const V8 hi8{BitCast(d8, hi).raw}; + // Move into most-significant bytes + const V8 lo8 = ShiftLeftBytes<16 - kSize>(V8{BitCast(d8, lo).raw}); + const V8 r = CombineShiftRightBytes<16 - kSize + kBytes>(dfull8, hi8, lo8); + return VFromD{BitCast(dfull, r).raw}; +} + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{wasm_i8x16_shuffle( + v.raw, v.raw, kLane, kLane, kLane, kLane, kLane, kLane, kLane, kLane, + kLane, kLane, kLane, kLane, kLane, kLane, kLane, kLane)}; +} + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{wasm_i16x8_shuffle(v.raw, v.raw, kLane, kLane, kLane, + kLane, kLane, kLane, kLane, kLane)}; +} + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{ + wasm_i32x4_shuffle(v.raw, v.raw, kLane, kLane, kLane, kLane)}; +} + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, kLane, kLane)}; +} + +// ------------------------------ TableLookupBytes + +// Returns vector of bytes[from[i]]. "from" is also interpreted as bytes, i.e. +// lane indices in [0, 16). +template +HWY_API Vec128 TableLookupBytes(const Vec128 bytes, + const Vec128 from) { + return Vec128{wasm_i8x16_swizzle(bytes.raw, from.raw)}; +} + +template +HWY_API Vec128 TableLookupBytesOr0(const Vec128 bytes, + const Vec128 from) { + const DFromV d; + // Mask size must match vector type, so cast everything to this type. + Repartition di8; + Repartition> d_bytes8; + const auto msb = BitCast(di8, from) < Zero(di8); + const auto lookup = + TableLookupBytes(BitCast(d_bytes8, bytes), BitCast(di8, from)); + return BitCast(d, IfThenZeroElse(msb, lookup)); +} + +// ------------------------------ Hard-coded shuffles + +// Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec128 Shuffle2301(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 0, 3, 2)}; +} + +// These are used by generic_ops-inl to implement LoadInterleaved3. +namespace detail { + +template +HWY_API Vec128 ShuffleTwo2301(const Vec128 a, + const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i8x16_shuffle(a.raw, b.raw, 1, 0, 3 + 16, 2 + 16, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F)}; +} +template +HWY_API Vec128 ShuffleTwo2301(const Vec128 a, + const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i16x8_shuffle(a.raw, b.raw, 1, 0, 3 + 8, 2 + 8, + 0x7FFF, 0x7FFF, 0x7FFF, 0x7FFF)}; +} +template +HWY_API Vec128 ShuffleTwo2301(const Vec128 a, + const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 1, 0, 3 + 4, 2 + 4)}; +} + +template +HWY_API Vec128 ShuffleTwo1230(const Vec128 a, + const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i8x16_shuffle(a.raw, b.raw, 0, 3, 2 + 16, 1 + 16, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F)}; +} +template +HWY_API Vec128 ShuffleTwo1230(const Vec128 a, + const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i16x8_shuffle(a.raw, b.raw, 0, 3, 2 + 8, 1 + 8, + 0x7FFF, 0x7FFF, 0x7FFF, 0x7FFF)}; +} +template +HWY_API Vec128 ShuffleTwo1230(const Vec128 a, + const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 0, 3, 2 + 4, 1 + 4)}; +} + +template +HWY_API Vec128 ShuffleTwo3012(const Vec128 a, + const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i8x16_shuffle(a.raw, b.raw, 2, 1, 0 + 16, 3 + 16, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, + 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F)}; +} +template +HWY_API Vec128 ShuffleTwo3012(const Vec128 a, + const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i16x8_shuffle(a.raw, b.raw, 2, 1, 0 + 8, 3 + 8, + 0x7FFF, 0x7FFF, 0x7FFF, 0x7FFF)}; +} +template +HWY_API Vec128 ShuffleTwo3012(const Vec128 a, + const Vec128 b) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 2, 1, 0 + 4, 3 + 4)}; +} + +} // namespace detail + +// Swap 64-bit halves +template +HWY_API Vec128 Shuffle01(const Vec128 v) { + static_assert(sizeof(T) == 8, "Only for 64-bit lanes"); + return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, 1, 0)}; +} +template +HWY_API Vec128 Shuffle1032(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, 1, 0)}; +} + +// Rotate right 32 bits +template +HWY_API Vec128 Shuffle0321(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 2, 3, 0)}; +} + +// Rotate left 32 bits +template +HWY_API Vec128 Shuffle2103(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 0, 1, 2)}; +} + +// Reverse +template +HWY_API Vec128 Shuffle0123(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 2, 1, 0)}; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices128 { + __v128_u raw; +}; + +namespace detail { + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; + return Iota(d8, 0); +} + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14}; + return Load(d8, kBroadcastLaneBytes); +} + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12}; + return Load(d8, kBroadcastLaneBytes); +} + +template +HWY_INLINE VFromD> IndicesFromVecBroadcastLaneBytes( + D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 8, 8}; + return Load(d8, kBroadcastLaneBytes); +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + return Zero(d8); +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kByteOffsets[16] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}; + return Load(d8, kByteOffsets); +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kByteOffsets[16] = { + 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; + return Load(d8, kByteOffsets); +} + +template +HWY_INLINE VFromD> IndicesFromVecByteOffsets(D d) { + const Repartition d8; + alignas(16) static constexpr uint8_t kByteOffsets[16] = { + 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7}; + return Load(d8, kByteOffsets); +} + +} // namespace detail + +template +HWY_API Indices128, MaxLanes(D())> IndicesFromVec( + D d, Vec128 vec) { + using T = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const RebindToUnsigned du; + using TU = TFromD; + HWY_DASSERT(AllTrue( + du, Lt(BitCast(du, vec), Set(du, static_cast(MaxLanes(d) * 2))))); +#endif + + (void)d; + return Indices128, MaxLanes(D())>{vec.raw}; +} + +template +HWY_API Indices128, MaxLanes(D())> IndicesFromVec( + D d, Vec128 vec) { + using T = TFromD; + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const RebindToUnsigned du; + using TU = TFromD; + HWY_DASSERT(AllTrue( + du, Lt(BitCast(du, vec), Set(du, static_cast(MaxLanes(d) * 2))))); +#endif + + const Repartition d8; + using V8 = VFromD; + + // Broadcast each lane index to all bytes of T and shift to bytes + const V8 lane_indices = TableLookupBytes( + BitCast(d8, vec), detail::IndicesFromVecBroadcastLaneBytes(d)); + constexpr int kIndexShiftAmt = static_cast(FloorLog2(sizeof(T))); + const V8 byte_indices = ShiftLeft(lane_indices); + const V8 sum = Add(byte_indices, detail::IndicesFromVecByteOffsets(d)); + return Indices128, MaxLanes(D())>{BitCast(d, sum).raw}; +} + +template +HWY_API Indices128, HWY_MAX_LANES_D(D)> SetTableIndices( + D d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + using TI = MakeSigned; + const DFromV d; + const Rebind di; + return BitCast(d, TableLookupBytes(BitCast(di, v), Vec128{idx.raw})); +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { + const DFromV d; + const Twice dt; +// TableLookupLanes currently requires table and index vectors to be the same +// size, though a half-length index vector would be sufficient here. +#if HWY_IS_MSAN + const Vec128 idx_vec{idx.raw}; + const Indices128 idx2{Combine(dt, idx_vec, idx_vec).raw}; +#else + // We only keep LowerHalf of the result, which is valid in idx. + const Indices128 idx2{idx.raw}; +#endif + return LowerHalf(d, TableLookupLanes(Combine(dt, b, a), idx2)); +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { + const DFromV d; + const Repartition du8; + + const VFromD byte_idx{idx.raw}; + const auto byte_idx_mod = byte_idx & Set(du8, uint8_t{0x0F}); + // If ANDing did not change the index, it is for the lower half. + const auto is_lo = (byte_idx == byte_idx_mod); + + return BitCast(d, IfThenElse(is_lo, TableLookupBytes(a, byte_idx_mod), + TableLookupBytes(b, byte_idx_mod))); +} + +// ------------------------------ Reverse (Shuffle0123, Shuffle2301, Shuffle01) + +// Single lane: no change +template , HWY_IF_LANES_D(D, 1)> +HWY_API Vec128 Reverse(D /* tag */, Vec128 v) { + return v; +} + +// 32-bit x2: shuffle +template , HWY_IF_T_SIZE(T, 4)> +HWY_API Vec64 Reverse(D /* tag */, const Vec64 v) { + return Vec64{Shuffle2301(Vec128{v.raw}).raw}; +} + +// 64-bit x2: shuffle +template , HWY_IF_T_SIZE(T, 8)> +HWY_API Vec128 Reverse(D /* tag */, const Vec128 v) { + return Shuffle01(v); +} + +// 32-bit x2: shuffle +template , HWY_IF_T_SIZE(T, 4)> +HWY_API Vec128 Reverse(D /* tag */, const Vec128 v) { + return Shuffle0123(v); +} + +// 16-bit +template +HWY_API VFromD Reverse(D d, const VFromD v) { + const RepartitionToWide> du32; + return BitCast(d, RotateRight<16>(Reverse(du32, BitCast(du32, v)))); +} + +template +HWY_API VFromD Reverse(D d, const VFromD v) { + static constexpr int kN = 16 + Lanes(d); + return VFromD{wasm_i8x16_shuffle( + v.raw, v.raw, + // kN is adjusted to ensure we have valid indices for all lengths. + kN - 1, kN - 2, kN - 3, kN - 4, kN - 5, kN - 6, kN - 7, kN - 8, kN - 9, + kN - 10, kN - 11, kN - 12, kN - 13, kN - 14, kN - 15, kN - 16)}; +} + +// ------------------------------ Reverse2 + +template +HWY_API VFromD Reverse2(D d, const VFromD v) { + const RepartitionToWide> dw; + return BitCast(d, RotateRight<16>(BitCast(dw, v))); +} + +template +HWY_API VFromD Reverse2(D /* tag */, const VFromD v) { + return Shuffle2301(v); +} + +template +HWY_API VFromD Reverse2(D /* tag */, const VFromD v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template +HWY_API VFromD Reverse4(D /* tag */, const VFromD v) { + return VFromD{wasm_i16x8_shuffle(v.raw, v.raw, 3, 2, 1, 0, 7, 6, 5, 4)}; +} + +template +HWY_API VFromD Reverse4(D /* tag */, const VFromD v) { + return Shuffle0123(v); +} + +template +HWY_API VFromD Reverse4(D /* tag */, const VFromD) { + HWY_ASSERT(0); // don't have 8 u64 lanes +} + +// ------------------------------ Reverse8 + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + return Reverse(d, v); +} + +template +HWY_API VFromD Reverse8(D /* tag */, const VFromD) { + HWY_ASSERT(0); // don't have 8 lanes for > 16-bit lanes +} + +// ------------------------------ InterleaveLower + +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i8x16_shuffle( + a.raw, b.raw, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 0, 8, 1, 9, 2, 10, 3, 11)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; +} + +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i8x16_shuffle( + a.raw, b.raw, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 0, 8, 1, 9, 2, 10, 3, 11)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; +} + +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; +} + +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; +} + +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, InterleaveLower(BitCast(du, a), BitCast(du, b))); +} + +// Additional overload for the optional tag (all vector lengths). +template +HWY_API VFromD InterleaveLower(D /* tag */, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveUpper (UpperHalf) + +// All functions inside detail lack the required D parameter. +namespace detail { + +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i8x16_shuffle(a.raw, b.raw, 8, 24, 9, 25, 10, + 26, 11, 27, 12, 28, 13, 29, 14, + 30, 15, 31)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; +} + +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i8x16_shuffle(a.raw, b.raw, 8, 24, 9, 25, 10, + 26, 11, 27, 12, 28, 13, 29, 14, + 30, 15, 31)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; +} + +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; +} +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; +} + +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; +} + +template +HWY_API Vec128 InterleaveUpper(Vec128 a, + Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; +} + +} // namespace detail + +// Full +template > +HWY_API Vec128 InterleaveUpper(D /* tag */, Vec128 a, Vec128 b) { + return detail::InterleaveUpper(a, b); +} + +// Partial +template +HWY_API VFromD InterleaveUpper(D d, VFromD a, VFromD b) { + const Half d2; + return InterleaveLower(d, VFromD{UpperHalf(d2, a).raw}, + VFromD{UpperHalf(d2, b).raw}); +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ------------------------------ Per4LaneBlockShuffle +namespace detail { + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, + hwy::SizeTag /*vect_size_tag*/, + V v) { + constexpr int kIdx3 = static_cast((kIdx3210 >> 6) & 3); + constexpr int kIdx2 = static_cast((kIdx3210 >> 4) & 3); + constexpr int kIdx1 = static_cast((kIdx3210 >> 2) & 3); + constexpr int kIdx0 = static_cast(kIdx3210 & 3); + return V{wasm_i8x16_shuffle(v.raw, v.raw, kIdx0, kIdx1, kIdx2, kIdx3, + kIdx0 + 4, kIdx1 + 4, kIdx2 + 4, kIdx3 + 4, + kIdx0 + 8, kIdx1 + 8, kIdx2 + 8, kIdx3 + 8, + kIdx0 + 12, kIdx1 + 12, kIdx2 + 12, kIdx3 + 12)}; +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, + hwy::SizeTag /*vect_size_tag*/, + V v) { + constexpr int kIdx3 = static_cast((kIdx3210 >> 6) & 3); + constexpr int kIdx2 = static_cast((kIdx3210 >> 4) & 3); + constexpr int kIdx1 = static_cast((kIdx3210 >> 2) & 3); + constexpr int kIdx0 = static_cast(kIdx3210 & 3); + return V{wasm_i16x8_shuffle(v.raw, v.raw, kIdx0, kIdx1, kIdx2, kIdx3, + kIdx0 + 4, kIdx1 + 4, kIdx2 + 4, kIdx3 + 4)}; +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, + hwy::SizeTag /*vect_size_tag*/, + V v) { + constexpr int kIdx3 = static_cast((kIdx3210 >> 6) & 3); + constexpr int kIdx2 = static_cast((kIdx3210 >> 4) & 3); + constexpr int kIdx1 = static_cast((kIdx3210 >> 2) & 3); + constexpr int kIdx0 = static_cast(kIdx3210 & 3); + return V{wasm_i32x4_shuffle(v.raw, v.raw, kIdx0, kIdx1, kIdx2, kIdx3)}; +} + +} // namespace detail + +// ------------------------------ SlideUpLanes + +namespace detail { + +template +HWY_INLINE V SlideUpLanes(V v, size_t amt) { + const DFromV d; + const Full64 du64; + const auto vu64 = ResizeBitCast(du64, v); + return ResizeBitCast( + d, ShiftLeftSame(vu64, static_cast(amt * sizeof(TFromV) * 8))); +} + +template +HWY_INLINE V SlideUpLanes(V v, size_t amt) { + const DFromV d; + const Repartition du8; + const auto idx = + Iota(du8, static_cast(size_t{0} - amt * sizeof(TFromV))); + return BitCast(d, TableLookupBytesOr0(BitCast(du8, v), idx)); +} + +} // namespace detail + +template +HWY_API VFromD SlideUpLanes(D /*d*/, VFromD v, size_t /*amt*/) { + return v; +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + case 2: + return ShiftLeftLanes<2>(d, v); + case 3: + return ShiftLeftLanes<3>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + case 2: + return ShiftLeftLanes<2>(d, v); + case 3: + return ShiftLeftLanes<3>(d, v); + case 4: + return ShiftLeftLanes<4>(d, v); + case 5: + return ShiftLeftLanes<5>(d, v); + case 6: + return ShiftLeftLanes<6>(d, v); + case 7: + return ShiftLeftLanes<7>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + case 2: + return ShiftLeftLanes<2>(d, v); + case 3: + return ShiftLeftLanes<3>(d, v); + case 4: + return ShiftLeftLanes<4>(d, v); + case 5: + return ShiftLeftLanes<5>(d, v); + case 6: + return ShiftLeftLanes<6>(d, v); + case 7: + return ShiftLeftLanes<7>(d, v); + case 8: + return ShiftLeftLanes<8>(d, v); + case 9: + return ShiftLeftLanes<9>(d, v); + case 10: + return ShiftLeftLanes<10>(d, v); + case 11: + return ShiftLeftLanes<11>(d, v); + case 12: + return ShiftLeftLanes<12>(d, v); + case 13: + return ShiftLeftLanes<13>(d, v); + case 14: + return ShiftLeftLanes<14>(d, v); + case 15: + return ShiftLeftLanes<15>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +// ------------------------------ SlideDownLanes + +namespace detail { + +template +HWY_INLINE V SlideDownLanes(V v, size_t amt) { + const DFromV d; + const Repartition, decltype(d)> dv; + return BitCast(d, + ShiftRightSame(BitCast(dv, v), + static_cast(amt * sizeof(TFromV) * 8))); +} + +template +HWY_INLINE V SlideDownLanes(V v, size_t amt) { + const DFromV d; + const Repartition di8; + auto idx = Iota(di8, static_cast(amt * sizeof(TFromV))); + idx = Or(idx, VecFromMask(di8, idx > Set(di8, int8_t{15}))); + return BitCast(d, TableLookupBytesOr0(BitCast(di8, v), idx)); +} + +} // namespace detail + +template +HWY_API VFromD SlideDownLanes(D /*d*/, VFromD v, size_t /*amt*/) { + return v; +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + case 4: + return ShiftRightLanes<4>(d, v); + case 5: + return ShiftRightLanes<5>(d, v); + case 6: + return ShiftRightLanes<6>(d, v); + case 7: + return ShiftRightLanes<7>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + case 4: + return ShiftRightLanes<4>(d, v); + case 5: + return ShiftRightLanes<5>(d, v); + case 6: + return ShiftRightLanes<6>(d, v); + case 7: + return ShiftRightLanes<7>(d, v); + case 8: + return ShiftRightLanes<8>(d, v); + case 9: + return ShiftRightLanes<9>(d, v); + case 10: + return ShiftRightLanes<10>(d, v); + case 11: + return ShiftRightLanes<11>(d, v); + case 12: + return ShiftRightLanes<12>(d, v); + case 13: + return ShiftRightLanes<13>(d, v); + case 14: + return ShiftRightLanes<14>(d, v); + case 15: + return ShiftRightLanes<15>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) + +// N = N/2 + N/2 (upper half undefined) +template >> +HWY_API VFromD Combine(D d, VH hi_half, VH lo_half) { + const Half dh; + const RebindToUnsigned duh; + // Treat half-width input as one lane, and expand to two lanes. + using VU = Vec128, 2>; + const VU lo{BitCast(duh, lo_half).raw}; + const VU hi{BitCast(duh, hi_half).raw}; + return BitCast(d, InterleaveLower(lo, hi)); +} + +// ------------------------------ ZeroExtendVector (IfThenElseZero) +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { + const Half dh; + return IfThenElseZero(FirstN(d, MaxLanes(dh)), VFromD{lo.raw}); +} + +// ------------------------------ ConcatLowerLower +template > +HWY_API Vec128 ConcatLowerLower(D /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{wasm_i64x2_shuffle(lo.raw, hi.raw, 0, 2)}; +} + +// ------------------------------ ConcatUpperUpper +template > +HWY_API Vec128 ConcatUpperUpper(D /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{wasm_i64x2_shuffle(lo.raw, hi.raw, 1, 3)}; +} + +// ------------------------------ ConcatLowerUpper +template > +HWY_API Vec128 ConcatLowerUpper(D d, Vec128 hi, Vec128 lo) { + return CombineShiftRightBytes<8>(d, hi, lo); +} + +// ------------------------------ ConcatUpperLower +template > +HWY_API Vec128 ConcatUpperLower(D d, Vec128 hi, Vec128 lo) { + return IfThenElse(FirstN(d, Lanes(d) / 2), lo, hi); +} + +// ------------------------------ Concat partial (Combine, LowerHalf) + +template +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), LowerHalf(d2, lo)); +} + +template +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, UpperHalf(d2, hi), UpperHalf(d2, lo)); +} + +template +HWY_API VFromD ConcatLowerUpper(D d, const VFromD hi, + const VFromD lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), UpperHalf(d2, lo)); +} + +template +HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, UpperHalf(d2, hi), LowerHalf(d2, lo)); +} + +// ------------------------------ ConcatOdd + +// 8-bit full +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec128 ConcatOdd(D /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 1, 3, 5, 7, 9, 11, 13, 15, + 17, 19, 21, 23, 25, 27, 29, 31)}; +} + +// 8-bit x8 +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec64 ConcatOdd(D /* tag */, Vec64 hi, Vec64 lo) { + // Don't care about upper half. + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 1, 3, 5, 7, 17, 19, 21, + 23, 1, 3, 5, 7, 17, 19, 21, 23)}; +} + +// 8-bit x4 +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec32 ConcatOdd(D /* tag */, Vec32 hi, Vec32 lo) { + // Don't care about upper 3/4. + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 1, 3, 17, 19, 1, 3, 17, + 19, 1, 3, 17, 19, 1, 3, 17, 19)}; +} + +// 16-bit full +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec128 ConcatOdd(D /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{ + wasm_i16x8_shuffle(lo.raw, hi.raw, 1, 3, 5, 7, 9, 11, 13, 15)}; +} + +// 16-bit x4 +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec64 ConcatOdd(D /* tag */, Vec64 hi, Vec64 lo) { + // Don't care about upper half. + return Vec128{ + wasm_i16x8_shuffle(lo.raw, hi.raw, 1, 3, 9, 11, 1, 3, 9, 11)}; +} + +// 32-bit full +template , HWY_IF_T_SIZE(T, 4)> +HWY_API Vec128 ConcatOdd(D /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{wasm_i32x4_shuffle(lo.raw, hi.raw, 1, 3, 5, 7)}; +} + +// Any T x2 +template , HWY_IF_LANES_D(D, 2)> +HWY_API Vec128 ConcatOdd(D d, Vec128 hi, Vec128 lo) { + return InterleaveUpper(d, lo, hi); +} + +// ------------------------------ ConcatEven (InterleaveLower) + +// 8-bit full +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec128 ConcatEven(D /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{wasm_i8x16_shuffle(lo.raw, hi.raw, 0, 2, 4, 6, 8, 10, 12, 14, + 16, 18, 20, 22, 24, 26, 28, 30)}; +} + +// 8-bit x8 +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec64 ConcatEven(D /* tag */, Vec64 hi, Vec64 lo) { + // Don't care about upper half. + return Vec64{wasm_i8x16_shuffle(lo.raw, hi.raw, 0, 2, 4, 6, 16, 18, 20, 22, + 0, 2, 4, 6, 16, 18, 20, 22)}; +} + +// 8-bit x4 +template , HWY_IF_T_SIZE(T, 1)> +HWY_API Vec32 ConcatEven(D /* tag */, Vec32 hi, Vec32 lo) { + // Don't care about upper 3/4. + return Vec32{wasm_i8x16_shuffle(lo.raw, hi.raw, 0, 2, 16, 18, 0, 2, 16, 18, + 0, 2, 16, 18, 0, 2, 16, 18)}; +} + +// 16-bit full +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec128 ConcatEven(D /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{ + wasm_i16x8_shuffle(lo.raw, hi.raw, 0, 2, 4, 6, 8, 10, 12, 14)}; +} + +// 16-bit x4 +template , HWY_IF_T_SIZE(T, 2)> +HWY_API Vec64 ConcatEven(D /* tag */, Vec64 hi, Vec64 lo) { + // Don't care about upper half. + return Vec64{wasm_i16x8_shuffle(lo.raw, hi.raw, 0, 2, 8, 10, 0, 2, 8, 10)}; +} + +// 32-bit full +template , HWY_IF_T_SIZE(T, 4)> +HWY_API Vec128 ConcatEven(D /* tag */, Vec128 hi, Vec128 lo) { + return Vec128{wasm_i32x4_shuffle(lo.raw, hi.raw, 0, 2, 4, 6)}; +} + +// Any T x2 +template , HWY_IF_LANES_D(D, 2)> +HWY_API Vec128 ConcatEven(D d, Vec128 hi, Vec128 lo) { + return InterleaveLower(d, lo, hi); +} + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec128 DupEven(Vec128 v) { + return Vec128{wasm_i8x16_shuffle(v.raw, v.raw, 0, 0, 2, 2, 4, 4, 6, 6, + 8, 8, 10, 10, 12, 12, 14, 14)}; +} + +template +HWY_API Vec128 DupEven(Vec128 v) { + return Vec128{wasm_i16x8_shuffle(v.raw, v.raw, 0, 0, 2, 2, 4, 4, 6, 6)}; +} + +template +HWY_API Vec128 DupEven(Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 0, 0, 2, 2)}; +} + +template +HWY_API Vec128 DupEven(const Vec128 v) { + return InterleaveLower(DFromV(), v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return Vec128{wasm_i8x16_shuffle(v.raw, v.raw, 1, 1, 3, 3, 5, 5, 7, 7, + 9, 9, 11, 11, 13, 13, 15, 15)}; +} + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return Vec128{wasm_i16x8_shuffle(v.raw, v.raw, 1, 1, 3, 3, 5, 5, 7, 7)}; +} + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 1, 3, 3)}; +} + +template +HWY_API Vec128 DupOdd(const Vec128 v) { + return InterleaveUpper(DFromV(), v, v); +} + +// ------------------------------ OddEven + +namespace detail { + +template +HWY_INLINE Vec128 OddEven(hwy::SizeTag<1> /* tag */, const Vec128 a, + const Vec128 b) { + const DFromV d; + const Repartition d8; + alignas(16) static constexpr uint8_t mask[16] = { + 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; + return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); +} +template +HWY_INLINE Vec128 OddEven(hwy::SizeTag<2> /* tag */, const Vec128 a, + const Vec128 b) { + return Vec128{ + wasm_i16x8_shuffle(a.raw, b.raw, 8, 1, 10, 3, 12, 5, 14, 7)}; +} +template +HWY_INLINE Vec128 OddEven(hwy::SizeTag<4> /* tag */, const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 4, 1, 6, 3)}; +} +template +HWY_INLINE Vec128 OddEven(hwy::SizeTag<8> /* tag */, const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i64x2_shuffle(a.raw, b.raw, 2, 1)}; +} + +} // namespace detail + +template +HWY_API Vec128 OddEven(const Vec128 a, const Vec128 b) { + return detail::OddEven(hwy::SizeTag(), a, b); +} +template +HWY_API Vec128 OddEven(const Vec128 a, + const Vec128 b) { + return Vec128{wasm_i32x4_shuffle(a.raw, b.raw, 4, 1, 6, 3)}; +} + +// ------------------------------ InterleaveEven +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{wasm_i8x16_shuffle(a.raw, b.raw, 0, 16, 2, 18, 4, 20, 6, 22, + 8, 24, 10, 26, 12, 28, 14, 30)}; +} + +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{wasm_i16x8_shuffle(a.raw, b.raw, 0, 8, 2, 10, 4, 12, 6, 14)}; +} + +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 2, 6)}; +} + +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// ------------------------------ InterleaveOdd +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{wasm_i8x16_shuffle(a.raw, b.raw, 1, 17, 3, 19, 5, 21, 7, 23, + 9, 25, 11, 27, 13, 29, 15, 31)}; +} + +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{wasm_i16x8_shuffle(a.raw, b.raw, 1, 9, 3, 11, 5, 13, 7, 15)}; +} + +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{wasm_i32x4_shuffle(a.raw, b.raw, 1, 5, 3, 7)}; +} + +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + return InterleaveUpper(d, a, b); +} + +// ------------------------------ OddEvenBlocks +template +HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks +template +HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { + return v; +} + +// ------------------------------ InterleaveEvenBlocks +template , HWY_IF_V_SIZE_LE_D(D, 16)> +HWY_API V InterleaveEvenBlocks(D, V a, V /*b*/) { + return a; +} +// ------------------------------ InterleaveOddBlocks +template , HWY_IF_V_SIZE_LE_D(D, 16)> +HWY_API V InterleaveOddBlocks(D, V a, V /*b*/) { + return a; +} + +// ------------------------------ InterleaveLowerBlocks +template , HWY_IF_V_SIZE_LE_D(D, 16)> +HWY_API V InterleaveLowerBlocks(D, V a, V /*b*/) { + return a; +} +// ------------------------------ InterleaveUpperBlocks +template , HWY_IF_V_SIZE_LE_D(D, 16)> +HWY_API V InterleaveUpperBlocks(D, V a, V /*b*/) { + return a; +} + +// ------------------------------ ReverseBlocks +template +HWY_API VFromD ReverseBlocks(D /* tag */, VFromD v) { + return v; // Single block: no change +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned: zero-extend. +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_u16x8_extend_low_u8x16(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_u32x4_extend_low_u16x8(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_u64x2_extend_low_u32x4(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{ + wasm_u32x4_extend_low_u16x8(wasm_u16x8_extend_low_u8x16(v.raw))}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_u16x8_extend_low_u8x16(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_u32x4_extend_low_u16x8(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_u64x2_extend_low_u32x4(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{ + wasm_u32x4_extend_low_u16x8(wasm_u16x8_extend_low_u8x16(v.raw))}; +} + +// U8/U16 to U64/I64: First, zero-extend to U32, and then zero-extend to +// TFromD +template +HWY_API VFromD PromoteTo(D d, V v) { + const Rebind du32; + return PromoteTo(d, PromoteTo(du32, v)); +} + +// Signed: replicate sign bit. +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_i16x8_extend_low_i8x16(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_i32x4_extend_low_i16x8(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_i64x2_extend_low_i32x4(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{ + wasm_i32x4_extend_low_i16x8(wasm_i16x8_extend_low_i8x16(v.raw))}; +} + +// I8/I16 to I64: First, promote to I32, and then promote to I64 +template +HWY_API VFromD PromoteTo(D d, V v) { + const Rebind di32; + return PromoteTo(d, PromoteTo(di32, v)); +} + +template +HWY_API VFromD PromoteTo(D df32, VFromD> v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_f64x2_convert_low_i32x4(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_f64x2_convert_low_u32x4(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_f64x2_promote_low_f32x4(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D di64, VFromD> v) { + const Rebind di32; + const RebindToFloat df32; + const RebindToUnsigned du32; + const Repartition du32_as_du8; + + const auto exponent_adj = BitCast( + du32, + Min(SaturatedSub(BitCast(du32_as_du8, ShiftRight<23>(BitCast(du32, v))), + BitCast(du32_as_du8, Set(du32, uint32_t{157}))), + BitCast(du32_as_du8, Set(du32, uint32_t{32})))); + const auto adj_v = + BitCast(df32, BitCast(du32, v) - ShiftLeft<23>(exponent_adj)); + + const auto f32_to_i32_result = ConvertTo(di32, adj_v); + const auto lo64_or_mask = PromoteTo( + di64, + BitCast(du32, VecFromMask(di32, Eq(f32_to_i32_result, + Set(di32, LimitsMax()))))); + + return Or(PromoteTo(di64, BitCast(di32, f32_to_i32_result)) + << PromoteTo(di64, exponent_adj), + lo64_or_mask); +} + +template +HWY_API VFromD PromoteTo(D du64, VFromD> v) { + const Rebind du32; + const RebindToFloat df32; + const Repartition du32_as_du8; + + const auto exponent_adj = BitCast( + du32, + Min(SaturatedSub(BitCast(du32_as_du8, ShiftRight<23>(BitCast(du32, v))), + BitCast(du32_as_du8, Set(du32, uint32_t{158}))), + BitCast(du32_as_du8, Set(du32, uint32_t{32})))); + + const auto adj_v = + BitCast(df32, BitCast(du32, v) - ShiftLeft<23>(exponent_adj)); + const auto f32_to_u32_result = ConvertTo(du32, adj_v); + const auto lo32_or_mask = PromoteTo( + du64, + VecFromMask(du32, f32_to_u32_result == Set(du32, LimitsMax()))); + + return Or(PromoteTo(du64, f32_to_u32_result) << PromoteTo(du64, exponent_adj), + lo32_or_mask); +} + +// ------------------------------ PromoteUpperTo + +// Per-target flag to prevent generic_ops-inl.h from defining PromoteUpperTo. +#ifdef HWY_NATIVE_PROMOTE_UPPER_TO +#undef HWY_NATIVE_PROMOTE_UPPER_TO +#else +#define HWY_NATIVE_PROMOTE_UPPER_TO +#endif + +// Unsigned: zero-extend. +template +HWY_API VFromD PromoteUpperTo(D /* tag */, + VFromD> v) { + return VFromD{wasm_u16x8_extend_high_u8x16(v.raw)}; +} +template +HWY_API VFromD PromoteUpperTo(D /* tag */, + VFromD> v) { + return VFromD{wasm_u32x4_extend_high_u16x8(v.raw)}; +} +template +HWY_API VFromD PromoteUpperTo(D /* tag */, + VFromD> v) { + return VFromD{wasm_u64x2_extend_high_u32x4(v.raw)}; +} + +template +HWY_API VFromD PromoteUpperTo(D /* tag */, + VFromD> v) { + return VFromD{wasm_u16x8_extend_high_u8x16(v.raw)}; +} +template +HWY_API VFromD PromoteUpperTo(D /* tag */, + VFromD> v) { + return VFromD{wasm_u32x4_extend_high_u16x8(v.raw)}; +} +template +HWY_API VFromD PromoteUpperTo(D /* tag */, + VFromD> v) { + return VFromD{wasm_u64x2_extend_high_u32x4(v.raw)}; +} + +// Signed: replicate sign bit. +template +HWY_API VFromD PromoteUpperTo(D /* tag */, + VFromD> v) { + return VFromD{wasm_i16x8_extend_high_i8x16(v.raw)}; +} +template +HWY_API VFromD PromoteUpperTo(D /* tag */, + VFromD> v) { + return VFromD{wasm_i32x4_extend_high_i16x8(v.raw)}; +} +template +HWY_API VFromD PromoteUpperTo(D /* tag */, + VFromD> v) { + return VFromD{wasm_i64x2_extend_high_i32x4(v.raw)}; +} + +template +HWY_API VFromD PromoteUpperTo(D df32, VFromD> v) { + const Rebind dh; + return PromoteTo(df32, UpperHalf(dh, v)); +} + +template +HWY_API VFromD PromoteUpperTo(D df32, VFromD> v) { + const Repartition du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteUpperTo(di32, BitCast(du16, v)))); +} + +template +HWY_API VFromD PromoteUpperTo(D dd, VFromD> v) { + // There is no wasm_f64x2_convert_high_i32x4. + return PromoteTo(dd, UpperHalf(Rebind(), v)); +} + +template +HWY_API VFromD PromoteUpperTo(D dd, VFromD> v) { + // There is no wasm_f64x2_convert_high_u32x4. + return PromoteTo(dd, UpperHalf(Rebind(), v)); +} + +template +HWY_API VFromD PromoteUpperTo(D dd, VFromD> v) { + // There is no wasm_f64x2_promote_high_f32x4. + return PromoteTo(dd, UpperHalf(Rebind(), v)); +} + +template +HWY_API VFromD PromoteUpperTo(D d64, VFromD> v) { + return PromoteTo(d64, UpperHalf(Rebind(), v)); +} + +// Generic version for <=64 bit input/output (_high is only for full vectors). +template +HWY_API VFromD PromoteUpperTo(D d, V v) { + const Rebind, decltype(d)> dh; + return PromoteTo(d, UpperHalf(dh, v)); +} + +// ------------------------------ PromoteEvenTo/PromoteOddTo +#include "hwy/ops/inside-inl.h" + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_u16x8_narrow_i32x4(v.raw, v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_i16x8_narrow_i32x4(v.raw, v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); + return VFromD{wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_u8x16_narrow_i16x8(v.raw, v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); + return VFromD{wasm_i8x16_narrow_i16x8(intermediate, intermediate)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_i8x16_narrow_i16x8(v.raw, v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D dn, VFromD> v) { + const DFromV du32; + const RebindToSigned di32; + return DemoteTo(dn, BitCast(di32, Min(v, Set(du32, 0x7FFFFFFF)))); +} + +template +HWY_API VFromD DemoteTo(D du8, VFromD> v) { + const DFromV du16; + const RebindToSigned di16; + return DemoteTo(du8, BitCast(di16, Min(v, Set(du16, 0x7FFF)))); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_i32x4_trunc_sat_f64x2_zero(v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_u32x4_trunc_sat_f64x2_zero(v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{wasm_f32x4_demote_f64x2_zero(v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D df32, VFromD> v) { + const Rebind df64; + const RebindToUnsigned du64; + const RebindToSigned di32; + const RebindToUnsigned du32; + + const auto k2p64_63 = Set(df64, 27670116110564327424.0); + const auto f64_hi52 = + Xor(BitCast(df64, ShiftRight<12>(BitCast(du64, v))), k2p64_63) - k2p64_63; + const auto f64_lo12 = + PromoteTo(df64, BitCast(di32, And(TruncateTo(du32, BitCast(du64, v)), + Set(du32, uint32_t{0x00000FFF})))); + + const auto f64_sum = f64_hi52 + f64_lo12; + const auto f64_carry = (f64_hi52 - f64_sum) + f64_lo12; + + const auto f64_sum_is_inexact = + ShiftRight<63>(BitCast(du64, VecFromMask(df64, f64_carry != Zero(df64)))); + const auto f64_bits_decrement = + And(ShiftRight<63>(BitCast(du64, Xor(f64_sum, f64_carry))), + f64_sum_is_inexact); + + const auto adj_f64_val = BitCast( + df64, + Or(BitCast(du64, f64_sum) - f64_bits_decrement, f64_sum_is_inexact)); + + return DemoteTo(df32, adj_f64_val); +} +template +HWY_API VFromD DemoteTo(D df32, VFromD> v) { + const Rebind df64; + const RebindToUnsigned du64; + const RebindToSigned di32; + const RebindToUnsigned du32; + + const auto k2p64 = Set(df64, 18446744073709551616.0); + const auto f64_hi52 = Or(BitCast(df64, ShiftRight<12>(v)), k2p64) - k2p64; + const auto f64_lo12 = + PromoteTo(df64, BitCast(di32, And(TruncateTo(du32, BitCast(du64, v)), + Set(du32, uint32_t{0x00000FFF})))); + + const auto f64_sum = f64_hi52 + f64_lo12; + const auto f64_carry = (f64_hi52 - f64_sum) + f64_lo12; + const auto f64_sum_is_inexact = + ShiftRight<63>(BitCast(du64, VecFromMask(df64, f64_carry != Zero(df64)))); + + const auto adj_f64_val = BitCast( + df64, + Or(BitCast(du64, f64_sum) - ShiftRight<63>(BitCast(du64, f64_carry)), + f64_sum_is_inexact)); + + return DemoteTo(df32, adj_f64_val); +} + +// Specializations for partial vectors because i16x8_narrow_i32x4 sets lanes +// above 2*N. +template +HWY_API Vec32 ReorderDemote2To(D dn, Vec32 a, + Vec32 b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +template +HWY_API Vec64 ReorderDemote2To(D dn, Vec64 a, + Vec64 b) { + const Twice dn_full; + const Repartition du32_full; + + const Vec128 v_full{wasm_i16x8_narrow_i32x4(a.raw, b.raw)}; + const auto vu32_full = BitCast(du32_full, v_full); + return LowerHalf( + BitCast(dn_full, ConcatEven(du32_full, vu32_full, vu32_full))); +} +template +HWY_API Vec128 ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return Vec128{wasm_i16x8_narrow_i32x4(a.raw, b.raw)}; +} + +template +HWY_API Vec32 ReorderDemote2To(D dn, Vec32 a, + Vec32 b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +template +HWY_API Vec64 ReorderDemote2To(D dn, Vec64 a, + Vec64 b) { + const Twice dn_full; + const Repartition du32_full; + + const Vec128 v_full{wasm_u16x8_narrow_i32x4(a.raw, b.raw)}; + const auto vu32_full = BitCast(du32_full, v_full); + return LowerHalf( + BitCast(dn_full, ConcatEven(du32_full, vu32_full, vu32_full))); +} +template +HWY_API Vec128 ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return Vec128{wasm_u16x8_narrow_i32x4(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D dn, Vec128 a, + Vec128 b) { + const DFromV du32; + const RebindToSigned di32; + const auto max_i32 = Set(du32, 0x7FFFFFFFu); + + const auto clamped_a = BitCast(di32, Min(a, max_i32)); + const auto clamped_b = BitCast(di32, Min(b, max_i32)); + return ReorderDemote2To(dn, clamped_a, clamped_b); +} +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} + +// Specializations for partial vectors because i8x16_narrow_i16x8 sets lanes +// above 2*N. +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +template +HWY_API Vec64 ReorderDemote2To(D dn, Vec64 a, + Vec64 b) { + const Twice dn_full; + const Repartition du32_full; + + const Vec128 v_full{wasm_i8x16_narrow_i16x8(a.raw, b.raw)}; + const auto vu32_full = BitCast(du32_full, v_full); + return LowerHalf( + BitCast(dn_full, ConcatEven(du32_full, vu32_full, vu32_full))); +} +template +HWY_API Vec128 ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return Vec128{wasm_i8x16_narrow_i16x8(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +template +HWY_API Vec64 ReorderDemote2To(D dn, Vec64 a, + Vec64 b) { + const Twice dn_full; + const Repartition du32_full; + + const Vec128 v_full{wasm_u8x16_narrow_i16x8(a.raw, b.raw)}; + const auto vu32_full = BitCast(du32_full, v_full); + return LowerHalf( + BitCast(dn_full, ConcatEven(du32_full, vu32_full, vu32_full))); +} +template +HWY_API Vec128 ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return Vec128{wasm_u8x16_narrow_i16x8(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D dn, Vec128 a, + Vec128 b) { + const DFromV du16; + const RebindToSigned di16; + const auto max_i16 = Set(du16, 0x7FFFu); + + const auto clamped_a = BitCast(di16, Min(a, max_i16)); + const auto clamped_b = BitCast(di16, Min(b, max_i16)); + return ReorderDemote2To(dn, clamped_a, clamped_b); +} +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} + +// For already range-limited input [0, 255]. +template +HWY_API Vec128 U8FromU32(const Vec128 v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); + return Vec128{ + wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; +} + +// ------------------------------ Truncations + +template +HWY_API VFromD TruncateTo(DTo /* tag */, Vec128 v) { + // BitCast requires the same size; DTo might be u8x1 and v u16x1. + const Repartition, DFromV> dto; + return VFromD{BitCast(dto, v).raw}; +} + +template +HWY_API Vec16 TruncateTo(D /* tag */, Vec128 v) { + const Full128 d; + const auto v1 = BitCast(d, v); + const auto v2 = ConcatEven(d, v1, v1); + const auto v4 = ConcatEven(d, v2, v2); + return LowerHalf(LowerHalf(LowerHalf(ConcatEven(d, v4, v4)))); +} + +template +HWY_API Vec32 TruncateTo(D /* tag */, Vec128 v) { + const Full128 d; + const auto v1 = BitCast(d, v); + const auto v2 = ConcatEven(d, v1, v1); + return LowerHalf(LowerHalf(ConcatEven(d, v2, v2))); +} + +template +HWY_API Vec64 TruncateTo(D /* tag */, Vec128 v) { + const Full128 d; + const auto v1 = BitCast(d, v); + return LowerHalf(ConcatEven(d, v1, v1)); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, VFromD> v) { + const Repartition> d; + const auto v1 = Vec128{v.raw}; + const auto v2 = ConcatEven(d, v1, v1); + const auto v3 = ConcatEven(d, v2, v2); + return VFromD{v3.raw}; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, VFromD> v) { + const Repartition> d; + const auto v1 = Vec128{v.raw}; + const auto v2 = ConcatEven(d, v1, v1); + return VFromD{v2.raw}; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, VFromD> v) { + const Repartition> d; + const auto v1 = Vec128{v.raw}; + const auto v2 = ConcatEven(d, v1, v1); + return VFromD{v2.raw}; +} + +// ------------------------------ Demotions to/from i64 + +namespace detail { +template +HWY_INLINE VFromD> DemoteFromU64MaskOutResult( + D /*dn*/, VFromD> v) { + return v; +} + +template +HWY_INLINE VFromD> DemoteFromU64MaskOutResult( + D /*dn*/, VFromD> v) { + const DFromV du64; + return And(v, + Set(du64, static_cast(hwy::HighestValue>()))); +} + +template +HWY_INLINE VFromD> DemoteFromU64Saturate( + D dn, VFromD> v) { + const Rebind du64; + const RebindToSigned di64; + constexpr int kShiftAmt = static_cast(sizeof(TFromD) * 8) - + static_cast(hwy::IsSigned>()); + + const auto too_big = BitCast( + du64, VecFromMask( + di64, Gt(BitCast(di64, ShiftRight(v)), Zero(di64)))); + return DemoteFromU64MaskOutResult(dn, Or(v, too_big)); +} + +template +HWY_INLINE VFromD ReorderDemote2From64To32Combine(D dn, V a, V b) { + return ConcatEven(dn, BitCast(dn, b), BitCast(dn, a)); +} + +} // namespace detail + +template +HWY_API VFromD DemoteTo(D dn, VFromD> v) { + const DFromV di64; + const RebindToUnsigned du64; + const RebindToUnsigned dn_u; + + // Negative values are saturated by first saturating their bitwise inverse + // and then inverting the saturation result + const auto invert_mask = BitCast(du64, BroadcastSignBit(v)); + const auto saturated_vals = Xor( + invert_mask, + detail::DemoteFromU64Saturate(dn, Xor(invert_mask, BitCast(du64, v)))); + return BitCast(dn, TruncateTo(dn_u, saturated_vals)); +} + +template +HWY_API VFromD DemoteTo(D dn, VFromD> v) { + const DFromV di64; + const RebindToUnsigned du64; + + const auto non_neg_vals = BitCast(du64, AndNot(BroadcastSignBit(v), v)); + return TruncateTo(dn, detail::DemoteFromU64Saturate(dn, non_neg_vals)); +} + +template +HWY_API VFromD DemoteTo(D dn, VFromD> v) { + return TruncateTo(dn, detail::DemoteFromU64Saturate(dn, v)); +} + +template )> +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} + +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} + +template +HWY_API Vec128 ReorderDemote2To(D dn, Vec128 a, + Vec128 b) { + const DFromV di64; + const RebindToUnsigned du64; + const Half dnh; + + // Negative values are saturated by first saturating their bitwise inverse + // and then inverting the saturation result + const auto invert_mask_a = BitCast(du64, BroadcastSignBit(a)); + const auto invert_mask_b = BitCast(du64, BroadcastSignBit(b)); + const auto saturated_a = Xor( + invert_mask_a, + detail::DemoteFromU64Saturate(dnh, Xor(invert_mask_a, BitCast(du64, a)))); + const auto saturated_b = Xor( + invert_mask_b, + detail::DemoteFromU64Saturate(dnh, Xor(invert_mask_b, BitCast(du64, b)))); + + return ConcatEven(dn, BitCast(dn, saturated_b), BitCast(dn, saturated_a)); +} + +template +HWY_API Vec128 ReorderDemote2To(D dn, Vec128 a, + Vec128 b) { + const DFromV di64; + const RebindToUnsigned du64; + const Half dnh; + + const auto saturated_a = detail::DemoteFromU64Saturate( + dnh, BitCast(du64, AndNot(BroadcastSignBit(a), a))); + const auto saturated_b = detail::DemoteFromU64Saturate( + dnh, BitCast(du64, AndNot(BroadcastSignBit(b), b))); + + return ConcatEven(dn, BitCast(dn, saturated_b), BitCast(dn, saturated_a)); +} + +template +HWY_API Vec128 ReorderDemote2To(D dn, Vec128 a, + Vec128 b) { + const Half dnh; + + const auto saturated_a = detail::DemoteFromU64Saturate(dnh, a); + const auto saturated_b = detail::DemoteFromU64Saturate(dnh, b); + + return ConcatEven(dn, BitCast(dn, saturated_b), BitCast(dn, saturated_a)); +} + +template ), class V, + HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedDemote2To(D d, V a, V b) { + return ReorderDemote2To(d, a, b); +} + +// ------------------------------ ConvertTo + +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD{wasm_f32x4_convert_i32x4(v.raw)}; +} +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD{wasm_f32x4_convert_u32x4(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D dd, VFromD> v) { + // Based on wim's approach (https://stackoverflow.com/questions/41144668/) + const Repartition d32; + const Repartition d64; + + // Toggle MSB of lower 32-bits and insert exponent for 2^84 + 2^63 + const auto k84_63 = Set(d64, 0x4530000080000000ULL); + const auto v_upper = BitCast(dd, ShiftRight<32>(BitCast(d64, v)) ^ k84_63); + + // Exponent is 2^52, lower 32 bits from v (=> 32-bit OddEven) + const auto k52 = Set(d32, 0x43300000); + const auto v_lower = BitCast(dd, OddEven(k52, BitCast(d32, v))); + + const auto k84_63_52 = BitCast(dd, Set(d64, 0x4530000080100000ULL)); + return (v_upper - k84_63_52) + v_lower; // order matters! +} + +namespace detail { +template +HWY_INLINE VFromD>> U64ToF64VecFast(VW w) { + const DFromV d64; + const RebindToFloat dd; + const auto cnst2_52_dbl = Set(dd, 0x0010000000000000); // 2^52 + return BitCast(dd, Or(w, BitCast(d64, cnst2_52_dbl))) - cnst2_52_dbl; +} +} // namespace detail + +template +HWY_API VFromD ConvertTo(D dd, VFromD> v) { + // Based on wim's approach (https://stackoverflow.com/questions/41144668/) + const RebindToUnsigned d64; + using VU = VFromD; + + const VU msk_lo = Set(d64, 0xFFFFFFFF); + const auto cnst2_32_dbl = Set(dd, 4294967296.0); // 2^32 + + // Extract the 32 lowest/highest significant bits of v + const VU v_lo = And(v, msk_lo); + const VU v_hi = ShiftRight<32>(v); + + const auto v_lo_dbl = detail::U64ToF64VecFast(v_lo); + return MulAdd(cnst2_32_dbl, detail::U64ToF64VecFast(v_hi), v_lo_dbl); +} + +// Truncates (rounds toward zero). +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD{wasm_i32x4_trunc_sat_f32x4(v.raw)}; +} +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD{wasm_u32x4_trunc_sat_f32x4(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(DI di, VFromD> v) { + using VI = VFromD; + using MI = MFromD; + const RebindToUnsigned du; + using VU = VFromD; + const Repartition du16; + const VI k1075 = Set(di, 1075); // biased exponent of 2^52 + + // Exponent indicates whether the number can be represented as int64_t. + const VU biased_exp = ShiftRight<52>(BitCast(du, v)) & Set(du, 0x7FF); + const MI in_range = BitCast(di, biased_exp) < Set(di, 1086); + + // If we were to cap the exponent at 51 and add 2^52, the number would be in + // [2^52, 2^53) and mantissa bits could be read out directly. We need to + // round-to-0 (truncate). + // Use 16-bit saturated unsigned subtraction to compute shift_mnt and + // shift_int since biased_exp[i] is a non-negative integer that is less than + // or equal to 2047. + // The upper 48 bits of both shift_mnt and shift_int are guaranteed to be + // zero as the upper 48 bits of both k1075 and biased_exp are zero. + + const VU shift_mnt = BitCast( + du, SaturatedSub(BitCast(du16, k1075), BitCast(du16, biased_exp))); + const VU shift_int = BitCast( + du, SaturatedSub(BitCast(du16, biased_exp), BitCast(du16, k1075))); + const VU mantissa = BitCast(du, v) & Set(du, (1ULL << 52) - 1); + // Include implicit 1-bit + VU int53 = (mantissa | Set(du, 1ULL << 52)) >> shift_mnt; + // WASM clamps shift count; zero if greater. + const MI tiny = BitCast(di, shift_mnt) > Set(di, 63); + int53 = IfThenZeroElse(RebindMask(du, tiny), int53); + + // For inputs larger than 2^53 - 1, insert zeros at the bottom. + // For inputs less than 2^63, the implicit 1-bit is guaranteed not to be + // shifted out of the left shift result below as shift_int[i] <= 10 is true + // for any inputs that are less than 2^63. + const VU shifted = int53 << shift_int; + + // Saturate to LimitsMin (unchanged when negating below) or LimitsMax. + const VI sign_mask = BroadcastSignBit(BitCast(di, v)); + const VI limit = Set(di, LimitsMax()) - sign_mask; + const VI magnitude = IfThenElse(in_range, BitCast(di, shifted), limit); + + // If the input was negative, negate the integer (two's complement). + return (magnitude ^ sign_mask) - sign_mask; +} + +template +HWY_API VFromD ConvertTo(DU du, VFromD> v) { + const RebindToSigned di; + using MI = MFromD; + using VU = VFromD; + const Repartition du16; + const VU k1075 = Set(du, 1075); /* biased exponent of 2^52 */ + + const auto non_neg_v = ZeroIfNegative(v); + + // Exponent indicates whether the number can be represented as int64_t. + const VU biased_exp = ShiftRight<52>(BitCast(du, non_neg_v)); + const VU out_of_range = + BitCast(du, VecFromMask(di, BitCast(di, biased_exp) > Set(di, 1086))); + + // If we were to cap the exponent at 51 and add 2^52, the number would be in + // [2^52, 2^53) and mantissa bits could be read out directly. We need to + // round-to-0 (truncate), but changing rounding mode in MXCSR hits a + // compiler reordering bug: https://gcc.godbolt.org/z/4hKj6c6qc . We instead + // manually shift the mantissa into place (we already have many of the + // inputs anyway). + + // Use 16-bit saturated unsigned subtraction to compute shift_mnt and + // shift_int since biased_exp[i] is a non-negative integer that is less than + // or equal to 2047. + + // 16-bit saturated unsigned subtraction is also more efficient than a + // 64-bit subtraction followed by a 64-bit signed Max operation on + // WASM. + + // The upper 48 bits of both shift_mnt and shift_int are guaranteed to be + // zero as the upper 48 bits of both k1075 and biased_exp are zero. + + const VU shift_mnt = BitCast( + du, SaturatedSub(BitCast(du16, k1075), BitCast(du16, biased_exp))); + const VU shift_int = BitCast( + du, SaturatedSub(BitCast(du16, biased_exp), BitCast(du16, k1075))); + const VU mantissa = BitCast(du, non_neg_v) & Set(du, (1ULL << 52) - 1); + // Include implicit 1-bit. + VU int53 = (mantissa | Set(du, 1ULL << 52)) >> shift_mnt; + // WASM clamps shift count; zero if greater. + const MI tiny = BitCast(di, shift_mnt) > Set(di, 63); + int53 = IfThenZeroElse(RebindMask(du, tiny), int53); + + // For inputs larger than 2^53 - 1, insert zeros at the bottom. + + // For inputs less than 2^64, the implicit 1-bit is guaranteed not to be + // shifted out of the left shift result below as shift_int[i] <= 11 is true + // for any inputs that are less than 2^64. + + const VU shifted = int53 << shift_int; + return (shifted | out_of_range); +} + +// ------------------------------ NearestInt (Round) +template +HWY_API Vec128, N> NearestInt(const Vec128 v) { + return ConvertTo(RebindToSigned>(), Round(v)); +} + +// ------------------------------ DemoteToNearestInt (Round) +template +HWY_API VFromD DemoteToNearestInt(DI32 di32, + VFromD> v) { + // No single instruction, round then demote. + return DemoteTo(di32, Round(v)); +} + +// ================================================== MISC + +// ------------------------------ SumsOf8 (ShiftRight, Add) +template +HWY_API Vec128 SumsOf8(const Vec128 v) { + const DFromV du8; + const RepartitionToWide du16; + const RepartitionToWide du32; + const RepartitionToWide du64; + using VU16 = VFromD; + + const VU16 vFDB97531 = ShiftRight<8>(BitCast(du16, v)); + const VU16 vECA86420 = And(BitCast(du16, v), Set(du16, 0xFF)); + const VU16 sFE_DC_BA_98_76_54_32_10 = Add(vFDB97531, vECA86420); + + const VU16 szz_FE_zz_BA_zz_76_zz_32 = + BitCast(du16, ShiftRight<16>(BitCast(du32, sFE_DC_BA_98_76_54_32_10))); + const VU16 sxx_FC_xx_B8_xx_74_xx_30 = + Add(sFE_DC_BA_98_76_54_32_10, szz_FE_zz_BA_zz_76_zz_32); + const VU16 szz_zz_xx_FC_zz_zz_xx_74 = + BitCast(du16, ShiftRight<32>(BitCast(du64, sxx_FC_xx_B8_xx_74_xx_30))); + const VU16 sxx_xx_xx_F8_xx_xx_xx_70 = + Add(sxx_FC_xx_B8_xx_74_xx_30, szz_zz_xx_FC_zz_zz_xx_74); + return And(BitCast(du64, sxx_xx_xx_F8_xx_xx_xx_70), Set(du64, 0xFFFF)); +} + +template +HWY_API Vec128 SumsOf8(const Vec128 v) { + const DFromV di8; + const RepartitionToWide di16; + const RepartitionToWide di32; + const RepartitionToWide di64; + const RebindToUnsigned du32; + const RebindToUnsigned du64; + using VI16 = VFromD; + + const VI16 vFDB97531 = ShiftRight<8>(BitCast(di16, v)); + const VI16 vECA86420 = ShiftRight<8>(ShiftLeft<8>(BitCast(di16, v))); + const VI16 sFE_DC_BA_98_76_54_32_10 = Add(vFDB97531, vECA86420); + + const VI16 sDC_zz_98_zz_54_zz_10_zz = + BitCast(di16, ShiftLeft<16>(BitCast(du32, sFE_DC_BA_98_76_54_32_10))); + const VI16 sFC_xx_B8_xx_74_xx_30_xx = + Add(sFE_DC_BA_98_76_54_32_10, sDC_zz_98_zz_54_zz_10_zz); + const VI16 sB8_xx_zz_zz_30_xx_zz_zz = + BitCast(di16, ShiftLeft<32>(BitCast(du64, sFC_xx_B8_xx_74_xx_30_xx))); + const VI16 sF8_xx_xx_xx_70_xx_xx_xx = + Add(sFC_xx_B8_xx_74_xx_30_xx, sB8_xx_zz_zz_30_xx_zz_zz); + return ShiftRight<48>(BitCast(di64, sF8_xx_xx_xx_70_xx_xx_xx)); +} + +// ------------------------------ LoadMaskBits (TestBit) + +namespace detail { + +template +HWY_INLINE MFromD LoadMaskBits(D d, uint64_t bits) { + const RebindToUnsigned du; + // Easier than Set(), which would require an >8-bit type, which would not + // compile for T=uint8_t, N=1. + const VFromD vbits{wasm_i32x4_splat(static_cast(bits))}; + + // Replicate bytes 8x such that each byte contains the bit that governs it. + alignas(16) static constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1}; + const auto rep8 = TableLookupBytes(vbits, Load(du, kRep8)); + + alignas(16) static constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, + 1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); +} + +template +HWY_INLINE MFromD LoadMaskBits(D d, uint64_t bits) { + const RebindToUnsigned du; + alignas(16) static constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + return RebindMask( + d, TestBit(Set(du, static_cast(bits)), Load(du, kBit))); +} + +template +HWY_INLINE MFromD LoadMaskBits(D d, uint64_t bits) { + const RebindToUnsigned du; + alignas(16) static constexpr uint32_t kBit[8] = {1, 2, 4, 8}; + return RebindMask( + d, TestBit(Set(du, static_cast(bits)), Load(du, kBit))); +} + +template +HWY_INLINE MFromD LoadMaskBits(D d, uint64_t bits) { + const RebindToUnsigned du; + alignas(16) static constexpr uint64_t kBit[8] = {1, 2}; + return RebindMask(d, TestBit(Set(du, bits), Load(du, kBit))); +} + +} // namespace detail + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + CopyBytes<(MaxLanes(d) + 7) / 8>(bits, &mask_bits); + return detail::LoadMaskBits(d, mask_bits); +} + +// ------------------------------ Dup128MaskFromMaskBits + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= (1u << kN) - 1; + return detail::LoadMaskBits(d, mask_bits); +} + +// ------------------------------ Mask + +namespace detail { + +// Returns the lowest N bits for the BitsFromMask result. +template +constexpr uint64_t OnlyActive(D d, uint64_t bits) { + return (d.MaxBytes() == 16) ? bits : bits & ((1ull << d.MaxLanes()) - 1); +} + +} // namespace detail + +template +HWY_API uint64_t BitsFromMask(D /*d*/, const MFromD mask) { + alignas(16) uint64_t lanes[2]; + wasm_v128_store(lanes, mask.raw); + + constexpr uint64_t kMagic = 0x103070F1F3F80ULL; + const uint64_t lo = ((lanes[0] * kMagic) >> 56); + const uint64_t hi = ((lanes[1] * kMagic) >> 48) & 0xFF00; + return hi + lo; // exactly 16 bits, no OnlyActive required +} + +template +HWY_API uint64_t BitsFromMask(D /*d*/, const MFromD mask) { + constexpr uint64_t kMagic = 0x103070F1F3F80ULL; + const uint64_t bytes = + static_cast(wasm_i64x2_extract_lane(mask.raw, 0)); + return (bytes * kMagic) >> 56; // exactly 8 bits, no OnlyActive required +} + +// 32-bit or less: need masking +template +HWY_API uint64_t BitsFromMask(D d, const MFromD mask) { + uint64_t bytes = static_cast(wasm_i64x2_extract_lane(mask.raw, 0)); + // Clear potentially undefined bytes. + bytes &= (1ULL << (Lanes(d) * 8)) - 1; + constexpr uint64_t kMagic = 0x103070F1F3F80ULL; + return detail::OnlyActive(d, (bytes * kMagic) >> 56); +} + +template +HWY_API uint64_t BitsFromMask(D /*d*/, const MFromD mask) { + // Remove useless lower half of each u16 while preserving the sign bit. + const Rebind d8; + using M8 = MFromD; + const __i16x8 zero = wasm_i16x8_splat(0); + const M8 mask8{wasm_i8x16_narrow_i16x8(mask.raw, zero)}; + return detail::OnlyActive(d8, BitsFromMask(d8, mask8)); +} + +template +HWY_API uint64_t BitsFromMask(D d, const MFromD mask) { + const __i32x4 mask_i = static_cast<__i32x4>(mask.raw); + const __i32x4 slice = wasm_i32x4_make(1, 2, 4, 8); + const __i32x4 sliced_mask = wasm_v128_and(mask_i, slice); + alignas(16) uint32_t lanes[4]; + wasm_v128_store(lanes, sliced_mask); + return detail::OnlyActive(d, lanes[0] | lanes[1] | lanes[2] | lanes[3]); +} + +template +HWY_API uint64_t BitsFromMask(D d, const MFromD mask) { + const __i64x2 mask_i = static_cast<__i64x2>(mask.raw); + const __i64x2 slice = wasm_i64x2_make(1, 2); + const __i64x2 sliced_mask = wasm_v128_and(mask_i, slice); + alignas(16) uint64_t lanes[2]; + wasm_v128_store(lanes, sliced_mask); + return detail::OnlyActive(d, lanes[0] | lanes[1]); +} + +namespace detail { + +// Returns 0xFF for bytes with index >= N, otherwise 0. +template +constexpr __i8x16 BytesAbove() { + return /**/ + (N == 0) ? wasm_i32x4_make(-1, -1, -1, -1) + : (N == 4) ? wasm_i32x4_make(0, -1, -1, -1) + : (N == 8) ? wasm_i32x4_make(0, 0, -1, -1) + : (N == 12) ? wasm_i32x4_make(0, 0, 0, -1) + : (N == 16) ? wasm_i32x4_make(0, 0, 0, 0) + : (N == 2) ? wasm_i16x8_make(0, -1, -1, -1, -1, -1, -1, -1) + : (N == 6) ? wasm_i16x8_make(0, 0, 0, -1, -1, -1, -1, -1) + : (N == 10) ? wasm_i16x8_make(0, 0, 0, 0, 0, -1, -1, -1) + : (N == 14) ? wasm_i16x8_make(0, 0, 0, 0, 0, 0, 0, -1) + : (N == 1) ? wasm_i8x16_make(0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1) + : (N == 3) ? wasm_i8x16_make(0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1) + : (N == 5) ? wasm_i8x16_make(0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1) + : (N == 7) ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, + -1, -1, -1) + : (N == 9) ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, + -1, -1, -1) + : (N == 11) + ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1) + : (N == 13) + ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1) + : wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1); +} + +} // namespace detail + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(D d, const MFromD mask, uint8_t* bits) { + const uint64_t mask_bits = BitsFromMask(d, mask); + const size_t kNumBytes = (d.MaxLanes() + 7) / 8; + CopyBytes(&mask_bits, bits); + return kNumBytes; +} + +template +HWY_API size_t CountTrue(D d, const MFromD m) { + return PopCount(BitsFromMask(d, m)); +} +template +HWY_API size_t CountTrue(D d, const MFromD m) { + return PopCount(BitsFromMask(d, m)); +} +template +HWY_API size_t CountTrue(D /*d*/, const MFromD m) { + const __i32x4 var_shift = wasm_i32x4_make(1, 2, 4, 8); + const __i32x4 shifted_bits = wasm_v128_and(m.raw, var_shift); + alignas(16) uint64_t lanes[2]; + wasm_v128_store(lanes, shifted_bits); + return PopCount(lanes[0] | lanes[1]); +} +template +HWY_API size_t CountTrue(D /*d*/, const MFromD m) { + alignas(16) int64_t lanes[2]; + wasm_v128_store(lanes, m.raw); + return static_cast(-(lanes[0] + lanes[1])); +} + +// Partial +template , HWY_IF_V_SIZE_LE_D(D, 8)> +HWY_API size_t CountTrue(D d, MFromD m) { + // Ensure all undefined bytes are 0. + const MFromD mask{detail::BytesAbove()}; + const Full128 dfull; + return CountTrue(dfull, Mask128{AndNot(mask, m).raw}); +} + +// Full vector +template +HWY_API bool AllFalse(D d, const MFromD m) { + const auto v8 = BitCast(Full128(), VecFromMask(d, m)); + return !wasm_v128_any_true(v8.raw); +} + +// Full vector +namespace detail { +template +HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask128 m) { + return wasm_i8x16_all_true(m.raw); +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask128 m) { + return wasm_i16x8_all_true(m.raw); +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask128 m) { + return wasm_i32x4_all_true(m.raw); +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask128 m) { + return wasm_i64x2_all_true(m.raw); +} + +} // namespace detail + +template > +HWY_API bool AllTrue(D /* tag */, const Mask128 m) { + return detail::AllTrue(hwy::SizeTag(), m); +} + +// Partial vectors + +template , HWY_IF_V_SIZE_LE_D(D, 8)> +HWY_API bool AllFalse(D d, const MFromD m) { + // Ensure all undefined bytes are 0. + const MFromD mask{detail::BytesAbove()}; + return AllFalse(Full128(), Mask128{AndNot(mask, m).raw}); +} + +template , HWY_IF_V_SIZE_LE_D(D, 8)> +HWY_API bool AllTrue(D d, const MFromD m) { + // Ensure all undefined bytes are FF. + const MFromD mask{detail::BytesAbove()}; + return AllTrue(Full128(), Mask128{Or(mask, m).raw}); +} + +template +HWY_API size_t FindKnownFirstTrue(D d, const MFromD mask) { + const uint32_t bits = static_cast(BitsFromMask(d, mask)); + return Num0BitsBelowLS1Bit_Nonzero32(bits); +} + +template +HWY_API intptr_t FindFirstTrue(D d, const MFromD mask) { + const uint32_t bits = static_cast(BitsFromMask(d, mask)); + return bits ? static_cast(Num0BitsBelowLS1Bit_Nonzero32(bits)) : -1; +} + +template +HWY_API size_t FindKnownLastTrue(D d, const MFromD mask) { + const uint32_t bits = static_cast(BitsFromMask(d, mask)); + return 31 - Num0BitsAboveMS1Bit_Nonzero32(bits); +} + +template +HWY_API intptr_t FindLastTrue(D d, const MFromD mask) { + const uint32_t bits = static_cast(BitsFromMask(d, mask)); + return bits + ? (31 - static_cast(Num0BitsAboveMS1Bit_Nonzero32(bits))) + : -1; +} + +// ------------------------------ Compress + +namespace detail { + +template +HWY_INLINE Vec128 IdxFromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd d; + const Rebind d8; + const Simd du; + + // We need byte indices for TableLookupBytes (one vector's worth for each of + // 256 combinations of 8 mask bits). Loading them directly requires 4 KiB. We + // can instead store lane indices and convert to byte indices (2*lane + 0..1), + // with the doubling baked into the table. Unpacking nibbles is likely more + // costly than the higher cache footprint from storing bytes. + alignas(16) static constexpr uint8_t table[256 * 8] = { + // PrintCompress16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 2, 0, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 4, 0, 2, 6, 8, 10, 12, 14, /**/ 0, 4, 2, 6, 8, 10, 12, 14, // + 2, 4, 0, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 6, 0, 2, 4, 8, 10, 12, 14, /**/ 0, 6, 2, 4, 8, 10, 12, 14, // + 2, 6, 0, 4, 8, 10, 12, 14, /**/ 0, 2, 6, 4, 8, 10, 12, 14, // + 4, 6, 0, 2, 8, 10, 12, 14, /**/ 0, 4, 6, 2, 8, 10, 12, 14, // + 2, 4, 6, 0, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 8, 0, 2, 4, 6, 10, 12, 14, /**/ 0, 8, 2, 4, 6, 10, 12, 14, // + 2, 8, 0, 4, 6, 10, 12, 14, /**/ 0, 2, 8, 4, 6, 10, 12, 14, // + 4, 8, 0, 2, 6, 10, 12, 14, /**/ 0, 4, 8, 2, 6, 10, 12, 14, // + 2, 4, 8, 0, 6, 10, 12, 14, /**/ 0, 2, 4, 8, 6, 10, 12, 14, // + 6, 8, 0, 2, 4, 10, 12, 14, /**/ 0, 6, 8, 2, 4, 10, 12, 14, // + 2, 6, 8, 0, 4, 10, 12, 14, /**/ 0, 2, 6, 8, 4, 10, 12, 14, // + 4, 6, 8, 0, 2, 10, 12, 14, /**/ 0, 4, 6, 8, 2, 10, 12, 14, // + 2, 4, 6, 8, 0, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 10, 0, 2, 4, 6, 8, 12, 14, /**/ 0, 10, 2, 4, 6, 8, 12, 14, // + 2, 10, 0, 4, 6, 8, 12, 14, /**/ 0, 2, 10, 4, 6, 8, 12, 14, // + 4, 10, 0, 2, 6, 8, 12, 14, /**/ 0, 4, 10, 2, 6, 8, 12, 14, // + 2, 4, 10, 0, 6, 8, 12, 14, /**/ 0, 2, 4, 10, 6, 8, 12, 14, // + 6, 10, 0, 2, 4, 8, 12, 14, /**/ 0, 6, 10, 2, 4, 8, 12, 14, // + 2, 6, 10, 0, 4, 8, 12, 14, /**/ 0, 2, 6, 10, 4, 8, 12, 14, // + 4, 6, 10, 0, 2, 8, 12, 14, /**/ 0, 4, 6, 10, 2, 8, 12, 14, // + 2, 4, 6, 10, 0, 8, 12, 14, /**/ 0, 2, 4, 6, 10, 8, 12, 14, // + 8, 10, 0, 2, 4, 6, 12, 14, /**/ 0, 8, 10, 2, 4, 6, 12, 14, // + 2, 8, 10, 0, 4, 6, 12, 14, /**/ 0, 2, 8, 10, 4, 6, 12, 14, // + 4, 8, 10, 0, 2, 6, 12, 14, /**/ 0, 4, 8, 10, 2, 6, 12, 14, // + 2, 4, 8, 10, 0, 6, 12, 14, /**/ 0, 2, 4, 8, 10, 6, 12, 14, // + 6, 8, 10, 0, 2, 4, 12, 14, /**/ 0, 6, 8, 10, 2, 4, 12, 14, // + 2, 6, 8, 10, 0, 4, 12, 14, /**/ 0, 2, 6, 8, 10, 4, 12, 14, // + 4, 6, 8, 10, 0, 2, 12, 14, /**/ 0, 4, 6, 8, 10, 2, 12, 14, // + 2, 4, 6, 8, 10, 0, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 12, 0, 2, 4, 6, 8, 10, 14, /**/ 0, 12, 2, 4, 6, 8, 10, 14, // + 2, 12, 0, 4, 6, 8, 10, 14, /**/ 0, 2, 12, 4, 6, 8, 10, 14, // + 4, 12, 0, 2, 6, 8, 10, 14, /**/ 0, 4, 12, 2, 6, 8, 10, 14, // + 2, 4, 12, 0, 6, 8, 10, 14, /**/ 0, 2, 4, 12, 6, 8, 10, 14, // + 6, 12, 0, 2, 4, 8, 10, 14, /**/ 0, 6, 12, 2, 4, 8, 10, 14, // + 2, 6, 12, 0, 4, 8, 10, 14, /**/ 0, 2, 6, 12, 4, 8, 10, 14, // + 4, 6, 12, 0, 2, 8, 10, 14, /**/ 0, 4, 6, 12, 2, 8, 10, 14, // + 2, 4, 6, 12, 0, 8, 10, 14, /**/ 0, 2, 4, 6, 12, 8, 10, 14, // + 8, 12, 0, 2, 4, 6, 10, 14, /**/ 0, 8, 12, 2, 4, 6, 10, 14, // + 2, 8, 12, 0, 4, 6, 10, 14, /**/ 0, 2, 8, 12, 4, 6, 10, 14, // + 4, 8, 12, 0, 2, 6, 10, 14, /**/ 0, 4, 8, 12, 2, 6, 10, 14, // + 2, 4, 8, 12, 0, 6, 10, 14, /**/ 0, 2, 4, 8, 12, 6, 10, 14, // + 6, 8, 12, 0, 2, 4, 10, 14, /**/ 0, 6, 8, 12, 2, 4, 10, 14, // + 2, 6, 8, 12, 0, 4, 10, 14, /**/ 0, 2, 6, 8, 12, 4, 10, 14, // + 4, 6, 8, 12, 0, 2, 10, 14, /**/ 0, 4, 6, 8, 12, 2, 10, 14, // + 2, 4, 6, 8, 12, 0, 10, 14, /**/ 0, 2, 4, 6, 8, 12, 10, 14, // + 10, 12, 0, 2, 4, 6, 8, 14, /**/ 0, 10, 12, 2, 4, 6, 8, 14, // + 2, 10, 12, 0, 4, 6, 8, 14, /**/ 0, 2, 10, 12, 4, 6, 8, 14, // + 4, 10, 12, 0, 2, 6, 8, 14, /**/ 0, 4, 10, 12, 2, 6, 8, 14, // + 2, 4, 10, 12, 0, 6, 8, 14, /**/ 0, 2, 4, 10, 12, 6, 8, 14, // + 6, 10, 12, 0, 2, 4, 8, 14, /**/ 0, 6, 10, 12, 2, 4, 8, 14, // + 2, 6, 10, 12, 0, 4, 8, 14, /**/ 0, 2, 6, 10, 12, 4, 8, 14, // + 4, 6, 10, 12, 0, 2, 8, 14, /**/ 0, 4, 6, 10, 12, 2, 8, 14, // + 2, 4, 6, 10, 12, 0, 8, 14, /**/ 0, 2, 4, 6, 10, 12, 8, 14, // + 8, 10, 12, 0, 2, 4, 6, 14, /**/ 0, 8, 10, 12, 2, 4, 6, 14, // + 2, 8, 10, 12, 0, 4, 6, 14, /**/ 0, 2, 8, 10, 12, 4, 6, 14, // + 4, 8, 10, 12, 0, 2, 6, 14, /**/ 0, 4, 8, 10, 12, 2, 6, 14, // + 2, 4, 8, 10, 12, 0, 6, 14, /**/ 0, 2, 4, 8, 10, 12, 6, 14, // + 6, 8, 10, 12, 0, 2, 4, 14, /**/ 0, 6, 8, 10, 12, 2, 4, 14, // + 2, 6, 8, 10, 12, 0, 4, 14, /**/ 0, 2, 6, 8, 10, 12, 4, 14, // + 4, 6, 8, 10, 12, 0, 2, 14, /**/ 0, 4, 6, 8, 10, 12, 2, 14, // + 2, 4, 6, 8, 10, 12, 0, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 14, 0, 2, 4, 6, 8, 10, 12, /**/ 0, 14, 2, 4, 6, 8, 10, 12, // + 2, 14, 0, 4, 6, 8, 10, 12, /**/ 0, 2, 14, 4, 6, 8, 10, 12, // + 4, 14, 0, 2, 6, 8, 10, 12, /**/ 0, 4, 14, 2, 6, 8, 10, 12, // + 2, 4, 14, 0, 6, 8, 10, 12, /**/ 0, 2, 4, 14, 6, 8, 10, 12, // + 6, 14, 0, 2, 4, 8, 10, 12, /**/ 0, 6, 14, 2, 4, 8, 10, 12, // + 2, 6, 14, 0, 4, 8, 10, 12, /**/ 0, 2, 6, 14, 4, 8, 10, 12, // + 4, 6, 14, 0, 2, 8, 10, 12, /**/ 0, 4, 6, 14, 2, 8, 10, 12, // + 2, 4, 6, 14, 0, 8, 10, 12, /**/ 0, 2, 4, 6, 14, 8, 10, 12, // + 8, 14, 0, 2, 4, 6, 10, 12, /**/ 0, 8, 14, 2, 4, 6, 10, 12, // + 2, 8, 14, 0, 4, 6, 10, 12, /**/ 0, 2, 8, 14, 4, 6, 10, 12, // + 4, 8, 14, 0, 2, 6, 10, 12, /**/ 0, 4, 8, 14, 2, 6, 10, 12, // + 2, 4, 8, 14, 0, 6, 10, 12, /**/ 0, 2, 4, 8, 14, 6, 10, 12, // + 6, 8, 14, 0, 2, 4, 10, 12, /**/ 0, 6, 8, 14, 2, 4, 10, 12, // + 2, 6, 8, 14, 0, 4, 10, 12, /**/ 0, 2, 6, 8, 14, 4, 10, 12, // + 4, 6, 8, 14, 0, 2, 10, 12, /**/ 0, 4, 6, 8, 14, 2, 10, 12, // + 2, 4, 6, 8, 14, 0, 10, 12, /**/ 0, 2, 4, 6, 8, 14, 10, 12, // + 10, 14, 0, 2, 4, 6, 8, 12, /**/ 0, 10, 14, 2, 4, 6, 8, 12, // + 2, 10, 14, 0, 4, 6, 8, 12, /**/ 0, 2, 10, 14, 4, 6, 8, 12, // + 4, 10, 14, 0, 2, 6, 8, 12, /**/ 0, 4, 10, 14, 2, 6, 8, 12, // + 2, 4, 10, 14, 0, 6, 8, 12, /**/ 0, 2, 4, 10, 14, 6, 8, 12, // + 6, 10, 14, 0, 2, 4, 8, 12, /**/ 0, 6, 10, 14, 2, 4, 8, 12, // + 2, 6, 10, 14, 0, 4, 8, 12, /**/ 0, 2, 6, 10, 14, 4, 8, 12, // + 4, 6, 10, 14, 0, 2, 8, 12, /**/ 0, 4, 6, 10, 14, 2, 8, 12, // + 2, 4, 6, 10, 14, 0, 8, 12, /**/ 0, 2, 4, 6, 10, 14, 8, 12, // + 8, 10, 14, 0, 2, 4, 6, 12, /**/ 0, 8, 10, 14, 2, 4, 6, 12, // + 2, 8, 10, 14, 0, 4, 6, 12, /**/ 0, 2, 8, 10, 14, 4, 6, 12, // + 4, 8, 10, 14, 0, 2, 6, 12, /**/ 0, 4, 8, 10, 14, 2, 6, 12, // + 2, 4, 8, 10, 14, 0, 6, 12, /**/ 0, 2, 4, 8, 10, 14, 6, 12, // + 6, 8, 10, 14, 0, 2, 4, 12, /**/ 0, 6, 8, 10, 14, 2, 4, 12, // + 2, 6, 8, 10, 14, 0, 4, 12, /**/ 0, 2, 6, 8, 10, 14, 4, 12, // + 4, 6, 8, 10, 14, 0, 2, 12, /**/ 0, 4, 6, 8, 10, 14, 2, 12, // + 2, 4, 6, 8, 10, 14, 0, 12, /**/ 0, 2, 4, 6, 8, 10, 14, 12, // + 12, 14, 0, 2, 4, 6, 8, 10, /**/ 0, 12, 14, 2, 4, 6, 8, 10, // + 2, 12, 14, 0, 4, 6, 8, 10, /**/ 0, 2, 12, 14, 4, 6, 8, 10, // + 4, 12, 14, 0, 2, 6, 8, 10, /**/ 0, 4, 12, 14, 2, 6, 8, 10, // + 2, 4, 12, 14, 0, 6, 8, 10, /**/ 0, 2, 4, 12, 14, 6, 8, 10, // + 6, 12, 14, 0, 2, 4, 8, 10, /**/ 0, 6, 12, 14, 2, 4, 8, 10, // + 2, 6, 12, 14, 0, 4, 8, 10, /**/ 0, 2, 6, 12, 14, 4, 8, 10, // + 4, 6, 12, 14, 0, 2, 8, 10, /**/ 0, 4, 6, 12, 14, 2, 8, 10, // + 2, 4, 6, 12, 14, 0, 8, 10, /**/ 0, 2, 4, 6, 12, 14, 8, 10, // + 8, 12, 14, 0, 2, 4, 6, 10, /**/ 0, 8, 12, 14, 2, 4, 6, 10, // + 2, 8, 12, 14, 0, 4, 6, 10, /**/ 0, 2, 8, 12, 14, 4, 6, 10, // + 4, 8, 12, 14, 0, 2, 6, 10, /**/ 0, 4, 8, 12, 14, 2, 6, 10, // + 2, 4, 8, 12, 14, 0, 6, 10, /**/ 0, 2, 4, 8, 12, 14, 6, 10, // + 6, 8, 12, 14, 0, 2, 4, 10, /**/ 0, 6, 8, 12, 14, 2, 4, 10, // + 2, 6, 8, 12, 14, 0, 4, 10, /**/ 0, 2, 6, 8, 12, 14, 4, 10, // + 4, 6, 8, 12, 14, 0, 2, 10, /**/ 0, 4, 6, 8, 12, 14, 2, 10, // + 2, 4, 6, 8, 12, 14, 0, 10, /**/ 0, 2, 4, 6, 8, 12, 14, 10, // + 10, 12, 14, 0, 2, 4, 6, 8, /**/ 0, 10, 12, 14, 2, 4, 6, 8, // + 2, 10, 12, 14, 0, 4, 6, 8, /**/ 0, 2, 10, 12, 14, 4, 6, 8, // + 4, 10, 12, 14, 0, 2, 6, 8, /**/ 0, 4, 10, 12, 14, 2, 6, 8, // + 2, 4, 10, 12, 14, 0, 6, 8, /**/ 0, 2, 4, 10, 12, 14, 6, 8, // + 6, 10, 12, 14, 0, 2, 4, 8, /**/ 0, 6, 10, 12, 14, 2, 4, 8, // + 2, 6, 10, 12, 14, 0, 4, 8, /**/ 0, 2, 6, 10, 12, 14, 4, 8, // + 4, 6, 10, 12, 14, 0, 2, 8, /**/ 0, 4, 6, 10, 12, 14, 2, 8, // + 2, 4, 6, 10, 12, 14, 0, 8, /**/ 0, 2, 4, 6, 10, 12, 14, 8, // + 8, 10, 12, 14, 0, 2, 4, 6, /**/ 0, 8, 10, 12, 14, 2, 4, 6, // + 2, 8, 10, 12, 14, 0, 4, 6, /**/ 0, 2, 8, 10, 12, 14, 4, 6, // + 4, 8, 10, 12, 14, 0, 2, 6, /**/ 0, 4, 8, 10, 12, 14, 2, 6, // + 2, 4, 8, 10, 12, 14, 0, 6, /**/ 0, 2, 4, 8, 10, 12, 14, 6, // + 6, 8, 10, 12, 14, 0, 2, 4, /**/ 0, 6, 8, 10, 12, 14, 2, 4, // + 2, 6, 8, 10, 12, 14, 0, 4, /**/ 0, 2, 6, 8, 10, 12, 14, 4, // + 4, 6, 8, 10, 12, 14, 0, 2, /**/ 0, 4, 6, 8, 10, 12, 14, 2, // + 2, 4, 6, 8, 10, 12, 14, 0, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx{Load(d8, table + mask_bits * 8).raw}; + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Simd d; + const Rebind d8; + const Simd du; + + // We need byte indices for TableLookupBytes (one vector's worth for each of + // 256 combinations of 8 mask bits). Loading them directly requires 4 KiB. We + // can instead store lane indices and convert to byte indices (2*lane + 0..1), + // with the doubling baked into the table. Unpacking nibbles is likely more + // costly than the higher cache footprint from storing bytes. + alignas(16) static constexpr uint8_t table[256 * 8] = { + // PrintCompressNot16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 14, 0, // + 0, 4, 6, 8, 10, 12, 14, 2, /**/ 4, 6, 8, 10, 12, 14, 0, 2, // + 0, 2, 6, 8, 10, 12, 14, 4, /**/ 2, 6, 8, 10, 12, 14, 0, 4, // + 0, 6, 8, 10, 12, 14, 2, 4, /**/ 6, 8, 10, 12, 14, 0, 2, 4, // + 0, 2, 4, 8, 10, 12, 14, 6, /**/ 2, 4, 8, 10, 12, 14, 0, 6, // + 0, 4, 8, 10, 12, 14, 2, 6, /**/ 4, 8, 10, 12, 14, 0, 2, 6, // + 0, 2, 8, 10, 12, 14, 4, 6, /**/ 2, 8, 10, 12, 14, 0, 4, 6, // + 0, 8, 10, 12, 14, 2, 4, 6, /**/ 8, 10, 12, 14, 0, 2, 4, 6, // + 0, 2, 4, 6, 10, 12, 14, 8, /**/ 2, 4, 6, 10, 12, 14, 0, 8, // + 0, 4, 6, 10, 12, 14, 2, 8, /**/ 4, 6, 10, 12, 14, 0, 2, 8, // + 0, 2, 6, 10, 12, 14, 4, 8, /**/ 2, 6, 10, 12, 14, 0, 4, 8, // + 0, 6, 10, 12, 14, 2, 4, 8, /**/ 6, 10, 12, 14, 0, 2, 4, 8, // + 0, 2, 4, 10, 12, 14, 6, 8, /**/ 2, 4, 10, 12, 14, 0, 6, 8, // + 0, 4, 10, 12, 14, 2, 6, 8, /**/ 4, 10, 12, 14, 0, 2, 6, 8, // + 0, 2, 10, 12, 14, 4, 6, 8, /**/ 2, 10, 12, 14, 0, 4, 6, 8, // + 0, 10, 12, 14, 2, 4, 6, 8, /**/ 10, 12, 14, 0, 2, 4, 6, 8, // + 0, 2, 4, 6, 8, 12, 14, 10, /**/ 2, 4, 6, 8, 12, 14, 0, 10, // + 0, 4, 6, 8, 12, 14, 2, 10, /**/ 4, 6, 8, 12, 14, 0, 2, 10, // + 0, 2, 6, 8, 12, 14, 4, 10, /**/ 2, 6, 8, 12, 14, 0, 4, 10, // + 0, 6, 8, 12, 14, 2, 4, 10, /**/ 6, 8, 12, 14, 0, 2, 4, 10, // + 0, 2, 4, 8, 12, 14, 6, 10, /**/ 2, 4, 8, 12, 14, 0, 6, 10, // + 0, 4, 8, 12, 14, 2, 6, 10, /**/ 4, 8, 12, 14, 0, 2, 6, 10, // + 0, 2, 8, 12, 14, 4, 6, 10, /**/ 2, 8, 12, 14, 0, 4, 6, 10, // + 0, 8, 12, 14, 2, 4, 6, 10, /**/ 8, 12, 14, 0, 2, 4, 6, 10, // + 0, 2, 4, 6, 12, 14, 8, 10, /**/ 2, 4, 6, 12, 14, 0, 8, 10, // + 0, 4, 6, 12, 14, 2, 8, 10, /**/ 4, 6, 12, 14, 0, 2, 8, 10, // + 0, 2, 6, 12, 14, 4, 8, 10, /**/ 2, 6, 12, 14, 0, 4, 8, 10, // + 0, 6, 12, 14, 2, 4, 8, 10, /**/ 6, 12, 14, 0, 2, 4, 8, 10, // + 0, 2, 4, 12, 14, 6, 8, 10, /**/ 2, 4, 12, 14, 0, 6, 8, 10, // + 0, 4, 12, 14, 2, 6, 8, 10, /**/ 4, 12, 14, 0, 2, 6, 8, 10, // + 0, 2, 12, 14, 4, 6, 8, 10, /**/ 2, 12, 14, 0, 4, 6, 8, 10, // + 0, 12, 14, 2, 4, 6, 8, 10, /**/ 12, 14, 0, 2, 4, 6, 8, 10, // + 0, 2, 4, 6, 8, 10, 14, 12, /**/ 2, 4, 6, 8, 10, 14, 0, 12, // + 0, 4, 6, 8, 10, 14, 2, 12, /**/ 4, 6, 8, 10, 14, 0, 2, 12, // + 0, 2, 6, 8, 10, 14, 4, 12, /**/ 2, 6, 8, 10, 14, 0, 4, 12, // + 0, 6, 8, 10, 14, 2, 4, 12, /**/ 6, 8, 10, 14, 0, 2, 4, 12, // + 0, 2, 4, 8, 10, 14, 6, 12, /**/ 2, 4, 8, 10, 14, 0, 6, 12, // + 0, 4, 8, 10, 14, 2, 6, 12, /**/ 4, 8, 10, 14, 0, 2, 6, 12, // + 0, 2, 8, 10, 14, 4, 6, 12, /**/ 2, 8, 10, 14, 0, 4, 6, 12, // + 0, 8, 10, 14, 2, 4, 6, 12, /**/ 8, 10, 14, 0, 2, 4, 6, 12, // + 0, 2, 4, 6, 10, 14, 8, 12, /**/ 2, 4, 6, 10, 14, 0, 8, 12, // + 0, 4, 6, 10, 14, 2, 8, 12, /**/ 4, 6, 10, 14, 0, 2, 8, 12, // + 0, 2, 6, 10, 14, 4, 8, 12, /**/ 2, 6, 10, 14, 0, 4, 8, 12, // + 0, 6, 10, 14, 2, 4, 8, 12, /**/ 6, 10, 14, 0, 2, 4, 8, 12, // + 0, 2, 4, 10, 14, 6, 8, 12, /**/ 2, 4, 10, 14, 0, 6, 8, 12, // + 0, 4, 10, 14, 2, 6, 8, 12, /**/ 4, 10, 14, 0, 2, 6, 8, 12, // + 0, 2, 10, 14, 4, 6, 8, 12, /**/ 2, 10, 14, 0, 4, 6, 8, 12, // + 0, 10, 14, 2, 4, 6, 8, 12, /**/ 10, 14, 0, 2, 4, 6, 8, 12, // + 0, 2, 4, 6, 8, 14, 10, 12, /**/ 2, 4, 6, 8, 14, 0, 10, 12, // + 0, 4, 6, 8, 14, 2, 10, 12, /**/ 4, 6, 8, 14, 0, 2, 10, 12, // + 0, 2, 6, 8, 14, 4, 10, 12, /**/ 2, 6, 8, 14, 0, 4, 10, 12, // + 0, 6, 8, 14, 2, 4, 10, 12, /**/ 6, 8, 14, 0, 2, 4, 10, 12, // + 0, 2, 4, 8, 14, 6, 10, 12, /**/ 2, 4, 8, 14, 0, 6, 10, 12, // + 0, 4, 8, 14, 2, 6, 10, 12, /**/ 4, 8, 14, 0, 2, 6, 10, 12, // + 0, 2, 8, 14, 4, 6, 10, 12, /**/ 2, 8, 14, 0, 4, 6, 10, 12, // + 0, 8, 14, 2, 4, 6, 10, 12, /**/ 8, 14, 0, 2, 4, 6, 10, 12, // + 0, 2, 4, 6, 14, 8, 10, 12, /**/ 2, 4, 6, 14, 0, 8, 10, 12, // + 0, 4, 6, 14, 2, 8, 10, 12, /**/ 4, 6, 14, 0, 2, 8, 10, 12, // + 0, 2, 6, 14, 4, 8, 10, 12, /**/ 2, 6, 14, 0, 4, 8, 10, 12, // + 0, 6, 14, 2, 4, 8, 10, 12, /**/ 6, 14, 0, 2, 4, 8, 10, 12, // + 0, 2, 4, 14, 6, 8, 10, 12, /**/ 2, 4, 14, 0, 6, 8, 10, 12, // + 0, 4, 14, 2, 6, 8, 10, 12, /**/ 4, 14, 0, 2, 6, 8, 10, 12, // + 0, 2, 14, 4, 6, 8, 10, 12, /**/ 2, 14, 0, 4, 6, 8, 10, 12, // + 0, 14, 2, 4, 6, 8, 10, 12, /**/ 14, 0, 2, 4, 6, 8, 10, 12, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 0, 14, // + 0, 4, 6, 8, 10, 12, 2, 14, /**/ 4, 6, 8, 10, 12, 0, 2, 14, // + 0, 2, 6, 8, 10, 12, 4, 14, /**/ 2, 6, 8, 10, 12, 0, 4, 14, // + 0, 6, 8, 10, 12, 2, 4, 14, /**/ 6, 8, 10, 12, 0, 2, 4, 14, // + 0, 2, 4, 8, 10, 12, 6, 14, /**/ 2, 4, 8, 10, 12, 0, 6, 14, // + 0, 4, 8, 10, 12, 2, 6, 14, /**/ 4, 8, 10, 12, 0, 2, 6, 14, // + 0, 2, 8, 10, 12, 4, 6, 14, /**/ 2, 8, 10, 12, 0, 4, 6, 14, // + 0, 8, 10, 12, 2, 4, 6, 14, /**/ 8, 10, 12, 0, 2, 4, 6, 14, // + 0, 2, 4, 6, 10, 12, 8, 14, /**/ 2, 4, 6, 10, 12, 0, 8, 14, // + 0, 4, 6, 10, 12, 2, 8, 14, /**/ 4, 6, 10, 12, 0, 2, 8, 14, // + 0, 2, 6, 10, 12, 4, 8, 14, /**/ 2, 6, 10, 12, 0, 4, 8, 14, // + 0, 6, 10, 12, 2, 4, 8, 14, /**/ 6, 10, 12, 0, 2, 4, 8, 14, // + 0, 2, 4, 10, 12, 6, 8, 14, /**/ 2, 4, 10, 12, 0, 6, 8, 14, // + 0, 4, 10, 12, 2, 6, 8, 14, /**/ 4, 10, 12, 0, 2, 6, 8, 14, // + 0, 2, 10, 12, 4, 6, 8, 14, /**/ 2, 10, 12, 0, 4, 6, 8, 14, // + 0, 10, 12, 2, 4, 6, 8, 14, /**/ 10, 12, 0, 2, 4, 6, 8, 14, // + 0, 2, 4, 6, 8, 12, 10, 14, /**/ 2, 4, 6, 8, 12, 0, 10, 14, // + 0, 4, 6, 8, 12, 2, 10, 14, /**/ 4, 6, 8, 12, 0, 2, 10, 14, // + 0, 2, 6, 8, 12, 4, 10, 14, /**/ 2, 6, 8, 12, 0, 4, 10, 14, // + 0, 6, 8, 12, 2, 4, 10, 14, /**/ 6, 8, 12, 0, 2, 4, 10, 14, // + 0, 2, 4, 8, 12, 6, 10, 14, /**/ 2, 4, 8, 12, 0, 6, 10, 14, // + 0, 4, 8, 12, 2, 6, 10, 14, /**/ 4, 8, 12, 0, 2, 6, 10, 14, // + 0, 2, 8, 12, 4, 6, 10, 14, /**/ 2, 8, 12, 0, 4, 6, 10, 14, // + 0, 8, 12, 2, 4, 6, 10, 14, /**/ 8, 12, 0, 2, 4, 6, 10, 14, // + 0, 2, 4, 6, 12, 8, 10, 14, /**/ 2, 4, 6, 12, 0, 8, 10, 14, // + 0, 4, 6, 12, 2, 8, 10, 14, /**/ 4, 6, 12, 0, 2, 8, 10, 14, // + 0, 2, 6, 12, 4, 8, 10, 14, /**/ 2, 6, 12, 0, 4, 8, 10, 14, // + 0, 6, 12, 2, 4, 8, 10, 14, /**/ 6, 12, 0, 2, 4, 8, 10, 14, // + 0, 2, 4, 12, 6, 8, 10, 14, /**/ 2, 4, 12, 0, 6, 8, 10, 14, // + 0, 4, 12, 2, 6, 8, 10, 14, /**/ 4, 12, 0, 2, 6, 8, 10, 14, // + 0, 2, 12, 4, 6, 8, 10, 14, /**/ 2, 12, 0, 4, 6, 8, 10, 14, // + 0, 12, 2, 4, 6, 8, 10, 14, /**/ 12, 0, 2, 4, 6, 8, 10, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 0, 12, 14, // + 0, 4, 6, 8, 10, 2, 12, 14, /**/ 4, 6, 8, 10, 0, 2, 12, 14, // + 0, 2, 6, 8, 10, 4, 12, 14, /**/ 2, 6, 8, 10, 0, 4, 12, 14, // + 0, 6, 8, 10, 2, 4, 12, 14, /**/ 6, 8, 10, 0, 2, 4, 12, 14, // + 0, 2, 4, 8, 10, 6, 12, 14, /**/ 2, 4, 8, 10, 0, 6, 12, 14, // + 0, 4, 8, 10, 2, 6, 12, 14, /**/ 4, 8, 10, 0, 2, 6, 12, 14, // + 0, 2, 8, 10, 4, 6, 12, 14, /**/ 2, 8, 10, 0, 4, 6, 12, 14, // + 0, 8, 10, 2, 4, 6, 12, 14, /**/ 8, 10, 0, 2, 4, 6, 12, 14, // + 0, 2, 4, 6, 10, 8, 12, 14, /**/ 2, 4, 6, 10, 0, 8, 12, 14, // + 0, 4, 6, 10, 2, 8, 12, 14, /**/ 4, 6, 10, 0, 2, 8, 12, 14, // + 0, 2, 6, 10, 4, 8, 12, 14, /**/ 2, 6, 10, 0, 4, 8, 12, 14, // + 0, 6, 10, 2, 4, 8, 12, 14, /**/ 6, 10, 0, 2, 4, 8, 12, 14, // + 0, 2, 4, 10, 6, 8, 12, 14, /**/ 2, 4, 10, 0, 6, 8, 12, 14, // + 0, 4, 10, 2, 6, 8, 12, 14, /**/ 4, 10, 0, 2, 6, 8, 12, 14, // + 0, 2, 10, 4, 6, 8, 12, 14, /**/ 2, 10, 0, 4, 6, 8, 12, 14, // + 0, 10, 2, 4, 6, 8, 12, 14, /**/ 10, 0, 2, 4, 6, 8, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 0, 10, 12, 14, // + 0, 4, 6, 8, 2, 10, 12, 14, /**/ 4, 6, 8, 0, 2, 10, 12, 14, // + 0, 2, 6, 8, 4, 10, 12, 14, /**/ 2, 6, 8, 0, 4, 10, 12, 14, // + 0, 6, 8, 2, 4, 10, 12, 14, /**/ 6, 8, 0, 2, 4, 10, 12, 14, // + 0, 2, 4, 8, 6, 10, 12, 14, /**/ 2, 4, 8, 0, 6, 10, 12, 14, // + 0, 4, 8, 2, 6, 10, 12, 14, /**/ 4, 8, 0, 2, 6, 10, 12, 14, // + 0, 2, 8, 4, 6, 10, 12, 14, /**/ 2, 8, 0, 4, 6, 10, 12, 14, // + 0, 8, 2, 4, 6, 10, 12, 14, /**/ 8, 0, 2, 4, 6, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 0, 8, 10, 12, 14, // + 0, 4, 6, 2, 8, 10, 12, 14, /**/ 4, 6, 0, 2, 8, 10, 12, 14, // + 0, 2, 6, 4, 8, 10, 12, 14, /**/ 2, 6, 0, 4, 8, 10, 12, 14, // + 0, 6, 2, 4, 8, 10, 12, 14, /**/ 6, 0, 2, 4, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 0, 6, 8, 10, 12, 14, // + 0, 4, 2, 6, 8, 10, 12, 14, /**/ 4, 0, 2, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 0, 4, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const Vec128 byte_idx{Load(d8, table + mask_bits * 8).raw}; + const Vec128 pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE Vec128 IdxFromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[16 * 16] = { + // PrintCompress32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, // + 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, // + 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[16 * 16] = { + // PrintCompressNot32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, + 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15}; + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IdxFromBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[4 * 16] = { + // PrintCompress64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE Vec128 IdxFromNotBits(const uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[4 * 16] = { + // PrintCompressNot64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Simd d; + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +// Helper functions called by both Compress and CompressStore - avoids a +// redundant BitsFromMask in the latter. + +template +HWY_INLINE Vec128 Compress(Vec128 v, const uint64_t mask_bits) { + const auto idx = detail::IdxFromBits(mask_bits); + const DFromV d; + const RebindToSigned di; + return BitCast(d, TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +template +HWY_INLINE Vec128 CompressNot(Vec128 v, const uint64_t mask_bits) { + const auto idx = detail::IdxFromNotBits(mask_bits); + const DFromV d; + const RebindToSigned di; + return BitCast(d, TableLookupBytes(BitCast(di, v), BitCast(di, idx))); +} + +} // namespace detail + +template +struct CompressIsPartition { +#if HWY_TARGET == HWY_WASM_EMU256 + enum { value = 0 }; +#else + enum { value = (sizeof(T) != 1) }; +#endif +}; + +// Single lane: no-op +template +HWY_API Vec128 Compress(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + // If mask[1] = 1 and mask[0] = 0, then swap both halves, else keep. + const Full128 d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskL, maskH); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 byte lanes +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + const DFromV d; + return detail::Compress(v, BitsFromMask(d, mask)); +} + +// Single lane: no-op +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // If mask[1] = 0 and mask[0] = 1, then swap both halves, else keep. + const Full128 d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskH, maskL); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 byte lanes +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + const DFromV d; + // For partial vectors, we cannot pull the Not() into the table because + // BitsFromMask clears the upper bits. + if (N < 16 / sizeof(T)) { + return detail::Compress(v, BitsFromMask(d, Not(mask))); + } + return detail::CompressNot(v, BitsFromMask(d, mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +// ------------------------------ CompressBits +template +HWY_API Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::Compress(v, mask_bits); +} + +// ------------------------------ CompressStore +template +HWY_API size_t CompressStore(VFromD v, MFromD mask, D d, + TFromD* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = BitsFromMask(d, mask); + const auto c = detail::Compress(v, mask_bits); + StoreU(c, d, unaligned); + return PopCount(mask_bits); +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; // so we can support fp16/bf16 + const uint64_t mask_bits = BitsFromMask(d, m); + const size_t count = PopCount(mask_bits); + const VFromD compressed = + detail::Compress(BitCast(du, v), mask_bits); + const MFromD store_mask = RebindMask(d, FirstN(du, count)); + BlendedStore(BitCast(d, compressed), store_mask, d, unaligned); + return count; +} + +// ------------------------------ CompressBitsStore + +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + uint64_t mask_bits = 0; + constexpr size_t kN = MaxLanes(d); + CopyBytes<(kN + 7) / 8>(bits, &mask_bits); + if (kN < 8) { + mask_bits &= (1ull << kN) - 1; + } + + const auto c = detail::Compress(v, mask_bits); + StoreU(c, d, unaligned); + return PopCount(mask_bits); +} + +// ------------------------------ StoreInterleaved2/3/4 + +// HWY_NATIVE_LOAD_STORE_INTERLEAVED not set, hence defined in +// generic_ops-inl.h. + +// ------------------------------ Additional mask logical operations +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + return mask; +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const FixedTag d; + const auto vmask = VecFromMask(d, mask); + return MaskFromVec(Or(vmask, InterleaveLower(vmask, vmask))); +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const Simd d; + const auto vmask = VecFromMask(d, mask); + const auto neg_vmask = + ResizeBitCast(d, Neg(ResizeBitCast(Full64(), vmask))); + return MaskFromVec(Or(vmask, neg_vmask)); +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const Full128 d; + const Repartition di64; + + auto vmask = BitCast(di64, VecFromMask(d, mask)); + vmask = Or(vmask, Neg(vmask)); + + // Copy the sign bit of the first int64_t lane to the second int64_t lane + const auto vmask2 = BroadcastSignBit(InterleaveLower(Zero(di64), vmask)); + return MaskFromVec(BitCast(d, Or(vmask, vmask2))); +} + +template +HWY_API Mask128 SetBeforeFirst(Mask128 mask) { + return Not(SetAtOrAfterFirst(mask)); +} + +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + return mask; +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const FixedTag d; + const RebindToSigned di; + + const auto vmask = BitCast(di, VecFromMask(d, mask)); + const auto zero = Zero(di); + const auto vmask2 = VecFromMask(di, InterleaveLower(zero, vmask) == zero); + return MaskFromVec(BitCast(d, And(vmask, vmask2))); +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const Simd d; + const RebindToSigned di; + + const auto vmask = ResizeBitCast(Full64(), VecFromMask(d, mask)); + const auto only_first_vmask = + BitCast(d, Neg(ResizeBitCast(di, And(vmask, Neg(vmask))))); + return MaskFromVec(only_first_vmask); +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const Full128 d; + const RebindToSigned di; + const Repartition di64; + + const auto zero = Zero(di64); + const auto vmask = BitCast(di64, VecFromMask(d, mask)); + const auto vmask2 = VecFromMask(di64, InterleaveLower(zero, vmask) == zero); + const auto only_first_vmask = Neg(BitCast(di, And(vmask, Neg(vmask)))); + return MaskFromVec(BitCast(d, And(only_first_vmask, BitCast(di, vmask2)))); +} + +template +HWY_API Mask128 SetAtOrBeforeFirst(Mask128 /*mask*/) { + const FixedTag d; + const RebindToSigned di; + using TI = MakeSigned; + + return RebindMask(d, MaskFromVec(Set(di, TI(-1)))); +} +template +HWY_API Mask128 SetAtOrBeforeFirst(Mask128 mask) { + const Simd d; + return SetBeforeFirst(MaskFromVec(ShiftLeftLanes<1>(VecFromMask(d, mask)))); +} + +// ------------------------------ MulEven/Odd (Load) + +template +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + alignas(16) T mul[2]; + mul[0] = Mul128(static_cast(wasm_i64x2_extract_lane(a.raw, 0)), + static_cast(wasm_i64x2_extract_lane(b.raw, 0)), &mul[1]); + return Load(Full128(), mul); +} + +template +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + alignas(16) T mul[2]; + mul[0] = Mul128(static_cast(wasm_i64x2_extract_lane(a.raw, 1)), + static_cast(wasm_i64x2_extract_lane(b.raw, 1)), &mul[1]); + return Load(Full128(), mul); +} + +// ------------------------------ I64/U64 MulHigh (GetLane) +template +HWY_API Vec64 MulHigh(Vec64 a, Vec64 b) { + T hi; + Mul128(GetLane(a), GetLane(b), &hi); + return Set(Full64(), hi); +} + +template +HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { + T hi_0; + T hi_1; + Mul128(GetLane(a), GetLane(b), &hi_0); + Mul128(detail::ExtractLane<1>(a), detail::ExtractLane<1>(b), &hi_1); + return Dup128VecFromValues(Full128(), hi_0, hi_1); +} + +// ------------------------------ WidenMulPairwiseAdd (MulAdd, PromoteEvenTo) + +// Generic for all vector lengths. +template >> +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), + Mul(PromoteOddTo(df, a), PromoteOddTo(df, b))); +} + +// Even if N=1, the input is always at least 2 lanes, hence i32x4_dot_i16x8 is +// safe. +template >> +HWY_API VFromD WidenMulPairwiseAdd(D32 /* tag */, V16 a, V16 b) { + return VFromD{wasm_i32x4_dot_i16x8(a.raw, b.raw)}; +} + +template >> +HWY_API VFromD WidenMulPairwiseAdd(DU32 du32, VU16 a, VU16 b) { + return MulAdd(PromoteEvenTo(du32, a), PromoteEvenTo(du32, b), + Mul(PromoteOddTo(du32, a), PromoteOddTo(du32, b))); +} + +// ------------------------------ ReorderWidenMulAccumulate + +template >> +HWY_API VFromD ReorderWidenMulAccumulate(D32 d32, V16 a, V16 b, + const VFromD sum0, + VFromD& /*sum1*/) { + return sum0 + WidenMulPairwiseAdd(d32, a, b); +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW) { + return sum0; // invariant already holds +} + +// ------------------------------ Reductions + +// Nothing native, generic_ops-inl defines SumOfLanes and ReduceSum. + +// ------------------------------ Lt128 + +template +HWY_INLINE MFromD Lt128(D d, VFromD a, VFromD b) { + // Truth table of Eq and Lt for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const MFromD eqHL = Eq(a, b); + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + // We need to bring cL to the upper lane/bit corresponding to cH. Comparing + // the result of InterleaveUpper/Lower requires 9 ops, whereas shifting the + // comparison result leftwards requires only 4. IfThenElse compiles to the + // same code as OrAnd(). + const VFromD ltLx = DupEven(ltHL); + const VFromD outHx = IfThenElse(eqHL, ltLx, ltHL); + return MaskFromVec(DupOdd(outHx)); +} + +template +HWY_INLINE MFromD Lt128Upper(D d, VFromD a, VFromD b) { + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + return MaskFromVec(InterleaveUpper(d, ltHL, ltHL)); +} + +// ------------------------------ Eq128 + +template +HWY_INLINE MFromD Eq128(D d, VFromD a, VFromD b) { + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + return MaskFromVec(And(Reverse2(d, eqHL), eqHL)); +} + +template +HWY_INLINE MFromD Eq128Upper(D d, VFromD a, VFromD b) { + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + return MaskFromVec(InterleaveUpper(d, eqHL, eqHL)); +} + +// ------------------------------ Ne128 + +template +HWY_INLINE MFromD Ne128(D d, VFromD a, VFromD b) { + const VFromD neHL = VecFromMask(d, Ne(a, b)); + return MaskFromVec(Or(Reverse2(d, neHL), neHL)); +} + +template +HWY_INLINE MFromD Ne128Upper(D d, VFromD a, VFromD b) { + const VFromD neHL = VecFromMask(d, Ne(a, b)); + return MaskFromVec(InterleaveUpper(d, neHL, neHL)); +} + +// ------------------------------ Min128, Max128 (Lt128) + +// Without a native OddEven, it seems infeasible to go faster than Lt128. +template +HWY_INLINE VFromD Min128(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128(d, a, b), a, b); +} + +template +HWY_INLINE VFromD Max128(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128(d, b, a), a, b); +} + +template +HWY_INLINE VFromD Min128Upper(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128Upper(d, a, b), a, b); +} + +template +HWY_INLINE VFromD Max128Upper(D d, const VFromD a, const VFromD b) { + return IfThenElse(Lt128Upper(d, b, a), a, b); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/lib/highway/hwy/ops/wasm_256-inl.h b/lib/highway/hwy/ops/wasm_256-inl.h new file mode 100644 index 00000000000..3c7e9b83804 --- /dev/null +++ b/lib/highway/hwy/ops/wasm_256-inl.h @@ -0,0 +1,2517 @@ +// Copyright 2021 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 256-bit WASM vectors and operations. Experimental. +// External include guard in highway.h - see comment there. + +// For half-width vectors. Already includes base.h and shared-inl.h. +#include "hwy/ops/wasm_128-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +template +class Vec256 { + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = 32 / sizeof(T); // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec256& operator*=(const Vec256 other) { + return *this = (*this * other); + } + HWY_INLINE Vec256& operator/=(const Vec256 other) { + return *this = (*this / other); + } + HWY_INLINE Vec256& operator+=(const Vec256 other) { + return *this = (*this + other); + } + HWY_INLINE Vec256& operator-=(const Vec256 other) { + return *this = (*this - other); + } + HWY_INLINE Vec256& operator%=(const Vec256 other) { + return *this = (*this % other); + } + HWY_INLINE Vec256& operator&=(const Vec256 other) { + return *this = (*this & other); + } + HWY_INLINE Vec256& operator|=(const Vec256 other) { + return *this = (*this | other); + } + HWY_INLINE Vec256& operator^=(const Vec256 other) { + return *this = (*this ^ other); + } + + Vec128 v0; + Vec128 v1; +}; + +template +struct Mask256 { + using PrivateT = T; // only for DFromM + static constexpr size_t kPrivateN = 32 / sizeof(T); // only for DFromM + + Mask128 m0; + Mask128 m1; +}; + +// ------------------------------ Zero + +// Avoid VFromD here because it is defined in terms of Zero. +template +HWY_API Vec256> Zero(D d) { + const Half dh; + Vec256> ret; + ret.v0 = ret.v1 = Zero(dh); + return ret; +} + +// ------------------------------ BitCast +template +HWY_API VFromD BitCast(D d, Vec256 v) { + const Half dh; + VFromD ret; + ret.v0 = BitCast(dh, v.v0); + ret.v1 = BitCast(dh, v.v1); + return ret; +} + +// ------------------------------ ResizeBitCast + +// 32-byte vector to 32-byte vector: Same as BitCast +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + return BitCast(d, v); +} + +// <= 16-byte vector to 32-byte vector +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + const Half dh; + VFromD ret; + ret.v0 = ResizeBitCast(dh, v); + ret.v1 = Zero(dh); + return ret; +} + +// 32-byte vector to <= 16-byte vector +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + return ResizeBitCast(d, v.v0); +} + +// ------------------------------ Set +template +HWY_API VFromD Set(D d, const T2 t) { + const Half dh; + VFromD ret; + ret.v0 = ret.v1 = Set(dh, static_cast>(t)); + return ret; +} + +// Undefined, Iota defined in wasm_128. + +// ------------------------------ Dup128VecFromValues +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + const Half dh; + VFromD ret; + ret.v0 = ret.v1 = Dup128VecFromValues(dh, t0, t1, t2, t3, t4, t5, t6, t7, t8, + t9, t10, t11, t12, t13, t14, t15); + return ret; +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const Half dh; + VFromD ret; + ret.v0 = ret.v1 = Dup128VecFromValues(dh, t0, t1, t2, t3, t4, t5, t6, t7); + return ret; +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + const Half dh; + VFromD ret; + ret.v0 = ret.v1 = Dup128VecFromValues(dh, t0, t1, t2, t3); + return ret; +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1) { + const Half dh; + VFromD ret; + ret.v0 = ret.v1 = Dup128VecFromValues(dh, t0, t1); + return ret; +} + +// ================================================== ARITHMETIC + +template +HWY_API Vec256 operator+(Vec256 a, const Vec256 b) { + a.v0 += b.v0; + a.v1 += b.v1; + return a; +} + +template +HWY_API Vec256 operator-(Vec256 a, const Vec256 b) { + a.v0 -= b.v0; + a.v1 -= b.v1; + return a; +} + +// ------------------------------ SumsOf8 +HWY_API Vec256 SumsOf8(const Vec256 v) { + Vec256 ret; + ret.v0 = SumsOf8(v.v0); + ret.v1 = SumsOf8(v.v1); + return ret; +} + +HWY_API Vec256 SumsOf8(const Vec256 v) { + Vec256 ret; + ret.v0 = SumsOf8(v.v0); + ret.v1 = SumsOf8(v.v1); + return ret; +} + +template +HWY_API Vec256 SaturatedAdd(Vec256 a, const Vec256 b) { + a.v0 = SaturatedAdd(a.v0, b.v0); + a.v1 = SaturatedAdd(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 SaturatedSub(Vec256 a, const Vec256 b) { + a.v0 = SaturatedSub(a.v0, b.v0); + a.v1 = SaturatedSub(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 AverageRound(Vec256 a, const Vec256 b) { + a.v0 = AverageRound(a.v0, b.v0); + a.v1 = AverageRound(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 Abs(Vec256 v) { + v.v0 = Abs(v.v0); + v.v1 = Abs(v.v1); + return v; +} + +// ------------------------------ Shift lanes by constant #bits + +template +HWY_API Vec256 ShiftLeft(Vec256 v) { + v.v0 = ShiftLeft(v.v0); + v.v1 = ShiftLeft(v.v1); + return v; +} + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + v.v0 = ShiftRight(v.v0); + v.v1 = ShiftRight(v.v1); + return v; +} + +// ------------------------------ RotateRight (ShiftRight, Or) +template +HWY_API Vec256 RotateRight(const Vec256 v) { + const DFromV d; + const RebindToUnsigned du; + + constexpr size_t kSizeInBits = sizeof(T) * 8; + static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); + if (kBits == 0) return v; + + return Or(BitCast(d, ShiftRight(BitCast(du, v))), + ShiftLeft(v)); +} + +// ------------------------------ Shift lanes by same variable #bits + +template +HWY_API Vec256 ShiftLeftSame(Vec256 v, const int bits) { + v.v0 = ShiftLeftSame(v.v0, bits); + v.v1 = ShiftLeftSame(v.v1, bits); + return v; +} + +template +HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { + v.v0 = ShiftRightSame(v.v0, bits); + v.v1 = ShiftRightSame(v.v1, bits); + return v; +} + +// ------------------------------ Min, Max +template +HWY_API Vec256 Min(Vec256 a, const Vec256 b) { + a.v0 = Min(a.v0, b.v0); + a.v1 = Min(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 Max(Vec256 a, const Vec256 b) { + a.v0 = Max(a.v0, b.v0); + a.v1 = Max(a.v1, b.v1); + return a; +} +// ------------------------------ Integer multiplication + +template +HWY_API Vec256 operator*(Vec256 a, const Vec256 b) { + a.v0 *= b.v0; + a.v1 *= b.v1; + return a; +} + +template +HWY_API Vec256 MulHigh(Vec256 a, const Vec256 b) { + a.v0 = MulHigh(a.v0, b.v0); + a.v1 = MulHigh(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 MulFixedPoint15(Vec256 a, const Vec256 b) { + a.v0 = MulFixedPoint15(a.v0, b.v0); + a.v1 = MulFixedPoint15(a.v1, b.v1); + return a; +} + +// Cannot use MakeWide because that returns uint128_t for uint64_t, but we want +// uint64_t. +template +HWY_API Vec256> MulEven(Vec256 a, const Vec256 b) { + Vec256> ret; + ret.v0 = MulEven(a.v0, b.v0); + ret.v1 = MulEven(a.v1, b.v1); + return ret; +} +template +HWY_API Vec256 MulEven(Vec256 a, const Vec256 b) { + Vec256 ret; + ret.v0 = MulEven(a.v0, b.v0); + ret.v1 = MulEven(a.v1, b.v1); + return ret; +} + +template +HWY_API Vec256> MulOdd(Vec256 a, const Vec256 b) { + Vec256> ret; + ret.v0 = MulOdd(a.v0, b.v0); + ret.v1 = MulOdd(a.v1, b.v1); + return ret; +} +template +HWY_API Vec256 MulOdd(Vec256 a, const Vec256 b) { + Vec256 ret; + ret.v0 = MulOdd(a.v0, b.v0); + ret.v1 = MulOdd(a.v1, b.v1); + return ret; +} + +// ------------------------------ Negate +template +HWY_API Vec256 Neg(Vec256 v) { + v.v0 = Neg(v.v0); + v.v1 = Neg(v.v1); + return v; +} + +// ------------------------------ AbsDiff +// generic_ops takes care of integer T. +template +HWY_API Vec256 AbsDiff(const Vec256 a, const Vec256 b) { + return Abs(a - b); +} + +// ------------------------------ Floating-point division +// generic_ops takes care of integer T. +template +HWY_API Vec256 operator/(Vec256 a, const Vec256 b) { + a.v0 /= b.v0; + a.v1 /= b.v1; + return a; +} + +// ------------------------------ Floating-point multiply-add variants + +template +HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, Vec256 add) { + mul.v0 = MulAdd(mul.v0, x.v0, add.v0); + mul.v1 = MulAdd(mul.v1, x.v1, add.v1); + return mul; +} + +template +HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, Vec256 add) { + mul.v0 = NegMulAdd(mul.v0, x.v0, add.v0); + mul.v1 = NegMulAdd(mul.v1, x.v1, add.v1); + return mul; +} + +template +HWY_API Vec256 MulSub(Vec256 mul, Vec256 x, Vec256 sub) { + mul.v0 = MulSub(mul.v0, x.v0, sub.v0); + mul.v1 = MulSub(mul.v1, x.v1, sub.v1); + return mul; +} + +template +HWY_API Vec256 NegMulSub(Vec256 mul, Vec256 x, Vec256 sub) { + mul.v0 = NegMulSub(mul.v0, x.v0, sub.v0); + mul.v1 = NegMulSub(mul.v1, x.v1, sub.v1); + return mul; +} + +// ------------------------------ Floating-point square root + +template +HWY_API Vec256 Sqrt(Vec256 v) { + v.v0 = Sqrt(v.v0); + v.v1 = Sqrt(v.v1); + return v; +} + +// ------------------------------ Floating-point rounding + +// Toward nearest integer, ties to even +template +HWY_API Vec256 Round(Vec256 v) { + v.v0 = Round(v.v0); + v.v1 = Round(v.v1); + return v; +} + +// Toward zero, aka truncate +template +HWY_API Vec256 Trunc(Vec256 v) { + v.v0 = Trunc(v.v0); + v.v1 = Trunc(v.v1); + return v; +} + +// Toward +infinity, aka ceiling +template +HWY_API Vec256 Ceil(Vec256 v) { + v.v0 = Ceil(v.v0); + v.v1 = Ceil(v.v1); + return v; +} + +// Toward -infinity, aka floor +template +HWY_API Vec256 Floor(Vec256 v) { + v.v0 = Floor(v.v0); + v.v1 = Floor(v.v1); + return v; +} + +// ------------------------------ Floating-point classification + +template +HWY_API Mask256 IsNaN(const Vec256 v) { + return v != v; +} + +template +HWY_API Mask256 IsInf(const Vec256 v) { + const DFromV d; + const RebindToUnsigned du; + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. + return RebindMask(d, Eq(Add(vu, vu), Set(du, hwy::MaxExponentTimes2()))); +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask256 IsFinite(const Vec256 v) { + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; // cheaper than unsigned comparison + const VFromD vu = BitCast(du, v); + // 'Shift left' to clear the sign bit, then right so we can compare with the + // max exponent (cannot compare with MaxExponentTimes2 directly because it is + // negative and non-negative floats would be greater). + const VFromD exp = + BitCast(di, ShiftRight() + 1>(Add(vu, vu))); + return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); +} + +// ================================================== COMPARE + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template > +HWY_API MFromD RebindMask(DTo /*tag*/, Mask256 m) { + static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); + return MFromD{Mask128{m.m0.raw}, Mask128{m.m1.raw}}; +} + +template +HWY_API Mask256 TestBit(Vec256 v, Vec256 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +template +HWY_API Mask256 operator==(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator==(a.v0, b.v0); + m.m1 = operator==(a.v1, b.v1); + return m; +} + +template +HWY_API Mask256 operator!=(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator!=(a.v0, b.v0); + m.m1 = operator!=(a.v1, b.v1); + return m; +} + +template +HWY_API Mask256 operator<(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator<(a.v0, b.v0); + m.m1 = operator<(a.v1, b.v1); + return m; +} + +template +HWY_API Mask256 operator>(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator>(a.v0, b.v0); + m.m1 = operator>(a.v1, b.v1); + return m; +} + +template +HWY_API Mask256 operator<=(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator<=(a.v0, b.v0); + m.m1 = operator<=(a.v1, b.v1); + return m; +} + +template +HWY_API Mask256 operator>=(Vec256 a, const Vec256 b) { + Mask256 m; + m.m0 = operator>=(a.v0, b.v0); + m.m1 = operator>=(a.v1, b.v1); + return m; +} + +// ------------------------------ FirstN (Iota, Lt) + +template +HWY_API MFromD FirstN(const D d, size_t num) { + const RebindToSigned di; // Signed comparisons may be cheaper. + using TI = TFromD; + return RebindMask(d, Iota(di, 0) < Set(di, static_cast(num))); +} + +// ================================================== LOGICAL + +template +HWY_API Vec256 Not(Vec256 v) { + v.v0 = Not(v.v0); + v.v1 = Not(v.v1); + return v; +} + +template +HWY_API Vec256 And(Vec256 a, Vec256 b) { + a.v0 = And(a.v0, b.v0); + a.v1 = And(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 AndNot(Vec256 not_mask, Vec256 mask) { + not_mask.v0 = AndNot(not_mask.v0, mask.v0); + not_mask.v1 = AndNot(not_mask.v1, mask.v1); + return not_mask; +} + +template +HWY_API Vec256 Or(Vec256 a, Vec256 b) { + a.v0 = Or(a.v0, b.v0); + a.v1 = Or(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 Xor(Vec256 a, Vec256 b) { + a.v0 = Xor(a.v0, b.v0); + a.v1 = Xor(a.v1, b.v1); + return a; +} + +template +HWY_API Vec256 Or3(Vec256 o1, Vec256 o2, Vec256 o3) { + return Or(o1, Or(o2, o3)); +} + +template +HWY_API Vec256 OrAnd(Vec256 o, Vec256 a1, Vec256 a2) { + return Or(o, And(a1, a2)); +} + +template +HWY_API Vec256 IfVecThenElse(Vec256 mask, Vec256 yes, Vec256 no) { + return IfThenElse(MaskFromVec(mask), yes, no); +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec256 operator&(const Vec256 a, const Vec256 b) { + return And(a, b); +} + +template +HWY_API Vec256 operator|(const Vec256 a, const Vec256 b) { + return Or(a, b); +} + +template +HWY_API Vec256 operator^(const Vec256 a, const Vec256 b) { + return Xor(a, b); +} + +// ------------------------------ CopySign +template +HWY_API Vec256 CopySign(const Vec256 magn, const Vec256 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const DFromV d; + return BitwiseIfThenElse(SignBit(d), sign, magn); +} + +// ------------------------------ CopySignToAbs +template +HWY_API Vec256 CopySignToAbs(const Vec256 abs, const Vec256 sign) { + static_assert(IsFloat(), "Only makes sense for floating-point"); + const DFromV d; + return OrAnd(abs, SignBit(d), sign); +} + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask256 MaskFromVec(const Vec256 v) { + Mask256 m; + m.m0 = MaskFromVec(v.v0); + m.m1 = MaskFromVec(v.v1); + return m; +} + +template > +HWY_API Vec256 VecFromMask(D d, Mask256 m) { + const Half dh; + Vec256 v; + v.v0 = VecFromMask(dh, m.m0); + v.v1 = VecFromMask(dh, m.m1); + return v; +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD m) { + const Half dh; + const uint64_t lo = BitsFromMask(dh, m.m0); + const uint64_t hi = BitsFromMask(dh, m.m1); + return (hi << Lanes(dh)) | lo; +} + +// mask ? yes : no +template +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { + yes.v0 = IfThenElse(mask.m0, yes.v0, no.v0); + yes.v1 = IfThenElse(mask.m1, yes.v1, no.v1); + return yes; +} + +// mask ? yes : 0 +template +HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { + return yes & VecFromMask(DFromV(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + return AndNot(VecFromMask(DFromV(), mask), no); +} + +template +HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { + v.v0 = IfNegativeThenElse(v.v0, yes.v0, no.v0); + v.v1 = IfNegativeThenElse(v.v1, yes.v1, no.v1); + return v; +} + +// ------------------------------ Mask logical + +template +HWY_API Mask256 Not(const Mask256 m) { + return MaskFromVec(Not(VecFromMask(Full256(), m))); +} + +template +HWY_API Mask256 And(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 Or(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 ExclusiveNeither(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +// ------------------------------ Shl (BroadcastSignBit, IfThenElse) +template +HWY_API Vec256 operator<<(Vec256 v, const Vec256 bits) { + v.v0 = operator<<(v.v0, bits.v0); + v.v1 = operator<<(v.v1, bits.v1); + return v; +} + +// ------------------------------ Shr (BroadcastSignBit, IfThenElse) +template +HWY_API Vec256 operator>>(Vec256 v, const Vec256 bits) { + v.v0 = operator>>(v.v0, bits.v0); + v.v1 = operator>>(v.v1, bits.v1); + return v; +} + +// ------------------------------ BroadcastSignBit (compare, VecFromMask) + +template +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return ShiftRight(v); +} +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + const DFromV d; + return VecFromMask(d, v < Zero(d)); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT aligned) { + const Half dh; + VFromD ret; + ret.v0 = Load(dh, aligned); + ret.v1 = Load(dh, aligned + Lanes(dh)); + return ret; +} + +template > +HWY_API Vec256 MaskedLoad(Mask256 m, D d, const T* HWY_RESTRICT aligned) { + return IfThenElseZero(m, Load(d, aligned)); +} + +template > +HWY_API Vec256 MaskedLoadOr(Vec256 v, Mask256 m, D d, + const T* HWY_RESTRICT aligned) { + return IfThenElse(m, Load(d, aligned), v); +} + +// LoadU == Load. +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + return Load(d, p); +} + +template +HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { + const Half dh; + VFromD ret; + ret.v0 = ret.v1 = Load(dh, p); + return ret; +} + +// ------------------------------ Store + +template > +HWY_API void Store(Vec256 v, D d, T* HWY_RESTRICT aligned) { + const Half dh; + Store(v.v0, dh, aligned); + Store(v.v1, dh, aligned + Lanes(dh)); +} + +// StoreU == Store. +template > +HWY_API void StoreU(Vec256 v, D d, T* HWY_RESTRICT p) { + Store(v, d, p); +} + +template > +HWY_API void BlendedStore(Vec256 v, Mask256 m, D d, T* HWY_RESTRICT p) { + StoreU(IfThenElse(m, v, LoadU(d, p)), d, p); +} + +// ------------------------------ Stream +template > +HWY_API void Stream(Vec256 v, D d, T* HWY_RESTRICT aligned) { + // Same as aligned stores. + Store(v, d, aligned); +} + +// ------------------------------ Scatter, Gather defined in wasm_128 + +// ================================================== SWIZZLE + +// ------------------------------ ExtractLane +template +HWY_API T ExtractLane(const Vec256 v, size_t i) { + alignas(32) T lanes[32 / sizeof(T)]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane +template +HWY_API Vec256 InsertLane(const Vec256 v, size_t i, T t) { + DFromV d; + alignas(32) T lanes[32 / sizeof(T)]; + Store(v, d, lanes); + lanes[i] = t; + return Load(d, lanes); +} + +// ------------------------------ ExtractBlock +template +HWY_API Vec128 ExtractBlock(Vec256 v) { + static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index"); + return (kBlockIdx == 0) ? v.v0 : v.v1; +} + +// ------------------------------ InsertBlock +template +HWY_API Vec256 InsertBlock(Vec256 v, Vec128 blk_to_insert) { + static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index"); + Vec256 result; + if (kBlockIdx == 0) { + result.v0 = blk_to_insert; + result.v1 = v.v1; + } else { + result.v0 = v.v0; + result.v1 = blk_to_insert; + } + return result; +} + +// ------------------------------ BroadcastBlock +template +HWY_API Vec256 BroadcastBlock(Vec256 v) { + static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index"); + Vec256 result; + result.v0 = result.v1 = (kBlockIdx == 0 ? v.v0 : v.v1); + return result; +} + +// ------------------------------ LowerHalf + +template > +HWY_API Vec128 LowerHalf(D /* tag */, Vec256 v) { + return v.v0; +} + +template +HWY_API Vec128 LowerHalf(Vec256 v) { + return v.v0; +} + +// ------------------------------ GetLane (LowerHalf) +template +HWY_API T GetLane(const Vec256 v) { + return GetLane(LowerHalf(v)); +} + +// ------------------------------ ShiftLeftBytes + +template > +HWY_API Vec256 ShiftLeftBytes(D d, Vec256 v) { + const Half dh; + v.v0 = ShiftLeftBytes(dh, v.v0); + v.v1 = ShiftLeftBytes(dh, v.v1); + return v; +} + +template +HWY_API Vec256 ShiftLeftBytes(Vec256 v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftLeftLanes + +template > +HWY_API Vec256 ShiftLeftLanes(D d, const Vec256 v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); +} + +template +HWY_API Vec256 ShiftLeftLanes(const Vec256 v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes +template > +HWY_API Vec256 ShiftRightBytes(D d, Vec256 v) { + const Half dh; + v.v0 = ShiftRightBytes(dh, v.v0); + v.v1 = ShiftRightBytes(dh, v.v1); + return v; +} + +// ------------------------------ ShiftRightLanes +template > +HWY_API Vec256 ShiftRightLanes(D d, const Vec256 v) { + const Repartition d8; + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) +template > +HWY_API Vec128 UpperHalf(D /* tag */, const Vec256 v) { + return v.v1; +} + +// ------------------------------ CombineShiftRightBytes + +template > +HWY_API Vec256 CombineShiftRightBytes(D d, Vec256 hi, Vec256 lo) { + const Half dh; + hi.v0 = CombineShiftRightBytes(dh, hi.v0, lo.v0); + hi.v1 = CombineShiftRightBytes(dh, hi.v1, lo.v1); + return hi; +} + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec256 Broadcast(const Vec256 v) { + Vec256 ret; + ret.v0 = Broadcast(v.v0); + ret.v1 = Broadcast(v.v1); + return ret; +} + +template +HWY_API Vec256 BroadcastLane(const Vec256 v) { + constexpr int kLanesPerBlock = static_cast(16 / sizeof(T)); + static_assert(0 <= kLane && kLane < kLanesPerBlock * 2, "Invalid lane"); + constexpr int kLaneInBlkIdx = kLane & (kLanesPerBlock - 1); + Vec256 ret; + ret.v0 = ret.v1 = + Broadcast(kLane >= kLanesPerBlock ? v.v1 : v.v0); + return ret; +} + +// ------------------------------ TableLookupBytes + +// Both full +template +HWY_API Vec256 TableLookupBytes(const Vec256 bytes, Vec256 from) { + from.v0 = TableLookupBytes(bytes.v0, from.v0); + from.v1 = TableLookupBytes(bytes.v1, from.v1); + return from; +} + +// Partial index vector +template +HWY_API Vec128 TableLookupBytes(Vec256 bytes, + const Vec128 from) { + // First expand to full 128, then 256. + const auto from_256 = ZeroExtendVector(Full256(), Vec128{from.raw}); + const auto tbl_full = TableLookupBytes(bytes, from_256); + // Shrink to 128, then partial. + return Vec128{LowerHalf(Full128(), tbl_full).raw}; +} + +// Partial table vector +template +HWY_API Vec256 TableLookupBytes(Vec128 bytes, const Vec256 from) { + // First expand to full 128, then 256. + const auto bytes_256 = ZeroExtendVector(Full256(), Vec128{bytes.raw}); + return TableLookupBytes(bytes_256, from); +} + +// Partial both are handled by wasm_128. + +template +HWY_API VI TableLookupBytesOr0(V bytes, VI from) { + // wasm out-of-bounds policy already zeros, so TableLookupBytes is fine. + return TableLookupBytes(bytes, from); +} + +// ------------------------------ Hard-coded shuffles + +template +HWY_API Vec256 Shuffle01(Vec256 v) { + v.v0 = Shuffle01(v.v0); + v.v1 = Shuffle01(v.v1); + return v; +} + +template +HWY_API Vec256 Shuffle2301(Vec256 v) { + v.v0 = Shuffle2301(v.v0); + v.v1 = Shuffle2301(v.v1); + return v; +} + +template +HWY_API Vec256 Shuffle1032(Vec256 v) { + v.v0 = Shuffle1032(v.v0); + v.v1 = Shuffle1032(v.v1); + return v; +} + +template +HWY_API Vec256 Shuffle0321(Vec256 v) { + v.v0 = Shuffle0321(v.v0); + v.v1 = Shuffle0321(v.v1); + return v; +} + +template +HWY_API Vec256 Shuffle2103(Vec256 v) { + v.v0 = Shuffle2103(v.v0); + v.v1 = Shuffle2103(v.v1); + return v; +} + +template +HWY_API Vec256 Shuffle0123(Vec256 v) { + v.v0 = Shuffle0123(v.v0); + v.v1 = Shuffle0123(v.v1); + return v; +} + +// Used by generic_ops-inl.h +namespace detail { + +template +HWY_API Vec256 ShuffleTwo2301(Vec256 a, const Vec256 b) { + a.v0 = ShuffleTwo2301(a.v0, b.v0); + a.v1 = ShuffleTwo2301(a.v1, b.v1); + return a; +} +template +HWY_API Vec256 ShuffleTwo1230(Vec256 a, const Vec256 b) { + a.v0 = ShuffleTwo1230(a.v0, b.v0); + a.v1 = ShuffleTwo1230(a.v1, b.v1); + return a; +} +template +HWY_API Vec256 ShuffleTwo3012(Vec256 a, const Vec256 b) { + a.v0 = ShuffleTwo3012(a.v0, b.v0); + a.v1 = ShuffleTwo3012(a.v1, b.v1); + return a; +} + +} // namespace detail + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices for use by TableLookupLanes. +template +struct Indices256 { + __v128_u i0; + __v128_u i1; +}; + +template , typename TI> +HWY_API Indices256 IndicesFromVec(D /* tag */, Vec256 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); + Indices256 ret; + ret.i0 = vec.v0.raw; + ret.i1 = vec.v1.raw; + return ret; +} + +template +HWY_API Indices256> SetTableIndices(D d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec256 TableLookupLanes(const Vec256 v, Indices256 idx) { + const DFromV d; + const Half dh; + const auto idx_i0 = IndicesFromVec(dh, Vec128{idx.i0}); + const auto idx_i1 = IndicesFromVec(dh, Vec128{idx.i1}); + + Vec256 result; + result.v0 = TwoTablesLookupLanes(v.v0, v.v1, idx_i0); + result.v1 = TwoTablesLookupLanes(v.v0, v.v1, idx_i1); + return result; +} + +template +HWY_API Vec256 TableLookupLanesOr0(Vec256 v, Indices256 idx) { + // The out of bounds behavior will already zero lanes. + return TableLookupLanesOr0(v, idx); +} + +template +HWY_API Vec256 TwoTablesLookupLanes(const Vec256 a, const Vec256 b, + Indices256 idx) { + const DFromV d; + const Half dh; + const RebindToUnsigned du; + using TU = MakeUnsigned; + constexpr size_t kLanesPerVect = 32 / sizeof(TU); + + Vec256 vi; + vi.v0 = Vec128{idx.i0}; + vi.v1 = Vec128{idx.i1}; + const auto vmod = vi & Set(du, TU{kLanesPerVect - 1}); + const auto is_lo = RebindMask(d, vi == vmod); + + const auto idx_i0 = IndicesFromVec(dh, vmod.v0); + const auto idx_i1 = IndicesFromVec(dh, vmod.v1); + + Vec256 result_lo; + Vec256 result_hi; + result_lo.v0 = TwoTablesLookupLanes(a.v0, a.v1, idx_i0); + result_lo.v1 = TwoTablesLookupLanes(a.v0, a.v1, idx_i1); + result_hi.v0 = TwoTablesLookupLanes(b.v0, b.v1, idx_i0); + result_hi.v1 = TwoTablesLookupLanes(b.v0, b.v1, idx_i1); + return IfThenElse(is_lo, result_lo, result_hi); +} + +// ------------------------------ Reverse +template > +HWY_API Vec256 Reverse(D d, const Vec256 v) { + const Half dh; + Vec256 ret; + ret.v1 = Reverse(dh, v.v0); // note reversed v1 member order + ret.v0 = Reverse(dh, v.v1); + return ret; +} + +// ------------------------------ Reverse2 +template > +HWY_API Vec256 Reverse2(D d, Vec256 v) { + const Half dh; + v.v0 = Reverse2(dh, v.v0); + v.v1 = Reverse2(dh, v.v1); + return v; +} + +// ------------------------------ Reverse4 + +// Each block has only 2 lanes, so swap blocks and their lanes. +template , HWY_IF_T_SIZE(T, 8)> +HWY_API Vec256 Reverse4(D d, const Vec256 v) { + const Half dh; + Vec256 ret; + ret.v0 = Reverse2(dh, v.v1); // swapped + ret.v1 = Reverse2(dh, v.v0); + return ret; +} + +template , HWY_IF_NOT_T_SIZE(T, 8)> +HWY_API Vec256 Reverse4(D d, Vec256 v) { + const Half dh; + v.v0 = Reverse4(dh, v.v0); + v.v1 = Reverse4(dh, v.v1); + return v; +} + +// ------------------------------ Reverse8 + +template , HWY_IF_T_SIZE(T, 8)> +HWY_API Vec256 Reverse8(D /* tag */, Vec256 /* v */) { + HWY_ASSERT(0); // don't have 8 u64 lanes +} + +// Each block has only 4 lanes, so swap blocks and their lanes. +template , HWY_IF_T_SIZE(T, 4)> +HWY_API Vec256 Reverse8(D d, const Vec256 v) { + const Half dh; + Vec256 ret; + ret.v0 = Reverse4(dh, v.v1); // swapped + ret.v1 = Reverse4(dh, v.v0); + return ret; +} + +template , + HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2))> +HWY_API Vec256 Reverse8(D d, Vec256 v) { + const Half dh; + v.v0 = Reverse8(dh, v.v0); + v.v1 = Reverse8(dh, v.v1); + return v; +} + +// ------------------------------ InterleaveLower + +template +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + a.v0 = InterleaveLower(a.v0, b.v0); + a.v1 = InterleaveLower(a.v1, b.v1); + return a; +} + +// wasm_128 already defines a template with D, V, V args. + +// ------------------------------ InterleaveUpper (UpperHalf) + +template > +HWY_API Vec256 InterleaveUpper(D d, Vec256 a, Vec256 b) { + const Half dh; + a.v0 = InterleaveUpper(dh, a.v0, b.v0); + a.v1 = InterleaveUpper(dh, a.v1, b.v1); + return a; +} + +// ------------------------------ InterleaveWholeLower +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const Half dh; + VFromD ret; + ret.v0 = InterleaveLower(a.v0, b.v0); + ret.v1 = InterleaveUpper(dh, a.v0, b.v0); + return ret; +} + +// ------------------------------ InterleaveWholeUpper +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const Half dh; + VFromD ret; + ret.v0 = InterleaveLower(a.v1, b.v1); + ret.v1 = InterleaveUpper(dh, a.v1, b.v1); + return ret; +} + +// ------------------------------ ZipLower/ZipUpper defined in wasm_128 + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) +template > +HWY_API Vec256 Combine(D /* d */, Vec128 hi, Vec128 lo) { + Vec256 ret; + ret.v1 = hi; + ret.v0 = lo; + return ret; +} + +// ------------------------------ ZeroExtendVector (Combine) +template > +HWY_API Vec256 ZeroExtendVector(D d, Vec128 lo) { + const Half dh; + return Combine(d, Zero(dh), lo); +} + +// ------------------------------ ZeroExtendResizeBitCast + +namespace detail { + +template +HWY_INLINE VFromD ZeroExtendResizeBitCast( + hwy::SizeTag /* from_size_tag */, + hwy::SizeTag<32> /* to_size_tag */, DTo d_to, DFrom d_from, + VFromD v) { + const Half dh_to; + return ZeroExtendVector(d_to, ZeroExtendResizeBitCast(dh_to, d_from, v)); +} + +} // namespace detail + +// ------------------------------ ConcatLowerLower +template > +HWY_API Vec256 ConcatLowerLower(D /* tag */, Vec256 hi, Vec256 lo) { + Vec256 ret; + ret.v1 = hi.v0; + ret.v0 = lo.v0; + return ret; +} + +// ------------------------------ ConcatUpperUpper +template > +HWY_API Vec256 ConcatUpperUpper(D /* tag */, Vec256 hi, Vec256 lo) { + Vec256 ret; + ret.v1 = hi.v1; + ret.v0 = lo.v1; + return ret; +} + +// ------------------------------ ConcatLowerUpper +template > +HWY_API Vec256 ConcatLowerUpper(D /* tag */, Vec256 hi, Vec256 lo) { + Vec256 ret; + ret.v1 = hi.v0; + ret.v0 = lo.v1; + return ret; +} + +// ------------------------------ ConcatUpperLower +template > +HWY_API Vec256 ConcatUpperLower(D /* tag */, Vec256 hi, Vec256 lo) { + Vec256 ret; + ret.v1 = hi.v1; + ret.v0 = lo.v0; + return ret; +} + +// ------------------------------ ConcatOdd +template > +HWY_API Vec256 ConcatOdd(D d, Vec256 hi, Vec256 lo) { + const Half dh; + Vec256 ret; + ret.v0 = ConcatOdd(dh, lo.v1, lo.v0); + ret.v1 = ConcatOdd(dh, hi.v1, hi.v0); + return ret; +} + +// ------------------------------ ConcatEven +template > +HWY_API Vec256 ConcatEven(D d, Vec256 hi, Vec256 lo) { + const Half dh; + Vec256 ret; + ret.v0 = ConcatEven(dh, lo.v1, lo.v0); + ret.v1 = ConcatEven(dh, hi.v1, hi.v0); + return ret; +} + +// ------------------------------ DupEven +template +HWY_API Vec256 DupEven(Vec256 v) { + v.v0 = DupEven(v.v0); + v.v1 = DupEven(v.v1); + return v; +} + +// ------------------------------ DupOdd +template +HWY_API Vec256 DupOdd(Vec256 v) { + v.v0 = DupOdd(v.v0); + v.v1 = DupOdd(v.v1); + return v; +} + +// ------------------------------ OddEven +template +HWY_API Vec256 OddEven(Vec256 a, const Vec256 b) { + a.v0 = OddEven(a.v0, b.v0); + a.v1 = OddEven(a.v1, b.v1); + return a; +} + +// ------------------------------ InterleaveEven +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const Half dh; + a.v0 = InterleaveEven(dh, a.v0, b.v0); + a.v1 = InterleaveEven(dh, a.v1, b.v1); + return a; +} + +// ------------------------------ InterleaveOdd +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const Half dh; + a.v0 = InterleaveOdd(dh, a.v0, b.v0); + a.v1 = InterleaveOdd(dh, a.v1, b.v1); + return a; +} + +// ------------------------------ OddEvenBlocks +template +HWY_API Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { + odd.v0 = even.v0; + return odd; +} + +// ------------------------------ SwapAdjacentBlocks +template +HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { + Vec256 ret; + ret.v0 = v.v1; // swapped order + ret.v1 = v.v0; + return ret; +} + +// ------------------------------ InterleaveEvenBlocks +template , HWY_IF_V_SIZE_D(D, 32)> +HWY_API V InterleaveEvenBlocks(D, V a, V b) { + V ret; + ret.v0 = a.v0; + ret.v1 = b.v0; + return ret; +} +// ------------------------------ InterleaveOddBlocks +template , HWY_IF_V_SIZE_D(D, 32)> +HWY_API V InterleaveOddBlocks(D, V a, V b) { + V ret; + ret.v0 = a.v1; + ret.v1 = b.v1; + return ret; +} + +// ------------------------------ InterleaveLowerBlocks +template , HWY_IF_V_SIZE_D(D, 32)> +HWY_API V InterleaveLowerBlocks(D d, V a, V b) { + return InterleaveEvenBlocks(d, a, b); +} +// ------------------------------ InterleaveUpperBlocks +template , HWY_IF_V_SIZE_D(D, 32)> +HWY_API V InterleaveUpperBlocks(D d, V a, V b) { + return InterleaveOddBlocks(d, a, b); +} + +// ------------------------------ ReverseBlocks +template > +HWY_API Vec256 ReverseBlocks(D /* tag */, const Vec256 v) { + return SwapAdjacentBlocks(v); // 2 blocks, so Swap = Reverse +} + +// ------------------------------ Per4LaneBlockShuffle +namespace detail { + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + const DFromV d; + const Half dh; + using VH = VFromD; + + constexpr int kIdx3 = static_cast((kIdx3210 >> 6) & 3); + constexpr int kIdx2 = static_cast((kIdx3210 >> 4) & 3); + constexpr int kIdx1 = static_cast((kIdx3210 >> 2) & 3); + constexpr int kIdx0 = static_cast(kIdx3210 & 3); + + V ret; + ret.v0 = VH{wasm_i8x16_shuffle( + v.v0.raw, v.v0.raw, kIdx0, kIdx1, kIdx2, kIdx3, kIdx0 + 4, kIdx1 + 4, + kIdx2 + 4, kIdx3 + 4, kIdx0 + 8, kIdx1 + 8, kIdx2 + 8, kIdx3 + 8, + kIdx0 + 12, kIdx1 + 12, kIdx2 + 12, kIdx3 + 12)}; + ret.v1 = VH{wasm_i8x16_shuffle( + v.v1.raw, v.v1.raw, kIdx0, kIdx1, kIdx2, kIdx3, kIdx0 + 4, kIdx1 + 4, + kIdx2 + 4, kIdx3 + 4, kIdx0 + 8, kIdx1 + 8, kIdx2 + 8, kIdx3 + 8, + kIdx0 + 12, kIdx1 + 12, kIdx2 + 12, kIdx3 + 12)}; + return ret; +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + const DFromV d; + const Half dh; + using VH = VFromD; + + constexpr int kIdx3 = static_cast((kIdx3210 >> 6) & 3); + constexpr int kIdx2 = static_cast((kIdx3210 >> 4) & 3); + constexpr int kIdx1 = static_cast((kIdx3210 >> 2) & 3); + constexpr int kIdx0 = static_cast(kIdx3210 & 3); + + V ret; + ret.v0 = VH{wasm_i16x8_shuffle(v.v0.raw, v.v0.raw, kIdx0, kIdx1, kIdx2, kIdx3, + kIdx0 + 4, kIdx1 + 4, kIdx2 + 4, kIdx3 + 4)}; + ret.v1 = VH{wasm_i16x8_shuffle(v.v1.raw, v.v1.raw, kIdx0, kIdx1, kIdx2, kIdx3, + kIdx0 + 4, kIdx1 + 4, kIdx2 + 4, kIdx3 + 4)}; + return ret; +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + const DFromV d; + const Half dh; + using VH = VFromD; + + constexpr int kIdx3 = static_cast((kIdx3210 >> 6) & 3); + constexpr int kIdx2 = static_cast((kIdx3210 >> 4) & 3); + constexpr int kIdx1 = static_cast((kIdx3210 >> 2) & 3); + constexpr int kIdx0 = static_cast(kIdx3210 & 3); + + V ret; + ret.v0 = + VH{wasm_i32x4_shuffle(v.v0.raw, v.v0.raw, kIdx0, kIdx1, kIdx2, kIdx3)}; + ret.v1 = + VH{wasm_i32x4_shuffle(v.v1.raw, v.v1.raw, kIdx0, kIdx1, kIdx2, kIdx3)}; + return ret; +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<8> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + const DFromV d; + const Half dh; + using VH = VFromD; + + constexpr int kIdx3 = static_cast((kIdx3210 >> 6) & 3); + constexpr int kIdx2 = static_cast((kIdx3210 >> 4) & 3); + constexpr int kIdx1 = static_cast((kIdx3210 >> 2) & 3); + constexpr int kIdx0 = static_cast(kIdx3210 & 3); + + V ret; + ret.v0 = VH{wasm_i64x2_shuffle(v.v0.raw, v.v1.raw, kIdx0, kIdx1)}; + ret.v1 = VH{wasm_i64x2_shuffle(v.v0.raw, v.v1.raw, kIdx2, kIdx3)}; + return ret; +} + +} // namespace detail + +// ------------------------------ SlideUpBlocks +template +HWY_API VFromD SlideUpBlocks(D d, VFromD v) { + static_assert(0 <= kBlocks && kBlocks <= 1, + "kBlocks must be between 0 and 1"); + return (kBlocks == 1) ? ConcatLowerLower(d, v, Zero(d)) : v; +} + +// ------------------------------ SlideDownBlocks +template +HWY_API VFromD SlideDownBlocks(D d, VFromD v) { + static_assert(0 <= kBlocks && kBlocks <= 1, + "kBlocks must be between 0 and 1"); + const Half dh; + return (kBlocks == 1) ? ZeroExtendVector(d, UpperHalf(dh, v)) : v; +} + +// ------------------------------ SlideUpLanes + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { + const Half dh; + const RebindToUnsigned du; + const RebindToUnsigned dh_u; + const auto vu = BitCast(du, v); + VFromD ret; + +#if !HWY_IS_DEBUG_BUILD + constexpr size_t kLanesPerBlock = 16 / sizeof(TFromD); + if (__builtin_constant_p(amt) && amt < kLanesPerBlock) { + switch (amt * sizeof(TFromD)) { + case 0: + return v; + case 1: + ret.v0 = BitCast(dh, ShiftLeftBytes<1>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<15>(dh_u, vu.v1, vu.v0)); + return ret; + case 2: + ret.v0 = BitCast(dh, ShiftLeftBytes<2>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<14>(dh_u, vu.v1, vu.v0)); + return ret; + case 3: + ret.v0 = BitCast(dh, ShiftLeftBytes<3>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<13>(dh_u, vu.v1, vu.v0)); + return ret; + case 4: + ret.v0 = BitCast(dh, ShiftLeftBytes<4>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<12>(dh_u, vu.v1, vu.v0)); + return ret; + case 5: + ret.v0 = BitCast(dh, ShiftLeftBytes<5>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<11>(dh_u, vu.v1, vu.v0)); + return ret; + case 6: + ret.v0 = BitCast(dh, ShiftLeftBytes<6>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<10>(dh_u, vu.v1, vu.v0)); + return ret; + case 7: + ret.v0 = BitCast(dh, ShiftLeftBytes<7>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<9>(dh_u, vu.v1, vu.v0)); + return ret; + case 8: + ret.v0 = BitCast(dh, ShiftLeftBytes<8>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<8>(dh_u, vu.v1, vu.v0)); + return ret; + case 9: + ret.v0 = BitCast(dh, ShiftLeftBytes<9>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<7>(dh_u, vu.v1, vu.v0)); + return ret; + case 10: + ret.v0 = BitCast(dh, ShiftLeftBytes<10>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<6>(dh_u, vu.v1, vu.v0)); + return ret; + case 11: + ret.v0 = BitCast(dh, ShiftLeftBytes<11>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<5>(dh_u, vu.v1, vu.v0)); + return ret; + case 12: + ret.v0 = BitCast(dh, ShiftLeftBytes<12>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<4>(dh_u, vu.v1, vu.v0)); + return ret; + case 13: + ret.v0 = BitCast(dh, ShiftLeftBytes<13>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<3>(dh_u, vu.v1, vu.v0)); + return ret; + case 14: + ret.v0 = BitCast(dh, ShiftLeftBytes<14>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<2>(dh_u, vu.v1, vu.v0)); + return ret; + case 15: + ret.v0 = BitCast(dh, ShiftLeftBytes<15>(dh_u, vu.v0)); + ret.v1 = BitCast(dh, CombineShiftRightBytes<1>(dh_u, vu.v1, vu.v0)); + return ret; + } + } + + if (__builtin_constant_p(amt >= kLanesPerBlock) && amt >= kLanesPerBlock) { + ret.v0 = Zero(dh); + ret.v1 = SlideUpLanes(dh, LowerHalf(dh, v), amt - kLanesPerBlock); + return ret; + } +#endif + + const Repartition du8; + const RebindToSigned di8; + const Half dh_i8; + + const auto lo_byte_idx = BitCast( + di8, + Iota(du8, static_cast(size_t{0} - amt * sizeof(TFromD)))); + + const auto hi_byte_idx = + UpperHalf(dh_i8, lo_byte_idx) - Set(dh_i8, int8_t{16}); + const auto hi_sel_mask = + UpperHalf(dh_i8, lo_byte_idx) > Set(dh_i8, int8_t{15}); + + ret = BitCast(d, + TableLookupBytesOr0(ConcatLowerLower(du, vu, vu), lo_byte_idx)); + ret.v1 = + BitCast(dh, IfThenElse(hi_sel_mask, + TableLookupBytes(UpperHalf(dh_u, vu), hi_byte_idx), + BitCast(dh_i8, ret.v1))); + return ret; +} + +// ------------------------------ Slide1Up +template +HWY_API VFromD Slide1Up(D d, VFromD v) { + VFromD ret; + const Half dh; + constexpr int kShrByteAmt = static_cast(16 - sizeof(TFromD)); + ret.v0 = ShiftLeftLanes<1>(dh, v.v0); + ret.v1 = CombineShiftRightBytes(dh, v.v1, v.v0); + return ret; +} + +// ------------------------------ SlideDownLanes + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { + const Half dh; + const RebindToUnsigned du; + const RebindToUnsigned dh_u; + VFromD ret; + + const auto vu = BitCast(du, v); + +#if !HWY_IS_DEBUG_BUILD + constexpr size_t kLanesPerBlock = 16 / sizeof(TFromD); + if (__builtin_constant_p(amt) && amt < kLanesPerBlock) { + switch (amt * sizeof(TFromD)) { + case 0: + return v; + case 1: + ret.v0 = BitCast(dh, CombineShiftRightBytes<1>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<1>(dh_u, vu.v1)); + return ret; + case 2: + ret.v0 = BitCast(dh, CombineShiftRightBytes<2>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<2>(dh_u, vu.v1)); + return ret; + case 3: + ret.v0 = BitCast(dh, CombineShiftRightBytes<3>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<3>(dh_u, vu.v1)); + return ret; + case 4: + ret.v0 = BitCast(dh, CombineShiftRightBytes<4>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<4>(dh_u, vu.v1)); + return ret; + case 5: + ret.v0 = BitCast(dh, CombineShiftRightBytes<5>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<5>(dh_u, vu.v1)); + return ret; + case 6: + ret.v0 = BitCast(dh, CombineShiftRightBytes<6>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<6>(dh_u, vu.v1)); + return ret; + case 7: + ret.v0 = BitCast(dh, CombineShiftRightBytes<7>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<7>(dh_u, vu.v1)); + return ret; + case 8: + ret.v0 = BitCast(dh, CombineShiftRightBytes<8>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<8>(dh_u, vu.v1)); + return ret; + case 9: + ret.v0 = BitCast(dh, CombineShiftRightBytes<9>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<9>(dh_u, vu.v1)); + return ret; + case 10: + ret.v0 = BitCast(dh, CombineShiftRightBytes<10>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<10>(dh_u, vu.v1)); + return ret; + case 11: + ret.v0 = BitCast(dh, CombineShiftRightBytes<11>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<11>(dh_u, vu.v1)); + return ret; + case 12: + ret.v0 = BitCast(dh, CombineShiftRightBytes<12>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<12>(dh_u, vu.v1)); + return ret; + case 13: + ret.v0 = BitCast(dh, CombineShiftRightBytes<13>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<13>(dh_u, vu.v1)); + return ret; + case 14: + ret.v0 = BitCast(dh, CombineShiftRightBytes<14>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<14>(dh_u, vu.v1)); + return ret; + case 15: + ret.v0 = BitCast(dh, CombineShiftRightBytes<15>(dh_u, vu.v1, vu.v0)); + ret.v1 = BitCast(dh, ShiftRightBytes<15>(dh_u, vu.v1)); + return ret; + } + } + + if (__builtin_constant_p(amt >= kLanesPerBlock) && amt >= kLanesPerBlock) { + ret.v0 = SlideDownLanes(dh, UpperHalf(dh, v), amt - kLanesPerBlock); + ret.v1 = Zero(dh); + return ret; + } +#endif + + const Repartition du8; + const Half dh_u8; + + const auto lo_byte_idx = + Iota(du8, static_cast(amt * sizeof(TFromD))); + const auto u8_16 = Set(du8, uint8_t{16}); + const auto hi_byte_idx = lo_byte_idx - u8_16; + + const auto lo_sel_mask = + LowerHalf(dh_u8, lo_byte_idx) < LowerHalf(dh_u8, u8_16); + ret = BitCast(d, IfThenElseZero(hi_byte_idx < u8_16, + TableLookupBytes(ConcatUpperUpper(du, vu, vu), + hi_byte_idx))); + ret.v0 = + BitCast(dh, IfThenElse(lo_sel_mask, + TableLookupBytes(LowerHalf(dh_u, vu), + LowerHalf(dh_u8, lo_byte_idx)), + BitCast(dh_u8, LowerHalf(dh, ret)))); + return ret; +} + +// ------------------------------ Slide1Down +template +HWY_API VFromD Slide1Down(D d, VFromD v) { + VFromD ret; + const Half dh; + constexpr int kShrByteAmt = static_cast(sizeof(TFromD)); + ret.v0 = CombineShiftRightBytes(dh, v.v1, v.v0); + ret.v1 = ShiftRightBytes(dh, v.v1); + return ret; +} + +// ================================================== CONVERT + +// ------------------------------ PromoteTo + +template +HWY_API VFromD PromoteTo(D d, Vec128 v) { + const Half dh; + VFromD ret; + // PromoteLowerTo is defined later in generic_ops-inl.h. + ret.v0 = PromoteTo(dh, LowerHalf(v)); + ret.v1 = PromoteUpperTo(dh, v); + return ret; +} + +// 4x promotion: 8-bit to 32-bit or 16-bit to 64-bit +template +HWY_API Vec256> PromoteTo(DW d, Vec64 v) { + const Half dh; + // 16-bit lanes for UI8->UI32, 32-bit lanes for UI16->UI64 + const Rebind, decltype(d)> d2; + const auto v_2x = PromoteTo(d2, v); + Vec256> ret; + // PromoteLowerTo is defined later in generic_ops-inl.h. + ret.v0 = PromoteTo(dh, LowerHalf(v_2x)); + ret.v1 = PromoteUpperTo(dh, v_2x); + return ret; +} + +// 8x promotion: 8-bit to 64-bit +template +HWY_API Vec256> PromoteTo(DW d, Vec32 v) { + const Half dh; + const Repartition>, decltype(dh)> d4; // 32-bit lanes + const auto v32 = PromoteTo(d4, v); + Vec256> ret; + // PromoteLowerTo is defined later in generic_ops-inl.h. + ret.v0 = PromoteTo(dh, LowerHalf(v32)); + ret.v1 = PromoteUpperTo(dh, v32); + return ret; +} + +// ------------------------------ PromoteUpperTo + +// Not native, but still define this here because wasm_128 toggles +// HWY_NATIVE_PROMOTE_UPPER_TO. +template +HWY_API VFromD PromoteUpperTo(D d, Vec256 v) { + // Lanes(d) may differ from Lanes(DFromV()). Use the lane type + // from v because it cannot be deduced from D (could be either bf16 or f16). + const Rebind dh; + return PromoteTo(d, UpperHalf(dh, v)); +} + +// ------------------------------ DemoteTo + +template +HWY_API Vec128 DemoteTo(D /* tag */, Vec256 v) { + return Vec128{wasm_u16x8_narrow_i32x4(v.v0.raw, v.v1.raw)}; +} + +template +HWY_API Vec128 DemoteTo(D /* tag */, Vec256 v) { + return Vec128{wasm_i16x8_narrow_i32x4(v.v0.raw, v.v1.raw)}; +} + +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec256 v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.v0.raw, v.v1.raw); + return Vec64{wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; +} + +template +HWY_API Vec128 DemoteTo(D /* tag */, Vec256 v) { + return Vec128{wasm_u8x16_narrow_i16x8(v.v0.raw, v.v1.raw)}; +} + +template +HWY_API Vec64 DemoteTo(D /* tag */, Vec256 v) { + const auto intermediate = wasm_i16x8_narrow_i32x4(v.v0.raw, v.v1.raw); + return Vec64{wasm_i8x16_narrow_i16x8(intermediate, intermediate)}; +} + +template +HWY_API Vec128 DemoteTo(D /* tag */, Vec256 v) { + return Vec128{wasm_i8x16_narrow_i16x8(v.v0.raw, v.v1.raw)}; +} + +template +HWY_API Vec128 DemoteTo(D di, Vec256 v) { + const Vec64 lo{wasm_i32x4_trunc_sat_f64x2_zero(v.v0.raw)}; + const Vec64 hi{wasm_i32x4_trunc_sat_f64x2_zero(v.v1.raw)}; + return Combine(di, hi, lo); +} + +template +HWY_API Vec128 DemoteTo(D di, Vec256 v) { + const Vec64 lo{wasm_u32x4_trunc_sat_f64x2_zero(v.v0.raw)}; + const Vec64 hi{wasm_u32x4_trunc_sat_f64x2_zero(v.v1.raw)}; + return Combine(di, hi, lo); +} + +template +HWY_API Vec128 DemoteTo(D df, Vec256 v) { + const Vec64 lo = DemoteTo(Full64(), v.v0); + const Vec64 hi = DemoteTo(Full64(), v.v1); + return Combine(df, hi, lo); +} + +template +HWY_API Vec128 DemoteTo(D df, Vec256 v) { + const Vec64 lo = DemoteTo(Full64(), v.v0); + const Vec64 hi = DemoteTo(Full64(), v.v1); + return Combine(df, hi, lo); +} + +template +HWY_API Vec128 DemoteTo(D d16, Vec256 v) { + const Half d16h; + const Vec64 lo = DemoteTo(d16h, v.v0); + const Vec64 hi = DemoteTo(d16h, v.v1); + return Combine(d16, hi, lo); +} + +template +HWY_API Vec128 DemoteTo(D df32, Vec256 v) { + const Half df32h; + const Vec64 lo = DemoteTo(df32h, v.v0); + const Vec64 hi = DemoteTo(df32h, v.v1); + return Combine(df32, hi, lo); +} + +// For already range-limited input [0, 255]. +HWY_API Vec64 U8FromU32(Vec256 v) { + const Full64 du8; + const Full256 di32; // no unsigned DemoteTo + return DemoteTo(du8, BitCast(di32, v)); +} + +// ------------------------------ Truncations + +template +HWY_API Vec32 TruncateTo(D /* tag */, Vec256 v) { + return Vec32{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 8, 16, 24, 0, + 8, 16, 24, 0, 8, 16, 24, 0, 8, 16, + 24)}; +} + +template +HWY_API Vec64 TruncateTo(D /* tag */, Vec256 v) { + return Vec64{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 8, 9, 16, + 17, 24, 25, 0, 1, 8, 9, 16, 17, 24, + 25)}; +} + +template +HWY_API Vec128 TruncateTo(D /* tag */, Vec256 v) { + return Vec128{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 2, 3, 8, + 9, 10, 11, 16, 17, 18, 19, 24, 25, + 26, 27)}; +} + +template +HWY_API Vec64 TruncateTo(D /* tag */, Vec256 v) { + return Vec64{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 4, 8, 12, 16, + 20, 24, 28, 0, 4, 8, 12, 16, 20, 24, + 28)}; +} + +template +HWY_API Vec128 TruncateTo(D /* tag */, Vec256 v) { + return Vec128{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 4, 5, 8, + 9, 12, 13, 16, 17, 20, 21, 24, 25, + 28, 29)}; +} + +template +HWY_API Vec128 TruncateTo(D /* tag */, Vec256 v) { + return Vec128{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 2, 4, 6, 8, + 10, 12, 14, 16, 18, 20, 22, 24, 26, + 28, 30)}; +} + +// ------------------------------ ReorderDemote2To +template ), HWY_IF_SIGNED_V(V), + HWY_IF_T_SIZE_ONE_OF_D(DN, (1 << 1) | (1 << 2) | (1 << 4)), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2)> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const Half dnh; + VFromD demoted; + demoted.v0 = DemoteTo(dnh, a); + demoted.v1 = DemoteTo(dnh, b); + return demoted; +} + +template ) * 2)> +HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { + const Half dnh; + VFromD demoted; + demoted.v0 = DemoteTo(dnh, a); + demoted.v1 = DemoteTo(dnh, b); + return demoted; +} + +// ------------------------------ Convert i32 <=> f32 (Round) + +template > +HWY_API Vec256 ConvertTo(DTo d, const Vec256 v) { + const Half dh; + Vec256 ret; + ret.v0 = ConvertTo(dh, v.v0); + ret.v1 = ConvertTo(dh, v.v1); + return ret; +} + +template +HWY_API Vec256> NearestInt(const Vec256 v) { + return ConvertTo(Full256>(), Round(v)); +} + +// ================================================== MISC + +// ------------------------------ LoadMaskBits (TestBit) + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + const Half dh; + MFromD ret; + ret.m0 = LoadMaskBits(dh, bits); + // If size=4, one 128-bit vector has 4 mask bits; otherwise 2 for size=8. + // Both halves fit in one byte's worth of mask bits. + constexpr size_t kBitsPerHalf = 16 / sizeof(TFromD); + const uint8_t bits_upper[8] = {static_cast(bits[0] >> kBitsPerHalf)}; + ret.m1 = LoadMaskBits(dh, bits_upper); + return ret; +} + +template +HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + const Half dh; + MFromD ret; + ret.m0 = LoadMaskBits(dh, bits); + constexpr size_t kLanesPerHalf = 16 / sizeof(TFromD); + constexpr size_t kBytesPerHalf = kLanesPerHalf / 8; + static_assert(kBytesPerHalf != 0, "Lane size <= 16 bits => at least 8 lanes"); + ret.m1 = LoadMaskBits(dh, bits + kBytesPerHalf); + return ret; +} + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + const Half dh; + MFromD ret; + ret.m0 = ret.m1 = Dup128MaskFromMaskBits(dh, mask_bits); + return ret; +} + +// ------------------------------ Mask + +// `p` points to at least 8 writable bytes. +template , + HWY_IF_T_SIZE_ONE_OF(T, (1 << 4) | (1 << 8))> +HWY_API size_t StoreMaskBits(D d, const Mask256 mask, uint8_t* bits) { + const Half dh; + StoreMaskBits(dh, mask.m0, bits); + const uint8_t lo = bits[0]; + StoreMaskBits(dh, mask.m1, bits); + // If size=4, one 128-bit vector has 4 mask bits; otherwise 2 for size=8. + // Both halves fit in one byte's worth of mask bits. + constexpr size_t kBitsPerHalf = 16 / sizeof(T); + bits[0] = static_cast(lo | (bits[0] << kBitsPerHalf)); + return (kBitsPerHalf * 2 + 7) / 8; +} + +template , + HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2))> +HWY_API size_t StoreMaskBits(D d, const Mask256 mask, uint8_t* bits) { + const Half dh; + constexpr size_t kLanesPerHalf = 16 / sizeof(T); + constexpr size_t kBytesPerHalf = kLanesPerHalf / 8; + static_assert(kBytesPerHalf != 0, "Lane size <= 16 bits => at least 8 lanes"); + StoreMaskBits(dh, mask.m0, bits); + StoreMaskBits(dh, mask.m1, bits + kBytesPerHalf); + return kBytesPerHalf * 2; +} + +template > +HWY_API size_t CountTrue(D d, const Mask256 m) { + const Half dh; + return CountTrue(dh, m.m0) + CountTrue(dh, m.m1); +} + +template > +HWY_API bool AllFalse(D d, const Mask256 m) { + const Half dh; + return AllFalse(dh, m.m0) && AllFalse(dh, m.m1); +} + +template > +HWY_API bool AllTrue(D d, const Mask256 m) { + const Half dh; + return AllTrue(dh, m.m0) && AllTrue(dh, m.m1); +} + +template > +HWY_API size_t FindKnownFirstTrue(D d, const Mask256 mask) { + const Half dh; + const intptr_t lo = FindFirstTrue(dh, mask.m0); // not known + constexpr size_t kLanesPerHalf = 16 / sizeof(T); + return lo >= 0 ? static_cast(lo) + : kLanesPerHalf + FindKnownFirstTrue(dh, mask.m1); +} + +template > +HWY_API intptr_t FindFirstTrue(D d, const Mask256 mask) { + const Half dh; + const intptr_t lo = FindFirstTrue(dh, mask.m0); + constexpr int kLanesPerHalf = 16 / sizeof(T); + if (lo >= 0) return lo; + + const intptr_t hi = FindFirstTrue(dh, mask.m1); + return hi + (hi >= 0 ? kLanesPerHalf : 0); +} + +template > +HWY_API size_t FindKnownLastTrue(D d, const Mask256 mask) { + const Half dh; + const intptr_t hi = FindLastTrue(dh, mask.m1); // not known + constexpr size_t kLanesPerHalf = 16 / sizeof(T); + return hi >= 0 ? kLanesPerHalf + static_cast(hi) + : FindKnownLastTrue(dh, mask.m0); +} + +template > +HWY_API intptr_t FindLastTrue(D d, const Mask256 mask) { + const Half dh; + constexpr int kLanesPerHalf = 16 / sizeof(T); + const intptr_t hi = FindLastTrue(dh, mask.m1); + return hi >= 0 ? kLanesPerHalf + hi : FindLastTrue(dh, mask.m0); +} + +// ------------------------------ CompressStore +template > +HWY_API size_t CompressStore(Vec256 v, const Mask256 mask, D d, + T* HWY_RESTRICT unaligned) { + const Half dh; + const size_t count = CompressStore(v.v0, mask.m0, dh, unaligned); + const size_t count2 = CompressStore(v.v1, mask.m1, dh, unaligned + count); + return count + count2; +} + +// ------------------------------ CompressBlendedStore +template > +HWY_API size_t CompressBlendedStore(Vec256 v, const Mask256 m, D d, + T* HWY_RESTRICT unaligned) { + const Half dh; + const size_t count = CompressBlendedStore(v.v0, m.m0, dh, unaligned); + const size_t count2 = CompressBlendedStore(v.v1, m.m1, dh, unaligned + count); + return count + count2; +} + +// ------------------------------ CompressBitsStore + +template > +HWY_API size_t CompressBitsStore(Vec256 v, const uint8_t* HWY_RESTRICT bits, + D d, T* HWY_RESTRICT unaligned) { + const Mask256 m = LoadMaskBits(d, bits); + return CompressStore(v, m, d, unaligned); +} + +// ------------------------------ Compress +template +HWY_API Vec256 Compress(const Vec256 v, const Mask256 mask) { + const DFromV d; + alignas(32) T lanes[32 / sizeof(T)] = {}; + (void)CompressStore(v, mask, d, lanes); + return Load(d, lanes); +} + +// ------------------------------ CompressNot +template +HWY_API Vec256 CompressNot(Vec256 v, const Mask256 mask) { + return Compress(v, Not(mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec256 CompressBlocksNot(Vec256 v, + Mask256 mask) { + const Full128 dh; + // Because the non-selected (mask=1) blocks are undefined, we can return the + // input unless mask = 01, in which case we must bring down the upper block. + return AllTrue(dh, AndNot(mask.m1, mask.m0)) ? SwapAdjacentBlocks(v) : v; +} + +// ------------------------------ CompressBits +template +HWY_API Vec256 CompressBits(Vec256 v, const uint8_t* HWY_RESTRICT bits) { + const Mask256 m = LoadMaskBits(DFromV(), bits); + return Compress(v, m); +} + +// ------------------------------ Expand +template +HWY_API Vec256 Expand(const Vec256 v, const Mask256 mask) { + Vec256 ret; + const Full256 d; + const Half dh; + alignas(32) T lanes[32 / sizeof(T)] = {}; + Store(v, d, lanes); + ret.v0 = Expand(v.v0, mask.m0); + ret.v1 = Expand(LoadU(dh, lanes + CountTrue(dh, mask.m0)), mask.m1); + return ret; +} + +// ------------------------------ LoadExpand +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { + return Expand(LoadU(d, unaligned), mask); +} + +// ------------------------------ LoadInterleaved3/4 + +// Implemented in generic_ops, we just overload LoadTransposedBlocks3/4. + +namespace detail { + +// Input: +// 1 0 (<- first block of unaligned) +// 3 2 +// 5 4 +// Output: +// 3 0 +// 4 1 +// 5 2 +template > +HWY_API void LoadTransposedBlocks3(D d, const T* HWY_RESTRICT unaligned, + Vec256& A, Vec256& B, Vec256& C) { + const Vec256 v10 = LoadU(d, unaligned + 0 * MaxLanes(d)); + const Vec256 v32 = LoadU(d, unaligned + 1 * MaxLanes(d)); + const Vec256 v54 = LoadU(d, unaligned + 2 * MaxLanes(d)); + + A = ConcatUpperLower(d, v32, v10); + B = ConcatLowerUpper(d, v54, v10); + C = ConcatUpperLower(d, v54, v32); +} + +// Input (128-bit blocks): +// 1 0 (first block of unaligned) +// 3 2 +// 5 4 +// 7 6 +// Output: +// 4 0 (LSB of A) +// 5 1 +// 6 2 +// 7 3 +template > +HWY_API void LoadTransposedBlocks4(D d, const T* HWY_RESTRICT unaligned, + Vec256& vA, Vec256& vB, Vec256& vC, + Vec256& vD) { + const Vec256 v10 = LoadU(d, unaligned + 0 * MaxLanes(d)); + const Vec256 v32 = LoadU(d, unaligned + 1 * MaxLanes(d)); + const Vec256 v54 = LoadU(d, unaligned + 2 * MaxLanes(d)); + const Vec256 v76 = LoadU(d, unaligned + 3 * MaxLanes(d)); + + vA = ConcatLowerLower(d, v54, v10); + vB = ConcatUpperUpper(d, v54, v10); + vC = ConcatLowerLower(d, v76, v32); + vD = ConcatUpperUpper(d, v76, v32); +} + +} // namespace detail + +// ------------------------------ StoreInterleaved2/3/4 (ConcatUpperLower) + +// Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. + +namespace detail { + +// Input (128-bit blocks): +// 2 0 (LSB of i) +// 3 1 +// Output: +// 1 0 +// 3 2 +template > +HWY_API void StoreTransposedBlocks2(Vec256 i, Vec256 j, D d, + T* HWY_RESTRICT unaligned) { + const Vec256 out0 = ConcatLowerLower(d, j, i); + const Vec256 out1 = ConcatUpperUpper(d, j, i); + StoreU(out0, d, unaligned + 0 * MaxLanes(d)); + StoreU(out1, d, unaligned + 1 * MaxLanes(d)); +} + +// Input (128-bit blocks): +// 3 0 (LSB of i) +// 4 1 +// 5 2 +// Output: +// 1 0 +// 3 2 +// 5 4 +template > +HWY_API void StoreTransposedBlocks3(Vec256 i, Vec256 j, Vec256 k, D d, + T* HWY_RESTRICT unaligned) { + const Vec256 out0 = ConcatLowerLower(d, j, i); + const Vec256 out1 = ConcatUpperLower(d, i, k); + const Vec256 out2 = ConcatUpperUpper(d, k, j); + StoreU(out0, d, unaligned + 0 * MaxLanes(d)); + StoreU(out1, d, unaligned + 1 * MaxLanes(d)); + StoreU(out2, d, unaligned + 2 * MaxLanes(d)); +} + +// Input (128-bit blocks): +// 4 0 (LSB of i) +// 5 1 +// 6 2 +// 7 3 +// Output: +// 1 0 +// 3 2 +// 5 4 +// 7 6 +template > +HWY_API void StoreTransposedBlocks4(Vec256 i, Vec256 j, Vec256 k, + Vec256 l, D d, + T* HWY_RESTRICT unaligned) { + // Write lower halves, then upper. + const Vec256 out0 = ConcatLowerLower(d, j, i); + const Vec256 out1 = ConcatLowerLower(d, l, k); + StoreU(out0, d, unaligned + 0 * MaxLanes(d)); + StoreU(out1, d, unaligned + 1 * MaxLanes(d)); + const Vec256 out2 = ConcatUpperUpper(d, j, i); + const Vec256 out3 = ConcatUpperUpper(d, l, k); + StoreU(out2, d, unaligned + 2 * MaxLanes(d)); + StoreU(out3, d, unaligned + 3 * MaxLanes(d)); +} + +} // namespace detail + +// ------------------------------ Additional mask logical operations + +template +HWY_API Mask256 SetAtOrAfterFirst(Mask256 mask) { + const Full256 d; + const Half dh; + const Repartition dh_i64; + + Mask256 result; + result.m0 = SetAtOrAfterFirst(mask.m0); + result.m1 = SetAtOrAfterFirst(mask.m1); + + // Copy the sign bit of the lower 128-bit half to the upper 128-bit half + const auto vmask_lo = BitCast(dh_i64, VecFromMask(dh, result.m0)); + result.m1 = + Or(result.m1, MaskFromVec(BitCast(dh, BroadcastSignBit(InterleaveUpper( + dh_i64, vmask_lo, vmask_lo))))); + + return result; +} + +template +HWY_API Mask256 SetBeforeFirst(Mask256 mask) { + return Not(SetAtOrAfterFirst(mask)); +} + +template +HWY_API Mask256 SetOnlyFirst(Mask256 mask) { + const Full256 d; + const RebindToSigned di; + const Repartition di64; + const Half dh_i64; + + const auto zero = Zero(di64); + const auto vmask = BitCast(di64, VecFromMask(d, mask)); + + const auto vmask_eq_0 = VecFromMask(di64, vmask == zero); + auto vmask2_lo = LowerHalf(dh_i64, vmask_eq_0); + auto vmask2_hi = UpperHalf(dh_i64, vmask_eq_0); + + vmask2_lo = And(vmask2_lo, InterleaveLower(vmask2_lo, vmask2_lo)); + vmask2_hi = And(ConcatLowerUpper(dh_i64, vmask2_hi, vmask2_lo), + InterleaveUpper(dh_i64, vmask2_lo, vmask2_lo)); + vmask2_lo = InterleaveLower(Set(dh_i64, int64_t{-1}), vmask2_lo); + + const auto vmask2 = Combine(di64, vmask2_hi, vmask2_lo); + const auto only_first_vmask = Neg(BitCast(di, And(vmask, Neg(vmask)))); + return MaskFromVec(BitCast(d, And(only_first_vmask, BitCast(di, vmask2)))); +} + +template +HWY_API Mask256 SetAtOrBeforeFirst(Mask256 mask) { + const Full256 d; + constexpr size_t kLanesPerBlock = MaxLanes(d) / 2; + + const auto vmask = VecFromMask(d, mask); + const auto vmask_lo = ConcatLowerLower(d, vmask, Zero(d)); + return SetBeforeFirst( + MaskFromVec(CombineShiftRightBytes<(kLanesPerBlock - 1) * sizeof(T)>( + d, vmask, vmask_lo))); +} + +// ------------------------------ WidenMulPairwiseAdd +template > +HWY_API Vec256 WidenMulPairwiseAdd(D32 d32, Vec256 a, Vec256 b) { + const Half d32h; + Vec256 result; + result.v0 = WidenMulPairwiseAdd(d32h, a.v0, b.v0); + result.v1 = WidenMulPairwiseAdd(d32h, a.v1, b.v1); + return result; +} + +// ------------------------------ ReorderWidenMulAccumulate +template > +HWY_API Vec256 ReorderWidenMulAccumulate(D32 d32, Vec256 a, + Vec256 b, Vec256 sum0, + Vec256& sum1) { + const Half d32h; + sum0.v0 = ReorderWidenMulAccumulate(d32h, a.v0, b.v0, sum0.v0, sum1.v0); + sum0.v1 = ReorderWidenMulAccumulate(d32h, a.v1, b.v1, sum0.v1, sum1.v1); + return sum0; +} + +// ------------------------------ Reductions in generic_ops + +// ------------------------------ Lt128 + +template > +HWY_INLINE Mask256 Lt128(D d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Lt128(dh, a.v0, b.v0); + ret.m1 = Lt128(dh, a.v1, b.v1); + return ret; +} + +template > +HWY_INLINE Mask256 Lt128Upper(D d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Lt128Upper(dh, a.v0, b.v0); + ret.m1 = Lt128Upper(dh, a.v1, b.v1); + return ret; +} + +template > +HWY_INLINE Mask256 Eq128(D d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Eq128(dh, a.v0, b.v0); + ret.m1 = Eq128(dh, a.v1, b.v1); + return ret; +} + +template > +HWY_INLINE Mask256 Eq128Upper(D d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Eq128Upper(dh, a.v0, b.v0); + ret.m1 = Eq128Upper(dh, a.v1, b.v1); + return ret; +} + +template > +HWY_INLINE Mask256 Ne128(D d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Ne128(dh, a.v0, b.v0); + ret.m1 = Ne128(dh, a.v1, b.v1); + return ret; +} + +template > +HWY_INLINE Mask256 Ne128Upper(D d, Vec256 a, Vec256 b) { + const Half dh; + Mask256 ret; + ret.m0 = Ne128Upper(dh, a.v0, b.v0); + ret.m1 = Ne128Upper(dh, a.v1, b.v1); + return ret; +} + +template > +HWY_INLINE Vec256 Min128(D d, Vec256 a, Vec256 b) { + const Half dh; + Vec256 ret; + ret.v0 = Min128(dh, a.v0, b.v0); + ret.v1 = Min128(dh, a.v1, b.v1); + return ret; +} + +template > +HWY_INLINE Vec256 Max128(D d, Vec256 a, Vec256 b) { + const Half dh; + Vec256 ret; + ret.v0 = Max128(dh, a.v0, b.v0); + ret.v1 = Max128(dh, a.v1, b.v1); + return ret; +} + +template > +HWY_INLINE Vec256 Min128Upper(D d, Vec256 a, Vec256 b) { + const Half dh; + Vec256 ret; + ret.v0 = Min128Upper(dh, a.v0, b.v0); + ret.v1 = Min128Upper(dh, a.v1, b.v1); + return ret; +} + +template > +HWY_INLINE Vec256 Max128Upper(D d, Vec256 a, Vec256 b) { + const Half dh; + Vec256 ret; + ret.v0 = Max128Upper(dh, a.v0, b.v0); + ret.v1 = Max128Upper(dh, a.v1, b.v1); + return ret; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); diff --git a/lib/highway/hwy/ops/x86_128-inl.h b/lib/highway/hwy/ops/x86_128-inl.h new file mode 100644 index 00000000000..f29b6e30907 --- /dev/null +++ b/lib/highway/hwy/ops/x86_128-inl.h @@ -0,0 +1,14151 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 128-bit vectors and SSE4 instructions, plus some AVX2 and AVX512-VL +// operations when compiling for those targets. +// External include guard in highway.h - see comment there. + +// Must come before HWY_DIAGNOSTICS and HWY_COMPILER_GCC_ACTUAL +#include "hwy/base.h" + +// Avoid uninitialized warnings in GCC's emmintrin.h - see +// https://github.com/google/highway/issues/710 and pull/902 +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") +HWY_DIAGNOSTICS_OFF(disable : 4701 4703 6001 26494, + ignored "-Wmaybe-uninitialized") +#endif + +#include +#include +#if HWY_TARGET == HWY_SSSE3 +#include // SSSE3 +#elif HWY_TARGET <= HWY_SSE4 +#include // SSE4 +#ifndef HWY_DISABLE_PCLMUL_AES +#include // CLMUL +#endif +#endif + +#include "hwy/ops/shared-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +// Enable generic functions for whichever of (f16, bf16) are not supported. +#if !HWY_HAVE_FLOAT16 +#define HWY_X86_IF_EMULATED_D(D) HWY_IF_SPECIAL_FLOAT_D(D) +#else +#define HWY_X86_IF_EMULATED_D(D) HWY_IF_BF16_D(D) +#endif + +#undef HWY_AVX3_HAVE_F32_TO_BF16C +#if HWY_TARGET <= HWY_AVX3_ZEN4 && !HWY_COMPILER_CLANGCL && \ + (HWY_COMPILER_GCC_ACTUAL >= 1000 || HWY_COMPILER_CLANG >= 900) && \ + HWY_AVX3_ENABLE_AVX512BF16 +#define HWY_AVX3_HAVE_F32_TO_BF16C 1 +#else +#define HWY_AVX3_HAVE_F32_TO_BF16C 0 +#endif + +#undef HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT +#if HWY_TARGET <= HWY_AVX3 && HWY_ARCH_X86_64 +#define HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT "v" +#else +#define HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT "x" +#endif + +#undef HWY_X86_HAVE_AVX10_2_OPS +#if HWY_TARGET_IS_AVX10_2 && \ + (HWY_COMPILER_GCC_ACTUAL >= 1501 || \ + (HWY_COMPILER3_CLANG >= 200103 && HWY_COMPILER_CLANG != 2100)) +#define HWY_X86_HAVE_AVX10_2_OPS 1 +#else +#define HWY_X86_HAVE_AVX10_2_OPS 0 +#endif + +template +struct Raw128 { + using type = __m128i; +}; +#if HWY_HAVE_FLOAT16 +template <> +struct Raw128 { + using type = __m128h; +}; +#endif // HWY_HAVE_FLOAT16 +template <> +struct Raw128 { + using type = __m128; +}; +template <> +struct Raw128 { + using type = __m128d; +}; + +} // namespace detail + +template +class Vec128 { + using Raw = typename detail::Raw128::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = N; // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec128& operator*=(const Vec128 other) { + return *this = (*this * other); + } + HWY_INLINE Vec128& operator/=(const Vec128 other) { + return *this = (*this / other); + } + HWY_INLINE Vec128& operator+=(const Vec128 other) { + return *this = (*this + other); + } + HWY_INLINE Vec128& operator-=(const Vec128 other) { + return *this = (*this - other); + } + HWY_INLINE Vec128& operator%=(const Vec128 other) { + return *this = (*this % other); + } + HWY_INLINE Vec128& operator&=(const Vec128 other) { + return *this = (*this & other); + } + HWY_INLINE Vec128& operator|=(const Vec128 other) { + return *this = (*this | other); + } + HWY_INLINE Vec128& operator^=(const Vec128 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +template +using Vec64 = Vec128; + +template +using Vec32 = Vec128; + +template +using Vec16 = Vec128; + +namespace detail { + +#if HWY_TARGET <= HWY_AVX3 + +// Template arg: sizeof(lane type) +template +struct RawMask128T {}; +template <> +struct RawMask128T<1> { + using type = __mmask16; +}; +template <> +struct RawMask128T<2> { + using type = __mmask8; +}; +template <> +struct RawMask128T<4> { + using type = __mmask8; +}; +template <> +struct RawMask128T<8> { + using type = __mmask8; +}; + +template +using RawMask128 = typename RawMask128T::type; + +#else // AVX2 or earlier + +template +using RawMask128 = typename Raw128::type; + +#endif // HWY_TARGET <= HWY_AVX3 + +} // namespace detail + +template +struct Mask128 { + using Raw = typename detail::RawMask128; + + using PrivateT = T; // only for DFromM + static constexpr size_t kPrivateN = N; // only for DFromM + +#if HWY_TARGET <= HWY_AVX3 + static Mask128 FromBits(uint64_t mask_bits) { + return Mask128{static_cast(mask_bits)}; + } +#else +// Lanes are either FF..FF or 0. +#endif + + Raw raw; +}; + +template +using DFromV = Simd; + +template +using DFromM = Simd; + +template +using TFromV = typename V::PrivateT; + +// ------------------------------ Zero + +// Use HWY_MAX_LANES_D here because VFromD is defined in terms of Zero. +template +HWY_API Vec128, HWY_MAX_LANES_D(D)> Zero(D /* tag */) { + return Vec128, HWY_MAX_LANES_D(D)>{_mm_setzero_si128()}; +} +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Zero(D /* tag */) { + return Vec128{_mm_setzero_ph()}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Zero(D /* tag */) { + return Vec128{_mm_setzero_ps()}; +} +template +HWY_API Vec128 Zero(D /* tag */) { + return Vec128{_mm_setzero_pd()}; +} +template +HWY_API Vec128, HWY_MAX_LANES_D(D)> Zero(D /* tag */) { + return Vec128, HWY_MAX_LANES_D(D)>{_mm_setzero_si128()}; +} + +// Using the existing Zero function instead of a dedicated function for +// deduction avoids having to forward-declare Vec256 here. +template +using VFromD = decltype(Zero(D())); + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __m128i BitCastToInteger(__m128i v) { return v; } +#if HWY_HAVE_FLOAT16 +HWY_INLINE __m128i BitCastToInteger(__m128h v) { return _mm_castph_si128(v); } +#endif // HWY_HAVE_FLOAT16 +HWY_INLINE __m128i BitCastToInteger(__m128 v) { return _mm_castps_si128(v); } +HWY_INLINE __m128i BitCastToInteger(__m128d v) { return _mm_castpd_si128(v); } + +#if HWY_AVX3_HAVE_F32_TO_BF16C +HWY_INLINE __m128i BitCastToInteger(__m128bh v) { + // Need to use reinterpret_cast on GCC/Clang or BitCastScalar on MSVC to + // bit cast a __m128bh to a __m128i as there is currently no intrinsic + // available (as of GCC 13 and Clang 17) that can bit cast a __m128bh vector + // to a __m128i vector + +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANG + // On GCC or Clang, use reinterpret_cast to bit cast a __m128bh to a __m128i + return reinterpret_cast<__m128i>(v); +#else + // On MSVC, use BitCastScalar to bit cast a __m128bh to a __m128i as MSVC does + // not allow reinterpret_cast, static_cast, or a C-style cast to be used to + // bit cast from one SSE/AVX vector type to a different SSE/AVX vector type + return BitCastScalar<__m128i>(v); +#endif // HWY_COMPILER_GCC || HWY_COMPILER_CLANG +} +#endif // HWY_AVX3_HAVE_F32_TO_BF16C + +template +HWY_INLINE Vec128 BitCastToByte(Vec128 v) { + return Vec128{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger128 { + HWY_INLINE __m128i operator()(__m128i v) { return v; } +}; +#if HWY_HAVE_FLOAT16 +template <> +struct BitCastFromInteger128 { + HWY_INLINE __m128h operator()(__m128i v) { return _mm_castsi128_ph(v); } +}; +#endif // HWY_HAVE_FLOAT16 +template <> +struct BitCastFromInteger128 { + HWY_INLINE __m128 operator()(__m128i v) { return _mm_castsi128_ps(v); } +}; +template <> +struct BitCastFromInteger128 { + HWY_INLINE __m128d operator()(__m128i v) { return _mm_castsi128_pd(v); } +}; + +template +HWY_INLINE VFromD BitCastFromByte(D /* tag */, + Vec128 v) { + return VFromD{BitCastFromInteger128>()(v.raw)}; +} + +} // namespace detail + +template +HWY_API VFromD BitCast(D d, + Vec128().MaxLanes()> v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Set + +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm_set1_epi8(static_cast(t))}; // NOLINT +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm_set1_epi16(static_cast(t))}; // NOLINT +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm_set1_epi32(static_cast(t))}; +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm_set1_epi64x(static_cast(t))}; // NOLINT +} +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD Set(D /* tag */, float16_t t) { + return VFromD{_mm_set1_ph(t)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API VFromD Set(D /* tag */, float t) { + return VFromD{_mm_set1_ps(t)}; +} +template +HWY_API VFromD Set(D /* tag */, double t) { + return VFromD{_mm_set1_pd(t)}; +} + +// Generic for all vector lengths. +template +HWY_API VFromD Set(D df, TFromD t) { + const RebindToUnsigned du; + static_assert(sizeof(TFromD) == 2, "Expecting [b]f16"); + uint16_t bits; + CopyBytes<2>(&t, &bits); + return BitCast(df, Set(du, bits)); +} + +// ------------------------------ Undefined + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template +HWY_API VFromD Undefined(D /* tag */) { + // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC + // generate an XOR instruction. + return VFromD{_mm_undefined_si128()}; +} +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD Undefined(D /* tag */) { + return VFromD{_mm_undefined_ph()}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API VFromD Undefined(D /* tag */) { + return VFromD{_mm_undefined_ps()}; +} +template +HWY_API VFromD Undefined(D /* tag */) { + return VFromD{_mm_undefined_pd()}; +} +template +HWY_API VFromD Undefined(D /* tag */) { + return VFromD{_mm_undefined_si128()}; +} + +HWY_DIAGNOSTICS(pop) + +// ------------------------------ GetLane + +template +HWY_API T GetLane(const Vec128 v) { + return static_cast(_mm_cvtsi128_si32(v.raw) & 0xFF); +} +template +HWY_API T GetLane(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const uint16_t bits = + static_cast(_mm_cvtsi128_si32(BitCast(du, v).raw) & 0xFFFF); + return BitCastScalar(bits); +} +template +HWY_API T GetLane(const Vec128 v) { + return static_cast(_mm_cvtsi128_si32(v.raw)); +} +template +HWY_API float GetLane(const Vec128 v) { + return _mm_cvtss_f32(v.raw); +} +template +HWY_API T GetLane(const Vec128 v) { +#if HWY_ARCH_X86_32 + const DFromV d; + alignas(16) T lanes[2]; + Store(v, d, lanes); + return lanes[0]; +#else + return static_cast(_mm_cvtsi128_si64(v.raw)); +#endif +} +template +HWY_API double GetLane(const Vec128 v) { + return _mm_cvtsd_f64(v.raw); +} + +// ------------------------------ ResizeBitCast + +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + const Repartition du8; + return BitCast(d, VFromD{detail::BitCastToInteger(v.raw)}); +} + +// ------------------------------ Dup128VecFromValues + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + return VFromD{_mm_setr_epi8( + static_cast(t0), static_cast(t1), static_cast(t2), + static_cast(t3), static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7), static_cast(t8), + static_cast(t9), static_cast(t10), static_cast(t11), + static_cast(t12), static_cast(t13), static_cast(t14), + static_cast(t15))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return VFromD{ + _mm_setr_epi16(static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7))}; +} + +// Generic for all vector lengths +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const RebindToSigned di; + return BitCast(d, + Dup128VecFromValues( + di, BitCastScalar(t0), BitCastScalar(t1), + BitCastScalar(t2), BitCastScalar(t3), + BitCastScalar(t4), BitCastScalar(t5), + BitCastScalar(t6), BitCastScalar(t7))); +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return VFromD{_mm_setr_ph(t0, t1, t2, t3, t4, t5, t6, t7)}; +} +#else +// Generic for all vector lengths if HWY_HAVE_FLOAT16 is not true +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + const RebindToSigned di; + return BitCast(d, + Dup128VecFromValues( + di, BitCastScalar(t0), BitCastScalar(t1), + BitCastScalar(t2), BitCastScalar(t3), + BitCastScalar(t4), BitCastScalar(t5), + BitCastScalar(t6), BitCastScalar(t7))); +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{ + _mm_setr_epi32(static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{_mm_setr_ps(t0, t1, t2, t3)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + // Need to use _mm_set_epi64x as there is no _mm_setr_epi64x intrinsic + // available + return VFromD{ + _mm_set_epi64x(static_cast(t1), static_cast(t0))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{_mm_setr_pd(t0, t1)}; +} + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD +namespace detail { + +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantRawX86Vec( + hwy::SizeTag<1> /* num_of_lanes_tag*/, RawV v) { + return __builtin_constant_p(v[0]); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantRawX86Vec( + hwy::SizeTag<2> /* num_of_lanes_tag*/, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantRawX86Vec( + hwy::SizeTag<4> /* num_of_lanes_tag*/, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]) && + __builtin_constant_p(v[2]) && __builtin_constant_p(v[3]); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantRawX86Vec( + hwy::SizeTag<8> /* num_of_lanes_tag*/, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]) && + __builtin_constant_p(v[2]) && __builtin_constant_p(v[3]) && + __builtin_constant_p(v[4]) && __builtin_constant_p(v[5]) && + __builtin_constant_p(v[6]) && __builtin_constant_p(v[7]); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantRawX86Vec( + hwy::SizeTag<16> /* num_of_lanes_tag*/, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]) && + __builtin_constant_p(v[2]) && __builtin_constant_p(v[3]) && + __builtin_constant_p(v[4]) && __builtin_constant_p(v[5]) && + __builtin_constant_p(v[6]) && __builtin_constant_p(v[7]) && + __builtin_constant_p(v[8]) && __builtin_constant_p(v[9]) && + __builtin_constant_p(v[10]) && __builtin_constant_p(v[11]) && + __builtin_constant_p(v[12]) && __builtin_constant_p(v[13]) && + __builtin_constant_p(v[14]) && __builtin_constant_p(v[15]); +} + +#if HWY_TARGET <= HWY_AVX2 +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantRawX86Vec( + hwy::SizeTag<32> /* num_of_lanes_tag*/, RawV v) { + return __builtin_constant_p(v[0]) && __builtin_constant_p(v[1]) && + __builtin_constant_p(v[2]) && __builtin_constant_p(v[3]) && + __builtin_constant_p(v[4]) && __builtin_constant_p(v[5]) && + __builtin_constant_p(v[6]) && __builtin_constant_p(v[7]) && + __builtin_constant_p(v[8]) && __builtin_constant_p(v[9]) && + __builtin_constant_p(v[10]) && __builtin_constant_p(v[11]) && + __builtin_constant_p(v[12]) && __builtin_constant_p(v[13]) && + __builtin_constant_p(v[14]) && __builtin_constant_p(v[15]) && + __builtin_constant_p(v[16]) && __builtin_constant_p(v[17]) && + __builtin_constant_p(v[18]) && __builtin_constant_p(v[19]) && + __builtin_constant_p(v[20]) && __builtin_constant_p(v[21]) && + __builtin_constant_p(v[22]) && __builtin_constant_p(v[23]) && + __builtin_constant_p(v[24]) && __builtin_constant_p(v[25]) && + __builtin_constant_p(v[26]) && __builtin_constant_p(v[27]) && + __builtin_constant_p(v[28]) && __builtin_constant_p(v[29]) && + __builtin_constant_p(v[30]) && __builtin_constant_p(v[31]); +} +#endif + +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantX86Vec( + hwy::SizeTag num_of_lanes_tag, V v) { + using T = TFromV; +#if HWY_HAVE_FLOAT16 && HWY_HAVE_SCALAR_F16_TYPE + using F16VecLaneT = hwy::float16_t::Native; +#else + using F16VecLaneT = uint16_t; +#endif + using RawVecLaneT = If(), F16VecLaneT, + If(), uint16_t, T>>; + + // Suppress the -Wignored-attributes warning that is emitted by + // RemoveCvRef with GCC + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4649, ignored "-Wignored-attributes") + typedef RawVecLaneT GccRawVec + __attribute__((__vector_size__(sizeof(RemoveCvRef)))); + HWY_DIAGNOSTICS(pop) + + return IsConstantRawX86Vec(num_of_lanes_tag, + reinterpret_cast(v.raw)); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED bool IsConstantX86VecForF2IConv(V v) { + constexpr size_t kNumOfLanesInRawSrcVec = + HWY_MAX(HWY_MAX_LANES_V(V), 16 / sizeof(TFromV)); + constexpr size_t kNumOfLanesInRawResultVec = + HWY_MAX(HWY_MAX_LANES_V(V), 16 / sizeof(TTo)); + constexpr size_t kNumOfLanesToCheck = + HWY_MIN(kNumOfLanesInRawSrcVec, kNumOfLanesInRawResultVec); + + return IsConstantX86Vec(hwy::SizeTag(), v); +} + +} // namespace detail +#endif // HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + +// ================================================== LOGICAL + +// ------------------------------ And + +template +HWY_API Vec128 And(Vec128 a, Vec128 b) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{ + _mm_and_si128(BitCast(du, a).raw, BitCast(du, b).raw)}); +} +template +HWY_API Vec128 And(Vec128 a, Vec128 b) { + return Vec128{_mm_and_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 And(Vec128 a, Vec128 b) { + return Vec128{_mm_and_pd(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec128 AndNot(Vec128 not_mask, Vec128 mask) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm_andnot_si128( + BitCast(du, not_mask).raw, BitCast(du, mask).raw)}); +} +template +HWY_API Vec128 AndNot(Vec128 not_mask, + Vec128 mask) { + return Vec128{_mm_andnot_ps(not_mask.raw, mask.raw)}; +} +template +HWY_API Vec128 AndNot(Vec128 not_mask, + Vec128 mask) { + return Vec128{_mm_andnot_pd(not_mask.raw, mask.raw)}; +} + +// ------------------------------ Or + +template +HWY_API Vec128 Or(Vec128 a, Vec128 b) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{ + _mm_or_si128(BitCast(du, a).raw, BitCast(du, b).raw)}); +} + +template +HWY_API Vec128 Or(Vec128 a, Vec128 b) { + return Vec128{_mm_or_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Or(Vec128 a, Vec128 b) { + return Vec128{_mm_or_pd(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template +HWY_API Vec128 Xor(Vec128 a, Vec128 b) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{ + _mm_xor_si128(BitCast(du, a).raw, BitCast(du, b).raw)}); +} + +template +HWY_API Vec128 Xor(Vec128 a, Vec128 b) { + return Vec128{_mm_xor_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Xor(Vec128 a, Vec128 b) { + return Vec128{_mm_xor_pd(a.raw, b.raw)}; +} + +// ------------------------------ TernaryLogic + +#undef HWY_X86_HAVE_TERNARY_LOGIC +#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN +#define HWY_X86_HAVE_TERNARY_LOGIC 1 +#else +#define HWY_X86_HAVE_TERNARY_LOGIC 0 +#endif + +#if HWY_X86_HAVE_TERNARY_LOGIC +namespace detail { + +// Forward-declare the per-target implementations. +template +struct TernaryLogicImpl; + +// Interface called from all targets. Without this, the compiler would only +// examine one of the overloads, because each is templated on kTernLogOp. +template +HWY_INLINE V TernaryLogic(V a, V b, V c) { + return TernaryLogicImpl()(a, b, c); +} + +// Per-target partial specialization. +template +struct TernaryLogicImpl { + template + HWY_INLINE V operator()(V a, V b, V c) const { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const __m128i ret = _mm_ternarylogic_epi64( + BitCast(du, a).raw, BitCast(du, b).raw, BitCast(du, c).raw, kTernLogOp); + return BitCast(d, VU{ret}); + } +}; + +} // namespace detail +#endif // HWY_X86_HAVE_TERNARY_LOGIC + +// ------------------------------ Not +template // generic for all vector lengths +HWY_API V Not(const V v) { +#if HWY_X86_HAVE_TERNARY_LOGIC + return detail::TernaryLogic<0x55>(v, v, v); +#else + const DFromV d; + const RebindToSigned di; + return Xor(v, BitCast(d, Set(di, -1))); +#endif +} + +#if HWY_X86_HAVE_TERNARY_LOGIC + +// ------------------------------ Xor3 + +#ifdef HWY_NATIVE_XOR3 +#undef HWY_NATIVE_XOR3 +#else +#define HWY_NATIVE_XOR3 +#endif + +template // generic for all vector lengths +HWY_API V Xor3(V x1, V x2, V x3) { + return detail::TernaryLogic<0x96>(x1, x2, x3); +} + +// ------------------------------ XorAndNot + +#ifdef HWY_NATIVE_BCAX +#undef HWY_NATIVE_BCAX +#else +#define HWY_NATIVE_BCAX +#endif + +template // generic for all vector lengths +HWY_API V XorAndNot(V x, V a1, V a2) { + return detail::TernaryLogic<0xD2>(x, a1, a2); +} + +#endif // HWY_X86_HAVE_TERNARY_LOGIC + +// ------------------------------ Or3 +template // generic for all vector lengths +HWY_API V Or3(V o1, V o2, V o3) { +#if HWY_X86_HAVE_TERNARY_LOGIC + return detail::TernaryLogic<0xFE>(o1, o2, o3); +#else + return Or(o1, Or(o2, o3)); +#endif +} + +// ------------------------------ OrAnd +template // generic for all vector lengths +HWY_API V OrAnd(V o, V a1, V a2) { +#if HWY_X86_HAVE_TERNARY_LOGIC + return detail::TernaryLogic<0xF8>(o, a1, a2); +#else + return Or(o, And(a1, a2)); +#endif +} + +// ------------------------------ IfVecThenElse +template // generic for all vector lengths +HWY_API V IfVecThenElse(V mask, V yes, V no) { +#if HWY_X86_HAVE_TERNARY_LOGIC + return detail::TernaryLogic<0xCA>(mask, yes, no); +#else + return IfThenElse(MaskFromVec(mask), yes, no); +#endif +} + +// ------------------------------ BitwiseIfThenElse +#if HWY_X86_HAVE_TERNARY_LOGIC + +#ifdef HWY_NATIVE_BITWISE_IF_THEN_ELSE +#undef HWY_NATIVE_BITWISE_IF_THEN_ELSE +#else +#define HWY_NATIVE_BITWISE_IF_THEN_ELSE +#endif + +template // generic for all vector lengths +HWY_API V BitwiseIfThenElse(V mask, V yes, V no) { + return IfVecThenElse(mask, yes, no); +} + +#endif // HWY_X86_HAVE_TERNARY_LOGIC + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec128 operator&(const Vec128 a, const Vec128 b) { + return And(a, b); +} + +template +HWY_API Vec128 operator|(const Vec128 a, const Vec128 b) { + return Or(a, b); +} + +template +HWY_API Vec128 operator^(const Vec128 a, const Vec128 b) { + return Xor(a, b); +} + +// ------------------------------ PopulationCount + +// 8/16 require BITALG, 32/64 require VPOPCNTDQ. +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +namespace detail { + +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<1> /* tag */, + Vec128 v) { + return Vec128{_mm_popcnt_epi8(v.raw)}; +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<2> /* tag */, + Vec128 v) { + return Vec128{_mm_popcnt_epi16(v.raw)}; +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<4> /* tag */, + Vec128 v) { + return Vec128{_mm_popcnt_epi32(v.raw)}; +} +template +HWY_INLINE Vec128 PopulationCount(hwy::SizeTag<8> /* tag */, + Vec128 v) { + return Vec128{_mm_popcnt_epi64(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 PopulationCount(Vec128 v) { + return detail::PopulationCount(hwy::SizeTag(), v); +} + +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ================================================== SIGN + +// ------------------------------ Neg + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec128 Neg(hwy::FloatTag /*tag*/, const Vec128 v) { + return Xor(v, SignBit(DFromV())); +} + +template +HWY_INLINE Vec128 Neg(hwy::SpecialTag /*tag*/, const Vec128 v) { + return Xor(v, SignBit(DFromV())); +} + +template +HWY_INLINE Vec128 Neg(hwy::SignedTag /*tag*/, const Vec128 v) { + return Zero(DFromV()) - v; +} + +} // namespace detail + +template +HWY_INLINE Vec128 Neg(const Vec128 v) { + return detail::Neg(hwy::TypeTag(), v); +} + +// ------------------------------ Floating-point Abs +// Generic for all vector lengths +template )> +HWY_API V Abs(V v) { + const DFromV d; + const RebindToSigned di; + using TI = TFromD; + return v & BitCast(d, Set(di, static_cast(~SignMask()))); +} + +// ------------------------------ CopySign +// Generic for all vector lengths. +template +HWY_API V CopySign(const V magn, const V sign) { + static_assert(IsFloat>(), "Only makes sense for floating-point"); + + const DFromV d; + const auto msb = SignBit(d); + + // Truth table for msb, magn, sign | bitwise msb ? sign : mag + // 0 0 0 | 0 + // 0 0 1 | 0 + // 0 1 0 | 1 + // 0 1 1 | 1 + // 1 0 0 | 0 + // 1 0 1 | 1 + // 1 1 0 | 0 + // 1 1 1 | 1 + return BitwiseIfThenElse(msb, sign, magn); +} + +// ------------------------------ CopySignToAbs +// Generic for all vector lengths. +template +HWY_API V CopySignToAbs(const V abs, const V sign) { + const DFromV d; + return OrAnd(abs, SignBit(d), sign); +} + +// ================================================== MASK + +#if HWY_TARGET <= HWY_AVX3 +// ------------------------------ MaskFromVec + +namespace detail { + +template +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<1> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi8_mask(v.raw)}; +} +template +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<2> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi16_mask(v.raw)}; +} +template +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<4> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi32_mask(v.raw)}; +} +template +HWY_INLINE Mask128 MaskFromVec(hwy::SizeTag<8> /*tag*/, + const Vec128 v) { + return Mask128{_mm_movepi64_mask(v.raw)}; +} + +} // namespace detail + +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + return detail::MaskFromVec(hwy::SizeTag(), v); +} +// There do not seem to be native floating-point versions of these instructions. +#if HWY_HAVE_FLOAT16 +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + const RebindToSigned> di; + return Mask128{MaskFromVec(BitCast(di, v)).raw}; +} +#endif +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + const RebindToSigned> di; + return Mask128{MaskFromVec(BitCast(di, v)).raw}; +} +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + const RebindToSigned> di; + return Mask128{MaskFromVec(BitCast(di, v)).raw}; +} + +template +using MFromD = decltype(MaskFromVec(VFromD())); + +// ------------------------------ MaskFalse (MFromD) + +#ifdef HWY_NATIVE_MASK_FALSE +#undef HWY_NATIVE_MASK_FALSE +#else +#define HWY_NATIVE_MASK_FALSE +#endif + +// Generic for all vector lengths +template +HWY_API MFromD MaskFalse(D /*d*/) { + return MFromD{static_cast().raw)>(0)}; +} + +// ------------------------------ SetMask +#ifdef HWY_NATIVE_SET_MASK +#undef HWY_NATIVE_SET_MASK +#else +#define HWY_NATIVE_SET_MASK +#endif + +template +HWY_API MFromD SetMask(D /*d*/, bool val) { + constexpr uint64_t kMask = (HWY_MAX_LANES_D(D) < 64) + ? ((1ULL << (HWY_MAX_LANES_D(D) & 63)) - 1ULL) + : LimitsMax(); + + return MFromD{static_cast().raw)>( + static_cast(-static_cast(val)) & kMask)}; +} + +// ------------------------------ IsNegative (MFromD) +#ifdef HWY_NATIVE_IS_NEGATIVE +#undef HWY_NATIVE_IS_NEGATIVE +#else +#define HWY_NATIVE_IS_NEGATIVE +#endif + +// Generic for all vector lengths +template +HWY_API MFromD> IsNegative(V v) { + return MaskFromVec(v); +} + +// ------------------------------ PromoteMaskTo (MFromD) + +#ifdef HWY_NATIVE_PROMOTE_MASK_TO +#undef HWY_NATIVE_PROMOTE_MASK_TO +#else +#define HWY_NATIVE_PROMOTE_MASK_TO +#endif + +// AVX3 PromoteMaskTo is generic for all vector lengths +template )), + class DFrom_2 = Rebind, DTo>, + hwy::EnableIf, MFromD>()>* = nullptr> +HWY_API MFromD PromoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, + MFromD m) { + return MFromD{static_cast().raw)>(m.raw)}; +} + +// ------------------------------ DemoteMaskTo (MFromD) + +#ifdef HWY_NATIVE_DEMOTE_MASK_TO +#undef HWY_NATIVE_DEMOTE_MASK_TO +#else +#define HWY_NATIVE_DEMOTE_MASK_TO +#endif + +// AVX3 DemoteMaskTo is generic for all vector lengths +template ) - 1), + class DFrom_2 = Rebind, DTo>, + hwy::EnableIf, MFromD>()>* = nullptr> +HWY_API MFromD DemoteMaskTo(DTo /*d_to*/, DFrom /*d_from*/, + MFromD m) { + return MFromD{static_cast().raw)>(m.raw)}; +} + +// ------------------------------ CombineMasks (MFromD) + +#ifdef HWY_NATIVE_COMBINE_MASKS +#undef HWY_NATIVE_COMBINE_MASKS +#else +#define HWY_NATIVE_COMBINE_MASKS +#endif + +// For Clang and GCC, mask intrinsics (KORTEST) weren't added until recently. +#if !defined(HWY_COMPILER_HAS_MASK_INTRINSICS) +#if HWY_COMPILER_MSVC != 0 || HWY_COMPILER_GCC_ACTUAL >= 700 || \ + HWY_COMPILER_CLANG >= 800 +#define HWY_COMPILER_HAS_MASK_INTRINSICS 1 +#else +#define HWY_COMPILER_HAS_MASK_INTRINSICS 0 +#endif +#endif // HWY_COMPILER_HAS_MASK_INTRINSICS + +template +HWY_API MFromD CombineMasks(D /*d*/, MFromD> hi, + MFromD> lo) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const __mmask8 combined_mask = _kor_mask8( + _kshiftli_mask8(static_cast<__mmask8>(hi.raw), 1), + _kand_mask8(static_cast<__mmask8>(lo.raw), static_cast<__mmask8>(1))); +#else + const auto combined_mask = + (static_cast(hi.raw) << 1) | (lo.raw & 1); +#endif + + return MFromD{static_cast().raw)>(combined_mask)}; +} + +template +HWY_API MFromD CombineMasks(D /*d*/, MFromD> hi, + MFromD> lo) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const __mmask8 combined_mask = _kor_mask8( + _kshiftli_mask8(static_cast<__mmask8>(hi.raw), 2), + _kand_mask8(static_cast<__mmask8>(lo.raw), static_cast<__mmask8>(3))); +#else + const auto combined_mask = + (static_cast(hi.raw) << 2) | (lo.raw & 3); +#endif + + return MFromD{static_cast().raw)>(combined_mask)}; +} + +template +HWY_API MFromD CombineMasks(D /*d*/, MFromD> hi, + MFromD> lo) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const __mmask8 combined_mask = _kor_mask8( + _kshiftli_mask8(static_cast<__mmask8>(hi.raw), 4), + _kand_mask8(static_cast<__mmask8>(lo.raw), static_cast<__mmask8>(15))); +#else + const auto combined_mask = + (static_cast(hi.raw) << 4) | (lo.raw & 15u); +#endif + + return MFromD{static_cast().raw)>(combined_mask)}; +} + +template +HWY_API MFromD CombineMasks(D /*d*/, MFromD> hi, + MFromD> lo) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const __mmask16 combined_mask = _mm512_kunpackb( + static_cast<__mmask16>(hi.raw), static_cast<__mmask16>(lo.raw)); +#else + const auto combined_mask = + ((static_cast(hi.raw) << 8) | (lo.raw & 0xFFu)); +#endif + + return MFromD{static_cast().raw)>(combined_mask)}; +} + +// ------------------------------ LowerHalfOfMask (MFromD) + +#ifdef HWY_NATIVE_LOWER_HALF_OF_MASK +#undef HWY_NATIVE_LOWER_HALF_OF_MASK +#else +#define HWY_NATIVE_LOWER_HALF_OF_MASK +#endif + +// Generic for all vector lengths +template +HWY_API MFromD LowerHalfOfMask(D d, MFromD> m) { + using RawM = decltype(MFromD().raw); + constexpr size_t kN = MaxLanes(d); + constexpr size_t kNumOfBitsInRawMask = sizeof(RawM) * 8; + + MFromD result_mask{static_cast(m.raw)}; + + if (kN < kNumOfBitsInRawMask) { + result_mask = + And(result_mask, MFromD{static_cast((1ULL << kN) - 1)}); + } + + return result_mask; +} + +// ------------------------------ UpperHalfOfMask (MFromD) + +#ifdef HWY_NATIVE_UPPER_HALF_OF_MASK +#undef HWY_NATIVE_UPPER_HALF_OF_MASK +#else +#define HWY_NATIVE_UPPER_HALF_OF_MASK +#endif + +template +HWY_API MFromD UpperHalfOfMask(D /*d*/, MFromD> m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const auto shifted_mask = _kshiftri_mask8(static_cast<__mmask8>(m.raw), 1); +#else + const auto shifted_mask = static_cast(m.raw) >> 1; +#endif + + return MFromD{static_cast().raw)>(shifted_mask)}; +} + +template +HWY_API MFromD UpperHalfOfMask(D /*d*/, MFromD> m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const auto shifted_mask = _kshiftri_mask8(static_cast<__mmask8>(m.raw), 2); +#else + const auto shifted_mask = static_cast(m.raw) >> 2; +#endif + + return MFromD{static_cast().raw)>(shifted_mask)}; +} + +template +HWY_API MFromD UpperHalfOfMask(D /*d*/, MFromD> m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const auto shifted_mask = _kshiftri_mask8(static_cast<__mmask8>(m.raw), 4); +#else + const auto shifted_mask = static_cast(m.raw) >> 4; +#endif + + return MFromD{static_cast().raw)>(shifted_mask)}; +} + +template +HWY_API MFromD UpperHalfOfMask(D /*d*/, MFromD> m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const auto shifted_mask = _kshiftri_mask16(static_cast<__mmask16>(m.raw), 8); +#else + const auto shifted_mask = static_cast(m.raw) >> 8; +#endif + + return MFromD{static_cast().raw)>(shifted_mask)}; +} + +// ------------------------------ OrderedDemote2MasksTo (MFromD, CombineMasks) + +#ifdef HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#undef HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#else +#define HWY_NATIVE_ORDERED_DEMOTE_2_MASKS_TO +#endif + +// Generic for all vector lengths +template ) / 2), + class DTo_2 = Repartition, DFrom>, + hwy::EnableIf, MFromD>()>* = nullptr> +HWY_API MFromD OrderedDemote2MasksTo(DTo d_to, DFrom /*d_from*/, + MFromD a, MFromD b) { + using MH = MFromD>; + using RawMH = decltype(MH().raw); + + return CombineMasks(d_to, MH{static_cast(b.raw)}, + MH{static_cast(a.raw)}); +} + +// ------------------------------ Slide mask up/down +#ifdef HWY_NATIVE_SLIDE_MASK +#undef HWY_NATIVE_SLIDE_MASK +#else +#define HWY_NATIVE_SLIDE_MASK +#endif + +template +HWY_API MFromD SlideMask1Up(D d, MFromD m) { + using RawM = decltype(MFromD().raw); + constexpr size_t kN = MaxLanes(d); + constexpr unsigned kValidLanesMask = (1u << kN) - 1u; + +#if HWY_COMPILER_HAS_MASK_INTRINSICS + MFromD result_mask{ + static_cast(_kshiftli_mask8(static_cast<__mmask8>(m.raw), 1))}; + + if (kN < 8) { + result_mask = + And(result_mask, MFromD{static_cast(kValidLanesMask)}); + } +#else + MFromD result_mask{ + static_cast((static_cast(m.raw) << 1) & kValidLanesMask)}; +#endif + + return result_mask; +} + +template +HWY_API MFromD SlideMask1Up(D /*d*/, MFromD m) { + using RawM = decltype(MFromD().raw); +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return MFromD{ + static_cast(_kshiftli_mask16(static_cast<__mmask16>(m.raw), 1))}; +#else + return MFromD{static_cast(static_cast(m.raw) << 1)}; +#endif +} + +template +HWY_API MFromD SlideMask1Down(D d, MFromD m) { + using RawM = decltype(MFromD().raw); + constexpr size_t kN = MaxLanes(d); + constexpr unsigned kValidLanesMask = (1u << kN) - 1u; + +#if HWY_COMPILER_HAS_MASK_INTRINSICS + if (kN < 8) { + m = And(m, MFromD{static_cast(kValidLanesMask)}); + } + + return MFromD{ + static_cast(_kshiftri_mask8(static_cast<__mmask8>(m.raw), 1))}; +#else + return MFromD{ + static_cast((static_cast(m.raw) & kValidLanesMask) >> 1)}; +#endif +} + +template +HWY_API MFromD SlideMask1Down(D /*d*/, MFromD m) { + using RawM = decltype(MFromD().raw); +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return MFromD{ + static_cast(_kshiftri_mask16(static_cast<__mmask16>(m.raw), 1))}; +#else + return MFromD{ + static_cast((static_cast(m.raw) & 0xFFFFu) >> 1)}; +#endif +} + +// Generic for all vector lengths +template +HWY_API MFromD SlideMaskUpLanes(D d, MFromD m, size_t amt) { + using RawM = decltype(MFromD().raw); + constexpr size_t kN = MaxLanes(d); + constexpr uint64_t kValidLanesMask = + static_cast(((kN < 64) ? (1ULL << kN) : 0ULL) - 1ULL); + + return MFromD{static_cast( + (static_cast(m.raw) << (amt & 63)) & kValidLanesMask)}; +} + +// Generic for all vector lengths +template +HWY_API MFromD SlideMaskDownLanes(D d, MFromD m, size_t amt) { + using RawM = decltype(MFromD().raw); + constexpr size_t kN = MaxLanes(d); + constexpr uint64_t kValidLanesMask = + static_cast(((kN < 64) ? (1ULL << kN) : 0ULL) - 1ULL); + + return MFromD{static_cast( + (static_cast(m.raw) & kValidLanesMask) >> (amt & 63))}; +} + +// ------------------------------ VecFromMask + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi8(v.raw)}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi16(v.raw)}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi32(v.raw)}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_movm_epi64(v.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_castsi128_ph(_mm_movm_epi16(v.raw))}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_castsi128_ps(_mm_movm_epi32(v.raw))}; +} + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{_mm_castsi128_pd(_mm_movm_epi64(v.raw))}; +} + +// Generic for all vector lengths. +template +HWY_API VFromD VecFromMask(D /* tag */, MFromD v) { + return VecFromMask(v); +} + +// ------------------------------ RebindMask (MaskFromVec) + +template +HWY_API MFromD RebindMask(DTo /* tag */, Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TFromD), "Must have same size"); + return MFromD{m.raw}; +} + +// ------------------------------ IfThenElse + +namespace detail { + +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<1> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_blend_epi8(mask.raw, no.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<2> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_blend_epi16(mask.raw, no.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<4> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_blend_epi32(mask.raw, no.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElse(hwy::SizeTag<8> /* tag */, + Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_blend_epi64(mask.raw, no.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + return detail::IfThenElse(hwy::SizeTag(), mask, yes, no); +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 IfThenElse(Mask128 mask, + Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_blend_ph(mask.raw, no.raw, yes.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// Generic for all vector lengths. +template , HWY_X86_IF_EMULATED_D(D)> +HWY_API V IfThenElse(MFromD mask, V yes, V no) { + const RebindToUnsigned du; + return BitCast( + D(), IfThenElse(RebindMask(du, mask), BitCast(du, yes), BitCast(du, no))); +} + +template +HWY_API Vec128 IfThenElse(Mask128 mask, + Vec128 yes, Vec128 no) { + return Vec128{_mm_mask_blend_ps(mask.raw, no.raw, yes.raw)}; +} + +template +HWY_API Vec128 IfThenElse(Mask128 mask, + Vec128 yes, + Vec128 no) { + return Vec128{_mm_mask_blend_pd(mask.raw, no.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<1> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi8(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<2> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi16(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<4> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi32(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec128 IfThenElseZero(hwy::SizeTag<8> /* tag */, + Mask128 mask, Vec128 yes) { + return Vec128{_mm_maskz_mov_epi64(mask.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return detail::IfThenElseZero(hwy::SizeTag(), mask, yes); +} + +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, + Vec128 yes) { + return Vec128{_mm_maskz_mov_ps(mask.raw, yes.raw)}; +} + +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, + Vec128 yes) { + return Vec128{_mm_maskz_mov_pd(mask.raw, yes.raw)}; +} + +// Generic for all vector lengths. +template , HWY_IF_SPECIAL_FLOAT_D(D)> +HWY_API V IfThenElseZero(MFromD mask, V yes) { + const RebindToUnsigned du; + return BitCast(D(), IfThenElseZero(RebindMask(du, mask), BitCast(du, yes))); +} + +namespace detail { + +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<1> /* tag */, + Mask128 mask, Vec128 no) { + // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. + return Vec128{_mm_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<2> /* tag */, + Mask128 mask, Vec128 no) { + return Vec128{_mm_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<4> /* tag */, + Mask128 mask, Vec128 no) { + return Vec128{_mm_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec128 IfThenZeroElse(hwy::SizeTag<8> /* tag */, + Mask128 mask, Vec128 no) { + return Vec128{_mm_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; +} + +} // namespace detail + +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return detail::IfThenZeroElse(hwy::SizeTag(), mask, no); +} + +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, + Vec128 no) { + return Vec128{_mm_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; +} + +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, + Vec128 no) { + return Vec128{_mm_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; +} + +// Generic for all vector lengths. +template , HWY_IF_SPECIAL_FLOAT_D(D)> +HWY_API V IfThenZeroElse(MFromD mask, V no) { + const RebindToUnsigned du; + return BitCast(D(), IfThenZeroElse(RebindMask(du, mask), BitCast(du, no))); +} + +// ------------------------------ Mask logical + +namespace detail { + +template +HWY_INLINE Mask128 And(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 And(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 And(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 And(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kand_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask128 AndNot(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask128 Or(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Or(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Or(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Or(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} + +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<1> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<2> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<4> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask128 Xor(hwy::SizeTag<8> /*tag*/, const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} + +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<1> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxnor_mask16(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; +#endif +} +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<2> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{_kxnor_mask8(a.raw, b.raw)}; +#else + return Mask128{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; +#endif +} +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<4> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0xF)}; +#else + return Mask128{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xF)}; +#endif +} +template +HWY_INLINE Mask128 ExclusiveNeither(hwy::SizeTag<8> /*tag*/, + const Mask128 a, + const Mask128 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0x3)}; +#else + return Mask128{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0x3)}; +#endif +} + +// UnmaskedNot returns ~m.raw without zeroing out any invalid bits +template +HWY_INLINE Mask128 UnmaskedNot(const Mask128 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{static_cast<__mmask16>(_knot_mask16(m.raw))}; +#else + return Mask128{static_cast<__mmask16>(~m.raw)}; +#endif +} + +template +HWY_INLINE Mask128 UnmaskedNot(const Mask128 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask128{static_cast<__mmask8>(_knot_mask8(m.raw))}; +#else + return Mask128{static_cast<__mmask8>(~m.raw)}; +#endif +} + +template +HWY_INLINE Mask128 Not(hwy::SizeTag<1> /*tag*/, const Mask128 m) { + // sizeof(T) == 1 and N == 16: simply return ~m as all 16 bits of m are valid + return UnmaskedNot(m); +} +template +HWY_INLINE Mask128 Not(hwy::SizeTag<1> /*tag*/, const Mask128 m) { + // sizeof(T) == 1 and N <= 8: need to zero out the upper bits of ~m as there + // are fewer than 16 valid bits in m + + // Return (~m) & ((1ull << N) - 1) + return AndNot(hwy::SizeTag<1>(), m, Mask128::FromBits((1ull << N) - 1)); +} +template +HWY_INLINE Mask128 Not(hwy::SizeTag<2> /*tag*/, const Mask128 m) { + // sizeof(T) == 2 and N == 8: simply return ~m as all 8 bits of m are valid + return UnmaskedNot(m); +} +template +HWY_INLINE Mask128 Not(hwy::SizeTag<2> /*tag*/, const Mask128 m) { + // sizeof(T) == 2 and N <= 4: need to zero out the upper bits of ~m as there + // are fewer than 8 valid bits in m + + // Return (~m) & ((1ull << N) - 1) + return AndNot(hwy::SizeTag<2>(), m, Mask128::FromBits((1ull << N) - 1)); +} +template +HWY_INLINE Mask128 Not(hwy::SizeTag<4> /*tag*/, const Mask128 m) { + // sizeof(T) == 4: need to zero out the upper bits of ~m as there are at most + // 4 valid bits in m + + // Return (~m) & ((1ull << N) - 1) + return AndNot(hwy::SizeTag<4>(), m, Mask128::FromBits((1ull << N) - 1)); +} +template +HWY_INLINE Mask128 Not(hwy::SizeTag<8> /*tag*/, const Mask128 m) { + // sizeof(T) == 8: need to zero out the upper bits of ~m as there are at most + // 2 valid bits in m + + // Return (~m) & ((1ull << N) - 1) + return AndNot(hwy::SizeTag<8>(), m, Mask128::FromBits((1ull << N) - 1)); +} + +} // namespace detail + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + return detail::And(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + return detail::AndNot(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + return detail::Or(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + return detail::Xor(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask128 Not(const Mask128 m) { + // Flip only the valid bits + return detail::Not(hwy::SizeTag(), m); +} + +template +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + return detail::ExclusiveNeither(hwy::SizeTag(), a, b); +} + +#else // AVX2 or below + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask128 MaskFromVec(const Vec128 v) { + return Mask128{v.raw}; +} + +template +using MFromD = decltype(MaskFromVec(VFromD())); + +template +HWY_API Vec128 VecFromMask(const Mask128 v) { + return Vec128{v.raw}; +} + +// Generic for all vector lengths. +template +HWY_API VFromD VecFromMask(D /* tag */, MFromD v) { + return VecFromMask(v); +} + +#if HWY_TARGET >= HWY_SSSE3 + +// mask ? yes : no +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + const auto vmask = VecFromMask(DFromV(), mask); + return Or(And(vmask, yes), AndNot(vmask, no)); +} + +#else // HWY_TARGET < HWY_SSSE3 + +// mask ? yes : no +template +HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, + Vec128 no) { + return Vec128{_mm_blendv_epi8(no.raw, yes.raw, mask.raw)}; +} +template +HWY_API Vec128 IfThenElse(Mask128 mask, + Vec128 yes, Vec128 no) { + return Vec128{_mm_blendv_ps(no.raw, yes.raw, mask.raw)}; +} +template +HWY_API Vec128 IfThenElse(Mask128 mask, + Vec128 yes, + Vec128 no) { + return Vec128{_mm_blendv_pd(no.raw, yes.raw, mask.raw)}; +} + +#endif // HWY_TARGET >= HWY_SSSE3 + +// mask ? yes : 0 +template +HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { + return yes & VecFromMask(DFromV(), mask); +} + +// mask ? 0 : no +template +HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { + return AndNot(VecFromMask(DFromV(), mask), no); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask128 Not(const Mask128 m) { + const Simd d; + return MaskFromVec(Not(VecFromMask(d, m))); +} + +template +HWY_API Mask128 And(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 AndNot(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Or(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 Xor(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask128 ExclusiveNeither(const Mask128 a, Mask128 b) { + const Simd d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// MaskedTernaryLogic depends on MFromD. +#if HWY_X86_HAVE_TERNARY_LOGIC +namespace detail { + +// Forward-declare implementation. +template +struct MaskedTernaryLogicImpl; + +// Same as TernaryLogic, but with writemask. If !mask, returns a. +template +HWY_INLINE V MaskedTernaryLogic(MFromD> mask, V a, V b, V c) { + return MaskedTernaryLogicImpl()(mask, a, b, c); +} + +template +struct MaskedTernaryLogicImpl { + template , HWY_IF_T_SIZE_D(D, 4)> + HWY_INLINE V operator()(MFromD mask, V a, V b, V c) const { + const D d; + const RebindToUnsigned du; + using VU = VFromD; + const __m128i ret = + _mm_mask_ternarylogic_epi32(a.raw, mask.raw, b.raw, c.raw, kTernLogOp); + return BitCast(d, VU{ret}); + } + template , HWY_IF_T_SIZE_D(D, 8)> + HWY_INLINE V operator()(MFromD mask, V a, V b, V c) const { + const D d; + const RebindToUnsigned du; + using VU = VFromD; + const __m128i ret = + _mm_mask_ternarylogic_epi64(a.raw, mask.raw, b.raw, c.raw, kTernLogOp); + return BitCast(d, VU{ret}); + } +}; + +} // namespace detail +#endif // HWY_X86_HAVE_TERNARY_LOGIC + +// ------------------------------ ShiftLeft + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi16(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi32(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + return Vec128{_mm_slli_epi64(v.raw, kBits)}; +} + +#if HWY_TARGET <= HWY_AVX3_DL + +namespace detail { +template +HWY_API Vec128 GaloisAffine( + Vec128 v, VFromD>> matrix) { + return Vec128{_mm_gf2p8affine_epi64_epi8(v.raw, matrix.raw, 0)}; +} +} // namespace detail + +#else // HWY_TARGET > HWY_AVX3_DL + +template +HWY_API Vec128 ShiftLeft(const Vec128 v) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ShiftLeft(Vec128>{v.raw}).raw}; + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); +} + +#endif // HWY_TARGET > HWY_AVX3_DL + +// ------------------------------ ShiftRight + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srli_epi16(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srli_epi32(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srai_epi16(v.raw, kBits)}; +} +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + return Vec128{_mm_srai_epi32(v.raw, kBits)}; +} + +#if HWY_TARGET > HWY_AVX3_DL + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRight(Vec128{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +#endif // HWY_TARGET > HWY_AVX3_DL + +// i64 is implemented after BroadcastSignBit. + +// ================================================== MEMORY (1) + +// Clang static analysis claims the memory immediately after a partial vector +// store is uninitialized, and also flags the input to partial loads (at least +// for loadl_pd) as "garbage". Since 2025-07, MSAN began raising errors. We +// work around this by using CopyBytes instead of intrinsics, but only for MSAN +// and static analyzer builds to avoid potentially bad code generation. +// Unfortunately __clang_analyzer__ was not defined for clang-tidy prior to v7. +#ifndef HWY_SAFE_PARTIAL_LOAD_STORE +#if HWY_IS_MSAN || (defined(__clang_analyzer__) || \ + (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 700)) +#define HWY_SAFE_PARTIAL_LOAD_STORE 1 +#else +#define HWY_SAFE_PARTIAL_LOAD_STORE 0 +#endif +#endif // HWY_SAFE_PARTIAL_LOAD_STORE + +// ------------------------------ Load + +template +HWY_API VFromD Load(D /* tag */, const TFromD* HWY_RESTRICT aligned) { + return VFromD{_mm_load_si128(reinterpret_cast(aligned))}; +} +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Load(D, const float16_t* HWY_RESTRICT aligned) { + return Vec128{_mm_load_ph(aligned)}; +} +#endif // HWY_HAVE_FLOAT16 +// Generic for all vector lengths greater than or equal to 16 bytes. +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT aligned) { + const RebindToUnsigned du; + return BitCast(d, Load(du, detail::U16LanePointer(aligned))); +} +template +HWY_API Vec128 Load(D /* tag */, const float* HWY_RESTRICT aligned) { + return Vec128{_mm_load_ps(aligned)}; +} +template +HWY_API Vec128 Load(D /* tag */, const double* HWY_RESTRICT aligned) { + return Vec128{_mm_load_pd(aligned)}; +} + +template +HWY_API VFromD LoadU(D /* tag */, const TFromD* HWY_RESTRICT p) { + return VFromD{_mm_loadu_si128(reinterpret_cast(p))}; +} +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 LoadU(D, const float16_t* HWY_RESTRICT p) { + return Vec128{_mm_loadu_ph(p)}; +} +#endif // HWY_HAVE_FLOAT16 +// Generic for all vector lengths greater than or equal to 16 bytes. +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + return BitCast(d, LoadU(du, detail::U16LanePointer(p))); +} +template +HWY_API Vec128 LoadU(D /* tag */, const float* HWY_RESTRICT p) { + return Vec128{_mm_loadu_ps(p)}; +} +template +HWY_API Vec128 LoadU(D /* tag */, const double* HWY_RESTRICT p) { + return Vec128{_mm_loadu_pd(p)}; +} + +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128i v = _mm_setzero_si128(); + CopyBytes<8>(p, &v); // not same size +#else + const __m128i v = _mm_loadl_epi64(reinterpret_cast(p)); +#endif + return BitCast(d, VFromD{v}); +} + +template +HWY_API Vec64 Load(D /* tag */, const float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128 v = _mm_setzero_ps(); + CopyBytes<8>(p, &v); // not same size + return Vec64{v}; +#else + const __m128 hi = _mm_setzero_ps(); + return Vec64{_mm_loadl_pi(hi, reinterpret_cast(p))}; +#endif +} + +template +HWY_API Vec64 Load(D /* tag */, const double* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128d v = _mm_setzero_pd(); + CopyBytes<8>(p, &v); // not same size + return Vec64{v}; +#else + return Vec64{_mm_load_sd(p)}; +#endif +} + +template +HWY_API Vec32 Load(D /* tag */, const float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128 v = _mm_setzero_ps(); + CopyBytes<4>(p, &v); // not same size + return Vec32{v}; +#else + return Vec32{_mm_load_ss(p)}; +#endif +} + +// Any <= 32 bit except +template +HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + // Clang ArgumentPromotionPass seems to break this code. We can unpoison + // before SetTableIndices -> LoadU -> Load and the memory is poisoned again. + detail::MaybeUnpoison(p, Lanes(d)); + +#if HWY_SAFE_PARTIAL_LOAD_STORE + __m128i v = Zero(Full128>()).raw; + CopyBytes(p, &v); // not same size as VFromD +#else + int32_t bits = 0; + CopyBytes(p, &bits); // not same size as VFromD + const __m128i v = _mm_cvtsi32_si128(bits); +#endif + return BitCast(d, VFromD{v}); +} + +// For < 128 bit, LoadU == Load. +template +HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { + return Load(d, p); +} + +// 128-bit SIMD => nothing to duplicate, same as an unaligned load. +template +HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { + return LoadU(d, p); +} + +// ------------------------------ Store + +template +HWY_API void Store(VFromD v, D /* tag */, TFromD* HWY_RESTRICT aligned) { + _mm_store_si128(reinterpret_cast<__m128i*>(aligned), v.raw); +} +#if HWY_HAVE_FLOAT16 +template +HWY_API void Store(Vec128 v, D, float16_t* HWY_RESTRICT aligned) { + _mm_store_ph(aligned, v.raw); +} +#endif // HWY_HAVE_FLOAT16 +// Generic for all vector lengths greater than or equal to 16 bytes. +template +HWY_API void Store(VFromD v, D d, TFromD* HWY_RESTRICT aligned) { + const RebindToUnsigned du; + Store(BitCast(du, v), du, reinterpret_cast(aligned)); +} +template +HWY_API void Store(Vec128 v, D /* tag */, float* HWY_RESTRICT aligned) { + _mm_store_ps(aligned, v.raw); +} +template +HWY_API void Store(Vec128 v, D /* tag */, + double* HWY_RESTRICT aligned) { + _mm_store_pd(aligned, v.raw); +} + +template +HWY_API void StoreU(VFromD v, D /* tag */, TFromD* HWY_RESTRICT p) { + _mm_storeu_si128(reinterpret_cast<__m128i*>(p), v.raw); +} +#if HWY_HAVE_FLOAT16 +template +HWY_API void StoreU(Vec128 v, D, float16_t* HWY_RESTRICT p) { + _mm_storeu_ph(p, v.raw); +} +#endif // HWY_HAVE_FLOAT16 +// Generic for all vector lengths greater than or equal to 16 bytes. +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + StoreU(BitCast(du, v), du, reinterpret_cast(p)); +} +template +HWY_API void StoreU(Vec128 v, D /* tag */, float* HWY_RESTRICT p) { + _mm_storeu_ps(p, v.raw); +} +template +HWY_API void StoreU(Vec128 v, D /* tag */, double* HWY_RESTRICT p) { + _mm_storeu_pd(p, v.raw); +} + +template +HWY_API void Store(VFromD v, D d, TFromD* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + (void)d; + CopyBytes<8>(&v, p); // not same size +#else + const RebindToUnsigned du; // for float16_t + _mm_storel_epi64(reinterpret_cast<__m128i*>(p), BitCast(du, v).raw); +#endif +} +template +HWY_API void Store(Vec64 v, D /* tag */, float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<8>(&v, p); // not same size +#else + _mm_storel_pi(reinterpret_cast<__m64*>(p), v.raw); +#endif +} +template +HWY_API void Store(Vec64 v, D /* tag */, double* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<8>(&v, p); // not same size +#else + _mm_storel_pd(p, v.raw); +#endif +} + +// Any <= 32 bit except +template +HWY_API void Store(VFromD v, D d, TFromD* HWY_RESTRICT p) { + CopyBytes(&v, p); // not same size +} +template +HWY_API void Store(Vec32 v, D /* tag */, float* HWY_RESTRICT p) { +#if HWY_SAFE_PARTIAL_LOAD_STORE + CopyBytes<4>(&v, p); // not same size +#else + _mm_store_ss(p, v.raw); +#endif +} + +// For < 128 bit, StoreU == Store. +template +HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { + Store(v, d, p); +} + +// ================================================== SWIZZLE (1) + +// ------------------------------ TableLookupBytes +template +HWY_API Vec128 TableLookupBytes(const Vec128 bytes, + const Vec128 from) { + const DFromV d; + const Repartition du8; + + const DFromV d_bytes; + const Repartition du8_bytes; +#if HWY_TARGET == HWY_SSE2 +#if HWY_COMPILER_GCC_ACTUAL && HWY_HAS_BUILTIN(__builtin_shuffle) + typedef uint8_t GccU8RawVectType __attribute__((__vector_size__(16))); + (void)d; + (void)du8; + (void)d_bytes; + (void)du8_bytes; + return Vec128{reinterpret_cast::type>( + __builtin_shuffle(reinterpret_cast(bytes.raw), + reinterpret_cast(from.raw)))}; +#else + const Full128 du8_full; + + alignas(16) uint8_t result_bytes[16]; + alignas(16) uint8_t u8_bytes[16]; + alignas(16) uint8_t from_bytes[16]; + + Store(Vec128{BitCast(du8_bytes, bytes).raw}, du8_full, u8_bytes); + Store(Vec128{BitCast(du8, from).raw}, du8_full, from_bytes); + + for (int i = 0; i < 16; i++) { + result_bytes[i] = u8_bytes[from_bytes[i] & 15]; + } + + return BitCast(d, VFromD{Load(du8_full, result_bytes).raw}); +#endif +#else // SSSE3 or newer + return BitCast( + d, VFromD{_mm_shuffle_epi8(BitCast(du8_bytes, bytes).raw, + BitCast(du8, from).raw)}); +#endif +} + +// ------------------------------ TableLookupBytesOr0 +// For all vector widths; x86 anyway zeroes if >= 0x80 on SSSE3/SSE4/AVX2/AVX3 +template +HWY_API VI TableLookupBytesOr0(const V bytes, const VI from) { +#if HWY_TARGET == HWY_SSE2 + const DFromV d; + const Repartition di8; + + const auto di8_from = BitCast(di8, from); + return BitCast(d, IfThenZeroElse(di8_from < Zero(di8), + TableLookupBytes(bytes, di8_from))); +#else + return TableLookupBytes(bytes, from); +#endif +} + +// ------------------------------ Shuffles (ShiftRight, TableLookupBytes) + +// Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). +// Shuffle0321 rotates one lane to the right (the previous least-significant +// lane is now most-significant). These could also be implemented via +// CombineShiftRightBytes but the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec128 Shuffle2301(const Vec128 v) { + static_assert(sizeof(T) == 4, "Only for 32-bit lanes"); + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{_mm_shuffle_epi32(v.raw, 0xB1)}; +} +template +HWY_API Vec128 Shuffle2301(const Vec128 v) { + static_assert(N == 2 || N == 4, "Does not make sense for N=1"); + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0xB1)}; +} + +// These are used by generic_ops-inl to implement LoadInterleaved3. As with +// Intel's shuffle* intrinsics and InterleaveLower, the lower half of the output +// comes from the first argument. +namespace detail { + +template +HWY_API Vec32 ShuffleTwo2301(const Vec32 a, const Vec32 b) { + const DFromV d; + const Twice d2; + const auto ba = Combine(d2, b, a); +#if HWY_TARGET == HWY_SSE2 + Vec32 ba_shuffled{ + _mm_shufflelo_epi16(ba.raw, _MM_SHUFFLE(3, 0, 3, 0))}; + return BitCast(d, Or(ShiftLeft<8>(ba_shuffled), ShiftRight<8>(ba_shuffled))); +#else + const RebindToUnsigned d2_u; + const auto shuffle_idx = + BitCast(d2, Dup128VecFromValues(d2_u, 1, 0, 7, 6, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0)); + return Vec32{TableLookupBytes(ba, shuffle_idx).raw}; +#endif +} +template +HWY_API Vec64 ShuffleTwo2301(const Vec64 a, const Vec64 b) { + const DFromV d; + const Twice d2; + const auto ba = Combine(d2, b, a); +#if HWY_TARGET == HWY_SSE2 + Vec64 ba_shuffled{ + _mm_shuffle_epi32(ba.raw, _MM_SHUFFLE(3, 0, 3, 0))}; + return Vec64{ + _mm_shufflelo_epi16(ba_shuffled.raw, _MM_SHUFFLE(2, 3, 0, 1))}; +#else + const RebindToUnsigned d2_u; + const auto shuffle_idx = BitCast( + d2, + Dup128VecFromValues(d2_u, 0x0302, 0x0100, 0x0f0e, 0x0d0c, 0, 0, 0, 0)); + return Vec64{TableLookupBytes(ba, shuffle_idx).raw}; +#endif +} +template +HWY_API Vec128 ShuffleTwo2301(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(2, 3, 0, 1); + return BitCast(d, Vec128{_mm_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +template +HWY_API Vec32 ShuffleTwo1230(const Vec32 a, const Vec32 b) { + const DFromV d; +#if HWY_TARGET == HWY_SSE2 + const auto zero = Zero(d); + const Rebind di16; + const Vec32 a_shuffled{_mm_shufflelo_epi16( + _mm_unpacklo_epi8(a.raw, zero.raw), _MM_SHUFFLE(3, 0, 3, 0))}; + const Vec32 b_shuffled{_mm_shufflelo_epi16( + _mm_unpacklo_epi8(b.raw, zero.raw), _MM_SHUFFLE(1, 2, 1, 2))}; + const auto ba_shuffled = Combine(di16, b_shuffled, a_shuffled); + return Vec32{_mm_packus_epi16(ba_shuffled.raw, ba_shuffled.raw)}; +#else + const Twice d2; + const auto ba = Combine(d2, b, a); + const RebindToUnsigned d2_u; + const auto shuffle_idx = + BitCast(d2, Dup128VecFromValues(d2_u, 0, 3, 6, 5, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0)); + return Vec32{TableLookupBytes(ba, shuffle_idx).raw}; +#endif +} +template +HWY_API Vec64 ShuffleTwo1230(const Vec64 a, const Vec64 b) { + const DFromV d; +#if HWY_TARGET == HWY_SSE2 + const Vec32 a_shuffled{ + _mm_shufflelo_epi16(a.raw, _MM_SHUFFLE(3, 0, 3, 0))}; + const Vec32 b_shuffled{ + _mm_shufflelo_epi16(b.raw, _MM_SHUFFLE(1, 2, 1, 2))}; + return Combine(d, b_shuffled, a_shuffled); +#else + const Twice d2; + const auto ba = Combine(d2, b, a); + const RebindToUnsigned d2_u; + const auto shuffle_idx = BitCast( + d2, + Dup128VecFromValues(d2_u, 0x0100, 0x0706, 0x0d0c, 0x0b0a, 0, 0, 0, 0)); + return Vec64{TableLookupBytes(ba, shuffle_idx).raw}; +#endif +} +template +HWY_API Vec128 ShuffleTwo1230(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(1, 2, 3, 0); + return BitCast(d, Vec128{_mm_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +template +HWY_API Vec32 ShuffleTwo3012(const Vec32 a, const Vec32 b) { + const DFromV d; +#if HWY_TARGET == HWY_SSE2 + const auto zero = Zero(d); + const Rebind di16; + const Vec32 a_shuffled{_mm_shufflelo_epi16( + _mm_unpacklo_epi8(a.raw, zero.raw), _MM_SHUFFLE(1, 2, 1, 2))}; + const Vec32 b_shuffled{_mm_shufflelo_epi16( + _mm_unpacklo_epi8(b.raw, zero.raw), _MM_SHUFFLE(3, 0, 3, 0))}; + const auto ba_shuffled = Combine(di16, b_shuffled, a_shuffled); + return Vec32{_mm_packus_epi16(ba_shuffled.raw, ba_shuffled.raw)}; +#else + const Twice d2; + const auto ba = Combine(d2, b, a); + const RebindToUnsigned d2_u; + const auto shuffle_idx = + BitCast(d2, Dup128VecFromValues(d2_u, 2, 1, 4, 7, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0)); + return Vec32{TableLookupBytes(ba, shuffle_idx).raw}; +#endif +} +template +HWY_API Vec64 ShuffleTwo3012(const Vec64 a, const Vec64 b) { + const DFromV d; +#if HWY_TARGET == HWY_SSE2 + const Vec32 a_shuffled{ + _mm_shufflelo_epi16(a.raw, _MM_SHUFFLE(1, 2, 1, 2))}; + const Vec32 b_shuffled{ + _mm_shufflelo_epi16(b.raw, _MM_SHUFFLE(3, 0, 3, 0))}; + return Combine(d, b_shuffled, a_shuffled); +#else + const Twice d2; + const auto ba = Combine(d2, b, a); + const RebindToUnsigned d2_u; + const auto shuffle_idx = BitCast( + d2, + Dup128VecFromValues(d2_u, 0x0504, 0x0302, 0x0908, 0x0f0e, 0, 0, 0, 0)); + return Vec64{TableLookupBytes(ba, shuffle_idx).raw}; +#endif +} +template +HWY_API Vec128 ShuffleTwo3012(const Vec128 a, const Vec128 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(3, 0, 1, 2); + return BitCast(d, Vec128{_mm_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +} // namespace detail + +// Swap 64-bit halves +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle1032(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec128 Shuffle01(const Vec128 v) { + return Vec128{_mm_shuffle_pd(v.raw, v.raw, 1)}; +} + +// Rotate right 32 bits +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec128 Shuffle0321(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x39)}; +} +// Rotate left 32 bits +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec128 Shuffle2103(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x93)}; +} + +// Reverse +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec128 Shuffle0123(const Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x1B)}; +} + +// ================================================== COMPARE + +#if HWY_TARGET <= HWY_AVX3 + +// Comparisons set a mask bit to 1 if the condition is true, else 0. + +// ------------------------------ TestBit + +namespace detail { + +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<1> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi8_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<2> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi16_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<4> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi32_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask128 TestBit(hwy::SizeTag<8> /*tag*/, const Vec128 v, + const Vec128 bit) { + return Mask128{_mm_test_epi64_mask(v.raw, bit.raw)}; +} + +} // namespace detail + +template +HWY_API Mask128 TestBit(const Vec128 v, const Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return detail::TestBit(hwy::SizeTag(), v, bit); +} + +// ------------------------------ Equality + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi8_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi16_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi32_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator==(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpeq_epi64_mask(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask128{_mm_cmp_ph_mask(a.raw, b.raw, _CMP_EQ_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator==(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +// ------------------------------ Inequality + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi8_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi16_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi32_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator!=(const Vec128 a, const Vec128 b) { + return Mask128{_mm_cmpneq_epi64_mask(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask128{_mm_cmp_ph_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator!=(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +// ------------------------------ Strict inequality + +// Signed/float < +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + return Mask128{_mm_cmpgt_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi64_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epu64_mask(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator>(Vec128 a, + Vec128 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask128{_mm_cmp_ph_mask(a.raw, b.raw, _CMP_GT_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} +template +HWY_API Mask128 operator>(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} + +// ------------------------------ Weak inequality + +#if HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask128{_mm_cmp_ph_mask(a.raw, b.raw, _CMP_GE_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} + +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epi64_mask(a.raw, b.raw)}; +} + +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epu8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epu16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epu32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator>=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_epu64_mask(a.raw, b.raw)}; +} + +#else // AVX2 or below + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API MFromD RebindMask(DTo dto, Mask128 m) { + static_assert(sizeof(TFrom) == sizeof(TFromD), "Must have same size"); + const Simd d; + return MaskFromVec(BitCast(dto, VecFromMask(d, m))); +} + +template +HWY_API Mask128 TestBit(Vec128 v, Vec128 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Equality + +// Unsigned +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_epi8(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_epi16(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_epi32(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + const DFromV d64; + const RepartitionToNarrow d32; + const auto cmp32 = VecFromMask(d32, Eq(BitCast(d32, a), BitCast(d32, b))); + const auto cmp64 = cmp32 & Shuffle2301(cmp32); + return MaskFromVec(BitCast(d64, cmp64)); +#else + return Mask128{_mm_cmpeq_epi64(a.raw, b.raw)}; +#endif +} + +// Signed +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_epi8(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_epi16(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_epi32(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(const Vec128 a, + const Vec128 b) { + // Same as signed ==; avoid duplicating the SSSE3 version. + const DFromV d; + RebindToUnsigned du; + return RebindMask(d, BitCast(du, a) == BitCast(du, b)); +} + +// Float +template +HWY_API Mask128 operator==(Vec128 a, Vec128 b) { + return Mask128{_mm_cmpeq_ps(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator==(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpeq_pd(a.raw, b.raw)}; +} + +// ------------------------------ Inequality + +// This cannot have T as a template argument, otherwise it is not more +// specialized than rewritten operator== in C++20, leading to compile +// errors: https://gcc.godbolt.org/z/xsrPhPvPT. +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Not(a == b); +} + +template +HWY_API Mask128 operator!=(Vec128 a, Vec128 b) { + return Mask128{_mm_cmpneq_ps(a.raw, b.raw)}; +} +template +HWY_API Mask128 operator!=(Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpneq_pd(a.raw, b.raw)}; +} + +// ------------------------------ Strict inequality + +namespace detail { + +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi8(a.raw, b.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi16(a.raw, b.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_epi32(a.raw, b.raw)}; +} + +template +HWY_INLINE Mask128 Gt(hwy::SignedTag /*tag*/, + const Vec128 a, + const Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + // See https://stackoverflow.com/questions/65166174/: + const DFromV d; + const RepartitionToNarrow d32; + const Vec128 m_eq32{Eq(BitCast(d32, a), BitCast(d32, b)).raw}; + const Vec128 m_gt32{Gt(BitCast(d32, a), BitCast(d32, b)).raw}; + // If a.upper is greater, upper := true. Otherwise, if a.upper == b.upper: + // upper := b-a (unsigned comparison result of lower). Otherwise: upper := 0. + const __m128i upper = OrAnd(m_gt32, m_eq32, Sub(b, a)).raw; + // Duplicate upper to lower half. + return Mask128{_mm_shuffle_epi32(upper, _MM_SHUFFLE(3, 3, 1, 1))}; +#else + return Mask128{_mm_cmpgt_epi64(a.raw, b.raw)}; // SSE4.2 +#endif +} + +template +HWY_INLINE Mask128 Gt(hwy::UnsignedTag /*tag*/, Vec128 a, + Vec128 b) { + const DFromV du; + const RebindToSigned di; + const Vec128 msb = Set(du, (LimitsMax() >> 1) + 1); + const auto sa = BitCast(di, Xor(a, msb)); + const auto sb = BitCast(di, Xor(b, msb)); + return RebindMask(du, Gt(hwy::SignedTag(), sa, sb)); +} + +template +HWY_INLINE Mask128 Gt(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_ps(a.raw, b.raw)}; +} +template +HWY_INLINE Mask128 Gt(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpgt_pd(a.raw, b.raw)}; +} + +} // namespace detail + +template +HWY_INLINE Mask128 operator>(Vec128 a, Vec128 b) { + return detail::Gt(hwy::TypeTag(), a, b); +} + +// ------------------------------ Weak inequality + +namespace detail { +template +HWY_INLINE Mask128 Ge(hwy::SignedTag tag, Vec128 a, + Vec128 b) { + return Not(Gt(tag, b, a)); +} + +template +HWY_INLINE Mask128 Ge(hwy::UnsignedTag tag, Vec128 a, + Vec128 b) { + return Not(Gt(tag, b, a)); +} + +template +HWY_INLINE Mask128 Ge(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_ps(a.raw, b.raw)}; +} +template +HWY_INLINE Mask128 Ge(hwy::FloatTag /*tag*/, Vec128 a, + Vec128 b) { + return Mask128{_mm_cmpge_pd(a.raw, b.raw)}; +} + +} // namespace detail + +template +HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { + return detail::Ge(hwy::TypeTag(), a, b); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Reversed comparisons + +template +HWY_API Mask128 operator<(Vec128 a, Vec128 b) { + return b > a; +} + +template +HWY_API Mask128 operator<=(Vec128 a, Vec128 b) { + return b >= a; +} + +// ------------------------------ Iota (Load) + +namespace detail { + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm_set_epi8( + static_cast(15), static_cast(14), static_cast(13), + static_cast(12), static_cast(11), static_cast(10), + static_cast(9), static_cast(8), static_cast(7), + static_cast(6), static_cast(5), static_cast(4), + static_cast(3), static_cast(2), static_cast(1), + static_cast(0))}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm_set_epi16(int16_t{7}, int16_t{6}, int16_t{5}, int16_t{4}, + int16_t{3}, int16_t{2}, int16_t{1}, + int16_t{0})}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm_set_ph(float16_t{7}, float16_t{6}, float16_t{5}, + float16_t{4}, float16_t{3}, float16_t{2}, + float16_t{1}, float16_t{0})}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{ + _mm_set_epi32(int32_t{3}, int32_t{2}, int32_t{1}, int32_t{0})}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm_set_epi64x(int64_t{1}, int64_t{0})}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm_set_ps(3.0f, 2.0f, 1.0f, 0.0f)}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm_set_pd(1.0, 0.0)}; +} + +#if HWY_COMPILER_MSVC +template +static HWY_INLINE V MaskOutVec128Iota(V v) { + const V mask_out_mask{_mm_set_epi32(0, 0, 0, 0xFF)}; + return v & mask_out_mask; +} +template +static HWY_INLINE V MaskOutVec128Iota(V v) { +#if HWY_TARGET <= HWY_SSE4 + return V{_mm_blend_epi16(v.raw, _mm_setzero_si128(), 0xFE)}; +#else + const V mask_out_mask{_mm_set_epi32(0, 0, 0, 0xFFFF)}; + return v & mask_out_mask; +#endif +} +template +static HWY_INLINE V MaskOutVec128Iota(V v) { + const DFromV d; + const Repartition df; + using VF = VFromD; + return BitCast(d, VF{_mm_move_ss(_mm_setzero_ps(), BitCast(df, v).raw)}); +} +template +static HWY_INLINE V MaskOutVec128Iota(V v) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + return BitCast(d, VU{_mm_move_epi64(BitCast(du, v).raw)}); +} +template +static HWY_INLINE V MaskOutVec128Iota(V v) { + return v; +} +#endif + +} // namespace detail + +template +HWY_API VFromD Iota(D d, const T2 first) { + const auto result_iota = + detail::Iota0(d) + Set(d, ConvertScalarTo>(first)); +#if HWY_COMPILER_MSVC + return detail::MaskOutVec128Iota(result_iota); +#else + return result_iota; +#endif +} + +// ------------------------------ FirstN (Iota, Lt) + +template , HWY_IF_V_SIZE_LE_D(D, 16)> +HWY_API M FirstN(D d, size_t num) { + constexpr size_t kN = MaxLanes(d); + // For AVX3, this ensures `num` <= 255 as required by bzhi, which only looks + // at the lower 8 bits; for AVX2 and below, this ensures `num` fits in TI. + num = HWY_MIN(num, kN); +#if HWY_TARGET <= HWY_AVX3 +#if HWY_ARCH_X86_64 + const uint64_t all = (1ull << kN) - 1; + return M::FromBits(_bzhi_u64(all, num)); +#else + const uint32_t all = static_cast((1ull << kN) - 1); + return M::FromBits(_bzhi_u32(all, static_cast(num))); +#endif // HWY_ARCH_X86_64 +#else // HWY_TARGET > HWY_AVX3 + const RebindToSigned di; // Signed comparisons are cheaper. + using TI = TFromD; + return RebindMask(d, detail::Iota0(di) < Set(di, static_cast(num))); +#endif // HWY_TARGET <= HWY_AVX3 +} + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + return Vec128{_mm_unpacklo_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; // for float16_t + return BitCast( + d, VU{_mm_unpacklo_epi16(BitCast(du, a).raw, BitCast(du, b).raw)}); +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + return Vec128{_mm_unpacklo_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { + return Vec128{_mm_unpacklo_epi64(a.raw, b.raw)}; +} + +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{_mm_unpacklo_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 InterleaveLower(Vec128 a, + Vec128 b) { + return Vec128{_mm_unpacklo_pd(a.raw, b.raw)}; +} + +// Generic for all vector lengths. +template +HWY_API VFromD InterleaveLower(D /* tag */, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// ================================================== MEMORY (2) + +// ------------------------------ MaskedLoad + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm_maskz_loadu_epi8(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm_maskz_loadu_epi16(m.raw, p)}); +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm_maskz_loadu_epi32(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm_maskz_loadu_epi64(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const float* HWY_RESTRICT p) { + return VFromD{_mm_maskz_loadu_ps(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const double* HWY_RESTRICT p) { + return VFromD{_mm_maskz_loadu_pd(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm_mask_loadu_epi8(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{ + _mm_mask_loadu_epi16(BitCast(du, v).raw, m.raw, p)}); +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm_mask_loadu_epi32(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm_mask_loadu_epi64(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const float* HWY_RESTRICT p) { + return VFromD{_mm_mask_loadu_ps(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const double* HWY_RESTRICT p) { + return VFromD{_mm_mask_loadu_pd(v.raw, m.raw, p)}; +} + +#elif HWY_TARGET == HWY_AVX2 + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + auto p_p = reinterpret_cast(p); // NOLINT + return VFromD{_mm_maskload_epi32(p_p, m.raw)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + auto p_p = reinterpret_cast(p); // NOLINT + return VFromD{_mm_maskload_epi64(p_p, m.raw)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D d, const float* HWY_RESTRICT p) { + const RebindToSigned di; + return VFromD{_mm_maskload_ps(p, BitCast(di, VecFromMask(d, m)).raw)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D d, const double* HWY_RESTRICT p) { + const RebindToSigned di; + return VFromD{_mm_maskload_pd(p, BitCast(di, VecFromMask(d, m)).raw)}; +} + +// There is no maskload_epi8/16, so blend instead. +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + return IfThenElseZero(m, LoadU(d, p)); +} + +#else // <= SSE4 + +// Avoid maskmov* - its nontemporal 'hint' causes it to bypass caches (slow). +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + return IfThenElseZero(m, LoadU(d, p)); +} + +#endif + +// ------------------------------ MaskedLoadOr + +#if HWY_TARGET > HWY_AVX3 // else: native + +// Generic for all vector lengths. +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + return IfThenElse(m, LoadU(d, p), v); +} + +#endif // HWY_TARGET > HWY_AVX3 + +// ------------------------------ LoadN (InterleaveLower) + +#if HWY_TARGET <= HWY_AVX2 && !HWY_MEM_OPS_MIGHT_FAULT + +#ifdef HWY_NATIVE_LOAD_N +#undef HWY_NATIVE_LOAD_N +#else +#define HWY_NATIVE_LOAD_N +#endif + +// Generic for all vector lengths. +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, HWY_MAX(HWY_MAX_LANES_D(D), 16 / sizeof(TFromD))> + d_full; + return ResizeBitCast(d, MaskedLoad(FirstN(d_full, num_lanes), d_full, p)); +} + +// Generic for all vector lengths. +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const FixedTag, HWY_MAX(HWY_MAX_LANES_D(D), 16 / sizeof(TFromD))> + d_full; + return ResizeBitCast(d, MaskedLoadOr(ResizeBitCast(d_full, no), + FirstN(d_full, num_lanes), d_full, p)); +} + +#if HWY_TARGET > HWY_AVX3 +namespace detail { + +// 'Leading' means the part that fits in 32-bit lanes. With 2-byte vectors, +// there are none, so return the remainder (v_trailing). +template +HWY_INLINE VFromD AVX2UIF8Or16LoadLeadingN( + VFromD /*load_mask*/, D /*d*/, const TFromD* HWY_RESTRICT /*p*/, + VFromD v_trailing) { + return v_trailing; +} + +template +HWY_INLINE VFromD AVX2UIF8Or16LoadLeadingNOr( + VFromD /*no*/, VFromD /*load_mask*/, D /*d*/, + const TFromD* HWY_RESTRICT /*p*/, VFromD v_trailing) { + return v_trailing; +} + +template +HWY_INLINE VFromD AVX2UIF8Or16LoadLeadingN(VFromD load_mask, D d, + const TFromD* HWY_RESTRICT p, + VFromD v_trailing) { + using DI32 = Repartition; + const FixedTag di32_full; + + // ResizeBitCast of load_mask to di32 is okay below if + // d.MaxBytes() < di32.MaxBytes() is true as any lanes of load_mask.raw past + // the first (lowest-index) lanes of load_mask.raw will have already been + // zeroed out by FirstN. + return ResizeBitCast( + d, IfNegativeThenElse( + ResizeBitCast(di32_full, load_mask), + MaskedLoad(MaskFromVec(ResizeBitCast(di32_full, load_mask)), + di32_full, reinterpret_cast(p)), + ResizeBitCast(di32_full, v_trailing))); +} + +template +HWY_INLINE VFromD AVX2UIF8Or16LoadLeadingNOr(VFromD no, + VFromD load_mask, D d, + const TFromD* HWY_RESTRICT p, + VFromD v_trailing) { + using DI32 = Repartition; + const FixedTag di32_full; + + // ResizeBitCast of load_mask to di32 is okay below if + // d.MaxBytes() < di32.MaxBytes() is true as any lanes of load_mask.raw past + // the first (lowest-index) lanes of load_mask.raw will have already been + // zeroed out by FirstN. + return ResizeBitCast( + d, IfNegativeThenElse( + ResizeBitCast(di32_full, load_mask), + MaskedLoadOr(ResizeBitCast(di32_full, no), + MaskFromVec(ResizeBitCast(di32_full, load_mask)), + di32_full, reinterpret_cast(p)), + ResizeBitCast(di32_full, v_trailing))); +} + +// Single lane: load or default value. +template +HWY_INLINE VFromD AVX2UIF8Or16LoadTrailingN(VFromD /*load_mask*/, D d, + const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + return (num_lanes > 0) ? LoadU(d, p) : Zero(d); +} + +template +HWY_INLINE VFromD AVX2UIF8Or16LoadTrailingNOr( + VFromD no, VFromD /*load_mask*/, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + return (num_lanes > 0) ? LoadU(d, p) : no; +} + +// Two lanes: load 1, 2, or default. +template +HWY_INLINE VFromD AVX2UIF8Or16LoadTrailingN(VFromD /*load_mask*/, D d, + const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + if (num_lanes > 1) { + return LoadU(d, p); + } else { + const FixedTag, 1> d1; + return (num_lanes == 1) ? ResizeBitCast(d, LoadU(d1, p)) : Zero(d); + } +} + +template +HWY_INLINE VFromD AVX2UIF8Or16LoadTrailingNOr( + VFromD no, VFromD /*load_mask*/, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + if (num_lanes > 1) { + return LoadU(d, p); + } else { + if (num_lanes == 0) return no; + // Load one, upper lane is default. + const FixedTag, 1> d1; + return InterleaveLower(ResizeBitCast(d, LoadU(d1, p)), no); + } +} + +template +HWY_INLINE VFromD AVX2UIF8Or16LoadTrailingN(VFromD load_mask, D d, + const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const size_t trailing_n = num_lanes & 3; + if (trailing_n == 0) return Zero(d); + + VFromD v_trailing = And(load_mask, Set(d, p[num_lanes - 1])); + + if ((trailing_n & 2) != 0) { + const Repartition di16; + int16_t i16_bits; + CopyBytes(p + num_lanes - trailing_n, &i16_bits); + v_trailing = BitCast( + d, IfNegativeThenElse(BitCast(di16, load_mask), Set(di16, i16_bits), + BitCast(di16, v_trailing))); + } + + return v_trailing; +} + +template +HWY_INLINE VFromD AVX2UIF8Or16LoadTrailingNOr( + VFromD no, VFromD load_mask, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + const size_t trailing_n = num_lanes & 3; + if (trailing_n == 0) return no; + + VFromD v_trailing = IfVecThenElse(load_mask, Set(d, p[num_lanes - 1]), no); + + if ((trailing_n & 2) != 0) { + const Repartition di16; + int16_t i16_bits; + CopyBytes(p + num_lanes - trailing_n, &i16_bits); + v_trailing = BitCast( + d, IfNegativeThenElse(BitCast(di16, load_mask), Set(di16, i16_bits), + BitCast(di16, v_trailing))); + } + + return v_trailing; +} + +template +HWY_INLINE VFromD AVX2UIF8Or16LoadTrailingN(VFromD load_mask, D d, + const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + if ((num_lanes & 1) != 0) { + return And(load_mask, Set(d, p[num_lanes - 1])); + } else { + return Zero(d); + } +} + +template +HWY_INLINE VFromD AVX2UIF8Or16LoadTrailingNOr( + VFromD no, VFromD load_mask, D d, const TFromD* HWY_RESTRICT p, + size_t num_lanes) { + if ((num_lanes & 1) != 0) { + return IfVecThenElse(load_mask, Set(d, p[num_lanes - 1]), no); + } else { + return no; + } +} + +} // namespace detail + +// Generic for all vector lengths. +template +HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, size_t N) { + const FixedTag, HWY_MAX(HWY_MAX_LANES_D(D), 16 / sizeof(TFromD))> + d_full; + + const VFromD load_mask = + ResizeBitCast(d, VecFromMask(d_full, FirstN(d_full, N))); + const size_t num_lanes = HWY_MIN(N, HWY_MAX_LANES_D(D)); + const VFromD v_trailing = + detail::AVX2UIF8Or16LoadTrailingN(load_mask, d, p, num_lanes); + +#if HWY_COMPILER_GCC && !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(num_lanes < (4 / sizeof(TFromD))) && + num_lanes < (4 / sizeof(TFromD))) { + return v_trailing; + } +#endif + + return detail::AVX2UIF8Or16LoadLeadingN(load_mask, d, p, v_trailing); +} + +// Generic for all vector lengths. +template +HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, + size_t N) { + const FixedTag, HWY_MAX(HWY_MAX_LANES_D(D), 16 / sizeof(TFromD))> + d_full; + + const VFromD load_mask = + ResizeBitCast(d, VecFromMask(d_full, FirstN(d_full, N))); + const size_t num_lanes = HWY_MIN(N, HWY_MAX_LANES_D(D)); + const VFromD v_trailing = + detail::AVX2UIF8Or16LoadTrailingNOr(no, load_mask, d, p, num_lanes); + +#if HWY_COMPILER_GCC && !HWY_IS_DEBUG_BUILD + if (__builtin_constant_p(num_lanes < (4 / sizeof(TFromD))) && + num_lanes < (4 / sizeof(TFromD))) { + return v_trailing; + } +#endif + + return detail::AVX2UIF8Or16LoadLeadingNOr(no, load_mask, d, p, v_trailing); +} + +#endif // HWY_TARGET > HWY_AVX3 +#endif // HWY_TARGET <= HWY_AVX2 && !HWY_MEM_OPS_MIGHT_FAULT + +// ------------------------------ BlendedStore + +namespace detail { + +// There is no maskload_epi8/16 with which we could safely implement +// BlendedStore. Manual blending is also unsafe because loading a full vector +// that crosses the array end causes asan faults. Resort to scalar code; the +// caller should instead use memcpy, assuming m is FirstN(d, n). +template +HWY_API void ScalarMaskedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + const RebindToSigned di; // for testing mask if T=bfloat16_t. + using TI = TFromD; + alignas(16) TI buf[MaxLanes(d)]; + alignas(16) TI mask[MaxLanes(d)]; + Store(BitCast(di, v), di, buf); + Store(BitCast(di, VecFromMask(d, m)), di, mask); + for (size_t i = 0; i < MaxLanes(d); ++i) { + if (mask[i]) { + CopySameSize(buf + i, p + i); + } + } +} +} // namespace detail + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT p) { + _mm_mask_storeu_epi8(p, m.raw, v.raw); +} +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + _mm_mask_storeu_epi16(reinterpret_cast(p), RebindMask(du, m).raw, + BitCast(du, v).raw); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + _mm_mask_storeu_epi32(pi, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + _mm_mask_storeu_epi64(pi, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D, float* HWY_RESTRICT p) { + _mm_mask_storeu_ps(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D, double* HWY_RESTRICT p) { + _mm_mask_storeu_pd(p, m.raw, v.raw); +} + +#elif HWY_TARGET == HWY_AVX2 + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + detail::ScalarMaskedStore(v, m, d, p); +} + +namespace detail { + +template +HWY_INLINE void NativeBlendedStore(V v, M m, TFromD* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + _mm_maskstore_epi32(pi, m.raw, v.raw); +} + +template +HWY_INLINE void NativeBlendedStore(V v, M m, TFromD* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + _mm_maskstore_epi64(pi, m.raw, v.raw); +} + +template +HWY_INLINE void NativeBlendedStore(V v, M m, float* HWY_RESTRICT p) { + _mm_maskstore_ps(p, m.raw, v.raw); +} + +template +HWY_INLINE void NativeBlendedStore(V v, M m, double* HWY_RESTRICT p) { + _mm_maskstore_pd(p, m.raw, v.raw); +} + +} // namespace detail + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + const RebindToSigned di; + // For partial vectors, avoid writing other lanes by zeroing their mask. + if (d.MaxBytes() < 16) { + const Full128> dfull; + const Mask128> mfull{m.raw}; + m = MFromD{And(mfull, FirstN(dfull, MaxLanes(d))).raw}; + } + + // Float/double require, and unsigned ints tolerate, signed int masks. + detail::NativeBlendedStore(v, RebindMask(di, m), p); +} + +#else // <= SSE4 + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + // Avoid maskmov* - its nontemporal 'hint' causes it to bypass caches (slow). + detail::ScalarMaskedStore(v, m, d, p); +} + +#endif // SSE4 + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi64(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_epi64(a.raw, b.raw)}; +} + +// Float +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator+(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_add_pd(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(Vec128 a, + Vec128 b) { + return Vec128{_mm_sub_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi64(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_epi64(a.raw, b.raw)}; +} + +// Float +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator-(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sub_pd(a.raw, b.raw)}; +} + +// ------------------------------ AddSub + +#if HWY_TARGET <= HWY_SSSE3 + +#undef HWY_IF_ADDSUB_V +#define HWY_IF_ADDSUB_V(V) \ + HWY_IF_V_SIZE_GT_V( \ + V, ((hwy::IsFloat3264>()) ? 32 : sizeof(TFromV))) + +template +HWY_API Vec128 AddSub(Vec128 a, Vec128 b) { + return Vec128{_mm_addsub_ps(a.raw, b.raw)}; +} +HWY_API Vec128 AddSub(Vec128 a, Vec128 b) { + return Vec128{_mm_addsub_pd(a.raw, b.raw)}; +} +#endif // HWY_TARGET <= HWY_SSSE3 + +// ------------------------------ PairwiseAdd128/PairwiseSub128 + +// Need to use the default implementation of PairwiseAdd128/PairwiseSub128 in +// generic_ops-inl.h for U8/I8/F16/I64/U64 vectors and 64-byte vectors + +#if HWY_TARGET <= HWY_SSSE3 + +#undef HWY_IF_PAIRWISE_ADD_128_D +#undef HWY_IF_PAIRWISE_SUB_128_D +#define HWY_IF_PAIRWISE_ADD_128_D(D) \ + hwy::EnableIf<( \ + HWY_MAX_LANES_D(D) > (32 / sizeof(hwy::HWY_NAMESPACE::TFromD)) || \ + (HWY_MAX_LANES_D(D) > (8 / sizeof(hwy::HWY_NAMESPACE::TFromD)) && \ + !(hwy::IsSameEither, int16_t, \ + uint16_t>() || \ + sizeof(hwy::HWY_NAMESPACE::TFromD) == 4 || \ + hwy::IsSame, double>())))>* = nullptr +#define HWY_IF_PAIRWISE_SUB_128_D(D) HWY_IF_PAIRWISE_ADD_128_D(D) + +template +HWY_API VFromD PairwiseAdd128(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm_hadd_epi16(a.raw, b.raw)}; +} +template +HWY_API VFromD PairwiseSub128(D /*d*/, VFromD a, VFromD b) { + const DFromV d; + const RebindToSigned di; + return BitCast(d, Neg(BitCast(di, VFromD{_mm_hsub_epi16(a.raw, b.raw)}))); +} +template +HWY_API VFromD PairwiseAdd128(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm_hadd_epi32(a.raw, b.raw)}; +} +template +HWY_API VFromD PairwiseSub128(D /*d*/, VFromD a, VFromD b) { + const DFromV d; + const RebindToSigned di; + return BitCast(d, Neg(BitCast(di, VFromD{_mm_hsub_epi32(a.raw, b.raw)}))); +} +template +HWY_API VFromD PairwiseAdd128(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm_hadd_ps(a.raw, b.raw)}; +} +template +HWY_API VFromD PairwiseSub128(D /*d*/, VFromD a, VFromD b) { + return Neg(VFromD{_mm_hsub_ps(a.raw, b.raw)}); +} +template +HWY_API VFromD PairwiseAdd128(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm_hadd_pd(a.raw, b.raw)}; +} +template +HWY_API VFromD PairwiseSub128(D /*d*/, VFromD a, VFromD b) { + return Neg(VFromD{_mm_hsub_pd(a.raw, b.raw)}); +} + +#endif // HWY_TARGET <= HWY_SSSE3 + +// ------------------------------ SumsOf8 +template +HWY_API Vec128 SumsOf8(const Vec128 v) { + return Vec128{_mm_sad_epu8(v.raw, _mm_setzero_si128())}; +} + +// Generic for all vector lengths +template )> +HWY_API VFromD>> SumsOf8(V v) { + const DFromV d; + const RebindToUnsigned du; + const Repartition di64; + + // Adjust the values of v to be in the 0..255 range by adding 128 to each lane + // of v (which is the same as an bitwise XOR of each i8 lane by 128) and then + // bitcasting the Xor result to an u8 vector. + const auto v_adj = BitCast(du, Xor(v, SignBit(d))); + + // Need to add -1024 to each i64 lane of the result of the SumsOf8(v_adj) + // operation to account for the adjustment made above. + return BitCast(di64, SumsOf8(v_adj)) + Set(di64, int64_t{-1024}); +} + +#ifdef HWY_NATIVE_SUMS_OF_8_ABS_DIFF +#undef HWY_NATIVE_SUMS_OF_8_ABS_DIFF +#else +#define HWY_NATIVE_SUMS_OF_8_ABS_DIFF +#endif + +template +HWY_API Vec128 SumsOf8AbsDiff(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_sad_epu8(a.raw, b.raw)}; +} + +// Generic for all vector lengths +template )> +HWY_API VFromD>> SumsOf8AbsDiff(V a, V b) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWideX3 di64; + + // Adjust the values of a and b to be in the 0..255 range by adding 128 to + // each lane of a and b (which is the same as an bitwise XOR of each i8 lane + // by 128) and then bitcasting the results of the Xor operations to u8 + // vectors. + const auto i8_msb = SignBit(d); + const auto a_adj = BitCast(du, Xor(a, i8_msb)); + const auto b_adj = BitCast(du, Xor(b, i8_msb)); + + // The result of SumsOf8AbsDiff(a_adj, b_adj) can simply be bitcasted to an + // i64 vector as |(a[i] + 128) - (b[i] + 128)| == |a[i] - b[i]| is true + return BitCast(di64, SumsOf8AbsDiff(a_adj, b_adj)); +} + +// ------------------------------ SumsOf4 +#if HWY_TARGET <= HWY_AVX3 +namespace detail { + +template +HWY_INLINE Vec128 SumsOf4( + hwy::UnsignedTag /*type_tag*/, hwy::SizeTag<1> /*lane_size_tag*/, + Vec128 v) { + const DFromV d; + + // _mm_maskz_dbsad_epu8 is used below as the odd uint16_t lanes need to be + // zeroed out and the sums of the 4 consecutive lanes are already in the + // even uint16_t lanes of the _mm_maskz_dbsad_epu8 result. + return Vec128{ + _mm_maskz_dbsad_epu8(static_cast<__mmask8>(0x55), v.raw, Zero(d).raw, 0)}; +} + +// detail::SumsOf4 for Vec128 on AVX3 is implemented in x86_512-inl.h + +} // namespace detail +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ SumsOfAdjQuadAbsDiff + +#if HWY_TARGET <= HWY_SSE4 +#ifdef HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF +#undef HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF +#else +#define HWY_NATIVE_SUMS_OF_ADJ_QUAD_ABS_DIFF +#endif + +template +HWY_API Vec128 SumsOfAdjQuadAbsDiff( + Vec128 a, Vec128 b) { + static_assert(0 <= kAOffset && kAOffset <= 1, + "kAOffset must be between 0 and 1"); + static_assert(0 <= kBOffset && kBOffset <= 3, + "kBOffset must be between 0 and 3"); + return Vec128{ + _mm_mpsadbw_epu8(a.raw, b.raw, (kAOffset << 2) | kBOffset)}; +} + +// Generic for all vector lengths +template )> +HWY_API VFromD>> SumsOfAdjQuadAbsDiff(V a, V b) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWide dw; + + // Adjust the values of a and b to be in the 0..255 range by adding 128 to + // each lane of a and b (which is the same as an bitwise XOR of each i8 lane + // by 128) and then bitcasting the results of the Xor operations to u8 + // vectors. + const auto i8_msb = SignBit(d); + const auto a_adj = BitCast(du, Xor(a, i8_msb)); + const auto b_adj = BitCast(du, Xor(b, i8_msb)); + + // The result of SumsOfAdjQuadAbsDiff(a_adj, b_adj) can + // simply be bitcasted to an i16 vector as + // |(a[i] + 128) - (b[i] + 128)| == |a[i] - b[i]| is true. + return BitCast(dw, SumsOfAdjQuadAbsDiff(a_adj, b_adj)); +} +#endif + +// ------------------------------ SumsOfShuffledQuadAbsDiff + +#if HWY_TARGET <= HWY_AVX3 +#ifdef HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF +#undef HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF +#else +#define HWY_NATIVE_SUMS_OF_SHUFFLED_QUAD_ABS_DIFF +#endif + +template +HWY_API Vec128 SumsOfShuffledQuadAbsDiff( + Vec128 a, Vec128 b) { + static_assert(0 <= kIdx0 && kIdx0 <= 3, "kIdx0 must be between 0 and 3"); + static_assert(0 <= kIdx1 && kIdx1 <= 3, "kIdx1 must be between 0 and 3"); + static_assert(0 <= kIdx2 && kIdx2 <= 3, "kIdx2 must be between 0 and 3"); + static_assert(0 <= kIdx3 && kIdx3 <= 3, "kIdx3 must be between 0 and 3"); + return Vec128{ + _mm_dbsad_epu8(b.raw, a.raw, _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kIdx0))}; +} + +// Generic for all vector lengths +template )> +HWY_API VFromD>> SumsOfShuffledQuadAbsDiff(V a, + V b) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWide dw; + + // Adjust the values of a and b to be in the 0..255 range by adding 128 to + // each lane of a and b (which is the same as an bitwise XOR of each i8 lane + // by 128) and then bitcasting the results of the Xor operations to u8 + // vectors. + const auto i8_msb = SignBit(d); + const auto a_adj = BitCast(du, Xor(a, i8_msb)); + const auto b_adj = BitCast(du, Xor(b, i8_msb)); + + // The result of + // SumsOfShuffledQuadAbsDiff(a_adj, b_adj) can + // simply be bitcasted to an i16 vector as + // |(a[i] + 128) - (b[i] + 128)| == |a[i] - b[i]| is true. + return BitCast( + dw, SumsOfShuffledQuadAbsDiff(a_adj, b_adj)); +} +#endif + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epu16(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedAdd(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_adds_epi16(a.raw, b.raw)}; +} + +#if HWY_X86_HAVE_TERNARY_LOGIC +#ifdef HWY_NATIVE_I32_SATURATED_ADDSUB +#undef HWY_NATIVE_I32_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I32_SATURATED_ADDSUB +#endif + +#ifdef HWY_NATIVE_I64_SATURATED_ADDSUB +#undef HWY_NATIVE_I64_SATURATED_ADDSUB +#else +#define HWY_NATIVE_I64_SATURATED_ADDSUB +#endif + +// Generic for all vector lengths. +template , HWY_IF_SIGNED_D(D), + HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 4) | (1 << 8))> +HWY_API V SaturatedAdd(V a, V b) { + const D d; + const V sum = a + b; + const MFromD overflow_mask = + MaskFromVec(detail::TernaryLogic<0x42>(a, b, sum)); + const V max = Set(d, LimitsMax>()); + const V overflow_result = + detail::MaskedTernaryLogic<0x55>(MaskFromVec(a), max, max, max); + return IfThenElse(overflow_mask, overflow_result, sum); +} + +#endif // HWY_X86_HAVE_TERNARY_LOGIC + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epu16(a.raw, b.raw)}; +} + +// Signed +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec128 SaturatedSub(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_subs_epi16(a.raw, b.raw)}; +} + +#if HWY_X86_HAVE_TERNARY_LOGIC +// Generic for all vector lengths. +template , HWY_IF_SIGNED_D(D), + HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 4) | (1 << 8))> +HWY_API V SaturatedSub(V a, V b) { + const D d; + const V diff = a - b; + const MFromD overflow_mask = + MaskFromVec(detail::TernaryLogic<0x18>(a, b, diff)); + const V max = Set(d, LimitsMax>()); + const V overflow_result = + detail::MaskedTernaryLogic<0x55>(MaskFromVec(a), max, max, max); + return IfThenElse(overflow_mask, overflow_result, diff); +} + +#endif // HWY_X86_HAVE_TERNARY_LOGIC + +// ------------------------------ AverageRound + +// Returns (a + b + 1) / 2 + +// Unsigned +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_avg_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 AverageRound(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_avg_epu16(a.raw, b.raw)}; +} + +// I8/I16 AverageRound is generic for all vector lengths +template +HWY_API V AverageRound(V a, V b) { + const DFromV d; + const RebindToUnsigned du; + const V sign_bit = SignBit(d); + return Xor(BitCast(d, AverageRound(BitCast(du, Xor(a, sign_bit)), + BitCast(du, Xor(b, sign_bit)))), + sign_bit); +} + +// ------------------------------ Integer multiplication + +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mullo_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mullo_epi16(a.raw, b.raw)}; +} + +// Returns the upper sizeof(T)*8 bits of a * b in each lane. +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mulhi_epu16(a.raw, b.raw)}; +} +template +HWY_API Vec128 MulHigh(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mulhi_epi16(a.raw, b.raw)}; +} + +template , 1)> +HWY_API V MulHigh(V a, V b) { + const DFromV d; + const Full128> d_full; + return ResizeBitCast( + d, Slide1Down(d_full, ResizeBitCast(d_full, MulEven(a, b)))); +} + +// I8/U8/I32/U32 MulHigh is generic for all vector lengths >= 2 lanes +template , 1)> +HWY_API V MulHigh(V a, V b) { + const DFromV d; + + const auto p_even = BitCast(d, MulEven(a, b)); + const auto p_odd = BitCast(d, MulOdd(a, b)); + return InterleaveOdd(d, p_even, p_odd); +} + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +template )> +HWY_API VFromD>> MulEven(V a, V b) { + const DFromV d; + const RepartitionToWide dw; + const auto lo8_mask = Set(dw, uint16_t{0x00FF}); + return And(ResizeBitCast(dw, a), lo8_mask) * + And(ResizeBitCast(dw, b), lo8_mask); +} + +template )> +HWY_API VFromD>> MulEven(V a, V b) { + const DFromV d; + const RepartitionToWide dw; + return ShiftRight<8>(ShiftLeft<8>(ResizeBitCast(dw, a))) * + ShiftRight<8>(ShiftLeft<8>(ResizeBitCast(dw, b))); +} + +template )> +HWY_API VFromD>> MulEven(V a, V b) { + const DFromV d; + const RepartitionToWide dw; + const RepartitionToNarrow dw_as_d16; + + const auto lo = ResizeBitCast(dw, a * b); + const auto hi = ShiftLeft<16>(ResizeBitCast(dw, MulHigh(a, b))); + return BitCast(dw, OddEven(BitCast(dw_as_d16, hi), BitCast(dw_as_d16, lo))); +} + +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_epu32(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MulEven(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + const DFromV d; + const RepartitionToWide dw; + const RebindToUnsigned du; + + // p[i] = (((a[i] >> 31) * (a[i] >> 31)) << 64) + + // (((a[i] >> 31) * b[i]) << 32) + + // (((b[i] >> 31) * a[i]) << 32) + + // ((a[i] & int64_t{0xFFFFFFFF}) * (b[i] & int64_t{0xFFFFFFFF})) + + // ((a[i] >> 31) * (a[i] >> 31)) << 64 does not need to be computed as the + // lower 64 bits of ((a[i] >> 31) * (a[i] >> 31)) << 64 is zero. + + // (((a[i] >> 31) * b[i]) << 32) + (((b[i] >> 31) * a[i]) << 32) == + // -((((a[i] >> 31) & b[i]) + ((b[i] >> 31) & a[i])) << 32) + + // ((a[i] & int64_t{0xFFFFFFFF}) * (b[i] & int64_t{0xFFFFFFFF})) can be + // computed using MulEven(BitCast(du, a), BitCast(du, b)) + + const auto neg_p_hi = ShiftLeft<32>( + ResizeBitCast(dw, And(ShiftRight<31>(a), b) + And(ShiftRight<31>(b), a))); + const auto p_lo = BitCast(dw, MulEven(BitCast(du, a), BitCast(du, b))); + return p_lo - neg_p_hi; +#else + return Vec128{_mm_mul_epi32(a.raw, b.raw)}; +#endif +} + +template +HWY_API VFromD>> MulOdd(V a, V b) { + const DFromV d; + const RepartitionToWide dw; + return ShiftRight<8>(ResizeBitCast(dw, a)) * + ShiftRight<8>(ResizeBitCast(dw, b)); +} + +template )> +HWY_API VFromD>> MulOdd(V a, V b) { + const DFromV d; + const RepartitionToWide dw; + const RebindToUnsigned dw_u; + const RepartitionToNarrow dw_as_d16; + + const auto lo = ShiftRight<16>(BitCast(dw_u, ResizeBitCast(dw, a * b))); + const auto hi = ResizeBitCast(dw, MulHigh(a, b)); + return BitCast(dw, OddEven(BitCast(dw_as_d16, hi), BitCast(dw_as_d16, lo))); +} + +template )> +HWY_API VFromD>> MulOdd(V a, V b) { + return MulEven(DupOdd(a), DupOdd(b)); +} + +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + // Not as inefficient as it looks: _mm_mullo_epi32 has 10 cycle latency. + // 64-bit right shift would also work but also needs port 5, so no benefit. + // Notation: x=don't care, z=0. + const __m128i a_x3x1 = _mm_shuffle_epi32(a.raw, _MM_SHUFFLE(3, 3, 1, 1)); + const auto mullo_x2x0 = MulEven(a, b); + const __m128i b_x3x1 = _mm_shuffle_epi32(b.raw, _MM_SHUFFLE(3, 3, 1, 1)); + const auto mullo_x3x1 = + MulEven(Vec128{a_x3x1}, Vec128{b_x3x1}); + // We could _mm_slli_epi64 by 32 to get 3z1z and OR with z2z0, but generating + // the latter requires one more instruction or a constant. + const __m128i mul_20 = + _mm_shuffle_epi32(mullo_x2x0.raw, _MM_SHUFFLE(2, 0, 2, 0)); + const __m128i mul_31 = + _mm_shuffle_epi32(mullo_x3x1.raw, _MM_SHUFFLE(2, 0, 2, 0)); + return Vec128{_mm_unpacklo_epi32(mul_20, mul_31)}; +#else + return Vec128{_mm_mullo_epi32(a.raw, b.raw)}; +#endif +} + +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + // Same as unsigned; avoid duplicating the SSSE3 code. + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, BitCast(du, a) * BitCast(du, b)); +} + +#if HWY_TARGET <= HWY_AVX3 +// Per-target flag to prevent generic_ops-inl.h from defining 64-bit operator*. +#ifdef HWY_NATIVE_MUL_64 +#undef HWY_NATIVE_MUL_64 +#else +#define HWY_NATIVE_MUL_64 +#endif + +template +HWY_API Vec128 operator*(Vec128 a, + Vec128 b) { + return Vec128{_mm_mullo_epi64(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(Vec128 a, + Vec128 b) { + return Vec128{_mm_mullo_epi64(a.raw, b.raw)}; +} +#endif + +// ------------------------------ RotateRight (ShiftRight, Or) + +// U8 RotateRight implementation on AVX3_DL is now in x86_avx3-inl.h as U8 +// RotateRight uses detail::GaloisAffine on AVX3_DL + +#if HWY_TARGET > HWY_AVX3_DL +template +HWY_API Vec128 RotateRight(const Vec128 v) { + static_assert(0 <= kBits && kBits < 8, "Invalid shift count"); + if (kBits == 0) return v; + // AVX3 does not support 8-bit. + return Or(ShiftRight(v), ShiftLeft(v)); +} +#endif + +template +HWY_API Vec128 RotateRight(const Vec128 v) { + static_assert(0 <= kBits && kBits < 16, "Invalid shift count"); + if (kBits == 0) return v; +#if HWY_TARGET <= HWY_AVX3_DL + return Vec128{_mm_shrdi_epi16(v.raw, v.raw, kBits)}; +#else + // AVX3 does not support 16-bit. + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +template +HWY_API Vec128 RotateRight(const Vec128 v) { + static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_ror_epi32(v.raw, kBits)}; +#else + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +template +HWY_API Vec128 RotateRight(const Vec128 v) { + static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_ror_epi64(v.raw, kBits)}; +#else + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +// I8/I16/I32/I64 RotateRight is generic for all vector lengths +template +HWY_API V RotateRight(V v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, RotateRight(BitCast(du, v))); +} + +// ------------------------------ Rol/Ror +#if HWY_TARGET <= HWY_AVX3_DL +#ifdef HWY_NATIVE_ROL_ROR_16 +#undef HWY_NATIVE_ROL_ROR_16 +#else +#define HWY_NATIVE_ROL_ROR_16 +#endif + +template +HWY_API Vec128 Ror(Vec128 a, Vec128 b) { + return Vec128{_mm_shrdv_epi16(a.raw, a.raw, b.raw)}; +} + +// U16/I16 Rol is generic for all vector lengths on AVX3_DL +template )> +HWY_API V Rol(V a, V b) { + const DFromV d; + const RebindToSigned di; + return Ror(a, BitCast(d, Neg(BitCast(di, b)))); +} + +#endif // HWY_TARGET <= HWY_AVX3_DL + +#if HWY_TARGET <= HWY_AVX3 + +#ifdef HWY_NATIVE_ROL_ROR_32_64 +#undef HWY_NATIVE_ROL_ROR_32_64 +#else +#define HWY_NATIVE_ROL_ROR_32_64 +#endif + +template +HWY_API Vec128 Rol(Vec128 a, Vec128 b) { + return Vec128{_mm_rolv_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Ror(Vec128 a, Vec128 b) { + return Vec128{_mm_rorv_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Rol(Vec128 a, Vec128 b) { + return Vec128{_mm_rolv_epi64(a.raw, b.raw)}; +} + +template +HWY_API Vec128 Ror(Vec128 a, Vec128 b) { + return Vec128{_mm_rorv_epi64(a.raw, b.raw)}; +} + +#endif + +// ------------------------------ RotateLeftSame/RotateRightSame + +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_ROL_ROR_SAME_16 +#undef HWY_NATIVE_ROL_ROR_SAME_16 +#else +#define HWY_NATIVE_ROL_ROR_SAME_16 +#endif + +// Generic for all vector lengths +template )> +HWY_API V RotateLeftSame(V v, int bits) { + const DFromV d; + return Ror(v, + Set(d, static_cast>(0u - static_cast(bits)))); +} + +template )> +HWY_API V RotateRightSame(V v, int bits) { + const DFromV d; + return Ror(v, Set(d, static_cast>(bits))); +} +#endif // HWY_TARGET <= HWY_AVX3_DL + +#if HWY_TARGET <= HWY_AVX3 + +#ifdef HWY_NATIVE_ROL_ROR_SAME_32_64 +#undef HWY_NATIVE_ROL_ROR_SAME_32_64 +#else +#define HWY_NATIVE_ROL_ROR_SAME_32_64 +#endif + +// Generic for all vector lengths +template +HWY_API V RotateLeftSame(V v, int bits) { + const DFromV d; + return Rol(v, Set(d, static_cast>(static_cast(bits)))); +} + +template +HWY_API V RotateRightSame(V v, int bits) { + const DFromV d; + return Ror(v, Set(d, static_cast>(static_cast(bits)))); +} +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + const DFromV d; + return VecFromMask(v < Zero(d)); +} + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight<15>(v); +} + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + return ShiftRight<31>(v); +} + +template +HWY_API Vec128 BroadcastSignBit(const Vec128 v) { + const DFromV d; +#if HWY_TARGET <= HWY_AVX3 + (void)d; + return Vec128{_mm_srai_epi64(v.raw, 63)}; +#elif HWY_TARGET == HWY_AVX2 || HWY_TARGET == HWY_SSE4 + return VecFromMask(v < Zero(d)); +#else + // Efficient Lt() requires SSE4.2 and BLENDVPD requires SSE4.1. 32-bit shift + // avoids generating a zero. + const RepartitionToNarrow d32; + const auto sign = ShiftRight<31>(BitCast(d32, v)); + return Vec128{ + _mm_shuffle_epi32(sign.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +#endif +} + +// ------------------------------ Integer Abs + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +template +HWY_API Vec128 Abs(const Vec128 v) { +#if HWY_COMPILER_MSVC || HWY_TARGET == HWY_SSE2 + const DFromV d; + const RebindToUnsigned du; + const auto zero = Zero(du); + const auto v_as_u8 = BitCast(du, v); + return BitCast(d, Min(v_as_u8, zero - v_as_u8)); +#else + return Vec128{_mm_abs_epi8(v.raw)}; +#endif +} + +template +HWY_API Vec128 Abs(const Vec128 v) { +#if HWY_TARGET == HWY_SSE2 + const auto zero = Zero(DFromV()); + return Max(v, zero - v); +#else + return Vec128{_mm_abs_epi16(v.raw)}; +#endif +} + +template +HWY_API Vec128 Abs(const Vec128 v) { +#if HWY_TARGET <= HWY_SSSE3 + return Vec128{_mm_abs_epi32(v.raw)}; +#else + const auto zero = Zero(DFromV()); + return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); +#endif +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API Vec128 Abs(const Vec128 v) { + return Vec128{_mm_abs_epi64(v.raw)}; +} +#else +// I64 Abs is generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template )> +HWY_API V Abs(V v) { + const auto zero = Zero(DFromV()); + return IfNegativeThenElse(v, zero - v, v); +} +#endif + +#ifdef HWY_NATIVE_SATURATED_ABS +#undef HWY_NATIVE_SATURATED_ABS +#else +#define HWY_NATIVE_SATURATED_ABS +#endif + +// Generic for all vector lengths +template )> +HWY_API V SaturatedAbs(V v) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, Min(BitCast(du, v), BitCast(du, SaturatedSub(Zero(d), v)))); +} + +// Generic for all vector lengths +template )> +HWY_API V SaturatedAbs(V v) { + return Max(v, SaturatedSub(Zero(DFromV()), v)); +} + +// Generic for all vector lengths +template )> +HWY_API V SaturatedAbs(V v) { + const auto abs_v = Abs(v); + +#if HWY_TARGET <= HWY_SSE4 + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, Min(BitCast(du, abs_v), + Set(du, static_cast(LimitsMax())))); +#else + return Add(abs_v, BroadcastSignBit(abs_v)); +#endif +} + +// Generic for all vector lengths +template )> +HWY_API V SaturatedAbs(V v) { + const auto abs_v = Abs(v); + return Add(abs_v, BroadcastSignBit(abs_v)); +} + +// GCC <14 and Clang <11 do not follow the Intel documentation for AVX-512VL +// srli_epi64: the count should be unsigned int. Note that this is not the same +// as the Shift3264Count in x86_512-inl.h (GCC also requires int). +#if (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1100) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400) +using Shift64Count = int; +#else +// Assume documented behavior. Clang 12, GCC 14 and MSVC 14.28.29910 match this. +using Shift64Count = unsigned int; +#endif + +template +HWY_API Vec128 ShiftRight(const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{ + _mm_srai_epi64(v.raw, static_cast(kBits))}; +#else + const DFromV di; + const RebindToUnsigned du; + const auto right = BitCast(di, ShiftRight(BitCast(du, v))); + const auto sign = ShiftLeft<64 - kBits>(BroadcastSignBit(v)); + return right | sign; +#endif +} + +// ------------------------------ IfNegativeThenElse +template +HWY_API Vec128 IfNegativeThenElse(const Vec128 v, + const Vec128 yes, + const Vec128 no) { +// int8: IfThenElse only looks at the MSB on SSE4 or newer +#if HWY_TARGET <= HWY_SSE4 + const auto mask = MaskFromVec(v); +#else + const DFromV d; + const RebindToSigned di; + const auto mask = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); +#endif + + return IfThenElse(mask, yes, no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + +// 16-bit: no native blendv on AVX2 or earlier, so copy sign to lower byte's +// MSB. +#if HWY_TARGET <= HWY_AVX3 + const auto mask = MaskFromVec(v); +#else + const DFromV d; + const RebindToSigned di; + const auto mask = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); +#endif + + return IfThenElse(mask, yes, no); +} + +template +HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, + Vec128 no) { + static_assert(IsSigned(), "Only works for signed/float"); + const DFromV d; + +#if HWY_TARGET > HWY_AVX3 && HWY_TARGET <= HWY_SSE4 + // 32/64-bit: use float IfThenElse on SSE4/AVX2, which only looks at the MSB + // on SSE4 or later. + const RebindToFloat df; + const auto mask = MaskFromVec(BitCast(df, v)); + return BitCast(d, IfThenElse(mask, BitCast(df, yes), BitCast(df, no))); +#else // SSE2, SSSE3, or AVX3 + +#if HWY_TARGET <= HWY_AVX3 + // No need to cast to float or broadcast sign bit on AVX3 as IfThenElse only + // looks at the MSB on AVX3 + (void)d; + const auto mask = MaskFromVec(v); +#else + const RebindToSigned di; + const auto mask = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); +#endif + + return IfThenElse(mask, yes, no); +#endif +} + +#if HWY_TARGET > HWY_AVX3 && HWY_TARGET <= HWY_SSE4 + +#ifdef HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#undef HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#else +#define HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO +#endif + +#ifdef HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#undef HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#else +#define HWY_NATIVE_IF_NEG_THEN_ZERO_ELSE +#endif + +// SSE4/AVX2 IfNegativeThenElseZero/IfNegativeThenZeroElse is generic for all +// vector lengths +template +HWY_API V IfNegativeThenElseZero(V v, V yes) { + const DFromV d; + return IfNegativeThenElse(v, yes, Zero(d)); +} + +template +HWY_API V IfNegativeThenElseZero(V v, V yes) { + return IfThenElseZero(IsNegative(v), yes); +} + +template +HWY_API V IfNegativeThenZeroElse(V v, V no) { + const DFromV d; + return IfNegativeThenElse(v, Zero(d), no); +} + +template +HWY_API V IfNegativeThenZeroElse(V v, V no) { + return IfThenZeroElse(IsNegative(v), no); +} + +#endif // HWY_TARGET > HWY_AVX3 && HWY_TARGET <= HWY_SSE4 + +// ------------------------------ IfNegativeThenNegOrUndefIfZero + +#if HWY_TARGET <= HWY_SSSE3 + +#ifdef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#undef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#else +#define HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG +#endif + +template +HWY_API Vec128 IfNegativeThenNegOrUndefIfZero(Vec128 mask, + Vec128 v) { + return Vec128{_mm_sign_epi8(v.raw, mask.raw)}; +} + +template +HWY_API Vec128 IfNegativeThenNegOrUndefIfZero( + Vec128 mask, Vec128 v) { + return Vec128{_mm_sign_epi16(v.raw, mask.raw)}; +} + +template +HWY_API Vec128 IfNegativeThenNegOrUndefIfZero( + Vec128 mask, Vec128 v) { + return Vec128{_mm_sign_epi32(v.raw, mask.raw)}; +} + +// Generic for all vector lengths +template )> +HWY_API V IfNegativeThenNegOrUndefIfZero(V mask, V v) { +#if HWY_TARGET <= HWY_AVX3 + // MaskedSubOr is more efficient than IfNegativeThenElse on AVX3 + const DFromV d; + return MaskedSubOr(v, MaskFromVec(mask), Zero(d), v); +#else + // IfNegativeThenElse is more efficient than MaskedSubOr on SSE4/AVX2 + return IfNegativeThenElse(mask, Neg(v), v); +#endif +} + +#endif // HWY_TARGET <= HWY_SSSE3 + +// ------------------------------ ShiftLeftSame + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{_mm_slli_epi16(v.raw, bits)}; + } +#endif + return Vec128{_mm_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{_mm_slli_epi32(v.raw, bits)}; + } +#endif + return Vec128{_mm_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{_mm_slli_epi64(v.raw, bits)}; + } +#endif + return Vec128{_mm_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{_mm_slli_epi16(v.raw, bits)}; + } +#endif + return Vec128{_mm_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{_mm_slli_epi32(v.raw, bits)}; + } +#endif + return Vec128{_mm_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{_mm_slli_epi64(v.raw, bits)}; + } +#endif + return Vec128{_mm_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftLeftSame(const Vec128 v, const int bits) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftLeftSame(Vec128>{v.raw}, bits).raw}; + return shifted & Set(d8, static_cast((0xFF << bits) & 0xFF)); +} + +// ------------------------------ ShiftRightSame (BroadcastSignBit) + +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{_mm_srli_epi16(v.raw, bits)}; + } +#endif + return Vec128{_mm_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{_mm_srli_epi32(v.raw, bits)}; + } +#endif + return Vec128{_mm_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{_mm_srli_epi64(v.raw, bits)}; + } +#endif + return Vec128{_mm_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, + const int bits) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec128 shifted{ + ShiftRightSame(Vec128{v.raw}, bits).raw}; + return shifted & Set(d8, static_cast(0xFF >> bits)); +} + +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{_mm_srai_epi16(v.raw, bits)}; + } +#endif + return Vec128{_mm_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{_mm_srai_epi32(v.raw, bits)}; + } +#endif + return Vec128{_mm_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +template +HWY_API Vec128 ShiftRightSame(const Vec128 v, + const int bits) { +#if HWY_TARGET <= HWY_AVX3 +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec128{ + _mm_srai_epi64(v.raw, static_cast(bits))}; + } +#endif + return Vec128{_mm_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +#else + const DFromV di; + const RebindToUnsigned du; + const auto right = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto sign = ShiftLeftSame(BroadcastSignBit(v), 64 - bits); + return right | sign; +#endif +} + +template +HWY_API Vec128 ShiftRightSame(Vec128 v, const int bits) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = + BitCast(di, Set(du, static_cast(0x80 >> bits))); + return (shifted ^ shifted_sign) - shifted_sign; +} + +// ------------------------------ Floating-point mul / div + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 operator*(Vec128 a, + Vec128 b) { + return Vec128{_mm_mul_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 operator*(Vec128 a, Vec128 b) { + return Vec128{_mm_mul_ps(a.raw, b.raw)}; +} +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_ss(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator*(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mul_pd(a.raw, b.raw)}; +} +HWY_API Vec64 operator*(const Vec64 a, const Vec64 b) { + return Vec64{_mm_mul_sd(a.raw, b.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 + +#ifdef HWY_NATIVE_MUL_BY_POW2 +#undef HWY_NATIVE_MUL_BY_POW2 +#else +#define HWY_NATIVE_MUL_BY_POW2 +#endif + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MulByFloorPow2(Vec128 a, + Vec128 b) { + return Vec128{_mm_scalef_ph(a.raw, b.raw)}; +} +#endif + +template +HWY_API Vec128 MulByFloorPow2(Vec128 a, + Vec128 b) { + return Vec128{_mm_scalef_ps(a.raw, b.raw)}; +} + +template +HWY_API Vec128 MulByFloorPow2(Vec128 a, + Vec128 b) { + return Vec128{_mm_scalef_pd(a.raw, b.raw)}; +} + +// MulByPow2 is generic for all vector lengths on AVX3 +template +HWY_API V MulByPow2(V v, VFromD>> exp) { + const DFromV d; + return MulByFloorPow2(v, ConvertTo(d, exp)); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_div_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_div_ps(a.raw, b.raw)}; +} +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_div_ss(a.raw, b.raw)}; +} +template +HWY_API Vec128 operator/(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_div_pd(a.raw, b.raw)}; +} +HWY_API Vec64 operator/(const Vec64 a, const Vec64 b) { + return Vec64{_mm_div_sd(a.raw, b.raw)}; +} + +// Approximate reciprocal +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 ApproximateReciprocal( + const Vec128 v) { + return Vec128{_mm_rcp_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { + return Vec128{_mm_rcp_ps(v.raw)}; +} +HWY_API Vec128 ApproximateReciprocal(const Vec128 v) { + return Vec128{_mm_rcp_ss(v.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 +#ifdef HWY_NATIVE_F64_APPROX_RECIP +#undef HWY_NATIVE_F64_APPROX_RECIP +#else +#define HWY_NATIVE_F64_APPROX_RECIP +#endif + +HWY_API Vec128 ApproximateReciprocal(Vec128 v) { + return Vec128{_mm_rcp14_pd(v.raw)}; +} +HWY_API Vec64 ApproximateReciprocal(Vec64 v) { + return Vec64{_mm_rcp14_sd(v.raw, v.raw)}; +} +#endif + +// Generic for all vector lengths. +template +HWY_API V AbsDiff(V a, V b) { + return Abs(a - b); +} + +// ------------------------------ GetExponent + +#if HWY_TARGET <= HWY_AVX3 + +#ifdef HWY_NATIVE_GET_EXPONENT +#undef HWY_NATIVE_GET_EXPONENT +#else +#define HWY_NATIVE_GET_EXPONENT +#endif + +#if HWY_HAVE_FLOAT16 +template ), HWY_IF_V_SIZE_LE_V(V, 16)> +HWY_API V GetExponent(V v) { + return V{_mm_getexp_ph(v.raw)}; +} +#endif +template ), HWY_IF_V_SIZE_LE_V(V, 16)> +HWY_API V GetExponent(V v) { + return V{_mm_getexp_ps(v.raw)}; +} +template ), HWY_IF_V_SIZE_LE_V(V, 16)> +HWY_API V GetExponent(V v) { + return V{_mm_getexp_pd(v.raw)}; +} + +#endif + +// ------------------------------ MaskedMinOr + +#if HWY_TARGET <= HWY_AVX3 + +#ifdef HWY_NATIVE_MASKED_ARITH +#undef HWY_NATIVE_MASKED_ARITH +#else +#define HWY_NATIVE_MASKED_ARITH +#endif + +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epu8(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epu16(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epu32(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epu64(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MaskedMinOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_min_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedMaxOr + +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epu8(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epu16(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epu32(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epu64(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MaskedMaxOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_max_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedAddOr + +template +HWY_API Vec128 MaskedAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_add_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_add_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_add_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_add_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_add_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_add_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MaskedAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_add_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedSubOr + +template +HWY_API Vec128 MaskedSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_sub_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_sub_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_sub_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_sub_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_sub_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_sub_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MaskedSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_sub_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedMulOr + +// There are no elementwise integer mask_mul. Generic for all vector lengths. +template +HWY_API V MaskedMulOr(V no, M m, V a, V b) { + return IfThenElse(m, a * b, no); +} + +template +HWY_API Vec128 MaskedMulOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_mul_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedMulOr(Vec128 no, + Mask128 m, Vec128 a, + Vec128 b) { + return Vec128{_mm_mask_mul_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MaskedMulOr(Vec128 no, + Mask128 m, + Vec128 a, + Vec128 b) { + return Vec128{_mm_mask_mul_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedDivOr + +template +HWY_API Vec128 MaskedDivOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_div_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedDivOr(Vec128 no, + Mask128 m, Vec128 a, + Vec128 b) { + return Vec128{_mm_mask_div_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MaskedDivOr(Vec128 no, + Mask128 m, + Vec128 a, + Vec128 b) { + return Vec128{_mm_mask_div_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// Generic for all vector lengths +template +HWY_API V MaskedDivOr(V no, MFromD> m, V a, V b) { + return IfThenElse(m, Div(a, b), no); +} + +// ------------------------------ MaskedModOr +// Generic for all vector lengths +template +HWY_API V MaskedModOr(V no, MFromD> m, V a, V b) { + return IfThenElse(m, Mod(a, b), no); +} + +// ------------------------------ MaskedSatAddOr + +template +HWY_API Vec128 MaskedSatAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_adds_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSatAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_adds_epu8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSatAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_adds_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSatAddOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_adds_epu16(no.raw, m.raw, a.raw, b.raw)}; +} + +// ------------------------------ MaskedSatSubOr + +template +HWY_API Vec128 MaskedSatSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_subs_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSatSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_subs_epu8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSatSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_subs_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec128 MaskedSatSubOr(Vec128 no, Mask128 m, + Vec128 a, Vec128 b) { + return Vec128{_mm_mask_subs_epu16(no.raw, m.raw, a.raw, b.raw)}; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Floating-point multiply-add variants + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MulAdd(Vec128 mul, + Vec128 x, + Vec128 add) { + return Vec128{_mm_fmadd_ph(mul.raw, x.raw, add.raw)}; +} + +template +HWY_API Vec128 NegMulAdd(Vec128 mul, + Vec128 x, + Vec128 add) { + return Vec128{_mm_fnmadd_ph(mul.raw, x.raw, add.raw)}; +} + +template +HWY_API Vec128 MulSub(Vec128 mul, + Vec128 x, + Vec128 sub) { + return Vec128{_mm_fmsub_ph(mul.raw, x.raw, sub.raw)}; +} + +template +HWY_API Vec128 NegMulSub(Vec128 mul, + Vec128 x, + Vec128 sub) { + return Vec128{_mm_fnmsub_ph(mul.raw, x.raw, sub.raw)}; +} + +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) + return mul * x + add; +#else + return Vec128{_mm_fmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +template +HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, + Vec128 add) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) + return mul * x + add; +#else + return Vec128{_mm_fmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +// Returns add - mul * x +template +HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, + Vec128 add) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) + return add - mul * x; +#else + return Vec128{_mm_fnmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +template +HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, + Vec128 add) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) + return add - mul * x; +#else + return Vec128{_mm_fnmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +// Returns mul * x - sub +template +HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, + Vec128 sub) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) + return mul * x - sub; +#else + return Vec128{_mm_fmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +template +HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, + Vec128 sub) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) + return mul * x - sub; +#else + return Vec128{_mm_fmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +// Returns -mul * x - sub +template +HWY_API Vec128 NegMulSub(Vec128 mul, Vec128 x, + Vec128 sub) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) + return Neg(mul) * x - sub; +#else + return Vec128{_mm_fnmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +template +HWY_API Vec128 NegMulSub(Vec128 mul, Vec128 x, + Vec128 sub) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) + return Neg(mul) * x - sub; +#else + return Vec128{_mm_fnmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +#if HWY_TARGET <= HWY_SSSE3 + +#undef HWY_IF_MULADDSUB_V +#define HWY_IF_MULADDSUB_V(V) \ + HWY_IF_LANES_GT_D(DFromV, 1), \ + HWY_IF_T_SIZE_ONE_OF_V( \ + V, (1 << 1) | ((hwy::IsFloat>()) \ + ? 0 \ + : ((1 << 2) | (1 << 4) | (1 << 8)))) + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MulAddSub(Vec128 mul, + Vec128 x, + Vec128 sub_or_add) { + return Vec128{_mm_fmaddsub_ph(mul.raw, x.raw, sub_or_add.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API Vec128 MulAddSub(Vec128 mul, Vec128 x, + Vec128 sub_or_add) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) + return AddSub(mul * x, sub_or_add); +#else + return Vec128{_mm_fmaddsub_ps(mul.raw, x.raw, sub_or_add.raw)}; +#endif +} + +HWY_API Vec128 MulAddSub(Vec128 mul, Vec128 x, + Vec128 sub_or_add) { +#if HWY_TARGET >= HWY_SSE4 || defined(HWY_DISABLE_BMI2_FMA) + return AddSub(mul * x, sub_or_add); +#else + return Vec128{_mm_fmaddsub_pd(mul.raw, x.raw, sub_or_add.raw)}; +#endif +} + +#endif // HWY_TARGET <= HWY_SSSE3 + +// ------------------------------ Floating-point square root + +// Full precision square root +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Sqrt(Vec128 v) { + return Vec128{_mm_sqrt_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Sqrt(Vec128 v) { + return Vec128{_mm_sqrt_ps(v.raw)}; +} +HWY_API Vec128 Sqrt(Vec128 v) { + return Vec128{_mm_sqrt_ss(v.raw)}; +} +template +HWY_API Vec128 Sqrt(Vec128 v) { + return Vec128{_mm_sqrt_pd(v.raw)}; +} +HWY_API Vec64 Sqrt(Vec64 v) { + return Vec64{_mm_sqrt_sd(_mm_setzero_pd(), v.raw)}; +} + +// Approximate reciprocal square root +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { + return Vec128{_mm_rsqrt_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { + return Vec128{_mm_rsqrt_ps(v.raw)}; +} +HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { + return Vec128{_mm_rsqrt_ss(v.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 +#ifdef HWY_NATIVE_F64_APPROX_RSQRT +#undef HWY_NATIVE_F64_APPROX_RSQRT +#else +#define HWY_NATIVE_F64_APPROX_RSQRT +#endif + +HWY_API Vec64 ApproximateReciprocalSqrt(Vec64 v) { + return Vec64{_mm_rsqrt14_sd(v.raw, v.raw)}; +} +HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { +#if HWY_COMPILER_MSVC + const DFromV d; + return Vec128{_mm_mask_rsqrt14_pd( + Undefined(d).raw, static_cast<__mmask8>(0xFF), v.raw)}; +#else + return Vec128{_mm_rsqrt14_pd(v.raw)}; +#endif +} +#endif + +// ------------------------------ Min (Gt, IfThenElse) + +namespace detail { + +template +HWY_INLINE HWY_MAYBE_UNUSED Vec128 MinU(const Vec128 a, + const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; + const auto msb = Set(du, static_cast(T(1) << (sizeof(T) * 8 - 1))); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, b, a); +} + +} // namespace detail + +// Unsigned +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{_mm_min_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + return Vec128{ + _mm_sub_epi16(a.raw, _mm_subs_epu16(a.raw, b.raw))}; +#else + return Vec128{_mm_min_epu16(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + return detail::MinU(a, b); +#else + return Vec128{_mm_min_epu32(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_min_epu64(a.raw, b.raw)}; +#else + return detail::MinU(a, b); +#endif +} + +// Signed +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + return IfThenElse(a < b, a, b); +#else + return Vec128{_mm_min_epi8(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{_mm_min_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + return IfThenElse(a < b, a, b); +#else + return Vec128{_mm_min_epi32(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_min_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, a, b); +#endif +} + +// Float +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Min(Vec128 a, + Vec128 b) { + return Vec128{_mm_min_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{_mm_min_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Min(Vec128 a, Vec128 b) { + return Vec128{_mm_min_pd(a.raw, b.raw)}; +} + +// ------------------------------ Max (Gt, IfThenElse) + +namespace detail { +template +HWY_INLINE HWY_MAYBE_UNUSED Vec128 MaxU(const Vec128 a, + const Vec128 b) { + const DFromV d; + const RebindToUnsigned du; + const RebindToSigned di; + const auto msb = Set(du, static_cast(T(1) << (sizeof(T) * 8 - 1))); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, a, b); +} + +} // namespace detail + +// Unsigned +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{_mm_max_epu8(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + return Vec128{ + _mm_add_epi16(a.raw, _mm_subs_epu16(b.raw, a.raw))}; +#else + return Vec128{_mm_max_epu16(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + return detail::MaxU(a, b); +#else + return Vec128{_mm_max_epu32(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_max_epu64(a.raw, b.raw)}; +#else + return detail::MaxU(a, b); +#endif +} + +// Signed +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + return IfThenElse(a < b, b, a); +#else + return Vec128{_mm_max_epi8(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{_mm_max_epi16(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + return IfThenElse(a < b, b, a); +#else + return Vec128{_mm_max_epi32(a.raw, b.raw)}; +#endif +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_max_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, b, a); +#endif +} + +// Float +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Max(Vec128 a, + Vec128 b) { + return Vec128{_mm_max_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{_mm_max_ps(a.raw, b.raw)}; +} +template +HWY_API Vec128 Max(Vec128 a, Vec128 b) { + return Vec128{_mm_max_pd(a.raw, b.raw)}; +} + +// ------------------------------ MinNumber and MaxNumber + +#ifdef HWY_NATIVE_FLOAT_MIN_MAX_NUMBER +#undef HWY_NATIVE_FLOAT_MIN_MAX_NUMBER +#else +#define HWY_NATIVE_FLOAT_MIN_MAX_NUMBER +#endif + +#if HWY_X86_HAVE_AVX10_2_OPS + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MinNumber(Vec128 a, + Vec128 b) { + return Vec128{_mm_minmax_ph(a.raw, b.raw, 0x14)}; +} +#endif +template +HWY_API Vec128 MinNumber(Vec128 a, Vec128 b) { + return Vec128{_mm_minmax_ps(a.raw, b.raw, 0x14)}; +} +template +HWY_API Vec128 MinNumber(Vec128 a, Vec128 b) { + return Vec128{_mm_minmax_pd(a.raw, b.raw, 0x14)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MaxNumber(Vec128 a, + Vec128 b) { + return Vec128{_mm_minmax_ph(a.raw, b.raw, 0x15)}; +} +#endif +template +HWY_API Vec128 MaxNumber(Vec128 a, Vec128 b) { + return Vec128{_mm_minmax_ps(a.raw, b.raw, 0x15)}; +} +template +HWY_API Vec128 MaxNumber(Vec128 a, Vec128 b) { + return Vec128{_mm_minmax_pd(a.raw, b.raw, 0x15)}; +} + +#else + +// MinNumber/MaxNumber are generic for all vector lengths on targets other +// than AVX10.2 +template +HWY_API V MinNumber(V a, V b) { + return Min(a, IfThenElse(IsNaN(b), a, b)); +} + +template +HWY_API V MaxNumber(V a, V b) { + return Max(a, IfThenElse(IsNaN(b), a, b)); +} + +#endif + +// ------------------------------ MinMagnitude and MaxMagnitude + +#if HWY_X86_HAVE_AVX10_2_OPS + +#ifdef HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE +#undef HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE +#else +#define HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE +#endif + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MinMagnitude(Vec128 a, + Vec128 b) { + return Vec128{_mm_minmax_ph(a.raw, b.raw, 0x16)}; +} +#endif +template +HWY_API Vec128 MinMagnitude(Vec128 a, Vec128 b) { + return Vec128{_mm_minmax_ps(a.raw, b.raw, 0x16)}; +} +template +HWY_API Vec128 MinMagnitude(Vec128 a, + Vec128 b) { + return Vec128{_mm_minmax_pd(a.raw, b.raw, 0x16)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 MaxMagnitude(Vec128 a, + Vec128 b) { + return Vec128{_mm_minmax_ph(a.raw, b.raw, 0x17)}; +} +#endif +template +HWY_API Vec128 MaxMagnitude(Vec128 a, Vec128 b) { + return Vec128{_mm_minmax_ps(a.raw, b.raw, 0x17)}; +} +template +HWY_API Vec128 MaxMagnitude(Vec128 a, + Vec128 b) { + return Vec128{_mm_minmax_pd(a.raw, b.raw, 0x17)}; +} + +#endif + +// ================================================== MEMORY (3) + +// ------------------------------ Non-temporal stores + +// On clang6, we see incorrect code generated for _mm_stream_pi, so +// round even partial vectors up to 16 bytes. +template +HWY_API void Stream(VFromD v, D d, TFromD* HWY_RESTRICT aligned) { + const RebindToUnsigned du; // for float16_t + _mm_stream_si128(reinterpret_cast<__m128i*>(aligned), BitCast(du, v).raw); +} +template +HWY_API void Stream(VFromD v, D /* tag */, float* HWY_RESTRICT aligned) { + _mm_stream_ps(aligned, v.raw); +} +template +HWY_API void Stream(VFromD v, D /* tag */, double* HWY_RESTRICT aligned) { + _mm_stream_pd(aligned, v.raw); +} + +// ------------------------------ Scatter + +// Work around warnings in the intrinsic definitions (passing -1 as a mask). +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +// Unfortunately the GCC/Clang intrinsics do not accept int64_t*. +using GatherIndex64 = long long int; // NOLINT(runtime/int) +static_assert(sizeof(GatherIndex64) == 8, "Must be 64-bit type"); + +#if HWY_TARGET <= HWY_AVX3 + +#ifdef HWY_NATIVE_SCATTER +#undef HWY_NATIVE_SCATTER +#else +#define HWY_NATIVE_SCATTER +#endif + +namespace detail { + +template +HWY_INLINE void NativeScatter128(VFromD v, D d, TFromD* HWY_RESTRICT base, + VI index) { + if (d.MaxBytes() == 16) { + _mm_i32scatter_epi32(base, index.raw, v.raw, kScale); + } else { + const __mmask8 mask = (1u << MaxLanes(d)) - 1; + _mm_mask_i32scatter_epi32(base, mask, index.raw, v.raw, kScale); + } +} + +template +HWY_INLINE void NativeScatter128(VFromD v, D d, TFromD* HWY_RESTRICT base, + VI index) { + if (d.MaxBytes() == 16) { + _mm_i64scatter_epi64(base, index.raw, v.raw, kScale); + } else { + const __mmask8 mask = (1u << MaxLanes(d)) - 1; + _mm_mask_i64scatter_epi64(base, mask, index.raw, v.raw, kScale); + } +} + +template +HWY_INLINE void NativeScatter128(VFromD v, D d, float* HWY_RESTRICT base, + VI index) { + if (d.MaxBytes() == 16) { + _mm_i32scatter_ps(base, index.raw, v.raw, kScale); + } else { + const __mmask8 mask = (1u << MaxLanes(d)) - 1; + _mm_mask_i32scatter_ps(base, mask, index.raw, v.raw, kScale); + } +} + +template +HWY_INLINE void NativeScatter128(VFromD v, D d, double* HWY_RESTRICT base, + VI index) { + if (d.MaxBytes() == 16) { + _mm_i64scatter_pd(base, index.raw, v.raw, kScale); + } else { + const __mmask8 mask = (1u << MaxLanes(d)) - 1; + _mm_mask_i64scatter_pd(base, mask, index.raw, v.raw, kScale); + } +} + +template +HWY_INLINE void NativeMaskedScatter128(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT base, VI index) { + // For partial vectors, ensure upper mask lanes are zero to prevent faults. + if (!detail::IsFull(d)) m = And(m, FirstN(d, Lanes(d))); + _mm_mask_i32scatter_epi32(base, m.raw, index.raw, v.raw, kScale); +} + +template +HWY_INLINE void NativeMaskedScatter128(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT base, VI index) { + // For partial vectors, ensure upper mask lanes are zero to prevent faults. + if (!detail::IsFull(d)) m = And(m, FirstN(d, Lanes(d))); + _mm_mask_i64scatter_epi64(base, m.raw, index.raw, v.raw, kScale); +} + +template +HWY_INLINE void NativeMaskedScatter128(VFromD v, MFromD m, D d, + float* HWY_RESTRICT base, VI index) { + // For partial vectors, ensure upper mask lanes are zero to prevent faults. + if (!detail::IsFull(d)) m = And(m, FirstN(d, Lanes(d))); + _mm_mask_i32scatter_ps(base, m.raw, index.raw, v.raw, kScale); +} + +template +HWY_INLINE void NativeMaskedScatter128(VFromD v, MFromD m, D d, + double* HWY_RESTRICT base, VI index) { + // For partial vectors, ensure upper mask lanes are zero to prevent faults. + if (!detail::IsFull(d)) m = And(m, FirstN(d, Lanes(d))); + _mm_mask_i64scatter_pd(base, m.raw, index.raw, v.raw, kScale); +} + +} // namespace detail + +template +HWY_API void ScatterOffset(VFromD v, D d, TFromD* HWY_RESTRICT base, + VFromD> offset) { + return detail::NativeScatter128<1>(v, d, base, offset); +} +template +HWY_API void ScatterIndex(VFromD v, D d, TFromD* HWY_RESTRICT base, + VFromD> index) { + return detail::NativeScatter128)>(v, d, base, index); +} +template +HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT base, + VFromD> index) { + return detail::NativeMaskedScatter128)>(v, m, d, base, + index); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Gather (Load/Store) + +#if HWY_TARGET <= HWY_AVX2 + +#ifdef HWY_NATIVE_GATHER +#undef HWY_NATIVE_GATHER +#else +#define HWY_NATIVE_GATHER +#endif + +namespace detail { + +template +HWY_INLINE Vec128 NativeGather128(const T* HWY_RESTRICT base, + Vec128 indices) { + return Vec128{_mm_i32gather_epi32( + reinterpret_cast(base), indices.raw, kScale)}; +} + +template +HWY_INLINE Vec128 NativeGather128(const T* HWY_RESTRICT base, + Vec128 indices) { + return Vec128{_mm_i64gather_epi64( + reinterpret_cast(base), indices.raw, kScale)}; +} + +template +HWY_INLINE Vec128 NativeGather128(const float* HWY_RESTRICT base, + Vec128 indices) { + return Vec128{_mm_i32gather_ps(base, indices.raw, kScale)}; +} + +template +HWY_INLINE Vec128 NativeGather128(const double* HWY_RESTRICT base, + Vec128 indices) { + return Vec128{_mm_i64gather_pd(base, indices.raw, kScale)}; +} + +template +HWY_INLINE Vec128 NativeMaskedGatherOr128(Vec128 no, + Mask128 m, + const T* HWY_RESTRICT base, + Vec128 indices) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_mmask_i32gather_epi32( + no.raw, m.raw, indices.raw, reinterpret_cast(base), + kScale)}; +#else + return Vec128{ + _mm_mask_i32gather_epi32(no.raw, reinterpret_cast(base), + indices.raw, m.raw, kScale)}; +#endif +} + +template +HWY_INLINE Vec128 NativeMaskedGatherOr128(Vec128 no, + Mask128 m, + const T* HWY_RESTRICT base, + Vec128 indices) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_mmask_i64gather_epi64( + no.raw, m.raw, indices.raw, reinterpret_cast(base), + kScale)}; +#else + return Vec128{_mm_mask_i64gather_epi64( + no.raw, reinterpret_cast(base), indices.raw, m.raw, + kScale)}; +#endif +} + +template +HWY_INLINE Vec128 NativeMaskedGatherOr128( + Vec128 no, Mask128 m, const float* HWY_RESTRICT base, + Vec128 indices) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{ + _mm_mmask_i32gather_ps(no.raw, m.raw, indices.raw, base, kScale)}; +#else + return Vec128{ + _mm_mask_i32gather_ps(no.raw, base, indices.raw, m.raw, kScale)}; +#endif +} + +template +HWY_INLINE Vec128 NativeMaskedGatherOr128( + Vec128 no, Mask128 m, const double* HWY_RESTRICT base, + Vec128 indices) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{ + _mm_mmask_i64gather_pd(no.raw, m.raw, indices.raw, base, kScale)}; +#else + return Vec128{ + _mm_mask_i64gather_pd(no.raw, base, indices.raw, m.raw, kScale)}; +#endif +} + +} // namespace detail + +template +HWY_API VFromD GatherOffset(D /*d*/, const TFromD* HWY_RESTRICT base, + VFromD> offsets) { + return detail::NativeGather128<1>(base, offsets); +} + +template > +HWY_API VFromD GatherIndex(D /*d*/, const T* HWY_RESTRICT base, + VFromD> indices) { + return detail::NativeGather128(base, indices); +} + +template > +HWY_API VFromD MaskedGatherIndexOr(VFromD no, MFromD m, D d, + const T* HWY_RESTRICT base, + VFromD> indices) { + // For partial vectors, ensure upper mask lanes are zero to prevent faults. + if (!detail::IsFull(d)) m = And(m, FirstN(d, Lanes(d))); + + return detail::NativeMaskedGatherOr128(no, m, base, indices); +} + +// Generic for all vector lengths. +template +HWY_API VFromD MaskedGatherIndex(MFromD m, D d, + const TFromD* HWY_RESTRICT base, + VFromD> indices) { + return MaskedGatherIndexOr(Zero(d), m, d, base, indices); +} + +#endif // HWY_TARGET <= HWY_AVX2 + +HWY_DIAGNOSTICS(pop) + +// ================================================== SWIZZLE (2) + +// ------------------------------ LowerHalf + +template +HWY_API VFromD LowerHalf(D /* tag */, VFromD> v) { + return VFromD{v.raw}; +} +template +HWY_API Vec128 LowerHalf(Vec128 v) { + return Vec128{v.raw}; +} + +// ------------------------------ ShiftLeftBytes + +template +HWY_API VFromD ShiftLeftBytes(D d, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + const RebindToUnsigned du; + return BitCast( + d, VFromD{_mm_slli_si128(BitCast(du, v).raw, kBytes)}); +} + +// Generic for all vector lengths. +template +HWY_API V ShiftLeftBytes(const V v) { + return ShiftLeftBytes(DFromV(), v); +} + +// ------------------------------ ShiftLeftLanes + +// Generic for all vector lengths. +template +HWY_API VFromD ShiftLeftLanes(D d, const VFromD v) { + const Repartition d8; + return BitCast(d, ShiftLeftBytes)>(BitCast(d8, v))); +} + +// Generic for all vector lengths. +template +HWY_API V ShiftLeftLanes(const V v) { + return ShiftLeftLanes(DFromV(), v); +} + +// ------------------------------ ShiftRightBytes +template +HWY_API VFromD ShiftRightBytes(D d, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + const RebindToUnsigned du; + // For partial vectors, clear upper lanes so we shift in zeros. + if (d.MaxBytes() != 16) { + const Full128> dfull; + const VFromD vfull{v.raw}; + v = VFromD{IfThenElseZero(FirstN(dfull, MaxLanes(d)), vfull).raw}; + } + return BitCast( + d, VFromD{_mm_srli_si128(BitCast(du, v).raw, kBytes)}); +} + +// ------------------------------ ShiftRightLanes +// Generic for all vector lengths. +template +HWY_API VFromD ShiftRightLanes(D d, const VFromD v) { + const Repartition d8; + constexpr size_t kBytes = kLanes * sizeof(TFromD); + return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); +} + +// ------------------------------ UpperHalf (ShiftRightBytes) + +// Full input: copy hi into lo (smaller instruction encoding than shifts). +template +HWY_API VFromD UpperHalf(D d, VFromD> v) { + const Twice> dut; + using VUT = VFromD; // for float16_t + const VUT vut = BitCast(dut, v); + return BitCast(d, LowerHalf(VUT{_mm_unpackhi_epi64(vut.raw, vut.raw)})); +} +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64{_mm_movehl_ps(v.raw, v.raw)}; +} +template +HWY_API Vec64 UpperHalf(D /* tag */, Vec128 v) { + return Vec64{_mm_unpackhi_pd(v.raw, v.raw)}; +} + +// Partial +template +HWY_API VFromD UpperHalf(D d, VFromD> v) { + return LowerHalf(d, ShiftRightBytes(Twice(), v)); +} + +// ------------------------------ ExtractLane (UpperHalf) + +namespace detail { + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET >= HWY_SSSE3 + const int pair = _mm_extract_epi16(v.raw, kLane / 2); + constexpr int kShift = kLane & 1 ? 8 : 0; + return static_cast((pair >> kShift) & 0xFF); +#else + return static_cast(_mm_extract_epi8(v.raw, kLane) & 0xFF); +#endif +} + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); + const DFromV d; + const RebindToUnsigned du; + const uint16_t lane = static_cast( + _mm_extract_epi16(BitCast(du, v).raw, kLane) & 0xFFFF); + return BitCastScalar(lane); +} + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET >= HWY_SSSE3 + return static_cast(_mm_cvtsi128_si32( + (kLane == 0) ? v.raw : _mm_shuffle_epi32(v.raw, kLane))); +#else + return static_cast(_mm_extract_epi32(v.raw, kLane)); +#endif +} + +template +HWY_INLINE T ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_ARCH_X86_32 + alignas(16) T lanes[2]; + Store(v, DFromV(), lanes); + return lanes[kLane]; +#elif HWY_TARGET >= HWY_SSSE3 + return static_cast( + _mm_cvtsi128_si64((kLane == 0) ? v.raw : _mm_shuffle_epi32(v.raw, 0xEE))); +#else + return static_cast(_mm_extract_epi64(v.raw, kLane)); +#endif +} + +template +HWY_INLINE float ExtractLane(const Vec128 v) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET >= HWY_SSSE3 + return _mm_cvtss_f32((kLane == 0) ? v.raw + : _mm_shuffle_ps(v.raw, v.raw, kLane)); +#else + // Bug in the intrinsic, returns int but should be float. + const int32_t bits = _mm_extract_ps(v.raw, kLane); + return BitCastScalar(bits); +#endif +} + +// There is no extract_pd; two overloads because there is no UpperHalf for N=1. +template +HWY_INLINE double ExtractLane(const Vec64 v) { + static_assert(kLane == 0, "Lane index out of bounds"); + return GetLane(v); +} + +template +HWY_INLINE double ExtractLane(const Vec128 v) { + static_assert(kLane < 2, "Lane index out of bounds"); + const Half> dh; + return kLane == 0 ? GetLane(v) : GetLane(UpperHalf(dh, v)); +} + +} // namespace detail + +// Requires one overload per vector length because ExtractLane<3> may be a +// compile error if it calls _mm_extract_epi64. +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { + HWY_DASSERT(i == 0); + (void)i; + return GetLane(v); +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + } + } +#endif + alignas(16) T lanes[2]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + } + } +#endif + alignas(16) T lanes[4]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + } + } +#endif + alignas(16) T lanes[8]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +template +HWY_API T ExtractLane(const Vec128 v, size_t i) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::ExtractLane<0>(v); + case 1: + return detail::ExtractLane<1>(v); + case 2: + return detail::ExtractLane<2>(v); + case 3: + return detail::ExtractLane<3>(v); + case 4: + return detail::ExtractLane<4>(v); + case 5: + return detail::ExtractLane<5>(v); + case 6: + return detail::ExtractLane<6>(v); + case 7: + return detail::ExtractLane<7>(v); + case 8: + return detail::ExtractLane<8>(v); + case 9: + return detail::ExtractLane<9>(v); + case 10: + return detail::ExtractLane<10>(v); + case 11: + return detail::ExtractLane<11>(v); + case 12: + return detail::ExtractLane<12>(v); + case 13: + return detail::ExtractLane<13>(v); + case 14: + return detail::ExtractLane<14>(v); + case 15: + return detail::ExtractLane<15>(v); + } + } +#endif + alignas(16) T lanes[16]; + Store(v, DFromV(), lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane (UpperHalf) + +namespace detail { + +template +HWY_INLINE V InsertLaneUsingBroadcastAndBlend(V v, size_t i, TFromV t) { + const DFromV d; + +#if HWY_TARGET <= HWY_AVX3 + using RawMask = decltype(MaskFromVec(VFromD()).raw); + const auto mask = MFromD{static_cast(uint64_t{1} << i)}; +#else + const RebindToUnsigned du; + using TU = TFromD; + const auto mask = RebindMask(d, Iota(du, 0) == Set(du, static_cast(i))); +#endif + + return IfThenElse(mask, Set(d, t), v); +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET >= HWY_SSSE3 + return InsertLaneUsingBroadcastAndBlend(v, kLane, t); +#else + return Vec128{_mm_insert_epi8(v.raw, t, kLane)}; +#endif +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); + const DFromV d; + const RebindToUnsigned du; + const uint16_t bits = BitCastScalar(t); + return BitCast(d, VFromD{ + _mm_insert_epi16(BitCast(du, v).raw, bits, kLane)}); +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET >= HWY_SSSE3 + return InsertLaneUsingBroadcastAndBlend(v, kLane, t); +#else + const MakeSigned ti = BitCastScalar>(t); + return Vec128{_mm_insert_epi32(v.raw, ti, kLane)}; +#endif +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, T t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET >= HWY_SSSE3 || HWY_ARCH_X86_32 + const DFromV d; + const RebindToFloat df; + const auto vt = BitCast(df, Set(d, t)); + if (kLane == 0) { + return BitCast( + d, Vec128{_mm_shuffle_pd(vt.raw, BitCast(df, v).raw, 2)}); + } + return BitCast( + d, Vec128{_mm_shuffle_pd(BitCast(df, v).raw, vt.raw, 0)}); +#else + const MakeSigned ti = BitCastScalar>(t); + return Vec128{_mm_insert_epi64(v.raw, ti, kLane)}; +#endif +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, float t) { + static_assert(kLane < N, "Lane index out of bounds"); +#if HWY_TARGET >= HWY_SSSE3 + return InsertLaneUsingBroadcastAndBlend(v, kLane, t); +#else + return Vec128{_mm_insert_ps(v.raw, _mm_set_ss(t), kLane << 4)}; +#endif +} + +// There is no insert_pd; two overloads because there is no UpperHalf for N=1. +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, double t) { + static_assert(kLane == 0, "Lane index out of bounds"); + return Set(DFromV(), t); +} + +template +HWY_INLINE Vec128 InsertLane(const Vec128 v, double t) { + static_assert(kLane < 2, "Lane index out of bounds"); + const DFromV d; + const Vec128 vt = Set(d, t); + if (kLane == 0) { + return Vec128{_mm_shuffle_pd(vt.raw, v.raw, 2)}; + } + return Vec128{_mm_shuffle_pd(v.raw, vt.raw, 0)}; +} + +} // namespace detail + +// Requires one overload per vector length because InsertLane<3> may be a +// compile error if it calls _mm_insert_epi64. + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { + HWY_DASSERT(i == 0); + (void)i; + return Set(DFromV(), t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + } + } +#endif + return detail::InsertLaneUsingBroadcastAndBlend(v, i, t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + } + } +#endif + return detail::InsertLaneUsingBroadcastAndBlend(v, i, t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + } + } +#endif + return detail::InsertLaneUsingBroadcastAndBlend(v, i, t); +} + +template +HWY_API Vec128 InsertLane(const Vec128 v, size_t i, T t) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(i)) { + switch (i) { + case 0: + return detail::InsertLane<0>(v, t); + case 1: + return detail::InsertLane<1>(v, t); + case 2: + return detail::InsertLane<2>(v, t); + case 3: + return detail::InsertLane<3>(v, t); + case 4: + return detail::InsertLane<4>(v, t); + case 5: + return detail::InsertLane<5>(v, t); + case 6: + return detail::InsertLane<6>(v, t); + case 7: + return detail::InsertLane<7>(v, t); + case 8: + return detail::InsertLane<8>(v, t); + case 9: + return detail::InsertLane<9>(v, t); + case 10: + return detail::InsertLane<10>(v, t); + case 11: + return detail::InsertLane<11>(v, t); + case 12: + return detail::InsertLane<12>(v, t); + case 13: + return detail::InsertLane<13>(v, t); + case 14: + return detail::InsertLane<14>(v, t); + case 15: + return detail::InsertLane<15>(v, t); + } + } +#endif + return detail::InsertLaneUsingBroadcastAndBlend(v, i, t); +} + +// ------------------------------ CombineShiftRightBytes + +#if HWY_TARGET == HWY_SSE2 +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + static_assert(0 < kBytes && kBytes < 16, "kBytes invalid"); + return Or(ShiftRightBytes(d, lo), ShiftLeftBytes<16 - kBytes>(d, hi)); +} +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + constexpr size_t kSize = d.MaxBytes(); + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); + + const Twice dt; + return VFromD{ShiftRightBytes(dt, Combine(dt, hi, lo)).raw}; +} +#else +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + const Repartition d8; + return BitCast(d, Vec128{_mm_alignr_epi8( + BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); +} + +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + constexpr size_t kSize = d.MaxBytes(); + static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid"); + const Repartition d8; + using V8 = Vec128; + const DFromV dfull8; + const Repartition, decltype(dfull8)> dfull; + const V8 hi8{BitCast(d8, hi).raw}; + // Move into most-significant bytes + const V8 lo8 = ShiftLeftBytes<16 - kSize>(V8{BitCast(d8, lo).raw}); + const V8 r = CombineShiftRightBytes<16 - kSize + kBytes>(dfull8, hi8, lo8); + return VFromD{BitCast(dfull, r).raw}; +} +#endif + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const VU vu = BitCast(du, v); // for float16_t + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + if (kLane < 4) { + const __m128i lo = _mm_shufflelo_epi16(vu.raw, (0x55 * kLane) & 0xFF); + return BitCast(d, VU{_mm_unpacklo_epi64(lo, lo)}); + } else { + const __m128i hi = _mm_shufflehi_epi16(vu.raw, (0x55 * (kLane - 4)) & 0xFF); + return BitCast(d, VU{_mm_unpackhi_epi64(hi, hi)}); + } +} + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + HWY_IF_CONSTEXPR(N == 1) { + // Workaround for MSVC compiler bug on single lane integer broadcast + return Vec128{v}; + } + HWY_IF_CONSTEXPR(N != 1) { + return Vec128{_mm_shuffle_epi32(v.raw, 0x55 * kLane)}; + } +} + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; +} + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_ps(v.raw, v.raw, 0x55 * kLane)}; +} + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + return Vec128{_mm_shuffle_pd(v.raw, v.raw, 3 * kLane)}; +} + +// ------------------------------ TableLookupLanes (Shuffle01) + +// Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. +template +struct Indices128 { + __m128i raw; +}; + +template , typename TI, size_t kN, + HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE(T, 1)> +HWY_API Indices128 IndicesFromVec(D d, Vec128 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Rebind di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, kN * 2)))); +#endif + + // No change as byte indices are always used for 8-bit lane types + (void)d; + return Indices128{vec.raw}; +} + +template , typename TI, size_t kN, + HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE(T, 2)> +HWY_API Indices128 IndicesFromVec(D d, Vec128 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Rebind di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, kN * 2)))); +#endif + +#if HWY_TARGET <= HWY_AVX3 || HWY_TARGET == HWY_SSE2 + (void)d; + return Indices128{vec.raw}; +#else // SSSE3, SSE4, or AVX2 + const Repartition d8; + using V8 = VFromD; + alignas(16) static constexpr uint8_t kByteOffsets[16] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}; + + // Broadcast each lane index to all 4 bytes of T + alignas(16) static constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14}; + const V8 lane_indices = TableLookupBytes(vec, Load(d8, kBroadcastLaneBytes)); + + // Shift to bytes + const Repartition d16; + const V8 byte_indices = BitCast(d8, ShiftLeft<1>(BitCast(d16, lane_indices))); + + return Indices128{Add(byte_indices, Load(d8, kByteOffsets)).raw}; +#endif // HWY_TARGET <= HWY_AVX3 || HWY_TARGET == HWY_SSE2 +} + +template , typename TI, size_t kN, + HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE(T, 4)> +HWY_API Indices128 IndicesFromVec(D d, Vec128 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Rebind di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, kN * 2)))); +#endif + +#if HWY_TARGET <= HWY_AVX2 || HWY_TARGET == HWY_SSE2 + (void)d; + return Indices128{vec.raw}; +#else + const Repartition d8; + using V8 = VFromD; + alignas(16) static constexpr uint8_t kByteOffsets[16] = { + 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; + + // Broadcast each lane index to all 4 bytes of T + alignas(16) static constexpr uint8_t kBroadcastLaneBytes[16] = { + 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12}; + const V8 lane_indices = TableLookupBytes(vec, Load(d8, kBroadcastLaneBytes)); + + // Shift to bytes + const Repartition d16; + const V8 byte_indices = BitCast(d8, ShiftLeft<2>(BitCast(d16, lane_indices))); + + return Indices128{Add(byte_indices, Load(d8, kByteOffsets)).raw}; +#endif +} + +template , typename TI, size_t kN, + HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE(T, 8)> +HWY_API Indices128 IndicesFromVec(D d, Vec128 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Rebind di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast(kN * 2))))); +#else + (void)d; +#endif + + // No change - even without AVX3, we can shuffle+blend. + return Indices128{vec.raw}; +} + +template +HWY_API Indices128, HWY_MAX_LANES_D(D)> SetTableIndices( + D d, const TI* idx) { + static_assert(sizeof(TFromD) == sizeof(TI), "Index size must match lane"); + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + return TableLookupBytes(v, Vec128{idx.raw}); +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { +#if HWY_TARGET <= HWY_AVX3 + return {_mm_permutexvar_epi16(idx.raw, v.raw)}; +#elif HWY_TARGET == HWY_SSE2 +#if HWY_COMPILER_GCC_ACTUAL && HWY_HAS_BUILTIN(__builtin_shuffle) + typedef uint16_t GccU16RawVectType __attribute__((__vector_size__(16))); + return Vec128{reinterpret_cast::type>( + __builtin_shuffle(reinterpret_cast(v.raw), + reinterpret_cast(idx.raw)))}; +#else + const Full128 d_full; + alignas(16) T src_lanes[8]; + alignas(16) uint16_t indices[8]; + alignas(16) T result_lanes[8]; + + Store(Vec128{v.raw}, d_full, src_lanes); + _mm_store_si128(reinterpret_cast<__m128i*>(indices), idx.raw); + + for (int i = 0; i < 8; i++) { + result_lanes[i] = src_lanes[indices[i] & 7u]; + } + + return Vec128{Load(d_full, result_lanes).raw}; +#endif // HWY_COMPILER_GCC_ACTUAL && HWY_HAS_BUILTIN(__builtin_shuffle) +#else + return TableLookupBytes(v, Vec128{idx.raw}); +#endif +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 TableLookupLanes(Vec128 v, + Indices128 idx) { + return {_mm_permutexvar_ph(idx.raw, v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + const DFromV d; + const Full128 d_full; + const Vec128 v_full = ZeroExtendResizeBitCast(d_full, d, v); + + const RebindToSigned di; + const Full128> di_full; + const VFromD vidx = + ZeroExtendResizeBitCast(di_full, di, VFromD{idx.raw}); + +#if HWY_TARGET <= HWY_AVX2 + // There is no permutevar for non-float; _mm256_permutevar8x32_epi32 is for + // 256-bit vectors, hence cast to float. + const Full128 df_full; + // Workaround for MSAN false positive. + HWY_IF_CONSTEXPR(HWY_IS_MSAN) PreventElision(GetLane(vidx)); + const Vec128 perm{ + _mm_permutevar_ps(BitCast(df_full, v_full).raw, vidx.raw)}; + return ResizeBitCast(d, perm); +#elif HWY_TARGET == HWY_SSE2 +#if HWY_COMPILER_GCC_ACTUAL && HWY_HAS_BUILTIN(__builtin_shuffle) + typedef uint32_t GccU32RawVectType __attribute__((__vector_size__(16))); + return Vec128{reinterpret_cast::type>( + __builtin_shuffle(reinterpret_cast(v_full.raw), + reinterpret_cast(vidx.raw)))}; +#else + alignas(16) T src_lanes[4]; + alignas(16) int32_t indices[4]; + alignas(16) T result_lanes[4]; + + Store(v_full, d_full, src_lanes); + Store(vidx, di_full, indices); + + for (size_t i = 0; i < N; i++) { + result_lanes[i] = src_lanes[static_cast(indices[i] & 3)]; + } + return Load(d, result_lanes); +#endif // HWY_COMPILER_GCC_ACTUAL && HWY_HAS_BUILTIN(__builtin_shuffle) +#else // SSSE3 or SSE4 + return ResizeBitCast(d, TableLookupBytes(BitCast(di_full, v_full), vidx)); +#endif +} + +// Single lane: no change +template +HWY_API Vec128 TableLookupLanes(Vec128 v, + Indices128 /* idx */) { + return v; +} + +template +HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { + const DFromV d; + // No need for ZeroExtendResizeBitCast, we have full vectors. + Vec128 vidx{idx.raw}; + + // Disable in MSAN builds due to false positive. Note that this affects + // CompressNot, which assumes upper index bits will be ignored. +#if HWY_TARGET <= HWY_AVX2 && !HWY_IS_MSAN + // There is no _mm_permute[x]var_epi64. + vidx += vidx; // bit1 is the decider (unusual) + const RebindToFloat df; + return BitCast( + d, Vec128{_mm_permutevar_pd(BitCast(df, v).raw, vidx.raw)}); +#else + // Only 2 lanes: can swap+blend. Choose v if vidx == iota. To avoid a 64-bit + // comparison (expensive on SSSE3), just invert the upper lane and subtract 1 + // to obtain an all-zero or all-one mask. + const RebindToSigned di; + const Vec128 same = (vidx ^ Iota(di, 0)) - Set(di, 1); + return BitCast( + d, IfVecThenElse(same, BitCast(di, v), Shuffle01(BitCast(di, v)))); +#endif +} + +// ------------------------------ ReverseBlocks + +// Single block: no change +template +HWY_API VFromD ReverseBlocks(D /* tag */, VFromD v) { + return v; +} + +// ------------------------------ Reverse (Shuffle0123, Shuffle2301) + +// Single lane: no change +template +HWY_API VFromD Reverse(D /* tag */, VFromD v) { + return v; +} + +// 32-bit x2: shuffle +template +HWY_API VFromD Reverse(D /* tag */, const VFromD v) { + return VFromD{Shuffle2301(Vec128>{v.raw}).raw}; +} + +// 64-bit x2: shuffle +template +HWY_API VFromD Reverse(D /* tag */, const VFromD v) { + return Shuffle01(v); +} + +// 32-bit x4: shuffle +template +HWY_API VFromD Reverse(D /* tag */, const VFromD v) { + return Shuffle0123(v); +} + +// 16-bit +template +HWY_API VFromD Reverse(D d, const VFromD v) { + const RebindToUnsigned du; + using VU = VFromD; + const VU vu = BitCast(du, v); // for float16_t + constexpr size_t kN = MaxLanes(d); + if (kN == 1) return v; + if (kN == 2) { + return BitCast(d, VU{_mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(0, 1, 0, 1))}); + } + if (kN == 4) { + return BitCast(d, VU{_mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(0, 1, 2, 3))}); + } + +#if HWY_TARGET == HWY_SSE2 + const VU rev4{ + _mm_shufflehi_epi16(_mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(0, 1, 2, 3)), + _MM_SHUFFLE(0, 1, 2, 3))}; + return BitCast(d, VU{_mm_shuffle_epi32(rev4.raw, _MM_SHUFFLE(1, 0, 3, 2))}); +#else + const RebindToSigned di; + const VFromD shuffle = Dup128VecFromValues( + di, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100); + return BitCast(d, TableLookupBytes(v, shuffle)); +#endif +} + +template +HWY_API VFromD Reverse(D d, const VFromD v) { + constexpr int kN = static_cast(MaxLanes(d)); + if (kN == 1) return v; +#if HWY_TARGET <= HWY_SSSE3 + // NOTE: Lanes with negative shuffle control mask values are set to zero. + alignas(16) static constexpr int8_t kReverse[16] = { + kN - 1, kN - 2, kN - 3, kN - 4, kN - 5, kN - 6, kN - 7, kN - 8, + kN - 9, kN - 10, kN - 11, kN - 12, kN - 13, kN - 14, kN - 15, kN - 16}; + const RebindToSigned di; + const VFromD idx = Load(di, kReverse); + return VFromD{_mm_shuffle_epi8(BitCast(di, v).raw, idx.raw)}; +#else + const RepartitionToWide d16; + return BitCast(d, Reverse(d16, RotateRight<8>(BitCast(d16, v)))); +#endif +} + +// ------------------------------ Reverse2 + +// Single lane: no change +template +HWY_API VFromD Reverse2(D /* tag */, VFromD v) { + return v; +} + +// Generic for all vector lengths (128-bit sufficient if SSE2). +template +HWY_API VFromD Reverse2(D d, VFromD v) { +#if HWY_TARGET <= HWY_AVX3 + const Repartition du32; + return BitCast(d, RotateRight<16>(BitCast(du32, v))); +#elif HWY_TARGET == HWY_SSE2 + const RebindToUnsigned du; + using VU = VFromD; + const VU vu = BitCast(du, v); // for float16_t + constexpr size_t kN = MaxLanes(d); + __m128i shuf_result = _mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(2, 3, 0, 1)); + if (kN > 4) { + shuf_result = _mm_shufflehi_epi16(shuf_result, _MM_SHUFFLE(2, 3, 0, 1)); + } + return BitCast(d, VU{shuf_result}); +#else + const RebindToSigned di; + const VFromD shuffle = Dup128VecFromValues( + di, 0x0302, 0x0100, 0x0706, 0x0504, 0x0B0A, 0x0908, 0x0F0E, 0x0D0C); + return BitCast(d, TableLookupBytes(v, shuffle)); +#endif +} + +// Generic for all vector lengths. +template +HWY_API VFromD Reverse2(D /* tag */, VFromD v) { + return Shuffle2301(v); +} + +// Generic for all vector lengths. +template +HWY_API VFromD Reverse2(D /* tag */, VFromD v) { + return Shuffle01(v); +} + +// ------------------------------ Reverse4 + +template +HWY_API VFromD Reverse4(D d, VFromD v) { + const RebindToUnsigned du; + using VU = VFromD; + const VU vu = BitCast(du, v); // for float16_t + // 4x 16-bit: a single shufflelo suffices. + constexpr size_t kN = MaxLanes(d); + if (kN <= 4) { + return BitCast(d, VU{_mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(0, 1, 2, 3))}); + } + +#if HWY_TARGET == HWY_SSE2 + return BitCast(d, VU{_mm_shufflehi_epi16( + _mm_shufflelo_epi16(vu.raw, _MM_SHUFFLE(0, 1, 2, 3)), + _MM_SHUFFLE(0, 1, 2, 3))}); +#else + const RebindToSigned di; + const VFromD shuffle = Dup128VecFromValues( + di, 0x0706, 0x0504, 0x0302, 0x0100, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908); + return BitCast(d, TableLookupBytes(v, shuffle)); +#endif +} + +// Generic for all vector lengths. +template +HWY_API VFromD Reverse4(D /* tag */, const VFromD v) { + return Shuffle0123(v); +} + +template +HWY_API VFromD Reverse4(D /* tag */, VFromD /* v */) { + HWY_ASSERT(0); // don't have 4 u64 lanes +} + +// ------------------------------ Reverse8 + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { +#if HWY_TARGET == HWY_SSE2 + const RepartitionToWide dw; + return Reverse2(d, BitCast(d, Shuffle0123(BitCast(dw, v)))); +#else + const RebindToSigned di; + const VFromD shuffle = Dup128VecFromValues( + di, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100); + return BitCast(d, TableLookupBytes(v, shuffle)); +#endif +} + +template +HWY_API VFromD Reverse8(D /* tag */, VFromD /* v */) { + HWY_ASSERT(0); // don't have 8 lanes if larger than 16-bit +} + +// ------------------------------ ReverseBits in x86_512 + +// ------------------------------ InterleaveUpper (UpperHalf) + +// Full +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm_unpackhi_epi8(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; // for float16_t + return BitCast( + d, VU{_mm_unpackhi_epi16(BitCast(du, a).raw, BitCast(du, b).raw)}); +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm_unpackhi_epi32(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm_unpackhi_epi64(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm_unpackhi_ps(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm_unpackhi_pd(a.raw, b.raw)}; +} + +// Partial +template +HWY_API VFromD InterleaveUpper(D d, VFromD a, VFromD b) { + const Half d2; + return InterleaveLower(d, VFromD{UpperHalf(d2, a).raw}, + VFromD{UpperHalf(d2, b).raw}); +} + +// -------------------------- I8/U8 Broadcast (InterleaveLower, InterleaveUpper) + +template +HWY_API Vec128 Broadcast(const Vec128 v) { + static_assert(0 <= kLane && kLane < N, "Invalid lane"); + const DFromV d; + +#if HWY_TARGET == HWY_SSE2 + const Full128 d_full; + const Vec128 v_full{v.raw}; + const auto v_interleaved = (kLane < 8) + ? InterleaveLower(d_full, v_full, v_full) + : InterleaveUpper(d_full, v_full, v_full); + return ResizeBitCast( + d, Broadcast(BitCast(Full128(), v_interleaved))); +#else + return TableLookupBytes(v, Set(d, static_cast(kLane))); +#endif +} + +// ------------------------------ ZipLower/ZipUpper (InterleaveLower) + +// Same as Interleave*, except that the return lanes are double-width integers; +// this is necessary because the single-lane scalar cannot return two values. +// Generic for all vector lengths. +template >> +HWY_API VFromD ZipLower(V a, V b) { + return BitCast(DW(), InterleaveLower(a, b)); +} +template , class DW = RepartitionToWide> +HWY_API VFromD ZipLower(DW dw, V a, V b) { + return BitCast(dw, InterleaveLower(D(), a, b)); +} + +template , class DW = RepartitionToWide> +HWY_API VFromD ZipUpper(DW dw, V a, V b) { + return BitCast(dw, InterleaveUpper(D(), a, b)); +} + +// ================================================== CONVERT (1) + +// ------------------------------ PromoteTo unsigned (TableLookupBytesOr0) +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + const __m128i zero = _mm_setzero_si128(); + return VFromD{_mm_unpacklo_epi8(v.raw, zero)}; +#else + return VFromD{_mm_cvtepu8_epi16(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + return VFromD{_mm_unpacklo_epi16(v.raw, _mm_setzero_si128())}; +#else + return VFromD{_mm_cvtepu16_epi32(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + return VFromD{_mm_unpacklo_epi32(v.raw, _mm_setzero_si128())}; +#else + return VFromD{_mm_cvtepu32_epi64(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + const __m128i zero = _mm_setzero_si128(); + const __m128i u16 = _mm_unpacklo_epi8(v.raw, zero); + return VFromD{_mm_unpacklo_epi16(u16, zero)}; +#else + return VFromD{_mm_cvtepu8_epi32(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { +#if HWY_TARGET > HWY_SSSE3 + const Rebind du32; + return PromoteTo(d, PromoteTo(du32, v)); +#elif HWY_TARGET == HWY_SSSE3 + alignas(16) static constexpr int8_t kShuffle[16] = { + 0, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1}; + const Repartition di8; + return TableLookupBytesOr0(v, BitCast(d, Load(di8, kShuffle))); +#else + (void)d; + return VFromD{_mm_cvtepu8_epi64(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { +#if HWY_TARGET > HWY_SSSE3 + const Rebind du32; + return PromoteTo(d, PromoteTo(du32, v)); +#elif HWY_TARGET == HWY_SSSE3 + alignas(16) static constexpr int8_t kShuffle[16] = { + 0, 1, -1, -1, -1, -1, -1, -1, 2, 3, -1, -1, -1, -1, -1, -1}; + const Repartition di8; + return TableLookupBytesOr0(v, BitCast(d, Load(di8, kShuffle))); +#else + (void)d; + return VFromD{_mm_cvtepu16_epi64(v.raw)}; +#endif +} + +// Unsigned to signed: same plus cast. +template ), sizeof(TFromV)), + HWY_IF_LANES_D(D, HWY_MAX_LANES_V(V))> +HWY_API VFromD PromoteTo(D di, V v) { + const RebindToUnsigned du; + return BitCast(di, PromoteTo(du, v)); +} + +// ------------------------------ PromoteTo signed (ShiftRight, ZipLower) + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + return ShiftRight<8>(VFromD{_mm_unpacklo_epi8(v.raw, v.raw)}); +#else + return VFromD{_mm_cvtepi8_epi16(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + return ShiftRight<16>(VFromD{_mm_unpacklo_epi16(v.raw, v.raw)}); +#else + return VFromD{_mm_cvtepi16_epi32(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + return ShiftRight<32>(VFromD{_mm_unpacklo_epi32(v.raw, v.raw)}); +#else + return VFromD{_mm_cvtepi32_epi64(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + const __m128i x2 = _mm_unpacklo_epi8(v.raw, v.raw); + const __m128i x4 = _mm_unpacklo_epi16(x2, x2); + return ShiftRight<24>(VFromD{x4}); +#else + return VFromD{_mm_cvtepi8_epi32(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + const Repartition di32; + const Half dh_i32; + const VFromD x4{PromoteTo(dh_i32, v).raw}; + const VFromD s4{ + _mm_shufflelo_epi16(x4.raw, _MM_SHUFFLE(3, 3, 1, 1))}; + return ZipLower(d, x4, s4); +#else + (void)d; + return VFromD{_mm_cvtepi8_epi64(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteTo(D d, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + const Repartition di32; + const Half dh_i32; + const VFromD x2{PromoteTo(dh_i32, v).raw}; + const VFromD s2{ + _mm_shufflelo_epi16(x2.raw, _MM_SHUFFLE(3, 3, 1, 1))}; + return ZipLower(d, x2, s2); +#else + (void)d; + return VFromD{_mm_cvtepi16_epi64(v.raw)}; +#endif +} + +// -------------------- PromoteTo float (ShiftLeft, IfNegativeThenElse) +#if HWY_TARGET < HWY_SSE4 && !defined(HWY_DISABLE_F16C) + +// Per-target flag to prevent generic_ops-inl.h from defining f16 conversions. +#ifdef HWY_NATIVE_F16C +#undef HWY_NATIVE_F16C +#else +#define HWY_NATIVE_F16C +#endif + +// Workaround for origin tracking bug in Clang msan prior to 11.0 +// (spurious "uninitialized memory" for TestF16 with "ORIGIN: invalid") +#if HWY_IS_MSAN && (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 1100) +#define HWY_INLINE_F16 HWY_NOINLINE +#else +#define HWY_INLINE_F16 HWY_INLINE +#endif +template +HWY_INLINE_F16 VFromD PromoteTo(D /*tag*/, VFromD> v) { +#if HWY_HAVE_FLOAT16 + const RebindToUnsigned> du16; + return VFromD{_mm_cvtph_ps(BitCast(du16, v).raw)}; +#else + return VFromD{_mm_cvtph_ps(v.raw)}; +#endif +} + +#endif // HWY_NATIVE_F16C + +#if HWY_HAVE_FLOAT16 + +#ifdef HWY_NATIVE_PROMOTE_F16_TO_F64 +#undef HWY_NATIVE_PROMOTE_F16_TO_F64 +#else +#define HWY_NATIVE_PROMOTE_F16_TO_F64 +#endif + +template +HWY_INLINE VFromD PromoteTo(D /*tag*/, VFromD> v) { + return VFromD{_mm_cvtph_pd(v.raw)}; +} + +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API VFromD PromoteTo(D df32, VFromD> v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtps_pd(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtepi32_pd(v.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD PromoteTo(D /*df64*/, VFromD> v) { + return VFromD{_mm_cvtepu32_pd(v.raw)}; +} +#else +// Generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD PromoteTo(D df64, VFromD> v) { + const Rebind di32; + const auto i32_to_f64_result = PromoteTo(df64, BitCast(di32, v)); + return i32_to_f64_result + IfNegativeThenElse(i32_to_f64_result, + Set(df64, 4294967296.0), + Zero(df64)); +} +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Per4LaneBlockShuffle +namespace detail { + +#ifdef HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#undef HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#else +#define HWY_NATIVE_PER4LANEBLKSHUF_DUP32 +#endif + +template +HWY_INLINE VFromD Per4LaneBlkShufDupSet4xU32(D d, const uint32_t x3, + const uint32_t x2, + const uint32_t x1, + const uint32_t x0) { + return ResizeBitCast( + d, Vec128{_mm_set_epi32( + static_cast(x3), static_cast(x2), + static_cast(x1), static_cast(x0))}); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, + hwy::SizeTag<8> /*vect_size_tag*/, V v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast(d, + VFromD{_mm_shufflelo_epi16( + BitCast(du, v).raw, static_cast(kIdx3210 & 0xFF))}); +} + +#if HWY_TARGET == HWY_SSE2 +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<2> /*lane_size_tag*/, + hwy::SizeTag<16> /*vect_size_tag*/, V v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + constexpr int kShuffle = static_cast(kIdx3210 & 0xFF); + return BitCast( + d, VFromD{_mm_shufflehi_epi16( + _mm_shufflelo_epi16(BitCast(du, v).raw, kShuffle), kShuffle)}); +} + +template * = nullptr> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag idx_3210_tag, + hwy::SizeTag<1> /*lane_size_tag*/, + hwy::SizeTag /*vect_size_tag*/, + V v) { + const DFromV d; + const RebindToUnsigned du; + const Rebind du16; + const RebindToSigned di16; + + const auto vu16 = PromoteTo(du16, BitCast(du, v)); + const auto shuf16_result = Per4LaneBlockShuffle( + idx_3210_tag, hwy::SizeTag<2>(), hwy::SizeTag(), vu16); + return BitCast(d, DemoteTo(du, BitCast(di16, shuf16_result))); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag idx_3210_tag, + hwy::SizeTag<1> /*lane_size_tag*/, + hwy::SizeTag<16> /*vect_size_tag*/, V v) { + const DFromV d; + const RebindToUnsigned du; + const Repartition du16; + const RebindToSigned di16; + + const auto zero = Zero(d); + const auto v_lo16 = BitCast(du16, InterleaveLower(d, v, zero)); + const auto v_hi16 = BitCast(du16, InterleaveUpper(d, v, zero)); + + const auto lo_shuf_result = Per4LaneBlockShuffle( + idx_3210_tag, hwy::SizeTag<2>(), hwy::SizeTag<16>(), v_lo16); + const auto hi_shuf_result = Per4LaneBlockShuffle( + idx_3210_tag, hwy::SizeTag<2>(), hwy::SizeTag<16>(), v_hi16); + + return BitCast(d, OrderedDemote2To(du, BitCast(di16, lo_shuf_result), + BitCast(di16, hi_shuf_result))); +} +#endif + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, + hwy::SizeTag<16> /*vect_size_tag*/, V v) { + return V{_mm_shuffle_epi32(v.raw, static_cast(kIdx3210 & 0xFF))}; +} + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, + hwy::SizeTag<16> /*vect_size_tag*/, V v) { + return V{_mm_shuffle_ps(v.raw, v.raw, static_cast(kIdx3210 & 0xFF))}; +} + +} // namespace detail + +// ------------------------------ SlideUpLanes + +namespace detail { + +template +HWY_INLINE V SlideUpLanes(V v, size_t amt) { + const DFromV d; + const Full64 du64; + const auto vu64 = ResizeBitCast(du64, v); + return ResizeBitCast( + d, ShiftLeftSame(vu64, static_cast(amt * sizeof(TFromV) * 8))); +} + +#if HWY_TARGET <= HWY_SSSE3 +template +HWY_INLINE V SlideUpLanes(V v, size_t amt) { + const DFromV d; + const Repartition du8; + const auto idx = + Iota(du8, static_cast(size_t{0} - amt * sizeof(TFromV))); + return BitCast(d, TableLookupBytesOr0(BitCast(du8, v), idx)); +} +#else +template +HWY_INLINE V SlideUpLanes(V v, size_t amt) { + const DFromV d; + const Repartition di32; + const Repartition du64; + constexpr size_t kNumOfLanesPerU64 = 8 / sizeof(TFromV); + + const auto vu64 = BitCast(du64, v); + const auto v_hi = IfVecThenElse( + BitCast(du64, Set(di32, -static_cast(amt >= kNumOfLanesPerU64))), + BitCast(du64, ShiftLeftBytes<8>(du64, vu64)), vu64); + const auto v_lo = ShiftLeftBytes<8>(du64, v_hi); + + const int shl_amt = static_cast((amt * sizeof(TFromV) * 8) & 63); + return BitCast( + d, Or(ShiftLeftSame(v_hi, shl_amt), ShiftRightSame(v_lo, 64 - shl_amt))); +} +#endif + +} // namespace detail + +template +HWY_API VFromD SlideUpLanes(D /*d*/, VFromD v, size_t /*amt*/) { + return v; +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + case 2: + return ShiftLeftLanes<2>(d, v); + case 3: + return ShiftLeftLanes<3>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + case 2: + return ShiftLeftLanes<2>(d, v); + case 3: + return ShiftLeftLanes<3>(d, v); + case 4: + return ShiftLeftLanes<4>(d, v); + case 5: + return ShiftLeftLanes<5>(d, v); + case 6: + return ShiftLeftLanes<6>(d, v); + case 7: + return ShiftLeftLanes<7>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftLeftLanes<1>(d, v); + case 2: + return ShiftLeftLanes<2>(d, v); + case 3: + return ShiftLeftLanes<3>(d, v); + case 4: + return ShiftLeftLanes<4>(d, v); + case 5: + return ShiftLeftLanes<5>(d, v); + case 6: + return ShiftLeftLanes<6>(d, v); + case 7: + return ShiftLeftLanes<7>(d, v); + case 8: + return ShiftLeftLanes<8>(d, v); + case 9: + return ShiftLeftLanes<9>(d, v); + case 10: + return ShiftLeftLanes<10>(d, v); + case 11: + return ShiftLeftLanes<11>(d, v); + case 12: + return ShiftLeftLanes<12>(d, v); + case 13: + return ShiftLeftLanes<13>(d, v); + case 14: + return ShiftLeftLanes<14>(d, v); + case 15: + return ShiftLeftLanes<15>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideUpLanes(v, amt); +} + +// ------------------------------ SlideDownLanes + +namespace detail { + +template +HWY_INLINE V SlideDownLanes(V v, size_t amt) { + const DFromV d; + const Repartition, decltype(d)> dv; + return BitCast(d, + ShiftRightSame(BitCast(dv, v), + static_cast(amt * sizeof(TFromV) * 8))); +} + +#if HWY_TARGET <= HWY_SSSE3 +template +HWY_INLINE V SlideDownLanes(V v, size_t amt) { + const DFromV d; + const Repartition di8; + auto idx = Iota(di8, static_cast(amt * sizeof(TFromV))); + idx = Or(idx, VecFromMask(di8, idx > Set(di8, int8_t{15}))); + return BitCast(d, TableLookupBytesOr0(BitCast(di8, v), idx)); +} +#else +template +HWY_INLINE V SlideDownLanes(V v, size_t amt) { + const DFromV d; + const Repartition di32; + const Repartition du64; + constexpr size_t kNumOfLanesPerU64 = 8 / sizeof(TFromV); + + const auto vu64 = BitCast(du64, v); + const auto v_lo = IfVecThenElse( + BitCast(du64, Set(di32, -static_cast(amt >= kNumOfLanesPerU64))), + BitCast(du64, ShiftRightBytes<8>(du64, vu64)), vu64); + const auto v_hi = ShiftRightBytes<8>(du64, v_lo); + + const int shr_amt = static_cast((amt * sizeof(TFromV) * 8) & 63); + return BitCast( + d, Or(ShiftRightSame(v_lo, shr_amt), ShiftLeftSame(v_hi, 64 - shr_amt))); +} +#endif + +} // namespace detail + +template +HWY_API VFromD SlideDownLanes(D /*d*/, VFromD v, size_t /*amt*/) { + return v; +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + case 4: + return ShiftRightLanes<4>(d, v); + case 5: + return ShiftRightLanes<5>(d, v); + case 6: + return ShiftRightLanes<6>(d, v); + case 7: + return ShiftRightLanes<7>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return ShiftRightLanes<1>(d, v); + case 2: + return ShiftRightLanes<2>(d, v); + case 3: + return ShiftRightLanes<3>(d, v); + case 4: + return ShiftRightLanes<4>(d, v); + case 5: + return ShiftRightLanes<5>(d, v); + case 6: + return ShiftRightLanes<6>(d, v); + case 7: + return ShiftRightLanes<7>(d, v); + case 8: + return ShiftRightLanes<8>(d, v); + case 9: + return ShiftRightLanes<9>(d, v); + case 10: + return ShiftRightLanes<10>(d, v); + case 11: + return ShiftRightLanes<11>(d, v); + case 12: + return ShiftRightLanes<12>(d, v); + case 13: + return ShiftRightLanes<13>(d, v); + case 14: + return ShiftRightLanes<14>(d, v); + case 15: + return ShiftRightLanes<15>(d, v); + } + } +#else + (void)d; +#endif + + return detail::SlideDownLanes(v, amt); +} + +// ================================================== MEMORY (4) + +// ------------------------------ StoreN (ExtractLane) + +#if HWY_TARGET <= HWY_AVX2 + +#ifdef HWY_NATIVE_STORE_N +#undef HWY_NATIVE_STORE_N +#else +#define HWY_NATIVE_STORE_N +#endif + +template +HWY_API void StoreN(VFromD v, D d, TFromD* HWY_RESTRICT p, + size_t max_lanes_to_store) { + const size_t num_lanes_to_store = + HWY_MIN(max_lanes_to_store, HWY_MAX_LANES_D(D)); + +#if HWY_COMPILER_MSVC + // Work around MSVC compiler bug by using a HWY_FENCE before the BlendedStore + HWY_FENCE; +#endif + + BlendedStore(v, FirstN(d, num_lanes_to_store), d, p); + +#if HWY_COMPILER_MSVC + // Work around MSVC compiler bug by using a HWY_FENCE after the BlendedStore + HWY_FENCE; +#endif + + detail::MaybeUnpoison(p, num_lanes_to_store); +} + +#if HWY_TARGET > HWY_AVX3 +template +HWY_API void StoreN(VFromD v, D d, TFromD* HWY_RESTRICT p, + size_t max_lanes_to_store) { + if (max_lanes_to_store > 0) { + StoreU(v, d, p); + } +} + +template +HWY_API void StoreN(VFromD v, D /*d*/, TFromD* HWY_RESTRICT p, + size_t max_lanes_to_store) { + if (max_lanes_to_store >= 1) { + p[static_cast(max_lanes_to_store > 1)] = detail::ExtractLane<1>(v); + p[0] = GetLane(v); + } +} + +namespace detail { + +template +HWY_API void AVX2UIF8Or16StoreTrailingN(VFromD v_trailing, D /*d*/, + TFromD* HWY_RESTRICT p, + size_t num_lanes_to_store) { + // AVX2UIF8Or16StoreTrailingN should only be called for an I8/U8 vector if + // (num_lanes_to_store & 3) != 0 is true + const auto v_full128 = ResizeBitCast(Full128>(), v_trailing); + if ((num_lanes_to_store & 2) != 0) { + const uint16_t u16_bits = GetLane(BitCast(Full128(), v_full128)); + p[num_lanes_to_store - 1] = detail::ExtractLane<2>(v_full128); + CopyBytes(&u16_bits, + p + (num_lanes_to_store & ~size_t{3})); + } else { + p[num_lanes_to_store - 1] = GetLane(v_full128); + } +} + +template +HWY_API void AVX2UIF8Or16StoreTrailingN(VFromD v_trailing, D /*d*/, + TFromD* p, + size_t num_lanes_to_store) { + // AVX2UIF8Or16StoreTrailingN should only be called for an I16/U16/F16/BF16 + // vector if (num_lanes_to_store & 1) == 1 is true + p[num_lanes_to_store - 1] = GetLane(v_trailing); +} + +} // namespace detail + +template +HWY_API void StoreN(VFromD v, D d, TFromD* p, size_t max_lanes_to_store) { + const size_t num_lanes_to_store = + HWY_MIN(max_lanes_to_store, HWY_MAX_LANES_D(D)); + + const FixedTag, HWY_MAX(HWY_MAX_LANES_D(D), 16 / sizeof(TFromD))> + d_full; + const RebindToUnsigned du_full; + const Repartition di32_full; + + const auto i32_store_mask = BitCast( + di32_full, VecFromMask(du_full, FirstN(du_full, num_lanes_to_store))); + const auto vi32 = ResizeBitCast(di32_full, v); + +#if HWY_COMPILER_MSVC + // Work around MSVC compiler bug by using a HWY_FENCE before the BlendedStore + HWY_FENCE; +#endif + + BlendedStore(vi32, MaskFromVec(i32_store_mask), di32_full, + reinterpret_cast(p)); + + constexpr size_t kNumOfLanesPerI32 = 4 / sizeof(TFromD); + constexpr size_t kTrailingLenMask = kNumOfLanesPerI32 - 1; + const size_t trailing_n = (num_lanes_to_store & kTrailingLenMask); + + if (trailing_n != 0) { + const VFromD v_trailing = ResizeBitCast( + d, SlideDownLanes(di32_full, vi32, + num_lanes_to_store / kNumOfLanesPerI32)); + detail::AVX2UIF8Or16StoreTrailingN(v_trailing, d, p, num_lanes_to_store); + } + +#if HWY_COMPILER_MSVC + // Work around MSVC compiler bug by using a HWY_FENCE after the BlendedStore + HWY_FENCE; +#endif + + detail::MaybeUnpoison(p, num_lanes_to_store); +} +#endif // HWY_TARGET > HWY_AVX3 +#endif // HWY_TARGET <= HWY_AVX2 + +// ================================================== COMBINE + +// ------------------------------ Combine (InterleaveLower) + +// N = N/2 + N/2 (upper half undefined) +template >> +HWY_API VFromD Combine(D d, VH hi_half, VH lo_half) { + const Half dh; + const RebindToUnsigned duh; + // Treat half-width input as one lane, and expand to two lanes. + using VU = Vec128, 2>; + const VU lo{BitCast(duh, lo_half).raw}; + const VU hi{BitCast(duh, hi_half).raw}; + return BitCast(d, InterleaveLower(lo, hi)); +} + +// ------------------------------ ZeroExtendVector (Combine, IfThenElseZero) + +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { + const RebindToUnsigned du; + const Half duh; + return BitCast(d, VFromD{_mm_move_epi64(BitCast(duh, lo).raw)}); +} + +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { + const Half dh; + return IfThenElseZero(FirstN(d, MaxLanes(dh)), VFromD{lo.raw}); +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { + const RebindToUnsigned du; + const Half duh; + return BitCast(d, ZeroExtendVector(du, BitCast(duh, lo))); +} +#endif + +// Generic for all vector lengths. +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { + const RebindToUnsigned du; + const Half duh; + return BitCast(d, ZeroExtendVector(du, BitCast(duh, lo))); +} + +// ------------------------------ Concat full (InterleaveLower) + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + const Repartition d64; + return BitCast(d, InterleaveLower(BitCast(d64, lo), BitCast(d64, hi))); +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + const Repartition d64; + return BitCast(d, InterleaveUpper(d64, BitCast(d64, lo), BitCast(d64, hi))); +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves) +template +HWY_API VFromD ConcatLowerUpper(D d, VFromD hi, VFromD lo) { + return CombineShiftRightBytes<8>(d, hi, lo); +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template +HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { + const Repartition dd; +#if HWY_TARGET >= HWY_SSSE3 + return BitCast( + d, Vec128{_mm_shuffle_pd(BitCast(dd, lo).raw, BitCast(dd, hi).raw, + _MM_SHUFFLE2(1, 0))}); +#else + // _mm_blend_epi16 has throughput 1/cycle on SKX, whereas _pd can do 3/cycle. + return BitCast(d, Vec128{_mm_blend_pd(BitCast(dd, hi).raw, + BitCast(dd, lo).raw, 1)}); +#endif +} +template +HWY_API Vec128 ConcatUpperLower(D d, Vec128 hi, + Vec128 lo) { +#if HWY_TARGET >= HWY_SSSE3 + (void)d; + return Vec128{_mm_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(3, 2, 1, 0))}; +#else + // _mm_shuffle_ps has throughput 1/cycle on SKX, whereas blend can do 3/cycle. + const RepartitionToWide dd; + return BitCast(d, Vec128{_mm_blend_pd(BitCast(dd, hi).raw, + BitCast(dd, lo).raw, 1)}); +#endif +} +template +HWY_API Vec128 ConcatUpperLower(D /* tag */, Vec128 hi, + Vec128 lo) { +#if HWY_TARGET >= HWY_SSSE3 + return Vec128{_mm_shuffle_pd(lo.raw, hi.raw, _MM_SHUFFLE2(1, 0))}; +#else + // _mm_shuffle_pd has throughput 1/cycle on SKX, whereas blend can do 3/cycle. + return Vec128{_mm_blend_pd(hi.raw, lo.raw, 1)}; +#endif +} + +// ------------------------------ Concat partial (Combine, LowerHalf) + +template +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), LowerHalf(d2, lo)); +} + +template +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, UpperHalf(d2, hi), UpperHalf(d2, lo)); +} + +template +HWY_API VFromD ConcatLowerUpper(D d, const VFromD hi, + const VFromD lo) { + const Half d2; + return Combine(d, LowerHalf(d2, hi), UpperHalf(d2, lo)); +} + +template +HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { + const Half d2; + return Combine(d, UpperHalf(d2, hi), LowerHalf(d2, lo)); +} + +// ------------------------------ ConcatOdd + +// 8-bit full +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const Repartition dw; + // Right-shift 8 bits per u16 so we can pack. + const Vec128 uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec128 uL = ShiftRight<8>(BitCast(dw, lo)); + return VFromD{_mm_packus_epi16(uL.raw, uH.raw)}; +} + +// 8-bit x8 +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { +#if HWY_TARGET == HWY_SSE2 + const Repartition dw; + // Right-shift 8 bits per u16 so we can pack. + const Vec64 uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec64 uL = ShiftRight<8>(BitCast(dw, lo)); + return VFromD{_mm_shuffle_epi32(_mm_packus_epi16(uL.raw, uH.raw), + _MM_SHUFFLE(2, 0, 2, 0))}; +#else + const Repartition du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactOddU8[8] = {1, 3, 5, 7}; + const VFromD shuf = BitCast(d, Load(Full64(), kCompactOddU8)); + const VFromD L = TableLookupBytes(lo, shuf); + const VFromD H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +#endif +} + +// 8-bit x4 +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { +#if HWY_TARGET == HWY_SSE2 + const Repartition dw; + const Twice dw_2; + // Right-shift 8 bits per u16 so we can pack. + const Vec32 uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec32 uL = ShiftRight<8>(BitCast(dw, lo)); + const Vec64 uHL = Combine(dw_2, uH, uL); + return VFromD{_mm_packus_epi16(uHL.raw, uHL.raw)}; +#else + const Repartition du16; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactOddU8[4] = {1, 3}; + const VFromD shuf = BitCast(d, Load(Full32(), kCompactOddU8)); + const VFromD L = TableLookupBytes(lo, shuf); + const VFromD H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du16, BitCast(du16, L), BitCast(du16, H))); +#endif +} + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + // Right-shift 16 bits per i32 - a *signed* shift of 0x8000xxxx returns + // 0xFFFF8000, which correctly saturates to 0x8000. + const RebindToUnsigned du; + const Repartition dw; + const Vec128 uH = ShiftRight<16>(BitCast(dw, hi)); + const Vec128 uL = ShiftRight<16>(BitCast(dw, lo)); + return BitCast(d, VFromD{_mm_packs_epi32(uL.raw, uH.raw)}); +} + +// 16-bit x4 +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { +#if HWY_TARGET == HWY_SSE2 + // Right-shift 16 bits per i32 - a *signed* shift of 0x8000xxxx returns + // 0xFFFF8000, which correctly saturates to 0x8000. + const Repartition dw; + const Vec64 uH = ShiftRight<16>(BitCast(dw, hi)); + const Vec64 uL = ShiftRight<16>(BitCast(dw, lo)); + return VFromD{_mm_shuffle_epi32(_mm_packs_epi32(uL.raw, uH.raw), + _MM_SHUFFLE(2, 0, 2, 0))}; +#else + const Repartition du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactOddU16[8] = {2, 3, 6, 7}; + const VFromD shuf = BitCast(d, Load(Full64(), kCompactOddU16)); + const VFromD L = TableLookupBytes(lo, shuf); + const VFromD H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +#endif +} + +// 32-bit full +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToFloat df; + return BitCast( + d, Vec128{_mm_shuffle_ps(BitCast(df, lo).raw, BitCast(df, hi).raw, + _MM_SHUFFLE(3, 1, 3, 1))}); +} + +// Any type x2 +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + return InterleaveUpper(d, lo, hi); +} + +// ------------------------------ ConcatEven (InterleaveLower) + +// 8-bit full +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const Repartition dw; + // Isolate lower 8 bits per u16 so we can pack. + const Vec128 mask = Set(dw, 0x00FF); + const Vec128 uH = And(BitCast(dw, hi), mask); + const Vec128 uL = And(BitCast(dw, lo), mask); + return VFromD{_mm_packus_epi16(uL.raw, uH.raw)}; +} + +// 8-bit x8 +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { +#if HWY_TARGET == HWY_SSE2 + const Repartition dw; + // Isolate lower 8 bits per u16 so we can pack. + const Vec64 mask = Set(dw, 0x00FF); + const Vec64 uH = And(BitCast(dw, hi), mask); + const Vec64 uL = And(BitCast(dw, lo), mask); + return VFromD{_mm_shuffle_epi32(_mm_packus_epi16(uL.raw, uH.raw), + _MM_SHUFFLE(2, 0, 2, 0))}; +#else + const Repartition du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactEvenU8[8] = {0, 2, 4, 6}; + const VFromD shuf = BitCast(d, Load(Full64(), kCompactEvenU8)); + const VFromD L = TableLookupBytes(lo, shuf); + const VFromD H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +#endif +} + +// 8-bit x4 +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { +#if HWY_TARGET == HWY_SSE2 + const Repartition dw; + const Twice dw_2; + // Isolate lower 8 bits per u16 so we can pack. + const Vec32 mask = Set(dw, 0x00FF); + const Vec32 uH = And(BitCast(dw, hi), mask); + const Vec32 uL = And(BitCast(dw, lo), mask); + const Vec64 uHL = Combine(dw_2, uH, uL); + return VFromD{_mm_packus_epi16(uHL.raw, uHL.raw)}; +#else + const Repartition du16; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactEvenU8[4] = {0, 2}; + const VFromD shuf = BitCast(d, Load(Full32(), kCompactEvenU8)); + const VFromD L = TableLookupBytes(lo, shuf); + const VFromD H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du16, BitCast(du16, L), BitCast(du16, H))); +#endif +} + +// 16-bit full +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { +#if HWY_TARGET <= HWY_SSE4 + // Isolate lower 16 bits per u32 so we can pack. + const RebindToUnsigned du; // for float16_t + const Repartition dw; + const Vec128 mask = Set(dw, 0x0000FFFF); + const Vec128 uH = And(BitCast(dw, hi), mask); + const Vec128 uL = And(BitCast(dw, lo), mask); + return BitCast(d, VFromD{_mm_packus_epi32(uL.raw, uH.raw)}); +#elif HWY_TARGET == HWY_SSE2 + const Repartition dw; + return ConcatOdd(d, BitCast(d, ShiftLeft<16>(BitCast(dw, hi))), + BitCast(d, ShiftLeft<16>(BitCast(dw, lo)))); +#else + const RebindToUnsigned du; + // packs_epi32 saturates 0x8000 to 0x7FFF. Instead ConcatEven within the two + // inputs, then concatenate them. + alignas(16) + const uint16_t kCompactEvenU16[8] = {0x0100, 0x0504, 0x0908, 0x0D0C}; + const VFromD shuf = BitCast(d, Load(du, kCompactEvenU16)); + const VFromD L = TableLookupBytes(lo, shuf); + const VFromD H = TableLookupBytes(hi, shuf); + return ConcatLowerLower(d, H, L); +#endif +} + +// 16-bit x4 +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { +#if HWY_TARGET == HWY_SSE2 + const Repartition dw; + return ConcatOdd(d, BitCast(d, ShiftLeft<16>(BitCast(dw, hi))), + BitCast(d, ShiftLeft<16>(BitCast(dw, lo)))); +#else + const Repartition du32; + // Don't care about upper half, no need to zero. + alignas(16) const uint8_t kCompactEvenU16[8] = {0, 1, 4, 5}; + const VFromD shuf = BitCast(d, Load(Full64(), kCompactEvenU16)); + const VFromD L = TableLookupBytes(lo, shuf); + const VFromD H = TableLookupBytes(hi, shuf); + return BitCast(d, InterleaveLower(du32, BitCast(du32, L), BitCast(du32, H))); +#endif +} + +// 32-bit full +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToFloat df; + return BitCast( + d, Vec128{_mm_shuffle_ps(BitCast(df, lo).raw, BitCast(df, hi).raw, + _MM_SHUFFLE(2, 0, 2, 0))}); +} +template +HWY_API VFromD ConcatEven(D /* d */, VFromD hi, VFromD lo) { + return VFromD{_mm_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(2, 0, 2, 0))}; +} + +// Any T x2 +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + return InterleaveLower(d, lo, hi); +} + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec128 DupEven(const Vec128 v) { + return v; +} + +template +HWY_API Vec128 DupEven(const Vec128 v) { + return InterleaveLower(DFromV(), v, v); +} + +template +HWY_API V DupEven(V v) { + const DFromV d; + +#if HWY_TARGET <= HWY_SSSE3 + const RebindToUnsigned du; + const VFromD shuffle = Dup128VecFromValues( + du, 0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14); + return TableLookupBytes(v, BitCast(d, shuffle)); +#else + const Repartition du16; + return IfVecThenElse(BitCast(d, Set(du16, uint16_t{0xFF00})), + BitCast(d, ShiftLeft<8>(BitCast(du16, v))), v); +#endif +} + +template +HWY_API Vec64 DupEven(const Vec64 v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm_shufflelo_epi16( + BitCast(du, v).raw, _MM_SHUFFLE(2, 2, 0, 0))}); +} + +// Generic for all vector lengths. +template +HWY_API V DupEven(const V v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t +#if HWY_TARGET <= HWY_SSSE3 + const VFromD shuffle = Dup128VecFromValues( + du, 0x0100, 0x0100, 0x0504, 0x0504, 0x0908, 0x0908, 0x0d0c, 0x0d0c); + return TableLookupBytes(v, BitCast(d, shuffle)); +#else + return BitCast( + d, VFromD{_mm_shufflehi_epi16( + _mm_shufflelo_epi16(BitCast(du, v).raw, _MM_SHUFFLE(2, 2, 0, 0)), + _MM_SHUFFLE(2, 2, 0, 0))}); +#endif +} + +template +HWY_API Vec128 DupEven(Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} + +HWY_API Vec128 DupEven(Vec128 v) { + return Vec128{_mm_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return v; +} + +template +HWY_API V DupOdd(V v) { + const DFromV d; + +#if HWY_TARGET <= HWY_SSSE3 + const RebindToUnsigned du; + const VFromD shuffle = Dup128VecFromValues( + du, 1, 1, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15); + return TableLookupBytes(v, BitCast(d, shuffle)); +#else + const Repartition du16; + return IfVecThenElse(BitCast(d, Set(du16, uint16_t{0x00FF})), + BitCast(d, ShiftRight<8>(BitCast(du16, v))), v); +#endif +} + +template +HWY_API Vec128 DupOdd(Vec128 v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm_shufflelo_epi16( + BitCast(du, v).raw, _MM_SHUFFLE(3, 3, 1, 1))}); +} + +// Generic for all vector lengths. +template +HWY_API V DupOdd(V v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t +#if HWY_TARGET <= HWY_SSSE3 + const VFromD shuffle = Dup128VecFromValues( + du, 0x0302, 0x0302, 0x0706, 0x0706, 0x0b0a, 0x0b0a, 0x0f0e, 0x0f0e); + return TableLookupBytes(v, BitCast(d, shuffle)); +#else + return BitCast( + d, VFromD{_mm_shufflehi_epi16( + _mm_shufflelo_epi16(BitCast(du, v).raw, _MM_SHUFFLE(3, 3, 1, 1)), + _MM_SHUFFLE(3, 3, 1, 1))}); +#endif +} + +template +HWY_API Vec128 DupOdd(Vec128 v) { + return Vec128{_mm_shuffle_epi32(v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} +template +HWY_API Vec128 DupOdd(Vec128 v) { + return Vec128{ + _mm_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} + +template +HWY_API Vec128 DupOdd(const Vec128 v) { + return InterleaveUpper(DFromV(), v, v); +} + +// ------------------------------ TwoTablesLookupLanes (DupEven) + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { + const DFromV d; + const Twice dt; +// TableLookupLanes currently requires table and index vectors to be the same +// size, though a half-length index vector would be sufficient here. +#if HWY_IS_MSAN + const Vec128 idx_vec{idx.raw}; + const Indices128 idx2{Combine(dt, idx_vec, idx_vec).raw}; +#else + // We only keep LowerHalf of the result, which is valid in idx. + const Indices128 idx2{idx.raw}; +#endif + return LowerHalf(d, TableLookupLanes(Combine(dt, b, a), idx2)); +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec128{_mm_permutex2var_epi8(a.raw, idx.raw, b.raw)}; +#else // AVX3 or below + const DFromV d; + const Vec128 idx_vec{idx.raw}; + +#if HWY_TARGET <= HWY_SSE4 + const Repartition du16; + const auto sel_hi_mask = + MaskFromVec(BitCast(d, ShiftLeft<3>(BitCast(du16, idx_vec)))); +#else + const RebindToSigned di; + const auto sel_hi_mask = + RebindMask(d, BitCast(di, idx_vec) > Set(di, int8_t{15})); +#endif + + const auto lo_lookup_result = TableLookupBytes(a, idx_vec); +#if HWY_TARGET <= HWY_AVX3 + const Vec128 lookup_result{_mm_mask_shuffle_epi8( + lo_lookup_result.raw, sel_hi_mask.raw, b.raw, idx_vec.raw)}; + return lookup_result; +#else + const auto hi_lookup_result = TableLookupBytes(b, idx_vec); + return IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result); +#endif // HWY_TARGET <= HWY_AVX3 +#endif // HWY_TARGET <= HWY_AVX3_DL +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_permutex2var_epi16(a.raw, idx.raw, b.raw)}; +#elif HWY_TARGET == HWY_SSE2 + const DFromV d; + const RebindToSigned di; + const Vec128 idx_vec{idx.raw}; + const auto sel_hi_mask = + RebindMask(d, BitCast(di, idx_vec) > Set(di, int16_t{7})); + const auto lo_lookup_result = TableLookupLanes(a, idx); + const auto hi_lookup_result = TableLookupLanes(b, idx); + return IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result); +#else + const DFromV d; + const Repartition du8; + return BitCast(d, TwoTablesLookupLanes(BitCast(du8, a), BitCast(du8, b), + Indices128{idx.raw})); +#endif +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_permutex2var_epi32(a.raw, idx.raw, b.raw)}; +#else // AVX2 or below + const DFromV d; + +#if HWY_TARGET <= HWY_AVX2 || HWY_TARGET == HWY_SSE2 + const Vec128 idx_vec{idx.raw}; + +#if HWY_TARGET <= HWY_AVX2 + const RebindToFloat d_sel; + const auto sel_hi_mask = MaskFromVec(BitCast(d_sel, ShiftLeft<29>(idx_vec))); +#else + const RebindToSigned d_sel; + const auto sel_hi_mask = BitCast(d_sel, idx_vec) > Set(d_sel, int32_t{3}); +#endif + + const auto lo_lookup_result = BitCast(d_sel, TableLookupLanes(a, idx)); + const auto hi_lookup_result = BitCast(d_sel, TableLookupLanes(b, idx)); + return BitCast(d, + IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result)); +#else // SSSE3 or SSE4 + const Repartition du8; + return BitCast(d, TwoTablesLookupLanes(BitCast(du8, a), BitCast(du8, b), + Indices128{idx.raw})); +#endif // HWY_TARGET <= HWY_AVX2 || HWY_TARGET == HWY_SSE2 +#endif // HWY_TARGET <= HWY_AVX3 +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, + Vec128 b, + Indices128 idx) { + return Vec128{_mm_permutex2var_ph(a.raw, idx.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_permutex2var_ps(a.raw, idx.raw, b.raw)}; +#elif HWY_TARGET <= HWY_AVX2 || HWY_TARGET == HWY_SSE2 + const DFromV d; + +#if HWY_TARGET <= HWY_AVX2 + const auto sel_hi_mask = + MaskFromVec(BitCast(d, ShiftLeft<29>(Vec128{idx.raw}))); +#else + const RebindToSigned di; + const auto sel_hi_mask = + RebindMask(d, Vec128{idx.raw} > Set(di, int32_t{3})); +#endif + + const auto lo_lookup_result = TableLookupLanes(a, idx); + const auto hi_lookup_result = TableLookupLanes(b, idx); + return IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result); +#else // SSSE3 or SSE4 + const DFromV d; + const Repartition du8; + return BitCast(d, TwoTablesLookupLanes(BitCast(du8, a), BitCast(du8, b), + Indices128{idx.raw})); +#endif +} + +template +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_permutex2var_epi64(a.raw, idx.raw, b.raw)}; +#else + const DFromV d; + const Vec128 idx_vec{idx.raw}; + const Indices128 idx_mod{And(idx_vec, Set(d, T{1})).raw}; + +#if HWY_TARGET <= HWY_SSE4 + const RebindToFloat d_sel; + const auto sel_hi_mask = MaskFromVec(BitCast(d_sel, ShiftLeft<62>(idx_vec))); +#else // SSE2 or SSSE3 + const Repartition di32; + const RebindToSigned d_sel; + const auto sel_hi_mask = MaskFromVec( + BitCast(d_sel, VecFromMask(di32, DupEven(BitCast(di32, idx_vec)) > + Set(di32, int32_t{1})))); +#endif // HWY_TARGET <= HWY_SSE4 + + const auto lo_lookup_result = BitCast(d_sel, TableLookupLanes(a, idx_mod)); + const auto hi_lookup_result = BitCast(d_sel, TableLookupLanes(b, idx_mod)); + return BitCast(d, + IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result)); +#endif // HWY_TARGET <= HWY_AVX3 +} + +HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, + Indices128 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_permutex2var_pd(a.raw, idx.raw, b.raw)}; +#else + const DFromV d; + const RebindToSigned di; + const Vec128 idx_vec{idx.raw}; + const Indices128 idx_mod{And(idx_vec, Set(di, int64_t{1})).raw}; + +#if HWY_TARGET <= HWY_SSE4 + const auto sel_hi_mask = MaskFromVec(BitCast(d, ShiftLeft<62>(idx_vec))); +#else // SSE2 or SSSE3 + const Repartition di32; + const auto sel_hi_mask = + MaskFromVec(BitCast(d, VecFromMask(di32, DupEven(BitCast(di32, idx_vec)) > + Set(di32, int32_t{1})))); +#endif // HWY_TARGET <= HWY_SSE4 + + const auto lo_lookup_result = TableLookupLanes(a, idx_mod); + const auto hi_lookup_result = TableLookupLanes(b, idx_mod); + return IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result); +#endif // HWY_TARGET <= HWY_AVX3 +} + +// ------------------------------ OddEven (IfThenElse) + +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { + const DFromV d; + const Repartition d8; + alignas(16) static constexpr uint8_t mask[16] = { + 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; + return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); +} + +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { + const DFromV d; +#if HWY_TARGET >= HWY_SSSE3 + const Repartition d8; + alignas(16) static constexpr uint8_t mask[16] = { + 0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0, 0xFF, 0xFF, 0, 0}; + return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); +#else + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm_blend_epi16( + BitCast(du, a).raw, BitCast(du, b).raw, 0x55)}); +#endif +} + +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + const __m128i odd = _mm_shuffle_epi32(a.raw, _MM_SHUFFLE(3, 1, 3, 1)); + const __m128i even = _mm_shuffle_epi32(b.raw, _MM_SHUFFLE(2, 0, 2, 0)); + return Vec128{_mm_unpacklo_epi32(even, odd)}; +#else + // _mm_blend_epi16 has throughput 1/cycle on SKX, whereas _ps can do 3/cycle. + const DFromV d; + const RebindToFloat df; + return BitCast(d, Vec128{_mm_blend_ps(BitCast(df, a).raw, + BitCast(df, b).raw, 5)}); +#endif +} + +template +HWY_INLINE Vec128 OddEven(const Vec128 a, const Vec128 b) { + // Same as ConcatUpperLower for full vectors; do not call that because this + // is more efficient for 64x1 vectors. + const DFromV d; + const RebindToFloat dd; +#if HWY_TARGET >= HWY_SSSE3 + return BitCast( + d, Vec128{_mm_shuffle_pd( + BitCast(dd, b).raw, BitCast(dd, a).raw, _MM_SHUFFLE2(1, 0))}); +#else + // _mm_shuffle_pd has throughput 1/cycle on SKX, whereas blend can do 3/cycle. + return BitCast(d, Vec128{_mm_blend_pd(BitCast(dd, a).raw, + BitCast(dd, b).raw, 1)}); +#endif +} + +template +HWY_API Vec128 OddEven(Vec128 a, Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + // SHUFPS must fill the lower half of the output from one input, so we + // need another shuffle. Unpack avoids another immediate byte. + const __m128 odd = _mm_shuffle_ps(a.raw, a.raw, _MM_SHUFFLE(3, 1, 3, 1)); + const __m128 even = _mm_shuffle_ps(b.raw, b.raw, _MM_SHUFFLE(2, 0, 2, 0)); + return Vec128{_mm_unpacklo_ps(even, odd)}; +#else + return Vec128{_mm_blend_ps(a.raw, b.raw, 5)}; +#endif +} + +// -------------------------- InterleaveEven + +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + return ConcatEven(d, b, a); +} + +// I8/U8 InterleaveEven is generic for all vector lengths that are >= 4 bytes +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const Repartition du16; + return OddEven(BitCast(d, ShiftLeft<8>(BitCast(du16, b))), a); +} + +// I16/U16 InterleaveEven is generic for all vector lengths that are >= 8 bytes +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const Repartition du32; + return OddEven(BitCast(d, ShiftLeft<16>(BitCast(du32, b))), a); +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm_mask_shuffle_epi32( + a.raw, static_cast<__mmask8>(0x0A), b.raw, + static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(2, 2, 0, 0)))}; +} +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm_mask_shuffle_ps(a.raw, static_cast<__mmask8>(0x0A), + b.raw, b.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} +#else +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const RebindToFloat df; + const auto b2_b0_a2_a0 = ConcatEven(df, BitCast(df, b), BitCast(df, a)); + return BitCast( + d, VFromD{_mm_shuffle_ps(b2_b0_a2_a0.raw, b2_b0_a2_a0.raw, + _MM_SHUFFLE(3, 1, 2, 0))}); +} +#endif + +// -------------------------- InterleaveOdd + +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + return ConcatOdd(d, b, a); +} + +// I8/U8 InterleaveOdd is generic for all vector lengths that are >= 4 bytes +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const Repartition du16; + return OddEven(b, BitCast(d, ShiftRight<8>(BitCast(du16, a)))); +} + +// I16/U16 InterleaveOdd is generic for all vector lengths that are >= 8 bytes +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const Repartition du32; + return OddEven(b, BitCast(d, ShiftRight<16>(BitCast(du32, a)))); +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm_mask_shuffle_epi32( + b.raw, static_cast<__mmask8>(0x05), a.raw, + static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(3, 3, 1, 1)))}; +} +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm_mask_shuffle_ps(b.raw, static_cast<__mmask8>(0x05), + a.raw, a.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} +#else +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const RebindToFloat df; + const auto b3_b1_a3_a1 = ConcatOdd(df, BitCast(df, b), BitCast(df, a)); + return BitCast( + d, VFromD{_mm_shuffle_ps(b3_b1_a3_a1.raw, b3_b1_a3_a1.raw, + _MM_SHUFFLE(3, 1, 2, 0))}); +} +#endif + +// ------------------------------ OddEvenBlocks +template +HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { + return even; +} + +// ------------------------------ SwapAdjacentBlocks +template +HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { + return v; +} + +// ------------------------------ InterleaveEvenBlocks +template , HWY_IF_V_SIZE_LE_D(D, 16)> +HWY_API V InterleaveEvenBlocks(D, V a, V /*b*/) { + return a; +} +// ------------------------------ InterleaveOddBlocks +template , HWY_IF_V_SIZE_LE_D(D, 16)> +HWY_API V InterleaveOddBlocks(D, V a, V /*b*/) { + return a; +} + +// ------------------------------ InterleaveLowerBlocks +template , HWY_IF_V_SIZE_LE_D(D, 16)> +HWY_API V InterleaveLowerBlocks(D, V a, V /*b*/) { + return a; +} +// ------------------------------ InterleaveUpperBlocks +template , HWY_IF_V_SIZE_LE_D(D, 16)> +HWY_API V InterleaveUpperBlocks(D, V a, V /*b*/) { + return a; +} + +// ------------------------------ Shl (ZipLower, Mul) + +// Use AVX2/3 variable shifts where available, otherwise multiply by powers of +// two from loading float exponents, which is considerably faster (according +// to LLVM-MCA) than scalar or testing bits: https://gcc.godbolt.org/z/9G7Y9v. + +namespace detail { + +#if HWY_TARGET == HWY_AVX2 // Unused for AVX3 - we use sllv directly +template +HWY_API V AVX2ShlU16Vec128(V v, V bits) { + const DFromV d; + const Rebind du32; + return TruncateTo(d, PromoteTo(du32, v) << PromoteTo(du32, bits)); +} +#elif HWY_TARGET > HWY_AVX2 + +template +static HWY_INLINE VFromD Pow2ConvF32ToI32( + D32 d32, VFromD> vf32) { + const RebindToSigned di32; +#if HWY_COMPILER_GCC_ACTUAL + // ConvertInRangeTo is safe with GCC due the inline assembly workaround used + // for F32->I32 ConvertInRangeTo with GCC + return BitCast(d32, ConvertInRangeTo(di32, vf32)); +#else + // Otherwise, use NearestIntInRange because we rely on the native 0x80..00 + // overflow behavior + return BitCast(d32, NearestIntInRange(di32, vf32)); +#endif +} + +// Returns 2^v for use as per-lane multipliers to emulate 16-bit shifts. +template +HWY_INLINE Vec128> Pow2(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWide dw; + const Rebind df; + const auto zero = Zero(d); + // Move into exponent (this u16 will become the upper half of an f32) + const auto exp = ShiftLeft<23 - 16>(v); + const auto upper = exp + Set(d, 0x3F80); // upper half of 1.0f + // Insert 0 into lower halves for reinterpreting as binary32. + const auto f0 = ZipLower(dw, zero, upper); + const auto f1 = ZipUpper(dw, zero, upper); + // See cvtps comment below. + const VFromD bits0 = Pow2ConvF32ToI32(dw, BitCast(df, f0)); + const VFromD bits1 = Pow2ConvF32ToI32(dw, BitCast(df, f1)); +#if HWY_TARGET <= HWY_SSE4 + return VFromD{_mm_packus_epi32(bits0.raw, bits1.raw)}; +#else + return ConcatEven(du, BitCast(du, bits1), BitCast(du, bits0)); +#endif +} + +template +HWY_INLINE Vec128, N> Pow2(const Vec128 v) { + const DFromV d; + const RebindToUnsigned du; + const Twice dt_u; + const RepartitionToWide dt_w; + const RebindToFloat dt_f; + // Move into exponent (this u16 will become the upper half of an f32) + const auto exp = ShiftLeft<23 - 16>(v); + const auto upper = exp + Set(d, 0x3F80); // upper half of 1.0f + // Insert 0 into lower halves for reinterpreting as binary32. + const auto f0 = ZipLower(dt_w, Zero(dt_u), ResizeBitCast(dt_u, upper)); + // See cvtps comment below. + const VFromD bits0 = + Pow2ConvF32ToI32(dt_w, BitCast(dt_f, f0)); +#if HWY_TARGET <= HWY_SSE4 + return VFromD{_mm_packus_epi32(bits0.raw, bits0.raw)}; +#elif HWY_TARGET == HWY_SSSE3 + alignas(16) + const uint16_t kCompactEvenU16[8] = {0x0100, 0x0504, 0x0908, 0x0D0C}; + return TableLookupBytes(bits0, Load(du, kCompactEvenU16)); +#else + const RebindToSigned dt_i32; + const auto bits0_i32 = ShiftRight<16>(BitCast(dt_i32, ShiftLeft<16>(bits0))); + return VFromD{_mm_packs_epi32(bits0_i32.raw, bits0_i32.raw)}; +#endif +} + +// Same, for 32-bit shifts. +template +HWY_INLINE Vec128, N> Pow2(const Vec128 v) { + const DFromV d; + const RebindToFloat df; + const auto exp = ShiftLeft<23>(v); + const auto f = exp + Set(d, 0x3F800000); // 1.0f + // Do not use ConvertTo because we rely on the native 0x80..00 overflow + // behavior. + return Pow2ConvF32ToI32(d, BitCast(df, f)); +} + +#endif // HWY_TARGET > HWY_AVX2 + +template +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_sllv_epi16(v.raw, bits.raw)}; +#elif HWY_TARGET == HWY_AVX2 + return AVX2ShlU16Vec128(v, bits); +#else + return v * Pow2(bits); +#endif +} + +#if HWY_TARGET > HWY_AVX3 +HWY_API Vec16 Shl(hwy::UnsignedTag /*tag*/, Vec16 v, + Vec16 bits) { +#if HWY_TARGET <= HWY_SSE4 + const Vec16 bits16{_mm_cvtepu16_epi64(bits.raw)}; +#else + const auto bits16 = And(bits, Vec16{_mm_set_epi64x(0, 0xFFFF)}); +#endif + return Vec16{_mm_sll_epi16(v.raw, bits16.raw)}; +} +#endif + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_INLINE V AVX2ShlU8Vec128(V v, V bits) { + const DFromV d; + const Rebind du16; + return TruncateTo(d, PromoteTo(du16, v) << PromoteTo(du16, bits)); +} +#elif HWY_TARGET <= HWY_AVX2 +template +HWY_INLINE V AVX2ShlU8Vec128(V v, V bits) { + const DFromV d; + const Rebind du32; + return TruncateTo(d, PromoteTo(du32, v) << PromoteTo(du32, bits)); +} +template +HWY_INLINE V AVX2ShlU8Vec128(V v, V bits) { + const DFromV d; + const Half dh; + const Rebind du16; + const Rebind dh_u32; + + const VFromD lo_shl_result = + PromoteTo(dh_u32, LowerHalf(dh, v)) + << PromoteTo(dh_u32, LowerHalf(dh, bits)); + const VFromD hi_shl_result = + PromoteTo(dh_u32, UpperHalf(dh, v)) + << PromoteTo(dh_u32, UpperHalf(dh, bits)); + const VFromD u16_shl_result = ConcatEven( + du16, BitCast(du16, hi_shl_result), BitCast(du16, lo_shl_result)); + return TruncateTo(d, u16_shl_result); +} +#endif // HWY_TARGET <= HWY_AVX3 + +// 8-bit: may use the Shl overload for uint16_t. +template +HWY_API Vec128 Shl(hwy::UnsignedTag tag, Vec128 v, + Vec128 bits) { + const DFromV d; +#if HWY_TARGET <= HWY_AVX3_DL + (void)tag; + // kMask[i] = 0xFF >> i + alignas(16) static constexpr uint8_t kMasks[16] = { + 0xFF, 0x7F, 0x3F, 0x1F, 0x0F, 0x07, 0x03, 0x01, 0x00}; + // kShl[i] = 1 << i + alignas(16) static constexpr uint8_t kShl[16] = {1, 2, 4, 8, 0x10, + 0x20, 0x40, 0x80, 0x00}; + v = And(v, TableLookupBytes(Load(Full64(), kMasks), bits)); + const VFromD mul = + TableLookupBytes(Load(Full64(), kShl), bits); + return VFromD{_mm_gf2p8mul_epi8(v.raw, mul.raw)}; +#elif HWY_TARGET <= HWY_AVX2 + (void)tag; + (void)d; + return AVX2ShlU8Vec128(v, bits); +#else + const Repartition dw; + using VW = VFromD; + const VW even_mask = Set(dw, 0x00FF); + const VW odd_mask = Set(dw, 0xFF00); + const VW vw = BitCast(dw, v); + const VW bits16 = BitCast(dw, bits); + // Shift even lanes in-place + const VW evens = Shl(tag, vw, And(bits16, even_mask)); + const VW odds = Shl(tag, And(vw, odd_mask), ShiftRight<8>(bits16)); + return OddEven(BitCast(d, odds), BitCast(d, evens)); +#endif +} +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_TARGET <= HWY_SSE4 + const Vec16 bits8{_mm_cvtepu8_epi64(bits.raw)}; +#else + const Vec16 bits8 = + And(Vec16{bits.raw}, Vec16{_mm_set_epi64x(0, 0xFF)}); +#endif + return Vec128{_mm_sll_epi16(v.raw, bits8.raw)}; +} + +template +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_TARGET >= HWY_SSE4 + return v * Pow2(bits); +#else + return Vec128{_mm_sllv_epi32(v.raw, bits.raw)}; +#endif +} + +#if HWY_TARGET >= HWY_SSE4 +HWY_API Vec32 Shl(hwy::UnsignedTag /*tag*/, Vec32 v, + const Vec32 bits) { +#if HWY_TARGET == HWY_SSE4 + const Vec32 bits32{_mm_cvtepu32_epi64(bits.raw)}; +#else + const auto bits32 = + Combine(Full64(), Zero(Full32()), bits); +#endif + return Vec32{_mm_sll_epi32(v.raw, bits32.raw)}; +} +#endif + +HWY_API Vec128 Shl(hwy::UnsignedTag /*tag*/, Vec128 v, + Vec128 bits) { +#if HWY_TARGET >= HWY_SSE4 + const DFromV d; + // Individual shifts and combine + const Vec128 out0{_mm_sll_epi64(v.raw, bits.raw)}; + const __m128i bits1 = _mm_unpackhi_epi64(bits.raw, bits.raw); + const Vec128 out1{_mm_sll_epi64(v.raw, bits1)}; + return ConcatUpperLower(d, out1, out0); +#else + return Vec128{_mm_sllv_epi64(v.raw, bits.raw)}; +#endif +} +HWY_API Vec64 Shl(hwy::UnsignedTag /*tag*/, Vec64 v, + Vec64 bits) { + return Vec64{_mm_sll_epi64(v.raw, bits.raw)}; +} + +// Signed left shift is the same as unsigned. +template +HWY_API Vec128 Shl(hwy::SignedTag /*tag*/, Vec128 v, + Vec128 bits) { + const DFromV di; + const RebindToUnsigned du; + return BitCast(di, + Shl(hwy::UnsignedTag(), BitCast(du, v), BitCast(du, bits))); +} + +} // namespace detail + +template +HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { + return detail::Shl(hwy::TypeTag(), v, bits); +} + +// ------------------------------ Shr (mul, mask, BroadcastSignBit) + +// Use AVX2+ variable shifts except for SSSE3/SSE4. There, we use +// widening multiplication by powers of two obtained by loading float exponents, +// followed by a constant right-shift. This is still faster than a scalar or +// bit-test approach: https://gcc.godbolt.org/z/9G7Y9v. + +#if HWY_TARGET <= HWY_AVX2 +namespace detail { + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_INLINE V AVX2ShrU8Vec128(V v, V bits) { + const DFromV d; + const Rebind du16; + const RebindToSigned di16; + return DemoteTo(d, + BitCast(di16, PromoteTo(du16, v) >> PromoteTo(du16, bits))); +} +#else // AVX2 +template +HWY_INLINE V AVX2ShrU16Vec128(V v, V bits) { + const DFromV d; + const Rebind du32; + const RebindToSigned di32; + return DemoteTo(d, + BitCast(di32, PromoteTo(du32, v) >> PromoteTo(du32, bits))); +} +template +HWY_INLINE V AVX2ShrU8Vec128(V v, V bits) { + const DFromV d; + const Rebind du32; + const RebindToSigned di32; + return DemoteTo(d, + BitCast(di32, PromoteTo(du32, v) >> PromoteTo(du32, bits))); +} +template +HWY_INLINE V AVX2ShrU8Vec128(V v, V bits) { + const DFromV d; + const Half dh; + const Rebind di16; + const Rebind du16; + const Rebind dh_i32; + const Rebind dh_u32; + + const auto lo_shr_result = + BitCast(dh_i32, PromoteTo(dh_u32, LowerHalf(dh, v)) >> + PromoteTo(dh_u32, LowerHalf(dh, bits))); + const auto hi_shr_result = + BitCast(dh_i32, PromoteTo(dh_u32, UpperHalf(dh, v)) >> + PromoteTo(dh_u32, UpperHalf(dh, bits))); + const auto i16_shr_result = + BitCast(di16, OrderedDemote2To(du16, lo_shr_result, hi_shr_result)); + return DemoteTo(d, i16_shr_result); +} +#endif // HWY_TARGET <= HWY_AVX3 + +} // namespace detail +#endif // HWY_TARGET <= HWY_AVX2 + +template +HWY_API Vec128 operator>>(Vec128 in, + const Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_srlv_epi16(in.raw, bits.raw)}; +#elif HWY_TARGET <= HWY_AVX2 + return detail::AVX2ShrU16Vec128(in, bits); +#else + const DFromV d; + // For bits=0, we cannot mul by 2^16, so fix the result later. + const auto out = MulHigh(in, detail::Pow2(Set(d, 16) - bits)); + // Replace output with input where bits == 0. + return IfThenElse(bits == Zero(d), in, out); +#endif +} + +#if HWY_TARGET > HWY_AVX3 +HWY_API Vec16 operator>>(const Vec16 in, + const Vec16 bits) { +#if HWY_TARGET <= HWY_SSE4 + const Vec16 bits16{_mm_cvtepu16_epi64(bits.raw)}; +#else + const auto bits16 = And(bits, Vec16{_mm_set_epi64x(0, 0xFFFF)}); +#endif + return Vec16{_mm_srl_epi16(in.raw, bits16.raw)}; +} +#endif + +// 8-bit uses 16-bit shifts. +template +HWY_API Vec128 operator>>(Vec128 in, + const Vec128 bits) { +#if HWY_TARGET <= HWY_AVX2 + return detail::AVX2ShrU8Vec128(in, bits); +#else + const DFromV d; + const Repartition dw; + using VW = VFromD; + const VW mask = Set(dw, 0x00FF); + const VW vw = BitCast(dw, in); + const VW bits16 = BitCast(dw, bits); + const VW evens = And(vw, mask) >> And(bits16, mask); + // Shift odd lanes in-place + const VW odds = vw >> ShiftRight<8>(bits16); + return OddEven(BitCast(d, odds), BitCast(d, evens)); +#endif +} +HWY_API Vec128 operator>>(const Vec128 in, + const Vec128 bits) { +#if HWY_TARGET <= HWY_SSE4 + const Vec16 in8{_mm_cvtepu8_epi16(in.raw)}; + const Vec16 bits8{_mm_cvtepu8_epi64(bits.raw)}; +#else + const Vec16 mask{_mm_set_epi64x(0, 0xFF)}; + const Vec16 in8 = And(Vec16{in.raw}, mask); + const Vec16 bits8 = And(Vec16{bits.raw}, mask); +#endif + return Vec128{_mm_srl_epi16(in8.raw, bits8.raw)}; +} + +template +HWY_API Vec128 operator>>(const Vec128 in, + const Vec128 bits) { +#if HWY_TARGET >= HWY_SSE4 + // 32x32 -> 64 bit mul, then shift right by 32. + const DFromV d32; + // Move odd lanes into position for the second mul. Shuffle more gracefully + // handles N=1 than repartitioning to u64 and shifting 32 bits right. + const Vec128 in31{_mm_shuffle_epi32(in.raw, 0x31)}; + // For bits=0, we cannot mul by 2^32, so fix the result later. + const auto mul = detail::Pow2(Set(d32, 32) - bits); + const auto out20 = ShiftRight<32>(MulEven(in, mul)); // z 2 z 0 + const Vec128 mul31{_mm_shuffle_epi32(mul.raw, 0x31)}; + // No need to shift right, already in the correct position. + const auto out31 = BitCast(d32, MulEven(in31, mul31)); // 3 ? 1 ? + const Vec128 out = OddEven(out31, BitCast(d32, out20)); + // Replace output with input where bits == 0. + return IfThenElse(bits == Zero(d32), in, out); +#else + return Vec128{_mm_srlv_epi32(in.raw, bits.raw)}; +#endif +} + +#if HWY_TARGET >= HWY_SSE4 +HWY_API Vec128 operator>>(const Vec128 in, + const Vec128 bits) { +#if HWY_TARGET == HWY_SSE4 + const Vec32 bits32{_mm_cvtepu32_epi64(bits.raw)}; +#else + const auto bits32 = + Combine(Full64(), Zero(Full32()), bits); +#endif + return Vec128{_mm_srl_epi32(in.raw, bits32.raw)}; +} +#endif + +HWY_API Vec128 operator>>(const Vec128 v, + const Vec128 bits) { +#if HWY_TARGET >= HWY_SSE4 + const DFromV d; + // Individual shifts and combine + const Vec128 out0{_mm_srl_epi64(v.raw, bits.raw)}; + const __m128i bits1 = _mm_unpackhi_epi64(bits.raw, bits.raw); + const Vec128 out1{_mm_srl_epi64(v.raw, bits1)}; + return ConcatUpperLower(d, out1, out0); +#else + return Vec128{_mm_srlv_epi64(v.raw, bits.raw)}; +#endif +} +HWY_API Vec64 operator>>(const Vec64 v, + const Vec64 bits) { + return Vec64{_mm_srl_epi64(v.raw, bits.raw)}; +} + +namespace detail { + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_INLINE V AVX2ShrI8Vec128(V v, V bits) { + const DFromV d; + const Rebind di16; + return DemoteTo(d, PromoteTo(di16, v) >> PromoteTo(di16, bits)); +} +#elif HWY_TARGET <= HWY_AVX2 // AVX2 +template +HWY_INLINE V AVX2ShrI16Vec128(V v, V bits) { + const DFromV d; + const Rebind di32; + return DemoteTo(d, PromoteTo(di32, v) >> PromoteTo(di32, bits)); +} +template +HWY_INLINE V AVX2ShrI8Vec128(V v, V bits) { + const DFromV d; + const Rebind di32; + return DemoteTo(d, PromoteTo(di32, v) >> PromoteTo(di32, bits)); +} +template +HWY_INLINE V AVX2ShrI8Vec128(V v, V bits) { + const DFromV d; + const Half dh; + const Rebind di16; + const Rebind dh_i32; + + const auto lo_shr_result = PromoteTo(dh_i32, LowerHalf(dh, v)) >> + PromoteTo(dh_i32, LowerHalf(dh, bits)); + const auto hi_shr_result = PromoteTo(dh_i32, UpperHalf(dh, v)) >> + PromoteTo(dh_i32, UpperHalf(dh, bits)); + const auto i16_shr_result = + OrderedDemote2To(di16, lo_shr_result, hi_shr_result); + return DemoteTo(d, i16_shr_result); +} +#endif + +#if HWY_TARGET > HWY_AVX3 +// Also used in x86_256-inl.h. +template +HWY_INLINE V SignedShr(const DI di, const V v, const V count_i) { + const RebindToUnsigned du; + const auto count = BitCast(du, count_i); // same type as value to shift + // Clear sign and restore afterwards. This is preferable to shifting the MSB + // downwards because Shr is somewhat more expensive than Shl. + const auto sign = BroadcastSignBit(v); + const auto abs = BitCast(du, v ^ sign); // off by one, but fixed below + return BitCast(di, abs >> count) ^ sign; +} +#endif + +} // namespace detail + +template +HWY_API Vec128 operator>>(Vec128 v, + Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_srav_epi16(v.raw, bits.raw)}; +#elif HWY_TARGET <= HWY_AVX2 + return detail::AVX2ShrI16Vec128(v, bits); +#else + const DFromV d; + return detail::SignedShr(d, v, bits); +#endif +} + +#if HWY_TARGET > HWY_AVX3 +HWY_API Vec16 operator>>(Vec16 v, Vec16 bits) { +#if HWY_TARGET <= HWY_SSE4 + const Vec16 bits16{_mm_cvtepu16_epi64(bits.raw)}; +#else + const auto bits16 = And(bits, Vec16{_mm_set_epi64x(0, 0xFFFF)}); +#endif + return Vec16{_mm_sra_epi16(v.raw, bits16.raw)}; +} +#endif + +template +HWY_API Vec128 operator>>(Vec128 v, + Vec128 bits) { +#if HWY_TARGET <= HWY_AVX2 + return detail::AVX2ShrI8Vec128(v, bits); +#else + const DFromV d; + return detail::SignedShr(d, v, bits); +#endif +} +HWY_API Vec128 operator>>(Vec128 v, + Vec128 bits) { +#if HWY_TARGET <= HWY_SSE4 + const Vec16 vi16{_mm_cvtepi8_epi16(v.raw)}; + const Vec16 bits8{_mm_cvtepu8_epi64(bits.raw)}; +#else + const DFromV d; + const Rebind di16; + const Twice dt; + + const auto vi16 = ShiftRight<8>(BitCast(di16, Combine(dt, v, v))); + const Vec16 bits8 = + And(Vec16{bits.raw}, Vec16{_mm_set_epi64x(0, 0xFF)}); +#endif + return Vec128{_mm_sra_epi16(vi16.raw, bits8.raw)}; +} + +template +HWY_API Vec128 operator>>(Vec128 v, + Vec128 bits) { +#if HWY_TARGET <= HWY_AVX2 + return Vec128{_mm_srav_epi32(v.raw, bits.raw)}; +#else + const DFromV d; + return detail::SignedShr(d, v, bits); +#endif +} + +#if HWY_TARGET > HWY_AVX2 +HWY_API Vec32 operator>>(Vec32 v, Vec32 bits) { +#if HWY_TARGET == HWY_SSE4 + const Vec32 bits32{_mm_cvtepu32_epi64(bits.raw)}; +#else + const auto bits32 = Combine(Full64(), Zero(Full32()), bits); +#endif + return Vec32{_mm_sra_epi32(v.raw, bits32.raw)}; +} +#endif + +template +HWY_API Vec128 operator>>(Vec128 v, + Vec128 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec128{_mm_srav_epi64(v.raw, bits.raw)}; +#else + const DFromV d; + return detail::SignedShr(d, v, bits); +#endif +} + +// ------------------------------ MulEven/Odd 64x64 (UpperHalf) + +namespace detail { + +template )> +static HWY_INLINE V SSE2Mul128(V a, V b, V& mulH) { + const DFromV du64; + const RepartitionToNarrow du32; + const auto maskL = Set(du64, 0xFFFFFFFFULL); + const auto a32 = BitCast(du32, a); + const auto b32 = BitCast(du32, b); + // Inputs for MulEven: we only need the lower 32 bits + const auto aH = Shuffle2301(a32); + const auto bH = Shuffle2301(b32); + + // Knuth double-word multiplication. We use 32x32 = 64 MulEven and only need + // the even (lower 64 bits of every 128-bit block) results. See + // https://github.com/hcs0/Hackers-Delight/blob/master/muldwu.c.txt + const auto aLbL = MulEven(a32, b32); + const auto w3 = aLbL & maskL; + + const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); + const auto w2 = t2 & maskL; + const auto w1 = ShiftRight<32>(t2); + + const auto t = MulEven(a32, bH) + w2; + const auto k = ShiftRight<32>(t); + + mulH = MulEven(aH, bH) + w1 + k; + return ShiftLeft<32>(t) + w3; +} + +template )> +static HWY_INLINE V SSE2Mul128(V a, V b, V& mulH) { + const DFromV di64; + const RebindToUnsigned du64; + using VU64 = VFromD; + + VU64 unsigned_mulH; + const auto mulL = BitCast( + di64, SSE2Mul128(BitCast(du64, a), BitCast(du64, b), unsigned_mulH)); + mulH = BitCast(di64, unsigned_mulH) - And(BroadcastSignBit(a), b) - + And(a, BroadcastSignBit(b)); + return mulL; +} + +} // namespace detail + +#if !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 + +template ), + HWY_IF_V_SIZE_GT_V(V, (HWY_ARCH_X86_64 ? 16 : 8))> +HWY_API V MulEven(V a, V b) { + V mulH; + const V mulL = detail::SSE2Mul128(a, b, mulH); + return InterleaveLower(mulL, mulH); +} + +template ), + HWY_IF_V_SIZE_GT_V(V, (HWY_ARCH_X86_64 ? 16 : 8))> +HWY_API V MulOdd(V a, V b) { + const DFromV du64; + V mulH; + const V mulL = detail::SSE2Mul128(a, b, mulH); + return InterleaveUpper(du64, mulL, mulH); +} + +#endif // !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 + +template ), + HWY_IF_V_SIZE_GT_V(V, (HWY_ARCH_X86_64 ? 8 : 0))> +HWY_API V MulHigh(V a, V b) { + V mulH; + detail::SSE2Mul128(a, b, mulH); + return mulH; +} + +#if HWY_ARCH_X86_64 + +template +HWY_API Vec128 MulEven(Vec128 a, Vec128 b) { + const DFromV d; + alignas(16) T mul[2]; + mul[0] = Mul128(GetLane(a), GetLane(b), &mul[1]); + return Load(d, mul); +} + +template +HWY_API Vec128 MulOdd(Vec128 a, Vec128 b) { + const DFromV d; + const Half d2; + alignas(16) T mul[2]; + const T a1 = GetLane(UpperHalf(d2, a)); + const T b1 = GetLane(UpperHalf(d2, b)); + mul[0] = Mul128(a1, b1, &mul[1]); + return Load(d, mul); +} + +template +HWY_API Vec64 MulHigh(Vec64 a, Vec64 b) { + T hi; + Mul128(GetLane(a), GetLane(b), &hi); + return Vec64{_mm_cvtsi64_si128(static_cast(hi))}; +} + +#endif // HWY_ARCH_X86_64 + +// ================================================== CONVERT (2) + +// ------------------------------ PromoteEvenTo/PromoteOddTo + +#if HWY_TARGET > HWY_AVX3 +namespace detail { + +// I32->I64 PromoteEvenTo/PromoteOddTo + +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + Vec64 v) { + return PromoteLowerTo(d_to, v); +} + +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + Vec128 v) { + const Repartition d_from; + return PromoteLowerTo(d_to, ConcatEven(d_from, v, v)); +} + +template +HWY_INLINE VFromD PromoteOddTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + V v) { + const Repartition d_from; + return PromoteLowerTo(d_to, ConcatOdd(d_from, v, v)); +} + +} // namespace detail +#endif + +// ------------------------------ PromoteEvenTo/PromoteOddTo +#include "hwy/ops/inside-inl.h" + +// ------------------------------ WidenMulPairwiseAdd (PromoteEvenTo) + +#if HWY_NATIVE_DOT_BF16 + +template >> +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + return VFromD{_mm_dpbf16_ps(Zero(df).raw, + reinterpret_cast<__m128bh>(a.raw), + reinterpret_cast<__m128bh>(b.raw))}; +} + +#else + +// Generic for all vector lengths. +template >> +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b), + Mul(PromoteOddTo(df, a), PromoteOddTo(df, b))); +} + +#endif // HWY_NATIVE_DOT_BF16 + +// Even if N=1, the input is always at least 2 lanes, hence madd_epi16 is safe. +template >> +HWY_API VFromD WidenMulPairwiseAdd(D32 /* tag */, V16 a, V16 b) { + return VFromD{_mm_madd_epi16(a.raw, b.raw)}; +} + +// Generic for all vector lengths. +template >> +HWY_API VFromD WidenMulPairwiseAdd(DU32 du32, VU16 a, VU16 b) { + const auto p_lo = a * b; + const auto p_hi = MulHigh(a, b); + + const auto p_hi1_lo0 = BitCast(du32, OddEven(p_hi, p_lo)); + const auto p_hi0_lo1 = Or(ShiftLeft<16>(BitCast(du32, p_hi)), + ShiftRight<16>(BitCast(du32, p_lo))); + return Add(BitCast(du32, p_hi1_lo0), BitCast(du32, p_hi0_lo1)); +} + +// ------------------------------ SatWidenMulPairwiseAdd + +#if HWY_TARGET <= HWY_SSSE3 + +#ifdef HWY_NATIVE_U8_I8_SATWIDENMULPAIRWISEADD +#undef HWY_NATIVE_U8_I8_SATWIDENMULPAIRWISEADD +#else +#define HWY_NATIVE_U8_I8_SATWIDENMULPAIRWISEADD +#endif + +// Even if N=1, the input is always at least 2 lanes, hence _mm_maddubs_epi16 +// is safe. +template +HWY_API VFromD SatWidenMulPairwiseAdd( + DI16 /* tag */, VFromD> a, + VFromD> b) { + return VFromD{_mm_maddubs_epi16(a.raw, b.raw)}; +} + +#endif + +// ------------------------------ SatWidenMulPairwiseAccumulate + +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#undef HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#else +#define HWY_NATIVE_I16_I16_SATWIDENMULPAIRWISEACCUM +#endif + +// Even if N=1, the I16 vectors have at least 2 lanes, hence _mm_dpwssds_epi32 +// is safe. +template +HWY_API VFromD SatWidenMulPairwiseAccumulate( + DI32 /* tag */, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD{_mm_dpwssds_epi32(sum.raw, a.raw, b.raw)}; +} + +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ------------------------------ ReorderWidenMulAccumulate (PromoteEvenTo) + +#if HWY_NATIVE_DOT_BF16 + +#ifdef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#undef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#else +#define HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16 +#endif + +template >> +HWY_API VFromD ReorderWidenMulAccumulate(DF /*df*/, VBF a, VBF b, + const VFromD sum0, + VFromD& /*sum1*/) { + return VFromD{_mm_dpbf16_ps(sum0.raw, reinterpret_cast<__m128bh>(a.raw), + reinterpret_cast<__m128bh>(b.raw))}; +} + +template +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW) { + // Sum1 is unused and the invariant already holds. + return sum0; +} + +#endif // HWY_NATIVE_DOT_BF16 + +// Even if N=1, the input is always at least 2 lanes, hence madd_epi16 is safe. +template >> +HWY_API VFromD ReorderWidenMulAccumulate(D32 d, V16 a, V16 b, + const VFromD sum0, + VFromD& /*sum1*/) { + (void)d; +#if HWY_TARGET <= HWY_AVX3_DL + return VFromD{_mm_dpwssd_epi32(sum0.raw, a.raw, b.raw)}; +#else + return sum0 + WidenMulPairwiseAdd(d, a, b); +#endif +} + +template >> +HWY_API VFromD ReorderWidenMulAccumulate(DU32 d, VU16 a, VU16 b, + const VFromD sum0, + VFromD& /*sum1*/) { + (void)d; + return sum0 + WidenMulPairwiseAdd(d, a, b); +} + +// ------------------------------ RearrangeToOddPlusEven +template +HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW) { + // For integer types, sum1 is unused and the invariant already holds. + return sum0; +} + +// ------------------------------ SumOfMulQuadAccumulate +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE +#endif + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DI32 /*di32*/, VFromD> a_u, + VFromD> b_i, VFromD sum) { + return VFromD{_mm_dpbusd_epi32(sum.raw, a_u.raw, b_i.raw)}; +} + +#ifdef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE +#endif + +#if HWY_X86_HAVE_AVX10_2_OPS +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 /*di32*/, + VFromD> a, + VFromD> b, + VFromD sum) { + return VFromD{_mm_dpbssd_epi32(sum.raw, a.raw, b.raw)}; +} +#else // !HWY_X86_HAVE_AVX10_2_OPS +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 di32, + VFromD> a, + VFromD> b, + VFromD sum) { + const Repartition du8; + + const auto a_u = BitCast(du8, a); + const auto result_sum_0 = SumOfMulQuadAccumulate(di32, a_u, b, sum); + const auto result_sum_1 = ShiftLeft<8>( + SumOfMulQuadAccumulate(di32, ShiftRight<7>(a_u), b, Zero(di32))); + return result_sum_0 - result_sum_1; +} +#endif // HWY_X86_HAVE_AVX10_2_OPS + +#ifdef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#undef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#else +#define HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE +#endif + +#if HWY_X86_HAVE_AVX10_2_OPS +template +HWY_API VFromD SumOfMulQuadAccumulate( + DU32 /*du32*/, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD{_mm_dpbuud_epi32(sum.raw, a.raw, b.raw)}; +} +#else // !HWY_X86_HAVE_AVX10_2_OPS +template +HWY_API VFromD SumOfMulQuadAccumulate( + DU32 du32, VFromD> a, + VFromD> b, VFromD sum) { + const Repartition du8; + const RebindToSigned di8; + const RebindToSigned di32; + + const auto b_i = BitCast(di8, b); + const auto result_sum_0 = + SumOfMulQuadAccumulate(di32, a, b_i, BitCast(di32, sum)); + const auto result_sum_1 = ShiftLeft<8>( + SumOfMulQuadAccumulate(di32, a, BroadcastSignBit(b_i), Zero(di32))); + + return BitCast(du32, result_sum_0 - result_sum_1); +} +#endif // HWY_X86_HAVE_AVX10_2_OPS + +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_packs_epi32(v.raw, v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { +#if HWY_TARGET >= HWY_SSSE3 + const Rebind di32; + const auto zero_if_neg = AndNot(ShiftRight<31>(v), v); + const auto too_big = VecFromMask(di32, Gt(v, Set(di32, 0xFFFF))); + const auto clamped = Or(zero_if_neg, too_big); +#if HWY_TARGET == HWY_SSE2 + const Rebind du16; + const RebindToSigned di16; + return BitCast(du16, DemoteTo(di16, ShiftRight<16>(ShiftLeft<16>(clamped)))); +#else + const Repartition du16; + // Lower 2 bytes from each 32-bit lane; same as return type for fewer casts. + alignas(16) static constexpr uint16_t kLower2Bytes[16] = { + 0x0100, 0x0504, 0x0908, 0x0D0C, 0x8080, 0x8080, 0x8080, 0x8080}; + const auto lo2 = Load(du16, kLower2Bytes); + return VFromD{TableLookupBytes(BitCast(du16, clamped), lo2).raw}; +#endif +#else + return VFromD{_mm_packus_epi32(v.raw, v.raw)}; +#endif +} + +template +HWY_API VFromD DemoteTo(D du16, VFromD> v) { + const DFromV du32; + const RebindToSigned di32; +#if HWY_TARGET >= HWY_SSSE3 + const auto too_big = + VecFromMask(di32, Gt(BitCast(di32, ShiftRight<16>(v)), Zero(di32))); + const auto clamped = Or(BitCast(di32, v), too_big); +#if HWY_TARGET == HWY_SSE2 + const RebindToSigned di16; + return BitCast(du16, DemoteTo(di16, ShiftRight<16>(ShiftLeft<16>(clamped)))); +#else + (void)du16; + const Repartition du16_full; + // Lower 2 bytes from each 32-bit lane; same as return type for fewer casts. + alignas(16) static constexpr uint16_t kLower2Bytes[16] = { + 0x0100, 0x0504, 0x0908, 0x0D0C, 0x8080, 0x8080, 0x8080, 0x8080}; + const auto lo2 = Load(du16_full, kLower2Bytes); + return VFromD{TableLookupBytes(BitCast(du16_full, clamped), lo2).raw}; +#endif +#else + return DemoteTo(du16, BitCast(di32, Min(v, Set(du32, 0x7FFFFFFF)))); +#endif +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + const __m128i i16 = _mm_packs_epi32(v.raw, v.raw); + return VFromD{_mm_packus_epi16(i16, i16)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_packus_epi16(v.raw, v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + const __m128i i16 = _mm_packs_epi32(v.raw, v.raw); + return VFromD{_mm_packs_epi16(i16, i16)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_packs_epi16(v.raw, v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D du8, VFromD> v) { +#if HWY_TARGET <= HWY_AVX3 + // NOTE: _mm_cvtusepi32_epi8 is a saturated conversion of 32-bit unsigned + // integers to 8-bit unsigned integers + (void)du8; + return VFromD{_mm_cvtusepi32_epi8(v.raw)}; +#else + const DFromV du32; + const RebindToSigned di32; + const auto max_i32 = Set(du32, 0x7FFFFFFFu); + +#if HWY_TARGET >= HWY_SSSE3 + // On SSE2/SSSE3, clamp u32 values to an i32 using the u8 Min operation + // as SSE2/SSSE3 can do an u8 Min operation in a single instruction. + + // The u8 Min operation below leaves the lower 24 bits of each 32-bit + // lane unchanged. + + // The u8 Min operation below will leave any values that are less than or + // equal to 0x7FFFFFFF unchanged. + + // For values that are greater than or equal to 0x80000000, the u8 Min + // operation below will force the upper 8 bits to 0x7F and leave the lower + // 24 bits unchanged. + + // An u8 Min operation is okay here as any clamped value that is greater than + // or equal to 0x80000000 will be clamped to a value between 0x7F000000 and + // 0x7FFFFFFF through the u8 Min operation below, which will then be converted + // to 0xFF through the i32->u8 demotion. + const Repartition du32_as_du8; + const auto clamped = BitCast( + di32, Min(BitCast(du32_as_du8, v), BitCast(du32_as_du8, max_i32))); +#else + const auto clamped = BitCast(di32, Min(v, max_i32)); +#endif + + return DemoteTo(du8, clamped); +#endif +} + +template +HWY_API VFromD DemoteTo(D du8, VFromD> v) { + const DFromV du16; + const RebindToSigned di16; + const auto max_i16 = Set(du16, 0x7FFF); + +#if HWY_TARGET >= HWY_SSSE3 + // On SSE2/SSSE3, clamp u16 values to an i16 using the u8 Min operation + // as SSE2/SSSE3 can do an u8 Min operation in a single instruction. + + // The u8 Min operation below leaves the lower 8 bits of each 16-bit + // lane unchanged. + + // The u8 Min operation below will leave any values that are less than or + // equal to 0x7FFF unchanged. + + // For values that are greater than or equal to 0x8000, the u8 Min + // operation below will force the upper 8 bits to 0x7F and leave the lower + // 8 bits unchanged. + + // An u8 Min operation is okay here as any clamped value that is greater than + // or equal to 0x8000 will be clamped to a value between 0x7F00 and + // 0x7FFF through the u8 Min operation below, which will then be converted + // to 0xFF through the i16->u8 demotion. + const Repartition du16_as_du8; + const auto clamped = BitCast( + di16, Min(BitCast(du16_as_du8, v), BitCast(du16_as_du8, max_i16))); +#else + const auto clamped = BitCast(di16, Min(v, max_i16)); +#endif + + return DemoteTo(du8, clamped); +} + +#if HWY_TARGET < HWY_SSE4 && !defined(HWY_DISABLE_F16C) + +// HWY_NATIVE_F16C was already toggled above. + +// Work around MSVC warning for _mm_cvtps_ph (8 is actually a valid immediate). +// clang-cl requires a non-empty string, so we 'ignore' the irrelevant -Wmain. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4556, ignored "-Wmain") + +template +HWY_API VFromD DemoteTo(D df16, VFromD> v) { + const RebindToUnsigned du16; + return BitCast( + df16, VFromD{_mm_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}); +} + +HWY_DIAGNOSTICS(pop) + +#endif // F16C + +#if HWY_HAVE_FLOAT16 + +#ifdef HWY_NATIVE_DEMOTE_F64_TO_F16 +#undef HWY_NATIVE_DEMOTE_F64_TO_F16 +#else +#define HWY_NATIVE_DEMOTE_F64_TO_F16 +#endif + +template +HWY_API VFromD DemoteTo(D /*df16*/, VFromD> v) { + return VFromD{_mm_cvtpd_ph(v.raw)}; +} + +#endif // HWY_HAVE_FLOAT16 + +// The _mm*_cvtneps_pbh and _mm*_cvtne2ps_pbh intrinsics require GCC 9 or later +// or Clang 10 or later + +// Also need GCC or Clang to bit cast the __m128bh, __m256bh, or __m512bh vector +// returned by the _mm*_cvtneps_pbh and _mm*_cvtne2ps_pbh intrinsics to a +// __m128i, __m256i, or __m512i as there are currently no intrinsics available +// (as of GCC 13 and Clang 17) to bit cast a __m128bh, __m256bh, or __m512bh +// vector to a __m128i, __m256i, or __m512i vector + +#if HWY_AVX3_HAVE_F32_TO_BF16C +#ifdef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#undef HWY_NATIVE_DEMOTE_F32_TO_BF16 +#else +#define HWY_NATIVE_DEMOTE_F32_TO_BF16 +#endif + +template +HWY_API VFromD DemoteTo(D /*dbf16*/, VFromD> v) { +#if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 + // Inline assembly workaround for LLVM codegen bug + __m128i raw_result; + __asm__("vcvtneps2bf16 %1, %0" : "=v"(raw_result) : "v"(v.raw)); + return VFromD{raw_result}; +#else + // The _mm_cvtneps_pbh intrinsic returns a __m128bh vector that needs to be + // bit casted to a __m128i vector + return VFromD{detail::BitCastToInteger(_mm_cvtneps_pbh(v.raw))}; +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D /*dbf16*/, Vec128 a, + Vec128 b) { +#if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 + // Inline assembly workaround for LLVM codegen bug + __m128i raw_result; + __asm__("vcvtne2ps2bf16 %2, %1, %0" + : "=v"(raw_result) + : "v"(b.raw), "v"(a.raw)); + return VFromD{raw_result}; +#else + // The _mm_cvtne2ps_pbh intrinsic returns a __m128bh vector that needs to be + // bit casted to a __m128i vector + return VFromD{detail::BitCastToInteger(_mm_cvtne2ps_pbh(b.raw, a.raw))}; +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec64 a, + Vec64 b) { + return VFromD{_mm_shuffle_epi32( + detail::BitCastToInteger(_mm_cvtne2ps_pbh(b.raw, a.raw)), + _MM_SHUFFLE(2, 0, 2, 0))}; +} + +template +HWY_API VFromD ReorderDemote2To(D dbf16, Vec32 a, Vec32 b) { + const DFromV d; + const Twice dt; + return DemoteTo(dbf16, Combine(dt, b, a)); +} +#endif // HWY_AVX3_HAVE_F32_TO_BF16C + +// Specializations for partial vectors because packs_epi32 sets lanes above 2*N. +template +HWY_API VFromD ReorderDemote2To(D dn, Vec32 a, Vec32 b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec64 a, + Vec64 b) { + return VFromD{_mm_shuffle_epi32(_mm_packs_epi32(a.raw, b.raw), + _MM_SHUFFLE(2, 0, 2, 0))}; +} +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return VFromD{_mm_packs_epi32(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D dn, Vec32 a, Vec32 b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +template +HWY_API VFromD ReorderDemote2To(D dn, Vec64 a, Vec64 b) { +#if HWY_TARGET >= HWY_SSSE3 + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +#else + (void)dn; + return VFromD{_mm_shuffle_epi32(_mm_packus_epi32(a.raw, b.raw), + _MM_SHUFFLE(2, 0, 2, 0))}; +#endif +} +template +HWY_API VFromD ReorderDemote2To(D dn, Vec128 a, Vec128 b) { +#if HWY_TARGET >= HWY_SSSE3 + const Half dnh; + const auto u16_a = DemoteTo(dnh, a); + const auto u16_b = DemoteTo(dnh, b); + return Combine(dn, u16_b, u16_a); +#else + (void)dn; + return VFromD{_mm_packus_epi32(a.raw, b.raw)}; +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D dn, Vec128 a, + Vec128 b) { + const DFromV du32; + const RebindToSigned di32; + const auto max_i32 = Set(du32, 0x7FFFFFFFu); + +#if HWY_TARGET >= HWY_SSSE3 + const Repartition du32_as_du8; + // On SSE2/SSSE3, clamp a and b using u8 Min operation + const auto clamped_a = BitCast( + di32, Min(BitCast(du32_as_du8, a), BitCast(du32_as_du8, max_i32))); + const auto clamped_b = BitCast( + di32, Min(BitCast(du32_as_du8, b), BitCast(du32_as_du8, max_i32))); +#else + const auto clamped_a = BitCast(di32, Min(a, max_i32)); + const auto clamped_b = BitCast(di32, Min(b, max_i32)); +#endif + + return ReorderDemote2To(dn, clamped_a, clamped_b); +} + +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} + +// Specializations for partial vectors because packs_epi32 sets lanes above 2*N. +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec64 a, + Vec64 b) { + return VFromD{_mm_shuffle_epi32(_mm_packs_epi16(a.raw, b.raw), + _MM_SHUFFLE(2, 0, 2, 0))}; +} +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return VFromD{_mm_packs_epi16(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec64 a, + Vec64 b) { + return VFromD{_mm_shuffle_epi32(_mm_packus_epi16(a.raw, b.raw), + _MM_SHUFFLE(2, 0, 2, 0))}; +} +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec128 a, + Vec128 b) { + return VFromD{_mm_packus_epi16(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D dn, Vec128 a, + Vec128 b) { + const DFromV du16; + const RebindToSigned di16; + const auto max_i16 = Set(du16, 0x7FFFu); + +#if HWY_TARGET >= HWY_SSSE3 + const Repartition du16_as_du8; + // On SSE2/SSSE3, clamp a and b using u8 Min operation + const auto clamped_a = BitCast( + di16, Min(BitCast(du16_as_du8, a), BitCast(du16_as_du8, max_i16))); + const auto clamped_b = BitCast( + di16, Min(BitCast(du16_as_du8, b), BitCast(du16_as_du8, max_i16))); +#else + const auto clamped_a = BitCast(di16, Min(a, max_i16)); + const auto clamped_b = BitCast(di16, Min(b, max_i16)); +#endif + + return ReorderDemote2To(dn, clamped_a, clamped_b); +} + +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} + +template ), + HWY_IF_V_SIZE_LE_D(D, 16), class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV) * 2)> +HWY_API VFromD OrderedDemote2To(D d, V a, V b) { + return ReorderDemote2To(d, a, b); +} + +#if HWY_AVX3_HAVE_F32_TO_BF16C +// F32 to BF16 OrderedDemote2To is generic for all vector lengths on targets +// that support AVX512BF16 +template +HWY_API VFromD OrderedDemote2To(D dbf16, VFromD> a, + VFromD> b) { + return ReorderDemote2To(dbf16, a, b); +} +#endif // HWY_AVX3_HAVE_F32_TO_BF16C + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtpd_ps(v.raw)}; +} + +namespace detail { + +// Generic for all vector lengths. +template +HWY_INLINE VFromD ClampF64ToI32Max(D d, VFromD v) { + // The max can be exactly represented in binary64, so clamping beforehand + // prevents x86 conversion from raising an exception and returning 80..00. + return Min(v, Set(d, 2147483647.0)); +} + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD +template +static constexpr HWY_INLINE TTo +X86ConvertScalarFromFloat(hwy::FloatTag /* to_type_tag */, TF from_val) { + return ConvertScalarTo(from_val); +} + +template +static HWY_BITCASTSCALAR_CONSTEXPR HWY_INLINE TTo +X86ConvertScalarFromFloat(hwy::SpecialTag /* to_type_tag */, TF from_val) { + return ConvertScalarTo(from_val); +} + +template +static HWY_BITCASTSCALAR_CXX14_CONSTEXPR HWY_INLINE TTo +X86ConvertScalarFromFloat(hwy::SignedTag /* to_type_tag */, TF from_val) { +#if HWY_HAVE_SCALAR_F16_TYPE && HWY_HAVE_SCALAR_F16_OPERATORS + using TFArith = If, hwy::bfloat16_t>(), float, + RemoveCvRef>; +#else + using TFArith = If>; +#endif + + const TFArith from_val_in_arith_type = ConvertScalarTo(from_val); + constexpr TTo kMinResultVal = LimitsMin(); + HWY_BITCASTSCALAR_CONSTEXPR const TFArith kMinOutOfRangePosVal = + ScalarAbs(ConvertScalarTo(kMinResultVal)); + + return (ScalarAbs(from_val_in_arith_type) < kMinOutOfRangePosVal) + ? ConvertScalarTo(from_val_in_arith_type) + : kMinResultVal; +} + +template +static HWY_CXX14_CONSTEXPR HWY_INLINE TTo +X86ConvertScalarFromFloat(hwy::UnsignedTag /* to_type_tag */, TF from_val) { +#if HWY_HAVE_SCALAR_F16_TYPE && HWY_HAVE_SCALAR_F16_OPERATORS + using TFArith = If, hwy::bfloat16_t>(), float, + RemoveCvRef>; +#else + using TFArith = If>; +#endif + + const TFArith from_val_in_arith_type = ConvertScalarTo(from_val); + constexpr TTo kTToMsb = static_cast(TTo{1} << (sizeof(TTo) * 8 - 1)); + constexpr const TFArith kNegOne = ConvertScalarTo(-1.0); + constexpr const TFArith kMinOutOfRangePosVal = + ConvertScalarTo(static_cast(kTToMsb) * 2.0); + + return (from_val_in_arith_type > kNegOne && + from_val_in_arith_type < kMinOutOfRangePosVal) + ? ConvertScalarTo(from_val_in_arith_type) + : LimitsMax(); +} + +template +static constexpr HWY_INLINE HWY_MAYBE_UNUSED TTo +X86ConvertScalarFromFloat(TF from_val) { + return X86ConvertScalarFromFloat(hwy::TypeTag>(), + from_val); +} + +#endif // HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + +} // namespace detail + +#ifdef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F64_TO_UI32_DEMOTE_IN_RANGE_TO +#endif + +template +HWY_API VFromD DemoteInRangeTo(D /* tag */, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm_cvtts_pd_epi32(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttpd_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), int32_t{0}, + int32_t{0}); + } +#endif + + __m128i raw_result; + __asm__("%vcvttpd2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvttpd_epi32(v.raw)}; +#endif +} + +// F64 to I32 DemoteTo is generic for all vector lengths +template +HWY_API VFromD DemoteTo(D di32, VFromD> v) { + const Rebind df64; + const VFromD clamped = detail::ClampF64ToI32Max(df64, v); + return DemoteInRangeTo(di32, clamped); +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD DemoteInRangeTo(D /* tag */, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm_cvtts_pd_epu32(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttpd_epu32 with GCC if any + // values of v[i] are not within the range of an uint32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), uint32_t{0}, + uint32_t{0}); + } +#endif + + __m128i raw_result; + __asm__("vcvttpd2udq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm_cvttpd_epu32(v.raw)}; +#endif +} + +// F64->U32 DemoteTo is generic for all vector lengths +template +HWY_API VFromD DemoteTo(D du32, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return DemoteInRangeTo(du32, v); +#else + return DemoteInRangeTo(du32, ZeroIfNegative(v)); +#endif +} +#else // HWY_TARGET > HWY_AVX3 + +// F64 to U32 DemoteInRangeTo is generic for all vector lengths on +// SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD DemoteInRangeTo(D du32, VFromD> v) { + const RebindToSigned di32; + const Rebind df64; + const RebindToUnsigned du64; + + const auto k2_31 = Set(df64, 2147483648.0); + const auto v_is_ge_k2_31 = (v >= k2_31); + const auto clamped_lo31_f64 = v - IfThenElseZero(v_is_ge_k2_31, k2_31); + const auto clamped_lo31_u32 = + BitCast(du32, DemoteInRangeTo(di32, clamped_lo31_f64)); + const auto clamped_u32_msb = ShiftLeft<31>( + TruncateTo(du32, BitCast(du64, VecFromMask(df64, v_is_ge_k2_31)))); + return Or(clamped_lo31_u32, clamped_u32_msb); +} + +// F64 to U32 DemoteTo is generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD DemoteTo(D du32, VFromD> v) { + const Rebind df64; + const auto clamped = Min(ZeroIfNegative(v), Set(df64, 4294967295.0)); + return DemoteInRangeTo(du32, clamped); +} +#endif // HWY_TARGET <= HWY_AVX3 + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtepi64_ps(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtepu64_ps(v.raw)}; +} +#else +// Generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD DemoteTo(D df32, VFromD> v) { + const Rebind df64; + const RebindToUnsigned du64; + const RebindToSigned di32; + const RebindToUnsigned du32; + + const auto k2p64_63 = Set(df64, 27670116110564327424.0); + const auto f64_hi52 = + Xor(BitCast(df64, ShiftRight<12>(BitCast(du64, v))), k2p64_63) - k2p64_63; + const auto f64_lo12 = + PromoteTo(df64, BitCast(di32, And(TruncateTo(du32, BitCast(du64, v)), + Set(du32, uint32_t{0x00000FFF})))); + + const auto f64_sum = f64_hi52 + f64_lo12; + const auto f64_carry = (f64_hi52 - f64_sum) + f64_lo12; + + const auto f64_sum_is_inexact = + ShiftRight<63>(BitCast(du64, VecFromMask(df64, f64_carry != Zero(df64)))); + const auto f64_bits_decrement = + And(ShiftRight<63>(BitCast(du64, Xor(f64_sum, f64_carry))), + f64_sum_is_inexact); + + const auto adj_f64_val = BitCast( + df64, + Or(BitCast(du64, f64_sum) - f64_bits_decrement, f64_sum_is_inexact)); + + return DemoteTo(df32, adj_f64_val); +} + +// Generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD DemoteTo(D df32, VFromD> v) { + const Rebind df64; + const RebindToUnsigned du64; + const RebindToSigned di32; + const RebindToUnsigned du32; + + const auto k2p64 = Set(df64, 18446744073709551616.0); + const auto f64_hi52 = Or(BitCast(df64, ShiftRight<12>(v)), k2p64) - k2p64; + const auto f64_lo12 = + PromoteTo(df64, BitCast(di32, And(TruncateTo(du32, BitCast(du64, v)), + Set(du32, uint32_t{0x00000FFF})))); + + const auto f64_sum = f64_hi52 + f64_lo12; + const auto f64_carry = (f64_hi52 - f64_sum) + f64_lo12; + const auto f64_sum_is_inexact = + ShiftRight<63>(BitCast(du64, VecFromMask(df64, f64_carry != Zero(df64)))); + + const auto adj_f64_val = BitCast( + df64, + Or(BitCast(du64, f64_sum) - ShiftRight<63>(BitCast(du64, f64_carry)), + f64_sum_is_inexact)); + + return DemoteTo(df32, adj_f64_val); +} +#endif + +// For already range-limited input [0, 255]. +template +HWY_API Vec128 U8FromU32(const Vec128 v) { +#if HWY_TARGET == HWY_SSE2 + const RebindToSigned> di32; + const Rebind du8; + return DemoteTo(du8, BitCast(di32, v)); +#else + const DFromV d32; + const Repartition d8; + alignas(16) static constexpr uint32_t k8From32[4] = { + 0x0C080400u, 0x0C080400u, 0x0C080400u, 0x0C080400u}; + // Also replicate bytes into all 32 bit lanes for safety. + const auto quad = TableLookupBytes(v, Load(d32, k8From32)); + return LowerHalf(LowerHalf(BitCast(d8, quad))); +#endif +} + +// ------------------------------ F32->UI64 PromoteTo +#ifdef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#undef HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#else +#define HWY_NATIVE_F32_TO_UI64_PROMOTE_IN_RANGE_TO +#endif + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD PromoteInRangeTo(D /*di64*/, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm_cvtts_ps_epi64(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior with GCC if any values of v[i] are not + // within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1])); + } +#endif + + __m128i raw_result; + __asm__("vcvttps2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm_cvttps_epi64(v.raw)}; +#endif +} + +// Generic for all vector lengths. +template +HWY_API VFromD PromoteTo(D di64, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return PromoteInRangeTo(di64, v); +#else + const Rebind df32; + const RebindToFloat df64; + // We now avoid GCC UB in PromoteInRangeTo via assembly, see #2189 and + // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=115115. Previously we fixed up + // the result afterwards using three instructions. Now we instead check if + // v >= 2^63, and if so replace the output with 2^63-1, which is likely more + // efficient. Note that the previous representable f32 is less than 2^63 and + // thus fits in i64. + const MFromD overflow = RebindMask( + di64, PromoteMaskTo(df64, df32, Ge(v, Set(df32, 9.223372e18f)))); + return IfThenElse(overflow, Set(di64, LimitsMax()), + PromoteInRangeTo(di64, v)); +#endif +} +template +HWY_API VFromD PromoteTo(D du64, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return PromoteInRangeTo(du64, v); +#else + return PromoteInRangeTo(du64, ZeroIfNegative(v)); +#endif +} +template +HWY_API VFromD PromoteInRangeTo(D /* tag */, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm_cvtts_ps_epu64(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior with GCC if any values of v[i] are not + // within the range of an uint64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1])); + } +#endif + + __m128i raw_result; + __asm__("vcvttps2uqq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm_cvttps_epu64(v.raw)}; +#endif +} +#else // AVX2 or below + +// Generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD PromoteTo(D di64, VFromD> v) { + const Rebind di32; + const RebindToFloat df32; + const RebindToUnsigned du32; + const Repartition du32_as_du8; + + const auto exponent_adj = BitCast( + du32, + Min(SaturatedSub(BitCast(du32_as_du8, ShiftRight<23>(BitCast(du32, v))), + BitCast(du32_as_du8, Set(du32, uint32_t{157}))), + BitCast(du32_as_du8, Set(du32, uint32_t{32})))); + const auto adj_v = + BitCast(df32, BitCast(du32, v) - ShiftLeft<23>(exponent_adj)); + + const auto f32_to_i32_result = ConvertTo(di32, adj_v); + const auto lo64_or_mask = PromoteTo( + di64, + BitCast(du32, VecFromMask(di32, Eq(f32_to_i32_result, + Set(di32, LimitsMax()))))); + + return Or(PromoteTo(di64, BitCast(di32, f32_to_i32_result)) + << PromoteTo(di64, exponent_adj), + lo64_or_mask); +} + +// Generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD PromoteInRangeTo(D d64, VFromD> v) { + const Rebind>, decltype(d64)> d32; + const RebindToSigned di32; + const RebindToFloat df32; + const RebindToUnsigned du32; + const Repartition du32_as_du8; + + const auto exponent_adj = BitCast( + du32, + SaturatedSub(BitCast(du32_as_du8, ShiftRight<23>(BitCast(du32, v))), + BitCast(du32_as_du8, Set(du32, uint32_t{0xFFFFFF9Du})))); + const auto adj_v = + BitCast(df32, BitCast(du32, v) - ShiftLeft<23>(exponent_adj)); + + const auto f32_to_i32_result = ConvertInRangeTo(di32, adj_v); + return PromoteTo(d64, BitCast(d32, f32_to_i32_result)) + << PromoteTo(d64, exponent_adj); +} + +namespace detail { + +template +HWY_INLINE VFromD PromoteF32ToU64OverflowMaskToU64( + DU64 du64, VFromD> i32_overflow_mask) { + const Rebind di32; + const Twice dt_i32; + + const auto vt_i32_overflow_mask = ResizeBitCast(dt_i32, i32_overflow_mask); + return BitCast(du64, + InterleaveLower(vt_i32_overflow_mask, vt_i32_overflow_mask)); +} + +template +HWY_INLINE VFromD PromoteF32ToU64OverflowMaskToU64( + DU64 du64, VFromD> i32_overflow_mask) { + const RebindToSigned di64; + return BitCast(du64, PromoteTo(di64, i32_overflow_mask)); +} + +} // namespace detail + +// Generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD PromoteTo(D du64, VFromD> v) { + const Rebind di32; + const RebindToFloat df32; + const RebindToUnsigned du32; + const Repartition du32_as_du8; + + const auto non_neg_v = ZeroIfNegative(v); + + const auto exponent_adj = BitCast( + du32, Min(SaturatedSub(BitCast(du32_as_du8, + ShiftRight<23>(BitCast(du32, non_neg_v))), + BitCast(du32_as_du8, Set(du32, uint32_t{157}))), + BitCast(du32_as_du8, Set(du32, uint32_t{33})))); + + const auto adj_v = + BitCast(df32, BitCast(du32, non_neg_v) - ShiftLeft<23>(exponent_adj)); + const auto f32_to_i32_result = ConvertInRangeTo(di32, adj_v); + + const auto i32_overflow_mask = BroadcastSignBit(f32_to_i32_result); + const auto overflow_result = + detail::PromoteF32ToU64OverflowMaskToU64(du64, i32_overflow_mask); + + return Or(PromoteTo(du64, BitCast(du32, f32_to_i32_result)) + << PromoteTo(du64, exponent_adj), + overflow_result); +} +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ MulFixedPoint15 + +#if HWY_TARGET == HWY_SSE2 +HWY_API Vec128 MulFixedPoint15(const Vec128 a, + const Vec128 b) { + const DFromV d; + const Repartition di32; + + auto lo_product = a * b; + auto hi_product = MulHigh(a, b); + + const VFromD i32_product_lo{ + _mm_unpacklo_epi16(lo_product.raw, hi_product.raw)}; + const VFromD i32_product_hi{ + _mm_unpackhi_epi16(lo_product.raw, hi_product.raw)}; + + const auto round_up_incr = Set(di32, 0x4000); + return ReorderDemote2To(d, ShiftRight<15>(i32_product_lo + round_up_incr), + ShiftRight<15>(i32_product_hi + round_up_incr)); +} + +template +HWY_API Vec128 MulFixedPoint15(const Vec128 a, + const Vec128 b) { + const DFromV d; + const Rebind di32; + + const auto lo_product = a * b; + const auto hi_product = MulHigh(a, b); + const VFromD i32_product{ + _mm_unpacklo_epi16(lo_product.raw, hi_product.raw)}; + + return DemoteTo(d, ShiftRight<15>(i32_product + Set(di32, 0x4000))); +} +#else +template +HWY_API Vec128 MulFixedPoint15(const Vec128 a, + const Vec128 b) { + return Vec128{_mm_mulhrs_epi16(a.raw, b.raw)}; +} +#endif + +// ------------------------------ Truncations + +template +HWY_API VFromD TruncateTo(DTo /* tag */, Vec128 v) { + // BitCast requires the same size; DTo might be u8x1 and v u16x1. + const Repartition, DFromV> dto; + return VFromD{BitCast(dto, v).raw}; +} + +template +HWY_API VFromD TruncateTo(D d, Vec128 v) { +#if HWY_TARGET == HWY_SSE2 + const Vec128 lo{v.raw}; + const Vec128 hi{_mm_unpackhi_epi64(v.raw, v.raw)}; + return Combine(d, hi, lo); +#else + const Repartition> d8; + (void)d; + alignas(16) static constexpr uint8_t kIdx[16] = {0, 8, 0, 8, 0, 8, 0, 8, + 0, 8, 0, 8, 0, 8, 0, 8}; + const Vec128 v8 = TableLookupBytes(v, Load(d8, kIdx)); + return LowerHalf(LowerHalf(LowerHalf(v8))); +#endif +} + +template +HWY_API VFromD TruncateTo(D d, Vec128 v) { +#if HWY_TARGET == HWY_SSE2 + const Vec128 lo{v.raw}; + const Vec128 hi{_mm_unpackhi_epi64(v.raw, v.raw)}; + return Combine(d, hi, lo); +#else + (void)d; + const Repartition> d16; + alignas(16) static constexpr uint16_t kIdx[8] = { + 0x100u, 0x908u, 0x100u, 0x908u, 0x100u, 0x908u, 0x100u, 0x908u}; + const Vec128 v16 = TableLookupBytes(v, Load(d16, kIdx)); + return LowerHalf(LowerHalf(v16)); +#endif +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec128 v) { + return VFromD{_mm_shuffle_epi32(v.raw, 0x88)}; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, VFromD> v) { + const DFromV du32; +#if HWY_TARGET == HWY_SSE2 + const RebindToSigned di32; + const Rebind du8; + return DemoteTo(du8, BitCast(di32, ShiftRight<24>(ShiftLeft<24>(v)))); +#else + const Repartition d; + alignas(16) static constexpr uint8_t kIdx[16] = { + 0x0u, 0x4u, 0x8u, 0xCu, 0x0u, 0x4u, 0x8u, 0xCu, + 0x0u, 0x4u, 0x8u, 0xCu, 0x0u, 0x4u, 0x8u, 0xCu}; + return LowerHalf(LowerHalf(TableLookupBytes(v, Load(d, kIdx)))); +#endif +} + +template +HWY_API VFromD TruncateTo(D /* tag */, VFromD> v) { + const DFromV du32; +#if HWY_TARGET == HWY_SSE2 + const RebindToSigned di32; + const Rebind du16; + const RebindToSigned di16; + return BitCast( + du16, DemoteTo(di16, ShiftRight<16>(BitCast(di32, ShiftLeft<16>(v))))); +#else + const Repartition d; + return LowerHalf(ConcatEven(d, BitCast(d, v), BitCast(d, v))); +#endif +} + +template +HWY_API VFromD TruncateTo(D /* tag */, VFromD> v) { + const DFromV du16; +#if HWY_TARGET == HWY_SSE2 + const RebindToSigned di16; + const Rebind du8; + const RebindToSigned di8; + return BitCast(du8, + DemoteTo(di8, ShiftRight<8>(BitCast(di16, ShiftLeft<8>(v))))); +#else + const Repartition d; + return LowerHalf(ConcatEven(d, BitCast(d, v), BitCast(d, v))); +#endif +} + +// ------------------------------ Demotions to/from i64 + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtsepi64_epi32(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtsepi64_epi16(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtsepi64_epi8(v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + const __mmask8 non_neg_mask = detail::UnmaskedNot(MaskFromVec(v)).raw; + return VFromD{_mm_maskz_cvtusepi64_epi32(non_neg_mask, v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + const __mmask8 non_neg_mask = detail::UnmaskedNot(MaskFromVec(v)).raw; + return VFromD{_mm_maskz_cvtusepi64_epi16(non_neg_mask, v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + const __mmask8 non_neg_mask = detail::UnmaskedNot(MaskFromVec(v)).raw; + return VFromD{_mm_maskz_cvtusepi64_epi8(non_neg_mask, v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtusepi64_epi32(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtusepi64_epi16(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtusepi64_epi8(v.raw)}; +} +#else // AVX2 or below + +// Disable the default unsigned to signed DemoteTo/ReorderDemote2To +// implementations in generic_ops-inl.h for U64->I8/I16/I32 demotions on +// SSE2/SSSE3/SSE4/AVX2 as U64->I8/I16/I32 DemoteTo/ReorderDemote2To for +// SSE2/SSSE3/SSE4/AVX2 is implemented in x86_128-inl.h + +// The default unsigned to signed DemoteTo/ReorderDemote2To +// implementations in generic_ops-inl.h are still used for U32->I8/I16 and +// U16->I8 demotions on SSE2/SSSE3/SSE4/AVX2 + +#undef HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V +#define HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V(V) HWY_IF_NOT_T_SIZE_V(V, 8) + +namespace detail { +template +HWY_INLINE VFromD> DemoteFromU64MaskOutResult( + D /*dn*/, VFromD> v) { + return v; +} + +template +HWY_INLINE VFromD> DemoteFromU64MaskOutResult( + D /*dn*/, VFromD> v) { + const DFromV du64; + return And(v, + Set(du64, static_cast(hwy::HighestValue>()))); +} + +template +HWY_INLINE VFromD> DemoteFromU64Saturate( + D dn, VFromD> v) { + const Rebind du64; + const RebindToSigned di64; + constexpr int kShiftAmt = static_cast(sizeof(TFromD) * 8) - + static_cast(hwy::IsSigned>()); + + const auto too_big = BitCast( + du64, VecFromMask( + di64, Gt(BitCast(di64, ShiftRight(v)), Zero(di64)))); + return DemoteFromU64MaskOutResult(dn, Or(v, too_big)); +} + +template +HWY_INLINE VFromD ReorderDemote2From64To32Combine(D dn, V a, V b) { + return ConcatEven(dn, BitCast(dn, b), BitCast(dn, a)); +} + +} // namespace detail + +template +HWY_API VFromD DemoteTo(D dn, VFromD> v) { + const DFromV di64; + const RebindToUnsigned du64; + const RebindToUnsigned dn_u; + + // Negative values are saturated by first saturating their bitwise inverse + // and then inverting the saturation result + const auto invert_mask = BitCast(du64, BroadcastSignBit(v)); + const auto saturated_vals = Xor( + invert_mask, + detail::DemoteFromU64Saturate(dn, Xor(invert_mask, BitCast(du64, v)))); + return BitCast(dn, TruncateTo(dn_u, saturated_vals)); +} + +template +HWY_API VFromD DemoteTo(D dn, VFromD> v) { + const DFromV di64; + const RebindToUnsigned du64; + + const auto non_neg_vals = BitCast(du64, AndNot(BroadcastSignBit(v), v)); + return TruncateTo(dn, detail::DemoteFromU64Saturate(dn, non_neg_vals)); +} + +template +HWY_API VFromD DemoteTo(D dn, VFromD> v) { + const RebindToUnsigned dn_u; + return BitCast(dn, TruncateTo(dn_u, detail::DemoteFromU64Saturate(dn, v))); +} + +#if HWY_TARGET == HWY_SSE2 +template +HWY_API VFromD DemoteTo(D dn, VFromD> v) { + const Rebind di32; + return DemoteTo(dn, DemoteTo(di32, v)); +} +#endif // HWY_TARGET == HWY_SSE2 + +template +HWY_API VFromD DemoteTo(D dn, VFromD> v) { + return TruncateTo(dn, detail::DemoteFromU64Saturate(dn, v)); +} +#endif // HWY_TARGET <= HWY_AVX3 + +template )> +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} + +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} + +#if HWY_TARGET > HWY_AVX3 +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const DFromV d; + const Twice dt; + return DemoteTo(dn, Combine(dt, b, a)); +} +#endif + +#if HWY_TARGET > HWY_AVX2 +template +HWY_API Vec128 ReorderDemote2To(D dn, Vec128 a, + Vec128 b) { + const DFromV di64; + const RebindToUnsigned du64; + const Half dnh; + + // Negative values are saturated by first saturating their bitwise inverse + // and then inverting the saturation result + const auto invert_mask_a = BitCast(du64, BroadcastSignBit(a)); + const auto invert_mask_b = BitCast(du64, BroadcastSignBit(b)); + const auto saturated_a = Xor( + invert_mask_a, + detail::DemoteFromU64Saturate(dnh, Xor(invert_mask_a, BitCast(du64, a)))); + const auto saturated_b = Xor( + invert_mask_b, + detail::DemoteFromU64Saturate(dnh, Xor(invert_mask_b, BitCast(du64, b)))); + + return ConcatEven(dn, BitCast(dn, saturated_b), BitCast(dn, saturated_a)); +} + +template +HWY_API Vec128 ReorderDemote2To(D dn, Vec128 a, + Vec128 b) { + const DFromV di64; + const RebindToUnsigned du64; + const Half dnh; + + const auto saturated_a = detail::DemoteFromU64Saturate( + dnh, BitCast(du64, AndNot(BroadcastSignBit(a), a))); + const auto saturated_b = detail::DemoteFromU64Saturate( + dnh, BitCast(du64, AndNot(BroadcastSignBit(b), b))); + + return ConcatEven(dn, BitCast(dn, saturated_b), BitCast(dn, saturated_a)); +} + +template +HWY_API VFromD ReorderDemote2To(D dn, Vec128 a, + Vec128 b) { + const Half dnh; + + const auto saturated_a = detail::DemoteFromU64Saturate(dnh, a); + const auto saturated_b = detail::DemoteFromU64Saturate(dnh, b); + + return ConcatEven(dn, BitCast(dn, saturated_b), BitCast(dn, saturated_a)); +} +#endif // HWY_TARGET > HWY_AVX2 + +// ------------------------------ Integer <=> fp (ShiftRight, OddEven) + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtepu16_ph(v.raw)}; +} +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtepi16_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return VFromD{_mm_cvtepi32_ps(v.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD ConvertTo(D /*df*/, VFromD> v) { + return VFromD{_mm_cvtepu32_ps(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /*dd*/, VFromD> v) { + return VFromD{_mm_cvtepi64_pd(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /*dd*/, VFromD> v) { + return VFromD{_mm_cvtepu64_pd(v.raw)}; +} +#else // AVX2 or below +// Generic for all vector lengths. +template +HWY_API VFromD ConvertTo(D df, VFromD> v) { + // Based on wim's approach (https://stackoverflow.com/questions/34066228/) + const RebindToUnsigned du32; + const RebindToSigned d32; + + const auto msk_lo = Set(du32, 0xFFFF); + const auto cnst2_16_flt = Set(df, 65536.0f); // 2^16 + + // Extract the 16 lowest/highest significant bits of v and cast to signed int + const auto v_lo = BitCast(d32, And(v, msk_lo)); + const auto v_hi = BitCast(d32, ShiftRight<16>(v)); + return MulAdd(cnst2_16_flt, ConvertTo(df, v_hi), ConvertTo(df, v_lo)); +} + +// Generic for all vector lengths. +template +HWY_API VFromD ConvertTo(D dd, VFromD> v) { + // Based on wim's approach (https://stackoverflow.com/questions/41144668/) + const Repartition d32; + const Repartition d64; + + // Toggle MSB of lower 32-bits and insert exponent for 2^84 + 2^63 + const auto k84_63 = Set(d64, 0x4530000080000000ULL); + const auto v_upper = BitCast(dd, ShiftRight<32>(BitCast(d64, v)) ^ k84_63); + + // Exponent is 2^52, lower 32 bits from v (=> 32-bit OddEven) + const auto k52 = Set(d32, 0x43300000); + const auto v_lower = BitCast(dd, OddEven(k52, BitCast(d32, v))); + + const auto k84_63_52 = BitCast(dd, Set(d64, 0x4530000080100000ULL)); + return (v_upper - k84_63_52) + v_lower; // order matters! +} + +namespace detail { +template +HWY_INLINE VFromD>> U64ToF64VecFast(VW w) { + const DFromV d64; + const RebindToFloat dd; + const auto cnst2_52_dbl = Set(dd, 0x0010000000000000); // 2^52 + return BitCast(dd, Or(w, BitCast(d64, cnst2_52_dbl))) - cnst2_52_dbl; +} +} // namespace detail + +// Generic for all vector lengths. +template +HWY_API VFromD ConvertTo(D dd, VFromD> v) { + // Based on wim's approach (https://stackoverflow.com/questions/41144668/) + const RebindToUnsigned d64; + using VU = VFromD; + + const VU msk_lo = Set(d64, 0xFFFFFFFF); + const auto cnst2_32_dbl = Set(dd, 4294967296.0); // 2^32 + + // Extract the 32 lowest/highest significant bits of v + const VU v_lo = And(v, msk_lo); + const VU v_hi = ShiftRight<32>(v); + + const auto v_lo_dbl = detail::U64ToF64VecFast(v_lo); + return MulAdd(cnst2_32_dbl, detail::U64ToF64VecFast(v_hi), v_lo_dbl); +} +#endif // HWY_TARGET <= HWY_AVX3 + +// Truncates (rounds toward zero). + +#ifdef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#undef HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#else +#define HWY_NATIVE_F2I_CONVERT_IN_RANGE_TO +#endif + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD ConvertInRangeTo(D /*di*/, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttph_epi16 if any values of v[i] + // are not within the range of an int16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7])); + } +#endif + + __m128i raw_result; + __asm__("vcvttph2w {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvttph_epi16(v.raw)}; +#endif +} + +// F16 to I16 ConvertTo is generic for all vector lengths +template +HWY_API VFromD ConvertTo(D di, VFromD> v) { + const RebindToFloat df; + // See comment at the first occurrence of "IfThenElse(overflow,". + const MFromD overflow = + RebindMask(di, Ge(v, Set(df, ConvertScalarTo(32768.0f)))); + return IfThenElse(overflow, Set(di, LimitsMax()), + ConvertInRangeTo(di, v)); +} + +template +HWY_API VFromD ConvertInRangeTo(D /* tag */, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttph_epu16 if any values of v[i] + // are not within the range of an uint16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7])); + } +#endif + + __m128i raw_result; + __asm__("vcvttph2uw {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvttph_epu16(v.raw)}; +#endif +} + +// F16->U16 ConvertTo is generic for all vector lengths +template +HWY_API VFromD ConvertTo(D /* tag */, VFromD> v) { + return ConvertInRangeTo(D(), ZeroIfNegative(v)); +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API VFromD ConvertInRangeTo(D /*di*/, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm_cvtts_ps_epi32(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttps_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3])); + } +#endif + + __m128i raw_result; + __asm__("%vcvttps2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvttps_epi32(v.raw)}; +#endif +} + +// F32 to I32 ConvertTo is generic for all vector lengths +template +HWY_API VFromD ConvertTo(D di, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return ConvertInRangeTo(di, v); +#else + const RebindToFloat df; + // See comment at the first occurrence of "IfThenElse(overflow,". + const MFromD overflow = RebindMask(di, Ge(v, Set(df, 2147483648.0f))); + return IfThenElse(overflow, Set(di, LimitsMax()), + ConvertInRangeTo(di, v)); +#endif +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD ConvertInRangeTo(DI /*di*/, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm_cvtts_pd_epi64(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttpd_epi64 with GCC if any + // values of v[i] are not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + DI(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1])); + } +#endif + + __m128i raw_result; + __asm__("vcvttpd2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvttpd_epi64(v.raw)}; +#endif +} + +// F64 to I64 ConvertTo is generic for all vector lengths on AVX3 +template +HWY_API VFromD ConvertTo(DI di, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return ConvertInRangeTo(di, v); +#else + const RebindToFloat df; + // See comment at the first occurrence of "IfThenElse(overflow,". + const MFromD overflow = + RebindMask(di, Ge(v, Set(df, 9.223372036854776e18))); + return IfThenElse(overflow, Set(di, LimitsMax()), + ConvertInRangeTo(di, v)); +#endif +} + +template +HWY_API VFromD ConvertInRangeTo(DU /*du*/, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm_cvtts_ps_epu32(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttps_epu32 with GCC if any + // values of v[i] are not within the range of an uint32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + DU(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3])); + } +#endif + + __m128i raw_result; + __asm__("vcvttps2udq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvttps_epu32(v.raw)}; +#endif +} + +// F32->U32 ConvertTo is generic for all vector lengths +template +HWY_API VFromD ConvertTo(DU du32, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return ConvertInRangeTo(du32, v); +#else + return ConvertInRangeTo(du32, ZeroIfNegative(v)); +#endif +} + +template +HWY_API VFromD ConvertInRangeTo(DU /*du*/, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm_cvtts_pd_epu64(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttpd_epu64 with GCC if any + // values of v[i] are not within the range of an uint64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + DU(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1])); + } +#endif + + __m128i raw_result; + __asm__("vcvttpd2uqq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvttpd_epu64(v.raw)}; +#endif +} + +// F64->U64 ConvertTo is generic for all vector lengths +template +HWY_API VFromD ConvertTo(DU du64, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return ConvertInRangeTo(du64, v); +#else + return ConvertInRangeTo(du64, ZeroIfNegative(v)); +#endif +} + +#else // AVX2 or below + +namespace detail { + +template +static HWY_INLINE VFromD ConvInRangeF32ToU32( + DU32 du32, VFromD> v, VFromD& exp_diff) { + const RebindToSigned di32; + const RebindToFloat df32; + + exp_diff = Set(du32, uint32_t{158}) - ShiftRight<23>(BitCast(du32, v)); + const auto scale_down_f32_val_mask = + VecFromMask(du32, Eq(exp_diff, Zero(du32))); + + const auto v_scaled = + BitCast(df32, BitCast(du32, v) + ShiftLeft<23>(scale_down_f32_val_mask)); + const auto f32_to_u32_result = + BitCast(du32, ConvertInRangeTo(di32, v_scaled)); + + return f32_to_u32_result + And(f32_to_u32_result, scale_down_f32_val_mask); +} + +} // namespace detail + +// F32 to U32 ConvertInRangeTo is generic for all vector lengths on +// SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD ConvertInRangeTo(DU32 du32, + VFromD> v) { + VFromD exp_diff; + const auto f32_to_u32_result = detail::ConvInRangeF32ToU32(du32, v, exp_diff); + return f32_to_u32_result; +} + +// F32 to U32 ConvertTo is generic for all vector lengths on +// SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD ConvertTo(DU32 du32, VFromD> v) { + const RebindToSigned di32; + + const auto non_neg_v = ZeroIfNegative(v); + VFromD exp_diff; + const auto f32_to_u32_result = + detail::ConvInRangeF32ToU32(du32, non_neg_v, exp_diff); + + return Or(f32_to_u32_result, + BitCast(du32, BroadcastSignBit(BitCast(di32, exp_diff)))); +} + +namespace detail { + +template +HWY_API VFromD ConvAbsInRangeF64ToUI64(D64 d64, + VFromD> v, + VFromD& biased_exp) { + const RebindToSigned di64; + const RebindToUnsigned du64; + using VU64 = VFromD; + const Repartition du16; + const VU64 k1075 = Set(du64, 1075); /* biased exponent of 2^52 */ + + // Exponent indicates whether the number can be represented as int64_t. + biased_exp = BitCast(d64, ShiftRight<52>(BitCast(du64, v))); + HWY_IF_CONSTEXPR(IsSigned>()) { + biased_exp = And(biased_exp, Set(d64, TFromD{0x7FF})); + } + + // If we were to cap the exponent at 51 and add 2^52, the number would be in + // [2^52, 2^53) and mantissa bits could be read out directly. We need to + // round-to-0 (truncate), but changing rounding mode in MXCSR hits a + // compiler reordering bug: https://gcc.godbolt.org/z/4hKj6c6qc . We instead + // manually shift the mantissa into place (we already have many of the + // inputs anyway). + + // Use 16-bit saturated unsigned subtraction to compute shift_mnt and + // shift_int since biased_exp[i] is a non-negative integer that is less than + // or equal to 2047. + + // 16-bit saturated unsigned subtraction is also more efficient than a + // 64-bit subtraction followed by a 64-bit signed Max operation on + // SSE2/SSSE3/SSE4/AVX2. + + // The upper 48 bits of both shift_mnt and shift_int are guaranteed to be + // zero as the upper 48 bits of both k1075 and biased_exp are zero. + + const VU64 shift_mnt = BitCast( + du64, SaturatedSub(BitCast(du16, k1075), BitCast(du16, biased_exp))); + const VU64 shift_int = BitCast( + du64, SaturatedSub(BitCast(du16, biased_exp), BitCast(du16, k1075))); + const VU64 mantissa = BitCast(du64, v) & Set(du64, (1ULL << 52) - 1); + // Include implicit 1-bit. NOTE: the shift count may exceed 63; we rely on x86 + // returning zero in that case. + const VU64 int53 = (mantissa | Set(du64, 1ULL << 52)) >> shift_mnt; + + // For inputs larger than 2^53 - 1, insert zeros at the bottom. + + // For inputs less than 2^64, the implicit 1-bit is guaranteed not to be + // shifted out of the left shift result below as shift_int[i] <= 11 is true + // for any inputs that are less than 2^64. + + return BitCast(d64, int53 << shift_int); +} + +} // namespace detail + +#if HWY_ARCH_X86_64 + +namespace detail { + +template +static HWY_INLINE int64_t SSE2ConvFirstF64LaneToI64(Vec128 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvttsd_si64 with GCC if v[0] is + // not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (IsConstantX86Vec(hwy::SizeTag<1>(), v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return X86ConvertScalarFromFloat(raw_v[0]); + } +#endif + + int64_t result; + __asm__("%vcvttsd2si {%1, %0|%0, %1}" + : "=r"(result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return result; +#else + return _mm_cvttsd_si64(v.raw); +#endif +} + +} // namespace detail + +template +HWY_API VFromD ConvertInRangeTo(DI /*di*/, Vec64 v) { + return VFromD{_mm_cvtsi64_si128(detail::SSE2ConvFirstF64LaneToI64(v))}; +} +template +HWY_API VFromD ConvertInRangeTo(DI /*di*/, Vec128 v) { + const __m128i i0 = _mm_cvtsi64_si128(detail::SSE2ConvFirstF64LaneToI64(v)); + const Full64 dd2; + const __m128i i1 = + _mm_cvtsi64_si128(detail::SSE2ConvFirstF64LaneToI64(UpperHalf(dd2, v))); + return VFromD{_mm_unpacklo_epi64(i0, i1)}; +} + +template +HWY_API VFromD ConvertTo(DI di, VFromD> v) { + const RebindToFloat df; + // See comment at the first occurrence of "IfThenElse(overflow,". + const MFromD overflow = + RebindMask(di, Ge(v, Set(df, 9.223372036854776e18))); + return IfThenElse(overflow, Set(di, LimitsMax()), + ConvertInRangeTo(di, v)); +} +#endif // HWY_ARCH_X86_64 + +#if !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 +template +HWY_API VFromD ConvertInRangeTo(DI di, VFromD> v) { + using VI = VFromD; + + VI biased_exp; + const VI shifted = detail::ConvAbsInRangeF64ToUI64(di, v, biased_exp); + const VI sign_mask = BroadcastSignBit(BitCast(di, v)); + + // If the input was negative, negate the integer (two's complement). + return (shifted ^ sign_mask) - sign_mask; +} + +template +HWY_API VFromD ConvertTo(DI di, VFromD> v) { + using VI = VFromD; + + VI biased_exp; + const VI shifted = detail::ConvAbsInRangeF64ToUI64(di, v, biased_exp); + +#if HWY_TARGET <= HWY_SSE4 + const auto in_range = biased_exp < Set(di, 1086); +#else + const Repartition di32; + const auto in_range = MaskFromVec(BitCast( + di, + VecFromMask(di32, DupEven(BitCast(di32, biased_exp)) < Set(di32, 1086)))); +#endif + + // Saturate to LimitsMin (unchanged when negating below) or LimitsMax. + const VI sign_mask = BroadcastSignBit(BitCast(di, v)); + const VI limit = Set(di, LimitsMax()) - sign_mask; + const VI magnitude = IfThenElse(in_range, shifted, limit); + + // If the input was negative, negate the integer (two's complement). + return (magnitude ^ sign_mask) - sign_mask; +} +#endif // !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 + +// Generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD ConvertInRangeTo(DU du, VFromD> v) { + VFromD biased_exp; + const auto shifted = detail::ConvAbsInRangeF64ToUI64(du, v, biased_exp); + return shifted; +} + +// Generic for all vector lengths on SSE2/SSSE3/SSE4/AVX2 +template +HWY_API VFromD ConvertTo(DU du, VFromD> v) { + const RebindToSigned di; + using VU = VFromD; + + VU biased_exp; + const VU shifted = + detail::ConvAbsInRangeF64ToUI64(du, ZeroIfNegative(v), biased_exp); + + // Exponent indicates whether the number can be represented as uint64_t. +#if HWY_TARGET <= HWY_SSE4 + const VU out_of_range = + BitCast(du, VecFromMask(di, BitCast(di, biased_exp) > Set(di, 1086))); +#else + const Repartition di32; + const VU out_of_range = BitCast( + du, + VecFromMask(di32, DupEven(BitCast(di32, biased_exp)) > Set(di32, 1086))); +#endif + + return (shifted | out_of_range); +} +#endif // HWY_TARGET <= HWY_AVX3 + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD +namespace detail { + +template +static HWY_INLINE HWY_MAYBE_UNUSED HWY_BITCASTSCALAR_CXX14_CONSTEXPR TTo +X86ScalarNearestInt(TF flt_val) { +#if HWY_HAVE_SCALAR_F16_TYPE && HWY_HAVE_SCALAR_F16_OPERATORS + using TFArith = If, hwy::bfloat16_t>(), float, + RemoveCvRef>; +#else + using TFArith = If>; +#endif + + const TTo trunc_int_val = X86ConvertScalarFromFloat(flt_val); + const TFArith abs_val_diff = ScalarAbs( + ConvertScalarTo(ConvertScalarTo(flt_val) - + ConvertScalarTo(trunc_int_val))); + constexpr TFArith kHalf = ConvertScalarTo(0.5); + + const bool round_result_up = + ((trunc_int_val ^ ScalarShr(trunc_int_val, sizeof(TTo) * 8 - 1)) != + LimitsMax()) && + (abs_val_diff > kHalf || + (abs_val_diff == kHalf && (trunc_int_val & 1) != 0)); + return static_cast( + trunc_int_val + + (round_result_up ? (ScalarSignBit(flt_val) ? (-1) : 1) : 0)); +} + +} // namespace detail +#endif // HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + +// If these are in namespace detail, the x86_256/512 templates are not found. +template +static HWY_INLINE VFromD NearestIntInRange(DI, + VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvtps_epi32 with GCC if any values + // of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues(DI(), + detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3])); + } +#endif + + __m128i raw_result; + __asm__("%vcvtps2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvtps_epi32(v.raw)}; +#endif +} + +#if HWY_HAVE_FLOAT16 +template +static HWY_INLINE VFromD NearestIntInRange(DI /*di*/, + VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvtph_epi16 if any values of v[i] + // are not within the range of an int16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues(DI(), + detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3]), + detail::X86ScalarNearestInt(raw_v[4]), + detail::X86ScalarNearestInt(raw_v[5]), + detail::X86ScalarNearestInt(raw_v[6]), + detail::X86ScalarNearestInt(raw_v[7])); + } +#endif + + __m128i raw_result; + __asm__("vcvtph2w {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvtph_epi16(v.raw)}; +#endif +} +#endif // HWY_HAVE_FLOAT16 + +#if HWY_TARGET <= HWY_AVX3 + +template +static HWY_INLINE VFromD NearestIntInRange(DI /*di*/, + VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvtpd_epi64 with GCC if any + // values of v[i] are not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues(DI(), + detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1])); + } +#endif + + __m128i raw_result; + __asm__("vcvtpd2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvtpd_epi64(v.raw)}; +#endif +} + +#else // HWY_TARGET > HWY_AVX3 + +namespace detail { + +#if HWY_ARCH_X86_64 +template +static HWY_INLINE int64_t +SSE2ConvFirstF64LaneToNearestI64(Vec128 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvtsd_si64 with GCC if v[0] is + // not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (IsConstantX86Vec(hwy::SizeTag<1>(), v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return X86ScalarNearestInt(raw_v[0]); + } +#endif + + int64_t result; + __asm__("%vcvtsd2si {%1, %0|%0, %1}" + : "=r"(result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return result; +#else + return _mm_cvtsd_si64(v.raw); +#endif +} +#endif // HWY_ARCH_X86_64 + +#if !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 +template +static HWY_INLINE VFromD SSE2NearestI64InRange( + DI64 di64, VFromD> v) { + const RebindToFloat df64; + const RebindToUnsigned du64; + using VI64 = VFromD; + + const auto mant_end = Set(df64, MantissaEnd()); + const auto is_small = Lt(Abs(v), mant_end); + + const auto adj_v = Max(v, Set(df64, -9223372036854775808.0)) + + IfThenElseZero(is_small, CopySignToAbs(mant_end, v)); + const auto adj_v_biased_exp = + And(BitCast(di64, ShiftRight<52>(BitCast(du64, adj_v))), + Set(di64, int64_t{0x7FF})); + + // We can simply subtract 1075 from adj_v_biased_exp[i] to get shift_int since + // adj_v_biased_exp[i] is at least 1075 + const VI64 shift_int = adj_v_biased_exp + Set(di64, int64_t{-1075}); + + const VI64 mantissa = BitCast(di64, adj_v) & Set(di64, (1LL << 52) - 1); + // Include implicit 1-bit if is_small[i] is 0. NOTE: the shift count may + // exceed 63; we rely on x86 returning zero in that case. + const VI64 int53 = mantissa | IfThenZeroElse(RebindMask(di64, is_small), + Set(di64, 1LL << 52)); + + const VI64 sign_mask = BroadcastSignBit(BitCast(di64, v)); + // If the input was negative, negate the integer (two's complement). + return ((int53 << shift_int) ^ sign_mask) - sign_mask; +} +#endif // !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 + +} // namespace detail + +#if HWY_ARCH_X86_64 +template +static HWY_INLINE VFromD NearestIntInRange(DI /*di*/, Vec64 v) { + return VFromD{ + _mm_cvtsi64_si128(detail::SSE2ConvFirstF64LaneToNearestI64(v))}; +} +template +static HWY_INLINE VFromD NearestIntInRange(DI /*di*/, Vec128 v) { + const __m128i i0 = + _mm_cvtsi64_si128(detail::SSE2ConvFirstF64LaneToNearestI64(v)); + const Full64 dd2; + const __m128i i1 = _mm_cvtsi64_si128( + detail::SSE2ConvFirstF64LaneToNearestI64(UpperHalf(dd2, v))); + return VFromD{_mm_unpacklo_epi64(i0, i1)}; +} +#endif // HWY_ARCH_X86_64 + +#if !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 +template +static HWY_INLINE VFromD NearestIntInRange(DI di, + VFromD> v) { + return detail::SSE2NearestI64InRange(di, v); +} +#endif // !HWY_ARCH_X86_64 || HWY_TARGET <= HWY_AVX2 + +#endif // HWY_TARGET <= HWY_AVX3 + +template +static HWY_INLINE VFromD DemoteToNearestIntInRange( + DI, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm_cvtpd_epi32 with GCC if any values + // of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + DI(), detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), int32_t{0}, int32_t{0}); + } +#endif + + __m128i raw_result; + __asm__("%vcvtpd2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm_cvtpd_epi32(v.raw)}; +#endif +} + +// F16/F32/F64 NearestInt is generic for all vector lengths +template , class DI = RebindToSigned, + HWY_IF_FLOAT_D(DF), + HWY_IF_T_SIZE_ONE_OF_D(DF, (1 << 4) | (1 << 8) | + (HWY_HAVE_FLOAT16 ? (1 << 2) : 0))> +HWY_API VFromD NearestInt(const VF v) { + const DI di; + using TI = TFromD; + using TF = TFromD; + using TFArith = If>; + + constexpr TFArith kMinOutOfRangePosVal = + static_cast(-static_cast(LimitsMin())); + static_assert(kMinOutOfRangePosVal > static_cast(0.0), + "kMinOutOfRangePosVal > 0.0 must be true"); + + // See comment at the first occurrence of "IfThenElse(overflow,". + // Here we are rounding, whereas previous occurrences truncate, but there is + // no difference because the previous float value is well below the max i32. + const auto overflow = RebindMask( + di, Ge(v, Set(DF(), ConvertScalarTo(kMinOutOfRangePosVal)))); + auto result = + IfThenElse(overflow, Set(di, LimitsMax()), NearestIntInRange(di, v)); + + return result; +} + +template +HWY_API VFromD DemoteToNearestInt(DI, VFromD> v) { + const DI di; + const Rebind df64; + return DemoteToNearestIntInRange(di, Min(v, Set(df64, 2147483647.0))); +} + +// ------------------------------ Floating-point rounding (ConvertTo) + +#if HWY_TARGET >= HWY_SSSE3 + +// Toward nearest integer, ties to even +template +HWY_API Vec128 Round(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + // Rely on rounding after addition with a large value such that no mantissa + // bits remain (assuming the current mode is nearest-even). We may need a + // compiler flag for precise floating-point to prevent "optimizing" this out. + const DFromV df; + const auto max = Set(df, MantissaEnd()); + const auto large = CopySignToAbs(max, v); + const auto added = large + v; + const auto rounded = added - large; + // Keep original if NaN or the magnitude is large (already an int). + return IfThenElse(Abs(v) < max, rounded, v); +} + +namespace detail { + +// Truncating to integer and converting back to float is correct except when the +// input magnitude is large, in which case the input was already an integer +// (because mantissa >> exponent is zero). +template +HWY_INLINE Mask128 UseInt(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const DFromV d; + return Abs(v) < Set(d, MantissaEnd()); +} + +} // namespace detail + +// Toward zero, aka truncate +template +HWY_API Vec128 Trunc(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertInRangeTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + return IfThenElse(detail::UseInt(v), CopySign(int_f, v), v); +} + +// Toward +infinity, aka ceiling +template +HWY_API Vec128 Ceil(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertInRangeTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a positive non-integer ends up smaller; if so, add 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f < v))); + + return IfThenElse(detail::UseInt(v), int_f - neg1, v); +} + +#ifdef HWY_NATIVE_CEIL_FLOOR_INT +#undef HWY_NATIVE_CEIL_FLOOR_INT +#else +#define HWY_NATIVE_CEIL_FLOOR_INT +#endif + +template +HWY_API VFromD>> CeilInt(V v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a positive non-integer ends up smaller; if so, add 1. + return integer - + VecFromMask(di, RebindMask(di, And(detail::UseInt(v), int_f < v))); +} + +// Toward -infinity, aka floor +template +HWY_API Vec128 Floor(const Vec128 v) { + static_assert(IsFloat(), "Only for float"); + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertInRangeTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a negative non-integer ends up larger; if so, subtract 1. + const auto neg1 = ConvertTo(df, VecFromMask(di, RebindMask(di, int_f > v))); + + return IfThenElse(detail::UseInt(v), int_f + neg1, v); +} + +template +HWY_API VFromD>> FloorInt(V v) { + const DFromV df; + const RebindToSigned di; + + const auto integer = ConvertTo(di, v); // round toward 0 + const auto int_f = ConvertTo(df, integer); + + // Truncating a negative non-integer ends up larger; if so, subtract 1. + return integer + + VecFromMask(di, RebindMask(di, And(detail::UseInt(v), int_f > v))); +} + +#else + +// Toward nearest integer, ties to even +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Round(const Vec128 v) { + return Vec128{ + _mm_roundscale_ph(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Round(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Round(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} + +// Toward zero, aka truncate +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Trunc(const Vec128 v) { + return Vec128{ + _mm_roundscale_ph(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Trunc(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Trunc(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} + +// Toward +infinity, aka ceiling +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Ceil(const Vec128 v) { + return Vec128{ + _mm_roundscale_ph(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Ceil(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Ceil(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} + +// Toward -infinity, aka floor +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Floor(const Vec128 v) { + return Vec128{ + _mm_roundscale_ph(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec128 Floor(const Vec128 v) { + return Vec128{ + _mm_round_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +template +HWY_API Vec128 Floor(const Vec128 v) { + return Vec128{ + _mm_round_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} + +#endif // !HWY_SSSE3 + +// ------------------------------ Floating-point classification + +#define HWY_X86_FPCLASS_QNAN 0x01 +#define HWY_X86_FPCLASS_POS0 0x02 +#define HWY_X86_FPCLASS_NEG0 0x04 +#define HWY_X86_FPCLASS_POS_INF 0x08 +#define HWY_X86_FPCLASS_NEG_INF 0x10 +#define HWY_X86_FPCLASS_SUBNORMAL 0x20 +#define HWY_X86_FPCLASS_NEG 0x40 +#define HWY_X86_FPCLASS_SNAN 0x80 + +#if HWY_HAVE_FLOAT16 || HWY_IDE + +template +HWY_API Mask128 IsNaN(const Vec128 v) { + return Mask128{ + _mm_fpclass_ph_mask(v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; +} + +template +HWY_API Mask128 IsEitherNaN(Vec128 a, + Vec128 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask128{_mm_cmp_ph_mask(a.raw, b.raw, _CMP_UNORD_Q)}; + HWY_DIAGNOSTICS(pop) +} + +template +HWY_API Mask128 IsInf(const Vec128 v) { + return Mask128{_mm_fpclass_ph_mask( + v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; +} + +template +HWY_API Mask128 IsFinite(const Vec128 v) { + // fpclass doesn't have a flag for positive, so we have to check for inf/NaN + // and negate the mask. + return Not(Mask128{_mm_fpclass_ph_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | + HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); +} + +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API Mask128 IsNaN(const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Mask128{ + _mm_fpclass_ps_mask(v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; +#else + return Mask128{_mm_cmpunord_ps(v.raw, v.raw)}; +#endif +} +template +HWY_API Mask128 IsNaN(const Vec128 v) { +#if HWY_TARGET <= HWY_AVX3 + return Mask128{ + _mm_fpclass_pd_mask(v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; +#else + return Mask128{_mm_cmpunord_pd(v.raw, v.raw)}; +#endif +} + +#ifdef HWY_NATIVE_IS_EITHER_NAN +#undef HWY_NATIVE_IS_EITHER_NAN +#else +#define HWY_NATIVE_IS_EITHER_NAN +#endif + +template +HWY_API Mask128 IsEitherNaN(Vec128 a, Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Mask128{_mm_cmp_ps_mask(a.raw, b.raw, _CMP_UNORD_Q)}; +#else + return Mask128{_mm_cmpunord_ps(a.raw, b.raw)}; +#endif +} + +template +HWY_API Mask128 IsEitherNaN(Vec128 a, + Vec128 b) { +#if HWY_TARGET <= HWY_AVX3 + return Mask128{_mm_cmp_pd_mask(a.raw, b.raw, _CMP_UNORD_Q)}; +#else + return Mask128{_mm_cmpunord_pd(a.raw, b.raw)}; +#endif +} + +#if HWY_TARGET <= HWY_AVX3 + +// Per-target flag to prevent generic_ops-inl.h from defining IsInf / IsFinite. +#ifdef HWY_NATIVE_ISINF +#undef HWY_NATIVE_ISINF +#else +#define HWY_NATIVE_ISINF +#endif + +template +HWY_API Mask128 IsInf(const Vec128 v) { + return Mask128{_mm_fpclass_ps_mask( + v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; +} +template +HWY_API Mask128 IsInf(const Vec128 v) { + return Mask128{_mm_fpclass_pd_mask( + v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; +} + +// Returns whether normal/subnormal/zero. +template +HWY_API Mask128 IsFinite(const Vec128 v) { + // fpclass doesn't have a flag for positive, so we have to check for inf/NaN + // and negate the mask. + return Not(Mask128{_mm_fpclass_ps_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | + HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); +} +template +HWY_API Mask128 IsFinite(const Vec128 v) { + return Not(Mask128{_mm_fpclass_pd_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | + HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ================================================== CRYPTO + +#if !defined(HWY_DISABLE_PCLMUL_AES) && HWY_TARGET <= HWY_SSE4 + +// Per-target flag to prevent generic_ops-inl.h from defining AESRound. +#ifdef HWY_NATIVE_AES +#undef HWY_NATIVE_AES +#else +#define HWY_NATIVE_AES +#endif + +HWY_API Vec128 AESRound(Vec128 state, + Vec128 round_key) { + return Vec128{_mm_aesenc_si128(state.raw, round_key.raw)}; +} + +HWY_API Vec128 AESLastRound(Vec128 state, + Vec128 round_key) { + return Vec128{_mm_aesenclast_si128(state.raw, round_key.raw)}; +} + +HWY_API Vec128 AESInvMixColumns(Vec128 state) { + return Vec128{_mm_aesimc_si128(state.raw)}; +} + +HWY_API Vec128 AESRoundInv(Vec128 state, + Vec128 round_key) { + return Vec128{_mm_aesdec_si128(state.raw, round_key.raw)}; +} + +HWY_API Vec128 AESLastRoundInv(Vec128 state, + Vec128 round_key) { + return Vec128{_mm_aesdeclast_si128(state.raw, round_key.raw)}; +} + +template +HWY_API Vec128 AESKeyGenAssist(Vec128 v) { + return Vec128{_mm_aeskeygenassist_si128(v.raw, kRcon)}; +} + +template +HWY_API Vec128 CLMulLower(Vec128 a, + Vec128 b) { + return Vec128{_mm_clmulepi64_si128(a.raw, b.raw, 0x00)}; +} + +template +HWY_API Vec128 CLMulUpper(Vec128 a, + Vec128 b) { + return Vec128{_mm_clmulepi64_si128(a.raw, b.raw, 0x11)}; +} + +#endif // !defined(HWY_DISABLE_PCLMUL_AES) && HWY_TARGET <= HWY_SSE4 + +// ================================================== MISC + +// ------------------------------ LoadMaskBits (TestBit) + +#if HWY_TARGET > HWY_AVX3 +namespace detail { + +template +HWY_INLINE MFromD LoadMaskBits128(D d, uint64_t mask_bits) { + const RebindToUnsigned du; + // Easier than Set(), which would require an >8-bit type, which would not + // compile for T=uint8_t, kN=1. + const VFromD vbits{_mm_cvtsi32_si128(static_cast(mask_bits))}; + +#if HWY_TARGET == HWY_SSE2 + // {b0, b1, ...} ===> {b0, b0, b1, b1, ...} + __m128i unpacked_vbits = _mm_unpacklo_epi8(vbits.raw, vbits.raw); + // {b0, b0, b1, b1, ...} ==> {b0, b0, b0, b0, b1, b1, b1, b1, ...} + unpacked_vbits = _mm_unpacklo_epi16(unpacked_vbits, unpacked_vbits); + // {b0, b0, b0, b0, b1, b1, b1, b1, ...} ==> + // {b0, b0, b0, b0, b0, b0, b0, b0, b1, b1, b1, b1, b1, b1, b1, b1} + const VFromD rep8{ + _mm_unpacklo_epi32(unpacked_vbits, unpacked_vbits)}; +#else + // Replicate bytes 8x such that each byte contains the bit that governs it. + alignas(16) static constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1}; + const auto rep8 = TableLookupBytes(vbits, Load(du, kRep8)); +#endif + const VFromD bit = Dup128VecFromValues( + du, 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); + return RebindMask(d, TestBit(rep8, bit)); +} + +template +HWY_INLINE MFromD LoadMaskBits128(D d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) static constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE MFromD LoadMaskBits128(D d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) static constexpr uint32_t kBit[8] = {1, 2, 4, 8}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE MFromD LoadMaskBits128(D d, uint64_t mask_bits) { + const RebindToUnsigned du; + alignas(16) static constexpr uint64_t kBit[8] = {1, 2}; + return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); +} + +} // namespace detail +#endif // HWY_TARGET > HWY_AVX3 + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + constexpr size_t kN = MaxLanes(d); +#if HWY_TARGET <= HWY_AVX3 + (void)d; + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (kN + 7) / 8; + CopyBytes(bits, &mask_bits); + if (kN < 8) { + mask_bits &= (1ull << kN) - 1; + } + + return MFromD::FromBits(mask_bits); +#else + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (kN + 7) / 8; + CopyBytes(bits, &mask_bits); + if (kN < 8) { + mask_bits &= (1ull << kN) - 1; + } + + return detail::LoadMaskBits128(d, mask_bits); +#endif +} + +// ------------------------------ Dup128MaskFromMaskBits + +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + constexpr size_t kN = MaxLanes(d); + if (kN < 8) mask_bits &= (1u << kN) - 1; + +#if HWY_TARGET <= HWY_AVX3 + return MFromD::FromBits(mask_bits); +#else + return detail::LoadMaskBits128(d, mask_bits); +#endif +} + +template +struct CompressIsPartition { +#if HWY_TARGET <= HWY_AVX3 + // AVX3 supports native compress, but a table-based approach allows + // 'partitioning' (also moving mask=false lanes to the top), which helps + // vqsort. This is only feasible for eight or less lanes, i.e. sizeof(T) == 8 + // on AVX3. For simplicity, we only use tables for 64-bit lanes (not AVX3 + // u32x8 etc.). + enum { value = (sizeof(T) == 8) }; +#else + // generic_ops-inl does not guarantee IsPartition for 8-bit. + enum { value = (sizeof(T) != 1) }; +#endif +}; + +namespace detail { + +// Returns `mask_bits` (from movemask) with the upper bits cleared, if there +// are 8 or fewer valid bits. +template +constexpr uint64_t OnlyActive(D d, uint64_t mask_bits) { + return (d.MaxBytes() >= 16) ? mask_bits + : mask_bits & ((1ull << d.MaxLanes()) - 1); +} + +} // namespace detail + +#if HWY_TARGET <= HWY_AVX3 + +// ------------------------------ BitsFromMask (MFromD, OnlyActive) +// Generic for all vector lengths. +template +HWY_INLINE uint64_t BitsFromMask(D d, MFromD mask) { + return detail::OnlyActive(d, mask.raw); +} + +// ------------------------------ StoreMaskBits + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(D d, MFromD mask, uint8_t* bits) { + constexpr size_t kN = MaxLanes(d); + constexpr size_t kNumBytes = (kN + 7) / 8; + CopyBytes(&mask.raw, bits); + + // Non-full byte, need to clear the undefined upper bits. + if (kN < 8) { + const int mask_bits = (1 << kN) - 1; + bits[0] = static_cast(bits[0] & mask_bits); + } + + return kNumBytes; +} + +// ------------------------------ Mask testing + +// Beware: the suffix indicates the number of mask bits, not lane size! + +template +HWY_API size_t CountTrue(D d, MFromD mask) { + constexpr size_t kN = MaxLanes(d); + const uint64_t mask_bits = uint64_t{mask.raw} & ((1ull << kN) - 1); + return PopCount(mask_bits); +} + +template +HWY_API size_t FindKnownFirstTrue(D d, MFromD mask) { + constexpr size_t kN = MaxLanes(d); + const uint32_t mask_bits = uint32_t{mask.raw} & ((1u << kN) - 1); + return Num0BitsBelowLS1Bit_Nonzero32(mask_bits); +} + +template +HWY_API intptr_t FindFirstTrue(D d, MFromD mask) { + constexpr size_t kN = MaxLanes(d); + const uint32_t mask_bits = uint32_t{mask.raw} & ((1u << kN) - 1); + return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero32(mask_bits)) : -1; +} + +template +HWY_API size_t FindKnownLastTrue(D d, MFromD mask) { + constexpr size_t kN = MaxLanes(d); + const uint32_t mask_bits = uint32_t{mask.raw} & ((1u << kN) - 1); + return 31 - Num0BitsAboveMS1Bit_Nonzero32(mask_bits); +} + +template +HWY_API intptr_t FindLastTrue(D d, MFromD mask) { + constexpr size_t kN = MaxLanes(d); + const uint32_t mask_bits = uint32_t{mask.raw} & ((1u << kN) - 1); + return mask_bits ? intptr_t(31 - Num0BitsAboveMS1Bit_Nonzero32(mask_bits)) + : -1; +} + +template +HWY_API bool AllFalse(D d, MFromD mask) { + constexpr size_t kN = MaxLanes(d); + const uint64_t mask_bits = uint64_t{mask.raw} & ((1ull << kN) - 1); + return mask_bits == 0; +} + +template +HWY_API bool AllTrue(D d, MFromD mask) { + constexpr size_t kN = MaxLanes(d); + const uint64_t mask_bits = uint64_t{mask.raw} & ((1ull << kN) - 1); + // Cannot use _kortestc because we may have less than 8 mask bits. + return mask_bits == (1ull << kN) - 1; +} + +// ------------------------------ Compress + +// 8-16 bit Compress, CompressStore defined in x86_512 because they use Vec512. + +// Single lane: no-op +template +HWY_API Vec128 Compress(Vec128 v, Mask128 /*m*/) { + return v; +} + +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + return Vec128{_mm_maskz_compress_ps(mask.raw, v.raw)}; +} + +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + HWY_DASSERT(mask.raw < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[64] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const DFromV d; + const Repartition d8; + const auto index = Load(d8, u8_indices + 16 * mask.raw); + return BitCast(d, TableLookupBytes(BitCast(d8, v), index)); +} + +// ------------------------------ CompressNot (Compress) + +// Single lane: no-op +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 /*m*/) { + return v; +} + +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // See CompressIsPartition, PrintCompressNot64x2NibbleTables + alignas(16) static constexpr uint64_t packed_array[16] = { + 0x00000010, 0x00000001, 0x00000010, 0x00000010}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 2). + const DFromV d; + const RebindToUnsigned du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(16) static constexpr uint64_t kShifts[2] = {0, 4}; + Vec128 indices = packed >> Load(du64, kShifts); + // _mm_permutevar_pd will ignore the upper bits, but TableLookupLanes uses + // a fallback in MSAN builds, so mask there. + HWY_IF_CONSTEXPR(HWY_IS_MSAN) indices &= Set(du64, 1); + return TableLookupLanes(v, Indices128{indices.raw}); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +// ------------------------------ CompressStore (defined in x86_512) + +// ------------------------------ CompressBlendedStore (defined in x86_avx3) + +// ------------------------------ CompressBitsStore (defined in x86_512) + +#else // AVX2 or below + +// ------------------------------ BitsFromMask + +namespace detail { + +constexpr HWY_INLINE uint64_t U64FromInt(int mask_bits) { + return static_cast(static_cast(mask_bits)); +} + +} // namespace detail + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const auto sign_bits = BitCast(d, VecFromMask(d, mask)).raw; + return detail::OnlyActive(d, + detail::U64FromInt(_mm_movemask_epi8(sign_bits))); +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + // Remove useless lower half of each u16 while preserving the sign bit. + const auto sign_bits = _mm_packs_epi16(mask.raw, _mm_setzero_si128()); + return detail::OnlyActive(d, + detail::U64FromInt(_mm_movemask_epi8(sign_bits))); +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const RebindToFloat df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)); + return detail::OnlyActive(d, + detail::U64FromInt(_mm_movemask_ps(sign_bits.raw))); +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const RebindToFloat df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)); + return detail::OnlyActive(d, + detail::U64FromInt(_mm_movemask_pd(sign_bits.raw))); +} + +// ------------------------------ StoreMaskBits +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(D d, MFromD mask, uint8_t* bits) { + constexpr size_t kNumBytes = (MaxLanes(d) + 7) / 8; + const uint64_t mask_bits = BitsFromMask(d, mask); + CopyBytes(&mask_bits, bits); + return kNumBytes; +} + +// ------------------------------ Mask testing + +template +HWY_API bool AllFalse(D d, MFromD mask) { + // Cheaper than PTEST, which is 2 uop / 3L. + return BitsFromMask(d, mask) == 0; +} + +template +HWY_API bool AllTrue(D d, MFromD mask) { + constexpr uint64_t kAllBits = (1ull << MaxLanes(d)) - 1; + return BitsFromMask(d, mask) == kAllBits; +} + +template +HWY_API size_t CountTrue(D d, MFromD mask) { + return PopCount(BitsFromMask(d, mask)); +} + +template +HWY_API size_t FindKnownFirstTrue(D d, MFromD mask) { + return Num0BitsBelowLS1Bit_Nonzero32( + static_cast(BitsFromMask(d, mask))); +} + +template +HWY_API intptr_t FindFirstTrue(D d, MFromD mask) { + const uint32_t mask_bits = static_cast(BitsFromMask(d, mask)); + return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero32(mask_bits)) : -1; +} + +template +HWY_API size_t FindKnownLastTrue(D d, MFromD mask) { + return 31 - Num0BitsAboveMS1Bit_Nonzero32( + static_cast(BitsFromMask(d, mask))); +} + +template +HWY_API intptr_t FindLastTrue(D d, MFromD mask) { + const uint32_t mask_bits = static_cast(BitsFromMask(d, mask)); + return mask_bits ? intptr_t(31 - Num0BitsAboveMS1Bit_Nonzero32(mask_bits)) + : -1; +} + +// ------------------------------ Compress, CompressBits + +namespace detail { + +// Also works for N < 8 because the first 16 4-tuples only reference bytes 0-6. +template +HWY_INLINE VFromD IndicesFromBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Rebind d8; + const Twice d8t; + const RebindToUnsigned du; + + // compress_epi16 requires VBMI2 and there is no permutevar_epi16, so we need + // byte indices for PSHUFB (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) static constexpr uint8_t table[2048] = { + // PrintCompress16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 2, 0, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 4, 0, 2, 6, 8, 10, 12, 14, /**/ 0, 4, 2, 6, 8, 10, 12, 14, // + 2, 4, 0, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 6, 0, 2, 4, 8, 10, 12, 14, /**/ 0, 6, 2, 4, 8, 10, 12, 14, // + 2, 6, 0, 4, 8, 10, 12, 14, /**/ 0, 2, 6, 4, 8, 10, 12, 14, // + 4, 6, 0, 2, 8, 10, 12, 14, /**/ 0, 4, 6, 2, 8, 10, 12, 14, // + 2, 4, 6, 0, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 8, 0, 2, 4, 6, 10, 12, 14, /**/ 0, 8, 2, 4, 6, 10, 12, 14, // + 2, 8, 0, 4, 6, 10, 12, 14, /**/ 0, 2, 8, 4, 6, 10, 12, 14, // + 4, 8, 0, 2, 6, 10, 12, 14, /**/ 0, 4, 8, 2, 6, 10, 12, 14, // + 2, 4, 8, 0, 6, 10, 12, 14, /**/ 0, 2, 4, 8, 6, 10, 12, 14, // + 6, 8, 0, 2, 4, 10, 12, 14, /**/ 0, 6, 8, 2, 4, 10, 12, 14, // + 2, 6, 8, 0, 4, 10, 12, 14, /**/ 0, 2, 6, 8, 4, 10, 12, 14, // + 4, 6, 8, 0, 2, 10, 12, 14, /**/ 0, 4, 6, 8, 2, 10, 12, 14, // + 2, 4, 6, 8, 0, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 10, 0, 2, 4, 6, 8, 12, 14, /**/ 0, 10, 2, 4, 6, 8, 12, 14, // + 2, 10, 0, 4, 6, 8, 12, 14, /**/ 0, 2, 10, 4, 6, 8, 12, 14, // + 4, 10, 0, 2, 6, 8, 12, 14, /**/ 0, 4, 10, 2, 6, 8, 12, 14, // + 2, 4, 10, 0, 6, 8, 12, 14, /**/ 0, 2, 4, 10, 6, 8, 12, 14, // + 6, 10, 0, 2, 4, 8, 12, 14, /**/ 0, 6, 10, 2, 4, 8, 12, 14, // + 2, 6, 10, 0, 4, 8, 12, 14, /**/ 0, 2, 6, 10, 4, 8, 12, 14, // + 4, 6, 10, 0, 2, 8, 12, 14, /**/ 0, 4, 6, 10, 2, 8, 12, 14, // + 2, 4, 6, 10, 0, 8, 12, 14, /**/ 0, 2, 4, 6, 10, 8, 12, 14, // + 8, 10, 0, 2, 4, 6, 12, 14, /**/ 0, 8, 10, 2, 4, 6, 12, 14, // + 2, 8, 10, 0, 4, 6, 12, 14, /**/ 0, 2, 8, 10, 4, 6, 12, 14, // + 4, 8, 10, 0, 2, 6, 12, 14, /**/ 0, 4, 8, 10, 2, 6, 12, 14, // + 2, 4, 8, 10, 0, 6, 12, 14, /**/ 0, 2, 4, 8, 10, 6, 12, 14, // + 6, 8, 10, 0, 2, 4, 12, 14, /**/ 0, 6, 8, 10, 2, 4, 12, 14, // + 2, 6, 8, 10, 0, 4, 12, 14, /**/ 0, 2, 6, 8, 10, 4, 12, 14, // + 4, 6, 8, 10, 0, 2, 12, 14, /**/ 0, 4, 6, 8, 10, 2, 12, 14, // + 2, 4, 6, 8, 10, 0, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 12, 0, 2, 4, 6, 8, 10, 14, /**/ 0, 12, 2, 4, 6, 8, 10, 14, // + 2, 12, 0, 4, 6, 8, 10, 14, /**/ 0, 2, 12, 4, 6, 8, 10, 14, // + 4, 12, 0, 2, 6, 8, 10, 14, /**/ 0, 4, 12, 2, 6, 8, 10, 14, // + 2, 4, 12, 0, 6, 8, 10, 14, /**/ 0, 2, 4, 12, 6, 8, 10, 14, // + 6, 12, 0, 2, 4, 8, 10, 14, /**/ 0, 6, 12, 2, 4, 8, 10, 14, // + 2, 6, 12, 0, 4, 8, 10, 14, /**/ 0, 2, 6, 12, 4, 8, 10, 14, // + 4, 6, 12, 0, 2, 8, 10, 14, /**/ 0, 4, 6, 12, 2, 8, 10, 14, // + 2, 4, 6, 12, 0, 8, 10, 14, /**/ 0, 2, 4, 6, 12, 8, 10, 14, // + 8, 12, 0, 2, 4, 6, 10, 14, /**/ 0, 8, 12, 2, 4, 6, 10, 14, // + 2, 8, 12, 0, 4, 6, 10, 14, /**/ 0, 2, 8, 12, 4, 6, 10, 14, // + 4, 8, 12, 0, 2, 6, 10, 14, /**/ 0, 4, 8, 12, 2, 6, 10, 14, // + 2, 4, 8, 12, 0, 6, 10, 14, /**/ 0, 2, 4, 8, 12, 6, 10, 14, // + 6, 8, 12, 0, 2, 4, 10, 14, /**/ 0, 6, 8, 12, 2, 4, 10, 14, // + 2, 6, 8, 12, 0, 4, 10, 14, /**/ 0, 2, 6, 8, 12, 4, 10, 14, // + 4, 6, 8, 12, 0, 2, 10, 14, /**/ 0, 4, 6, 8, 12, 2, 10, 14, // + 2, 4, 6, 8, 12, 0, 10, 14, /**/ 0, 2, 4, 6, 8, 12, 10, 14, // + 10, 12, 0, 2, 4, 6, 8, 14, /**/ 0, 10, 12, 2, 4, 6, 8, 14, // + 2, 10, 12, 0, 4, 6, 8, 14, /**/ 0, 2, 10, 12, 4, 6, 8, 14, // + 4, 10, 12, 0, 2, 6, 8, 14, /**/ 0, 4, 10, 12, 2, 6, 8, 14, // + 2, 4, 10, 12, 0, 6, 8, 14, /**/ 0, 2, 4, 10, 12, 6, 8, 14, // + 6, 10, 12, 0, 2, 4, 8, 14, /**/ 0, 6, 10, 12, 2, 4, 8, 14, // + 2, 6, 10, 12, 0, 4, 8, 14, /**/ 0, 2, 6, 10, 12, 4, 8, 14, // + 4, 6, 10, 12, 0, 2, 8, 14, /**/ 0, 4, 6, 10, 12, 2, 8, 14, // + 2, 4, 6, 10, 12, 0, 8, 14, /**/ 0, 2, 4, 6, 10, 12, 8, 14, // + 8, 10, 12, 0, 2, 4, 6, 14, /**/ 0, 8, 10, 12, 2, 4, 6, 14, // + 2, 8, 10, 12, 0, 4, 6, 14, /**/ 0, 2, 8, 10, 12, 4, 6, 14, // + 4, 8, 10, 12, 0, 2, 6, 14, /**/ 0, 4, 8, 10, 12, 2, 6, 14, // + 2, 4, 8, 10, 12, 0, 6, 14, /**/ 0, 2, 4, 8, 10, 12, 6, 14, // + 6, 8, 10, 12, 0, 2, 4, 14, /**/ 0, 6, 8, 10, 12, 2, 4, 14, // + 2, 6, 8, 10, 12, 0, 4, 14, /**/ 0, 2, 6, 8, 10, 12, 4, 14, // + 4, 6, 8, 10, 12, 0, 2, 14, /**/ 0, 4, 6, 8, 10, 12, 2, 14, // + 2, 4, 6, 8, 10, 12, 0, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, // + 14, 0, 2, 4, 6, 8, 10, 12, /**/ 0, 14, 2, 4, 6, 8, 10, 12, // + 2, 14, 0, 4, 6, 8, 10, 12, /**/ 0, 2, 14, 4, 6, 8, 10, 12, // + 4, 14, 0, 2, 6, 8, 10, 12, /**/ 0, 4, 14, 2, 6, 8, 10, 12, // + 2, 4, 14, 0, 6, 8, 10, 12, /**/ 0, 2, 4, 14, 6, 8, 10, 12, // + 6, 14, 0, 2, 4, 8, 10, 12, /**/ 0, 6, 14, 2, 4, 8, 10, 12, // + 2, 6, 14, 0, 4, 8, 10, 12, /**/ 0, 2, 6, 14, 4, 8, 10, 12, // + 4, 6, 14, 0, 2, 8, 10, 12, /**/ 0, 4, 6, 14, 2, 8, 10, 12, // + 2, 4, 6, 14, 0, 8, 10, 12, /**/ 0, 2, 4, 6, 14, 8, 10, 12, // + 8, 14, 0, 2, 4, 6, 10, 12, /**/ 0, 8, 14, 2, 4, 6, 10, 12, // + 2, 8, 14, 0, 4, 6, 10, 12, /**/ 0, 2, 8, 14, 4, 6, 10, 12, // + 4, 8, 14, 0, 2, 6, 10, 12, /**/ 0, 4, 8, 14, 2, 6, 10, 12, // + 2, 4, 8, 14, 0, 6, 10, 12, /**/ 0, 2, 4, 8, 14, 6, 10, 12, // + 6, 8, 14, 0, 2, 4, 10, 12, /**/ 0, 6, 8, 14, 2, 4, 10, 12, // + 2, 6, 8, 14, 0, 4, 10, 12, /**/ 0, 2, 6, 8, 14, 4, 10, 12, // + 4, 6, 8, 14, 0, 2, 10, 12, /**/ 0, 4, 6, 8, 14, 2, 10, 12, // + 2, 4, 6, 8, 14, 0, 10, 12, /**/ 0, 2, 4, 6, 8, 14, 10, 12, // + 10, 14, 0, 2, 4, 6, 8, 12, /**/ 0, 10, 14, 2, 4, 6, 8, 12, // + 2, 10, 14, 0, 4, 6, 8, 12, /**/ 0, 2, 10, 14, 4, 6, 8, 12, // + 4, 10, 14, 0, 2, 6, 8, 12, /**/ 0, 4, 10, 14, 2, 6, 8, 12, // + 2, 4, 10, 14, 0, 6, 8, 12, /**/ 0, 2, 4, 10, 14, 6, 8, 12, // + 6, 10, 14, 0, 2, 4, 8, 12, /**/ 0, 6, 10, 14, 2, 4, 8, 12, // + 2, 6, 10, 14, 0, 4, 8, 12, /**/ 0, 2, 6, 10, 14, 4, 8, 12, // + 4, 6, 10, 14, 0, 2, 8, 12, /**/ 0, 4, 6, 10, 14, 2, 8, 12, // + 2, 4, 6, 10, 14, 0, 8, 12, /**/ 0, 2, 4, 6, 10, 14, 8, 12, // + 8, 10, 14, 0, 2, 4, 6, 12, /**/ 0, 8, 10, 14, 2, 4, 6, 12, // + 2, 8, 10, 14, 0, 4, 6, 12, /**/ 0, 2, 8, 10, 14, 4, 6, 12, // + 4, 8, 10, 14, 0, 2, 6, 12, /**/ 0, 4, 8, 10, 14, 2, 6, 12, // + 2, 4, 8, 10, 14, 0, 6, 12, /**/ 0, 2, 4, 8, 10, 14, 6, 12, // + 6, 8, 10, 14, 0, 2, 4, 12, /**/ 0, 6, 8, 10, 14, 2, 4, 12, // + 2, 6, 8, 10, 14, 0, 4, 12, /**/ 0, 2, 6, 8, 10, 14, 4, 12, // + 4, 6, 8, 10, 14, 0, 2, 12, /**/ 0, 4, 6, 8, 10, 14, 2, 12, // + 2, 4, 6, 8, 10, 14, 0, 12, /**/ 0, 2, 4, 6, 8, 10, 14, 12, // + 12, 14, 0, 2, 4, 6, 8, 10, /**/ 0, 12, 14, 2, 4, 6, 8, 10, // + 2, 12, 14, 0, 4, 6, 8, 10, /**/ 0, 2, 12, 14, 4, 6, 8, 10, // + 4, 12, 14, 0, 2, 6, 8, 10, /**/ 0, 4, 12, 14, 2, 6, 8, 10, // + 2, 4, 12, 14, 0, 6, 8, 10, /**/ 0, 2, 4, 12, 14, 6, 8, 10, // + 6, 12, 14, 0, 2, 4, 8, 10, /**/ 0, 6, 12, 14, 2, 4, 8, 10, // + 2, 6, 12, 14, 0, 4, 8, 10, /**/ 0, 2, 6, 12, 14, 4, 8, 10, // + 4, 6, 12, 14, 0, 2, 8, 10, /**/ 0, 4, 6, 12, 14, 2, 8, 10, // + 2, 4, 6, 12, 14, 0, 8, 10, /**/ 0, 2, 4, 6, 12, 14, 8, 10, // + 8, 12, 14, 0, 2, 4, 6, 10, /**/ 0, 8, 12, 14, 2, 4, 6, 10, // + 2, 8, 12, 14, 0, 4, 6, 10, /**/ 0, 2, 8, 12, 14, 4, 6, 10, // + 4, 8, 12, 14, 0, 2, 6, 10, /**/ 0, 4, 8, 12, 14, 2, 6, 10, // + 2, 4, 8, 12, 14, 0, 6, 10, /**/ 0, 2, 4, 8, 12, 14, 6, 10, // + 6, 8, 12, 14, 0, 2, 4, 10, /**/ 0, 6, 8, 12, 14, 2, 4, 10, // + 2, 6, 8, 12, 14, 0, 4, 10, /**/ 0, 2, 6, 8, 12, 14, 4, 10, // + 4, 6, 8, 12, 14, 0, 2, 10, /**/ 0, 4, 6, 8, 12, 14, 2, 10, // + 2, 4, 6, 8, 12, 14, 0, 10, /**/ 0, 2, 4, 6, 8, 12, 14, 10, // + 10, 12, 14, 0, 2, 4, 6, 8, /**/ 0, 10, 12, 14, 2, 4, 6, 8, // + 2, 10, 12, 14, 0, 4, 6, 8, /**/ 0, 2, 10, 12, 14, 4, 6, 8, // + 4, 10, 12, 14, 0, 2, 6, 8, /**/ 0, 4, 10, 12, 14, 2, 6, 8, // + 2, 4, 10, 12, 14, 0, 6, 8, /**/ 0, 2, 4, 10, 12, 14, 6, 8, // + 6, 10, 12, 14, 0, 2, 4, 8, /**/ 0, 6, 10, 12, 14, 2, 4, 8, // + 2, 6, 10, 12, 14, 0, 4, 8, /**/ 0, 2, 6, 10, 12, 14, 4, 8, // + 4, 6, 10, 12, 14, 0, 2, 8, /**/ 0, 4, 6, 10, 12, 14, 2, 8, // + 2, 4, 6, 10, 12, 14, 0, 8, /**/ 0, 2, 4, 6, 10, 12, 14, 8, // + 8, 10, 12, 14, 0, 2, 4, 6, /**/ 0, 8, 10, 12, 14, 2, 4, 6, // + 2, 8, 10, 12, 14, 0, 4, 6, /**/ 0, 2, 8, 10, 12, 14, 4, 6, // + 4, 8, 10, 12, 14, 0, 2, 6, /**/ 0, 4, 8, 10, 12, 14, 2, 6, // + 2, 4, 8, 10, 12, 14, 0, 6, /**/ 0, 2, 4, 8, 10, 12, 14, 6, // + 6, 8, 10, 12, 14, 0, 2, 4, /**/ 0, 6, 8, 10, 12, 14, 2, 4, // + 2, 6, 8, 10, 12, 14, 0, 4, /**/ 0, 2, 6, 8, 10, 12, 14, 4, // + 4, 6, 8, 10, 12, 14, 0, 2, /**/ 0, 4, 6, 8, 10, 12, 14, 2, // + 2, 4, 6, 8, 10, 12, 14, 0, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const VFromD byte_idx{Load(d8, table + mask_bits * 8).raw}; + const VFromD pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE VFromD IndicesFromNotBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 256); + const Rebind d8; + const Twice d8t; + const RebindToUnsigned du; + + // compress_epi16 requires VBMI2 and there is no permutevar_epi16, so we need + // byte indices for PSHUFB (one vector's worth for each of 256 combinations of + // 8 mask bits). Loading them directly would require 4 KiB. We can instead + // store lane indices and convert to byte indices (2*lane + 0..1), with the + // doubling baked into the table. AVX2 Compress32 stores eight 4-bit lane + // indices (total 1 KiB), broadcasts them into each 32-bit lane and shifts. + // Here, 16-bit lanes are too narrow to hold all bits, and unpacking nibbles + // is likely more costly than the higher cache footprint from storing bytes. + alignas(16) static constexpr uint8_t table[2048] = { + // PrintCompressNot16x8Tables + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 14, 0, // + 0, 4, 6, 8, 10, 12, 14, 2, /**/ 4, 6, 8, 10, 12, 14, 0, 2, // + 0, 2, 6, 8, 10, 12, 14, 4, /**/ 2, 6, 8, 10, 12, 14, 0, 4, // + 0, 6, 8, 10, 12, 14, 2, 4, /**/ 6, 8, 10, 12, 14, 0, 2, 4, // + 0, 2, 4, 8, 10, 12, 14, 6, /**/ 2, 4, 8, 10, 12, 14, 0, 6, // + 0, 4, 8, 10, 12, 14, 2, 6, /**/ 4, 8, 10, 12, 14, 0, 2, 6, // + 0, 2, 8, 10, 12, 14, 4, 6, /**/ 2, 8, 10, 12, 14, 0, 4, 6, // + 0, 8, 10, 12, 14, 2, 4, 6, /**/ 8, 10, 12, 14, 0, 2, 4, 6, // + 0, 2, 4, 6, 10, 12, 14, 8, /**/ 2, 4, 6, 10, 12, 14, 0, 8, // + 0, 4, 6, 10, 12, 14, 2, 8, /**/ 4, 6, 10, 12, 14, 0, 2, 8, // + 0, 2, 6, 10, 12, 14, 4, 8, /**/ 2, 6, 10, 12, 14, 0, 4, 8, // + 0, 6, 10, 12, 14, 2, 4, 8, /**/ 6, 10, 12, 14, 0, 2, 4, 8, // + 0, 2, 4, 10, 12, 14, 6, 8, /**/ 2, 4, 10, 12, 14, 0, 6, 8, // + 0, 4, 10, 12, 14, 2, 6, 8, /**/ 4, 10, 12, 14, 0, 2, 6, 8, // + 0, 2, 10, 12, 14, 4, 6, 8, /**/ 2, 10, 12, 14, 0, 4, 6, 8, // + 0, 10, 12, 14, 2, 4, 6, 8, /**/ 10, 12, 14, 0, 2, 4, 6, 8, // + 0, 2, 4, 6, 8, 12, 14, 10, /**/ 2, 4, 6, 8, 12, 14, 0, 10, // + 0, 4, 6, 8, 12, 14, 2, 10, /**/ 4, 6, 8, 12, 14, 0, 2, 10, // + 0, 2, 6, 8, 12, 14, 4, 10, /**/ 2, 6, 8, 12, 14, 0, 4, 10, // + 0, 6, 8, 12, 14, 2, 4, 10, /**/ 6, 8, 12, 14, 0, 2, 4, 10, // + 0, 2, 4, 8, 12, 14, 6, 10, /**/ 2, 4, 8, 12, 14, 0, 6, 10, // + 0, 4, 8, 12, 14, 2, 6, 10, /**/ 4, 8, 12, 14, 0, 2, 6, 10, // + 0, 2, 8, 12, 14, 4, 6, 10, /**/ 2, 8, 12, 14, 0, 4, 6, 10, // + 0, 8, 12, 14, 2, 4, 6, 10, /**/ 8, 12, 14, 0, 2, 4, 6, 10, // + 0, 2, 4, 6, 12, 14, 8, 10, /**/ 2, 4, 6, 12, 14, 0, 8, 10, // + 0, 4, 6, 12, 14, 2, 8, 10, /**/ 4, 6, 12, 14, 0, 2, 8, 10, // + 0, 2, 6, 12, 14, 4, 8, 10, /**/ 2, 6, 12, 14, 0, 4, 8, 10, // + 0, 6, 12, 14, 2, 4, 8, 10, /**/ 6, 12, 14, 0, 2, 4, 8, 10, // + 0, 2, 4, 12, 14, 6, 8, 10, /**/ 2, 4, 12, 14, 0, 6, 8, 10, // + 0, 4, 12, 14, 2, 6, 8, 10, /**/ 4, 12, 14, 0, 2, 6, 8, 10, // + 0, 2, 12, 14, 4, 6, 8, 10, /**/ 2, 12, 14, 0, 4, 6, 8, 10, // + 0, 12, 14, 2, 4, 6, 8, 10, /**/ 12, 14, 0, 2, 4, 6, 8, 10, // + 0, 2, 4, 6, 8, 10, 14, 12, /**/ 2, 4, 6, 8, 10, 14, 0, 12, // + 0, 4, 6, 8, 10, 14, 2, 12, /**/ 4, 6, 8, 10, 14, 0, 2, 12, // + 0, 2, 6, 8, 10, 14, 4, 12, /**/ 2, 6, 8, 10, 14, 0, 4, 12, // + 0, 6, 8, 10, 14, 2, 4, 12, /**/ 6, 8, 10, 14, 0, 2, 4, 12, // + 0, 2, 4, 8, 10, 14, 6, 12, /**/ 2, 4, 8, 10, 14, 0, 6, 12, // + 0, 4, 8, 10, 14, 2, 6, 12, /**/ 4, 8, 10, 14, 0, 2, 6, 12, // + 0, 2, 8, 10, 14, 4, 6, 12, /**/ 2, 8, 10, 14, 0, 4, 6, 12, // + 0, 8, 10, 14, 2, 4, 6, 12, /**/ 8, 10, 14, 0, 2, 4, 6, 12, // + 0, 2, 4, 6, 10, 14, 8, 12, /**/ 2, 4, 6, 10, 14, 0, 8, 12, // + 0, 4, 6, 10, 14, 2, 8, 12, /**/ 4, 6, 10, 14, 0, 2, 8, 12, // + 0, 2, 6, 10, 14, 4, 8, 12, /**/ 2, 6, 10, 14, 0, 4, 8, 12, // + 0, 6, 10, 14, 2, 4, 8, 12, /**/ 6, 10, 14, 0, 2, 4, 8, 12, // + 0, 2, 4, 10, 14, 6, 8, 12, /**/ 2, 4, 10, 14, 0, 6, 8, 12, // + 0, 4, 10, 14, 2, 6, 8, 12, /**/ 4, 10, 14, 0, 2, 6, 8, 12, // + 0, 2, 10, 14, 4, 6, 8, 12, /**/ 2, 10, 14, 0, 4, 6, 8, 12, // + 0, 10, 14, 2, 4, 6, 8, 12, /**/ 10, 14, 0, 2, 4, 6, 8, 12, // + 0, 2, 4, 6, 8, 14, 10, 12, /**/ 2, 4, 6, 8, 14, 0, 10, 12, // + 0, 4, 6, 8, 14, 2, 10, 12, /**/ 4, 6, 8, 14, 0, 2, 10, 12, // + 0, 2, 6, 8, 14, 4, 10, 12, /**/ 2, 6, 8, 14, 0, 4, 10, 12, // + 0, 6, 8, 14, 2, 4, 10, 12, /**/ 6, 8, 14, 0, 2, 4, 10, 12, // + 0, 2, 4, 8, 14, 6, 10, 12, /**/ 2, 4, 8, 14, 0, 6, 10, 12, // + 0, 4, 8, 14, 2, 6, 10, 12, /**/ 4, 8, 14, 0, 2, 6, 10, 12, // + 0, 2, 8, 14, 4, 6, 10, 12, /**/ 2, 8, 14, 0, 4, 6, 10, 12, // + 0, 8, 14, 2, 4, 6, 10, 12, /**/ 8, 14, 0, 2, 4, 6, 10, 12, // + 0, 2, 4, 6, 14, 8, 10, 12, /**/ 2, 4, 6, 14, 0, 8, 10, 12, // + 0, 4, 6, 14, 2, 8, 10, 12, /**/ 4, 6, 14, 0, 2, 8, 10, 12, // + 0, 2, 6, 14, 4, 8, 10, 12, /**/ 2, 6, 14, 0, 4, 8, 10, 12, // + 0, 6, 14, 2, 4, 8, 10, 12, /**/ 6, 14, 0, 2, 4, 8, 10, 12, // + 0, 2, 4, 14, 6, 8, 10, 12, /**/ 2, 4, 14, 0, 6, 8, 10, 12, // + 0, 4, 14, 2, 6, 8, 10, 12, /**/ 4, 14, 0, 2, 6, 8, 10, 12, // + 0, 2, 14, 4, 6, 8, 10, 12, /**/ 2, 14, 0, 4, 6, 8, 10, 12, // + 0, 14, 2, 4, 6, 8, 10, 12, /**/ 14, 0, 2, 4, 6, 8, 10, 12, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 0, 14, // + 0, 4, 6, 8, 10, 12, 2, 14, /**/ 4, 6, 8, 10, 12, 0, 2, 14, // + 0, 2, 6, 8, 10, 12, 4, 14, /**/ 2, 6, 8, 10, 12, 0, 4, 14, // + 0, 6, 8, 10, 12, 2, 4, 14, /**/ 6, 8, 10, 12, 0, 2, 4, 14, // + 0, 2, 4, 8, 10, 12, 6, 14, /**/ 2, 4, 8, 10, 12, 0, 6, 14, // + 0, 4, 8, 10, 12, 2, 6, 14, /**/ 4, 8, 10, 12, 0, 2, 6, 14, // + 0, 2, 8, 10, 12, 4, 6, 14, /**/ 2, 8, 10, 12, 0, 4, 6, 14, // + 0, 8, 10, 12, 2, 4, 6, 14, /**/ 8, 10, 12, 0, 2, 4, 6, 14, // + 0, 2, 4, 6, 10, 12, 8, 14, /**/ 2, 4, 6, 10, 12, 0, 8, 14, // + 0, 4, 6, 10, 12, 2, 8, 14, /**/ 4, 6, 10, 12, 0, 2, 8, 14, // + 0, 2, 6, 10, 12, 4, 8, 14, /**/ 2, 6, 10, 12, 0, 4, 8, 14, // + 0, 6, 10, 12, 2, 4, 8, 14, /**/ 6, 10, 12, 0, 2, 4, 8, 14, // + 0, 2, 4, 10, 12, 6, 8, 14, /**/ 2, 4, 10, 12, 0, 6, 8, 14, // + 0, 4, 10, 12, 2, 6, 8, 14, /**/ 4, 10, 12, 0, 2, 6, 8, 14, // + 0, 2, 10, 12, 4, 6, 8, 14, /**/ 2, 10, 12, 0, 4, 6, 8, 14, // + 0, 10, 12, 2, 4, 6, 8, 14, /**/ 10, 12, 0, 2, 4, 6, 8, 14, // + 0, 2, 4, 6, 8, 12, 10, 14, /**/ 2, 4, 6, 8, 12, 0, 10, 14, // + 0, 4, 6, 8, 12, 2, 10, 14, /**/ 4, 6, 8, 12, 0, 2, 10, 14, // + 0, 2, 6, 8, 12, 4, 10, 14, /**/ 2, 6, 8, 12, 0, 4, 10, 14, // + 0, 6, 8, 12, 2, 4, 10, 14, /**/ 6, 8, 12, 0, 2, 4, 10, 14, // + 0, 2, 4, 8, 12, 6, 10, 14, /**/ 2, 4, 8, 12, 0, 6, 10, 14, // + 0, 4, 8, 12, 2, 6, 10, 14, /**/ 4, 8, 12, 0, 2, 6, 10, 14, // + 0, 2, 8, 12, 4, 6, 10, 14, /**/ 2, 8, 12, 0, 4, 6, 10, 14, // + 0, 8, 12, 2, 4, 6, 10, 14, /**/ 8, 12, 0, 2, 4, 6, 10, 14, // + 0, 2, 4, 6, 12, 8, 10, 14, /**/ 2, 4, 6, 12, 0, 8, 10, 14, // + 0, 4, 6, 12, 2, 8, 10, 14, /**/ 4, 6, 12, 0, 2, 8, 10, 14, // + 0, 2, 6, 12, 4, 8, 10, 14, /**/ 2, 6, 12, 0, 4, 8, 10, 14, // + 0, 6, 12, 2, 4, 8, 10, 14, /**/ 6, 12, 0, 2, 4, 8, 10, 14, // + 0, 2, 4, 12, 6, 8, 10, 14, /**/ 2, 4, 12, 0, 6, 8, 10, 14, // + 0, 4, 12, 2, 6, 8, 10, 14, /**/ 4, 12, 0, 2, 6, 8, 10, 14, // + 0, 2, 12, 4, 6, 8, 10, 14, /**/ 2, 12, 0, 4, 6, 8, 10, 14, // + 0, 12, 2, 4, 6, 8, 10, 14, /**/ 12, 0, 2, 4, 6, 8, 10, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 0, 12, 14, // + 0, 4, 6, 8, 10, 2, 12, 14, /**/ 4, 6, 8, 10, 0, 2, 12, 14, // + 0, 2, 6, 8, 10, 4, 12, 14, /**/ 2, 6, 8, 10, 0, 4, 12, 14, // + 0, 6, 8, 10, 2, 4, 12, 14, /**/ 6, 8, 10, 0, 2, 4, 12, 14, // + 0, 2, 4, 8, 10, 6, 12, 14, /**/ 2, 4, 8, 10, 0, 6, 12, 14, // + 0, 4, 8, 10, 2, 6, 12, 14, /**/ 4, 8, 10, 0, 2, 6, 12, 14, // + 0, 2, 8, 10, 4, 6, 12, 14, /**/ 2, 8, 10, 0, 4, 6, 12, 14, // + 0, 8, 10, 2, 4, 6, 12, 14, /**/ 8, 10, 0, 2, 4, 6, 12, 14, // + 0, 2, 4, 6, 10, 8, 12, 14, /**/ 2, 4, 6, 10, 0, 8, 12, 14, // + 0, 4, 6, 10, 2, 8, 12, 14, /**/ 4, 6, 10, 0, 2, 8, 12, 14, // + 0, 2, 6, 10, 4, 8, 12, 14, /**/ 2, 6, 10, 0, 4, 8, 12, 14, // + 0, 6, 10, 2, 4, 8, 12, 14, /**/ 6, 10, 0, 2, 4, 8, 12, 14, // + 0, 2, 4, 10, 6, 8, 12, 14, /**/ 2, 4, 10, 0, 6, 8, 12, 14, // + 0, 4, 10, 2, 6, 8, 12, 14, /**/ 4, 10, 0, 2, 6, 8, 12, 14, // + 0, 2, 10, 4, 6, 8, 12, 14, /**/ 2, 10, 0, 4, 6, 8, 12, 14, // + 0, 10, 2, 4, 6, 8, 12, 14, /**/ 10, 0, 2, 4, 6, 8, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 0, 10, 12, 14, // + 0, 4, 6, 8, 2, 10, 12, 14, /**/ 4, 6, 8, 0, 2, 10, 12, 14, // + 0, 2, 6, 8, 4, 10, 12, 14, /**/ 2, 6, 8, 0, 4, 10, 12, 14, // + 0, 6, 8, 2, 4, 10, 12, 14, /**/ 6, 8, 0, 2, 4, 10, 12, 14, // + 0, 2, 4, 8, 6, 10, 12, 14, /**/ 2, 4, 8, 0, 6, 10, 12, 14, // + 0, 4, 8, 2, 6, 10, 12, 14, /**/ 4, 8, 0, 2, 6, 10, 12, 14, // + 0, 2, 8, 4, 6, 10, 12, 14, /**/ 2, 8, 0, 4, 6, 10, 12, 14, // + 0, 8, 2, 4, 6, 10, 12, 14, /**/ 8, 0, 2, 4, 6, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 0, 8, 10, 12, 14, // + 0, 4, 6, 2, 8, 10, 12, 14, /**/ 4, 6, 0, 2, 8, 10, 12, 14, // + 0, 2, 6, 4, 8, 10, 12, 14, /**/ 2, 6, 0, 4, 8, 10, 12, 14, // + 0, 6, 2, 4, 8, 10, 12, 14, /**/ 6, 0, 2, 4, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 0, 6, 8, 10, 12, 14, // + 0, 4, 2, 6, 8, 10, 12, 14, /**/ 4, 0, 2, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 0, 4, 6, 8, 10, 12, 14, // + 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14}; + + const VFromD byte_idx{Load(d8, table + mask_bits * 8).raw}; + const VFromD pairs = ZipLower(byte_idx, byte_idx); + return BitCast(d, pairs + Set(du, 0x0100)); +} + +template +HWY_INLINE VFromD IndicesFromBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[256] = { + // PrintCompress32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, // + 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, // + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, // + 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, // + 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, // + 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, // + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, // + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, // + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE VFromD IndicesFromNotBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 16); + + // There are only 4 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[256] = { + // PrintCompressNot32x4Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, + 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, + 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 0, 1, 2, 3, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE VFromD IndicesFromBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[64] = { + // PrintCompress64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_INLINE VFromD IndicesFromNotBits128(D d, uint64_t mask_bits) { + HWY_DASSERT(mask_bits < 4); + + // There are only 2 lanes, so we can afford to load the index vector directly. + alignas(16) static constexpr uint8_t u8_indices[64] = { + // PrintCompressNot64x2Tables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + const Repartition d8; + return BitCast(d, Load(d8, u8_indices + 16 * mask_bits)); +} + +template +HWY_API Vec128 CompressBits(Vec128 v, uint64_t mask_bits) { + const DFromV d; + const RebindToUnsigned du; + + HWY_DASSERT(mask_bits < (1ull << N)); + const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits)); + return BitCast(d, TableLookupBytes(BitCast(du, v), indices)); +} + +template +HWY_API Vec128 CompressNotBits(Vec128 v, uint64_t mask_bits) { + const DFromV d; + const RebindToUnsigned du; + + HWY_DASSERT(mask_bits < (1ull << N)); + const auto indices = BitCast(du, detail::IndicesFromNotBits128(d, mask_bits)); + return BitCast(d, TableLookupBytes(BitCast(du, v), indices)); +} + +} // namespace detail + +// Single lane: no-op +template +HWY_API Vec128 Compress(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + // If mask[1] = 1 and mask[0] = 0, then swap both halves, else keep. + const DFromV d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskL, maskH); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +// General case, 2 or 4 bytes +template +HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { + const DFromV d; + return detail::CompressBits(v, BitsFromMask(d, mask)); +} + +// ------------------------------ CompressNot + +// Single lane: no-op +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 /*m*/) { + return v; +} + +// Two lanes: conditional swap +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + // If mask[1] = 0 and mask[0] = 1, then swap both halves, else keep. + const DFromV d; + const Vec128 m = VecFromMask(d, mask); + const Vec128 maskL = DupEven(m); + const Vec128 maskH = DupOdd(m); + const Vec128 swap = AndNot(maskH, maskL); + return IfVecThenElse(swap, Shuffle01(v), v); +} + +template +HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { + const DFromV d; + // For partial vectors, we cannot pull the Not() into the table because + // BitsFromMask clears the upper bits. + if (N < 16 / sizeof(T)) { + return detail::CompressBits(v, BitsFromMask(d, Not(mask))); + } + return detail::CompressNotBits(v, BitsFromMask(d, mask)); +} + +// ------------------------------ CompressBlocksNot +HWY_API Vec128 CompressBlocksNot(Vec128 v, + Mask128 /* m */) { + return v; +} + +template +HWY_API Vec128 CompressBits(Vec128 v, + const uint8_t* HWY_RESTRICT bits) { + uint64_t mask_bits = 0; + constexpr size_t kNumBytes = (N + 7) / 8; + CopyBytes(bits, &mask_bits); + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::CompressBits(v, mask_bits); +} + +// ------------------------------ CompressStore, CompressBitsStore + +template +HWY_API size_t CompressStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + const uint64_t mask_bits = BitsFromMask(d, m); + HWY_DASSERT(mask_bits < (1ull << MaxLanes(d))); + const size_t count = PopCount(mask_bits); + + // Avoid _mm_maskmoveu_si128 (>500 cycle latency because it bypasses caches). + const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + StoreU(compressed, d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + const uint64_t mask_bits = BitsFromMask(d, m); + HWY_DASSERT(mask_bits < (1ull << MaxLanes(d))); + const size_t count = PopCount(mask_bits); + + // Avoid _mm_maskmoveu_si128 (>500 cycle latency because it bypasses caches). + const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + BlendedStore(compressed, FirstN(d, count), d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + + uint64_t mask_bits = 0; + constexpr size_t kN = MaxLanes(d); + constexpr size_t kNumBytes = (kN + 7) / 8; + CopyBytes(bits, &mask_bits); + if (kN < 8) { + mask_bits &= (1ull << kN) - 1; + } + const size_t count = PopCount(mask_bits); + + // Avoid _mm_maskmoveu_si128 (>500 cycle latency because it bypasses caches). + const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits)); + const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices)); + StoreU(compressed, d, unaligned); + + detail::MaybeUnpoison(unaligned, count); + return count; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Expand + +// Otherwise, use the generic_ops-inl.h fallback. +#if HWY_TARGET <= HWY_AVX3 || HWY_IDE + +// The native instructions for 8/16-bit actually require VBMI2 (HWY_AVX3_DL), +// but we still want to override generic_ops-inl's table-based implementation +// whenever we have the 32-bit expand provided by AVX3. +#ifdef HWY_NATIVE_EXPAND +#undef HWY_NATIVE_EXPAND +#else +#define HWY_NATIVE_EXPAND +#endif + +namespace detail { + +#if HWY_TARGET <= HWY_AVX3_DL || HWY_IDE // VBMI2 + +template +HWY_INLINE Vec128 NativeExpand(Vec128 v, + Mask128 mask) { + return Vec128{_mm_maskz_expand_epi8(mask.raw, v.raw)}; +} + +template +HWY_INLINE Vec128 NativeExpand(Vec128 v, + Mask128 mask) { + return Vec128{_mm_maskz_expand_epi16(mask.raw, v.raw)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(MFromD mask, D /* d */, + const uint8_t* HWY_RESTRICT unaligned) { + return VFromD{_mm_maskz_expandloadu_epi8(mask.raw, unaligned)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(MFromD mask, D /* d */, + const uint16_t* HWY_RESTRICT unaligned) { + return VFromD{_mm_maskz_expandloadu_epi16(mask.raw, unaligned)}; +} + +#endif // HWY_TARGET <= HWY_AVX3_DL + +template +HWY_INLINE Vec128 NativeExpand(Vec128 v, + Mask128 mask) { + return Vec128{_mm_maskz_expand_epi32(mask.raw, v.raw)}; +} + +template +HWY_INLINE Vec128 NativeExpand(Vec128 v, + Mask128 mask) { + return Vec128{_mm_maskz_expand_epi64(mask.raw, v.raw)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(MFromD mask, D /* d */, + const uint32_t* HWY_RESTRICT unaligned) { + return VFromD{_mm_maskz_expandloadu_epi32(mask.raw, unaligned)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(MFromD mask, D /* d */, + const uint64_t* HWY_RESTRICT unaligned) { + return VFromD{_mm_maskz_expandloadu_epi64(mask.raw, unaligned)}; +} + +} // namespace detail + +// Otherwise, 8/16-bit are implemented in x86_512 using PromoteTo. +#if HWY_TARGET <= HWY_AVX3_DL || HWY_IDE // VBMI2 + +template +HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { + const DFromV d; + const RebindToUnsigned du; + const MFromD mu = RebindMask(du, mask); + return BitCast(d, detail::NativeExpand(BitCast(du, v), mu)); +} + +#endif // HWY_TARGET <= HWY_AVX3_DL + +template +HWY_API Vec128 Expand(Vec128 v, Mask128 mask) { + const DFromV d; + const RebindToUnsigned du; + const MFromD mu = RebindMask(du, mask); + return BitCast(d, detail::NativeExpand(BitCast(du, v), mu)); +} + +// ------------------------------ LoadExpand + +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { +#if HWY_TARGET <= HWY_AVX3_DL // VBMI2 + const RebindToUnsigned du; + using TU = TFromD; + const TU* HWY_RESTRICT pu = reinterpret_cast(unaligned); + const MFromD mu = RebindMask(du, mask); + return BitCast(d, detail::NativeLoadExpand(mu, du, pu)); +#else + return Expand(LoadU(d, unaligned), mask); +#endif +} + +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned du; + using TU = TFromD; + const TU* HWY_RESTRICT pu = reinterpret_cast(unaligned); + const MFromD mu = RebindMask(du, mask); + return BitCast(d, detail::NativeLoadExpand(mu, du, pu)); +#else + return Expand(LoadU(d, unaligned), mask); +#endif +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ StoreInterleaved2/3/4 + +// HWY_NATIVE_LOAD_STORE_INTERLEAVED not set, hence defined in +// generic_ops-inl.h. + +// ------------------------------ Additional mask logical operations + +#if HWY_TARGET <= HWY_AVX3 +namespace detail { + +template +static HWY_INLINE uint32_t AVX3Blsi(T x) { + using TU = MakeUnsigned; + const auto u32_val = static_cast(static_cast(x)); +#if HWY_COMPILER_CLANGCL + return static_cast(u32_val & (0u - u32_val)); +#else + return static_cast(_blsi_u32(u32_val)); +#endif +} +template +static HWY_INLINE uint64_t AVX3Blsi(T x) { + const auto u64_val = static_cast(x); +#if HWY_COMPILER_CLANGCL || HWY_ARCH_X86_32 + return static_cast(u64_val & (0ULL - u64_val)); +#else + return static_cast(_blsi_u64(u64_val)); +#endif +} + +template +static HWY_INLINE uint32_t AVX3Blsmsk(T x) { + using TU = MakeUnsigned; + const auto u32_val = static_cast(static_cast(x)); +#if HWY_COMPILER_CLANGCL + return static_cast(u32_val ^ (u32_val - 1u)); +#else + return static_cast(_blsmsk_u32(u32_val)); +#endif +} +template +static HWY_INLINE uint64_t AVX3Blsmsk(T x) { + const auto u64_val = static_cast(x); +#if HWY_COMPILER_CLANGCL || HWY_ARCH_X86_32 + return static_cast(u64_val ^ (u64_val - 1ULL)); +#else + return static_cast(_blsmsk_u64(u64_val)); +#endif +} + +} // namespace detail + +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + constexpr uint32_t kActiveElemMask = (uint32_t{1} << N) - 1; + return Mask128{static_cast::Raw>( + (0u - detail::AVX3Blsi(mask.raw)) & kActiveElemMask)}; +} +template +HWY_API Mask128 SetBeforeFirst(Mask128 mask) { + constexpr uint32_t kActiveElemMask = (uint32_t{1} << N) - 1; + return Mask128{static_cast::Raw>( + (detail::AVX3Blsi(mask.raw) - 1u) & kActiveElemMask)}; +} +template +HWY_API Mask128 SetAtOrBeforeFirst(Mask128 mask) { + constexpr uint32_t kActiveElemMask = (uint32_t{1} << N) - 1; + return Mask128{static_cast::Raw>( + detail::AVX3Blsmsk(mask.raw) & kActiveElemMask)}; +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + return Mask128{ + static_cast::Raw>(detail::AVX3Blsi(mask.raw))}; +} +#else // AVX2 or below +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + return mask; +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const FixedTag d; + const auto vmask = VecFromMask(d, mask); + return MaskFromVec(Or(vmask, InterleaveLower(vmask, vmask))); +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const Simd d; + const auto vmask = VecFromMask(d, mask); + const auto neg_vmask = + ResizeBitCast(d, Neg(ResizeBitCast(Full64(), vmask))); + return MaskFromVec(Or(vmask, neg_vmask)); +} +template +HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { + const Full128 d; + const Repartition di64; + const Repartition df32; + const Repartition di32; + using VF = VFromD; + + auto vmask = BitCast(di64, VecFromMask(d, mask)); + vmask = Or(vmask, Neg(vmask)); + + // Copy the sign bit of the first int64_t lane to the second int64_t lane + const auto vmask2 = BroadcastSignBit( + BitCast(di32, VF{_mm_shuffle_ps(Zero(df32).raw, BitCast(df32, vmask).raw, + _MM_SHUFFLE(1, 1, 0, 0))})); + return MaskFromVec(BitCast(d, Or(vmask, BitCast(di64, vmask2)))); +} + +template +HWY_API Mask128 SetBeforeFirst(Mask128 mask) { + return Not(SetAtOrAfterFirst(mask)); +} + +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + return mask; +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const FixedTag d; + const RebindToSigned di; + + const auto vmask = BitCast(di, VecFromMask(d, mask)); + const auto zero = Zero(di); + const auto vmask2 = VecFromMask(di, InterleaveLower(zero, vmask) == zero); + return MaskFromVec(BitCast(d, And(vmask, vmask2))); +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const Simd d; + const RebindToSigned di; + + const auto vmask = ResizeBitCast(Full64(), VecFromMask(d, mask)); + const auto only_first_vmask = + BitCast(d, Neg(ResizeBitCast(di, And(vmask, Neg(vmask))))); + return MaskFromVec(only_first_vmask); +} +template +HWY_API Mask128 SetOnlyFirst(Mask128 mask) { + const Full128 d; + const RebindToSigned di; + const Repartition di64; + + const auto zero = Zero(di64); + const auto vmask = BitCast(di64, VecFromMask(d, mask)); + const auto vmask2 = VecFromMask(di64, InterleaveLower(zero, vmask) == zero); + const auto only_first_vmask = Neg(BitCast(di, And(vmask, Neg(vmask)))); + return MaskFromVec(BitCast(d, And(only_first_vmask, BitCast(di, vmask2)))); +} + +template +HWY_API Mask128 SetAtOrBeforeFirst(Mask128 /*mask*/) { + const FixedTag d; + const RebindToSigned di; + using TI = MakeSigned; + + return RebindMask(d, MaskFromVec(Set(di, TI(-1)))); +} +template +HWY_API Mask128 SetAtOrBeforeFirst(Mask128 mask) { + const Simd d; + return SetBeforeFirst(MaskFromVec(ShiftLeftLanes<1>(VecFromMask(d, mask)))); +} +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Reductions + +// Nothing fully native, generic_ops-inl defines SumOfLanes and ReduceSum. + +// We provide specializations of u8x8 and u8x16, so exclude those. +#undef HWY_IF_SUM_OF_LANES_D +#define HWY_IF_SUM_OF_LANES_D(D) \ + HWY_IF_LANES_GT_D(D, 1), \ + hwy::EnableIf, uint8_t>() || \ + (HWY_V_SIZE_D(D) != 8 && HWY_V_SIZE_D(D) != 16)>* = \ + nullptr + +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + return Set(d, static_cast(GetLane(SumsOf8(v)) & 0xFF)); +} +template +HWY_API VFromD SumOfLanes(D d, VFromD v) { + const Repartition d64; + VFromD sums = SumsOf8(v); + sums = SumOfLanes(d64, sums); + return Broadcast<0>(BitCast(d, sums)); +} + +#if HWY_TARGET <= HWY_SSE4 +// We provide specializations of u8x8, u8x16, and u16x8, so exclude those. +#undef HWY_IF_MINMAX_OF_LANES_D +#define HWY_IF_MINMAX_OF_LANES_D(D) \ + HWY_IF_LANES_GT_D(D, 1), \ + hwy::EnableIf<(!hwy::IsSame, uint8_t>() || \ + ((HWY_V_SIZE_D(D) < 8) || (HWY_V_SIZE_D(D) > 16))) && \ + (!hwy::IsSame, uint16_t>() || \ + (HWY_V_SIZE_D(D) != 16))>* = nullptr + +template +HWY_API Vec128 MinOfLanes(D /* tag */, Vec128 v) { + return Broadcast<0>(Vec128{_mm_minpos_epu16(v.raw)}); +} + +template +HWY_API Vec128 MaxOfLanes(D d, Vec128 v) { + const Vec128 max = Set(d, LimitsMax()); + return max - MinOfLanes(d, max - v); +} + +template +HWY_API Vec64 MinOfLanes(D d, Vec64 v) { + const Rebind d16; + return TruncateTo(d, MinOfLanes(d16, PromoteTo(d16, v))); +} +template +HWY_API Vec128 MinOfLanes(D d, Vec128 v) { + const Half dh; + Vec64 result = + Min(MinOfLanes(dh, UpperHalf(dh, v)), MinOfLanes(dh, LowerHalf(dh, v))); + return Combine(d, result, result); +} + +template +HWY_API Vec64 MaxOfLanes(D d, Vec64 v) { + const Vec64 m(Set(d, LimitsMax())); + return m - MinOfLanes(d, m - v); +} +template +HWY_API Vec128 MaxOfLanes(D d, Vec128 v) { + const Vec128 m(Set(d, LimitsMax())); + return m - MinOfLanes(d, m - v); +} + +#endif // HWY_TARGET <= HWY_SSE4 + +// ------------------------------ BitShuffle +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_BITSHUFFLE +#undef HWY_NATIVE_BITSHUFFLE +#else +#define HWY_NATIVE_BITSHUFFLE +#endif + +template ), HWY_IF_UI8(TFromV), + HWY_IF_V_SIZE_LE_V(V, 16), + HWY_IF_V_SIZE_V(VI, HWY_MAX_LANES_V(V) * 8)> +HWY_API V BitShuffle(V v, VI idx) { + const DFromV d64; + const RebindToUnsigned du64; + const Rebind du8; + + int32_t i32_bit_shuf_result = static_cast( + static_cast(_mm_bitshuffle_epi64_mask(v.raw, idx.raw))); + + return BitCast(d64, PromoteTo(du64, VFromD{_mm_cvtsi32_si128( + i32_bit_shuf_result)})); +} +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ------------------------------ MultiRotateRight + +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_MULTIROTATERIGHT +#undef HWY_NATIVE_MULTIROTATERIGHT +#else +#define HWY_NATIVE_MULTIROTATERIGHT +#endif + +template ), HWY_IF_UI8(TFromV), + HWY_IF_V_SIZE_LE_V(V, 16), + HWY_IF_V_SIZE_V(VI, HWY_MAX_LANES_V(V) * 8)> +HWY_API V MultiRotateRight(V v, VI idx) { + return V{_mm_multishift_epi64_epi8(idx.raw, v.raw)}; +} + +#endif + +// ------------------------------ Lt128 + +namespace detail { + +// Returns vector-mask for Lt128. Generic for all vector lengths. +template +HWY_INLINE VFromD Lt128Vec(const D d, VFromD a, VFromD b) { + // Truth table of Eq and Lt for Hi and Lo u64. + // (removed lines with (=H && cH) or (=L && cL) - cannot both be true) + // =H =L cH cL | out = cH | (=H & cL) + // 0 0 0 0 | 0 + // 0 0 0 1 | 0 + // 0 0 1 0 | 1 + // 0 0 1 1 | 1 + // 0 1 0 0 | 0 + // 0 1 0 1 | 0 + // 0 1 1 0 | 1 + // 1 0 0 0 | 0 + // 1 0 0 1 | 1 + // 1 1 0 0 | 0 + const auto eqHL = Eq(a, b); + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + const VFromD ltLX = ShiftLeftLanes<1>(ltHL); + const VFromD vecHx = IfThenElse(eqHL, ltLX, ltHL); + return InterleaveUpper(d, vecHx, vecHx); +} + +// Returns vector-mask for Eq128. Generic for all vector lengths. +template +HWY_INLINE VFromD Eq128Vec(D d, VFromD a, VFromD b) { + const auto eqHL = VecFromMask(d, Eq(a, b)); + const auto eqLH = Reverse2(d, eqHL); + return And(eqHL, eqLH); +} + +template +HWY_INLINE VFromD Ne128Vec(D d, VFromD a, VFromD b) { + const auto neHL = VecFromMask(d, Ne(a, b)); + const auto neLH = Reverse2(d, neHL); + return Or(neHL, neLH); +} + +template +HWY_INLINE VFromD Lt128UpperVec(D d, VFromD a, VFromD b) { + // No specialization required for AVX-512: Mask <-> Vec is fast, and + // copying mask bits to their neighbor seems infeasible. + const VFromD ltHL = VecFromMask(d, Lt(a, b)); + return InterleaveUpper(d, ltHL, ltHL); +} + +template +HWY_INLINE VFromD Eq128UpperVec(D d, VFromD a, VFromD b) { + // No specialization required for AVX-512: Mask <-> Vec is fast, and + // copying mask bits to their neighbor seems infeasible. + const VFromD eqHL = VecFromMask(d, Eq(a, b)); + return InterleaveUpper(d, eqHL, eqHL); +} + +template +HWY_INLINE VFromD Ne128UpperVec(D d, VFromD a, VFromD b) { + // No specialization required for AVX-512: Mask <-> Vec is fast, and + // copying mask bits to their neighbor seems infeasible. + const VFromD neHL = VecFromMask(d, Ne(a, b)); + return InterleaveUpper(d, neHL, neHL); +} + +} // namespace detail + +template +HWY_API MFromD Lt128(D d, VFromD a, VFromD b) { + return MaskFromVec(detail::Lt128Vec(d, a, b)); +} + +template +HWY_API MFromD Eq128(D d, VFromD a, VFromD b) { + return MaskFromVec(detail::Eq128Vec(d, a, b)); +} + +template +HWY_API MFromD Ne128(D d, VFromD a, VFromD b) { + return MaskFromVec(detail::Ne128Vec(d, a, b)); +} + +template +HWY_API MFromD Lt128Upper(D d, VFromD a, VFromD b) { + return MaskFromVec(detail::Lt128UpperVec(d, a, b)); +} + +template +HWY_API MFromD Eq128Upper(D d, VFromD a, VFromD b) { + return MaskFromVec(detail::Eq128UpperVec(d, a, b)); +} + +template +HWY_API MFromD Ne128Upper(D d, VFromD a, VFromD b) { + return MaskFromVec(detail::Ne128UpperVec(d, a, b)); +} + +// ------------------------------ Min128, Max128 (Lt128) + +// Avoids the extra MaskFromVec in Lt128. +template +HWY_API VFromD Min128(D d, VFromD a, VFromD b) { + return IfVecThenElse(detail::Lt128Vec(d, a, b), a, b); +} + +template +HWY_API VFromD Max128(D d, VFromD a, VFromD b) { + return IfVecThenElse(detail::Lt128Vec(d, b, a), a, b); +} + +template +HWY_API VFromD Min128Upper(D d, VFromD a, VFromD b) { + return IfVecThenElse(detail::Lt128UpperVec(d, a, b), a, b); +} + +template +HWY_API VFromD Max128Upper(D d, VFromD a, VFromD b) { + return IfVecThenElse(detail::Lt128UpperVec(d, b, a), a, b); +} + +// -------------------- LeadingZeroCount, TrailingZeroCount, HighestSetBitIndex + +#if HWY_TARGET <= HWY_AVX3 + +#ifdef HWY_NATIVE_LEADING_ZERO_COUNT +#undef HWY_NATIVE_LEADING_ZERO_COUNT +#else +#define HWY_NATIVE_LEADING_ZERO_COUNT +#endif + +template ), HWY_IF_V_SIZE_LE_D(DFromV, 16)> +HWY_API V LeadingZeroCount(V v) { + return V{_mm_lzcnt_epi32(v.raw)}; +} + +template ), HWY_IF_V_SIZE_LE_D(DFromV, 16)> +HWY_API V LeadingZeroCount(V v) { + return V{_mm_lzcnt_epi64(v.raw)}; +} + +// HighestSetBitIndex and TrailingZeroCount is implemented in x86_512-inl.h +// for AVX3 targets + +#endif // HWY_TARGET <= HWY_AVX3 + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#undef HWY_X86_IF_EMULATED_D + +// Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - +// the warning seems to be issued at the call site of intrinsics, i.e. our code. +HWY_DIAGNOSTICS(pop) diff --git a/lib/highway/hwy/ops/x86_256-inl.h b/lib/highway/hwy/ops/x86_256-inl.h new file mode 100644 index 00000000000..a239b06544d --- /dev/null +++ b/lib/highway/hwy/ops/x86_256-inl.h @@ -0,0 +1,8996 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 256-bit vectors and AVX2 instructions, plus some AVX512-VL operations when +// compiling for that target. +// External include guard in highway.h - see comment there. + +// WARNING: most operations do not cross 128-bit block boundaries. In +// particular, "Broadcast", pack and zip behavior may be surprising. + +// Must come before HWY_DIAGNOSTICS and HWY_COMPILER_CLANGCL +#include "hwy/base.h" + +// Avoid uninitialized warnings in GCC's avx512fintrin.h - see +// https://github.com/google/highway/issues/710) +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") +HWY_DIAGNOSTICS_OFF(disable : 4701 4703 6001 26494, + ignored "-Wmaybe-uninitialized") +#endif + +// Must come before HWY_COMPILER_CLANGCL +#include // AVX2+ + +#if HWY_COMPILER_CLANGCL +// Including should be enough, but Clang's headers helpfully skip +// including these headers when _MSC_VER is defined, like when using clang-cl. +// Include these directly here. +#include +// avxintrin defines __m256i and must come before avx2intrin. +#include +#include // _pext_u64 +#include +#include +#include + +#if HWY_TARGET <= HWY_AVX10_2 +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +// Must come after avx512fintrin, else will not define 512-bit intrinsics. +#include +#include +#include +#include +#include + +#endif // HWY_TARGET <= HWY_AVX10_2 + +// clang-format on +#endif // HWY_COMPILER_CLANGCL + +// For half-width vectors. Already includes base.h. +#include "hwy/ops/shared-inl.h" +// Already included by shared-inl, but do it again to avoid IDE warnings. +#include "hwy/ops/x86_128-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace detail { + +template +struct Raw256 { + using type = __m256i; +}; +#if HWY_HAVE_FLOAT16 +template <> +struct Raw256 { + using type = __m256h; +}; +#endif // HWY_HAVE_FLOAT16 +template <> +struct Raw256 { + using type = __m256; +}; +template <> +struct Raw256 { + using type = __m256d; +}; + +} // namespace detail + +template +class Vec256 { + using Raw = typename detail::Raw256::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = 32 / sizeof(T); // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec256& operator*=(const Vec256 other) { + return *this = (*this * other); + } + HWY_INLINE Vec256& operator/=(const Vec256 other) { + return *this = (*this / other); + } + HWY_INLINE Vec256& operator+=(const Vec256 other) { + return *this = (*this + other); + } + HWY_INLINE Vec256& operator-=(const Vec256 other) { + return *this = (*this - other); + } + HWY_INLINE Vec256& operator%=(const Vec256 other) { + return *this = (*this % other); + } + HWY_INLINE Vec256& operator&=(const Vec256 other) { + return *this = (*this & other); + } + HWY_INLINE Vec256& operator|=(const Vec256 other) { + return *this = (*this | other); + } + HWY_INLINE Vec256& operator^=(const Vec256 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +namespace detail { + +#if HWY_TARGET <= HWY_AVX3 + +// Template arg: sizeof(lane type) +template +struct RawMask256T {}; +template <> +struct RawMask256T<1> { + using type = __mmask32; +}; +template <> +struct RawMask256T<2> { + using type = __mmask16; +}; +template <> +struct RawMask256T<4> { + using type = __mmask8; +}; +template <> +struct RawMask256T<8> { + using type = __mmask8; +}; + +template +using RawMask256 = typename RawMask256T::type; + +#else // AVX2 or earlier + +template +using RawMask256 = typename Raw256::type; + +#endif // HWY_TARGET <= HWY_AVX3 + +} // namespace detail + +template +struct Mask256 { + using Raw = typename detail::RawMask256; + + using PrivateT = T; // only for DFromM + static constexpr size_t kPrivateN = 32 / sizeof(T); // only for DFromM + +#if HWY_TARGET <= HWY_AVX3 + static Mask256 FromBits(uint64_t mask_bits) { + return Mask256{static_cast(mask_bits)}; + } +#else +// Lanes are either FF..FF or 0. +#endif // HWY_TARGET <= HWY_AVX3 + + Raw raw; +}; + +template +using Full256 = Simd; + +// ------------------------------ Zero + +// Cannot use VFromD here because it is defined in terms of Zero. +template +HWY_API Vec256> Zero(D /* tag */) { + return Vec256>{_mm256_setzero_si256()}; +} +template +HWY_API Vec256 Zero(D /* tag */) { + return Vec256{_mm256_setzero_si256()}; +} +template +HWY_API Vec256 Zero(D /* tag */) { +#if HWY_HAVE_FLOAT16 + return Vec256{_mm256_setzero_ph()}; +#else + return Vec256{_mm256_setzero_si256()}; +#endif // HWY_HAVE_FLOAT16 +} +template +HWY_API Vec256 Zero(D /* tag */) { + return Vec256{_mm256_setzero_ps()}; +} +template +HWY_API Vec256 Zero(D /* tag */) { + return Vec256{_mm256_setzero_pd()}; +} + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __m256i BitCastToInteger(__m256i v) { return v; } +#if HWY_HAVE_FLOAT16 +HWY_INLINE __m256i BitCastToInteger(__m256h v) { + return _mm256_castph_si256(v); +} +#endif // HWY_HAVE_FLOAT16 +HWY_INLINE __m256i BitCastToInteger(__m256 v) { return _mm256_castps_si256(v); } +HWY_INLINE __m256i BitCastToInteger(__m256d v) { + return _mm256_castpd_si256(v); +} + +#if HWY_AVX3_HAVE_F32_TO_BF16C +HWY_INLINE __m256i BitCastToInteger(__m256bh v) { + // Need to use reinterpret_cast on GCC/Clang or BitCastScalar on MSVC to + // bit cast a __m256bh to a __m256i as there is currently no intrinsic + // available (as of GCC 13 and Clang 17) that can bit cast a __m256bh vector + // to a __m256i vector + +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANG + // On GCC or Clang, use reinterpret_cast to bit cast a __m256bh to a __m256i + return reinterpret_cast<__m256i>(v); +#else + // On MSVC, use BitCastScalar to bit cast a __m256bh to a __m256i as MSVC does + // not allow reinterpret_cast, static_cast, or a C-style cast to be used to + // bit cast from one AVX vector type to a different AVX vector type + return BitCastScalar<__m256i>(v); +#endif // HWY_COMPILER_GCC || HWY_COMPILER_CLANG +} +#endif // HWY_AVX3_HAVE_F32_TO_BF16C + +template +HWY_INLINE Vec256 BitCastToByte(Vec256 v) { + return Vec256{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger256 { + HWY_INLINE __m256i operator()(__m256i v) { return v; } +}; +#if HWY_HAVE_FLOAT16 +template <> +struct BitCastFromInteger256 { + HWY_INLINE __m256h operator()(__m256i v) { return _mm256_castsi256_ph(v); } +}; +#endif // HWY_HAVE_FLOAT16 +template <> +struct BitCastFromInteger256 { + HWY_INLINE __m256 operator()(__m256i v) { return _mm256_castsi256_ps(v); } +}; +template <> +struct BitCastFromInteger256 { + HWY_INLINE __m256d operator()(__m256i v) { return _mm256_castsi256_pd(v); } +}; + +template +HWY_INLINE VFromD BitCastFromByte(D /* tag */, Vec256 v) { + return VFromD{BitCastFromInteger256>()(v.raw)}; +} + +} // namespace detail + +template +HWY_API VFromD BitCast(D d, Vec256 v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Set + +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm256_set1_epi8(static_cast(t))}; // NOLINT +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm256_set1_epi16(static_cast(t))}; // NOLINT +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm256_set1_epi32(static_cast(t))}; +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm256_set1_epi64x(static_cast(t))}; // NOLINT +} +// bfloat16_t is handled by x86_128-inl.h. +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec256 Set(D /* tag */, float16_t t) { + return Vec256{_mm256_set1_ph(t)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec256 Set(D /* tag */, float t) { + return Vec256{_mm256_set1_ps(t)}; +} +template +HWY_API Vec256 Set(D /* tag */, double t) { + return Vec256{_mm256_set1_pd(t)}; +} + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template +HWY_API VFromD Undefined(D /* tag */) { + // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC + // generate an XOR instruction. + return VFromD{_mm256_undefined_si256()}; +} +template +HWY_API Vec256 Undefined(D /* tag */) { + return Vec256{_mm256_undefined_si256()}; +} +template +HWY_API Vec256 Undefined(D /* tag */) { +#if HWY_HAVE_FLOAT16 + return Vec256{_mm256_undefined_ph()}; +#else + return Vec256{_mm256_undefined_si256()}; +#endif +} +template +HWY_API Vec256 Undefined(D /* tag */) { + return Vec256{_mm256_undefined_ps()}; +} +template +HWY_API Vec256 Undefined(D /* tag */) { + return Vec256{_mm256_undefined_pd()}; +} + +HWY_DIAGNOSTICS(pop) + +// ------------------------------ ResizeBitCast + +// 32-byte vector to 32-byte vector (or 64-byte vector to 64-byte vector on +// AVX3) +template ))> +HWY_API VFromD ResizeBitCast(D d, FromV v) { + return BitCast(d, v); +} + +// 32-byte vector to 16-byte vector (or 64-byte vector to 32-byte vector on +// AVX3) +template )) / 2)> +HWY_API VFromD ResizeBitCast(D d, FromV v) { + const DFromV d_from; + const Half dh_from; + return BitCast(d, LowerHalf(dh_from, v)); +} + +// 32-byte vector (or 64-byte vector on AVX3) to <= 8-byte vector +template +HWY_API VFromD ResizeBitCast(D /*d*/, FromV v) { + return VFromD{ResizeBitCast(Full128>(), v).raw}; +} + +// <= 16-byte vector to 32-byte vector +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + return BitCast(d, Vec256{_mm256_castsi128_si256( + ResizeBitCast(Full128(), v).raw)}); +} + +// ------------------------------ Dup128VecFromValues + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { + return VFromD{_mm256_setr_epi8( + static_cast(t0), static_cast(t1), static_cast(t2), + static_cast(t3), static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7), static_cast(t8), + static_cast(t9), static_cast(t10), static_cast(t11), + static_cast(t12), static_cast(t13), static_cast(t14), + static_cast(t15), static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), static_cast(t4), + static_cast(t5), static_cast(t6), static_cast(t7), + static_cast(t8), static_cast(t9), static_cast(t10), + static_cast(t11), static_cast(t12), static_cast(t13), + static_cast(t14), static_cast(t15))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return VFromD{ + _mm256_setr_epi16(static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7), + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t4), static_cast(t5), + static_cast(t6), static_cast(t7))}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return VFromD{_mm256_setr_ph(t0, t1, t2, t3, t4, t5, t6, t7, t0, t1, t2, + t3, t4, t5, t6, t7)}; +} +#endif + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{ + _mm256_setr_epi32(static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{_mm256_setr_ps(t0, t1, t2, t3, t0, t1, t2, t3)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{ + _mm256_setr_epi64x(static_cast(t0), static_cast(t1), + static_cast(t0), static_cast(t1))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{_mm256_setr_pd(t0, t1, t0, t1)}; +} + +// ================================================== LOGICAL + +// ------------------------------ And + +template +HWY_API Vec256 And(Vec256 a, Vec256 b) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm256_and_si256(BitCast(du, a).raw, + BitCast(du, b).raw)}); +} + +HWY_API Vec256 And(Vec256 a, Vec256 b) { + return Vec256{_mm256_and_ps(a.raw, b.raw)}; +} +HWY_API Vec256 And(Vec256 a, Vec256 b) { + return Vec256{_mm256_and_pd(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec256 AndNot(Vec256 not_mask, Vec256 mask) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm256_andnot_si256( + BitCast(du, not_mask).raw, BitCast(du, mask).raw)}); +} +HWY_API Vec256 AndNot(Vec256 not_mask, Vec256 mask) { + return Vec256{_mm256_andnot_ps(not_mask.raw, mask.raw)}; +} +HWY_API Vec256 AndNot(Vec256 not_mask, Vec256 mask) { + return Vec256{_mm256_andnot_pd(not_mask.raw, mask.raw)}; +} + +// ------------------------------ Or + +template +HWY_API Vec256 Or(Vec256 a, Vec256 b) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm256_or_si256(BitCast(du, a).raw, + BitCast(du, b).raw)}); +} + +HWY_API Vec256 Or(Vec256 a, Vec256 b) { + return Vec256{_mm256_or_ps(a.raw, b.raw)}; +} +HWY_API Vec256 Or(Vec256 a, Vec256 b) { + return Vec256{_mm256_or_pd(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template +HWY_API Vec256 Xor(Vec256 a, Vec256 b) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm256_xor_si256(BitCast(du, a).raw, + BitCast(du, b).raw)}); +} + +HWY_API Vec256 Xor(Vec256 a, Vec256 b) { + return Vec256{_mm256_xor_ps(a.raw, b.raw)}; +} +HWY_API Vec256 Xor(Vec256 a, Vec256 b) { + return Vec256{_mm256_xor_pd(a.raw, b.raw)}; +} + +#if HWY_X86_HAVE_TERNARY_LOGIC +namespace detail { + +// Per-target partial specialization. +template +struct TernaryLogicImpl { + template + HWY_INLINE V operator()(V a, V b, V c) const { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const __m256i ret = _mm256_ternarylogic_epi64( + BitCast(du, a).raw, BitCast(du, b).raw, BitCast(du, c).raw, kTernLogOp); + return BitCast(d, VU{ret}); + } +}; + +// Same, but with writemask. If !mask, returns a. +template +struct MaskedTernaryLogicImpl { + template , HWY_IF_T_SIZE_D(D, 4)> + HWY_INLINE V operator()(MFromD mask, V a, V b, V c) const { + const D d; + const RebindToUnsigned du; + using VU = VFromD; + const __m256i ret = _mm256_mask_ternarylogic_epi32(a.raw, mask.raw, b.raw, + c.raw, kTernLogOp); + return BitCast(d, VU{ret}); + } + template , HWY_IF_T_SIZE_D(D, 8)> + HWY_INLINE V operator()(MFromD mask, V a, V b, V c) const { + const D d; + const RebindToUnsigned du; + using VU = VFromD; + const __m256i ret = _mm256_mask_ternarylogic_epi64(a.raw, mask.raw, b.raw, + c.raw, kTernLogOp); + return BitCast(d, VU{ret}); + } +}; + +} // namespace detail +#endif // HWY_X86_HAVE_TERNARY_LOGIC + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec256 operator&(const Vec256 a, const Vec256 b) { + return And(a, b); +} + +template +HWY_API Vec256 operator|(const Vec256 a, const Vec256 b) { + return Or(a, b); +} + +template +HWY_API Vec256 operator^(const Vec256 a, const Vec256 b) { + return Xor(a, b); +} + +// ------------------------------ PopulationCount + +// 8/16 require BITALG, 32/64 require VPOPCNTDQ. +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +namespace detail { + +template +HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<1> /* tag */, Vec256 v) { + return Vec256{_mm256_popcnt_epi8(v.raw)}; +} +template +HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<2> /* tag */, Vec256 v) { + return Vec256{_mm256_popcnt_epi16(v.raw)}; +} +template +HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<4> /* tag */, Vec256 v) { + return Vec256{_mm256_popcnt_epi32(v.raw)}; +} +template +HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<8> /* tag */, Vec256 v) { + return Vec256{_mm256_popcnt_epi64(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec256 PopulationCount(Vec256 v) { + return detail::PopulationCount(hwy::SizeTag(), v); +} + +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ================================================== MASK + +#if HWY_TARGET <= HWY_AVX3 + +// ------------------------------ IfThenElse + +// Returns mask ? b : a. + +namespace detail { + +// Templates for signed/unsigned integer of a particular size. +template +HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<1> /* tag */, Mask256 mask, + Vec256 yes, Vec256 no) { + return Vec256{_mm256_mask_blend_epi8(mask.raw, no.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<2> /* tag */, Mask256 mask, + Vec256 yes, Vec256 no) { + return Vec256{_mm256_mask_blend_epi16(mask.raw, no.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<4> /* tag */, Mask256 mask, + Vec256 yes, Vec256 no) { + return Vec256{_mm256_mask_blend_epi32(mask.raw, no.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<8> /* tag */, Mask256 mask, + Vec256 yes, Vec256 no) { + return Vec256{_mm256_mask_blend_epi64(mask.raw, no.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { + return detail::IfThenElse(hwy::SizeTag(), mask, yes, no); +} +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 IfThenElse(Mask256 mask, + Vec256 yes, + Vec256 no) { + return Vec256{_mm256_mask_blend_ph(mask.raw, no.raw, yes.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, + Vec256 no) { + return Vec256{_mm256_mask_blend_ps(mask.raw, no.raw, yes.raw)}; +} +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, + Vec256 no) { + return Vec256{_mm256_mask_blend_pd(mask.raw, no.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<1> /* tag */, Mask256 mask, + Vec256 yes) { + return Vec256{_mm256_maskz_mov_epi8(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<2> /* tag */, Mask256 mask, + Vec256 yes) { + return Vec256{_mm256_maskz_mov_epi16(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<4> /* tag */, Mask256 mask, + Vec256 yes) { + return Vec256{_mm256_maskz_mov_epi32(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<8> /* tag */, Mask256 mask, + Vec256 yes) { + return Vec256{_mm256_maskz_mov_epi64(mask.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { + return detail::IfThenElseZero(hwy::SizeTag(), mask, yes); +} +HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { + return Vec256{_mm256_maskz_mov_ps(mask.raw, yes.raw)}; +} +HWY_API Vec256 IfThenElseZero(Mask256 mask, + Vec256 yes) { + return Vec256{_mm256_maskz_mov_pd(mask.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<1> /* tag */, Mask256 mask, + Vec256 no) { + // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. + return Vec256{_mm256_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<2> /* tag */, Mask256 mask, + Vec256 no) { + return Vec256{_mm256_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<4> /* tag */, Mask256 mask, + Vec256 no) { + return Vec256{_mm256_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<8> /* tag */, Mask256 mask, + Vec256 no) { + return Vec256{_mm256_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; +} + +} // namespace detail + +template +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + return detail::IfThenZeroElse(hwy::SizeTag(), mask, no); +} +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + return Vec256{_mm256_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; +} +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + return Vec256{_mm256_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; +} + +// ------------------------------ Mask logical + +namespace detail { + +template +HWY_INLINE Mask256 And(hwy::SizeTag<1> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kand_mask32(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask32>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 And(hwy::SizeTag<2> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kand_mask16(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask16>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 And(hwy::SizeTag<4> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kand_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 And(hwy::SizeTag<8> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kand_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask256 AndNot(hwy::SizeTag<1> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kandn_mask32(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask32>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 AndNot(hwy::SizeTag<2> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kandn_mask16(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask16>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 AndNot(hwy::SizeTag<4> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask256 AndNot(hwy::SizeTag<8> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(~a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask256 Or(hwy::SizeTag<1> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kor_mask32(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask32>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Or(hwy::SizeTag<2> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kor_mask16(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask16>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Or(hwy::SizeTag<4> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kor_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Or(hwy::SizeTag<8> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kor_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw | b.raw)}; +#endif +} + +template +HWY_INLINE Mask256 Xor(hwy::SizeTag<1> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxor_mask32(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask32>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Xor(hwy::SizeTag<2> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxor_mask16(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask16>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Xor(hwy::SizeTag<4> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask256 Xor(hwy::SizeTag<8> /*tag*/, const Mask256 a, + const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(a.raw ^ b.raw)}; +#endif +} + +template +HWY_INLINE Mask256 ExclusiveNeither(hwy::SizeTag<1> /*tag*/, + const Mask256 a, const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxnor_mask32(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask32>(~(a.raw ^ b.raw) & 0xFFFFFFFF)}; +#endif +} +template +HWY_INLINE Mask256 ExclusiveNeither(hwy::SizeTag<2> /*tag*/, + const Mask256 a, const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxnor_mask16(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; +#endif +} +template +HWY_INLINE Mask256 ExclusiveNeither(hwy::SizeTag<4> /*tag*/, + const Mask256 a, const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{_kxnor_mask8(a.raw, b.raw)}; +#else + return Mask256{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; +#endif +} +template +HWY_INLINE Mask256 ExclusiveNeither(hwy::SizeTag<8> /*tag*/, + const Mask256 a, const Mask256 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0xF)}; +#else + return Mask256{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xF)}; +#endif +} + +// UnmaskedNot returns ~m.raw without zeroing out any invalid bits +template +HWY_INLINE Mask256 UnmaskedNot(const Mask256 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{static_cast<__mmask32>(_knot_mask32(m.raw))}; +#else + return Mask256{static_cast<__mmask32>(~m.raw)}; +#endif +} + +template +HWY_INLINE Mask256 UnmaskedNot(const Mask256 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{static_cast<__mmask16>(_knot_mask16(m.raw))}; +#else + return Mask256{static_cast<__mmask16>(~m.raw)}; +#endif +} + +template +HWY_INLINE Mask256 UnmaskedNot(const Mask256 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask256{static_cast<__mmask8>(_knot_mask8(m.raw))}; +#else + return Mask256{static_cast<__mmask8>(~m.raw)}; +#endif +} + +template +HWY_INLINE Mask256 Not(hwy::SizeTag<1> /*tag*/, const Mask256 m) { + // sizeof(T) == 1: simply return ~m as all 32 bits of m are valid + return UnmaskedNot(m); +} +template +HWY_INLINE Mask256 Not(hwy::SizeTag<2> /*tag*/, const Mask256 m) { + // sizeof(T) == 2: simply return ~m as all 16 bits of m are valid + return UnmaskedNot(m); +} +template +HWY_INLINE Mask256 Not(hwy::SizeTag<4> /*tag*/, const Mask256 m) { + // sizeof(T) == 4: simply return ~m as all 8 bits of m are valid + return UnmaskedNot(m); +} +template +HWY_INLINE Mask256 Not(hwy::SizeTag<8> /*tag*/, const Mask256 m) { + // sizeof(T) == 8: need to zero out the upper 4 bits of ~m as only the lower + // 4 bits of m are valid + + // Return (~m) & 0x0F + return AndNot(hwy::SizeTag<8>(), m, Mask256::FromBits(uint64_t{0x0F})); +} + +} // namespace detail + +template +HWY_API Mask256 And(const Mask256 a, Mask256 b) { + return detail::And(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { + return detail::AndNot(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask256 Or(const Mask256 a, Mask256 b) { + return detail::Or(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { + return detail::Xor(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask256 Not(const Mask256 m) { + // Flip only the valid bits. + return detail::Not(hwy::SizeTag(), m); +} + +template +HWY_API Mask256 ExclusiveNeither(const Mask256 a, Mask256 b) { + return detail::ExclusiveNeither(hwy::SizeTag(), a, b); +} + +template +HWY_API MFromD CombineMasks(D /*d*/, MFromD> hi, + MFromD> lo) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const __mmask32 combined_mask = _mm512_kunpackw( + static_cast<__mmask32>(hi.raw), static_cast<__mmask32>(lo.raw)); +#else + const auto combined_mask = + ((static_cast(hi.raw) << 16) | (lo.raw & 0xFFFFu)); +#endif + + return MFromD{static_cast().raw)>(combined_mask)}; +} + +template +HWY_API MFromD UpperHalfOfMask(D /*d*/, MFromD> m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const auto shifted_mask = _kshiftri_mask32(static_cast<__mmask32>(m.raw), 16); +#else + const auto shifted_mask = static_cast(m.raw) >> 16; +#endif + + return MFromD{static_cast().raw)>(shifted_mask)}; +} + +template +HWY_API MFromD SlideMask1Up(D /*d*/, MFromD m) { + using RawM = decltype(MFromD().raw); +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return MFromD{ + static_cast(_kshiftli_mask32(static_cast<__mmask32>(m.raw), 1))}; +#else + return MFromD{static_cast(static_cast(m.raw) << 1)}; +#endif +} + +template +HWY_API MFromD SlideMask1Down(D /*d*/, MFromD m) { + using RawM = decltype(MFromD().raw); +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return MFromD{ + static_cast(_kshiftri_mask32(static_cast<__mmask32>(m.raw), 1))}; +#else + return MFromD{static_cast(static_cast(m.raw) >> 1)}; +#endif +} + +#else // AVX2 + +// ------------------------------ Mask + +// Mask and Vec are the same (true = FF..FF). +template +HWY_API Mask256 MaskFromVec(const Vec256 v) { + return Mask256{v.raw}; +} + +template +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{v.raw}; +} + +// ------------------------------ IfThenElse + +// mask ? yes : no +template +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { + return Vec256{_mm256_blendv_epi8(no.raw, yes.raw, mask.raw)}; +} +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, + Vec256 no) { + return Vec256{_mm256_blendv_ps(no.raw, yes.raw, mask.raw)}; +} +HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, + Vec256 no) { + return Vec256{_mm256_blendv_pd(no.raw, yes.raw, mask.raw)}; +} + +// mask ? yes : 0 +template +HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { + const DFromV d; + return yes & VecFromMask(d, mask); +} + +// mask ? 0 : no +template +HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { + const DFromV d; + return AndNot(VecFromMask(d, mask), no); +} + +template +HWY_API Vec256 ZeroIfNegative(Vec256 v) { + static_assert(IsSigned(), "Only for float"); + const DFromV d; + const auto zero = Zero(d); + // AVX2 IfThenElse only looks at the MSB for 32/64-bit lanes + return IfThenElse(MaskFromVec(v), zero, v); +} + +// ------------------------------ Mask logical + +template +HWY_API Mask256 Not(const Mask256 m) { + const Full256 d; + return MaskFromVec(Not(VecFromMask(d, m))); +} + +template +HWY_API Mask256 And(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 Or(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); +} + +template +HWY_API Mask256 ExclusiveNeither(const Mask256 a, Mask256 b) { + const Full256 d; + return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ================================================== COMPARE + +#if HWY_TARGET <= HWY_AVX3 + +// Comparisons set a mask bit to 1 if the condition is true, else 0. + +template +HWY_API MFromD RebindMask(DTo /*tag*/, Mask256 m) { + static_assert(sizeof(TFrom) == sizeof(TFromD), "Must have same size"); + return MFromD{m.raw}; +} + +namespace detail { + +template +HWY_INLINE Mask256 TestBit(hwy::SizeTag<1> /*tag*/, const Vec256 v, + const Vec256 bit) { + return Mask256{_mm256_test_epi8_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask256 TestBit(hwy::SizeTag<2> /*tag*/, const Vec256 v, + const Vec256 bit) { + return Mask256{_mm256_test_epi16_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask256 TestBit(hwy::SizeTag<4> /*tag*/, const Vec256 v, + const Vec256 bit) { + return Mask256{_mm256_test_epi32_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask256 TestBit(hwy::SizeTag<8> /*tag*/, const Vec256 v, + const Vec256 bit) { + return Mask256{_mm256_test_epi64_mask(v.raw, bit.raw)}; +} + +} // namespace detail + +template +HWY_API Mask256 TestBit(const Vec256 v, const Vec256 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return detail::TestBit(hwy::SizeTag(), v, bit); +} + +// ------------------------------ Equality + +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpeq_epi64_mask(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Mask256 operator==(Vec256 a, + Vec256 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask256{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_EQ_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +// ------------------------------ Inequality + +template +HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpneq_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpneq_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpneq_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { + return Mask256{_mm256_cmpneq_epi64_mask(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Mask256 operator!=(Vec256 a, + Vec256 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask256{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +// ------------------------------ Strict inequality + +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epi8_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epi16_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epi32_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epi64_mask(a.raw, b.raw)}; +} + +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epu8_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epu16_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epu32_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpgt_epu64_mask(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask256{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_GT_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} + +// ------------------------------ Weak inequality + +#if HWY_HAVE_FLOAT16 +HWY_API Mask256 operator>=(Vec256 a, + Vec256 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask256{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_GE_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} +HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} + +HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpge_epi8_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpge_epi16_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpge_epi32_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpge_epi64_mask(a.raw, b.raw)}; +} + +HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpge_epu8_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>=(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpge_epu16_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>=(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpge_epu32_mask(a.raw, b.raw)}; +} +HWY_API Mask256 operator>=(const Vec256 a, + const Vec256 b) { + return Mask256{_mm256_cmpge_epu64_mask(a.raw, b.raw)}; +} + +// ------------------------------ Mask + +namespace detail { + +template +HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<1> /*tag*/, const Vec256 v) { + return Mask256{_mm256_movepi8_mask(v.raw)}; +} +template +HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<2> /*tag*/, const Vec256 v) { + return Mask256{_mm256_movepi16_mask(v.raw)}; +} +template +HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<4> /*tag*/, const Vec256 v) { + return Mask256{_mm256_movepi32_mask(v.raw)}; +} +template +HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<8> /*tag*/, const Vec256 v) { + return Mask256{_mm256_movepi64_mask(v.raw)}; +} + +} // namespace detail + +template +HWY_API Mask256 MaskFromVec(const Vec256 v) { + return detail::MaskFromVec(hwy::SizeTag(), v); +} +// There do not seem to be native floating-point versions of these instructions. +template +HWY_API Mask256 MaskFromVec(const Vec256 v) { + const RebindToSigned> di; + return Mask256{MaskFromVec(BitCast(di, v)).raw}; +} + +template +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_movm_epi8(v.raw)}; +} + +template +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_movm_epi16(v.raw)}; +} + +template +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_movm_epi32(v.raw)}; +} + +template +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_movm_epi64(v.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_castsi256_ph(_mm256_movm_epi16(v.raw))}; +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_castsi256_ps(_mm256_movm_epi32(v.raw))}; +} + +HWY_API Vec256 VecFromMask(const Mask256 v) { + return Vec256{_mm256_castsi256_pd(_mm256_movm_epi64(v.raw))}; +} + +#else // AVX2 + +// Comparisons fill a lane with 1-bits if the condition is true, else 0. + +template +HWY_API MFromD RebindMask(DTo d_to, Mask256 m) { + static_assert(sizeof(TFrom) == sizeof(TFromD), "Must have same size"); + const Full256 dfrom; + return MaskFromVec(BitCast(d_to, VecFromMask(dfrom, m))); +} + +template +HWY_API Mask256 TestBit(const Vec256 v, const Vec256 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return (v & bit) == bit; +} + +// ------------------------------ Equality + +template +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpeq_epi8(a.raw, b.raw)}; +} + +template +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpeq_epi16(a.raw, b.raw)}; +} + +template +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpeq_epi32(a.raw, b.raw)}; +} + +template +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmpeq_epi64(a.raw, b.raw)}; +} + +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +HWY_API Mask256 operator==(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +// ------------------------------ Inequality + +template +HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { + return Not(a == b); +} +HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_NEQ_OQ)}; +} +HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +// ------------------------------ Strict inequality + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +// Pre-9.3 GCC immintrin.h uses char, which may be unsigned, causing cmpgt_epi8 +// to perform an unsigned comparison instead of the intended signed. Workaround +// is to cast to an explicitly signed type. See https://godbolt.org/z/PL7Ujy +#if HWY_COMPILER_GCC_ACTUAL != 0 && HWY_COMPILER_GCC_ACTUAL < 903 +#define HWY_AVX2_GCC_CMPGT8_WORKAROUND 1 +#else +#define HWY_AVX2_GCC_CMPGT8_WORKAROUND 0 +#endif + +HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, + Vec256 b) { +#if HWY_AVX2_GCC_CMPGT8_WORKAROUND + using i8x32 = signed char __attribute__((__vector_size__(32))); + return Mask256{static_cast<__m256i>(reinterpret_cast(a.raw) > + reinterpret_cast(b.raw))}; +#else + return Mask256{_mm256_cmpgt_epi8(a.raw, b.raw)}; +#endif +} +HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmpgt_epi16(a.raw, b.raw)}; +} +HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmpgt_epi32(a.raw, b.raw)}; +} +HWY_API Mask256 Gt(hwy::SignedTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmpgt_epi64(a.raw, b.raw)}; +} + +template +HWY_INLINE Mask256 Gt(hwy::UnsignedTag /*tag*/, Vec256 a, Vec256 b) { + const Full256 du; + const RebindToSigned di; + const Vec256 msb = Set(du, (LimitsMax() >> 1) + 1); + return RebindMask(du, BitCast(di, Xor(a, msb)) > BitCast(di, Xor(b, msb))); +} + +HWY_API Mask256 Gt(hwy::FloatTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_GT_OQ)}; +} +HWY_API Mask256 Gt(hwy::FloatTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_GT_OQ)}; +} + +} // namespace detail + +template +HWY_API Mask256 operator>(Vec256 a, Vec256 b) { + return detail::Gt(hwy::TypeTag(), a, b); +} + +// ------------------------------ Weak inequality + +namespace detail { + +template +HWY_INLINE Mask256 Ge(hwy::SignedTag tag, Vec256 a, Vec256 b) { + return Not(Gt(tag, b, a)); +} + +template +HWY_INLINE Mask256 Ge(hwy::UnsignedTag tag, Vec256 a, Vec256 b) { + return Not(Gt(tag, b, a)); +} + +HWY_INLINE Mask256 Ge(hwy::FloatTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_GE_OQ)}; +} +HWY_INLINE Mask256 Ge(hwy::FloatTag /*tag*/, Vec256 a, + Vec256 b) { + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_GE_OQ)}; +} + +} // namespace detail + +template +HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { + return detail::Ge(hwy::TypeTag(), a, b); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Reversed comparisons + +template +HWY_API Mask256 operator<(const Vec256 a, const Vec256 b) { + return b > a; +} + +template +HWY_API Mask256 operator<=(const Vec256 a, const Vec256 b) { + return b >= a; +} + +// ------------------------------ Min (Gt, IfThenElse) + +// Unsigned +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_min_epu16(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_min_epu32(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, + const Vec256 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_min_epu64(a.raw, b.raw)}; +#else + const Full256 du; + const Full256 di; + const auto msb = Set(du, 1ull << 63); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, b, a); +#endif +} + +// Signed +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_min_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, a, b); +#endif +} + +// Float +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 Min(Vec256 a, Vec256 b) { + return Vec256{_mm256_min_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_ps(a.raw, b.raw)}; +} +HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_min_pd(a.raw, b.raw)}; +} + +// ------------------------------ Max (Gt, IfThenElse) + +// Unsigned +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_max_epu16(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { + return Vec256{_mm256_max_epu32(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, + const Vec256 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_max_epu64(a.raw, b.raw)}; +#else + const Full256 du; + const Full256 di; + const auto msb = Set(du, 1ull << 63); + const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); + return IfThenElse(gt, a, b); +#endif +} + +// Signed +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_max_epi64(a.raw, b.raw)}; +#else + return IfThenElse(a < b, b, a); +#endif +} + +// Float +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 Max(Vec256 a, Vec256 b) { + return Vec256{_mm256_max_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_ps(a.raw, b.raw)}; +} +HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { + return Vec256{_mm256_max_pd(a.raw, b.raw)}; +} + +// ------------------------------ MinNumber and MaxNumber + +#if HWY_X86_HAVE_AVX10_2_OPS + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 MinNumber(Vec256 a, Vec256 b) { + return Vec256{_mm256_minmax_ph(a.raw, b.raw, 0x14)}; +} +#endif +HWY_API Vec256 MinNumber(Vec256 a, Vec256 b) { + return Vec256{_mm256_minmax_ps(a.raw, b.raw, 0x14)}; +} +HWY_API Vec256 MinNumber(Vec256 a, Vec256 b) { + return Vec256{_mm256_minmax_pd(a.raw, b.raw, 0x14)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 MaxNumber(Vec256 a, Vec256 b) { + return Vec256{_mm256_minmax_ph(a.raw, b.raw, 0x15)}; +} +#endif +HWY_API Vec256 MaxNumber(Vec256 a, Vec256 b) { + return Vec256{_mm256_minmax_ps(a.raw, b.raw, 0x15)}; +} +HWY_API Vec256 MaxNumber(Vec256 a, Vec256 b) { + return Vec256{_mm256_minmax_pd(a.raw, b.raw, 0x15)}; +} + +#endif + +// ------------------------------ MinMagnitude and MaxMagnitude + +#if HWY_X86_HAVE_AVX10_2_OPS + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 MinMagnitude(Vec256 a, + Vec256 b) { + return Vec256{_mm256_minmax_ph(a.raw, b.raw, 0x16)}; +} +#endif +HWY_API Vec256 MinMagnitude(Vec256 a, Vec256 b) { + return Vec256{_mm256_minmax_ps(a.raw, b.raw, 0x16)}; +} +HWY_API Vec256 MinMagnitude(Vec256 a, Vec256 b) { + return Vec256{_mm256_minmax_pd(a.raw, b.raw, 0x16)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 MaxMagnitude(Vec256 a, + Vec256 b) { + return Vec256{_mm256_minmax_ph(a.raw, b.raw, 0x17)}; +} +#endif +HWY_API Vec256 MaxMagnitude(Vec256 a, Vec256 b) { + return Vec256{_mm256_minmax_ps(a.raw, b.raw, 0x17)}; +} +HWY_API Vec256 MaxMagnitude(Vec256 a, Vec256 b) { + return Vec256{_mm256_minmax_pd(a.raw, b.raw, 0x17)}; +} + +#endif + +// ------------------------------ Iota + +namespace detail { + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm256_set_epi8( + static_cast(31), static_cast(30), static_cast(29), + static_cast(28), static_cast(27), static_cast(26), + static_cast(25), static_cast(24), static_cast(23), + static_cast(22), static_cast(21), static_cast(20), + static_cast(19), static_cast(18), static_cast(17), + static_cast(16), static_cast(15), static_cast(14), + static_cast(13), static_cast(12), static_cast(11), + static_cast(10), static_cast(9), static_cast(8), + static_cast(7), static_cast(6), static_cast(5), + static_cast(4), static_cast(3), static_cast(2), + static_cast(1), static_cast(0))}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm256_set_epi16( + int16_t{15}, int16_t{14}, int16_t{13}, int16_t{12}, int16_t{11}, + int16_t{10}, int16_t{9}, int16_t{8}, int16_t{7}, int16_t{6}, int16_t{5}, + int16_t{4}, int16_t{3}, int16_t{2}, int16_t{1}, int16_t{0})}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{ + _mm256_set_ph(float16_t{15}, float16_t{14}, float16_t{13}, float16_t{12}, + float16_t{11}, float16_t{10}, float16_t{9}, float16_t{8}, + float16_t{7}, float16_t{6}, float16_t{5}, float16_t{4}, + float16_t{3}, float16_t{2}, float16_t{1}, float16_t{0})}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm256_set_epi32(int32_t{7}, int32_t{6}, int32_t{5}, + int32_t{4}, int32_t{3}, int32_t{2}, + int32_t{1}, int32_t{0})}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{ + _mm256_set_epi64x(int64_t{3}, int64_t{2}, int64_t{1}, int64_t{0})}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{ + _mm256_set_ps(7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f)}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm256_set_pd(3.0, 2.0, 1.0, 0.0)}; +} + +} // namespace detail + +template +HWY_API VFromD Iota(D d, const T2 first) { + return detail::Iota0(d) + Set(d, ConvertScalarTo>(first)); +} + +// ------------------------------ FirstN (Iota, Lt) + +template > +HWY_API M FirstN(const D d, size_t n) { + constexpr size_t kN = MaxLanes(d); + // For AVX3, this ensures `num` <= 255 as required by bzhi, which only looks + // at the lower 8 bits; for AVX2 and below, this ensures `num` fits in TI. + n = HWY_MIN(n, kN); + +#if HWY_TARGET <= HWY_AVX3 +#if HWY_ARCH_X86_64 + const uint64_t all = (1ull << kN) - 1; + return M::FromBits(_bzhi_u64(all, n)); +#else + const uint32_t all = static_cast((1ull << kN) - 1); + return M::FromBits(_bzhi_u32(all, static_cast(n))); +#endif // HWY_ARCH_X86_64 +#else + const RebindToSigned di; // Signed comparisons are cheaper. + using TI = TFromD; + return RebindMask(d, detail::Iota0(di) < Set(di, static_cast(n))); +#endif +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{_mm256_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{_mm256_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{_mm256_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{_mm256_add_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{_mm256_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{_mm256_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{_mm256_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{_mm256_add_epi64(a.raw, b.raw)}; +} + +// Float +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{_mm256_add_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{_mm256_add_ps(a.raw, b.raw)}; +} +HWY_API Vec256 operator+(Vec256 a, Vec256 b) { + return Vec256{_mm256_add_pd(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{_mm256_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{_mm256_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{_mm256_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{_mm256_sub_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{_mm256_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{_mm256_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{_mm256_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{_mm256_sub_epi64(a.raw, b.raw)}; +} + +// Float +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{_mm256_sub_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{_mm256_sub_ps(a.raw, b.raw)}; +} +HWY_API Vec256 operator-(Vec256 a, Vec256 b) { + return Vec256{_mm256_sub_pd(a.raw, b.raw)}; +} + +// ------------------------------ AddSub + +HWY_API Vec256 AddSub(Vec256 a, Vec256 b) { + return Vec256{_mm256_addsub_ps(a.raw, b.raw)}; +} +HWY_API Vec256 AddSub(Vec256 a, Vec256 b) { + return Vec256{_mm256_addsub_pd(a.raw, b.raw)}; +} + +// ------------------------------ PairwiseAdd128/PairwiseSub128 + +template +HWY_API VFromD PairwiseAdd128(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm256_hadd_epi16(a.raw, b.raw)}; +} +template +HWY_API VFromD PairwiseSub128(D /*d*/, VFromD a, VFromD b) { + const DFromV d; + const RebindToSigned di; + return BitCast(d, + Neg(BitCast(di, VFromD{_mm256_hsub_epi16(a.raw, b.raw)}))); +} +template +HWY_API VFromD PairwiseAdd128(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm256_hadd_epi32(a.raw, b.raw)}; +} +template +HWY_API VFromD PairwiseSub128(D /*d*/, VFromD a, VFromD b) { + const DFromV d; + const RebindToSigned di; + return BitCast(d, + Neg(BitCast(di, VFromD{_mm256_hsub_epi32(a.raw, b.raw)}))); +} +template +HWY_API VFromD PairwiseAdd128(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm256_hadd_ps(a.raw, b.raw)}; +} +template +HWY_API VFromD PairwiseSub128(D /*d*/, VFromD a, VFromD b) { + return Neg(VFromD{_mm256_hsub_ps(a.raw, b.raw)}); +} +template +HWY_API VFromD PairwiseAdd128(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm256_hadd_pd(a.raw, b.raw)}; +} +template +HWY_API VFromD PairwiseSub128(D /*d*/, VFromD a, VFromD b) { + return Neg(VFromD{_mm256_hsub_pd(a.raw, b.raw)}); +} + +// ------------------------------ SumsOf8 +HWY_API Vec256 SumsOf8(Vec256 v) { + return Vec256{_mm256_sad_epu8(v.raw, _mm256_setzero_si256())}; +} + +HWY_API Vec256 SumsOf8AbsDiff(Vec256 a, Vec256 b) { + return Vec256{_mm256_sad_epu8(a.raw, b.raw)}; +} + +// ------------------------------ SumsOf4 +#if HWY_TARGET <= HWY_AVX3 +namespace detail { + +HWY_INLINE Vec256 SumsOf4(hwy::UnsignedTag /*type_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, + Vec256 v) { + const DFromV d; + + // _mm256_maskz_dbsad_epu8 is used below as the odd uint16_t lanes need to be + // zeroed out and the sums of the 4 consecutive lanes are already in the + // even uint16_t lanes of the _mm256_maskz_dbsad_epu8 result. + return Vec256{_mm256_maskz_dbsad_epu8( + static_cast<__mmask16>(0x5555), v.raw, Zero(d).raw, 0)}; +} + +// detail::SumsOf4 for Vec256 on AVX3 is implemented in x86_512-inl.h + +} // namespace detail +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ SumsOfAdjQuadAbsDiff + +template +HWY_API Vec256 SumsOfAdjQuadAbsDiff(Vec256 a, + Vec256 b) { + static_assert(0 <= kAOffset && kAOffset <= 1, + "kAOffset must be between 0 and 1"); + static_assert(0 <= kBOffset && kBOffset <= 3, + "kBOffset must be between 0 and 3"); + return Vec256{_mm256_mpsadbw_epu8( + a.raw, b.raw, + (kAOffset << 5) | (kBOffset << 3) | (kAOffset << 2) | kBOffset)}; +} + +// ------------------------------ SumsOfShuffledQuadAbsDiff + +#if HWY_TARGET <= HWY_AVX3 +template +static Vec256 SumsOfShuffledQuadAbsDiff(Vec256 a, + Vec256 b) { + static_assert(0 <= kIdx0 && kIdx0 <= 3, "kIdx0 must be between 0 and 3"); + static_assert(0 <= kIdx1 && kIdx1 <= 3, "kIdx1 must be between 0 and 3"); + static_assert(0 <= kIdx2 && kIdx2 <= 3, "kIdx2 must be between 0 and 3"); + static_assert(0 <= kIdx3 && kIdx3 <= 3, "kIdx3 must be between 0 and 3"); + return Vec256{ + _mm256_dbsad_epu8(b.raw, a.raw, _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kIdx0))}; +} +#endif + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +HWY_API Vec256 SaturatedAdd(Vec256 a, Vec256 b) { + return Vec256{_mm256_adds_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedAdd(Vec256 a, Vec256 b) { + return Vec256{_mm256_adds_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 SaturatedAdd(Vec256 a, Vec256 b) { + return Vec256{_mm256_adds_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedAdd(Vec256 a, Vec256 b) { + return Vec256{_mm256_adds_epi16(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +HWY_API Vec256 SaturatedSub(Vec256 a, Vec256 b) { + return Vec256{_mm256_subs_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedSub(Vec256 a, Vec256 b) { + return Vec256{_mm256_subs_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec256 SaturatedSub(Vec256 a, Vec256 b) { + return Vec256{_mm256_subs_epi8(a.raw, b.raw)}; +} +HWY_API Vec256 SaturatedSub(Vec256 a, Vec256 b) { + return Vec256{_mm256_subs_epi16(a.raw, b.raw)}; +} + +// ------------------------------ Average + +HWY_API Vec256 AverageRound(Vec256 a, Vec256 b) { + return Vec256{_mm256_avg_epu8(a.raw, b.raw)}; +} +HWY_API Vec256 AverageRound(Vec256 a, Vec256 b) { + return Vec256{_mm256_avg_epu16(a.raw, b.raw)}; +} + +// ------------------------------ Abs (Sub) + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +HWY_API Vec256 Abs(Vec256 v) { +#if HWY_COMPILER_MSVC + // Workaround for incorrect codegen? (wrong result) + const DFromV d; + const auto zero = Zero(d); + return Vec256{_mm256_max_epi8(v.raw, (zero - v).raw)}; +#else + return Vec256{_mm256_abs_epi8(v.raw)}; +#endif +} +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{_mm256_abs_epi16(v.raw)}; +} +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{_mm256_abs_epi32(v.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 +HWY_API Vec256 Abs(const Vec256 v) { + return Vec256{_mm256_abs_epi64(v.raw)}; +} +#endif + +// ------------------------------ Integer multiplication + +// Unsigned +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi32(a.raw, b.raw)}; +} +#if HWY_TARGET <= HWY_AVX3 +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi64(a.raw, b.raw)}; +} +#endif + +// Signed +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi32(a.raw, b.raw)}; +} +#if HWY_TARGET <= HWY_AVX3 +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mullo_epi64(a.raw, b.raw)}; +} +#endif + +// Returns the upper 16 bits of a * b in each lane. +HWY_API Vec256 MulHigh(Vec256 a, Vec256 b) { + return Vec256{_mm256_mulhi_epu16(a.raw, b.raw)}; +} +HWY_API Vec256 MulHigh(Vec256 a, Vec256 b) { + return Vec256{_mm256_mulhi_epi16(a.raw, b.raw)}; +} + +HWY_API Vec256 MulFixedPoint15(Vec256 a, Vec256 b) { + return Vec256{_mm256_mulhrs_epi16(a.raw, b.raw)}; +} + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +HWY_API Vec256 MulEven(Vec256 a, Vec256 b) { + return Vec256{_mm256_mul_epi32(a.raw, b.raw)}; +} +HWY_API Vec256 MulEven(Vec256 a, Vec256 b) { + return Vec256{_mm256_mul_epu32(a.raw, b.raw)}; +} + +// ------------------------------ ShiftLeft + +#if HWY_TARGET <= HWY_AVX3_DL +namespace detail { +template +HWY_API Vec256 GaloisAffine(Vec256 v, Vec256 matrix) { + return Vec256{_mm256_gf2p8affine_epi64_epi8(v.raw, matrix.raw, 0)}; +} +} // namespace detail +#endif // HWY_TARGET <= HWY_AVX3_DL + +template +HWY_API Vec256 ShiftLeft(Vec256 v) { + return Vec256{_mm256_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(Vec256 v) { + return Vec256{_mm256_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(Vec256 v) { + return Vec256{_mm256_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(Vec256 v) { + return Vec256{_mm256_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(Vec256 v) { + return Vec256{_mm256_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftLeft(Vec256 v) { + return Vec256{_mm256_slli_epi64(v.raw, kBits)}; +} + +#if HWY_TARGET > HWY_AVX3_DL + +template +HWY_API Vec256 ShiftLeft(const Vec256 v) { + const Full256 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftLeft(BitCast(d16, v))); + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); +} + +#endif // HWY_TARGET > HWY_AVX3_DL + +// ------------------------------ ShiftRight + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + return Vec256{_mm256_srli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + return Vec256{_mm256_srli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + return Vec256{_mm256_srli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + return Vec256{_mm256_srai_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + return Vec256{_mm256_srai_epi32(v.raw, kBits)}; +} + +#if HWY_TARGET > HWY_AVX3_DL + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + const Full256 d8; + // Use raw instead of BitCast to support N=1. + const Vec256 shifted{ShiftRight(Vec256{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template +HWY_API Vec256 ShiftRight(Vec256 v) { + const Full256 di; + const Full256 du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +#endif // HWY_TARGET > HWY_AVX3_DL + +// i64 is implemented after BroadcastSignBit. + +// ------------------------------ RotateRight + +// U8 RotateRight implementation on AVX3_DL is now in x86_512-inl.h as U8 +// RotateRight uses detail::GaloisAffine on AVX3_DL + +#if HWY_TARGET > HWY_AVX3_DL +template +HWY_API Vec256 RotateRight(const Vec256 v) { + static_assert(0 <= kBits && kBits < 8, "Invalid shift count"); + if (kBits == 0) return v; + // AVX3 does not support 8-bit. + return Or(ShiftRight(v), ShiftLeft(v)); +} +#endif + +template +HWY_API Vec256 RotateRight(const Vec256 v) { + static_assert(0 <= kBits && kBits < 16, "Invalid shift count"); + if (kBits == 0) return v; +#if HWY_TARGET <= HWY_AVX3_DL + return Vec256{_mm256_shrdi_epi16(v.raw, v.raw, kBits)}; +#else + // AVX3 does not support 16-bit. + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +template +HWY_API Vec256 RotateRight(const Vec256 v) { + static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_ror_epi32(v.raw, kBits)}; +#else + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +template +HWY_API Vec256 RotateRight(const Vec256 v) { + static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_ror_epi64(v.raw, kBits)}; +#else + if (kBits == 0) return v; + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +// ------------------------------ Rol/Ror +#if HWY_TARGET <= HWY_AVX3_DL +template +HWY_API Vec256 Ror(Vec256 a, Vec256 b) { + return Vec256{_mm256_shrdv_epi16(a.raw, a.raw, b.raw)}; +} +#endif // HWY_TARGET <= HWY_AVX3_DL + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API Vec256 Rol(Vec256 a, Vec256 b) { + return Vec256{_mm256_rolv_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec256 Ror(Vec256 a, Vec256 b) { + return Vec256{_mm256_rorv_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec256 Rol(Vec256 a, Vec256 b) { + return Vec256{_mm256_rolv_epi64(a.raw, b.raw)}; +} + +template +HWY_API Vec256 Ror(Vec256 a, Vec256 b) { + return Vec256{_mm256_rorv_epi64(a.raw, b.raw)}; +} + +#endif + +// ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + const DFromV d; + return VecFromMask(v < Zero(d)); +} + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return ShiftRight<15>(v); +} + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return ShiftRight<31>(v); +} + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + return Vec256{ + _mm256_srai_epi64(v.raw, static_cast(kBits))}; +} + +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + return ShiftRight<63>(v); +} + +#else // AVX2 + +// Unlike above, this will be used to implement int64_t ShiftRight. +HWY_API Vec256 BroadcastSignBit(const Vec256 v) { + const DFromV d; + return VecFromMask(v < Zero(d)); +} + +template +HWY_API Vec256 ShiftRight(const Vec256 v) { + const Full256 di; + const Full256 du; + const auto right = BitCast(di, ShiftRight(BitCast(du, v))); + const auto sign = ShiftLeft<64 - kBits>(BroadcastSignBit(v)); + return right | sign; +} + +#endif // #if HWY_TARGET <= HWY_AVX3 + +// ------------------------------ IfNegativeThenElse (BroadcastSignBit) +HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, + Vec256 no) { + // int8: AVX2 IfThenElse only looks at the MSB. + return IfThenElse(MaskFromVec(v), yes, no); +} + +template +HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { + static_assert(IsSigned(), "Only works for signed/float"); + +#if HWY_TARGET <= HWY_AVX3 + const auto mask = MaskFromVec(v); +#else + // 16-bit: no native blendv on AVX2, so copy sign to lower byte's MSB. + const DFromV d; + const RebindToSigned di; + const auto mask = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v)))); +#endif + + return IfThenElse(mask, yes, no); +} + +template +HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { + static_assert(IsSigned(), "Only works for signed/float"); + +#if HWY_TARGET <= HWY_AVX3 + // No need to cast to float on AVX3 as IfThenElse only looks at the MSB on + // AVX3 + return IfThenElse(MaskFromVec(v), yes, no); +#else + const DFromV d; + const RebindToFloat df; + // 32/64-bit: use float IfThenElse, which only looks at the MSB. + const MFromD msb = MaskFromVec(BitCast(df, v)); + return BitCast(d, IfThenElse(msb, BitCast(df, yes), BitCast(df, no))); +#endif +} + +// ------------------------------ IfNegativeThenNegOrUndefIfZero + +HWY_API Vec256 IfNegativeThenNegOrUndefIfZero(Vec256 mask, + Vec256 v) { + return Vec256{_mm256_sign_epi8(v.raw, mask.raw)}; +} + +HWY_API Vec256 IfNegativeThenNegOrUndefIfZero(Vec256 mask, + Vec256 v) { + return Vec256{_mm256_sign_epi16(v.raw, mask.raw)}; +} + +HWY_API Vec256 IfNegativeThenNegOrUndefIfZero(Vec256 mask, + Vec256 v) { + return Vec256{_mm256_sign_epi32(v.raw, mask.raw)}; +} + +// ------------------------------ ShiftLeftSame + +// Disable sign conversion warnings for GCC debug intrinsics. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +HWY_API Vec256 ShiftLeftSame(const Vec256 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{_mm256_slli_epi16(v.raw, bits)}; + } +#endif + return Vec256{_mm256_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftLeftSame(const Vec256 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{_mm256_slli_epi32(v.raw, bits)}; + } +#endif + return Vec256{_mm256_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftLeftSame(const Vec256 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{_mm256_slli_epi64(v.raw, bits)}; + } +#endif + return Vec256{_mm256_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{_mm256_slli_epi16(v.raw, bits)}; + } +#endif + return Vec256{_mm256_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{_mm256_slli_epi32(v.raw, bits)}; + } +#endif + return Vec256{_mm256_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{_mm256_slli_epi64(v.raw, bits)}; + } +#endif + return Vec256{_mm256_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { + const Full256 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftLeftSame(BitCast(d16, v), bits)); + return shifted & Set(d8, static_cast((0xFF << bits) & 0xFF)); +} + +// ------------------------------ ShiftRightSame (BroadcastSignBit) + +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{_mm256_srli_epi16(v.raw, bits)}; + } +#endif + return Vec256{_mm256_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{_mm256_srli_epi32(v.raw, bits)}; + } +#endif + return Vec256{_mm256_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{_mm256_srli_epi64(v.raw, bits)}; + } +#endif + return Vec256{_mm256_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { + const Full256 d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftRightSame(BitCast(d16, v), bits)); + return shifted & Set(d8, static_cast(0xFF >> bits)); +} + +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{_mm256_srai_epi16(v.raw, bits)}; + } +#endif + return Vec256{_mm256_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{_mm256_srai_epi32(v.raw, bits)}; + } +#endif + return Vec256{_mm256_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec256 ShiftRightSame(const Vec256 v, + const int bits) { +#if HWY_TARGET <= HWY_AVX3 +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec256{ + _mm256_srai_epi64(v.raw, static_cast(bits))}; + } +#endif + return Vec256{_mm256_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +#else + const Full256 di; + const Full256 du; + const auto right = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto sign = ShiftLeftSame(BroadcastSignBit(v), 64 - bits); + return right | sign; +#endif +} + +HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { + const Full256 di; + const Full256 du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = + BitCast(di, Set(du, static_cast(0x80 >> bits))); + return (shifted ^ shifted_sign) - shifted_sign; +} + +HWY_DIAGNOSTICS(pop) + +// ------------------------------ Neg (Xor, Sub) + +// Tag dispatch instead of SFINAE for MSVC 2017 compatibility +namespace detail { + +template +HWY_INLINE Vec256 Neg(hwy::FloatTag /*tag*/, const Vec256 v) { + const DFromV d; + return Xor(v, SignBit(d)); +} + +template +HWY_INLINE Vec256 Neg(hwy::SpecialTag /*tag*/, const Vec256 v) { + const DFromV d; + return Xor(v, SignBit(d)); +} + +// Not floating-point +template +HWY_INLINE Vec256 Neg(hwy::SignedTag /*tag*/, const Vec256 v) { + const DFromV d; + return Zero(d) - v; +} + +} // namespace detail + +template +HWY_API Vec256 Neg(const Vec256 v) { + return detail::Neg(hwy::TypeTag(), v); +} + +// ------------------------------ Floating-point mul / div + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mul_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mul_ps(a.raw, b.raw)}; +} +HWY_API Vec256 operator*(Vec256 a, Vec256 b) { + return Vec256{_mm256_mul_pd(a.raw, b.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 MulByFloorPow2(Vec256 a, + Vec256 b) { + return Vec256{_mm256_scalef_ph(a.raw, b.raw)}; +} +#endif + +HWY_API Vec256 MulByFloorPow2(Vec256 a, Vec256 b) { + return Vec256{_mm256_scalef_ps(a.raw, b.raw)}; +} + +HWY_API Vec256 MulByFloorPow2(Vec256 a, Vec256 b) { + return Vec256{_mm256_scalef_pd(a.raw, b.raw)}; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 operator/(Vec256 a, Vec256 b) { + return Vec256{_mm256_div_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 operator/(Vec256 a, Vec256 b) { + return Vec256{_mm256_div_ps(a.raw, b.raw)}; +} +HWY_API Vec256 operator/(Vec256 a, Vec256 b) { + return Vec256{_mm256_div_pd(a.raw, b.raw)}; +} + +// Approximate reciprocal +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 ApproximateReciprocal(Vec256 v) { + return Vec256{_mm256_rcp_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Vec256 ApproximateReciprocal(Vec256 v) { + return Vec256{_mm256_rcp_ps(v.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 +HWY_API Vec256 ApproximateReciprocal(Vec256 v) { + return Vec256{_mm256_rcp14_pd(v.raw)}; +} +#endif + +// ------------------------------ GetExponent + +#if HWY_TARGET <= HWY_AVX3 + +#if HWY_HAVE_FLOAT16 +template ), HWY_IF_V_SIZE_V(V, 32)> +HWY_API V GetExponent(V v) { + return V{_mm256_getexp_ph(v.raw)}; +} +#endif +template ), HWY_IF_V_SIZE_V(V, 32)> +HWY_API V GetExponent(V v) { + return V{_mm256_getexp_ps(v.raw)}; +} +template ), HWY_IF_V_SIZE_V(V, 32)> +HWY_API V GetExponent(V v) { + return V{_mm256_getexp_pd(v.raw)}; +} + +#endif + +// ------------------------------ MaskedMinOr + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epu8(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epu16(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epu32(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epu64(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec256 MaskedMinOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_min_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedMaxOr + +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epu8(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epu16(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epu32(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epu64(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec256 MaskedMaxOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_max_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedAddOr + +template +HWY_API Vec256 MaskedAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_add_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_add_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_add_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_add_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_add_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_add_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec256 MaskedAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_add_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedSubOr + +template +HWY_API Vec256 MaskedSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_sub_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_sub_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_sub_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_sub_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_sub_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_sub_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec256 MaskedSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_sub_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedMulOr + +HWY_API Vec256 MaskedMulOr(Vec256 no, Mask256 m, + Vec256 a, Vec256 b) { + return Vec256{_mm256_mask_mul_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +HWY_API Vec256 MaskedMulOr(Vec256 no, Mask256 m, + Vec256 a, Vec256 b) { + return Vec256{_mm256_mask_mul_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 MaskedMulOr(Vec256 no, + Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_mul_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedDivOr + +HWY_API Vec256 MaskedDivOr(Vec256 no, Mask256 m, + Vec256 a, Vec256 b) { + return Vec256{_mm256_mask_div_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +HWY_API Vec256 MaskedDivOr(Vec256 no, Mask256 m, + Vec256 a, Vec256 b) { + return Vec256{_mm256_mask_div_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 MaskedDivOr(Vec256 no, + Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_div_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedSatAddOr + +template +HWY_API Vec256 MaskedSatAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_adds_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSatAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_adds_epu8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSatAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_adds_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSatAddOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_adds_epu16(no.raw, m.raw, a.raw, b.raw)}; +} + +// ------------------------------ MaskedSatSubOr + +template +HWY_API Vec256 MaskedSatSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_subs_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSatSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_subs_epu8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSatSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_subs_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec256 MaskedSatSubOr(Vec256 no, Mask256 m, Vec256 a, + Vec256 b) { + return Vec256{_mm256_mask_subs_epu16(no.raw, m.raw, a.raw, b.raw)}; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Floating-point multiply-add variants + +#if HWY_HAVE_FLOAT16 + +HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{_mm256_fmadd_ph(mul.raw, x.raw, add.raw)}; +} + +HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, + Vec256 add) { + return Vec256{_mm256_fnmadd_ph(mul.raw, x.raw, add.raw)}; +} + +HWY_API Vec256 MulSub(Vec256 mul, Vec256 x, + Vec256 sub) { + return Vec256{_mm256_fmsub_ph(mul.raw, x.raw, sub.raw)}; +} + +HWY_API Vec256 NegMulSub(Vec256 mul, Vec256 x, + Vec256 sub) { + return Vec256{_mm256_fnmsub_ph(mul.raw, x.raw, sub.raw)}; +} + +#endif // HWY_HAVE_FLOAT16 + +HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, + Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x + add; +#else + return Vec256{_mm256_fmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, + Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x + add; +#else + return Vec256{_mm256_fmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, + Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return add - mul * x; +#else + return Vec256{_mm256_fnmadd_ps(mul.raw, x.raw, add.raw)}; +#endif +} +HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, + Vec256 add) { +#ifdef HWY_DISABLE_BMI2_FMA + return add - mul * x; +#else + return Vec256{_mm256_fnmadd_pd(mul.raw, x.raw, add.raw)}; +#endif +} + +HWY_API Vec256 MulSub(Vec256 mul, Vec256 x, + Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x - sub; +#else + return Vec256{_mm256_fmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +HWY_API Vec256 MulSub(Vec256 mul, Vec256 x, + Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return mul * x - sub; +#else + return Vec256{_mm256_fmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +HWY_API Vec256 NegMulSub(Vec256 mul, Vec256 x, + Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return Neg(mul * x) - sub; +#else + return Vec256{_mm256_fnmsub_ps(mul.raw, x.raw, sub.raw)}; +#endif +} +HWY_API Vec256 NegMulSub(Vec256 mul, Vec256 x, + Vec256 sub) { +#ifdef HWY_DISABLE_BMI2_FMA + return Neg(mul * x) - sub; +#else + return Vec256{_mm256_fnmsub_pd(mul.raw, x.raw, sub.raw)}; +#endif +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 MulAddSub(Vec256 mul, Vec256 x, + Vec256 sub_or_add) { + return Vec256{_mm256_fmaddsub_ph(mul.raw, x.raw, sub_or_add.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Vec256 MulAddSub(Vec256 mul, Vec256 x, + Vec256 sub_or_add) { +#ifdef HWY_DISABLE_BMI2_FMA + return AddSub(mul * x, sub_or_add); +#else + return Vec256{_mm256_fmaddsub_ps(mul.raw, x.raw, sub_or_add.raw)}; +#endif +} + +HWY_API Vec256 MulAddSub(Vec256 mul, Vec256 x, + Vec256 sub_or_add) { +#ifdef HWY_DISABLE_BMI2_FMA + return AddSub(mul * x, sub_or_add); +#else + return Vec256{_mm256_fmaddsub_pd(mul.raw, x.raw, sub_or_add.raw)}; +#endif +} + +// ------------------------------ Floating-point square root + +// Full precision square root +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 Sqrt(Vec256 v) { + return Vec256{_mm256_sqrt_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 Sqrt(Vec256 v) { + return Vec256{_mm256_sqrt_ps(v.raw)}; +} +HWY_API Vec256 Sqrt(Vec256 v) { + return Vec256{_mm256_sqrt_pd(v.raw)}; +} + +// Approximate reciprocal square root +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 ApproximateReciprocalSqrt(Vec256 v) { + return Vec256{_mm256_rsqrt_ph(v.raw)}; +} +#endif +HWY_API Vec256 ApproximateReciprocalSqrt(Vec256 v) { + return Vec256{_mm256_rsqrt_ps(v.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 +HWY_API Vec256 ApproximateReciprocalSqrt(Vec256 v) { +#if HWY_COMPILER_MSVC + const DFromV d; + return Vec256{_mm256_mask_rsqrt14_pd( + Undefined(d).raw, static_cast<__mmask8>(0xFF), v.raw)}; +#else + return Vec256{_mm256_rsqrt14_pd(v.raw)}; +#endif +} +#endif + +// ------------------------------ Floating-point rounding + +// Toward nearest integer, tie to even +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 Round(Vec256 v) { + return Vec256{_mm256_roundscale_ph( + v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 Round(Vec256 v) { + return Vec256{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256 Round(Vec256 v) { + return Vec256{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} + +// Toward zero, aka truncate +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 Trunc(Vec256 v) { + return Vec256{ + _mm256_roundscale_ph(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 Trunc(Vec256 v) { + return Vec256{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256 Trunc(Vec256 v) { + return Vec256{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} + +// Toward +infinity, aka ceiling +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 Ceil(Vec256 v) { + return Vec256{ + _mm256_roundscale_ph(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 Ceil(Vec256 v) { + return Vec256{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256 Ceil(Vec256 v) { + return Vec256{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} + +// Toward -infinity, aka floor +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 Floor(Vec256 v) { + return Vec256{ + _mm256_roundscale_ph(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 Floor(Vec256 v) { + return Vec256{ + _mm256_round_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec256 Floor(Vec256 v) { + return Vec256{ + _mm256_round_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} + +// ------------------------------ Floating-point classification + +#if HWY_HAVE_FLOAT16 || HWY_IDE + +HWY_API Mask256 IsNaN(Vec256 v) { + return Mask256{_mm256_fpclass_ph_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; +} + +HWY_API Mask256 IsEitherNaN(Vec256 a, + Vec256 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask256{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_UNORD_Q)}; + HWY_DIAGNOSTICS(pop) +} + +HWY_API Mask256 IsInf(Vec256 v) { + return Mask256{_mm256_fpclass_ph_mask( + v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; +} + +HWY_API Mask256 IsFinite(Vec256 v) { + // fpclass doesn't have a flag for positive, so we have to check for inf/NaN + // and negate the mask. + return Not(Mask256{_mm256_fpclass_ph_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | + HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); +} + +#endif // HWY_HAVE_FLOAT16 + +HWY_API Mask256 IsNaN(Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + return Mask256{_mm256_fpclass_ps_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; +#else + return Mask256{_mm256_cmp_ps(v.raw, v.raw, _CMP_UNORD_Q)}; +#endif +} +HWY_API Mask256 IsNaN(Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + return Mask256{_mm256_fpclass_pd_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; +#else + return Mask256{_mm256_cmp_pd(v.raw, v.raw, _CMP_UNORD_Q)}; +#endif +} + +HWY_API Mask256 IsEitherNaN(Vec256 a, Vec256 b) { +#if HWY_TARGET <= HWY_AVX3 + return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_UNORD_Q)}; +#else + return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_UNORD_Q)}; +#endif +} + +HWY_API Mask256 IsEitherNaN(Vec256 a, Vec256 b) { +#if HWY_TARGET <= HWY_AVX3 + return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_UNORD_Q)}; +#else + return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_UNORD_Q)}; +#endif +} + +#if HWY_TARGET <= HWY_AVX3 + +HWY_API Mask256 IsInf(Vec256 v) { + return Mask256{_mm256_fpclass_ps_mask( + v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; +} +HWY_API Mask256 IsInf(Vec256 v) { + return Mask256{_mm256_fpclass_pd_mask( + v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; +} + +HWY_API Mask256 IsFinite(Vec256 v) { + // fpclass doesn't have a flag for positive, so we have to check for inf/NaN + // and negate the mask. + return Not(Mask256{_mm256_fpclass_ps_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | + HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); +} +HWY_API Mask256 IsFinite(Vec256 v) { + return Not(Mask256{_mm256_fpclass_pd_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | + HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API VFromD Load(D /* tag */, const TFromD* HWY_RESTRICT aligned) { + return VFromD{ + _mm256_load_si256(reinterpret_cast(aligned))}; +} +// bfloat16_t is handled by x86_128-inl.h. +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec256 Load(D /* tag */, + const float16_t* HWY_RESTRICT aligned) { + return Vec256{_mm256_load_ph(aligned)}; +} +#endif +template +HWY_API Vec256 Load(D /* tag */, const float* HWY_RESTRICT aligned) { + return Vec256{_mm256_load_ps(aligned)}; +} +template +HWY_API Vec256 Load(D /* tag */, const double* HWY_RESTRICT aligned) { + return Vec256{_mm256_load_pd(aligned)}; +} + +template +HWY_API VFromD LoadU(D /* tag */, const TFromD* HWY_RESTRICT p) { + return VFromD{_mm256_loadu_si256(reinterpret_cast(p))}; +} +// bfloat16_t is handled by x86_128-inl.h. +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec256 LoadU(D /* tag */, const float16_t* HWY_RESTRICT p) { + return Vec256{_mm256_loadu_ph(p)}; +} +#endif +template +HWY_API Vec256 LoadU(D /* tag */, const float* HWY_RESTRICT p) { + return Vec256{_mm256_loadu_ps(p)}; +} +template +HWY_API Vec256 LoadU(D /* tag */, const double* HWY_RESTRICT p) { + return Vec256{_mm256_loadu_pd(p)}; +} + +// ------------------------------ MaskedLoad + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm256_maskz_loadu_epi8(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm256_maskz_loadu_epi16(m.raw, p)}); +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm256_maskz_loadu_epi32(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm256_maskz_loadu_epi64(m.raw, p)}; +} + +template +HWY_API Vec256 MaskedLoad(Mask256 m, D /* tag */, + const float* HWY_RESTRICT p) { + return Vec256{_mm256_maskz_loadu_ps(m.raw, p)}; +} + +template +HWY_API Vec256 MaskedLoad(Mask256 m, D /* tag */, + const double* HWY_RESTRICT p) { + return Vec256{_mm256_maskz_loadu_pd(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm256_mask_loadu_epi8(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{ + _mm256_mask_loadu_epi16(BitCast(du, v).raw, m.raw, p)}); +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm256_mask_loadu_epi32(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm256_mask_loadu_epi64(v.raw, m.raw, p)}; +} + +template +HWY_API Vec256 MaskedLoadOr(VFromD v, Mask256 m, D /* tag */, + const float* HWY_RESTRICT p) { + return Vec256{_mm256_mask_loadu_ps(v.raw, m.raw, p)}; +} + +template +HWY_API Vec256 MaskedLoadOr(VFromD v, Mask256 m, D /* tag */, + const double* HWY_RESTRICT p) { + return Vec256{_mm256_mask_loadu_pd(v.raw, m.raw, p)}; +} + +#else // AVX2 + +// There is no maskload_epi8/16, so blend instead. +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + return IfThenElseZero(m, LoadU(d, p)); +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + return VFromD{_mm256_maskload_epi32(pi, m.raw)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + return VFromD{_mm256_maskload_epi64(pi, m.raw)}; +} + +template +HWY_API Vec256 MaskedLoad(Mask256 m, D d, + const float* HWY_RESTRICT p) { + const Vec256 mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + return Vec256{_mm256_maskload_ps(p, mi.raw)}; +} + +template +HWY_API Vec256 MaskedLoad(Mask256 m, D d, + const double* HWY_RESTRICT p) { + const Vec256 mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + return Vec256{_mm256_maskload_pd(p, mi.raw)}; +} + +#endif + +// ------------------------------ LoadDup128 + +// Loads 128 bit and duplicates into both 128-bit halves. This avoids the +// 3-cycle cost of moving data between 128-bit halves and avoids port 5. +template +HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; + const Full128> d128; + const RebindToUnsigned du128; + const __m128i v128 = BitCast(du128, LoadU(d128, p)).raw; +#if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1931 + // Workaround for incorrect results with _mm256_broadcastsi128_si256. Note + // that MSVC also lacks _mm256_zextsi128_si256, but cast (which leaves the + // upper half undefined) is fine because we're overwriting that anyway. + // This workaround seems in turn to generate incorrect code in MSVC 2022 + // (19.31), so use broadcastsi128 there. + return BitCast(d, VFromD{_mm256_inserti128_si256( + _mm256_castsi128_si256(v128), v128, 1)}); +#else + // The preferred path. This is perhaps surprising, because vbroadcasti128 + // with xmm input has 7 cycle latency on Intel, but Clang >= 7 is able to + // pattern-match this to vbroadcastf128 with a memory operand as desired. + return BitCast(d, VFromD{_mm256_broadcastsi128_si256(v128)}); +#endif +} +template +HWY_API Vec256 LoadDup128(D /* tag */, const float* HWY_RESTRICT p) { +#if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1931 + const Full128 d128; + const __m128 v128 = LoadU(d128, p).raw; + return Vec256{ + _mm256_insertf128_ps(_mm256_castps128_ps256(v128), v128, 1)}; +#else + return Vec256{_mm256_broadcast_ps(reinterpret_cast(p))}; +#endif +} +template +HWY_API Vec256 LoadDup128(D /* tag */, const double* HWY_RESTRICT p) { +#if HWY_COMPILER_MSVC && HWY_COMPILER_MSVC < 1931 + const Full128 d128; + const __m128d v128 = LoadU(d128, p).raw; + return Vec256{ + _mm256_insertf128_pd(_mm256_castpd128_pd256(v128), v128, 1)}; +#else + return Vec256{ + _mm256_broadcast_pd(reinterpret_cast(p))}; +#endif +} + +// ------------------------------ Store + +template +HWY_API void Store(VFromD v, D /* tag */, TFromD* HWY_RESTRICT aligned) { + _mm256_store_si256(reinterpret_cast<__m256i*>(aligned), v.raw); +} +#if HWY_HAVE_FLOAT16 +template +HWY_API void Store(Vec256 v, D /* tag */, + float16_t* HWY_RESTRICT aligned) { + _mm256_store_ph(aligned, v.raw); +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API void Store(Vec256 v, D /* tag */, float* HWY_RESTRICT aligned) { + _mm256_store_ps(aligned, v.raw); +} +template +HWY_API void Store(Vec256 v, D /* tag */, + double* HWY_RESTRICT aligned) { + _mm256_store_pd(aligned, v.raw); +} + +template +HWY_API void StoreU(VFromD v, D /* tag */, TFromD* HWY_RESTRICT p) { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(p), v.raw); +} +#if HWY_HAVE_FLOAT16 +template +HWY_API void StoreU(Vec256 v, D /* tag */, + float16_t* HWY_RESTRICT p) { + _mm256_storeu_ph(p, v.raw); +} +#endif +template +HWY_API void StoreU(Vec256 v, D /* tag */, float* HWY_RESTRICT p) { + _mm256_storeu_ps(p, v.raw); +} +template +HWY_API void StoreU(Vec256 v, D /* tag */, double* HWY_RESTRICT p) { + _mm256_storeu_pd(p, v.raw); +} + +// ------------------------------ BlendedStore + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT p) { + _mm256_mask_storeu_epi8(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + _mm256_mask_storeu_epi16(reinterpret_cast(p), + RebindMask(du, m).raw, BitCast(du, v).raw); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT p) { + _mm256_mask_storeu_epi32(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT p) { + _mm256_mask_storeu_epi64(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, D /* tag */, + float* HWY_RESTRICT p) { + _mm256_mask_storeu_ps(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, D /* tag */, + double* HWY_RESTRICT p) { + _mm256_mask_storeu_pd(p, m.raw, v.raw); +} + +#else // AVX2 + +// Intel SDM says "No AC# reported for any mask bit combinations". However, AMD +// allows AC# if "Alignment checking enabled and: 256-bit memory operand not +// 32-byte aligned". Fortunately AC# is not enabled by default and requires both +// OS support (CR0) and the application to set rflags.AC. We assume these remain +// disabled because x86/x64 code and compiler output often contain misaligned +// scalar accesses, which would also fault. +// +// Caveat: these are slow on AMD Jaguar/Bulldozer. + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + // There is no maskload_epi8/16. Blending is also unsafe because loading a + // full vector that crosses the array end causes asan faults. Resort to scalar + // code; the caller should instead use memcpy, assuming m is FirstN(d, n). + const RebindToUnsigned du; + using TU = TFromD; + alignas(32) TU buf[MaxLanes(d)]; + alignas(32) TU mask[MaxLanes(d)]; + Store(BitCast(du, v), du, buf); + Store(BitCast(du, VecFromMask(d, m)), du, mask); + for (size_t i = 0; i < MaxLanes(d); ++i) { + if (mask[i]) { + CopySameSize(buf + i, p + i); + } + } +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + _mm256_maskstore_epi32(pi, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT p) { + auto pi = reinterpret_cast(p); // NOLINT + _mm256_maskstore_epi64(pi, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, D d, + float* HWY_RESTRICT p) { + const Vec256 mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + _mm256_maskstore_ps(p, mi.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec256 v, Mask256 m, D d, + double* HWY_RESTRICT p) { + const Vec256 mi = + BitCast(RebindToSigned(), VecFromMask(d, m)); + _mm256_maskstore_pd(p, mi.raw, v.raw); +} + +#endif + +// ------------------------------ Non-temporal stores + +template +HWY_API void Stream(VFromD v, D d, TFromD* HWY_RESTRICT aligned) { + const RebindToUnsigned du; // for float16_t + _mm256_stream_si256(reinterpret_cast<__m256i*>(aligned), BitCast(du, v).raw); +} +template +HWY_API void Stream(Vec256 v, D /* tag */, float* HWY_RESTRICT aligned) { + _mm256_stream_ps(aligned, v.raw); +} +template +HWY_API void Stream(Vec256 v, D /* tag */, + double* HWY_RESTRICT aligned) { + _mm256_stream_pd(aligned, v.raw); +} + +// ------------------------------ ScatterOffset + +// Work around warnings in the intrinsic definitions (passing -1 as a mask). +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +#if HWY_TARGET <= HWY_AVX3 + +template +HWY_API void ScatterOffset(VFromD v, D /* tag */, + TFromD* HWY_RESTRICT base, + Vec256 offset) { + _mm256_i32scatter_epi32(base, offset.raw, v.raw, 1); +} + +template +HWY_API void ScatterOffset(VFromD v, D /* tag */, + TFromD* HWY_RESTRICT base, + Vec256 offset) { + _mm256_i64scatter_epi64(base, offset.raw, v.raw, 1); +} + +template +HWY_API void ScatterOffset(VFromD v, D /* tag */, float* HWY_RESTRICT base, + const Vec256 offset) { + _mm256_i32scatter_ps(base, offset.raw, v.raw, 1); +} + +template +HWY_API void ScatterOffset(VFromD v, D /* tag */, double* HWY_RESTRICT base, + const Vec256 offset) { + _mm256_i64scatter_pd(base, offset.raw, v.raw, 1); +} + +// ------------------------------ ScatterIndex + +template +HWY_API void ScatterIndex(VFromD v, D /* tag */, + TFromD* HWY_RESTRICT base, + VFromD> index) { + _mm256_i32scatter_epi32(base, index.raw, v.raw, 4); +} + +template +HWY_API void ScatterIndex(VFromD v, D /* tag */, + TFromD* HWY_RESTRICT base, + VFromD> index) { + _mm256_i64scatter_epi64(base, index.raw, v.raw, 8); +} + +template +HWY_API void ScatterIndex(VFromD v, D /* tag */, float* HWY_RESTRICT base, + VFromD> index) { + _mm256_i32scatter_ps(base, index.raw, v.raw, 4); +} + +template +HWY_API void ScatterIndex(VFromD v, D /* tag */, double* HWY_RESTRICT base, + VFromD> index) { + _mm256_i64scatter_pd(base, index.raw, v.raw, 8); +} + +// ------------------------------ MaskedScatterIndex + +template +HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT base, + VFromD> index) { + _mm256_mask_i32scatter_epi32(base, m.raw, index.raw, v.raw, 4); +} + +template +HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT base, + VFromD> index) { + _mm256_mask_i64scatter_epi64(base, m.raw, index.raw, v.raw, 8); +} + +template +HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D /* tag */, + float* HWY_RESTRICT base, + VFromD> index) { + _mm256_mask_i32scatter_ps(base, m.raw, index.raw, v.raw, 4); +} + +template +HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D /* tag */, + double* HWY_RESTRICT base, + VFromD> index) { + _mm256_mask_i64scatter_pd(base, m.raw, index.raw, v.raw, 8); +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Gather + +namespace detail { + +template +HWY_INLINE Vec256 NativeGather256(const T* HWY_RESTRICT base, + Vec256 indices) { + return Vec256{_mm256_i32gather_epi32( + reinterpret_cast(base), indices.raw, kScale)}; +} + +template +HWY_INLINE Vec256 NativeGather256(const T* HWY_RESTRICT base, + Vec256 indices) { + return Vec256{_mm256_i64gather_epi64( + reinterpret_cast(base), indices.raw, kScale)}; +} + +template +HWY_API Vec256 NativeGather256(const float* HWY_RESTRICT base, + Vec256 indices) { + return Vec256{_mm256_i32gather_ps(base, indices.raw, kScale)}; +} + +template +HWY_API Vec256 NativeGather256(const double* HWY_RESTRICT base, + Vec256 indices) { + return Vec256{_mm256_i64gather_pd(base, indices.raw, kScale)}; +} + +} // namespace detail + +template +HWY_API VFromD GatherOffset(D /*d*/, const TFromD* HWY_RESTRICT base, + VFromD> offsets) { + return detail::NativeGather256<1>(base, offsets); +} + +template +HWY_API VFromD GatherIndex(D /*d*/, const TFromD* HWY_RESTRICT base, + VFromD> indices) { + return detail::NativeGather256)>(base, indices); +} + +// ------------------------------ MaskedGatherIndexOr + +namespace detail { + +template +HWY_INLINE Vec256 NativeMaskedGatherOr256(Vec256 no, Mask256 m, + const T* HWY_RESTRICT base, + Vec256 indices) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_mmask_i32gather_epi32( + no.raw, m.raw, indices.raw, reinterpret_cast(base), + kScale)}; +#else + return Vec256{_mm256_mask_i32gather_epi32( + no.raw, reinterpret_cast(base), indices.raw, m.raw, + kScale)}; +#endif +} + +template +HWY_INLINE Vec256 NativeMaskedGatherOr256(Vec256 no, Mask256 m, + const T* HWY_RESTRICT base, + Vec256 indices) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_mmask_i64gather_epi64( + no.raw, m.raw, indices.raw, reinterpret_cast(base), + kScale)}; +#else + // For reasons unknown, _mm256_mask_i64gather_epi64 returns all-zeros. + const Full256 d; + const Full256 dd; + return BitCast(d, + Vec256{_mm256_mask_i64gather_pd( + BitCast(dd, no).raw, reinterpret_cast(base), + indices.raw, RebindMask(dd, m).raw, kScale)}); +#endif +} + +template +HWY_API Vec256 NativeMaskedGatherOr256(Vec256 no, + Mask256 m, + const float* HWY_RESTRICT base, + Vec256 indices) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{ + _mm256_mmask_i32gather_ps(no.raw, m.raw, indices.raw, base, kScale)}; +#else + return Vec256{ + _mm256_mask_i32gather_ps(no.raw, base, indices.raw, m.raw, kScale)}; +#endif +} + +template +HWY_API Vec256 NativeMaskedGatherOr256(Vec256 no, + Mask256 m, + const double* HWY_RESTRICT base, + Vec256 indices) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{ + _mm256_mmask_i64gather_pd(no.raw, m.raw, indices.raw, base, kScale)}; +#else + return Vec256{ + _mm256_mask_i64gather_pd(no.raw, base, indices.raw, m.raw, kScale)}; +#endif +} + +} // namespace detail + +template +HWY_API VFromD MaskedGatherIndexOr(VFromD no, MFromD m, D /*d*/, + const TFromD* HWY_RESTRICT base, + VFromD> indices) { + return detail::NativeMaskedGatherOr256)>(no, m, base, + indices); +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== SWIZZLE + +// ------------------------------ LowerHalf + +template +HWY_API VFromD LowerHalf(D /* tag */, VFromD> v) { + return VFromD{_mm256_castsi256_si128(v.raw)}; +} +template +HWY_API Vec128 LowerHalf(D /* tag */, Vec256 v) { + return Vec128{_mm256_castsi256_si128(v.raw)}; +} +template +HWY_API Vec128 LowerHalf(D /* tag */, Vec256 v) { +#if HWY_HAVE_FLOAT16 + return Vec128{_mm256_castph256_ph128(v.raw)}; +#else + return Vec128{_mm256_castsi256_si128(v.raw)}; +#endif // HWY_HAVE_FLOAT16 +} +template +HWY_API Vec128 LowerHalf(D /* tag */, Vec256 v) { + return Vec128{_mm256_castps256_ps128(v.raw)}; +} +template +HWY_API Vec128 LowerHalf(D /* tag */, Vec256 v) { + return Vec128{_mm256_castpd256_pd128(v.raw)}; +} + +template +HWY_API Vec128 LowerHalf(Vec256 v) { + const Full128 dh; + return LowerHalf(dh, v); +} + +// ------------------------------ UpperHalf + +template +HWY_API VFromD UpperHalf(D d, VFromD> v) { + const RebindToUnsigned du; // for float16_t + const Twice dut; + return BitCast(d, VFromD{ + _mm256_extracti128_si256(BitCast(dut, v).raw, 1)}); +} +template +HWY_API VFromD UpperHalf(D /* tag */, Vec256 v) { + return VFromD{_mm256_extractf128_ps(v.raw, 1)}; +} +template +HWY_API VFromD UpperHalf(D /* tag */, Vec256 v) { + return VFromD{_mm256_extractf128_pd(v.raw, 1)}; +} + +// ------------------------------ ExtractLane (Store) +template +HWY_API T ExtractLane(const Vec256 v, size_t i) { + const DFromV d; + HWY_DASSERT(i < Lanes(d)); + +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + constexpr size_t kLanesPerBlock = 16 / sizeof(T); + if (__builtin_constant_p(i < kLanesPerBlock) && (i < kLanesPerBlock)) { + return ExtractLane(LowerHalf(Half(), v), i); + } +#endif + + alignas(32) T lanes[32 / sizeof(T)]; + Store(v, d, lanes); + return lanes[i]; +} + +// ------------------------------ InsertLane (Store) +template +HWY_API Vec256 InsertLane(const Vec256 v, size_t i, T t) { + return detail::InsertLaneUsingBroadcastAndBlend(v, i, t); +} + +// ------------------------------ GetLane (LowerHalf) +template +HWY_API T GetLane(const Vec256 v) { + return GetLane(LowerHalf(v)); +} + +// ------------------------------ ExtractBlock (LowerHalf, UpperHalf) + +template +HWY_API Vec128 ExtractBlock(Vec256 v) { + static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index"); + const Half> dh; + return (kBlockIdx == 0) ? LowerHalf(dh, v) : UpperHalf(dh, v); +} + +// ------------------------------ ZeroExtendVector + +// Unfortunately the initial _mm256_castsi128_si256 intrinsic leaves the upper +// bits undefined. Although it makes sense for them to be zero (VEX encoded +// 128-bit instructions zero the upper lanes to avoid large penalties), a +// compiler could decide to optimize out code that relies on this. +// +// The newer _mm256_zextsi128_si256 intrinsic fixes this by specifying the +// zeroing, but it is not available on MSVC until 1920 nor GCC until 10.1. +// Unfortunately as of 2023-08 it still seems to cause internal compiler errors +// on MSVC, so we consider it unavailable there. +// +// Without zext we can still possibly obtain the desired code thanks to pattern +// recognition; note that the expensive insert instruction might not actually be +// generated, see https://gcc.godbolt.org/z/1MKGaP. + +#if !defined(HWY_HAVE_ZEXT) +#if (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 500) || \ + (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL >= 1000) +#define HWY_HAVE_ZEXT 1 +#else +#define HWY_HAVE_ZEXT 0 +#endif +#endif // defined(HWY_HAVE_ZEXT) + +template +HWY_API VFromD ZeroExtendVector(D /* tag */, VFromD> lo) { +#if HWY_HAVE_ZEXT + return VFromD{_mm256_zextsi128_si256(lo.raw)}; +#elif HWY_COMPILER_MSVC + // Workaround: _mm256_inserti128_si256 does not actually zero the hi part. + return VFromD{_mm256_set_m128i(_mm_setzero_si128(), lo.raw)}; +#else + return VFromD{_mm256_inserti128_si256(_mm256_setzero_si256(), lo.raw, 0)}; +#endif +} +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec256 ZeroExtendVector(D d, Vec128 lo) { +#if HWY_HAVE_ZEXT + (void)d; + return Vec256{_mm256_zextph128_ph256(lo.raw)}; +#else + const RebindToUnsigned du; + return BitCast(d, ZeroExtendVector(du, BitCast(du, lo))); +#endif // HWY_HAVE_ZEXT +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec256 ZeroExtendVector(D /* tag */, Vec128 lo) { +#if HWY_HAVE_ZEXT + return Vec256{_mm256_zextps128_ps256(lo.raw)}; +#else + return Vec256{_mm256_insertf128_ps(_mm256_setzero_ps(), lo.raw, 0)}; +#endif +} +template +HWY_API Vec256 ZeroExtendVector(D /* tag */, Vec128 lo) { +#if HWY_HAVE_ZEXT + return Vec256{_mm256_zextpd128_pd256(lo.raw)}; +#else + return Vec256{_mm256_insertf128_pd(_mm256_setzero_pd(), lo.raw, 0)}; +#endif +} + +// ------------------------------ ZeroExtendResizeBitCast + +namespace detail { + +template +HWY_INLINE VFromD ZeroExtendResizeBitCast( + hwy::SizeTag<8> /* from_size_tag */, hwy::SizeTag<32> /* to_size_tag */, + DTo d_to, DFrom d_from, VFromD v) { + const Twice dt_from; + const Twice dq_from; + return BitCast(d_to, ZeroExtendVector(dq_from, ZeroExtendVector(dt_from, v))); +} + +} // namespace detail + +// ------------------------------ Combine + +template +HWY_API VFromD Combine(D d, VFromD> hi, VFromD> lo) { + const RebindToUnsigned du; // for float16_t + const Half dh_u; + const auto lo256 = ZeroExtendVector(du, BitCast(dh_u, lo)); + return BitCast(d, VFromD{_mm256_inserti128_si256( + lo256.raw, BitCast(dh_u, hi).raw, 1)}); +} +template +HWY_API Vec256 Combine(D d, Vec128 hi, Vec128 lo) { + const auto lo256 = ZeroExtendVector(d, lo); + return Vec256{_mm256_insertf128_ps(lo256.raw, hi.raw, 1)}; +} +template +HWY_API Vec256 Combine(D d, Vec128 hi, Vec128 lo) { + const auto lo256 = ZeroExtendVector(d, lo); + return Vec256{_mm256_insertf128_pd(lo256.raw, hi.raw, 1)}; +} + +// ------------------------------ ShiftLeftBytes +template +HWY_API VFromD ShiftLeftBytes(D /* tag */, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + // This is the same operation as _mm256_bslli_epi128. + return VFromD{_mm256_slli_si256(v.raw, kBytes)}; +} + +// ------------------------------ ShiftRightBytes +template +HWY_API VFromD ShiftRightBytes(D /* tag */, VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + // This is the same operation as _mm256_bsrli_epi128. + return VFromD{_mm256_srli_si256(v.raw, kBytes)}; +} + +// ------------------------------ CombineShiftRightBytes +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + const Repartition d8; + return BitCast(d, Vec256{_mm256_alignr_epi8( + BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); +} + +// ------------------------------ Broadcast + +template +HWY_API Vec256 Broadcast(const Vec256 v) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const VU vu = BitCast(du, v); // for float16_t + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + if (kLane < 4) { + const __m256i lo = _mm256_shufflelo_epi16(vu.raw, (0x55 * kLane) & 0xFF); + return BitCast(d, VU{_mm256_unpacklo_epi64(lo, lo)}); + } else { + const __m256i hi = + _mm256_shufflehi_epi16(vu.raw, (0x55 * (kLane - 4)) & 0xFF); + return BitCast(d, VU{_mm256_unpackhi_epi64(hi, hi)}); + } +} +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec256{_mm256_shuffle_epi32(v.raw, 0x55 * kLane)}; +} + +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec256{_mm256_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; +} + +template +HWY_API Vec256 Broadcast(Vec256 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x55 * kLane)}; +} + +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + return Vec256{_mm256_shuffle_pd(v.raw, v.raw, 15 * kLane)}; +} + +// ------------------------------ Concat blocks (LowerHalf, ZeroExtendVector) + +// _mm256_broadcastsi128_si256 has 7 cycle latency on ICL. +// _mm256_permute2x128_si256 is slow on Zen1 (8 uops), so we avoid it (at no +// extra cost) for LowerLower and UpperLower. + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; // for float16_t + const Half d2; + const RebindToUnsigned du2; // for float16_t + return BitCast( + d, VFromD{_mm256_inserti128_si256( + BitCast(du, lo).raw, BitCast(du2, LowerHalf(d2, hi)).raw, 1)}); +} +template +HWY_API Vec256 ConcatLowerLower(D d, Vec256 hi, + Vec256 lo) { + const Half d2; + return Vec256{_mm256_insertf128_ps(lo.raw, LowerHalf(d2, hi).raw, 1)}; +} +template +HWY_API Vec256 ConcatLowerLower(D d, Vec256 hi, + Vec256 lo) { + const Half d2; + return Vec256{_mm256_insertf128_pd(lo.raw, LowerHalf(d2, hi).raw, 1)}; +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) +template +HWY_API VFromD ConcatLowerUpper(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm256_permute2x128_si256( + BitCast(du, lo).raw, BitCast(du, hi).raw, 0x21)}); +} +template +HWY_API Vec256 ConcatLowerUpper(D /* tag */, Vec256 hi, + Vec256 lo) { + return Vec256{_mm256_permute2f128_ps(lo.raw, hi.raw, 0x21)}; +} +template +HWY_API Vec256 ConcatLowerUpper(D /* tag */, Vec256 hi, + Vec256 lo) { + return Vec256{_mm256_permute2f128_pd(lo.raw, hi.raw, 0x21)}; +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template +HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm256_blend_epi32( + BitCast(du, hi).raw, BitCast(du, lo).raw, 0x0F)}); +} +template +HWY_API Vec256 ConcatUpperLower(D /* tag */, Vec256 hi, + Vec256 lo) { + return Vec256{_mm256_blend_ps(hi.raw, lo.raw, 0x0F)}; +} +template +HWY_API Vec256 ConcatUpperLower(D /* tag */, Vec256 hi, + Vec256 lo) { + return Vec256{_mm256_blend_pd(hi.raw, lo.raw, 3)}; +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm256_permute2x128_si256( + BitCast(du, lo).raw, BitCast(du, hi).raw, 0x31)}); +} +template +HWY_API Vec256 ConcatUpperUpper(D /* tag */, Vec256 hi, + Vec256 lo) { + return Vec256{_mm256_permute2f128_ps(lo.raw, hi.raw, 0x31)}; +} +template +HWY_API Vec256 ConcatUpperUpper(D /* tag */, Vec256 hi, + Vec256 lo) { + return Vec256{_mm256_permute2f128_pd(lo.raw, hi.raw, 0x31)}; +} + +// ------------------------------ BroadcastBlock +template +HWY_API Vec256 BroadcastBlock(Vec256 v) { + static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index"); + const DFromV d; + return (kBlockIdx == 0) ? ConcatLowerLower(d, v, v) + : ConcatUpperUpper(d, v, v); +} + +// ------------------------------ BroadcastLane + +namespace detail { + +template +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec256 v) { + const Half> dh; + return Vec256{_mm256_broadcastb_epi8(LowerHalf(dh, v).raw)}; +} + +template +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec256 v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + const Half dh; + const RebindToUnsigned dh_u; + return BitCast(d, VFromD{_mm256_broadcastw_epi16( + BitCast(dh_u, LowerHalf(dh, v)).raw)}); +} + +template +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec256 v) { + const Half> dh; + return Vec256{_mm256_broadcastd_epi32(LowerHalf(dh, v).raw)}; +} + +template +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec256 v) { + const Half> dh; + return Vec256{_mm256_broadcastq_epi64(LowerHalf(dh, v).raw)}; +} + +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec256 v) { + const Half> dh; + return Vec256{_mm256_broadcastss_ps(LowerHalf(dh, v).raw)}; +} + +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec256 v) { + const Half> dh; + return Vec256{_mm256_broadcastsd_pd(LowerHalf(dh, v).raw)}; +} + +template * = nullptr, + HWY_IF_NOT_T_SIZE(T, 8)> +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag /* lane_idx_tag */, + Vec256 v) { + constexpr size_t kLanesPerBlock = 16 / sizeof(T); + constexpr int kBlockIdx = static_cast(kLaneIdx / kLanesPerBlock); + constexpr int kLaneInBlkIdx = + static_cast(kLaneIdx) & (kLanesPerBlock - 1); + return Broadcast(BroadcastBlock(v)); +} + +template * = nullptr, + HWY_IF_UI64(T)> +HWY_INLINE Vec256 BroadcastLane(hwy::SizeTag /* lane_idx_tag */, + Vec256 v) { + static_assert(kLaneIdx <= 3, "Invalid lane"); + return Vec256{ + _mm256_permute4x64_epi64(v.raw, static_cast(0x55 * kLaneIdx))}; +} + +template * = nullptr> +HWY_INLINE Vec256 BroadcastLane( + hwy::SizeTag /* lane_idx_tag */, Vec256 v) { + static_assert(kLaneIdx <= 3, "Invalid lane"); + return Vec256{ + _mm256_permute4x64_pd(v.raw, static_cast(0x55 * kLaneIdx))}; +} + +} // namespace detail + +template +HWY_API Vec256 BroadcastLane(Vec256 v) { + static_assert(kLaneIdx >= 0, "Invalid lane"); + return detail::BroadcastLane(hwy::SizeTag(kLaneIdx)>(), + v); +} + +// ------------------------------ Hard-coded shuffles + +// Notation: let Vec256 have lanes 7,6,5,4,3,2,1,0 (0 is +// least-significant). Shuffle0321 rotates four-lane blocks one lane to the +// right (the previous least-significant lane is now most-significant => +// 47650321). These could also be implemented via CombineShiftRightBytes but +// the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec256 Shuffle2301(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0xB1)}; +} +HWY_API Vec256 Shuffle2301(const Vec256 v) { + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0xB1)}; +} + +// Used by generic_ops-inl.h +namespace detail { + +template +HWY_API Vec256 ShuffleTwo2301(const Vec256 a, const Vec256 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(2, 3, 0, 1); + return BitCast(d, Vec256{_mm256_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} +template +HWY_API Vec256 ShuffleTwo1230(const Vec256 a, const Vec256 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(1, 2, 3, 0); + return BitCast(d, Vec256{_mm256_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} +template +HWY_API Vec256 ShuffleTwo3012(const Vec256 a, const Vec256 b) { + const DFromV d; + const RebindToFloat df; + constexpr int m = _MM_SHUFFLE(3, 0, 1, 2); + return BitCast(d, Vec256{_mm256_shuffle_ps(BitCast(df, a).raw, + BitCast(df, b).raw, m)}); +} + +} // namespace detail + +// Swap 64-bit halves +HWY_API Vec256 Shuffle1032(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle1032(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle1032(const Vec256 v) { + // Shorter encoding than _mm256_permute_ps. + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle01(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle01(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; +} +HWY_API Vec256 Shuffle01(const Vec256 v) { + // Shorter encoding than _mm256_permute_pd. + return Vec256{_mm256_shuffle_pd(v.raw, v.raw, 5)}; +} + +// Rotate right 32 bits +HWY_API Vec256 Shuffle0321(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec256 Shuffle0321(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x39)}; +} +HWY_API Vec256 Shuffle0321(const Vec256 v) { + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x39)}; +} +// Rotate left 32 bits +HWY_API Vec256 Shuffle2103(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec256 Shuffle2103(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x93)}; +} +HWY_API Vec256 Shuffle2103(const Vec256 v) { + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x93)}; +} + +// Reverse +HWY_API Vec256 Shuffle0123(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec256 Shuffle0123(const Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, 0x1B)}; +} +HWY_API Vec256 Shuffle0123(const Vec256 v) { + return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x1B)}; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. +template +struct Indices256 { + __m256i raw; +}; + +// 8-bit lanes: indices remain unchanged +template +HWY_API Indices256> IndicesFromVec(D /* tag */, Vec256 vec) { + static_assert(sizeof(TFromD) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Full256 di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast(2 * Lanes(di)))))); +#endif + return Indices256>{vec.raw}; +} + +// 16-bit lanes: convert indices to 32x8 unless AVX3 is available +template +HWY_API Indices256> IndicesFromVec(D /* tag */, Vec256 vec) { + static_assert(sizeof(TFromD) == sizeof(TI), "Index size must match lane"); + const Full256 di; +#if HWY_IS_DEBUG_BUILD + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast(2 * Lanes(di)))))); +#endif + +#if HWY_TARGET <= HWY_AVX3 + (void)di; + return Indices256>{vec.raw}; +#else + const Repartition d8; + using V8 = VFromD; + alignas(32) static constexpr uint8_t kByteOffsets[32] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}; + + // Broadcast each lane index to all 2 bytes of T + alignas(32) static constexpr uint8_t kBroadcastLaneBytes[32] = { + 0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14, + 0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14}; + const V8 lane_indices = TableLookupBytes(vec, Load(d8, kBroadcastLaneBytes)); + + // Shift to bytes + const Repartition d16; + const V8 byte_indices = BitCast(d8, ShiftLeft<1>(BitCast(d16, lane_indices))); + + return Indices256>{Add(byte_indices, Load(d8, kByteOffsets)).raw}; +#endif // HWY_TARGET <= HWY_AVX3 +} + +// Native 8x32 instruction: indices remain unchanged +template +HWY_API Indices256> IndicesFromVec(D /* tag */, Vec256 vec) { + static_assert(sizeof(TFromD) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const Full256 di; + HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && + AllTrue(di, Lt(vec, Set(di, static_cast(2 * Lanes(di)))))); +#endif + return Indices256>{vec.raw}; +} + +// 64-bit lanes: convert indices to 8x32 unless AVX3 is available +template +HWY_API Indices256> IndicesFromVec(D d, Vec256 idx64) { + static_assert(sizeof(TFromD) == sizeof(TI), "Index size must match lane"); + const Rebind di; + (void)di; // potentially unused +#if HWY_IS_DEBUG_BUILD + HWY_DASSERT(AllFalse(di, Lt(idx64, Zero(di))) && + AllTrue(di, Lt(idx64, Set(di, static_cast(2 * Lanes(di)))))); +#endif + +#if HWY_TARGET <= HWY_AVX3 + (void)d; + return Indices256>{idx64.raw}; +#else + const Repartition df; // 32-bit! + // Replicate 64-bit index into upper 32 bits + const Vec256 dup = + BitCast(di, Vec256{_mm256_moveldup_ps(BitCast(df, idx64).raw)}); + // For each idx64 i, idx32 are 2*i and 2*i+1. + const Vec256 idx32 = dup + dup + Set(di, TI(1) << 32); + return Indices256>{idx32.raw}; +#endif +} + +template +HWY_API Indices256> SetTableIndices(D d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec256{_mm256_permutexvar_epi8(idx.raw, v.raw)}; +#else + const Vec256 idx_vec{idx.raw}; + const DFromV d; + const Repartition du16; + const auto sel_hi_mask = + MaskFromVec(BitCast(d, ShiftLeft<3>(BitCast(du16, idx_vec)))); + + const auto a = ConcatLowerLower(d, v, v); + const auto b = ConcatUpperUpper(d, v, v); + const auto lo_lookup_result = TableLookupBytes(a, idx_vec); + +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_mask_shuffle_epi8( + lo_lookup_result.raw, sel_hi_mask.raw, b.raw, idx_vec.raw)}; +#else + const auto hi_lookup_result = TableLookupBytes(b, idx_vec); + return IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result); +#endif // HWY_TARGET <= HWY_AVX3 +#endif // HWY_TARGET <= HWY_AVX3_DL +} + +template +HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_permutexvar_epi16(idx.raw, v.raw)}; +#else + const DFromV d; + const Repartition du8; + return BitCast( + d, TableLookupLanes(BitCast(du8, v), Indices256{idx.raw})); +#endif +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 TableLookupLanes(Vec256 v, + Indices256 idx) { + return Vec256{_mm256_permutexvar_ph(idx.raw, v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { + return Vec256{_mm256_permutevar8x32_epi32(v.raw, idx.raw)}; +} + +template +HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_permutexvar_epi64(idx.raw, v.raw)}; +#else + return Vec256{_mm256_permutevar8x32_epi32(v.raw, idx.raw)}; +#endif +} + +HWY_API Vec256 TableLookupLanes(const Vec256 v, + const Indices256 idx) { + return Vec256{_mm256_permutevar8x32_ps(v.raw, idx.raw)}; +} + +HWY_API Vec256 TableLookupLanes(const Vec256 v, + const Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_permutexvar_pd(idx.raw, v.raw)}; +#else + const Full256 df; + const Full256 du; + return BitCast(df, Vec256{_mm256_permutevar8x32_epi32( + BitCast(du, v).raw, idx.raw)}); +#endif +} + +template +HWY_API Vec256 TwoTablesLookupLanes(Vec256 a, Vec256 b, + Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec256{_mm256_permutex2var_epi8(a.raw, idx.raw, b.raw)}; +#else + const DFromV d; + const auto sel_hi_mask = + MaskFromVec(BitCast(d, ShiftLeft<2>(Vec256{idx.raw}))); + const auto lo_lookup_result = TableLookupLanes(a, idx); + const auto hi_lookup_result = TableLookupLanes(b, idx); + return IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result); +#endif +} + +template +HWY_API Vec256 TwoTablesLookupLanes(Vec256 a, Vec256 b, + Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_permutex2var_epi16(a.raw, idx.raw, b.raw)}; +#else + const DFromV d; + const Repartition du8; + return BitCast(d, TwoTablesLookupLanes(BitCast(du8, a), BitCast(du8, b), + Indices256{idx.raw})); +#endif +} + +template +HWY_API Vec256 TwoTablesLookupLanes(Vec256 a, Vec256 b, + Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_permutex2var_epi32(a.raw, idx.raw, b.raw)}; +#else + const DFromV d; + const RebindToFloat df; + const Vec256 idx_vec{idx.raw}; + + const auto sel_hi_mask = MaskFromVec(BitCast(df, ShiftLeft<28>(idx_vec))); + const auto lo_lookup_result = BitCast(df, TableLookupLanes(a, idx)); + const auto hi_lookup_result = BitCast(df, TableLookupLanes(b, idx)); + return BitCast(d, + IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result)); +#endif +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec256 TwoTablesLookupLanes(Vec256 a, + Vec256 b, + Indices256 idx) { + return Vec256{_mm256_permutex2var_ph(a.raw, idx.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec256 TwoTablesLookupLanes(Vec256 a, Vec256 b, + Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_permutex2var_ps(a.raw, idx.raw, b.raw)}; +#else + const DFromV d; + const auto sel_hi_mask = + MaskFromVec(BitCast(d, ShiftLeft<28>(Vec256{idx.raw}))); + const auto lo_lookup_result = TableLookupLanes(a, idx); + const auto hi_lookup_result = TableLookupLanes(b, idx); + return IfThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result); +#endif +} + +template +HWY_API Vec256 TwoTablesLookupLanes(Vec256 a, Vec256 b, + Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_permutex2var_epi64(a.raw, idx.raw, b.raw)}; +#else + const DFromV d; + const Repartition du32; + return BitCast(d, TwoTablesLookupLanes(BitCast(du32, a), BitCast(du32, b), + Indices256{idx.raw})); +#endif +} + +HWY_API Vec256 TwoTablesLookupLanes(Vec256 a, Vec256 b, + Indices256 idx) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_permutex2var_pd(a.raw, idx.raw, b.raw)}; +#else + const DFromV d; + const Repartition du32; + return BitCast(d, TwoTablesLookupLanes(BitCast(du32, a), BitCast(du32, b), + Indices256{idx.raw})); +#endif +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm256_permute4x64_epi64( + BitCast(du, v).raw, _MM_SHUFFLE(1, 0, 3, 2))}); +} + +HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { + return Vec256{_mm256_permute4x64_pd(v.raw, _MM_SHUFFLE(1, 0, 3, 2))}; +} + +HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { + // Assume no domain-crossing penalty between float/double (true on SKX). + const DFromV d; + const RepartitionToWide dw; + return BitCast(d, SwapAdjacentBlocks(BitCast(dw, v))); +} + +// ------------------------------ InterleaveEvenBlocks (ConcatLowerLower) +template , HWY_IF_V_SIZE_D(D, 32)> +HWY_API V InterleaveEvenBlocks(D d, V a, V b) { + return ConcatLowerLower(d, b, a); +} + +// ------------------------------ InterleaveOddBlocks (ConcatUpperUpper) +template , HWY_IF_V_SIZE_D(D, 32)> +HWY_API V InterleaveOddBlocks(D d, V a, V b) { + return ConcatUpperUpper(d, b, a); +} + +// ------------------------------ InterleaveLowerBlocks +template , HWY_IF_V_SIZE_D(D, 32)> +HWY_API V InterleaveLowerBlocks(D d, V a, V b) { + return InterleaveEvenBlocks(d, a, b); +} +// ------------------------------ InterleaveUpperBlocks +template , HWY_IF_V_SIZE_D(D, 32)> +HWY_API V InterleaveUpperBlocks(D d, V a, V b) { + return InterleaveOddBlocks(d, a, b); +} + +// ------------------------------ Reverse (RotateRight) + +template +HWY_API VFromD Reverse(D d, const VFromD v) { + alignas(32) static constexpr int32_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +template +HWY_API VFromD Reverse(D d, const VFromD v) { + alignas(32) static constexpr int64_t kReverse[4] = {3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +template +HWY_API VFromD Reverse(D d, const VFromD v) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToSigned di; + alignas(32) static constexpr int16_t kReverse[16] = { + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; + const Vec256 idx = Load(di, kReverse); + return BitCast(d, Vec256{ + _mm256_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +#else + const RebindToSigned di; + const VFromD shuffle = Dup128VecFromValues( + di, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100); + const auto rev128 = TableLookupBytes(v, shuffle); + return VFromD{ + _mm256_permute4x64_epi64(rev128.raw, _MM_SHUFFLE(1, 0, 3, 2))}; +#endif +} + +template +HWY_API VFromD Reverse(D d, const VFromD v) { +#if HWY_TARGET <= HWY_AVX3_DL + alignas(32) static constexpr TFromD kReverse[32] = { + 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +#else + // First reverse bytes within blocks via PSHUFB, then swap blocks. + alignas(32) static constexpr TFromD kReverse[32] = { + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; + return SwapAdjacentBlocks(TableLookupBytes(v, Load(d, kReverse))); +#endif +} + +// ------------------------------ Reverse2 (in x86_128) + +// ------------------------------ Reverse4 (SwapAdjacentBlocks) + +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + const RebindToSigned di; + const VFromD shuffle = Dup128VecFromValues( + di, 0x0706, 0x0504, 0x0302, 0x0100, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908); + return BitCast(d, TableLookupBytes(v, shuffle)); +} + +// 32 bit Reverse4 defined in x86_128. + +template +HWY_API VFromD Reverse4(D /* tag */, const VFromD v) { + // Could also use _mm256_permute4x64_epi64. + return SwapAdjacentBlocks(Shuffle01(v)); +} + +// ------------------------------ Reverse8 + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + const RebindToSigned di; + const VFromD shuffle = Dup128VecFromValues( + di, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100); + return BitCast(d, TableLookupBytes(v, shuffle)); +} + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + return Reverse(d, v); +} + +template +HWY_API VFromD Reverse8(D /* tag */, const VFromD /* v */) { + HWY_ASSERT(0); // AVX2 does not have 8 64-bit lanes +} + +// ------------------------------ ReverseBits in x86_512 + +// ------------------------------ InterleaveLower + +// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides +// the least-significant lane) and "b". To concatenate two half-width integers +// into one, use ZipLower/Upper instead (also works with scalar). + +template +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{_mm256_unpacklo_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; // for float16_t + return BitCast( + d, VU{_mm256_unpacklo_epi16(BitCast(du, a).raw, BitCast(du, b).raw)}); +} +template +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{_mm256_unpacklo_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{_mm256_unpacklo_epi64(a.raw, b.raw)}; +} + +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{_mm256_unpacklo_ps(a.raw, b.raw)}; +} +HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { + return Vec256{_mm256_unpacklo_pd(a.raw, b.raw)}; +} + +// ------------------------------ InterleaveUpper + +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm256_unpackhi_epi8(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + using VU = VFromD; // for float16_t + return BitCast( + d, VU{_mm256_unpackhi_epi16(BitCast(du, a).raw, BitCast(du, b).raw)}); +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm256_unpackhi_epi32(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm256_unpackhi_epi64(a.raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm256_unpackhi_ps(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm256_unpackhi_pd(a.raw, b.raw)}; +} + +// ---------------------------- InsertBlock (ConcatLowerLower, ConcatUpperLower) +template +HWY_API Vec256 InsertBlock(Vec256 v, Vec128 blk_to_insert) { + static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index"); + + const DFromV d; + const auto vec_to_insert = ResizeBitCast(d, blk_to_insert); + return (kBlockIdx == 0) ? ConcatUpperLower(d, v, vec_to_insert) + : ConcatLowerLower(d, vec_to_insert, v); +} + +// ------------------------------ ConcatOdd + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3_DL + alignas(32) static constexpr uint8_t kIdx[32] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, + 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63}; + return BitCast( + d, Vec256{_mm256_permutex2var_epi8( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Unsigned 8-bit shift so we can pack. + const Vec256 uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec256 uL = ShiftRight<8>(BitCast(dw, lo)); + const __m256i u8 = _mm256_packus_epi16(uL.raw, uH.raw); + return VFromD{_mm256_permute4x64_epi64(u8, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(32) static constexpr uint16_t kIdx[16] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}; + return BitCast( + d, Vec256{_mm256_permutex2var_epi16( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Unsigned 16-bit shift so we can pack. + const Vec256 uH = ShiftRight<16>(BitCast(dw, hi)); + const Vec256 uL = ShiftRight<16>(BitCast(dw, lo)); + const __m256i u16 = _mm256_packus_epi32(uL.raw, uH.raw); + return BitCast(d, VFromD{_mm256_permute4x64_epi64( + u16, _MM_SHUFFLE(3, 1, 2, 0))}); +#endif +} + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(32) static constexpr uint32_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; + return BitCast( + d, Vec256{_mm256_permutex2var_epi32( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +#else + const RebindToFloat df; + const Vec256 v3131{_mm256_shuffle_ps( + BitCast(df, lo).raw, BitCast(df, hi).raw, _MM_SHUFFLE(3, 1, 3, 1))}; + return VFromD{_mm256_permute4x64_epi64(BitCast(du, v3131).raw, + _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(32) static constexpr uint32_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; + return VFromD{_mm256_permutex2var_ps(lo.raw, Load(du, kIdx).raw, hi.raw)}; +#else + const VFromD v3131{ + _mm256_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(3, 1, 3, 1))}; + return BitCast(d, Vec256{_mm256_permute4x64_epi64( + BitCast(du, v3131).raw, _MM_SHUFFLE(3, 1, 2, 0))}); +#endif +} + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) static constexpr uint64_t kIdx[4] = {1, 3, 5, 7}; + return BitCast( + d, Vec256{_mm256_permutex2var_epi64( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +#else + const RebindToFloat df; + const Vec256 v31{ + _mm256_shuffle_pd(BitCast(df, lo).raw, BitCast(df, hi).raw, 15)}; + return VFromD{ + _mm256_permute4x64_epi64(BitCast(du, v31).raw, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +template +HWY_API Vec256 ConcatOdd(D d, Vec256 hi, Vec256 lo) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[4] = {1, 3, 5, 7}; + return Vec256{ + _mm256_permutex2var_pd(lo.raw, Load(du, kIdx).raw, hi.raw)}; +#else + (void)d; + const Vec256 v31{_mm256_shuffle_pd(lo.raw, hi.raw, 15)}; + return Vec256{ + _mm256_permute4x64_pd(v31.raw, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +// ------------------------------ ConcatEven + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3_DL + alignas(64) static constexpr uint8_t kIdx[32] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; + return BitCast( + d, Vec256{_mm256_permutex2var_epi8( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Isolate lower 8 bits per u16 so we can pack. + const Vec256 mask = Set(dw, 0x00FF); + const Vec256 uH = And(BitCast(dw, hi), mask); + const Vec256 uL = And(BitCast(dw, lo), mask); + const __m256i u8 = _mm256_packus_epi16(uL.raw, uH.raw); + return VFromD{_mm256_permute4x64_epi64(u8, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) static constexpr uint16_t kIdx[16] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; + return BitCast( + d, Vec256{_mm256_permutex2var_epi16( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Isolate lower 16 bits per u32 so we can pack. + const Vec256 mask = Set(dw, 0x0000FFFF); + const Vec256 uH = And(BitCast(dw, hi), mask); + const Vec256 uL = And(BitCast(dw, lo), mask); + const __m256i u16 = _mm256_packus_epi32(uL.raw, uH.raw); + return BitCast(d, VFromD{_mm256_permute4x64_epi64( + u16, _MM_SHUFFLE(3, 1, 2, 0))}); +#endif +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) static constexpr uint32_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return BitCast( + d, Vec256{_mm256_permutex2var_epi32( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +#else + const RebindToFloat df; + const Vec256 v2020{_mm256_shuffle_ps( + BitCast(df, lo).raw, BitCast(df, hi).raw, _MM_SHUFFLE(2, 0, 2, 0))}; + return VFromD{_mm256_permute4x64_epi64(BitCast(du, v2020).raw, + _MM_SHUFFLE(3, 1, 2, 0))}; + +#endif +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) static constexpr uint32_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return VFromD{_mm256_permutex2var_ps(lo.raw, Load(du, kIdx).raw, hi.raw)}; +#else + const VFromD v2020{ + _mm256_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(2, 0, 2, 0))}; + return BitCast(d, Vec256{_mm256_permute4x64_epi64( + BitCast(du, v2020).raw, _MM_SHUFFLE(3, 1, 2, 0))}); + +#endif +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3 + alignas(64) static constexpr uint64_t kIdx[4] = {0, 2, 4, 6}; + return BitCast( + d, Vec256{_mm256_permutex2var_epi64( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +#else + const RebindToFloat df; + const Vec256 v20{ + _mm256_shuffle_pd(BitCast(df, lo).raw, BitCast(df, hi).raw, 0)}; + return VFromD{ + _mm256_permute4x64_epi64(BitCast(du, v20).raw, _MM_SHUFFLE(3, 1, 2, 0))}; + +#endif +} + +template +HWY_API Vec256 ConcatEven(D d, Vec256 hi, Vec256 lo) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[4] = {0, 2, 4, 6}; + return Vec256{ + _mm256_permutex2var_pd(lo.raw, Load(du, kIdx).raw, hi.raw)}; +#else + (void)d; + const Vec256 v20{_mm256_shuffle_pd(lo.raw, hi.raw, 0)}; + return Vec256{ + _mm256_permute4x64_pd(v20.raw, _MM_SHUFFLE(3, 1, 2, 0))}; +#endif +} + +// ------------------------------ InterleaveWholeLower + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { +#if HWY_TARGET <= HWY_AVX3_DL + const RebindToUnsigned du; + alignas(32) static constexpr uint8_t kIdx[32] = { + 0, 32, 1, 33, 2, 34, 3, 35, 4, 36, 5, 37, 6, 38, 7, 39, + 8, 40, 9, 41, 10, 42, 11, 43, 12, 44, 13, 45, 14, 46, 15, 47}; + return VFromD{_mm256_permutex2var_epi8(a.raw, Load(du, kIdx).raw, b.raw)}; +#else + return ConcatLowerLower(d, InterleaveUpper(d, a, b), InterleaveLower(a, b)); +#endif +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint16_t kIdx[16] = {0, 16, 1, 17, 2, 18, 3, 19, + 4, 20, 5, 21, 6, 22, 7, 23}; + return BitCast( + d, VFromD{_mm256_permutex2var_epi16( + BitCast(du, a).raw, Load(du, kIdx).raw, BitCast(du, b).raw)}); +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint32_t kIdx[8] = {0, 8, 1, 9, 2, 10, 3, 11}; + return VFromD{_mm256_permutex2var_epi32(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint32_t kIdx[8] = {0, 8, 1, 9, 2, 10, 3, 11}; + return VFromD{_mm256_permutex2var_ps(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint64_t kIdx[4] = {0, 4, 1, 5}; + return VFromD{_mm256_permutex2var_epi64(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint64_t kIdx[4] = {0, 4, 1, 5}; + return VFromD{_mm256_permutex2var_pd(a.raw, Load(du, kIdx).raw, b.raw)}; +} +#else // AVX2 +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + return ConcatLowerLower(d, InterleaveUpper(d, a, b), InterleaveLower(a, b)); +} +#endif + +// ------------------------------ InterleaveWholeUpper + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { +#if HWY_TARGET <= HWY_AVX3_DL + const RebindToUnsigned du; + alignas(32) static constexpr uint8_t kIdx[32] = { + 16, 48, 17, 49, 18, 50, 19, 51, 20, 52, 21, 53, 22, 54, 23, 55, + 24, 56, 25, 57, 26, 58, 27, 59, 28, 60, 29, 61, 30, 62, 31, 63}; + return VFromD{_mm256_permutex2var_epi8(a.raw, Load(du, kIdx).raw, b.raw)}; +#else + return ConcatUpperUpper(d, InterleaveUpper(d, a, b), InterleaveLower(a, b)); +#endif +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint16_t kIdx[16] = { + 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}; + return BitCast( + d, VFromD{_mm256_permutex2var_epi16( + BitCast(du, a).raw, Load(du, kIdx).raw, BitCast(du, b).raw)}); +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint32_t kIdx[8] = {4, 12, 5, 13, 6, 14, 7, 15}; + return VFromD{_mm256_permutex2var_epi32(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint32_t kIdx[8] = {4, 12, 5, 13, 6, 14, 7, 15}; + return VFromD{_mm256_permutex2var_ps(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint64_t kIdx[4] = {2, 6, 3, 7}; + return VFromD{_mm256_permutex2var_epi64(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(32) static constexpr uint64_t kIdx[4] = {2, 6, 3, 7}; + return VFromD{_mm256_permutex2var_pd(a.raw, Load(du, kIdx).raw, b.raw)}; +} +#else // AVX2 +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + return ConcatUpperUpper(d, InterleaveUpper(d, a, b), InterleaveLower(a, b)); +} +#endif + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec256 DupEven(Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} +HWY_API Vec256 DupEven(Vec256 v) { + return Vec256{ + _mm256_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; +} + +template +HWY_API Vec256 DupEven(const Vec256 v) { + const DFromV d; + return InterleaveLower(d, v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec256 DupOdd(Vec256 v) { + return Vec256{_mm256_shuffle_epi32(v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} +HWY_API Vec256 DupOdd(Vec256 v) { + return Vec256{ + _mm256_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; +} + +template +HWY_API Vec256 DupOdd(const Vec256 v) { + const DFromV d; + return InterleaveUpper(d, v, v); +} + +// ------------------------------ OddEven + +template +HWY_INLINE Vec256 OddEven(Vec256 a, Vec256 b) { + const DFromV d; + const Full256 d8; + const VFromD mask = + Dup128VecFromValues(d8, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, + 0, 0xFF, 0, 0xFF, 0); + return IfThenElse(MaskFromVec(BitCast(d, mask)), b, a); +} + +template +HWY_INLINE Vec256 OddEven(Vec256 a, Vec256 b) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm256_blend_epi16( + BitCast(du, a).raw, BitCast(du, b).raw, 0x55)}); +} + +#if HWY_HAVE_FLOAT16 +HWY_INLINE Vec256 OddEven(Vec256 a, Vec256 b) { + return Vec256{ + _mm256_mask_blend_ph(static_cast<__mmask16>(0x5555), a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_INLINE Vec256 OddEven(Vec256 a, Vec256 b) { + return Vec256{_mm256_blend_epi32(a.raw, b.raw, 0x55)}; +} + +template +HWY_INLINE Vec256 OddEven(Vec256 a, Vec256 b) { + return Vec256{_mm256_blend_epi32(a.raw, b.raw, 0x33)}; +} + +HWY_API Vec256 OddEven(Vec256 a, Vec256 b) { + return Vec256{_mm256_blend_ps(a.raw, b.raw, 0x55)}; +} + +HWY_API Vec256 OddEven(Vec256 a, Vec256 b) { + return Vec256{_mm256_blend_pd(a.raw, b.raw, 5)}; +} + +// -------------------------- InterleaveEven + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm256_mask_shuffle_epi32( + a.raw, static_cast<__mmask8>(0xAA), b.raw, + static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(2, 2, 0, 0)))}; +} +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm256_mask_shuffle_ps(a.raw, static_cast<__mmask8>(0xAA), + b.raw, b.raw, + _MM_SHUFFLE(2, 2, 0, 0))}; +} +#else +template +HWY_API VFromD InterleaveEven(D d, VFromD a, VFromD b) { + const RebindToFloat df; + const VFromD b2_b0_a2_a0{_mm256_shuffle_ps( + BitCast(df, a).raw, BitCast(df, b).raw, _MM_SHUFFLE(2, 0, 2, 0))}; + return BitCast( + d, VFromD{_mm256_shuffle_ps( + b2_b0_a2_a0.raw, b2_b0_a2_a0.raw, _MM_SHUFFLE(3, 1, 2, 0))}); +} +#endif + +// I64/U64/F64 InterleaveEven is generic for vector lengths >= 32 bytes +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return InterleaveLower(a, b); +} + +// -------------------------- InterleaveOdd + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm256_mask_shuffle_epi32( + b.raw, static_cast<__mmask8>(0x55), a.raw, + static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(3, 3, 1, 1)))}; +} +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm256_mask_shuffle_ps(b.raw, static_cast<__mmask8>(0x55), + a.raw, a.raw, + _MM_SHUFFLE(3, 3, 1, 1))}; +} +#else +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + const RebindToFloat df; + const VFromD b3_b1_a3_a3{_mm256_shuffle_ps( + BitCast(df, a).raw, BitCast(df, b).raw, _MM_SHUFFLE(3, 1, 3, 1))}; + return BitCast( + d, VFromD{_mm256_shuffle_ps( + b3_b1_a3_a3.raw, b3_b1_a3_a3.raw, _MM_SHUFFLE(3, 1, 2, 0))}); +} +#endif + +// I64/U64/F64 InterleaveOdd is generic for vector lengths >= 32 bytes +template +HWY_API VFromD InterleaveOdd(D d, VFromD a, VFromD b) { + return InterleaveUpper(d, a, b); +} + +// ------------------------------ OddEvenBlocks + +template +Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm256_blend_epi32( + BitCast(du, odd).raw, BitCast(du, even).raw, 0xFu)}); +} + +HWY_API Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { + return Vec256{_mm256_blend_ps(odd.raw, even.raw, 0xFu)}; +} + +HWY_API Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { + return Vec256{_mm256_blend_pd(odd.raw, even.raw, 0x3u)}; +} + +// ------------------------------ ReverseBlocks (SwapAdjacentBlocks) + +template +HWY_API VFromD ReverseBlocks(D /*d*/, VFromD v) { + return SwapAdjacentBlocks(v); +} + +// ------------------------------ TableLookupBytes (ZeroExtendVector) + +// Both full +template +HWY_API Vec256 TableLookupBytes(Vec256 bytes, Vec256 from) { + const DFromV d; + return BitCast(d, Vec256{_mm256_shuffle_epi8( + BitCast(Full256(), bytes).raw, + BitCast(Full256(), from).raw)}); +} + +// Partial index vector +template +HWY_API Vec128 TableLookupBytes(Vec256 bytes, Vec128 from) { + const Full256 di; + const Half dih; + // First expand to full 128, then 256. + const auto from_256 = ZeroExtendVector(di, Vec128{from.raw}); + const auto tbl_full = TableLookupBytes(bytes, from_256); + // Shrink to 128, then partial. + return Vec128{LowerHalf(dih, tbl_full).raw}; +} + +// Partial table vector +template +HWY_API Vec256 TableLookupBytes(Vec128 bytes, Vec256 from) { + const Full256 d; + // First expand to full 128, then 256. + const auto bytes_256 = ZeroExtendVector(d, Vec128{bytes.raw}); + return TableLookupBytes(bytes_256, from); +} + +// Partial both are handled by x86_128. + +// ------------------------------ I8/U8 Broadcast (TableLookupBytes) + +template +HWY_API Vec256 Broadcast(const Vec256 v) { + static_assert(0 <= kLane && kLane < 16, "Invalid lane"); + return TableLookupBytes(v, Set(Full256(), static_cast(kLane))); +} + +// ------------------------------ Per4LaneBlockShuffle + +namespace detail { + +template +HWY_INLINE VFromD Per4LaneBlkShufDupSet4xU32(D d, const uint32_t x3, + const uint32_t x2, + const uint32_t x1, + const uint32_t x0) { + return BitCast(d, Vec256{_mm256_set_epi32( + static_cast(x3), static_cast(x2), + static_cast(x1), static_cast(x0), + static_cast(x3), static_cast(x2), + static_cast(x1), static_cast(x0))}); +} + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + return V{_mm256_shuffle_epi32(v.raw, static_cast(kIdx3210 & 0xFF))}; +} + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + return V{_mm256_shuffle_ps(v.raw, v.raw, static_cast(kIdx3210 & 0xFF))}; +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x44> /*idx_3210_tag*/, + hwy::SizeTag<8> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + const DFromV d; + return ConcatLowerLower(d, v, v); +} + +template +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xEE> /*idx_3210_tag*/, + hwy::SizeTag<8> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + const DFromV d; + return ConcatUpperUpper(d, v, v); +} + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<8> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + return V{_mm256_permute4x64_epi64(v.raw, static_cast(kIdx3210 & 0xFF))}; +} + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<8> /*lane_size_tag*/, + hwy::SizeTag<32> /*vect_size_tag*/, V v) { + return V{_mm256_permute4x64_pd(v.raw, static_cast(kIdx3210 & 0xFF))}; +} + +} // namespace detail + +// ------------------------------ SlideUpLanes + +namespace detail { + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_INLINE V CombineShiftRightI32Lanes(V hi, V lo) { + const DFromV d; + const Repartition du32; + return BitCast(d, + Vec256{_mm256_alignr_epi32( + BitCast(du32, hi).raw, BitCast(du32, lo).raw, kI32Lanes)}); +} + +template +HWY_INLINE V CombineShiftRightI64Lanes(V hi, V lo) { + const DFromV d; + const Repartition du64; + return BitCast(d, + Vec256{_mm256_alignr_epi64( + BitCast(du64, hi).raw, BitCast(du64, lo).raw, kI64Lanes)}); +} + +template +HWY_INLINE V SlideUpI64Lanes(V v) { + static_assert(0 <= kI64Lanes && kI64Lanes <= 3, + "kI64Lanes must be between 0 and 3"); + const DFromV d; + return CombineShiftRightI64Lanes<4 - kI64Lanes>(v, Zero(d)); +} +#else // AVX2 +template )> +HWY_INLINE V SlideUpI64Lanes(V v) { + static_assert(0 <= kI64Lanes && kI64Lanes <= 3, + "kI64Lanes must be between 0 and 3"); + constexpr int kIdx0 = (-kI64Lanes) & 3; + constexpr int kIdx1 = (-kI64Lanes + 1) & 3; + constexpr int kIdx2 = (-kI64Lanes + 2) & 3; + constexpr int kIdx3 = (-kI64Lanes + 3) & 3; + constexpr int kIdx3210 = _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kIdx0); + constexpr int kBlendMask = (1 << (kI64Lanes * 2)) - 1; + + const DFromV d; + return V{_mm256_blend_epi32(_mm256_permute4x64_epi64(v.raw, kIdx3210), + Zero(d).raw, kBlendMask)}; +} + +template )> +HWY_INLINE V SlideUpI64Lanes(V v) { + static_assert(0 <= kI64Lanes && kI64Lanes <= 3, + "kI64Lanes must be between 0 and 3"); + constexpr int kIdx0 = (-kI64Lanes) & 3; + constexpr int kIdx1 = (-kI64Lanes + 1) & 3; + constexpr int kIdx2 = (-kI64Lanes + 2) & 3; + constexpr int kIdx3 = (-kI64Lanes + 3) & 3; + constexpr int kIdx3210 = _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kIdx0); + constexpr int kBlendMask = (1 << kI64Lanes) - 1; + + const DFromV d; + const Repartition dd; + return BitCast(d, Vec256{_mm256_blend_pd( + _mm256_permute4x64_pd(BitCast(dd, v).raw, kIdx3210), + Zero(dd).raw, kBlendMask)}); +} +#endif // HWY_TARGET <= HWY_AVX3 + +template HWY_AVX3) ? (1 << 2) : 0))> +HWY_INLINE VFromD TableLookupSlideUpLanes(D d, VFromD v, size_t amt) { + const Repartition du8; + + const auto idx_vec = + Iota(du8, static_cast(size_t{0} - amt * sizeof(TFromD))); + const Indices256> idx{idx_vec.raw}; + +#if HWY_TARGET <= HWY_AVX3_DL + return TwoTablesLookupLanes(v, Zero(d), idx); +#else + return TableLookupLanes(v, idx); +#endif +} + +template +HWY_INLINE VFromD TableLookupSlideUpLanes(D d, VFromD v, size_t amt) { + const RebindToUnsigned du; + using TU = TFromD; + + const auto idx = Iota(du, static_cast(size_t{0} - amt)); +#if HWY_TARGET <= HWY_AVX3 + const auto masked_idx = + And(idx, Set(du, static_cast(MaxLanes(d) * 2 - 1))); + return TwoTablesLookupLanes(v, Zero(d), IndicesFromVec(d, masked_idx)); +#else + const auto masked_idx = And(idx, Set(du, static_cast(MaxLanes(d) - 1))); + return IfThenElseZero(RebindMask(d, idx == masked_idx), + TableLookupLanes(v, IndicesFromVec(d, masked_idx))); +#endif +} + +#if HWY_TARGET > HWY_AVX3 +template +HWY_INLINE VFromD TableLookupSlideUpLanes(D d, VFromD v, size_t amt) { + const RepartitionToNarrow dn; + return BitCast(d, TableLookupSlideUpLanes(dn, BitCast(dn, v), amt * 2)); +} +#endif // HWY_TARGET > HWY_AVX3 + +} // namespace detail + +template +HWY_API VFromD SlideUpBlocks(D d, VFromD v) { + static_assert(0 <= kBlocks && kBlocks <= 1, + "kBlocks must be between 0 and 1"); + return (kBlocks == 1) ? ConcatLowerLower(d, v, Zero(d)) : v; +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + constexpr size_t kLanesPerBlock = 16 / sizeof(TFromD); + if (__builtin_constant_p(amt)) { + const auto v_lo = ConcatLowerLower(d, v, Zero(d)); + switch (amt * sizeof(TFromD)) { + case 0: + return v; + case 1: + return CombineShiftRightBytes<15>(d, v, v_lo); + case 2: + return CombineShiftRightBytes<14>(d, v, v_lo); + case 3: + return CombineShiftRightBytes<13>(d, v, v_lo); + case 4: +#if HWY_TARGET <= HWY_AVX3 + return detail::CombineShiftRightI32Lanes<7>(v, Zero(d)); +#else + return CombineShiftRightBytes<12>(d, v, v_lo); +#endif + case 5: + return CombineShiftRightBytes<11>(d, v, v_lo); + case 6: + return CombineShiftRightBytes<10>(d, v, v_lo); + case 7: + return CombineShiftRightBytes<9>(d, v, v_lo); + case 8: + return detail::SlideUpI64Lanes<1>(v); + case 9: + return CombineShiftRightBytes<7>(d, v, v_lo); + case 10: + return CombineShiftRightBytes<6>(d, v, v_lo); + case 11: + return CombineShiftRightBytes<5>(d, v, v_lo); + case 12: +#if HWY_TARGET <= HWY_AVX3 + return detail::CombineShiftRightI32Lanes<5>(v, Zero(d)); +#else + return CombineShiftRightBytes<4>(d, v, v_lo); +#endif + case 13: + return CombineShiftRightBytes<3>(d, v, v_lo); + case 14: + return CombineShiftRightBytes<2>(d, v, v_lo); + case 15: + return CombineShiftRightBytes<1>(d, v, v_lo); + case 16: + return ConcatLowerLower(d, v, Zero(d)); +#if HWY_TARGET <= HWY_AVX3 + case 20: + return detail::CombineShiftRightI32Lanes<3>(v, Zero(d)); +#endif + case 24: + return detail::SlideUpI64Lanes<3>(v); +#if HWY_TARGET <= HWY_AVX3 + case 28: + return detail::CombineShiftRightI32Lanes<1>(v, Zero(d)); +#endif + } + } + + if (__builtin_constant_p(amt >= kLanesPerBlock) && amt >= kLanesPerBlock) { + const Half dh; + return Combine(d, SlideUpLanes(dh, LowerHalf(dh, v), amt - kLanesPerBlock), + Zero(dh)); + } +#endif + + return detail::TableLookupSlideUpLanes(d, v, amt); +} + +// ------------------------------ Slide1Up + +template +HWY_API VFromD Slide1Up(D d, VFromD v) { + const auto v_lo = ConcatLowerLower(d, v, Zero(d)); + return CombineShiftRightBytes<15>(d, v, v_lo); +} + +template +HWY_API VFromD Slide1Up(D d, VFromD v) { + const auto v_lo = ConcatLowerLower(d, v, Zero(d)); + return CombineShiftRightBytes<14>(d, v, v_lo); +} + +template +HWY_API VFromD Slide1Up(D d, VFromD v) { +#if HWY_TARGET <= HWY_AVX3 + return detail::CombineShiftRightI32Lanes<7>(v, Zero(d)); +#else + const auto v_lo = ConcatLowerLower(d, v, Zero(d)); + return CombineShiftRightBytes<12>(d, v, v_lo); +#endif +} + +template +HWY_API VFromD Slide1Up(D /*d*/, VFromD v) { + return detail::SlideUpI64Lanes<1>(v); +} + +// ------------------------------ SlideDownLanes + +namespace detail { + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_INLINE V SlideDownI64Lanes(V v) { + static_assert(0 <= kI64Lanes && kI64Lanes <= 3, + "kI64Lanes must be between 0 and 3"); + const DFromV d; + return CombineShiftRightI64Lanes(Zero(d), v); +} +#else // AVX2 +template )> +HWY_INLINE V SlideDownI64Lanes(V v) { + static_assert(0 <= kI64Lanes && kI64Lanes <= 3, + "kI64Lanes must be between 0 and 3"); + constexpr int kIdx1 = (kI64Lanes + 1) & 3; + constexpr int kIdx2 = (kI64Lanes + 2) & 3; + constexpr int kIdx3 = (kI64Lanes + 3) & 3; + constexpr int kIdx3210 = _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kI64Lanes); + constexpr int kBlendMask = + static_cast((0xFFu << ((4 - kI64Lanes) * 2)) & 0xFFu); + + const DFromV d; + return V{_mm256_blend_epi32(_mm256_permute4x64_epi64(v.raw, kIdx3210), + Zero(d).raw, kBlendMask)}; +} + +template )> +HWY_INLINE V SlideDownI64Lanes(V v) { + static_assert(0 <= kI64Lanes && kI64Lanes <= 3, + "kI64Lanes must be between 0 and 3"); + constexpr int kIdx1 = (kI64Lanes + 1) & 3; + constexpr int kIdx2 = (kI64Lanes + 2) & 3; + constexpr int kIdx3 = (kI64Lanes + 3) & 3; + constexpr int kIdx3210 = _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kI64Lanes); + constexpr int kBlendMask = (0x0F << (4 - kI64Lanes)) & 0x0F; + + const DFromV d; + const Repartition dd; + return BitCast(d, Vec256{_mm256_blend_pd( + _mm256_permute4x64_pd(BitCast(dd, v).raw, kIdx3210), + Zero(dd).raw, kBlendMask)}); +} +#endif // HWY_TARGET <= HWY_AVX3 + +template HWY_AVX3) ? (1 << 2) : 0))> +HWY_INLINE VFromD TableLookupSlideDownLanes(D d, VFromD v, size_t amt) { + const Repartition du8; + + auto idx_vec = Iota(du8, static_cast(amt * sizeof(TFromD))); + +#if HWY_TARGET <= HWY_AVX3_DL + const auto result_mask = idx_vec < Set(du8, uint8_t{32}); + return VFromD{ + _mm256_maskz_permutexvar_epi8(result_mask.raw, idx_vec.raw, v.raw)}; +#else + const RebindToSigned di8; + idx_vec = + Or(idx_vec, BitCast(du8, VecFromMask(di8, BitCast(di8, idx_vec) > + Set(di8, int8_t{31})))); + return TableLookupLanes(v, Indices256>{idx_vec.raw}); +#endif +} + +template +HWY_INLINE VFromD TableLookupSlideDownLanes(D d, VFromD v, size_t amt) { + const RebindToUnsigned du; + using TU = TFromD; + + const auto idx = Iota(du, static_cast(amt)); + const auto masked_idx = And(idx, Set(du, static_cast(MaxLanes(d) - 1))); + + return IfThenElseZero(RebindMask(d, idx == masked_idx), + TableLookupLanes(v, IndicesFromVec(d, masked_idx))); +} + +#if HWY_TARGET > HWY_AVX3 +template +HWY_INLINE VFromD TableLookupSlideDownLanes(D d, VFromD v, size_t amt) { + const RepartitionToNarrow dn; + return BitCast(d, TableLookupSlideDownLanes(dn, BitCast(dn, v), amt * 2)); +} +#endif // HWY_TARGET > HWY_AVX3 + +} // namespace detail + +template +HWY_API VFromD SlideDownBlocks(D d, VFromD v) { + static_assert(0 <= kBlocks && kBlocks <= 1, + "kBlocks must be between 0 and 1"); + const Half dh; + return (kBlocks == 1) ? ZeroExtendVector(d, UpperHalf(dh, v)) : v; +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + constexpr size_t kLanesPerBlock = 16 / sizeof(TFromD); + const Half dh; + if (__builtin_constant_p(amt)) { + const auto v_hi = ZeroExtendVector(d, UpperHalf(dh, v)); + switch (amt * sizeof(TFromD)) { + case 0: + return v; + case 1: + return CombineShiftRightBytes<1>(d, v_hi, v); + case 2: + return CombineShiftRightBytes<2>(d, v_hi, v); + case 3: + return CombineShiftRightBytes<3>(d, v_hi, v); + case 4: +#if HWY_TARGET <= HWY_AVX3 + return detail::CombineShiftRightI32Lanes<1>(Zero(d), v); +#else + return CombineShiftRightBytes<4>(d, v_hi, v); +#endif + case 5: + return CombineShiftRightBytes<5>(d, v_hi, v); + case 6: + return CombineShiftRightBytes<6>(d, v_hi, v); + case 7: + return CombineShiftRightBytes<7>(d, v_hi, v); + case 8: + return detail::SlideDownI64Lanes<1>(v); + case 9: + return CombineShiftRightBytes<9>(d, v_hi, v); + case 10: + return CombineShiftRightBytes<10>(d, v_hi, v); + case 11: + return CombineShiftRightBytes<11>(d, v_hi, v); + case 12: +#if HWY_TARGET <= HWY_AVX3 + return detail::CombineShiftRightI32Lanes<3>(Zero(d), v); +#else + return CombineShiftRightBytes<12>(d, v_hi, v); +#endif + case 13: + return CombineShiftRightBytes<13>(d, v_hi, v); + case 14: + return CombineShiftRightBytes<14>(d, v_hi, v); + case 15: + return CombineShiftRightBytes<15>(d, v_hi, v); + case 16: + return v_hi; +#if HWY_TARGET <= HWY_AVX3 + case 20: + return detail::CombineShiftRightI32Lanes<5>(Zero(d), v); +#endif + case 24: + return detail::SlideDownI64Lanes<3>(v); +#if HWY_TARGET <= HWY_AVX3 + case 28: + return detail::CombineShiftRightI32Lanes<7>(Zero(d), v); +#endif + } + } + + if (__builtin_constant_p(amt >= kLanesPerBlock) && amt >= kLanesPerBlock) { + return ZeroExtendVector( + d, SlideDownLanes(dh, UpperHalf(dh, v), amt - kLanesPerBlock)); + } +#endif + + return detail::TableLookupSlideDownLanes(d, v, amt); +} + +// ------------------------------ Slide1Down + +template +HWY_API VFromD Slide1Down(D d, VFromD v) { + const Half dh; + const auto v_hi = ZeroExtendVector(d, UpperHalf(dh, v)); + return CombineShiftRightBytes<1>(d, v_hi, v); +} + +template +HWY_API VFromD Slide1Down(D d, VFromD v) { + const Half dh; + const auto v_hi = ZeroExtendVector(d, UpperHalf(dh, v)); + return CombineShiftRightBytes<2>(d, v_hi, v); +} + +template +HWY_API VFromD Slide1Down(D d, VFromD v) { +#if HWY_TARGET <= HWY_AVX3 + return detail::CombineShiftRightI32Lanes<1>(Zero(d), v); +#else + const Half dh; + const auto v_hi = ZeroExtendVector(d, UpperHalf(dh, v)); + return CombineShiftRightBytes<4>(d, v_hi, v); +#endif +} + +template +HWY_API VFromD Slide1Down(D /*d*/, VFromD v) { + return detail::SlideDownI64Lanes<1>(v); +} + +// ------------------------------ Shl (Mul, ZipLower) + +namespace detail { + +#if HWY_TARGET > HWY_AVX3 && !HWY_IDE // AVX2 or older +template +HWY_INLINE V AVX2ShlU16Vec256(V v, V bits) { + const DFromV d; + const Half dh; + const Rebind du32; + + const auto lo_shl_result = PromoteTo(du32, LowerHalf(dh, v)) + << PromoteTo(du32, LowerHalf(dh, bits)); + const auto hi_shl_result = PromoteTo(du32, UpperHalf(dh, v)) + << PromoteTo(du32, UpperHalf(dh, bits)); + return ConcatEven(d, BitCast(d, hi_shl_result), BitCast(d, lo_shl_result)); +} +#endif + +HWY_INLINE Vec256 Shl(hwy::UnsignedTag /*tag*/, Vec256 v, + Vec256 bits) { +#if HWY_TARGET <= HWY_AVX3 || HWY_IDE + return Vec256{_mm256_sllv_epi16(v.raw, bits.raw)}; +#else + return AVX2ShlU16Vec256(v, bits); +#endif +} + +// 8-bit: may use the Shl overload for uint16_t. +HWY_API Vec256 Shl(hwy::UnsignedTag tag, Vec256 v, + Vec256 bits) { + const DFromV d; +#if HWY_TARGET <= HWY_AVX3_DL + (void)tag; + // masks[i] = 0xFF >> i + const VFromD masks = + Dup128VecFromValues(d, 0xFF, 0x7F, 0x3F, 0x1F, 0x0F, 0x07, 0x03, 0x01, 0, + 0, 0, 0, 0, 0, 0, 0); + // kShl[i] = 1 << i + const VFromD shl = Dup128VecFromValues( + d, 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80, 0, 0, 0, 0, 0, 0, 0, 0); + v = And(v, TableLookupBytes(masks, bits)); + const VFromD mul = TableLookupBytes(shl, bits); + return VFromD{_mm256_gf2p8mul_epi8(v.raw, mul.raw)}; +#else + const Repartition dw; + using VW = VFromD; + const VW even_mask = Set(dw, 0x00FF); + const VW odd_mask = Set(dw, 0xFF00); + const VW vw = BitCast(dw, v); + const VW bits16 = BitCast(dw, bits); + // Shift even lanes in-place + const VW evens = Shl(tag, vw, And(bits16, even_mask)); + const VW odds = Shl(tag, And(vw, odd_mask), ShiftRight<8>(bits16)); + return OddEven(BitCast(d, odds), BitCast(d, evens)); +#endif +} + +HWY_INLINE Vec256 Shl(hwy::UnsignedTag /*tag*/, Vec256 v, + Vec256 bits) { + return Vec256{_mm256_sllv_epi32(v.raw, bits.raw)}; +} + +HWY_INLINE Vec256 Shl(hwy::UnsignedTag /*tag*/, Vec256 v, + Vec256 bits) { + return Vec256{_mm256_sllv_epi64(v.raw, bits.raw)}; +} + +template +HWY_INLINE Vec256 Shl(hwy::SignedTag /*tag*/, Vec256 v, Vec256 bits) { + // Signed left shifts are the same as unsigned. + const Full256 di; + const Full256> du; + return BitCast(di, + Shl(hwy::UnsignedTag(), BitCast(du, v), BitCast(du, bits))); +} + +} // namespace detail + +template +HWY_API Vec256 operator<<(Vec256 v, Vec256 bits) { + return detail::Shl(hwy::TypeTag(), v, bits); +} + +// ------------------------------ Shr (MulHigh, IfThenElse, Not) + +#if HWY_TARGET > HWY_AVX3 // AVX2 +namespace detail { + +template +HWY_INLINE V AVX2ShrU16Vec256(V v, V bits) { + const DFromV d; + const Half dh; + const Rebind di32; + const Rebind du32; + + const auto lo_shr_result = + PromoteTo(du32, LowerHalf(dh, v)) >> PromoteTo(du32, LowerHalf(dh, bits)); + const auto hi_shr_result = + PromoteTo(du32, UpperHalf(dh, v)) >> PromoteTo(du32, UpperHalf(dh, bits)); + return OrderedDemote2To(d, BitCast(di32, lo_shr_result), + BitCast(di32, hi_shr_result)); +} + +} // namespace detail +#endif + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_srlv_epi16(v.raw, bits.raw)}; +#else + return detail::AVX2ShrU16Vec256(v, bits); +#endif +} + +// 8-bit uses 16-bit shifts. +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + const DFromV d; + const RepartitionToWide dw; + using VW = VFromD; + const VW mask = Set(dw, 0x00FF); + const VW vw = BitCast(dw, v); + const VW bits16 = BitCast(dw, bits); + const VW evens = And(vw, mask) >> And(bits16, mask); + // Shift odd lanes in-place + const VW odds = vw >> ShiftRight<8>(bits16); + return OddEven(BitCast(d, odds), BitCast(d, evens)); +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + return Vec256{_mm256_srlv_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + return Vec256{_mm256_srlv_epi64(v.raw, bits.raw)}; +} + +#if HWY_TARGET > HWY_AVX3 // AVX2 +namespace detail { + +template +HWY_INLINE V AVX2ShrI16Vec256(V v, V bits) { + const DFromV d; + const Half dh; + const Rebind di32; + + const auto lo_shr_result = + PromoteTo(di32, LowerHalf(dh, v)) >> PromoteTo(di32, LowerHalf(dh, bits)); + const auto hi_shr_result = + PromoteTo(di32, UpperHalf(dh, v)) >> PromoteTo(di32, UpperHalf(dh, bits)); + return OrderedDemote2To(d, lo_shr_result, hi_shr_result); +} + +} // namespace detail +#endif + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_srav_epi16(v.raw, bits.raw)}; +#else + return detail::AVX2ShrI16Vec256(v, bits); +#endif +} + +// 8-bit uses 16-bit shifts. +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + const DFromV d; + const RepartitionToWide dw; + const RebindToUnsigned dw_u; + using VW = VFromD; + const VW mask = Set(dw, 0x00FF); + const VW vw = BitCast(dw, v); + const VW bits16 = BitCast(dw, bits); + const VW evens = ShiftRight<8>(ShiftLeft<8>(vw)) >> And(bits16, mask); + // Shift odd lanes in-place + const VW odds = vw >> BitCast(dw, ShiftRight<8>(BitCast(dw_u, bits16))); + return OddEven(BitCast(d, odds), BitCast(d, evens)); +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { + return Vec256{_mm256_srav_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec256 operator>>(Vec256 v, Vec256 bits) { +#if HWY_TARGET <= HWY_AVX3 + return Vec256{_mm256_srav_epi64(v.raw, bits.raw)}; +#else + const DFromV d; + return detail::SignedShr(d, v, bits); +#endif +} + +// ------------------------------ WidenMulPairwiseAdd + +#if HWY_NATIVE_DOT_BF16 + +template >> +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + return VFromD{_mm256_dpbf16_ps(Zero(df).raw, + reinterpret_cast<__m256bh>(a.raw), + reinterpret_cast<__m256bh>(b.raw))}; +} + +#endif // HWY_NATIVE_DOT_BF16 + +template +HWY_API VFromD WidenMulPairwiseAdd(D /*d32*/, Vec256 a, + Vec256 b) { + return VFromD{_mm256_madd_epi16(a.raw, b.raw)}; +} + +// ------------------------------ SatWidenMulPairwiseAdd + +template +HWY_API VFromD SatWidenMulPairwiseAdd( + DI16 /* tag */, VFromD> a, + VFromD> b) { + return VFromD{_mm256_maddubs_epi16(a.raw, b.raw)}; +} + +// ------------------------------ SatWidenMulPairwiseAccumulate + +#if HWY_TARGET <= HWY_AVX3_DL +template +HWY_API VFromD SatWidenMulPairwiseAccumulate( + DI32 /* tag */, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD{_mm256_dpwssds_epi32(sum.raw, a.raw, b.raw)}; +} +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ------------------------------ ReorderWidenMulAccumulate + +#if HWY_NATIVE_DOT_BF16 +template >> +HWY_API VFromD ReorderWidenMulAccumulate(DF /*df*/, VBF a, VBF b, + const VFromD sum0, + VFromD& /*sum1*/) { + return VFromD{_mm256_dpbf16_ps(sum0.raw, + reinterpret_cast<__m256bh>(a.raw), + reinterpret_cast<__m256bh>(b.raw))}; +} +#endif // HWY_NATIVE_DOT_BF16 + +template +HWY_API VFromD ReorderWidenMulAccumulate(D d, Vec256 a, + Vec256 b, + const VFromD sum0, + VFromD& /*sum1*/) { + (void)d; +#if HWY_TARGET <= HWY_AVX3_DL + return VFromD{_mm256_dpwssd_epi32(sum0.raw, a.raw, b.raw)}; +#else + return sum0 + WidenMulPairwiseAdd(d, a, b); +#endif +} + +// ------------------------------ SumOfMulQuadAccumulate + +#if HWY_TARGET <= HWY_AVX3_DL + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DI32 /*di32*/, VFromD> a_u, + VFromD> b_i, VFromD sum) { + return VFromD{_mm256_dpbusd_epi32(sum.raw, a_u.raw, b_i.raw)}; +} + +#if HWY_X86_HAVE_AVX10_2_OPS +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 /*di32*/, + VFromD> a, + VFromD> b, + VFromD sum) { + return VFromD{_mm256_dpbssd_epi32(sum.raw, a.raw, b.raw)}; +} + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DU32 /*du32*/, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD{_mm256_dpbuud_epi32(sum.raw, a.raw, b.raw)}; +} +#endif // HWY_X86_HAVE_AVX10_2_OPS + +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm256_cvtps_pd(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm256_cvtepi32_pd(v.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API Vec256 PromoteTo(D /* tag */, Vec128 v) { + return Vec256{_mm256_cvtepu32_pd(v.raw)}; +} +#endif + +// Unsigned: zero-extend. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then Zip* would be faster. +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm256_cvtepu8_epi16(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm256_cvtepu8_epi32(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm256_cvtepu16_epi32(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm256_cvtepu32_epi64(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec64 v) { + return VFromD{_mm256_cvtepu16_epi64(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec32 v) { + return VFromD{_mm256_cvtepu8_epi64(v.raw)}; +} + +// Signed: replicate sign bit. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then ZipUpper/lo followed by +// signed shift would be faster. +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm256_cvtepi8_epi16(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm256_cvtepi8_epi32(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm256_cvtepi16_epi32(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm256_cvtepi32_epi64(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec64 v) { + return VFromD{_mm256_cvtepi16_epi64(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec32 v) { + return VFromD{_mm256_cvtepi8_epi64(v.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD PromoteInRangeTo(D /*di64*/, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm256_cvtts_ps_epi64(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior with GCC if any values of v[i] are not + // within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi64x( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttps2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvttps_epi64(v.raw)}; +#endif // HWY_COMPILER_GCC_ACTUAL +} +template +HWY_API VFromD PromoteInRangeTo(D /* tag */, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm256_cvtts_ps_epu64(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior with GCC if any values of v[i] are not + // within the range of an uint64_t +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(16))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi64x( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttps2uqq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvttps_epu64(v.raw)}; +#endif // HWY_COMPILER_GCC_ACTUAL +} +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ PromoteEvenTo/PromoteOddTo +#if HWY_TARGET > HWY_AVX3 +namespace detail { + +// I32->I64 PromoteEvenTo/PromoteOddTo + +template +HWY_INLINE VFromD PromoteEvenTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + Vec256 v) { + return BitCast(d_to, OddEven(DupEven(BroadcastSignBit(v)), v)); +} + +template +HWY_INLINE VFromD PromoteOddTo(hwy::SignedTag /*to_type_tag*/, + hwy::SizeTag<8> /*to_lane_size_tag*/, + hwy::SignedTag /*from_type_tag*/, D d_to, + Vec256 v) { + return BitCast(d_to, OddEven(BroadcastSignBit(v), DupOdd(v))); +} + +} // namespace detail +#endif + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + const __m256i u16 = _mm256_packus_epi32(v.raw, v.raw); + // Concatenating lower halves of both 128-bit blocks afterward is more + // efficient than an extra input with low block = high block of v. + return VFromD{_mm256_castsi256_si128(_mm256_permute4x64_epi64(u16, 0x88))}; +} + +template +HWY_API VFromD DemoteTo(D dn, Vec256 v) { + const DFromV d; + const RebindToSigned di; + return DemoteTo(dn, BitCast(di, Min(v, Set(d, 0x7FFFFFFFu)))); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + const __m256i i16 = _mm256_packs_epi32(v.raw, v.raw); + return VFromD{_mm256_castsi256_si128(_mm256_permute4x64_epi64(i16, 0x88))}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + const __m256i i16_blocks = _mm256_packs_epi32(v.raw, v.raw); + // Concatenate lower 64 bits of each 128-bit block + const __m256i i16_concat = _mm256_permute4x64_epi64(i16_blocks, 0x88); + const __m128i i16 = _mm256_castsi256_si128(i16_concat); + return VFromD{_mm_packus_epi16(i16, i16)}; +} + +template +HWY_API VFromD DemoteTo(D dn, Vec256 v) { +#if HWY_TARGET <= HWY_AVX3 + (void)dn; + return VFromD{_mm256_cvtusepi32_epi8(v.raw)}; +#else + const DFromV d; + const RebindToSigned di; + return DemoteTo(dn, BitCast(di, Min(v, Set(d, 0x7FFFFFFFu)))); +#endif +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + const __m256i u8 = _mm256_packus_epi16(v.raw, v.raw); + return VFromD{_mm256_castsi256_si128(_mm256_permute4x64_epi64(u8, 0x88))}; +} + +template +HWY_API VFromD DemoteTo(D dn, Vec256 v) { + const DFromV d; + const RebindToSigned di; + return DemoteTo(dn, BitCast(di, Min(v, Set(d, 0x7FFFu)))); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + const __m256i i16_blocks = _mm256_packs_epi32(v.raw, v.raw); + // Concatenate lower 64 bits of each 128-bit block + const __m256i i16_concat = _mm256_permute4x64_epi64(i16_blocks, 0x88); + const __m128i i16 = _mm256_castsi256_si128(i16_concat); + return VFromD{_mm_packs_epi16(i16, i16)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + const __m256i i8 = _mm256_packs_epi16(v.raw, v.raw); + return VFromD{_mm256_castsi256_si128(_mm256_permute4x64_epi64(i8, 0x88))}; +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm256_cvtsepi64_epi32(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm256_cvtsepi64_epi16(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm256_cvtsepi64_epi8(v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + const __mmask8 non_neg_mask = detail::UnmaskedNot(MaskFromVec(v)).raw; + return VFromD{_mm256_maskz_cvtusepi64_epi32(non_neg_mask, v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + const __mmask8 non_neg_mask = detail::UnmaskedNot(MaskFromVec(v)).raw; + return VFromD{_mm256_maskz_cvtusepi64_epi16(non_neg_mask, v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + const __mmask8 non_neg_mask = detail::UnmaskedNot(MaskFromVec(v)).raw; + return VFromD{_mm256_maskz_cvtusepi64_epi8(non_neg_mask, v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm256_cvtusepi64_epi32(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm256_cvtusepi64_epi16(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm256_cvtusepi64_epi8(v.raw)}; +} +#endif // HWY_TARGET <= HWY_AVX3 + +#ifndef HWY_DISABLE_F16C + +// Avoid "value of intrinsic immediate argument '8' is out of range '0 - 7'". +// 8 is the correct value of _MM_FROUND_NO_EXC, which is allowed here. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4556, ignored "-Wsign-conversion") + +template +HWY_API VFromD DemoteTo(D df16, Vec256 v) { + const RebindToUnsigned du16; + return BitCast( + df16, VFromD{_mm256_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}); +} + +HWY_DIAGNOSTICS(pop) + +#endif // HWY_DISABLE_F16C + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD DemoteTo(D /*df16*/, Vec256 v) { + return VFromD{_mm256_cvtpd_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +#if HWY_AVX3_HAVE_F32_TO_BF16C +template +HWY_API VFromD DemoteTo(D /*dbf16*/, Vec256 v) { +#if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 + // Inline assembly workaround for LLVM codegen bug + __m128i raw_result; + __asm__("vcvtneps2bf16 %1, %0" : "=v"(raw_result) : "v"(v.raw)); + return VFromD{raw_result}; +#else + // The _mm256_cvtneps_pbh intrinsic returns a __m128bh vector that needs to be + // bit casted to a __m128i vector + return VFromD{detail::BitCastToInteger(_mm256_cvtneps_pbh(v.raw))}; +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D /*dbf16*/, Vec256 a, + Vec256 b) { +#if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 + // Inline assembly workaround for LLVM codegen bug + __m256i raw_result; + __asm__("vcvtne2ps2bf16 %2, %1, %0" + : "=v"(raw_result) + : "v"(b.raw), "v"(a.raw)); + return VFromD{raw_result}; +#else + // The _mm256_cvtne2ps_pbh intrinsic returns a __m256bh vector that needs to + // be bit casted to a __m256i vector + return VFromD{detail::BitCastToInteger(_mm256_cvtne2ps_pbh(b.raw, a.raw))}; +#endif +} +#endif // HWY_AVX3_HAVE_F32_TO_BF16C + +template +HWY_API VFromD ReorderDemote2To(D /*d16*/, Vec256 a, + Vec256 b) { + return VFromD{_mm256_packs_epi32(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /*d16*/, Vec256 a, + Vec256 b) { + return VFromD{_mm256_packus_epi32(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D dn, Vec256 a, + Vec256 b) { + const DFromV d; + const RebindToSigned di; + const auto max_i32 = Set(d, 0x7FFFFFFFu); + return ReorderDemote2To(dn, BitCast(di, Min(a, max_i32)), + BitCast(di, Min(b, max_i32))); +} + +template +HWY_API VFromD ReorderDemote2To(D /*d16*/, Vec256 a, + Vec256 b) { + return VFromD{_mm256_packs_epi16(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /*d16*/, Vec256 a, + Vec256 b) { + return VFromD{_mm256_packus_epi16(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D dn, Vec256 a, + Vec256 b) { + const DFromV d; + const RebindToSigned di; + const auto max_i16 = Set(d, 0x7FFFu); + return ReorderDemote2To(dn, BitCast(di, Min(a, max_i16)), + BitCast(di, Min(b, max_i16))); +} + +#if HWY_TARGET > HWY_AVX3 +template +HWY_API Vec256 ReorderDemote2To(D dn, Vec256 a, + Vec256 b) { + const DFromV di64; + const RebindToUnsigned du64; + const Half dnh; + const Repartition dn_f; + + // Negative values are saturated by first saturating their bitwise inverse + // and then inverting the saturation result + const auto invert_mask_a = BitCast(du64, BroadcastSignBit(a)); + const auto invert_mask_b = BitCast(du64, BroadcastSignBit(b)); + const auto saturated_a = Xor( + invert_mask_a, + detail::DemoteFromU64Saturate(dnh, Xor(invert_mask_a, BitCast(du64, a)))); + const auto saturated_b = Xor( + invert_mask_b, + detail::DemoteFromU64Saturate(dnh, Xor(invert_mask_b, BitCast(du64, b)))); + + return BitCast(dn, + Vec256{_mm256_shuffle_ps(BitCast(dn_f, saturated_a).raw, + BitCast(dn_f, saturated_b).raw, + _MM_SHUFFLE(2, 0, 2, 0))}); +} + +template +HWY_API Vec256 ReorderDemote2To(D dn, Vec256 a, + Vec256 b) { + const DFromV di64; + const RebindToUnsigned du64; + const Half dnh; + const Repartition dn_f; + + const auto saturated_a = detail::DemoteFromU64Saturate( + dnh, BitCast(du64, AndNot(BroadcastSignBit(a), a))); + const auto saturated_b = detail::DemoteFromU64Saturate( + dnh, BitCast(du64, AndNot(BroadcastSignBit(b), b))); + + return BitCast(dn, + Vec256{_mm256_shuffle_ps(BitCast(dn_f, saturated_a).raw, + BitCast(dn_f, saturated_b).raw, + _MM_SHUFFLE(2, 0, 2, 0))}); +} + +template +HWY_API VFromD ReorderDemote2To(D dn, Vec256 a, + Vec256 b) { + const Half dnh; + const Repartition dn_f; + + const auto saturated_a = detail::DemoteFromU64Saturate(dnh, a); + const auto saturated_b = detail::DemoteFromU64Saturate(dnh, b); + + return BitCast(dn, + Vec256{_mm256_shuffle_ps(BitCast(dn_f, saturated_a).raw, + BitCast(dn_f, saturated_b).raw, + _MM_SHUFFLE(2, 0, 2, 0))}); +} +#endif // HWY_TARGET > HWY_AVX3 + +template ), + HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV) * 2), + HWY_IF_T_SIZE_ONE_OF_V(V, + (1 << 1) | (1 << 2) | (1 << 4) | + ((HWY_TARGET > HWY_AVX3) ? (1 << 8) : 0))> +HWY_API VFromD OrderedDemote2To(D d, V a, V b) { + return VFromD{_mm256_permute4x64_epi64(ReorderDemote2To(d, a, b).raw, + _MM_SHUFFLE(3, 1, 2, 0))}; +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const Half dnh; + return Combine(dn, DemoteTo(dnh, b), DemoteTo(dnh, a)); +} + +template +HWY_API VFromD ReorderDemote2To(D dn, VFromD> a, + VFromD> b) { + const Half dnh; + return Combine(dn, DemoteTo(dnh, b), DemoteTo(dnh, a)); +} + +template ), + HWY_IF_V_SIZE_GT_D(D, 16), class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV) * 2), + HWY_IF_T_SIZE_V(V, 8)> +HWY_API VFromD OrderedDemote2To(D d, V a, V b) { + return ReorderDemote2To(d, a, b); +} +#endif + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm256_cvtpd_ps(v.raw)}; +} + +template +HWY_API VFromD DemoteInRangeTo(D /* tag */, Vec256 v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm256_cvtts_pd_epi32(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttpd_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3])); + } +#endif + + __m128i raw_result; + __asm__("vcvttpd2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm256_cvttpd_epi32(v.raw)}; +#endif +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD DemoteInRangeTo(D /* tag */, Vec256 v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm256_cvtts_pd_epu32(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttpd_epu32 with GCC if any + // values of v[i] are not within the range of an uint32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues( + D(), detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3])); + } +#endif + + __m128i raw_result; + __asm__("vcvttpd2udq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm256_cvttpd_epu32(v.raw)}; +#endif +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm256_cvtepi64_ps(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm256_cvtepu64_ps(v.raw)}; +} +#endif + +// For already range-limited input [0, 255]. +HWY_API Vec128 U8FromU32(const Vec256 v) { + const Full256 d32; + const Full64 d8; + alignas(32) static constexpr uint32_t k8From32[8] = { + 0x0C080400u, ~0u, ~0u, ~0u, ~0u, 0x0C080400u, ~0u, ~0u}; + // Place first four bytes in lo[0], remaining 4 in hi[1]. + const auto quad = TableLookupBytes(v, Load(d32, k8From32)); + // Interleave both quadruplets - OR instead of unpack reduces port5 pressure. + const auto lo = LowerHalf(quad); + const auto hi = UpperHalf(Half(), quad); + return BitCast(d8, LowerHalf(lo | hi)); +} + +// ------------------------------ Truncations + +namespace detail { + +// LO and HI each hold four indices of bytes within a 128-bit block. +template +HWY_INLINE Vec128 LookupAndConcatHalves(Vec256 v) { + const Full256 d32; + +#if HWY_TARGET <= HWY_AVX3_DL + alignas(32) static constexpr uint32_t kMap[8] = { + LO, HI, 0x10101010 + LO, 0x10101010 + HI, 0, 0, 0, 0}; + const auto result = _mm256_permutexvar_epi8(Load(d32, kMap).raw, v.raw); +#else + alignas(32) static constexpr uint32_t kMap[8] = {LO, HI, ~0u, ~0u, + ~0u, ~0u, LO, HI}; + const auto quad = TableLookupBytes(v, Load(d32, kMap)); + const auto result = _mm256_permute4x64_epi64(quad.raw, 0xCC); + // Possible alternative: + // const auto lo = LowerHalf(quad); + // const auto hi = UpperHalf(Half(), quad); + // const auto result = lo | hi; +#endif + + return Vec128{_mm256_castsi256_si128(result)}; +} + +// LO and HI each hold two indices of bytes within a 128-bit block. +template +HWY_INLINE Vec128 LookupAndConcatQuarters(Vec256 v) { + const Full256 d16; + +#if HWY_TARGET <= HWY_AVX3_DL + alignas(32) static constexpr uint16_t kMap[16] = { + LO, HI, 0x1010 + LO, 0x1010 + HI, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + const auto result = _mm256_permutexvar_epi8(Load(d16, kMap).raw, v.raw); + return LowerHalf(Vec128{_mm256_castsi256_si128(result)}); +#else + constexpr uint16_t ff = static_cast(~0u); + alignas(32) static constexpr uint16_t kMap[16] = { + LO, ff, HI, ff, ff, ff, ff, ff, ff, ff, ff, ff, LO, ff, HI, ff}; + const auto quad = TableLookupBytes(v, Load(d16, kMap)); + const auto mixed = _mm256_permute4x64_epi64(quad.raw, 0xCC); + const auto half = _mm256_castsi256_si128(mixed); + return LowerHalf(Vec128{_mm_packus_epi32(half, half)}); +#endif +} + +} // namespace detail + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec256 v) { + const Full256 d32; +#if HWY_TARGET <= HWY_AVX3_DL + alignas(32) static constexpr uint32_t kMap[8] = {0x18100800u, 0, 0, 0, + 0, 0, 0, 0}; + const auto result = _mm256_permutexvar_epi8(Load(d32, kMap).raw, v.raw); + return LowerHalf(LowerHalf(LowerHalf(Vec256{result}))); +#else + alignas(32) static constexpr uint32_t kMap[8] = {0xFFFF0800u, ~0u, ~0u, ~0u, + 0x0800FFFFu, ~0u, ~0u, ~0u}; + const auto quad = TableLookupBytes(v, Load(d32, kMap)); + const auto lo = LowerHalf(quad); + const auto hi = UpperHalf(Half(), quad); + const auto result = lo | hi; + return LowerHalf(LowerHalf(Vec128{result.raw})); +#endif +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec256 v) { + const auto result = detail::LookupAndConcatQuarters<0x100, 0x908>(v); + return VFromD{result.raw}; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec256 v) { + const Full256 d32; + alignas(32) static constexpr uint32_t kEven[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto v32 = + TableLookupLanes(BitCast(d32, v), SetTableIndices(d32, kEven)); + return LowerHalf(Vec256{v32.raw}); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec256 v) { + const auto full = detail::LookupAndConcatQuarters<0x400, 0xC08>(v); + return VFromD{full.raw}; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec256 v) { + const auto full = detail::LookupAndConcatHalves<0x05040100, 0x0D0C0908>(v); + return VFromD{full.raw}; +} + +template +HWY_API VFromD TruncateTo(D /* tag */, Vec256 v) { + const auto full = detail::LookupAndConcatHalves<0x06040200, 0x0E0C0A08>(v); + return VFromD{full.raw}; +} + +// ------------------------------ Integer <=> fp (ShiftRight, OddEven) + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD ConvertTo(D /* tag */, Vec256 v) { + return VFromD{_mm256_cvtepu16_ph(v.raw)}; +} +template +HWY_API VFromD ConvertTo(D /* tag */, Vec256 v) { + return VFromD{_mm256_cvtepi16_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API VFromD ConvertTo(D /* tag */, Vec256 v) { + return VFromD{_mm256_cvtepi32_ps(v.raw)}; +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD ConvertTo(D /*df*/, Vec256 v) { + return VFromD{_mm256_cvtepu32_ps(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /*dd*/, Vec256 v) { + return VFromD{_mm256_cvtepi64_pd(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /*dd*/, Vec256 v) { + return VFromD{_mm256_cvtepu64_pd(v.raw)}; +} +#endif // HWY_TARGET <= HWY_AVX3 + +// Truncates (rounds toward zero). + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD ConvertInRangeTo(D /*d*/, Vec256 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttph_epi16 with GCC if any + // values of v[i] are not within the range of an int16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi16( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7]), + detail::X86ConvertScalarFromFloat(raw_v[8]), + detail::X86ConvertScalarFromFloat(raw_v[9]), + detail::X86ConvertScalarFromFloat(raw_v[10]), + detail::X86ConvertScalarFromFloat(raw_v[11]), + detail::X86ConvertScalarFromFloat(raw_v[12]), + detail::X86ConvertScalarFromFloat(raw_v[13]), + detail::X86ConvertScalarFromFloat(raw_v[14]), + detail::X86ConvertScalarFromFloat(raw_v[15]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttph2w {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // HWY_COMPILER_GCC_ACTUAL < 1200 + return VFromD{_mm256_cvttph_epi16(v.raw)}; +#endif +} +template +HWY_API VFromD ConvertInRangeTo(D /* tag */, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttph_epu16 with GCC if any + // values of v[i] are not within the range of an uint16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi16( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[4])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[5])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[6])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[7])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[8])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[9])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[10])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[11])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[12])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[13])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[14])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[15])))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttph2uw {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // HWY_COMPILER_GCC_ACTUAL < 1200 + return VFromD{_mm256_cvttph_epu16(v.raw)}; +#endif +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API VFromD ConvertInRangeTo(D /*d*/, Vec256 v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm256_cvtts_ps_epi32(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttps_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi32( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttps2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm256_cvttps_epi32(v.raw)}; +#endif +} + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API VFromD ConvertInRangeTo(D /*di*/, Vec256 v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm256_cvtts_pd_epi64(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttpd_epi64 with GCC if any + // values of v[i] are not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi64x( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttpd2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvttpd_epi64(v.raw)}; +#endif // HWY_COMPILER_GCC_ACTUAL +} +template +HWY_API VFromD ConvertInRangeTo(DU /*du*/, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm256_cvtts_ps_epu32(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttps_epu32 with GCC if any + // values of v[i] are not within the range of an uint32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi32( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[4])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[5])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[6])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[7])))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttps2udq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvttps_epu32(v.raw)}; +#endif // HWY_COMPILER_GCC_ACTUAL +} +template +HWY_API VFromD ConvertInRangeTo(DU /*du*/, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm256_cvtts_pd_epu64(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvttpd_epu64 with GCC if any + // values of v[i] are not within the range of an uint64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi64x( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttpd2uqq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvttpd_epu64(v.raw)}; +#endif // HWY_COMPILER_GCC_ACTUAL +} +#endif // HWY_TARGET <= HWY_AVX3 + +template +static HWY_INLINE VFromD NearestIntInRange(DI, + VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvtps_epi32 if any values of + // v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm256_setr_epi32(detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3]), + detail::X86ScalarNearestInt(raw_v[4]), + detail::X86ScalarNearestInt(raw_v[5]), + detail::X86ScalarNearestInt(raw_v[6]), + detail::X86ScalarNearestInt(raw_v[7]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvtps2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvtps_epi32(v.raw)}; +#endif // HWY_COMPILER_GCC_ACTUAL +} + +#if HWY_HAVE_FLOAT16 +template +static HWY_INLINE VFromD NearestIntInRange(DI /*d*/, Vec256 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvtph_epi16 with GCC if any + // values of v[i] are not within the range of an int16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm256_setr_epi16(detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3]), + detail::X86ScalarNearestInt(raw_v[4]), + detail::X86ScalarNearestInt(raw_v[5]), + detail::X86ScalarNearestInt(raw_v[6]), + detail::X86ScalarNearestInt(raw_v[7]), + detail::X86ScalarNearestInt(raw_v[8]), + detail::X86ScalarNearestInt(raw_v[9]), + detail::X86ScalarNearestInt(raw_v[10]), + detail::X86ScalarNearestInt(raw_v[11]), + detail::X86ScalarNearestInt(raw_v[12]), + detail::X86ScalarNearestInt(raw_v[13]), + detail::X86ScalarNearestInt(raw_v[14]), + detail::X86ScalarNearestInt(raw_v[15]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvtph2w {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvtph_epi16(v.raw)}; +#endif +} +#endif + +#if HWY_TARGET <= HWY_AVX3 +template +static HWY_INLINE VFromD NearestIntInRange(DI, + VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvtpd_epi64 with GCC if any + // values of v[i] are not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm256_setr_epi64x(detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvtpd2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvtpd_epi64(v.raw)}; +#endif // HWY_COMPILER_GCC_ACTUAL +} +#endif // HWY_TARGET <= HWY_AVX3 + +template +static HWY_INLINE VFromD DemoteToNearestIntInRange( + DI, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm256_cvtpd_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF32RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return Dup128VecFromValues(DI(), + detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3])); + } +#endif + + __m128i raw_result; + __asm__("vcvtpd2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else // !HWY_COMPILER_GCC_ACTUAL + return VFromD{_mm256_cvtpd_epi32(v.raw)}; +#endif +} + +#ifndef HWY_DISABLE_F16C + +template +HWY_API VFromD PromoteTo(D df32, Vec128 v) { + (void)df32; +#if HWY_HAVE_FLOAT16 + const RebindToUnsigned> du16; + return VFromD{_mm256_cvtph_ps(BitCast(du16, v).raw)}; +#else + return VFromD{_mm256_cvtph_ps(v.raw)}; +#endif // HWY_HAVE_FLOAT16 +} + +#endif // HWY_DISABLE_F16C + +#if HWY_HAVE_FLOAT16 + +template +HWY_INLINE VFromD PromoteTo(D /*tag*/, Vec64 v) { + return VFromD{_mm256_cvtph_pd(v.raw)}; +} + +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API VFromD PromoteTo(D df32, Vec128 v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +// ================================================== CRYPTO + +#if !defined(HWY_DISABLE_PCLMUL_AES) + +HWY_API Vec256 AESRound(Vec256 state, + Vec256 round_key) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec256{_mm256_aesenc_epi128(state.raw, round_key.raw)}; +#else + const Full256 d; + const Half d2; + return Combine(d, AESRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESRound(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec256 AESLastRound(Vec256 state, + Vec256 round_key) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec256{_mm256_aesenclast_epi128(state.raw, round_key.raw)}; +#else + const Full256 d; + const Half d2; + return Combine(d, + AESLastRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESLastRound(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec256 AESRoundInv(Vec256 state, + Vec256 round_key) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec256{_mm256_aesdec_epi128(state.raw, round_key.raw)}; +#else + const Full256 d; + const Half d2; + return Combine(d, AESRoundInv(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESRoundInv(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec256 AESLastRoundInv(Vec256 state, + Vec256 round_key) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec256{_mm256_aesdeclast_epi128(state.raw, round_key.raw)}; +#else + const Full256 d; + const Half d2; + return Combine( + d, AESLastRoundInv(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESLastRoundInv(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +template )> +HWY_API V AESInvMixColumns(V state) { + const DFromV d; +#if HWY_TARGET <= HWY_AVX3_DL + // On AVX3_DL, it is more efficient to do an InvMixColumns operation for a + // 256-bit or 512-bit vector by doing a AESLastRound operation + // (_mm256_aesenclast_epi128/_mm512_aesenclast_epi128) followed by a + // AESRoundInv operation (_mm256_aesdec_epi128/_mm512_aesdec_epi128) than to + // split the vector into 128-bit vectors, carrying out multiple + // _mm_aesimc_si128 operations, and then combining the _mm_aesimc_si128 + // results back into a 256-bit or 512-bit vector. + const auto zero = Zero(d); + return AESRoundInv(AESLastRound(state, zero), zero); +#else + const Half dh; + return Combine(d, AESInvMixColumns(UpperHalf(dh, state)), + AESInvMixColumns(LowerHalf(dh, state))); +#endif +} + +template +HWY_API Vec256 AESKeyGenAssist(Vec256 v) { + const Full256 d; +#if HWY_TARGET <= HWY_AVX3_DL + const VFromD rconXorMask = Dup128VecFromValues( + d, 0, kRcon, 0, 0, 0, 0, 0, 0, 0, kRcon, 0, 0, 0, 0, 0, 0); + const VFromD rotWordShuffle = Dup128VecFromValues( + d, 0, 13, 10, 7, 1, 14, 11, 4, 8, 5, 2, 15, 9, 6, 3, 12); + const Repartition du32; + const auto w13 = BitCast(d, DupOdd(BitCast(du32, v))); + const auto sub_word_result = AESLastRound(w13, rconXorMask); + return TableLookupBytes(sub_word_result, rotWordShuffle); +#else + const Half d2; + return Combine(d, AESKeyGenAssist(UpperHalf(d2, v)), + AESKeyGenAssist(LowerHalf(v))); +#endif +} + +HWY_API Vec256 CLMulLower(Vec256 a, Vec256 b) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec256{_mm256_clmulepi64_epi128(a.raw, b.raw, 0x00)}; +#else + const Full256 d; + const Half d2; + return Combine(d, CLMulLower(UpperHalf(d2, a), UpperHalf(d2, b)), + CLMulLower(LowerHalf(a), LowerHalf(b))); +#endif +} + +HWY_API Vec256 CLMulUpper(Vec256 a, Vec256 b) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec256{_mm256_clmulepi64_epi128(a.raw, b.raw, 0x11)}; +#else + const Full256 d; + const Half d2; + return Combine(d, CLMulUpper(UpperHalf(d2, a), UpperHalf(d2, b)), + CLMulUpper(LowerHalf(a), LowerHalf(b))); +#endif +} + +#endif // HWY_DISABLE_PCLMUL_AES + +// ================================================== MISC + +#if HWY_TARGET <= HWY_AVX3 + +// ------------------------------ LoadMaskBits + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + constexpr size_t kN = MaxLanes(d); + constexpr size_t kNumBytes = (kN + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes(bits, &mask_bits); + + if (kN < 8) { + mask_bits &= (1ull << kN) - 1; + } + + return MFromD::FromBits(mask_bits); +} + +// ------------------------------ StoreMaskBits + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(D d, MFromD mask, uint8_t* bits) { + constexpr size_t kN = MaxLanes(d); + constexpr size_t kNumBytes = (kN + 7) / 8; + + CopyBytes(&mask.raw, bits); + + // Non-full byte, need to clear the undefined upper bits. + if (kN < 8) { + const int mask_bits = static_cast((1ull << kN) - 1); + bits[0] = static_cast(bits[0] & mask_bits); + } + return kNumBytes; +} + +// ------------------------------ Mask testing + +template +HWY_API size_t CountTrue(D /* tag */, MFromD mask) { + return PopCount(static_cast(mask.raw)); +} + +template +HWY_API size_t FindKnownFirstTrue(D /* tag */, MFromD mask) { + return Num0BitsBelowLS1Bit_Nonzero32(mask.raw); +} + +template +HWY_API intptr_t FindFirstTrue(D d, MFromD mask) { + return mask.raw ? static_cast(FindKnownFirstTrue(d, mask)) + : intptr_t{-1}; +} + +template +HWY_API size_t FindKnownLastTrue(D /* tag */, MFromD mask) { + return 31 - Num0BitsAboveMS1Bit_Nonzero32(mask.raw); +} + +template +HWY_API intptr_t FindLastTrue(D d, MFromD mask) { + return mask.raw ? static_cast(FindKnownLastTrue(d, mask)) + : intptr_t{-1}; +} + +// Beware: the suffix indicates the number of mask bits, not lane size! + +namespace detail { + +template +HWY_INLINE bool AllFalse(hwy::SizeTag<1> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask32_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<2> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask16_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<4> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask8_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<8> /*tag*/, const Mask256 mask) { + return (uint64_t{mask.raw} & 0xF) == 0; +} + +} // namespace detail + +template +HWY_API bool AllFalse(D /* tag */, MFromD mask) { + return detail::AllFalse(hwy::SizeTag)>(), mask); +} + +namespace detail { + +template +HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask32_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFFFFFu; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask16_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFu; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask256 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask8_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFu; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask256 mask) { + // Cannot use _kortestc because we have less than 8 mask bits. + return mask.raw == 0xFu; +} + +} // namespace detail + +template +HWY_API bool AllTrue(D /* tag */, const MFromD mask) { + return detail::AllTrue(hwy::SizeTag)>(), mask); +} + +// ------------------------------ Compress + +// 16-bit is defined in x86_512 so we can use 512-bit vectors. + +template +HWY_API Vec256 Compress(Vec256 v, Mask256 mask) { + return Vec256{_mm256_maskz_compress_epi32(mask.raw, v.raw)}; +} + +HWY_API Vec256 Compress(Vec256 v, Mask256 mask) { + return Vec256{_mm256_maskz_compress_ps(mask.raw, v.raw)}; +} + +template +HWY_API Vec256 Compress(Vec256 v, Mask256 mask) { + // See CompressIsPartition. + alignas(16) static constexpr uint64_t packed_array[16] = { + // PrintCompress64x4NibbleTables + 0x00003210, 0x00003210, 0x00003201, 0x00003210, 0x00003102, 0x00003120, + 0x00003021, 0x00003210, 0x00002103, 0x00002130, 0x00002031, 0x00002310, + 0x00001032, 0x00001320, 0x00000321, 0x00003210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 2) - + // _mm256_permutexvar_epi64 will ignore the upper bits. + const DFromV d; + const RebindToUnsigned du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(64) static constexpr uint64_t shifts[4] = {0, 4, 8, 12}; + const auto indices = Indices256{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// ------------------------------ CompressNot (Compress) + +// Implemented in x86_512 for lane size != 8. + +template +HWY_API Vec256 CompressNot(Vec256 v, Mask256 mask) { + // See CompressIsPartition. + alignas(16) static constexpr uint64_t packed_array[16] = { + // PrintCompressNot64x4NibbleTables + 0x00003210, 0x00000321, 0x00001320, 0x00001032, 0x00002310, 0x00002031, + 0x00002130, 0x00002103, 0x00003210, 0x00003021, 0x00003120, 0x00003102, + 0x00003210, 0x00003201, 0x00003210, 0x00003210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 2) - + // _mm256_permutexvar_epi64 will ignore the upper bits. + const DFromV d; + const RebindToUnsigned du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(32) static constexpr uint64_t shifts[4] = {0, 4, 8, 12}; + const auto indices = Indices256{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// ------------------------------ CompressStore (defined in x86_512) +// ------------------------------ CompressBlendedStore (defined in x86_512) +// ------------------------------ CompressBitsStore (defined in x86_512) + +#else // AVX2 + +// ------------------------------ LoadMaskBits (TestBit) + +namespace detail { + +// 256 suffix avoids ambiguity with x86_128 without needing HWY_IF_V_SIZE. +template +HWY_INLINE Mask256 LoadMaskBits256(uint64_t mask_bits) { + const Full256 d; + const RebindToUnsigned du; + const Repartition du32; + const auto vbits = BitCast(du, Set(du32, static_cast(mask_bits))); + + // Replicate bytes 8x such that each byte contains the bit that governs it. + const Repartition du64; + alignas(32) static constexpr uint64_t kRep8[4] = { + 0x0000000000000000ull, 0x0101010101010101ull, 0x0202020202020202ull, + 0x0303030303030303ull}; + const auto rep8 = TableLookupBytes(vbits, BitCast(du, Load(du64, kRep8))); + + const VFromD bit = Dup128VecFromValues( + du, 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); + return RebindMask(d, TestBit(rep8, bit)); +} + +template +HWY_INLINE Mask256 LoadMaskBits256(uint64_t mask_bits) { + const Full256 d; + const RebindToUnsigned du; + alignas(32) static constexpr uint16_t kBit[16] = { + 1, 2, 4, 8, 16, 32, 64, 128, + 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE Mask256 LoadMaskBits256(uint64_t mask_bits) { + const Full256 d; + const RebindToUnsigned du; + alignas(32) static constexpr uint32_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + const auto vmask_bits = Set(du, static_cast(mask_bits)); + return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); +} + +template +HWY_INLINE Mask256 LoadMaskBits256(uint64_t mask_bits) { + const Full256 d; + const RebindToUnsigned du; + alignas(32) static constexpr uint64_t kBit[8] = {1, 2, 4, 8}; + return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); +} + +} // namespace detail + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { + constexpr size_t kN = MaxLanes(d); + constexpr size_t kNumBytes = (kN + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes(bits, &mask_bits); + + if (kN < 8) { + mask_bits &= (1ull << kN) - 1; + } + + return detail::LoadMaskBits256>(mask_bits); +} + +// ------------------------------ BitsFromMask + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const RebindToUnsigned d8; + const auto sign_bits = BitCast(d8, VecFromMask(d, mask)).raw; + // Prevent sign-extension of 32-bit masks because the intrinsic returns int. + return static_cast(_mm256_movemask_epi8(sign_bits)); +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { +#if !defined(HWY_DISABLE_BMI2_FMA) && !defined(HWY_DISABLE_PEXT_ON_AVX2) + const Repartition d8; + const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); + const uint64_t sign_bits8 = BitsFromMask(d8, mask8); + // Skip the bits from the lower byte of each u16 (better not to use the + // same packs_epi16 as SSE4, because that requires an extra swizzle here). + return _pext_u32(static_cast(sign_bits8), 0xAAAAAAAAu); +#else + // Slow workaround for when BMI2 is disabled + // Remove useless lower half of each u16 while preserving the sign bit. + // Bytes [0, 8) and [16, 24) have the same sign bits as the input lanes. + const auto sign_bits = _mm256_packs_epi16(mask.raw, _mm256_setzero_si256()); + // Move odd qwords (value zero) to top so they don't affect the mask value. + const auto compressed = _mm256_castsi256_si128( + _mm256_permute4x64_epi64(sign_bits, _MM_SHUFFLE(3, 1, 2, 0))); + return static_cast(_mm_movemask_epi8(compressed)); +#endif // HWY_ARCH_X86_64 +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const RebindToFloat df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)).raw; + return static_cast(_mm256_movemask_ps(sign_bits)); +} + +template +HWY_API uint64_t BitsFromMask(D d, MFromD mask) { + const RebindToFloat df; + const auto sign_bits = BitCast(df, VecFromMask(d, mask)).raw; + return static_cast(_mm256_movemask_pd(sign_bits)); +} + +// ------------------------------ StoreMaskBits +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(D d, MFromD mask, uint8_t* bits) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + HWY_LANES_CONSTEXPR size_t kNumBytes = (N + 7) / 8; + + const uint64_t mask_bits = BitsFromMask(d, mask); + CopyBytes(&mask_bits, bits, kNumBytes); + return kNumBytes; +} + +// ------------------------------ Mask testing + +// Specialize for 16-bit lanes to avoid unnecessary pext. This assumes each mask +// lane is 0 or ~0. +template +HWY_API bool AllFalse(D d, MFromD mask) { + const Repartition d8; + const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); + return BitsFromMask(d8, mask8) == 0; +} + +template +HWY_API bool AllFalse(D d, MFromD mask) { + // Cheaper than PTEST, which is 2 uop / 3L. + return BitsFromMask(d, mask) == 0; +} + +template +HWY_API bool AllTrue(D d, MFromD mask) { + const Repartition d8; + const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); + return BitsFromMask(d8, mask8) == (1ull << 32) - 1; +} +template +HWY_API bool AllTrue(D d, MFromD mask) { + constexpr uint64_t kAllBits = (1ull << MaxLanes(d)) - 1; + return BitsFromMask(d, mask) == kAllBits; +} + +template +HWY_API size_t CountTrue(D d, MFromD mask) { + const Repartition d8; + const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); + return PopCount(BitsFromMask(d8, mask8)) >> 1; +} +template +HWY_API size_t CountTrue(D d, MFromD mask) { + return PopCount(BitsFromMask(d, mask)); +} + +template +HWY_API size_t FindKnownFirstTrue(D d, MFromD mask) { + const uint32_t mask_bits = static_cast(BitsFromMask(d, mask)); + return Num0BitsBelowLS1Bit_Nonzero32(mask_bits); +} + +template +HWY_API intptr_t FindFirstTrue(D d, MFromD mask) { + const uint32_t mask_bits = static_cast(BitsFromMask(d, mask)); + return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero32(mask_bits)) : -1; +} + +template +HWY_API size_t FindKnownLastTrue(D d, MFromD mask) { + const uint32_t mask_bits = static_cast(BitsFromMask(d, mask)); + return 31 - Num0BitsAboveMS1Bit_Nonzero32(mask_bits); +} + +template +HWY_API intptr_t FindLastTrue(D d, MFromD mask) { + const uint32_t mask_bits = static_cast(BitsFromMask(d, mask)); + return mask_bits ? intptr_t(31 - Num0BitsAboveMS1Bit_Nonzero32(mask_bits)) + : -1; +} + +// ------------------------------ Compress, CompressBits + +namespace detail { + +template +HWY_INLINE Vec256 IndicesFromBits256(uint64_t mask_bits) { + const Full256 d32; + // We need a masked Iota(). With 8 lanes, there are 256 combinations and a LUT + // of SetTableIndices would require 8 KiB, a large part of L1D. The other + // alternative is _pext_u64, but this is extremely slow on Zen2 (18 cycles) + // and unavailable in 32-bit builds. We instead compress each index into 4 + // bits, for a total of 1 KiB. + alignas(16) static constexpr uint32_t packed_array[256] = { + // PrintCompress32x8Tables + 0x76543210, 0x76543218, 0x76543209, 0x76543298, 0x7654310a, 0x765431a8, + 0x765430a9, 0x76543a98, 0x7654210b, 0x765421b8, 0x765420b9, 0x76542b98, + 0x765410ba, 0x76541ba8, 0x76540ba9, 0x7654ba98, 0x7653210c, 0x765321c8, + 0x765320c9, 0x76532c98, 0x765310ca, 0x76531ca8, 0x76530ca9, 0x7653ca98, + 0x765210cb, 0x76521cb8, 0x76520cb9, 0x7652cb98, 0x76510cba, 0x7651cba8, + 0x7650cba9, 0x765cba98, 0x7643210d, 0x764321d8, 0x764320d9, 0x76432d98, + 0x764310da, 0x76431da8, 0x76430da9, 0x7643da98, 0x764210db, 0x76421db8, + 0x76420db9, 0x7642db98, 0x76410dba, 0x7641dba8, 0x7640dba9, 0x764dba98, + 0x763210dc, 0x76321dc8, 0x76320dc9, 0x7632dc98, 0x76310dca, 0x7631dca8, + 0x7630dca9, 0x763dca98, 0x76210dcb, 0x7621dcb8, 0x7620dcb9, 0x762dcb98, + 0x7610dcba, 0x761dcba8, 0x760dcba9, 0x76dcba98, 0x7543210e, 0x754321e8, + 0x754320e9, 0x75432e98, 0x754310ea, 0x75431ea8, 0x75430ea9, 0x7543ea98, + 0x754210eb, 0x75421eb8, 0x75420eb9, 0x7542eb98, 0x75410eba, 0x7541eba8, + 0x7540eba9, 0x754eba98, 0x753210ec, 0x75321ec8, 0x75320ec9, 0x7532ec98, + 0x75310eca, 0x7531eca8, 0x7530eca9, 0x753eca98, 0x75210ecb, 0x7521ecb8, + 0x7520ecb9, 0x752ecb98, 0x7510ecba, 0x751ecba8, 0x750ecba9, 0x75ecba98, + 0x743210ed, 0x74321ed8, 0x74320ed9, 0x7432ed98, 0x74310eda, 0x7431eda8, + 0x7430eda9, 0x743eda98, 0x74210edb, 0x7421edb8, 0x7420edb9, 0x742edb98, + 0x7410edba, 0x741edba8, 0x740edba9, 0x74edba98, 0x73210edc, 0x7321edc8, + 0x7320edc9, 0x732edc98, 0x7310edca, 0x731edca8, 0x730edca9, 0x73edca98, + 0x7210edcb, 0x721edcb8, 0x720edcb9, 0x72edcb98, 0x710edcba, 0x71edcba8, + 0x70edcba9, 0x7edcba98, 0x6543210f, 0x654321f8, 0x654320f9, 0x65432f98, + 0x654310fa, 0x65431fa8, 0x65430fa9, 0x6543fa98, 0x654210fb, 0x65421fb8, + 0x65420fb9, 0x6542fb98, 0x65410fba, 0x6541fba8, 0x6540fba9, 0x654fba98, + 0x653210fc, 0x65321fc8, 0x65320fc9, 0x6532fc98, 0x65310fca, 0x6531fca8, + 0x6530fca9, 0x653fca98, 0x65210fcb, 0x6521fcb8, 0x6520fcb9, 0x652fcb98, + 0x6510fcba, 0x651fcba8, 0x650fcba9, 0x65fcba98, 0x643210fd, 0x64321fd8, + 0x64320fd9, 0x6432fd98, 0x64310fda, 0x6431fda8, 0x6430fda9, 0x643fda98, + 0x64210fdb, 0x6421fdb8, 0x6420fdb9, 0x642fdb98, 0x6410fdba, 0x641fdba8, + 0x640fdba9, 0x64fdba98, 0x63210fdc, 0x6321fdc8, 0x6320fdc9, 0x632fdc98, + 0x6310fdca, 0x631fdca8, 0x630fdca9, 0x63fdca98, 0x6210fdcb, 0x621fdcb8, + 0x620fdcb9, 0x62fdcb98, 0x610fdcba, 0x61fdcba8, 0x60fdcba9, 0x6fdcba98, + 0x543210fe, 0x54321fe8, 0x54320fe9, 0x5432fe98, 0x54310fea, 0x5431fea8, + 0x5430fea9, 0x543fea98, 0x54210feb, 0x5421feb8, 0x5420feb9, 0x542feb98, + 0x5410feba, 0x541feba8, 0x540feba9, 0x54feba98, 0x53210fec, 0x5321fec8, + 0x5320fec9, 0x532fec98, 0x5310feca, 0x531feca8, 0x530feca9, 0x53feca98, + 0x5210fecb, 0x521fecb8, 0x520fecb9, 0x52fecb98, 0x510fecba, 0x51fecba8, + 0x50fecba9, 0x5fecba98, 0x43210fed, 0x4321fed8, 0x4320fed9, 0x432fed98, + 0x4310feda, 0x431feda8, 0x430feda9, 0x43feda98, 0x4210fedb, 0x421fedb8, + 0x420fedb9, 0x42fedb98, 0x410fedba, 0x41fedba8, 0x40fedba9, 0x4fedba98, + 0x3210fedc, 0x321fedc8, 0x320fedc9, 0x32fedc98, 0x310fedca, 0x31fedca8, + 0x30fedca9, 0x3fedca98, 0x210fedcb, 0x21fedcb8, 0x20fedcb9, 0x2fedcb98, + 0x10fedcba, 0x1fedcba8, 0x0fedcba9, 0xfedcba98}; + + // No need to mask because _mm256_permutevar8x32_epi32 ignores bits 3..31. + // Just shift each copy of the 32 bit LUT to extract its 4-bit fields. + // If broadcasting 32-bit from memory incurs the 3-cycle block-crossing + // latency, it may be faster to use LoadDup128 and PSHUFB. + const auto packed = Set(d32, packed_array[mask_bits]); + alignas(32) static constexpr uint32_t shifts[8] = {0, 4, 8, 12, + 16, 20, 24, 28}; + return packed >> Load(d32, shifts); +} + +template +HWY_INLINE Vec256 IndicesFromBits256(uint64_t mask_bits) { + const Full256 d32; + + // For 64-bit, we still need 32-bit indices because there is no 64-bit + // permutevar, but there are only 4 lanes, so we can afford to skip the + // unpacking and load the entire index vector directly. + alignas(32) static constexpr uint32_t u32_indices[128] = { + // PrintCompress64x4PairTables + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 6, 7, + 10, 11, 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 4, 5, 6, 7, + 12, 13, 0, 1, 2, 3, 6, 7, 8, 9, 12, 13, 2, 3, 6, 7, + 10, 11, 12, 13, 0, 1, 6, 7, 8, 9, 10, 11, 12, 13, 6, 7, + 14, 15, 0, 1, 2, 3, 4, 5, 8, 9, 14, 15, 2, 3, 4, 5, + 10, 11, 14, 15, 0, 1, 4, 5, 8, 9, 10, 11, 14, 15, 4, 5, + 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 12, 13, 14, 15, 2, 3, + 10, 11, 12, 13, 14, 15, 0, 1, 8, 9, 10, 11, 12, 13, 14, 15}; + return Load(d32, u32_indices + 8 * mask_bits); +} + +template +HWY_INLINE Vec256 IndicesFromNotBits256(uint64_t mask_bits) { + const Full256 d32; + // We need a masked Iota(). With 8 lanes, there are 256 combinations and a LUT + // of SetTableIndices would require 8 KiB, a large part of L1D. The other + // alternative is _pext_u64, but this is extremely slow on Zen2 (18 cycles) + // and unavailable in 32-bit builds. We instead compress each index into 4 + // bits, for a total of 1 KiB. + alignas(16) static constexpr uint32_t packed_array[256] = { + // PrintCompressNot32x8Tables + 0xfedcba98, 0x8fedcba9, 0x9fedcba8, 0x98fedcba, 0xafedcb98, 0xa8fedcb9, + 0xa9fedcb8, 0xa98fedcb, 0xbfedca98, 0xb8fedca9, 0xb9fedca8, 0xb98fedca, + 0xbafedc98, 0xba8fedc9, 0xba9fedc8, 0xba98fedc, 0xcfedba98, 0xc8fedba9, + 0xc9fedba8, 0xc98fedba, 0xcafedb98, 0xca8fedb9, 0xca9fedb8, 0xca98fedb, + 0xcbfeda98, 0xcb8feda9, 0xcb9feda8, 0xcb98feda, 0xcbafed98, 0xcba8fed9, + 0xcba9fed8, 0xcba98fed, 0xdfecba98, 0xd8fecba9, 0xd9fecba8, 0xd98fecba, + 0xdafecb98, 0xda8fecb9, 0xda9fecb8, 0xda98fecb, 0xdbfeca98, 0xdb8feca9, + 0xdb9feca8, 0xdb98feca, 0xdbafec98, 0xdba8fec9, 0xdba9fec8, 0xdba98fec, + 0xdcfeba98, 0xdc8feba9, 0xdc9feba8, 0xdc98feba, 0xdcafeb98, 0xdca8feb9, + 0xdca9feb8, 0xdca98feb, 0xdcbfea98, 0xdcb8fea9, 0xdcb9fea8, 0xdcb98fea, + 0xdcbafe98, 0xdcba8fe9, 0xdcba9fe8, 0xdcba98fe, 0xefdcba98, 0xe8fdcba9, + 0xe9fdcba8, 0xe98fdcba, 0xeafdcb98, 0xea8fdcb9, 0xea9fdcb8, 0xea98fdcb, + 0xebfdca98, 0xeb8fdca9, 0xeb9fdca8, 0xeb98fdca, 0xebafdc98, 0xeba8fdc9, + 0xeba9fdc8, 0xeba98fdc, 0xecfdba98, 0xec8fdba9, 0xec9fdba8, 0xec98fdba, + 0xecafdb98, 0xeca8fdb9, 0xeca9fdb8, 0xeca98fdb, 0xecbfda98, 0xecb8fda9, + 0xecb9fda8, 0xecb98fda, 0xecbafd98, 0xecba8fd9, 0xecba9fd8, 0xecba98fd, + 0xedfcba98, 0xed8fcba9, 0xed9fcba8, 0xed98fcba, 0xedafcb98, 0xeda8fcb9, + 0xeda9fcb8, 0xeda98fcb, 0xedbfca98, 0xedb8fca9, 0xedb9fca8, 0xedb98fca, + 0xedbafc98, 0xedba8fc9, 0xedba9fc8, 0xedba98fc, 0xedcfba98, 0xedc8fba9, + 0xedc9fba8, 0xedc98fba, 0xedcafb98, 0xedca8fb9, 0xedca9fb8, 0xedca98fb, + 0xedcbfa98, 0xedcb8fa9, 0xedcb9fa8, 0xedcb98fa, 0xedcbaf98, 0xedcba8f9, + 0xedcba9f8, 0xedcba98f, 0xfedcba98, 0xf8edcba9, 0xf9edcba8, 0xf98edcba, + 0xfaedcb98, 0xfa8edcb9, 0xfa9edcb8, 0xfa98edcb, 0xfbedca98, 0xfb8edca9, + 0xfb9edca8, 0xfb98edca, 0xfbaedc98, 0xfba8edc9, 0xfba9edc8, 0xfba98edc, + 0xfcedba98, 0xfc8edba9, 0xfc9edba8, 0xfc98edba, 0xfcaedb98, 0xfca8edb9, + 0xfca9edb8, 0xfca98edb, 0xfcbeda98, 0xfcb8eda9, 0xfcb9eda8, 0xfcb98eda, + 0xfcbaed98, 0xfcba8ed9, 0xfcba9ed8, 0xfcba98ed, 0xfdecba98, 0xfd8ecba9, + 0xfd9ecba8, 0xfd98ecba, 0xfdaecb98, 0xfda8ecb9, 0xfda9ecb8, 0xfda98ecb, + 0xfdbeca98, 0xfdb8eca9, 0xfdb9eca8, 0xfdb98eca, 0xfdbaec98, 0xfdba8ec9, + 0xfdba9ec8, 0xfdba98ec, 0xfdceba98, 0xfdc8eba9, 0xfdc9eba8, 0xfdc98eba, + 0xfdcaeb98, 0xfdca8eb9, 0xfdca9eb8, 0xfdca98eb, 0xfdcbea98, 0xfdcb8ea9, + 0xfdcb9ea8, 0xfdcb98ea, 0xfdcbae98, 0xfdcba8e9, 0xfdcba9e8, 0xfdcba98e, + 0xfedcba98, 0xfe8dcba9, 0xfe9dcba8, 0xfe98dcba, 0xfeadcb98, 0xfea8dcb9, + 0xfea9dcb8, 0xfea98dcb, 0xfebdca98, 0xfeb8dca9, 0xfeb9dca8, 0xfeb98dca, + 0xfebadc98, 0xfeba8dc9, 0xfeba9dc8, 0xfeba98dc, 0xfecdba98, 0xfec8dba9, + 0xfec9dba8, 0xfec98dba, 0xfecadb98, 0xfeca8db9, 0xfeca9db8, 0xfeca98db, + 0xfecbda98, 0xfecb8da9, 0xfecb9da8, 0xfecb98da, 0xfecbad98, 0xfecba8d9, + 0xfecba9d8, 0xfecba98d, 0xfedcba98, 0xfed8cba9, 0xfed9cba8, 0xfed98cba, + 0xfedacb98, 0xfeda8cb9, 0xfeda9cb8, 0xfeda98cb, 0xfedbca98, 0xfedb8ca9, + 0xfedb9ca8, 0xfedb98ca, 0xfedbac98, 0xfedba8c9, 0xfedba9c8, 0xfedba98c, + 0xfedcba98, 0xfedc8ba9, 0xfedc9ba8, 0xfedc98ba, 0xfedcab98, 0xfedca8b9, + 0xfedca9b8, 0xfedca98b, 0xfedcba98, 0xfedcb8a9, 0xfedcb9a8, 0xfedcb98a, + 0xfedcba98, 0xfedcba89, 0xfedcba98, 0xfedcba98}; + + // No need to mask because <_mm256_permutevar8x32_epi32> ignores bits 3..31. + // Just shift each copy of the 32 bit LUT to extract its 4-bit fields. + // If broadcasting 32-bit from memory incurs the 3-cycle block-crossing + // latency, it may be faster to use LoadDup128 and PSHUFB. + const Vec256 packed = Set(d32, packed_array[mask_bits]); + alignas(32) static constexpr uint32_t shifts[8] = {0, 4, 8, 12, + 16, 20, 24, 28}; + return packed >> Load(d32, shifts); +} + +template +HWY_INLINE Vec256 IndicesFromNotBits256(uint64_t mask_bits) { + const Full256 d32; + + // For 64-bit, we still need 32-bit indices because there is no 64-bit + // permutevar, but there are only 4 lanes, so we can afford to skip the + // unpacking and load the entire index vector directly. + alignas(32) static constexpr uint32_t u32_indices[128] = { + // PrintCompressNot64x4PairTables + 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, 8, 9, + 8, 9, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, + 8, 9, 10, 11, 14, 15, 12, 13, 10, 11, 14, 15, 8, 9, 12, 13, + 8, 9, 14, 15, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, + 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 8, 9, 14, 15, + 8, 9, 12, 13, 10, 11, 14, 15, 12, 13, 8, 9, 10, 11, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 10, 11, 8, 9, 12, 13, 14, 15, + 8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}; + return Load(d32, u32_indices + 8 * mask_bits); +} + +template +HWY_INLINE Vec256 Compress(Vec256 v, const uint64_t mask_bits) { + const DFromV d; + const Repartition du32; + + HWY_DASSERT(mask_bits < (1ull << Lanes(d))); + // 32-bit indices because we only have _mm256_permutevar8x32_epi32 (there is + // no instruction for 4x64). + const Indices256 indices{IndicesFromBits256(mask_bits).raw}; + return BitCast(d, TableLookupLanes(BitCast(du32, v), indices)); +} + +// LUTs are infeasible for 2^16 possible masks, so splice together two +// half-vector Compress. +template +HWY_INLINE Vec256 Compress(Vec256 v, const uint64_t mask_bits) { + const DFromV d; + const RebindToUnsigned du; + const auto vu16 = BitCast(du, v); // (required for float16_t inputs) + const Half duh; + const auto half0 = LowerHalf(duh, vu16); + const auto half1 = UpperHalf(duh, vu16); + + const uint64_t mask_bits0 = mask_bits & 0xFF; + const uint64_t mask_bits1 = mask_bits >> 8; + const auto compressed0 = detail::CompressBits(half0, mask_bits0); + const auto compressed1 = detail::CompressBits(half1, mask_bits1); + + alignas(32) uint16_t all_true[16] = {}; + // Store mask=true lanes, left to right. + const size_t num_true0 = PopCount(mask_bits0); + Store(compressed0, duh, all_true); + StoreU(compressed1, duh, all_true + num_true0); + + if (hwy::HWY_NAMESPACE::CompressIsPartition::value) { + // Store mask=false lanes, right to left. The second vector fills the upper + // half with right-aligned false lanes. The first vector is shifted + // rightwards to overwrite the true lanes of the second. + alignas(32) uint16_t all_false[16] = {}; + const size_t num_true1 = PopCount(mask_bits1); + Store(compressed1, duh, all_false + 8); + StoreU(compressed0, duh, all_false + num_true1); + + const auto mask = FirstN(du, num_true0 + num_true1); + return BitCast(d, + IfThenElse(mask, Load(du, all_true), Load(du, all_false))); + } else { + // Only care about the mask=true lanes. + return BitCast(d, Load(du, all_true)); + } +} + +template +HWY_INLINE Vec256 CompressNot(Vec256 v, const uint64_t mask_bits) { + const DFromV d; + const Repartition du32; + + HWY_DASSERT(mask_bits < (1ull << Lanes(d))); + // 32-bit indices because we only have _mm256_permutevar8x32_epi32 (there is + // no instruction for 4x64). + const Indices256 indices{IndicesFromNotBits256(mask_bits).raw}; + return BitCast(d, TableLookupLanes(BitCast(du32, v), indices)); +} + +// LUTs are infeasible for 2^16 possible masks, so splice together two +// half-vector Compress. +template +HWY_INLINE Vec256 CompressNot(Vec256 v, const uint64_t mask_bits) { + // Compress ensures only the lower 16 bits are set, so flip those. + return Compress(v, mask_bits ^ 0xFFFF); +} + +} // namespace detail + +template +HWY_API Vec256 Compress(Vec256 v, Mask256 m) { + const DFromV d; + return detail::Compress(v, BitsFromMask(d, m)); +} + +template +HWY_API Vec256 CompressNot(Vec256 v, Mask256 m) { + const DFromV d; + return detail::CompressNot(v, BitsFromMask(d, m)); +} + +HWY_API Vec256 CompressBlocksNot(Vec256 v, + Mask256 mask) { + return CompressNot(v, mask); +} + +template +HWY_API Vec256 CompressBits(Vec256 v, const uint8_t* HWY_RESTRICT bits) { + constexpr size_t N = 32 / sizeof(T); + constexpr size_t kNumBytes = (N + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes(bits, &mask_bits); + + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + + return detail::Compress(v, mask_bits); +} + +// ------------------------------ CompressStore, CompressBitsStore + +template +HWY_API size_t CompressStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = BitsFromMask(d, m); + const size_t count = PopCount(mask_bits); + StoreU(detail::Compress(v, mask_bits), d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = BitsFromMask(d, m); + const size_t count = PopCount(mask_bits); + + const RebindToUnsigned du; + const Repartition du32; + HWY_DASSERT(mask_bits < (1ull << Lanes(d))); + // 32-bit indices because we only have _mm256_permutevar8x32_epi32 (there is + // no instruction for 4x64). Nibble MSB encodes FirstN. + const Vec256 idx_mask = + detail::IndicesFromBits256>(mask_bits); + // Shift nibble MSB into MSB + const Mask256 mask32 = MaskFromVec(ShiftLeft<28>(idx_mask)); + // First cast to unsigned (RebindMask cannot change lane size) + const MFromD mask_u{mask32.raw}; + const MFromD mask = RebindMask(d, mask_u); + const VFromD compressed = BitCast( + d, + TableLookupLanes(BitCast(du32, v), Indices256{idx_mask.raw})); + + BlendedStore(compressed, mask, d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + const uint64_t mask_bits = BitsFromMask(d, m); + const size_t count = PopCount(mask_bits); + const VFromD compressed = detail::Compress(v, mask_bits); + +#if HWY_MEM_OPS_MIGHT_FAULT // true if HWY_IS_MSAN + // BlendedStore tests mask for each lane, but we know that the mask is + // FirstN, so we can just copy. + alignas(32) TFromD buf[16]; + Store(compressed, d, buf); + CopyBytes(buf, unaligned, count * sizeof(TFromD)); +#else + BlendedStore(compressed, FirstN(d, count), d, unaligned); +#endif + return count; +} + +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + HWY_LANES_CONSTEXPR size_t kNumBytes = (N + 7) / 8; + + uint64_t mask_bits = 0; + CopyBytes(bits, &mask_bits, kNumBytes); + + if (N < 8) { + mask_bits &= (1ull << N) - 1; + } + const size_t count = PopCount(mask_bits); + + StoreU(detail::Compress(v, mask_bits), d, unaligned); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Dup128MaskFromMaskBits + +// Generic for all vector lengths >= 32 bytes +template +HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { + const Half dh; + const auto mh = Dup128MaskFromMaskBits(dh, mask_bits); + return CombineMasks(d, mh, mh); +} + +// ------------------------------ Expand + +// Always define Expand/LoadExpand because generic_ops only does so for Vec128. + +namespace detail { + +#if HWY_TARGET <= HWY_AVX3_DL || HWY_IDE // VBMI2 + +HWY_INLINE Vec256 NativeExpand(Vec256 v, + Mask256 mask) { + return Vec256{_mm256_maskz_expand_epi8(mask.raw, v.raw)}; +} + +HWY_INLINE Vec256 NativeExpand(Vec256 v, + Mask256 mask) { + return Vec256{_mm256_maskz_expand_epi16(mask.raw, v.raw)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(MFromD mask, D /* d */, + const uint8_t* HWY_RESTRICT unaligned) { + return VFromD{_mm256_maskz_expandloadu_epi8(mask.raw, unaligned)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(MFromD mask, D /* d */, + const uint16_t* HWY_RESTRICT unaligned) { + return VFromD{_mm256_maskz_expandloadu_epi16(mask.raw, unaligned)}; +} + +#endif // HWY_TARGET <= HWY_AVX3_DL +#if HWY_TARGET <= HWY_AVX3 || HWY_IDE + +HWY_INLINE Vec256 NativeExpand(Vec256 v, + Mask256 mask) { + return Vec256{_mm256_maskz_expand_epi32(mask.raw, v.raw)}; +} + +HWY_INLINE Vec256 NativeExpand(Vec256 v, + Mask256 mask) { + return Vec256{_mm256_maskz_expand_epi64(mask.raw, v.raw)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(MFromD mask, D /* d */, + const uint32_t* HWY_RESTRICT unaligned) { + return VFromD{_mm256_maskz_expandloadu_epi32(mask.raw, unaligned)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(MFromD mask, D /* d */, + const uint64_t* HWY_RESTRICT unaligned) { + return VFromD{_mm256_maskz_expandloadu_epi64(mask.raw, unaligned)}; +} + +#endif // HWY_TARGET <= HWY_AVX3 + +} // namespace detail + +template +HWY_API Vec256 Expand(Vec256 v, Mask256 mask) { + const DFromV d; +#if HWY_TARGET <= HWY_AVX3_DL // VBMI2 + const RebindToUnsigned du; + const MFromD mu = RebindMask(du, mask); + return BitCast(d, detail::NativeExpand(BitCast(du, v), mu)); +#else + // LUTs are infeasible for so many mask combinations, so Combine two + // half-vector Expand. + const Half dh; + const uint64_t mask_bits = BitsFromMask(d, mask); + constexpr size_t N = 32 / sizeof(T); + const size_t countL = PopCount(mask_bits & ((1 << (N / 2)) - 1)); + const Mask128 maskL = MaskFromVec(LowerHalf(VecFromMask(d, mask))); + const Vec128 expandL = Expand(LowerHalf(v), maskL); + // We have to shift the input by a variable number of bytes, but there isn't + // a table-driven option for that until VBMI, and CPUs with that likely also + // have VBMI2 and thus native Expand. + alignas(32) T lanes[N]; + Store(v, d, lanes); + const Mask128 maskH = MaskFromVec(UpperHalf(dh, VecFromMask(d, mask))); + const Vec128 expandH = Expand(LoadU(dh, lanes + countL), maskH); + return Combine(d, expandH, expandL); +#endif +} + +// If AVX3, this is already implemented by x86_512. +#if HWY_TARGET != HWY_AVX3 + +template +HWY_API Vec256 Expand(Vec256 v, Mask256 mask) { + const Full256 d; +#if HWY_TARGET <= HWY_AVX3_DL // VBMI2 + const RebindToUnsigned du; + return BitCast(d, detail::NativeExpand(BitCast(du, v), RebindMask(du, mask))); +#else // AVX2 + // LUTs are infeasible for 2^16 possible masks, so splice together two + // half-vector Expand. + const Half dh; + const Mask128 maskL = MaskFromVec(LowerHalf(VecFromMask(d, mask))); + const Vec128 expandL = Expand(LowerHalf(v), maskL); + // We have to shift the input by a variable number of u16. permutevar_epi16 + // requires AVX3 and if we had that, we'd use native u32 Expand. The only + // alternative is re-loading, which incurs a store to load forwarding stall. + alignas(32) T lanes[32 / sizeof(T)]; + Store(v, d, lanes); + const Vec128 vH = LoadU(dh, lanes + CountTrue(dh, maskL)); + const Mask128 maskH = MaskFromVec(UpperHalf(dh, VecFromMask(d, mask))); + const Vec128 expandH = Expand(vH, maskH); + return Combine(d, expandH, expandL); +#endif // AVX2 +} + +#endif // HWY_TARGET != HWY_AVX3 + +template +HWY_API Vec256 Expand(Vec256 v, Mask256 mask) { + const Full256 d; +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned du; + const MFromD mu = RebindMask(du, mask); + return BitCast(d, detail::NativeExpand(BitCast(du, v), mu)); +#else + const RebindToUnsigned du; + const uint64_t mask_bits = BitsFromMask(d, mask); + + alignas(16) constexpr uint32_t packed_array[256] = { + // PrintExpand32x8Nibble. + 0xffffffff, 0xfffffff0, 0xffffff0f, 0xffffff10, 0xfffff0ff, 0xfffff1f0, + 0xfffff10f, 0xfffff210, 0xffff0fff, 0xffff1ff0, 0xffff1f0f, 0xffff2f10, + 0xffff10ff, 0xffff21f0, 0xffff210f, 0xffff3210, 0xfff0ffff, 0xfff1fff0, + 0xfff1ff0f, 0xfff2ff10, 0xfff1f0ff, 0xfff2f1f0, 0xfff2f10f, 0xfff3f210, + 0xfff10fff, 0xfff21ff0, 0xfff21f0f, 0xfff32f10, 0xfff210ff, 0xfff321f0, + 0xfff3210f, 0xfff43210, 0xff0fffff, 0xff1ffff0, 0xff1fff0f, 0xff2fff10, + 0xff1ff0ff, 0xff2ff1f0, 0xff2ff10f, 0xff3ff210, 0xff1f0fff, 0xff2f1ff0, + 0xff2f1f0f, 0xff3f2f10, 0xff2f10ff, 0xff3f21f0, 0xff3f210f, 0xff4f3210, + 0xff10ffff, 0xff21fff0, 0xff21ff0f, 0xff32ff10, 0xff21f0ff, 0xff32f1f0, + 0xff32f10f, 0xff43f210, 0xff210fff, 0xff321ff0, 0xff321f0f, 0xff432f10, + 0xff3210ff, 0xff4321f0, 0xff43210f, 0xff543210, 0xf0ffffff, 0xf1fffff0, + 0xf1ffff0f, 0xf2ffff10, 0xf1fff0ff, 0xf2fff1f0, 0xf2fff10f, 0xf3fff210, + 0xf1ff0fff, 0xf2ff1ff0, 0xf2ff1f0f, 0xf3ff2f10, 0xf2ff10ff, 0xf3ff21f0, + 0xf3ff210f, 0xf4ff3210, 0xf1f0ffff, 0xf2f1fff0, 0xf2f1ff0f, 0xf3f2ff10, + 0xf2f1f0ff, 0xf3f2f1f0, 0xf3f2f10f, 0xf4f3f210, 0xf2f10fff, 0xf3f21ff0, + 0xf3f21f0f, 0xf4f32f10, 0xf3f210ff, 0xf4f321f0, 0xf4f3210f, 0xf5f43210, + 0xf10fffff, 0xf21ffff0, 0xf21fff0f, 0xf32fff10, 0xf21ff0ff, 0xf32ff1f0, + 0xf32ff10f, 0xf43ff210, 0xf21f0fff, 0xf32f1ff0, 0xf32f1f0f, 0xf43f2f10, + 0xf32f10ff, 0xf43f21f0, 0xf43f210f, 0xf54f3210, 0xf210ffff, 0xf321fff0, + 0xf321ff0f, 0xf432ff10, 0xf321f0ff, 0xf432f1f0, 0xf432f10f, 0xf543f210, + 0xf3210fff, 0xf4321ff0, 0xf4321f0f, 0xf5432f10, 0xf43210ff, 0xf54321f0, + 0xf543210f, 0xf6543210, 0x0fffffff, 0x1ffffff0, 0x1fffff0f, 0x2fffff10, + 0x1ffff0ff, 0x2ffff1f0, 0x2ffff10f, 0x3ffff210, 0x1fff0fff, 0x2fff1ff0, + 0x2fff1f0f, 0x3fff2f10, 0x2fff10ff, 0x3fff21f0, 0x3fff210f, 0x4fff3210, + 0x1ff0ffff, 0x2ff1fff0, 0x2ff1ff0f, 0x3ff2ff10, 0x2ff1f0ff, 0x3ff2f1f0, + 0x3ff2f10f, 0x4ff3f210, 0x2ff10fff, 0x3ff21ff0, 0x3ff21f0f, 0x4ff32f10, + 0x3ff210ff, 0x4ff321f0, 0x4ff3210f, 0x5ff43210, 0x1f0fffff, 0x2f1ffff0, + 0x2f1fff0f, 0x3f2fff10, 0x2f1ff0ff, 0x3f2ff1f0, 0x3f2ff10f, 0x4f3ff210, + 0x2f1f0fff, 0x3f2f1ff0, 0x3f2f1f0f, 0x4f3f2f10, 0x3f2f10ff, 0x4f3f21f0, + 0x4f3f210f, 0x5f4f3210, 0x2f10ffff, 0x3f21fff0, 0x3f21ff0f, 0x4f32ff10, + 0x3f21f0ff, 0x4f32f1f0, 0x4f32f10f, 0x5f43f210, 0x3f210fff, 0x4f321ff0, + 0x4f321f0f, 0x5f432f10, 0x4f3210ff, 0x5f4321f0, 0x5f43210f, 0x6f543210, + 0x10ffffff, 0x21fffff0, 0x21ffff0f, 0x32ffff10, 0x21fff0ff, 0x32fff1f0, + 0x32fff10f, 0x43fff210, 0x21ff0fff, 0x32ff1ff0, 0x32ff1f0f, 0x43ff2f10, + 0x32ff10ff, 0x43ff21f0, 0x43ff210f, 0x54ff3210, 0x21f0ffff, 0x32f1fff0, + 0x32f1ff0f, 0x43f2ff10, 0x32f1f0ff, 0x43f2f1f0, 0x43f2f10f, 0x54f3f210, + 0x32f10fff, 0x43f21ff0, 0x43f21f0f, 0x54f32f10, 0x43f210ff, 0x54f321f0, + 0x54f3210f, 0x65f43210, 0x210fffff, 0x321ffff0, 0x321fff0f, 0x432fff10, + 0x321ff0ff, 0x432ff1f0, 0x432ff10f, 0x543ff210, 0x321f0fff, 0x432f1ff0, + 0x432f1f0f, 0x543f2f10, 0x432f10ff, 0x543f21f0, 0x543f210f, 0x654f3210, + 0x3210ffff, 0x4321fff0, 0x4321ff0f, 0x5432ff10, 0x4321f0ff, 0x5432f1f0, + 0x5432f10f, 0x6543f210, 0x43210fff, 0x54321ff0, 0x54321f0f, 0x65432f10, + 0x543210ff, 0x654321f0, 0x6543210f, 0x76543210, + }; + + // For lane i, shift the i-th 4-bit index down to bits [0, 3). + const Vec256 packed = Set(du, packed_array[mask_bits]); + alignas(32) constexpr uint32_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; + // TableLookupLanes ignores upper bits; avoid bounds-check in IndicesFromVec. + const Indices256 indices{(packed >> Load(du, shifts)).raw}; + const Vec256 expand = TableLookupLanes(BitCast(du, v), indices); + // TableLookupLanes cannot also zero masked-off lanes, so do that now. + return IfThenElseZero(mask, BitCast(d, expand)); +#endif +} + +template +HWY_API Vec256 Expand(Vec256 v, Mask256 mask) { + const Full256 d; +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned du; + const MFromD mu = RebindMask(du, mask); + return BitCast(d, detail::NativeExpand(BitCast(du, v), mu)); +#else + const RebindToUnsigned du; + const uint64_t mask_bits = BitsFromMask(d, mask); + + alignas(16) constexpr uint64_t packed_array[16] = { + // PrintExpand64x4Nibble. + 0x0000ffff, 0x0000fff0, 0x0000ff0f, 0x0000ff10, 0x0000f0ff, 0x0000f1f0, + 0x0000f10f, 0x0000f210, 0x00000fff, 0x00001ff0, 0x00001f0f, 0x00002f10, + 0x000010ff, 0x000021f0, 0x0000210f, 0x00003210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 2). + const Vec256 packed = Set(du, packed_array[mask_bits]); + alignas(32) constexpr uint64_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; +#if HWY_TARGET <= HWY_AVX3 // native 64-bit TableLookupLanes + // TableLookupLanes ignores upper bits; avoid bounds-check in IndicesFromVec. + const Indices256 indices{(packed >> Load(du, shifts)).raw}; +#else + // 64-bit TableLookupLanes on AVX2 requires IndicesFromVec, which checks + // bounds, so clear the upper bits. + const Vec256 masked = And(packed >> Load(du, shifts), Set(du, 3)); + const Indices256 indices = IndicesFromVec(du, masked); +#endif + const Vec256 expand = TableLookupLanes(BitCast(du, v), indices); + // TableLookupLanes cannot also zero masked-off lanes, so do that now. + return IfThenElseZero(mask, BitCast(d, expand)); +#endif +} + +// ------------------------------ LoadExpand + +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { +#if HWY_TARGET <= HWY_AVX3_DL // VBMI2 + const RebindToUnsigned du; + using TU = TFromD; + const TU* HWY_RESTRICT pu = reinterpret_cast(unaligned); + const MFromD mu = RebindMask(du, mask); + return BitCast(d, detail::NativeLoadExpand(mu, du, pu)); +#else + return Expand(LoadU(d, unaligned), mask); +#endif +} + +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { +#if HWY_TARGET <= HWY_AVX3 + const RebindToUnsigned du; + using TU = TFromD; + const TU* HWY_RESTRICT pu = reinterpret_cast(unaligned); + const MFromD mu = RebindMask(du, mask); + return BitCast(d, detail::NativeLoadExpand(mu, du, pu)); +#else + return Expand(LoadU(d, unaligned), mask); +#endif +} + +// ------------------------------ LoadInterleaved3/4 + +// Implemented in generic_ops, we just overload LoadTransposedBlocks3/4. + +namespace detail { +// Input: +// 1 0 (<- first block of unaligned) +// 3 2 +// 5 4 +// Output: +// 3 0 +// 4 1 +// 5 2 +template +HWY_API void LoadTransposedBlocks3(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& A, VFromD& B, VFromD& C) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + const VFromD v10 = LoadU(d, unaligned + 0 * N); // 1 0 + const VFromD v32 = LoadU(d, unaligned + 1 * N); + const VFromD v54 = LoadU(d, unaligned + 2 * N); + + A = ConcatUpperLower(d, v32, v10); + B = ConcatLowerUpper(d, v54, v10); + C = ConcatUpperLower(d, v54, v32); +} + +// Input (128-bit blocks): +// 1 0 (first block of unaligned) +// 3 2 +// 5 4 +// 7 6 +// Output: +// 4 0 (LSB of vA) +// 5 1 +// 6 2 +// 7 3 +template +HWY_API void LoadTransposedBlocks4(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& vA, VFromD& vB, VFromD& vC, + VFromD& vD) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + const VFromD v10 = LoadU(d, unaligned + 0 * N); + const VFromD v32 = LoadU(d, unaligned + 1 * N); + const VFromD v54 = LoadU(d, unaligned + 2 * N); + const VFromD v76 = LoadU(d, unaligned + 3 * N); + + vA = ConcatLowerLower(d, v54, v10); + vB = ConcatUpperUpper(d, v54, v10); + vC = ConcatLowerLower(d, v76, v32); + vD = ConcatUpperUpper(d, v76, v32); +} +} // namespace detail + +// ------------------------------ StoreInterleaved2/3/4 (ConcatUpperLower) + +// Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. + +namespace detail { +// Input (128-bit blocks): +// 2 0 (LSB of i) +// 3 1 +// Output: +// 1 0 +// 3 2 +template +HWY_API void StoreTransposedBlocks2(VFromD i, VFromD j, D d, + TFromD* HWY_RESTRICT unaligned) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatUpperUpper(d, j, i); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); +} + +// Input (128-bit blocks): +// 3 0 (LSB of i) +// 4 1 +// 5 2 +// Output: +// 1 0 +// 3 2 +// 5 4 +template +HWY_API void StoreTransposedBlocks3(VFromD i, VFromD j, VFromD k, D d, + TFromD* HWY_RESTRICT unaligned) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatUpperLower(d, i, k); + const auto out2 = ConcatUpperUpper(d, k, j); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + StoreU(out2, d, unaligned + 2 * N); +} + +// Input (128-bit blocks): +// 4 0 (LSB of i) +// 5 1 +// 6 2 +// 7 3 +// Output: +// 1 0 +// 3 2 +// 5 4 +// 7 6 +template +HWY_API void StoreTransposedBlocks4(VFromD i, VFromD j, VFromD k, + VFromD l, D d, + TFromD* HWY_RESTRICT unaligned) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + // Write lower halves, then upper. + const auto out0 = ConcatLowerLower(d, j, i); + const auto out1 = ConcatLowerLower(d, l, k); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + const auto out2 = ConcatUpperUpper(d, j, i); + const auto out3 = ConcatUpperUpper(d, l, k); + StoreU(out2, d, unaligned + 2 * N); + StoreU(out3, d, unaligned + 3 * N); +} +} // namespace detail + +// ------------------------------ Additional mask logical operations + +#if HWY_TARGET <= HWY_AVX3 +template +HWY_API Mask256 SetAtOrAfterFirst(Mask256 mask) { + constexpr size_t N = MaxLanes(Full256()); + constexpr uint32_t kActiveElemMask = + static_cast((uint64_t{1} << N) - 1); + return Mask256{static_cast::Raw>( + (0u - detail::AVX3Blsi(mask.raw)) & kActiveElemMask)}; +} +template +HWY_API Mask256 SetBeforeFirst(Mask256 mask) { + constexpr size_t N = MaxLanes(Full256()); + constexpr uint32_t kActiveElemMask = + static_cast((uint64_t{1} << N) - 1); + return Mask256{static_cast::Raw>( + (detail::AVX3Blsi(mask.raw) - 1u) & kActiveElemMask)}; +} +template +HWY_API Mask256 SetAtOrBeforeFirst(Mask256 mask) { + constexpr size_t N = MaxLanes(Full256()); + constexpr uint32_t kActiveElemMask = + static_cast((uint64_t{1} << N) - 1); + return Mask256{static_cast::Raw>( + detail::AVX3Blsmsk(mask.raw) & kActiveElemMask)}; +} +template +HWY_API Mask256 SetOnlyFirst(Mask256 mask) { + return Mask256{ + static_cast::Raw>(detail::AVX3Blsi(mask.raw))}; +} +#else // AVX2 +template +HWY_API Mask256 SetAtOrAfterFirst(Mask256 mask) { + const Full256 d; + const Repartition di64; + const Repartition df32; + const Repartition di32; + const Half dh_i64; + const Half dh_i32; + using VF32 = VFromD; + + auto vmask = BitCast(di64, VecFromMask(d, mask)); + vmask = Or(vmask, Neg(vmask)); + + // Copy the sign bit of the even int64_t lanes to the odd int64_t lanes + const auto vmask2 = BitCast( + di32, VF32{_mm256_shuffle_ps(Zero(df32).raw, BitCast(df32, vmask).raw, + _MM_SHUFFLE(1, 1, 0, 0))}); + vmask = Or(vmask, BitCast(di64, BroadcastSignBit(vmask2))); + + // Copy the sign bit of the lower 128-bit half to the upper 128-bit half + const auto vmask3 = + BroadcastSignBit(Broadcast<3>(BitCast(dh_i32, LowerHalf(dh_i64, vmask)))); + vmask = Or(vmask, BitCast(di64, Combine(di32, vmask3, Zero(dh_i32)))); + return MaskFromVec(BitCast(d, vmask)); +} + +template +HWY_API Mask256 SetBeforeFirst(Mask256 mask) { + return Not(SetAtOrAfterFirst(mask)); +} + +template +HWY_API Mask256 SetOnlyFirst(Mask256 mask) { + const Full256 d; + const RebindToSigned di; + const Repartition di64; + const Half dh_i64; + + const auto zero = Zero(di64); + const auto vmask = BitCast(di64, VecFromMask(d, mask)); + + const auto vmask_eq_0 = VecFromMask(di64, vmask == zero); + auto vmask2_lo = LowerHalf(dh_i64, vmask_eq_0); + auto vmask2_hi = UpperHalf(dh_i64, vmask_eq_0); + + vmask2_lo = And(vmask2_lo, InterleaveLower(vmask2_lo, vmask2_lo)); + vmask2_hi = And(ConcatLowerUpper(dh_i64, vmask2_hi, vmask2_lo), + InterleaveUpper(dh_i64, vmask2_lo, vmask2_lo)); + vmask2_lo = InterleaveLower(Set(dh_i64, int64_t{-1}), vmask2_lo); + + const auto vmask2 = Combine(di64, vmask2_hi, vmask2_lo); + const auto only_first_vmask = Neg(BitCast(di, And(vmask, Neg(vmask)))); + return MaskFromVec(BitCast(d, And(only_first_vmask, BitCast(di, vmask2)))); +} + +template +HWY_API Mask256 SetAtOrBeforeFirst(Mask256 mask) { + const Full256 d; + constexpr size_t kLanesPerBlock = MaxLanes(d) / 2; + + const auto vmask = VecFromMask(d, mask); + const auto vmask_lo = ConcatLowerLower(d, vmask, Zero(d)); + return SetBeforeFirst( + MaskFromVec(CombineShiftRightBytes<(kLanesPerBlock - 1) * sizeof(T)>( + d, vmask, vmask_lo))); +} +#endif // HWY_TARGET <= HWY_AVX3 + +// ------------------------------ Reductions in generic_ops + +// ------------------------------ BitShuffle +#if HWY_TARGET <= HWY_AVX3_DL +template ), HWY_IF_UI8(TFromV), + HWY_IF_V_SIZE_V(V, 32), HWY_IF_V_SIZE_V(VI, 32)> +HWY_API V BitShuffle(V v, VI idx) { + const DFromV d64; + const RebindToUnsigned du64; + const Rebind du8; + + int32_t i32_bit_shuf_result = + static_cast(_mm256_bitshuffle_epi64_mask(v.raw, idx.raw)); + + return BitCast(d64, PromoteTo(du64, VFromD{_mm_cvtsi32_si128( + i32_bit_shuf_result)})); +} +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ------------------------------ MultiRotateRight + +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_MULTIROTATERIGHT +#undef HWY_NATIVE_MULTIROTATERIGHT +#else +#define HWY_NATIVE_MULTIROTATERIGHT +#endif + +template ), HWY_IF_UI8(TFromV), + HWY_IF_V_SIZE_V(V, 32), HWY_IF_V_SIZE_V(VI, HWY_MAX_LANES_V(V) * 8)> +HWY_API V MultiRotateRight(V v, VI idx) { + return V{_mm256_multishift_epi64_epi8(idx.raw, v.raw)}; +} + +#endif + +// ------------------------------ LeadingZeroCount + +#if HWY_TARGET <= HWY_AVX3 +template ), HWY_IF_V_SIZE_V(V, 32)> +HWY_API V LeadingZeroCount(V v) { + return V{_mm256_lzcnt_epi32(v.raw)}; +} + +template ), HWY_IF_V_SIZE_V(V, 32)> +HWY_API V LeadingZeroCount(V v) { + return V{_mm256_lzcnt_epi64(v.raw)}; +} + +namespace detail { + +template , HWY_MAX_BYTES / 4)> +static HWY_INLINE HWY_MAYBE_UNUSED V Lzcnt32ForU8OrU16OrU32(V v) { + const DFromV d; + const Rebind di32; + const Rebind du32; + + const auto v_lz_count = LeadingZeroCount(PromoteTo(du32, v)); + return DemoteTo(d, BitCast(di32, v_lz_count)); +} + +template +static HWY_INLINE HWY_MAYBE_UNUSED V Lzcnt32ForU8OrU16OrU32(V v) { + return LeadingZeroCount(v); +} + +template , HWY_MAX_BYTES / 4)> +static HWY_INLINE HWY_MAYBE_UNUSED V Lzcnt32ForU8OrU16OrU32(V v) { + const DFromV d; + const RepartitionToWide dw; + const RebindToSigned dw_i; + + const auto lo_v_lz_count = Lzcnt32ForU8OrU16OrU32(PromoteLowerTo(dw, v)); + const auto hi_v_lz_count = Lzcnt32ForU8OrU16OrU32(PromoteUpperTo(dw, v)); + return OrderedDemote2To(d, BitCast(dw_i, lo_v_lz_count), + BitCast(dw_i, hi_v_lz_count)); +} + +} // namespace detail + +template +HWY_API V LeadingZeroCount(V v) { + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; + + constexpr TU kNumOfBitsInT{sizeof(TU) * 8}; + const auto v_lzcnt32 = detail::Lzcnt32ForU8OrU16OrU32(BitCast(du, v)); + return BitCast(d, Min(v_lzcnt32 - Set(du, TU{32 - kNumOfBitsInT}), + Set(du, TU{kNumOfBitsInT}))); +} + +template +HWY_API V HighestSetBitIndex(V v) { + const DFromV d; + const RebindToUnsigned du; + using TU = TFromD; + return BitCast( + d, Set(du, TU{31}) - detail::Lzcnt32ForU8OrU16OrU32(BitCast(du, v))); +} + +template +HWY_API V HighestSetBitIndex(V v) { + const DFromV d; + using T = TFromD; + return BitCast(d, Set(d, T{sizeof(T) * 8 - 1}) - LeadingZeroCount(v)); +} + +template +HWY_API V TrailingZeroCount(V v) { + const DFromV d; + const RebindToSigned di; + using T = TFromD; + + const auto vi = BitCast(di, v); + const auto lowest_bit = BitCast(d, And(vi, Neg(vi))); + constexpr T kNumOfBitsInT{sizeof(T) * 8}; + const auto bit_idx = HighestSetBitIndex(lowest_bit); + return IfThenElse(MaskFromVec(bit_idx), Set(d, kNumOfBitsInT), bit_idx); +} +#endif // HWY_TARGET <= HWY_AVX3 + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +// Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - +// the warning seems to be issued at the call site of intrinsics, i.e. our code. +HWY_DIAGNOSTICS(pop) diff --git a/lib/highway/hwy/ops/x86_512-inl.h b/lib/highway/hwy/ops/x86_512-inl.h new file mode 100644 index 00000000000..433cc765ef0 --- /dev/null +++ b/lib/highway/hwy/ops/x86_512-inl.h @@ -0,0 +1,7718 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// 512-bit AVX512 vectors and operations. +// External include guard in highway.h - see comment there. + +// WARNING: most operations do not cross 128-bit block boundaries. In +// particular, "Broadcast", pack and zip behavior may be surprising. + +// Must come before HWY_DIAGNOSTICS and HWY_COMPILER_CLANGCL +#include "hwy/base.h" + +// Avoid uninitialized warnings in GCC's avx512fintrin.h - see +// https://github.com/google/highway/issues/710) +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") +HWY_DIAGNOSTICS_OFF(disable : 4701 4703 6001 26494, + ignored "-Wmaybe-uninitialized") +#endif + +#include // AVX2+ + +#if HWY_COMPILER_CLANGCL +// Including should be enough, but Clang's headers helpfully skip +// including these headers when _MSC_VER is defined, like when using clang-cl. +// Include these directly here. +// clang-format off +#include + +#include +// avxintrin defines __m256i and must come before avx2intrin. +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#if HWY_TARGET <= HWY_AVX3_DL +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +// Must come after avx512fintrin, else will not define 512-bit intrinsics. +#include +#include +#include +#endif // HWY_TARGET <= HWY_AVX3_DL + +#if HWY_TARGET <= HWY_AVX3_SPR +#include +#include +#endif // HWY_TARGET <= HWY_AVX3_SPR + +// clang-format on +#endif // HWY_COMPILER_CLANGCL + +// For half-width vectors. Already includes base.h and shared-inl.h. +#include "hwy/ops/x86_256-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +namespace detail { + +template +struct Raw512 { + using type = __m512i; +}; +#if HWY_HAVE_FLOAT16 +template <> +struct Raw512 { + using type = __m512h; +}; +#endif // HWY_HAVE_FLOAT16 +template <> +struct Raw512 { + using type = __m512; +}; +template <> +struct Raw512 { + using type = __m512d; +}; + +// Template arg: sizeof(lane type) +template +struct RawMask512 {}; +template <> +struct RawMask512<1> { + using type = __mmask64; +}; +template <> +struct RawMask512<2> { + using type = __mmask32; +}; +template <> +struct RawMask512<4> { + using type = __mmask16; +}; +template <> +struct RawMask512<8> { + using type = __mmask8; +}; + +} // namespace detail + +template +class Vec512 { + using Raw = typename detail::Raw512::type; + + public: + using PrivateT = T; // only for DFromV + static constexpr size_t kPrivateN = 64 / sizeof(T); // only for DFromV + + // Compound assignment. Only usable if there is a corresponding non-member + // binary operator overload. For example, only f32 and f64 support division. + HWY_INLINE Vec512& operator*=(const Vec512 other) { + return *this = (*this * other); + } + HWY_INLINE Vec512& operator/=(const Vec512 other) { + return *this = (*this / other); + } + HWY_INLINE Vec512& operator+=(const Vec512 other) { + return *this = (*this + other); + } + HWY_INLINE Vec512& operator-=(const Vec512 other) { + return *this = (*this - other); + } + HWY_INLINE Vec512& operator%=(const Vec512 other) { + return *this = (*this % other); + } + HWY_INLINE Vec512& operator&=(const Vec512 other) { + return *this = (*this & other); + } + HWY_INLINE Vec512& operator|=(const Vec512 other) { + return *this = (*this | other); + } + HWY_INLINE Vec512& operator^=(const Vec512 other) { + return *this = (*this ^ other); + } + + Raw raw; +}; + +// Mask register: one bit per lane. +template +struct Mask512 { + using Raw = typename detail::RawMask512::type; + + using PrivateT = T; // only for DFromM + static constexpr size_t kPrivateN = 64 / sizeof(T); // only for DFromM + + Raw raw; +}; + +template +using Full512 = Simd; + +// ------------------------------ BitCast + +namespace detail { + +HWY_INLINE __m512i BitCastToInteger(__m512i v) { return v; } +#if HWY_HAVE_FLOAT16 +HWY_INLINE __m512i BitCastToInteger(__m512h v) { + return _mm512_castph_si512(v); +} +#endif // HWY_HAVE_FLOAT16 +HWY_INLINE __m512i BitCastToInteger(__m512 v) { return _mm512_castps_si512(v); } +HWY_INLINE __m512i BitCastToInteger(__m512d v) { + return _mm512_castpd_si512(v); +} + +#if HWY_AVX3_HAVE_F32_TO_BF16C +HWY_INLINE __m512i BitCastToInteger(__m512bh v) { + // Need to use reinterpret_cast on GCC/Clang or BitCastScalar on MSVC to + // bit cast a __m512bh to a __m512i as there is currently no intrinsic + // available (as of GCC 13 and Clang 17) that can bit cast a __m512bh vector + // to a __m512i vector + +#if HWY_COMPILER_GCC || HWY_COMPILER_CLANG + // On GCC or Clang, use reinterpret_cast to bit cast a __m512bh to a __m512i + return reinterpret_cast<__m512i>(v); +#else + // On MSVC, use BitCastScalar to bit cast a __m512bh to a __m512i as MSVC does + // not allow reinterpret_cast, static_cast, or a C-style cast to be used to + // bit cast from one AVX vector type to a different AVX vector type + return BitCastScalar<__m512i>(v); +#endif // HWY_COMPILER_GCC || HWY_COMPILER_CLANG +} +#endif // HWY_AVX3_HAVE_F32_TO_BF16C + +template +HWY_INLINE Vec512 BitCastToByte(Vec512 v) { + return Vec512{BitCastToInteger(v.raw)}; +} + +// Cannot rely on function overloading because return types differ. +template +struct BitCastFromInteger512 { + HWY_INLINE __m512i operator()(__m512i v) { return v; } +}; +#if HWY_HAVE_FLOAT16 +template <> +struct BitCastFromInteger512 { + HWY_INLINE __m512h operator()(__m512i v) { return _mm512_castsi512_ph(v); } +}; +#endif // HWY_HAVE_FLOAT16 +template <> +struct BitCastFromInteger512 { + HWY_INLINE __m512 operator()(__m512i v) { return _mm512_castsi512_ps(v); } +}; +template <> +struct BitCastFromInteger512 { + HWY_INLINE __m512d operator()(__m512i v) { return _mm512_castsi512_pd(v); } +}; + +template +HWY_INLINE VFromD BitCastFromByte(D /* tag */, Vec512 v) { + return VFromD{BitCastFromInteger512>()(v.raw)}; +} + +} // namespace detail + +template +HWY_API VFromD BitCast(D d, Vec512 v) { + return detail::BitCastFromByte(d, detail::BitCastToByte(v)); +} + +// ------------------------------ Set + +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm512_set1_epi8(static_cast(t))}; // NOLINT +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm512_set1_epi16(static_cast(t))}; // NOLINT +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm512_set1_epi32(static_cast(t))}; +} +template +HWY_API VFromD Set(D /* tag */, TFromD t) { + return VFromD{_mm512_set1_epi64(static_cast(t))}; // NOLINT +} +// bfloat16_t is handled by x86_128-inl.h. +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec512 Set(D /* tag */, float16_t t) { + return Vec512{_mm512_set1_ph(t)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec512 Set(D /* tag */, float t) { + return Vec512{_mm512_set1_ps(t)}; +} +template +HWY_API Vec512 Set(D /* tag */, double t) { + return Vec512{_mm512_set1_pd(t)}; +} + +// ------------------------------ Zero (Set) + +// GCC pre-9.1 lacked setzero, so use Set instead. +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 + +// Cannot use VFromD here because it is defined in terms of Zero. +template +HWY_API Vec512> Zero(D d) { + return Set(d, TFromD{0}); +} +// BitCast is defined below, but the Raw type is the same, so use that. +template +HWY_API Vec512 Zero(D /* tag */) { + const RebindToUnsigned du; + return Vec512{Set(du, 0).raw}; +} +template +HWY_API Vec512 Zero(D /* tag */) { + const RebindToUnsigned du; + return Vec512{Set(du, 0).raw}; +} + +#else + +template +HWY_API Vec512> Zero(D /* tag */) { + return Vec512>{_mm512_setzero_si512()}; +} +template +HWY_API Vec512 Zero(D /* tag */) { + return Vec512{_mm512_setzero_si512()}; +} +template +HWY_API Vec512 Zero(D /* tag */) { +#if HWY_HAVE_FLOAT16 + return Vec512{_mm512_setzero_ph()}; +#else + return Vec512{_mm512_setzero_si512()}; +#endif +} +template +HWY_API Vec512 Zero(D /* tag */) { + return Vec512{_mm512_setzero_ps()}; +} +template +HWY_API Vec512 Zero(D /* tag */) { + return Vec512{_mm512_setzero_pd()}; +} + +#endif // HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 + +// ------------------------------ Undefined + +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") + +// Returns a vector with uninitialized elements. +template +HWY_API Vec512> Undefined(D /* tag */) { + // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC + // generate an XOR instruction. + return Vec512>{_mm512_undefined_epi32()}; +} +template +HWY_API Vec512 Undefined(D /* tag */) { + return Vec512{_mm512_undefined_epi32()}; +} +template +HWY_API Vec512 Undefined(D /* tag */) { +#if HWY_HAVE_FLOAT16 + return Vec512{_mm512_undefined_ph()}; +#else + return Vec512{_mm512_undefined_epi32()}; +#endif +} +template +HWY_API Vec512 Undefined(D /* tag */) { + return Vec512{_mm512_undefined_ps()}; +} +template +HWY_API Vec512 Undefined(D /* tag */) { + return Vec512{_mm512_undefined_pd()}; +} + +HWY_DIAGNOSTICS(pop) + +// ------------------------------ ResizeBitCast + +// 64-byte vector to 16-byte vector +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + return BitCast(d, Vec128{_mm512_castsi512_si128( + BitCast(Full512(), v).raw)}); +} + +// <= 16-byte vector to 64-byte vector +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + return BitCast(d, Vec512{_mm512_castsi128_si512( + ResizeBitCast(Full128(), v).raw)}); +} + +// 32-byte vector to 64-byte vector +template +HWY_API VFromD ResizeBitCast(D d, FromV v) { + return BitCast(d, Vec512{_mm512_castsi256_si512( + BitCast(Full256(), v).raw)}); +} + +// ------------------------------ Dup128VecFromValues + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, TFromD t7, + TFromD t8, TFromD t9, TFromD t10, + TFromD t11, TFromD t12, + TFromD t13, TFromD t14, + TFromD t15) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 + // Missing set_epi8/16. + return BroadcastBlock<0>(ResizeBitCast( + d, Dup128VecFromValues(Full128>(), t0, t1, t2, t3, t4, t5, t6, + t7, t8, t9, t10, t11, t12, t13, t14, t15))); +#else + (void)d; + // Need to use _mm512_set_epi8 as there is no _mm512_setr_epi8 intrinsic + // available + return VFromD{_mm512_set_epi8( + static_cast(t15), static_cast(t14), static_cast(t13), + static_cast(t12), static_cast(t11), static_cast(t10), + static_cast(t9), static_cast(t8), static_cast(t7), + static_cast(t6), static_cast(t5), static_cast(t4), + static_cast(t3), static_cast(t2), static_cast(t1), + static_cast(t0), static_cast(t15), static_cast(t14), + static_cast(t13), static_cast(t12), static_cast(t11), + static_cast(t10), static_cast(t9), static_cast(t8), + static_cast(t7), static_cast(t6), static_cast(t5), + static_cast(t4), static_cast(t3), static_cast(t2), + static_cast(t1), static_cast(t0), static_cast(t15), + static_cast(t14), static_cast(t13), static_cast(t12), + static_cast(t11), static_cast(t10), static_cast(t9), + static_cast(t8), static_cast(t7), static_cast(t6), + static_cast(t5), static_cast(t4), static_cast(t3), + static_cast(t2), static_cast(t1), static_cast(t0), + static_cast(t15), static_cast(t14), static_cast(t13), + static_cast(t12), static_cast(t11), static_cast(t10), + static_cast(t9), static_cast(t8), static_cast(t7), + static_cast(t6), static_cast(t5), static_cast(t4), + static_cast(t3), static_cast(t2), static_cast(t1), + static_cast(t0))}; +#endif +} + +template +HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 + // Missing set_epi8/16. + return BroadcastBlock<0>( + ResizeBitCast(d, Dup128VecFromValues(Full128>(), t0, t1, t2, t3, + t4, t5, t6, t7))); +#else + (void)d; + // Need to use _mm512_set_epi16 as there is no _mm512_setr_epi16 intrinsic + // available + return VFromD{ + _mm512_set_epi16(static_cast(t7), static_cast(t6), + static_cast(t5), static_cast(t4), + static_cast(t3), static_cast(t2), + static_cast(t1), static_cast(t0), + static_cast(t7), static_cast(t6), + static_cast(t5), static_cast(t4), + static_cast(t3), static_cast(t2), + static_cast(t1), static_cast(t0), + static_cast(t7), static_cast(t6), + static_cast(t5), static_cast(t4), + static_cast(t3), static_cast(t2), + static_cast(t1), static_cast(t0), + static_cast(t7), static_cast(t6), + static_cast(t5), static_cast(t4), + static_cast(t3), static_cast(t2), + static_cast(t1), static_cast(t0))}; +#endif +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3, TFromD t4, + TFromD t5, TFromD t6, + TFromD t7) { + return VFromD{_mm512_setr_ph(t0, t1, t2, t3, t4, t5, t6, t7, t0, t1, t2, + t3, t4, t5, t6, t7, t0, t1, t2, t3, t4, t5, + t6, t7, t0, t1, t2, t3, t4, t5, t6, t7)}; +} +#endif + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{ + _mm512_setr_epi32(static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3), + static_cast(t0), static_cast(t1), + static_cast(t2), static_cast(t3))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, + TFromD t2, TFromD t3) { + return VFromD{_mm512_setr_ps(t0, t1, t2, t3, t0, t1, t2, t3, t0, t1, t2, + t3, t0, t1, t2, t3)}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{ + _mm512_setr_epi64(static_cast(t0), static_cast(t1), + static_cast(t0), static_cast(t1), + static_cast(t0), static_cast(t1), + static_cast(t0), static_cast(t1))}; +} + +template +HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { + return VFromD{_mm512_setr_pd(t0, t1, t0, t1, t0, t1, t0, t1)}; +} + +// ----------------------------- Iota + +namespace detail { + +template +HWY_INLINE VFromD Iota0(D d) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 + // Missing set_epi8/16. + alignas(64) static constexpr TFromD kIota[64] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}; + return Load(d, kIota); +#else + (void)d; + return VFromD{_mm512_set_epi8( + static_cast(63), static_cast(62), static_cast(61), + static_cast(60), static_cast(59), static_cast(58), + static_cast(57), static_cast(56), static_cast(55), + static_cast(54), static_cast(53), static_cast(52), + static_cast(51), static_cast(50), static_cast(49), + static_cast(48), static_cast(47), static_cast(46), + static_cast(45), static_cast(44), static_cast(43), + static_cast(42), static_cast(41), static_cast(40), + static_cast(39), static_cast(38), static_cast(37), + static_cast(36), static_cast(35), static_cast(34), + static_cast(33), static_cast(32), static_cast(31), + static_cast(30), static_cast(29), static_cast(28), + static_cast(27), static_cast(26), static_cast(25), + static_cast(24), static_cast(23), static_cast(22), + static_cast(21), static_cast(20), static_cast(19), + static_cast(18), static_cast(17), static_cast(16), + static_cast(15), static_cast(14), static_cast(13), + static_cast(12), static_cast(11), static_cast(10), + static_cast(9), static_cast(8), static_cast(7), + static_cast(6), static_cast(5), static_cast(4), + static_cast(3), static_cast(2), static_cast(1), + static_cast(0))}; +#endif +} + +template +HWY_INLINE VFromD Iota0(D d) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900 + // Missing set_epi8/16. + alignas(64) static constexpr TFromD kIota[32] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; + return Load(d, kIota); +#else + (void)d; + return VFromD{_mm512_set_epi16( + int16_t{31}, int16_t{30}, int16_t{29}, int16_t{28}, int16_t{27}, + int16_t{26}, int16_t{25}, int16_t{24}, int16_t{23}, int16_t{22}, + int16_t{21}, int16_t{20}, int16_t{19}, int16_t{18}, int16_t{17}, + int16_t{16}, int16_t{15}, int16_t{14}, int16_t{13}, int16_t{12}, + int16_t{11}, int16_t{10}, int16_t{9}, int16_t{8}, int16_t{7}, int16_t{6}, + int16_t{5}, int16_t{4}, int16_t{3}, int16_t{2}, int16_t{1}, int16_t{0})}; +#endif +} + +#if HWY_HAVE_FLOAT16 +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm512_set_ph( + float16_t{31}, float16_t{30}, float16_t{29}, float16_t{28}, float16_t{27}, + float16_t{26}, float16_t{25}, float16_t{24}, float16_t{23}, float16_t{22}, + float16_t{21}, float16_t{20}, float16_t{19}, float16_t{18}, float16_t{17}, + float16_t{16}, float16_t{15}, float16_t{14}, float16_t{13}, float16_t{12}, + float16_t{11}, float16_t{10}, float16_t{9}, float16_t{8}, float16_t{7}, + float16_t{6}, float16_t{5}, float16_t{4}, float16_t{3}, float16_t{2}, + float16_t{1}, float16_t{0})}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm512_set_epi32( + int32_t{15}, int32_t{14}, int32_t{13}, int32_t{12}, int32_t{11}, + int32_t{10}, int32_t{9}, int32_t{8}, int32_t{7}, int32_t{6}, int32_t{5}, + int32_t{4}, int32_t{3}, int32_t{2}, int32_t{1}, int32_t{0})}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm512_set_epi64(int64_t{7}, int64_t{6}, int64_t{5}, + int64_t{4}, int64_t{3}, int64_t{2}, + int64_t{1}, int64_t{0})}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm512_set_ps(15.0f, 14.0f, 13.0f, 12.0f, 11.0f, 10.0f, 9.0f, + 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, + 0.0f)}; +} + +template +HWY_INLINE VFromD Iota0(D /*d*/) { + return VFromD{_mm512_set_pd(7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0)}; +} + +} // namespace detail + +template +HWY_API VFromD Iota(D d, const T2 first) { + return detail::Iota0(d) + Set(d, ConvertScalarTo>(first)); +} + +// ================================================== LOGICAL + +#if HWY_X86_HAVE_TERNARY_LOGIC +namespace detail { + +// Per-target partial specialization. +template +struct TernaryLogicImpl { + template + HWY_INLINE V operator()(V a, V b, V c) const { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const __m512i ret = _mm512_ternarylogic_epi64( + BitCast(du, a).raw, BitCast(du, b).raw, BitCast(du, c).raw, kTernLogOp); + return BitCast(d, VU{ret}); + } +}; + +// Same, but with writemask. If !mask, returns a. +template +struct MaskedTernaryLogicImpl { + template , HWY_IF_T_SIZE_D(D, 4)> + HWY_INLINE V operator()(MFromD mask, V a, V b, V c) const { + const D d; + const RebindToUnsigned du; + using VU = VFromD; + const __m512i ret = _mm512_mask_ternarylogic_epi32(a.raw, mask.raw, b.raw, + c.raw, kTernLogOp); + return BitCast(d, VU{ret}); + } + template , HWY_IF_T_SIZE_D(D, 8)> + HWY_INLINE V operator()(MFromD mask, V a, V b, V c) const { + const D d; + const RebindToUnsigned du; + using VU = VFromD; + const __m512i ret = _mm512_mask_ternarylogic_epi64(a.raw, mask.raw, b.raw, + c.raw, kTernLogOp); + return BitCast(d, VU{ret}); + } +}; + +} // namespace detail +#endif // HWY_X86_HAVE_TERNARY_LOGIC + +// ------------------------------ And + +template +HWY_API Vec512 And(const Vec512 a, const Vec512 b) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm512_and_si512(BitCast(du, a).raw, + BitCast(du, b).raw)}); +} + +HWY_API Vec512 And(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_and_ps(a.raw, b.raw)}; +} +HWY_API Vec512 And(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_and_pd(a.raw, b.raw)}; +} + +// ------------------------------ AndNot + +// Returns ~not_mask & mask. +template +HWY_API Vec512 AndNot(const Vec512 not_mask, const Vec512 mask) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm512_andnot_si512( + BitCast(du, not_mask).raw, BitCast(du, mask).raw)}); +} +HWY_API Vec512 AndNot(const Vec512 not_mask, + const Vec512 mask) { + return Vec512{_mm512_andnot_ps(not_mask.raw, mask.raw)}; +} +HWY_API Vec512 AndNot(const Vec512 not_mask, + const Vec512 mask) { + return Vec512{_mm512_andnot_pd(not_mask.raw, mask.raw)}; +} + +// ------------------------------ Or + +template +HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm512_or_si512(BitCast(du, a).raw, + BitCast(du, b).raw)}); +} + +HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_or_ps(a.raw, b.raw)}; +} +HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_or_pd(a.raw, b.raw)}; +} + +// ------------------------------ Xor + +template +HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { + const DFromV d; // for float16_t + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm512_xor_si512(BitCast(du, a).raw, + BitCast(du, b).raw)}); +} + +HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_xor_ps(a.raw, b.raw)}; +} +HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { + return Vec512{_mm512_xor_pd(a.raw, b.raw)}; +} + +// ------------------------------ Operator overloads (internal-only if float) + +template +HWY_API Vec512 operator&(const Vec512 a, const Vec512 b) { + return And(a, b); +} + +template +HWY_API Vec512 operator|(const Vec512 a, const Vec512 b) { + return Or(a, b); +} + +template +HWY_API Vec512 operator^(const Vec512 a, const Vec512 b) { + return Xor(a, b); +} + +// ------------------------------ PopulationCount + +// 8/16 require BITALG, 32/64 require VPOPCNTDQ. +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_POPCNT +#undef HWY_NATIVE_POPCNT +#else +#define HWY_NATIVE_POPCNT +#endif + +namespace detail { + +template +HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<1> /* tag */, Vec512 v) { + return Vec512{_mm512_popcnt_epi8(v.raw)}; +} +template +HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<2> /* tag */, Vec512 v) { + return Vec512{_mm512_popcnt_epi16(v.raw)}; +} +template +HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<4> /* tag */, Vec512 v) { + return Vec512{_mm512_popcnt_epi32(v.raw)}; +} +template +HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<8> /* tag */, Vec512 v) { + return Vec512{_mm512_popcnt_epi64(v.raw)}; +} + +} // namespace detail + +template +HWY_API Vec512 PopulationCount(Vec512 v) { + return detail::PopulationCount(hwy::SizeTag(), v); +} + +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ================================================== MASK + +// ------------------------------ FirstN + +// Possibilities for constructing a bitmask of N ones: +// - kshift* only consider the lowest byte of the shift count, so they would +// not correctly handle large n. +// - Scalar shifts >= 64 are UB. +// - BZHI has the desired semantics; we assume AVX-512 implies BMI2. However, +// we need 64-bit masks for sizeof(T) == 1, so special-case 32-bit builds. + +#if HWY_ARCH_X86_32 +namespace detail { + +// 32 bit mask is sufficient for lane size >= 2. +template +HWY_INLINE Mask512 FirstN(size_t n) { + Mask512 m; + const uint32_t all = ~uint32_t{0}; + // BZHI only looks at the lower 8 bits of n, but it has been clamped to + // MaxLanes, which is at most 32. + m.raw = static_cast(_bzhi_u32(all, n)); + return m; +} + +#if HWY_COMPILER_MSVC >= 1920 || HWY_COMPILER_GCC_ACTUAL >= 900 || \ + HWY_COMPILER_CLANG || HWY_COMPILER_ICC +template +HWY_INLINE Mask512 FirstN(size_t n) { + uint32_t lo_mask; + uint32_t hi_mask; + uint32_t hi_mask_len; +#if HWY_COMPILER_GCC + if (__builtin_constant_p(n >= 32) && n >= 32) { + if (__builtin_constant_p(n >= 64) && n >= 64) { + hi_mask_len = 32u; + } else { + hi_mask_len = static_cast(n) - 32u; + } + lo_mask = hi_mask = 0xFFFFFFFFu; + } else // NOLINT(readability/braces) +#endif + { + const uint32_t lo_mask_len = static_cast(n); + lo_mask = _bzhi_u32(0xFFFFFFFFu, lo_mask_len); + +#if HWY_COMPILER_GCC + if (__builtin_constant_p(lo_mask_len <= 32) && lo_mask_len <= 32) { + return Mask512{static_cast<__mmask64>(lo_mask)}; + } +#endif + + _addcarry_u32(_subborrow_u32(0, lo_mask_len, 32u, &hi_mask_len), + 0xFFFFFFFFu, 0u, &hi_mask); + } + hi_mask = _bzhi_u32(hi_mask, hi_mask_len); +#if HWY_COMPILER_GCC && !HWY_COMPILER_ICC + if (__builtin_constant_p((static_cast(hi_mask) << 32) | lo_mask)) +#endif + return Mask512{static_cast<__mmask64>( + (static_cast(hi_mask) << 32) | lo_mask)}; +#if HWY_COMPILER_GCC && !HWY_COMPILER_ICC + else + return Mask512{_mm512_kunpackd(static_cast<__mmask64>(hi_mask), + static_cast<__mmask64>(lo_mask))}; +#endif +} +#else // HWY_COMPILER.. +template +HWY_INLINE Mask512 FirstN(size_t n) { + const uint64_t bits = n < 64 ? ((1ULL << n) - 1) : ~uint64_t{0}; + return Mask512{static_cast<__mmask64>(bits)}; +} +#endif // HWY_COMPILER.. +} // namespace detail +#endif // HWY_ARCH_X86_32 + +template +HWY_API MFromD FirstN(D d, size_t n) { + // This ensures `num` <= 255 as required by bzhi, which only looks + // at the lower 8 bits. + n = HWY_MIN(n, MaxLanes(d)); + +#if HWY_ARCH_X86_64 + MFromD m; + const uint64_t all = ~uint64_t{0}; + m.raw = static_cast(_bzhi_u64(all, n)); + return m; +#else + return detail::FirstN>(n); +#endif // HWY_ARCH_X86_64 +} + +// ------------------------------ IfThenElse + +// Returns mask ? b : a. + +namespace detail { + +// Templates for signed/unsigned integer of a particular size. +template +HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<1> /* tag */, + const Mask512 mask, const Vec512 yes, + const Vec512 no) { + return Vec512{_mm512_mask_blend_epi8(mask.raw, no.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<2> /* tag */, + const Mask512 mask, const Vec512 yes, + const Vec512 no) { + return Vec512{_mm512_mask_blend_epi16(mask.raw, no.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<4> /* tag */, + const Mask512 mask, const Vec512 yes, + const Vec512 no) { + return Vec512{_mm512_mask_blend_epi32(mask.raw, no.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<8> /* tag */, + const Mask512 mask, const Vec512 yes, + const Vec512 no) { + return Vec512{_mm512_mask_blend_epi64(mask.raw, no.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec512 IfThenElse(const Mask512 mask, const Vec512 yes, + const Vec512 no) { + return detail::IfThenElse(hwy::SizeTag(), mask, yes, no); +} +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 IfThenElse(Mask512 mask, + Vec512 yes, + Vec512 no) { + return Vec512{_mm512_mask_blend_ph(mask.raw, no.raw, yes.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 IfThenElse(Mask512 mask, Vec512 yes, + Vec512 no) { + return Vec512{_mm512_mask_blend_ps(mask.raw, no.raw, yes.raw)}; +} +HWY_API Vec512 IfThenElse(Mask512 mask, Vec512 yes, + Vec512 no) { + return Vec512{_mm512_mask_blend_pd(mask.raw, no.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<1> /* tag */, + const Mask512 mask, + const Vec512 yes) { + return Vec512{_mm512_maskz_mov_epi8(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<2> /* tag */, + const Mask512 mask, + const Vec512 yes) { + return Vec512{_mm512_maskz_mov_epi16(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<4> /* tag */, + const Mask512 mask, + const Vec512 yes) { + return Vec512{_mm512_maskz_mov_epi32(mask.raw, yes.raw)}; +} +template +HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<8> /* tag */, + const Mask512 mask, + const Vec512 yes) { + return Vec512{_mm512_maskz_mov_epi64(mask.raw, yes.raw)}; +} + +} // namespace detail + +template +HWY_API Vec512 IfThenElseZero(const Mask512 mask, const Vec512 yes) { + return detail::IfThenElseZero(hwy::SizeTag(), mask, yes); +} +HWY_API Vec512 IfThenElseZero(Mask512 mask, Vec512 yes) { + return Vec512{_mm512_maskz_mov_ps(mask.raw, yes.raw)}; +} +HWY_API Vec512 IfThenElseZero(Mask512 mask, + Vec512 yes) { + return Vec512{_mm512_maskz_mov_pd(mask.raw, yes.raw)}; +} + +namespace detail { + +template +HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<1> /* tag */, + const Mask512 mask, const Vec512 no) { + // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. + return Vec512{_mm512_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<2> /* tag */, + const Mask512 mask, const Vec512 no) { + return Vec512{_mm512_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<4> /* tag */, + const Mask512 mask, const Vec512 no) { + return Vec512{_mm512_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; +} +template +HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<8> /* tag */, + const Mask512 mask, const Vec512 no) { + return Vec512{_mm512_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; +} + +} // namespace detail + +template +HWY_API Vec512 IfThenZeroElse(const Mask512 mask, const Vec512 no) { + return detail::IfThenZeroElse(hwy::SizeTag(), mask, no); +} +HWY_API Vec512 IfThenZeroElse(Mask512 mask, Vec512 no) { + return Vec512{_mm512_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; +} +HWY_API Vec512 IfThenZeroElse(Mask512 mask, Vec512 no) { + return Vec512{_mm512_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; +} + +template +HWY_API Vec512 IfNegativeThenElse(Vec512 v, Vec512 yes, Vec512 no) { + static_assert(IsSigned(), "Only works for signed/float"); + // AVX3 MaskFromVec only looks at the MSB + return IfThenElse(MaskFromVec(v), yes, no); +} + +template +HWY_API Vec512 IfNegativeThenNegOrUndefIfZero(Vec512 mask, Vec512 v) { + // AVX3 MaskFromVec only looks at the MSB + const DFromV d; + return MaskedSubOr(v, MaskFromVec(mask), Zero(d), v); +} + +// ================================================== ARITHMETIC + +// ------------------------------ Addition + +// Unsigned +HWY_API Vec512 operator+(Vec512 a, Vec512 b) { + return Vec512{_mm512_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(Vec512 a, Vec512 b) { + return Vec512{_mm512_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(Vec512 a, Vec512 b) { + return Vec512{_mm512_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(Vec512 a, Vec512 b) { + return Vec512{_mm512_add_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 operator+(Vec512 a, Vec512 b) { + return Vec512{_mm512_add_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(Vec512 a, Vec512 b) { + return Vec512{_mm512_add_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(Vec512 a, Vec512 b) { + return Vec512{_mm512_add_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(Vec512 a, Vec512 b) { + return Vec512{_mm512_add_epi64(a.raw, b.raw)}; +} + +// Float +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 operator+(Vec512 a, Vec512 b) { + return Vec512{_mm512_add_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 operator+(Vec512 a, Vec512 b) { + return Vec512{_mm512_add_ps(a.raw, b.raw)}; +} +HWY_API Vec512 operator+(Vec512 a, Vec512 b) { + return Vec512{_mm512_add_pd(a.raw, b.raw)}; +} + +// ------------------------------ Subtraction + +// Unsigned +HWY_API Vec512 operator-(Vec512 a, Vec512 b) { + return Vec512{_mm512_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(Vec512 a, Vec512 b) { + return Vec512{_mm512_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(Vec512 a, Vec512 b) { + return Vec512{_mm512_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(Vec512 a, Vec512 b) { + return Vec512{_mm512_sub_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 operator-(Vec512 a, Vec512 b) { + return Vec512{_mm512_sub_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(Vec512 a, Vec512 b) { + return Vec512{_mm512_sub_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(Vec512 a, Vec512 b) { + return Vec512{_mm512_sub_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(Vec512 a, Vec512 b) { + return Vec512{_mm512_sub_epi64(a.raw, b.raw)}; +} + +// Float +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 operator-(Vec512 a, Vec512 b) { + return Vec512{_mm512_sub_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 operator-(Vec512 a, Vec512 b) { + return Vec512{_mm512_sub_ps(a.raw, b.raw)}; +} +HWY_API Vec512 operator-(Vec512 a, Vec512 b) { + return Vec512{_mm512_sub_pd(a.raw, b.raw)}; +} + +// ------------------------------ SumsOf8 +HWY_API Vec512 SumsOf8(const Vec512 v) { + const Full512 d; + return Vec512{_mm512_sad_epu8(v.raw, Zero(d).raw)}; +} + +HWY_API Vec512 SumsOf8AbsDiff(Vec512 a, Vec512 b) { + return Vec512{_mm512_sad_epu8(a.raw, b.raw)}; +} + +// ------------------------------ SumsOf4 +namespace detail { + +HWY_INLINE Vec512 SumsOf4(hwy::UnsignedTag /*type_tag*/, + hwy::SizeTag<1> /*lane_size_tag*/, + Vec512 v) { + const DFromV d; + + // _mm512_maskz_dbsad_epu8 is used below as the odd uint16_t lanes need to be + // zeroed out and the sums of the 4 consecutive lanes are already in the + // even uint16_t lanes of the _mm512_maskz_dbsad_epu8 result. + return Vec512{_mm512_maskz_dbsad_epu8( + static_cast<__mmask32>(0x55555555), v.raw, Zero(d).raw, 0)}; +} + +// I8->I32 SumsOf4 +// Generic for all vector lengths +template +HWY_INLINE VFromD>> SumsOf4( + hwy::SignedTag /*type_tag*/, hwy::SizeTag<1> /*lane_size_tag*/, V v) { + const DFromV d; + const RebindToUnsigned du; + const RepartitionToWideX2 di32; + + // Adjust the values of v to be in the 0..255 range by adding 128 to each lane + // of v (which is the same as an bitwise XOR of each i8 lane by 128) and then + // bitcasting the Xor result to an u8 vector. + const auto v_adj = BitCast(du, Xor(v, SignBit(d))); + + // Need to add -512 to each i32 lane of the result of the + // SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<1>(), v_adj) operation to account + // for the adjustment made above. + return BitCast(di32, SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<1>(), v_adj)) + + Set(di32, int32_t{-512}); +} + +} // namespace detail + +// ------------------------------ SumsOfShuffledQuadAbsDiff + +#if HWY_TARGET <= HWY_AVX3 +template +static Vec512 SumsOfShuffledQuadAbsDiff(Vec512 a, + Vec512 b) { + static_assert(0 <= kIdx0 && kIdx0 <= 3, "kIdx0 must be between 0 and 3"); + static_assert(0 <= kIdx1 && kIdx1 <= 3, "kIdx1 must be between 0 and 3"); + static_assert(0 <= kIdx2 && kIdx2 <= 3, "kIdx2 must be between 0 and 3"); + static_assert(0 <= kIdx3 && kIdx3 <= 3, "kIdx3 must be between 0 and 3"); + return Vec512{ + _mm512_dbsad_epu8(b.raw, a.raw, _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kIdx0))}; +} +#endif + +// ------------------------------ SaturatedAdd + +// Returns a + b clamped to the destination range. + +// Unsigned +HWY_API Vec512 SaturatedAdd(Vec512 a, Vec512 b) { + return Vec512{_mm512_adds_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 SaturatedAdd(Vec512 a, Vec512 b) { + return Vec512{_mm512_adds_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 SaturatedAdd(Vec512 a, Vec512 b) { + return Vec512{_mm512_adds_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 SaturatedAdd(Vec512 a, Vec512 b) { + return Vec512{_mm512_adds_epi16(a.raw, b.raw)}; +} + +// ------------------------------ SaturatedSub + +// Returns a - b clamped to the destination range. + +// Unsigned +HWY_API Vec512 SaturatedSub(Vec512 a, Vec512 b) { + return Vec512{_mm512_subs_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 SaturatedSub(Vec512 a, Vec512 b) { + return Vec512{_mm512_subs_epu16(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 SaturatedSub(Vec512 a, Vec512 b) { + return Vec512{_mm512_subs_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 SaturatedSub(Vec512 a, Vec512 b) { + return Vec512{_mm512_subs_epi16(a.raw, b.raw)}; +} + +// ------------------------------ Average + +// Returns (a + b + 1) / 2 + +// Unsigned +HWY_API Vec512 AverageRound(Vec512 a, Vec512 b) { + return Vec512{_mm512_avg_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 AverageRound(Vec512 a, Vec512 b) { + return Vec512{_mm512_avg_epu16(a.raw, b.raw)}; +} + +// ------------------------------ Abs (Sub) + +// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. +HWY_API Vec512 Abs(const Vec512 v) { +#if HWY_COMPILER_MSVC + // Workaround for incorrect codegen? (untested due to internal compiler error) + const DFromV d; + const auto zero = Zero(d); + return Vec512{_mm512_max_epi8(v.raw, (zero - v).raw)}; +#else + return Vec512{_mm512_abs_epi8(v.raw)}; +#endif +} +HWY_API Vec512 Abs(const Vec512 v) { + return Vec512{_mm512_abs_epi16(v.raw)}; +} +HWY_API Vec512 Abs(const Vec512 v) { + return Vec512{_mm512_abs_epi32(v.raw)}; +} +HWY_API Vec512 Abs(const Vec512 v) { + return Vec512{_mm512_abs_epi64(v.raw)}; +} + +// ------------------------------ ShiftLeft + +#if HWY_TARGET <= HWY_AVX3_DL +namespace detail { +template +HWY_API Vec512 GaloisAffine(Vec512 v, Vec512 matrix) { + return Vec512{_mm512_gf2p8affine_epi64_epi8(v.raw, matrix.raw, 0)}; +} +} // namespace detail +#endif // HWY_TARGET <= HWY_AVX3_DL + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + return Vec512{_mm512_slli_epi64(v.raw, kBits)}; +} + +#if HWY_TARGET > HWY_AVX3_DL + +template +HWY_API Vec512 ShiftLeft(const Vec512 v) { + const DFromV d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftLeft(BitCast(d16, v))); + return kBits == 1 + ? (v + v) + : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); +} + +#endif // HWY_TARGET > HWY_AVX3_DL + +// ------------------------------ ShiftRight + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srli_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srli_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srli_epi64(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srai_epi16(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srai_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + return Vec512{_mm512_srai_epi64(v.raw, kBits)}; +} + +#if HWY_TARGET > HWY_AVX3_DL + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + const DFromV d8; + // Use raw instead of BitCast to support N=1. + const Vec512 shifted{ShiftRight(Vec512{v.raw}).raw}; + return shifted & Set(d8, 0xFF >> kBits); +} + +template +HWY_API Vec512 ShiftRight(const Vec512 v) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); + const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); + return (shifted ^ shifted_sign) - shifted_sign; +} + +#endif // HWY_TARGET > HWY_AVX3_DL + +// ------------------------------ RotateRight + +#if HWY_TARGET > HWY_AVX3_DL +template +HWY_API Vec512 RotateRight(const Vec512 v) { + static_assert(0 <= kBits && kBits < 8, "Invalid shift count"); + if (kBits == 0) return v; + // AVX3 does not support 8-bit. + return Or(ShiftRight(v), ShiftLeft(v)); +} +#endif // HWY_TARGET > HWY_AVX3_DL + +template +HWY_API Vec512 RotateRight(const Vec512 v) { + static_assert(0 <= kBits && kBits < 16, "Invalid shift count"); + if (kBits == 0) return v; +#if HWY_TARGET <= HWY_AVX3_DL + return Vec512{_mm512_shrdi_epi16(v.raw, v.raw, kBits)}; +#else + // AVX3 does not support 16-bit. + return Or(ShiftRight(v), ShiftLeft(v)); +#endif +} + +template +HWY_API Vec512 RotateRight(const Vec512 v) { + static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); + if (kBits == 0) return v; + return Vec512{_mm512_ror_epi32(v.raw, kBits)}; +} + +template +HWY_API Vec512 RotateRight(const Vec512 v) { + static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); + if (kBits == 0) return v; + return Vec512{_mm512_ror_epi64(v.raw, kBits)}; +} + +// ------------------------------ Rol/Ror +#if HWY_TARGET <= HWY_AVX3_DL +template +HWY_API Vec512 Ror(Vec512 a, Vec512 b) { + return Vec512{_mm512_shrdv_epi16(a.raw, a.raw, b.raw)}; +} +#endif // HWY_TARGET <= HWY_AVX3_DL + +template +HWY_API Vec512 Rol(Vec512 a, Vec512 b) { + return Vec512{_mm512_rolv_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec512 Ror(Vec512 a, Vec512 b) { + return Vec512{_mm512_rorv_epi32(a.raw, b.raw)}; +} + +template +HWY_API Vec512 Rol(Vec512 a, Vec512 b) { + return Vec512{_mm512_rolv_epi64(a.raw, b.raw)}; +} + +template +HWY_API Vec512 Ror(Vec512 a, Vec512 b) { + return Vec512{_mm512_rorv_epi64(a.raw, b.raw)}; +} + +// ------------------------------ ShiftLeftSame + +// GCC <14 and Clang <11 do not follow the Intel documentation for AVX-512 +// shift-with-immediate: the counts should all be unsigned int. Despite casting, +// we still see warnings in GCC debug builds, hence disable. +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1100 +using Shift16Count = int; +using Shift3264Count = int; +#elif HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400 +// GCC 11.0 requires these, prior versions used a macro+cast and don't care. +using Shift16Count = int; +using Shift3264Count = unsigned int; +#else +// Assume documented behavior. Clang 11, GCC 14 and MSVC 14.28.29910 match this. +using Shift16Count = unsigned int; +using Shift3264Count = unsigned int; +#endif + +HWY_API Vec512 ShiftLeftSame(const Vec512 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_slli_epi16(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftLeftSame(const Vec512 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_slli_epi32(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftLeftSame(const Vec512 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_slli_epi64(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_slli_epi16(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_slli_epi32(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_slli_epi64(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +template +HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { + const DFromV d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftLeftSame(BitCast(d16, v), bits)); + return shifted & Set(d8, static_cast((0xFF << bits) & 0xFF)); +} + +// ------------------------------ ShiftRightSame + +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_srli_epi16(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_srli_epi32(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_srli_epi64(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftRightSame(Vec512 v, const int bits) { + const DFromV d8; + const RepartitionToWide d16; + const auto shifted = BitCast(d8, ShiftRightSame(BitCast(d16, v), bits)); + return shifted & Set(d8, static_cast(0xFF >> bits)); +} + +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_srai_epi16(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_srai_epi32(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; +} +HWY_API Vec512 ShiftRightSame(const Vec512 v, + const int bits) { +#if HWY_COMPILER_GCC + if (__builtin_constant_p(bits)) { + return Vec512{ + _mm512_srai_epi64(v.raw, static_cast(bits))}; + } +#endif + return Vec512{_mm512_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; +} + +HWY_API Vec512 ShiftRightSame(Vec512 v, const int bits) { + const DFromV di; + const RebindToUnsigned du; + const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); + const auto shifted_sign = + BitCast(di, Set(du, static_cast(0x80 >> bits))); + return (shifted ^ shifted_sign) - shifted_sign; +} + +HWY_DIAGNOSTICS(pop) + +// ------------------------------ Minimum + +// Unsigned +HWY_API Vec512 Min(Vec512 a, Vec512 b) { + return Vec512{_mm512_min_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 Min(Vec512 a, Vec512 b) { + return Vec512{_mm512_min_epu16(a.raw, b.raw)}; +} +HWY_API Vec512 Min(Vec512 a, Vec512 b) { + return Vec512{_mm512_min_epu32(a.raw, b.raw)}; +} +HWY_API Vec512 Min(Vec512 a, Vec512 b) { + return Vec512{_mm512_min_epu64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 Min(Vec512 a, Vec512 b) { + return Vec512{_mm512_min_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 Min(Vec512 a, Vec512 b) { + return Vec512{_mm512_min_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 Min(Vec512 a, Vec512 b) { + return Vec512{_mm512_min_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 Min(Vec512 a, Vec512 b) { + return Vec512{_mm512_min_epi64(a.raw, b.raw)}; +} + +// Float +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 Min(Vec512 a, Vec512 b) { + return Vec512{_mm512_min_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 Min(Vec512 a, Vec512 b) { + return Vec512{_mm512_min_ps(a.raw, b.raw)}; +} +HWY_API Vec512 Min(Vec512 a, Vec512 b) { + return Vec512{_mm512_min_pd(a.raw, b.raw)}; +} + +// ------------------------------ Maximum + +// Unsigned +HWY_API Vec512 Max(Vec512 a, Vec512 b) { + return Vec512{_mm512_max_epu8(a.raw, b.raw)}; +} +HWY_API Vec512 Max(Vec512 a, Vec512 b) { + return Vec512{_mm512_max_epu16(a.raw, b.raw)}; +} +HWY_API Vec512 Max(Vec512 a, Vec512 b) { + return Vec512{_mm512_max_epu32(a.raw, b.raw)}; +} +HWY_API Vec512 Max(Vec512 a, Vec512 b) { + return Vec512{_mm512_max_epu64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 Max(Vec512 a, Vec512 b) { + return Vec512{_mm512_max_epi8(a.raw, b.raw)}; +} +HWY_API Vec512 Max(Vec512 a, Vec512 b) { + return Vec512{_mm512_max_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 Max(Vec512 a, Vec512 b) { + return Vec512{_mm512_max_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 Max(Vec512 a, Vec512 b) { + return Vec512{_mm512_max_epi64(a.raw, b.raw)}; +} + +// Float +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 Max(Vec512 a, Vec512 b) { + return Vec512{_mm512_max_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 Max(Vec512 a, Vec512 b) { + return Vec512{_mm512_max_ps(a.raw, b.raw)}; +} +HWY_API Vec512 Max(Vec512 a, Vec512 b) { + return Vec512{_mm512_max_pd(a.raw, b.raw)}; +} + +// ------------------------------ MinNumber and MaxNumber + +#if HWY_X86_HAVE_AVX10_2_OPS + +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 MinNumber(Vec512 a, Vec512 b) { + return Vec512{_mm512_minmax_ph(a.raw, b.raw, 0x14)}; +} +#endif +HWY_API Vec512 MinNumber(Vec512 a, Vec512 b) { + return Vec512{_mm512_minmax_ps(a.raw, b.raw, 0x14)}; +} +HWY_API Vec512 MinNumber(Vec512 a, Vec512 b) { + return Vec512{_mm512_minmax_pd(a.raw, b.raw, 0x14)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 MaxNumber(Vec512 a, Vec512 b) { + return Vec512{_mm512_minmax_ph(a.raw, b.raw, 0x15)}; +} +#endif +HWY_API Vec512 MaxNumber(Vec512 a, Vec512 b) { + return Vec512{_mm512_minmax_ps(a.raw, b.raw, 0x15)}; +} +HWY_API Vec512 MaxNumber(Vec512 a, Vec512 b) { + return Vec512{_mm512_minmax_pd(a.raw, b.raw, 0x15)}; +} + +#endif + +// ------------------------------ MinMagnitude and MaxMagnitude + +#if HWY_X86_HAVE_AVX10_2_OPS + +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 MinMagnitude(Vec512 a, + Vec512 b) { + return Vec512{_mm512_minmax_ph(a.raw, b.raw, 0x16)}; +} +#endif +HWY_API Vec512 MinMagnitude(Vec512 a, Vec512 b) { + return Vec512{_mm512_minmax_ps(a.raw, b.raw, 0x16)}; +} +HWY_API Vec512 MinMagnitude(Vec512 a, Vec512 b) { + return Vec512{_mm512_minmax_pd(a.raw, b.raw, 0x16)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 MaxMagnitude(Vec512 a, + Vec512 b) { + return Vec512{_mm512_minmax_ph(a.raw, b.raw, 0x17)}; +} +#endif +HWY_API Vec512 MaxMagnitude(Vec512 a, Vec512 b) { + return Vec512{_mm512_minmax_ps(a.raw, b.raw, 0x17)}; +} +HWY_API Vec512 MaxMagnitude(Vec512 a, Vec512 b) { + return Vec512{_mm512_minmax_pd(a.raw, b.raw, 0x17)}; +} + +#endif + +// ------------------------------ Integer multiplication + +// Unsigned +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi64(a.raw, b.raw)}; +} + +// Signed +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi16(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mullo_epi64(a.raw, b.raw)}; +} + +// Returns the upper 16 bits of a * b in each lane. +HWY_API Vec512 MulHigh(Vec512 a, Vec512 b) { + return Vec512{_mm512_mulhi_epu16(a.raw, b.raw)}; +} +HWY_API Vec512 MulHigh(Vec512 a, Vec512 b) { + return Vec512{_mm512_mulhi_epi16(a.raw, b.raw)}; +} + +HWY_API Vec512 MulFixedPoint15(Vec512 a, Vec512 b) { + return Vec512{_mm512_mulhrs_epi16(a.raw, b.raw)}; +} + +// Multiplies even lanes (0, 2 ..) and places the double-wide result into +// even and the upper half into its odd neighbor lane. +HWY_API Vec512 MulEven(Vec512 a, Vec512 b) { + return Vec512{_mm512_mul_epi32(a.raw, b.raw)}; +} +HWY_API Vec512 MulEven(Vec512 a, Vec512 b) { + return Vec512{_mm512_mul_epu32(a.raw, b.raw)}; +} + +// ------------------------------ Neg (Sub) + +template +HWY_API Vec512 Neg(const Vec512 v) { + const DFromV d; + return Xor(v, SignBit(d)); +} + +template +HWY_API Vec512 Neg(const Vec512 v) { + const DFromV d; + return Zero(d) - v; +} + +// ------------------------------ Floating-point mul / div + +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mul_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mul_ps(a.raw, b.raw)}; +} +HWY_API Vec512 operator*(Vec512 a, Vec512 b) { + return Vec512{_mm512_mul_pd(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 MulByFloorPow2(Vec512 a, + Vec512 b) { + return Vec512{_mm512_scalef_ph(a.raw, b.raw)}; +} +#endif + +HWY_API Vec512 MulByFloorPow2(Vec512 a, Vec512 b) { + return Vec512{_mm512_scalef_ps(a.raw, b.raw)}; +} + +HWY_API Vec512 MulByFloorPow2(Vec512 a, Vec512 b) { + return Vec512{_mm512_scalef_pd(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 operator/(Vec512 a, Vec512 b) { + return Vec512{_mm512_div_ph(a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 operator/(Vec512 a, Vec512 b) { + return Vec512{_mm512_div_ps(a.raw, b.raw)}; +} +HWY_API Vec512 operator/(Vec512 a, Vec512 b) { + return Vec512{_mm512_div_pd(a.raw, b.raw)}; +} + +// Approximate reciprocal +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 ApproximateReciprocal(const Vec512 v) { + return Vec512{_mm512_rcp_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 ApproximateReciprocal(const Vec512 v) { + return Vec512{_mm512_rcp14_ps(v.raw)}; +} + +HWY_API Vec512 ApproximateReciprocal(Vec512 v) { + return Vec512{_mm512_rcp14_pd(v.raw)}; +} + +// ------------------------------ GetExponent + +#if HWY_HAVE_FLOAT16 +template ), HWY_IF_V_SIZE_V(V, 64)> +HWY_API V GetExponent(V v) { + return V{_mm512_getexp_ph(v.raw)}; +} +#endif +template ), HWY_IF_V_SIZE_V(V, 64)> +HWY_API V GetExponent(V v) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return V{_mm512_getexp_ps(v.raw)}; + HWY_DIAGNOSTICS(pop) +} +template ), HWY_IF_V_SIZE_V(V, 64)> +HWY_API V GetExponent(V v) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return V{_mm512_getexp_pd(v.raw)}; + HWY_DIAGNOSTICS(pop) +} + +// ------------------------------ MaskedMinOr + +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epu8(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epu16(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epu32(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epu64(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec512 MaskedMinOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_min_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedMaxOr + +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epu8(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epu16(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epu32(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epu64(no.raw, m.raw, a.raw, b.raw)}; +} +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec512 MaskedMaxOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_max_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedAddOr + +template +HWY_API Vec512 MaskedAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_add_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_add_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_add_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_add_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_add_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_add_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec512 MaskedAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_add_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedSubOr + +template +HWY_API Vec512 MaskedSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_sub_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_sub_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_sub_epi32(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_sub_epi64(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_sub_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_sub_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec512 MaskedSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_sub_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedMulOr + +HWY_API Vec512 MaskedMulOr(Vec512 no, Mask512 m, + Vec512 a, Vec512 b) { + return Vec512{_mm512_mask_mul_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +HWY_API Vec512 MaskedMulOr(Vec512 no, Mask512 m, + Vec512 a, Vec512 b) { + return Vec512{_mm512_mask_mul_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 MaskedMulOr(Vec512 no, + Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_mul_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedDivOr + +HWY_API Vec512 MaskedDivOr(Vec512 no, Mask512 m, + Vec512 a, Vec512 b) { + return Vec512{_mm512_mask_div_ps(no.raw, m.raw, a.raw, b.raw)}; +} + +HWY_API Vec512 MaskedDivOr(Vec512 no, Mask512 m, + Vec512 a, Vec512 b) { + return Vec512{_mm512_mask_div_pd(no.raw, m.raw, a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 MaskedDivOr(Vec512 no, + Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_div_ph(no.raw, m.raw, a.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +// ------------------------------ MaskedSatAddOr + +template +HWY_API Vec512 MaskedSatAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_adds_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSatAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_adds_epu8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSatAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_adds_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSatAddOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_adds_epu16(no.raw, m.raw, a.raw, b.raw)}; +} + +// ------------------------------ MaskedSatSubOr + +template +HWY_API Vec512 MaskedSatSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_subs_epi8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSatSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_subs_epu8(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSatSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_subs_epi16(no.raw, m.raw, a.raw, b.raw)}; +} + +template +HWY_API Vec512 MaskedSatSubOr(Vec512 no, Mask512 m, Vec512 a, + Vec512 b) { + return Vec512{_mm512_mask_subs_epu16(no.raw, m.raw, a.raw, b.raw)}; +} + +// ------------------------------ Floating-point multiply-add variants + +#if HWY_HAVE_FLOAT16 + +HWY_API Vec512 MulAdd(Vec512 mul, Vec512 x, + Vec512 add) { + return Vec512{_mm512_fmadd_ph(mul.raw, x.raw, add.raw)}; +} + +HWY_API Vec512 NegMulAdd(Vec512 mul, Vec512 x, + Vec512 add) { + return Vec512{_mm512_fnmadd_ph(mul.raw, x.raw, add.raw)}; +} + +HWY_API Vec512 MulSub(Vec512 mul, Vec512 x, + Vec512 sub) { + return Vec512{_mm512_fmsub_ph(mul.raw, x.raw, sub.raw)}; +} + +HWY_API Vec512 NegMulSub(Vec512 mul, Vec512 x, + Vec512 sub) { + return Vec512{_mm512_fnmsub_ph(mul.raw, x.raw, sub.raw)}; +} + +#endif // HWY_HAVE_FLOAT16 + +// Returns mul * x + add +HWY_API Vec512 MulAdd(Vec512 mul, Vec512 x, + Vec512 add) { + return Vec512{_mm512_fmadd_ps(mul.raw, x.raw, add.raw)}; +} +HWY_API Vec512 MulAdd(Vec512 mul, Vec512 x, + Vec512 add) { + return Vec512{_mm512_fmadd_pd(mul.raw, x.raw, add.raw)}; +} + +// Returns add - mul * x +HWY_API Vec512 NegMulAdd(Vec512 mul, Vec512 x, + Vec512 add) { + return Vec512{_mm512_fnmadd_ps(mul.raw, x.raw, add.raw)}; +} +HWY_API Vec512 NegMulAdd(Vec512 mul, Vec512 x, + Vec512 add) { + return Vec512{_mm512_fnmadd_pd(mul.raw, x.raw, add.raw)}; +} + +// Returns mul * x - sub +HWY_API Vec512 MulSub(Vec512 mul, Vec512 x, + Vec512 sub) { + return Vec512{_mm512_fmsub_ps(mul.raw, x.raw, sub.raw)}; +} +HWY_API Vec512 MulSub(Vec512 mul, Vec512 x, + Vec512 sub) { + return Vec512{_mm512_fmsub_pd(mul.raw, x.raw, sub.raw)}; +} + +// Returns -mul * x - sub +HWY_API Vec512 NegMulSub(Vec512 mul, Vec512 x, + Vec512 sub) { + return Vec512{_mm512_fnmsub_ps(mul.raw, x.raw, sub.raw)}; +} +HWY_API Vec512 NegMulSub(Vec512 mul, Vec512 x, + Vec512 sub) { + return Vec512{_mm512_fnmsub_pd(mul.raw, x.raw, sub.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 MulAddSub(Vec512 mul, Vec512 x, + Vec512 sub_or_add) { + return Vec512{_mm512_fmaddsub_ph(mul.raw, x.raw, sub_or_add.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Vec512 MulAddSub(Vec512 mul, Vec512 x, + Vec512 sub_or_add) { + return Vec512{_mm512_fmaddsub_ps(mul.raw, x.raw, sub_or_add.raw)}; +} + +HWY_API Vec512 MulAddSub(Vec512 mul, Vec512 x, + Vec512 sub_or_add) { + return Vec512{_mm512_fmaddsub_pd(mul.raw, x.raw, sub_or_add.raw)}; +} + +// ------------------------------ Floating-point square root + +// Full precision square root +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 Sqrt(const Vec512 v) { + return Vec512{_mm512_sqrt_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 Sqrt(const Vec512 v) { + return Vec512{_mm512_sqrt_ps(v.raw)}; +} +HWY_API Vec512 Sqrt(const Vec512 v) { + return Vec512{_mm512_sqrt_pd(v.raw)}; +} + +// Approximate reciprocal square root +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 ApproximateReciprocalSqrt(Vec512 v) { + return Vec512{_mm512_rsqrt_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 ApproximateReciprocalSqrt(Vec512 v) { + return Vec512{_mm512_rsqrt14_ps(v.raw)}; +} + +HWY_API Vec512 ApproximateReciprocalSqrt(Vec512 v) { + return Vec512{_mm512_rsqrt14_pd(v.raw)}; +} + +// ------------------------------ Floating-point rounding + +// Work around warnings in the intrinsic definitions (passing -1 as a mask). +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +// Toward nearest integer, tie to even +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 Round(Vec512 v) { + return Vec512{_mm512_roundscale_ph( + v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 Round(Vec512 v) { + return Vec512{_mm512_roundscale_ps( + v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512 Round(Vec512 v) { + return Vec512{_mm512_roundscale_pd( + v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; +} + +// Toward zero, aka truncate +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 Trunc(Vec512 v) { + return Vec512{ + _mm512_roundscale_ph(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 Trunc(Vec512 v) { + return Vec512{ + _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512 Trunc(Vec512 v) { + return Vec512{ + _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; +} + +// Toward +infinity, aka ceiling +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 Ceil(Vec512 v) { + return Vec512{ + _mm512_roundscale_ph(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 Ceil(Vec512 v) { + return Vec512{ + _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512 Ceil(Vec512 v) { + return Vec512{ + _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; +} + +// Toward -infinity, aka floor +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 Floor(Vec512 v) { + return Vec512{ + _mm512_roundscale_ph(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 Floor(Vec512 v) { + return Vec512{ + _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} +HWY_API Vec512 Floor(Vec512 v) { + return Vec512{ + _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== COMPARE + +// Comparisons set a mask bit to 1 if the condition is true, else 0. + +template +HWY_API MFromD RebindMask(DTo /*tag*/, Mask512 m) { + static_assert(sizeof(TFrom) == sizeof(TFromD), "Must have same size"); + return MFromD{m.raw}; +} + +namespace detail { + +template +HWY_INLINE Mask512 TestBit(hwy::SizeTag<1> /*tag*/, Vec512 v, + Vec512 bit) { + return Mask512{_mm512_test_epi8_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask512 TestBit(hwy::SizeTag<2> /*tag*/, Vec512 v, + Vec512 bit) { + return Mask512{_mm512_test_epi16_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask512 TestBit(hwy::SizeTag<4> /*tag*/, Vec512 v, + Vec512 bit) { + return Mask512{_mm512_test_epi32_mask(v.raw, bit.raw)}; +} +template +HWY_INLINE Mask512 TestBit(hwy::SizeTag<8> /*tag*/, Vec512 v, + Vec512 bit) { + return Mask512{_mm512_test_epi64_mask(v.raw, bit.raw)}; +} + +} // namespace detail + +template +HWY_API Mask512 TestBit(const Vec512 v, const Vec512 bit) { + static_assert(!hwy::IsFloat(), "Only integer vectors supported"); + return detail::TestBit(hwy::SizeTag(), v, bit); +} + +// ------------------------------ Equality + +template +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpeq_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpeq_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpeq_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpeq_epi64_mask(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Mask512 operator==(Vec512 a, + Vec512 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask512{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_EQ_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +HWY_API Mask512 operator==(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; +} + +// ------------------------------ Inequality + +template +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpneq_epi8_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpneq_epi16_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpneq_epi32_mask(a.raw, b.raw)}; +} +template +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpneq_epi64_mask(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Mask512 operator!=(Vec512 a, + Vec512 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask512{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; +} + +// ------------------------------ Strict inequality + +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epu8_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epu16_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epu32_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epu64_mask(a.raw, b.raw)}; +} + +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epi8_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epi16_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epi32_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpgt_epi64_mask(a.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask512{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_GT_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} +HWY_API Mask512 operator>(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; +} + +// ------------------------------ Weak inequality + +#if HWY_HAVE_FLOAT16 +HWY_API Mask512 operator>=(Vec512 a, + Vec512 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask512{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_GE_OQ)}; + HWY_DIAGNOSTICS(pop) +} +#endif // HWY_HAVE_FLOAT16 + +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; +} + +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpge_epu8_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpge_epu16_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpge_epu32_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpge_epu64_mask(a.raw, b.raw)}; +} + +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpge_epi8_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpge_epi16_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpge_epi32_mask(a.raw, b.raw)}; +} +HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmpge_epi64_mask(a.raw, b.raw)}; +} + +// ------------------------------ Reversed comparisons + +template +HWY_API Mask512 operator<(Vec512 a, Vec512 b) { + return b > a; +} + +template +HWY_API Mask512 operator<=(Vec512 a, Vec512 b) { + return b >= a; +} + +// ------------------------------ Mask + +template +HWY_API Mask512 MaskFromVec(Vec512 v) { + return Mask512{_mm512_movepi8_mask(v.raw)}; +} +template +HWY_API Mask512 MaskFromVec(Vec512 v) { + return Mask512{_mm512_movepi16_mask(v.raw)}; +} +template +HWY_API Mask512 MaskFromVec(Vec512 v) { + return Mask512{_mm512_movepi32_mask(v.raw)}; +} +template +HWY_API Mask512 MaskFromVec(Vec512 v) { + return Mask512{_mm512_movepi64_mask(v.raw)}; +} +template +HWY_API Mask512 MaskFromVec(Vec512 v) { + const RebindToSigned> di; + return Mask512{MaskFromVec(BitCast(di, v)).raw}; +} + +template +HWY_API Vec512 VecFromMask(Mask512 m) { + return Vec512{_mm512_movm_epi8(m.raw)}; +} +template +HWY_API Vec512 VecFromMask(Mask512 m) { + return Vec512{_mm512_movm_epi16(m.raw)}; +} +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 VecFromMask(Mask512 m) { + return Vec512{_mm512_castsi512_ph(_mm512_movm_epi16(m.raw))}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec512 VecFromMask(Mask512 m) { + return Vec512{_mm512_movm_epi32(m.raw)}; +} +template +HWY_API Vec512 VecFromMask(Mask512 m) { + return Vec512{_mm512_movm_epi64(m.raw)}; +} +template +HWY_API Vec512 VecFromMask(Mask512 m) { + const Full512 d; + const Full512> di; + return BitCast(d, VecFromMask(RebindMask(di, m))); +} + +// ------------------------------ Mask logical + +namespace detail { + +template +HWY_INLINE Mask512 Not(hwy::SizeTag<1> /*tag*/, Mask512 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_knot_mask64(m.raw)}; +#else + return Mask512{~m.raw}; +#endif +} +template +HWY_INLINE Mask512 Not(hwy::SizeTag<2> /*tag*/, Mask512 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_knot_mask32(m.raw)}; +#else + return Mask512{~m.raw}; +#endif +} +template +HWY_INLINE Mask512 Not(hwy::SizeTag<4> /*tag*/, Mask512 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_knot_mask16(m.raw)}; +#else + return Mask512{static_cast(~m.raw & 0xFFFF)}; +#endif +} +template +HWY_INLINE Mask512 Not(hwy::SizeTag<8> /*tag*/, Mask512 m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_knot_mask8(m.raw)}; +#else + return Mask512{static_cast(~m.raw & 0xFF)}; +#endif +} + +template +HWY_INLINE Mask512 And(hwy::SizeTag<1> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kand_mask64(a.raw, b.raw)}; +#else + return Mask512{a.raw & b.raw}; +#endif +} +template +HWY_INLINE Mask512 And(hwy::SizeTag<2> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kand_mask32(a.raw, b.raw)}; +#else + return Mask512{a.raw & b.raw}; +#endif +} +template +HWY_INLINE Mask512 And(hwy::SizeTag<4> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kand_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask512 And(hwy::SizeTag<8> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kand_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask512 AndNot(hwy::SizeTag<1> /*tag*/, Mask512 a, + Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kandn_mask64(a.raw, b.raw)}; +#else + return Mask512{~a.raw & b.raw}; +#endif +} +template +HWY_INLINE Mask512 AndNot(hwy::SizeTag<2> /*tag*/, Mask512 a, + Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kandn_mask32(a.raw, b.raw)}; +#else + return Mask512{~a.raw & b.raw}; +#endif +} +template +HWY_INLINE Mask512 AndNot(hwy::SizeTag<4> /*tag*/, Mask512 a, + Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kandn_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast(~a.raw & b.raw)}; +#endif +} +template +HWY_INLINE Mask512 AndNot(hwy::SizeTag<8> /*tag*/, Mask512 a, + Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kandn_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast(~a.raw & b.raw)}; +#endif +} + +template +HWY_INLINE Mask512 Or(hwy::SizeTag<1> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kor_mask64(a.raw, b.raw)}; +#else + return Mask512{a.raw | b.raw}; +#endif +} +template +HWY_INLINE Mask512 Or(hwy::SizeTag<2> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kor_mask32(a.raw, b.raw)}; +#else + return Mask512{a.raw | b.raw}; +#endif +} +template +HWY_INLINE Mask512 Or(hwy::SizeTag<4> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kor_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw | b.raw)}; +#endif +} +template +HWY_INLINE Mask512 Or(hwy::SizeTag<8> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kor_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw | b.raw)}; +#endif +} + +template +HWY_INLINE Mask512 Xor(hwy::SizeTag<1> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxor_mask64(a.raw, b.raw)}; +#else + return Mask512{a.raw ^ b.raw}; +#endif +} +template +HWY_INLINE Mask512 Xor(hwy::SizeTag<2> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxor_mask32(a.raw, b.raw)}; +#else + return Mask512{a.raw ^ b.raw}; +#endif +} +template +HWY_INLINE Mask512 Xor(hwy::SizeTag<4> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxor_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask512 Xor(hwy::SizeTag<8> /*tag*/, Mask512 a, Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxor_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast(a.raw ^ b.raw)}; +#endif +} + +template +HWY_INLINE Mask512 ExclusiveNeither(hwy::SizeTag<1> /*tag*/, Mask512 a, + Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxnor_mask64(a.raw, b.raw)}; +#else + return Mask512{~(a.raw ^ b.raw)}; +#endif +} +template +HWY_INLINE Mask512 ExclusiveNeither(hwy::SizeTag<2> /*tag*/, Mask512 a, + Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxnor_mask32(a.raw, b.raw)}; +#else + return Mask512{static_cast<__mmask32>(~(a.raw ^ b.raw) & 0xFFFFFFFF)}; +#endif +} +template +HWY_INLINE Mask512 ExclusiveNeither(hwy::SizeTag<4> /*tag*/, Mask512 a, + Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxnor_mask16(a.raw, b.raw)}; +#else + return Mask512{static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)}; +#endif +} +template +HWY_INLINE Mask512 ExclusiveNeither(hwy::SizeTag<8> /*tag*/, Mask512 a, + Mask512 b) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return Mask512{_kxnor_mask8(a.raw, b.raw)}; +#else + return Mask512{static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)}; +#endif +} + +} // namespace detail + +template +HWY_API Mask512 Not(Mask512 m) { + return detail::Not(hwy::SizeTag(), m); +} + +template +HWY_API Mask512 And(Mask512 a, Mask512 b) { + return detail::And(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask512 AndNot(Mask512 a, Mask512 b) { + return detail::AndNot(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask512 Or(Mask512 a, Mask512 b) { + return detail::Or(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask512 Xor(Mask512 a, Mask512 b) { + return detail::Xor(hwy::SizeTag(), a, b); +} + +template +HWY_API Mask512 ExclusiveNeither(Mask512 a, Mask512 b) { + return detail::ExclusiveNeither(hwy::SizeTag(), a, b); +} + +template +HWY_API MFromD CombineMasks(D /*d*/, MFromD> hi, + MFromD> lo) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const __mmask64 combined_mask = _mm512_kunpackd( + static_cast<__mmask64>(hi.raw), static_cast<__mmask64>(lo.raw)); +#else + const __mmask64 combined_mask = static_cast<__mmask64>( + ((static_cast(hi.raw) << 32) | (lo.raw & 0xFFFFFFFFULL))); +#endif + + return MFromD{combined_mask}; +} + +template +HWY_API MFromD UpperHalfOfMask(D /*d*/, MFromD> m) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + const auto shifted_mask = _kshiftri_mask64(static_cast<__mmask64>(m.raw), 32); +#else + const auto shifted_mask = static_cast(m.raw) >> 32; +#endif + + return MFromD{static_cast().raw)>(shifted_mask)}; +} + +template +HWY_API MFromD SlideMask1Up(D /*d*/, MFromD m) { + using RawM = decltype(MFromD().raw); +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return MFromD{ + static_cast(_kshiftli_mask64(static_cast<__mmask64>(m.raw), 1))}; +#else + return MFromD{static_cast(static_cast(m.raw) << 1)}; +#endif +} + +template +HWY_API MFromD SlideMask1Down(D /*d*/, MFromD m) { + using RawM = decltype(MFromD().raw); +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return MFromD{ + static_cast(_kshiftri_mask64(static_cast<__mmask64>(m.raw), 1))}; +#else + return MFromD{static_cast(static_cast(m.raw) >> 1)}; +#endif +} + +// ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) + +HWY_API Vec512 BroadcastSignBit(Vec512 v) { +#if HWY_TARGET <= HWY_AVX3_DL + const Repartition> du64; + return detail::GaloisAffine(v, Set(du64, 0x8080808080808080ull)); +#else + const DFromV d; + return VecFromMask(v < Zero(d)); +#endif +} + +HWY_API Vec512 BroadcastSignBit(Vec512 v) { + return ShiftRight<15>(v); +} + +HWY_API Vec512 BroadcastSignBit(Vec512 v) { + return ShiftRight<31>(v); +} + +HWY_API Vec512 BroadcastSignBit(Vec512 v) { + return ShiftRight<63>(v); +} + +// ------------------------------ Floating-point classification (Not) + +#if HWY_HAVE_FLOAT16 || HWY_IDE + +namespace detail { + +template +__mmask32 Fix_mm512_fpclass_ph_mask(__m512h v) { +#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1500 + // GCC's _mm512_cmp_ph_mask uses `__mmask8` instead of `__mmask32`, hence only + // the first 8 lanes are set. + return static_cast<__mmask32>(__builtin_ia32_fpclassph512_mask( + static_cast<__v32hf>(v), kCategories, static_cast<__mmask32>(-1))); +#else + return _mm512_fpclass_ph_mask(v, kCategories); +#endif +} + +} // namespace detail + +HWY_API Mask512 IsNaN(Vec512 v) { + constexpr int kCategories = HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN; + return Mask512{ + detail::Fix_mm512_fpclass_ph_mask(v.raw)}; +} + +HWY_API Mask512 IsEitherNaN(Vec512 a, + Vec512 b) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + return Mask512{_mm512_cmp_ph_mask(a.raw, b.raw, _CMP_UNORD_Q)}; + HWY_DIAGNOSTICS(pop) +} + +HWY_API Mask512 IsInf(Vec512 v) { + constexpr int kCategories = HWY_X86_FPCLASS_POS_INF | HWY_X86_FPCLASS_NEG_INF; + return Mask512{ + detail::Fix_mm512_fpclass_ph_mask(v.raw)}; +} + +// Returns whether normal/subnormal/zero. fpclass doesn't have a flag for +// positive, so we have to check for inf/NaN and negate. +HWY_API Mask512 IsFinite(Vec512 v) { + constexpr int kCategories = HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | + HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF; + return Not(Mask512{ + detail::Fix_mm512_fpclass_ph_mask(v.raw)}); +} + +#endif // HWY_HAVE_FLOAT16 + +HWY_API Mask512 IsNaN(Vec512 v) { + return Mask512{_mm512_fpclass_ps_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; +} +HWY_API Mask512 IsNaN(Vec512 v) { + return Mask512{_mm512_fpclass_pd_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN)}; +} + +HWY_API Mask512 IsEitherNaN(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_UNORD_Q)}; +} + +HWY_API Mask512 IsEitherNaN(Vec512 a, Vec512 b) { + return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_UNORD_Q)}; +} + +HWY_API Mask512 IsInf(Vec512 v) { + return Mask512{_mm512_fpclass_ps_mask( + v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; +} +HWY_API Mask512 IsInf(Vec512 v) { + return Mask512{_mm512_fpclass_pd_mask( + v.raw, HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}; +} + +// Returns whether normal/subnormal/zero. fpclass doesn't have a flag for +// positive, so we have to check for inf/NaN and negate. +HWY_API Mask512 IsFinite(Vec512 v) { + return Not(Mask512{_mm512_fpclass_ps_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | + HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); +} +HWY_API Mask512 IsFinite(Vec512 v) { + return Not(Mask512{_mm512_fpclass_pd_mask( + v.raw, HWY_X86_FPCLASS_SNAN | HWY_X86_FPCLASS_QNAN | + HWY_X86_FPCLASS_NEG_INF | HWY_X86_FPCLASS_POS_INF)}); +} + +// ================================================== MEMORY + +// ------------------------------ Load + +template +HWY_API VFromD Load(D /* tag */, const TFromD* HWY_RESTRICT aligned) { + return VFromD{_mm512_load_si512(aligned)}; +} +// bfloat16_t is handled by x86_128-inl.h. +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec512 Load(D /* tag */, + const float16_t* HWY_RESTRICT aligned) { + return Vec512{_mm512_load_ph(aligned)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec512 Load(D /* tag */, const float* HWY_RESTRICT aligned) { + return Vec512{_mm512_load_ps(aligned)}; +} +template +HWY_API VFromD Load(D /* tag */, const double* HWY_RESTRICT aligned) { + return VFromD{_mm512_load_pd(aligned)}; +} + +template +HWY_API VFromD LoadU(D /* tag */, const TFromD* HWY_RESTRICT p) { + return VFromD{_mm512_loadu_si512(p)}; +} + +// bfloat16_t is handled by x86_128-inl.h. +#if HWY_HAVE_FLOAT16 +template +HWY_API Vec512 LoadU(D /* tag */, const float16_t* HWY_RESTRICT p) { + return Vec512{_mm512_loadu_ph(p)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec512 LoadU(D /* tag */, const float* HWY_RESTRICT p) { + return Vec512{_mm512_loadu_ps(p)}; +} +template +HWY_API VFromD LoadU(D /* tag */, const double* HWY_RESTRICT p) { + return VFromD{_mm512_loadu_pd(p)}; +} + +// ------------------------------ MaskedLoad + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm512_maskz_loadu_epi8(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm512_maskz_loadu_epi16( + m.raw, reinterpret_cast(p))}); +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm512_maskz_loadu_epi32(m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoad(MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm512_maskz_loadu_epi64(m.raw, p)}; +} + +template +HWY_API Vec512 MaskedLoad(Mask512 m, D /* tag */, + const float* HWY_RESTRICT p) { + return Vec512{_mm512_maskz_loadu_ps(m.raw, p)}; +} + +template +HWY_API Vec512 MaskedLoad(Mask512 m, D /* tag */, + const double* HWY_RESTRICT p) { + return Vec512{_mm512_maskz_loadu_pd(m.raw, p)}; +} + +// ------------------------------ MaskedLoadOr + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm512_mask_loadu_epi8(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, + const TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + return BitCast( + d, VFromD{_mm512_mask_loadu_epi16( + BitCast(du, v).raw, m.raw, reinterpret_cast(p))}); +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm512_mask_loadu_epi32(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D /* tag */, + const TFromD* HWY_RESTRICT p) { + return VFromD{_mm512_mask_loadu_epi64(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, Mask512 m, D /* tag */, + const float* HWY_RESTRICT p) { + return VFromD{_mm512_mask_loadu_ps(v.raw, m.raw, p)}; +} + +template +HWY_API VFromD MaskedLoadOr(VFromD v, Mask512 m, D /* tag */, + const double* HWY_RESTRICT p) { + return VFromD{_mm512_mask_loadu_pd(v.raw, m.raw, p)}; +} + +// ------------------------------ LoadDup128 + +// Loads 128 bit and duplicates into both 128-bit halves. This avoids the +// 3-cycle cost of moving data between 128-bit halves and avoids port 5. +template +HWY_API VFromD LoadDup128(D d, const TFromD* const HWY_RESTRICT p) { + const RebindToUnsigned du; + const Full128> d128; + const RebindToUnsigned du128; + return BitCast(d, VFromD{_mm512_broadcast_i32x4( + BitCast(du128, LoadU(d128, p)).raw)}); +} +template +HWY_API VFromD LoadDup128(D /* tag */, const float* HWY_RESTRICT p) { + const __m128 x4 = _mm_loadu_ps(p); + return VFromD{_mm512_broadcast_f32x4(x4)}; +} + +template +HWY_API VFromD LoadDup128(D /* tag */, const double* HWY_RESTRICT p) { + const __m128d x2 = _mm_loadu_pd(p); + return VFromD{_mm512_broadcast_f64x2(x2)}; +} + +// ------------------------------ Store + +template +HWY_API void Store(VFromD v, D /* tag */, TFromD* HWY_RESTRICT aligned) { + _mm512_store_si512(reinterpret_cast<__m512i*>(aligned), v.raw); +} +// bfloat16_t is handled by x86_128-inl.h. +#if HWY_HAVE_FLOAT16 +template +HWY_API void Store(Vec512 v, D /* tag */, + float16_t* HWY_RESTRICT aligned) { + _mm512_store_ph(aligned, v.raw); +} +#endif +template +HWY_API void Store(Vec512 v, D /* tag */, float* HWY_RESTRICT aligned) { + _mm512_store_ps(aligned, v.raw); +} +template +HWY_API void Store(VFromD v, D /* tag */, double* HWY_RESTRICT aligned) { + _mm512_store_pd(aligned, v.raw); +} + +template +HWY_API void StoreU(VFromD v, D /* tag */, TFromD* HWY_RESTRICT p) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(p), v.raw); +} +// bfloat16_t is handled by x86_128-inl.h. +#if HWY_HAVE_FLOAT16 +template +HWY_API void StoreU(Vec512 v, D /* tag */, + float16_t* HWY_RESTRICT p) { + _mm512_storeu_ph(p, v.raw); +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API void StoreU(Vec512 v, D /* tag */, float* HWY_RESTRICT p) { + _mm512_storeu_ps(p, v.raw); +} +template +HWY_API void StoreU(Vec512 v, D /* tag */, double* HWY_RESTRICT p) { + _mm512_storeu_pd(p, v.raw); +} + +// ------------------------------ BlendedStore + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT p) { + _mm512_mask_storeu_epi8(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT p) { + const RebindToUnsigned du; // for float16_t + _mm512_mask_storeu_epi16(reinterpret_cast(p), m.raw, + BitCast(du, v).raw); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT p) { + _mm512_mask_storeu_epi32(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT p) { + _mm512_mask_storeu_epi64(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec512 v, Mask512 m, D /* tag */, + float* HWY_RESTRICT p) { + _mm512_mask_storeu_ps(p, m.raw, v.raw); +} + +template +HWY_API void BlendedStore(Vec512 v, Mask512 m, D /* tag */, + double* HWY_RESTRICT p) { + _mm512_mask_storeu_pd(p, m.raw, v.raw); +} + +// ------------------------------ Non-temporal stores + +template +HWY_API void Stream(VFromD v, D d, TFromD* HWY_RESTRICT aligned) { + const RebindToUnsigned du; // for float16_t + _mm512_stream_si512(reinterpret_cast<__m512i*>(aligned), BitCast(du, v).raw); +} +template +HWY_API void Stream(VFromD v, D /* tag */, float* HWY_RESTRICT aligned) { + _mm512_stream_ps(aligned, v.raw); +} +template +HWY_API void Stream(VFromD v, D /* tag */, double* HWY_RESTRICT aligned) { + _mm512_stream_pd(aligned, v.raw); +} + +// ------------------------------ ScatterOffset + +// Work around warnings in the intrinsic definitions (passing -1 as a mask). +HWY_DIAGNOSTICS(push) +HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + +template +HWY_API void ScatterOffset(VFromD v, D /* tag */, + TFromD* HWY_RESTRICT base, + VFromD> offset) { + _mm512_i32scatter_epi32(base, offset.raw, v.raw, 1); +} + +template +HWY_API void ScatterOffset(VFromD v, D /* tag */, + TFromD* HWY_RESTRICT base, + VFromD> offset) { + _mm512_i64scatter_epi64(base, offset.raw, v.raw, 1); +} + +template +HWY_API void ScatterOffset(VFromD v, D /* tag */, float* HWY_RESTRICT base, + Vec512 offset) { + _mm512_i32scatter_ps(base, offset.raw, v.raw, 1); +} + +template +HWY_API void ScatterOffset(VFromD v, D /* tag */, double* HWY_RESTRICT base, + Vec512 offset) { + _mm512_i64scatter_pd(base, offset.raw, v.raw, 1); +} + +// ------------------------------ ScatterIndex + +template +HWY_API void ScatterIndex(VFromD v, D /* tag */, + TFromD* HWY_RESTRICT base, + VFromD> index) { + _mm512_i32scatter_epi32(base, index.raw, v.raw, 4); +} + +template +HWY_API void ScatterIndex(VFromD v, D /* tag */, + TFromD* HWY_RESTRICT base, + VFromD> index) { + _mm512_i64scatter_epi64(base, index.raw, v.raw, 8); +} + +template +HWY_API void ScatterIndex(VFromD v, D /* tag */, float* HWY_RESTRICT base, + Vec512 index) { + _mm512_i32scatter_ps(base, index.raw, v.raw, 4); +} + +template +HWY_API void ScatterIndex(VFromD v, D /* tag */, double* HWY_RESTRICT base, + Vec512 index) { + _mm512_i64scatter_pd(base, index.raw, v.raw, 8); +} + +// ------------------------------ MaskedScatterIndex + +template +HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT base, + VFromD> index) { + _mm512_mask_i32scatter_epi32(base, m.raw, index.raw, v.raw, 4); +} + +template +HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D /* tag */, + TFromD* HWY_RESTRICT base, + VFromD> index) { + _mm512_mask_i64scatter_epi64(base, m.raw, index.raw, v.raw, 8); +} + +template +HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D /* tag */, + float* HWY_RESTRICT base, + Vec512 index) { + _mm512_mask_i32scatter_ps(base, m.raw, index.raw, v.raw, 4); +} + +template +HWY_API void MaskedScatterIndex(VFromD v, MFromD m, D /* tag */, + double* HWY_RESTRICT base, + Vec512 index) { + _mm512_mask_i64scatter_pd(base, m.raw, index.raw, v.raw, 8); +} + +// ------------------------------ Gather + +namespace detail { + +template +HWY_INLINE Vec512 NativeGather512(const T* HWY_RESTRICT base, + Vec512 indices) { + return Vec512{_mm512_i32gather_epi32(indices.raw, base, kScale)}; +} + +template +HWY_INLINE Vec512 NativeGather512(const T* HWY_RESTRICT base, + Vec512 indices) { + return Vec512{_mm512_i64gather_epi64(indices.raw, base, kScale)}; +} + +template +HWY_INLINE Vec512 NativeGather512(const float* HWY_RESTRICT base, + Vec512 indices) { + return Vec512{_mm512_i32gather_ps(indices.raw, base, kScale)}; +} + +template +HWY_INLINE Vec512 NativeGather512(const double* HWY_RESTRICT base, + Vec512 indices) { + return Vec512{_mm512_i64gather_pd(indices.raw, base, kScale)}; +} + +template +HWY_INLINE Vec512 NativeMaskedGatherOr512(Vec512 no, Mask512 m, + const T* HWY_RESTRICT base, + Vec512 indices) { + return Vec512{ + _mm512_mask_i32gather_epi32(no.raw, m.raw, indices.raw, base, kScale)}; +} + +template +HWY_INLINE Vec512 NativeMaskedGatherOr512(Vec512 no, Mask512 m, + const T* HWY_RESTRICT base, + Vec512 indices) { + return Vec512{ + _mm512_mask_i64gather_epi64(no.raw, m.raw, indices.raw, base, kScale)}; +} + +template +HWY_INLINE Vec512 NativeMaskedGatherOr512(Vec512 no, + Mask512 m, + const float* HWY_RESTRICT base, + Vec512 indices) { + return Vec512{ + _mm512_mask_i32gather_ps(no.raw, m.raw, indices.raw, base, kScale)}; +} + +template +HWY_INLINE Vec512 NativeMaskedGatherOr512( + Vec512 no, Mask512 m, const double* HWY_RESTRICT base, + Vec512 indices) { + return Vec512{ + _mm512_mask_i64gather_pd(no.raw, m.raw, indices.raw, base, kScale)}; +} +} // namespace detail + +template +HWY_API VFromD GatherOffset(D /*d*/, const TFromD* HWY_RESTRICT base, + VFromD> offsets) { + return detail::NativeGather512<1>(base, offsets); +} + +template +HWY_API VFromD GatherIndex(D /*d*/, const TFromD* HWY_RESTRICT base, + VFromD> indices) { + return detail::NativeGather512)>(base, indices); +} + +template +HWY_API VFromD MaskedGatherIndexOr(VFromD no, MFromD m, D /*d*/, + const TFromD* HWY_RESTRICT base, + VFromD> indices) { + return detail::NativeMaskedGatherOr512)>(no, m, base, + indices); +} + +HWY_DIAGNOSTICS(pop) + +// ================================================== SWIZZLE + +// ------------------------------ LowerHalf + +template +HWY_API VFromD LowerHalf(D /* tag */, VFromD> v) { + return VFromD{_mm512_castsi512_si256(v.raw)}; +} +template +HWY_API VFromD LowerHalf(D /* tag */, Vec512 v) { + return VFromD{_mm512_castsi512_si256(v.raw)}; +} +template +HWY_API VFromD LowerHalf(D /* tag */, Vec512 v) { +#if HWY_HAVE_FLOAT16 + return VFromD{_mm512_castph512_ph256(v.raw)}; +#else + return VFromD{_mm512_castsi512_si256(v.raw)}; +#endif // HWY_HAVE_FLOAT16 +} +template +HWY_API VFromD LowerHalf(D /* tag */, Vec512 v) { + return VFromD{_mm512_castps512_ps256(v.raw)}; +} +template +HWY_API VFromD LowerHalf(D /* tag */, Vec512 v) { + return VFromD{_mm512_castpd512_pd256(v.raw)}; +} + +template +HWY_API Vec256 LowerHalf(Vec512 v) { + const Half> dh; + return LowerHalf(dh, v); +} + +// ------------------------------ UpperHalf + +template +HWY_API VFromD UpperHalf(D d, VFromD> v) { + const RebindToUnsigned du; // for float16_t + const Twice dut; + return BitCast(d, VFromD{ + _mm512_extracti32x8_epi32(BitCast(dut, v).raw, 1)}); +} +template +HWY_API VFromD UpperHalf(D /* tag */, VFromD> v) { + return VFromD{_mm512_extractf32x8_ps(v.raw, 1)}; +} +template +HWY_API VFromD UpperHalf(D /* tag */, VFromD> v) { + return VFromD{_mm512_extractf64x4_pd(v.raw, 1)}; +} + +// ------------------------------ ExtractLane (Store) +template +HWY_API T ExtractLane(const Vec512 v, size_t i) { + const DFromV d; + HWY_DASSERT(i < Lanes(d)); + +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + constexpr size_t kLanesPerBlock = 16 / sizeof(T); + if (__builtin_constant_p(i < kLanesPerBlock) && (i < kLanesPerBlock)) { + return ExtractLane(ResizeBitCast(Full128(), v), i); + } +#endif + + alignas(64) T lanes[MaxLanes(d)]; + Store(v, d, lanes); + return lanes[i]; +} + +// ------------------------------ ExtractBlock +template * = nullptr> +HWY_API Vec128 ExtractBlock(Vec512 v) { + const DFromV d; + const Half dh; + return ExtractBlock(LowerHalf(dh, v)); +} + +template 1)>* = nullptr> +HWY_API Vec128 ExtractBlock(Vec512 v) { + static_assert(kBlockIdx <= 3, "Invalid block index"); + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast(Full128(), + Vec128>{ + _mm512_extracti32x4_epi32(BitCast(du, v).raw, kBlockIdx)}); +} + +template 1)>* = nullptr> +HWY_API Vec128 ExtractBlock(Vec512 v) { + static_assert(kBlockIdx <= 3, "Invalid block index"); + return Vec128{_mm512_extractf32x4_ps(v.raw, kBlockIdx)}; +} + +template 1)>* = nullptr> +HWY_API Vec128 ExtractBlock(Vec512 v) { + static_assert(kBlockIdx <= 3, "Invalid block index"); + return Vec128{_mm512_extractf64x2_pd(v.raw, kBlockIdx)}; +} + +// ------------------------------ InsertLane (Store) +template +HWY_API Vec512 InsertLane(const Vec512 v, size_t i, T t) { + return detail::InsertLaneUsingBroadcastAndBlend(v, i, t); +} + +// ------------------------------ InsertBlock +namespace detail { + +template +HWY_INLINE Vec512 InsertBlock(hwy::SizeTag<0> /* blk_idx_tag */, Vec512 v, + Vec128 blk_to_insert) { + const DFromV d; + const auto insert_mask = FirstN(d, 16 / sizeof(T)); + return IfThenElse(insert_mask, ResizeBitCast(d, blk_to_insert), v); +} + +template +HWY_INLINE Vec512 InsertBlock(hwy::SizeTag /* blk_idx_tag */, + Vec512 v, Vec128 blk_to_insert) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + const Full128> du_blk_to_insert; + return BitCast( + d, VFromD{_mm512_inserti32x4( + BitCast(du, v).raw, BitCast(du_blk_to_insert, blk_to_insert).raw, + static_cast(kBlockIdx & 3))}); +} + +template * = nullptr> +HWY_INLINE Vec512 InsertBlock(hwy::SizeTag /* blk_idx_tag */, + Vec512 v, + Vec128 blk_to_insert) { + return Vec512{_mm512_insertf32x4(v.raw, blk_to_insert.raw, + static_cast(kBlockIdx & 3))}; +} + +template * = nullptr> +HWY_INLINE Vec512 InsertBlock(hwy::SizeTag /* blk_idx_tag */, + Vec512 v, + Vec128 blk_to_insert) { + return Vec512{_mm512_insertf64x2(v.raw, blk_to_insert.raw, + static_cast(kBlockIdx & 3))}; +} + +} // namespace detail + +template +HWY_API Vec512 InsertBlock(Vec512 v, Vec128 blk_to_insert) { + static_assert(0 <= kBlockIdx && kBlockIdx <= 3, "Invalid block index"); + return detail::InsertBlock(hwy::SizeTag(kBlockIdx)>(), v, + blk_to_insert); +} + +// ------------------------------ GetLane (LowerHalf) +template +HWY_API T GetLane(const Vec512 v) { + return GetLane(LowerHalf(v)); +} + +// ------------------------------ ZeroExtendVector + +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { +#if HWY_HAVE_ZEXT // See definition/comment in x86_256-inl.h. + (void)d; + return VFromD{_mm512_zextsi256_si512(lo.raw)}; +#else + return VFromD{_mm512_inserti32x8(Zero(d).raw, lo.raw, 0)}; +#endif +} +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { +#if HWY_HAVE_ZEXT + (void)d; + return VFromD{_mm512_zextph256_ph512(lo.raw)}; +#else + const RebindToUnsigned du; + return BitCast(d, ZeroExtendVector(du, BitCast(du, lo))); +#endif +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { +#if HWY_HAVE_ZEXT + (void)d; + return VFromD{_mm512_zextps256_ps512(lo.raw)}; +#else + return VFromD{_mm512_insertf32x8(Zero(d).raw, lo.raw, 0)}; +#endif +} +template +HWY_API VFromD ZeroExtendVector(D d, VFromD> lo) { +#if HWY_HAVE_ZEXT + (void)d; + return VFromD{_mm512_zextpd256_pd512(lo.raw)}; +#else + return VFromD{_mm512_insertf64x4(Zero(d).raw, lo.raw, 0)}; +#endif +} + +// ------------------------------ ZeroExtendResizeBitCast + +namespace detail { + +template +HWY_INLINE VFromD ZeroExtendResizeBitCast( + hwy::SizeTag<16> /* from_size_tag */, hwy::SizeTag<64> /* to_size_tag */, + DTo d_to, DFrom d_from, VFromD v) { + const Repartition du8_from; + const auto vu8 = BitCast(du8_from, v); + const RebindToUnsigned du_to; +#if HWY_HAVE_ZEXT + return BitCast(d_to, + VFromD{_mm512_zextsi128_si512(vu8.raw)}); +#else + return BitCast(d_to, VFromD{ + _mm512_inserti32x4(Zero(du_to).raw, vu8.raw, 0)}); +#endif +} + +template +HWY_INLINE VFromD ZeroExtendResizeBitCast( + hwy::SizeTag<16> /* from_size_tag */, hwy::SizeTag<64> /* to_size_tag */, + DTo d_to, DFrom d_from, VFromD v) { + const Repartition df32_from; + const auto vf32 = BitCast(df32_from, v); +#if HWY_HAVE_ZEXT + (void)d_to; + return Vec512{_mm512_zextps128_ps512(vf32.raw)}; +#else + return Vec512{_mm512_insertf32x4(Zero(d_to).raw, vf32.raw, 0)}; +#endif +} + +template +HWY_INLINE Vec512 ZeroExtendResizeBitCast( + hwy::SizeTag<16> /* from_size_tag */, hwy::SizeTag<64> /* to_size_tag */, + DTo d_to, DFrom d_from, VFromD v) { + const Repartition df64_from; + const auto vf64 = BitCast(df64_from, v); +#if HWY_HAVE_ZEXT + (void)d_to; + return Vec512{_mm512_zextpd128_pd512(vf64.raw)}; +#else + return Vec512{_mm512_insertf64x2(Zero(d_to).raw, vf64.raw, 0)}; +#endif +} + +template +HWY_INLINE VFromD ZeroExtendResizeBitCast( + hwy::SizeTag<8> /* from_size_tag */, hwy::SizeTag<64> /* to_size_tag */, + DTo d_to, DFrom d_from, VFromD v) { + const Twice dt_from; + return ZeroExtendResizeBitCast(hwy::SizeTag<16>(), hwy::SizeTag<64>(), d_to, + dt_from, ZeroExtendVector(dt_from, v)); +} + +} // namespace detail + +// ------------------------------ Combine + +template +HWY_API VFromD Combine(D d, VFromD> hi, VFromD> lo) { + const RebindToUnsigned du; // for float16_t + const Half duh; + const __m512i lo512 = ZeroExtendVector(du, BitCast(duh, lo)).raw; + return BitCast(d, VFromD{ + _mm512_inserti32x8(lo512, BitCast(duh, hi).raw, 1)}); +} +template +HWY_API VFromD Combine(D d, VFromD> hi, VFromD> lo) { + return VFromD{_mm512_insertf32x8(ZeroExtendVector(d, lo).raw, hi.raw, 1)}; +} +template +HWY_API VFromD Combine(D d, VFromD> hi, VFromD> lo) { + return VFromD{_mm512_insertf64x4(ZeroExtendVector(d, lo).raw, hi.raw, 1)}; +} + +// ------------------------------ ShiftLeftBytes +template +HWY_API VFromD ShiftLeftBytes(D /* tag */, const VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + return VFromD{_mm512_bslli_epi128(v.raw, kBytes)}; +} + +// ------------------------------ ShiftRightBytes +template +HWY_API VFromD ShiftRightBytes(D /* tag */, const VFromD v) { + static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); + return VFromD{_mm512_bsrli_epi128(v.raw, kBytes)}; +} + +// ------------------------------ CombineShiftRightBytes + +template +HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { + const Repartition d8; + return BitCast(d, Vec512{_mm512_alignr_epi8( + BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); +} + +// ------------------------------ Broadcast/splat any lane + +template +HWY_API Vec512 Broadcast(const Vec512 v) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; + const VU vu = BitCast(du, v); // for float16_t + static_assert(0 <= kLane && kLane < 8, "Invalid lane"); + if (kLane < 4) { + const __m512i lo = _mm512_shufflelo_epi16(vu.raw, (0x55 * kLane) & 0xFF); + return BitCast(d, VU{_mm512_unpacklo_epi64(lo, lo)}); + } else { + const __m512i hi = + _mm512_shufflehi_epi16(vu.raw, (0x55 * (kLane - 4)) & 0xFF); + return BitCast(d, VU{_mm512_unpackhi_epi64(hi, hi)}); + } +} + +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); + return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; +} + +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = kLane ? _MM_PERM_DCDC : _MM_PERM_BABA; + return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; +} + +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 4, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, perm)}; +} + +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 2, "Invalid lane"); + constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0xFF * kLane); + return Vec512{_mm512_shuffle_pd(v.raw, v.raw, perm)}; +} + +// ------------------------------ BroadcastBlock +template +HWY_API Vec512 BroadcastBlock(Vec512 v) { + static_assert(0 <= kBlockIdx && kBlockIdx <= 3, "Invalid block index"); + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast( + d, VFromD{_mm512_shuffle_i32x4( + BitCast(du, v).raw, BitCast(du, v).raw, 0x55 * kBlockIdx)}); +} + +template +HWY_API Vec512 BroadcastBlock(Vec512 v) { + static_assert(0 <= kBlockIdx && kBlockIdx <= 3, "Invalid block index"); + return Vec512{_mm512_shuffle_f32x4(v.raw, v.raw, 0x55 * kBlockIdx)}; +} + +template +HWY_API Vec512 BroadcastBlock(Vec512 v) { + static_assert(0 <= kBlockIdx && kBlockIdx <= 3, "Invalid block index"); + return Vec512{_mm512_shuffle_f64x2(v.raw, v.raw, 0x55 * kBlockIdx)}; +} + +// ------------------------------ BroadcastLane + +namespace detail { + +template +HWY_INLINE Vec512 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec512 v) { + return Vec512{_mm512_broadcastb_epi8(ResizeBitCast(Full128(), v).raw)}; +} + +template +HWY_INLINE Vec512 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec512 v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm512_broadcastw_epi16( + ResizeBitCast(Full128(), v).raw)}); +} + +template +HWY_INLINE Vec512 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec512 v) { + return Vec512{_mm512_broadcastd_epi32(ResizeBitCast(Full128(), v).raw)}; +} + +template +HWY_INLINE Vec512 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec512 v) { + return Vec512{_mm512_broadcastq_epi64(ResizeBitCast(Full128(), v).raw)}; +} + +HWY_INLINE Vec512 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec512 v) { + return Vec512{ + _mm512_broadcastss_ps(ResizeBitCast(Full128(), v).raw)}; +} + +HWY_INLINE Vec512 BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */, + Vec512 v) { + return Vec512{ + _mm512_broadcastsd_pd(ResizeBitCast(Full128(), v).raw)}; +} + +template * = nullptr> +HWY_INLINE Vec512 BroadcastLane(hwy::SizeTag /* lane_idx_tag */, + Vec512 v) { + constexpr size_t kLanesPerBlock = 16 / sizeof(T); + constexpr int kBlockIdx = static_cast(kLaneIdx / kLanesPerBlock); + constexpr int kLaneInBlkIdx = + static_cast(kLaneIdx) & (kLanesPerBlock - 1); + return Broadcast(BroadcastBlock(v)); +} + +} // namespace detail + +template +HWY_API Vec512 BroadcastLane(Vec512 v) { + static_assert(0 <= kLaneIdx, "Invalid lane"); + return detail::BroadcastLane(hwy::SizeTag(kLaneIdx)>(), + v); +} + +// ------------------------------ Hard-coded shuffles + +// Notation: let Vec512 have lanes 7,6,5,4,3,2,1,0 (0 is +// least-significant). Shuffle0321 rotates four-lane blocks one lane to the +// right (the previous least-significant lane is now most-significant => +// 47650321). These could also be implemented via CombineShiftRightBytes but +// the shuffle_abcd notation is more convenient. + +// Swap 32-bit halves in 64-bit halves. +template +HWY_API Vec512 Shuffle2301(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CDAB)}; +} +HWY_API Vec512 Shuffle2301(const Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CDAB)}; +} + +namespace detail { + +template +HWY_API Vec512 ShuffleTwo2301(const Vec512 a, const Vec512 b) { + const DFromV d; + const RebindToFloat df; + return BitCast( + d, Vec512{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, + _MM_PERM_CDAB)}); +} +template +HWY_API Vec512 ShuffleTwo1230(const Vec512 a, const Vec512 b) { + const DFromV d; + const RebindToFloat df; + return BitCast( + d, Vec512{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, + _MM_PERM_BCDA)}); +} +template +HWY_API Vec512 ShuffleTwo3012(const Vec512 a, const Vec512 b) { + const DFromV d; + const RebindToFloat df; + return BitCast( + d, Vec512{_mm512_shuffle_ps(BitCast(df, a).raw, BitCast(df, b).raw, + _MM_PERM_DABC)}); +} + +} // namespace detail + +// Swap 64-bit halves +HWY_API Vec512 Shuffle1032(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle1032(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle1032(const Vec512 v) { + // Shorter encoding than _mm512_permute_ps. + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle01(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle01(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; +} +HWY_API Vec512 Shuffle01(const Vec512 v) { + // Shorter encoding than _mm512_permute_pd. + return Vec512{_mm512_shuffle_pd(v.raw, v.raw, _MM_PERM_BBBB)}; +} + +// Rotate right 32 bits +HWY_API Vec512 Shuffle0321(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ADCB)}; +} +HWY_API Vec512 Shuffle0321(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ADCB)}; +} +HWY_API Vec512 Shuffle0321(const Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_ADCB)}; +} +// Rotate left 32 bits +HWY_API Vec512 Shuffle2103(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CBAD)}; +} +HWY_API Vec512 Shuffle2103(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CBAD)}; +} +HWY_API Vec512 Shuffle2103(const Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CBAD)}; +} + +// Reverse +HWY_API Vec512 Shuffle0123(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ABCD)}; +} +HWY_API Vec512 Shuffle0123(const Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ABCD)}; +} +HWY_API Vec512 Shuffle0123(const Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_ABCD)}; +} + +// ------------------------------ TableLookupLanes + +// Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. +template +struct Indices512 { + __m512i raw; +}; + +template , typename TI> +HWY_API Indices512 IndicesFromVec(D /* tag */, Vec512 vec) { + static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); +#if HWY_IS_DEBUG_BUILD + const DFromV di; + const RebindToUnsigned du; + using TU = MakeUnsigned; + const auto vec_u = BitCast(du, vec); + HWY_DASSERT( + AllTrue(du, Lt(vec_u, Set(du, static_cast(128 / sizeof(T)))))); +#endif + return Indices512{vec.raw}; +} + +template +HWY_API Indices512> SetTableIndices(D d, const TI* idx) { + const Rebind di; + return IndicesFromVec(d, LoadU(di, idx)); +} + +template +HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec512{_mm512_permutexvar_epi8(idx.raw, v.raw)}; +#else + const DFromV d; + const Repartition du16; + const Vec512 idx_vec{idx.raw}; + + const auto bd_sel_mask = + MaskFromVec(BitCast(d, ShiftLeft<3>(BitCast(du16, idx_vec)))); + const auto cd_sel_mask = + MaskFromVec(BitCast(d, ShiftLeft<2>(BitCast(du16, idx_vec)))); + + const Vec512 v_a{_mm512_shuffle_i32x4(v.raw, v.raw, 0x00)}; + const Vec512 v_b{_mm512_shuffle_i32x4(v.raw, v.raw, 0x55)}; + const Vec512 v_c{_mm512_shuffle_i32x4(v.raw, v.raw, 0xAA)}; + const Vec512 v_d{_mm512_shuffle_i32x4(v.raw, v.raw, 0xFF)}; + + const auto shuf_a = TableLookupBytes(v_a, idx_vec); + const auto shuf_c = TableLookupBytes(v_c, idx_vec); + const Vec512 shuf_ab{_mm512_mask_shuffle_epi8(shuf_a.raw, bd_sel_mask.raw, + v_b.raw, idx_vec.raw)}; + const Vec512 shuf_cd{_mm512_mask_shuffle_epi8(shuf_c.raw, bd_sel_mask.raw, + v_d.raw, idx_vec.raw)}; + return IfThenElse(cd_sel_mask, shuf_cd, shuf_ab); +#endif +} + +template +HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { + return Vec512{_mm512_permutexvar_epi16(idx.raw, v.raw)}; +} +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 TableLookupLanes(Vec512 v, + Indices512 idx) { + return Vec512{_mm512_permutexvar_ph(idx.raw, v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { + return Vec512{_mm512_permutexvar_epi32(idx.raw, v.raw)}; +} + +template +HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { + return Vec512{_mm512_permutexvar_epi64(idx.raw, v.raw)}; +} + +HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { + return Vec512{_mm512_permutexvar_ps(idx.raw, v.raw)}; +} + +HWY_API Vec512 TableLookupLanes(Vec512 v, + Indices512 idx) { + return Vec512{_mm512_permutexvar_pd(idx.raw, v.raw)}; +} + +template +HWY_API Vec512 TwoTablesLookupLanes(Vec512 a, Vec512 b, + Indices512 idx) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec512{_mm512_permutex2var_epi8(a.raw, idx.raw, b.raw)}; +#else + const DFromV d; + const auto b_sel_mask = + MaskFromVec(BitCast(d, ShiftLeft<1>(Vec512{idx.raw}))); + return IfThenElse(b_sel_mask, TableLookupLanes(b, idx), + TableLookupLanes(a, idx)); +#endif +} + +template +HWY_API Vec512 TwoTablesLookupLanes(Vec512 a, Vec512 b, + Indices512 idx) { + return Vec512{_mm512_permutex2var_epi16(a.raw, idx.raw, b.raw)}; +} + +template +HWY_API Vec512 TwoTablesLookupLanes(Vec512 a, Vec512 b, + Indices512 idx) { + return Vec512{_mm512_permutex2var_epi32(a.raw, idx.raw, b.raw)}; +} + +#if HWY_HAVE_FLOAT16 +HWY_API Vec512 TwoTablesLookupLanes(Vec512 a, + Vec512 b, + Indices512 idx) { + return Vec512{_mm512_permutex2var_ph(a.raw, idx.raw, b.raw)}; +} +#endif // HWY_HAVE_FLOAT16 +HWY_API Vec512 TwoTablesLookupLanes(Vec512 a, Vec512 b, + Indices512 idx) { + return Vec512{_mm512_permutex2var_ps(a.raw, idx.raw, b.raw)}; +} + +template +HWY_API Vec512 TwoTablesLookupLanes(Vec512 a, Vec512 b, + Indices512 idx) { + return Vec512{_mm512_permutex2var_epi64(a.raw, idx.raw, b.raw)}; +} + +HWY_API Vec512 TwoTablesLookupLanes(Vec512 a, Vec512 b, + Indices512 idx) { + return Vec512{_mm512_permutex2var_pd(a.raw, idx.raw, b.raw)}; +} + +// ------------------------------ Reverse + +template +HWY_API VFromD Reverse(D d, const VFromD v) { +#if HWY_TARGET <= HWY_AVX3_DL + const RebindToSigned di; + alignas(64) static constexpr int8_t kReverse[64] = { + 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, + 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, + 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; + const Vec512 idx = Load(di, kReverse); + return BitCast( + d, Vec512{_mm512_permutexvar_epi8(idx.raw, BitCast(di, v).raw)}); +#else + const RepartitionToWide d16; + return BitCast(d, Reverse(d16, RotateRight<8>(BitCast(d16, v)))); +#endif +} + +template +HWY_API VFromD Reverse(D d, const VFromD v) { + const RebindToSigned di; + alignas(64) static constexpr int16_t kReverse[32] = { + 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; + const Vec512 idx = Load(di, kReverse); + return BitCast(d, Vec512{ + _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +} + +template +HWY_API VFromD Reverse(D d, const VFromD v) { + alignas(64) static constexpr int32_t kReverse[16] = { + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +template +HWY_API VFromD Reverse(D d, const VFromD v) { + alignas(64) static constexpr int64_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0}; + return TableLookupLanes(v, SetTableIndices(d, kReverse)); +} + +// ------------------------------ Reverse2 (in x86_128) + +// ------------------------------ Reverse4 + +template +HWY_API VFromD Reverse4(D d, const VFromD v) { + const RebindToSigned di; + alignas(64) static constexpr int16_t kReverse4[32] = { + 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12, + 19, 18, 17, 16, 23, 22, 21, 20, 27, 26, 25, 24, 31, 30, 29, 28}; + const Vec512 idx = Load(di, kReverse4); + return BitCast(d, Vec512{ + _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +} + +// 32 bit Reverse4 defined in x86_128. + +template +HWY_API VFromD Reverse4(D /* tag */, const VFromD v) { + return VFromD{_mm512_permutex_epi64(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; +} +template +HWY_API VFromD Reverse4(D /* tag */, VFromD v) { + return VFromD{_mm512_permutex_pd(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; +} + +// ------------------------------ Reverse8 + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + const RebindToSigned di; + alignas(64) static constexpr int16_t kReverse8[32] = { + 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, + 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24}; + const Vec512 idx = Load(di, kReverse8); + return BitCast(d, Vec512{ + _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); +} + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + const RebindToSigned di; + alignas(64) static constexpr int32_t kReverse8[16] = { + 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8}; + const Vec512 idx = Load(di, kReverse8); + return BitCast(d, Vec512{ + _mm512_permutexvar_epi32(idx.raw, BitCast(di, v).raw)}); +} + +template +HWY_API VFromD Reverse8(D d, const VFromD v) { + return Reverse(d, v); +} + +// ------------------------------ ReverseBits (GaloisAffine) + +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_REVERSE_BITS_UI8 +#undef HWY_NATIVE_REVERSE_BITS_UI8 +#else +#define HWY_NATIVE_REVERSE_BITS_UI8 +#endif + +// Generic for all vector lengths. Must be defined after all GaloisAffine. +template +HWY_API V ReverseBits(V v) { + const Repartition> du64; + return detail::GaloisAffine(v, Set(du64, 0x8040201008040201u)); +} + +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ------------------------------ InterleaveLower + +template +HWY_API Vec512 InterleaveLower(Vec512 a, Vec512 b) { + return Vec512{_mm512_unpacklo_epi8(a.raw, b.raw)}; +} +template +HWY_API Vec512 InterleaveLower(Vec512 a, Vec512 b) { + const DFromV d; + const RebindToUnsigned du; + using VU = VFromD; // for float16_t + return BitCast( + d, VU{_mm512_unpacklo_epi16(BitCast(du, a).raw, BitCast(du, b).raw)}); +} +template +HWY_API Vec512 InterleaveLower(Vec512 a, Vec512 b) { + return Vec512{_mm512_unpacklo_epi32(a.raw, b.raw)}; +} +template +HWY_API Vec512 InterleaveLower(Vec512 a, Vec512 b) { + return Vec512{_mm512_unpacklo_epi64(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(Vec512 a, Vec512 b) { + return Vec512{_mm512_unpacklo_ps(a.raw, b.raw)}; +} +HWY_API Vec512 InterleaveLower(Vec512 a, Vec512 b) { + return Vec512{_mm512_unpacklo_pd(a.raw, b.raw)}; +} + +// ------------------------------ InterleaveUpper + +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm512_unpackhi_epi8(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + using VU = VFromD; // for float16_t + return BitCast( + d, VU{_mm512_unpackhi_epi16(BitCast(du, a).raw, BitCast(du, b).raw)}); +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm512_unpackhi_epi32(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm512_unpackhi_epi64(a.raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm512_unpackhi_ps(a.raw, b.raw)}; +} +template +HWY_API VFromD InterleaveUpper(D /* tag */, VFromD a, VFromD b) { + return VFromD{_mm512_unpackhi_pd(a.raw, b.raw)}; +} + +// ------------------------------ Concat* halves + +// hiH,hiL loH,loL |-> hiL,loL (= lower halves) +template +HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, + VFromD{_mm512_shuffle_i32x4( + BitCast(du, lo).raw, BitCast(du, hi).raw, _MM_PERM_BABA)}); +} +template +HWY_API VFromD ConcatLowerLower(D /* tag */, VFromD hi, VFromD lo) { + return VFromD{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_BABA)}; +} +template +HWY_API Vec512 ConcatLowerLower(D /* tag */, Vec512 hi, + Vec512 lo) { + return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_BABA)}; +} + +// hiH,hiL loH,loL |-> hiH,loH (= upper halves) +template +HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, + VFromD{_mm512_shuffle_i32x4( + BitCast(du, lo).raw, BitCast(du, hi).raw, _MM_PERM_DCDC)}); +} +template +HWY_API VFromD ConcatUpperUpper(D /* tag */, VFromD hi, VFromD lo) { + return VFromD{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_DCDC)}; +} +template +HWY_API Vec512 ConcatUpperUpper(D /* tag */, Vec512 hi, + Vec512 lo) { + return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_DCDC)}; +} + +// hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) +template +HWY_API VFromD ConcatLowerUpper(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, + VFromD{_mm512_shuffle_i32x4( + BitCast(du, lo).raw, BitCast(du, hi).raw, _MM_PERM_BADC)}); +} +template +HWY_API VFromD ConcatLowerUpper(D /* tag */, VFromD hi, VFromD lo) { + return VFromD{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_BADC)}; +} +template +HWY_API Vec512 ConcatLowerUpper(D /* tag */, Vec512 hi, + Vec512 lo) { + return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_BADC)}; +} + +// hiH,hiL loH,loL |-> hiH,loL (= outer halves) +template +HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { + // There are no imm8 blend in AVX512. Use blend16 because 32-bit masks + // are efficiently loaded from 32-bit regs. + const __mmask32 mask = /*_cvtu32_mask32 */ (0x0000FFFF); + const RebindToUnsigned du; // for float16_t + return BitCast(d, VFromD{_mm512_mask_blend_epi16( + mask, BitCast(du, hi).raw, BitCast(du, lo).raw)}); +} +template +HWY_API VFromD ConcatUpperLower(D /* tag */, VFromD hi, VFromD lo) { + const __mmask16 mask = /*_cvtu32_mask16 */ (0x00FF); + return VFromD{_mm512_mask_blend_ps(mask, hi.raw, lo.raw)}; +} +template +HWY_API Vec512 ConcatUpperLower(D /* tag */, Vec512 hi, + Vec512 lo) { + const __mmask8 mask = /*_cvtu32_mask8 */ (0x0F); + return Vec512{_mm512_mask_blend_pd(mask, hi.raw, lo.raw)}; +} + +// ------------------------------ ConcatOdd + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3_DL + alignas(64) static constexpr uint8_t kIdx[64] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, + 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, + 53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75, 77, + 79, 81, 83, 85, 87, 89, 91, 93, 95, 97, 99, 101, 103, + 105, 107, 109, 111, 113, 115, 117, 119, 121, 123, 125, 127}; + return BitCast( + d, Vec512{_mm512_permutex2var_epi8( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Right-shift 8 bits per u16 so we can pack. + const Vec512 uH = ShiftRight<8>(BitCast(dw, hi)); + const Vec512 uL = ShiftRight<8>(BitCast(dw, lo)); + const Vec512 u8{_mm512_packus_epi16(uL.raw, uH.raw)}; + // Undo block interleave: lower half = even u64 lanes, upper = odd u64 lanes. + const Full512 du64; + alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + return BitCast(d, TableLookupLanes(u8, SetTableIndices(du64, kIdx))); +#endif +} + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint16_t kIdx[32] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, + 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63}; + return BitCast( + d, Vec512{_mm512_permutex2var_epi16( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +} + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint32_t kIdx[16] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}; + return BitCast( + d, Vec512{_mm512_permutex2var_epi32( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +} + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint32_t kIdx[16] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}; + return VFromD{_mm512_permutex2var_ps(lo.raw, Load(du, kIdx).raw, hi.raw)}; +} + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; + return BitCast( + d, Vec512{_mm512_permutex2var_epi64( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +} + +template +HWY_API VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; + return VFromD{_mm512_permutex2var_pd(lo.raw, Load(du, kIdx).raw, hi.raw)}; +} + +// ------------------------------ ConcatEven + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; +#if HWY_TARGET <= HWY_AVX3_DL + alignas(64) static constexpr uint8_t kIdx[64] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, + 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, + 52, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72, 74, 76, + 78, 80, 82, 84, 86, 88, 90, 92, 94, 96, 98, 100, 102, + 104, 106, 108, 110, 112, 114, 116, 118, 120, 122, 124, 126}; + return BitCast( + d, Vec512{_mm512_permutex2var_epi8( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +#else + const RepartitionToWide dw; + // Isolate lower 8 bits per u16 so we can pack. + const Vec512 mask = Set(dw, 0x00FF); + const Vec512 uH = And(BitCast(dw, hi), mask); + const Vec512 uL = And(BitCast(dw, lo), mask); + const Vec512 u8{_mm512_packus_epi16(uL.raw, uH.raw)}; + // Undo block interleave: lower half = even u64 lanes, upper = odd u64 lanes. + const Full512 du64; + alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + return BitCast(d, TableLookupLanes(u8, SetTableIndices(du64, kIdx))); +#endif +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint16_t kIdx[32] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; + return BitCast( + d, Vec512{_mm512_permutex2var_epi16( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint32_t kIdx[16] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; + return BitCast( + d, Vec512{_mm512_permutex2var_epi32( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint32_t kIdx[16] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; + return VFromD{_mm512_permutex2var_ps(lo.raw, Load(du, kIdx).raw, hi.raw)}; +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return BitCast( + d, Vec512{_mm512_permutex2var_epi64( + BitCast(du, lo).raw, Load(du, kIdx).raw, BitCast(du, hi).raw)}); +} + +template +HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; + return VFromD{_mm512_permutex2var_pd(lo.raw, Load(du, kIdx).raw, hi.raw)}; +} + +// ------------------------------ InterleaveWholeLower + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { +#if HWY_TARGET <= HWY_AVX3_DL + const RebindToUnsigned du; + alignas(64) static constexpr uint8_t kIdx[64] = { + 0, 64, 1, 65, 2, 66, 3, 67, 4, 68, 5, 69, 6, 70, 7, 71, + 8, 72, 9, 73, 10, 74, 11, 75, 12, 76, 13, 77, 14, 78, 15, 79, + 16, 80, 17, 81, 18, 82, 19, 83, 20, 84, 21, 85, 22, 86, 23, 87, + 24, 88, 25, 89, 26, 90, 27, 91, 28, 92, 29, 93, 30, 94, 31, 95}; + return VFromD{_mm512_permutex2var_epi8(a.raw, Load(du, kIdx).raw, b.raw)}; +#else + alignas(64) static constexpr uint64_t kIdx2[8] = {0, 1, 8, 9, 2, 3, 10, 11}; + const Repartition du64; + return VFromD{_mm512_permutex2var_epi64(InterleaveLower(a, b).raw, + Load(du64, kIdx2).raw, + InterleaveUpper(d, a, b).raw)}; +#endif +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint16_t kIdx[32] = { + 0, 32, 1, 33, 2, 34, 3, 35, 4, 36, 5, 37, 6, 38, 7, 39, + 8, 40, 9, 41, 10, 42, 11, 43, 12, 44, 13, 45, 14, 46, 15, 47}; + return BitCast( + d, VFromD{_mm512_permutex2var_epi16( + BitCast(du, a).raw, Load(du, kIdx).raw, BitCast(du, b).raw)}); +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint32_t kIdx[16] = {0, 16, 1, 17, 2, 18, 3, 19, + 4, 20, 5, 21, 6, 22, 7, 23}; + return VFromD{_mm512_permutex2var_epi32(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint32_t kIdx[16] = {0, 16, 1, 17, 2, 18, 3, 19, + 4, 20, 5, 21, 6, 22, 7, 23}; + return VFromD{_mm512_permutex2var_ps(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[8] = {0, 8, 1, 9, 2, 10, 3, 11}; + return VFromD{_mm512_permutex2var_epi64(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[8] = {0, 8, 1, 9, 2, 10, 3, 11}; + return VFromD{_mm512_permutex2var_pd(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +// ------------------------------ InterleaveWholeUpper + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { +#if HWY_TARGET <= HWY_AVX3_DL + const RebindToUnsigned du; + alignas(64) static constexpr uint8_t kIdx[64] = { + 32, 96, 33, 97, 34, 98, 35, 99, 36, 100, 37, 101, 38, 102, 39, 103, + 40, 104, 41, 105, 42, 106, 43, 107, 44, 108, 45, 109, 46, 110, 47, 111, + 48, 112, 49, 113, 50, 114, 51, 115, 52, 116, 53, 117, 54, 118, 55, 119, + 56, 120, 57, 121, 58, 122, 59, 123, 60, 124, 61, 125, 62, 126, 63, 127}; + return VFromD{_mm512_permutex2var_epi8(a.raw, Load(du, kIdx).raw, b.raw)}; +#else + alignas(64) static constexpr uint64_t kIdx2[8] = {4, 5, 12, 13, 6, 7, 14, 15}; + const Repartition du64; + return VFromD{_mm512_permutex2var_epi64(InterleaveLower(a, b).raw, + Load(du64, kIdx2).raw, + InterleaveUpper(d, a, b).raw)}; +#endif +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint16_t kIdx[32] = { + 16, 48, 17, 49, 18, 50, 19, 51, 20, 52, 21, 53, 22, 54, 23, 55, + 24, 56, 25, 57, 26, 58, 27, 59, 28, 60, 29, 61, 30, 62, 31, 63}; + return BitCast( + d, VFromD{_mm512_permutex2var_epi16( + BitCast(du, a).raw, Load(du, kIdx).raw, BitCast(du, b).raw)}); +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint32_t kIdx[16] = { + 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}; + return VFromD{_mm512_permutex2var_epi32(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint32_t kIdx[16] = { + 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}; + return VFromD{_mm512_permutex2var_ps(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[8] = {4, 12, 5, 13, 6, 14, 7, 15}; + return VFromD{_mm512_permutex2var_epi64(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +template +HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { + const RebindToUnsigned du; + alignas(64) static constexpr uint64_t kIdx[8] = {4, 12, 5, 13, 6, 14, 7, 15}; + return VFromD{_mm512_permutex2var_pd(a.raw, Load(du, kIdx).raw, b.raw)}; +} + +// ------------------------------ DupEven (InterleaveLower) + +template +HWY_API Vec512 DupEven(Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CCAA)}; +} +HWY_API Vec512 DupEven(Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CCAA)}; +} + +template +HWY_API Vec512 DupEven(const Vec512 v) { + const DFromV d; + return InterleaveLower(d, v, v); +} + +// ------------------------------ DupOdd (InterleaveUpper) + +template +HWY_API Vec512 DupOdd(Vec512 v) { + return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_DDBB)}; +} +HWY_API Vec512 DupOdd(Vec512 v) { + return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_DDBB)}; +} + +template +HWY_API Vec512 DupOdd(const Vec512 v) { + const DFromV d; + return InterleaveUpper(d, v, v); +} + +// ------------------------------ OddEven (IfThenElse) + +template +HWY_API Vec512 OddEven(const Vec512 a, const Vec512 b) { + constexpr size_t s = sizeof(T); + constexpr int shift = s == 1 ? 0 : s == 2 ? 32 : s == 4 ? 48 : 56; + return IfThenElse(Mask512{0x5555555555555555ull >> shift}, b, a); +} + +// -------------------------- InterleaveEven + +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm512_mask_shuffle_epi32( + a.raw, static_cast<__mmask16>(0xAAAA), b.raw, + static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(2, 2, 0, 0)))}; +} +template +HWY_API VFromD InterleaveEven(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm512_mask_shuffle_ps(a.raw, static_cast<__mmask16>(0xAAAA), + b.raw, b.raw, + _MM_SHUFFLE(2, 2, 0, 0))}; +} +// -------------------------- InterleaveOdd + +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm512_mask_shuffle_epi32( + b.raw, static_cast<__mmask16>(0x5555), a.raw, + static_cast<_MM_PERM_ENUM>(_MM_SHUFFLE(3, 3, 1, 1)))}; +} +template +HWY_API VFromD InterleaveOdd(D /*d*/, VFromD a, VFromD b) { + return VFromD{_mm512_mask_shuffle_ps(b.raw, static_cast<__mmask16>(0x5555), + a.raw, a.raw, + _MM_SHUFFLE(3, 3, 1, 1))}; +} + +// ------------------------------ OddEvenBlocks + +template +HWY_API Vec512 OddEvenBlocks(Vec512 odd, Vec512 even) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast( + d, VFromD{_mm512_mask_blend_epi64( + __mmask8{0x33u}, BitCast(du, odd).raw, BitCast(du, even).raw)}); +} + +HWY_API Vec512 OddEvenBlocks(Vec512 odd, Vec512 even) { + return Vec512{ + _mm512_mask_blend_ps(__mmask16{0x0F0Fu}, odd.raw, even.raw)}; +} + +HWY_API Vec512 OddEvenBlocks(Vec512 odd, Vec512 even) { + return Vec512{ + _mm512_mask_blend_pd(__mmask8{0x33u}, odd.raw, even.raw)}; +} + +// ------------------------------ SwapAdjacentBlocks + +template +HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { + const DFromV d; + const RebindToUnsigned du; // for float16_t + return BitCast(d, + VFromD{_mm512_shuffle_i32x4( + BitCast(du, v).raw, BitCast(du, v).raw, _MM_PERM_CDAB)}); +} + +HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { + return Vec512{_mm512_shuffle_f32x4(v.raw, v.raw, _MM_PERM_CDAB)}; +} + +HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { + return Vec512{_mm512_shuffle_f64x2(v.raw, v.raw, _MM_PERM_CDAB)}; +} + +// ------------------------------ InterleaveEvenBlocks +template +HWY_API Vec512 InterleaveEvenBlocks(Full512 d, Vec512 a, Vec512 b) { + return OddEvenBlocks(SlideUpBlocks<1>(d, b), a); +} + +// ------------------------------ InterleaveOddBlocks (ConcatUpperUpper) +template +HWY_API Vec512 InterleaveOddBlocks(Full512 d, Vec512 a, Vec512 b) { + return OddEvenBlocks(b, SlideDownBlocks<1>(d, a)); +} + +// ------------------------------ InterleaveLowerBlocks (TwoTablesLookupLanes) + +// Note that _mm512_shuffle_f32x4 etc. can only use `a` to populate the lower +// half of the result, so we would require at least two instructions. We instead +// use table lookups. + +template +HWY_API Vec512 InterleaveLowerBlocks(Full512 d, Vec512 a, + Vec512 b) { + const Repartition du64; + HWY_ALIGN static constexpr int64_t kIdx[8] = {0, 1, 8, 9, 2, 3, 10, 11}; + const auto idx = SetTableIndices(du64, kIdx); + return BitCast(d, + TwoTablesLookupLanes(BitCast(du64, a), BitCast(du64, b), idx)); +} + +HWY_API Vec512 InterleaveLowerBlocks(Full512 d, Vec512 a, + Vec512 b) { + HWY_ALIGN static constexpr int32_t kIdx[16] = {0, 1, 2, 3, 16, 17, 18, 19, + 4, 5, 6, 7, 20, 21, 22, 23}; + const auto idx = SetTableIndices(d, kIdx); + return TwoTablesLookupLanes(a, b, idx); +} + +HWY_API Vec512 InterleaveLowerBlocks(Full512 d, + Vec512 a, + Vec512 b) { + HWY_ALIGN static constexpr int64_t kIdx[8] = {0, 1, 8, 9, 2, 3, 10, 11}; + const auto idx = SetTableIndices(d, kIdx); + return TwoTablesLookupLanes(a, b, idx); +} + +// ------------------------------ InterleaveUpperBlocks (TwoTablesLookupLanes) +template +HWY_API Vec512 InterleaveUpperBlocks(Full512 d, Vec512 a, + Vec512 b) { + const Repartition du64; + HWY_ALIGN static constexpr int64_t kIdx[8] = {4, 5, 12, 13, 6, 7, 14, 15}; + const auto idx = SetTableIndices(du64, kIdx); + return BitCast( + d, TwoTablesLookupLanes(du64, BitCast(du64, a), BitCast(du64, b), idx)); +} + +HWY_API Vec512 InterleaveUpperBlocks(Full512 d, Vec512 a, + Vec512 b) { + HWY_ALIGN static constexpr int32_t kIdx[16] = { + 8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31}; + const auto idx = SetTableIndices(d, kIdx); + return TwoTablesLookupLanes(a, b, idx); +} + +HWY_API Vec512 InterleaveUpperBlocks(Full512 d, + Vec512 a, + Vec512 b) { + HWY_ALIGN static constexpr int64_t kIdx[8] = {4, 5, 12, 13, 6, 7, 14, 15}; + const auto idx = SetTableIndices(d, kIdx); + return TwoTablesLookupLanes(a, b, idx); +} + +// ------------------------------ ReverseBlocks + +template +HWY_API VFromD ReverseBlocks(D d, VFromD v) { + const RebindToUnsigned du; // for float16_t + return BitCast(d, + VFromD{_mm512_shuffle_i32x4( + BitCast(du, v).raw, BitCast(du, v).raw, _MM_PERM_ABCD)}); +} +template +HWY_API VFromD ReverseBlocks(D /* tag */, VFromD v) { + return VFromD{_mm512_shuffle_f32x4(v.raw, v.raw, _MM_PERM_ABCD)}; +} +template +HWY_API VFromD ReverseBlocks(D /* tag */, VFromD v) { + return VFromD{_mm512_shuffle_f64x2(v.raw, v.raw, _MM_PERM_ABCD)}; +} + +// ------------------------------ TableLookupBytes (ZeroExtendVector) + +// Both full +template +HWY_API Vec512 TableLookupBytes(Vec512 bytes, Vec512 indices) { + const DFromV d; + return BitCast(d, Vec512{_mm512_shuffle_epi8( + BitCast(Full512(), bytes).raw, + BitCast(Full512(), indices).raw)}); +} + +// Partial index vector +template +HWY_API Vec128 TableLookupBytes(Vec512 bytes, Vec128 from) { + const Full512 d512; + const Half d256; + const Half d128; + // First expand to full 128, then 256, then 512. + const Vec128 from_full{from.raw}; + const auto from_512 = + ZeroExtendVector(d512, ZeroExtendVector(d256, from_full)); + const auto tbl_full = TableLookupBytes(bytes, from_512); + // Shrink to 256, then 128, then partial. + return Vec128{LowerHalf(d128, LowerHalf(d256, tbl_full)).raw}; +} +template +HWY_API Vec256 TableLookupBytes(Vec512 bytes, Vec256 from) { + const DFromV dih; + const Twice di; + const auto from_512 = ZeroExtendVector(di, from); + return LowerHalf(dih, TableLookupBytes(bytes, from_512)); +} + +// Partial table vector +template +HWY_API Vec512 TableLookupBytes(Vec128 bytes, Vec512 from) { + const DFromV d512; + const Half d256; + const Half d128; + // First expand to full 128, then 256, then 512. + const Vec128 bytes_full{bytes.raw}; + const auto bytes_512 = + ZeroExtendVector(d512, ZeroExtendVector(d256, bytes_full)); + return TableLookupBytes(bytes_512, from); +} +template +HWY_API Vec512 TableLookupBytes(Vec256 bytes, Vec512 from) { + const Full512 d; + return TableLookupBytes(ZeroExtendVector(d, bytes), from); +} + +// Partial both are handled by x86_128/256. + +// ------------------------------ I8/U8 Broadcast (TableLookupBytes) + +template +HWY_API Vec512 Broadcast(const Vec512 v) { + static_assert(0 <= kLane && kLane < 16, "Invalid lane"); + return TableLookupBytes(v, Set(Full512(), static_cast(kLane))); +} + +// ------------------------------ Per4LaneBlockShuffle + +namespace detail { + +template +HWY_INLINE VFromD Per4LaneBlkShufDupSet4xU32(D d, const uint32_t x3, + const uint32_t x2, + const uint32_t x1, + const uint32_t x0) { + return BitCast(d, Vec512{_mm512_set_epi32( + static_cast(x3), static_cast(x2), + static_cast(x1), static_cast(x0), + static_cast(x3), static_cast(x2), + static_cast(x1), static_cast(x0), + static_cast(x3), static_cast(x2), + static_cast(x1), static_cast(x0), + static_cast(x3), static_cast(x2), + static_cast(x1), static_cast(x0))}); +} + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, + hwy::SizeTag<64> /*vect_size_tag*/, V v) { + return V{ + _mm512_shuffle_epi32(v.raw, static_cast<_MM_PERM_ENUM>(kIdx3210 & 0xFF))}; +} + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<4> /*lane_size_tag*/, + hwy::SizeTag<64> /*vect_size_tag*/, V v) { + return V{_mm512_shuffle_ps(v.raw, v.raw, static_cast(kIdx3210 & 0xFF))}; +} + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<8> /*lane_size_tag*/, + hwy::SizeTag<64> /*vect_size_tag*/, V v) { + return V{_mm512_permutex_epi64(v.raw, static_cast(kIdx3210 & 0xFF))}; +} + +template )> +HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, + hwy::SizeTag<8> /*lane_size_tag*/, + hwy::SizeTag<64> /*vect_size_tag*/, V v) { + return V{_mm512_permutex_pd(v.raw, static_cast(kIdx3210 & 0xFF))}; +} + +} // namespace detail + +// ------------------------------ SlideUpLanes + +namespace detail { + +template +HWY_INLINE V CombineShiftRightI32Lanes(V hi, V lo) { + const DFromV d; + const Repartition du32; + return BitCast(d, + Vec512{_mm512_alignr_epi32( + BitCast(du32, hi).raw, BitCast(du32, lo).raw, kI32Lanes)}); +} + +template +HWY_INLINE V CombineShiftRightI64Lanes(V hi, V lo) { + const DFromV d; + const Repartition du64; + return BitCast(d, + Vec512{_mm512_alignr_epi64( + BitCast(du64, hi).raw, BitCast(du64, lo).raw, kI64Lanes)}); +} + +template +HWY_INLINE V SlideUpI32Lanes(V v) { + static_assert(0 <= kI32Lanes && kI32Lanes <= 15, + "kI32Lanes must be between 0 and 15"); + const DFromV d; + return CombineShiftRightI32Lanes<16 - kI32Lanes>(v, Zero(d)); +} + +template +HWY_INLINE V SlideUpI64Lanes(V v) { + static_assert(0 <= kI64Lanes && kI64Lanes <= 7, + "kI64Lanes must be between 0 and 7"); + const DFromV d; + return CombineShiftRightI64Lanes<8 - kI64Lanes>(v, Zero(d)); +} + +template +HWY_INLINE VFromD TableLookupSlideUpLanes(D d, VFromD v, size_t amt) { + const Repartition du8; + +#if HWY_TARGET <= HWY_AVX3_DL + const auto byte_idx = Iota(du8, static_cast(size_t{0} - amt)); + return TwoTablesLookupLanes(v, Zero(d), Indices512>{byte_idx.raw}); +#else + const Repartition du16; + const Repartition du64; + const auto byte_idx = Iota(du8, static_cast(size_t{0} - (amt & 15))); + const auto blk_u64_idx = + Iota(du64, static_cast(uint64_t{0} - ((amt >> 4) << 1))); + + const VFromD even_blocks{ + _mm512_shuffle_i32x4(v.raw, v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; + const VFromD odd_blocks{ + _mm512_shuffle_i32x4(v.raw, v.raw, _MM_SHUFFLE(3, 1, 1, 3))}; + const auto odd_sel_mask = + MaskFromVec(BitCast(d, ShiftLeft<3>(BitCast(du16, byte_idx)))); + const auto even_blk_lookup_result = + BitCast(d, TableLookupBytes(even_blocks, byte_idx)); + const VFromD blockwise_slide_up_result{ + _mm512_mask_shuffle_epi8(even_blk_lookup_result.raw, odd_sel_mask.raw, + odd_blocks.raw, byte_idx.raw)}; + return BitCast(d, TwoTablesLookupLanes( + BitCast(du64, blockwise_slide_up_result), Zero(du64), + Indices512{blk_u64_idx.raw})); +#endif +} + +} // namespace detail + +template +HWY_API VFromD SlideUpBlocks(D d, VFromD v) { + static_assert(0 <= kBlocks && kBlocks <= 3, + "kBlocks must be between 0 and 3"); + switch (kBlocks) { + case 0: + return v; + case 1: + return detail::SlideUpI64Lanes<2>(v); + case 2: + return ConcatLowerLower(d, v, Zero(d)); + case 3: + return detail::SlideUpI64Lanes<6>(v); + } + + return v; +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return detail::SlideUpI32Lanes<1>(v); + case 2: + return detail::SlideUpI64Lanes<1>(v); + case 3: + return detail::SlideUpI32Lanes<3>(v); + case 4: + return detail::SlideUpI64Lanes<2>(v); + case 5: + return detail::SlideUpI32Lanes<5>(v); + case 6: + return detail::SlideUpI64Lanes<3>(v); + case 7: + return detail::SlideUpI32Lanes<7>(v); + case 8: + return ConcatLowerLower(d, v, Zero(d)); + case 9: + return detail::SlideUpI32Lanes<9>(v); + case 10: + return detail::SlideUpI64Lanes<5>(v); + case 11: + return detail::SlideUpI32Lanes<11>(v); + case 12: + return detail::SlideUpI64Lanes<6>(v); + case 13: + return detail::SlideUpI32Lanes<13>(v); + case 14: + return detail::SlideUpI64Lanes<7>(v); + case 15: + return detail::SlideUpI32Lanes<15>(v); + } + } +#endif + + return detail::TableLookupSlideUpLanes(d, v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + switch (amt) { + case 0: + return v; + case 1: + return detail::SlideUpI64Lanes<1>(v); + case 2: + return detail::SlideUpI64Lanes<2>(v); + case 3: + return detail::SlideUpI64Lanes<3>(v); + case 4: + return ConcatLowerLower(d, v, Zero(d)); + case 5: + return detail::SlideUpI64Lanes<5>(v); + case 6: + return detail::SlideUpI64Lanes<6>(v); + case 7: + return detail::SlideUpI64Lanes<7>(v); + } + } +#endif + + return detail::TableLookupSlideUpLanes(d, v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + if ((amt & 3) == 0) { + const Repartition du32; + return BitCast(d, SlideUpLanes(du32, BitCast(du32, v), amt >> 2)); + } else if ((amt & 1) == 0) { + const Repartition du16; + return BitCast( + d, detail::TableLookupSlideUpLanes(du16, BitCast(du16, v), amt >> 1)); + } +#if HWY_TARGET > HWY_AVX3_DL + else if (amt <= 63) { // NOLINT(readability/braces) + const Repartition du64; + const size_t blk_u64_slideup_amt = (amt >> 4) << 1; + const auto vu64 = BitCast(du64, v); + const auto v_hi = + BitCast(d, SlideUpLanes(du64, vu64, blk_u64_slideup_amt)); + const auto v_lo = + (blk_u64_slideup_amt <= 4) + ? BitCast(d, SlideUpLanes(du64, vu64, blk_u64_slideup_amt + 2)) + : Zero(d); + switch (amt & 15) { + case 1: + return CombineShiftRightBytes<15>(d, v_hi, v_lo); + case 3: + return CombineShiftRightBytes<13>(d, v_hi, v_lo); + case 5: + return CombineShiftRightBytes<11>(d, v_hi, v_lo); + case 7: + return CombineShiftRightBytes<9>(d, v_hi, v_lo); + case 9: + return CombineShiftRightBytes<7>(d, v_hi, v_lo); + case 11: + return CombineShiftRightBytes<5>(d, v_hi, v_lo); + case 13: + return CombineShiftRightBytes<3>(d, v_hi, v_lo); + case 15: + return CombineShiftRightBytes<1>(d, v_hi, v_lo); + } + } +#endif // HWY_TARGET > HWY_AVX3_DL + } +#endif + + return detail::TableLookupSlideUpLanes(d, v, amt); +} + +template +HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt) && (amt & 1) == 0) { + const Repartition du32; + return BitCast(d, SlideUpLanes(du32, BitCast(du32, v), amt >> 1)); + } +#endif + + return detail::TableLookupSlideUpLanes(d, v, amt); +} + +// ------------------------------ Slide1Up + +template +HWY_API VFromD Slide1Up(D d, VFromD v) { +#if HWY_TARGET <= HWY_AVX3_DL + return detail::TableLookupSlideUpLanes(d, v, 1); +#else + const auto v_lo = detail::SlideUpI64Lanes<2>(v); + return CombineShiftRightBytes<15>(d, v, v_lo); +#endif +} + +template +HWY_API VFromD Slide1Up(D d, VFromD v) { + return detail::TableLookupSlideUpLanes(d, v, 1); +} + +template +HWY_API VFromD Slide1Up(D /*d*/, VFromD v) { + return detail::SlideUpI32Lanes<1>(v); +} + +template +HWY_API VFromD Slide1Up(D /*d*/, VFromD v) { + return detail::SlideUpI64Lanes<1>(v); +} + +// ------------------------------ SlideDownLanes + +namespace detail { + +template +HWY_INLINE V SlideDownI32Lanes(V v) { + static_assert(0 <= kI32Lanes && kI32Lanes <= 15, + "kI32Lanes must be between 0 and 15"); + const DFromV d; + return CombineShiftRightI32Lanes(Zero(d), v); +} + +template +HWY_INLINE V SlideDownI64Lanes(V v) { + static_assert(0 <= kI64Lanes && kI64Lanes <= 7, + "kI64Lanes must be between 0 and 7"); + const DFromV d; + return CombineShiftRightI64Lanes(Zero(d), v); +} + +template +HWY_INLINE VFromD TableLookupSlideDownLanes(D d, VFromD v, size_t amt) { + const Repartition du8; + +#if HWY_TARGET <= HWY_AVX3_DL + auto byte_idx = Iota(du8, static_cast(amt)); + return TwoTablesLookupLanes(v, Zero(d), Indices512>{byte_idx.raw}); +#else + const Repartition du16; + const Repartition du64; + const auto byte_idx = Iota(du8, static_cast(amt & 15)); + const auto blk_u64_idx = Iota(du64, static_cast(((amt >> 4) << 1))); + + const VFromD even_blocks{ + _mm512_shuffle_i32x4(v.raw, v.raw, _MM_SHUFFLE(0, 2, 2, 0))}; + const VFromD odd_blocks{ + _mm512_shuffle_i32x4(v.raw, v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; + const auto odd_sel_mask = + MaskFromVec(BitCast(d, ShiftLeft<3>(BitCast(du16, byte_idx)))); + const VFromD even_blk_lookup_result{ + _mm512_maskz_shuffle_epi8(static_cast<__mmask64>(0x0000FFFFFFFFFFFFULL), + even_blocks.raw, byte_idx.raw)}; + const VFromD blockwise_slide_up_result{ + _mm512_mask_shuffle_epi8(even_blk_lookup_result.raw, odd_sel_mask.raw, + odd_blocks.raw, byte_idx.raw)}; + return BitCast(d, TwoTablesLookupLanes( + BitCast(du64, blockwise_slide_up_result), Zero(du64), + Indices512{blk_u64_idx.raw})); +#endif +} + +} // namespace detail + +template +HWY_API VFromD SlideDownBlocks(D d, VFromD v) { + static_assert(0 <= kBlocks && kBlocks <= 3, + "kBlocks must be between 0 and 3"); + const Half dh; + switch (kBlocks) { + case 0: + return v; + case 1: + return detail::SlideDownI64Lanes<2>(v); + case 2: + return ZeroExtendVector(d, UpperHalf(dh, v)); + case 3: + return detail::SlideDownI64Lanes<6>(v); + } + + return v; +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + const Half dh; + switch (amt) { + case 1: + return detail::SlideDownI32Lanes<1>(v); + case 2: + return detail::SlideDownI64Lanes<1>(v); + case 3: + return detail::SlideDownI32Lanes<3>(v); + case 4: + return detail::SlideDownI64Lanes<2>(v); + case 5: + return detail::SlideDownI32Lanes<5>(v); + case 6: + return detail::SlideDownI64Lanes<3>(v); + case 7: + return detail::SlideDownI32Lanes<7>(v); + case 8: + return ZeroExtendVector(d, UpperHalf(dh, v)); + case 9: + return detail::SlideDownI32Lanes<9>(v); + case 10: + return detail::SlideDownI64Lanes<5>(v); + case 11: + return detail::SlideDownI32Lanes<11>(v); + case 12: + return detail::SlideDownI64Lanes<6>(v); + case 13: + return detail::SlideDownI32Lanes<13>(v); + case 14: + return detail::SlideDownI64Lanes<7>(v); + case 15: + return detail::SlideDownI32Lanes<15>(v); + } + } +#endif + + return detail::TableLookupSlideDownLanes(d, v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + const Half dh; + switch (amt) { + case 0: + return v; + case 1: + return detail::SlideDownI64Lanes<1>(v); + case 2: + return detail::SlideDownI64Lanes<2>(v); + case 3: + return detail::SlideDownI64Lanes<3>(v); + case 4: + return ZeroExtendVector(d, UpperHalf(dh, v)); + case 5: + return detail::SlideDownI64Lanes<5>(v); + case 6: + return detail::SlideDownI64Lanes<6>(v); + case 7: + return detail::SlideDownI64Lanes<7>(v); + } + } +#endif + + return detail::TableLookupSlideDownLanes(d, v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt)) { + if ((amt & 3) == 0) { + const Repartition du32; + return BitCast(d, SlideDownLanes(du32, BitCast(du32, v), amt >> 2)); + } else if ((amt & 1) == 0) { + const Repartition du16; + return BitCast(d, detail::TableLookupSlideDownLanes( + du16, BitCast(du16, v), amt >> 1)); + } +#if HWY_TARGET > HWY_AVX3_DL + else if (amt <= 63) { // NOLINT(readability/braces) + const Repartition du64; + const size_t blk_u64_slidedown_amt = (amt >> 4) << 1; + const auto vu64 = BitCast(du64, v); + const auto v_lo = + BitCast(d, SlideDownLanes(du64, vu64, blk_u64_slidedown_amt)); + const auto v_hi = + (blk_u64_slidedown_amt <= 4) + ? BitCast(d, + SlideDownLanes(du64, vu64, blk_u64_slidedown_amt + 2)) + : Zero(d); + switch (amt & 15) { + case 1: + return CombineShiftRightBytes<1>(d, v_hi, v_lo); + case 3: + return CombineShiftRightBytes<3>(d, v_hi, v_lo); + case 5: + return CombineShiftRightBytes<5>(d, v_hi, v_lo); + case 7: + return CombineShiftRightBytes<7>(d, v_hi, v_lo); + case 9: + return CombineShiftRightBytes<9>(d, v_hi, v_lo); + case 11: + return CombineShiftRightBytes<11>(d, v_hi, v_lo); + case 13: + return CombineShiftRightBytes<13>(d, v_hi, v_lo); + case 15: + return CombineShiftRightBytes<15>(d, v_hi, v_lo); + } + } +#endif + } +#endif + + return detail::TableLookupSlideDownLanes(d, v, amt); +} + +template +HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { +#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang + if (__builtin_constant_p(amt) && (amt & 1) == 0) { + const Repartition du32; + return BitCast(d, SlideDownLanes(du32, BitCast(du32, v), amt >> 1)); + } +#endif + + return detail::TableLookupSlideDownLanes(d, v, amt); +} + +// ------------------------------ Slide1Down + +template +HWY_API VFromD Slide1Down(D d, VFromD v) { +#if HWY_TARGET <= HWY_AVX3_DL + return detail::TableLookupSlideDownLanes(d, v, 1); +#else + const auto v_hi = detail::SlideDownI64Lanes<2>(v); + return CombineShiftRightBytes<1>(d, v_hi, v); +#endif +} + +template +HWY_API VFromD Slide1Down(D d, VFromD v) { + return detail::TableLookupSlideDownLanes(d, v, 1); +} + +template +HWY_API VFromD Slide1Down(D /*d*/, VFromD v) { + return detail::SlideDownI32Lanes<1>(v); +} + +template +HWY_API VFromD Slide1Down(D /*d*/, VFromD v) { + return detail::SlideDownI64Lanes<1>(v); +} + +// ================================================== CONVERT + +// ------------------------------ Promotions (part w/ narrow lanes -> full) + +// Unsigned: zero-extend. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then Zip* would be faster. +template +HWY_API VFromD PromoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm512_cvtepu8_epi16(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm512_cvtepu8_epi32(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm512_cvtepu16_epi32(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm512_cvtepu32_epi64(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm512_cvtepu16_epi64(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec64 v) { + return VFromD{_mm512_cvtepu8_epi64(v.raw)}; +} + +// Signed: replicate sign bit. +// Note: these have 3 cycle latency; if inputs are already split across the +// 128 bit blocks (in their upper/lower halves), then ZipUpper/lo followed by +// signed shift would be faster. +template +HWY_API VFromD PromoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm512_cvtepi8_epi16(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm512_cvtepi8_epi32(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm512_cvtepi16_epi32(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm512_cvtepi32_epi64(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { + return VFromD{_mm512_cvtepi16_epi64(v.raw)}; +} +template +HWY_API VFromD PromoteTo(D /* tag */, Vec64 v) { + return VFromD{_mm512_cvtepi8_epi64(v.raw)}; +} + +// Float +template +HWY_API VFromD PromoteTo(D /* tag */, Vec256 v) { +#if HWY_HAVE_FLOAT16 + const RebindToUnsigned> du16; + return VFromD{_mm512_cvtph_ps(BitCast(du16, v).raw)}; +#else + return VFromD{_mm512_cvtph_ps(v.raw)}; +#endif // HWY_HAVE_FLOAT16 +} + +#if HWY_HAVE_FLOAT16 + +template +HWY_INLINE VFromD PromoteTo(D /*tag*/, Vec128 v) { + return VFromD{_mm512_cvtph_pd(v.raw)}; +} + +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API VFromD PromoteTo(D df32, Vec256 v) { + const Rebind du16; + const RebindToSigned di32; + return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); +} + +template +HWY_API VFromD PromoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm512_cvtps_pd(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm512_cvtepi32_pd(v.raw)}; +} + +template +HWY_API VFromD PromoteTo(D /* tag */, Vec256 v) { + return VFromD{_mm512_cvtepu32_pd(v.raw)}; +} + +template +HWY_API VFromD PromoteInRangeTo(D /*di64*/, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm512_cvtts_ps_epi64(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior with GCC if any values of v[i] are not + // within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm512_setr_epi64( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7]))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttps2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttps_epi64(v.raw)}; +#endif +} +template +HWY_API VFromD PromoteInRangeTo(D /* tag */, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm512_cvtts_ps_epu64(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior with GCC if any values of v[i] are not + // within the range of an uint64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(32))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm512_setr_epi64( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[4])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[5])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[6])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[7])))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttps2uqq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttps_epu64(v.raw)}; +#endif +} + +// ------------------------------ Demotions (full -> part w/ narrow lanes) + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + const Full512 du64; + const Vec512 u16{_mm512_packus_epi32(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(du64, kLanes); + const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, u16.raw)}; + return LowerHalf(even); +} + +template +HWY_API VFromD DemoteTo(D dn, Vec512 v) { + const DFromV d; + const RebindToSigned di; + return DemoteTo(dn, BitCast(di, Min(v, Set(d, 0x7FFFFFFFu)))); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + const Full512 du64; + const Vec512 i16{_mm512_packs_epi32(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(du64, kLanes); + const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, i16.raw)}; + return LowerHalf(even); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + const Full512 du32; + const Vec512 i16{_mm512_packs_epi32(v.raw, v.raw)}; + const Vec512 u8{_mm512_packus_epi16(i16.raw, i16.raw)}; + + const VFromD idx32 = Dup128VecFromValues(du32, 0, 4, 8, 12); + const Vec512 fixed{_mm512_permutexvar_epi32(idx32.raw, u8.raw)}; + return LowerHalf(LowerHalf(fixed)); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtusepi32_epi8(v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + const Full512 du64; + const Vec512 u8{_mm512_packus_epi16(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(du64, kLanes); + const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, u8.raw)}; + return LowerHalf(even); +} + +template +HWY_API VFromD DemoteTo(D dn, Vec512 v) { + const DFromV d; + const RebindToSigned di; + return DemoteTo(dn, BitCast(di, Min(v, Set(d, 0x7FFFu)))); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + const Full512 du32; + const Vec512 i16{_mm512_packs_epi32(v.raw, v.raw)}; + const Vec512 i8{_mm512_packs_epi16(i16.raw, i16.raw)}; + + const VFromD idx32 = Dup128VecFromValues(du32, 0, 4, 8, 12); + const Vec512 fixed{_mm512_permutexvar_epi32(idx32.raw, i8.raw)}; + return LowerHalf(LowerHalf(fixed)); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + const Full512 du64; + const Vec512 u8{_mm512_packs_epi16(v.raw, v.raw)}; + + // Compress even u64 lanes into 256 bit. + alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; + const auto idx64 = Load(du64, kLanes); + const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, u8.raw)}; + return LowerHalf(even); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtsepi64_epi32(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtsepi64_epi16(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtsepi64_epi8(v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + const __mmask8 non_neg_mask = Not(MaskFromVec(v)).raw; + return VFromD{_mm512_maskz_cvtusepi64_epi32(non_neg_mask, v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + const __mmask8 non_neg_mask = Not(MaskFromVec(v)).raw; + return VFromD{_mm512_maskz_cvtusepi64_epi16(non_neg_mask, v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + const __mmask8 non_neg_mask = Not(MaskFromVec(v)).raw; + return VFromD{_mm512_maskz_cvtusepi64_epi8(non_neg_mask, v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtusepi64_epi32(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtusepi64_epi16(v.raw)}; +} +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtusepi64_epi8(v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D df16, Vec512 v) { + // Work around warnings in the intrinsic definitions (passing -1 as a mask). + HWY_DIAGNOSTICS(push) + HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") + const RebindToUnsigned du16; + return BitCast( + df16, VFromD{_mm512_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}); + HWY_DIAGNOSTICS(pop) +} + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD DemoteTo(D /*df16*/, Vec512 v) { + return VFromD{_mm512_cvtpd_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +#if HWY_AVX3_HAVE_F32_TO_BF16C +template +HWY_API VFromD DemoteTo(D /*dbf16*/, Vec512 v) { +#if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 + // Inline assembly workaround for LLVM codegen bug + __m256i raw_result; + __asm__("vcvtneps2bf16 %1, %0" : "=v"(raw_result) : "v"(v.raw)); + return VFromD{raw_result}; +#else + // The _mm512_cvtneps_pbh intrinsic returns a __m256bh vector that needs to be + // bit casted to a __m256i vector + return VFromD{detail::BitCastToInteger(_mm512_cvtneps_pbh(v.raw))}; +#endif +} + +template +HWY_API VFromD ReorderDemote2To(D /*dbf16*/, Vec512 a, + Vec512 b) { +#if HWY_COMPILER_CLANG >= 1600 && HWY_COMPILER_CLANG < 2000 + // Inline assembly workaround for LLVM codegen bug + __m512i raw_result; + __asm__("vcvtne2ps2bf16 %2, %1, %0" + : "=v"(raw_result) + : "v"(b.raw), "v"(a.raw)); + return VFromD{raw_result}; +#else + // The _mm512_cvtne2ps_pbh intrinsic returns a __m512bh vector that needs to + // be bit casted to a __m512i vector + return VFromD{detail::BitCastToInteger(_mm512_cvtne2ps_pbh(b.raw, a.raw))}; +#endif +} +#endif // HWY_AVX3_HAVE_F32_TO_BF16C + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec512 a, + Vec512 b) { + return VFromD{_mm512_packs_epi32(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec512 a, + Vec512 b) { + return VFromD{_mm512_packus_epi32(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D dn, Vec512 a, + Vec512 b) { + const DFromV du32; + const RebindToSigned di32; + const auto max_i32 = Set(du32, 0x7FFFFFFFu); + + return ReorderDemote2To(dn, BitCast(di32, Min(a, max_i32)), + BitCast(di32, Min(b, max_i32))); +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec512 a, + Vec512 b) { + return VFromD{_mm512_packs_epi16(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D /* tag */, Vec512 a, + Vec512 b) { + return VFromD{_mm512_packus_epi16(a.raw, b.raw)}; +} + +template +HWY_API VFromD ReorderDemote2To(D dn, Vec512 a, + Vec512 b) { + const DFromV du16; + const RebindToSigned di16; + const auto max_i16 = Set(du16, 0x7FFFu); + + return ReorderDemote2To(dn, BitCast(di16, Min(a, max_i16)), + BitCast(di16, Min(b, max_i16))); +} + +template ), + HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), + HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), + HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV) * 2), + HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2) | (1 << 4))> +HWY_API VFromD OrderedDemote2To(D d, V a, V b) { + const Full512 du64; + alignas(64) static constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + return BitCast(d, TableLookupLanes(BitCast(du64, ReorderDemote2To(d, a, b)), + SetTableIndices(du64, kIdx))); +} + +template +HWY_API VFromD DemoteTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtpd_ps(v.raw)}; +} + +template +HWY_API VFromD DemoteInRangeTo(D /* tag */, Vec512 v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm512_cvtts_pd_epi32(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttpd_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm256_setr_epi32( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7])) + }; + } +#endif + + __m256i raw_result; + __asm__("vcvttpd2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttpd_epi32(v.raw)}; +#endif +} + +template +HWY_API VFromD DemoteInRangeTo(D /* tag */, Vec512 v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm512_cvtts_pd_epu32(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttpd_epu32 with GCC if any + // values of v[i] are not within the range of an uint32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm256_setr_epi32( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[4])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[5])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[6])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[7])))}; + } +#endif + + __m256i raw_result; + __asm__("vcvttpd2udq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttpd_epu32(v.raw)}; +#endif +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm512_cvtepi64_ps(v.raw)}; +} + +template +HWY_API VFromD DemoteTo(D /* tag */, VFromD> v) { + return VFromD{_mm512_cvtepu64_ps(v.raw)}; +} + +// For already range-limited input [0, 255]. +HWY_API Vec128 U8FromU32(const Vec512 v) { + const DFromV d32; + // In each 128 bit block, gather the lower byte of 4 uint32_t lanes into the + // lowest 4 bytes. + const VFromD v8From32 = + Dup128VecFromValues(d32, 0x0C080400u, ~0u, ~0u, ~0u); + const auto quads = TableLookupBytes(v, v8From32); + // Gather the lowest 4 bytes of 4 128-bit blocks. + const VFromD index32 = Dup128VecFromValues(d32, 0, 4, 8, 12); + const Vec512 bytes{_mm512_permutexvar_epi32(index32.raw, quads.raw)}; + return LowerHalf(LowerHalf(bytes)); +} + +// ------------------------------ Truncations + +template +HWY_API VFromD TruncateTo(D d, const Vec512 v) { +#if HWY_TARGET <= HWY_AVX3_DL + (void)d; + const Full512 d8; + const VFromD v8From64 = Dup128VecFromValues( + d8, 0, 8, 16, 24, 32, 40, 48, 56, 0, 8, 16, 24, 32, 40, 48, 56); + const Vec512 bytes{_mm512_permutexvar_epi8(v8From64.raw, v.raw)}; + return LowerHalf(LowerHalf(LowerHalf(bytes))); +#else + const Full512 d32; + alignas(64) static constexpr uint32_t kEven[16] = {0, 2, 4, 6, 8, 10, 12, 14, + 0, 2, 4, 6, 8, 10, 12, 14}; + const Vec512 even{ + _mm512_permutexvar_epi32(Load(d32, kEven).raw, v.raw)}; + return TruncateTo(d, LowerHalf(even)); +#endif +} + +template +HWY_API VFromD TruncateTo(D /* tag */, const Vec512 v) { + const Full512 d16; + alignas(16) static constexpr uint16_t k16From64[8] = {0, 4, 8, 12, + 16, 20, 24, 28}; + const Vec512 bytes{ + _mm512_permutexvar_epi16(LoadDup128(d16, k16From64).raw, v.raw)}; + return LowerHalf(LowerHalf(bytes)); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, const Vec512 v) { + const Full512 d32; + alignas(64) static constexpr uint32_t kEven[16] = {0, 2, 4, 6, 8, 10, 12, 14, + 0, 2, 4, 6, 8, 10, 12, 14}; + const Vec512 even{ + _mm512_permutexvar_epi32(Load(d32, kEven).raw, v.raw)}; + return LowerHalf(even); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, const Vec512 v) { +#if HWY_TARGET <= HWY_AVX3_DL + const Full512 d8; + const VFromD v8From32 = Dup128VecFromValues( + d8, 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60); + const Vec512 bytes{_mm512_permutexvar_epi8(v8From32.raw, v.raw)}; +#else + const Full512 d32; + // In each 128 bit block, gather the lower byte of 4 uint32_t lanes into the + // lowest 4 bytes. + const VFromD v8From32 = + Dup128VecFromValues(d32, 0x0C080400u, ~0u, ~0u, ~0u); + const auto quads = TableLookupBytes(v, v8From32); + // Gather the lowest 4 bytes of 4 128-bit blocks. + const VFromD index32 = Dup128VecFromValues(d32, 0, 4, 8, 12); + const Vec512 bytes{_mm512_permutexvar_epi32(index32.raw, quads.raw)}; +#endif + return LowerHalf(LowerHalf(bytes)); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, const Vec512 v) { + const Full512 d16; + alignas(64) static constexpr uint16_t k16From32[32] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; + const Vec512 bytes{ + _mm512_permutexvar_epi16(Load(d16, k16From32).raw, v.raw)}; + return LowerHalf(bytes); +} + +template +HWY_API VFromD TruncateTo(D /* tag */, const Vec512 v) { +#if HWY_TARGET <= HWY_AVX3_DL + const Full512 d8; + alignas(64) static constexpr uint8_t k8From16[64] = { + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}; + const Vec512 bytes{ + _mm512_permutexvar_epi8(Load(d8, k8From16).raw, v.raw)}; +#else + const Full512 d32; + const VFromD v16From32 = Dup128VecFromValues( + d32, 0x06040200u, 0x0E0C0A08u, 0x06040200u, 0x0E0C0A08u); + const auto quads = TableLookupBytes(v, v16From32); + alignas(64) static constexpr uint32_t kIndex32[16] = { + 0, 1, 4, 5, 8, 9, 12, 13, 0, 1, 4, 5, 8, 9, 12, 13}; + const Vec512 bytes{ + _mm512_permutexvar_epi32(Load(d32, kIndex32).raw, quads.raw)}; +#endif + return LowerHalf(bytes); +} + +// ------------------------------ Convert integer <=> floating point + +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD ConvertTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtepu16_ph(v.raw)}; +} +template +HWY_API VFromD ConvertTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtepi16_ph(v.raw)}; +} +#endif // HWY_HAVE_FLOAT16 + +template +HWY_API VFromD ConvertTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtepi32_ps(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /* tag */, Vec512 v) { + return VFromD{_mm512_cvtepi64_pd(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /* tag*/, Vec512 v) { + return VFromD{_mm512_cvtepu32_ps(v.raw)}; +} + +template +HWY_API VFromD ConvertTo(D /* tag*/, Vec512 v) { + return VFromD{_mm512_cvtepu64_pd(v.raw)}; +} + +// Truncates (rounds toward zero). +#if HWY_HAVE_FLOAT16 +template +HWY_API VFromD ConvertInRangeTo(D /*d*/, Vec512 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttph_epi16 with GCC if any + // values of v[i] are not within the range of an int16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm512_set_epi16(detail::X86ConvertScalarFromFloat(raw_v[31]), + detail::X86ConvertScalarFromFloat(raw_v[30]), + detail::X86ConvertScalarFromFloat(raw_v[29]), + detail::X86ConvertScalarFromFloat(raw_v[28]), + detail::X86ConvertScalarFromFloat(raw_v[27]), + detail::X86ConvertScalarFromFloat(raw_v[26]), + detail::X86ConvertScalarFromFloat(raw_v[25]), + detail::X86ConvertScalarFromFloat(raw_v[24]), + detail::X86ConvertScalarFromFloat(raw_v[23]), + detail::X86ConvertScalarFromFloat(raw_v[22]), + detail::X86ConvertScalarFromFloat(raw_v[21]), + detail::X86ConvertScalarFromFloat(raw_v[20]), + detail::X86ConvertScalarFromFloat(raw_v[19]), + detail::X86ConvertScalarFromFloat(raw_v[18]), + detail::X86ConvertScalarFromFloat(raw_v[17]), + detail::X86ConvertScalarFromFloat(raw_v[16]), + detail::X86ConvertScalarFromFloat(raw_v[15]), + detail::X86ConvertScalarFromFloat(raw_v[14]), + detail::X86ConvertScalarFromFloat(raw_v[13]), + detail::X86ConvertScalarFromFloat(raw_v[12]), + detail::X86ConvertScalarFromFloat(raw_v[11]), + detail::X86ConvertScalarFromFloat(raw_v[10]), + detail::X86ConvertScalarFromFloat(raw_v[9]), + detail::X86ConvertScalarFromFloat(raw_v[8]), + detail::X86ConvertScalarFromFloat(raw_v[7]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[0]))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttph2w {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttph_epi16(v.raw)}; +#endif +} +template +HWY_API VFromD ConvertInRangeTo(D /* tag */, VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttph_epu16 with GCC if any + // values of v[i] are not within the range of an uint16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm512_set_epi16( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[31])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[30])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[29])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[28])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[27])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[26])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[25])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[24])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[23])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[22])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[21])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[20])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[19])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[18])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[17])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[16])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[15])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[14])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[13])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[12])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[11])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[10])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[9])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[8])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[7])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[6])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[5])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[4])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttph2uw {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttph_epu16(v.raw)}; +#endif +} +#endif // HWY_HAVE_FLOAT16 +template +HWY_API VFromD ConvertInRangeTo(D /*d*/, Vec512 v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm512_cvtts_ps_epi32(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttps_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm512_setr_epi32( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7]), + detail::X86ConvertScalarFromFloat(raw_v[8]), + detail::X86ConvertScalarFromFloat(raw_v[9]), + detail::X86ConvertScalarFromFloat(raw_v[10]), + detail::X86ConvertScalarFromFloat(raw_v[11]), + detail::X86ConvertScalarFromFloat(raw_v[12]), + detail::X86ConvertScalarFromFloat(raw_v[13]), + detail::X86ConvertScalarFromFloat(raw_v[14]), + detail::X86ConvertScalarFromFloat(raw_v[15]))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttps2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttps_epi32(v.raw)}; +#endif +} +template +HWY_API VFromD ConvertInRangeTo(D /*di*/, Vec512 v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm512_cvtts_pd_epi64(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttpd_epi64 with GCC if any + // values of v[i] are not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm512_setr_epi64( + detail::X86ConvertScalarFromFloat(raw_v[0]), + detail::X86ConvertScalarFromFloat(raw_v[1]), + detail::X86ConvertScalarFromFloat(raw_v[2]), + detail::X86ConvertScalarFromFloat(raw_v[3]), + detail::X86ConvertScalarFromFloat(raw_v[4]), + detail::X86ConvertScalarFromFloat(raw_v[5]), + detail::X86ConvertScalarFromFloat(raw_v[6]), + detail::X86ConvertScalarFromFloat(raw_v[7]))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttpd2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttpd_epi64(v.raw)}; +#endif +} +template +HWY_API VFromD ConvertInRangeTo(DU /*du*/, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm512_cvtts_ps_epu32(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttps_epu32 with GCC if any + // values of v[i] are not within the range of an uint32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm512_setr_epi32( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[4])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[5])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[6])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[7])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[8])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[9])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[10])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[11])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[12])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[13])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[14])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[15])))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttps2udq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttps_epu32(v.raw)}; +#endif +} +template +HWY_API VFromD ConvertInRangeTo(DU /*du*/, VFromD> v) { +#if HWY_X86_HAVE_AVX10_2_OPS + return VFromD{_mm512_cvtts_pd_epu64(v.raw)}; +#elif HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvttpd_epu64 with GCC if any + // values of v[i] are not within the range of an uint64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{_mm512_setr_epi64( + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[0])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[1])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[2])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[3])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[4])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[5])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[6])), + static_cast( + detail::X86ConvertScalarFromFloat(raw_v[7])))}; + } +#endif + + __m512i raw_result; + __asm__("vcvttpd2uqq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvttpd_epu64(v.raw)}; +#endif +} + +template +static HWY_INLINE VFromD NearestIntInRange(DI, + VFromD> v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvtps_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef float GccF32RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm512_setr_epi32(detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3]), + detail::X86ScalarNearestInt(raw_v[4]), + detail::X86ScalarNearestInt(raw_v[5]), + detail::X86ScalarNearestInt(raw_v[6]), + detail::X86ScalarNearestInt(raw_v[7]), + detail::X86ScalarNearestInt(raw_v[8]), + detail::X86ScalarNearestInt(raw_v[9]), + detail::X86ScalarNearestInt(raw_v[10]), + detail::X86ScalarNearestInt(raw_v[11]), + detail::X86ScalarNearestInt(raw_v[12]), + detail::X86ScalarNearestInt(raw_v[13]), + detail::X86ScalarNearestInt(raw_v[14]), + detail::X86ScalarNearestInt(raw_v[15]))}; + } +#endif + + __m512i raw_result; + __asm__("vcvtps2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvtps_epi32(v.raw)}; +#endif +} + +#if HWY_HAVE_FLOAT16 +template +static HWY_INLINE VFromD NearestIntInRange(DI /*d*/, Vec512 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvtph_epi16 with GCC if any + // values of v[i] are not within the range of an int16_t + +#if HWY_COMPILER_GCC_ACTUAL >= 1200 && !HWY_IS_DEBUG_BUILD && \ + HWY_HAVE_SCALAR_F16_TYPE + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef hwy::float16_t::Native GccF16RawVectType + __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm512_set_epi16(detail::X86ScalarNearestInt(raw_v[31]), + detail::X86ScalarNearestInt(raw_v[30]), + detail::X86ScalarNearestInt(raw_v[29]), + detail::X86ScalarNearestInt(raw_v[28]), + detail::X86ScalarNearestInt(raw_v[27]), + detail::X86ScalarNearestInt(raw_v[26]), + detail::X86ScalarNearestInt(raw_v[25]), + detail::X86ScalarNearestInt(raw_v[24]), + detail::X86ScalarNearestInt(raw_v[23]), + detail::X86ScalarNearestInt(raw_v[22]), + detail::X86ScalarNearestInt(raw_v[21]), + detail::X86ScalarNearestInt(raw_v[20]), + detail::X86ScalarNearestInt(raw_v[19]), + detail::X86ScalarNearestInt(raw_v[18]), + detail::X86ScalarNearestInt(raw_v[17]), + detail::X86ScalarNearestInt(raw_v[16]), + detail::X86ScalarNearestInt(raw_v[15]), + detail::X86ScalarNearestInt(raw_v[14]), + detail::X86ScalarNearestInt(raw_v[13]), + detail::X86ScalarNearestInt(raw_v[12]), + detail::X86ScalarNearestInt(raw_v[11]), + detail::X86ScalarNearestInt(raw_v[10]), + detail::X86ScalarNearestInt(raw_v[9]), + detail::X86ScalarNearestInt(raw_v[8]), + detail::X86ScalarNearestInt(raw_v[7]), + detail::X86ScalarNearestInt(raw_v[6]), + detail::X86ScalarNearestInt(raw_v[5]), + detail::X86ScalarNearestInt(raw_v[4]), + detail::X86ScalarNearestInt(raw_v[3]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[0]))}; + } +#endif + + __m512i raw_result; + __asm__("vcvtph2w {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvtph_epi16(v.raw)}; +#endif +} +#endif // HWY_HAVE_FLOAT16 + +template +static HWY_INLINE VFromD NearestIntInRange(DI /*di*/, Vec512 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvtpd_epi64 with GCC if any + // values of v[i] are not within the range of an int64_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm512_setr_epi64(detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3]), + detail::X86ScalarNearestInt(raw_v[4]), + detail::X86ScalarNearestInt(raw_v[5]), + detail::X86ScalarNearestInt(raw_v[6]), + detail::X86ScalarNearestInt(raw_v[7]))}; + } +#endif + + __m512i raw_result; + __asm__("vcvtpd2qq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvtpd_epi64(v.raw)}; +#endif +} + +template +static HWY_INLINE VFromD DemoteToNearestIntInRange(DI /* tag */, + Vec512 v) { +#if HWY_COMPILER_GCC_ACTUAL + // Workaround for undefined behavior in _mm512_cvtpd_epi32 with GCC if any + // values of v[i] are not within the range of an int32_t + +#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD + if (detail::IsConstantX86VecForF2IConv(v)) { + typedef double GccF64RawVectType __attribute__((__vector_size__(64))); + const auto raw_v = reinterpret_cast(v.raw); + return VFromD{ + _mm256_setr_epi32(detail::X86ScalarNearestInt(raw_v[0]), + detail::X86ScalarNearestInt(raw_v[1]), + detail::X86ScalarNearestInt(raw_v[2]), + detail::X86ScalarNearestInt(raw_v[3]), + detail::X86ScalarNearestInt(raw_v[4]), + detail::X86ScalarNearestInt(raw_v[5]), + detail::X86ScalarNearestInt(raw_v[6]), + detail::X86ScalarNearestInt(raw_v[7]))}; + } +#endif + + __m256i raw_result; + __asm__("vcvtpd2dq {%1, %0|%0, %1}" + : "=" HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(raw_result) + : HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT(v.raw) + :); + return VFromD{raw_result}; +#else + return VFromD{_mm512_cvtpd_epi32(v.raw)}; +#endif +} + +// ================================================== CRYPTO + +#if !defined(HWY_DISABLE_PCLMUL_AES) + +HWY_API Vec512 AESRound(Vec512 state, + Vec512 round_key) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec512{_mm512_aesenc_epi128(state.raw, round_key.raw)}; +#else + const DFromV d; + const Half d2; + return Combine(d, AESRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESRound(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec512 AESLastRound(Vec512 state, + Vec512 round_key) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec512{_mm512_aesenclast_epi128(state.raw, round_key.raw)}; +#else + const DFromV d; + const Half d2; + return Combine(d, + AESLastRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESLastRound(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec512 AESRoundInv(Vec512 state, + Vec512 round_key) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec512{_mm512_aesdec_epi128(state.raw, round_key.raw)}; +#else + const Full512 d; + const Half d2; + return Combine(d, AESRoundInv(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESRoundInv(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +HWY_API Vec512 AESLastRoundInv(Vec512 state, + Vec512 round_key) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec512{_mm512_aesdeclast_epi128(state.raw, round_key.raw)}; +#else + const Full512 d; + const Half d2; + return Combine( + d, AESLastRoundInv(UpperHalf(d2, state), UpperHalf(d2, round_key)), + AESLastRoundInv(LowerHalf(state), LowerHalf(round_key))); +#endif +} + +template +HWY_API Vec512 AESKeyGenAssist(Vec512 v) { + const Full512 d; +#if HWY_TARGET <= HWY_AVX3_DL + const VFromD rconXorMask = Dup128VecFromValues( + d, 0, kRcon, 0, 0, 0, 0, 0, 0, 0, kRcon, 0, 0, 0, 0, 0, 0); + const VFromD rotWordShuffle = Dup128VecFromValues( + d, 0, 13, 10, 7, 1, 14, 11, 4, 8, 5, 2, 15, 9, 6, 3, 12); + const Repartition du32; + const auto w13 = BitCast(d, DupOdd(BitCast(du32, v))); + const auto sub_word_result = AESLastRound(w13, rconXorMask); + return TableLookupBytes(sub_word_result, rotWordShuffle); +#else + const Half d2; + return Combine(d, AESKeyGenAssist(UpperHalf(d2, v)), + AESKeyGenAssist(LowerHalf(v))); +#endif +} + +HWY_API Vec512 CLMulLower(Vec512 va, Vec512 vb) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec512{_mm512_clmulepi64_epi128(va.raw, vb.raw, 0x00)}; +#else + alignas(64) uint64_t a[8]; + alignas(64) uint64_t b[8]; + const DFromV d; + const Half> d128; + Store(va, d, a); + Store(vb, d, b); + for (size_t i = 0; i < 8; i += 2) { + const auto mul = CLMulLower(Load(d128, a + i), Load(d128, b + i)); + Store(mul, d128, a + i); + } + return Load(d, a); +#endif +} + +HWY_API Vec512 CLMulUpper(Vec512 va, Vec512 vb) { +#if HWY_TARGET <= HWY_AVX3_DL + return Vec512{_mm512_clmulepi64_epi128(va.raw, vb.raw, 0x11)}; +#else + alignas(64) uint64_t a[8]; + alignas(64) uint64_t b[8]; + const DFromV d; + const Half> d128; + Store(va, d, a); + Store(vb, d, b); + for (size_t i = 0; i < 8; i += 2) { + const auto mul = CLMulUpper(Load(d128, a + i), Load(d128, b + i)); + Store(mul, d128, a + i); + } + return Load(d, a); +#endif +} + +#endif // HWY_DISABLE_PCLMUL_AES + +// ================================================== MISC + +// ------------------------------ SumsOfAdjQuadAbsDiff (Broadcast, +// SumsOfAdjShufQuadAbsDiff) + +template +HWY_API Vec512 SumsOfAdjQuadAbsDiff(Vec512 a, + Vec512 b) { + static_assert(0 <= kAOffset && kAOffset <= 1, + "kAOffset must be between 0 and 1"); + static_assert(0 <= kBOffset && kBOffset <= 3, + "kBOffset must be between 0 and 3"); + +#if HWY_X86_HAVE_AVX10_2_OPS + // AVX10.2 now has the _mm512_mpsadbw_epu8 intrinsic available + return Vec512{_mm512_mpsadbw_epu8( + a.raw, b.raw, + (kAOffset << 5) | (kBOffset << 3) | (kAOffset << 2) | kBOffset)}; +#else + const DFromV d; + const RepartitionToWideX2 du32; + + // The _mm512_mpsadbw_epu8 intrinsic is not available prior to AVX10.2. + // The SumsOfAdjQuadAbsDiff operation is implementable for 512-bit vectors on + // pre-AVX10.2 targets that support AVX3 using SumsOfShuffledQuadAbsDiff and + // U32 Broadcast. + return SumsOfShuffledQuadAbsDiff( + a, BitCast(d, Broadcast(BitCast(du32, b)))); +#endif +} + +// ------------------------------ Mask testing + +// Beware: the suffix indicates the number of mask bits, not lane size! + +namespace detail { + +template +HWY_INLINE bool AllFalse(hwy::SizeTag<1> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask64_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<2> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask32_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<4> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask16_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} +template +HWY_INLINE bool AllFalse(hwy::SizeTag<8> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestz_mask8_u8(mask.raw, mask.raw); +#else + return mask.raw == 0; +#endif +} + +} // namespace detail + +template +HWY_API bool AllFalse(D /* tag */, const MFromD mask) { + return detail::AllFalse(hwy::SizeTag)>(), mask); +} + +namespace detail { + +template +HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask64_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFFFFFFFFFFFFFull; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask32_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFFFFFull; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask16_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFFFull; +#endif +} +template +HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask512 mask) { +#if HWY_COMPILER_HAS_MASK_INTRINSICS + return _kortestc_mask8_u8(mask.raw, mask.raw); +#else + return mask.raw == 0xFFull; +#endif +} + +} // namespace detail + +template +HWY_API bool AllTrue(D /* tag */, const MFromD mask) { + return detail::AllTrue(hwy::SizeTag)>(), mask); +} + +// `p` points to at least 8 readable bytes, not all of which need be valid. +template +HWY_API MFromD LoadMaskBits(D /* tag */, const uint8_t* HWY_RESTRICT bits) { + MFromD mask; + CopyBytes<8 / sizeof(TFromD)>(bits, &mask.raw); + // N >= 8 (= 512 / 64), so no need to mask invalid bits. + return mask; +} + +// `p` points to at least 8 writable bytes. +template +HWY_API size_t StoreMaskBits(D /* tag */, MFromD mask, uint8_t* bits) { + const size_t kNumBytes = 8 / sizeof(TFromD); + CopyBytes(&mask.raw, bits); + // N >= 8 (= 512 / 64), so no need to mask invalid bits. + return kNumBytes; +} + +template +HWY_API size_t CountTrue(D /* tag */, const MFromD mask) { + return PopCount(static_cast(mask.raw)); +} + +template +HWY_API size_t FindKnownFirstTrue(D /* tag */, MFromD mask) { + return Num0BitsBelowLS1Bit_Nonzero32(mask.raw); +} + +template +HWY_API size_t FindKnownFirstTrue(D /* tag */, MFromD mask) { + return Num0BitsBelowLS1Bit_Nonzero64(mask.raw); +} + +template +HWY_API intptr_t FindFirstTrue(D d, MFromD mask) { + return mask.raw ? static_cast(FindKnownFirstTrue(d, mask)) + : intptr_t{-1}; +} + +template +HWY_API size_t FindKnownLastTrue(D /* tag */, MFromD mask) { + return 31 - Num0BitsAboveMS1Bit_Nonzero32(mask.raw); +} + +template +HWY_API size_t FindKnownLastTrue(D /* tag */, MFromD mask) { + return 63 - Num0BitsAboveMS1Bit_Nonzero64(mask.raw); +} + +template +HWY_API intptr_t FindLastTrue(D d, MFromD mask) { + return mask.raw ? static_cast(FindKnownLastTrue(d, mask)) + : intptr_t{-1}; +} + +// ------------------------------ Compress + +template +HWY_API Vec512 Compress(Vec512 v, Mask512 mask) { + // See CompressIsPartition. u64 is faster than u32. + alignas(16) static constexpr uint64_t packed_array[256] = { + // From PrintCompress32x8Tables, without the FirstN extension (there is + // no benefit to including them because 64-bit CompressStore is anyway + // masked, but also no harm because TableLookupLanes ignores the MSB). + 0x76543210, 0x76543210, 0x76543201, 0x76543210, 0x76543102, 0x76543120, + 0x76543021, 0x76543210, 0x76542103, 0x76542130, 0x76542031, 0x76542310, + 0x76541032, 0x76541320, 0x76540321, 0x76543210, 0x76532104, 0x76532140, + 0x76532041, 0x76532410, 0x76531042, 0x76531420, 0x76530421, 0x76534210, + 0x76521043, 0x76521430, 0x76520431, 0x76524310, 0x76510432, 0x76514320, + 0x76504321, 0x76543210, 0x76432105, 0x76432150, 0x76432051, 0x76432510, + 0x76431052, 0x76431520, 0x76430521, 0x76435210, 0x76421053, 0x76421530, + 0x76420531, 0x76425310, 0x76410532, 0x76415320, 0x76405321, 0x76453210, + 0x76321054, 0x76321540, 0x76320541, 0x76325410, 0x76310542, 0x76315420, + 0x76305421, 0x76354210, 0x76210543, 0x76215430, 0x76205431, 0x76254310, + 0x76105432, 0x76154320, 0x76054321, 0x76543210, 0x75432106, 0x75432160, + 0x75432061, 0x75432610, 0x75431062, 0x75431620, 0x75430621, 0x75436210, + 0x75421063, 0x75421630, 0x75420631, 0x75426310, 0x75410632, 0x75416320, + 0x75406321, 0x75463210, 0x75321064, 0x75321640, 0x75320641, 0x75326410, + 0x75310642, 0x75316420, 0x75306421, 0x75364210, 0x75210643, 0x75216430, + 0x75206431, 0x75264310, 0x75106432, 0x75164320, 0x75064321, 0x75643210, + 0x74321065, 0x74321650, 0x74320651, 0x74326510, 0x74310652, 0x74316520, + 0x74306521, 0x74365210, 0x74210653, 0x74216530, 0x74206531, 0x74265310, + 0x74106532, 0x74165320, 0x74065321, 0x74653210, 0x73210654, 0x73216540, + 0x73206541, 0x73265410, 0x73106542, 0x73165420, 0x73065421, 0x73654210, + 0x72106543, 0x72165430, 0x72065431, 0x72654310, 0x71065432, 0x71654320, + 0x70654321, 0x76543210, 0x65432107, 0x65432170, 0x65432071, 0x65432710, + 0x65431072, 0x65431720, 0x65430721, 0x65437210, 0x65421073, 0x65421730, + 0x65420731, 0x65427310, 0x65410732, 0x65417320, 0x65407321, 0x65473210, + 0x65321074, 0x65321740, 0x65320741, 0x65327410, 0x65310742, 0x65317420, + 0x65307421, 0x65374210, 0x65210743, 0x65217430, 0x65207431, 0x65274310, + 0x65107432, 0x65174320, 0x65074321, 0x65743210, 0x64321075, 0x64321750, + 0x64320751, 0x64327510, 0x64310752, 0x64317520, 0x64307521, 0x64375210, + 0x64210753, 0x64217530, 0x64207531, 0x64275310, 0x64107532, 0x64175320, + 0x64075321, 0x64753210, 0x63210754, 0x63217540, 0x63207541, 0x63275410, + 0x63107542, 0x63175420, 0x63075421, 0x63754210, 0x62107543, 0x62175430, + 0x62075431, 0x62754310, 0x61075432, 0x61754320, 0x60754321, 0x67543210, + 0x54321076, 0x54321760, 0x54320761, 0x54327610, 0x54310762, 0x54317620, + 0x54307621, 0x54376210, 0x54210763, 0x54217630, 0x54207631, 0x54276310, + 0x54107632, 0x54176320, 0x54076321, 0x54763210, 0x53210764, 0x53217640, + 0x53207641, 0x53276410, 0x53107642, 0x53176420, 0x53076421, 0x53764210, + 0x52107643, 0x52176430, 0x52076431, 0x52764310, 0x51076432, 0x51764320, + 0x50764321, 0x57643210, 0x43210765, 0x43217650, 0x43207651, 0x43276510, + 0x43107652, 0x43176520, 0x43076521, 0x43765210, 0x42107653, 0x42176530, + 0x42076531, 0x42765310, 0x41076532, 0x41765320, 0x40765321, 0x47653210, + 0x32107654, 0x32176540, 0x32076541, 0x32765410, 0x31076542, 0x31765420, + 0x30765421, 0x37654210, 0x21076543, 0x21765430, 0x20765431, 0x27654310, + 0x10765432, 0x17654320, 0x07654321, 0x76543210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 3) - + // _mm512_permutexvar_epi64 will ignore the upper bits. + const DFromV d; + const RebindToUnsigned du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(64) static constexpr uint64_t shifts[8] = {0, 4, 8, 12, + 16, 20, 24, 28}; + const auto indices = Indices512{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// ------------------------------ Expand + +namespace detail { + +#if HWY_TARGET <= HWY_AVX3_DL // VBMI2 +HWY_INLINE Vec512 NativeExpand(Vec512 v, + Mask512 mask) { + return Vec512{_mm512_maskz_expand_epi8(mask.raw, v.raw)}; +} + +HWY_INLINE Vec512 NativeExpand(Vec512 v, + Mask512 mask) { + return Vec512{_mm512_maskz_expand_epi16(mask.raw, v.raw)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(Mask512 mask, D /* d */, + const uint8_t* HWY_RESTRICT unaligned) { + return VFromD{_mm512_maskz_expandloadu_epi8(mask.raw, unaligned)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(Mask512 mask, D /* d */, + const uint16_t* HWY_RESTRICT unaligned) { + return VFromD{_mm512_maskz_expandloadu_epi16(mask.raw, unaligned)}; +} +#endif // HWY_TARGET <= HWY_AVX3_DL + +HWY_INLINE Vec512 NativeExpand(Vec512 v, + Mask512 mask) { + return Vec512{_mm512_maskz_expand_epi32(mask.raw, v.raw)}; +} + +HWY_INLINE Vec512 NativeExpand(Vec512 v, + Mask512 mask) { + return Vec512{_mm512_maskz_expand_epi64(mask.raw, v.raw)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(Mask512 mask, D /* d */, + const uint32_t* HWY_RESTRICT unaligned) { + return VFromD{_mm512_maskz_expandloadu_epi32(mask.raw, unaligned)}; +} + +template +HWY_INLINE VFromD NativeLoadExpand(Mask512 mask, D /* d */, + const uint64_t* HWY_RESTRICT unaligned) { + return VFromD{_mm512_maskz_expandloadu_epi64(mask.raw, unaligned)}; +} + +} // namespace detail + +template +HWY_API Vec512 Expand(Vec512 v, const Mask512 mask) { + const Full512 d; +#if HWY_TARGET <= HWY_AVX3_DL // VBMI2 + const RebindToUnsigned du; + const auto mu = RebindMask(du, mask); + return BitCast(d, detail::NativeExpand(BitCast(du, v), mu)); +#else + // LUTs are infeasible for 2^64 possible masks, so splice together two + // half-vector Expand. + const Full256 dh; + constexpr size_t N = MaxLanes(d); + // We have to shift the input by a variable number of u8. Shuffling requires + // VBMI2, in which case we would already have NativeExpand. We instead + // load at an offset, which may incur a store to load forwarding stall. + alignas(64) T lanes[N]; + Store(v, d, lanes); + using Bits = typename Mask256::Raw; + const Mask256 maskL{ + static_cast(mask.raw & Bits{(1ULL << (N / 2)) - 1})}; + const Mask256 maskH{static_cast(mask.raw >> (N / 2))}; + const size_t countL = CountTrue(dh, maskL); + const Vec256 expandL = Expand(LowerHalf(v), maskL); + const Vec256 expandH = Expand(LoadU(dh, lanes + countL), maskH); + return Combine(d, expandH, expandL); +#endif +} + +template +HWY_API Vec512 Expand(Vec512 v, const Mask512 mask) { + const Full512 d; + const RebindToUnsigned du; + const Vec512 vu = BitCast(du, v); +#if HWY_TARGET <= HWY_AVX3_DL // VBMI2 + return BitCast(d, detail::NativeExpand(vu, RebindMask(du, mask))); +#else // AVX3 + // LUTs are infeasible for 2^32 possible masks, so splice together two + // half-vector Expand. + const Full256 dh; + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + using Bits = typename Mask256::Raw; + const Mask256 maskL{ + static_cast(mask.raw & static_cast((1ULL << (N / 2)) - 1))}; + const Mask256 maskH{static_cast(mask.raw >> (N / 2))}; + // In AVX3 we can permutevar, which avoids a potential store to load + // forwarding stall vs. reloading the input. + alignas(64) uint16_t iota[64] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; + const Vec512 indices = LoadU(du, iota + CountTrue(dh, maskL)); + const Vec512 shifted{_mm512_permutexvar_epi16(indices.raw, vu.raw)}; + const Vec256 expandL = Expand(LowerHalf(v), maskL); + const Vec256 expandH = Expand(LowerHalf(BitCast(d, shifted)), maskH); + return Combine(d, expandH, expandL); +#endif // AVX3 +} + +template +HWY_API V Expand(V v, const M mask) { + const DFromV d; + const RebindToUnsigned du; + const auto mu = RebindMask(du, mask); + return BitCast(d, detail::NativeExpand(BitCast(du, v), mu)); +} + +// For smaller vectors, it is likely more efficient to promote to 32-bit. +// This works for u8x16, u16x8, u16x16 (can be promoted to u32x16), but is +// unnecessary if HWY_AVX3_DL, which provides native instructions. +#if HWY_TARGET > HWY_AVX3_DL // no VBMI2 + +template , 16)> +HWY_API V Expand(V v, M mask) { + const DFromV d; + const RebindToUnsigned du; + const Rebind du32; + const VFromD vu = BitCast(du, v); + using M32 = MFromD; + const M32 m32{static_cast(mask.raw)}; + return BitCast(d, TruncateTo(du, Expand(PromoteTo(du32, vu), m32))); +} + +#endif // HWY_TARGET > HWY_AVX3_DL + +// ------------------------------ LoadExpand + +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { +#if HWY_TARGET <= HWY_AVX3_DL // VBMI2 + const RebindToUnsigned du; + using TU = TFromD; + const TU* HWY_RESTRICT pu = reinterpret_cast(unaligned); + const MFromD mu = RebindMask(du, mask); + return BitCast(d, detail::NativeLoadExpand(mu, du, pu)); +#else + return Expand(LoadU(d, unaligned), mask); +#endif +} + +template +HWY_API VFromD LoadExpand(MFromD mask, D d, + const TFromD* HWY_RESTRICT unaligned) { + const RebindToUnsigned du; + using TU = TFromD; + const TU* HWY_RESTRICT pu = reinterpret_cast(unaligned); + const MFromD mu = RebindMask(du, mask); + return BitCast(d, detail::NativeLoadExpand(mu, du, pu)); +} + +// ------------------------------ CompressNot + +template +HWY_API Vec512 CompressNot(Vec512 v, Mask512 mask) { + // See CompressIsPartition. u64 is faster than u32. + alignas(16) static constexpr uint64_t packed_array[256] = { + // From PrintCompressNot32x8Tables, without the FirstN extension (there is + // no benefit to including them because 64-bit CompressStore is anyway + // masked, but also no harm because TableLookupLanes ignores the MSB). + 0x76543210, 0x07654321, 0x17654320, 0x10765432, 0x27654310, 0x20765431, + 0x21765430, 0x21076543, 0x37654210, 0x30765421, 0x31765420, 0x31076542, + 0x32765410, 0x32076541, 0x32176540, 0x32107654, 0x47653210, 0x40765321, + 0x41765320, 0x41076532, 0x42765310, 0x42076531, 0x42176530, 0x42107653, + 0x43765210, 0x43076521, 0x43176520, 0x43107652, 0x43276510, 0x43207651, + 0x43217650, 0x43210765, 0x57643210, 0x50764321, 0x51764320, 0x51076432, + 0x52764310, 0x52076431, 0x52176430, 0x52107643, 0x53764210, 0x53076421, + 0x53176420, 0x53107642, 0x53276410, 0x53207641, 0x53217640, 0x53210764, + 0x54763210, 0x54076321, 0x54176320, 0x54107632, 0x54276310, 0x54207631, + 0x54217630, 0x54210763, 0x54376210, 0x54307621, 0x54317620, 0x54310762, + 0x54327610, 0x54320761, 0x54321760, 0x54321076, 0x67543210, 0x60754321, + 0x61754320, 0x61075432, 0x62754310, 0x62075431, 0x62175430, 0x62107543, + 0x63754210, 0x63075421, 0x63175420, 0x63107542, 0x63275410, 0x63207541, + 0x63217540, 0x63210754, 0x64753210, 0x64075321, 0x64175320, 0x64107532, + 0x64275310, 0x64207531, 0x64217530, 0x64210753, 0x64375210, 0x64307521, + 0x64317520, 0x64310752, 0x64327510, 0x64320751, 0x64321750, 0x64321075, + 0x65743210, 0x65074321, 0x65174320, 0x65107432, 0x65274310, 0x65207431, + 0x65217430, 0x65210743, 0x65374210, 0x65307421, 0x65317420, 0x65310742, + 0x65327410, 0x65320741, 0x65321740, 0x65321074, 0x65473210, 0x65407321, + 0x65417320, 0x65410732, 0x65427310, 0x65420731, 0x65421730, 0x65421073, + 0x65437210, 0x65430721, 0x65431720, 0x65431072, 0x65432710, 0x65432071, + 0x65432170, 0x65432107, 0x76543210, 0x70654321, 0x71654320, 0x71065432, + 0x72654310, 0x72065431, 0x72165430, 0x72106543, 0x73654210, 0x73065421, + 0x73165420, 0x73106542, 0x73265410, 0x73206541, 0x73216540, 0x73210654, + 0x74653210, 0x74065321, 0x74165320, 0x74106532, 0x74265310, 0x74206531, + 0x74216530, 0x74210653, 0x74365210, 0x74306521, 0x74316520, 0x74310652, + 0x74326510, 0x74320651, 0x74321650, 0x74321065, 0x75643210, 0x75064321, + 0x75164320, 0x75106432, 0x75264310, 0x75206431, 0x75216430, 0x75210643, + 0x75364210, 0x75306421, 0x75316420, 0x75310642, 0x75326410, 0x75320641, + 0x75321640, 0x75321064, 0x75463210, 0x75406321, 0x75416320, 0x75410632, + 0x75426310, 0x75420631, 0x75421630, 0x75421063, 0x75436210, 0x75430621, + 0x75431620, 0x75431062, 0x75432610, 0x75432061, 0x75432160, 0x75432106, + 0x76543210, 0x76054321, 0x76154320, 0x76105432, 0x76254310, 0x76205431, + 0x76215430, 0x76210543, 0x76354210, 0x76305421, 0x76315420, 0x76310542, + 0x76325410, 0x76320541, 0x76321540, 0x76321054, 0x76453210, 0x76405321, + 0x76415320, 0x76410532, 0x76425310, 0x76420531, 0x76421530, 0x76421053, + 0x76435210, 0x76430521, 0x76431520, 0x76431052, 0x76432510, 0x76432051, + 0x76432150, 0x76432105, 0x76543210, 0x76504321, 0x76514320, 0x76510432, + 0x76524310, 0x76520431, 0x76521430, 0x76521043, 0x76534210, 0x76530421, + 0x76531420, 0x76531042, 0x76532410, 0x76532041, 0x76532140, 0x76532104, + 0x76543210, 0x76540321, 0x76541320, 0x76541032, 0x76542310, 0x76542031, + 0x76542130, 0x76542103, 0x76543210, 0x76543021, 0x76543120, 0x76543102, + 0x76543210, 0x76543201, 0x76543210, 0x76543210}; + + // For lane i, shift the i-th 4-bit index down to bits [0, 3) - + // _mm512_permutexvar_epi64 will ignore the upper bits. + const DFromV d; + const RebindToUnsigned du64; + const auto packed = Set(du64, packed_array[mask.raw]); + alignas(64) static constexpr uint64_t shifts[8] = {0, 4, 8, 12, + 16, 20, 24, 28}; + const auto indices = Indices512{(packed >> Load(du64, shifts)).raw}; + return TableLookupLanes(v, indices); +} + +// ------------------------------ LoadInterleaved4 + +// Actually implemented in generic_ops, we just overload LoadTransposedBlocks4. +namespace detail { + +// Type-safe wrapper. +template <_MM_PERM_ENUM kPerm, typename T> +Vec512 Shuffle128(const Vec512 lo, const Vec512 hi) { + const DFromV d; + const RebindToUnsigned du; + return BitCast(d, VFromD{_mm512_shuffle_i64x2( + BitCast(du, lo).raw, BitCast(du, hi).raw, kPerm)}); +} +template <_MM_PERM_ENUM kPerm> +Vec512 Shuffle128(const Vec512 lo, const Vec512 hi) { + return Vec512{_mm512_shuffle_f32x4(lo.raw, hi.raw, kPerm)}; +} +template <_MM_PERM_ENUM kPerm> +Vec512 Shuffle128(const Vec512 lo, const Vec512 hi) { + return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, kPerm)}; +} + +// Input (128-bit blocks): +// 3 2 1 0 (<- first block in unaligned) +// 7 6 5 4 +// b a 9 8 +// Output: +// 9 6 3 0 (LSB of A) +// a 7 4 1 +// b 8 5 2 +template +HWY_API void LoadTransposedBlocks3(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& A, VFromD& B, VFromD& C) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + const VFromD v3210 = LoadU(d, unaligned + 0 * N); + const VFromD v7654 = LoadU(d, unaligned + 1 * N); + const VFromD vba98 = LoadU(d, unaligned + 2 * N); + + const VFromD v5421 = detail::Shuffle128<_MM_PERM_BACB>(v3210, v7654); + const VFromD va976 = detail::Shuffle128<_MM_PERM_CBDC>(v7654, vba98); + + A = detail::Shuffle128<_MM_PERM_CADA>(v3210, va976); + B = detail::Shuffle128<_MM_PERM_DBCA>(v5421, va976); + C = detail::Shuffle128<_MM_PERM_DADB>(v5421, vba98); +} + +// Input (128-bit blocks): +// 3 2 1 0 (<- first block in unaligned) +// 7 6 5 4 +// b a 9 8 +// f e d c +// Output: +// c 8 4 0 (LSB of A) +// d 9 5 1 +// e a 6 2 +// f b 7 3 +template +HWY_API void LoadTransposedBlocks4(D d, const TFromD* HWY_RESTRICT unaligned, + VFromD& vA, VFromD& vB, VFromD& vC, + VFromD& vD) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + const VFromD v3210 = LoadU(d, unaligned + 0 * N); + const VFromD v7654 = LoadU(d, unaligned + 1 * N); + const VFromD vba98 = LoadU(d, unaligned + 2 * N); + const VFromD vfedc = LoadU(d, unaligned + 3 * N); + + const VFromD v5410 = detail::Shuffle128<_MM_PERM_BABA>(v3210, v7654); + const VFromD vdc98 = detail::Shuffle128<_MM_PERM_BABA>(vba98, vfedc); + const VFromD v7632 = detail::Shuffle128<_MM_PERM_DCDC>(v3210, v7654); + const VFromD vfeba = detail::Shuffle128<_MM_PERM_DCDC>(vba98, vfedc); + vA = detail::Shuffle128<_MM_PERM_CACA>(v5410, vdc98); + vB = detail::Shuffle128<_MM_PERM_DBDB>(v5410, vdc98); + vC = detail::Shuffle128<_MM_PERM_CACA>(v7632, vfeba); + vD = detail::Shuffle128<_MM_PERM_DBDB>(v7632, vfeba); +} + +} // namespace detail + +// ------------------------------ StoreInterleaved2 + +// Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. + +namespace detail { + +// Input (128-bit blocks): +// 6 4 2 0 (LSB of i) +// 7 5 3 1 +// Output: +// 3 2 1 0 +// 7 6 5 4 +template +HWY_API void StoreTransposedBlocks2(const VFromD i, const VFromD j, D d, + TFromD* HWY_RESTRICT unaligned) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + const auto j1_j0_i1_i0 = detail::Shuffle128<_MM_PERM_BABA>(i, j); + const auto j3_j2_i3_i2 = detail::Shuffle128<_MM_PERM_DCDC>(i, j); + const auto j1_i1_j0_i0 = + detail::Shuffle128<_MM_PERM_DBCA>(j1_j0_i1_i0, j1_j0_i1_i0); + const auto j3_i3_j2_i2 = + detail::Shuffle128<_MM_PERM_DBCA>(j3_j2_i3_i2, j3_j2_i3_i2); + StoreU(j1_i1_j0_i0, d, unaligned + 0 * N); + StoreU(j3_i3_j2_i2, d, unaligned + 1 * N); +} + +// Input (128-bit blocks): +// 9 6 3 0 (LSB of i) +// a 7 4 1 +// b 8 5 2 +// Output: +// 3 2 1 0 +// 7 6 5 4 +// b a 9 8 +template +HWY_API void StoreTransposedBlocks3(const VFromD i, const VFromD j, + const VFromD k, D d, + TFromD* HWY_RESTRICT unaligned) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + const VFromD j2_j0_i2_i0 = detail::Shuffle128<_MM_PERM_CACA>(i, j); + const VFromD i3_i1_k2_k0 = detail::Shuffle128<_MM_PERM_DBCA>(k, i); + const VFromD j3_j1_k3_k1 = detail::Shuffle128<_MM_PERM_DBDB>(k, j); + + const VFromD out0 = // i1 k0 j0 i0 + detail::Shuffle128<_MM_PERM_CACA>(j2_j0_i2_i0, i3_i1_k2_k0); + const VFromD out1 = // j2 i2 k1 j1 + detail::Shuffle128<_MM_PERM_DBAC>(j3_j1_k3_k1, j2_j0_i2_i0); + const VFromD out2 = // k3 j3 i3 k2 + detail::Shuffle128<_MM_PERM_BDDB>(i3_i1_k2_k0, j3_j1_k3_k1); + + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + StoreU(out2, d, unaligned + 2 * N); +} + +// Input (128-bit blocks): +// c 8 4 0 (LSB of i) +// d 9 5 1 +// e a 6 2 +// f b 7 3 +// Output: +// 3 2 1 0 +// 7 6 5 4 +// b a 9 8 +// f e d c +template +HWY_API void StoreTransposedBlocks4(const VFromD i, const VFromD j, + const VFromD k, const VFromD l, D d, + TFromD* HWY_RESTRICT unaligned) { + HWY_LANES_CONSTEXPR size_t N = Lanes(d); + const VFromD j1_j0_i1_i0 = detail::Shuffle128<_MM_PERM_BABA>(i, j); + const VFromD l1_l0_k1_k0 = detail::Shuffle128<_MM_PERM_BABA>(k, l); + const VFromD j3_j2_i3_i2 = detail::Shuffle128<_MM_PERM_DCDC>(i, j); + const VFromD l3_l2_k3_k2 = detail::Shuffle128<_MM_PERM_DCDC>(k, l); + const VFromD out0 = + detail::Shuffle128<_MM_PERM_CACA>(j1_j0_i1_i0, l1_l0_k1_k0); + const VFromD out1 = + detail::Shuffle128<_MM_PERM_DBDB>(j1_j0_i1_i0, l1_l0_k1_k0); + const VFromD out2 = + detail::Shuffle128<_MM_PERM_CACA>(j3_j2_i3_i2, l3_l2_k3_k2); + const VFromD out3 = + detail::Shuffle128<_MM_PERM_DBDB>(j3_j2_i3_i2, l3_l2_k3_k2); + StoreU(out0, d, unaligned + 0 * N); + StoreU(out1, d, unaligned + 1 * N); + StoreU(out2, d, unaligned + 2 * N); + StoreU(out3, d, unaligned + 3 * N); +} + +} // namespace detail + +// ------------------------------ Additional mask logical operations + +template +HWY_API Mask512 SetAtOrAfterFirst(Mask512 mask) { + return Mask512{ + static_cast::Raw>(0u - detail::AVX3Blsi(mask.raw))}; +} +template +HWY_API Mask512 SetBeforeFirst(Mask512 mask) { + return Mask512{ + static_cast::Raw>(detail::AVX3Blsi(mask.raw) - 1u)}; +} +template +HWY_API Mask512 SetAtOrBeforeFirst(Mask512 mask) { + return Mask512{ + static_cast::Raw>(detail::AVX3Blsmsk(mask.raw))}; +} +template +HWY_API Mask512 SetOnlyFirst(Mask512 mask) { + return Mask512{ + static_cast::Raw>(detail::AVX3Blsi(mask.raw))}; +} + +// ------------------------------ Shl (Dup128VecFromValues) + +HWY_API Vec512 operator<<(Vec512 v, Vec512 bits) { + return Vec512{_mm512_sllv_epi16(v.raw, bits.raw)}; +} + +// 8-bit: may use the << overload for uint16_t. +HWY_API Vec512 operator<<(Vec512 v, Vec512 bits) { + const DFromV d; +#if HWY_TARGET <= HWY_AVX3_DL + // kMask[i] = 0xFF >> i + const VFromD masks = + Dup128VecFromValues(d, 0xFF, 0x7F, 0x3F, 0x1F, 0x0F, 0x07, 0x03, 0x01, 0, + 0, 0, 0, 0, 0, 0, 0); + // kShl[i] = 1 << i + const VFromD shl = + Dup128VecFromValues(d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0, + 0, 0, 0, 0, 0, 0, 0); + v = And(v, TableLookupBytes(masks, bits)); + const VFromD mul = TableLookupBytes(shl, bits); + return VFromD{_mm512_gf2p8mul_epi8(v.raw, mul.raw)}; +#else + const Repartition dw; + using VW = VFromD; + const VW even_mask = Set(dw, 0x00FF); + const VW odd_mask = Set(dw, 0xFF00); + const VW vw = BitCast(dw, v); + const VW bits16 = BitCast(dw, bits); + // Shift even lanes in-place + const VW evens = vw << And(bits16, even_mask); + const VW odds = And(vw, odd_mask) << ShiftRight<8>(bits16); + return OddEven(BitCast(d, odds), BitCast(d, evens)); +#endif +} + +HWY_API Vec512 operator<<(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_sllv_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator<<(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_sllv_epi64(v.raw, bits.raw)}; +} + +// Signed left shift is the same as unsigned. +template +HWY_API Vec512 operator<<(const Vec512 v, const Vec512 bits) { + const DFromV di; + const RebindToUnsigned du; + return BitCast(di, BitCast(du, v) << BitCast(du, bits)); +} + +// ------------------------------ Shr (IfVecThenElse) + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srlv_epi16(v.raw, bits.raw)}; +} + +// 8-bit uses 16-bit shifts. +HWY_API Vec512 operator>>(Vec512 v, Vec512 bits) { + const DFromV d; + const RepartitionToWide dw; + using VW = VFromD; + const VW mask = Set(dw, 0x00FF); + const VW vw = BitCast(dw, v); + const VW bits16 = BitCast(dw, bits); + const VW evens = And(vw, mask) >> And(bits16, mask); + // Shift odd lanes in-place + const VW odds = vw >> ShiftRight<8>(bits16); + return OddEven(BitCast(d, odds), BitCast(d, evens)); +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srlv_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srlv_epi64(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srav_epi16(v.raw, bits.raw)}; +} + +// 8-bit uses 16-bit shifts. +HWY_API Vec512 operator>>(Vec512 v, Vec512 bits) { + const DFromV d; + const RepartitionToWide dw; + const RebindToUnsigned dw_u; + using VW = VFromD; + const VW mask = Set(dw, 0x00FF); + const VW vw = BitCast(dw, v); + const VW bits16 = BitCast(dw, bits); + const VW evens = ShiftRight<8>(ShiftLeft<8>(vw)) >> And(bits16, mask); + // Shift odd lanes in-place + const VW odds = vw >> BitCast(dw, ShiftRight<8>(BitCast(dw_u, bits16))); + return OddEven(BitCast(d, odds), BitCast(d, evens)); +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srav_epi32(v.raw, bits.raw)}; +} + +HWY_API Vec512 operator>>(const Vec512 v, + const Vec512 bits) { + return Vec512{_mm512_srav_epi64(v.raw, bits.raw)}; +} + +// ------------------------------ WidenMulPairwiseAdd + +#if HWY_NATIVE_DOT_BF16 +template >> +HWY_API VFromD WidenMulPairwiseAdd(DF df, VBF a, VBF b) { + return VFromD{_mm512_dpbf16_ps(Zero(df).raw, + reinterpret_cast<__m512bh>(a.raw), + reinterpret_cast<__m512bh>(b.raw))}; +} +#endif // HWY_NATIVE_DOT_BF16 + +template +HWY_API VFromD WidenMulPairwiseAdd(D /*d32*/, Vec512 a, + Vec512 b) { + return VFromD{_mm512_madd_epi16(a.raw, b.raw)}; +} + +// ------------------------------ SatWidenMulPairwiseAdd +template +HWY_API VFromD SatWidenMulPairwiseAdd( + DI16 /* tag */, VFromD> a, + VFromD> b) { + return VFromD{_mm512_maddubs_epi16(a.raw, b.raw)}; +} + +// ------------------------------ SatWidenMulPairwiseAccumulate +#if HWY_TARGET <= HWY_AVX3_DL +template +HWY_API VFromD SatWidenMulPairwiseAccumulate( + DI32 /* tag */, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD{_mm512_dpwssds_epi32(sum.raw, a.raw, b.raw)}; +} +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ------------------------------ ReorderWidenMulAccumulate + +#if HWY_NATIVE_DOT_BF16 +template >> +HWY_API VFromD ReorderWidenMulAccumulate(DF /*df*/, VBF a, VBF b, + const VFromD sum0, + VFromD& /*sum1*/) { + return VFromD{_mm512_dpbf16_ps(sum0.raw, + reinterpret_cast<__m512bh>(a.raw), + reinterpret_cast<__m512bh>(b.raw))}; +} +#endif // HWY_NATIVE_DOT_BF16 + +template +HWY_API VFromD ReorderWidenMulAccumulate(D d, Vec512 a, + Vec512 b, + const VFromD sum0, + VFromD& /*sum1*/) { + (void)d; +#if HWY_TARGET <= HWY_AVX3_DL + return VFromD{_mm512_dpwssd_epi32(sum0.raw, a.raw, b.raw)}; +#else + return sum0 + WidenMulPairwiseAdd(d, a, b); +#endif +} + +// ------------------------------ SumOfMulQuadAccumulate + +#if HWY_TARGET <= HWY_AVX3_DL + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DI32 /*di32*/, VFromD> a_u, + VFromD> b_i, VFromD sum) { + return VFromD{_mm512_dpbusd_epi32(sum.raw, a_u.raw, b_i.raw)}; +} + +#if HWY_X86_HAVE_AVX10_2_OPS +template +HWY_API VFromD SumOfMulQuadAccumulate(DI32 /*di32*/, + VFromD> a, + VFromD> b, + VFromD sum) { + return VFromD{_mm512_dpbssd_epi32(sum.raw, a.raw, b.raw)}; +} + +template +HWY_API VFromD SumOfMulQuadAccumulate( + DU32 /*du32*/, VFromD> a, + VFromD> b, VFromD sum) { + return VFromD{_mm512_dpbuud_epi32(sum.raw, a.raw, b.raw)}; +} +#endif // HWY_X86_HAVE_AVX10_2_OPS + +#endif + +// ------------------------------ Reductions + +namespace detail { + +// Used by generic_ops-inl +template +HWY_INLINE VFromD ReduceAcrossBlocks(D d, Func f, VFromD v) { + v = f(v, SwapAdjacentBlocks(v)); + return f(v, ReverseBlocks(d, v)); +} + +} // namespace detail + +// ------------------------------ BitShuffle +#if HWY_TARGET <= HWY_AVX3_DL +template ), HWY_IF_UI8(TFromV), + HWY_IF_V_SIZE_V(V, 64), HWY_IF_V_SIZE_V(VI, 64)> +HWY_API V BitShuffle(V v, VI idx) { + const DFromV d64; + const RebindToUnsigned du64; + const Rebind du8; + + const __mmask64 mmask64_bit_shuf_result = + _mm512_bitshuffle_epi64_mask(v.raw, idx.raw); + +#if HWY_ARCH_X86_64 + const VFromD vu8_bit_shuf_result{ + _mm_cvtsi64_si128(static_cast(mmask64_bit_shuf_result))}; +#else + const int32_t i32_lo_bit_shuf_result = + static_cast(mmask64_bit_shuf_result); + const int32_t i32_hi_bit_shuf_result = + static_cast(_kshiftri_mask64(mmask64_bit_shuf_result, 32)); + + const VFromD vu8_bit_shuf_result = ResizeBitCast( + du8, InterleaveLower( + Vec128{_mm_cvtsi32_si128(i32_lo_bit_shuf_result)}, + Vec128{_mm_cvtsi32_si128(i32_hi_bit_shuf_result)})); +#endif + + return BitCast(d64, PromoteTo(du64, vu8_bit_shuf_result)); +} +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ------------------------------ MultiRotateRight + +#if HWY_TARGET <= HWY_AVX3_DL + +#ifdef HWY_NATIVE_MULTIROTATERIGHT +#undef HWY_NATIVE_MULTIROTATERIGHT +#else +#define HWY_NATIVE_MULTIROTATERIGHT +#endif + +template ), HWY_IF_UI8(TFromV), + HWY_IF_V_SIZE_V(V, 64), HWY_IF_V_SIZE_V(VI, HWY_MAX_LANES_V(V) * 8)> +HWY_API V MultiRotateRight(V v, VI idx) { + return V{_mm512_multishift_epi64_epi8(idx.raw, v.raw)}; +} + +#endif + +// -------------------- LeadingZeroCount + +template ), HWY_IF_V_SIZE_V(V, 64)> +HWY_API V LeadingZeroCount(V v) { + return V{_mm512_lzcnt_epi32(v.raw)}; +} + +template ), HWY_IF_V_SIZE_V(V, 64)> +HWY_API V LeadingZeroCount(V v) { + return V{_mm512_lzcnt_epi64(v.raw)}; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +// Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - +// the warning seems to be issued at the call site of intrinsics, i.e. our code. +HWY_DIAGNOSTICS(pop) diff --git a/lib/highway/hwy/ops/x86_avx3-inl.h b/lib/highway/hwy/ops/x86_avx3-inl.h new file mode 100644 index 00000000000..7ea32bce6b3 --- /dev/null +++ b/lib/highway/hwy/ops/x86_avx3-inl.h @@ -0,0 +1,501 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// External include guard in highway.h - see comment there. + +// For AVX3/AVX10 targets that support 512-byte vectors. Already includes base.h +// and shared-inl.h. +#include "hwy/ops/x86_512-inl.h" + +// AVX3/AVX10 ops that have dependencies on ops defined in x86_512-inl.h if +// HWY_MAX_BYTES >= 64 is true are defined below + +// Avoid uninitialized warnings in GCC's avx512fintrin.h - see +// https://github.com/google/highway/issues/710) +HWY_DIAGNOSTICS(push) +#if HWY_COMPILER_GCC_ACTUAL +HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") +HWY_DIAGNOSTICS_OFF(disable : 4701 4703 6001 26494, + ignored "-Wmaybe-uninitialized") +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +#if HWY_TARGET <= HWY_AVX3_DL + +// ------------------------------ ShiftLeft + +// Generic for all vector lengths. Must be defined after all GaloisAffine. +template +HWY_API V ShiftLeft(const V v) { + const Repartition> du64; + if (kBits == 0) return v; + if (kBits == 1) return v + v; + constexpr uint64_t kMatrix = (0x0102040810204080ULL >> kBits) & + (0x0101010101010101ULL * (0xFF >> kBits)); + return detail::GaloisAffine(v, Set(du64, kMatrix)); +} + +// ------------------------------ ShiftRight + +// Generic for all vector lengths. Must be defined after all GaloisAffine. +template )> +HWY_API V ShiftRight(const V v) { + const Repartition> du64; + if (kBits == 0) return v; + constexpr uint64_t kMatrix = + (0x0102040810204080ULL << kBits) & + (0x0101010101010101ULL * ((0xFF << kBits) & 0xFF)); + return detail::GaloisAffine(v, Set(du64, kMatrix)); +} + +// Generic for all vector lengths. Must be defined after all GaloisAffine. +template )> +HWY_API V ShiftRight(const V v) { + const Repartition> du64; + if (kBits == 0) return v; + constexpr uint64_t kShift = + (0x0102040810204080ULL << kBits) & + (0x0101010101010101ULL * ((0xFF << kBits) & 0xFF)); + constexpr uint64_t kSign = + kBits == 0 ? 0 : (0x8080808080808080ULL >> (64 - (8 * kBits))); + return detail::GaloisAffine(v, Set(du64, kShift | kSign)); +} + +// ------------------------------ RotateRight + +// U8 RotateRight is generic for all vector lengths on AVX3_DL +template )> +HWY_API V RotateRight(V v) { + static_assert(0 <= kBits && kBits < 8, "Invalid shift count"); + + const Repartition> du64; + if (kBits == 0) return v; + + constexpr uint64_t kShrMatrix = + (0x0102040810204080ULL << kBits) & + (0x0101010101010101ULL * ((0xFF << kBits) & 0xFF)); + constexpr int kShlBits = (-kBits) & 7; + constexpr uint64_t kShlMatrix = (0x0102040810204080ULL >> kShlBits) & + (0x0101010101010101ULL * (0xFF >> kShlBits)); + constexpr uint64_t kMatrix = kShrMatrix | kShlMatrix; + + return detail::GaloisAffine(v, Set(du64, kMatrix)); +} + +#endif // HWY_TARGET <= HWY_AVX3_DL + +// ------------------------------ Compress + +#pragma push_macro("HWY_X86_SLOW_COMPRESS_STORE") + +#ifndef HWY_X86_SLOW_COMPRESS_STORE // allow override +// Slow on Zen4 and SPR, faster if we emulate via Compress(). +#if HWY_TARGET == HWY_AVX3_ZEN4 || HWY_TARGET == HWY_AVX3_SPR +#define HWY_X86_SLOW_COMPRESS_STORE 1 +#else +#define HWY_X86_SLOW_COMPRESS_STORE 0 +#endif +#endif // HWY_X86_SLOW_COMPRESS_STORE + +// Always implement 8-bit here even if we lack VBMI2 because we can do better +// than generic_ops (8 at a time) via the native 32-bit compress (16 at a time). +#ifdef HWY_NATIVE_COMPRESS8 +#undef HWY_NATIVE_COMPRESS8 +#else +#define HWY_NATIVE_COMPRESS8 +#endif + +namespace detail { + +#if HWY_TARGET <= HWY_AVX3_DL // VBMI2 +template +HWY_INLINE Vec128 NativeCompress(const Vec128 v, + const Mask128 mask) { + return Vec128{_mm_maskz_compress_epi8(mask.raw, v.raw)}; +} +HWY_INLINE Vec256 NativeCompress(const Vec256 v, + const Mask256 mask) { + return Vec256{_mm256_maskz_compress_epi8(mask.raw, v.raw)}; +} +#if HWY_MAX_BYTES >= 64 +HWY_INLINE Vec512 NativeCompress(const Vec512 v, + const Mask512 mask) { + return Vec512{_mm512_maskz_compress_epi8(mask.raw, v.raw)}; +} +#endif + +template +HWY_INLINE Vec128 NativeCompress(const Vec128 v, + const Mask128 mask) { + return Vec128{_mm_maskz_compress_epi16(mask.raw, v.raw)}; +} +HWY_INLINE Vec256 NativeCompress(const Vec256 v, + const Mask256 mask) { + return Vec256{_mm256_maskz_compress_epi16(mask.raw, v.raw)}; +} +#if HWY_MAX_BYTES >= 64 +HWY_INLINE Vec512 NativeCompress(const Vec512 v, + const Mask512 mask) { + return Vec512{_mm512_maskz_compress_epi16(mask.raw, v.raw)}; +} +#endif + +// Do not even define these to prevent accidental usage. +#if !HWY_X86_SLOW_COMPRESS_STORE + +template +HWY_INLINE void NativeCompressStore(Vec128 v, + Mask128 mask, + uint8_t* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi8(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec256 v, Mask256 mask, + uint8_t* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_epi8(unaligned, mask.raw, v.raw); +} +#if HWY_MAX_BYTES >= 64 +HWY_INLINE void NativeCompressStore(Vec512 v, Mask512 mask, + uint8_t* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_epi8(unaligned, mask.raw, v.raw); +} +#endif + +template +HWY_INLINE void NativeCompressStore(Vec128 v, + Mask128 mask, + uint16_t* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi16(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec256 v, Mask256 mask, + uint16_t* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_epi16(unaligned, mask.raw, v.raw); +} +#if HWY_MAX_BYTES >= 64 +HWY_INLINE void NativeCompressStore(Vec512 v, Mask512 mask, + uint16_t* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_epi16(unaligned, mask.raw, v.raw); +} +#endif // HWY_MAX_BYTES >= 64 + +#endif // HWY_X86_SLOW_COMPRESS_STORE + +#endif // HWY_TARGET <= HWY_AVX3_DL + +template +HWY_INLINE Vec128 NativeCompress(Vec128 v, + Mask128 mask) { + return Vec128{_mm_maskz_compress_epi32(mask.raw, v.raw)}; +} +HWY_INLINE Vec256 NativeCompress(Vec256 v, + Mask256 mask) { + return Vec256{_mm256_maskz_compress_epi32(mask.raw, v.raw)}; +} + +#if HWY_MAX_BYTES >= 64 +HWY_INLINE Vec512 NativeCompress(Vec512 v, + Mask512 mask) { + return Vec512{_mm512_maskz_compress_epi32(mask.raw, v.raw)}; +} +#endif +// We use table-based compress for 64-bit lanes, see CompressIsPartition. + +// Do not even define these to prevent accidental usage. +#if !HWY_X86_SLOW_COMPRESS_STORE + +template +HWY_INLINE void NativeCompressStore(Vec128 v, + Mask128 mask, + uint32_t* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec256 v, Mask256 mask, + uint32_t* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); +} +#if HWY_MAX_BYTES >= 64 +HWY_INLINE void NativeCompressStore(Vec512 v, Mask512 mask, + uint32_t* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); +} +#endif + +template +HWY_INLINE void NativeCompressStore(Vec128 v, + Mask128 mask, + uint64_t* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec256 v, Mask256 mask, + uint64_t* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); +} +#if HWY_MAX_BYTES >= 64 +HWY_INLINE void NativeCompressStore(Vec512 v, Mask512 mask, + uint64_t* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); +} +#endif + +template +HWY_INLINE void NativeCompressStore(Vec128 v, Mask128 mask, + float* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_ps(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec256 v, Mask256 mask, + float* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_ps(unaligned, mask.raw, v.raw); +} +#if HWY_MAX_BYTES >= 64 +HWY_INLINE void NativeCompressStore(Vec512 v, Mask512 mask, + float* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_ps(unaligned, mask.raw, v.raw); +} +#endif + +template +HWY_INLINE void NativeCompressStore(Vec128 v, + Mask128 mask, + double* HWY_RESTRICT unaligned) { + _mm_mask_compressstoreu_pd(unaligned, mask.raw, v.raw); +} +HWY_INLINE void NativeCompressStore(Vec256 v, Mask256 mask, + double* HWY_RESTRICT unaligned) { + _mm256_mask_compressstoreu_pd(unaligned, mask.raw, v.raw); +} +#if HWY_MAX_BYTES >= 64 +HWY_INLINE void NativeCompressStore(Vec512 v, Mask512 mask, + double* HWY_RESTRICT unaligned) { + _mm512_mask_compressstoreu_pd(unaligned, mask.raw, v.raw); +} +#endif + +#endif // HWY_X86_SLOW_COMPRESS_STORE + +// For u8x16 and <= u16x16 we can avoid store+load for Compress because there is +// only a single compressed vector (u32x16). Other EmuCompress are implemented +// after the EmuCompressStore they build upon. +template ), + HWY_IF_LANES_LE_D(DFromV, HWY_MAX_BYTES / 4)> +static HWY_INLINE HWY_MAYBE_UNUSED V EmuCompress(V v, MFromD> mask) { + const DFromV d; + const Rebind d32; + const VFromD v0 = PromoteTo(d32, v); + + using M32 = MFromD; + const M32 m0 = PromoteMaskTo(d32, d, mask); + return TruncateTo(d, Compress(v0, m0)); +} + +template ), + HWY_IF_LANES_LE_D(DFromV, HWY_MAX_BYTES / 4)> +static HWY_INLINE HWY_MAYBE_UNUSED V EmuCompress(V v, MFromD> mask) { + const DFromV d; + const Rebind di32; + const RebindToUnsigned du32; + + const MFromD mask32 = PromoteMaskTo(du32, d, mask); + // DemoteTo is 2 ops, but likely lower latency than TruncateTo on SKX. + // Only i32 -> u16 is supported, whereas NativeCompress expects u32. + const VFromD v32 = PromoteTo(du32, v); + return DemoteTo(d, BitCast(di32, NativeCompress(v32, mask32))); +} + +// See above - small-vector EmuCompressStore are implemented via EmuCompress. +template +static HWY_INLINE HWY_MAYBE_UNUSED void EmuCompressStore( + VFromD v, MFromD mask, D d, TFromD* HWY_RESTRICT unaligned) { + StoreU(EmuCompress(v, mask), d, unaligned); +} + +// Main emulation logic for wider vector, starting with EmuCompressStore because +// it is most convenient to merge pieces using memory (concatenating vectors at +// byte offsets is difficult). +template +static HWY_INLINE HWY_MAYBE_UNUSED void EmuCompressStore( + VFromD v, MFromD mask, D d, TFromD* HWY_RESTRICT unaligned) { + const Half dh; + + const MFromD m0 = LowerHalfOfMask(dh, mask); + const MFromD m1 = UpperHalfOfMask(dh, mask); + + const VFromD v0 = LowerHalf(dh, v); + const VFromD v1 = UpperHalf(dh, v); + + EmuCompressStore(v0, m0, dh, unaligned); + EmuCompressStore(v1, m1, dh, unaligned + CountTrue(dh, m0)); +} + +// Finally, the remaining EmuCompress for wide vectors, using EmuCompressStore. +template , HWY_MAX_BYTES / 4)> +static HWY_INLINE HWY_MAYBE_UNUSED V EmuCompress(V v, MFromD> mask) { + using D = DFromV; + using T = TFromD; + const D d; + + alignas(HWY_MAX_LANES_D(D) * sizeof(T)) T buf[2 * HWY_MAX_LANES_D(D)]; + EmuCompressStore(v, mask, d, buf); + return Load(d, buf); +} + +} // namespace detail + +template +HWY_API V Compress(V v, const M mask) { + const DFromV d; + const RebindToUnsigned du; + const auto mu = RebindMask(du, mask); +#if HWY_TARGET <= HWY_AVX3_DL // VBMI2 + return BitCast(d, detail::NativeCompress(BitCast(du, v), mu)); +#else + return BitCast(d, detail::EmuCompress(BitCast(du, v), mu)); +#endif +} + +template +HWY_API V Compress(V v, const M mask) { + const DFromV d; + const RebindToUnsigned du; + const auto mu = RebindMask(du, mask); + return BitCast(d, detail::NativeCompress(BitCast(du, v), mu)); +} + +// ------------------------------ CompressNot + +template +HWY_API V CompressNot(V v, const M mask) { + return Compress(v, Not(mask)); +} + +// uint64_t lanes. Only implement for 256 and 512-bit vectors because this is a +// no-op for 128-bit. +template , 16)> +HWY_API V CompressBlocksNot(V v, M mask) { + return CompressNot(v, mask); +} + +// ------------------------------ CompressBits +template +HWY_API V CompressBits(V v, const uint8_t* HWY_RESTRICT bits) { + return Compress(v, LoadMaskBits(DFromV(), bits)); +} + +// ------------------------------ CompressStore + +// Generic for all vector lengths. + +template +HWY_API size_t CompressStore(VFromD v, MFromD mask, D d, + TFromD* HWY_RESTRICT unaligned) { +#if HWY_X86_SLOW_COMPRESS_STORE + StoreU(Compress(v, mask), d, unaligned); +#else + const RebindToUnsigned du; + const auto mu = RebindMask(du, mask); + auto pu = reinterpret_cast * HWY_RESTRICT>(unaligned); + +#if HWY_TARGET <= HWY_AVX3_DL // VBMI2 + detail::NativeCompressStore(BitCast(du, v), mu, pu); +#else + detail::EmuCompressStore(BitCast(du, v), mu, du, pu); +#endif +#endif // HWY_X86_SLOW_COMPRESS_STORE + const size_t count = CountTrue(d, mask); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +template +HWY_API size_t CompressStore(VFromD v, MFromD mask, D d, + TFromD* HWY_RESTRICT unaligned) { +#if HWY_X86_SLOW_COMPRESS_STORE + StoreU(Compress(v, mask), d, unaligned); +#else + const RebindToUnsigned du; + const auto mu = RebindMask(du, mask); + using TU = TFromD; + TU* HWY_RESTRICT pu = reinterpret_cast(unaligned); + detail::NativeCompressStore(BitCast(du, v), mu, pu); +#endif // HWY_X86_SLOW_COMPRESS_STORE + const size_t count = CountTrue(d, mask); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +// Additional overloads to avoid casting to uint32_t (delay?). +template +HWY_API size_t CompressStore(VFromD v, MFromD mask, D d, + TFromD* HWY_RESTRICT unaligned) { +#if HWY_X86_SLOW_COMPRESS_STORE + StoreU(Compress(v, mask), d, unaligned); +#else + (void)d; + detail::NativeCompressStore(v, mask, unaligned); +#endif // HWY_X86_SLOW_COMPRESS_STORE + const size_t count = PopCount(uint64_t{mask.raw}); + detail::MaybeUnpoison(unaligned, count); + return count; +} + +// ------------------------------ CompressBlendedStore +template +HWY_API size_t CompressBlendedStore(VFromD v, MFromD m, D d, + TFromD* HWY_RESTRICT unaligned) { + // Native CompressStore already does the blending at no extra cost (latency + // 11, rthroughput 2 - same as compress plus store). + + HWY_IF_CONSTEXPR(HWY_MAX_LANES_D(D) < (16 / sizeof(TFromD))) { + m = And(m, FirstN(d, HWY_MAX_LANES_D(D))); + } + + HWY_IF_CONSTEXPR(!HWY_X86_SLOW_COMPRESS_STORE && + (HWY_TARGET <= HWY_AVX3_DL || sizeof(TFromD) > 2)) { + return CompressStore(v, m, d, unaligned); + } + else { + const size_t count = CountTrue(d, m); + StoreN(Compress(v, m), d, unaligned, count); + detail::MaybeUnpoison(unaligned, count); + return count; + } +} + +// ------------------------------ CompressBitsStore +// Generic for all vector lengths. +template +HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, + D d, TFromD* HWY_RESTRICT unaligned) { + return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); +} + +#pragma pop_macro("HWY_X86_SLOW_COMPRESS_STORE") + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +// Note that the GCC warnings are not suppressed if we only wrap the *intrin.h - +// the warning seems to be issued at the call site of intrinsics, i.e. our code. +HWY_DIAGNOSTICS(pop) diff --git a/lib/highway/hwy/per_target.cc b/lib/highway/hwy/per_target.cc new file mode 100644 index 00000000000..432d3861e03 --- /dev/null +++ b/lib/highway/hwy/per_target.cc @@ -0,0 +1,77 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Enable all targets so that calling Have* does not call into a null pointer. +#ifndef HWY_COMPILE_ALL_ATTAINABLE +#define HWY_COMPILE_ALL_ATTAINABLE +#endif +#include "hwy/per_target.h" + +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "hwy/per_target.cc" +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace { +int64_t GetTarget() { return HWY_TARGET; } +size_t GetVectorBytes() { return Lanes(ScalableTag()); } +bool GetHaveInteger64() { return HWY_HAVE_INTEGER64 != 0; } +bool GetHaveFloat16() { return HWY_HAVE_FLOAT16 != 0; } +bool GetHaveFloat64() { return HWY_HAVE_FLOAT64 != 0; } +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE + +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace hwy { +namespace { +HWY_EXPORT(GetTarget); +HWY_EXPORT(GetVectorBytes); +HWY_EXPORT(GetHaveInteger64); +HWY_EXPORT(GetHaveFloat16); +HWY_EXPORT(GetHaveFloat64); +} // namespace + +HWY_DLLEXPORT int64_t DispatchedTarget() { + return HWY_DYNAMIC_DISPATCH(GetTarget)(); +} + +HWY_DLLEXPORT size_t VectorBytes() { + return HWY_DYNAMIC_DISPATCH(GetVectorBytes)(); +} + +HWY_DLLEXPORT bool HaveInteger64() { + return HWY_DYNAMIC_DISPATCH(GetHaveInteger64)(); +} + +HWY_DLLEXPORT bool HaveFloat16() { + return HWY_DYNAMIC_DISPATCH(GetHaveFloat16)(); +} + +HWY_DLLEXPORT bool HaveFloat64() { + return HWY_DYNAMIC_DISPATCH(GetHaveFloat64)(); +} + +} // namespace hwy +#endif // HWY_ONCE diff --git a/lib/highway/hwy/per_target.h b/lib/highway/hwy/per_target.h new file mode 100644 index 00000000000..7a86b0ebe60 --- /dev/null +++ b/lib/highway/hwy/per_target.h @@ -0,0 +1,49 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_PER_TARGET_H_ +#define HIGHWAY_HWY_PER_TARGET_H_ + +#include +#include + +#include "hwy/highway_export.h" + +// Functions to query the capabilities of the target that will be called by +// HWY_DYNAMIC_DISPATCH, which is not necessarily the current target. + +namespace hwy { + +// Returns the HWY_TARGET which HWY_DYNAMIC_DISPATCH selected. +HWY_DLLEXPORT int64_t DispatchedTarget(); + +// Returns size in bytes of a vector, i.e. `Lanes(ScalableTag())`. +// +// Do not cache the result, which may change after calling DisableTargets, or +// if software requests a different vector size (e.g. when entering/exiting SME +// streaming mode). Instead call this right before the code that depends on the +// result, without any DisableTargets or SME transition in-between. Note that +// this involves an indirect call, so prefer not to call this frequently nor +// unnecessarily. +HWY_DLLEXPORT size_t VectorBytes(); + +// Returns whether 64-bit integers, 16/64-bit floats are a supported lane type. +HWY_DLLEXPORT bool HaveInteger64(); +HWY_DLLEXPORT bool HaveFloat16(); +HWY_DLLEXPORT bool HaveFloat64(); + +} // namespace hwy + +#endif // HIGHWAY_HWY_PER_TARGET_H_ diff --git a/lib/highway/hwy/perf_counters.cc b/lib/highway/hwy/perf_counters.cc new file mode 100644 index 00000000000..4cad466d67f --- /dev/null +++ b/lib/highway/hwy/perf_counters.cc @@ -0,0 +1,377 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/perf_counters.h" + +#include "hwy/detect_compiler_arch.h" // HWY_OS_LINUX + +#if HWY_OS_LINUX || HWY_IDE +#include +#include // open +#include +#include +#include +#include +#include // strcmp +#include +#include +#include // O_RDONLY +#include +#include +#include + +#include +#include + +#include "hwy/base.h" // HWY_ASSERT +#include "hwy/bit_set.h" +#include "hwy/timer.h" + +#endif // HWY_OS_LINUX || HWY_IDE + +namespace hwy { +namespace platform { + +#if HWY_OS_LINUX || HWY_IDE + +namespace { + +bool PerfCountersSupported() { + // This is the documented way. + struct stat s; + return stat("/proc/sys/kernel/perf_event_paranoid", &s) == 0; +} + +// If we detect Linux < 6.9 and AMD EPYC, use cycles instead of ref-cycles +// because the latter is not supported and returns 0, see +// https://lwn.net/Articles/967791/. +uint64_t RefCyclesOrCycles() { + const uint32_t ref_cycles = PERF_COUNT_HW_REF_CPU_CYCLES; + + utsname buf; + if (uname(&buf) != 0) return ref_cycles; + if (std::string(buf.sysname) != "Linux") return ref_cycles; + int major, minor; + if (sscanf(buf.release, "%d.%d", &major, &minor) != 2) return ref_cycles; + if (major > 6 || (major == 6 && minor >= 9)) return ref_cycles; + + // AMD Zen4 CPU + char cpu100[100]; + if (!GetCpuString(cpu100)) return ref_cycles; + if (std::string(cpu100).rfind("AMD EPYC", 0) != 0) return ref_cycles; + + return PERF_COUNT_HW_CPU_CYCLES; +} + +struct CounterConfig { // for perf_event_open + uint64_t config; + uint32_t type; + PerfCounters::Counter c; +}; + +std::vector AllCounterConfigs() { + constexpr uint32_t kHW = PERF_TYPE_HARDWARE; + constexpr uint32_t kSW = PERF_TYPE_SOFTWARE; + constexpr uint32_t kC = PERF_TYPE_HW_CACHE; + constexpr uint64_t kL3 = PERF_COUNT_HW_CACHE_LL; + constexpr uint64_t kLoad = uint64_t{PERF_COUNT_HW_CACHE_OP_READ} << 8; + constexpr uint64_t kStore = uint64_t{PERF_COUNT_HW_CACHE_OP_WRITE} << 8; + constexpr uint64_t kAcc = uint64_t{PERF_COUNT_HW_CACHE_RESULT_ACCESS} << 16; + + // Order is important for bin-packing event groups. x86 can only handle two + // LLC-related events per group, so spread them out and arrange SW events + // such that do not start a new group. This list of counters may change. + return {{RefCyclesOrCycles(), kHW, PerfCounters::kRefCycles}, + {PERF_COUNT_HW_INSTRUCTIONS, kHW, PerfCounters::kInstructions}, + {PERF_COUNT_SW_PAGE_FAULTS, kSW, PerfCounters::kPageFaults}, + {kL3 | kLoad | kAcc, kC, PerfCounters::kL3Loads}, + {kL3 | kStore | kAcc, kC, PerfCounters::kL3Stores}, + {PERF_COUNT_HW_BRANCH_INSTRUCTIONS, kHW, PerfCounters::kBranches}, + {PERF_COUNT_HW_BRANCH_MISSES, kHW, PerfCounters::kBranchMispredicts}, + // Second group: + {PERF_COUNT_HW_BUS_CYCLES, kHW, PerfCounters::kBusCycles}, + {PERF_COUNT_SW_CPU_MIGRATIONS, kSW, PerfCounters::kMigrations}, + {PERF_COUNT_HW_CACHE_REFERENCES, kHW, PerfCounters::kCacheRefs}, + {PERF_COUNT_HW_CACHE_MISSES, kHW, PerfCounters::kCacheMisses}}; +} + +size_t& PackedIdx(PerfCounters::Counter c) { + static size_t packed_idx[64]; + return packed_idx[static_cast(c)]; +} + +class PMU { + static perf_event_attr MakeAttr(const CounterConfig& cc) { + perf_event_attr attr = {}; + attr.type = cc.type; + attr.size = sizeof(attr); + attr.config = cc.config; + // We request more counters than the HW may support. If so, they are + // multiplexed and only active for a fraction of the runtime. Recording the + // times lets us extrapolate. GROUP enables a single syscall to reduce the + // cost of reading. + attr.read_format = PERF_FORMAT_TOTAL_TIME_ENABLED | + PERF_FORMAT_TOTAL_TIME_RUNNING | PERF_FORMAT_GROUP; + // Do not set inherit=1 because that conflicts with PERF_FORMAT_GROUP. + // Do not set disable=1, so that perf_event_open verifies all events in the + // group can be scheduled together. + attr.exclude_kernel = 1; // required if perf_event_paranoid == 1 + attr.exclude_hv = 1; // = hypervisor + return attr; + } + + static int SysPerfEventOpen(const CounterConfig& cc, int leader_fd) { + perf_event_attr attr = MakeAttr(cc); + const int pid = 0; // current process (cannot also be -1) + const int cpu = -1; // any CPU + // Retry if interrupted by signals; this actually happens (b/64774091). + for (int retry = 0; retry < 10; ++retry) { + const int flags = 0; + const int fd = static_cast( + syscall(__NR_perf_event_open, &attr, pid, cpu, leader_fd, flags)); + if (!(fd == -1 && errno == EINTR)) return fd; + } + HWY_WARN("perf_event_open retries were insufficient."); + return -1; + } + + // Reads from `fd`; recovers from interruptions before/during the read. + static bool ReadBytes(int fd, ssize_t size, void* to) { + uint8_t* bytes = reinterpret_cast(to); + ssize_t pos = 0; + for (int retry = 0; retry < 10; ++retry) { + const ssize_t bytes_read = + read(fd, bytes + pos, static_cast(size - pos)); + if (HWY_UNLIKELY(bytes_read <= 0)) { + if (errno == EINTR) continue; + HWY_WARN("perf read() failed, errno %d.", errno); + return false; + } + pos += bytes_read; + HWY_ASSERT(pos <= size); + if (HWY_LIKELY(pos == size)) return true; // success + } + HWY_WARN("perf read() wanted %d bytes, got %d.", static_cast(size), + static_cast(pos)); + return false; + } + + // Array size in Buf; this is another upper bound on group size. It should be + // loose because it only wastes a bit of stack space, whereas an unnecessary + // extra group decreases coverage. Most HW supports 4-8 counters per group. + static constexpr size_t kMaxEventsPerGroup = PerfCounters::kCapacity; + +#pragma pack(push, 1) + struct Buf { + uint64_t num_events; + uint64_t time_enabled; + uint64_t time_running; + uint64_t values[kMaxEventsPerGroup]; + }; +#pragma pack(pop) + + // Returns false on error, otherwise sets `extrapolate` and `values`. + static bool ReadAndExtrapolate(int fd, size_t num_events, double& extrapolate, + double* HWY_RESTRICT values) { + Buf buf; + const ssize_t want_bytes = // size of var-len `Buf` + static_cast(24 + num_events * sizeof(uint64_t)); + if (HWY_UNLIKELY(!ReadBytes(fd, want_bytes, &buf))) return false; + + HWY_DASSERT(num_events == buf.num_events); + HWY_DASSERT(buf.time_running <= buf.time_enabled); + // If the group was not yet scheduled, we must avoid division by zero. + // In case counters were previously running and not reset, their current + // values may be nonzero. Returning zero could be interpreted as counters + // running backwards, so we instead treat this as a failure and mark the + // counters as invalid. + if (HWY_UNLIKELY(buf.time_running == 0)) return false; + + // Extrapolate each value. + extrapolate = static_cast(buf.time_enabled) / + static_cast(buf.time_running); + for (size_t i = 0; i < buf.num_events; ++i) { + values[i] = static_cast(buf.values[i]) * extrapolate; + } + return true; + } + + public: + bool Init() { + // Allow callers who do not know about each other to each call `Init`. + // If this already succeeded, we're done; if not, we will try again. + if (HWY_UNLIKELY(!fds_.empty())) return true; + if (HWY_UNLIKELY(!PerfCountersSupported())) { + HWY_WARN( + "This Linux does not support perf counters. The program will" + "continue, but counters will return zero."); + return false; + } + + groups_.push_back(Group()); + fds_.reserve(PerfCounters::kCapacity); + + for (const CounterConfig& config : AllCounterConfigs()) { + // If the group is limited by our buffer size, add a new one. + if (HWY_UNLIKELY(groups_.back().num_events == kMaxEventsPerGroup)) { + groups_.push_back(Group()); + } + + int fd = SysPerfEventOpen(config, groups_.back().leader_fd); + // Retry in case the group is limited by HW capacity. Do not check + // errno because it is too inconsistent (ENOSPC, EINVAL, others?). + if (HWY_UNLIKELY(fd < 0)) { + fd = SysPerfEventOpen(config, /*leader_fd=*/-1); + if (fd >= 0 && groups_.back().num_events != 0) { + groups_.push_back(Group()); + } + } + + if (HWY_UNLIKELY(fd < 0)) { + HWY_WARN("perf_event_open %d errno %d for counter %s.", fd, errno, + PerfCounters::Name(config.c)); + } else { + // Add to group and set as leader if empty. + if (groups_.back().leader_fd == -1) { + groups_.back().leader_fd = fd; + + // Ensure the leader is not a SW event, because adding an HW + // event to a group with only SW events is slow, and starting + // with SW may trigger a bug, see + // https://lore.kernel.org/lkml/tip-a1150c202207cc8501bebc45b63c264f91959260@git.kernel.org/ + if (HWY_UNLIKELY(config.type == PERF_TYPE_SOFTWARE)) { + HWY_WARN("SW event %s should not be leader.", + PerfCounters::Name(config.c)); + } + } + + PackedIdx(config.c) = fds_.size(); + groups_.back().num_events += 1; + valid_.Set(static_cast(config.c)); + fds_.push_back(fd); + } + } + + // If no counters are available, remove the empty group. + if (HWY_UNLIKELY(fds_.empty())) { + HWY_ASSERT(groups_.size() == 1); + HWY_ASSERT(groups_.back().num_events == 0); + HWY_ASSERT(groups_.back().leader_fd == -1); + groups_.clear(); + } + + size_t num_valid = 0; + for (const Group& group : groups_) { + num_valid += group.num_events; + // All groups have a leader and are not empty. + HWY_ASSERT(group.leader_fd >= 0); + HWY_ASSERT(0 != group.num_events && + group.num_events <= kMaxEventsPerGroup); + } + // Total `num_events` matches `fds_` and `Valid()`. + HWY_ASSERT(num_valid == fds_.size()); + HWY_ASSERT(num_valid == valid_.Count()); + HWY_ASSERT(num_valid <= PerfCounters::kCapacity); + + if (num_valid) { + StopAllAndReset(); + return true; + } else { + HWY_WARN("No valid counters found."); + return true; + } + } + + bool StartAll() { + if (HWY_UNLIKELY(fds_.empty())) return false; + HWY_ASSERT(prctl(PR_TASK_PERF_EVENTS_ENABLE) == 0); + return true; + } + + void StopAllAndReset() { + HWY_ASSERT(prctl(PR_TASK_PERF_EVENTS_DISABLE) == 0); + for (int fd : fds_) { + HWY_ASSERT(ioctl(fd, PERF_EVENT_IOC_RESET, 0) == 0); + } + } + + // Returns false on error, otherwise sets `valid`, `max_extrapolate`, and + // `values`. + bool Read(BitSet64& valid, double& max_extrapolate, double* values) { + if (HWY_UNLIKELY(!valid_.Any())) return false; + + // Read all counters into buffer in the order in which they were opened. + max_extrapolate = 1.0; + double* pos = values; + for (const Group& group : groups_) { + double extrapolate; + if (HWY_UNLIKELY(!ReadAndExtrapolate(group.leader_fd, group.num_events, + extrapolate, pos))) { + return false; + } + max_extrapolate = HWY_MAX(max_extrapolate, extrapolate); + pos += group.num_events; + } + + valid = valid_; + HWY_DASSERT(pos == values + valid.Count()); + return true; + } + + private: + std::vector fds_; // one per valid_ + BitSet64 valid_; + + struct Group { + size_t num_events = 0; + int leader_fd = -1; + }; + std::vector groups_; +}; + +// Monostate, see header. +PMU& GetPMU() { + static PMU& pmu = *new PMU(); // avoids exit-dtor warning (no dtor required) + return pmu; +} + +} // namespace + +HWY_DLLEXPORT bool PerfCounters::Init() { return GetPMU().Init(); } +HWY_DLLEXPORT bool PerfCounters::StartAll() { return GetPMU().StartAll(); } +HWY_DLLEXPORT void PerfCounters::StopAllAndReset() { + GetPMU().StopAllAndReset(); +} +HWY_DLLEXPORT PerfCounters::PerfCounters() { + if (HWY_UNLIKELY(!GetPMU().Read(valid_, max_extrapolate_, values_))) { + valid_ = BitSet64(); + max_extrapolate_ = 0.0; + hwy::ZeroBytes(values_, sizeof(values_)); + } +} +HWY_DLLEXPORT size_t PerfCounters::IndexForCounter(Counter c) { + return PackedIdx(c); +} +#else +HWY_DLLEXPORT bool PerfCounters::Init() { return false; } +HWY_DLLEXPORT bool PerfCounters::StartAll() { return false; } +HWY_DLLEXPORT void PerfCounters::StopAllAndReset() {} +HWY_DLLEXPORT PerfCounters::PerfCounters() + : max_extrapolate_(1.0), values_{0.0} {} +HWY_DLLEXPORT size_t PerfCounters::IndexForCounter(Counter) { return 0; } +#endif // HWY_OS_LINUX || HWY_IDE + +} // namespace platform +} // namespace hwy diff --git a/lib/highway/hwy/perf_counters.h b/lib/highway/hwy/perf_counters.h new file mode 100644 index 00000000000..74851162477 --- /dev/null +++ b/lib/highway/hwy/perf_counters.h @@ -0,0 +1,156 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_PERF_COUNTERS_H_ +#define HIGHWAY_HWY_PERF_COUNTERS_H_ + +// Reads OS/CPU performance counters. + +#include + +#include "hwy/base.h" // HWY_ABORT +#include "hwy/bit_set.h" + +namespace hwy { +namespace platform { + +// Avoid padding in case callers such as profiler.h store many instances. +#pragma pack(push, 1) +// Provides access to CPU/OS performance counters. Each instance has space for +// multiple counter values; which counters these are may change in future. +// Although counters are per-CPU, Linux accesses them via a syscall, hence we +// use the monostate pattern to avoid callers having to pass around a pointer. +// Note that this is not thread-safe, so the static member functions should only +// be called from the main thread. +class PerfCounters { + public: + // Chosen such that this class occupies one or two cache lines. + static constexpr size_t kCapacity = 14; + + // Bit indices used to identify counters. The ordering is arbitrary. Some of + // these counters may be 'removed' in the sense of not being visited by + // `Foreach`, but their enumerators will remain. New counters may be appended. + enum Counter { + kRefCycles = 0, + kInstructions, + kBranches, + kBranchMispredicts, + kBusCycles, + kCacheRefs, + kCacheMisses, + kL3Loads, + kL3Stores, + kPageFaults, // SW + kMigrations // SW + }; // BitSet64 requires these values to be less than 64. + + // Strings for user-facing messages, not used in the implementation. + static inline const char* Name(Counter c) { + switch (c) { + case kRefCycles: + return "ref_cycles"; + case kInstructions: + return "instructions"; + case kBranches: + return "branches"; + case kBranchMispredicts: + return "branch_mispredicts"; + case kBusCycles: + return "bus_cycles"; + case kCacheRefs: + return "cache_refs"; + case kCacheMisses: + return "cache_misses"; + case kL3Loads: + return "l3_load"; + case kL3Stores: + return "l3_store"; + case kPageFaults: + return "page_fault"; + case kMigrations: + return "migration"; + default: + HWY_UNREACHABLE; + } + } + + // Returns false if counters are unavailable. Must be called at least once + // before `StartAll`; it is separate to reduce the overhead of repeatedly + // stopping/starting counters. + HWY_DLLEXPORT static bool Init(); + + // Returns false if counters are unavailable, otherwise starts them. Note that + // they default to stopped. Unless this is called, the values read may be 0. + HWY_DLLEXPORT static bool StartAll(); + + // Stops and zeros all counters. This is not necessary if users subtract the + // previous counter values, but can increase precision because floating-point + // has more precision near zero. + HWY_DLLEXPORT static void StopAllAndReset(); + + // Reads the current (extrapolated, in case of multiplexing) counter values. + HWY_DLLEXPORT PerfCounters(); + + // Returns whether any counters were successfully read. + bool AnyValid() const { return valid_.Any(); } + + // Returns whether the given counter was successfully read. + bool IsValid(Counter c) const { + const size_t bit_idx = static_cast(c); + return valid_.Get(bit_idx); + } + + // Returns the maximum extrapolation factor for any counter, which is the + // total time between `StartAll` and now or the last `StopAllAndReset`, + // divided by the time that the counter was actually running. This + // approximates the number of counter groups that the CPU multiplexes onto the + // actual counter hardware. It is only meaningful if AnyValid(). + double MaxExtrapolate() const { return max_extrapolate_; } + + // Returns the value of the given counter, or zero if it is not valid. + double Get(Counter c) const { + return IsValid(c) ? values_[IndexForCounter(c)] : 0.0; + } + + // For each valid counter in increasing numerical order, calls `visitor` with + // the value and `Counter`. + template + void Foreach(const Visitor& visitor) { + valid_.Foreach([&](size_t bit_idx) { + const Counter c = static_cast(bit_idx); + visitor(values_[IndexForCounter(c)], c); + }); + } + + private: + // Index within `values_` for a given counter. + HWY_DLLEXPORT static size_t IndexForCounter(Counter c); + + BitSet64 valid_; + double max_extrapolate_; + // Floating-point because these are extrapolated (multiplexing). It would be + // nice for this to fit in one cache line to reduce the cost of reading + // counters in profiler.h, but some of the values are too large for float and + // we want more than 8 counters. Ensure all values are sums, not ratios, so + // that profiler.h can add/subtract them. These are contiguous in memory, in + // the order that counters were initialized. + double values_[kCapacity]; +}; +#pragma pack(pop) + +} // namespace platform +} // namespace hwy + +#endif // HIGHWAY_HWY_PERF_COUNTERS_H_ diff --git a/lib/highway/hwy/print-inl.h b/lib/highway/hwy/print-inl.h new file mode 100644 index 00000000000..48a9cce3d8c --- /dev/null +++ b/lib/highway/hwy/print-inl.h @@ -0,0 +1,64 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Print() function + +#include + +#include "hwy/highway.h" +#include "hwy/print.h" + +// Per-target include guard +#if defined(HIGHWAY_HWY_PRINT_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_PRINT_INL_H_ +#undef HIGHWAY_HWY_PRINT_INL_H_ +#else +#define HIGHWAY_HWY_PRINT_INL_H_ +#endif + +#if HWY_TARGET == HWY_RVV +#include "hwy/aligned_allocator.h" +#endif + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { + +// Prints lanes around `lane`, in memory order. +template > +HWY_API void Print(const D d, const char* caption, V v, size_t lane_u = 0, + size_t max_lanes = 7) { + const size_t N = Lanes(d); + using T = TFromD; +#if HWY_TARGET == HWY_RVV + auto storage = AllocateAligned(N); + T* HWY_RESTRICT lanes = storage.get(); +#else + // This works around an SVE compile error on GCC 11 and 12. Calling + // AllocateAligned here would seem to require it be marked with HWY_ATTR. + HWY_ALIGN T lanes[MaxLanes(d)]; +#endif + Store(v, d, lanes); + + const auto info = hwy::detail::MakeTypeInfo(); + hwy::detail::PrintArray(info, caption, lanes, N, lane_u, max_lanes); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // per-target include guard diff --git a/lib/highway/hwy/print.cc b/lib/highway/hwy/print.cc new file mode 100644 index 00000000000..cea41042b69 --- /dev/null +++ b/lib/highway/hwy/print.cc @@ -0,0 +1,138 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/print.h" + +#include + +#include "hwy/base.h" +#include "hwy/detect_compiler_arch.h" + +namespace hwy { +namespace detail { + +HWY_DLLEXPORT void TypeName(const TypeInfo& info, size_t N, char* string100) { + const char prefix = info.is_float ? 'f' : (info.is_signed ? 'i' : 'u'); + // Omit the xN suffix for scalars. + if (N == 1) { + // NOLINTNEXTLINE + snprintf(string100, 64, "%c%d", prefix, + static_cast(info.sizeof_t * 8)); + } else { + // NOLINTNEXTLINE + snprintf(string100, 64, "%c%dx%d", prefix, + static_cast(info.sizeof_t * 8), static_cast(N)); + } +} + +// The NOLINT are to suppress the warning about passing 100 instead of +// `sizeof(string100)`, which is a pointer. +HWY_DLLEXPORT void ToString(const TypeInfo& info, const void* ptr, + char* string100) { + if (info.sizeof_t == 1) { + if (info.is_signed) { + int8_t byte; + CopyBytes<1>(ptr, &byte); // endian-safe: we ensured sizeof(T)=1. + snprintf(string100, 100, "%d", byte); // NOLINT + } else { + uint8_t byte; + CopyBytes<1>(ptr, &byte); // endian-safe: we ensured sizeof(T)=1. + snprintf(string100, 100, "0x%02X", byte); // NOLINT + } + } else if (info.sizeof_t == 2) { + if (info.is_bf16) { + const double value = static_cast(F32FromBF16Mem(ptr)); + // NOLINTNEXTLINE + snprintf(string100, 100, hwy::ScalarAbs(value) < 1E-3 ? "%.3E" : "%.3f", + value); + } else if (info.is_float) { + const double value = static_cast(F32FromF16Mem(ptr)); + // NOLINTNEXTLINE + snprintf(string100, 100, hwy::ScalarAbs(value) < 1E-4 ? "%.4E" : "%.4f", + value); + } else { + uint16_t bits; + CopyBytes<2>(ptr, &bits); + snprintf(string100, 100, "0x%04X", bits); // NOLINT + } + } else if (info.sizeof_t == 4) { + if (info.is_float) { + float value; + CopyBytes<4>(ptr, &value); + // NOLINTNEXTLINE + snprintf(string100, 100, hwy::ScalarAbs(value) < 1E-6f ? "%.9E" : "%.9f", + static_cast(value)); + } else if (info.is_signed) { + int32_t value; + CopyBytes<4>(ptr, &value); + snprintf(string100, 100, "%d", value); // NOLINT + } else { + uint32_t value; + CopyBytes<4>(ptr, &value); + snprintf(string100, 100, "%u", value); // NOLINT + } + } else if (info.sizeof_t == 8) { + if (info.is_float) { + double value; + CopyBytes<8>(ptr, &value); + // NOLINTNEXTLINE + snprintf(string100, 100, hwy::ScalarAbs(value) < 1E-9 ? "%.18E" : "%.18f", + value); + } else { + const uint8_t* ptr8 = reinterpret_cast(ptr); + uint32_t lo, hi; + CopyBytes<4>(ptr8 + (HWY_IS_LITTLE_ENDIAN ? 0 : 4), &lo); + CopyBytes<4>(ptr8 + (HWY_IS_LITTLE_ENDIAN ? 4 : 0), &hi); + snprintf(string100, 100, "0x%08x%08x", hi, lo); // NOLINT + } + } else if (info.sizeof_t == 16) { + HWY_ASSERT(!info.is_float && !info.is_signed && !info.is_bf16); + const uint8_t* ptr8 = reinterpret_cast(ptr); + uint32_t words[4]; + CopyBytes<4>(ptr8 + (HWY_IS_LITTLE_ENDIAN ? 0 : 12), &words[0]); + CopyBytes<4>(ptr8 + (HWY_IS_LITTLE_ENDIAN ? 4 : 8), &words[1]); + CopyBytes<4>(ptr8 + (HWY_IS_LITTLE_ENDIAN ? 8 : 4), &words[2]); + CopyBytes<4>(ptr8 + (HWY_IS_LITTLE_ENDIAN ? 12 : 0), &words[3]); + // NOLINTNEXTLINE + snprintf(string100, 100, "0x%08x%08x_%08x%08x", words[3], words[2], + words[1], words[0]); + } +} + +HWY_DLLEXPORT void PrintArray(const TypeInfo& info, const char* caption, + const void* array_void, size_t N, size_t lane_u, + size_t max_lanes) { + const uint8_t* array_bytes = reinterpret_cast(array_void); + + char type_name[100]; + TypeName(info, N, type_name); + + const intptr_t lane = intptr_t(lane_u); + const size_t begin = static_cast(HWY_MAX(0, lane - 2)); + const size_t end = HWY_MIN(begin + max_lanes, N); + fprintf(stderr, "%s %s [%d+ ->]:\n ", type_name, caption, + static_cast(begin)); + for (size_t i = begin; i < end; ++i) { + const void* ptr = array_bytes + i * info.sizeof_t; + char str[100]; + ToString(info, ptr, str); + fprintf(stderr, "%s,", str); + } + if (begin >= end) fprintf(stderr, "(out of bounds)"); + fprintf(stderr, "\n"); +} + +} // namespace detail +} // namespace hwy diff --git a/lib/highway/hwy/print.h b/lib/highway/hwy/print.h new file mode 100644 index 00000000000..e61631e650f --- /dev/null +++ b/lib/highway/hwy/print.h @@ -0,0 +1,75 @@ +// Copyright 2022 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HWY_PRINT_H_ +#define HWY_PRINT_H_ + +// Helpers for printing vector lanes. + +#include +#include + +#include "hwy/base.h" +#include "hwy/highway_export.h" + +namespace hwy { + +namespace detail { + +// For implementing value comparisons etc. as type-erased functions to reduce +// template bloat. +struct TypeInfo { + size_t sizeof_t; + bool is_float; + bool is_signed; + bool is_bf16; +}; + +template +HWY_INLINE TypeInfo MakeTypeInfo() { + TypeInfo info; + info.sizeof_t = sizeof(T); + info.is_float = IsFloat(); + info.is_signed = IsSigned(); + info.is_bf16 = IsSame(); + return info; +} + +HWY_DLLEXPORT void TypeName(const TypeInfo& info, size_t N, char* string100); +HWY_DLLEXPORT void ToString(const TypeInfo& info, const void* ptr, + char* string100); + +HWY_DLLEXPORT void PrintArray(const TypeInfo& info, const char* caption, + const void* array_void, size_t N, + size_t lane_u = 0, size_t max_lanes = 7); + +} // namespace detail + +template +HWY_NOINLINE void PrintValue(T value) { + char str[100]; + detail::ToString(hwy::detail::MakeTypeInfo(), &value, str); + fprintf(stderr, "%s,", str); +} + +template +HWY_NOINLINE void PrintArray(const T* value, size_t count) { + detail::PrintArray(hwy::detail::MakeTypeInfo(), "", value, count, 0, + count); +} + +} // namespace hwy + +#endif // HWY_PRINT_H_ diff --git a/lib/highway/hwy/profiler.cc b/lib/highway/hwy/profiler.cc new file mode 100644 index 00000000000..8855d483660 --- /dev/null +++ b/lib/highway/hwy/profiler.cc @@ -0,0 +1,146 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/profiler.h" + +#include "hwy/highway_export.h" // HWY_DLLEXPORT + +#if PROFILER_ENABLED + +#include +#include +#include + +#include "hwy/base.h" +#include "hwy/robust_statistics.h" +#include "hwy/timer.h" + +#endif // PROFILER_ENABLED + +namespace hwy { + +#if PROFILER_ENABLED + +static constexpr bool kPrintOverhead = true; + +// Must zero-init because `ThreadFunc` calls `SetGlobalIdx()` potentially after +// this is first used in the `pool::Worker` ctor. +/*static*/ thread_local size_t Profiler::s_global_idx = 0; + +// Detects duration of a zero-length zone: timer plus packet overhead. +static uint64_t DetectSelfOverhead(Profiler& profiler, size_t global_idx) { + static const profiler::ZoneHandle zone = profiler.AddZone("DetectSelf"); + profiler::Results results; + const size_t kNumSamples = 25; + uint32_t samples[kNumSamples]; + for (size_t idx_sample = 0; idx_sample < kNumSamples; ++idx_sample) { + // Enough for stable measurements, but only about 50 ms startup cost. + const size_t kNumDurations = 700; + uint32_t durations[kNumDurations]; + for (size_t idx_duration = 0; idx_duration < kNumDurations; + ++idx_duration) { + { + PROFILER_ZONE3(profiler, global_idx, zone); + } + durations[idx_duration] = + static_cast(profiler.GetFirstDurationAndReset(global_idx)); + } + samples[idx_sample] = robust_statistics::Mode(durations, kNumDurations); + } + return robust_statistics::Mode(samples, kNumSamples); +} + +// Detects average duration of a zero-length zone, after deducting self +// overhead. This accounts for the delay before/after capturing start/end +// timestamps, for example due to fence instructions in timer::Start/Stop. +static uint64_t DetectChildOverhead(Profiler& profiler, size_t global_idx, + uint64_t self_overhead) { + static const profiler::ZoneHandle zone = profiler.AddZone("DetectChild"); + // Enough for stable measurements, but only about 50 ms startup cost. + const size_t kMaxSamples = 30; + uint32_t samples[kMaxSamples]; + size_t num_samples = 0; + // Upper bound because timer resolution might be too coarse to get nonzero. + for (size_t s = 0; s < 2 * kMaxSamples && num_samples < kMaxSamples; ++s) { + const size_t kNumDurations = 50; + uint32_t durations[kNumDurations]; + for (size_t d = 0; d < kNumDurations; ++d) { + constexpr size_t kReps = 500; + HWY_FENCE; + const uint64_t t0 = timer::Start(); + for (size_t r = 0; r < kReps; ++r) { + PROFILER_ZONE3(profiler, global_idx, zone); + } + const uint64_t t1 = timer::Stop(); + HWY_FENCE; + // We are measuring the total, not individual zone durations, to include + // cross-zone overhead. + (void)profiler.GetFirstDurationAndReset(global_idx); + + const uint64_t avg_duration = (t1 - t0 + kReps / 2) / kReps; + durations[d] = static_cast( + profiler::PerWorker::ClampedSubtract(avg_duration, self_overhead)); + } + samples[num_samples] = robust_statistics::Mode(durations, kNumDurations); + // Overhead is nonzero, but we often measure zero; skip them to prevent + // getting a zero result. + num_samples += (samples[num_samples] != 0); + } + return num_samples == 0 ? 0 : robust_statistics::Mode(samples, num_samples); +} + +Profiler::Profiler() { + const uint64_t t0 = timer::Start(); + + char cpu[100]; + if (HWY_UNLIKELY(!platform::HaveTimerStop(cpu))) { + HWY_ABORT("CPU %s is too old for PROFILER_ENABLED=1, exiting", cpu); + } + + // `ThreadPool` calls `Profiler::Get()` before it creates threads, hence this + // is guaranteed to be running on the main thread. + constexpr size_t kMain = 0; + // Must be called before any use of `PROFILER_ZONE*/PROFILER_FUNC*`. This runs + // only once because `Profiler` is a singleton. + ReserveWorker(kMain); + SetGlobalIdx(kMain); + + profiler::Overheads overheads; + // WARNING: must pass in `*this` and use `PROFILER_ZONE3` to avoid calling + // `Profiler::Get()`, because that would re-enter the magic static init. + overheads.self = DetectSelfOverhead(*this, kMain); + overheads.child = DetectChildOverhead(*this, kMain, overheads.self); + for (size_t worker = 0; worker < profiler::kMaxWorkers; ++worker) { + workers_[worker].SetOverheads(overheads); + } + + HWY_IF_CONSTEXPR(kPrintOverhead) { + printf("Self overhead: %.0f; child: %.0f; elapsed %.1f ms\n", + static_cast(overheads.self), + static_cast(overheads.child), + static_cast(timer::Stop() - t0) / + platform::InvariantTicksPerSecond() * 1E3); + } +} + +#endif // PROFILER_ENABLED + +// Even if disabled, we want to export the symbol. +HWY_DLLEXPORT Profiler& Profiler::Get() { + static Profiler* profiler = new Profiler(); + return *profiler; +} + +} // namespace hwy diff --git a/lib/highway/hwy/profiler.h b/lib/highway/hwy/profiler.h new file mode 100644 index 00000000000..c03d539f8d9 --- /dev/null +++ b/lib/highway/hwy/profiler.h @@ -0,0 +1,875 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_PROFILER_H_ +#define HIGHWAY_HWY_PROFILER_H_ + +#include +#include +#include // strcmp, strlen + +#include +#include + +#include "hwy/base.h" +#include "hwy/highway_export.h" + +// High precision, low overhead time measurements. Returns exact call counts and +// total elapsed time for user-defined 'zones' (code regions, i.e. C++ scopes). +// +// Uses RAII to capture begin/end timestamps, with user-specified zone names: +// `{ PROFILER_ZONE("name"); /*code*/ }` or the name of the current function: +// `void FuncToMeasure() { PROFILER_FUNC; /*code*/ }`. +// +// You can reduce the overhead by passing `global_idx`, which can be taken from +// the argument to the `ThreadPool::Run` lambda (if the pool was constructed +// with non-default `PoolWorkerMapping`), or from a saved copy of the +// thread-local `Profiler::Thread`: `PROFILER_ZONE2(global_idx, name)`. +// +// The preferred API allows passing flags, such as requesting inclusive time: +// `static const auto zone = profiler.AddZone("name", flags);` and then +// `PROFILER_ZONE3(profiler, global_idx, zone)`. +// +// After all threads exit all zones, call `Profiler::Get().PrintResults()` to +// print call counts and average durations [CPU cycles] to stdout, sorted in +// descending order of total duration. + +// If zero, mock `Profiler` and `profiler::Zone` will be defined. +#ifndef PROFILER_ENABLED +#define PROFILER_ENABLED 0 +#endif + +#if PROFILER_ENABLED +#include + +#include // std::sort +#include +#include + +#include "hwy/aligned_allocator.h" +#include "hwy/bit_set.h" +#include "hwy/timer.h" +#endif // PROFILER_ENABLED + +namespace hwy { + +// Flags: we want type-safety (enum class) to catch mistakes such as confusing +// zone with flags. Base type (`uint32_t`) ensures it is safe to cast. Defined +// outside the `#if` because callers pass them to `PROFILER_ZONE3`. When adding +// flags, also update `kNumFlags` and `ChildTotalMask`. +enum class ProfilerFlags : uint32_t { + kDefault = 0, + // The zone should report cumulative time, including all child zones. If not + // specified, zones report self-time, excluding child zones. + kInclusive = 1 +}; + +// Called during `PrintResults` to print results from other modules. +using ProfilerFunc = std::function; + +template +class StringTable { + static constexpr std::memory_order kRelaxed = std::memory_order_relaxed; + static constexpr std::memory_order kAcq = std::memory_order_acquire; + static constexpr std::memory_order kRel = std::memory_order_release; + + public: + // Returns a copy of the `name` passed to `Add` that returned the + // given `idx`. + const char* Name(size_t idx) const { + // `kAcq` so that the string contents are also visible after the pointer is + // published via `kRelease` store. + return ptrs_[idx].load(kAcq); + } + + // Returns `idx < kMaxStrings`. Can be called concurrently. Calls with the + // same `name` return the same `idx`. + size_t Add(const char* name) { + // Linear search if it already exists. `kAcq` ensures we see prior stores. + const size_t num_strings = next_ptr_.load(kAcq); + HWY_ASSERT(num_strings < kMaxStrings); + for (size_t idx = 1; idx < num_strings; ++idx) { + const char* existing = ptrs_[idx].load(kAcq); + // `next_ptr_` was published after writing `ptr_`, hence it is non-null. + HWY_ASSERT(existing != nullptr); + if (HWY_UNLIKELY(!strcmp(existing, name))) { + return idx; + } + } + + // Copy `name` into `chars_` before publishing the pointer. + const size_t len = strlen(name) + 1; + const size_t pos = next_char_.fetch_add(len, kRelaxed); + HWY_ASSERT(pos + len <= sizeof(chars_)); + strcpy(chars_ + pos, name); // NOLINT + + for (;;) { + size_t idx = next_ptr_.load(kRelaxed); + HWY_ASSERT(idx < kMaxStrings); + + // Attempt to claim the next `idx` via CAS. + const char* expected = nullptr; + if (HWY_LIKELY(ptrs_[idx].compare_exchange_weak(expected, chars_ + pos, + kRel, kRelaxed))) { + // Publish the new count and make the `ptrs_` write visible. + next_ptr_.store(idx + 1, kRel); + HWY_DASSERT(!strcmp(Name(idx), name)); + return idx; + } + + // We lost the race. `expected` has been updated. + if (HWY_UNLIKELY(!strcmp(expected, name))) { + // Done, another thread added the same name. Note that we waste the + // extra space in `chars_`, which is fine because it is rare. + HWY_DASSERT(!strcmp(Name(idx), name)); + return idx; + } + + // Other thread added a different name. Retry with the next slot. + } + } + + private: + std::atomic ptrs_[kMaxStrings]; + std::atomic next_ptr_{1}; // next idx + std::atomic next_char_{0}; + char chars_[kMaxStrings * 55]; +}; + +#if PROFILER_ENABLED + +// Implementation details. +namespace profiler { + +HWY_INLINE_VAR constexpr size_t kNumFlags = 1; + +// Upper bounds for fixed-size data structures, guarded via HWY_DASSERT: + +// Maximum nesting of zones, chosen such that `PerWorker` is 256 bytes. +HWY_INLINE_VAR constexpr size_t kMaxDepth = 13; +// Reports with more than ~50 are anyway difficult to read. +HWY_INLINE_VAR constexpr size_t kMaxZones = 128; +// Upper bound on global worker_idx across all pools. Note that fiber libraries +// can spawn hundreds of threads. Turin has 128-192 cores. +HWY_INLINE_VAR constexpr size_t kMaxWorkers = 256; + +// Type-safe wrapper for zone index plus flags, returned by `AddZone`. +class ZoneHandle { + public: + ZoneHandle() : bits_(0) {} // for Accumulator member initialization + + ZoneHandle(size_t zone_idx, ProfilerFlags flags) { + HWY_DASSERT(0 != zone_idx && zone_idx < kMaxZones); + const uint32_t flags_u = static_cast(flags); + HWY_DASSERT(flags_u < (1u << kNumFlags)); + bits_ = (static_cast(zone_idx) << kNumFlags) | flags_u; + HWY_DASSERT(ZoneIdx() == zone_idx); + } + + ZoneHandle(const ZoneHandle& other) = default; + ZoneHandle& operator=(const ZoneHandle& other) = default; + + bool operator==(const ZoneHandle other) const { return bits_ == other.bits_; } + bool operator!=(const ZoneHandle other) const { return bits_ != other.bits_; } + + size_t ZoneIdx() const { + HWY_DASSERT(bits_ != 0); + const size_t zone_idx = bits_ >> kNumFlags; + HWY_DASSERT(0 != zone_idx && zone_idx < kMaxZones); + return zone_idx; + } + + bool IsInclusive() const { + HWY_DASSERT(bits_ != 0); + return (bits_ & static_cast(ProfilerFlags::kInclusive)) != 0; + } + + // Returns a mask to zero/ignore child totals for inclusive zones. + uint64_t ChildTotalMask() const { + // With a ternary operator, clang tends to generate a branch. + // return IsInclusive() ? 0 : ~uint64_t{0}; + const uint32_t bit = + bits_ & static_cast(ProfilerFlags::kInclusive); + return uint64_t{bit} - 1; + } + + private: + uint32_t bits_; +}; + +// Storage for zone names. +class Zones { + public: + // Returns a copy of the `name` passed to `AddZone` that returned the + // given `zone`. + const char* Name(ZoneHandle zone) const { + return strings_.Name(zone.ZoneIdx()); + } + + // Can be called concurrently. Calls with the same `name` return the same + // `ZoneHandle.ZoneIdx()`. + ZoneHandle AddZone(const char* name, ProfilerFlags flags) { + return ZoneHandle(strings_.Add(name), flags); + } + + private: + StringTable strings_; +}; + +// Allows other classes such as `ThreadPool` to register/unregister a function +// to call during `PrintResults`. This allows us to gather data from the worker +// threads without having to wait until they exit, and decouples the profiler +// from other modules. Thread-safe. +class Funcs { + static constexpr auto kAcq = std::memory_order_acquire; + static constexpr auto kRel = std::memory_order_release; + + public: + // Can be called concurrently with distinct keys. + void Add(intptr_t key, ProfilerFunc func) { + HWY_ASSERT(key != 0 && key != kPending); // reserved values + HWY_ASSERT(func); // not empty + + for (size_t i = 0; i < kMaxFuncs; ++i) { + intptr_t expected = 0; + // Lost a race with a concurrent `Add`, try the next slot. + if (!keys_[i].compare_exchange_strong(expected, kPending, kRel)) { + continue; + } + // We own the slot: move func there. + funcs_[i] = std::move(func); + keys_[i].store(key, kRel); // publishes the `func` write. + return; + } + + HWY_ABORT("Funcs::Add: no free slot, increase kMaxFuncs."); + } + + // Can be called concurrently with distinct keys. It is an error to call this + // without a prior `Add` of the same key. + void Remove(intptr_t key) { + HWY_ASSERT(key != 0 && key != kPending); // reserved values + + for (size_t i = 0; i < kMaxFuncs; ++i) { + intptr_t actual = keys_[i].load(kAcq); + if (actual == key) { + // In general, concurrent removal is fine, but in this specific context, + // owners are expected to remove their key exactly once, from the same + // thread that added it. In that case, CAS should not fail. + if (!keys_[i].compare_exchange_strong(actual, kPending, kRel)) { + HWY_WARN("Funcs: CAS failed, why is there a concurrent Remove?"); + } + funcs_[i] = ProfilerFunc(); + keys_[i].store(0, kRel); // publishes the `func` write. + return; + } + } + HWY_ABORT("Funcs::Remove: failed to find key %p.", + reinterpret_cast(key)); + } + + void CallAll() const { + for (size_t i = 0; i < kMaxFuncs; ++i) { + intptr_t key = keys_[i].load(kAcq); // ensures `funcs_` is visible. + // Safely handles concurrent Add/Remove. + if (key != 0 && key != kPending) { + funcs_[i](); + } + } + } + + private: + static constexpr size_t kMaxFuncs = 64; + static constexpr intptr_t kPending = -1; + + ProfilerFunc funcs_[kMaxFuncs]; // non-atomic + std::atomic keys_[kMaxFuncs] = {}; +}; + +// Holds total duration and number of calls. Worker index is implicit in the +// index of this class within the `Accumulators` array. +struct Accumulator { + void Add(ZoneHandle new_zone, uint64_t self_duration) { + duration += self_duration; + + // Only called for valid zones. + HWY_DASSERT(new_zone != ZoneHandle()); + // Our zone might not have been set yet. + HWY_DASSERT(zone == ZoneHandle() || zone == new_zone); + zone = new_zone; + + num_calls += 1; + } + + void Take(Accumulator& other) { + duration += other.duration; + other.duration = 0; + + // `ZoneSet` ensures we only call this for non-empty `other`. + HWY_DASSERT(other.zone != ZoneHandle()); + // Our zone might not have been set yet. + HWY_DASSERT(zone == ZoneHandle() || zone == other.zone); + zone = other.zone; + + num_calls += other.num_calls; + other.num_calls = 0; + } + + uint64_t duration = 0; + ZoneHandle zone; // flags are used by `Results::Print` + uint32_t num_calls = 0; +}; +static_assert(sizeof(Accumulator) == 16, "Wrong Accumulator size"); + +using ZoneSet = hwy::BitSet; +using WorkerSet = hwy::BitSet; +using AtomicWorkerSet = hwy::AtomicBitSet; + +// Durations are per-CPU, but end to end performance is defined by wall time. +// Assuming fork-join parallelism, zones are entered by multiple threads +// concurrently, which means the total number of unique threads is also the +// degree of concurrency, so we can estimate wall time as CPU time divided by +// the number of unique threads seen. This is facilitated by unique `global_idx` +// passed in by callers, or taken from thread-local `GlobalIdx()`. +// +// We also want to support varying thread counts per call site, because the same +// function/zone may be called from multiple pools. `EndRootRun` calls +// `CountWorkersAndReset` after each top-level `ThreadPool::Run`, which +// generates one data point summarized via descriptive statistics. Here we +// implement a simpler version of `Stats` because we do not require +// geomean/variance/kurtosis/skewness. Because concurrency is a small integer, +// we can simply compute sums rather than online moments. There is also only one +// instance across all threads, hence we do not require a `Take`. +// +// Note that subsequently discovered prior work estimates the number of active +// and idle processors by updating atomic counters whenever they start/finish a +// task: https://homes.cs.washington.edu/~tom/pubs/quartz.pdf and "Effective +// performance measurement and analysis of multithreaded applications". We +// instead accumulate zone durations into per-thread storage. +// `CountWorkersAndReset` then checks how many were nonzero, which avoids +// expensive atomic updates and ensures accurate counts per-zone, rather than +// estimates of current activity at each sample. +// D. Vyukov's https://github.com/dvyukov/perf-load, also integrated into Linux +// perf, also corrects for parallelism without using atomic counters by tracing +// context switches. Note that we often pin threads, which avoids migrations, +// but reduces the number of context switch events to mainly preemptions. +class ConcurrencyStats { + public: + ConcurrencyStats() { Reset(); } + + void Notify(const size_t x) { + sum_ += x; + ++n_; + min_ = HWY_MIN(min_, x); + max_ = HWY_MAX(max_, x); + } + + size_t Count() const { return n_; } + size_t Min() const { return min_; } + size_t Max() const { return max_; } + double Mean() const { + return static_cast(sum_) / static_cast(n_); + } + + void Reset() { + sum_ = 0; + n_ = 0; + min_ = hwy::HighestValue(); + max_ = hwy::LowestValue(); + } + + private: + uint64_t sum_; + size_t n_; + size_t min_; + size_t max_; +}; +static_assert(sizeof(ConcurrencyStats) == (8 + 3 * sizeof(size_t)), ""); + +// Holds the final results across all threads, including `ConcurrencyStats` +// and `PoolStats`, updated/printed by the main thread. +class Results { + public: + void TakeAccumulator(const size_t global_idx, const size_t zone_idx, + Accumulator& other) { + HWY_DASSERT(global_idx < kMaxWorkers); + HWY_DASSERT(zone_idx < kMaxZones); + HWY_DASSERT(other.zone.ZoneIdx() == zone_idx); + + visited_zones_.Set(zone_idx); + totals_[zone_idx].Take(other); + workers_[zone_idx].Set(global_idx); + } + + // Moves the total number of threads seen during the preceding root-level + // `ThreadPool::Run` into one data point for `ConcurrencyStats`. + void CountWorkersAndReset(const size_t zone_idx) { + HWY_DASSERT(zone_idx < kMaxZones); + const size_t num_workers = workers_[zone_idx].Count(); + // Although workers_[zone_idx] at one point was non-empty, it is reset + // below, and so can be empty on the second call to this via `PrintResults`, + // after one from `EndRootRun`. Do not add a data point if empty. + if (num_workers != 0) { + concurrency_[zone_idx].Notify(num_workers); + } + workers_[zone_idx] = WorkerSet(); + } + + void CountWorkersAndReset() { + visited_zones_.Foreach( + [&](size_t zone_idx) { CountWorkersAndReset(zone_idx); }); + } + + void PrintAndReset(const Zones& zones) { + const double inv_freq = 1.0 / hwy::platform::InvariantTicksPerSecond(); + + // Sort by decreasing total (self) cost. `totals_` are sparse, so sort an + // index vector instead. + std::vector indices; + indices.reserve(visited_zones_.Count()); + visited_zones_.Foreach([&](size_t zone_idx) { + indices.push_back(static_cast(zone_idx)); + // In case the zone exited after `EndRootRun` and was not yet added. + CountWorkersAndReset(zone_idx); + }); + std::sort(indices.begin(), indices.end(), [&](uint32_t a, uint32_t b) { + return totals_[a].duration > totals_[b].duration; + }); + printf(" %-40s: %10s x %15s / %5s (%5s %3s-%3s) = %9s\n", "Zone", "Calls", + "Cycles/Call", "Avg Count", "Count", "Min", "Max", "Wall Time(s)"); + + for (uint32_t zone_idx : indices) { + Accumulator& total = totals_[zone_idx]; // cleared after printing + HWY_ASSERT(total.zone.ZoneIdx() == zone_idx); + HWY_ASSERT(total.num_calls != 0); // else visited_zones_ is wrong + + ConcurrencyStats& concurrency = concurrency_[zone_idx]; + const double duration = static_cast(total.duration); + const double per_call = + static_cast(total.duration) / total.num_calls; + // See comment on `ConcurrencyStats`. + const double avg_concurrency = concurrency.Mean(); + // Avoid division by zero. + const double concurrency_divisor = HWY_MAX(1.0, avg_concurrency); + printf("%s%-40s: %10.0f x %15.0f / %5.1f (%5zu %3zu-%3zu) = %9.6f\n", + total.zone.IsInclusive() ? "(I)" : " ", zones.Name(total.zone), + static_cast(total.num_calls), per_call, avg_concurrency, + concurrency.Count(), concurrency.Min(), concurrency.Max(), + duration * inv_freq / concurrency_divisor); + + total = Accumulator(); + concurrency.Reset(); + // `workers_` was already reset by `CountWorkersAndReset`. + } + visited_zones_ = ZoneSet(); + } + + private: + // Indicates which of the array entries are in use. + ZoneSet visited_zones_; + Accumulator totals_[kMaxZones]; + WorkerSet workers_[kMaxZones]; + ConcurrencyStats concurrency_[kMaxZones]; +}; + +// Delay after capturing timestamps before/after the actual zone runs. Even +// with frequency throttling disabled, this has a multimodal distribution, +// including 32, 34, 48, 52, 59, 62. +struct Overheads { + uint64_t self = 0; + uint64_t child = 0; +}; +static_assert(sizeof(Overheads) == 16, "Wrong Overheads size"); + +class Accumulators { + // We generally want to group threads together because they are often + // accessed together during a zone, but also want to avoid threads sharing a + // cache line. Hence interleave 8 zones per worker. + static constexpr size_t kPerLine = HWY_ALIGNMENT / sizeof(Accumulator); + + public: + Accumulator& Get(const size_t global_idx, const size_t zone_idx) { + HWY_DASSERT(global_idx < kMaxWorkers); + HWY_DASSERT(zone_idx < kMaxZones); + const size_t line = zone_idx / kPerLine; + const size_t offset = zone_idx % kPerLine; + return zones_[(line * kMaxWorkers + global_idx) * kPerLine + offset]; + } + + private: + Accumulator zones_[kMaxZones * kMaxWorkers]; +}; + +// Reacts to zone enter/exit events. Builds a stack of active zones and +// accumulates self/child duration for each. +class PerWorker { + public: + template + static T ClampedSubtract(const T minuend, const T subtrahend) { + static_assert(IsUnsigned(), ""); + const T difference = minuend - subtrahend; + // Clang output for this is verified to CMOV rather than branch. + const T no_underflow = (subtrahend > minuend) ? T{0} : ~T{0}; + return difference & no_underflow; + } + + void SetOverheads(const Overheads& overheads) { overheads_ = overheads; } + + // Entering a zone: push onto stack. + void Enter(const uint64_t t_enter) { + const size_t depth = depth_; + HWY_DASSERT(depth < kMaxDepth); + t_enter_[depth] = t_enter; + child_total_[1 + depth] = 0; + depth_ = 1 + depth; + } + + // Exiting the most recently entered zone (top of stack). + void Exit(const uint64_t t_exit, const size_t global_idx, + const ZoneHandle zone, Accumulators& accumulators) { + HWY_DASSERT(depth_ > 0); + const size_t depth = depth_ - 1; + const size_t zone_idx = zone.ZoneIdx(); + const uint64_t duration = t_exit - t_enter_[depth]; + // Clang output for this is verified not to branch. This is 0 if inclusive, + // otherwise the child total. + const uint64_t child_total = + child_total_[1 + depth] & zone.ChildTotalMask(); + + const uint64_t self_duration = ClampedSubtract( + duration, overheads_.self + overheads_.child + child_total); + accumulators.Get(global_idx, zone_idx).Add(zone, self_duration); + // For faster TakeAccumulator() - not all zones are encountered. + visited_zones_.Set(zone_idx); + + // Adding this nested time to the parent's `child_total` will + // cause it to be later subtracted from the parent's `self_duration`. + child_total_[1 + depth - 1] += duration + overheads_.child; + + depth_ = depth; + } + + // Returns the duration of one enter/exit pair and resets all state. Called + // via `DetectSelfOverhead`. + uint64_t GetFirstDurationAndReset(size_t global_idx, + Accumulators& accumulators) { + HWY_DASSERT(depth_ == 0); + + HWY_DASSERT(visited_zones_.Count() == 1); + const size_t zone_idx = visited_zones_.First(); + HWY_DASSERT(zone_idx <= 3); + HWY_DASSERT(visited_zones_.Get(zone_idx)); + visited_zones_.Clear(zone_idx); + + Accumulator& zone = accumulators.Get(global_idx, zone_idx); + const uint64_t duration = zone.duration; + zone = Accumulator(); + return duration; + } + + // Adds all data to `results` and resets it here. Called from the main thread. + void MoveTo(const size_t global_idx, Accumulators& accumulators, + Results& results) { + visited_zones_.Foreach([&](size_t zone_idx) { + results.TakeAccumulator(global_idx, zone_idx, + accumulators.Get(global_idx, zone_idx)); + }); + // OK to reset even if we have active zones, because we set `visited_zones_` + // when exiting the zone. + visited_zones_ = ZoneSet(); + } + + private: + // 40 bytes: + ZoneSet visited_zones_; // Which `zones_` have been active on this worker. + uint64_t depth_ = 0; // Current nesting level for active zones. + Overheads overheads_; + + uint64_t t_enter_[kMaxDepth]; + // Used to deduct child duration from parent's self time (unless inclusive). + // Shifting by one avoids bounds-checks for depth_ = 0 (root zone). + uint64_t child_total_[1 + kMaxDepth] = {0}; +}; +// Enables shift rather than multiplication. +static_assert(sizeof(PerWorker) == 256, "Wrong size"); + +} // namespace profiler + +class Profiler { + public: + static HWY_DLLEXPORT Profiler& Get(); + + // Returns `global_idx` from thread-local storage (0 for the main thread). + // Used by `PROFILER_ZONE/PROFILER_FUNC`. It is faster to instead pass the + // global_idx from `ThreadPool::Run` (if constructed with non-default + // `PoolWorkerMapping`) to `PROFILER_ZONE2/PROFILER_ZONE3`. + // DEPRECATED: use `GlobalIdx` instead. + static size_t Thread() { return s_global_idx; } + static size_t GlobalIdx() { return s_global_idx; } + // Must be called from all worker threads, and once also on the main thread, + // before any use of `PROFILER_ZONE/PROFILER_FUNC`. + static void SetGlobalIdx(size_t global_idx) { s_global_idx = global_idx; } + + void ReserveWorker(size_t global_idx) { + HWY_ASSERT(!workers_reserved_.Get(global_idx)); + workers_reserved_.Set(global_idx); + } + + void FreeWorker(size_t global_idx) { + HWY_ASSERT(workers_reserved_.Get(global_idx)); + workers_reserved_.Clear(global_idx); + } + + // Called by `Zone` from any thread. + void Enter(uint64_t t_enter, size_t global_idx) { + GetWorker(global_idx).Enter(t_enter); + } + + // Called by `~Zone` from any thread. + void Exit(uint64_t t_exit, size_t global_idx, profiler::ZoneHandle zone) { + GetWorker(global_idx).Exit(t_exit, global_idx, zone, accumulators_); + } + + uint64_t GetFirstDurationAndReset(size_t global_idx) { + return GetWorker(global_idx) + .GetFirstDurationAndReset(global_idx, accumulators_); + } + + const char* Name(profiler::ZoneHandle zone) const { + return zones_.Name(zone); + } + + // Copies `name` into the string table and returns its unique `zone`. Uses + // linear search, which is fine because this is called during static init. + // Called via static initializer and the result is passed to the `Zone` ctor. + profiler::ZoneHandle AddZone(const char* name, + ProfilerFlags flags = ProfilerFlags::kDefault) { + return zones_.AddZone(name, flags); + } + + void AddFunc(void* owner, ProfilerFunc func) { + funcs_.Add(reinterpret_cast(owner), func); + } + void RemoveFunc(void* owner) { + funcs_.Remove(reinterpret_cast(owner)); + } + + // For reporting average concurrency. Called by `ThreadPool::Run` on the main + // thread, returns true if this is the first call since the last `EndRootRun`. + // + // We want to report the concurrency of each separate 'invocation' of a zone. + // A unique per-call identifier (could be approximated with the line number + // and return address) is not sufficient because the caller may in turn be + // called from differing parallel sections. A per-`ThreadPool::Run` counter + // also under-reports concurrency because each pool in nested parallelism + // (over packages and CCXes) would be considered separate invocations. + // + // The alternative of detecting overlapping zones via timestamps is not 100% + // reliable because timers may not be synchronized across sockets or perhaps + // even cores. "Invariant" x86 TSCs are indeed synchronized across cores, but + // not across sockets unless the RESET# signal reaches each at the same time. + // Linux seems to make an effort to correct this, and Arm's "generic timer" + // broadcasts to "all cores", but there is no universal guarantee. + // + // Under the assumption that all concurrency is via our `ThreadPool`, we can + // record all `global_idx` for each outermost (root) `ThreadPool::Run`. This + // collapses all nested pools into one 'invocation'. We then compute per-zone + // concurrency as the number of unique `global_idx` seen per invocation. + bool IsRootRun() { + // We are not the root if a Run was already active. + return !run_active_.test_and_set(std::memory_order_acquire); + } + + // Must be called if `IsRootRun` returned true. Resets the state so that the + // next call to `IsRootRun` will again return true. Called from main thread. + // Note that some zones may still be active. Their concurrency will be updated + // when `PrintResults` is called. + void EndRootRun() { + UpdateResults(); + results_.CountWorkersAndReset(); + + run_active_.clear(std::memory_order_release); + } + + // Prints results. Call from main thread after all threads have exited all + // zones. Resets all state, can be called again after more zones. + void PrintResults() { + UpdateResults(); + // `CountWorkersAndReset` is fused into `Print`, so do not call it here. + + results_.PrintAndReset(zones_); + + funcs_.CallAll(); + } + + // TODO: remove when no longer called. + void SetMaxThreads(size_t) {} + + private: + // Sets main thread index, computes self-overhead, and checks timer support. + Profiler(); + + profiler::PerWorker& GetWorker(size_t global_idx) { + HWY_DASSERT(workers_reserved_.Get(global_idx)); + return workers_[global_idx]; + } + + // Moves accumulators into Results. Called from the main thread. + void UpdateResults() { + // Ensure we see all writes from before the workers' release fence. + std::atomic_thread_fence(std::memory_order_acquire); + + workers_reserved_.Foreach([&](size_t global_idx) { + workers_[global_idx].MoveTo(global_idx, accumulators_, results_); + }); + } + + static thread_local size_t s_global_idx; + + // These are atomic because `ThreadFunc` reserves its slot(s) and even + // `ThreadPool::ThreadPool` may be called concurrently. Both have bit `i` set + // between calls to `Reserve*(i)` and `Free*(i)`. They are consulted in + // `UpdateResults` and to validate arguments in debug builds, and only updated + // in the pool/thread init/shutdown. + profiler::AtomicWorkerSet workers_reserved_; + + std::atomic_flag run_active_ = ATOMIC_FLAG_INIT; + + profiler::Funcs funcs_; + + // To avoid locking, each worker has its own working set. We could access this + // through `thread_local` pointers, but that is slow to read on x86. Because + // our `ThreadPool` anyway passes a `global_idx` argument, we can instead pass + // that through the `PROFILER_ZONE2/PROFILER_ZONE3` macros. + profiler::PerWorker workers_[profiler::kMaxWorkers]; + + profiler::Accumulators accumulators_; + + profiler::Results results_; + + profiler::Zones zones_; +}; + +namespace profiler { + +// RAII for zone entry/exit. +class Zone { + public: + // Thread-compatible; must not call concurrently with the same `global_idx`, + // which is either: + // - passed from `ThreadPool::Run` (if it was constructed with non-default + // `PoolWorkerMapping`) to `PROFILER_ZONE2/PROFILER_ZONE3`; + // - obtained from `Profiler::GlobalIdx()`; or + // - 0 if running on the main thread. + Zone(Profiler& profiler, size_t global_idx, ZoneHandle zone) + : profiler_(profiler) { + HWY_FENCE; + const uint64_t t_enter = timer::Start(); + HWY_FENCE; + global_idx_ = static_cast(global_idx); + zone_ = zone; + profiler.Enter(t_enter, global_idx); + HWY_FENCE; + } + + ~Zone() { + HWY_FENCE; + const uint64_t t_exit = timer::Stop(); + profiler_.Exit(t_exit, static_cast(global_idx_), zone_); + HWY_FENCE; + } + + private: + Profiler& profiler_; + uint32_t global_idx_; + ZoneHandle zone_; +}; + +} // namespace profiler +#else // profiler disabled: stub implementation + +namespace profiler { +struct ZoneHandle {}; +} // namespace profiler + +struct Profiler { + static HWY_DLLEXPORT Profiler& Get(); + + // DEPRECATED: use `GlobalIdx` instead. + static size_t Thread() { return 0; } + static size_t GlobalIdx() { return 0; } + static void SetGlobalIdx(size_t) {} + void ReserveWorker(size_t) {} + void FreeWorker(size_t) {} + void Enter(uint64_t, size_t) {} + void Exit(uint64_t, size_t, profiler::ZoneHandle) {} + uint64_t GetFirstDurationAndReset(size_t) { return 0; } + + const char* Name(profiler::ZoneHandle) const { return nullptr; } + profiler::ZoneHandle AddZone(const char*, + ProfilerFlags = ProfilerFlags::kDefault) { + return profiler::ZoneHandle(); + } + + void AddFunc(void*, ProfilerFunc) {} + void RemoveFunc(void*) {} + + bool IsRootRun() { return false; } + void EndRootRun() {} + void PrintResults() {} + + // TODO: remove when no longer called. + void SetMaxThreads(size_t) {} +}; + +namespace profiler { +struct Zone { + Zone(Profiler&, size_t, ZoneHandle) {} +}; + +} // namespace profiler +#endif // PROFILER_ENABLED || HWY_IDE + +} // namespace hwy + +// Creates a `Zone` lvalue with a line-dependent name, which records the elapsed +// time from here until the end of the current scope. `p` is from +// `Profiler::Get()` or a cached reference. `global_idx < kMaxWorkers`. `zone` +// is the return value of `AddZone`. Separating its static init from the `Zone` +// may be more efficient than `PROFILER_ZONE2`. +#define PROFILER_ZONE3(p, global_idx, zone) \ + HWY_FENCE; \ + const hwy::profiler::Zone HWY_CONCAT(Z, __LINE__)(p, global_idx, zone); \ + HWY_FENCE + +// For compatibility with old callers that do not pass `p` nor `flags`. +// Also calls AddZone. Usage: `PROFILER_ZONE2(global_idx, "MyZone");` +#define PROFILER_ZONE2(global_idx, name) \ + static const hwy::profiler::ZoneHandle HWY_CONCAT(zone, __LINE__) = \ + hwy::Profiler::Get().AddZone(name); \ + PROFILER_ZONE3(hwy::Profiler::Get(), global_idx, HWY_CONCAT(zone, __LINE__)) +#define PROFILER_FUNC2(global_idx) PROFILER_ZONE2(global_idx, __func__) + +// OBSOLETE: it is more efficient to pass `global_idx` from `ThreadPool` to +// `PROFILER_ZONE2/PROFILER_ZONE3`. Here we get it from thread_local storage. +#define PROFILER_ZONE(name) PROFILER_ZONE2(hwy::Profiler::GlobalIdx(), name) +#define PROFILER_FUNC PROFILER_FUNC2(hwy::Profiler::GlobalIdx()) + +// DEPRECATED: Use `hwy::Profiler::Get()` directly instead. +#define PROFILER_ADD_ZONE(name) hwy::Profiler::Get().AddZone(name) +#define PROFILER_IS_ROOT_RUN() hwy::Profiler::Get().IsRootRun() +#define PROFILER_END_ROOT_RUN() hwy::Profiler::Get().EndRootRun() +#define PROFILER_PRINT_RESULTS() hwy::Profiler::Get().PrintResults() + +#endif // HIGHWAY_HWY_PROFILER_H_ diff --git a/lib/highway/hwy/robust_statistics.h b/lib/highway/hwy/robust_statistics.h new file mode 100644 index 00000000000..08b444f408f --- /dev/null +++ b/lib/highway/hwy/robust_statistics.h @@ -0,0 +1,150 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_ROBUST_STATISTICS_H_ +#define HIGHWAY_HWY_ROBUST_STATISTICS_H_ + +#include // std::sort, std::find_if +#include +#include // std::pair +#include + +#include "hwy/base.h" + +namespace hwy { +namespace robust_statistics { + +// Sorts integral values in ascending order (e.g. for Mode). About 3x faster +// than std::sort for input distributions with very few unique values. +template +void CountingSort(T* values, size_t num_values) { + // Unique values and their frequency (similar to flat_map). + using Unique = std::pair; + std::vector unique; + for (size_t i = 0; i < num_values; ++i) { + const T value = values[i]; + const auto pos = + std::find_if(unique.begin(), unique.end(), + [value](const Unique u) { return u.first == value; }); + if (pos == unique.end()) { + unique.push_back(std::make_pair(value, 1)); + } else { + ++pos->second; + } + } + + // Sort in ascending order of value (pair.first). + std::sort(unique.begin(), unique.end()); + + // Write that many copies of each unique value to the array. + T* HWY_RESTRICT p = values; + for (const auto& value_count : unique) { + std::fill(p, p + value_count.second, value_count.first); + p += value_count.second; + } + HWY_ASSERT(p == values + num_values); +} + +// @return i in [idx_begin, idx_begin + half_count) that minimizes +// sorted[i + half_count] - sorted[i]. +template +size_t MinRange(const T* const HWY_RESTRICT sorted, const size_t idx_begin, + const size_t half_count) { + T min_range = std::numeric_limits::max(); + size_t min_idx = 0; + + for (size_t idx = idx_begin; idx < idx_begin + half_count; ++idx) { + HWY_ASSERT(sorted[idx] <= sorted[idx + half_count]); + const T range = sorted[idx + half_count] - sorted[idx]; + if (range < min_range) { + min_range = range; + min_idx = idx; + } + } + + return min_idx; +} + +// Returns an estimate of the mode by calling MinRange on successively +// halved intervals. "sorted" must be in ascending order. This is the +// Half Sample Mode estimator proposed by Bickel in "On a fast, robust +// estimator of the mode", with complexity O(N log N). The mode is less +// affected by outliers in highly-skewed distributions than the median. +// The averaging operation below assumes "T" is an unsigned integer type. +template +T ModeOfSorted(const T* const HWY_RESTRICT sorted, const size_t num_values) { + size_t idx_begin = 0; + size_t half_count = num_values / 2; + while (half_count > 1) { + idx_begin = MinRange(sorted, idx_begin, half_count); + half_count >>= 1; + } + + const T x = sorted[idx_begin + 0]; + if (half_count == 0) { + return x; + } + HWY_ASSERT(half_count == 1); + const T average = (x + sorted[idx_begin + 1] + 1) / 2; + return average; +} + +// Returns the mode. Side effect: sorts "values". +template +T Mode(T* values, const size_t num_values) { + CountingSort(values, num_values); + return ModeOfSorted(values, num_values); +} + +template +T Mode(T (&values)[N]) { + return Mode(&values[0], N); +} + +// Returns the median value. Side effect: sorts "values". +template +T Median(T* values, const size_t num_values) { + HWY_ASSERT(num_values != 0); + std::sort(values, values + num_values); + const size_t half = num_values / 2; + // Odd count: return middle + if (num_values % 2) { + return values[half]; + } + // For integers, round rather than truncate. + const T bias = hwy::IsInteger() ? T{1} : T{0}; + // Even count: return average of middle two. + return (values[half] + values[half - 1] + bias) / 2; +} + +// Returns a robust measure of variability. +template +T MedianAbsoluteDeviation(const T* values, const size_t num_values, + const T median) { + HWY_ASSERT(num_values != 0); + std::vector abs_deviations; + abs_deviations.reserve(num_values); + for (size_t i = 0; i < num_values; ++i) { + const int64_t abs = ScalarAbs(static_cast(values[i]) - + static_cast(median)); + abs_deviations.push_back(static_cast(abs)); + } + return Median(abs_deviations.data(), num_values); +} + +} // namespace robust_statistics +} // namespace hwy + +#endif // HIGHWAY_HWY_ROBUST_STATISTICS_H_ diff --git a/lib/highway/hwy/stats.cc b/lib/highway/hwy/stats.cc new file mode 100644 index 00000000000..b9f76cb082b --- /dev/null +++ b/lib/highway/hwy/stats.cc @@ -0,0 +1,115 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/stats.h" + +#include + +#include // std::min +#include + +#include "hwy/base.h" // HWY_ASSERT + +namespace hwy { + +void Stats::Assimilate(const Stats& other) { + const int64_t total_n = n_ + other.n_; + if (total_n == 0) return; // Nothing to do; prevents div by zero. + + min_ = std::min(min_, other.min_); + max_ = std::max(max_, other.max_); + + sum_log_ += other.sum_log_; + + const double product_n = n_ * other.n_; + const double n2 = n_ * n_; + const double other_n2 = other.n_ * other.n_; + const int64_t total_n2 = total_n * total_n; + const double total_n3 = static_cast(total_n2) * total_n; + // Precompute reciprocal for speed - used at least twice. + const double inv_total_n = 1.0 / total_n; + const double inv_total_n2 = 1.0 / total_n2; + + const double delta = other.m1_ - m1_; + const double delta2 = delta * delta; + const double delta3 = delta * delta2; + const double delta4 = delta2 * delta2; + + m1_ = (n_ * m1_ + other.n_ * other.m1_) * inv_total_n; + + const double new_m2 = m2_ + other.m2_ + delta2 * product_n * inv_total_n; + + const double new_m3 = + m3_ + other.m3_ + delta3 * product_n * (n_ - other.n_) * inv_total_n2 + + 3.0 * delta * (n_ * other.m2_ - other.n_ * m2_) * inv_total_n; + + m4_ += other.m4_ + + delta4 * product_n * (n2 - product_n + other_n2) / total_n3 + + 6.0 * delta2 * (n2 * other.m2_ + other_n2 * m2_) * inv_total_n2 + + 4.0 * delta * (n_ * other.m3_ - other.n_ * m3_) * inv_total_n; + + m2_ = new_m2; + m3_ = new_m3; + n_ = total_n; +} + +std::string Stats::ToString(int exclude) const { + if (Count() == 0) return std::string("(none)"); + + char buf[300]; + size_t pos = 0; + int ret; // snprintf - bytes written or negative for error. + + if ((exclude & kNoCount) == 0) { + ret = snprintf(buf + pos, sizeof(buf) - pos, "Count=%9zu ", + static_cast(Count())); + HWY_ASSERT(ret > 0); + pos += ret; + } + + if ((exclude & kNoMeanSD) == 0) { + const float sd = StandardDeviation(); + ret = snprintf(buf + pos, sizeof(buf) - pos, "Mean=%10.3e SD=%8.2e ", + Mean(), sd); + HWY_ASSERT(ret > 0); + pos += ret; + } + + if ((exclude & kNoMinMax) == 0) { + ret = snprintf(buf + pos, sizeof(buf) - pos, "Min=%10.3e Max=%10.3e ", + static_cast(Min()), static_cast(Max())); + HWY_ASSERT(ret > 0); + pos += ret; + } + + if ((exclude & kNoSkewKurt) == 0) { + ret = snprintf(buf + pos, sizeof(buf) - pos, "Skew=%5.2f Kurt=%7.2f ", + Skewness(), Kurtosis()); + HWY_ASSERT(ret > 0); + pos += ret; + } + + if ((exclude & kNoGeomean) == 0) { + ret = snprintf(buf + pos, sizeof(buf) - pos, "GeoMean=%9.6f ", + GeometricMean()); + HWY_ASSERT(ret > 0); + pos += ret; + } + + HWY_ASSERT(pos < sizeof(buf)); + return buf; +} + +} // namespace hwy diff --git a/lib/highway/hwy/stats.h b/lib/highway/hwy/stats.h new file mode 100644 index 00000000000..041f3d2d55b --- /dev/null +++ b/lib/highway/hwy/stats.h @@ -0,0 +1,244 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_STATS_H_ +#define HIGHWAY_HWY_STATS_H_ + +#include +#include + +#include +#include + +#include "hwy/base.h" // HWY_ASSERT + +namespace hwy { + +// Thread-compatible. +template +class Bins { + public: + Bins() { Reset(); } + + template + void Notify(T bin) { + HWY_ASSERT(T{0} <= bin && bin < static_cast(N)); + counts_[static_cast(bin)]++; + } + + uint32_t Bin(size_t bin_idx) const { + HWY_DASSERT(bin_idx < N); + return counts_[bin_idx]; + } + + void ResetBin(size_t bin_idx) { + HWY_DASSERT(bin_idx < N); + counts_[bin_idx] = 0; + } + + void Assimilate(const Bins& other) { + for (size_t i = 0; i < N; ++i) { + counts_[i] += other.counts_[i]; + } + } + + size_t FirstNonzero() const { + for (size_t i = 0; i < N; ++i) { + if (counts_[i] != 0) return i; + } + return N; + } + + size_t LastNonzero() const { + for (size_t i = N - 1; i < N; --i) { + if (counts_[i] != 0) return i; + } + return 0; + } + + size_t NumNonzero() const { + size_t num_nonzero = 0; + for (size_t i = 0; i < N; ++i) { + num_nonzero += (counts_[i] != 0); + } + return num_nonzero; + } + + size_t ModalBinIdx() const { + size_t max = 0; + size_t idx_max = 0; + for (size_t i = 0; i < N; ++i) { + if (counts_[i] > max) { + max = counts_[i]; + idx_max = i; + } + } + return idx_max; + } + + void Print(const char* caption, bool skip_zero = false) const { + fprintf(stderr, "\n%s [%zu, modal idx %zu]\n", caption, N, ModalBinIdx()); + const size_t first_nonzero = FirstNonzero(); + const size_t last_nonzero = LastNonzero(); + if (skip_zero) { + for (size_t i = first_nonzero; i <= last_nonzero; ++i) { + if (counts_[i] != 0) { + fprintf(stderr, " %3zu: %zu\n", i, counts_[i]); + } + } + } else { + for (size_t i = first_nonzero; i <= last_nonzero; ++i) { + fprintf(stderr, " %3zu: %zu\n", i, counts_[i]); + } + } + } + + void Reset() { + for (size_t i = 0; i < N; ++i) { + counts_[i] = 0; + } + } + + private: + uint32_t counts_[N]; +}; + +// Descriptive statistics of a variable (4 moments). Thread-compatible. +class Stats { + public: + Stats() { Reset(); } + + void Notify(const float x) { + ++n_; + + min_ = HWY_MIN(min_, x); + max_ = HWY_MAX(max_, x); + + // Logarithmic transform avoids/delays underflow and overflow. + sum_log_ += std::log(static_cast(x)); + + // Online moments. Reference: + // https://www.thinkbrg.com/media/publication/720_McCrary_ImplementingAlgorithms_Whitepaper_20151119_WEB.pdf + const double d = x - m1_; + const double d_div_n = d / static_cast(n_); + const double d2n1_div_n = d * (static_cast(n_) - 1) * d_div_n; + const int64_t n_poly = n_ * n_ - 3 * n_ + 3; + m1_ += d_div_n; + m4_ += d_div_n * + (d_div_n * (d2n1_div_n * static_cast(n_poly) + 6.0 * m2_) - + 4.0 * m3_); + m3_ += d_div_n * (d2n1_div_n * (static_cast(n_) - 2) - 3.0 * m2_); + m2_ += d2n1_div_n; + } + + void Assimilate(const Stats& other); + + int64_t Count() const { return n_; } + + float Min() const { return min_; } + float Max() const { return max_; } + + double GeometricMean() const { + return n_ == 0 ? 0.0 : std::exp(sum_log_ / static_cast(n_)); + } + + double Mean() const { return m1_; } + // Same as Mu2. Assumes n_ is large. + double SampleVariance() const { + return n_ == 0 ? 0.0 : m2_ / static_cast(n_); + } + // Unbiased estimator for population variance even for smaller n_. + double Variance() const { + if (n_ == 0) return 0.0; + if (n_ == 1) return m2_; + return m2_ / static_cast(n_ - 1); + } + double StandardDeviation() const { return std::sqrt(Variance()); } + // Near zero for normal distributions; if positive on a unimodal distribution, + // the right tail is fatter. Assumes n_ is large. + double SampleSkewness() const { + if (ScalarAbs(m2_) < 1E-7) return 0.0; + return m3_ * std::sqrt(static_cast(n_)) / std::pow(m2_, 1.5); + } + // Corrected for bias (same as Wikipedia and Minitab but not Excel). + double Skewness() const { + if (n_ == 0) return 0.0; + const double biased = SampleSkewness(); + const double r = (static_cast(n_) - 1.0) / static_cast(n_); + return biased * std::pow(r, 1.5); + } + // Near zero for normal distributions; smaller values indicate fewer/smaller + // outliers and larger indicates more/larger outliers. Assumes n_ is large. + double SampleKurtosis() const { + if (ScalarAbs(m2_) < 1E-7) return 0.0; + return m4_ * static_cast(n_) / (m2_ * m2_); + } + // Corrected for bias (same as Wikipedia and Minitab but not Excel). + double Kurtosis() const { + if (n_ == 0) return 0.0; + const double biased = SampleKurtosis(); + const double r = (static_cast(n_) - 1.0) / static_cast(n_); + return biased * r * r; + } + + // Central moments, useful for "method of moments"-based parameter estimation + // of a mixture of two Gaussians. Assumes Count() != 0. + double Mu1() const { return m1_; } + double Mu2() const { return m2_ / static_cast(n_); } + double Mu3() const { return m3_ / static_cast(n_); } + double Mu4() const { return m4_ / static_cast(n_); } + + // Which statistics to EXCLUDE in ToString + enum { + kNoCount = 1, + kNoMeanSD = 2, + kNoMinMax = 4, + kNoSkewKurt = 8, + kNoGeomean = 16 + }; + std::string ToString(int exclude = 0) const; + + void Reset() { + n_ = 0; + + min_ = hwy::HighestValue(); + max_ = hwy::LowestValue(); + + sum_log_ = 0.0; + + m1_ = 0.0; + m2_ = 0.0; + m3_ = 0.0; + m4_ = 0.0; + } + + private: + int64_t n_; // signed for faster conversion + safe subtraction + + float min_; + float max_; + + double sum_log_; // for geomean + + // Moments + double m1_; + double m2_; + double m3_; + double m4_; +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_STATS_H_ diff --git a/lib/highway/hwy/targets.cc b/lib/highway/hwy/targets.cc new file mode 100644 index 00000000000..e323891b039 --- /dev/null +++ b/lib/highway/hwy/targets.cc @@ -0,0 +1,835 @@ +// Copyright 2019 Google LLC +// Copyright 2025 Arm Limited and/or its affiliates +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/targets.h" + +#include +#include + +#include "hwy/base.h" +#include "hwy/detect_targets.h" +#include "hwy/highway.h" +#include "hwy/x86_cpuid.h" + +#if HWY_ARCH_X86 +#include + +#elif (HWY_ARCH_ARM || HWY_ARCH_PPC || HWY_ARCH_S390X || HWY_ARCH_RISCV || \ + HWY_ARCH_LOONGARCH) && \ + HWY_OS_LINUX +// sys/auxv.h does not always include asm/hwcap.h, or define HWCAP*, hence we +// still include this directly. See #1199. +#if HWY_HAVE_ASM_HWCAP +#include +#endif +#if HWY_HAVE_AUXV +#include +#endif + +#endif // HWY_ARCH_* + +#if HWY_OS_APPLE +#include +#include +#endif // HWY_OS_APPLE + +namespace hwy { + +#if HWY_OS_APPLE +static HWY_INLINE HWY_MAYBE_UNUSED bool HasCpuFeature( + const char* feature_name) { + int result = 0; + size_t len = sizeof(int); + return (sysctlbyname(feature_name, &result, &len, nullptr, 0) == 0 && + result != 0); +} + +static HWY_INLINE HWY_MAYBE_UNUSED bool ParseU32(const char*& ptr, + uint32_t& parsed_val) { + uint64_t parsed_u64 = 0; + + const char* start_ptr = ptr; + for (char ch; (ch = (*ptr)) != '\0'; ++ptr) { + unsigned digit = static_cast(static_cast(ch)) - + static_cast(static_cast('0')); + if (digit > 9u) { + break; + } + + parsed_u64 = (parsed_u64 * 10u) + digit; + if (parsed_u64 > 0xFFFFFFFFu) { + return false; + } + } + + parsed_val = static_cast(parsed_u64); + return (ptr != start_ptr); +} + +static HWY_INLINE HWY_MAYBE_UNUSED bool IsMacOs12_2OrLater() { + utsname uname_buf; + ZeroBytes(&uname_buf, sizeof(utsname)); + + if ((uname(&uname_buf)) != 0) { + return false; + } + + const char* ptr = uname_buf.release; + if (!ptr) { + return false; + } + + uint32_t major; + uint32_t minor; + if (!ParseU32(ptr, major)) { + return false; + } + + if (*ptr != '.') { + return false; + } + + ++ptr; + if (!ParseU32(ptr, minor)) { + return false; + } + + // We are running on macOS 12.2 or later if the Darwin kernel version is 21.3 + // or later + return (major > 21 || (major == 21 && minor >= 3)); +} +#endif // HWY_OS_APPLE + +#if HWY_ARCH_X86 && HWY_HAVE_RUNTIME_DISPATCH +namespace x86 { + +// Returns the lower 32 bits of extended control register 0. +// Requires CPU support for "OSXSAVE" (see below). +static uint32_t ReadXCR0() { +#if HWY_COMPILER_MSVC + return static_cast(_xgetbv(0)); +#else // HWY_COMPILER_MSVC + uint32_t xcr0, xcr0_high; + const uint32_t index = 0; + asm volatile(".byte 0x0F, 0x01, 0xD0" + : "=a"(xcr0), "=d"(xcr0_high) + : "c"(index)); + return xcr0; +#endif // HWY_COMPILER_MSVC +} + +// Arbitrary bit indices indicating which instruction set extensions are +// supported. Use enum to ensure values are distinct. +enum class FeatureIndex : uint32_t { + kSSE = 0, + kSSE2, + kSSE3, + kSSSE3, + + kSSE41, + kSSE42, + kCLMUL, + kAES, + + kAVX, + kAVX2, + kF16C, + kFMA, + kLZCNT, + kBMI, + kBMI2, + + kAVX512F, + kAVX512VL, + kAVX512CD, + kAVX512DQ, + kAVX512BW, + kAVX512FP16, + kAVX512BF16, + + kVNNI, + kVPCLMULQDQ, + kVBMI, + kVBMI2, + kVAES, + kPOPCNTDQ, + kBITALG, + kGFNI, + + kAVX10, + kAPX, + + kSentinel +}; +static_assert(static_cast(FeatureIndex::kSentinel) < 64, + "Too many bits for u64"); + +static HWY_INLINE constexpr uint64_t Bit(FeatureIndex index) { + return 1ull << static_cast(index); +} + +// Returns bit array of FeatureIndex from CPUID feature flags. +static uint64_t FlagsFromCPUID() { + uint64_t flags = 0; // return value + uint32_t abcd[4]; + Cpuid(0, 0, abcd); + const uint32_t max_level = abcd[0]; + + // Standard feature flags + Cpuid(1, 0, abcd); + flags |= IsBitSet(abcd[3], 25) ? Bit(FeatureIndex::kSSE) : 0; + flags |= IsBitSet(abcd[3], 26) ? Bit(FeatureIndex::kSSE2) : 0; + flags |= IsBitSet(abcd[2], 0) ? Bit(FeatureIndex::kSSE3) : 0; + flags |= IsBitSet(abcd[2], 1) ? Bit(FeatureIndex::kCLMUL) : 0; + flags |= IsBitSet(abcd[2], 9) ? Bit(FeatureIndex::kSSSE3) : 0; + flags |= IsBitSet(abcd[2], 12) ? Bit(FeatureIndex::kFMA) : 0; + flags |= IsBitSet(abcd[2], 19) ? Bit(FeatureIndex::kSSE41) : 0; + flags |= IsBitSet(abcd[2], 20) ? Bit(FeatureIndex::kSSE42) : 0; + flags |= IsBitSet(abcd[2], 25) ? Bit(FeatureIndex::kAES) : 0; + flags |= IsBitSet(abcd[2], 28) ? Bit(FeatureIndex::kAVX) : 0; + flags |= IsBitSet(abcd[2], 29) ? Bit(FeatureIndex::kF16C) : 0; + + // Extended feature flags + Cpuid(0x80000001U, 0, abcd); + flags |= IsBitSet(abcd[2], 5) ? Bit(FeatureIndex::kLZCNT) : 0; + + // Extended features + if (max_level >= 7) { + Cpuid(7, 0, abcd); + flags |= IsBitSet(abcd[1], 3) ? Bit(FeatureIndex::kBMI) : 0; + flags |= IsBitSet(abcd[1], 5) ? Bit(FeatureIndex::kAVX2) : 0; + flags |= IsBitSet(abcd[1], 8) ? Bit(FeatureIndex::kBMI2) : 0; + + flags |= IsBitSet(abcd[1], 16) ? Bit(FeatureIndex::kAVX512F) : 0; + flags |= IsBitSet(abcd[1], 17) ? Bit(FeatureIndex::kAVX512DQ) : 0; + flags |= IsBitSet(abcd[1], 28) ? Bit(FeatureIndex::kAVX512CD) : 0; + flags |= IsBitSet(abcd[1], 30) ? Bit(FeatureIndex::kAVX512BW) : 0; + flags |= IsBitSet(abcd[1], 31) ? Bit(FeatureIndex::kAVX512VL) : 0; + + flags |= IsBitSet(abcd[2], 1) ? Bit(FeatureIndex::kVBMI) : 0; + flags |= IsBitSet(abcd[2], 6) ? Bit(FeatureIndex::kVBMI2) : 0; + flags |= IsBitSet(abcd[2], 8) ? Bit(FeatureIndex::kGFNI) : 0; + flags |= IsBitSet(abcd[2], 9) ? Bit(FeatureIndex::kVAES) : 0; + flags |= IsBitSet(abcd[2], 10) ? Bit(FeatureIndex::kVPCLMULQDQ) : 0; + flags |= IsBitSet(abcd[2], 11) ? Bit(FeatureIndex::kVNNI) : 0; + flags |= IsBitSet(abcd[2], 12) ? Bit(FeatureIndex::kBITALG) : 0; + flags |= IsBitSet(abcd[2], 14) ? Bit(FeatureIndex::kPOPCNTDQ) : 0; + + flags |= IsBitSet(abcd[3], 23) ? Bit(FeatureIndex::kAVX512FP16) : 0; + + Cpuid(7, 1, abcd); + flags |= IsBitSet(abcd[0], 5) ? Bit(FeatureIndex::kAVX512BF16) : 0; + flags |= IsBitSet(abcd[3], 19) ? Bit(FeatureIndex::kAVX10) : 0; + flags |= IsBitSet(abcd[3], 21) ? Bit(FeatureIndex::kAPX) : 0; + } + + return flags; +} + +// Each Highway target requires a 'group' of multiple features/flags. +static constexpr uint64_t kGroupSSE2 = + Bit(FeatureIndex::kSSE) | Bit(FeatureIndex::kSSE2); + +static constexpr uint64_t kGroupSSSE3 = + Bit(FeatureIndex::kSSE3) | Bit(FeatureIndex::kSSSE3) | kGroupSSE2; + +#ifdef HWY_DISABLE_PCLMUL_AES +static constexpr uint64_t kGroupSSE4 = + Bit(FeatureIndex::kSSE41) | Bit(FeatureIndex::kSSE42) | kGroupSSSE3; +#else +static constexpr uint64_t kGroupSSE4 = + Bit(FeatureIndex::kSSE41) | Bit(FeatureIndex::kSSE42) | + Bit(FeatureIndex::kCLMUL) | Bit(FeatureIndex::kAES) | kGroupSSSE3; +#endif // HWY_DISABLE_PCLMUL_AES + +// We normally assume BMI/BMI2/FMA are available if AVX2 is. This allows us to +// use BZHI and (compiler-generated) MULX. However, VirtualBox lacks them +// [https://www.virtualbox.org/ticket/15471]. Thus we provide the option of +// avoiding using and requiring these so AVX2 can still be used. +#ifdef HWY_DISABLE_BMI2_FMA +static constexpr uint64_t kGroupBMI2_FMA = 0; +#else +static constexpr uint64_t kGroupBMI2_FMA = Bit(FeatureIndex::kBMI) | + Bit(FeatureIndex::kBMI2) | + Bit(FeatureIndex::kFMA); +#endif + +#ifdef HWY_DISABLE_F16C +static constexpr uint64_t kGroupF16C = 0; +#else +static constexpr uint64_t kGroupF16C = Bit(FeatureIndex::kF16C); +#endif + +static constexpr uint64_t kGroupAVX2 = + Bit(FeatureIndex::kAVX) | Bit(FeatureIndex::kAVX2) | + Bit(FeatureIndex::kLZCNT) | kGroupBMI2_FMA | kGroupF16C | kGroupSSE4; + +static constexpr uint64_t kGroupAVX3 = + Bit(FeatureIndex::kAVX512F) | Bit(FeatureIndex::kAVX512VL) | + Bit(FeatureIndex::kAVX512DQ) | Bit(FeatureIndex::kAVX512BW) | + Bit(FeatureIndex::kAVX512CD) | kGroupAVX2; + +static constexpr uint64_t kGroupAVX3_DL = + Bit(FeatureIndex::kVNNI) | Bit(FeatureIndex::kVPCLMULQDQ) | + Bit(FeatureIndex::kVBMI) | Bit(FeatureIndex::kVBMI2) | + Bit(FeatureIndex::kVAES) | Bit(FeatureIndex::kPOPCNTDQ) | + Bit(FeatureIndex::kBITALG) | Bit(FeatureIndex::kGFNI) | kGroupAVX3; + +static constexpr uint64_t kGroupAVX3_ZEN4 = + Bit(FeatureIndex::kAVX512BF16) | kGroupAVX3_DL; + +static constexpr uint64_t kGroupAVX3_SPR = + Bit(FeatureIndex::kAVX512FP16) | kGroupAVX3_ZEN4; + +static constexpr uint64_t kGroupAVX10 = + Bit(FeatureIndex::kAVX10) | Bit(FeatureIndex::kAPX) | + Bit(FeatureIndex::kVPCLMULQDQ) | Bit(FeatureIndex::kVAES) | + Bit(FeatureIndex::kGFNI) | kGroupAVX2; + +static int64_t DetectTargets() { + int64_t bits = 0; // return value of supported targets. + HWY_IF_CONSTEXPR(HWY_ARCH_X86_64) { + bits |= HWY_SSE2; // always present in x64 + } + + const uint64_t flags = FlagsFromCPUID(); + // Set target bit(s) if all their group's flags are all set. + if ((flags & kGroupAVX3_SPR) == kGroupAVX3_SPR) { + bits |= HWY_AVX3_SPR; + } + if ((flags & kGroupAVX3_DL) == kGroupAVX3_DL) { + bits |= HWY_AVX3_DL; + } + if ((flags & kGroupAVX3) == kGroupAVX3) { + bits |= HWY_AVX3; + } + if ((flags & kGroupAVX2) == kGroupAVX2) { + bits |= HWY_AVX2; + } + if ((flags & kGroupSSE4) == kGroupSSE4) { + bits |= HWY_SSE4; + } + if ((flags & kGroupSSSE3) == kGroupSSSE3) { + bits |= HWY_SSSE3; + } + HWY_IF_CONSTEXPR(HWY_ARCH_X86_32) { + if ((flags & kGroupSSE2) == kGroupSSE2) { + bits |= HWY_SSE2; + } + } + + uint32_t abcd[4]; + + if ((flags & kGroupAVX10) == kGroupAVX10) { + Cpuid(0x24, 0, abcd); + + // AVX10 version is in lower 8 bits of abcd[1] + const uint32_t avx10_ver = abcd[1] & 0xFFu; + + // 512-bit vectors are supported if avx10_ver >= 1 is true and bit 18 of + // abcd[1] is set + const bool has_avx10_with_512bit_vectors = + (avx10_ver >= 1) && IsBitSet(abcd[1], 18); + + if (has_avx10_with_512bit_vectors) { + // AVX10.1 or later with support for 512-bit vectors implies support for + // the AVX3/AVX3_DL/AVX3_SPR targets + bits |= (HWY_AVX3_SPR | HWY_AVX3_DL | HWY_AVX3); + + if (avx10_ver >= 2) { + // AVX10.2 is supported if avx10_ver >= 2 is true + bits |= HWY_AVX10_2; + } + } + } + + // Clear AVX2/AVX3 bits if the CPU or OS does not support XSAVE - otherwise, + // YMM/ZMM registers are not preserved across context switches. + + // The lower 128 bits of XMM0-XMM15 are guaranteed to be preserved across + // context switches on x86_64 + + // The following OS's are known to preserve the lower 128 bits of XMM + // registers across context switches on x86 CPUs that support SSE (even in + // 32-bit mode): + // - Windows 2000 or later + // - Linux 2.4.0 or later + // - Mac OS X 10.4 or later + // - FreeBSD 4.4 or later + // - NetBSD 1.6 or later + // - OpenBSD 3.5 or later + // - UnixWare 7 Release 7.1.1 or later + // - Solaris 9 4/04 or later + + Cpuid(1, 0, abcd); + const bool has_xsave = IsBitSet(abcd[2], 26); + const bool has_osxsave = IsBitSet(abcd[2], 27); + constexpr int64_t min_avx2 = HWY_AVX2 | (HWY_AVX2 - 1); + + if (has_xsave && has_osxsave) { +#if HWY_OS_APPLE + // On macOS, check for AVX3 XSAVE support by checking that we are running on + // macOS 12.2 or later and HasCpuFeature("hw.optional.avx512f") returns true + + // There is a bug in macOS 12.1 or earlier that can cause ZMM16-ZMM31, the + // upper 256 bits of the ZMM registers, and K0-K7 (the AVX512 mask + // registers) to not be properly preserved across a context switch on + // macOS 12.1 or earlier. + + // This bug on macOS 12.1 or earlier on x86_64 CPU's with AVX3 support is + // described at + // https://community.intel.com/t5/Software-Tuning-Performance/MacOS-Darwin-kernel-bug-clobbers-AVX-512-opmask-register-state/m-p/1327259, + // https://github.com/golang/go/issues/49233, and + // https://github.com/simdutf/simdutf/pull/236. + + // In addition to the bug that is there on macOS 12.1 or earlier, bits 5, 6, + // and 7 can be set to 0 on x86_64 CPUs with AVX3 support on macOS until + // the first AVX512 instruction is executed as macOS only preserves + // ZMM16-ZMM31, the upper 256 bits of the ZMM registers, and K0-K7 across a + // context switch on threads that have executed an AVX512 instruction. + + // Checking for AVX3 XSAVE support on macOS using + // HasCpuFeature("hw.optional.avx512f") avoids false negative results + // on x86_64 CPU's that have AVX3 support. + const bool have_avx3_xsave_support = + IsMacOs12_2OrLater() && HasCpuFeature("hw.optional.avx512f"); +#endif + + const uint32_t xcr0 = ReadXCR0(); + constexpr int64_t min_avx3 = HWY_AVX3 | (HWY_AVX3 - 1); + // XMM/YMM + if (!IsBitSet(xcr0, 1) || !IsBitSet(xcr0, 2)) { + // Clear the AVX2/AVX3 bits if XMM/YMM XSAVE is not enabled + bits &= ~min_avx2; + } + +#if !HWY_OS_APPLE + // On OS's other than macOS, check for AVX3 XSAVE support by checking that + // bits 5, 6, and 7 of XCR0 are set. + const bool have_avx3_xsave_support = + IsBitSet(xcr0, 5) && IsBitSet(xcr0, 6) && IsBitSet(xcr0, 7); +#endif + + // opmask, ZMM lo/hi + if (!have_avx3_xsave_support) { + bits &= ~min_avx3; + } + } else { // !has_xsave || !has_osxsave + // Clear the AVX2/AVX3 bits if the CPU or OS does not support XSAVE + bits &= ~min_avx2; + } + + // This is mainly to work around the slow Zen4 CompressStore. It's unclear + // whether subsequent AMD models will be affected; assume yes. + if ((bits & HWY_AVX3_DL) && (flags & kGroupAVX3_ZEN4) == kGroupAVX3_ZEN4 && + IsAMD()) { + bits |= HWY_AVX3_ZEN4; + } + + return bits; +} + +} // namespace x86 +#elif HWY_ARCH_ARM && HWY_HAVE_RUNTIME_DISPATCH +namespace arm { + +#if HWY_ARCH_ARM_A64 && !HWY_OS_APPLE && \ + (HWY_COMPILER_GCC || HWY_COMPILER_CLANG) && \ + ((HWY_TARGETS & HWY_ALL_SVE) != 0) +HWY_PUSH_ATTRIBUTES("+sve") +static int64_t DetectAdditionalSveTargets(int64_t detected_targets) { + uint64_t sve_vec_len; + + // Use inline assembly instead of svcntb_pat(SV_ALL) as GCC or Clang might + // possibly optimize a svcntb_pat(SV_ALL) call to a constant if the + // -msve-vector-bits option is specified + asm("cntb %0" : "=r"(sve_vec_len)::); + + return ((sve_vec_len == 32) + ? HWY_SVE_256 + : (((detected_targets & HWY_SVE2) != 0 && sve_vec_len == 16) + ? HWY_SVE2_128 + : 0)); +} +HWY_POP_ATTRIBUTES +#endif + +static int64_t DetectTargets() { + int64_t bits = 0; // return value of supported targets. + + using CapBits = unsigned long; // NOLINT +#if HWY_OS_APPLE + const CapBits hw = 0UL; +#else + // For Android, this has been supported since API 20 (2014). + const CapBits hw = getauxval(AT_HWCAP); +#endif + (void)hw; + +#if HWY_ARCH_ARM_A64 + bits |= HWY_NEON_WITHOUT_AES; // aarch64 always has NEON and VFPv4.. + +#if HWY_OS_APPLE + if (HasCpuFeature("hw.optional.arm.FEAT_AES")) { + bits |= HWY_NEON; + + // Some macOS versions report AdvSIMD_HPFPCvt under a different key. + // Check both known variants for compatibility. + if ((HasCpuFeature("hw.optional.AdvSIMD_HPFPCvt") || + HasCpuFeature("hw.optional.arm.AdvSIMD_HPFPCvt")) && + HasCpuFeature("hw.optional.arm.FEAT_DotProd") && + HasCpuFeature("hw.optional.arm.FEAT_BF16") && + HasCpuFeature("hw.optional.arm.FEAT_I8MM")) { + bits |= HWY_NEON_BF16; + } + } +#else // !HWY_OS_APPLE + // .. but not necessarily AES, which is required for HWY_NEON. +#if defined(HWCAP_AES) + if (hw & HWCAP_AES) { + bits |= HWY_NEON; + +#if defined(HWCAP_ASIMDHP) && defined(HWCAP_ASIMDDP) && defined(HWCAP2_BF16) + const CapBits hw2 = getauxval(AT_HWCAP2); + constexpr CapBits kGroupF16Dot = HWCAP_ASIMDHP | HWCAP_ASIMDDP; + constexpr CapBits kGroupBF16 = HWCAP2_BF16; + if ((hw & kGroupF16Dot) == kGroupF16Dot && + (hw2 & kGroupBF16) == kGroupBF16) { + bits |= HWY_NEON_BF16; + } +#endif // HWCAP_ASIMDHP && HWCAP_ASIMDDP && HWCAP2_BF16 + } +#endif // HWCAP_AES + +#if defined(HWCAP_SVE) + if (hw & HWCAP_SVE) { + bits |= HWY_SVE; + } +#endif + +#ifndef HWCAP2_SVE2 +#define HWCAP2_SVE2 (1 << 1) +#endif +#ifndef HWCAP2_SVEAES +#define HWCAP2_SVEAES (1 << 2) +#endif +#ifndef HWCAP2_SVEI8MM +#define HWCAP2_SVEI8MM (1 << 9) +#endif +#ifndef HWCAP2_SVEBF16 +#define HWCAP2_SVEBF16 (1 << 12) +#endif + + constexpr CapBits kGroupSVE2 = HWCAP2_SVE2 | HWCAP2_SVEAES; + const CapBits hw2 = getauxval(AT_HWCAP2); + if ((hw2 & kGroupSVE2) == kGroupSVE2) { + bits |= HWY_SVE2; + } + +#if (HWY_COMPILER_GCC || HWY_COMPILER_CLANG) && \ + ((HWY_TARGETS & HWY_ALL_SVE) != 0) + if ((bits & HWY_ALL_SVE) != 0) { + bits |= DetectAdditionalSveTargets(bits); + + // SVE2_128 implies I8MM and BF16, hence remove it if they are not present. + constexpr CapBits kGroupSVE2_128 = HWCAP2_SVEI8MM | HWCAP2_SVEBF16; + if ((hw2 & kGroupSVE2_128) != kGroupSVE2_128) { + bits &= ~HWY_SVE2_128; + } + } +#endif // (HWY_COMPILER_GCC || HWY_COMPILER_CLANG) && + // ((HWY_TARGETS & HWY_ALL_SVE) != 0) + +#endif // HWY_OS_APPLE + +#else // !HWY_ARCH_ARM_A64 + +// Some old auxv.h / hwcap.h do not define these. If not, treat as unsupported. +#if defined(HWCAP_NEON) && defined(HWCAP_VFPv4) + if ((hw & HWCAP_NEON) && (hw & HWCAP_VFPv4)) { + bits |= HWY_NEON_WITHOUT_AES; + } +#endif + + // aarch32 would check getauxval(AT_HWCAP2) & HWCAP2_AES, but we do not yet + // support that platform, and Armv7 lacks AES entirely. Because HWY_NEON + // requires native AES instructions, we do not enable that target here. + +#endif // HWY_ARCH_ARM_A64 + return bits; +} +} // namespace arm +#elif HWY_ARCH_PPC && HWY_HAVE_RUNTIME_DISPATCH +namespace ppc { + +#ifndef PPC_FEATURE_HAS_ALTIVEC +#define PPC_FEATURE_HAS_ALTIVEC 0x10000000 +#endif + +#ifndef PPC_FEATURE_HAS_VSX +#define PPC_FEATURE_HAS_VSX 0x00000080 +#endif + +#ifndef PPC_FEATURE2_ARCH_2_07 +#define PPC_FEATURE2_ARCH_2_07 0x80000000 +#endif + +#ifndef PPC_FEATURE2_VEC_CRYPTO +#define PPC_FEATURE2_VEC_CRYPTO 0x02000000 +#endif + +#ifndef PPC_FEATURE2_ARCH_3_00 +#define PPC_FEATURE2_ARCH_3_00 0x00800000 +#endif + +#ifndef PPC_FEATURE2_ARCH_3_1 +#define PPC_FEATURE2_ARCH_3_1 0x00040000 +#endif + +using CapBits = unsigned long; // NOLINT + +// For AT_HWCAP, the others are for AT_HWCAP2 +static constexpr CapBits kGroupVSX = + PPC_FEATURE_HAS_ALTIVEC | PPC_FEATURE_HAS_VSX; + +#if defined(HWY_DISABLE_PPC8_CRYPTO) +static constexpr CapBits kGroupPPC8 = PPC_FEATURE2_ARCH_2_07; +#else +static constexpr CapBits kGroupPPC8 = + PPC_FEATURE2_ARCH_2_07 | PPC_FEATURE2_VEC_CRYPTO; +#endif +static constexpr CapBits kGroupPPC9 = kGroupPPC8 | PPC_FEATURE2_ARCH_3_00; +static constexpr CapBits kGroupPPC10 = kGroupPPC9 | PPC_FEATURE2_ARCH_3_1; + +static int64_t DetectTargets() { + int64_t bits = 0; // return value of supported targets. + +#if defined(AT_HWCAP) && defined(AT_HWCAP2) + const CapBits hw = getauxval(AT_HWCAP); + + if ((hw & kGroupVSX) == kGroupVSX) { + const CapBits hw2 = getauxval(AT_HWCAP2); + if ((hw2 & kGroupPPC8) == kGroupPPC8) { + bits |= HWY_PPC8; + } + if ((hw2 & kGroupPPC9) == kGroupPPC9) { + bits |= HWY_PPC9; + } + if ((hw2 & kGroupPPC10) == kGroupPPC10) { + bits |= HWY_PPC10; + } + } // VSX +#endif // defined(AT_HWCAP) && defined(AT_HWCAP2) + + return bits; +} +} // namespace ppc +#elif HWY_ARCH_S390X && HWY_HAVE_RUNTIME_DISPATCH +namespace s390x { + +#ifndef HWCAP_S390_VX +#define HWCAP_S390_VX 2048 +#endif + +#ifndef HWCAP_S390_VXE +#define HWCAP_S390_VXE 8192 +#endif + +#ifndef HWCAP_S390_VXRS_EXT2 +#define HWCAP_S390_VXRS_EXT2 32768 +#endif + +using CapBits = unsigned long; // NOLINT + +static constexpr CapBits kGroupZ14 = HWCAP_S390_VX | HWCAP_S390_VXE; +static constexpr CapBits kGroupZ15 = + HWCAP_S390_VX | HWCAP_S390_VXE | HWCAP_S390_VXRS_EXT2; + +static int64_t DetectTargets() { + int64_t bits = 0; + +#if defined(AT_HWCAP) + const CapBits hw = getauxval(AT_HWCAP); + + if ((hw & kGroupZ14) == kGroupZ14) { + bits |= HWY_Z14; + } + + if ((hw & kGroupZ15) == kGroupZ15) { + bits |= HWY_Z15; + } +#endif + + return bits; +} +} // namespace s390x +#elif HWY_ARCH_RISCV && HWY_HAVE_RUNTIME_DISPATCH +namespace rvv { + +#ifndef HWCAP_RVV +#define COMPAT_HWCAP_ISA_V (1 << ('V' - 'A')) +#endif + +using CapBits = unsigned long; // NOLINT + +static int64_t DetectTargets() { + int64_t bits = 0; + + const CapBits hw = getauxval(AT_HWCAP); + + if ((hw & COMPAT_HWCAP_ISA_V) == COMPAT_HWCAP_ISA_V) { + size_t e8m1_vec_len; +#if HWY_ARCH_RISCV_64 + int64_t vtype_reg_val; +#else + int32_t vtype_reg_val; +#endif + + // Check that a vuint8m1_t vector is at least 16 bytes and that tail + // agnostic and mask agnostic mode are supported + asm volatile( + // Avoid compiler error on GCC or Clang if -march=rv64gcv1p0 or + // -march=rv32gcv1p0 option is not specified on the command line + ".option push\n\t" + ".option arch, +v\n\t" + "vsetvli %0, zero, e8, m1, ta, ma\n\t" + "csrr %1, vtype\n\t" + ".option pop" + : "=r"(e8m1_vec_len), "=r"(vtype_reg_val)); + + // The RVV target is supported if the VILL bit of VTYPE (the MSB bit of + // VTYPE) is not set and the length of a vuint8m1_t vector is at least 16 + // bytes + if (vtype_reg_val >= 0 && e8m1_vec_len >= 16) { + bits |= HWY_RVV; + } + } + + return bits; +} +} // namespace rvv +#elif HWY_ARCH_LOONGARCH && HWY_HAVE_RUNTIME_DISPATCH + +namespace loongarch { + +#ifndef LA_HWCAP_LSX +#define LA_HWCAP_LSX (1u << 4) +#endif +#ifndef LA_HWCAP_LASX +#define LA_HWCAP_LASX (1u << 5) +#endif + +using CapBits = unsigned long; // NOLINT + +static int64_t DetectTargets() { + int64_t bits = 0; + const CapBits hw = getauxval(AT_HWCAP); + if (hw & LA_HWCAP_LSX) bits |= HWY_LSX; + if (hw & LA_HWCAP_LASX) bits |= HWY_LASX; + return bits; +} +} // namespace loongarch +#endif // HWY_ARCH_* + +// Returns targets supported by the CPU, independently of DisableTargets. +// Factored out of SupportedTargets to make its structure more obvious. Note +// that x86 CPUID may take several hundred cycles. +static int64_t DetectTargets() { + // Apps will use only one of these (the default is EMU128), but compile flags + // for this TU may differ from that of the app, so allow both. + int64_t bits = HWY_SCALAR | HWY_EMU128; + +#if HWY_ARCH_X86 && HWY_HAVE_RUNTIME_DISPATCH + bits |= x86::DetectTargets(); +#elif HWY_ARCH_ARM && HWY_HAVE_RUNTIME_DISPATCH + bits |= arm::DetectTargets(); +#elif HWY_ARCH_PPC && HWY_HAVE_RUNTIME_DISPATCH + bits |= ppc::DetectTargets(); +#elif HWY_ARCH_S390X && HWY_HAVE_RUNTIME_DISPATCH + bits |= s390x::DetectTargets(); +#elif HWY_ARCH_RISCV && HWY_HAVE_RUNTIME_DISPATCH + bits |= rvv::DetectTargets(); +#elif HWY_ARCH_LOONGARCH && HWY_HAVE_RUNTIME_DISPATCH + bits |= loongarch::DetectTargets(); + +#else + // TODO(janwas): detect support for WASM. + // This file is typically compiled without HWY_IS_TEST, but targets_test has + // it set, and will expect all of its HWY_TARGETS (= all attainable) to be + // supported. + bits |= HWY_ENABLED_BASELINE; +#endif // HWY_ARCH_* + + if ((bits & HWY_ENABLED_BASELINE) != HWY_ENABLED_BASELINE) { + const uint64_t bits_u = static_cast(bits); + const uint64_t enabled = static_cast(HWY_ENABLED_BASELINE); + HWY_WARN("CPU supports 0x%08x%08x, software requires 0x%08x%08x\n", + static_cast(bits_u >> 32), + static_cast(bits_u & 0xFFFFFFFF), + static_cast(enabled >> 32), + static_cast(enabled & 0xFFFFFFFF)); + } + + return bits; +} + +// When running tests, this value can be set to the mocked supported targets +// mask. Only written to from a single thread before the test starts. +static int64_t supported_targets_for_test_ = 0; + +// Mask of targets disabled at runtime with DisableTargets. +static int64_t supported_mask_ = LimitsMax(); + +HWY_DLLEXPORT void DisableTargets(int64_t disabled_targets) { + supported_mask_ = static_cast(~disabled_targets); + // This will take effect on the next call to SupportedTargets, which is + // called right before GetChosenTarget::Update. However, calling Update here + // would make it appear that HWY_DYNAMIC_DISPATCH was called, which we want + // to check in tests. We instead de-initialize such that the next + // HWY_DYNAMIC_DISPATCH calls GetChosenTarget::Update via FunctionCache. + GetChosenTarget().DeInit(); +} + +HWY_DLLEXPORT void SetSupportedTargetsForTest(int64_t targets) { + supported_targets_for_test_ = targets; + GetChosenTarget().DeInit(); // see comment above +} + +HWY_DLLEXPORT int64_t SupportedTargets() { + int64_t targets = supported_targets_for_test_; + if (HWY_LIKELY(targets == 0)) { + // Mock not active. Re-detect instead of caching just in case we're on a + // heterogeneous ISA (also requires some app support to pin threads). This + // is only reached on the first HWY_DYNAMIC_DISPATCH or after each call to + // DisableTargets or SetSupportedTargetsForTest. + targets = DetectTargets(); + + // VectorBytes invokes HWY_DYNAMIC_DISPATCH. To prevent infinite recursion, + // first set up ChosenTarget. No need to Update() again afterwards with the + // final targets - that will be done by a caller of this function. + GetChosenTarget().Update(targets); + } + + targets &= supported_mask_; + return targets == 0 ? HWY_STATIC_TARGET : targets; +} + +HWY_DLLEXPORT ChosenTarget& GetChosenTarget() { + static ChosenTarget chosen_target; + return chosen_target; +} + +} // namespace hwy diff --git a/lib/highway/hwy/targets.h b/lib/highway/hwy/targets.h new file mode 100644 index 00000000000..b4ea2850a21 --- /dev/null +++ b/lib/highway/hwy/targets.h @@ -0,0 +1,388 @@ +// Copyright 2020 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_TARGETS_H_ +#define HIGHWAY_HWY_TARGETS_H_ + +// Allows opting out of C++ standard library usage, which is not available in +// some Compiler Explorer environments. +#ifndef HWY_NO_LIBCXX +#include +#endif + +// For SIMD module implementations and their callers. Defines which targets to +// generate and call. + +#include "hwy/base.h" +#include "hwy/detect_targets.h" +#include "hwy/highway_export.h" + +#if !defined(HWY_NO_LIBCXX) +#include +#endif + +namespace hwy { + +// Returns bitfield of enabled targets that are supported on this CPU; there is +// always at least one such target, hence the return value is never 0. The +// targets returned may change after calling DisableTargets. This function is +// always defined, but the HWY_SUPPORTED_TARGETS wrapper may allow eliding +// calls to it if there is only a single target enabled. +HWY_DLLEXPORT int64_t SupportedTargets(); + +// Evaluates to a function call, or literal if there is a single target. +#if (HWY_TARGETS & (HWY_TARGETS - 1)) == 0 +#define HWY_SUPPORTED_TARGETS HWY_TARGETS +#else +#define HWY_SUPPORTED_TARGETS hwy::SupportedTargets() +#endif + +// Subsequent SupportedTargets will not return targets whose bit(s) are set in +// `disabled_targets`. Exception: if SupportedTargets would return 0, it will +// instead return HWY_STATIC_TARGET (there must always be one target to call). +// +// This function is useful for disabling targets known to be buggy, or if the +// best available target is undesirable (perhaps due to throttling or memory +// bandwidth limitations). Use SetSupportedTargetsForTest instead of this +// function for iteratively enabling specific targets for testing. +HWY_DLLEXPORT void DisableTargets(int64_t disabled_targets); + +// Subsequent SupportedTargets will return the given set of targets, except +// those disabled via DisableTargets. Call with a mask of 0 to disable the mock +// and return to the normal SupportedTargets behavior. Used to run tests for +// all targets. +HWY_DLLEXPORT void SetSupportedTargetsForTest(int64_t targets); + +#ifndef HWY_NO_LIBCXX + +// Return the list of targets in HWY_TARGETS supported by the CPU as a list of +// individual HWY_* target macros such as HWY_SCALAR or HWY_NEON. This list +// is affected by the current SetSupportedTargetsForTest() mock if any. +HWY_INLINE std::vector SupportedAndGeneratedTargets() { + std::vector ret; + for (int64_t targets = SupportedTargets() & HWY_TARGETS; targets != 0; + targets = targets & (targets - 1)) { + int64_t current_target = targets & ~(targets - 1); + ret.push_back(current_target); + } + return ret; +} + +#endif // HWY_NO_LIBCXX + +// Returns a string that satisfies gtest IsValidParamName(). No longer report +// targets as "Unknown" if they are for a different architecture, because some +// users unconditionally disable targets and we want to see which. +static inline HWY_MAYBE_UNUSED const char* TargetName(int64_t target) { + switch (target) { + case HWY_EMU128: + return "EMU128"; + case HWY_SCALAR: + return "SCALAR"; + + // X86 + case HWY_SSE2: + return "SSE2"; + case HWY_SSSE3: + return "SSSE3"; + case HWY_SSE4: + return "SSE4"; + case HWY_AVX2: + return "AVX2"; + case HWY_AVX3: + return "AVX3"; + case HWY_AVX3_DL: + return "AVX3_DL"; + case HWY_AVX3_ZEN4: + return "AVX3_ZEN4"; + case HWY_AVX3_SPR: + return "AVX3_SPR"; + case HWY_AVX10_2: + return "AVX10_2"; + + // ARM + case HWY_SVE2_128: + return "SVE2_128"; + case HWY_SVE_256: + return "SVE_256"; + case HWY_SVE2: + return "SVE2"; + case HWY_SVE: + return "SVE"; + case HWY_NEON_BF16: + return "NEON_BF16"; + case HWY_NEON: + return "NEON"; + case HWY_NEON_WITHOUT_AES: + return "NEON_WITHOUT_AES"; + + // PPC + case HWY_PPC8: + return "PPC8"; + case HWY_PPC9: + return "PPC9"; + case HWY_PPC10: + return "PPC10"; + + // S390X + case HWY_Z14: + return "Z14"; + case HWY_Z15: + return "Z15"; + + // WASM + case HWY_WASM: + return "WASM"; + case HWY_WASM_EMU256: + return "WASM_EMU256"; + + // RISCV + case HWY_RVV: + return "RVV"; + + // LOONGARCH + case HWY_LSX: + return "LSX"; + case HWY_LASX: + return "LASX"; + } + + return "Unknown"; +} + +// Invokes VISITOR(TARGET, NAMESPACE) for all enabled targets. Alphabetic order. +#define HWY_VISIT_TARGETS(VISITOR) \ + HWY_VISIT_AVX10_2(VISITOR) \ + HWY_VISIT_AVX2(VISITOR) \ + HWY_VISIT_AVX3(VISITOR) \ + HWY_VISIT_AVX3_DL(VISITOR) \ + HWY_VISIT_AVX3_SPR(VISITOR) \ + HWY_VISIT_AVX3_ZEN4(VISITOR) \ + HWY_VISIT_FALLBACK(VISITOR) \ + HWY_VISIT_LASX(VISITOR) \ + HWY_VISIT_LSX(VISITOR) \ + HWY_VISIT_NEON(VISITOR) \ + HWY_VISIT_NEON_BF16(VISITOR) \ + HWY_VISIT_NEON_WITHOUT_AES(VISITOR) \ + HWY_VISIT_PPC10(VISITOR) \ + HWY_VISIT_PPC8(VISITOR) \ + HWY_VISIT_PPC9(VISITOR) \ + HWY_VISIT_RVV(VISITOR) \ + HWY_VISIT_SSE2(VISITOR) \ + HWY_VISIT_SSE4(VISITOR) \ + HWY_VISIT_SSSE3(VISITOR) \ + HWY_VISIT_SVE(VISITOR) \ + HWY_VISIT_SVE2(VISITOR) \ + HWY_VISIT_SVE2_128(VISITOR) \ + HWY_VISIT_SVE_256(VISITOR) \ + HWY_VISIT_WASM(VISITOR) \ + HWY_VISIT_WASM_EMU256(VISITOR) \ + HWY_VISIT_Z14(VISITOR) \ + HWY_VISIT_Z15(VISITOR) + +// The maximum number of dynamic targets on any architecture is defined by +// HWY_MAX_DYNAMIC_TARGETS and depends on the arch. + +// For the ChosenTarget mask and index we use a different bit arrangement than +// in the HWY_TARGETS mask. Only the targets involved in the current +// architecture are used in this mask, and therefore only the least significant +// (HWY_MAX_DYNAMIC_TARGETS + 2) bits of the int64_t mask are used. The least +// significant bit is set when the mask is not initialized, the next +// HWY_MAX_DYNAMIC_TARGETS more significant bits are a range of bits from the +// HWY_TARGETS or SupportedTargets() mask for the given architecture shifted to +// that position and the next more significant bit is used for HWY_SCALAR (if +// HWY_COMPILE_ONLY_SCALAR is defined) or HWY_EMU128. Because of this we need to +// define equivalent values for HWY_TARGETS in this representation. +// This mask representation allows to use ctz() on this mask and obtain a small +// number that's used as an index of the table for dynamic dispatch. In this +// way the first entry is used when the mask is uninitialized, the following +// HWY_MAX_DYNAMIC_TARGETS are for dynamic dispatch and the last one is for +// scalar. + +// The HWY_SCALAR/HWY_EMU128 bit in the ChosenTarget mask format. +#define HWY_CHOSEN_TARGET_MASK_SCALAR (1LL << (HWY_MAX_DYNAMIC_TARGETS + 1)) + +// Converts from a HWY_TARGETS mask to a ChosenTarget mask format for the +// current architecture. +#define HWY_CHOSEN_TARGET_SHIFT(X) \ + ((((X) >> (HWY_HIGHEST_TARGET_BIT + 1 - HWY_MAX_DYNAMIC_TARGETS)) & \ + ((1LL << HWY_MAX_DYNAMIC_TARGETS) - 1)) \ + << 1) + +// The HWY_TARGETS mask in the ChosenTarget mask format. +#define HWY_CHOSEN_TARGET_MASK_TARGETS \ + (HWY_CHOSEN_TARGET_SHIFT(HWY_TARGETS) | HWY_CHOSEN_TARGET_MASK_SCALAR | 1LL) + +#if HWY_ARCH_X86 +// Maximum number of dynamic targets, changing this value is an ABI incompatible +// change +#define HWY_MAX_DYNAMIC_TARGETS 15 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_X86 +// These must match the order in which the HWY_TARGETS are defined +// starting by the least significant (HWY_HIGHEST_TARGET_BIT + 1 - +// HWY_MAX_DYNAMIC_TARGETS) bit. This list must contain exactly +// HWY_MAX_DYNAMIC_TARGETS elements and does not include SCALAR. The first entry +// corresponds to the best target. Don't include a "," at the end of the list. +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_AVX10_2(func_name), /* AVX10_2 */ \ + HWY_CHOOSE_AVX3_SPR(func_name), /* AVX3_SPR */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_AVX3_ZEN4(func_name), /* AVX3_ZEN4 */ \ + HWY_CHOOSE_AVX3_DL(func_name), /* AVX3_DL */ \ + HWY_CHOOSE_AVX3(func_name), /* AVX3 */ \ + HWY_CHOOSE_AVX2(func_name), /* AVX2 */ \ + nullptr, /* AVX */ \ + HWY_CHOOSE_SSE4(func_name), /* SSE4 */ \ + HWY_CHOOSE_SSSE3(func_name), /* SSSE3 */ \ + nullptr, /* reserved - SSE3? */ \ + HWY_CHOOSE_SSE2(func_name) /* SSE2 */ + +#elif HWY_ARCH_ARM +// See HWY_ARCH_X86 above for details. +#define HWY_MAX_DYNAMIC_TARGETS 15 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_ARM +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_SVE2_128(func_name), /* SVE2 128-bit */ \ + HWY_CHOOSE_SVE_256(func_name), /* SVE 256-bit */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_SVE2(func_name), /* SVE2 */ \ + HWY_CHOOSE_SVE(func_name), /* SVE */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_NEON_BF16(func_name), /* NEON + f16/dot/bf16 */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_NEON(func_name), /* NEON */ \ + HWY_CHOOSE_NEON_WITHOUT_AES(func_name) /* NEON without AES */ + +#elif HWY_ARCH_RISCV +// See HWY_ARCH_X86 above for details. +#define HWY_MAX_DYNAMIC_TARGETS 9 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_RVV +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_RVV(func_name), /* RVV */ \ + nullptr /* reserved */ + +#elif HWY_ARCH_PPC || HWY_ARCH_S390X +// See HWY_ARCH_X86 above for details. +#define HWY_MAX_DYNAMIC_TARGETS 9 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_PPC +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_PPC10(func_name), /* PPC10 */ \ + HWY_CHOOSE_PPC9(func_name), /* PPC9 */ \ + HWY_CHOOSE_PPC8(func_name), /* PPC8 */ \ + HWY_CHOOSE_Z15(func_name), /* Z15 */ \ + HWY_CHOOSE_Z14(func_name) /* Z14 */ + +#elif HWY_ARCH_WASM +// See HWY_ARCH_X86 above for details. +#define HWY_MAX_DYNAMIC_TARGETS 9 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_WASM +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + nullptr, /* reserved */ \ + HWY_CHOOSE_WASM_EMU256(func_name), /* WASM_EMU256 */ \ + HWY_CHOOSE_WASM(func_name), /* WASM */ \ + nullptr /* reserved */ + +#elif HWY_ARCH_LOONGARCH +#define HWY_MAX_DYNAMIC_TARGETS 3 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_LOONGARCH +#define HWY_CHOOSE_TARGET_LIST(func_name) \ + nullptr, /* reserved */ \ + HWY_CHOOSE_LASX(func_name), /* LASX */ \ + HWY_CHOOSE_LSX(func_name) /* LSX */ + +#else +// Unknown architecture, will use HWY_SCALAR without dynamic dispatch, though +// still creating single-entry tables in HWY_EXPORT to ensure portability. +#define HWY_MAX_DYNAMIC_TARGETS 1 +#define HWY_HIGHEST_TARGET_BIT HWY_HIGHEST_TARGET_BIT_SCALAR +#endif + +// Bitfield of supported and enabled targets. The format differs from that of +// HWY_TARGETS; the lowest bit governs the first function pointer (which is +// special in that it calls FunctionCache, then Update, then dispatches to the +// actual implementation) in the tables created by HWY_EXPORT. Monostate (see +// GetChosenTarget), thread-safe except on RVV. +struct ChosenTarget { + public: + // Reset bits according to `targets` (typically the return value of + // SupportedTargets()). Postcondition: IsInitialized() == true. + void Update(int64_t targets) { + // These are `targets` shifted downwards, see above. Also include SCALAR + // (corresponds to the last entry in the function table) as fallback. + StoreMask(HWY_CHOSEN_TARGET_SHIFT(targets) | HWY_CHOSEN_TARGET_MASK_SCALAR); + } + + // Reset to the uninitialized state, so that FunctionCache will call Update + // during the next HWY_DYNAMIC_DISPATCH, and IsInitialized returns false. + void DeInit() { StoreMask(1); } + + // Whether Update was called. This indicates whether any HWY_DYNAMIC_DISPATCH + // function was called, which we check in tests. + bool IsInitialized() const { return LoadMask() != 1; } + + // Return the index in the dynamic dispatch table to be used by the current + // CPU. Note that this method must be in the header file so it uses the value + // of HWY_CHOSEN_TARGET_MASK_TARGETS defined in the translation unit that + // calls it, which may be different from others. This means we only enable + // those targets that were actually compiled in this module. + size_t HWY_INLINE GetIndex() const { + return hwy::Num0BitsBelowLS1Bit_Nonzero64( + static_cast(LoadMask() & HWY_CHOSEN_TARGET_MASK_TARGETS)); + } + + private: +#if defined(HWY_NO_LIBCXX) + int64_t LoadMask() const { return mask_; } + void StoreMask(int64_t mask) { mask_ = mask; } + + int64_t mask_{1}; // Initialized to 1 so GetIndex() returns 0. +#else + int64_t LoadMask() const { return mask_.load(); } + void StoreMask(int64_t mask) { mask_.store(mask); } + + std::atomic mask_{1}; // Initialized to 1 so GetIndex() returns 0. +#endif // HWY_ARCH_RISCV +}; + +// For internal use (e.g. by FunctionCache and DisableTargets). +HWY_DLLEXPORT ChosenTarget& GetChosenTarget(); + +} // namespace hwy + +#endif // HIGHWAY_HWY_TARGETS_H_ diff --git a/lib/highway/hwy/timer-inl.h b/lib/highway/hwy/timer-inl.h new file mode 100644 index 00000000000..87c98df02c7 --- /dev/null +++ b/lib/highway/hwy/timer-inl.h @@ -0,0 +1,48 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// DEPRECATED, use timer.h instead. + +#include "hwy/timer.h" + +#if defined(HIGHWAY_HWY_TIMER_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef HIGHWAY_HWY_TIMER_INL_H_ +#undef HIGHWAY_HWY_TIMER_INL_H_ +#else +#define HIGHWAY_HWY_TIMER_INL_H_ +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace hwy { +namespace HWY_NAMESPACE { +namespace timer { + +// Deprecated aliases so that old code still compiles. Prefer to use +// `hwy::timer::*` from timer.h because that does not require highway.h. +using Ticks = hwy::timer::Ticks; + +inline Ticks Start() { return hwy::timer::Start(); } +inline Ticks Stop() { return hwy::timer::Stop(); } + +} // namespace timer + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace hwy +HWY_AFTER_NAMESPACE(); + +#endif // per-target include guard diff --git a/lib/highway/hwy/timer.cc b/lib/highway/hwy/timer.cc new file mode 100644 index 00000000000..eac90566c82 --- /dev/null +++ b/lib/highway/hwy/timer.cc @@ -0,0 +1,192 @@ +// Copyright 2019 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hwy/timer.h" + +#include + +#include // NOLINT +#include // NOLINT + +#include "hwy/base.h" +#include "hwy/robust_statistics.h" +#include "hwy/x86_cpuid.h" + +namespace hwy { + +#if HWY_ARCH_X86 +namespace x86 { + +static bool HasRDTSCP() { + uint32_t abcd[4]; + Cpuid(0x80000001U, 0, abcd); // Extended feature flags + if ((abcd[3] & (1u << 27)) == 0) return false; // RDTSCP + + Cpuid(0x80000007U, 0, abcd); + if ((abcd[3] & (1u << 8)) == 0) { + HWY_WARN("TSC not constant/invariant, may vary frequency or jump."); + } + return true; +} + +} // namespace x86 +#endif // HWY_ARCH_X86 + +// Measures the actual current frequency of Ticks. We cannot rely on the nominal +// frequency encoded in x86 GetCpuString because it is misleading on M1 Rosetta, +// and not reported by AMD. CPUID 0x15 is also not yet widely supported. Also +// used on RISC-V and aarch64. +static HWY_MAYBE_UNUSED double MeasureNominalClockRate() { + double max_ticks_per_sec = 0.0; + // Arbitrary, enough to ignore 2 outliers without excessive init time. + for (int rep = 0; rep < 3; ++rep) { + auto time0 = std::chrono::steady_clock::now(); + using Time = decltype(time0); + const timer::Ticks ticks0 = timer::Start(); + const Time time_min = time0 + std::chrono::milliseconds(10); + + Time time1; + timer::Ticks ticks1; + for (;;) { + time1 = std::chrono::steady_clock::now(); + // Ideally this would be Stop, but that requires RDTSCP on x86. To avoid + // another codepath, just use Start instead. now() presumably has its own + // fence-like behavior. + ticks1 = timer::Start(); // Do not use Stop, see comment above + if (time1 >= time_min) break; + } + + const double dticks = static_cast(ticks1 - ticks0); + std::chrono::duration> dtime = time1 - time0; + const double ticks_per_sec = dticks / dtime.count(); + max_ticks_per_sec = HWY_MAX(max_ticks_per_sec, ticks_per_sec); + } + return max_ticks_per_sec; +} + +#if HWY_ARCH_PPC && defined(__GLIBC__) && defined(__powerpc64__) +namespace ppc { + +static HWY_INLINE double GetTimebaseFreq() { + const auto timebase_freq = __ppc_get_timebase_freq(); + // If timebase_freq is greater than 0, then return timebase_freq. + + // Otherwise, if timebase_freq is less than or equal to 0, fall back to + // MeasureNominalClockRate(). This works around issues if running on QEMU on + // non-PPC CPU's. + return (timebase_freq > 0) ? static_cast(timebase_freq) + : MeasureNominalClockRate(); +} + +} // namespace ppc +#endif + +namespace platform { + +HWY_DLLEXPORT bool GetCpuString(char* cpu100) { +#if HWY_ARCH_X86 + uint32_t abcd[4]; + + // Check if brand string is supported (it is on all reasonable Intel/AMD) + x86::Cpuid(0x80000000U, 0, abcd); + if (abcd[0] < 0x80000004U) { + cpu100[0] = '\0'; + return false; + } + + for (size_t i = 0; i < 3; ++i) { + x86::Cpuid(static_cast(0x80000002U + i), 0, abcd); + CopyBytes(&abcd[0], cpu100 + i * 16); // not same size + } + cpu100[48] = '\0'; + return true; +#else + cpu100[0] = '?'; + cpu100[1] = '\0'; + return false; +#endif +} + +HWY_DLLEXPORT double Now() { + static const double mul = 1.0 / InvariantTicksPerSecond(); + return static_cast(timer::Start()) * mul; +} + +HWY_DLLEXPORT bool HaveTimerStop(char* cpu100) { +#if HWY_ARCH_X86 + if (!x86::HasRDTSCP()) { + (void)GetCpuString(cpu100); + return false; + } +#endif + *cpu100 = '\0'; + return true; +} + +HWY_DLLEXPORT double InvariantTicksPerSecond() { +#if HWY_ARCH_PPC && defined(__GLIBC__) && defined(__powerpc64__) + static const double freq = ppc::GetTimebaseFreq(); + return freq; +#elif HWY_ARCH_X86 || HWY_ARCH_RISCV || (HWY_ARCH_ARM_A64 && !HWY_COMPILER_MSVC) + // We assume the x86 TSC is invariant; it is on all recent Intel/AMD CPUs. + static const double freq = MeasureNominalClockRate(); + return freq; +#elif defined(_WIN32) || defined(_WIN64) + LARGE_INTEGER freq; + (void)QueryPerformanceFrequency(&freq); + return static_cast(freq.QuadPart); +#elif defined(__APPLE__) + // https://developer.apple.com/library/mac/qa/qa1398/_index.html + mach_timebase_info_data_t timebase; + (void)mach_timebase_info(&timebase); + return static_cast(timebase.denom) / timebase.numer * 1E9; +#else + return 1E9; // Haiku and clock_gettime return nanoseconds. +#endif +} + +HWY_DLLEXPORT uint64_t TimerResolution() { + char cpu100[100]; + bool can_use_stop = HaveTimerStop(cpu100); + + // For measuring timer overhead/resolution. Used in a nested loop => + // quadratic time, acceptable because we know timer overhead is "low". + // constexpr because this is used to define array bounds. + constexpr size_t kTimerSamples = 256; + + // Nested loop avoids exceeding stack/L1 capacity. + timer::Ticks repetitions[kTimerSamples]; + for (size_t rep = 0; rep < kTimerSamples; ++rep) { + timer::Ticks samples[kTimerSamples]; + if (can_use_stop) { + for (size_t i = 0; i < kTimerSamples; ++i) { + const timer::Ticks t0 = timer::Start(); + const timer::Ticks t1 = timer::Stop(); // we checked HasRDTSCP above + samples[i] = t1 - t0; + } + } else { + for (size_t i = 0; i < kTimerSamples; ++i) { + const timer::Ticks t0 = timer::Start(); + const timer::Ticks t1 = timer::Start(); // do not use Stop, see above + samples[i] = t1 - t0; + } + } + repetitions[rep] = robust_statistics::Mode(samples); + } + return robust_statistics::Mode(repetitions); +} + +} // namespace platform +} // namespace hwy diff --git a/lib/highway/hwy/timer.h b/lib/highway/hwy/timer.h new file mode 100644 index 00000000000..0873ac2cf1f --- /dev/null +++ b/lib/highway/hwy/timer.h @@ -0,0 +1,281 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_TIMER_H_ +#define HIGHWAY_HWY_TIMER_H_ + +// Platform-specific timer functions. Provides Now() and functions for +// interpreting and converting Ticks. + +#include +#include // clock_gettime + +#include "hwy/base.h" + +#if defined(_WIN32) || defined(_WIN64) +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif // WIN32_LEAN_AND_MEAN +#include +#endif + +#if defined(__APPLE__) +#include +#include +#endif + +#if defined(__HAIKU__) +#include +#endif + +#if HWY_ARCH_PPC && defined(__GLIBC__) && defined(__powerpc64__) +#include // NOLINT __ppc_get_timebase_freq +#endif + +#if HWY_ARCH_X86 && HWY_COMPILER_MSVC +#include +#endif + +namespace hwy { +namespace platform { + +// Returns current timestamp [in seconds] relative to an unspecified origin. +// Features: monotonic (no negative elapsed time), steady (unaffected by system +// time changes), high-resolution (on the order of microseconds). +// Uses InvariantTicksPerSecond and the baseline version of timer::Start(). +HWY_DLLEXPORT double Now(); + +// Functions related to `Ticks` below. + +// Returns whether it is safe to call timer::Stop without executing an illegal +// instruction; if false, fills cpu100 (a pointer to a 100 character buffer) +// via GetCpuString(). +HWY_DLLEXPORT bool HaveTimerStop(char* cpu100); + +// Returns tick rate, useful for converting timer::Ticks to seconds. Invariant +// means the tick counter frequency is independent of CPU throttling or sleep. +// This call may be expensive, callers should cache the result. +HWY_DLLEXPORT double InvariantTicksPerSecond(); + +// Returns ticks elapsed in back to back timer calls, i.e. a function of the +// timer resolution (minimum measurable difference) and overhead. +// This call is expensive, callers should cache the result. +HWY_DLLEXPORT uint64_t TimerResolution(); + +// Returns false if no detailed description is available, otherwise fills +// `cpu100` with up to 100 characters (including \0) identifying the CPU model. +HWY_DLLEXPORT bool GetCpuString(char* cpu100); + +} // namespace platform + +struct Timestamp { + Timestamp() { t = platform::Now(); } + double t; +}; + +static inline double SecondsSince(const Timestamp& t0) { + const Timestamp t1; + return t1.t - t0.t; +} + +// Low-level Start/Stop functions, previously in timer-inl.h. + +namespace timer { + +// Ticks := platform-specific timer values (CPU cycles on x86). Must be +// unsigned to guarantee wraparound on overflow. +using Ticks = uint64_t; + +// Start/Stop return absolute timestamps and must be placed immediately before +// and after the region to measure. We provide separate Start/Stop functions +// because they use different fences. +// +// Background: RDTSC is not 'serializing'; earlier instructions may complete +// after it, and/or later instructions may complete before it. 'Fences' ensure +// regions' elapsed times are independent of such reordering. The only +// documented unprivileged serializing instruction is CPUID, which acts as a +// full fence (no reordering across it in either direction). Unfortunately +// the latency of CPUID varies wildly (perhaps made worse by not initializing +// its EAX input). Because it cannot reliably be deducted from the region's +// elapsed time, it must not be included in the region to measure (i.e. +// between the two RDTSC). +// +// The newer RDTSCP is sometimes described as serializing, but it actually +// only serves as a half-fence with release semantics. Although all +// instructions in the region will complete before the final timestamp is +// captured, subsequent instructions may leak into the region and increase the +// elapsed time. Inserting another fence after the final `RDTSCP` would prevent +// such reordering without affecting the measured region. +// +// Fortunately, such a fence exists. The LFENCE instruction is only documented +// to delay later loads until earlier loads are visible. However, Intel's +// reference manual says it acts as a full fence (waiting until all earlier +// instructions have completed, and delaying later instructions until it +// completes). AMD assigns the same behavior to MFENCE. +// +// We need a fence before the initial RDTSC to prevent earlier instructions +// from leaking into the region, and arguably another after RDTSC to avoid +// region instructions from completing before the timestamp is recorded. +// When surrounded by fences, the additional `RDTSCP` half-fence provides no +// benefit, so the initial timestamp can be recorded via RDTSC, which has +// lower overhead than `RDTSCP` because it does not read TSC_AUX. In summary, +// we define Start = LFENCE/RDTSC/LFENCE; Stop = RDTSCP/LFENCE. +// +// Using Start+Start leads to higher variance and overhead than Stop+Stop. +// However, Stop+Stop includes an LFENCE in the region measurements, which +// adds a delay dependent on earlier loads. The combination of Start+Stop +// is faster than Start+Start and more consistent than Stop+Stop because +// the first LFENCE already delayed subsequent loads before the measured +// region. This combination seems not to have been considered in prior work: +// http://akaros.cs.berkeley.edu/lxr/akaros/kern/arch/x86/rdtsc_test.c +// +// Note: performance counters can measure 'exact' instructions-retired or +// (unhalted) cycle counts. The RDPMC instruction is not serializing and also +// requires fences. Unfortunately, it is not accessible on all OSes and we +// prefer to avoid kernel-mode drivers. Performance counters are also affected +// by several under/over-count errata, so we use the TSC instead. + +// Returns a 64-bit timestamp in unit of 'ticks'; to convert to seconds, +// divide by InvariantTicksPerSecond. +static HWY_INLINE Ticks Start() { + Ticks t; +#if HWY_ARCH_PPC && defined(__GLIBC__) && defined(__powerpc64__) + asm volatile("mfspr %0, %1" : "=r"(t) : "i"(268)); +#elif HWY_ARCH_ARM_A64 && !HWY_COMPILER_MSVC + // pmccntr_el0 is privileged but cntvct_el0 is accessible in Linux and QEMU. + asm volatile("mrs %0, cntvct_el0" : "=r"(t)); +#elif HWY_ARCH_X86 && HWY_COMPILER_MSVC + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); + t = __rdtsc(); + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); +#elif HWY_ARCH_X86_64 + asm volatile( + "lfence\n\t" + "rdtsc\n\t" + "shl $32, %%rdx\n\t" + "or %%rdx, %0\n\t" + "lfence" + : "=a"(t) + : + // "memory" avoids reordering. rdx = TSC >> 32. + // "cc" = flags modified by SHL. + : "rdx", "memory", "cc"); +#elif HWY_ARCH_RISCV + asm volatile("fence; rdtime %0" : "=r"(t)); +#elif defined(_WIN32) || defined(_WIN64) + LARGE_INTEGER counter; + (void)QueryPerformanceCounter(&counter); + t = counter.QuadPart; +#elif defined(__APPLE__) + t = mach_absolute_time(); +#elif defined(__HAIKU__) + t = system_time_nsecs(); // since boot +#else // POSIX + timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + t = static_cast(ts.tv_sec * 1000000000LL + ts.tv_nsec); +#endif + return t; +} + +// WARNING: on x86, caller must check `HaveTimerStop()` before using this! +static HWY_INLINE Ticks Stop() { + uint64_t t; +#if HWY_ARCH_PPC && defined(__GLIBC__) && defined(__powerpc64__) + asm volatile("mfspr %0, %1" : "=r"(t) : "i"(268)); +#elif HWY_ARCH_ARM_A64 && !HWY_COMPILER_MSVC + // pmccntr_el0 is privileged but cntvct_el0 is accessible in Linux and QEMU. + asm volatile("mrs %0, cntvct_el0" : "=r"(t)); +#elif HWY_ARCH_X86 && HWY_COMPILER_MSVC + _ReadWriteBarrier(); + unsigned aux; + t = __rdtscp(&aux); + _ReadWriteBarrier(); + _mm_lfence(); + _ReadWriteBarrier(); +#elif HWY_ARCH_X86_64 + // Use inline asm because __rdtscp generates code to store TSC_AUX (ecx). + asm volatile( + "rdtscp\n\t" + "shl $32, %%rdx\n\t" + "or %%rdx, %0\n\t" + "lfence" + : "=a"(t) + : + // "memory" avoids reordering. rcx = TSC_AUX. rdx = TSC >> 32. + // "cc" = flags modified by SHL. + : "rcx", "rdx", "memory", "cc"); +#else + t = Start(); +#endif + return t; +} + +} // namespace timer + +// Wrapper around Start/Stop that checks whether the CPU supports Stop. +class Timer { + public: + Timer() { + char cpu100[100]; + have_timer_stop_ = platform::HaveTimerStop(cpu100); + } + + // Before/After have fences to prevent the measured code 'leaking out'. + timer::Ticks Before() const { return timer::Start(); } + timer::Ticks After() const { + return have_timer_stop_ ? timer::Stop() : timer::Start(); + } + + private: + bool have_timer_stop_; +}; + +static inline double Seconds(timer::Ticks ticks) { + return static_cast(ticks) / platform::InvariantTicksPerSecond(); +} + +// Measures elapsed time since construction, with automatic reset. +class Stopwatch { + public: + explicit Stopwatch(const Timer& timestamps) : timer_(timestamps) { Reset(); } + + timer::Ticks Origin() const { return t0_; } + void Reset() { t0_ = timer_.Before(); } + + // Also resets the start time to the current time to enable reuse without a + // second call to the timer. + timer::Ticks Elapsed() { + const timer::Ticks t1 = timer_.After(); + const timer::Ticks elapsed = t1 - t0_; + t0_ = t1; + return elapsed; + } + + private: + const Timer& timer_; + timer::Ticks t0_; +}; + +} // namespace hwy + +#endif // HIGHWAY_HWY_TIMER_H_ diff --git a/lib/highway/hwy/x86_cpuid.h b/lib/highway/hwy/x86_cpuid.h new file mode 100644 index 00000000000..60bd51a0224 --- /dev/null +++ b/lib/highway/hwy/x86_cpuid.h @@ -0,0 +1,81 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HIGHWAY_HWY_X86_CPUID_H_ +#define HIGHWAY_HWY_X86_CPUID_H_ + +// Wrapper for x86 CPUID intrinsics. Empty on other platforms. + +#include + +#include "hwy/base.h" + +#if HWY_ARCH_X86 + +#if HWY_COMPILER_MSVC || HWY_COMPILER_CLANGCL +#include +#else +#include +#endif + +namespace hwy { +namespace x86 { + +// Calls CPUID instruction with eax=level and ecx=count and returns the result +// in abcd array where abcd = {eax, ebx, ecx, edx} (hence the name abcd). +static inline void Cpuid(const uint32_t level, const uint32_t count, + uint32_t* HWY_RESTRICT abcd) { +#if HWY_COMPILER_MSVC || HWY_COMPILER_CLANGCL + int regs[4]; + __cpuidex(regs, static_cast(level), static_cast(count)); + for (int i = 0; i < 4; ++i) { + abcd[i] = static_cast(regs[i]); + } +#else // HWY_COMPILER_MSVC || HWY_COMPILER_CLANGCL + uint32_t a; + uint32_t b; + uint32_t c; + uint32_t d; + __cpuid_count(level, count, a, b, c, d); + abcd[0] = a; + abcd[1] = b; + abcd[2] = c; + abcd[3] = d; +#endif // HWY_COMPILER_MSVC || HWY_COMPILER_CLANGCL +} + +static inline bool IsBitSet(const uint32_t reg, const int index) { + return (reg & (1U << index)) != 0; +} + +static inline uint32_t MaxLevel() { + uint32_t abcd[4]; + Cpuid(0, 0, abcd); + return abcd[0]; +} + +static inline bool IsAMD() { + uint32_t abcd[4]; + Cpuid(0, 0, abcd); + const uint32_t max_level = abcd[0]; + return max_level >= 1 && abcd[1] == 0x68747541 && abcd[2] == 0x444d4163 && + abcd[3] == 0x69746e65; +} + +} // namespace x86 +} // namespace hwy + +#endif // HWY_ARCH_X86 +#endif // HIGHWAY_HWY_X86_CPUID_H_ diff --git a/tools/clang-format.sh b/tools/clang-format.sh index 23e17bb6a46..8c2a6c6919a 100755 --- a/tools/clang-format.sh +++ b/tools/clang-format.sh @@ -100,7 +100,7 @@ EOF start_time_file=$(mktemp -t clang-format-start-time.XXXXXXXXXX) touch ${start_time_file} - target_files=$(find $DIR -iname \*.[ch] -o -iname \*.cc -o -iname \*.h.in -o -iname \*.hpp -o -iname \*.cript | grep -vE 'lib/(Catch2|fastlz|ls-hpack|swoc|systemtap|yamlcpp)') + target_files=$(find $DIR -iname \*.[ch] -o -iname \*.cc -o -iname \*.h.in -o -iname \*.hpp -o -iname \*.cript | grep -vE 'lib/(Catch2|fastlz|highway|ls-hpack|swoc|systemtap|yamlcpp)') for file in ${target_files}; do # The ink_autoconf.h and ink_autoconf.h.in files are generated files, # so they do not need to be re-formatted by clang-format. Doing so diff --git a/tools/cmake-format.sh b/tools/cmake-format.sh index 07a1a32e12a..aded34d2493 100755 --- a/tools/cmake-format.sh +++ b/tools/cmake-format.sh @@ -25,6 +25,10 @@ CMAKE_FORMAT_VERSION="0.6.13" VERSION="0.6.13" +function cmake_vendorlib_grep() { + grep -E "CMakeLists.txt|.cmake$" | grep -vE "lib/(Catch2|fastlz|highway|ls-hpack|swoc|yamlcpp)" || true +} + function main() { set -e # exit on error @@ -48,16 +52,13 @@ function main() { # formatting files the user doesn't want formatted. tmp_dir=$(mktemp -d -t tracked-git-files.XXXXXXXXXX) files=${tmp_dir}/git_files.txt - files_filtered=${tmp_dir}/git_files_filtered.txt - git ls-tree -r HEAD --name-only ${DIR} | grep -E 'CMakeLists.txt|.cmake$' | grep -vE "lib/(Catch2|fastlz|ls-hpack|swoc|yamlcpp)" > ${files} + git ls-tree -r HEAD --name-only ${DIR} | cmake_vendorlib_grep > ${files} # Add to the above any newly added staged files. - git diff --cached --name-only --diff-filter=A >> ${files} - # But probably not all the new staged files are CMakeLists.txt files: - grep -E 'CMakeLists.txt|.cmake$' ${files} > ${files_filtered} + git diff --cached --name-only --diff-filter=A | cmake_vendorlib_grep >> ${files} # Prepend the filenames with "./" to make the modified file output consistent # with the clang-format target output. - sed -i'.bak' 's:^:\./:' ${files_filtered} - rm -f ${files_filtered}.bak + sed -i'.bak' 's:^:\./:' ${files} + rm -f ${files}.bak # Efficiently retrieving modification timestamps in a platform # independent way is challenging. We use find's -newer argument, which @@ -66,8 +67,8 @@ function main() { # after this we assume was modified by cmake-format. start_time_file=${tmp_dir}/format_start.$$ touch ${start_time_file} - uv tool run --quiet --from cmakelang@${CMAKE_FORMAT_VERSION} --with pyaml cmake-format -i $(cat ${files_filtered}) - find $(cat ${files_filtered}) -newer ${start_time_file} + uv tool run --quiet --from cmakelang@${CMAKE_FORMAT_VERSION} --with pyaml cmake-format -i $(cat ${files}) + find $(cat ${files}) -newer ${start_time_file} rm -rf ${tmp_dir} } diff --git a/tools/git/pre-commit b/tools/git/pre-commit index d200824217f..c11552271e6 100755 --- a/tools/git/pre-commit +++ b/tools/git/pre-commit @@ -62,7 +62,7 @@ trap "rm -f $clang_patch_file $yapf_patch_file $cmake_format_patch_file" 0 1 2 3 # Loop over all files that are changed, and produce a diff file REPO_ROOT=$(cd $(dirname $0)/../.. && git rev-parse --show-toplevel) YAPF_CONFIG=${REPO_ROOT}/.style.yapf -git diff-index --cached --diff-filter=ACMR --name-only HEAD | grep -vE "lib/(Catch2|fastlz|ls-hpack|swoc|yamlcpp)" | while read file; do +git diff-index --cached --diff-filter=ACMR --name-only HEAD | grep -vE "lib/(Catch2|fastlz|highway|ls-hpack|swoc|yamlcpp)" | while read file; do case "$file" in *.cc | *.c | *.h | *.h.in | *.cript) ${FORMAT} "$file" | diff -u "$file" - >>"$clang_patch_file" @@ -80,7 +80,7 @@ git diff-index --cached --diff-filter=ACMR --name-only HEAD | grep -vE "lib/(Cat done # Now repeat the above for CMakeLists.txt files. -git diff-index --cached --diff-filter=ACMR --name-only HEAD | grep -E 'CMakeLists.txt|\.cmake$' | grep -vE "lib/(Catch2|fastlz|ls-hpack|swoc|yamlcpp)" | while read file; do +git diff-index --cached --diff-filter=ACMR --name-only HEAD | grep -E 'CMakeLists.txt|\.cmake$' | grep -vE "lib/(Catch2|highway|fastlz|ls-hpack|swoc|yamlcpp)" | while read file; do uv tool run --quiet --from cmakelang@${CMAKE_FORMAT_VERSION} --with pyaml cmake-format "$file" | diff -u "$file" - >>"$cmake_format_patch_file" done