From 7d742ee3a2205951e4ee4231e1032ba4d04871f8 Mon Sep 17 00:00:00 2001 From: HecreReed <821896444@qq.com> Date: Thu, 7 May 2026 14:42:04 +0800 Subject: [PATCH 1/3] fix: preserve distinct storage for scalar bitwise ops --- lib/PTO/IR/PTO.cpp | 53 ++++++++++++++++++- lib/PTO/Transforms/PTOPlanMemory.cpp | 9 ++++ ...14_scalar_bitwise_plan_memory_distinct.pto | 26 +++++++++ .../issue614_tands_same_storage_verify.pto | 19 +++++++ 4 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 test/lit/pto/issue614_scalar_bitwise_plan_memory_distinct.pto create mode 100644 test/lit/pto/issue614_tands_same_storage_verify.pto diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 103f07088..913209888 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -2803,10 +2803,61 @@ verifyShiftLikeBinaryTileOpCommon(Operation *op, Type src0Ty, Type src1Ty, return e0; } +struct StaticLocalTileStorage { + pto::AddressSpace addressSpace; + int64_t addr; +}; + +static bool isLocalStorageSpace(std::optional addressSpace) { + return addressSpace && *addressSpace != pto::AddressSpace::GM && + *addressSpace != pto::AddressSpace::Zero; +} + +static std::optional +getStaticLocalTileStorage(Value value) { + if (!value || isa(value)) + return std::nullopt; + + if (auto bitcast = value.getDefiningOp()) + return getStaticLocalTileStorage(bitcast.getSrc()); + if (auto reshape = value.getDefiningOp()) + return getStaticLocalTileStorage(reshape.getSrc()); + if (auto setValidShape = value.getDefiningOp()) + return getStaticLocalTileStorage(setValidShape.getSource()); + if (auto subview = value.getDefiningOp()) { + // Only zero-offset subviews provably preserve the same base address. + for (Value offset : subview.getOffsets()) { + auto constOffset = getConstIndexValue(offset); + if (!constOffset || *constOffset != 0) + return std::nullopt; + } + return getStaticLocalTileStorage(subview.getSource()); + } + if (auto allocTile = value.getDefiningOp()) { + auto addressSpace = getPTOMemorySpaceEnum(allocTile.getResult().getType()); + if (!isLocalStorageSpace(addressSpace) || !allocTile.getAddr()) + return std::nullopt; + auto addr = getConstantIntegerValue(allocTile.getAddr()); + if (!addr) + return std::nullopt; + return StaticLocalTileStorage{*addressSpace, *addr}; + } + + return std::nullopt; +} + +static bool haveProvenSameStaticLocalTileStorage(Value lhs, Value rhs) { + auto lhsStorage = getStaticLocalTileStorage(lhs); + auto rhsStorage = getStaticLocalTileStorage(rhs); + return lhsStorage && rhsStorage && + lhsStorage->addressSpace == rhsStorage->addressSpace && + lhsStorage->addr == rhsStorage->addr; +} + static FailureOr verifyDistinctRowMajorUnaryTileOpCommon( Operation *op, Value src, Value dst, StringRef srcName = "src", StringRef dstName = "dst") { - if (src == dst) { + if (src == dst || haveProvenSameStaticLocalTileStorage(src, dst)) { op->emitOpError("expects src and dst to use different storage"); return failure(); } diff --git a/lib/PTO/Transforms/PTOPlanMemory.cpp b/lib/PTO/Transforms/PTOPlanMemory.cpp index beb51c020..4d45e52d7 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.cpp +++ b/lib/PTO/Transforms/PTOPlanMemory.cpp @@ -436,6 +436,15 @@ void MemLivenessAnalysis::RecursionIR(Region *region, Liveness live) { op, ptoDpsOp.getDpsInits(), stableValueOrder); genBuffers.append(scratchBuffers.begin(), scratchBuffers.end()); UpdateOpGenInfo(curOpInfo, genBuffers); + // These scalar bitwise ops are lowered to ISA forms that require src + // and dst to reside in distinct local storage. + if (auto tandsOp = dyn_cast(op)) { + RecordSemanticConflict(tandsOp.getSrc(), tandsOp.getDst()); + } else if (auto torsOp = dyn_cast(op)) { + RecordSemanticConflict(torsOp.getSrc(), torsOp.getDst()); + } else if (auto txorsOp = dyn_cast(op)) { + RecordSemanticConflict(txorsOp.getSrc(), txorsOp.getDst()); + } for (const auto &conflictPair : getScratchConflictPairsFromEffects(op, ptoDpsOp.getDpsInits(), stableValueOrder)) { diff --git a/test/lit/pto/issue614_scalar_bitwise_plan_memory_distinct.pto b/test/lit/pto/issue614_scalar_bitwise_plan_memory_distinct.pto new file mode 100644 index 000000000..a34d4ee99 --- /dev/null +++ b/test/lit/pto/issue614_scalar_bitwise_plan_memory_distinct.pto @@ -0,0 +1,26 @@ +// RUN: ptoas --mlir-print-ir-after=pto-plan-memory %s 2>&1 1>/dev/null | FileCheck %s + +module { + func.func @tands_plan_memory_distinct(%src_gm: memref<16x16xi16, #pto.address_space>, + %dst_gm: memref<16x16xi16, #pto.address_space>) { + %scalar = arith.constant 7 : i16 + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tload ins(%src_gm : memref<16x16xi16, #pto.address_space>) + outs(%src : !pto.tile_buf) + pto.tands ins(%src, %scalar : !pto.tile_buf, i16) + outs(%dst : !pto.tile_buf) + pto.tstore ins(%dst : !pto.tile_buf) + outs(%dst_gm : memref<16x16xi16, #pto.address_space>) + return + } +} + +// CHECK: IR Dump After PlanMemory +// CHECK: func.func @tands_plan_memory_distinct +// CHECK-NOT: memref.alloc +// CHECK-DAG: %[[ADDR0:.*]] = arith.constant 0 : i64 +// CHECK-DAG: %[[ADDR512:.*]] = arith.constant 512 : i64 +// CHECK-DAG: pto.pointer_cast(%[[ADDR0]]) : memref<16x16xi16, #pto.address_space<{{vec|ub}}>> +// CHECK-DAG: pto.pointer_cast(%[[ADDR512]]) : memref<16x16xi16, #pto.address_space<{{vec|ub}}>> diff --git a/test/lit/pto/issue614_tands_same_storage_verify.pto b/test/lit/pto/issue614_tands_same_storage_verify.pto new file mode 100644 index 000000000..d623627a9 --- /dev/null +++ b/test/lit/pto/issue614_tands_same_storage_verify.pto @@ -0,0 +1,19 @@ +// RUN: not ptoas --pto-arch=a3 %s 2>&1 | FileCheck %s + +module { + func.func @tands_same_storage_via_subview() { + %c0_i64 = arith.constant 0 : i64 + %c0 = arith.constant 0 : index + %scalar = arith.constant 7 : i16 + %tile = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + %alias = pto.subview %tile[%c0, %c0] sizes [16, 16] : + !pto.tile_buf + -> !pto.tile_buf + + pto.tands ins(%tile, %scalar : !pto.tile_buf, i16) + outs(%alias : !pto.tile_buf) + return + } +} + +// CHECK: error: 'pto.tands' op expects src and dst to use different storage From 3930c67dfc7296c9bd9adf85b626eebf96128743 Mon Sep 17 00:00:00 2001 From: HecreReed <821896444@qq.com> Date: Thu, 7 May 2026 15:03:14 +0800 Subject: [PATCH 2/3] test: relax issue614 plan-memory pointer-cast checks --- test/lit/pto/issue614_scalar_bitwise_plan_memory_distinct.pto | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/lit/pto/issue614_scalar_bitwise_plan_memory_distinct.pto b/test/lit/pto/issue614_scalar_bitwise_plan_memory_distinct.pto index a34d4ee99..bb386d093 100644 --- a/test/lit/pto/issue614_scalar_bitwise_plan_memory_distinct.pto +++ b/test/lit/pto/issue614_scalar_bitwise_plan_memory_distinct.pto @@ -22,5 +22,5 @@ module { // CHECK-NOT: memref.alloc // CHECK-DAG: %[[ADDR0:.*]] = arith.constant 0 : i64 // CHECK-DAG: %[[ADDR512:.*]] = arith.constant 512 : i64 -// CHECK-DAG: pto.pointer_cast(%[[ADDR0]]) : memref<16x16xi16, #pto.address_space<{{vec|ub}}>> -// CHECK-DAG: pto.pointer_cast(%[[ADDR512]]) : memref<16x16xi16, #pto.address_space<{{vec|ub}}>> +// CHECK-DAG: = pto.pointer_cast(%[[ADDR0]]) +// CHECK-DAG: = pto.pointer_cast(%[[ADDR512]]) From 01c9fe0e4f3b905d46f1364c7e2e83a61211c272 Mon Sep 17 00:00:00 2001 From: HecreReed <821896444@qq.com> Date: Thu, 7 May 2026 20:15:24 +0800 Subject: [PATCH 3/3] fix: mark A5-only gather scatter samples as a5 --- test/samples/Mgather/mgather.py | 3 ++- test/samples/Mscatter/mscatter.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test/samples/Mgather/mgather.py b/test/samples/Mgather/mgather.py index 3ce7786c2..d63a2008c 100644 --- a/test/samples/Mgather/mgather.py +++ b/test/samples/Mgather/mgather.py @@ -6,7 +6,7 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. -from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.ir import Context, Location, Module, InsertionPoint, StringAttr from mlir.dialects import func, arith, pto from mlir.ir import IndexType, IntegerType @@ -17,6 +17,7 @@ def build(): with Location.unknown(ctx): m = Module.create() + m.operation.attributes["pto.target_arch"] = StringAttr.get("a5") i32 = IntegerType.get_signless(32, ctx) ptr_i32 = pto.PtrType.get(i32, ctx) diff --git a/test/samples/Mscatter/mscatter.py b/test/samples/Mscatter/mscatter.py index 96896ee83..d3a3652df 100644 --- a/test/samples/Mscatter/mscatter.py +++ b/test/samples/Mscatter/mscatter.py @@ -6,7 +6,7 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. -from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.ir import Context, Location, Module, InsertionPoint, StringAttr from mlir.dialects import func, arith, pto from mlir.ir import IndexType, IntegerType @@ -17,6 +17,7 @@ def build(): with Location.unknown(ctx): m = Module.create() + m.operation.attributes["pto.target_arch"] = StringAttr.get("a5") i32 = IntegerType.get_signless(32, ctx) ptr_i32 = pto.PtrType.get(i32, ctx)