Skip to content

Commit 1554110

Browse files
committed
einsum: parallelize hadamard-reduction outer h-tile loop
1 parent 7f76cda commit 1554110

1 file changed

Lines changed: 144 additions & 54 deletions

File tree

src/TiledArray/einsum/tiledarray.h

Lines changed: 144 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,6 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
642642
}
643643

644644
if (!e) { // hadamard reduction
645-
646645
auto &[A, B] = AB;
647646
TiledRange trange(range_map[i]);
648647
RangeProduct tiles;
@@ -685,75 +684,147 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
685684
: element_contract_op.value()(l, r);
686685
};
687686

688-
auto pa = A.permutation;
689-
auto pb = B.permutation;
687+
auto const pa = A.permutation;
688+
auto const pb = B.permutation;
689+
auto const pc = C.permutation;
690+
691+
// Each H-tile iteration produces an independent output tile, so the
692+
// loop is parallel-safe. Dispatch per-H-tile work to the MADNESS task
693+
// queue; pre-size a per-slot result vector so tasks write their own
694+
// slot without synchronization, and gather before exiting scope so
695+
// captured references stay alive for the task lifetime.
696+
//
697+
// Input tiles must be resolved BEFORE submitting tasks: calling
698+
// .get() on an unready future from inside a task body is unsafe
699+
// with the PaRSEC backend ("recursive call to wait"). We instead
700+
// issue all find()s up-front (non-blocking, so requests overlap)
701+
// and materialize them on the main thread; the submitted tasks
702+
// then operate purely on local data.
703+
using ATile = typename ArrayA::value_type;
704+
using BTile = typename ArrayB::value_type;
705+
706+
// Per-H-tile job metadata shared by both the in-flight (futures)
707+
// representation and the materialized (resolved tiles) representation.
708+
struct HJobMeta {
709+
Index h; // H-tile coord in A/B's annotation
710+
Index c_target; // C-tile coord (= apply(pc, h))
711+
size_t batch; // product of H.batch sizes for this h
712+
};
713+
struct PendingHJob : HJobMeta {
714+
std::vector<
715+
std::tuple<Index, madness::Future<ATile>, madness::Future<BTile>>>
716+
inputs;
717+
};
718+
struct HJob : HJobMeta {
719+
// (i_index, ai, bi) for each non-zero input pair contributing to h
720+
std::vector<std::tuple<Index, ATile, BTile>> inputs;
721+
};
722+
723+
// Phase 1: issue all find() calls (non-blocking) so remote requests
724+
// are in flight concurrently; collect futures + metadata.
725+
std::vector<PendingHJob> pending_jobs;
690726
for (Index h : H.tiles) {
691-
auto const pc = C.permutation;
692-
auto const c = apply(pc, h);
693-
if (!C.array.is_local(c)) continue;
694-
size_t batch = 1;
695-
for (size_t i = 0; i < h.size(); ++i) {
696-
batch *= H.batch[i].at(h[i]);
727+
auto const c_target = apply(pc, h);
728+
if (!C.array.is_local(c_target)) continue;
729+
PendingHJob pj;
730+
pj.h = h;
731+
pj.c_target = c_target;
732+
pj.batch = 1;
733+
for (size_t hi = 0; hi < h.size(); ++hi) {
734+
pj.batch *= H.batch[hi].at(h[hi]);
697735
}
698-
ResultTensor tile(TiledArray::Range{batch},
699-
typename ResultTensor::value_type{});
700736
for (Index i : tiles) {
701-
// skip this unless both input tiles exist
702737
const auto pahi_inv = apply_inverse(pa, h + i);
703738
const auto pbhi_inv = apply_inverse(pb, h + i);
704739
if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv)) continue;
740+
pj.inputs.emplace_back(i, A.array.find(pahi_inv),
741+
B.array.find(pbhi_inv));
742+
}
743+
pending_jobs.push_back(std::move(pj));
744+
}
745+
746+
// Phase 2: materialize all input tiles on the main thread.
747+
auto materialize = [](PendingHJob &&pj) -> HJob {
748+
HJob job;
749+
static_cast<HJobMeta &>(job) = static_cast<HJobMeta const &>(pj);
750+
job.inputs.reserve(pj.inputs.size());
751+
for (auto &[i, fa, fb] : pj.inputs)
752+
job.inputs.emplace_back(std::move(i), fa.get(), fb.get());
753+
return job;
754+
};
755+
std::vector<HJob> jobs;
756+
jobs.reserve(pending_jobs.size());
757+
for (auto &pj : pending_jobs) jobs.push_back(materialize(std::move(pj)));
758+
pending_jobs.clear();
759+
760+
// Single inner-product element op subsuming the three template arms
761+
// (ToT-ToT, mixed T/ToT, plain T). Writes one batch element of `out`
762+
// from one batch slice each of A and B.
763+
auto kelement_op = [&](auto &out, auto const &aik, auto const &bik) {
764+
if constexpr (AreArrayToT<ArrayA, ArrayB>) {
765+
auto vol = aik.total_size();
766+
TA_ASSERT(vol == bik.total_size());
767+
for (auto ii = 0; ii < vol; ++ii)
768+
out.add_to(element_product_op(aik.data()[ii], bik.data()[ii]));
769+
} else if constexpr (!AreArraySame<ArrayA, ArrayB>) {
770+
auto vol = aik.total_size();
771+
TA_ASSERT(vol == bik.total_size());
772+
for (auto ii = 0; ii < vol; ++ii) {
773+
if constexpr (IsArrayToT<ArrayA>) {
774+
out.add_to(aik.data()[ii].scale(bik.data()[ii]));
775+
} else {
776+
out.add_to(bik.data()[ii].scale(aik.data()[ii]));
777+
}
778+
}
779+
} else {
780+
out += aik.dot(bik);
781+
}
782+
};
783+
784+
std::vector<std::pair<Index, ResultTensor>> h_results(jobs.size());
705785

