Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion lib/PTO/IR/PTO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2803,10 +2803,61 @@ verifyShiftLikeBinaryTileOpCommon(Operation *op, Type src0Ty, Type src1Ty,
return e0;
}

struct StaticLocalTileStorage {
pto::AddressSpace addressSpace;
int64_t addr;
};

static bool isLocalStorageSpace(std::optional<pto::AddressSpace> addressSpace) {
return addressSpace && *addressSpace != pto::AddressSpace::GM &&
*addressSpace != pto::AddressSpace::Zero;
}

static std::optional<StaticLocalTileStorage>
getStaticLocalTileStorage(Value value) {
if (!value || isa<BlockArgument>(value))
return std::nullopt;

if (auto bitcast = value.getDefiningOp<pto::BitcastOp>())
return getStaticLocalTileStorage(bitcast.getSrc());
if (auto reshape = value.getDefiningOp<pto::TReshapeOp>())
return getStaticLocalTileStorage(reshape.getSrc());
if (auto setValidShape = value.getDefiningOp<pto::SetValidShapeOp>())
return getStaticLocalTileStorage(setValidShape.getSource());
if (auto subview = value.getDefiningOp<pto::SubViewOp>()) {
// Only zero-offset subviews provably preserve the same base address.
for (Value offset : subview.getOffsets()) {
auto constOffset = getConstIndexValue(offset);
if (!constOffset || *constOffset != 0)
return std::nullopt;
}
return getStaticLocalTileStorage(subview.getSource());
}
if (auto allocTile = value.getDefiningOp<pto::AllocTileOp>()) {
auto addressSpace = getPTOMemorySpaceEnum(allocTile.getResult().getType());
if (!isLocalStorageSpace(addressSpace) || !allocTile.getAddr())
return std::nullopt;
auto addr = getConstantIntegerValue(allocTile.getAddr());
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The getConstantIntegerValue function expects a Value as an argument, but allocTile.getAddr() returns an IntegerAttr (attribute). This will cause a compilation error. Please use the appropriate way to retrieve the integer value from the attribute, for example, allocTile.getAddr().getInt().

if (!addr)
return std::nullopt;
return StaticLocalTileStorage{*addressSpace, *addr};
}

return std::nullopt;
}

static bool haveProvenSameStaticLocalTileStorage(Value lhs, Value rhs) {
auto lhsStorage = getStaticLocalTileStorage(lhs);
auto rhsStorage = getStaticLocalTileStorage(rhs);
return lhsStorage && rhsStorage &&
lhsStorage->addressSpace == rhsStorage->addressSpace &&
lhsStorage->addr == rhsStorage->addr;
}

