diff --git a/src/passes/Heap2Local.cpp b/src/passes/Heap2Local.cpp index 59c1e7e3fc5..22d05dd9eae 100644 --- a/src/passes/Heap2Local.cpp +++ b/src/passes/Heap2Local.cpp @@ -194,7 +194,7 @@ enum class ParentChildInteraction : int8_t { // When we insert scratch locals, we sometimes need to record the flow between // their set and subsequent get. -using ScratchInfo = std::unordered_map; +using ScratchInfo = std::unordered_map>; // Core analysis that provides an escapes() method to check if an allocation // escapes in a way that prevents optimizing it away as described above. It also @@ -292,8 +292,9 @@ struct EscapeAnalyzer { // scratchInfo first because it contains sets that localGraph doesn't // know about. if (auto it = scratchInfo.find(set); it != scratchInfo.end()) { - auto* get = it->second; - flows.push({get, parents.getParent(get)}); + for (auto* get : it->second) { + flows.push({get, parents.getParent(get)}); + } } else { // This is one of the sets we are written to, and so we must check for // exclusive use of our allocation by all the gets that read the @@ -625,6 +626,32 @@ struct Struct2Local : PostWalker { // is only something to store if it is non-nullable, and we store it that way. Type descType; + std::unordered_map, std::vector>> + scratchLocals; + + Index addVar(Type type) { + Index idx = builder.addVar(func, type); + scratchLocals[idx] = {}; + return idx; + } + + LocalSet* makeLocalSet(Index index, Expression* value) { + auto* set = builder.makeLocalSet(index, value); + if (auto it = scratchLocals.find(index); it != scratchLocals.end()) { + it->second.first.push_back(set); + } + return set; + } + + LocalGet* makeLocalGet(Index index, Type type) { + auto* get = builder.makeLocalGet(index, type); + if (auto it = scratchLocals.find(index); it != scratchLocals.end()) { + it->second.second.push_back(get); + } + return get; + } + Struct2Local(StructNew* allocation, EscapeAnalyzer& analyzer, Function* func, @@ -634,16 +661,26 @@ struct Struct2Local : PostWalker { // Allocate locals to store the allocation's fields and descriptor in. for (auto field : fields) { - localIndexes.push_back(builder.addVar(func, field.type)); + localIndexes.push_back(addVar(field.type)); } if (allocation->desc) { descType = allocation->desc->type.with(NonNullable); - localIndexes.push_back(builder.addVar(func, descType)); + localIndexes.push_back(addVar(descType)); } // Replace the things we need to using the visit* methods. walk(func->body); + for (auto& [_, setsAndGets] : scratchLocals) { + auto& sets = setsAndGets.first; + auto& gets = setsAndGets.second; + for (auto* set : sets) { + for (auto* get : gets) { + analyzer.scratchInfo[set].push_back(get); + } + } + } + if (refinalize) { ReFinalize().walkFunctionInModule(func, &wasm); } @@ -762,18 +799,17 @@ struct Struct2Local : PostWalker { // Create the temp variables. if (!curr->isWithDefault()) { for (auto field : fields) { - tempIndexes.push_back(builder.addVar(func, field.type)); + tempIndexes.push_back(addVar(field.type)); } } if (curr->desc) { - tempIndexes.push_back(builder.addVar(func, descType)); + tempIndexes.push_back(addVar(descType)); } // Store the initial values into the temp locals. if (!curr->isWithDefault()) { for (Index i = 0; i < fields.size(); i++) { - contents.push_back( - builder.makeLocalSet(tempIndexes[i], curr->operands[i])); + contents.push_back(makeLocalSet(tempIndexes[i], curr->operands[i])); } } if (curr->desc) { @@ -783,7 +819,7 @@ struct Struct2Local : PostWalker { if (curr->desc->type.isNullable()) { desc = builder.makeRefAs(RefAsNonNull, desc); } - contents.push_back(builder.makeLocalSet(tempIndexes[numTemps - 1], desc)); + contents.push_back(makeLocalSet(tempIndexes[numTemps - 1], desc)); } // Store the values into the locals representing the fields. @@ -791,13 +827,12 @@ struct Struct2Local : PostWalker { auto* val = curr->isWithDefault() ? builder.makeConstantExpression(Literal::makeZero(fields[i].type)) - : builder.makeLocalGet(tempIndexes[i], fields[i].type); - contents.push_back(builder.makeLocalSet(localIndexes[i], val)); + : makeLocalGet(tempIndexes[i], fields[i].type); + contents.push_back(makeLocalSet(localIndexes[i], val)); } if (curr->desc) { - auto* val = builder.makeLocalGet(tempIndexes[numTemps - 1], descType); - contents.push_back( - builder.makeLocalSet(localIndexes[fields.size()], val)); + auto* val = makeLocalGet(tempIndexes[numTemps - 1], descType); + contents.push_back(makeLocalSet(localIndexes[fields.size()], val)); } // Replace the allocation with a null reference. This changes the type @@ -913,14 +948,13 @@ struct Struct2Local : PostWalker { // There might be a null value to let through. Reuse curr as a cast to // null. Use a scratch local to move the reference value past the desc // value. - Index scratch = builder.addVar(func, curr->ref->type); - replaceCurrent( - builder.blockify(builder.makeLocalSet(scratch, curr->ref), - builder.makeDrop(curr->desc), - curr)); + Index scratch = addVar(curr->ref->type); + replaceCurrent(builder.blockify(makeLocalSet(scratch, curr->ref), + builder.makeDrop(curr->desc), + curr)); curr->desc = nullptr; curr->type = curr->type.with(curr->type.getHeapType().getBottom()); - curr->ref = builder.makeLocalGet(scratch, curr->ref->type); + curr->ref = makeLocalGet(scratch, curr->ref->type); } else { // Either the cast does not allow nulls or we know the value isn't // null anyway, so the cast certainly fails. @@ -942,7 +976,7 @@ struct Struct2Local : PostWalker { builder.makeIf( builder.makeRefEq( curr->desc, - builder.makeLocalGet(localIndexes[fields.size()], descType)), + makeLocalGet(localIndexes[fields.size()], descType)), builder.makeRefNull(allocation->type.getHeapType()), builder.makeUnreachable()))); } @@ -992,7 +1026,7 @@ struct Struct2Local : PostWalker { } auto descIndex = localIndexes[fields.size()]; - Expression* value = builder.makeLocalGet(descIndex, descType); + Expression* value = makeLocalGet(descIndex, descType); replaceCurrent(builder.blockify(builder.makeDrop(curr->ref), value)); // After removing the ref.get_desc, a null may be falling through, @@ -1009,7 +1043,7 @@ struct Struct2Local : PostWalker { // write the data to the local instead of the heap allocation. auto* replacement = builder.makeSequence( builder.makeDrop(curr->ref), - builder.makeLocalSet(localIndexes[curr->index], curr->value)); + makeLocalSet(localIndexes[curr->index], curr->value)); // This struct.set cannot possibly synchronize with other threads via the // read value, since the struct never escapes this function, so we don't @@ -1044,7 +1078,7 @@ struct Struct2Local : PostWalker { // which may be more refined. refinalize = true; } - Expression* value = builder.makeLocalGet(localIndexes[curr->index], type); + Expression* value = makeLocalGet(localIndexes[curr->index], type); // Note that in theory we could try to do better here than to fix up the // packing and signedness on gets: we could truncate on sets. That would be // more efficient if all gets are unsigned, as gets outnumber sets in @@ -1078,22 +1112,20 @@ struct Struct2Local : PostWalker { // first scratch local in case the evaluation of the modification value ends // up changing the field value. This is similar to the scratch locals used // for struct.new. - auto oldScratch = builder.addVar(func, type); - auto valScratch = builder.addVar(func, type); + auto oldScratch = addVar(type); + auto valScratch = addVar(type); auto local = localIndexes[curr->index]; - auto* block = - builder.makeSequence(builder.makeDrop(curr->ref), - builder.makeLocalSet(valScratch, curr->value)); + auto* block = builder.makeSequence(builder.makeDrop(curr->ref), + makeLocalSet(valScratch, curr->value)); // Stash the old value to return. - block->list.push_back( - builder.makeLocalSet(oldScratch, builder.makeLocalGet(local, type))); + block->list.push_back(makeLocalSet(oldScratch, makeLocalGet(local, type))); // Store the updated value. Expression* newVal = nullptr; if (curr->op == RMWXchg) { - newVal = builder.makeLocalGet(valScratch, type); + newVal = makeLocalGet(valScratch, type); } else { Abstract::Op binop = Abstract::Add; switch (curr->op) { @@ -1116,13 +1148,13 @@ struct Struct2Local : PostWalker { WASM_UNREACHABLE("unexpected op"); } newVal = builder.makeBinary(Abstract::getBinary(type, binop), - builder.makeLocalGet(local, type), - builder.makeLocalGet(valScratch, type)); + makeLocalGet(local, type), + makeLocalGet(valScratch, type)); } - block->list.push_back(builder.makeLocalSet(local, newVal)); + block->list.push_back(makeLocalSet(local, newVal)); // Unstash the old value. - block->list.push_back(builder.makeLocalGet(oldScratch, type)); + block->list.push_back(makeLocalGet(oldScratch, type)); block->type = type; replaceCurrent(block); } @@ -1149,20 +1181,29 @@ struct Struct2Local : PostWalker { expectedType = Type( HeapTypes::eq.getBasic(type.getHeapType().getShared()), Nullable); } - auto oldScratch = builder.addVar(func, type); - auto expectedScratch = builder.addVar(func, expectedType); - auto replacementScratch = builder.addVar(func, type); + auto oldScratch = addVar(type); + auto expectedScratch = addVar(expectedType); + auto replacementScratch = addVar(type); auto local = localIndexes[curr->index]; - auto* block = builder.makeBlock( - {builder.makeDrop(curr->ref), - builder.makeLocalSet(expectedScratch, curr->expected), - builder.makeLocalSet(replacementScratch, curr->replacement), - builder.makeLocalSet(oldScratch, builder.makeLocalGet(local, type))}); + auto* setExpectedScratch = makeLocalSet(expectedScratch, curr->expected); + auto* setReplacementScratch = + makeLocalSet(replacementScratch, curr->replacement); + auto* setOldScratch = makeLocalSet(oldScratch, makeLocalGet(local, type)); + auto* block = builder.makeBlock({builder.makeDrop(curr->ref), + setExpectedScratch, + setReplacementScratch, + setOldScratch}); + analyzer.parents.setParent(curr->replacement, setReplacementScratch); + analyzer.parents.setParent(curr->expected, setExpectedScratch); + + analyzer.parents.setParent(setExpectedScratch, block); + analyzer.parents.setParent(setReplacementScratch, block); + analyzer.parents.setParent(setOldScratch, block); // Create the check for whether we should do the exchange. - auto* lhs = builder.makeLocalGet(local, type); - auto* rhs = builder.makeLocalGet(expectedScratch, expectedType); + auto* lhs = makeLocalGet(local, type); + auto* rhs = makeLocalGet(expectedScratch, expectedType); Expression* pred; if (type.isRef()) { pred = builder.makeRefEq(lhs, rhs); @@ -1170,15 +1211,31 @@ struct Struct2Local : PostWalker { pred = builder.makeBinary(Abstract::getBinary(type, Abstract::Eq), lhs, rhs); } + analyzer.parents.setParent(rhs, pred); + analyzer.parents.setParent(lhs, pred); // The conditional exchange. - block->list.push_back(builder.makeIf( - pred, - builder.makeLocalSet(local, - builder.makeLocalGet(replacementScratch, type)))); + auto* getReplacementScratch = makeLocalGet(replacementScratch, type); + auto* setLocal = makeLocalSet(local, getReplacementScratch); + auto* iff = builder.makeIf(pred, setLocal); + block->list.push_back(iff); + + analyzer.parents.setParent(getReplacementScratch, setLocal); + analyzer.parents.setParent(setLocal, iff); + analyzer.parents.setParent(pred, iff); + analyzer.parents.setParent(iff, block); // Unstash the old value. - block->list.push_back(builder.makeLocalGet(oldScratch, type)); + auto* getOldScratch = makeLocalGet(oldScratch, type); + block->list.push_back(getOldScratch); + + analyzer.parents.setParent(getOldScratch, block); + + auto* parent = analyzer.parents.getParent(curr); + if (parent) { + analyzer.parents.setParent(block, parent); + } + block->type = type; replaceCurrent(block); return; @@ -1196,9 +1253,9 @@ struct Struct2Local : PostWalker { // happened. Use a nullable scratch local in case we also optimize `ref` // later and need to replace it with a null. auto refType = curr->ref->type.with(Nullable); - auto refScratch = builder.addVar(func, refType); - auto* setRefScratch = builder.makeLocalSet(refScratch, curr->ref); - auto* getRefScratch = builder.makeLocalGet(refScratch, refType); + auto refScratch = addVar(refType); + auto* setRefScratch = makeLocalSet(refScratch, curr->ref); + auto* getRefScratch = makeLocalGet(refScratch, refType); auto* structGet = builder.makeStructGet( curr->index, getRefScratch, curr->order, curr->type); auto* block = builder.makeBlock({setRefScratch, @@ -1210,7 +1267,6 @@ struct Struct2Local : PostWalker { // necessary in case `ref` gets processed later so we can detect that it // flows to the new struct.atomic.get, which may need to be replaced. analyzer.parents.setParent(curr->ref, setRefScratch); - analyzer.scratchInfo.insert({setRefScratch, getRefScratch}); analyzer.parents.setParent(getRefScratch, structGet); return; } @@ -1228,13 +1284,13 @@ struct Struct2Local : PostWalker { // See the equivalent handling of allocations flowing through the // `expected` field of StructCmpxchg. auto refType = curr->ref->type.with(Nullable); - auto refScratch = builder.addVar(func, refType); - auto* setRefScratch = builder.makeLocalSet(refScratch, curr->ref); - auto* getRefScratch = builder.makeLocalGet(refScratch, refType); + auto refScratch = addVar(refType); + auto* setRefScratch = makeLocalSet(refScratch, curr->ref); + auto* getRefScratch = makeLocalGet(refScratch, refType); - auto indexScratch = builder.addVar(func, Type::i32); - auto* setIndexScratch = builder.makeLocalSet(indexScratch, curr->index); - auto* getIndexScratch = builder.makeLocalGet(indexScratch, Type::i32); + auto indexScratch = addVar(Type::i32); + auto* setIndexScratch = makeLocalSet(indexScratch, curr->index); + auto* getIndexScratch = makeLocalGet(indexScratch, Type::i32); auto* arrayGet = builder.makeArrayGet( getRefScratch, getIndexScratch, curr->order, curr->type); diff --git a/test/lit/passes/heap2local-rmw.wast b/test/lit/passes/heap2local-rmw.wast index 7d329d2b53e..0d02d92caa2 100644 --- a/test/lit/passes/heap2local-rmw.wast +++ b/test/lit/passes/heap2local-rmw.wast @@ -1169,7 +1169,7 @@ ;; CHECK-NEXT: (ref.null none) ;; CHECK-NEXT: ) ;; CHECK-NEXT: ) - ;; CHECK-NEXT: (local.set $3 + ;; CHECK-NEXT: (drop ;; CHECK-NEXT: (block (result nullref) ;; CHECK-NEXT: (ref.null none) ;; CHECK-NEXT: ) @@ -1181,9 +1181,14 @@ ;; CHECK-NEXT: (local.get $1) ;; CHECK-NEXT: ) ;; CHECK-NEXT: (if - ;; CHECK-NEXT: (ref.eq - ;; CHECK-NEXT: (local.get $1) - ;; CHECK-NEXT: (local.get $3) + ;; CHECK-NEXT: (block (result i32) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (local.get $1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (ref.null none) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (i32.const 0) ;; CHECK-NEXT: ) ;; CHECK-NEXT: (then ;; CHECK-NEXT: (local.set $1 @@ -1252,7 +1257,7 @@ ;; CHECK-NEXT: (ref.null none) ;; CHECK-NEXT: ) ;; CHECK-NEXT: ) - ;; CHECK-NEXT: (local.set $3 + ;; CHECK-NEXT: (drop ;; CHECK-NEXT: (block (result nullref) ;; CHECK-NEXT: (ref.null none) ;; CHECK-NEXT: ) @@ -1264,9 +1269,14 @@ ;; CHECK-NEXT: (local.get $0) ;; CHECK-NEXT: ) ;; CHECK-NEXT: (if - ;; CHECK-NEXT: (ref.eq - ;; CHECK-NEXT: (local.get $0) - ;; CHECK-NEXT: (local.get $3) + ;; CHECK-NEXT: (block (result i32) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (ref.null none) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (i32.const 0) ;; CHECK-NEXT: ) ;; CHECK-NEXT: (then ;; CHECK-NEXT: (local.set $0 @@ -1313,7 +1323,7 @@ ;; CHECK-NEXT: (ref.null (shared none)) ;; CHECK-NEXT: ) ;; CHECK-NEXT: ) - ;; CHECK-NEXT: (local.set $3 + ;; CHECK-NEXT: (drop ;; CHECK-NEXT: (block (result (ref null (shared none))) ;; CHECK-NEXT: (ref.null (shared none)) ;; CHECK-NEXT: ) @@ -1325,9 +1335,14 @@ ;; CHECK-NEXT: (local.get $0) ;; CHECK-NEXT: ) ;; CHECK-NEXT: (if - ;; CHECK-NEXT: (ref.eq - ;; CHECK-NEXT: (local.get $0) - ;; CHECK-NEXT: (local.get $3) + ;; CHECK-NEXT: (block (result i32) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (ref.null (shared none)) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (i32.const 0) ;; CHECK-NEXT: ) ;; CHECK-NEXT: (then ;; CHECK-NEXT: (local.set $0 @@ -1609,3 +1624,100 @@ ) ) ) + +(module + ;; CHECK: (type $struct (shared (struct (field (mut (ref null $struct)))))) + (type $struct (shared (struct (field (mut (ref null $struct)))))) + ;; CHECK: (type $1 (func (result i32))) + + ;; CHECK: (func $must-optimize-ref-eq (type $1) (result i32) + ;; CHECK-NEXT: (local $local (ref null $struct)) + ;; CHECK-NEXT: (local $1 (ref null $struct)) + ;; CHECK-NEXT: (local $2 (ref null $struct)) + ;; CHECK-NEXT: (local $3 (ref null (shared eq))) + ;; CHECK-NEXT: (local $4 (ref null $struct)) + ;; CHECK-NEXT: (local $5 (ref null $struct)) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (block (result (ref null $struct)) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (block (result (ref null (shared none))) + ;; CHECK-NEXT: (local.set $1 + ;; CHECK-NEXT: (ref.null (shared none)) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (ref.null (shared none)) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (block (result (ref null (shared none))) + ;; CHECK-NEXT: (local.set $5 + ;; CHECK-NEXT: (ref.null (shared none)) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (ref.null (shared none)) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $4 + ;; CHECK-NEXT: (struct.new_default $struct) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $2 + ;; CHECK-NEXT: (local.get $1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (if + ;; CHECK-NEXT: (block (result i32) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (local.get $1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (ref.null (shared none)) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (i32.const 0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (then + ;; CHECK-NEXT: (local.set $1 + ;; CHECK-NEXT: (local.get $4) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.get $2) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (ref.is_null + ;; CHECK-NEXT: (block (result (ref null $struct)) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (ref.null (shared none)) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.get $1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $must-optimize-ref-eq (result i32) + (local $local (ref null $struct)) + (drop + ;; 2. Next this will be optimized. The local from (1) representing field 0 + ;; will be compared with a local holding the expected value below. This + ;; ref.eq comparison should fail if we optimize correctly, leaving the + ;; local from (1) with its default null value. + (struct.atomic.rmw.cmpxchg acqrel acqrel $struct 0 + (local.tee $local + ;; 1. This will be optimized. The local representing field 0 will be + ;; set to null. + (struct.new_default $struct) + ) + ;; 3. This is the next allocation to be optimized. When it is replaced + ;; with a ref.null, it would cause the ref.eq comparison created in (2) + ;; to start succeeding incorrectly, except that we optimize the ref.eq + ;; when we see that this allocation flows only into one side of it. + (struct.new_default $struct) ;; Expected + ;; 4. This allocation is not optimized, so we would end up incorrectly + ;; writing this non-null value into the local from (1) if we had not + ;; optimized the ref.eq. + (struct.new_default $struct) ;; Replacement + ) + ) + (ref.is_null + ;; This is replaced with a get of the local from (1), which would + ;; incorrectly contain the non-null replacement value if we had not + ;; optimized the ref.eq. + (struct.get $struct 0 (local.get $local)) + ) + ) +)