From 681c5747b6824ba46e96f52cddebaba49248b7dc Mon Sep 17 00:00:00 2001 From: Zhendong404 Date: Mon, 18 May 2026 22:02:24 +0800 Subject: [PATCH] feature(tile fusion): support tile op scheduling and marking last_use --- include/PTO/Transforms/CppPostprocess.h | 22 + include/PTO/Transforms/Passes.h | 3 + include/PTO/Transforms/Passes.td | 36 ++ .../Transforms/TileFusion/FusionAnalysis.h | 135 +++++ .../Transforms/TileFusion/FusionOpSemantics.h | 52 ++ lib/PTO/Transforms/CMakeLists.txt | 6 + lib/PTO/Transforms/CppPostprocess.cpp | 176 ++++++ lib/PTO/Transforms/PTOToEmitC.cpp | 348 ++++++------ lib/PTO/Transforms/PTOViewToMemref.cpp | 51 +- .../Transforms/TileFusion/FusionAnalysis.cpp | 506 +++++++++++++++++ .../TileFusion/FusionOpSemantics.cpp | 123 ++++ .../Transforms/TileFusion/PTOFusionPlan.cpp | 526 ++++++++++++++++++ .../Transforms/TileFusion/PTOMarkLastUse.cpp | 268 +++++++++ .../Transforms/TileFusion/PTOOpScheduling.cpp | 321 +++++++++++ .../final_emitc_last_use_level2.pto | 62 +++ test/lit/tile_fusion/fusion_plan_diamond.pto | 52 ++ .../fusion_plan_dynamic_shape_negative.pto | 34 ++ .../fusion_plan_interleaved_join.pto | 45 ++ test/lit/tile_fusion/fusion_plan_join.pto | 35 ++ .../fusion_plan_treshape_boundary.pto | 44 ++ .../mark_last_use_post_span_block_level2.pto | 62 +++ .../mark_last_use_repeated_ssa_level2.pto | 53 ++ .../mark_last_use_slot_mask_level2.pto | 79 +++ ...p_fusion_adapter_placement_level2_tadd.pto | 75 +++ ...p_fusion_adapter_placement_level3_tadd.pto | 75 +++ test/lit/tile_fusion/op_fusion_cli_flags.pto | 29 + .../op_fusion_nonfused_control.pto | 49 ++ test/lit/tile_fusion/op_scheduling_basic.pto | 46 ++ .../op_scheduling_negative_call_boundary.pto | 38 ++ .../op_scheduling_negative_region.pto | 46 ++ .../op_scheduling_negative_ssa.pto | 39 ++ .../op_scheduling_pure_op_bridge.pto | 39 ++ .../tile_fusion/op_scheduling_treshape.pto | 49 ++ tools/ptoas/ptoas.cpp | 31 ++ 34 files changed, 3375 insertions(+), 180 deletions(-) create mode 100644 include/PTO/Transforms/CppPostprocess.h create mode 100644 include/PTO/Transforms/TileFusion/FusionAnalysis.h create mode 100644 include/PTO/Transforms/TileFusion/FusionOpSemantics.h create mode 100644 lib/PTO/Transforms/CppPostprocess.cpp create mode 100644 lib/PTO/Transforms/TileFusion/FusionAnalysis.cpp create mode 100644 lib/PTO/Transforms/TileFusion/FusionOpSemantics.cpp create mode 100644 lib/PTO/Transforms/TileFusion/PTOFusionPlan.cpp create mode 100644 lib/PTO/Transforms/TileFusion/PTOMarkLastUse.cpp create mode 100644 lib/PTO/Transforms/TileFusion/PTOOpScheduling.cpp create mode 100644 test/lit/tile_fusion/final_emitc_last_use_level2.pto create mode 100644 test/lit/tile_fusion/fusion_plan_diamond.pto create mode 100644 test/lit/tile_fusion/fusion_plan_dynamic_shape_negative.pto create mode 100644 test/lit/tile_fusion/fusion_plan_interleaved_join.pto create mode 100644 test/lit/tile_fusion/fusion_plan_join.pto create mode 100644 test/lit/tile_fusion/fusion_plan_treshape_boundary.pto create mode 100644 test/lit/tile_fusion/mark_last_use_post_span_block_level2.pto create mode 100644 test/lit/tile_fusion/mark_last_use_repeated_ssa_level2.pto create mode 100644 test/lit/tile_fusion/mark_last_use_slot_mask_level2.pto create mode 100644 test/lit/tile_fusion/op_fusion_adapter_placement_level2_tadd.pto create mode 100644 test/lit/tile_fusion/op_fusion_adapter_placement_level3_tadd.pto create mode 100644 test/lit/tile_fusion/op_fusion_cli_flags.pto create mode 100644 test/lit/tile_fusion/op_fusion_nonfused_control.pto create mode 100644 test/lit/tile_fusion/op_scheduling_basic.pto create mode 100644 test/lit/tile_fusion/op_scheduling_negative_call_boundary.pto create mode 100644 test/lit/tile_fusion/op_scheduling_negative_region.pto create mode 100644 test/lit/tile_fusion/op_scheduling_negative_ssa.pto create mode 100644 test/lit/tile_fusion/op_scheduling_pure_op_bridge.pto create mode 100644 test/lit/tile_fusion/op_scheduling_treshape.pto diff --git a/include/PTO/Transforms/CppPostprocess.h b/include/PTO/Transforms/CppPostprocess.h new file mode 100644 index 000000000..dd7039d1d --- /dev/null +++ b/include/PTO/Transforms/CppPostprocess.h @@ -0,0 +1,22 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef MLIR_DIALECT_PTO_TRANSFORMS_CPPPOSTPROCESS_H +#define MLIR_DIALECT_PTO_TRANSFORMS_CPPPOSTPROCESS_H + +#include + +namespace mlir { +namespace pto { + +bool rewriteLastUseMarkersInCpp(std::string &cpp); + +} // namespace pto +} // namespace mlir + +#endif // MLIR_DIALECT_PTO_TRANSFORMS_CPPPOSTPROCESS_H diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 84c92eb17..93be1633c 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -70,6 +70,9 @@ std::unique_ptr createPTOViewToMemrefPass(); std::unique_ptr createPTOMaterializeTileHandlesPass(); std::unique_ptr createInferPTOLayoutPass(); std::unique_ptr createPTOA5NormalizeTMovPass(); +std::unique_ptr createFusionPlanPass(); +std::unique_ptr createOpSchedulingPass(); +std::unique_ptr createPTOMarkLastUsePass(); //===----------------------------------------------------------------------===// // Registration diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 444efe268..07a3f0589 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -167,6 +167,42 @@ def PTOLoweringSyncToPipe : Pass<"pto-lowering-sync-to-pipe", "func::FuncOp"> { ]; } +def FusionPlan : Pass<"pto-fusion-plan", "func::FuncOp"> { + let summary = "Build conservative tile-fusion planning groups from block-local analysis"; + let description = [{ + Consumes PreFusionAnalysis results, forms conservative block-local fusion + groups for currently supported tile-native compute ops, and annotates + accepted group members with: + - pto.fusion.group_id + - pto.fusion.order + }]; + let constructor = "mlir::pto::createFusionPlanPass()"; + let dependentDialects = ["mlir::pto::PTODialect"]; +} + +def OpScheduling : Pass<"pto-op-scheduling", "func::FuncOp"> { + let summary = "Compact planned fusion groups into block-local contiguous spans"; + let description = [{ + Consumes fusion planning metadata emitted by FusionPlanPass and performs + block-local instruction scheduling to make each accepted fusion group a + contiguous span without redefining group membership. + }]; + let constructor = "mlir::pto::createOpSchedulingPass()"; + let dependentDialects = ["mlir::pto::PTODialect"]; +} + +def PTOMarkLastUse : Pass<"pto-mark-last-use", "func::FuncOp"> { + let summary = "Mark scheduled tile-fusion operand last-use slots"; + let description = [{ + Walks scheduled tile-fusion spans identified by pto.fusion.group_id / + pto.fusion.order and annotates each op with a stable per-input last-use + bit mask. The analysis considers both later in-span uses and later + post-span uses in the enclosing block. + }]; + let constructor = "mlir::pto::createPTOMarkLastUsePass()"; + let dependentDialects = ["mlir::pto::PTODialect"]; +} + def PTOWrapFunctionsInSections : Pass<"pto-wrap-functions-in-sections", "func::FuncOp"> { let summary = "Wrap attributed single-core functions in PTO cube/vector sections"; let description = [{ diff --git a/include/PTO/Transforms/TileFusion/FusionAnalysis.h b/include/PTO/Transforms/TileFusion/FusionAnalysis.h new file mode 100644 index 000000000..9f64e776f --- /dev/null +++ b/include/PTO/Transforms/TileFusion/FusionAnalysis.h @@ -0,0 +1,135 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef PTO_TRANSFORMS_TILEFUSION_FUSIONANALYSIS_H +#define PTO_TRANSFORMS_TILEFUSION_FUSIONANALYSIS_H + +#include "PTO/Transforms/TileFusion/FusionOpSemantics.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Support/LLVM.h" + +#include +#include + +namespace mlir { +namespace pto { + +enum class IterationDomainProof { + Proven, + Unproven, +}; + +enum class IterationDomainUnprovenReason { + None, + MissingTileDomain, + DynamicShape, + InconsistentShape, +}; + +struct IterationDomainInfo { + int64_t vRow = ShapedType::kDynamic; + int64_t vCol = ShapedType::kDynamic; + IterationDomainProof proof = IterationDomainProof::Unproven; + IterationDomainUnprovenReason unprovenReason = + IterationDomainUnprovenReason::MissingTileDomain; +}; + +struct IterationDomainClass { + unsigned id = 0; + IterationDomainInfo info; + SmallVector members; +}; + +struct FusionDFGEdge { + unsigned producerNode = 0; + unsigned consumerNode = 0; + Value value; +}; + +struct FusionValueLiveness { + Value value; + std::optional producerNode; + SmallVector consumerNodes; + SmallVector writeInstances; + std::optional lastLocalConsumer; + bool hasExternalUsers = false; + bool escapesBlock = false; + bool hasLocalBoundaryUsers = false; + bool hasLocalHardBoundaryUsers = false; +}; + +enum class FusionWriteInstanceEscapeClass { + Internal, + LocalBoundaryExternal, + HardExternal, +}; + +struct FusionWriteInstanceLiveness { + unsigned id = 0; + Value value; + Value storageValue; + std::optional producerNode; + SmallVector consumerNodes; + std::optional lastLocalConsumer; + FusionWriteInstanceEscapeClass escapeClass = + FusionWriteInstanceEscapeClass::Internal; + bool hasExternalUsers = false; + bool escapesBlock = false; + bool hasLocalBoundaryUsers = false; + bool hasLocalHardBoundaryUsers = false; +}; + +struct FusionComputeNode { + unsigned id = 0; + unsigned blockOrder = 0; + Operation *op = nullptr; + FusionOpSemantics semantics; + unsigned iterationDomainClass = 0; + SmallVector incomingEdges; + SmallVector outgoingEdges; +}; + +struct FusionBlockAnalysis { + Block *block = nullptr; + SmallVector computeNodes; + SmallVector iterationDomainClasses; + SmallVector edges; + SmallVector liveness; + SmallVector writeInstances; +}; + +struct PreFusionAnalysisResult { + SmallVector blocks; +}; + +FailureOr buildPreFusionAnalysis(func::FuncOp func); + +class PreFusionAnalysis { +public: + explicit PreFusionAnalysis(func::FuncOp func) { + FailureOr resultOr = buildPreFusionAnalysis(func); + if (succeeded(resultOr)) + result = std::move(*resultOr); + } + + bool isValid() const { return result.has_value(); } + + const PreFusionAnalysisResult &getResult() const { + assert(result && "expected valid pre-fusion analysis result"); + return *result; + } + +private: + std::optional result; +}; + +} // namespace pto +} // namespace mlir + +#endif // PTO_TRANSFORMS_TILEFUSION_FUSIONANALYSIS_H diff --git a/include/PTO/Transforms/TileFusion/FusionOpSemantics.h b/include/PTO/Transforms/TileFusion/FusionOpSemantics.h new file mode 100644 index 000000000..8f60b91e3 --- /dev/null +++ b/include/PTO/Transforms/TileFusion/FusionOpSemantics.h @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef PTO_TRANSFORMS_TILEFUSION_FUSIONOPSEMANTICS_H +#define PTO_TRANSFORMS_TILEFUSION_FUSIONOPSEMANTICS_H + +#include "PTO/IR/PTO.h" + +#include "mlir/Support/LLVM.h" + +#include + +namespace mlir { +namespace pto { + +enum class FusionOpKind { + Compute, + LocalBoundary, + HardBoundary, +}; + +enum class FusionComputeFamily { + Unknown, + Elementwise, + ScalarExpand, + RowBroadcastBinary, + ReduceRow, + ReduceCol, +}; + +struct FusionOpSemantics { + FusionOpKind kind = FusionOpKind::HardBoundary; + FusionComputeFamily computeFamily = FusionComputeFamily::Unknown; + Operation *op = nullptr; + std::string opName; + SmallVector tileInputs; + SmallVector tileOutputs; + SmallVector scalarInputs; +}; + +bool isSupportedPreFusionComputeOp(StringRef opName); +FailureOr getFusionOpSemantics(Operation *op); + +} // namespace pto +} // namespace mlir + +#endif // PTO_TRANSFORMS_TILEFUSION_FUSIONOPSEMANTICS_H diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index 9c1f7d22c..2319eab62 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -12,11 +12,17 @@ # See LICENSE in the root of the software repository for the full text of the License. add_mlir_dialect_library(PTOTransforms + TileFusion/FusionAnalysis.cpp + TileFusion/FusionOpSemantics.cpp + TileFusion/PTOFusionPlan.cpp + TileFusion/PTOOpScheduling.cpp + TileFusion/PTOMarkLastUse.cpp InsertSync/PTOInsertSync.cpp PTOInjectBarrierAllSync.cpp InsertSync/InsertSyncDebug.cpp PTOViewToMemref.cpp PTOToEmitC.cpp + CppPostprocess.cpp Utils.cpp OptMemPlanForPipeline.cpp AllocToPointerCast.cpp diff --git a/lib/PTO/Transforms/CppPostprocess.cpp b/lib/PTO/Transforms/CppPostprocess.cpp new file mode 100644 index 000000000..554193a74 --- /dev/null +++ b/lib/PTO/Transforms/CppPostprocess.cpp @@ -0,0 +1,176 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/Transforms/CppPostprocess.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/FormatVariadic.h" + +#include +#include +#include + +namespace mlir { +namespace pto { + +namespace { + +struct ParsedMarkerCall { + size_t markerPos; + size_t rparenPos; + llvm::SmallVector args; +}; + +static bool parseMarkerArgs(llvm::StringRef argsRef, + llvm::SmallVectorImpl &args) { + args.clear(); + if (argsRef.empty()) + return true; + + int parenDepth = 0; + size_t partBegin = 0; + for (size_t i = 0; i < argsRef.size(); ++i) { + char c = argsRef[i]; + if (c == '(') { + ++parenDepth; + continue; + } + if (c == ')') { + if (parenDepth > 0) + --parenDepth; + continue; + } + if (c == ',' && parenDepth == 0) { + args.push_back(argsRef.slice(partBegin, i).trim()); + partBegin = i + 1; + } + } + if (partBegin > argsRef.size()) + return false; + args.push_back(argsRef.drop_front(partBegin).trim()); + return true; +} + +static bool parseLastUseMarkerName(llvm::StringRef markerName, + std::string &callee, + std::string &lastUseArgs) { + static constexpr llvm::StringLiteral kPrefix = "PTOAS__LAST_USE__"; + if (!markerName.starts_with(kPrefix)) + return false; + + llvm::StringRef payload = markerName.drop_front(kPrefix.size()); + size_t split = payload.find("__"); + if (split == llvm::StringRef::npos) + return false; + + callee = payload.take_front(split).str(); + llvm::StringRef encoded = payload.drop_front(split + 2); + if (callee.empty() || encoded.empty()) + return false; + + lastUseArgs.clear(); + size_t pos = 0; + while (pos < encoded.size()) { + size_t next = encoded.find("__", pos); + llvm::StringRef token = + next == llvm::StringRef::npos ? encoded.drop_front(pos) + : encoded.slice(pos, next); + if (token.empty()) + return false; + if (!llvm::all_of(token, [](char c) { return std::isdigit(c); })) + return false; + if (!lastUseArgs.empty()) + lastUseArgs.append(", "); + lastUseArgs.append(token.str()); + if (next == llvm::StringRef::npos) + break; + pos = next + 2; + } + return !lastUseArgs.empty(); +} + +} // namespace + +bool rewriteLastUseMarkersInCpp(std::string &cpp) { + size_t searchPos = 0; + bool changed = false; + static constexpr llvm::StringLiteral kPrefix = "PTOAS__LAST_USE__"; + while (true) { + size_t markerPos = cpp.find(kPrefix.str(), searchPos); + if (markerPos == std::string::npos) + break; + + size_t lparenPos = markerPos + kPrefix.size(); + while (lparenPos < cpp.size() && cpp[lparenPos] != '(') + ++lparenPos; + if (lparenPos >= cpp.size()) { + searchPos = markerPos + 1; + continue; + } + + ParsedMarkerCall call{markerPos, std::string::npos, {}}; + size_t argsBegin = lparenPos + 1; + int parenDepth = 0; + for (size_t i = argsBegin; i < cpp.size(); ++i) { + char c = cpp[i]; + if (c == '(') { + ++parenDepth; + continue; + } + if (c != ')') + continue; + if (parenDepth == 0) { + call.rparenPos = i; + break; + } + --parenDepth; + } + if (call.rparenPos == std::string::npos) { + searchPos = markerPos + 1; + continue; + } + + llvm::StringRef argsRef(cpp.data() + argsBegin, call.rparenPos - argsBegin); + if (!parseMarkerArgs(argsRef, call.args)) { + searchPos = call.rparenPos + 1; + continue; + } + + llvm::StringRef markerName(cpp.data() + markerPos, lparenPos - markerPos); + std::string callee; + std::string lastUseArgs; + if (!parseLastUseMarkerName(markerName, callee, lastUseArgs)) { + searchPos = call.rparenPos + 1; + continue; + } + + std::string replacement; + replacement.reserve(callee.size() + lastUseArgs.size() + argsRef.size() + + 32); + replacement.append("[[pto::last_use("); + replacement.append(lastUseArgs); + replacement.append(")]] "); + replacement.append(callee); + replacement.push_back('('); + for (size_t i = 0; i < call.args.size(); ++i) { + if (i) + replacement.append(", "); + replacement.append(call.args[i].str()); + } + replacement.push_back(')'); + + cpp.replace(markerPos, (call.rparenPos - markerPos) + 1, replacement); + changed = true; + searchPos = markerPos + replacement.size(); + } + return changed; +} + +} // namespace pto +} // namespace mlir diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index fd08d24c6..92cf3ba19 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -125,6 +125,131 @@ static constexpr llvm::StringLiteral kForceDynamicValidShapeAttrName = "__pto.force_dynamic_valid_shape"; static constexpr llvm::StringLiteral kGlobalTensorStridesAttrName = "__pto.globaltensor_strides"; +static constexpr llvm::StringLiteral kLastUseAttrName = "pto.last_use"; +static constexpr llvm::StringLiteral kLastUseMarkerPrefix = "PTOAS__LAST_USE__"; + +static SmallVector collectTileOperandNumbers(Operation *op) { + SmallVector tileOperandNumbers; + for (OpOperand &operand : op->getOpOperands()) { + if (isa(operand.get().getType())) + tileOperandNumbers.push_back(operand.getOperandNumber()); + } + return tileOperandNumbers; +} + +static bool isDpsInitOperand(OpOperand &operand) { + Operation *owner = operand.getOwner(); + if (auto dpsIface = dyn_cast(owner)) { + for (OpOperand &init : dpsIface.getDpsInitsMutable()) { + if (&init == &operand) + return true; + } + } + return false; +} + +static SmallVector +buildDefaultLastUseTileSlotOrder(Operation *op) { + SmallVector dpsInitTileOperands; + SmallVector nonDpsTileOperands; + for (OpOperand &operand : op->getOpOperands()) { + if (!isa(operand.get().getType())) + continue; + if (isDpsInitOperand(operand)) + dpsInitTileOperands.push_back(operand.getOperandNumber()); + else + nonDpsTileOperands.push_back(operand.getOperandNumber()); + } + + // Most tile intrinsics lower as `CALLEE(dst, src0, src1, ...)`. When an op + // has exactly one DPS init tile, treat that output slot as the leading + // emitted tile operand so `[[pto::last_use(...)]]` aligns with the final + // intrinsic call order. + if (dpsInitTileOperands.size() == 1) { + SmallVector ordered{dpsInitTileOperands.front()}; + ordered.append(nonDpsTileOperands.begin(), nonDpsTileOperands.end()); + return ordered; + } + + SmallVector ordered = std::move(nonDpsTileOperands); + ordered.append(dpsInitTileOperands.begin(), dpsInitTileOperands.end()); + return ordered; +} + +static std::optional buildLastUseMarkerCallee(Operation *op, + StringRef callee, + ArrayRef tileSlotOrder = {}) { + auto lastUseAttr = dyn_cast_or_null( + op->getAttr(kLastUseAttrName)); + if (!lastUseAttr) + return std::nullopt; + + SmallVector originalTileOperands = collectTileOperandNumbers(op); + ArrayRef originalBits = lastUseAttr.asArrayRef(); + if (originalTileOperands.size() != originalBits.size()) + return std::nullopt; + + SmallVector defaultTileSlotOrder; + if (tileSlotOrder.empty()) { + defaultTileSlotOrder = buildDefaultLastUseTileSlotOrder(op); + tileSlotOrder = defaultTileSlotOrder; + } + if (tileSlotOrder.size() != originalBits.size()) + return std::nullopt; + + SmallVector reorderedBits; + reorderedBits.reserve(tileSlotOrder.size()); + for (unsigned operandNumber : tileSlotOrder) { + bool found = false; + for (auto [idx, originalOperandNumber] : llvm::enumerate(originalTileOperands)) { + if (originalOperandNumber != operandNumber) + continue; + reorderedBits.push_back(originalBits[idx]); + found = true; + break; + } + if (!found) + return std::nullopt; + } + + std::string marker = kLastUseMarkerPrefix.str(); + marker.append(callee.str()); + marker.append("__"); + bool first = true; + for (int64_t bit : reorderedBits) { + if (!first) + marker.append("__"); + first = false; + marker.append(std::to_string(bit)); + } + return marker; +} + +static StringRef getLastUseAwareCallee(Operation *op, StringRef callee, + std::string &storage, + ArrayRef tileSlotOrder = {}) { + std::optional marker = + buildLastUseMarkerCallee(op, callee, tileSlotOrder); + if (!marker) + return callee; + storage = std::move(*marker); + return storage; +} + +static void createLastUseAwareOpaqueCall(ConversionPatternRewriter &rewriter, + Operation *op, TypeRange resultTypes, + StringRef callee, + ValueRange operands, + ArrayAttr args = ArrayAttr{}, + ArrayAttr templateArgs = ArrayAttr{}, + ArrayRef tileSlotOrder = {}) { + std::string calleeStorage; + StringRef effectiveCallee = + getLastUseAwareCallee(op, callee, calleeStorage, tileSlotOrder); + rewriter.create(op->getLoc(), resultTypes, + effectiveCallee, args, templateArgs, + operands); +} static Value peelUnrealized(Value v) { if (auto castOp = v.getDefiningOp()) @@ -5412,7 +5537,6 @@ struct PTOTAxpyToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TAxpyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - Value src = peelUnrealized(adaptor.getSrc()); Value dst = peelUnrealized(adaptor.getDst()); Value scalar = peelUnrealized(adaptor.getScalar()); @@ -5458,7 +5582,6 @@ struct PTOGetScaleAddrToEmitC LogicalResult matchAndRewrite(pto::TGetScaleAddrOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - Value src = peelUnrealized(adaptor.getSrc()); Value dst = peelUnrealized(adaptor.getDst()); @@ -5639,10 +5762,8 @@ struct PTOTAddToTADD : public OpConversionPattern { Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); - rewriter.create( - op.getLoc(), TypeRange{}, "TADD", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1}); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TADD", ValueRange{dst, src0, src1}); rewriter.eraseOp(op); return success(); @@ -7025,10 +7146,8 @@ struct PTOAddSToTADDS : public OpConversionPattern { Value dst = peelUnrealized(adaptor.getDst()); Value scalar = peelUnrealized(adaptor.getScalar()); - rewriter.create( - op.getLoc(), TypeRange{}, "TADDS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TADDS", ValueRange{dst, src, scalar}); rewriter.eraseOp(op); return success(); @@ -7228,7 +7347,6 @@ struct PTOColExpandMulToEmitC : public OpConversionPattern LogicalResult matchAndRewrite(pto::TColExpandMulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); @@ -7250,7 +7368,6 @@ struct PTOColExpandAddToEmitC : public OpConversionPattern LogicalResult matchAndRewrite(pto::TColExpandAddOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); @@ -7272,7 +7389,6 @@ struct PTOColExpandDivToEmitC : public OpConversionPattern LogicalResult matchAndRewrite(pto::TColExpandDivOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); @@ -7295,7 +7411,6 @@ struct PTOColExpandExpdifToEmitC LogicalResult matchAndRewrite(pto::TColExpandExpdifOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); @@ -7317,7 +7432,6 @@ struct PTOColExpandSubToEmitC : public OpConversionPattern LogicalResult matchAndRewrite(pto::TColExpandSubOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); @@ -7339,7 +7453,6 @@ struct PTOColExpandMaxToEmitC : public OpConversionPattern LogicalResult matchAndRewrite(pto::TColExpandMaxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); @@ -7361,7 +7474,6 @@ struct PTOColExpandMinToEmitC : public OpConversionPattern LogicalResult matchAndRewrite(pto::TColExpandMinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); @@ -7480,17 +7592,12 @@ struct PTOColMaxToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TColMaxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src = peelUnrealized(adaptor.getSrc()); Value dst = peelUnrealized(adaptor.getDst()); // intrinsic: TCOLMAX(dst, src) - rewriter.create( - loc, TypeRange{}, "TCOLMAX", - /*args=*/ArrayAttr{}, // default: print all operands - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TCOLMAX", ValueRange{dst, src}); rewriter.eraseOp(op); return success(); @@ -7503,7 +7610,6 @@ struct PTOColArgMaxToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TColArgMaxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - Value src = peelUnrealized(adaptor.getSrc()); Value tmp = peelUnrealized(adaptor.getTmp()); Value dst = peelUnrealized(adaptor.getDst()); @@ -7524,17 +7630,12 @@ struct PTOColMinToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TColMinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src = peelUnrealized(adaptor.getSrc()); Value dst = peelUnrealized(adaptor.getDst()); // intrinsic: TCOLMIN(dst, src) - rewriter.create( - loc, TypeRange{}, "TCOLMIN", - /*args=*/ArrayAttr{}, // default: print all operands - /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{dst, src}); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TCOLMIN", ValueRange{dst, src}); rewriter.eraseOp(op); return success(); @@ -7547,7 +7648,6 @@ struct PTOColArgMinToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TColArgMinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - Value src = peelUnrealized(adaptor.getSrc()); Value tmp = peelUnrealized(adaptor.getTmp()); Value dst = peelUnrealized(adaptor.getDst()); @@ -7587,18 +7687,23 @@ struct PTOColSumToEmitC : public OpConversionPattern { Value isBinaryVal = rewriter.create( loc, boolTy, emitc::OpaqueAttr::get(ctx, tok)); - rewriter.create( - loc, TypeRange{}, "TCOLSUM", - /*args=*/ArrayAttr(), - /*templateArgs=*/ArrayAttr(), - /*operands=*/ValueRange{dst, src, tmp, isBinaryVal}); + SmallVector tileSlotOrder; + tileSlotOrder.push_back(op.getDstMutable().getOperandNumber()); + tileSlotOrder.push_back(op.getSrcMutable().getOperandNumber()); + tileSlotOrder.push_back(op.getTmpMutable().begin()->getOperandNumber()); + + createLastUseAwareOpaqueCall( + rewriter, op.getOperation(), TypeRange{}, "TCOLSUM", + ValueRange{dst, src, tmp, isBinaryVal}, ArrayAttr{}, ArrayAttr{}, + tileSlotOrder); } else { // Format 1: without tmp and isBinary - rewriter.create( - loc, TypeRange{}, "TCOLSUM", - /*args=*/ArrayAttr(), - /*templateArgs=*/ArrayAttr(), - /*operands=*/ValueRange{dst, src}); + SmallVector tileSlotOrder; + tileSlotOrder.push_back(op.getDstMutable().getOperandNumber()); + tileSlotOrder.push_back(op.getSrcMutable().getOperandNumber()); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TCOLSUM", ValueRange{dst, src}, + ArrayAttr{}, ArrayAttr{}, tileSlotOrder); } rewriter.eraseOp(op); @@ -7726,10 +7831,8 @@ struct PTODivToTDIV : public OpConversionPattern { Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); - rewriter.create( - op.getLoc(), TypeRange{}, "TDIV", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src0, src1}); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TDIV", ValueRange{dst, src0, src1}); rewriter.eraseOp(op); return success(); @@ -7746,18 +7849,14 @@ struct PTODivSToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TDivSOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src = peelUnrealized(adaptor.getSrc()); Value scalar = peelUnrealized(adaptor.getScalar()); Value dst = peelUnrealized(adaptor.getDst()); // Preserve source order from textual parse: // ins(tile, scalar) -> TDIVS(dst, tile, scalar) // ins(scalar, tile) -> TDIVS(dst, scalar, tile) - rewriter.create( - loc, TypeRange{}, "TDIVS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TDIVS", ValueRange{dst, src, scalar}); rewriter.eraseOp(op); return success(); @@ -7775,15 +7874,11 @@ struct PTOTDivSToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TDivSOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src = peelUnrealized(adaptor.getSrc()); Value scalar = peelUnrealized(adaptor.getScalar()); Value dst = peelUnrealized(adaptor.getDst()); - rewriter.create( - loc, TypeRange{}, "TDIVS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src, scalar}); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TDIVS", ValueRange{dst, src, scalar}); rewriter.eraseOp(op); return success(); @@ -7798,15 +7893,11 @@ struct PTOExpToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TExpOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src = peelUnrealized(adaptor.getSrc()); Value dst = peelUnrealized(adaptor.getDst()); - rewriter.create( - loc, TypeRange{}, "TEXP", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, src}); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TEXP", ValueRange{dst, src}); rewriter.eraseOp(op); return success(); @@ -7821,15 +7912,11 @@ struct PTOExpandsToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TExpandsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value scalar = peelUnrealized(adaptor.getScalar()); Value dst = peelUnrealized(adaptor.getDst()); - rewriter.create( - loc, TypeRange{}, "TEXPANDS", - ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, scalar}); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TEXPANDS", ValueRange{dst, scalar}); rewriter.eraseOp(op); return success(); @@ -8209,17 +8296,13 @@ struct PTOMaxToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TMaxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TMAX", operands); rewriter.eraseOp(op); return success(); @@ -8235,17 +8318,13 @@ struct PTOMaxToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TMaxSOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc()); Value scalar = peelUnrealized(adaptor.getScalar()); Value dst = peelUnrealized(adaptor.getDst()); SmallVector operands{dst, src0, scalar}; - rewriter.create( - loc, TypeRange{}, "TMAXS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TMAXS", operands); rewriter.eraseOp(op); return success(); @@ -8262,17 +8341,13 @@ struct PTOMinToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TMinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TMIN", operands); rewriter.eraseOp(op); return success(); @@ -8292,17 +8367,13 @@ struct PTOMinsToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TMinSOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src = peelUnrealized(adaptor.getSrc()); Value dst = peelUnrealized(adaptor.getDst()); Value scalar = peelUnrealized(adaptor.getScalar()); SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TMINS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TMINS", operands); rewriter.eraseOp(op); return success(); @@ -8635,17 +8706,13 @@ struct PTOMulToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TMulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TMUL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TMUL", operands); rewriter.eraseOp(op); return success(); @@ -8660,17 +8727,13 @@ struct PTOMulsToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TMulSOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src = peelUnrealized(adaptor.getSrc0()); Value dst = peelUnrealized(adaptor.getDst()); Value scalar = peelUnrealized(adaptor.getScalar()); SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TMULS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TMULS", operands); rewriter.eraseOp(op); return success(); @@ -8736,7 +8799,6 @@ struct PTOOrToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TOrOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); @@ -8789,7 +8851,6 @@ struct PTOPartAddToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TPartAddOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); @@ -8963,7 +9024,6 @@ struct PTORecipToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TRecipOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - Value src = peelUnrealized(adaptor.getSrc()); Value dst = peelUnrealized(adaptor.getDst()); @@ -9058,7 +9118,6 @@ struct PTORemSToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TRemSOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - Value src = peelUnrealized(adaptor.getSrc()); Value tmp = peelUnrealized(adaptor.getTmp()); Value dst = peelUnrealized(adaptor.getDst()); @@ -9180,8 +9239,12 @@ static void replaceOrEraseWithOpaqueCall(Operation *op, ArrayRef args, ConversionPatternRewriter &rewriter) { TypeRange resultTypes = op->getResultTypes(); + std::string calleeStorage; + StringRef effectiveCallee = + getLastUseAwareCallee(op, callee, calleeStorage); auto call = rewriter.create( - op->getLoc(), resultTypes, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); + op->getLoc(), resultTypes, effectiveCallee, ArrayAttr{}, ArrayAttr{}, + ValueRange(args)); if (resultTypes.empty()) rewriter.eraseOp(op); else @@ -9192,8 +9255,7 @@ static void replaceOrEraseWithOpaqueCallAndReturnDst(Operation *op, Value dst, StringRef callee, ArrayRef args, ConversionPatternRewriter &rewriter) { - rewriter.create( - op->getLoc(), TypeRange{}, callee, ArrayAttr{}, ArrayAttr{}, ValueRange(args)); + createLastUseAwareOpaqueCall(rewriter, op, TypeRange{}, callee, args); if (op->getNumResults() == 1) rewriter.replaceOp(op, dst); else @@ -9352,8 +9414,6 @@ struct PTORowExpandDivToEmitC : public OpConversionPattern LogicalResult matchAndRewrite(pto::TRowExpandDivOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); @@ -9364,10 +9424,8 @@ struct PTORowExpandDivToEmitC : public OpConversionPattern operands.assign({dst, src0, src1, tmp}); else operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDDIV", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TROWEXPANDDIV", operands); rewriter.eraseOp(op); return success(); @@ -9382,8 +9440,6 @@ struct PTORowExpandMulToEmitC : public OpConversionPattern LogicalResult matchAndRewrite(pto::TRowExpandMulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); @@ -9394,10 +9450,8 @@ struct PTORowExpandMulToEmitC : public OpConversionPattern operands.assign({dst, src0, src1, tmp}); else operands.assign({dst, src0, src1}); - rewriter.create( - loc, TypeRange{}, "TROWEXPANDMUL", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TROWEXPANDMUL", operands); rewriter.eraseOp(op); return success(); @@ -9498,17 +9552,13 @@ struct PTORowMaxToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TRowMaxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src = peelUnrealized(adaptor.getSrc()); Value tmp = peelUnrealized(adaptor.getTmp()); Value dst = peelUnrealized(adaptor.getDst()); SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWMAX", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TROWMAX", operands); rewriter.eraseOp(op); return success(); @@ -9522,7 +9572,6 @@ struct PTORowArgMaxToEmitC LogicalResult matchAndRewrite(pto::TRowArgMaxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - Value src = peelUnrealized(adaptor.getSrc()); Value tmp = peelUnrealized(adaptor.getTmp()); Value dst = peelUnrealized(adaptor.getDst()); @@ -9545,17 +9594,13 @@ struct PTORowMinToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TRowMinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src = peelUnrealized(adaptor.getSrc()); Value tmp = peelUnrealized(adaptor.getTmp()); Value dst = peelUnrealized(adaptor.getDst()); SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWMIN", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TROWMIN", operands); rewriter.eraseOp(op); return success(); @@ -9569,7 +9614,6 @@ struct PTORowArgMinToEmitC LogicalResult matchAndRewrite(pto::TRowArgMinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - Value src = peelUnrealized(adaptor.getSrc()); Value tmp = peelUnrealized(adaptor.getTmp()); Value dst = peelUnrealized(adaptor.getDst()); @@ -9593,17 +9637,13 @@ struct PTORowSumToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TRowSumOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src = peelUnrealized(adaptor.getSrc()); Value tmp = peelUnrealized(adaptor.getTmp()); Value dst = peelUnrealized(adaptor.getDst()); SmallVector operands{dst, src, tmp}; - rewriter.create( - loc, TypeRange{}, "TROWSUM", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TROWSUM", operands); rewriter.eraseOp(op); return success(); @@ -9937,17 +9977,13 @@ struct PTOSubSToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TSubOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src0 = peelUnrealized(adaptor.getSrc0()); Value src1 = peelUnrealized(adaptor.getSrc1()); Value dst = peelUnrealized(adaptor.getDst()); SmallVector operands{dst, src0, src1}; - rewriter.create( - loc, TypeRange{}, "TSUB", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TSUB", operands); rewriter.eraseOp(op); return success(); @@ -9993,17 +10029,13 @@ struct PTOSubSSToEmitC : public OpConversionPattern { LogicalResult matchAndRewrite(pto::TSubSOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value src = peelUnrealized(adaptor.getSrc()); Value scalar = peelUnrealized(adaptor.getScalar()); Value dst = peelUnrealized(adaptor.getDst()); SmallVector operands{dst, src, scalar}; - rewriter.create( - loc, TypeRange{}, "TSUBS", - /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, - /*operands=*/operands); + createLastUseAwareOpaqueCall(rewriter, op.getOperation(), TypeRange{}, + "TSUBS", operands); rewriter.eraseOp(op); return success(); diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index 2cd529ea5..6c4ed176a 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -126,6 +126,16 @@ static void lookupValidDims(Value v, Value &vRow, Value &vCol) { vCol = Value(); } +template +static OpTy replaceOpWithClonedAttrs(IRRewriter &rewriter, Operation *op, + Args &&...args) { + auto newOp = + rewriter.create(op->getLoc(), std::forward(args)...); + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp->getResults()); + return newOp; +} + // ============================================================================= // Helper Functions for Layout Normalization // ============================================================================= @@ -1686,8 +1696,9 @@ struct PTOViewToMemrefPass for (auto op : trans) { IRRewriter rewriter(ctx); rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, op->getOperand(0), op->getOperand(1), op->getOperand(2)); + replaceOpWithClonedAttrs( + rewriter, op, TypeRange{}, op->getOperand(0), op->getOperand(1), + op->getOperand(2)); } // --- TExpOp [Src, Dst] --- @@ -1696,8 +1707,9 @@ struct PTOViewToMemrefPass for (auto op : exp) { IRRewriter rewriter(ctx); rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, op->getOperand(0), op->getOperand(1)); + replaceOpWithClonedAttrs(rewriter, op, TypeRange{}, + op->getOperand(0), + op->getOperand(1)); } // --- TMulOp [Src, Scalar, Dst] --- @@ -1706,8 +1718,9 @@ struct PTOViewToMemrefPass for (auto op : mul) { IRRewriter rewriter(ctx); rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, op->getOperand(0), op.getOperand(1), op->getOperand(2)); + replaceOpWithClonedAttrs( + rewriter, op, op->getOperand(0), op.getOperand(1), + op->getOperand(2)); } // --- TMulSOp [Src, Scalar, Dst] --- @@ -1716,8 +1729,9 @@ struct PTOViewToMemrefPass for (auto op : muls) { IRRewriter rewriter(ctx); rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, op->getOperand(0), op.getScalar(), op->getOperand(2)); + replaceOpWithClonedAttrs( + rewriter, op, op->getOperand(0), op.getScalar(), + op->getOperand(2)); } // --- TAddOp [Src0, Src1, Dst] --- @@ -1727,9 +1741,9 @@ struct PTOViewToMemrefPass IRRewriter rewriter(ctx); rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp( - op, TypeRange{}, - op->getOperand(0), op->getOperand(1), op->getOperand(2)); + replaceOpWithClonedAttrs( + rewriter, op, TypeRange{}, op->getOperand(0), op->getOperand(1), + op->getOperand(2)); } // --- TMatmulOp [Lhs, Rhs, Dst] (no optional bias in ODS) --- @@ -1961,12 +1975,8 @@ struct PTOViewToMemrefPass return; } - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - src, - scalar, - dst); + replaceOpWithClonedAttrs(rewriter, op, TypeRange{}, src, + scalar, dst); } SmallVector addscops; @@ -2572,11 +2582,8 @@ struct PTOViewToMemrefPass return; } - rewriter.replaceOpWithNewOp( - op, - TypeRange{}, - scalar, - dst); + replaceOpWithClonedAttrs(rewriter, op, TypeRange{}, + scalar, dst); } SmallVector extractops; diff --git a/lib/PTO/Transforms/TileFusion/FusionAnalysis.cpp b/lib/PTO/Transforms/TileFusion/FusionAnalysis.cpp new file mode 100644 index 000000000..fd6bb1b4f --- /dev/null +++ b/lib/PTO/Transforms/TileFusion/FusionAnalysis.cpp @@ -0,0 +1,506 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/Transforms/TileFusion/FusionAnalysis.h" + +#include "PTO/IR/PTO.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinTypes.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" + +namespace mlir { +namespace pto { + +namespace { + +static int64_t getConstantIndexOrDynamic(Value value) { + if (!value) + return ShapedType::kDynamic; + if (auto cst = value.getDefiningOp()) + return cst.value(); + if (auto cst = value.getDefiningOp()) + return cst.value(); + return ShapedType::kDynamic; +} + +static SmallVector getValidShapeVec(Type type) { + if (auto tileType = dyn_cast(type)) { + return SmallVector(tileType.getValidShape().begin(), + tileType.getValidShape().end()); + } + if (auto shapedType = dyn_cast(type)) { + return SmallVector(shapedType.getShape().begin(), + shapedType.getShape().end()); + } + return {}; +} + +static SmallVector getValidShapeVec(Value value) { + SmallVector validShape = getValidShapeVec(value.getType()); + if (auto bind = value.getDefiningOp()) { + if (validShape.size() >= 1 && bind.getValidRow()) + validShape[0] = getConstantIndexOrDynamic(bind.getValidRow()); + if (validShape.size() >= 2 && bind.getValidCol()) + validShape[1] = getConstantIndexOrDynamic(bind.getValidCol()); + } + return validShape; +} + +struct Rank2IterationSpace { + int64_t rows = ShapedType::kDynamic; + int64_t cols = ShapedType::kDynamic; +}; + +static std::optional getRank2IterationSpace(Value value) { + SmallVector validShape = getValidShapeVec(value); + if (validShape.size() < 2) + return std::nullopt; + return Rank2IterationSpace{validShape[0], validShape[1]}; +} + +static void mergeIterationDim(int64_t &mergedDim, int64_t dim, + IterationDomainInfo &info) { + if (mergedDim == ShapedType::kDynamic || dim == ShapedType::kDynamic) { + mergedDim = ShapedType::kDynamic; + if (info.unprovenReason == IterationDomainUnprovenReason::None) + info.unprovenReason = IterationDomainUnprovenReason::DynamicShape; + return; + } + + if (mergedDim != dim) { + mergedDim = ShapedType::kDynamic; + info.unprovenReason = IterationDomainUnprovenReason::InconsistentShape; + } +} + +static IterationDomainInfo +inferConsensusIterationDomain(ArrayRef anchorValues) { + IterationDomainInfo info; + info.unprovenReason = IterationDomainUnprovenReason::None; + + if (anchorValues.empty()) + return info; + + std::optional firstSpace = + getRank2IterationSpace(anchorValues.front()); + if (!firstSpace) + return info; + + info.vRow = firstSpace->rows; + info.vCol = firstSpace->cols; + + if (info.vRow == ShapedType::kDynamic || info.vCol == ShapedType::kDynamic) + info.unprovenReason = IterationDomainUnprovenReason::DynamicShape; + + for (Value value : ArrayRef(anchorValues).drop_front()) { + std::optional space = getRank2IterationSpace(value); + if (!space) { + info.vRow = ShapedType::kDynamic; + info.vCol = ShapedType::kDynamic; + info.unprovenReason = IterationDomainUnprovenReason::MissingTileDomain; + return info; + } + mergeIterationDim(info.vRow, space->rows, info); + mergeIterationDim(info.vCol, space->cols, info); + } + + if (info.unprovenReason == IterationDomainUnprovenReason::None && + info.vRow != ShapedType::kDynamic && info.vCol != ShapedType::kDynamic) { + info.proof = IterationDomainProof::Proven; + return info; + } + + if (info.unprovenReason == IterationDomainUnprovenReason::None) + info.unprovenReason = IterationDomainUnprovenReason::DynamicShape; + return info; +} + +static IterationDomainInfo +inferIterationDomainInfo(const FusionOpSemantics &semantics) { + switch (semantics.computeFamily) { + case FusionComputeFamily::Elementwise: { + SmallVector anchors; + anchors.append(semantics.tileInputs.begin(), semantics.tileInputs.end()); + anchors.append(semantics.tileOutputs.begin(), semantics.tileOutputs.end()); + return inferConsensusIterationDomain(anchors); + } + case FusionComputeFamily::ScalarExpand: + case FusionComputeFamily::RowBroadcastBinary: + return inferConsensusIterationDomain(semantics.tileOutputs); + case FusionComputeFamily::ReduceRow: + case FusionComputeFamily::ReduceCol: + return inferConsensusIterationDomain(semantics.tileInputs); + case FusionComputeFamily::Unknown: + return IterationDomainInfo(); + } + return IterationDomainInfo(); +} + +static unsigned assignIterationDomainClass( + SmallVectorImpl &classes, + DenseMap, unsigned> &provenClassByKey, + const IterationDomainInfo &info, unsigned nodeId) { + if (info.proof == IterationDomainProof::Proven) { + std::pair key{info.vRow, info.vCol}; + auto it = provenClassByKey.find(key); + if (it != provenClassByKey.end()) { + classes[it->second].members.push_back(nodeId); + return it->second; + } + + unsigned classId = classes.size(); + IterationDomainClass klass; + klass.id = classId; + klass.info = info; + klass.members.push_back(nodeId); + classes.push_back(std::move(klass)); + provenClassByKey.try_emplace(key, classId); + return classId; + } + + unsigned classId = classes.size(); + IterationDomainClass klass; + klass.id = classId; + klass.info = info; + klass.members.push_back(nodeId); + classes.push_back(std::move(klass)); + return classId; +} + +struct MutableLiveness { + FusionValueLiveness live; +}; + +struct MutableWriteInstance { + FusionWriteInstanceLiveness live; + unsigned producerBlockOrder = 0; +}; + +static FusionWriteInstanceEscapeClass classifyEscapeClass( + const FusionWriteInstanceLiveness &live) { + if (live.hasExternalUsers || live.escapesBlock || + live.hasLocalHardBoundaryUsers) { + return FusionWriteInstanceEscapeClass::HardExternal; + } + if (live.hasLocalBoundaryUsers) + return FusionWriteInstanceEscapeClass::LocalBoundaryExternal; + return FusionWriteInstanceEscapeClass::Internal; +} + +static Value getWriteInstanceStorageValue(Operation *op, unsigned outputIndex, + Value output) { + if (auto dpsIface = dyn_cast(op)) { + unsigned tileOutputIndex = 0; + for (Value init : dpsIface.getDpsInits()) { + if (!isa(init.getType())) + continue; + if (tileOutputIndex == outputIndex) + return init; + ++tileOutputIndex; + } + } + return output; +} + +static unsigned getOrCreateLivenessSlot(DenseMap &slotByValue, + SmallVectorImpl &slots, + Value value) { + auto [it, inserted] = slotByValue.try_emplace(value, slots.size()); + if (inserted) { + MutableLiveness state; + state.live.value = value; + slots.push_back(std::move(state)); + } + return it->second; +} + +static void appendUniqueNode(SmallVectorImpl &nodes, unsigned nodeId) { + if (!llvm::is_contained(nodes, nodeId)) + nodes.push_back(nodeId); +} + +static void recordLastLocalConsumer(std::optional &lastLocalConsumer, + unsigned consumerId) { + if (!lastLocalConsumer || consumerId > *lastLocalConsumer) + lastLocalConsumer = consumerId; +} + +static void finalizeBlockLiveness( + Block &block, DenseMap &kindByOp, + DenseMap &computeNodeByOp, + SmallVectorImpl &mutableLiveness) { + for (MutableLiveness &state : mutableLiveness) { + for (OpOperand &use : state.live.value.getUses()) { + Operation *user = use.getOwner(); + if (user->getBlock() != &block) { + state.live.hasExternalUsers = true; + state.live.escapesBlock = true; + continue; + } + + auto kindIt = kindByOp.find(user); + if (kindIt == kindByOp.end()) + continue; + + if (user->hasTrait()) + state.live.escapesBlock = true; + + switch (kindIt->second) { + case FusionOpKind::Compute: { + auto nodeIt = computeNodeByOp.find(user); + if (nodeIt == computeNodeByOp.end()) + continue; + unsigned consumerId = nodeIt->second; + appendUniqueNode(state.live.consumerNodes, consumerId); + recordLastLocalConsumer(state.live.lastLocalConsumer, consumerId); + break; + } + case FusionOpKind::LocalBoundary: + state.live.hasLocalBoundaryUsers = true; + break; + case FusionOpKind::HardBoundary: + state.live.hasLocalHardBoundaryUsers = true; + break; + } + } + } +} + +static std::optional findReachingWriteInstance( + ArrayRef writeInstanceIds, + ArrayRef mutableWriteInstances, + std::optional userBlockOrder) { + if (writeInstanceIds.empty()) + return std::nullopt; + + if (!userBlockOrder) + return writeInstanceIds.back(); + + for (unsigned writeInstanceId : llvm::reverse(writeInstanceIds)) { + if (mutableWriteInstances[writeInstanceId].producerBlockOrder < + *userBlockOrder) + return writeInstanceId; + } + return std::nullopt; +} + +static bool isDpsInitOperandUse(OpOperand &use) { + auto dpsIface = dyn_cast(use.getOwner()); + if (!dpsIface) + return false; + + for (OpOperand &dpsInit : dpsIface.getDpsInitsMutable()) + if (&dpsInit == &use) + return true; + return false; +} + +static void finalizeWriteInstances( + Block &block, DenseMap &kindByOp, + DenseMap &computeNodeByOp, + DenseMap &blockOrderByOp, + ArrayRef mutableLiveness, + SmallVectorImpl &mutableWriteInstances) { + for (const MutableLiveness &storageState : mutableLiveness) { + if (storageState.live.writeInstances.empty()) + continue; + + for (OpOperand &use : storageState.live.value.getUses()) { + if (isDpsInitOperandUse(use)) + continue; + + Operation *user = use.getOwner(); + bool isInBlock = user->getBlock() == █ + std::optional userBlockOrder; + if (isInBlock) { + auto orderIt = blockOrderByOp.find(user); + if (orderIt != blockOrderByOp.end()) + userBlockOrder = orderIt->second; + } + + std::optional writeInstanceId = findReachingWriteInstance( + storageState.live.writeInstances, mutableWriteInstances, + userBlockOrder); + if (!writeInstanceId) + continue; + + FusionWriteInstanceLiveness &writeLive = + mutableWriteInstances[*writeInstanceId].live; + + if (!isInBlock) { + writeLive.hasExternalUsers = true; + writeLive.escapesBlock = true; + continue; + } + + auto kindIt = kindByOp.find(user); + if (kindIt == kindByOp.end()) + continue; + + if (user->hasTrait()) + writeLive.escapesBlock = true; + + switch (kindIt->second) { + case FusionOpKind::Compute: { + auto nodeIt = computeNodeByOp.find(user); + if (nodeIt == computeNodeByOp.end()) + continue; + unsigned consumerId = nodeIt->second; + appendUniqueNode(writeLive.consumerNodes, consumerId); + recordLastLocalConsumer(writeLive.lastLocalConsumer, consumerId); + break; + } + case FusionOpKind::LocalBoundary: + writeLive.hasLocalBoundaryUsers = true; + break; + case FusionOpKind::HardBoundary: + writeLive.hasLocalHardBoundaryUsers = true; + break; + } + } + } + + for (MutableWriteInstance &state : mutableWriteInstances) + state.live.escapeClass = classifyEscapeClass(state.live); +} + +static FailureOr analyzeBlock(Block &block) { + FusionBlockAnalysis analysis; + analysis.block = █ + + DenseMap producerByValue; + DenseMap livenessSlotByValue; + SmallVector mutableLiveness; + SmallVector mutableWriteInstances; + DenseMap kindByOp; + DenseMap computeNodeByOp; + DenseMap blockOrderByOp; + DenseMap, unsigned> provenClassByKey; + + unsigned blockOrder = 0; + for (Operation &op : block) { + FailureOr semanticsOr = getFusionOpSemantics(&op); + if (failed(semanticsOr)) { + op.emitError("failed to normalize fusion op semantics"); + return failure(); + } + blockOrderByOp[&op] = blockOrder; + kindByOp[&op] = semanticsOr->kind; + + if (semanticsOr->kind == FusionOpKind::LocalBoundary) { + for (Value input : semanticsOr->tileInputs) + getOrCreateLivenessSlot(livenessSlotByValue, mutableLiveness, input); + for (Value output : semanticsOr->tileOutputs) + getOrCreateLivenessSlot(livenessSlotByValue, mutableLiveness, output); + ++blockOrder; + continue; + } + + if (semanticsOr->kind != FusionOpKind::Compute) { + ++blockOrder; + continue; + } + + FusionComputeNode node; + node.id = analysis.computeNodes.size(); + node.blockOrder = blockOrder; + node.op = &op; + node.semantics = *semanticsOr; + computeNodeByOp[&op] = node.id; + + IterationDomainInfo domainInfo = inferIterationDomainInfo(node.semantics); + node.iterationDomainClass = assignIterationDomainClass( + analysis.iterationDomainClasses, provenClassByKey, domainInfo, node.id); + + for (auto [outputIdx, output] : llvm::enumerate(node.semantics.tileOutputs)) { + producerByValue[output] = node.id; + unsigned liveSlot = + getOrCreateLivenessSlot(livenessSlotByValue, mutableLiveness, output); + mutableLiveness[liveSlot].live.producerNode = node.id; + + MutableWriteInstance writeInstance; + writeInstance.live.id = mutableWriteInstances.size(); + writeInstance.live.value = output; + writeInstance.live.storageValue = + getWriteInstanceStorageValue(&op, outputIdx, output); + writeInstance.live.producerNode = node.id; + writeInstance.producerBlockOrder = blockOrder; + mutableLiveness[liveSlot].live.writeInstances.push_back( + writeInstance.live.id); + mutableWriteInstances.push_back(std::move(writeInstance)); + } + + for (Value input : node.semantics.tileInputs) { + unsigned liveSlot = + getOrCreateLivenessSlot(livenessSlotByValue, mutableLiveness, input); + appendUniqueNode(mutableLiveness[liveSlot].live.consumerNodes, node.id); + recordLastLocalConsumer(mutableLiveness[liveSlot].live.lastLocalConsumer, + node.id); + + auto producerIt = producerByValue.find(input); + if (producerIt == producerByValue.end()) + continue; + + FusionDFGEdge edge; + edge.producerNode = producerIt->second; + edge.consumerNode = node.id; + edge.value = input; + + unsigned edgeId = analysis.edges.size(); + analysis.edges.push_back(edge); + node.incomingEdges.push_back(edgeId); + if (edge.producerNode < analysis.computeNodes.size()) + analysis.computeNodes[edge.producerNode].outgoingEdges.push_back(edgeId); + } + + analysis.computeNodes.push_back(std::move(node)); + ++blockOrder; + } + + finalizeBlockLiveness(block, kindByOp, computeNodeByOp, mutableLiveness); + finalizeWriteInstances(block, kindByOp, computeNodeByOp, blockOrderByOp, + mutableLiveness, mutableWriteInstances); + + analysis.liveness.reserve(mutableLiveness.size()); + for (MutableLiveness &state : mutableLiveness) + analysis.liveness.push_back(std::move(state.live)); + analysis.writeInstances.reserve(mutableWriteInstances.size()); + for (MutableWriteInstance &state : mutableWriteInstances) + analysis.writeInstances.push_back(std::move(state.live)); + + return std::move(analysis); +} + +static LogicalResult analyzeRegion(Region ®ion, + SmallVectorImpl &blocks) { + for (Block &block : region.getBlocks()) { + FailureOr blockAnalysis = analyzeBlock(block); + if (failed(blockAnalysis)) + return failure(); + blocks.push_back(std::move(*blockAnalysis)); + for (Operation &op : block) + for (Region &nested : op.getRegions()) + if (failed(analyzeRegion(nested, blocks))) + return failure(); + } + return success(); +} + +} // namespace + +FailureOr buildPreFusionAnalysis(func::FuncOp func) { + PreFusionAnalysisResult result; + if (failed(analyzeRegion(func.getRegion(), result.blocks))) + return failure(); + return std::move(result); +} + +} // namespace pto +} // namespace mlir diff --git a/lib/PTO/Transforms/TileFusion/FusionOpSemantics.cpp b/lib/PTO/Transforms/TileFusion/FusionOpSemantics.cpp new file mode 100644 index 000000000..d613a0bc2 --- /dev/null +++ b/lib/PTO/Transforms/TileFusion/FusionOpSemantics.cpp @@ -0,0 +1,123 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/Transforms/TileFusion/FusionOpSemantics.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringSwitch.h" + +namespace mlir { +namespace pto { + +static FusionComputeFamily getFusionComputeFamily(StringRef opName) { + return llvm::StringSwitch(opName) + .Cases("tadd", "tsub", "tmul", "tdiv", "tmax", "tmin", + FusionComputeFamily::Elementwise) + .Cases("tadds", "tsubs", "tmuls", "tdivs", "tmaxs", "tmins", + FusionComputeFamily::Elementwise) + .Case("texp", FusionComputeFamily::Elementwise) + .Case("texpands", FusionComputeFamily::ScalarExpand) + .Cases("trowexpandmul", "trowexpanddiv", + FusionComputeFamily::RowBroadcastBinary) + .Cases("trowsum", "trowmax", "trowmin", FusionComputeFamily::ReduceRow) + .Cases("tcolsum", "tcolmax", "tcolmin", FusionComputeFamily::ReduceCol) + .Default(FusionComputeFamily::Unknown); +} + +bool isSupportedPreFusionComputeOp(StringRef opName) { + return getFusionComputeFamily(opName) != FusionComputeFamily::Unknown; +} + +static bool isTileFusionTileValue(Value value) { + return isa(value.getType()); +} + +static SmallVector collectNormalizedTileOutputs(Operation *op) { + SmallVector outputs; + + if (auto dpsIface = dyn_cast(op)) { + for (Value init : dpsIface.getDpsInits()) { + if (isTileFusionTileValue(init)) + outputs.push_back(init); + } + if (!outputs.empty()) + return outputs; + } + + for (Value result : op->getResults()) { + if (isTileFusionTileValue(result)) + outputs.push_back(result); + } + return outputs; +} + +static StringRef getTileFusionOpName(Operation *op) { + StringRef opName = op->getName().getStringRef(); + opName.consume_front("pto."); + return opName; +} + +FailureOr getFusionOpSemantics(Operation *op) { + FusionOpSemantics semantics; + semantics.op = op; + semantics.opName = getTileFusionOpName(op).str(); + + if (auto reshape = dyn_cast(op)) { + semantics.kind = FusionOpKind::LocalBoundary; + semantics.opName = "treshape"; + semantics.tileInputs.push_back(reshape.getSrc()); + semantics.tileOutputs.push_back(reshape.getResult()); + return semantics; + } + + semantics.computeFamily = getFusionComputeFamily(semantics.opName); + if (semantics.computeFamily == FusionComputeFamily::Unknown) { + semantics.kind = FusionOpKind::HardBoundary; + return semantics; + } + + auto dpsIface = dyn_cast(op); + if (!dpsIface && op->getNumResults() == 0) { + semantics.kind = FusionOpKind::HardBoundary; + return semantics; + } + + semantics.kind = FusionOpKind::Compute; + semantics.tileOutputs = collectNormalizedTileOutputs(op); + if (semantics.tileOutputs.empty()) + return failure(); + + SmallVector dpsInitOperandNumbers; + if (dpsIface) { + for (OpOperand &dpsInit : dpsIface.getDpsInitsMutable()) + dpsInitOperandNumbers.push_back(dpsInit.getOperandNumber()); + } + + for (OpOperand &operand : op->getOpOperands()) { + if (llvm::is_contained(dpsInitOperandNumbers, operand.getOperandNumber())) + continue; + + Value value = operand.get(); + if (isTileFusionTileValue(value)) + semantics.tileInputs.push_back(value); + else + semantics.scalarInputs.push_back(value); + } + + if (semantics.tileInputs.empty()) { + for (Value output : semantics.tileOutputs) { + if (!isa(output.getType())) + return failure(); + } + } + + return semantics; +} + +} // namespace pto +} // namespace mlir diff --git a/lib/PTO/Transforms/TileFusion/PTOFusionPlan.cpp b/lib/PTO/Transforms/TileFusion/PTOFusionPlan.cpp new file mode 100644 index 000000000..1ccdff7a2 --- /dev/null +++ b/lib/PTO/Transforms/TileFusion/PTOFusionPlan.cpp @@ -0,0 +1,526 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/Transforms/TileFusion/FusionAnalysis.h" +#include "PTO/Transforms/TileFusion/FusionOpSemantics.h" + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringSwitch.h" + +#include + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_FUSIONPLAN +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; + +namespace { + +static constexpr llvm::StringLiteral kFusionGroupIdAttr = + "pto.fusion.group_id"; +static constexpr llvm::StringLiteral kFusionOrderAttr = "pto.fusion.order"; + +struct PlannedFusionGroup { + SmallVector members; +}; + +struct PlanningContext { + const pto::FusionBlockAnalysis &blockAnalysis; +}; + +struct PlanningCost { + int64_t dependencyBenefit = 0; + int64_t loopMergeBenefit = 0; + int64_t liveTilePenalty = 0; + int64_t vfParameterPenalty = 0; + bool rejectedForDynamicShape = false; + + int64_t total() const { + return dependencyBenefit + loopMergeBenefit - liveTilePenalty - + vfParameterPenalty; + } +}; + +struct PlanningDecision { + bool accept = false; + PlanningCost cost; +}; + +static bool isCurrentlyPlannableOp(StringRef opName) { + return llvm::StringSwitch(opName) + .Cases("tmul", "tdiv", "tadd", "tsub", "tmax", "tmin", true) + .Cases("tmuls", "tdivs", "tadds", "tsubs", "tmaxs", "tmins", true) + .Case("texp", true) + .Case("texpands", true) + .Cases("trowexpandmul", "trowexpanddiv", true) + .Default(false); +} + +static bool isProvenIterationDomain( + const pto::FusionBlockAnalysis &blockAnalysis, + const pto::FusionComputeNode &node) { + if (node.iterationDomainClass >= blockAnalysis.iterationDomainClasses.size()) + return false; + return blockAnalysis.iterationDomainClasses[node.iterationDomainClass] + .info.proof == pto::IterationDomainProof::Proven; +} + +static bool dependsOnPreviousNode( + const pto::FusionBlockAnalysis &blockAnalysis, + const pto::FusionComputeNode &previous, + const pto::FusionComputeNode ¤t) { + for (unsigned edgeId : current.incomingEdges) { + if (edgeId >= blockAnalysis.edges.size()) + continue; + if (blockAnalysis.edges[edgeId].producerNode == previous.id) + return true; + } + + for (Value output : previous.semantics.tileOutputs) + if (llvm::is_contained(current.semantics.tileInputs, output)) + return true; + + return false; +} + +static SmallVector +buildStableInGroupOrder(ArrayRef members) { + SmallVector ordered(members.begin(), + members.end()); + llvm::stable_sort(ordered, [](const pto::FusionComputeNode *lhs, + const pto::FusionComputeNode *rhs) { + if (lhs->blockOrder != rhs->blockOrder) + return lhs->blockOrder < rhs->blockOrder; + return lhs->id < rhs->id; + }); + return ordered; +} + +static void assignStableGroupMetadata(ArrayRef groups, + MLIRContext *ctx, + int64_t &nextGroupId) { + SmallVector orderedGroups; + orderedGroups.reserve(groups.size()); + for (const PlannedFusionGroup &group : groups) + orderedGroups.push_back(&group); + + llvm::stable_sort(orderedGroups, [](const PlannedFusionGroup *lhs, + const PlannedFusionGroup *rhs) { + const pto::FusionComputeNode *lhsFirst = lhs->members.front(); + const pto::FusionComputeNode *rhsFirst = rhs->members.front(); + if (lhsFirst->blockOrder != rhsFirst->blockOrder) + return lhsFirst->blockOrder < rhsFirst->blockOrder; + return lhsFirst->id < rhsFirst->id; + }); + + for (const PlannedFusionGroup *group : orderedGroups) { + const int64_t groupId = nextGroupId++; + SmallVector stableOrder = + buildStableInGroupOrder(group->members); + for (auto [order, node] : llvm::enumerate(stableOrder)) { + node->op->setAttr(kFusionGroupIdAttr, + IntegerAttr::get(IntegerType::get(ctx, 64), groupId)); + node->op->setAttr( + kFusionOrderAttr, + IntegerAttr::get(IntegerType::get(ctx, 64), + static_cast(order))); + } + } +} + +static bool isSupportedPlanningNode(const pto::FusionComputeNode &node) { + return node.semantics.kind == pto::FusionOpKind::Compute && + isCurrentlyPlannableOp(node.semantics.opName); +} + +static unsigned +countEdgesFromGroup(const pto::FusionBlockAnalysis &blockAnalysis, + ArrayRef group, + const pto::FusionComputeNode &candidate) { + DenseSet producerIds; + for (const pto::FusionComputeNode *member : group) + producerIds.insert(member->id); + + unsigned count = 0; + for (unsigned edgeId : candidate.incomingEdges) { + if (edgeId >= blockAnalysis.edges.size()) + continue; + if (producerIds.contains(blockAnalysis.edges[edgeId].producerNode)) + ++count; + } + return count; +} + +struct GroupFootprint { + unsigned liveTileCount = 0; + unsigned vfParameterCount = 0; +}; + +static bool nodesHaveDirectDataFlowConnection( + const pto::FusionBlockAnalysis &blockAnalysis, + const pto::FusionComputeNode &lhs, const pto::FusionComputeNode &rhs) { + for (unsigned edgeId : lhs.outgoingEdges) { + if (edgeId >= blockAnalysis.edges.size()) + continue; + if (blockAnalysis.edges[edgeId].consumerNode == rhs.id) + return true; + } + + for (unsigned edgeId : lhs.incomingEdges) { + if (edgeId >= blockAnalysis.edges.size()) + continue; + if (blockAnalysis.edges[edgeId].producerNode == rhs.id) + return true; + } + + for (Value output : lhs.semantics.tileOutputs) + if (llvm::is_contained(rhs.semantics.tileInputs, output)) + return true; + + for (Value output : rhs.semantics.tileOutputs) + if (llvm::is_contained(lhs.semantics.tileInputs, output)) + return true; + + return false; +} + +static unsigned +countConnectionsToGroup(const pto::FusionBlockAnalysis &blockAnalysis, + ArrayRef group, + const pto::FusionComputeNode &candidate) { + unsigned connections = 0; + for (const pto::FusionComputeNode *member : group) + if (nodesHaveDirectDataFlowConnection(blockAnalysis, *member, candidate)) + ++connections; + return connections; +} + +static GroupFootprint +computeGroupFootprint(ArrayRef members) { + DenseSet producedTiles; + DenseSet touchedTiles; + DenseSet externalInputs; + + for (const pto::FusionComputeNode *member : members) { + for (Value output : member->semantics.tileOutputs) { + producedTiles.insert(output); + touchedTiles.insert(output); + } + } + + for (const pto::FusionComputeNode *member : members) { + for (Value input : member->semantics.tileInputs) { + touchedTiles.insert(input); + if (!producedTiles.contains(input)) + externalInputs.insert(input); + } + } + + GroupFootprint footprint; + footprint.liveTileCount = touchedTiles.size(); + footprint.vfParameterCount = externalInputs.size() + producedTiles.size(); + return footprint; +} + +class CostModel { +public: + virtual ~CostModel() = default; + + virtual PlanningDecision evaluateSeed(const PlanningContext &ctx, + const pto::FusionComputeNode &candidate) + const = 0; + + virtual PlanningDecision + evaluateAppend(const PlanningContext &ctx, + ArrayRef currentGroup, + const pto::FusionComputeNode &candidate) const = 0; +}; + +class ConservativeGreedyCostModel final : public CostModel { +public: + PlanningDecision + evaluateSeed(const PlanningContext &ctx, + const pto::FusionComputeNode &candidate) const override { + PlanningDecision decision; + if (!isSupportedPlanningNode(candidate)) + return decision; + + if (!isProvenIterationDomain(ctx.blockAnalysis, candidate)) { + decision.cost.rejectedForDynamicShape = true; + return decision; + } + + decision.accept = true; + return decision; + } + + PlanningDecision + evaluateAppend(const PlanningContext &ctx, + ArrayRef currentGroup, + const pto::FusionComputeNode &candidate) const override { + PlanningDecision seedDecision = evaluateSeed(ctx, candidate); + if (!seedDecision.accept) + return seedDecision; + + PlanningDecision decision; + if (currentGroup.empty()) { + decision.accept = true; + return decision; + } + + const pto::FusionComputeNode &previous = *currentGroup.back(); + const bool sameDomainClass = + previous.iterationDomainClass == candidate.iterationDomainClass; + const bool contiguousInBlock = + candidate.blockOrder == previous.blockOrder + 1; + const bool directlyDependent = + dependsOnPreviousNode(ctx.blockAnalysis, previous, candidate); + if (!sameDomainClass || !contiguousInBlock || !directlyDependent) + return decision; + + SmallVector proposedGroup( + currentGroup.begin(), currentGroup.end()); + proposedGroup.push_back(&candidate); + GroupFootprint footprint = computeGroupFootprint(proposedGroup); + + decision.cost.dependencyBenefit = + 4 * static_cast( + countEdgesFromGroup(ctx.blockAnalysis, currentGroup, candidate)); + decision.cost.loopMergeBenefit = 2; + decision.cost.liveTilePenalty = + std::max(0, static_cast(footprint.liveTileCount) - 4); + decision.cost.vfParameterPenalty = std::max( + 0, static_cast(footprint.vfParameterCount) - 6); + decision.accept = decision.cost.total() > 0; + return decision; + } +}; + +class ConservativeDAGGreedyCostModel final : public CostModel { +public: + PlanningDecision + evaluateSeed(const PlanningContext &ctx, + const pto::FusionComputeNode &candidate) const override { + PlanningDecision decision; + if (!isSupportedPlanningNode(candidate)) + return decision; + + if (!isProvenIterationDomain(ctx.blockAnalysis, candidate)) { + decision.cost.rejectedForDynamicShape = true; + return decision; + } + + decision.accept = true; + return decision; + } + + PlanningDecision + evaluateAppend(const PlanningContext &ctx, + ArrayRef currentGroup, + const pto::FusionComputeNode &candidate) const override { + PlanningDecision seedDecision = evaluateSeed(ctx, candidate); + if (!seedDecision.accept) + return seedDecision; + + PlanningDecision decision; + if (currentGroup.empty()) { + decision.accept = true; + return decision; + } + + if (currentGroup.front()->iterationDomainClass != + candidate.iterationDomainClass) + return decision; + + const unsigned connectionCount = + countConnectionsToGroup(ctx.blockAnalysis, currentGroup, candidate); + if (connectionCount == 0) + return decision; + + SmallVector proposedGroup( + currentGroup.begin(), currentGroup.end()); + proposedGroup.push_back(&candidate); + GroupFootprint footprint = computeGroupFootprint(proposedGroup); + + decision.cost.dependencyBenefit = 4 * static_cast(connectionCount); + decision.cost.loopMergeBenefit = 4; + decision.cost.liveTilePenalty = std::max( + 0, static_cast(footprint.liveTileCount) - 10); + decision.cost.vfParameterPenalty = std::max( + 0, static_cast(footprint.vfParameterCount) - 12); + decision.accept = decision.cost.total() > 0; + return decision; + } +}; + +class StrategyEngine { +public: + virtual ~StrategyEngine() = default; + + virtual SmallVector + planBlock(const PlanningContext &ctx, const CostModel &costModel) const = 0; +}; + +class ConservativeGreedyStrategyEngine final : public StrategyEngine { +public: + SmallVector + planBlock(const PlanningContext &ctx, + const CostModel &costModel) const override { + SmallVector groups; + SmallVector chain; + + auto flushChain = [&]() { + if (chain.size() < 2) { + chain.clear(); + return; + } + + PlannedFusionGroup group; + group.members = chain; + groups.push_back(std::move(group)); + chain.clear(); + }; + + for (const pto::FusionComputeNode &node : ctx.blockAnalysis.computeNodes) { + PlanningDecision seedDecision = costModel.evaluateSeed(ctx, node); + if (!seedDecision.accept) { + flushChain(); + continue; + } + + if (chain.empty()) { + chain.push_back(&node); + continue; + } + + PlanningDecision appendDecision = + costModel.evaluateAppend(ctx, chain, node); + if (!appendDecision.accept) { + flushChain(); + chain.push_back(&node); + continue; + } + + chain.push_back(&node); + } + + flushChain(); + return groups; + } +}; + +class ConservativeDAGGreedyStrategyEngine final : public StrategyEngine { +public: + SmallVector + planBlock(const PlanningContext &ctx, + const CostModel &costModel) const override { + SmallVector groups; + DenseSet assignedNodes; + + for (const pto::FusionComputeNode &seed : ctx.blockAnalysis.computeNodes) { + if (assignedNodes.contains(seed.id)) + continue; + + PlanningDecision seedDecision = costModel.evaluateSeed(ctx, seed); + if (!seedDecision.accept) + continue; + + SmallVector groupMembers; + DenseSet groupNodeIds; + groupMembers.push_back(&seed); + groupNodeIds.insert(seed.id); + + bool changed = true; + while (changed) { + changed = false; + for (const pto::FusionComputeNode &candidate : + ctx.blockAnalysis.computeNodes) { + if (assignedNodes.contains(candidate.id) || + groupNodeIds.contains(candidate.id)) + continue; + + PlanningDecision appendDecision = + costModel.evaluateAppend(ctx, groupMembers, candidate); + if (!appendDecision.accept) + continue; + + groupMembers.push_back(&candidate); + groupNodeIds.insert(candidate.id); + changed = true; + } + } + + if (groupMembers.size() < 2) + continue; + + PlannedFusionGroup group; + group.members = buildStableInGroupOrder(groupMembers); + groups.push_back(group); + for (const pto::FusionComputeNode *member : group.members) + assignedNodes.insert(member->id); + } + + return groups; + } +}; + +static void clearPlanningAttrs(func::FuncOp func) { + func.walk([](Operation *op) { + op->removeAttr(kFusionGroupIdAttr); + op->removeAttr(kFusionOrderAttr); + }); +} + +struct FusionPlanPass : public pto::impl::FusionPlanBase { + using pto::impl::FusionPlanBase::FusionPlanBase; + + void runOnOperation() override { + func::FuncOp func = getOperation(); + if (func.isExternal()) + return; + + clearPlanningAttrs(func); + + const auto &analysis = getAnalysis(); + if (!analysis.isValid()) { + signalPassFailure(); + return; + } + + MLIRContext *ctx = &getContext(); + int64_t nextGroupId = 0; + ConservativeDAGGreedyCostModel costModel; + ConservativeDAGGreedyStrategyEngine strategyEngine; + + for (const pto::FusionBlockAnalysis &blockAnalysis : + analysis.getResult().blocks) { + PlanningContext planningCtx{blockAnalysis}; + SmallVector groups = + strategyEngine.planBlock(planningCtx, costModel); + assignStableGroupMetadata(groups, ctx, nextGroupId); + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createFusionPlanPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/TileFusion/PTOMarkLastUse.cpp b/lib/PTO/Transforms/TileFusion/PTOMarkLastUse.cpp new file mode 100644 index 000000000..df4231b8f --- /dev/null +++ b/lib/PTO/Transforms/TileFusion/PTOMarkLastUse.cpp @@ -0,0 +1,268 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" + +#include +#include + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_PTOMARKLASTUSE +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; + +namespace { + +static constexpr llvm::StringLiteral kFusionGroupIdAttr = + "pto.fusion.group_id"; +static constexpr llvm::StringLiteral kFusionOrderAttr = "pto.fusion.order"; +static constexpr llvm::StringLiteral kLastUseAttrName = "pto.last_use"; + +struct GroupSpanMember { + Operation *op = nullptr; + int64_t order = 0; +}; + +struct GroupSpan { + Block *block = nullptr; + int64_t groupId = -1; + SmallVector members; +}; + +static bool isTileType(Type type) { + return isa(type); +} + +static bool isDpsInitOperand(OpOperand &operand) { + Operation *owner = operand.getOwner(); + if (auto dpsIface = dyn_cast(owner)) { + for (OpOperand &init : dpsIface.getDpsInitsMutable()) { + if (&init == &operand) + return true; + } + } + return false; +} + +static bool isTileOperand(OpOperand &operand) { + return isTileType(operand.get().getType()); +} + +static bool isTileInputOperand(OpOperand &operand) { + return isTileOperand(operand) && !isDpsInitOperand(operand); +} + +// The last-use mask is indexed by tile operand slots only, in source operand +// order after filtering out scalar operands. DPS init/output tile slots are +// preserved and always materialize as 0. +static SmallVector collectTileOperands(Operation *op) { + SmallVector tileOperands; + for (OpOperand &operand : op->getOpOperands()) { + if (isTileOperand(operand)) + tileOperands.push_back(&operand); + } + return tileOperands; +} + +static std::optional getRequiredI64Attr(Operation *op, + StringRef attrName) { + if (auto attr = op->getAttrOfType(attrName)) + return attr.getInt(); + return std::nullopt; +} + +static bool hasIncompleteFusionMetadata(Operation *op) { + const bool hasGroupId = op->hasAttr(kFusionGroupIdAttr); + const bool hasOrder = op->hasAttr(kFusionOrderAttr); + return hasGroupId != hasOrder; +} + +static LogicalResult +collectGroupSpansInBlock(Block &block, SmallVectorImpl &spans) { + DenseMap spanIndexByGroupId; + + GroupSpan current; + + auto flush = [&]() -> LogicalResult { + if (current.members.empty()) + return success(); + + current.block = █ + auto [it, inserted] = + spanIndexByGroupId.try_emplace(current.groupId, spans.size()); + if (!inserted) { + current.members.front().op->emitError( + "expected one contiguous span per pto.fusion.group_id within a basic " + "block"); + return failure(); + } + + spans.push_back(std::move(current)); + current = GroupSpan(); + return success(); + }; + + for (Operation &op : block) { + if (hasIncompleteFusionMetadata(&op)) { + op.emitError("expected pto.fusion.group_id and pto.fusion.order to " + "either both exist or both be absent"); + return failure(); + } + + std::optional groupId = + getRequiredI64Attr(&op, kFusionGroupIdAttr); + if (!groupId) { + if (failed(flush())) + return failure(); + continue; + } + + std::optional order = getRequiredI64Attr(&op, kFusionOrderAttr); + if (!order) { + op.emitError("missing required pto.fusion.order attribute"); + return failure(); + } + + if (current.members.empty()) { + current.groupId = *groupId; + current.members.push_back(GroupSpanMember{&op, *order}); + continue; + } + + if (current.groupId != *groupId) { + if (failed(flush())) + return failure(); + current.groupId = *groupId; + } + + if (!current.members.empty() && current.members.back().order >= *order) { + op.emitError("expected contiguous fusion span to follow increasing " + "pto.fusion.order"); + return failure(); + } + + current.members.push_back(GroupSpanMember{&op, *order}); + } + + return flush(); +} + +static bool isSpanLocalLastUseCandidate(Value value, Operation *currentOp, + Block *block) { + if (!value) + return false; + + for (OpOperand &use : value.getUses()) { + Operation *user = use.getOwner(); + if (user == currentOp) + continue; + if (user->getBlock() != block) + return false; + if (currentOp->isBeforeInBlock(user)) + return false; + } + return true; +} + +static bool hasLaterUseAfterSpan(Value value, Operation *spanEnd, Block *block) { + for (OpOperand &use : value.getUses()) { + Operation *user = use.getOwner(); + if (user->getBlock() == block) + return spanEnd->isBeforeInBlock(user); + return true; + } + return false; +} + +static void markGroupSpanLastUse(const GroupSpan &span) { + if (span.members.empty()) + return; + + Block &block = *span.block; + Operation *spanEnd = span.members.back().op; + for (const GroupSpanMember &member : span.members) { + Operation &op = *member.op; + SmallVector tileOperands = collectTileOperands(&op); + if (tileOperands.empty()) { + op.removeAttr(kLastUseAttrName); + continue; + } + + SmallVector lastUseMask; + lastUseMask.reserve(tileOperands.size()); + for (OpOperand *operand : tileOperands) { + if (!isTileInputOperand(*operand)) { + lastUseMask.push_back(0); + continue; + } + bool blockedByLaterSpanUse = + !isSpanLocalLastUseCandidate(operand->get(), &op, &block); + bool blockedByLaterPostSpanUse = + hasLaterUseAfterSpan(operand->get(), spanEnd, &block); + lastUseMask.push_back( + (!blockedByLaterSpanUse && !blockedByLaterPostSpanUse) ? 1 : 0); + } + + op.setAttr(kLastUseAttrName, + Builder(op.getContext()).getDenseI64ArrayAttr(lastUseMask)); + } +} + +static LogicalResult markRegionLastUse(Region ®ion) { + for (Block &block : region.getBlocks()) { + SmallVector spans; + if (failed(collectGroupSpansInBlock(block, spans))) + return failure(); + for (const GroupSpan &span : spans) + markGroupSpanLastUse(span); + + for (Operation &op : block) + for (Region &nestedRegion : op.getRegions()) + if (failed(markRegionLastUse(nestedRegion))) + return failure(); + } + return success(); +} + +struct PTOMarkLastUsePass + : public pto::impl::PTOMarkLastUseBase { + using pto::impl::PTOMarkLastUseBase< + PTOMarkLastUsePass>::PTOMarkLastUseBase; + + void runOnOperation() override { + func::FuncOp func = getOperation(); + if (func.isExternal()) + return; + + if (failed(markRegionLastUse(func.getRegion()))) { + signalPassFailure(); + return; + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createPTOMarkLastUsePass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/TileFusion/PTOOpScheduling.cpp b/lib/PTO/Transforms/TileFusion/PTOOpScheduling.cpp new file mode 100644 index 000000000..745f739c5 --- /dev/null +++ b/lib/PTO/Transforms/TileFusion/PTOOpScheduling.cpp @@ -0,0 +1,321 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/Transforms/TileFusion/FusionOpSemantics.h" + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" + +#include + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_OPSCHEDULING +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; + +namespace { + +static constexpr llvm::StringLiteral kFusionGroupIdAttr = + "pto.fusion.group_id"; +static constexpr llvm::StringLiteral kFusionOrderAttr = "pto.fusion.order"; + +enum class SchedulingBarrierKind { + Movable, + LocalBoundary, + HardBoundary, +}; + +struct GroupMember { + Operation *op = nullptr; + int64_t order = 0; + unsigned originalIndex = 0; +}; + +struct ScheduledGroup { + int64_t groupId = 0; + unsigned firstOriginalIndex = 0; + SmallVector members; +}; + +static std::optional getRequiredI64Attr(Operation *op, + StringRef attrName) { + if (auto attr = op->getAttrOfType(attrName)) + return attr.getInt(); + return std::nullopt; +} + +static bool hasIncompleteFusionMetadata(Operation *op) { + const bool hasGroupId = op->hasAttr(kFusionGroupIdAttr); + const bool hasOrder = op->hasAttr(kFusionOrderAttr); + return hasGroupId != hasOrder; +} + +static bool sharesAnyValue(ArrayRef lhs, ArrayRef rhs) { + for (Value value : lhs) + if (llvm::is_contained(rhs, value)) + return true; + return false; +} + +static SchedulingBarrierKind classifySchedulingBarrier(Operation *op) { + if (op->hasTrait() || !op->getRegions().empty()) + return SchedulingBarrierKind::HardBoundary; + if (isa(op)) + return SchedulingBarrierKind::HardBoundary; + + FailureOr semanticsOr = pto::getFusionOpSemantics(op); + if (succeeded(semanticsOr)) { + switch (semanticsOr->kind) { + case pto::FusionOpKind::Compute: + return SchedulingBarrierKind::Movable; + case pto::FusionOpKind::LocalBoundary: + return SchedulingBarrierKind::LocalBoundary; + case pto::FusionOpKind::HardBoundary: + return SchedulingBarrierKind::HardBoundary; + } + } + if (!isMemoryEffectFree(op)) + return SchedulingBarrierKind::HardBoundary; + return SchedulingBarrierKind::Movable; +} + +static bool hasDependencyOnLocalBoundary(Operation *movingOp, + Operation *boundaryOp) { + FailureOr movingSemanticsOr = + pto::getFusionOpSemantics(movingOp); + FailureOr boundarySemanticsOr = + pto::getFusionOpSemantics(boundaryOp); + if (failed(movingSemanticsOr) || failed(boundarySemanticsOr)) + return true; + + const pto::FusionOpSemantics &movingSemantics = *movingSemanticsOr; + const pto::FusionOpSemantics &boundarySemantics = *boundarySemanticsOr; + + return sharesAnyValue(boundarySemantics.tileOutputs, + movingSemantics.tileInputs) || + sharesAnyValue(boundarySemantics.tileOutputs, + movingSemantics.tileOutputs) || + sharesAnyValue(movingSemantics.tileOutputs, + boundarySemantics.tileInputs) || + sharesAnyValue(movingSemantics.tileOutputs, + boundarySemantics.tileOutputs); +} + +static bool crossesOperandDefinition(Operation *movingOp, Operation *candidate) { + for (Value operand : movingOp->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + if (defOp == candidate) + return true; + } + return false; +} + +static bool canMoveEarlierAcross(Operation *movingOp, Operation *candidate) { + if (crossesOperandDefinition(movingOp, candidate)) + return false; + + switch (classifySchedulingBarrier(candidate)) { + case SchedulingBarrierKind::Movable: + return true; + case SchedulingBarrierKind::LocalBoundary: + return !hasDependencyOnLocalBoundary(movingOp, candidate); + case SchedulingBarrierKind::HardBoundary: + return false; + } + return false; +} + +static bool canMoveLaterAcross(Operation *movingOp, Operation *candidate) { + for (Value operand : candidate->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + if (defOp == movingOp) + return false; + } + + FailureOr movingSemanticsOr = + pto::getFusionOpSemantics(movingOp); + FailureOr candidateSemanticsOr = + pto::getFusionOpSemantics(candidate); + if (succeeded(movingSemanticsOr) && succeeded(candidateSemanticsOr)) { + if (sharesAnyValue(movingSemanticsOr->tileOutputs, + candidateSemanticsOr->tileInputs) || + sharesAnyValue(movingSemanticsOr->tileOutputs, + candidateSemanticsOr->tileOutputs)) + return false; + } + + switch (classifySchedulingBarrier(candidate)) { + case SchedulingBarrierKind::Movable: + return true; + case SchedulingBarrierKind::LocalBoundary: + return !hasDependencyOnLocalBoundary(movingOp, candidate); + case SchedulingBarrierKind::HardBoundary: + return false; + } + return false; +} + +static bool canMoveAfter(Operation *movingOp, Operation *anchorOp) { + if (!movingOp || !anchorOp || movingOp == anchorOp) + return false; + if (movingOp->getBlock() != anchorOp->getBlock()) + return false; + + Operation *cursor = anchorOp->getNextNode(); + while (cursor && cursor != movingOp) { + if (!canMoveEarlierAcross(movingOp, cursor)) + return false; + cursor = cursor->getNextNode(); + } + return cursor == movingOp; +} + +static LogicalResult +collectScheduledGroups(Block &block, SmallVectorImpl &groups) { + DenseMap groupIndexById; + + unsigned originalIndex = 0; + for (Operation &op : block) { + if (hasIncompleteFusionMetadata(&op)) { + op.emitError("expected pto.fusion.group_id and pto.fusion.order to " + "either both exist or both be absent"); + return failure(); + } + + std::optional groupId = + getRequiredI64Attr(&op, kFusionGroupIdAttr); + if (!groupId) { + ++originalIndex; + continue; + } + + std::optional order = getRequiredI64Attr(&op, kFusionOrderAttr); + if (!order) { + op.emitError("missing required pto.fusion.order attribute"); + return failure(); + } + + auto [it, inserted] = groupIndexById.try_emplace(*groupId, groups.size()); + if (inserted) { + ScheduledGroup group; + group.groupId = *groupId; + group.firstOriginalIndex = originalIndex; + groups.push_back(std::move(group)); + } + + ScheduledGroup &group = groups[it->second]; + group.members.push_back(GroupMember{&op, *order, originalIndex}); + ++originalIndex; + } + + llvm::sort(groups, [](const ScheduledGroup &lhs, const ScheduledGroup &rhs) { + if (lhs.firstOriginalIndex != rhs.firstOriginalIndex) + return lhs.firstOriginalIndex < rhs.firstOriginalIndex; + return lhs.groupId < rhs.groupId; + }); + + for (ScheduledGroup &group : groups) { + llvm::sort(group.members, [](const GroupMember &lhs, const GroupMember &rhs) { + if (lhs.order != rhs.order) + return lhs.order < rhs.order; + return lhs.originalIndex < rhs.originalIndex; + }); + + std::optional previousOrder; + for (const GroupMember &member : group.members) { + if (classifySchedulingBarrier(member.op) != + SchedulingBarrierKind::Movable) { + member.op->emitError("fusion scheduling metadata must only annotate " + "movable compute ops"); + return failure(); + } + if (previousOrder && *previousOrder == member.order) { + member.op->emitError("duplicate pto.fusion.order within one fusion " + "group"); + return failure(); + } + previousOrder = member.order; + } + } + + return success(); +} + +static void scheduleGroup(ScheduledGroup &group) { + if (group.members.size() < 2) + return; + + Operation *placement = group.members.front().op; + for (GroupMember &member : llvm::drop_begin(group.members)) { + Operation *op = member.op; + while (op != placement && op != placement->getNextNode()) { + if (canMoveAfter(op, placement)) { + op->moveAfter(placement); + break; + } + + Operation *blockingOp = placement->getNextNode(); + if (!blockingOp || blockingOp == op || + !canMoveLaterAcross(placement, blockingOp)) + break; + + placement->moveAfter(blockingOp); + } + placement = op; + } +} + +static LogicalResult scheduleRegion(Region ®ion) { + for (Block &block : region.getBlocks()) { + SmallVector groups; + if (failed(collectScheduledGroups(block, groups))) + return failure(); + for (ScheduledGroup &group : groups) + scheduleGroup(group); + + for (Operation &op : block) + for (Region &nestedRegion : op.getRegions()) + if (failed(scheduleRegion(nestedRegion))) + return failure(); + } + return success(); +} + +struct OpSchedulingPass + : public pto::impl::OpSchedulingBase { + using pto::impl::OpSchedulingBase::OpSchedulingBase; + + void runOnOperation() override { + func::FuncOp func = getOperation(); + if (func.isExternal()) + return; + + if (failed(scheduleRegion(func.getRegion()))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createOpSchedulingPass() { + return std::make_unique(); +} diff --git a/test/lit/tile_fusion/final_emitc_last_use_level2.pto b/test/lit/tile_fusion/final_emitc_last_use_level2.pto new file mode 100644 index 000000000..961ba81d4 --- /dev/null +++ b/test/lit/tile_fusion/final_emitc_last_use_level2.pto @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Guards the reduced frontend fusion end-to-end EmitC contract: +// - the shared mainline stays free of any pto.fusion_region / pto.yield lifecycle +// - the final generated C++ still emits output-first last_use slots as +// [[pto::last_use(0, 1, 1)]] CALLEE(...) +// +// RUN: { ptoas --pto-arch=a5 --pto-level=level2 --enable-op-fusion --emit-pto-ir %s --mlir-print-ir-after=pto-view-to-memref -o /dev/null 2>&1 || true; } | FileCheck %s --check-prefix=IR +// RUN: ptoas --pto-arch=a5 --pto-level=level2 --enable-op-fusion %s -o - | FileCheck %s --check-prefix=CPP + +module { + func.func @final_emitc_last_use_level2( + %dst_ptr: !pto.ptr, + %src0: !pto.tile_buf, + %src1: !pto.tile_buf, + %src2: !pto.tile_buf) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c32, %c32], strides = [%c32, %c1] + : !pto.tensor_view + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0], sizes = [%c32, %c32] + : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + + %tmp0 = pto.alloc_tile : !pto.tile_buf + %tmp1 = pto.alloc_tile : !pto.tile_buf + + pto.tadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%tmp0 : !pto.tile_buf) + pto.tadd ins(%tmp0, %src2 : !pto.tile_buf, + !pto.tile_buf) + outs(%tmp1 : !pto.tile_buf) + pto.tstore ins(%tmp1 : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<32x32xf32>) + return + } +} + +// IR-LABEL: func.func @final_emitc_last_use_level2( +// IR-NOT: pto.fusion_region +// IR-NOT: pto.yield +// IR: pto.tadd{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 0 : i64{{.*}}} +// IR: pto.tadd{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 1 : i64{{.*}}} +// IR: pto.tstore ins( + +// CPP-LABEL: __global__ AICORE void final_emitc_last_use_level2( +// CPP-NOT: pto.fusion_region +// CPP-NOT: pto.yield +// CPP-NOT: PTOAS__LAST_USE__ +// CPP: {{\[\[pto::last_use\(0, 1, 1\)\]\] TADD\(}} +// CPP: {{\[\[pto::last_use\(0, 1, 1\)\]\] TADD\(}} +// CPP: TSTORE( diff --git a/test/lit/tile_fusion/fusion_plan_diamond.pto b/test/lit/tile_fusion/fusion_plan_diamond.pto new file mode 100644 index 000000000..ccaf40b76 --- /dev/null +++ b/test/lit/tile_fusion/fusion_plan_diamond.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Guards FusionPlan annotation for a diamond-shaped tile DAG so all supported +// block-local compute ops remain in one ordered fusion group. +// +// RUN: ptoas --pto-arch=a5 --pto-level=level2 --enable-op-fusion --emit-pto-ir %s --mlir-print-ir-after=pto-fusion-plan -o /dev/null 2>&1 | FileCheck %s + +module { + func.func @fusion_plan_diamond( + %arg0: !pto.tile_buf, + %arg1: !pto.tile_buf, + %arg2: !pto.tile_buf, + %arg3: !pto.tile_buf, + %arg4: !pto.tile_buf, + %arg5: !pto.tile_buf) { + %tmp0 = pto.alloc_tile : !pto.tile_buf + %tmp1 = pto.alloc_tile : !pto.tile_buf + %tmp2 = pto.alloc_tile : !pto.tile_buf + %tmp3 = pto.alloc_tile : !pto.tile_buf + %tmp4 = pto.alloc_tile : !pto.tile_buf + %tmp5 = pto.alloc_tile : !pto.tile_buf + %tmp6 = pto.alloc_tile : !pto.tile_buf + %tmp7 = pto.alloc_tile : !pto.tile_buf + + pto.tmax ins(%arg0, %arg1 : !pto.tile_buf, !pto.tile_buf) outs(%tmp0 : !pto.tile_buf) + pto.tsub ins(%tmp0, %arg2 : !pto.tile_buf, !pto.tile_buf) outs(%tmp1 : !pto.tile_buf) + pto.tsub ins(%tmp0, %arg3 : !pto.tile_buf, !pto.tile_buf) outs(%tmp2 : !pto.tile_buf) + pto.texp ins(%tmp1 : !pto.tile_buf) outs(%tmp3 : !pto.tile_buf) + pto.texp ins(%tmp2 : !pto.tile_buf) outs(%tmp4 : !pto.tile_buf) + pto.tmul ins(%tmp3, %arg4 : !pto.tile_buf, !pto.tile_buf) outs(%tmp5 : !pto.tile_buf) + pto.tmul ins(%tmp4, %arg5 : !pto.tile_buf, !pto.tile_buf) outs(%tmp6 : !pto.tile_buf) + pto.tadd ins(%tmp5, %tmp6 : !pto.tile_buf, !pto.tile_buf) outs(%tmp7 : !pto.tile_buf) + return + } +} + +// CHECK: // -----// IR Dump After FusionPlan (pto-fusion-plan) //----- // +// CHECK-LABEL: func.func @fusion_plan_diamond( +// CHECK: pto.tmax{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 0 : i64} +// CHECK: pto.tsub{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 1 : i64} +// CHECK: pto.tsub{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 2 : i64} +// CHECK: pto.texp{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 3 : i64} +// CHECK: pto.texp{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 4 : i64} +// CHECK: pto.tmul{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 5 : i64} +// CHECK: pto.tmul{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 6 : i64} +// CHECK: pto.tadd{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 7 : i64} diff --git a/test/lit/tile_fusion/fusion_plan_dynamic_shape_negative.pto b/test/lit/tile_fusion/fusion_plan_dynamic_shape_negative.pto new file mode 100644 index 000000000..5685c2475 --- /dev/null +++ b/test/lit/tile_fusion/fusion_plan_dynamic_shape_negative.pto @@ -0,0 +1,34 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Guards FusionPlan's conservative dynamic-shape behavior: dynamic valid-shape +// tile ops must not receive pto.fusion.* metadata. +// +// RUN: ptoas --pto-arch=a5 --pto-level=level2 --enable-op-fusion --emit-pto-ir %s --mlir-print-ir-after=pto-fusion-plan -o /dev/null 2>&1 | FileCheck %s + +module { + func.func @fusion_plan_dynamic_shape_negative( + %arg0: !pto.tile_buf, + %vrow: index, + %vcol: index) { + %tmp0 = pto.alloc_tile valid_row = %vrow valid_col = %vcol : !pto.tile_buf + %tmp1 = pto.alloc_tile valid_row = %vrow valid_col = %vcol : !pto.tile_buf + + pto.tadd ins(%arg0, %arg0 : !pto.tile_buf, !pto.tile_buf) outs(%tmp0 : !pto.tile_buf) + pto.tmul ins(%tmp0, %arg0 : !pto.tile_buf, !pto.tile_buf) outs(%tmp1 : !pto.tile_buf) + return + } +} + +// CHECK: // -----// IR Dump After FusionPlan (pto-fusion-plan) //----- // +// CHECK-LABEL: func.func @fusion_plan_dynamic_shape_negative( +// CHECK: pto.tadd ins(%arg0, %arg0 +// CHECK-NOT: pto.fusion.group_id +// CHECK: pto.tmul ins(%0, %arg0 +// CHECK-NOT: pto.fusion.group_id +// CHECK: return diff --git a/test/lit/tile_fusion/fusion_plan_interleaved_join.pto b/test/lit/tile_fusion/fusion_plan_interleaved_join.pto new file mode 100644 index 000000000..8b818230f --- /dev/null +++ b/test/lit/tile_fusion/fusion_plan_interleaved_join.pto @@ -0,0 +1,45 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Guards FusionPlan group discovery when unrelated compute ops are interleaved +// with an otherwise valid join-shaped fusion group. +// +// RUN: ptoas --pto-arch=a5 --pto-level=level2 --enable-op-fusion --emit-pto-ir %s --mlir-print-ir-after=pto-fusion-plan -o /dev/null 2>&1 | FileCheck %s + +module { + func.func @fusion_plan_interleaved_join( + %full0: !pto.tile_buf, + %full1: !pto.tile_buf, + %full2: !pto.tile_buf, + %full3: !pto.tile_buf, + %row0: !pto.tile_buf, + %row1: !pto.tile_buf) { + %tmp0 = pto.alloc_tile : !pto.tile_buf + %noise0 = pto.alloc_tile : !pto.tile_buf + %tmp1 = pto.alloc_tile : !pto.tile_buf + %noise1 = pto.alloc_tile : !pto.tile_buf + %tmp2 = pto.alloc_tile : !pto.tile_buf + + pto.trowexpandmul ins(%full0, %row0 : !pto.tile_buf, !pto.tile_buf) outs(%tmp0 : !pto.tile_buf) + pto.tadd ins(%full2, %full3 : !pto.tile_buf, !pto.tile_buf) outs(%noise0 : !pto.tile_buf) + pto.trowexpandmul ins(%full1, %row1 : !pto.tile_buf, !pto.tile_buf) outs(%tmp1 : !pto.tile_buf) + pto.tmul ins(%full2, %full3 : !pto.tile_buf, !pto.tile_buf) outs(%noise1 : !pto.tile_buf) + pto.tadd ins(%tmp0, %tmp1 : !pto.tile_buf, !pto.tile_buf) outs(%tmp2 : !pto.tile_buf) + return + } +} + +// CHECK: // -----// IR Dump After FusionPlan (pto-fusion-plan) //----- // +// CHECK-LABEL: func.func @fusion_plan_interleaved_join( +// CHECK: pto.trowexpandmul{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 0 : i64} +// CHECK: pto.tadd ins(%arg2, %arg3 +// CHECK-NOT: pto.fusion.group_id +// CHECK: pto.trowexpandmul{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 1 : i64} +// CHECK: pto.tmul ins(%arg2, %arg3 +// CHECK-NOT: pto.fusion.group_id +// CHECK: pto.tadd{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 2 : i64} diff --git a/test/lit/tile_fusion/fusion_plan_join.pto b/test/lit/tile_fusion/fusion_plan_join.pto new file mode 100644 index 000000000..7581993d8 --- /dev/null +++ b/test/lit/tile_fusion/fusion_plan_join.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Guards FusionPlan annotation for a join-shaped block-local tile DAG before +// scheduling and region formation consume the metadata. +// +// RUN: ptoas --pto-arch=a5 --pto-level=level2 --enable-op-fusion --emit-pto-ir %s --mlir-print-ir-after=pto-fusion-plan -o /dev/null 2>&1 | FileCheck %s + +module { + func.func @fusion_plan_join( + %full0: !pto.tile_buf, + %full1: !pto.tile_buf, + %row0: !pto.tile_buf, + %row1: !pto.tile_buf) { + %tmp0 = pto.alloc_tile : !pto.tile_buf + %tmp1 = pto.alloc_tile : !pto.tile_buf + %tmp2 = pto.alloc_tile : !pto.tile_buf + + pto.trowexpandmul ins(%full0, %row0 : !pto.tile_buf, !pto.tile_buf) outs(%tmp0 : !pto.tile_buf) + pto.trowexpandmul ins(%full1, %row1 : !pto.tile_buf, !pto.tile_buf) outs(%tmp1 : !pto.tile_buf) + pto.tadd ins(%tmp0, %tmp1 : !pto.tile_buf, !pto.tile_buf) outs(%tmp2 : !pto.tile_buf) + return + } +} + +// CHECK: // -----// IR Dump After FusionPlan (pto-fusion-plan) //----- // +// CHECK-LABEL: func.func @fusion_plan_join( +// CHECK: pto.trowexpandmul{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 0 : i64} +// CHECK: pto.trowexpandmul{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 1 : i64} +// CHECK: pto.tadd{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 2 : i64} diff --git a/test/lit/tile_fusion/fusion_plan_treshape_boundary.pto b/test/lit/tile_fusion/fusion_plan_treshape_boundary.pto new file mode 100644 index 000000000..5e06043c0 --- /dev/null +++ b/test/lit/tile_fusion/fusion_plan_treshape_boundary.pto @@ -0,0 +1,44 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Guards FusionPlan boundaries around pto.treshape while allowing a later +// unrelated same-domain chain to form a normal fusion group. +// +// RUN: ptoas --pto-arch=a5 --pto-level=level2 --enable-op-fusion --emit-pto-ir %s --mlir-print-ir-after=pto-fusion-plan -o /dev/null 2>&1 | FileCheck %s + +module { + func.func @fusion_plan_treshape_boundary( + %arg0: !pto.tile_buf, + %arg1: !pto.tile_buf, + %arg2: !pto.tile_buf) { + %tmp0 = pto.alloc_tile : !pto.tile_buf + %tmp1 = pto.alloc_tile : !pto.tile_buf + %tmp2 = pto.alloc_tile : !pto.tile_buf + %tmp3 = pto.alloc_tile : !pto.tile_buf + + pto.tadd ins(%arg0, %arg0 : !pto.tile_buf, !pto.tile_buf) outs(%tmp0 : !pto.tile_buf) + %view = pto.treshape %tmp0 : !pto.tile_buf -> !pto.tile_buf + pto.tmul ins(%view, %arg1 : !pto.tile_buf, !pto.tile_buf) outs(%tmp1 : !pto.tile_buf) + pto.tadd ins(%arg2, %arg2 : !pto.tile_buf, !pto.tile_buf) outs(%tmp2 : !pto.tile_buf) + pto.tmul ins(%tmp2, %arg2 : !pto.tile_buf, !pto.tile_buf) outs(%tmp3 : !pto.tile_buf) + return + } +} + +// CHECK: // -----// IR Dump After FusionPlan (pto-fusion-plan) //----- // +// CHECK-LABEL: func.func @fusion_plan_treshape_boundary( +// CHECK: pto.tadd ins(%arg0, %arg0 +// CHECK-NOT: pto.fusion.group_id +// CHECK: %[[VIEW:.*]] = pto.treshape %0 +// CHECK: pto.tmul ins(%[[VIEW]], %arg1 +// CHECK-NOT: pto.fusion.group_id +// CHECK: pto.tadd ins(%arg2, %arg2 +// CHECK-SAME: {pto.fusion.group_id = 0 : i64, pto.fusion.order = 0 : i64} +// CHECK: pto.tmul ins(%2, %arg2 +// CHECK-SAME: {pto.fusion.group_id = 0 : i64, pto.fusion.order = 1 : i64} +// CHECK: return diff --git a/test/lit/tile_fusion/mark_last_use_post_span_block_level2.pto b/test/lit/tile_fusion/mark_last_use_post_span_block_level2.pto new file mode 100644 index 000000000..5d9fa6221 --- /dev/null +++ b/test/lit/tile_fusion/mark_last_use_post_span_block_level2.pto @@ -0,0 +1,62 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Guards PTOMarkLastUse against later post-span tile uses: +// - tiles that remain live after the scheduled fusion span must not be marked +// last-use on earlier in-span consumers +// - tiles with no later users after the span still qualify normally +// +// RUN: { ptoas --pto-arch=a5 --pto-level=level2 --enable-op-fusion %s --mlir-print-ir-after=pto-mark-last-use -o /dev/null 2>&1 || true; } | FileCheck %s + +module { + func.func @mark_last_use_post_span_block_level2( + %dst0_ptr: !pto.ptr, + %dst1_ptr: !pto.ptr, + %src0: !pto.tile_buf, + %src1: !pto.tile_buf, + %src2: !pto.tile_buf) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + + %dst0_view = pto.make_tensor_view %dst0_ptr, + shape = [%c32, %c32], strides = [%c32, %c1] + : !pto.tensor_view + %dst0_part = pto.partition_view %dst0_view, + offsets = [%c0, %c0], sizes = [%c32, %c32] + : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + %dst1_view = pto.make_tensor_view %dst1_ptr, + shape = [%c32, %c32], strides = [%c32, %c1] + : !pto.tensor_view + %dst1_part = pto.partition_view %dst1_view, + offsets = [%c0, %c0], sizes = [%c32, %c32] + : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + + %tmp0 = pto.alloc_tile : !pto.tile_buf + %tmp1 = pto.alloc_tile : !pto.tile_buf + + pto.tadd ins(%src0, %src1 : !pto.tile_buf, + !pto.tile_buf) + outs(%tmp0 : !pto.tile_buf) + pto.tadd ins(%tmp0, %src2 : !pto.tile_buf, + !pto.tile_buf) + outs(%tmp1 : !pto.tile_buf) + pto.tstore ins(%tmp1 : !pto.tile_buf) + outs(%dst0_part : !pto.partition_tensor_view<32x32xf32>) + pto.tstore ins(%tmp0 : !pto.tile_buf) + outs(%dst1_part : !pto.partition_tensor_view<32x32xf32>) + return + } +} + +// CHECK: // -----// IR Dump After PTOMarkLastUse (pto-mark-last-use) //----- // +// CHECK-LABEL: func.func @mark_last_use_post_span_block_level2( +// CHECK: pto.tadd ins(%{{.*}}, %{{.*}} : !pto.tile_buf, !pto.tile_buf) outs(%{{.*}} : !pto.tile_buf) {pto.fusion.group_id = 0 : i64, pto.fusion.order = 0 : i64, pto.last_use = array} +// CHECK-NEXT: pto.tadd ins(%{{.*}}, %{{.*}} : !pto.tile_buf, !pto.tile_buf) outs(%{{.*}} : !pto.tile_buf) {pto.fusion.group_id = 0 : i64, pto.fusion.order = 1 : i64, pto.last_use = array} +// CHECK: pto.tstore +// CHECK: pto.tstore ins(%{{.*}} : !pto.tile_buf) diff --git a/test/lit/tile_fusion/mark_last_use_repeated_ssa_level2.pto b/test/lit/tile_fusion/mark_last_use_repeated_ssa_level2.pto new file mode 100644 index 000000000..78c6d6f1c --- /dev/null +++ b/test/lit/tile_fusion/mark_last_use_repeated_ssa_level2.pto @@ -0,0 +1,53 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Guards PTOMarkLastUse when the same tile SSA value occupies multiple tile +// operand slots of the same fused op. Each slot must be evaluated separately, +// so repeated SSA operands can still both be last-use. +// +// RUN: { ptoas --pto-arch=a5 --pto-level=level2 --enable-op-fusion %s --mlir-print-ir-after=pto-mark-last-use -o /dev/null 2>&1 || true; } | FileCheck %s +// RUN: ptoas --pto-arch=a5 --pto-level=level2 --enable-op-fusion %s -o - | FileCheck %s --check-prefix=CPP + +module { + func.func @mark_last_use_repeated_ssa_level2( + %dst_ptr: !pto.ptr, + %src: !pto.tile_buf) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c32, %c32], strides = [%c32, %c1] + : !pto.tensor_view + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0], sizes = [%c32, %c32] + : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + + %tmp = pto.alloc_tile : !pto.tile_buf + %out = pto.alloc_tile : !pto.tile_buf + + pto.tadd ins(%src, %src : !pto.tile_buf, + !pto.tile_buf) + outs(%tmp : !pto.tile_buf) + pto.tadd ins(%tmp, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%out : !pto.tile_buf) + pto.tstore ins(%out : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<32x32xf32>) + return + } +} + +// CHECK: // -----// IR Dump After PTOMarkLastUse (pto-mark-last-use) //----- // +// CHECK-LABEL: func.func @mark_last_use_repeated_ssa_level2( +// CHECK: pto.tadd ins(%{{.*}}, %{{.*}} : !pto.tile_buf, !pto.tile_buf) outs(%{{.*}} : !pto.tile_buf) {pto.fusion.group_id = 0 : i64, pto.fusion.order = 0 : i64, pto.last_use = array} +// CHECK-NEXT: pto.tadd ins(%{{.*}}, %{{.*}} : !pto.tile_buf, !pto.tile_buf) outs(%{{.*}} : !pto.tile_buf) {pto.fusion.group_id = 0 : i64, pto.fusion.order = 1 : i64, pto.last_use = array} + +// CPP-LABEL: __global__ AICORE void mark_last_use_repeated_ssa_level2( +// CPP: {{\[\[pto::last_use\(0, 1, 1\)\]\] TADD\(}} +// CPP: {{\[\[pto::last_use\(0, 1, 1\)\]\] TADD\(}} diff --git a/test/lit/tile_fusion/mark_last_use_slot_mask_level2.pto b/test/lit/tile_fusion/mark_last_use_slot_mask_level2.pto new file mode 100644 index 000000000..9371ca987 --- /dev/null +++ b/test/lit/tile_fusion/mark_last_use_slot_mask_level2.pto @@ -0,0 +1,79 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Guards PTOMarkLastUse slot encoding on level2 fusion: +// - scalar operands must not occupy last-use slots +// - DPS init/output tiles must keep a slot and be forced to 0 +// - ops with only scalar inputs still get one output-tile slot marked 0 +// +// RUN: { ptoas --pto-arch=a5 --pto-level=level2 --enable-op-fusion %s --mlir-print-ir-after=pto-mark-last-use -o /dev/null 2>&1 || true; } | FileCheck %s +// RUN: { ptoas --pto-arch=a5 --pto-level=level2 --enable-op-fusion %s --mlir-print-ir-after-all -o /dev/null 2>&1 || true; } | FileCheck %s --check-prefix=MARKER +// RUN: ptoas --pto-arch=a5 --pto-level=level2 --enable-op-fusion %s -o - | FileCheck %s --check-prefix=CPP + +module { + func.func @mark_last_use_slot_mask_level2( + %dst_ptr: !pto.ptr, + %src: !pto.tile_buf) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %alpha = arith.constant 2.0 : f32 + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c32, %c32], strides = [%c32, %c1] + : !pto.tensor_view + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0], sizes = [%c32, %c32] + : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + + %tmp0 = pto.alloc_tile : !pto.tile_buf + %tmp1 = pto.alloc_tile : !pto.tile_buf + %tmp2 = pto.alloc_tile : !pto.tile_buf + %tmp3 = pto.alloc_tile : !pto.tile_buf + %tmp4 = pto.alloc_tile : !pto.tile_buf + + pto.texpands ins(%alpha : f32) outs(%tmp0 : !pto.tile_buf) + pto.tadds ins(%tmp0, %alpha : !pto.tile_buf, f32) + outs(%tmp1 : !pto.tile_buf) + pto.tadd ins(%tmp1, %src : !pto.tile_buf, + !pto.tile_buf) + outs(%tmp2 : !pto.tile_buf) + pto.tsub ins(%tmp2, %src : !pto.tile_buf, + !pto.tile_buf) + outs(%tmp3 : !pto.tile_buf) + pto.tsubs ins(%tmp3, %alpha : !pto.tile_buf, f32) + outs(%tmp4 : !pto.tile_buf) + pto.tstore ins(%tmp4 : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<32x32xf32>) + return + } +} + +// CHECK: // -----// IR Dump After PTOMarkLastUse (pto-mark-last-use) //----- // +// CHECK-LABEL: func.func @mark_last_use_slot_mask_level2( +// CHECK: pto.texpands ins(%{{.*}} : f32) outs(%{{.*}} : !pto.tile_buf) {pto.fusion.group_id = 0 : i64, pto.fusion.order = 0 : i64, pto.last_use = array} +// CHECK-NEXT: pto.tadds ins(%{{.*}}, %{{.*}} : !pto.tile_buf, f32) outs(%{{.*}} : !pto.tile_buf) {pto.fusion.group_id = 0 : i64, pto.fusion.order = 1 : i64, pto.last_use = array} +// CHECK-NEXT: pto.tadd ins(%{{.*}}, %{{.*}} : !pto.tile_buf, !pto.tile_buf) outs(%{{.*}} : !pto.tile_buf) {pto.fusion.group_id = 0 : i64, pto.fusion.order = 2 : i64, pto.last_use = array} +// CHECK-NEXT: pto.tsub ins(%{{.*}}, %{{.*}} : !pto.tile_buf, !pto.tile_buf) outs(%{{.*}} : !pto.tile_buf) {pto.fusion.group_id = 0 : i64, pto.fusion.order = 3 : i64, pto.last_use = array} +// CHECK-NEXT: pto.tsubs ins(%{{.*}}, %{{.*}} : !pto.tile_buf, f32) outs(%{{.*}} : !pto.tile_buf) {pto.fusion.group_id = 0 : i64, pto.fusion.order = 4 : i64, pto.last_use = array} + +// MARKER: // -----// IR Dump After {anonymous}::EmitPTOManualPass () //----- // +// MARKER: emitc.call_opaque "PTOAS__LAST_USE__TEXPANDS__0" +// MARKER: emitc.call_opaque "PTOAS__LAST_USE__TADDS__0__1" +// MARKER: emitc.call_opaque "PTOAS__LAST_USE__TADD__0__1__0" +// MARKER: emitc.call_opaque "PTOAS__LAST_USE__TSUB__0__1__1" +// MARKER: emitc.call_opaque "PTOAS__LAST_USE__TSUBS__0__1" +// MARKER-NOT: pto::last_use + +// CPP-LABEL: __global__ AICORE void mark_last_use_slot_mask_level2( +// CPP-NOT: PTOAS__LAST_USE__ +// CPP: {{\[\[pto::last_use\(0\)\]\] TEXPANDS\(}} +// CPP: {{\[\[pto::last_use\(0, 1\)\]\] TADDS\(}} +// CPP: {{\[\[pto::last_use\(0, 1, 0\)\]\] TADD\(}} +// CPP: {{\[\[pto::last_use\(0, 1, 1\)\]\] TSUB\(}} +// CPP: {{\[\[pto::last_use\(0, 1\)\]\] TSUBS\(}} diff --git a/test/lit/tile_fusion/op_fusion_adapter_placement_level2_tadd.pto b/test/lit/tile_fusion/op_fusion_adapter_placement_level2_tadd.pto new file mode 100644 index 000000000..5c2f81d4e --- /dev/null +++ b/test/lit/tile_fusion/op_fusion_adapter_placement_level2_tadd.pto @@ -0,0 +1,75 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Guards level2 fusion adapter placement around flat scheduled fusion spans, +// shared memref lowering, optional sync insertion, and the non-fused baseline. +// +// RUN: ptoas --pto-arch=a5 --pto-level=level2 %s --emit-pto-ir -o - | FileCheck %s --check-prefix=NOFUSE +// RUN: { ptoas --pto-arch=a5 --pto-level=level2 --enable-op-fusion --emit-pto-ir %s --mlir-print-ir-after=pto-view-to-memref -o /dev/null 2>&1 || true; } | FileCheck %s --check-prefix=SEAM +// RUN: { ptoas --pto-arch=a5 --pto-level=level2 --enable-op-fusion --enable-insert-sync --emit-pto-ir %s --mlir-print-ir-after=pto-insert-sync -o /dev/null 2>&1 || true; } | FileCheck %s --check-prefix=SYNC + +module { + func.func @fusion_adapter_placement_level2_tadd(%dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c32, %c32], strides = [%c32, %c1] + : !pto.tensor_view + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0], sizes = [%c32, %c32] + : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + + %a = pto.alloc_tile : !pto.tile_buf + %b = pto.alloc_tile : !pto.tile_buf + %c = pto.alloc_tile : !pto.tile_buf + %d = pto.alloc_tile : !pto.tile_buf + %tmp0 = pto.alloc_tile : !pto.tile_buf + %tmp1 = pto.alloc_tile : !pto.tile_buf + %out = pto.alloc_tile : !pto.tile_buf + + pto.tadd ins(%a, %b : !pto.tile_buf, !pto.tile_buf) outs(%tmp0 : !pto.tile_buf) + pto.tadd ins(%c, %d : !pto.tile_buf, !pto.tile_buf) outs(%tmp1 : !pto.tile_buf) + pto.tadd ins(%tmp0, %tmp1 : !pto.tile_buf, !pto.tile_buf) outs(%out : !pto.tile_buf) + pto.tstore ins(%out : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<32x32xf32>) + return + } +} + +// NOFUSE-LABEL: func.func @fusion_adapter_placement_level2_tadd( +// NOFUSE-NOT: pto.fusion_region +// NOFUSE: pto.tadd ins( +// NOFUSE-NEXT: pto.tadd ins( +// NOFUSE-NEXT: pto.tadd ins( +// NOFUSE: pto.tstore ins( +// NOFUSE: return + +// SEAM-LABEL: func.func @fusion_adapter_placement_level2_tadd( +// SEAM-NOT: pto.fusion_region +// SEAM: memref.alloc() +// SEAM: pto.bind_tile +// SEAM: memref.alloc() +// SEAM: pto.bind_tile +// SEAM: memref.alloc() +// SEAM: pto.bind_tile +// SEAM: pto.tadd{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 0 : i64{{.*}}} +// SEAM: pto.tadd{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 1 : i64{{.*}}} +// SEAM: pto.tadd{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 2 : i64{{.*}}} +// SEAM: pto.tstore ins( +// SEAM: return + +// SYNC-LABEL: func.func @fusion_adapter_placement_level2_tadd( +// SYNC-NOT: pto.fusion_region +// SYNC: pto.tadd{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 0 : i64{{.*}}} +// SYNC: pto.tadd{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 1 : i64{{.*}}} +// SYNC: pto.tadd{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 2 : i64{{.*}}} +// SYNC: pto.tstore ins( +// SYNC: pto.barrier {pto.auto_sync_tail_barrier} +// SYNC: return diff --git a/test/lit/tile_fusion/op_fusion_adapter_placement_level3_tadd.pto b/test/lit/tile_fusion/op_fusion_adapter_placement_level3_tadd.pto new file mode 100644 index 000000000..09d319e07 --- /dev/null +++ b/test/lit/tile_fusion/op_fusion_adapter_placement_level3_tadd.pto @@ -0,0 +1,75 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Guards level3 fusion adapter placement so manual-address tile IR lowers +// through the shared memref path while preserving flat scheduled fusion spans, +// and stays non-fused when op fusion is disabled. +// +// RUN: ptoas --pto-arch=a5 --pto-level=level3 %s --emit-pto-ir -o - | FileCheck %s --check-prefix=NOFUSE +// RUN: { ptoas --pto-arch=a5 --pto-level=level3 --enable-op-fusion --emit-pto-ir %s --mlir-print-ir-after=pto-view-to-memref -o /dev/null 2>&1 || true; } | FileCheck %s --check-prefix=SEAM + +module { + func.func @fusion_adapter_placement_level3_tadd(%dst_ptr: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + + %dst_view = pto.make_tensor_view %dst_ptr, + shape = [%c32, %c32], strides = [%c32, %c1] + : !pto.tensor_view + %dst_part = pto.partition_view %dst_view, + offsets = [%c0, %c0], sizes = [%c32, %c32] + : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + + %c0_i64 = arith.constant 0 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c12288_i64 = arith.constant 12288 : i64 + %c16384_i64 = arith.constant 16384 : i64 + %c20480_i64 = arith.constant 20480 : i64 + %c24576_i64 = arith.constant 24576 : i64 + %c28672_i64 = arith.constant 28672 : i64 + + %a = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + %b = pto.alloc_tile addr = %c4096_i64 : !pto.tile_buf + %c = pto.alloc_tile addr = %c8192_i64 : !pto.tile_buf + %d = pto.alloc_tile addr = %c12288_i64 : !pto.tile_buf + %tmp0 = pto.alloc_tile addr = %c16384_i64 : !pto.tile_buf + %tmp1 = pto.alloc_tile addr = %c20480_i64 : !pto.tile_buf + %out = pto.alloc_tile addr = %c24576_i64 : !pto.tile_buf + + pto.tadd ins(%a, %b : !pto.tile_buf, !pto.tile_buf) outs(%tmp0 : !pto.tile_buf) + pto.tadd ins(%c, %d : !pto.tile_buf, !pto.tile_buf) outs(%tmp1 : !pto.tile_buf) + pto.tadd ins(%tmp0, %tmp1 : !pto.tile_buf, !pto.tile_buf) outs(%out : !pto.tile_buf) + pto.tstore ins(%out : !pto.tile_buf) + outs(%dst_part : !pto.partition_tensor_view<32x32xf32>) + return + } +} + +// NOFUSE-LABEL: func.func @fusion_adapter_placement_level3_tadd( +// NOFUSE-NOT: pto.fusion_region +// NOFUSE: pto.tadd ins( +// NOFUSE-NEXT: pto.tadd ins( +// NOFUSE-NEXT: pto.tadd ins( +// NOFUSE: pto.tstore ins( +// NOFUSE: return + +// SEAM: // -----// IR Dump After PTOViewToMemref (pto-view-to-memref) //----- // +// SEAM-LABEL: func.func @fusion_adapter_placement_level3_tadd( +// SEAM-NOT: pto.fusion_region +// SEAM: pto.pointer_cast(%c0_i64) +// SEAM: pto.pointer_cast(%c4096_i64) +// SEAM: pto.pointer_cast(%c8192_i64) +// SEAM: pto.pointer_cast(%c12288_i64) +// SEAM: pto.pointer_cast(%c16384_i64) +// SEAM: pto.pointer_cast(%c20480_i64) +// SEAM: pto.pointer_cast(%c24576_i64) +// SEAM: pto.tadd{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 0 : i64{{.*}}} +// SEAM: pto.tadd{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 1 : i64{{.*}}} +// SEAM: pto.tadd{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 2 : i64{{.*}}} diff --git a/test/lit/tile_fusion/op_fusion_cli_flags.pto b/test/lit/tile_fusion/op_fusion_cli_flags.pto new file mode 100644 index 000000000..0c9ee9960 --- /dev/null +++ b/test/lit/tile_fusion/op_fusion_cli_flags.pto @@ -0,0 +1,29 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Guards --enable-op-fusion CLI visibility and warning behavior on the +// EmitC mainline frontend-fusion gate. +// +// RUN: ptoas --help-hidden 2>&1 | FileCheck %s --check-prefix=HELP +// RUN: ptoas --pto-arch=a5 --pto-level=level2 --enable-op-fusion %s -o /dev/null 2>&1 | FileCheck %s --allow-empty --check-prefix=LEVEL2 +// RUN: ptoas --pto-arch=a5 --pto-level=level1 --enable-op-fusion %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=LEVEL1 +// RUN: ptoas --pto-arch=a3 --enable-op-fusion %s -o /dev/null 2>&1 | FileCheck %s --check-prefix=A3 + +module { + func.func @fusion_cli_flags() { + return + } +} + +// HELP: --enable-op-fusion + +// LEVEL2-NOT: Warning: --enable-op-fusion is ignored + +// LEVEL1: Warning: --enable-op-fusion is ignored because --pto-level=level2 or level3 is required. + +// A3: Warning: --enable-op-fusion is ignored because --pto-arch=a5 is required. diff --git a/test/lit/tile_fusion/op_fusion_nonfused_control.pto b/test/lit/tile_fusion/op_fusion_nonfused_control.pto new file mode 100644 index 000000000..dc753b288 --- /dev/null +++ b/test/lit/tile_fusion/op_fusion_nonfused_control.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Guards the non-fused control surface for frontend fusion: +// - no flag keeps the baseline unfused +// - A3 with the flag warns and stays unfused +// - level1 with the flag warns and stays unfused +// +// RUN: ptoas --pto-arch=a5 --pto-level=level2 --emit-pto-ir %s -o - | FileCheck %s --check-prefix=NOFUSE +// RUN: ptoas --pto-arch=a3 --enable-op-fusion --emit-pto-ir %s -o - 2>&1 | FileCheck %s --check-prefix=A3 +// RUN: ptoas --pto-arch=a5 --pto-level=level1 --enable-op-fusion --emit-pto-ir %s -o - 2>&1 | FileCheck %s --check-prefix=LEVEL1 + +module { + func.func @fusion_nonfused_control( + %a: !pto.tile_buf, + %b: !pto.tile_buf, + %c: !pto.tile_buf) { + %tmp = pto.alloc_tile : !pto.tile_buf + %out = pto.alloc_tile : !pto.tile_buf + pto.tadd ins(%a, %b : !pto.tile_buf, !pto.tile_buf) + outs(%tmp : !pto.tile_buf) + pto.tadd ins(%tmp, %c : !pto.tile_buf, !pto.tile_buf) + outs(%out : !pto.tile_buf) + return + } +} + +// NOFUSE-LABEL: func.func @fusion_nonfused_control( +// NOFUSE-NOT: pto.fusion_region +// NOFUSE: pto.tadd ins( +// NOFUSE-NEXT: pto.tadd ins( +// NOFUSE: return + +// A3: Warning: --enable-op-fusion is ignored because --pto-arch=a5 is required. +// A3-LABEL: func.func @fusion_nonfused_control( +// A3-NOT: pto.fusion_region +// A3: pto.tadd ins( +// A3-NEXT: pto.tadd ins( + +// LEVEL1: Warning: --enable-op-fusion is ignored because --pto-level=level2 or level3 is required. +// LEVEL1-LABEL: func.func @fusion_nonfused_control( +// LEVEL1-NOT: pto.fusion_region +// LEVEL1: pto.tadd ins( +// LEVEL1-NEXT: pto.tadd ins( diff --git a/test/lit/tile_fusion/op_scheduling_basic.pto b/test/lit/tile_fusion/op_scheduling_basic.pto new file mode 100644 index 000000000..83f2e43cb --- /dev/null +++ b/test/lit/tile_fusion/op_scheduling_basic.pto @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Guards OpScheduling compaction of one planned group into a contiguous span +// while preserving fusion order. +// +// RUN: ptoas --pto-arch=a5 --pto-level=level2 --enable-op-fusion --emit-pto-ir %s --mlir-print-ir-after=pto-op-scheduling -o /dev/null 2>&1 | FileCheck %s + +module { + func.func @op_scheduling_basic( + %full0: !pto.tile_buf, + %full1: !pto.tile_buf, + %full2: !pto.tile_buf, + %full3: !pto.tile_buf, + %row0: !pto.tile_buf, + %row1: !pto.tile_buf) { + %tmp0 = pto.alloc_tile : !pto.tile_buf + %noise0 = pto.alloc_tile : !pto.tile_buf + %tmp1 = pto.alloc_tile : !pto.tile_buf + %noise1 = pto.alloc_tile : !pto.tile_buf + %tmp2 = pto.alloc_tile : !pto.tile_buf + + pto.trowexpandmul ins(%full0, %row0 : !pto.tile_buf, !pto.tile_buf) outs(%tmp0 : !pto.tile_buf) + pto.tadd ins(%full2, %full3 : !pto.tile_buf, !pto.tile_buf) outs(%noise0 : !pto.tile_buf) + pto.trowexpandmul ins(%full1, %row1 : !pto.tile_buf, !pto.tile_buf) outs(%tmp1 : !pto.tile_buf) + pto.tmul ins(%full2, %full3 : !pto.tile_buf, !pto.tile_buf) outs(%noise1 : !pto.tile_buf) + pto.tadd ins(%tmp0, %tmp1 : !pto.tile_buf, !pto.tile_buf) outs(%tmp2 : !pto.tile_buf) + return + } +} + +// CHECK: // -----// IR Dump After OpScheduling (pto-op-scheduling) //----- // +// CHECK-LABEL: func.func @op_scheduling_basic( +// CHECK: pto.trowexpandmul{{.*}}outs(%[[DST0:[0-9]+]] : +// CHECK-SAME: {pto.fusion.group_id = 0 : i64, pto.fusion.order = 0 : i64} +// CHECK-NEXT: pto.trowexpandmul{{.*}}outs(%[[DST1:[0-9]+]] : +// CHECK-SAME: {pto.fusion.group_id = 0 : i64, pto.fusion.order = 1 : i64} +// CHECK-NEXT: pto.tadd ins(%[[DST0]], %[[DST1]] +// CHECK-SAME: {pto.fusion.group_id = 0 : i64, pto.fusion.order = 2 : i64} +// CHECK-NEXT: pto.tadd ins(%arg2, %arg3 +// CHECK-NEXT: pto.tmul ins(%arg2, %arg3 diff --git a/test/lit/tile_fusion/op_scheduling_negative_call_boundary.pto b/test/lit/tile_fusion/op_scheduling_negative_call_boundary.pto new file mode 100644 index 000000000..fc818a401 --- /dev/null +++ b/test/lit/tile_fusion/op_scheduling_negative_call_boundary.pto @@ -0,0 +1,38 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Guards OpScheduling against moving planned fusion members across an unrelated +// function call boundary. +// +// RUN: not ptoas --pto-arch=a5 --pto-level=level2 --enable-op-fusion --emit-pto-ir %s --mlir-print-ir-after=pto-op-scheduling -o /dev/null 2>&1 | FileCheck %s + +module { + func.func private @touch_boundary() + + func.func @op_scheduling_negative_call_boundary( + %full0: !pto.tile_buf, + %full1: !pto.tile_buf, + %full2: !pto.tile_buf) { + %tmp0 = pto.alloc_tile : !pto.tile_buf + %tmp1 = pto.alloc_tile : !pto.tile_buf + + pto.tadd ins(%full0, %full1 : !pto.tile_buf, !pto.tile_buf) outs(%tmp0 : !pto.tile_buf) + func.call @touch_boundary() : () -> () + pto.tadd ins(%tmp0, %full2 : !pto.tile_buf, !pto.tile_buf) outs(%tmp1 : !pto.tile_buf) + return + } +} + +// CHECK: // -----// IR Dump After OpScheduling (pto-op-scheduling) //----- // +// CHECK-LABEL: func.func @op_scheduling_negative_call_boundary( +// CHECK: pto.tadd{{.*}}outs(%[[DST0:[0-9]+]] +// CHECK-SAME: {pto.fusion.group_id = 0 : i64, pto.fusion.order = 0 : i64} +// CHECK-NEXT: call @touch_boundary() +// CHECK-NEXT: pto.tadd ins(%[[DST0]], %arg2 +// CHECK-SAME: {pto.fusion.group_id = 0 : i64, pto.fusion.order = 1 : i64} +// CHECK: error: expected one contiguous span per pto.fusion.group_id within a basic block diff --git a/test/lit/tile_fusion/op_scheduling_negative_region.pto b/test/lit/tile_fusion/op_scheduling_negative_region.pto new file mode 100644 index 000000000..bb48e0fb2 --- /dev/null +++ b/test/lit/tile_fusion/op_scheduling_negative_region.pto @@ -0,0 +1,46 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Guards OpScheduling against moving planned fusion members across an unrelated +// region boundary. +// +// RUN: not ptoas --pto-arch=a5 --pto-level=level2 --enable-op-fusion --emit-pto-ir %s --mlir-print-ir-after=pto-op-scheduling -o /dev/null 2>&1 | FileCheck %s + +module { + func.func @op_scheduling_negative_region( + %cond: i1, + %full0: !pto.tile_buf, + %full1: !pto.tile_buf, + %full2: !pto.tile_buf, + %row0: !pto.tile_buf, + %row1: !pto.tile_buf) { + %tmp0 = pto.alloc_tile : !pto.tile_buf + %noise = pto.alloc_tile : !pto.tile_buf + %tmp1 = pto.alloc_tile : !pto.tile_buf + + pto.trowexpandmul ins(%full0, %row0 : !pto.tile_buf, !pto.tile_buf) outs(%tmp0 : !pto.tile_buf) + scf.if %cond { + pto.tadd ins(%full1, %full2 : !pto.tile_buf, !pto.tile_buf) outs(%noise : !pto.tile_buf) + scf.yield + } else { + scf.yield + } + pto.trowexpandmul ins(%tmp0, %row1 : !pto.tile_buf, !pto.tile_buf) outs(%tmp1 : !pto.tile_buf) + return + } +} + +// CHECK: // -----// IR Dump After OpScheduling (pto-op-scheduling) //----- // +// CHECK-LABEL: func.func @op_scheduling_negative_region( +// CHECK: pto.trowexpandmul{{.*}}outs(%[[DST0:[0-9]+]] +// CHECK-SAME: {pto.fusion.group_id = 0 : i64, pto.fusion.order = 0 : i64} +// CHECK-NEXT: scf.if %arg0 { +// CHECK: pto.tadd ins(%arg2, %arg3 +// CHECK: pto.trowexpandmul ins(%[[DST0]], %arg5 +// CHECK-SAME: {pto.fusion.group_id = 0 : i64, pto.fusion.order = 1 : i64} +// CHECK: error: expected one contiguous span per pto.fusion.group_id within a basic block diff --git a/test/lit/tile_fusion/op_scheduling_negative_ssa.pto b/test/lit/tile_fusion/op_scheduling_negative_ssa.pto new file mode 100644 index 000000000..4633d5b0d --- /dev/null +++ b/test/lit/tile_fusion/op_scheduling_negative_ssa.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Guards OpScheduling against moving planned fusion members across an SSA +// definition consumed by a later group member. +// +// RUN: not ptoas --pto-arch=a5 --pto-level=level2 --enable-op-fusion --emit-pto-ir %s --mlir-print-ir-after=pto-op-scheduling -o /dev/null 2>&1 | FileCheck %s + +module { + func.func private @produce_row( + !pto.tile_buf) + -> !pto.tile_buf + + func.func @op_scheduling_negative_ssa( + %full0: !pto.tile_buf, + %row0: !pto.tile_buf) { + %tmp0 = pto.alloc_tile : !pto.tile_buf + %tmp1 = pto.alloc_tile : !pto.tile_buf + + pto.trowexpandmul ins(%full0, %row0 : !pto.tile_buf, !pto.tile_buf) outs(%tmp0 : !pto.tile_buf) + %rowTmp = func.call @produce_row(%tmp0) : (!pto.tile_buf) -> !pto.tile_buf + pto.trowexpandmul ins(%tmp0, %rowTmp : !pto.tile_buf, !pto.tile_buf) outs(%tmp1 : !pto.tile_buf) + return + } +} + +// CHECK: // -----// IR Dump After OpScheduling (pto-op-scheduling) //----- // +// CHECK-LABEL: func.func @op_scheduling_negative_ssa( +// CHECK: pto.trowexpandmul{{.*}}outs(%[[DST0:[0-9]+]] +// CHECK-SAME: {pto.fusion.group_id = 0 : i64, pto.fusion.order = 0 : i64} +// CHECK-NEXT: %[[ROW:[0-9]+]] = call @produce_row(%[[DST0]]) +// CHECK-NEXT: pto.trowexpandmul ins(%[[DST0]], %[[ROW]] +// CHECK-SAME: {pto.fusion.group_id = 0 : i64, pto.fusion.order = 1 : i64} +// CHECK: error: expected one contiguous span per pto.fusion.group_id within a basic block diff --git a/test/lit/tile_fusion/op_scheduling_pure_op_bridge.pto b/test/lit/tile_fusion/op_scheduling_pure_op_bridge.pto new file mode 100644 index 000000000..551bf8cd2 --- /dev/null +++ b/test/lit/tile_fusion/op_scheduling_pure_op_bridge.pto @@ -0,0 +1,39 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Guards OpScheduling against treating memory-effect-free non-fusion ops as +// hard barriers. Pure ops like arith.constant / arith.index_cast may appear +// between planned group members without preventing the group from becoming one +// contiguous fusion span. +// +// RUN: ptoas --pto-arch=a5 --pto-level=level2 --enable-op-fusion --emit-pto-ir %s --mlir-print-ir-after=pto-op-scheduling -o /dev/null 2>&1 | FileCheck %s --check-prefix=SCHEDULE + +module { + func.func @op_scheduling_pure_op_bridge( + %full0: !pto.tile_buf, + %full1: !pto.tile_buf, + %full2: !pto.tile_buf) { + %tmp0 = pto.alloc_tile : !pto.tile_buf + %tmp1 = pto.alloc_tile : !pto.tile_buf + %out = pto.alloc_tile : !pto.tile_buf + + pto.tadd ins(%full0, %full1 : !pto.tile_buf, !pto.tile_buf) outs(%tmp0 : !pto.tile_buf) + %c0 = arith.constant 0 : index + pto.tsub ins(%tmp0, %full2 : !pto.tile_buf, !pto.tile_buf) outs(%tmp1 : !pto.tile_buf) + %c1 = arith.index_cast %c0 : index to i32 + pto.tadd ins(%tmp1, %tmp0 : !pto.tile_buf, !pto.tile_buf) outs(%out : !pto.tile_buf) + return + } +} + +// SCHEDULE: // -----// IR Dump After OpScheduling (pto-op-scheduling) //----- // +// SCHEDULE-LABEL: func.func @op_scheduling_pure_op_bridge( +// SCHEDULE: pto.tadd{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 0 : i64} +// SCHEDULE-NEXT: pto.tsub{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 1 : i64} +// SCHEDULE-NEXT: pto.tadd{{.*}}{pto.fusion.group_id = 0 : i64, pto.fusion.order = 2 : i64} +// SCHEDULE-NOT: expected one contiguous span per pto.fusion.group_id within a basic block diff --git a/test/lit/tile_fusion/op_scheduling_treshape.pto b/test/lit/tile_fusion/op_scheduling_treshape.pto new file mode 100644 index 000000000..1e55cb5a4 --- /dev/null +++ b/test/lit/tile_fusion/op_scheduling_treshape.pto @@ -0,0 +1,49 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// Guards OpScheduling compaction across an unrelated local treshape chain while +// keeping that treshape chain outside the fused span. +// +// RUN: ptoas --pto-arch=a5 --pto-level=level2 --enable-op-fusion --emit-pto-ir %s --mlir-print-ir-after=pto-op-scheduling -o /dev/null 2>&1 | FileCheck %s + +module { + func.func @op_scheduling_treshape( + %full0: !pto.tile_buf, + %full1: !pto.tile_buf, + %full2: !pto.tile_buf, + %full3: !pto.tile_buf, + %wide: !pto.tile_buf, + %row0: !pto.tile_buf, + %row1: !pto.tile_buf) { + %tmp0 = pto.alloc_tile : !pto.tile_buf + %noise0 = pto.alloc_tile : !pto.tile_buf + %tmp1 = pto.alloc_tile : !pto.tile_buf + %tmp2 = pto.alloc_tile : !pto.tile_buf + %noise1 = pto.alloc_tile : !pto.tile_buf + + pto.trowexpandmul ins(%full0, %row0 : !pto.tile_buf, !pto.tile_buf) outs(%tmp0 : !pto.tile_buf) + pto.tadd ins(%full2, %full3 : !pto.tile_buf, !pto.tile_buf) outs(%noise0 : !pto.tile_buf) + %view = pto.treshape %noise0 : !pto.tile_buf -> !pto.tile_buf + pto.trowexpandmul ins(%full1, %row1 : !pto.tile_buf, !pto.tile_buf) outs(%tmp1 : !pto.tile_buf) + pto.tadd ins(%tmp0, %tmp1 : !pto.tile_buf, !pto.tile_buf) outs(%tmp2 : !pto.tile_buf) + pto.tmul ins(%view, %wide : !pto.tile_buf, !pto.tile_buf) outs(%noise1 : !pto.tile_buf) + return + } +} + +// CHECK: // -----// IR Dump After OpScheduling (pto-op-scheduling) //----- // +// CHECK-LABEL: func.func @op_scheduling_treshape( +// CHECK: pto.trowexpandmul{{.*}}outs(%[[DST0:[0-9]+]] : +// CHECK-SAME: {pto.fusion.group_id = 0 : i64, pto.fusion.order = 0 : i64} +// CHECK-NEXT: pto.trowexpandmul{{.*}}outs(%[[DST1:[0-9]+]] : +// CHECK-SAME: {pto.fusion.group_id = 0 : i64, pto.fusion.order = 1 : i64} +// CHECK-NEXT: pto.tadd ins(%[[DST0]], %[[DST1]] +// CHECK-SAME: {pto.fusion.group_id = 0 : i64, pto.fusion.order = 2 : i64} +// CHECK-NEXT: pto.tadd ins(%arg2, %arg3 +// CHECK-NEXT: %[[RESHAPED:[0-9]+]] = pto.treshape +// CHECK-NEXT: pto.tmul ins(%[[RESHAPED]], %arg4 diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index f0e2b9cd9..ba40a62d7 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -9,6 +9,7 @@ #include "PTO/IR/PTO.h" #include "PTO/Transforms/Passes.h" #include "PTO/Transforms/BufferizableOpInterfaceImpl.h" +#include "PTO/Transforms/CppPostprocess.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/InitAllDialects.h" @@ -205,6 +206,12 @@ static llvm::cl::opt graphSyncSolverEventIdMax( "Lower values exercise the PIPE_ALL coloring fallback sooner."), llvm::cl::init(8)); +static llvm::cl::opt enableOpFusion( + "enable-op-fusion", + llvm::cl::desc("Enable frontend tile fusion on the A5 EmitC mainline " + "(requires --pto-arch=a5 and --pto-level=level2|level3)"), + llvm::cl::init(false)); + static llvm::cl::opt disableInferLayout( "disable-infer-layout", llvm::cl::desc("Disable PTO layout inference pass (static-only)"), @@ -1054,6 +1061,20 @@ int main(int argc, char **argv) { return 1; } + if (enableOpFusion) { + if (arch != "a5") { + llvm::errs() << "Warning: --enable-op-fusion is ignored because " + "--pto-arch=a5 is required.\n"; + } else if (effectiveLevel == PTOBuildLevel::Level1) { + llvm::errs() << "Warning: --enable-op-fusion is ignored because " + "--pto-level=level2 or level3 is required.\n"; + } + } + + const bool enableA5FrontendFusionPath = + enableOpFusion && arch == "a5" && + effectiveLevel != PTOBuildLevel::Level1; + bool invalidAutoSyncTailHint = false; module->walk([&](mlir::func::FuncOp func) { auto hintAttr = @@ -1151,6 +1172,15 @@ int main(int argc, char **argv) { if (!disableInferLayout) pm.addNestedPass(pto::createInferPTOLayoutPass()); pm.addNestedPass(pto::createPTOA5NormalizeTMovPass()); + + // Keep frontend fusion on tile-native PTO IR and annotate last_use directly + // on scheduled block-local spans before the shared mainline lowers tiles. + if (enableA5FrontendFusionPath) { + pm.addNestedPass(pto::createFusionPlanPass()); + pm.addNestedPass(pto::createOpSchedulingPass()); + pm.addNestedPass(pto::createPTOMarkLastUsePass()); + } + pm.addPass(pto::createPTOViewToMemrefPass()); if (effectiveLevel != PTOBuildLevel::Level3) { @@ -1242,6 +1272,7 @@ int main(int argc, char **argv) { rewriteAsyncEventMarkers(cppOutput); rewritePtrScalarMarkers(cppOutput); rewriteEventIdArrayMarkers(cppOutput); + pto::rewriteLastUseMarkersInCpp(cppOutput); rewriteAddPtrTraceMarkers(cppOutput, emitAddPtrTrace); rewriteScalarConstantDecls(cppOutput); rewriteHoistedGlobalTensorDecls(cppOutput);