static FailureOr<Type> verifyDistinctRowMajorUnaryTileOpCommon(
Operation *op, Value src, Value dst, StringRef srcName = "src",
StringRef dstName = "dst") {
if (src == dst) {
if (src == dst || haveProvenSameStaticLocalTileStorage(src, dst)) {
op->emitOpError("expects src and dst to use different storage");
return failure();
}
Expand Down
9 changes: 9 additions & 0 deletions lib/PTO/Transforms/PTOPlanMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,15 @@ void MemLivenessAnalysis::RecursionIR(Region *region, Liveness live) {
op, ptoDpsOp.getDpsInits(), stableValueOrder);
genBuffers.append(scratchBuffers.begin(), scratchBuffers.end());
UpdateOpGenInfo(curOpInfo, genBuffers);
// These scalar bitwise ops are lowered to ISA forms that require src
// and dst to reside in distinct local storage.
if (auto tandsOp = dyn_cast<pto::TAndSOp>(op)) {
RecordSemanticConflict(tandsOp.getSrc(), tandsOp.getDst());
} else if (auto torsOp = dyn_cast<pto::TOrSOp>(op)) {
RecordSemanticConflict(torsOp.getSrc(), torsOp.getDst());
} else if (auto txorsOp = dyn_cast<pto::TXorSOp>(op)) {
RecordSemanticConflict(txorsOp.getSrc(), txorsOp.getDst());
}
for (const auto &conflictPair :
getScratchConflictPairsFromEffects(op, ptoDpsOp.getDpsInits(),
stableValueOrder)) {
Expand Down
26 changes: 26 additions & 0 deletions test/lit/pto/issue614_scalar_bitwise_plan_memory_distinct.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: ptoas --mlir-print-ir-after=pto-plan-memory %s 2>&1 1>/dev/null | FileCheck %s

module {
func.func @tands_plan_memory_distinct(%src_gm: memref<16x16xi16, #pto.address_space<gm>>,
%dst_gm: memref<16x16xi16, #pto.address_space<gm>>) {
%scalar = arith.constant 7 : i16
%src = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=i16, rows=16, cols=16, v_row=16, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%dst = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=i16, rows=16, cols=16, v_row=16, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>

pto.tload ins(%src_gm : memref<16x16xi16, #pto.address_space<gm>>)
outs(%src : !pto.tile_buf<loc=vec, dtype=i16, rows=16, cols=16, v_row=16, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
pto.tands ins(%src, %scalar : !pto.tile_buf<loc=vec, dtype=i16, rows=16, cols=16, v_row=16, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>, i16)
outs(%dst : !pto.tile_buf<loc=vec, dtype=i16, rows=16, cols=16, v_row=16, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
pto.tstore ins(%dst : !pto.tile_buf<loc=vec, dtype=i16, rows=16, cols=16, v_row=16, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
outs(%dst_gm : memref<16x16xi16, #pto.address_space<gm>>)
return
}
}

// CHECK: IR Dump After PlanMemory
// CHECK: func.func @tands_plan_memory_distinct
// CHECK-NOT: memref.alloc
// CHECK-DAG: %[[ADDR0:.*]] = arith.constant 0 : i64
// CHECK-DAG: %[[ADDR512:.*]] = arith.constant 512 : i64
// CHECK-DAG: = pto.pointer_cast(%[[ADDR0]])
// CHECK-DAG: = pto.pointer_cast(%[[ADDR512]])
19 changes: 19 additions & 0 deletions test/lit/pto/issue614_tands_same_storage_verify.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: not ptoas --pto-arch=a3 %s 2>&1 | FileCheck %s

module {
func.func @tands_same_storage_via_subview() {
%c0_i64 = arith.constant 0 : i64
%c0 = arith.constant 0 : index
%scalar = arith.constant 7 : i16
%tile = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf<loc=vec, dtype=i16, rows=16, cols=16, v_row=16, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%alias = pto.subview %tile[%c0, %c0] sizes [16, 16] :
!pto.tile_buf<loc=vec, dtype=i16, rows=16, cols=16, v_row=16, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>
-> !pto.tile_buf<loc=vec, dtype=i16, rows=16, cols=16, v_row=16, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>

pto.tands ins(%tile, %scalar : !pto.tile_buf<loc=vec, dtype=i16, rows=16, cols=16, v_row=16, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>, i16)
outs(%alias : !pto.tile_buf<loc=vec, dtype=i16, rows=16, cols=16, v_row=16, v_col=16, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
return
}
}

// CHECK: error: 'pto.tands' op expects src and dst to use different storage
3 changes: 2 additions & 1 deletion test/samples/Mgather/mgather.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.

from mlir.ir import Context, Location, Module, InsertionPoint
from mlir.ir import Context, Location, Module, InsertionPoint, StringAttr
from mlir.dialects import func, arith, pto
from mlir.ir import IndexType, IntegerType

Expand All @@ -17,6 +17,7 @@ def build():

with Location.unknown(ctx):
m = Module.create()
m.operation.attributes["pto.target_arch"] = StringAttr.get("a5")

i32 = IntegerType.get_signless(32, ctx)
ptr_i32 = pto.PtrType.get(i32, ctx)
Expand Down
3 changes: 2 additions & 1 deletion test/samples/Mscatter/mscatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.

from mlir.ir import Context, Location, Module, InsertionPoint
from mlir.ir import Context, Location, Module, InsertionPoint, StringAttr
from mlir.dialects import func, arith, pto
from mlir.ir import IndexType, IntegerType

Expand All @@ -17,6 +17,7 @@ def build():

with Location.unknown(ctx):
m = Module.create()
m.operation.attributes["pto.target_arch"] = StringAttr.get("a5")

i32 = IntegerType.get_signless(32, ctx)
ptr_i32 = pto.PtrType.get(i32, ctx)
Expand Down
Loading