706-
auto ai = A.array.find(pahi_inv).get();
707-
auto bi = B.array.find(pbhi_inv).get();
786+
// per_h_work: process jobs[slot] and write into h_results[slot].
787+
// Captures listed explicitly so the lifetime contract is checkable:
788+
// every captured reference must outlive the task queue (gathered
789+
// below before this scope exits).
790+
auto per_h_work = [&jobs, &trange, &h_results, &kelement_op, pa, pb, pc,
791+
&C](size_t slot) -> bool {
792+
auto const &job = jobs[slot];
793+
size_t batch = job.batch;
794+
ResultTensor tile(TiledArray::Range{batch},
795+
typename ResultTensor::value_type{});
796+
for (auto const &[i, ai_in, bi_in] : job.inputs) {
797+
ATile ai = ai_in;
798+
BTile bi = bi_in;
708799
if (pa) ai = ai.permute(pa);
709800
if (pb) bi = bi.permute(pb);
710801
auto shape = trange.tile(i);
711802
ai = ai.reshape(shape, batch);
712803
bi = bi.reshape(shape, batch);
713804
for (size_t k = 0; k < batch; ++k) {
714-
using Ix = ::Einsum::Index<std::string>;
715-
if constexpr (AreArrayToT<ArrayA, ArrayB>) {
716-
auto aik = ai.batch(k);
717-
auto bik = bi.batch(k);
718-
auto vol = aik.total_size();
719-
TA_ASSERT(vol == bik.total_size());
720-
721-
auto &el = tile({k});
722-
using TensorT = std::remove_reference_t<decltype(el)>;
723-
724-
for (auto i = 0; i < vol; ++i)
725-
el.add_to(element_product_op(aik.data()[i], bik.data()[i]));
726-
727-
} else if constexpr (!AreArraySame<ArrayA, ArrayB>) {
728-
auto aik = ai.batch(k);
729-
auto bik = bi.batch(k);
730-
auto vol = aik.total_size();
731-
TA_ASSERT(vol == bik.total_size());
732-
733-
auto &el = tile({k});
734-
735-
for (auto i = 0; i < vol; ++i)
736-
if constexpr (IsArrayToT<ArrayA>) {
737-
el.add_to(aik.data()[i].scale(bik.data()[i]));
738-
} else {
739-
el.add_to(bik.data()[i].scale(aik.data()[i]));
740-
}
741-
742-
} else {
743-
auto hk = ai.batch(k).dot(bi.batch(k));
744-
tile({k}) += hk;
745-
}
805+
auto &el = tile({k});
806+
kelement_op(el, ai.batch(k), bi.batch(k));
746807
}
747808
}
748809
// data is stored as h1 h2 ... but all modes folded as 1 batch dim
749810
// first reshape to h = (h1 h2 ...)
750811
// n.b. can't just use shape = C.array.trange().tile(h)
751-
auto shape = apply_inverse(pc, C.array.trange().tile(c));
812+
auto shape = apply_inverse(pc, C.array.trange().tile(job.c_target));
752813
tile = tile.reshape(shape);
753814
// then permute to target C layout c = (c1 c2 ...)
754815
if (pc) tile = tile.permute(pc);
755-
// and move to C_local_tiles
756-
C_local_tiles.emplace_back(std::move(c), std::move(tile));
816+
h_results[slot] = {job.c_target, std::move(tile)};
817+
return true;
818+
};
819+
820+
std::vector<madness::Future<bool>> h_futures;
821+
h_futures.reserve(jobs.size());
822+
for (size_t slot = 0; slot < jobs.size(); ++slot) {
823+
h_futures.push_back(world.taskq.add(per_h_work, slot));
824+
}
825+
for (auto &fut : h_futures) fut.get();
826+
for (auto &r : h_results) {
827+
C_local_tiles.emplace_back(std::move(r.first), std::move(r.second));
757828
}
758829

