diff --git a/source/source_basis/module_pw/pw_gatherscatter.h b/source/source_basis/module_pw/pw_gatherscatter.h index 00ce160e254..0c4b5fc0fb7 100644 --- a/source/source_basis/module_pw/pw_gatherscatter.h +++ b/source/source_basis/module_pw/pw_gatherscatter.h @@ -1,10 +1,49 @@ #include "pw_basis.h" #include "source_base/global_function.h" #include "source_base/timer.h" +#include #include namespace ModulePW { +namespace detail +{ +template +inline void copy_complex_buffer(const std::complex* in, std::complex* out, const int count) +{ + if (count <= 0) + { + return; + } + + std::copy_n(in, count, out); +} + +// Top-level transform copies own the OpenMP parallel region; gather/scatter +// loops call the non-parallel helper inside their existing parallel regions. +template +inline void copy_complex_buffer_parallel(const std::complex* in, std::complex* out, const int count) +{ + constexpr int chunk_size = 1024; + if (count <= chunk_size) + { + copy_complex_buffer(in, out, count); + return; + } + +#ifdef _OPENMP +#pragma omp parallel for schedule(static) + for (int offset = 0; offset < count; offset += chunk_size) + { + const int chunk_count = std::min(chunk_size, count - offset); + std::copy_n(in + offset, chunk_count, out + offset); + } +#else + copy_complex_buffer(in, out, count); +#endif +} +} // namespace detail + /** * @brief gather planes and scatter sticks * @param in: (nplane,fftny,fftnx) @@ -21,16 +60,18 @@ void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const const int nst_ = this->nst; const int nz_ = this->nz; const int* istot2ixy_ = this->istot2ixy; + ModuleBase::timer::start(this->classname, "gatherp_copy_serial"); #ifdef _OPENMP #pragma omp parallel for #endif for(int is = 0 ; is < nst_ ; ++is) { int ixy = istot2ixy_[is]; - std::complex *outp = &out[is*nz_]; - std::complex *inp = &in[ixy*nz_]; - std::memcpy(outp, inp, nz_ * sizeof(std::complex)); + std::complex* outp = &out[is*nz_]; + const std::complex* inp = &in[ixy*nz_]; + detail::copy_complex_buffer(inp, outp, nz_); } + ModuleBase::timer::end(this->classname, "gatherp_copy_serial"); return; } @@ -41,16 +82,18 @@ void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const const int nstot_gps = this->nstot; const int nplane_gps = this->nplane; const int* istot2ixy_gps = this->istot2ixy; + ModuleBase::timer::start(this->classname, "gatherp_copy_pack"); #ifdef _OPENMP #pragma omp parallel for #endif for (int istot = 0; istot < nstot_gps; ++istot) { int ixy = istot2ixy_gps[istot]; - std::complex *outp = &out[istot * nplane_gps]; - std::complex *inp = &in[ixy * nplane_gps]; - std::memcpy(outp, inp, nplane_gps * sizeof(std::complex)); + std::complex* outp = &out[istot * nplane_gps]; + const std::complex* inp = &in[ixy * nplane_gps]; + detail::copy_complex_buffer(inp, outp, nplane_gps); } + ModuleBase::timer::end(this->classname, "gatherp_copy_pack"); //exchange data //(nplane,nstot) to (numz[ip],ns, poolnproc) @@ -74,6 +117,7 @@ void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const const int* numz_gps = this->numz; const int* startg_gps = this->startg; const int* startz_gps = this->startz; + ModuleBase::timer::start(this->classname, "gatherp_copy_unpack"); #ifdef _OPENMP #pragma omp parallel for collapse(2) #endif @@ -84,11 +128,12 @@ void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const int nzip = numz_gps[ip]; std::complex *outp0 = &out[startz_gps[ip]]; std::complex *inp0 = &in[startg_gps[ip]]; - std::complex *outp = &outp0[is * nz_gps]; - std::complex *inp = &inp0[is * nzip ]; - std::memcpy(outp, inp, nzip * sizeof(std::complex)); + std::complex* outp = &outp0[is * nz_gps]; + const std::complex* inp = &inp0[is * nzip ]; + detail::copy_complex_buffer(inp, outp, nzip); } } + ModuleBase::timer::end(this->classname, "gatherp_copy_unpack"); #endif return; } @@ -109,6 +154,7 @@ void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const const int nst_ = this->nst; const int nz_ = this->nz; const int* istot2ixy_ = this->istot2ixy; + ModuleBase::timer::start(this->classname, "gathers_zero_serial"); #ifdef _OPENMP #pragma omp parallel for schedule(static) #endif @@ -116,17 +162,20 @@ void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const { out[i] = std::complex(0, 0); } + ModuleBase::timer::end(this->classname, "gathers_zero_serial"); + ModuleBase::timer::start(this->classname, "gathers_copy_serial"); #ifdef _OPENMP #pragma omp parallel for #endif for(int is = 0 ; is < nst_ ; ++is) { int ixy = istot2ixy_[is]; - std::complex *outp = &out[ixy*nz_]; - std::complex *inp = &in[is*nz_]; - std::memcpy(outp, inp, nz_ * sizeof(std::complex)); + std::complex* outp = &out[ixy*nz_]; + const std::complex* inp = &in[is*nz_]; + detail::copy_complex_buffer(inp, outp, nz_); } + ModuleBase::timer::end(this->classname, "gathers_copy_serial"); return; } @@ -140,6 +189,7 @@ void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const const int* numz_ = this->numz; const int* startg_ = this->startg; const int* startz_ = this->startz; + ModuleBase::timer::start(this->classname, "gathers_copy_pack"); #ifdef _OPENMP #pragma omp parallel for collapse(2) #endif @@ -150,11 +200,12 @@ void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const int nzip = numz_[ip]; std::complex *outp0 = &out[startg_[ip]]; std::complex *inp0 = &in[startz_[ip]]; - std::complex *outp = &outp0[is * nzip]; - std::complex *inp = &inp0[is * nz_ ]; - std::memcpy(outp, inp, nzip * sizeof(std::complex)); + std::complex* outp = &outp0[is * nzip]; + const std::complex* inp = &inp0[is * nz_ ]; + detail::copy_complex_buffer(inp, outp, nzip); } } + ModuleBase::timer::end(this->classname, "gathers_copy_pack"); //exchange data //(numz[ip],ns, poolnproc) to (nplane,nstot) @@ -172,6 +223,7 @@ void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const } const int nrxx_gsp = this->nrxx; + ModuleBase::timer::start(this->classname, "gathers_zero_mpi"); #ifdef _OPENMP #pragma omp parallel for schedule(static) #endif @@ -179,10 +231,12 @@ void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const { out[i] = std::complex(0, 0); } + ModuleBase::timer::end(this->classname, "gathers_zero_mpi"); //change (nplane,nstot) to (nplane fftnxy) const int nstot = this->nstot; const int nplane = this->nplane; const int* istot2ixy = this->istot2ixy; + ModuleBase::timer::start(this->classname, "gathers_copy_unpack"); #ifdef _OPENMP #pragma omp parallel for #endif @@ -190,10 +244,11 @@ void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const { int ixy = istot2ixy[istot]; //int ixy = (ixy / fftny)*ny + ixy % fftny; - std::complex *outp = &out[ixy * nplane]; - std::complex *inp = &in[istot * nplane]; - std::memcpy(outp, inp, nplane * sizeof(std::complex)); + std::complex* outp = &out[ixy * nplane]; + const std::complex* inp = &in[istot * nplane]; + detail::copy_complex_buffer(inp, outp, nplane); } + ModuleBase::timer::end(this->classname, "gathers_copy_unpack"); #endif return; } diff --git a/source/source_basis/module_pw/pw_transform_k.cpp b/source/source_basis/module_pw/pw_transform_k.cpp index a09aa2b686f..8c45e3d9b22 100644 --- a/source/source_basis/module_pw/pw_transform_k.cpp +++ b/source/source_basis/module_pw/pw_transform_k.cpp @@ -33,13 +33,7 @@ void PW_Basis_K::real2recip(const std::complex* in, assert(this->gamma_only == false); auto* auxr = this->fft_bundle.get_auxr_data(); -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (int ir = 0; ir < this->nrxx; ++ir) - { - auxr[ir] = in[ir]; - } + detail::copy_complex_buffer_parallel(in, auxr, this->nrxx); this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data(), fft_bundle.get_auxr_data()); this->gatherp_scatters(this->fft_bundle.get_auxr_data(), this->fft_bundle.get_auxg_data()); @@ -200,13 +194,7 @@ void PW_Basis_K::recip2real(const std::complex* in, } else { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (int ir = 0; ir < this->nrxx; ++ir) - { - out[ir] = auxr[ir]; - } + detail::copy_complex_buffer_parallel(auxr, out, this->nrxx); } ModuleBase::timer::end(this->classname, "recip2real"); } diff --git a/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp b/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp index 84932bae2ff..039ea7c089a 100644 --- a/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp +++ b/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp @@ -2,6 +2,9 @@ #include "source_base/global_function.h" #include "source_base/constants.h" #include "source_base/matrix3.h" +#include +#include +#include /************************************************ * serial unit test of functions in pw_basis.cpp @@ -27,6 +30,7 @@ #define private public #include "../pw_basis_k.h" #include "../pw_basis.h" +#include "../pw_gatherscatter.h" #undef private #undef protected @@ -188,4 +192,99 @@ TEST_F(PWBasisKTEST, CollectLocalPW) EXPECT_EQ(basis_k.npwk_max,2721); } +TEST_F(PWBasisKTEST, ComplexTransformRoundTrip) +{ + ModulePW::PW_Basis_K basis_k(device_flag, precision_double); + double lat0 = 2.0; + ModuleBase::Matrix3 latvec(1.0,0.0,1.0, + 0.0,2.0,0.0, + 0.0,0.0,2.0); + double gridecut = 30.0; + const bool gamma_only_in = false; + const double gk_ecut_in = 20.0; + const int nks_in = 1; + const ModuleBase::Vector3 kvec_d_in[1] = { {0.0, 0.0, 0.0} }; + const int distribution_type_in = 2; + const bool xprime_in = false; + + basis_k.initgrids(lat0, latvec, gridecut); + basis_k.initparameters(gamma_only_in, gk_ecut_in, nks_in, kvec_d_in, distribution_type_in, xprime_in); + ASSERT_NO_THROW(basis_k.setuptransform()); + ASSERT_NE(basis_k.npwk, nullptr); + ASSERT_GT(basis_k.npwk[0], 0); + + // Use reciprocal-space input because arbitrary real-space data is projected + // by the plane-wave cutoff and is not exactly recoverable. + std::vector> recip_in(basis_k.npwk[0]); + std::vector> real_space(basis_k.nrxx); + std::vector> recip_out(basis_k.npwk[0]); + for (int ig = 0; ig < basis_k.npwk[0]; ++ig) + { + const double real_part = (ig % 17 - 8) / 11.0; + const double imag_part = (ig % 19 - 9) / 13.0; + recip_in[ig] = std::complex(real_part, imag_part); + } + + basis_k.recip2real(recip_in.data(), real_space.data(), 0); + basis_k.real2recip(real_space.data(), recip_out.data(), 0); + + for (int ig = 0; ig < basis_k.npwk[0]; ++ig) + { + EXPECT_NEAR(recip_in[ig].real(), recip_out[ig].real(), 1e-10); + EXPECT_NEAR(recip_in[ig].imag(), recip_out[ig].imag(), 1e-10); + } +} + +TEST_F(PWBasisKTEST, CopyComplexBufferTimerBenchmark) +{ + if (std::getenv("ABACUS_PW_SIMD_TIMER_TEST") == nullptr) + { + GTEST_SKIP() << "Set ABACUS_PW_SIMD_TIMER_TEST=1 to run the copy timer benchmark."; + } + const int count = 1 << 20; + const int repeats = 64; + std::vector> src(count); + std::vector> copy_n_dst(count); + std::vector> scalar_dst(count); + + for (int i = 0; i < count; ++i) + { + src[i] = std::complex((i % 97) / 17.0, (i % 89) / 19.0); + } + + volatile double checksum = 0.0; + + const auto copy_n_start = std::chrono::steady_clock::now(); + for (int repeat = 0; repeat < repeats; ++repeat) + { + ModulePW::detail::copy_complex_buffer(src.data(), copy_n_dst.data(), count); + checksum += copy_n_dst[repeat].real(); + } + const auto copy_n_end = std::chrono::steady_clock::now(); + + const auto scalar_start = std::chrono::steady_clock::now(); + for (int repeat = 0; repeat < repeats; ++repeat) + { + for (int i = 0; i < count; ++i) + { + scalar_dst[i] = src[i]; + } + checksum += scalar_dst[repeat].imag(); + } + const auto scalar_end = std::chrono::steady_clock::now(); + + const double copy_n_time = std::chrono::duration(copy_n_end - copy_n_start).count(); + const double scalar_time = std::chrono::duration(scalar_end - scalar_start).count(); + const double bytes_moved = static_cast(count) * sizeof(std::complex) * repeats; + const double gib = bytes_moved / (1024.0 * 1024.0 * 1024.0); + + std::cout << "PW_SIMD_TEST copy_n_helper " << copy_n_time << " s, " + << gib / copy_n_time << " GiB/s\n"; + std::cout << "PW_SIMD_TEST scalar_loop " << scalar_time << " s, " + << gib / scalar_time << " GiB/s\n"; + std::cout << "PW_SIMD_TEST speedup copy_n/scalar " << scalar_time / copy_n_time + << ", checksum " << checksum << "\n"; + + ASSERT_EQ(copy_n_dst, scalar_dst); +} diff --git a/source/source_basis/module_pw/test_serial/pw_basis_test.cpp b/source/source_basis/module_pw/test_serial/pw_basis_test.cpp index ea678b9d97c..57ac8f06554 100644 --- a/source/source_basis/module_pw/test_serial/pw_basis_test.cpp +++ b/source/source_basis/module_pw/test_serial/pw_basis_test.cpp @@ -2,6 +2,7 @@ #include "source_base/global_function.h" #include "source_base/constants.h" #include "source_base/matrix3.h" +#include /************************************************ * serial unit test of functions in pw_basis.cpp @@ -362,3 +363,41 @@ TEST_F(PWBasisTEST,CollectUniqgg) pwb.collect_uniqgg(); EXPECT_EQ(pwb.ngg,78); } + +TEST_F(PWBasisTEST,ComplexTransformRoundTrip) +{ + double lat0 = 2.0; + ModuleBase::Matrix3 latvec(1.0,0.0,1.0, + 0.0,2.0,0.0, + 0.0,0.0,2.0); + double gridecut = 30.0; + bool gamma_only_in = false; + double pwecut_in = 20.0; + int distribution_type_in = 2; + bool xprime_in = false; + + pwb.initgrids(lat0, latvec, gridecut); + pwb.initparameters(gamma_only_in, pwecut_in, distribution_type_in, xprime_in); + ASSERT_NO_THROW(pwb.setuptransform()); + + // Use reciprocal-space input because arbitrary real-space data is projected + // by the plane-wave cutoff and is not exactly recoverable. + std::vector> recip_in(pwb.npw); + std::vector> real_space(pwb.nrxx); + std::vector> recip_out(pwb.npw); + for (int ig = 0; ig < pwb.npw; ++ig) + { + const double real_part = (ig % 11 - 5) / 7.0; + const double imag_part = (ig % 13 - 6) / 9.0; + recip_in[ig] = std::complex(real_part, imag_part); + } + + pwb.recip2real(recip_in.data(), real_space.data()); + pwb.real2recip(real_space.data(), recip_out.data()); + + for (int ig = 0; ig < pwb.npw; ++ig) + { + EXPECT_NEAR(recip_in[ig].real(), recip_out[ig].real(), 1e-10); + EXPECT_NEAR(recip_in[ig].imag(), recip_out[ig].imag(), 1e-10); + } +}