Skip to content
91 changes: 73 additions & 18 deletions source/source_basis/module_pw/pw_gatherscatter.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,49 @@
#include "pw_basis.h"
#include "source_base/global_function.h"
#include "source_base/timer.h"
#include <algorithm>
#include <typeinfo>

namespace ModulePW
{
namespace detail
{
template <typename T>
inline void copy_complex_buffer(const std::complex<T>* in, std::complex<T>* out, const int count)
Comment thread
Aunixt marked this conversation as resolved.
{
if (count <= 0)
{
return;
}

std::copy_n(in, count, out);
}

// Top-level transform copies own the OpenMP parallel region; gather/scatter
// loops call the non-parallel helper inside their existing parallel regions.
template <typename T>
inline void copy_complex_buffer_parallel(const std::complex<T>* in, std::complex<T>* out, const int count)
{
constexpr int chunk_size = 1024;
if (count <= chunk_size)
{
copy_complex_buffer(in, out, count);
return;
}

#ifdef _OPENMP
#pragma omp parallel for schedule(static)
for (int offset = 0; offset < count; offset += chunk_size)
{
const int chunk_count = std::min(chunk_size, count - offset);
std::copy_n(in + offset, chunk_count, out + offset);
}
#else
copy_complex_buffer(in, out, count);
#endif
}
} // namespace detail

/**
* @brief gather planes and scatter sticks
* @param in: (nplane,fftny,fftnx)
Expand All @@ -21,16 +60,18 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
const int nst_ = this->nst;
const int nz_ = this->nz;
const int* istot2ixy_ = this->istot2ixy;
ModuleBase::timer::start(this->classname, "gatherp_copy_serial");
#ifdef _OPENMP
#pragma omp parallel for
#endif
for(int is = 0 ; is < nst_ ; ++is)
{
int ixy = istot2ixy_[is];
std::complex<T> *outp = &out[is*nz_];
std::complex<T> *inp = &in[ixy*nz_];
std::memcpy(outp, inp, nz_ * sizeof(std::complex<T>));
std::complex<T>* outp = &out[is*nz_];
const std::complex<T>* inp = &in[ixy*nz_];
detail::copy_complex_buffer(inp, outp, nz_);
}
ModuleBase::timer::end(this->classname, "gatherp_copy_serial");
return;
}

Expand All @@ -41,16 +82,18 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
const int nstot_gps = this->nstot;
const int nplane_gps = this->nplane;
const int* istot2ixy_gps = this->istot2ixy;
ModuleBase::timer::start(this->classname, "gatherp_copy_pack");
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int istot = 0; istot < nstot_gps; ++istot)
{
int ixy = istot2ixy_gps[istot];
std::complex<T> *outp = &out[istot * nplane_gps];
std::complex<T> *inp = &in[ixy * nplane_gps];
std::memcpy(outp, inp, nplane_gps * sizeof(std::complex<T>));
std::complex<T>* outp = &out[istot * nplane_gps];
const std::complex<T>* inp = &in[ixy * nplane_gps];
detail::copy_complex_buffer(inp, outp, nplane_gps);
}
ModuleBase::timer::end(this->classname, "gatherp_copy_pack");

//exchange data
//(nplane,nstot) to (numz[ip],ns, poolnproc)
Expand All @@ -74,6 +117,7 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
const int* numz_gps = this->numz;
const int* startg_gps = this->startg;
const int* startz_gps = this->startz;
ModuleBase::timer::start(this->classname, "gatherp_copy_unpack");
#ifdef _OPENMP
#pragma omp parallel for collapse(2)
#endif
Expand All @@ -84,11 +128,12 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
int nzip = numz_gps[ip];
std::complex<T> *outp0 = &out[startz_gps[ip]];
std::complex<T> *inp0 = &in[startg_gps[ip]];
std::complex<T> *outp = &outp0[is * nz_gps];
std::complex<T> *inp = &inp0[is * nzip ];
std::memcpy(outp, inp, nzip * sizeof(std::complex<T>));
std::complex<T>* outp = &outp0[is * nz_gps];
const std::complex<T>* inp = &inp0[is * nzip ];
detail::copy_complex_buffer(inp, outp, nzip);
}
}
ModuleBase::timer::end(this->classname, "gatherp_copy_unpack");
#endif
return;
}
Expand All @@ -109,24 +154,28 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
const int nst_ = this->nst;
const int nz_ = this->nz;
const int* istot2ixy_ = this->istot2ixy;
ModuleBase::timer::start(this->classname, "gathers_zero_serial");
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for(int i = 0; i < nrxx_; ++i)
{
out[i] = std::complex<T>(0, 0);
}
ModuleBase::timer::end(this->classname, "gathers_zero_serial");

ModuleBase::timer::start(this->classname, "gathers_copy_serial");
#ifdef _OPENMP
#pragma omp parallel for
#endif
for(int is = 0 ; is < nst_ ; ++is)
{
int ixy = istot2ixy_[is];
std::complex<T> *outp = &out[ixy*nz_];
std::complex<T> *inp = &in[is*nz_];
std::memcpy(outp, inp, nz_ * sizeof(std::complex<T>));
std::complex<T>* outp = &out[ixy*nz_];
const std::complex<T>* inp = &in[is*nz_];
detail::copy_complex_buffer(inp, outp, nz_);
}
ModuleBase::timer::end(this->classname, "gathers_copy_serial");
return;
}

Expand All @@ -140,6 +189,7 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
const int* numz_ = this->numz;
const int* startg_ = this->startg;
const int* startz_ = this->startz;
ModuleBase::timer::start(this->classname, "gathers_copy_pack");
#ifdef _OPENMP
#pragma omp parallel for collapse(2)
#endif
Expand All @@ -150,11 +200,12 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
int nzip = numz_[ip];
std::complex<T> *outp0 = &out[startg_[ip]];
std::complex<T> *inp0 = &in[startz_[ip]];
std::complex<T> *outp = &outp0[is * nzip];
std::complex<T> *inp = &inp0[is * nz_ ];
std::memcpy(outp, inp, nzip * sizeof(std::complex<T>));
std::complex<T>* outp = &outp0[is * nzip];
const std::complex<T>* inp = &inp0[is * nz_ ];
detail::copy_complex_buffer(inp, outp, nzip);
}
}
ModuleBase::timer::end(this->classname, "gathers_copy_pack");

//exchange data
//(numz[ip],ns, poolnproc) to (nplane,nstot)
Expand All @@ -172,28 +223,32 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
}

const int nrxx_gsp = this->nrxx;
ModuleBase::timer::start(this->classname, "gathers_zero_mpi");
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for(int i = 0; i < nrxx_gsp; ++i)
{
out[i] = std::complex<T>(0, 0);
}
ModuleBase::timer::end(this->classname, "gathers_zero_mpi");
//change (nplane,nstot) to (nplane fftnxy)
const int nstot = this->nstot;
const int nplane = this->nplane;
const int* istot2ixy = this->istot2ixy;
ModuleBase::timer::start(this->classname, "gathers_copy_unpack");
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int istot = 0;istot < nstot; ++istot)
{
int ixy = istot2ixy[istot];
//int ixy = (ixy / fftny)*ny + ixy % fftny;
std::complex<T> *outp = &out[ixy * nplane];
std::complex<T> *inp = &in[istot * nplane];
std::memcpy(outp, inp, nplane * sizeof(std::complex<T>));
std::complex<T>* outp = &out[ixy * nplane];
const std::complex<T>* inp = &in[istot * nplane];
detail::copy_complex_buffer(inp, outp, nplane);
}
ModuleBase::timer::end(this->classname, "gathers_copy_unpack");
#endif
return;
}
Expand Down
16 changes: 2 additions & 14 deletions source/source_basis/module_pw/pw_transform_k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,7 @@ void PW_Basis_K::real2recip(const std::complex<FPTYPE>* in,

assert(this->gamma_only == false);
auto* auxr = this->fft_bundle.get_auxr_data<FPTYPE>();
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (int ir = 0; ir < this->nrxx; ++ir)
{
auxr[ir] = in[ir];
}
detail::copy_complex_buffer_parallel(in, auxr, this->nrxx);
this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>());

this->gatherp_scatters(this->fft_bundle.get_auxr_data<FPTYPE>(), this->fft_bundle.get_auxg_data<FPTYPE>());
Expand Down Expand Up @@ -200,13 +194,7 @@ void PW_Basis_K::recip2real(const std::complex<FPTYPE>* in,
}
else
{
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (int ir = 0; ir < this->nrxx; ++ir)
{
out[ir] = auxr[ir];
}
detail::copy_complex_buffer_parallel(auxr, out, this->nrxx);
}
ModuleBase::timer::end(this->classname, "recip2real");
}
Expand Down
99 changes: 99 additions & 0 deletions source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
#include "source_base/global_function.h"
#include "source_base/constants.h"
#include "source_base/matrix3.h"
#include <chrono>
#include <cstdlib>
#include <vector>

/************************************************
* serial unit test of functions in pw_basis.cpp
Expand All @@ -27,6 +30,7 @@
#define private public
#include "../pw_basis_k.h"
#include "../pw_basis.h"
#include "../pw_gatherscatter.h"
#undef private
#undef protected

Expand Down Expand Up @@ -188,4 +192,99 @@ TEST_F(PWBasisKTEST, CollectLocalPW)
EXPECT_EQ(basis_k.npwk_max,2721);
}

TEST_F(PWBasisKTEST, ComplexTransformRoundTrip)
{
ModulePW::PW_Basis_K basis_k(device_flag, precision_double);
double lat0 = 2.0;
ModuleBase::Matrix3 latvec(1.0,0.0,1.0,
0.0,2.0,0.0,
0.0,0.0,2.0);
double gridecut = 30.0;
const bool gamma_only_in = false;
const double gk_ecut_in = 20.0;
const int nks_in = 1;
const ModuleBase::Vector3<double> kvec_d_in[1] = { {0.0, 0.0, 0.0} };
const int distribution_type_in = 2;
const bool xprime_in = false;

basis_k.initgrids(lat0, latvec, gridecut);
basis_k.initparameters(gamma_only_in, gk_ecut_in, nks_in, kvec_d_in, distribution_type_in, xprime_in);
ASSERT_NO_THROW(basis_k.setuptransform());
ASSERT_NE(basis_k.npwk, nullptr);
ASSERT_GT(basis_k.npwk[0], 0);

// Use reciprocal-space input because arbitrary real-space data is projected
// by the plane-wave cutoff and is not exactly recoverable.
std::vector<std::complex<double>> recip_in(basis_k.npwk[0]);
std::vector<std::complex<double>> real_space(basis_k.nrxx);
std::vector<std::complex<double>> recip_out(basis_k.npwk[0]);
Comment thread
Aunixt marked this conversation as resolved.
for (int ig = 0; ig < basis_k.npwk[0]; ++ig)
{
const double real_part = (ig % 17 - 8) / 11.0;
const double imag_part = (ig % 19 - 9) / 13.0;
recip_in[ig] = std::complex<double>(real_part, imag_part);
}

basis_k.recip2real(recip_in.data(), real_space.data(), 0);
basis_k.real2recip(real_space.data(), recip_out.data(), 0);

for (int ig = 0; ig < basis_k.npwk[0]; ++ig)
{
EXPECT_NEAR(recip_in[ig].real(), recip_out[ig].real(), 1e-10);
EXPECT_NEAR(recip_in[ig].imag(), recip_out[ig].imag(), 1e-10);
}
}

TEST_F(PWBasisKTEST, CopyComplexBufferTimerBenchmark)
{
if (std::getenv("ABACUS_PW_SIMD_TIMER_TEST") == nullptr)
{
GTEST_SKIP() << "Set ABACUS_PW_SIMD_TIMER_TEST=1 to run the copy timer benchmark.";
}

const int count = 1 << 20;
const int repeats = 64;
std::vector<std::complex<double>> src(count);
std::vector<std::complex<double>> copy_n_dst(count);
std::vector<std::complex<double>> scalar_dst(count);

for (int i = 0; i < count; ++i)
{
src[i] = std::complex<double>((i % 97) / 17.0, (i % 89) / 19.0);
}

volatile double checksum = 0.0;

const auto copy_n_start = std::chrono::steady_clock::now();
for (int repeat = 0; repeat < repeats; ++repeat)
{
ModulePW::detail::copy_complex_buffer(src.data(), copy_n_dst.data(), count);
checksum += copy_n_dst[repeat].real();
}
const auto copy_n_end = std::chrono::steady_clock::now();

const auto scalar_start = std::chrono::steady_clock::now();
for (int repeat = 0; repeat < repeats; ++repeat)
{
for (int i = 0; i < count; ++i)
{
scalar_dst[i] = src[i];
}
checksum += scalar_dst[repeat].imag();
}
const auto scalar_end = std::chrono::steady_clock::now();

const double copy_n_time = std::chrono::duration<double>(copy_n_end - copy_n_start).count();
const double scalar_time = std::chrono::duration<double>(scalar_end - scalar_start).count();
const double bytes_moved = static_cast<double>(count) * sizeof(std::complex<double>) * repeats;
const double gib = bytes_moved / (1024.0 * 1024.0 * 1024.0);

std::cout << "PW_SIMD_TEST copy_n_helper " << copy_n_time << " s, "
<< gib / copy_n_time << " GiB/s\n";
std::cout << "PW_SIMD_TEST scalar_loop " << scalar_time << " s, "
<< gib / scalar_time << " GiB/s\n";
std::cout << "PW_SIMD_TEST speedup copy_n/scalar " << scalar_time / copy_n_time
<< ", checksum " << checksum << "\n";

ASSERT_EQ(copy_n_dst, scalar_dst);
}
Loading
Loading