diff --git a/kernel/simd/intrin.h b/kernel/simd/intrin.h index 3802a91e1a..df68ff5262 100644 --- a/kernel/simd/intrin.h +++ b/kernel/simd/intrin.h @@ -56,6 +56,11 @@ extern "C" { #include #endif +/** WASM SIMD **/ +#if defined(ARCH_WASM) && defined(__wasm_simd128__) +#include +#endif + // distribute #if defined(HAVE_AVX512VL) || defined(HAVE_AVX512BF16) #include "intrin_avx512.h" @@ -69,6 +74,10 @@ extern "C" { #include "intrin_neon.h" #endif +#if defined(ARCH_WASM) && defined(__wasm_simd128__) +#include "intrin_wasm.h" +#endif + #ifndef V_SIMD #define V_SIMD 0 #define V_SIMD_F64 0 diff --git a/kernel/simd/intrin_wasm.h b/kernel/simd/intrin_wasm.h new file mode 100644 index 0000000000..1e04c70127 --- /dev/null +++ b/kernel/simd/intrin_wasm.h @@ -0,0 +1,63 @@ +#include + +#define V_SIMD 128 +#define V_SIMD_F64 1 + +/*************************** + * Data Type + ***************************/ +typedef v128_t v_f32; +typedef v128_t v_f64; +#define v_nlanes_f32 4 +#define v_nlanes_f64 2 + +/*************************** + * Arithmetic + ***************************/ +#define v_add_f32 wasm_f32x4_add +#define v_add_f64 wasm_f64x2_add +#define v_sub_f32 wasm_f32x4_sub +#define v_sub_f64 wasm_f64x2_sub +#define v_mul_f32 wasm_f32x4_mul +#define v_mul_f64 wasm_f64x2_mul + +BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c) +{ return v_add_f32(v_mul_f32(a, b), c); } + +BLAS_FINLINE v_f64 v_muladd_f64(v_f64 a, v_f64 b, v_f64 c) +{ return v_add_f64(v_mul_f64(a, b), c); } + +BLAS_FINLINE v_f32 v_mulsub_f32(v_f32 a, v_f32 b, v_f32 c) +{ return v_sub_f32(v_mul_f32(a, b), c); } + +BLAS_FINLINE v_f64 v_mulsub_f64(v_f64 a, v_f64 b, v_f64 c) +{ return v_sub_f64(v_mul_f64(a, b), c); } + +/*************************** + * reduction + ***************************/ +BLAS_FINLINE float v_sum_f32(v_f32 a) +{ + return wasm_f32x4_extract_lane(a, 0) + + wasm_f32x4_extract_lane(a, 1) + + wasm_f32x4_extract_lane(a, 2) + + wasm_f32x4_extract_lane(a, 3); +} + +BLAS_FINLINE double v_sum_f64(v_f64 a) +{ + return wasm_f64x2_extract_lane(a, 0) + + wasm_f64x2_extract_lane(a, 1); +} + +/*************************** + * memory + ***************************/ +#define v_loadu_f32(a) wasm_v128_load((const float*)a) +#define v_loadu_f64(a) wasm_v128_load((const double*)a) +#define v_storeu_f32(a, v) wasm_v128_store((float*)a, v) +#define v_storeu_f64(a, v) wasm_v128_store((double*)a, v) +#define v_setall_f32(VAL) wasm_f32x4_splat(VAL) +#define v_setall_f64(VAL) wasm_f64x2_splat(VAL) +#define v_zero_f32() wasm_f32x4_splat(0.0f) +#define v_zero_f64() wasm_f64x2_splat(0.0) diff --git a/kernel/wasm/KERNEL.WASM128_GENERIC b/kernel/wasm/KERNEL.WASM128_GENERIC index 679147767d..1f1946a015 100644 --- a/kernel/wasm/KERNEL.WASM128_GENERIC +++ b/kernel/wasm/KERNEL.WASM128_GENERIC @@ -40,7 +40,7 @@ DSUMKERNEL = ../arm/sum.c CSUMKERNEL = ../arm/zsum.c ZSUMKERNEL = ../arm/zsum.c -SAXPYKERNEL = ../riscv64/axpy.c +SAXPYKERNEL = ../x86_64/saxpy.c DAXPYKERNEL = ../x86_64/daxpy.c CAXPYKERNEL = ../riscv64/zaxpy.c ZAXPYKERNEL = ../riscv64/zaxpy.c diff --git a/kernel/x86_64/saxpy.c b/kernel/x86_64/saxpy.c index ff911c52b9..b425f1515f 100644 --- a/kernel/x86_64/saxpy.c +++ b/kernel/x86_64/saxpy.c @@ -43,12 +43,23 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #ifndef HAVE_KERNEL_16 +#include"../simd/intrin.h" static void saxpy_kernel_16(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *alpha) { BLASLONG register i = 0; FLOAT a = *alpha; +#if V_SIMD + v_f32 __alpha, tmp; + __alpha = v_setall_f32(*alpha); + const int vstep = v_nlanes_f32; + + for (; i < n; i += vstep) { + tmp = v_muladd_f32(__alpha, v_loadu_f32(x + i), v_loadu_f32(y + i)); + v_storeu_f32(y + i, tmp); + } +#else while(i < n) { y[i] += a * x[i]; @@ -62,6 +73,7 @@ static void saxpy_kernel_16(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *alpha) i+=8 ; } +#endif } @@ -131,4 +143,3 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x, BLAS } - diff --git a/kernel/x86_64/srot.c b/kernel/x86_64/srot.c index 05724b427a..441c9fd593 100644 --- a/kernel/x86_64/srot.c +++ b/kernel/x86_64/srot.c @@ -13,7 +13,7 @@ static void srot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s) { BLASLONG i = 0; -#if V_SIMD && !defined(C_PGI) && (defined(HAVE_FMA3) || V_SIMD > 128) +#if V_SIMD && !defined(C_PGI) && (defined(HAVE_FMA3) || V_SIMD > 128 || defined(ARCH_WASM)) const int vstep = v_nlanes_f32; const int unrollx4 = n & (-vstep * 4); const int unrollx = n & -vstep;