759830
build_C_array();
@@ -809,17 +880,36 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
809880
term.local_tiles.clear();
810881
const Permutation &P = term.permutation;
811882

883+
using TileType =
884+
typename std::decay_t<decltype(term.array)>::value_type;
885+
std::vector<std::pair<Index, madness::Future<TileType>>> pending;
886+
812887
for (Index ei : term.tiles) {
813888
auto idx = apply_inverse(P, h + ei);
814889
if (!term.array.is_local(idx)) continue;
815890
if (term.array.is_zero(idx)) continue;
816-
// TODO no need for immediate evaluation
817-
auto tile = term.array.find_local(idx).get();
818-
if (P) tile = tile.permute(P);
891+
892+
auto tile_future = term.array.find_local(idx); // non-blocking
819893
auto shape = term.ei_tiled_range.tile(ei);
820-
tile = tile.reshape(shape, batch);
821-
term.local_tiles.push_back({ei, tile});
894+
895+
// Submit per-tile permute-and-reshape as a MADNESS task.
896+
// Capture P by value (it's small) and shape by value.
897+
madness::Future<TileType> permuted = owners->taskq.add(
898+
[P, shape, batch](const TileType &tile) -> TileType {
899+
TileType result = P ? tile.permute(P) : tile;
900+
return result.reshape(shape, batch);
901+
},
902+
tile_future);
903+
904+
pending.emplace_back(ei, permuted);
905+
}
906+
907+
// Wait for all per-tile tasks to complete, then gather into
908+
// local_tiles.
909+
for (auto &[ei, fut] : pending) {
910+
term.local_tiles.push_back({ei, fut.get()});
822911
}
912+
823913
bool replicated = term.array.pmap()->is_replicated();
824914
term.ei = TiledArray::make_array<decltype(term.array)>(
825915
*owners, term.ei_tiled_range, term.local_tiles.begin(),

0 commit comments

Comments
 (0)