diff --git a/source/source_basis/module_pw/CMakeLists.txt b/source/source_basis/module_pw/CMakeLists.txt index 912772e0573..bdab5e9f75b 100644 --- a/source/source_basis/module_pw/CMakeLists.txt +++ b/source/source_basis/module_pw/CMakeLists.txt @@ -43,6 +43,25 @@ add_library( ${objects} ) +add_executable( + MODULE_PW_cache_bench + test_serial/pw_cache_bench.cpp +) + +target_link_libraries( + MODULE_PW_cache_bench + parameter + ${math_libs} + planewave + device + base + Threads::Threads +) + +if(USE_OPENMP) + target_link_libraries(MODULE_PW_cache_bench OpenMP::OpenMP_CXX) +endif() + if (USE_DSP) target_link_libraries(planewave PRIVATE ${MTBLAS_FFT_DIR}/libmtblas/lib/libmtfft.a) diff --git a/source/source_basis/module_pw/pw_basis.cpp b/source/source_basis/module_pw/pw_basis.cpp index 549fec8e5a4..77a3626994c 100644 --- a/source/source_basis/module_pw/pw_basis.cpp +++ b/source/source_basis/module_pw/pw_basis.cpp @@ -5,6 +5,7 @@ #include "source_base/timer.h" #include "source_base/global_function.h" +#include namespace ModulePW { @@ -28,17 +29,13 @@ PW_Basis:: ~PW_Basis() delete[] fftixy2ip; delete[] nst_per; delete[] npw_per; - delete[] gdirect; - delete[] gcar; - delete[] gg; delete[] startz; delete[] numz; delete[] numg; delete[] numr; delete[] startg; delete[] startr; - delete[] ig2igg; - delete[] gg_uniq; + this->clear_owned_cache(); #if defined(__CUDA) || defined(__ROCM) if (this->device == "gpu") { @@ -48,6 +45,91 @@ PW_Basis:: ~PW_Basis() #endif } +void PW_Basis::clear_owned_cache() +{ + std::lock_guard guard(this->cache_mutex); + this->invalidate_cache_unlocked(); +} + +PW_Basis::CacheStats PW_Basis::get_cache_stats() const +{ + std::lock_guard guard(this->cache_mutex); + return this->get_cache_stats_unlocked(); +} + +PW_Basis::CacheStats PW_Basis::get_cache_stats_unlocked() const +{ + CacheStats stats; + stats.local_pw_hits = this->local_pw_cache_hits.load(); + stats.local_pw_misses = this->local_pw_cache_misses.load(); + stats.uniqgg_hits = this->uniqgg_cache_hits.load(); + stats.uniqgg_misses = this->uniqgg_cache_misses.load(); + const bool has_local_pw_cache = this->local_pw_cache_valid.load() + && this->npw > 0 + && this->gg != nullptr + && this->gdirect != nullptr + && this->gcar != nullptr; + const bool has_uniqgg_cache = this->uniqgg_cache_valid.load() + && this->ngg > 0 + && this->ig2igg != nullptr + && this->gg_uniq != nullptr; + if (has_local_pw_cache) + { + stats.cache_bytes += sizeof(double) * this->npw; + stats.cache_bytes += sizeof(ModuleBase::Vector3) * this->npw * 2; + } + if (has_uniqgg_cache) + { + stats.cache_bytes += sizeof(int) * this->npw; + stats.cache_bytes += sizeof(double) * this->ngg; + } + return stats; +} + +void PW_Basis::reset_cache_stats() +{ + this->local_pw_cache_hits.store(0); + this->local_pw_cache_misses.store(0); + this->uniqgg_cache_hits.store(0); + this->uniqgg_cache_misses.store(0); +} + +PW_Basis::CacheSignature PW_Basis::make_cache_signature() const +{ + CacheSignature signature; + signature.lat0 = this->lat0; + signature.tpiba = this->tpiba; + signature.tpiba2 = this->tpiba2; + signature.nx = this->nx; + signature.ny = this->ny; + signature.nz = this->nz; + signature.fftnx = this->fftnx; + signature.fftny = this->fftny; + signature.fftnz = this->fftnz; + signature.npw = this->npw; + signature.G = this->G; + signature.GT = this->GT; + signature.GGT = this->GGT; + return signature; +} + +bool PW_Basis::cache_signature_matches(const CacheSignature& signature) const +{ + return signature.lat0 == this->lat0 + && signature.tpiba == this->tpiba + && signature.tpiba2 == this->tpiba2 + && signature.nx == this->nx + && signature.ny == this->ny + && signature.nz == this->nz + && signature.fftnx == this->fftnx + && signature.fftny == this->fftny + && signature.fftnz == this->fftnz + && signature.npw == this->npw + && std::memcmp(&signature.G, &this->G, sizeof(ModuleBase::Matrix3)) == 0 + && std::memcmp(&signature.GT, &this->GT, sizeof(ModuleBase::Matrix3)) == 0 + && std::memcmp(&signature.GGT, &this->GGT, sizeof(ModuleBase::Matrix3)) == 0; +} + /// /// distribute plane wave basis and real-space grids to different processors /// set up maps for fft and create arrays for MPI_Alltoall @@ -138,10 +220,33 @@ void PW_Basis::collect_local_pw() { return; } + ModuleBase::timer::start(this->classname, "collect_local_pw"); + std::lock_guard guard(this->cache_mutex); + if (this->local_pw_cache_valid.load() + && this->cache_signature_matches(this->local_pw_cache_signature)) + { + ModuleBase::timer::start(this->classname, "collect_local_pw_cache_hit"); + this->local_pw_cache_hits.fetch_add(1); + ModuleBase::timer::end(this->classname, "collect_local_pw_cache_hit"); + ModuleBase::timer::end(this->classname, "collect_local_pw"); + return; + } + ModuleBase::timer::start(this->classname, "collect_local_pw_cache_build"); + this->local_pw_cache_misses.fetch_add(1); this->ig_gge0 = -1; - delete[] this->gg; this->gg = new double[this->npw]; - delete[] this->gdirect; this->gdirect = new ModuleBase::Vector3[this->npw]; - delete[] this->gcar; this->gcar = new ModuleBase::Vector3[this->npw]; + this->gg_cache_storage.reset(new double[this->npw]); + this->gdirect_cache_storage.reset(new ModuleBase::Vector3[this->npw]); + this->gcar_cache_storage.reset(new ModuleBase::Vector3[this->npw]); + this->gg = this->gg_cache_storage.get(); + this->gdirect = this->gdirect_cache_storage.get(); + this->gcar = this->gcar_cache_storage.get(); + // Unique-G data depends on gg, so rebuilding local G data invalidates it. + this->uniqgg_cache_valid.store(false); + this->ig2igg_cache_storage.reset(); + this->gg_uniq_cache_storage.reset(); + this->ig2igg = nullptr; + this->gg_uniq = nullptr; + this->ngg = 0; ModuleBase::Vector3 f; int gamma_num = 0; @@ -182,6 +287,10 @@ void PW_Basis::collect_local_pw() } } } + this->local_pw_cache_valid.store(true); + this->local_pw_cache_signature = this->make_cache_signature(); + ModuleBase::timer::end(this->classname, "collect_local_pw_cache_build"); + ModuleBase::timer::end(this->classname, "collect_local_pw"); return; } @@ -196,45 +305,74 @@ void PW_Basis::collect_uniqgg() { return; } + ModuleBase::timer::start(this->classname, "collect_uniqgg"); + std::lock_guard guard(this->cache_mutex); + if (this->uniqgg_cache_valid.load() + && this->cache_signature_matches(this->uniqgg_cache_signature)) + { + ModuleBase::timer::start(this->classname, "collect_uniqgg_cache_hit"); + this->uniqgg_cache_hits.fetch_add(1); + ModuleBase::timer::end(this->classname, "collect_uniqgg_cache_hit"); + ModuleBase::timer::end(this->classname, "collect_uniqgg"); + return; + } + ModuleBase::timer::start(this->classname, "collect_uniqgg_cache_build"); + this->uniqgg_cache_misses.fetch_add(1); this->ig_gge0 = -1; - delete[] this->ig2igg; this->ig2igg = new int [this->npw]; + this->ig2igg_cache_storage.reset(new int[this->npw]); + this->ig2igg = this->ig2igg_cache_storage.get(); - int *sortindex = new int [this->npw];//Reconstruct the mapping of the plane wave index ig according to the energy size of the plane waves - double *tmpgg = new double [this->npw];//Ranking the plane waves by energy size while ensuring that the same energy is preserved for each wave to correspond - double *tmpgg2 = new double [this->npw];//ranking the plane waves by energy size and removing the duplicates - ModuleBase::Vector3 f; - for(int ig = 0 ; ig < this-> npw ; ++ig) + std::vector sortindex(this->npw); // Reconstruct the plane-wave index mapping after sorting by energy. + std::vector tmpgg(this->npw); + std::vector tmpgg2(this->npw); + // Reuse gg when collect_local_pw has already built the same G^2 values. + if (this->local_pw_cache_valid.load() && this->gg != nullptr) { - int isz = this->ig2isz[ig]; - int iz = isz % this->nz; - int is = isz / this->nz; - int ixy = this->is2fftixy[is]; - int ix = ixy / this->fftny; - int iy = ixy % this->fftny; - if (ix >= int(this->nx/2) + 1) - { - ix -= this->nx; - } - if (iy >= int(this->ny/2) + 1) - { - iy -= this->ny; - } - if (iz >= int(this->nz/2) + 1) + for(int ig = 0 ; ig < this-> npw ; ++ig) { - iz -= this->nz; + tmpgg[ig] = this->gg[ig]; + if(tmpgg[ig] < 1e-8) + { + this->ig_gge0 = ig; + } } - f.x = ix; - f.y = iy; - f.z = iz; - tmpgg[ig] = f * (this->GGT * f); - if(tmpgg[ig] < 1e-8) + } + else + { + ModuleBase::Vector3 f; + for(int ig = 0 ; ig < this-> npw ; ++ig) { - this->ig_gge0 = ig; + int isz = this->ig2isz[ig]; + int iz = isz % this->nz; + int is = isz / this->nz; + int ixy = this->is2fftixy[is]; + int ix = ixy / this->fftny; + int iy = ixy % this->fftny; + if (ix >= int(this->nx/2) + 1) + { + ix -= this->nx; + } + if (iy >= int(this->ny/2) + 1) + { + iy -= this->ny; + } + if (iz >= int(this->nz/2) + 1) + { + iz -= this->nz; + } + f.x = ix; + f.y = iy; + f.z = iz; + tmpgg[ig] = f * (this->GGT * f); + if(tmpgg[ig] < 1e-8) + { + this->ig_gge0 = ig; + } } } - ModuleBase::GlobalFunc::ZEROS(sortindex, this->npw); - ModuleBase::heapsort(this->npw, tmpgg, sortindex); + ModuleBase::GlobalFunc::ZEROS(sortindex.data(), this->npw); + ModuleBase::heapsort(this->npw, tmpgg.data(), sortindex.data()); int igg = 0; @@ -261,14 +399,16 @@ void PW_Basis::collect_uniqgg() } tmpgg2[igg] = avg_gg / double(avg_n); this->ngg = igg + 1; - delete[] this->gg_uniq; this->gg_uniq = new double [this->ngg]; + this->gg_uniq_cache_storage.reset(new double[this->ngg]); + this->gg_uniq = this->gg_uniq_cache_storage.get(); for(int igg = 0 ; igg < this->ngg ; ++igg) { gg_uniq[igg] = tmpgg2[igg]; } - delete[] sortindex; - delete[] tmpgg; - delete[] tmpgg2; + this->uniqgg_cache_valid.store(true); + this->uniqgg_cache_signature = this->make_cache_signature(); + ModuleBase::timer::end(this->classname, "collect_uniqgg_cache_build"); + ModuleBase::timer::end(this->classname, "collect_uniqgg"); } void PW_Basis::getfftixy2is(int * fftixy2is) const @@ -295,10 +435,12 @@ void PW_Basis::getfftixy2is(int * fftixy2is) const void PW_Basis::set_device(std::string device_) { this->device = std::move(device_); + this->invalidate_cache(); } void PW_Basis::set_precision(std::string precision_) { this->precision = std::move(precision_); + this->invalidate_cache(); } } diff --git a/source/source_basis/module_pw/pw_basis.h b/source/source_basis/module_pw/pw_basis.h index b834cb0e0f4..813e93342e9 100644 --- a/source/source_basis/module_pw/pw_basis.h +++ b/source/source_basis/module_pw/pw_basis.h @@ -9,9 +9,14 @@ #include #include "source_base/module_fft/fft_bundle.h" #include +#include #ifdef __MPI #include "mpi.h" #endif +#include +#include +#include +#include namespace ModulePW { @@ -56,8 +61,21 @@ class PW_Basis { public: + struct CacheStats + { + std::uint64_t local_pw_hits = 0; + std::uint64_t local_pw_misses = 0; + std::uint64_t uniqgg_hits = 0; + std::uint64_t uniqgg_misses = 0; + std::size_t cache_bytes = 0; + }; + std::string classname; PW_Basis(); + // PW_Basis owns FFT/distribution maps through raw pointers, so copying would + // create ambiguous ownership and stale cache pointers. + PW_Basis(const PW_Basis& other) = delete; + PW_Basis& operator=(const PW_Basis& other) = delete; PW_Basis(std::string device_, std::string precision_); virtual ~PW_Basis(); //Init mpi parameters @@ -137,9 +155,74 @@ class PW_Basis //distribute plane waves and grids and set up fft void setuptransform(); + CacheStats get_cache_stats() const; + void reset_cache_stats(); + protected: int *startnsz_per=nullptr;//useless intermediate variable// startnsz_per[ip]: starting is * nz stick in the ip^th proc. + virtual void invalidate_cache() + { + std::lock_guard guard(this->cache_mutex); + this->invalidate_cache_unlocked(); + } + + void clear_owned_cache(); + + // Public gg/gcar/gdirect pointers are non-owning views of these cache buffers. + std::atomic local_pw_cache_valid{false}; + std::atomic uniqgg_cache_valid{false}; + mutable std::mutex cache_mutex; + std::unique_ptr gg_cache_storage; + std::unique_ptr[]> gdirect_cache_storage; + std::unique_ptr[]> gcar_cache_storage; + std::unique_ptr ig2igg_cache_storage; + std::unique_ptr gg_uniq_cache_storage; + std::atomic local_pw_cache_hits{0}; + std::atomic local_pw_cache_misses{0}; + std::atomic uniqgg_cache_hits{0}; + std::atomic uniqgg_cache_misses{0}; + + struct CacheSignature + { + double lat0 = 0.0; + double tpiba = 0.0; + double tpiba2 = 0.0; + int nx = 0; + int ny = 0; + int nz = 0; + int fftnx = 0; + int fftny = 0; + int fftnz = 0; + int npw = 0; + ModuleBase::Matrix3 G; + ModuleBase::Matrix3 GT; + ModuleBase::Matrix3 GGT; + }; + CacheSignature make_cache_signature() const; + bool cache_signature_matches(const CacheSignature& signature) const; + CacheSignature local_pw_cache_signature; + CacheSignature uniqgg_cache_signature; + + virtual void invalidate_cache_unlocked() + { + this->local_pw_cache_valid.store(false); + this->uniqgg_cache_valid.store(false); + this->gg_cache_storage.reset(); + this->gdirect_cache_storage.reset(); + this->gcar_cache_storage.reset(); + this->ig2igg_cache_storage.reset(); + this->gg_uniq_cache_storage.reset(); + this->gg = nullptr; + this->gdirect = nullptr; + this->gcar = nullptr; + this->ig2igg = nullptr; + this->gg_uniq = nullptr; + this->ngg = 0; + this->ig_gge0 = -1; + } + CacheStats get_cache_stats_unlocked() const; + //distribute plane waves to different processors void distribute_g(); diff --git a/source/source_basis/module_pw/pw_basis_k.cpp b/source/source_basis/module_pw/pw_basis_k.cpp index 2c2d02bf927..6418258d3bc 100644 --- a/source/source_basis/module_pw/pw_basis_k.cpp +++ b/source/source_basis/module_pw/pw_basis_k.cpp @@ -6,6 +6,7 @@ #include "source_base/timer.h" #include +#include namespace ModulePW { @@ -21,7 +22,6 @@ PW_Basis_K::~PW_Basis_K() delete[] npwk; delete[] igl2isz_k; delete[] igl2ig_k; - delete[] gk2; #if defined(__CUDA) || defined(__ROCM) if (this->device == "gpu") { @@ -44,6 +44,49 @@ PW_Basis_K::~PW_Basis_K() #if defined(__CUDA) || defined(__ROCM) } #endif + this->clear_k_cache_storage(); +} + +void PW_Basis_K::clear_k_cache_storage() +{ + std::lock_guard guard(this->cache_mutex); + this->invalidate_cache_unlocked(); +} + +PW_Basis_K::KCacheStats PW_Basis_K::get_k_cache_stats() const +{ + std::lock_guard guard(this->cache_mutex); + KCacheStats stats; + const auto base_stats = PW_Basis::get_cache_stats_unlocked(); + static_cast(stats) = base_stats; + stats.gcar_hits = this->gcar_cache_hits.load(); + stats.gcar_misses = this->gcar_cache_misses.load(); + stats.gk2_hits = this->gk2_cache_hits.load(); + stats.gk2_misses = this->gk2_cache_misses.load(); + if (this->gcar_cache_valid.load() + && this->gcar != nullptr + && this->npwk_max > 0 + && this->nks > 0) + { + stats.cache_bytes += sizeof(ModuleBase::Vector3) * this->npwk_max * this->nks; + } + if (this->gk_cache_valid.load() + && this->gk2 != nullptr + && this->npwk_max > 0 + && this->nks > 0) + { + stats.cache_bytes += sizeof(double) * this->npwk_max * this->nks; + } + return stats; +} + +void PW_Basis_K::reset_k_cache_stats() +{ + PW_Basis::reset_cache_stats(); + this->gcar_cache_hits.store(0); + this->gcar_cache_misses.store(0); + this->gk2_cache_hits.store(0); + this->gk2_cache_misses.store(0); } void PW_Basis_K::initparameters(const bool gamma_only_in, @@ -101,6 +144,7 @@ void PW_Basis_K::initparameters(const bool gamma_only_in, this->fftnxy = this->fftnx * this->fftny; this->fftnxyz = this->fftnxy * this->fftnz; this->distribution_type = distribution_type_in; + this->invalidate_cache(); #if defined(__CUDA) || defined(__ROCM) if (this->device == "gpu") { @@ -129,6 +173,7 @@ void PW_Basis_K::initparameters(const bool gamma_only_in, void PW_Basis_K::setupIndGk() { + this->invalidate_cache(); // count npwk this->npwk_max = 0; delete[] this->npwk; @@ -198,6 +243,34 @@ void PW_Basis_K::setupIndGk() return; } +ModuleBase::Vector3 PW_Basis_K::cal_GplusK_cartesian(const int ik, const int ig) const +{ + int isz = this->ig2isz[ig]; + int iz = isz % this->nz; + int is = isz / this->nz; + int ix = this->is2fftixy[is] / this->fftny; + int iy = this->is2fftixy[is] % this->fftny; + if (ix >= int(this->nx / 2) + 1) + { + ix -= this->nx; + } + if (iy >= int(this->ny / 2) + 1) + { + iy -= this->ny; + } + if (iz >= int(this->nz / 2) + 1) + { + iz -= this->nz; + } + ModuleBase::Vector3 f; + f.x = ix; + f.y = iy; + f.z = iz; + f = f * this->G; + ModuleBase::Vector3 g_temp_ = this->kvec_c[ik] + f; + return g_temp_; +} + /// /// distribute plane wave basis and real-space grids to different processors /// set up maps for fft and create arrays for MPI_Alltoall @@ -249,19 +322,60 @@ void PW_Basis_K::setuptransform() void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_height_in, const double& erf_sigma_in) { - this->erf_ecut = erf_ecut_in; - this->erf_height = erf_height_in; - this->erf_sigma = erf_sigma_in; if (this->npwk_max <= 0) { return; } - delete[] gk2; - delete[] gcar; - this->gk2 = new double[this->npwk_max * this->nks]; - this->gcar = new ModuleBase::Vector3[this->npwk_max * this->nks]; - ModuleBase::Memory::record("PW_B_K::gk2", sizeof(double) * this->npwk_max * this->nks); - ModuleBase::Memory::record("PW_B_K::gcar", sizeof(ModuleBase::Vector3) * this->npwk_max * this->nks); + ModuleBase::timer::start(this->classname, "collect_local_pw"); + std::lock_guard guard(this->cache_mutex); + const bool locked_gcar_hit = this->gcar_cache_valid.load() && this->gcar != nullptr; + const bool locked_gk2_hit = this->gk_cache_valid.load() + && this->gk2 != nullptr + && this->erf_ecut == erf_ecut_in + && this->erf_height == erf_height_in + && this->erf_sigma == erf_sigma_in; + if (locked_gcar_hit && locked_gk2_hit) + { + ModuleBase::timer::start(this->classname, "collect_local_pw_cache_hit"); + this->gcar_cache_hits.fetch_add(1); + this->gk2_cache_hits.fetch_add(1); + ModuleBase::timer::end(this->classname, "collect_local_pw_cache_hit"); + ModuleBase::timer::end(this->classname, "collect_local_pw"); + return; + } + if (!locked_gcar_hit) + { + ModuleBase::timer::start(this->classname, "collect_local_pw_build_gcar"); + } + if (!locked_gk2_hit) + { + ModuleBase::timer::start(this->classname, "collect_local_pw_build_gk2"); + } + if (locked_gcar_hit) + { + this->gcar_cache_hits.fetch_add(1); + } + else + { + this->gcar_cache_misses.fetch_add(1); + this->k_gcar_cache_storage.reset(new ModuleBase::Vector3[this->npwk_max * this->nks]); + this->gcar = this->k_gcar_cache_storage.get(); + ModuleBase::Memory::record("PW_B_K::gcar", sizeof(ModuleBase::Vector3) * this->npwk_max * this->nks); + } + if (locked_gk2_hit) + { + this->gk2_cache_hits.fetch_add(1); + } + else + { + this->gk2_cache_misses.fetch_add(1); + this->k_gk2_cache_storage.reset(new double[this->npwk_max * this->nks]); + this->gk2 = this->k_gk2_cache_storage.get(); + ModuleBase::Memory::record("PW_B_K::gk2", sizeof(double) * this->npwk_max * this->nks); + } + this->erf_ecut = erf_ecut_in; + this->erf_height = erf_height_in; + this->erf_sigma = erf_sigma_in; ModuleBase::Vector3 f; for (int ik = 0; ik < this->nks; ++ik) @@ -291,36 +405,55 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h f.y = iy; f.z = iz; - this->gcar[ik * npwk_max + igl] = f * this->G; - double temp_gk2 = (f + kv) * (this->GGT * (f + kv)); - if (erf_height > 0) + if (!locked_gcar_hit) { - this->gk2[ik * npwk_max + igl] - = temp_gk2 + erf_height / tpiba2 * (1.0 + std::erf((temp_gk2 * tpiba2 - erf_ecut) / erf_sigma)); + this->gcar[ik * npwk_max + igl] = f * this->G; } - else + if (!locked_gk2_hit) { - this->gk2[ik * npwk_max + igl] = temp_gk2; + const double temp_gk2 = (f + kv) * (this->GGT * (f + kv)); + if (erf_height > 0) + { + this->gk2[ik * npwk_max + igl] + = temp_gk2 + erf_height / tpiba2 * (1.0 + std::erf((temp_gk2 * tpiba2 - erf_ecut) / erf_sigma)); + } + else + { + this->gk2[ik * npwk_max + igl] = temp_gk2; + } } } } + if (!locked_gcar_hit) + { + this->sync_gcar_device_cache(); + this->gcar_cache_valid.store(true); + ModuleBase::timer::end(this->classname, "collect_local_pw_build_gcar"); + } + if (!locked_gk2_hit) + { + this->sync_gk2_device_cache(); + this->gk_cache_valid.store(true); + ModuleBase::timer::end(this->classname, "collect_local_pw_build_gk2"); + } + ModuleBase::timer::end(this->classname, "collect_local_pw"); +} + +void PW_Basis_K::sync_gcar_device_cache() +{ #if defined(__CUDA) || defined(__ROCM) if (this->device == "gpu") { if (this->float_data_) { - resmem_sd_op()(this->s_gk2, this->npwk_max * this->nks); resmem_sd_op()(this->s_gcar, this->npwk_max * this->nks * 3); - castmem_d2s_h2d_op()(this->s_gk2, this->gk2, this->npwk_max * this->nks); castmem_d2s_h2d_op()(this->s_gcar, reinterpret_cast(&this->gcar[0][0]), this->npwk_max * this->nks * 3); } if (this->double_data_) { - resmem_dd_op()(this->d_gk2, this->npwk_max * this->nks); resmem_dd_op()(this->d_gcar, this->npwk_max * this->nks * 3); - syncmem_d2d_h2d_op()(this->d_gk2, this->gk2, this->npwk_max * this->nks); syncmem_d2d_h2d_op()(this->d_gcar, reinterpret_cast(&this->gcar[0][0]), this->npwk_max * this->nks * 3); @@ -331,9 +464,7 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h #endif if (this->float_data_) { - resmem_sh_op()(this->s_gk2, this->npwk_max * this->nks, "PW_B_K::s_gk2"); resmem_sh_op()(this->s_gcar, this->npwk_max * this->nks * 3, "PW_B_K::s_gcar"); - castmem_d2s_h2h_op()(this->s_gk2, this->gk2, this->npwk_max * this->nks); castmem_d2s_h2h_op()(this->s_gcar, reinterpret_cast(&this->gcar[0][0]), this->npwk_max * this->nks * 3); @@ -341,7 +472,6 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h if (this->double_data_) { this->d_gcar = reinterpret_cast(&this->gcar[0][0]); - this->d_gk2 = this->gk2; } // There's no need to allocate double pointers while in a CPU environment. #if defined(__CUDA) || defined(__ROCM) @@ -349,32 +479,37 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h #endif } -ModuleBase::Vector3 PW_Basis_K::cal_GplusK_cartesian(const int ik, const int ig) const +void PW_Basis_K::sync_gk2_device_cache() { - int isz = this->ig2isz[ig]; - int iz = isz % this->nz; - int is = isz / this->nz; - int ix = this->is2fftixy[is] / this->fftny; - int iy = this->is2fftixy[is] % this->fftny; - if (ix >= int(this->nx / 2) + 1) - { - ix -= this->nx; - } - if (iy >= int(this->ny / 2) + 1) +#if defined(__CUDA) || defined(__ROCM) + if (this->device == "gpu") { - iy -= this->ny; + if (this->float_data_) + { + resmem_sd_op()(this->s_gk2, this->npwk_max * this->nks); + castmem_d2s_h2d_op()(this->s_gk2, this->gk2, this->npwk_max * this->nks); + } + if (this->double_data_) + { + resmem_dd_op()(this->d_gk2, this->npwk_max * this->nks); + syncmem_d2d_h2d_op()(this->d_gk2, this->gk2, this->npwk_max * this->nks); + } } - if (iz >= int(this->nz / 2) + 1) + else { - iz -= this->nz; +#endif + if (this->float_data_) + { + resmem_sh_op()(this->s_gk2, this->npwk_max * this->nks, "PW_B_K::s_gk2"); + castmem_d2s_h2h_op()(this->s_gk2, this->gk2, this->npwk_max * this->nks); + } + if (this->double_data_) + { + this->d_gk2 = this->gk2; + } +#if defined(__CUDA) || defined(__ROCM) } - ModuleBase::Vector3 f; - f.x = ix; - f.y = iy; - f.z = iz; - f = f * this->G; - ModuleBase::Vector3 g_temp_ = this->kvec_c[ik] + f; - return g_temp_; +#endif } double& PW_Basis_K::getgk2(const int ik, const int igl) const @@ -510,23 +645,27 @@ double* PW_Basis_K::get_kvec_c_data() const template <> float* PW_Basis_K::get_gcar_data() const { - return this->s_gcar; + std::lock_guard guard(this->cache_mutex); + return this->gcar_cache_valid.load() ? this->s_gcar : nullptr; } template <> double* PW_Basis_K::get_gcar_data() const { - return this->d_gcar; + std::lock_guard guard(this->cache_mutex); + return this->gcar_cache_valid.load() ? this->d_gcar : nullptr; } template <> float* PW_Basis_K::get_gk2_data() const { - return this->s_gk2; + std::lock_guard guard(this->cache_mutex); + return this->gk_cache_valid.load() ? this->s_gk2 : nullptr; } template <> double* PW_Basis_K::get_gk2_data() const { - return this->d_gk2; + std::lock_guard guard(this->cache_mutex); + return this->gk_cache_valid.load() ? this->d_gk2 : nullptr; } -} // namespace ModulePW \ No newline at end of file +} // namespace ModulePW diff --git a/source/source_basis/module_pw/pw_basis_k.h b/source/source_basis/module_pw/pw_basis_k.h index f633a30769d..81dac5b2e53 100644 --- a/source/source_basis/module_pw/pw_basis_k.h +++ b/source/source_basis/module_pw/pw_basis_k.h @@ -56,6 +56,14 @@ class PW_Basis_K : public PW_Basis { public: + struct KCacheStats : public PW_Basis::CacheStats + { + std::uint64_t gcar_hits = 0; + std::uint64_t gcar_misses = 0; + std::uint64_t gk2_hits = 0; + std::uint64_t gk2_misses = 0; + }; + PW_Basis_K(); PW_Basis_K(std::string device_, std::string precision_) : PW_Basis(device_, precision_) {classname="PW_Basis_K";} ~PW_Basis_K(); @@ -99,16 +107,42 @@ class PW_Basis_K : public PW_Basis const double& erf_height_in = 0.0, const double& erf_sigma_in = 0.1); + KCacheStats get_k_cache_stats() const; + void reset_k_cache_stats(); + private: + void clear_k_cache_storage(); + void invalidate_cache_unlocked() override + { + PW_Basis::invalidate_cache_unlocked(); + this->gcar_cache_valid.store(false); + this->gk_cache_valid.store(false); + this->k_gcar_cache_storage.reset(); + this->k_gk2_cache_storage.reset(); + this->gcar = nullptr; + this->gk2 = nullptr; + this->d_gcar = nullptr; + this->d_gk2 = nullptr; + } + void sync_gcar_device_cache(); + void sync_gk2_device_cache(); + + std::atomic gcar_cache_valid{false}; + std::atomic gk_cache_valid{false}; + std::unique_ptr[]> k_gcar_cache_storage; + std::unique_ptr k_gk2_cache_storage; + std::atomic gcar_cache_hits{0}; + std::atomic gcar_cache_misses{0}; + std::atomic gk2_cache_hits{0}; + std::atomic gk2_cache_misses{0}; float * s_gk2 = nullptr; double * d_gk2 = nullptr; // modulus (G+K)^2 of G vectors [npwk_max*nks] //create igl2isz_k map array for fft void setupIndGk(); // get ig2ixyz_k void get_ig2ixyz_k(); - //calculate G+K, it is a private function + //calculate G+K in cartesian coordinates ModuleBase::Vector3 cal_GplusK_cartesian(const int ik, const int ig) const; - public: template void real2recip(const FPTYPE* in, @@ -280,4 +314,3 @@ class PW_Basis_K : public PW_Basis #endif //PlaneWave_K class #include "./pw_basis_k_big.h" //temporary it will be removed - diff --git a/source/source_basis/module_pw/pw_distributeg.cpp b/source/source_basis/module_pw/pw_distributeg.cpp index 9e47ecff7a4..795f271d45b 100644 --- a/source/source_basis/module_pw/pw_distributeg.cpp +++ b/source/source_basis/module_pw/pw_distributeg.cpp @@ -177,6 +177,7 @@ void PW_Basis::get_ig2isz_is2fftixy( { delete[] this->ig2isz; this->ig2isz = nullptr; // map ig to the z coordinate of this planewave. delete[] this->is2fftixy; this->is2fftixy = nullptr; // map is (index of sticks) to ixy (iy + ix * fftny). + this->invalidate_cache(); #if defined(__CUDA) || defined(__ROCM) if (this->device == "gpu") { delmem_int_op()(this->d_is2fftixy); @@ -242,6 +243,7 @@ void PW_Basis::get_ig2isz_is2fftixy( syncmem_int_h2d_op()(ig2ixyz_gpu, ig2ixyz.data(), this->npw); } #endif + this->invalidate_cache(); return; } } // namespace ModulePW \ No newline at end of file diff --git a/source/source_basis/module_pw/pw_gatherscatter.h b/source/source_basis/module_pw/pw_gatherscatter.h index 00ce160e254..6a13b9bd782 100644 --- a/source/source_basis/module_pw/pw_gatherscatter.h +++ b/source/source_basis/module_pw/pw_gatherscatter.h @@ -15,8 +15,8 @@ namespace ModulePW template void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const { - - if(this->poolnproc == 1) //In this case nst=nstot, nz = nplane, + + if(this->poolnproc == 1) //In this case nst=nstot, nz = nplane, { const int nst_ = this->nst; const int nz_ = this->nz; @@ -103,7 +103,10 @@ void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const template void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const { - if(this->poolnproc == 1) //In this case nrxx=fftnx*fftny*nz, nst = nstot, + + + + if(this->poolnproc == 1) //In this case nrxx=fftnx*fftny*nz, nst = nstot, { const int nrxx_ = this->nrxx; const int nst_ = this->nst; diff --git a/source/source_basis/module_pw/pw_init.cpp b/source/source_basis/module_pw/pw_init.cpp index 08c676d39f3..9a7d3b06043 100644 --- a/source/source_basis/module_pw/pw_init.cpp +++ b/source/source_basis/module_pw/pw_init.cpp @@ -13,6 +13,7 @@ void PW_Basis:: initmpi( this->poolnproc = poolnproc_in; this->poolrank = poolrank_in; this->pool_world = pool_world_in; + this->invalidate_cache(); } #endif /// @@ -142,6 +143,7 @@ void PW_Basis:: initgrids( this->nz = ibox[2]; this->nxy =this->nx * this->ny; this->nxyz = this->nxy * this->nz; + this->invalidate_cache(); delete[] ibox; return; @@ -203,6 +205,7 @@ void PW_Basis:: initgrids( MPI_Allreduce(MPI_IN_PLACE, &this->gridecut_lat, 1, MPI_DOUBLE, MPI_MIN , this->pool_world); #endif this->gridecut_lat -= 1e-6; + this->invalidate_cache(); delete[] ibox; return; @@ -240,6 +243,7 @@ void PW_Basis:: initparameters( this->ggecut = this->gridecut_lat; } this->distribution_type = distribution_type_in; + this->invalidate_cache(); } // Set parameters about full planewave, used only in OFDFT for now. sunliang added 2022-08-30 @@ -251,5 +255,6 @@ void PW_Basis::setfullpw( this->full_pw = inpt_full_pw; this->full_pw_dim = inpt_full_pw_dim; if (!this->full_pw) this->full_pw_dim = 0; + this->invalidate_cache(); +} } -} \ No newline at end of file diff --git a/source/source_basis/module_pw/test/test-big.cpp b/source/source_basis/module_pw/test/test-big.cpp index f1c2082d0b2..1f56597ece3 100644 --- a/source/source_basis/module_pw/test/test-big.cpp +++ b/source/source_basis/module_pw/test/test-big.cpp @@ -53,7 +53,7 @@ TEST_F(PWTEST,test_big) pwktest.initgrids(lat0,latvec, pwtest.nx, pwtest.ny, pwtest.nz); pwtest.initparameters(gamma_only,wfcecut,distribution_type,xprime); pwktest.initparameters(gamma_only,wfcecut,nks,kvec_d,distribution_type, xprime); - static_cast(pwtest).setuptransform(); + pwtest.ModulePW::PW_Basis::setuptransform(); pwktest.setuptransform(); EXPECT_EQ(pwtest.nx%2, 0); EXPECT_EQ(pwtest.ny%2, 0); @@ -85,7 +85,7 @@ TEST_F(PWTEST,test_big) class TestPW_Basis_Big : public ::testing::Test { public: - ModulePW::PW_Basis_Big pwtest = ModulePW::PW_Basis_Big(); + ModulePW::PW_Basis_Big pwtest; }; // Test the function with nproc = 0 (bx and by) @@ -157,4 +157,4 @@ TEST_F(TestPW_Basis_Big, BzNprocNoResultTest) { int nproc = 5; pwtest.autoset_big_cell_size(b_size, nc_size, nproc); EXPECT_EQ(b_size, 3); -} \ No newline at end of file +} diff --git a/source/source_basis/module_pw/test/test1-1-1.cpp b/source/source_basis/module_pw/test/test1-1-1.cpp index 3eb9d8fd5e4..4be297f276e 100644 --- a/source/source_basis/module_pw/test/test1-1-1.cpp +++ b/source/source_basis/module_pw/test/test1-1-1.cpp @@ -29,6 +29,36 @@ TEST_F(PWTEST,test1_1_1) pwtest.initgrids(lat0, latvec, wfcecut); pwtest.initparameters(gamma_only, wfcecut, distribution_type,xprime); pwtest.setuptransform(); + pwtest.reset_cache_stats(); + pwtest.collect_local_pw(); + pwtest.collect_uniqgg(); + auto stats_after_build = pwtest.get_cache_stats(); + EXPECT_EQ(stats_after_build.local_pw_misses, 1); + EXPECT_EQ(stats_after_build.uniqgg_misses, 1); + double* gg_ptr = pwtest.gg; + int* ig2igg_ptr = pwtest.ig2igg; + double* gguniq_ptr = pwtest.gg_uniq; + const int ngg_before = pwtest.ngg; + const double gg_sample = pwtest.gg[0]; + pwtest.collect_local_pw(); + pwtest.collect_uniqgg(); + EXPECT_EQ(pwtest.gg, gg_ptr); + EXPECT_EQ(pwtest.ig2igg, ig2igg_ptr); + EXPECT_EQ(pwtest.gg_uniq, gguniq_ptr); + EXPECT_EQ(pwtest.ngg, ngg_before); + EXPECT_DOUBLE_EQ(pwtest.gg[0], gg_sample); + auto stats_after_hit = pwtest.get_cache_stats(); + EXPECT_EQ(stats_after_hit.local_pw_hits, 1); + EXPECT_EQ(stats_after_hit.uniqgg_hits, 1); + EXPECT_GT(stats_after_hit.cache_bytes, 0); + pwtest.initparameters(gamma_only, wfcecut, distribution_type, xprime); + EXPECT_EQ(pwtest.gg, nullptr); + EXPECT_EQ(pwtest.gdirect, nullptr); + EXPECT_EQ(pwtest.gcar, nullptr); + EXPECT_EQ(pwtest.ig2igg, nullptr); + EXPECT_EQ(pwtest.gg_uniq, nullptr); + EXPECT_EQ(pwtest.get_cache_stats().cache_bytes, 0); + pwtest.setuptransform(); pwtest.collect_local_pw(); pwtest.collect_uniqgg(); ModuleBase::Matrix3 GT,G,GGT; @@ -229,4 +259,4 @@ TEST_F(PWTEST,test1_1_1) delete[] irindex; -} \ No newline at end of file +} diff --git a/source/source_basis/module_pw/test_serial/CMakeLists.txt b/source/source_basis/module_pw/test_serial/CMakeLists.txt index 52e594afb99..07179b2bd1c 100644 --- a/source/source_basis/module_pw/test_serial/CMakeLists.txt +++ b/source/source_basis/module_pw/test_serial/CMakeLists.txt @@ -34,3 +34,18 @@ AddTest( LIBS parameter ${math_libs} planewave_serial device base SOURCES pw_basis_k_test.cpp ) + +add_executable( + MODULE_PW_cache_bench_serial + pw_cache_bench.cpp +) + +target_link_libraries( + MODULE_PW_cache_bench_serial + parameter + ${math_libs} + planewave_serial + device + base + Threads::Threads +) diff --git a/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp b/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp index 84932bae2ff..4b1ec1f0aed 100644 --- a/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp +++ b/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp @@ -2,6 +2,7 @@ #include "source_base/global_function.h" #include "source_base/constants.h" #include "source_base/matrix3.h" +#include "source_base/timer.h" /************************************************ * serial unit test of functions in pw_basis.cpp @@ -183,9 +184,59 @@ TEST_F(PWBasisKTEST, CollectLocalPW) const bool xprime_in = true; basis_k.initparameters(gamma_only_in, gk_ecut_in, nks_in,kvec_d_in, distribution_type_in, xprime_in); EXPECT_NO_THROW(basis_k.setuptransform()); + basis_k.reset_k_cache_stats(); EXPECT_NO_THROW(basis_k.collect_local_pw()); + ASSERT_GT(basis_k.npwk[0], 0); + auto* gk2_ptr = basis_k.get_gk2_data(); + auto* gcar_ptr = basis_k.get_gcar_data(); + const double gk2_sample = basis_k.getgk2(0,0); + const auto stats_after_build = basis_k.get_k_cache_stats(); + EXPECT_EQ(stats_after_build.gcar_misses, 1); + EXPECT_EQ(stats_after_build.gk2_misses, 1); + EXPECT_NO_THROW(basis_k.collect_local_pw()); + EXPECT_EQ(basis_k.get_gk2_data(), gk2_ptr); + EXPECT_EQ(basis_k.get_gcar_data(), gcar_ptr); + EXPECT_DOUBLE_EQ(basis_k.getgk2(0,0), gk2_sample); + EXPECT_NO_THROW(basis_k.collect_local_pw(1.0, 0.5, 0.2)); + EXPECT_EQ(basis_k.get_gcar_data(), gcar_ptr); + const auto stats_after_hits = basis_k.get_k_cache_stats(); + EXPECT_EQ(stats_after_hits.gcar_hits, 2); + EXPECT_EQ(stats_after_hits.gcar_misses, 1); + EXPECT_EQ(stats_after_hits.gk2_hits, 1); + EXPECT_EQ(stats_after_hits.gk2_misses, 2); + EXPECT_GT(stats_after_hits.cache_bytes, 0); + basis_k.initparameters(gamma_only_in, gk_ecut_in, nks_in, kvec_d_in, distribution_type_in, xprime_in); + EXPECT_EQ(basis_k.gcar, nullptr); + EXPECT_EQ(basis_k.gk2, nullptr); + EXPECT_EQ(basis_k.k_gcar_cache_storage, nullptr); + EXPECT_EQ(basis_k.k_gk2_cache_storage, nullptr); + EXPECT_EQ(basis_k.get_gcar_data(), nullptr); + EXPECT_EQ(basis_k.get_gk2_data(), nullptr); + EXPECT_EQ(basis_k.get_k_cache_stats().cache_bytes, 0); EXPECT_EQ(basis_k.npw,3695); EXPECT_EQ(basis_k.npwk_max,2721); } - +TEST_F(PWBasisKTEST, CollectLocalPWRecordsTimers) +{ + ModuleBase::timer::timer_pool.clear(); + ModulePW::PW_Basis_K basis_k(device_flag, precision_double); + double lat0 = 1.8897261254578281; + ModuleBase::Matrix3 latvec(10.0,0.0,0.0, + 0.0,10.0,0.0, + 0.0,0.0,10.0); + double gridecut = 10.0; + basis_k.initgrids(lat0, latvec, gridecut); + const bool gamma_only_in = true; + const double gk_ecut_in = 11.0; + const int nks_in = 3; + const ModuleBase::Vector3 kvec_d_in[3] = { {0.0, 0.0, 0.0}, {0.1, 0.2, 0.3}, {0.4, 0.5, 0.6} }; + const int distribution_type_in = 1; + const bool xprime_in = true; + basis_k.initparameters(gamma_only_in, gk_ecut_in, nks_in, kvec_d_in, distribution_type_in, xprime_in); + basis_k.setuptransform(); + basis_k.collect_local_pw(); + const auto& timer_pool = ModuleBase::timer::timer_pool[basis_k.classname]; + EXPECT_TRUE(timer_pool.count("collect_local_pw")); + EXPECT_GE(timer_pool.at("collect_local_pw").calls, 1u); +} diff --git a/source/source_basis/module_pw/test_serial/pw_basis_test.cpp b/source/source_basis/module_pw/test_serial/pw_basis_test.cpp index ea678b9d97c..4692ed43e30 100644 --- a/source/source_basis/module_pw/test_serial/pw_basis_test.cpp +++ b/source/source_basis/module_pw/test_serial/pw_basis_test.cpp @@ -2,6 +2,7 @@ #include "source_base/global_function.h" #include "source_base/constants.h" #include "source_base/matrix3.h" +#include "source_base/timer.h" /************************************************ * serial unit test of functions in pw_basis.cpp @@ -362,3 +363,91 @@ TEST_F(PWBasisTEST,CollectUniqgg) pwb.collect_uniqgg(); EXPECT_EQ(pwb.ngg,78); } + +TEST_F(PWBasisTEST, CacheStorageClearedOnParameterChange) +{ + double lat0 = 1.8897261254578281; + ModuleBase::Matrix3 latvec(10.0,0.0,0.0, + 0.0,10.0,0.0, + 0.0,0.0,10.0); + double gridecut=10.0; + bool gamma_only_in = true; + double pwecut_in = 11.0; + int distribution_type_in = 2; + bool xprime_in = true; + pwb.initgrids(lat0,latvec,gridecut); + pwb.initparameters(gamma_only_in,pwecut_in,distribution_type_in,xprime_in); + EXPECT_NO_THROW(pwb.setuptransform()); + pwb.collect_local_pw(); + pwb.collect_uniqgg(); + EXPECT_GT(pwb.get_cache_stats().cache_bytes, 0); + pwb.initparameters(gamma_only_in,pwecut_in,distribution_type_in,xprime_in); + EXPECT_EQ(pwb.gg_cache_storage, nullptr); + EXPECT_EQ(pwb.gdirect_cache_storage, nullptr); + EXPECT_EQ(pwb.gcar_cache_storage, nullptr); + EXPECT_EQ(pwb.ig2igg_cache_storage, nullptr); + EXPECT_EQ(pwb.gg_uniq_cache_storage, nullptr); + EXPECT_EQ(pwb.get_cache_stats().cache_bytes, 0); +} + +TEST_F(PWBasisTEST, CacheSignatureRejectsChangedLattice) +{ + double lat0 = 1.8897261254578281; + ModuleBase::Matrix3 latvec(10.0,0.0,0.0, + 0.0,10.0,0.0, + 0.0,0.0,10.0); + double gridecut=10.0; + bool gamma_only_in = true; + double pwecut_in = 11.0; + int distribution_type_in = 2; + bool xprime_in = true; + pwb.initgrids(lat0,latvec,gridecut); + pwb.initparameters(gamma_only_in,pwecut_in,distribution_type_in,xprime_in); + EXPECT_NO_THROW(pwb.setuptransform()); + pwb.collect_local_pw(); + int changed_ig = -1; + for (int ig = 0; ig < pwb.npw; ++ig) + { + if (std::abs(pwb.gdirect[ig].x) > 1e-12) + { + changed_ig = ig; + break; + } + } + ASSERT_GE(changed_ig, 0); + const double old_gg = pwb.gg[changed_ig]; + pwb.collect_local_pw(); + EXPECT_EQ(pwb.get_cache_stats().local_pw_hits, 1); + EXPECT_EQ(pwb.get_cache_stats().local_pw_misses, 1); + + pwb.G.e11 *= 1.1; + pwb.GGT = pwb.G * pwb.GT; + pwb.collect_local_pw(); + EXPECT_EQ(pwb.get_cache_stats().local_pw_hits, 1); + EXPECT_EQ(pwb.get_cache_stats().local_pw_misses, 2); + EXPECT_NE(pwb.gg[changed_ig], old_gg); +} + +TEST_F(PWBasisTEST, CacheCollectionRecordsTimers) +{ + ModuleBase::timer::timer_pool.clear(); + double lat0 = 1.8897261254578281; + ModuleBase::Matrix3 latvec(10.0,0.0,0.0, + 0.0,10.0,0.0, + 0.0,0.0,10.0); + double gridecut = 10.0; + bool gamma_only_in = true; + double pwecut_in = 11.0; + int distribution_type_in = 2; + bool xprime_in = true; + pwb.initgrids(lat0, latvec, gridecut); + pwb.initparameters(gamma_only_in, pwecut_in, distribution_type_in, xprime_in); + pwb.setuptransform(); + pwb.collect_local_pw(); + pwb.collect_uniqgg(); + const auto& timer_pool = ModuleBase::timer::timer_pool[pwb.classname]; + EXPECT_TRUE(timer_pool.count("collect_local_pw")); + EXPECT_TRUE(timer_pool.count("collect_uniqgg")); + EXPECT_GE(timer_pool.at("collect_local_pw").calls, 1u); + EXPECT_GE(timer_pool.at("collect_uniqgg").calls, 1u); +} diff --git a/source/source_basis/module_pw/test_serial/pw_cache_bench.cpp b/source/source_basis/module_pw/test_serial/pw_cache_bench.cpp new file mode 100644 index 00000000000..b1e2ae7bc8d --- /dev/null +++ b/source/source_basis/module_pw/test_serial/pw_cache_bench.cpp @@ -0,0 +1,153 @@ +#include "source_base/matrix3.h" +#include "source_base/timer.h" + +#include "../pw_basis.h" +#include "../pw_basis_k.h" + +#include +#include +#include +#include + +#ifdef __MPI +#include "mpi.h" +#endif + +namespace +{ + +using Clock = std::chrono::steady_clock; + +template +double measure_seconds(Func&& func) +{ + const auto start = Clock::now(); + func(); + const auto end = Clock::now(); + return std::chrono::duration(end - start).count(); +} + +void print_metric(const std::string& name, const double value) +{ + std::cout << "METRIC " << name << " " << std::fixed << std::setprecision(9) << value << '\n'; +} + +void print_timer_metric(const std::string& class_name, const std::string& timer_name) +{ + const auto class_it = ModuleBase::timer::timer_pool.find(class_name); + if (class_it == ModuleBase::timer::timer_pool.end()) + { + return; + } + const auto timer_it = class_it->second.find(timer_name); + if (timer_it == class_it->second.end()) + { + return; + } + print_metric("timer." + class_name + "." + timer_name + ".seconds", timer_it->second.cpu_second); + print_metric("timer." + class_name + "." + timer_name + ".calls", static_cast(timer_it->second.calls)); +} + +void bench_pw_basis() +{ + constexpr int repeat_calls = 2000; + ModuleBase::timer::timer_pool.clear(); + + ModulePW::PW_Basis basis; + const ModuleBase::Matrix3 latvec(1, 0, 0, 0, 1, 0, 0, 0, 1); + const double lat0 = 10.0; + const double wfcecut = 50.0; + const double rhoecut = 4.0 * wfcecut; + const int distribution_type = 1; + + basis.initgrids(lat0, latvec, rhoecut); + basis.initparameters(false, wfcecut, distribution_type, true); + + print_metric("PW_Basis.setuptransform.wall", measure_seconds([&]() { basis.setuptransform(); })); + print_metric("PW_Basis.collect_local_pw.first.wall", measure_seconds([&]() { basis.collect_local_pw(); })); + print_metric("PW_Basis.collect_local_pw.repeat.wall", + measure_seconds([&]() { + for (int i = 0; i < repeat_calls; ++i) + { + basis.collect_local_pw(); + } + })); + print_metric("PW_Basis.collect_uniqgg.first.wall", measure_seconds([&]() { basis.collect_uniqgg(); })); + print_metric("PW_Basis.collect_uniqgg.repeat.wall", + measure_seconds([&]() { + for (int i = 0; i < repeat_calls; ++i) + { + basis.collect_uniqgg(); + } + })); + + print_timer_metric("PW_Basis", "setuptransform"); + print_timer_metric("PW_Basis", "collect_local_pw"); + print_timer_metric("PW_Basis", "collect_local_pw_cache_hit"); + print_timer_metric("PW_Basis", "collect_local_pw_cache_build"); + print_timer_metric("PW_Basis", "collect_uniqgg"); + print_timer_metric("PW_Basis", "collect_uniqgg_cache_hit"); + print_timer_metric("PW_Basis", "collect_uniqgg_cache_build"); +} + +void bench_pw_basis_k() +{ + constexpr int repeat_calls = 2000; + ModuleBase::timer::timer_pool.clear(); + + ModulePW::PW_Basis_K basis("cpu", "double"); + const ModuleBase::Matrix3 latvec(10.0, 0.0, 0.0, + 0.0, 10.0, 0.0, + 0.0, 0.0, 10.0); + const double lat0 = 1.8897261254578281; + const double gridecut = 10.0; + const bool gamma_only = true; + const double gk_ecut = 11.0; + const int nks = 3; + const ModuleBase::Vector3 kvec_d[3] = {{0.0, 0.0, 0.0}, {0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}}; + const int distribution_type = 1; + const bool xprime = true; + + basis.initgrids(lat0, latvec, gridecut); + basis.initparameters(gamma_only, gk_ecut, nks, kvec_d, distribution_type, xprime); + + print_metric("PW_Basis_K.setuptransform.wall", measure_seconds([&]() { basis.setuptransform(); })); + print_metric("PW_Basis_K.collect_local_pw.first.wall", measure_seconds([&]() { basis.collect_local_pw(); })); + print_metric("PW_Basis_K.collect_local_pw.repeat.wall", + measure_seconds([&]() { + for (int i = 0; i < repeat_calls; ++i) + { + basis.collect_local_pw(); + } + })); + print_metric("PW_Basis_K.collect_local_pw.gk2_rebuild.wall", + measure_seconds([&]() { + for (int i = 0; i < repeat_calls; ++i) + { + basis.collect_local_pw(1.0, 0.5, 0.2); + } + })); + + print_timer_metric("PW_Basis_K", "setuptransform"); + print_timer_metric("PW_Basis_K", "collect_local_pw"); + print_timer_metric("PW_Basis_K", "collect_local_pw_cache_hit"); + print_timer_metric("PW_Basis_K", "collect_local_pw_build_gcar"); + print_timer_metric("PW_Basis_K", "collect_local_pw_build_gk2"); +} + +} // namespace + +int main() +{ +#ifdef __MPI + int argc = 0; + char** argv = nullptr; + MPI_Init(&argc, &argv); +#endif + bench_pw_basis(); + bench_pw_basis_k(); +#ifdef __MPI + MPI_Finalize(); +#endif + return 0; +}