Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 118 additions & 62 deletions src/passes/Heap2Local.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ enum class ParentChildInteraction : int8_t {

// When we insert scratch locals, we sometimes need to record the flow between
// their set and subsequent get.
using ScratchInfo = std::unordered_map<LocalSet*, LocalGet*>;
using ScratchInfo = std::unordered_map<LocalSet*, std::vector<LocalGet*>>;

// Core analysis that provides an escapes() method to check if an allocation
// escapes in a way that prevents optimizing it away as described above. It also
Expand Down Expand Up @@ -292,8 +292,9 @@ struct EscapeAnalyzer {
// scratchInfo first because it contains sets that localGraph doesn't
// know about.
if (auto it = scratchInfo.find(set); it != scratchInfo.end()) {
auto* get = it->second;
flows.push({get, parents.getParent(get)});
for (auto* get : it->second) {
flows.push({get, parents.getParent(get)});
}
} else {
// This is one of the sets we are written to, and so we must check for
// exclusive use of our allocation by all the gets that read the
Expand Down Expand Up @@ -625,6 +626,32 @@ struct Struct2Local : PostWalker<Struct2Local> {
// is only something to store if it is non-nullable, and we store it that way.
Type descType;

std::unordered_map<Index,
std::pair<std::vector<LocalSet*>, std::vector<LocalGet*>>>
scratchLocals;

Index addVar(Type type) {
Index idx = builder.addVar(func, type);
scratchLocals[idx] = {};
return idx;
}

LocalSet* makeLocalSet(Index index, Expression* value) {
auto* set = builder.makeLocalSet(index, value);
if (auto it = scratchLocals.find(index); it != scratchLocals.end()) {
it->second.first.push_back(set);
}
return set;
}

LocalGet* makeLocalGet(Index index, Type type) {
auto* get = builder.makeLocalGet(index, type);
if (auto it = scratchLocals.find(index); it != scratchLocals.end()) {
it->second.second.push_back(get);
}
return get;
}

Struct2Local(StructNew* allocation,
EscapeAnalyzer& analyzer,
Function* func,
Expand All @@ -634,16 +661,26 @@ struct Struct2Local : PostWalker<Struct2Local> {

// Allocate locals to store the allocation's fields and descriptor in.
for (auto field : fields) {
localIndexes.push_back(builder.addVar(func, field.type));
localIndexes.push_back(addVar(field.type));
}
if (allocation->desc) {
descType = allocation->desc->type.with(NonNullable);
localIndexes.push_back(builder.addVar(func, descType));
localIndexes.push_back(addVar(descType));
}

// Replace the things we need to using the visit* methods.
walk(func->body);

for (auto& [_, setsAndGets] : scratchLocals) {
auto& sets = setsAndGets.first;
auto& gets = setsAndGets.second;
for (auto* set : sets) {
for (auto* get : gets) {
analyzer.scratchInfo[set].push_back(get);
}
}
}

if (refinalize) {
ReFinalize().walkFunctionInModule(func, &wasm);
}
Expand Down Expand Up @@ -762,18 +799,17 @@ struct Struct2Local : PostWalker<Struct2Local> {
// Create the temp variables.
if (!curr->isWithDefault()) {
for (auto field : fields) {
tempIndexes.push_back(builder.addVar(func, field.type));
tempIndexes.push_back(addVar(field.type));
}
}
if (curr->desc) {
tempIndexes.push_back(builder.addVar(func, descType));
tempIndexes.push_back(addVar(descType));
}

// Store the initial values into the temp locals.
if (!curr->isWithDefault()) {
for (Index i = 0; i < fields.size(); i++) {
contents.push_back(
builder.makeLocalSet(tempIndexes[i], curr->operands[i]));
contents.push_back(makeLocalSet(tempIndexes[i], curr->operands[i]));
}
}
if (curr->desc) {
Expand All @@ -783,21 +819,20 @@ struct Struct2Local : PostWalker<Struct2Local> {
if (curr->desc->type.isNullable()) {
desc = builder.makeRefAs(RefAsNonNull, desc);
}
contents.push_back(builder.makeLocalSet(tempIndexes[numTemps - 1], desc));
contents.push_back(makeLocalSet(tempIndexes[numTemps - 1], desc));
}

// Store the values into the locals representing the fields.
for (Index i = 0; i < fields.size(); ++i) {
auto* val =
curr->isWithDefault()
? builder.makeConstantExpression(Literal::makeZero(fields[i].type))
: builder.makeLocalGet(tempIndexes[i], fields[i].type);
contents.push_back(builder.makeLocalSet(localIndexes[i], val));
: makeLocalGet(tempIndexes[i], fields[i].type);
contents.push_back(makeLocalSet(localIndexes[i], val));
}
if (curr->desc) {
auto* val = builder.makeLocalGet(tempIndexes[numTemps - 1], descType);
contents.push_back(
builder.makeLocalSet(localIndexes[fields.size()], val));
auto* val = makeLocalGet(tempIndexes[numTemps - 1], descType);
contents.push_back(makeLocalSet(localIndexes[fields.size()], val));
}

// Replace the allocation with a null reference. This changes the type
Expand Down Expand Up @@ -913,14 +948,13 @@ struct Struct2Local : PostWalker<Struct2Local> {
// There might be a null value to let through. Reuse curr as a cast to
// null. Use a scratch local to move the reference value past the desc
// value.
Index scratch = builder.addVar(func, curr->ref->type);
replaceCurrent(
builder.blockify(builder.makeLocalSet(scratch, curr->ref),
builder.makeDrop(curr->desc),
curr));
Index scratch = addVar(curr->ref->type);
replaceCurrent(builder.blockify(makeLocalSet(scratch, curr->ref),
builder.makeDrop(curr->desc),
curr));
curr->desc = nullptr;
curr->type = curr->type.with(curr->type.getHeapType().getBottom());
curr->ref = builder.makeLocalGet(scratch, curr->ref->type);
curr->ref = makeLocalGet(scratch, curr->ref->type);
} else {
// Either the cast does not allow nulls or we know the value isn't
// null anyway, so the cast certainly fails.
Expand All @@ -942,7 +976,7 @@ struct Struct2Local : PostWalker<Struct2Local> {
builder.makeIf(
builder.makeRefEq(
curr->desc,
builder.makeLocalGet(localIndexes[fields.size()], descType)),
makeLocalGet(localIndexes[fields.size()], descType)),
builder.makeRefNull(allocation->type.getHeapType()),
builder.makeUnreachable())));
}
Expand Down Expand Up @@ -992,7 +1026,7 @@ struct Struct2Local : PostWalker<Struct2Local> {
}

auto descIndex = localIndexes[fields.size()];
Expression* value = builder.makeLocalGet(descIndex, descType);
Expression* value = makeLocalGet(descIndex, descType);
replaceCurrent(builder.blockify(builder.makeDrop(curr->ref), value));

// After removing the ref.get_desc, a null may be falling through,
Expand All @@ -1009,7 +1043,7 @@ struct Struct2Local : PostWalker<Struct2Local> {
// write the data to the local instead of the heap allocation.
auto* replacement = builder.makeSequence(
builder.makeDrop(curr->ref),
builder.makeLocalSet(localIndexes[curr->index], curr->value));
makeLocalSet(localIndexes[curr->index], curr->value));

// This struct.set cannot possibly synchronize with other threads via the
// read value, since the struct never escapes this function, so we don't
Expand Down Expand Up @@ -1044,7 +1078,7 @@ struct Struct2Local : PostWalker<Struct2Local> {
// which may be more refined.
refinalize = true;
}
Expression* value = builder.makeLocalGet(localIndexes[curr->index], type);
Expression* value = makeLocalGet(localIndexes[curr->index], type);
// Note that in theory we could try to do better here than to fix up the
// packing and signedness on gets: we could truncate on sets. That would be
// more efficient if all gets are unsigned, as gets outnumber sets in
Expand Down Expand Up @@ -1078,22 +1112,20 @@ struct Struct2Local : PostWalker<Struct2Local> {
// first scratch local in case the evaluation of the modification value ends
// up changing the field value. This is similar to the scratch locals used
// for struct.new.
auto oldScratch = builder.addVar(func, type);
auto valScratch = builder.addVar(func, type);
auto oldScratch = addVar(type);
auto valScratch = addVar(type);
auto local = localIndexes[curr->index];

auto* block =
builder.makeSequence(builder.makeDrop(curr->ref),
builder.makeLocalSet(valScratch, curr->value));
auto* block = builder.makeSequence(builder.makeDrop(curr->ref),
makeLocalSet(valScratch, curr->value));

// Stash the old value to return.
block->list.push_back(
builder.makeLocalSet(oldScratch, builder.makeLocalGet(local, type)));
block->list.push_back(makeLocalSet(oldScratch, makeLocalGet(local, type)));

// Store the updated value.
Expression* newVal = nullptr;
if (curr->op == RMWXchg) {
newVal = builder.makeLocalGet(valScratch, type);
newVal = makeLocalGet(valScratch, type);
} else {
Abstract::Op binop = Abstract::Add;
switch (curr->op) {
Expand All @@ -1116,13 +1148,13 @@ struct Struct2Local : PostWalker<Struct2Local> {
WASM_UNREACHABLE("unexpected op");
}
newVal = builder.makeBinary(Abstract::getBinary(type, binop),
builder.makeLocalGet(local, type),
builder.makeLocalGet(valScratch, type));
makeLocalGet(local, type),
makeLocalGet(valScratch, type));
}
block->list.push_back(builder.makeLocalSet(local, newVal));
block->list.push_back(makeLocalSet(local, newVal));

// Unstash the old value.
block->list.push_back(builder.makeLocalGet(oldScratch, type));
block->list.push_back(makeLocalGet(oldScratch, type));
block->type = type;
replaceCurrent(block);
}
Expand All @@ -1149,36 +1181,61 @@ struct Struct2Local : PostWalker<Struct2Local> {
expectedType = Type(
HeapTypes::eq.getBasic(type.getHeapType().getShared()), Nullable);
}
auto oldScratch = builder.addVar(func, type);
auto expectedScratch = builder.addVar(func, expectedType);
auto replacementScratch = builder.addVar(func, type);
auto oldScratch = addVar(type);
auto expectedScratch = addVar(expectedType);
auto replacementScratch = addVar(type);
auto local = localIndexes[curr->index];

auto* block = builder.makeBlock(
{builder.makeDrop(curr->ref),
builder.makeLocalSet(expectedScratch, curr->expected),
builder.makeLocalSet(replacementScratch, curr->replacement),
builder.makeLocalSet(oldScratch, builder.makeLocalGet(local, type))});
auto* setExpectedScratch = makeLocalSet(expectedScratch, curr->expected);
auto* setReplacementScratch =
makeLocalSet(replacementScratch, curr->replacement);
auto* setOldScratch = makeLocalSet(oldScratch, makeLocalGet(local, type));
auto* block = builder.makeBlock({builder.makeDrop(curr->ref),
setExpectedScratch,
setReplacementScratch,
setOldScratch});
analyzer.parents.setParent(curr->replacement, setReplacementScratch);
analyzer.parents.setParent(curr->expected, setExpectedScratch);

analyzer.parents.setParent(setExpectedScratch, block);
analyzer.parents.setParent(setReplacementScratch, block);
analyzer.parents.setParent(setOldScratch, block);

// Create the check for whether we should do the exchange.
auto* lhs = builder.makeLocalGet(local, type);
auto* rhs = builder.makeLocalGet(expectedScratch, expectedType);
auto* lhs = makeLocalGet(local, type);
auto* rhs = makeLocalGet(expectedScratch, expectedType);
Expression* pred;
if (type.isRef()) {
pred = builder.makeRefEq(lhs, rhs);
} else {
pred =
builder.makeBinary(Abstract::getBinary(type, Abstract::Eq), lhs, rhs);
}
analyzer.parents.setParent(rhs, pred);
analyzer.parents.setParent(lhs, pred);

// The conditional exchange.
block->list.push_back(builder.makeIf(
pred,
builder.makeLocalSet(local,
builder.makeLocalGet(replacementScratch, type))));
auto* getReplacementScratch = makeLocalGet(replacementScratch, type);
auto* setLocal = makeLocalSet(local, getReplacementScratch);
auto* iff = builder.makeIf(pred, setLocal);
block->list.push_back(iff);

analyzer.parents.setParent(getReplacementScratch, setLocal);
analyzer.parents.setParent(setLocal, iff);
analyzer.parents.setParent(pred, iff);
analyzer.parents.setParent(iff, block);

// Unstash the old value.
block->list.push_back(builder.makeLocalGet(oldScratch, type));
auto* getOldScratch = makeLocalGet(oldScratch, type);
block->list.push_back(getOldScratch);

analyzer.parents.setParent(getOldScratch, block);

auto* parent = analyzer.parents.getParent(curr);
if (parent) {
analyzer.parents.setParent(block, parent);
}

block->type = type;
replaceCurrent(block);
return;
Expand All @@ -1196,9 +1253,9 @@ struct Struct2Local : PostWalker<Struct2Local> {
// happened. Use a nullable scratch local in case we also optimize `ref`
// later and need to replace it with a null.
auto refType = curr->ref->type.with(Nullable);
auto refScratch = builder.addVar(func, refType);
auto* setRefScratch = builder.makeLocalSet(refScratch, curr->ref);
auto* getRefScratch = builder.makeLocalGet(refScratch, refType);
auto refScratch = addVar(refType);
auto* setRefScratch = makeLocalSet(refScratch, curr->ref);
auto* getRefScratch = makeLocalGet(refScratch, refType);
auto* structGet = builder.makeStructGet(
curr->index, getRefScratch, curr->order, curr->type);
auto* block = builder.makeBlock({setRefScratch,
Expand All @@ -1210,7 +1267,6 @@ struct Struct2Local : PostWalker<Struct2Local> {
// necessary in case `ref` gets processed later so we can detect that it
// flows to the new struct.atomic.get, which may need to be replaced.
analyzer.parents.setParent(curr->ref, setRefScratch);
analyzer.scratchInfo.insert({setRefScratch, getRefScratch});
analyzer.parents.setParent(getRefScratch, structGet);
return;
}
Expand All @@ -1228,13 +1284,13 @@ struct Struct2Local : PostWalker<Struct2Local> {
// See the equivalent handling of allocations flowing through the
// `expected` field of StructCmpxchg.
auto refType = curr->ref->type.with(Nullable);
auto refScratch = builder.addVar(func, refType);
auto* setRefScratch = builder.makeLocalSet(refScratch, curr->ref);
auto* getRefScratch = builder.makeLocalGet(refScratch, refType);
auto refScratch = addVar(refType);
auto* setRefScratch = makeLocalSet(refScratch, curr->ref);
auto* getRefScratch = makeLocalGet(refScratch, refType);

auto indexScratch = builder.addVar(func, Type::i32);
auto* setIndexScratch = builder.makeLocalSet(indexScratch, curr->index);
auto* getIndexScratch = builder.makeLocalGet(indexScratch, Type::i32);
auto indexScratch = addVar(Type::i32);
auto* setIndexScratch = makeLocalSet(indexScratch, curr->index);
auto* getIndexScratch = makeLocalGet(indexScratch, Type::i32);

auto* arrayGet = builder.makeArrayGet(
getRefScratch, getIndexScratch, curr->order, curr->type);
Expand Down
Loading
Loading