[flang][HLFIR] Optimize MINLOC/MAXLOC for equality masks#186916
[flang][HLFIR] Optimize MINLOC/MAXLOC for equality masks#186916
Conversation
This patch implements `isEqualityMask` to identify when the MASK argument is an equality comparison against an invariant value (e.g., MASK = A == X). - This allows the SimplifyHLFIRIntrinsicscation pass to extract the invariant search target and bypasses the creation of a temporary logical mask array by inlining the equality comparison directly into the reduction loop. optimization removes the 'hlfir.apply' to the mask's hlfir.elemental, which gets eliminated in bufferize-hlfir pass. - Simplifies the reduction state by removing the min/max value tracker, as the target value is already known. - Implements a "first-hit" locking mechanism. Test Coverage: - 1D, 2D, 3D Variable/Constant equality searches - Verified optimized - Duplicate match handling - Verified first-occurrence logic - No-match cases - Verified zero result - Different array/Non-invariant target - Verified safe fallback
…quality_mask_flang
…quality_mask_flang
…quality_mask_flang
|
@llvm/pr-subscribers-flang-fir-hlfir Author: None (anoopkg6) ChangesThis pr optimizes the lowering of MINLOC and MAXLOC when the MASK argument is a simple equality comparison (e.g., MASK = A == X). This patch introduces isEqualityMask to identify these patterns and inline the comparison directly into the reduction loop. This allows the SimplifyHLFIRIntrinsicscation pass to extract the invariant Patch is 44.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/186916.diff 3 Files Affected:
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index f47353dc30f64..2d987b6300ab3 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -32,6 +32,94 @@ namespace hlfir {
#define DEBUG_TYPE "simplify-hlfir-intrinsics"
+namespace {
+// Check if the given mask is an equality comparison of the search array
+// against an invariant value (e.g., MASK = A == target) by traversing
+// HLFIR/FIR operations to find the underlying elemental comparison
+// and extract the invariant search targetVal.
+// It returns true if the mask is a simple equality comparison against a
+// scalar/invariant.
+bool isEqualityMask(mlir::Value mask, mlir::Value searchArray,
+ mlir::Value &targetVal) {
+ if (!mask)
+ return false;
+
+ // Trace back HLFIR/FIR wrappers to get Elemental producer.
+ mlir::Value currentMask = mask;
+ while (auto def = currentMask.getDefiningOp()) {
+ if (!mlir::isa<hlfir::AsExprOp, fir::ConvertOp, hlfir::DeclareOp,
+ hlfir::CopyInOp>(def))
+ break;
+ currentMask = def->getOperand(0);
+ }
+ // Ensure the mask is produced by an hlfir.elemental.
+ auto elemental = currentMask.getDefiningOp<hlfir::ElementalOp>();
+ if (!elemental)
+ return false;
+
+ // Inspect the elemental body to find the boolean result logic.
+ mlir::Block &body = elemental.getRegion().front();
+ auto yieldOp = mlir::cast<hlfir::YieldElementOp>(body.getTerminator());
+ mlir::Value val = yieldOp.getElementValue();
+ // Get core comparison, ignoring intermediate type casts.
+ while (auto conv = val.getDefiningOp<fir::ConvertOp>())
+ val = conv.getOperand();
+
+ // We currently only optimize integer equality (arith.cmpi eq).
+ auto cmpOp = val.getDefiningOp<mlir::arith::CmpIOp>();
+ if (!cmpOp || cmpOp.getPredicate() != mlir::arith::CmpIPredicate::eq)
+ return false;
+
+ // Determine if a value is invariant relative to the mask loop.
+ // Handles constants, function arguments, and values defined in outer scopes.
+ auto isInvariant = [&](mlir::Value v) {
+ if (auto arg = mlir::dyn_cast<mlir::BlockArgument>(v))
+ return arg.getOwner()->getParent() != &elemental.getRegion();
+ if (auto *op = v.getDefiningOp())
+ return !elemental.getRegion().isAncestor(op->getParentRegion());
+ return true;
+ };
+
+ // Trace the Array Side to the base buffer.
+ auto getBase = [](mlir::Value v) -> mlir::Value {
+ while (v) {
+ mlir::Operation *def = v.getDefiningOp();
+ if (!def)
+ break;
+ if (auto decl = mlir::dyn_cast<hlfir::DeclareOp>(def))
+ v = decl.getMemref();
+ else if (auto load = mlir::dyn_cast<fir::LoadOp>(def))
+ v = load.getMemref();
+ else if (auto apply = mlir::dyn_cast<hlfir::ApplyOp>(def))
+ v = apply.getExpr();
+ else if (auto des = mlir::dyn_cast<hlfir::DesignateOp>(def))
+ v = des.getMemref();
+ else if (mlir::isa<fir::ConvertOp, hlfir::AsExprOp>(def))
+ v = def->getOperand(0);
+ else
+ break;
+ }
+ return v;
+ };
+
+ mlir::Value lhs = cmpOp.getLhs(), rhs = cmpOp.getRhs();
+ bool lhsInv = isInvariant(lhs), rhsInv = isInvariant(rhs);
+ // The optimization is valid only if exactly one side is invariant (the
+ // target) and the other side is variant (the array element).
+ if (lhsInv == rhsInv)
+ return false;
+
+ targetVal = lhsInv ? lhs : rhs;
+ mlir::Value arraySide = lhsInv ? rhs : lhs;
+
+ // Verify the mask refers to the same array being searched.
+ if (getBase(arraySide) == getBase(searchArray))
+ return true;
+
+ return false;
+}
+} // end anonymous namespace
+
static llvm::cl::opt<bool> forceMatmulAsElemental(
"flang-inline-matmul-as-elemental",
llvm::cl::desc("Expand hlfir.matmul as elemental operation"),
@@ -526,6 +614,15 @@ class MinMaxlocAsElementalConverter : public ReductionAsElementalConverter {
void
checkReductions(const llvm::SmallVectorImpl<mlir::Value> &reductions) const {
+ mlir::Value targetVal;
+ // Check if the mask qualifies for the optimized equality mask search path.
+ if (isEqualityMask(this->getMask(), mlir::cast<T>(this->op).getArray(),
+ targetVal)) {
+ // Expect coordinate indices.
+ assert(reductions.size() == getNumCoors() &&
+ "invalid number of reductions for equality mask MINLOC/MAXLOC");
+ return;
+ }
if (!useIsFirst())
assert(reductions.size() == getNumCoors() + 1 &&
"invalid number of reductions for MINLOC/MAXLOC");
@@ -635,6 +732,51 @@ llvm::SmallVector<mlir::Value>
MinMaxlocAsElementalConverter<T>::reduceOneElement(
const llvm::SmallVectorImpl<mlir::Value> ¤tValue, hlfir::Entity array,
mlir::ValueRange oneBasedIndices) {
+ mlir::Value targetVal;
+ // The mask is an equality comparison (e.g., MASK = A == target) inline the
+ // comparison to find the first occurrence efficiently.
+ if (isEqualityMask(this->getMask(), array, targetVal)) {
+ // Directly load the array element and compare with the targetVal.
+ hlfir::Entity elementValue =
+ hlfir::loadElementAt(loc, builder, array, oneBasedIndices);
+ mlir::Value isMatch = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::eq, (mlir::Value)elementValue,
+ targetVal);
+ // currentValue contains [Coord1, ..., CoordN, FirstHitBool]
+ mlir::Value firstHitBool = currentValue.back();
+ // shouldUpdate is true only if we have a match and we haven't found one
+ // yet.
+ mlir::Value shouldUpdate =
+ mlir::arith::AndIOp::create(builder, loc, isMatch, firstHitBool);
+ // Conditional Update: Only update coordinates if a match is found.
+ auto ifOp = fir::IfOp::create(builder, loc,
+ mlir::ValueRange(currentValue).getTypes(),
+ shouldUpdate, /*withElse=*/true);
+ // If match found and it's the first one, record coordinates.
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ llvm::SmallVector<mlir::Value> thenResults;
+ unsigned rank = array.getRank();
+ // Get the firstHit flag.
+ for (unsigned i = 0; i < rank; ++i) {
+ mlir::Value loopIdx = builder.createConvert(
+ loc, currentValue[i].getType(), oneBasedIndices[i]);
+ thenResults.emplace_back(loopIdx);
+ }
+
+ // Update the flag: Set to 0 (False) for all future iterations.
+ mlir::Value falseVal =
+ mlir::arith::ConstantIntOp::create(builder, loc, 0, 1);
+ thenResults.emplace_back(falseVal);
+
+ fir::ResultOp::create(builder, loc, thenResults);
+
+ // No match or already found a previous match: maintain the current state.
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ fir::ResultOp::create(builder, loc, currentValue);
+
+ builder.setInsertionPointAfter(ifOp);
+ return ifOp.getResults();
+ }
checkReductions(currentValue);
hlfir::Entity elementValue =
hlfir::loadElementAt(loc, builder, array, oneBasedIndices);
@@ -685,6 +827,49 @@ MinMaxlocAsElementalConverter<T>::reduceOneElement(
template <typename T>
hlfir::Entity MinMaxlocAsElementalConverter<T>::genFinalResult(
const llvm::SmallVectorImpl<mlir::Value> &reductionResults) {
+ mlir::Value targetVal;
+ // Finalize results for the equality-mask search.
+ if (isEqualityMask(this->getMask(), mlir::cast<T>(this->op).getArray(),
+ targetVal)) {
+ unsigned rank = getNumCoors();
+ mlir::Type resultElemTy =
+ hlfir::getFortranElementType(this->getResultType());
+ // MINLOC/MAXLOC returns an integer array of shape [rank].
+ // Manually build the HLFIR expression to hold the resulting coordinates.
+ llvm::SmallVector<int64_t> shapeVec{static_cast<int64_t>(rank)};
+ mlir::Type exprTy = hlfir::ExprType::get(builder.getContext(), shapeVec,
+ resultElemTy, false);
+ mlir::Value resRank =
+ builder.createIntegerConstant(loc, builder.getIndexType(), rank);
+ mlir::Value resShape = fir::ShapeOp::create(builder, loc, resRank);
+
+ // Create an elemental operation to map the scalar reduction results
+ // (coordinates) back into a Fortran array result.
+ auto elemental =
+ hlfir::ElementalOp::create(builder, loc, exprTy, resShape,
+ /*mold=*/mlir::Value{},
+ /*typeparams=*/mlir::ValueRange{},
+ /*isUnordered=*/false);
+ {
+ // Fill the elemental body.
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(elemental.getBody());
+ // Map the 1-based elemental index, result[i] = reductionResults[i-1].
+ mlir::Value elemIdx = elemental.getIndices()[0];
+ mlir::Value resultVal = reductionResults[0];
+ for (unsigned i = 1; i < rank; ++i) {
+ mlir::Value dimConst =
+ builder.createIntegerConstant(loc, builder.getIndexType(), i + 1);
+ mlir::Value isDimMatch = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::eq, elemIdx, dimConst);
+ // Select specific coordinate matching current elemental dimension.
+ resultVal = mlir::arith::SelectOp::create(
+ builder, loc, isDimMatch, reductionResults[i], resultVal);
+ }
+ hlfir::YieldElementOp::create(builder, loc, resultVal);
+ }
+ return hlfir::Entity{elemental.getResult()};
+ }
// Identification of the final result of MINLOC/MAXLOC:
// * If DIM is absent, the result is rank-one array.
// * If DIM is present:
@@ -1214,9 +1399,39 @@ mlir::LogicalResult ReductionAsElementalConverter::convert() {
extents.push_back(
builder.createConvert(loc, builder.getIndexType(), dimExtent));
- // Initial value for the reduction.
- llvm::SmallVector<mlir::Value, 1> reductionInitValues =
- genReductionInitValues(inputIndices, extents);
+ mlir::Value minMaxMask;
+ if (auto minloc = mlir::dyn_cast<hlfir::MinlocOp>(op)) {
+ minMaxMask = minloc.getMask();
+ } else if (auto maxloc = mlir::dyn_cast<hlfir::MaxlocOp>(op)) {
+ minMaxMask = maxloc.getMask();
+ }
+ mlir::Value targetVal;
+ bool isFixedSearch = false;
+ // Check if the mask allows for a simplified search optimization.
+ if (minMaxMask)
+ isFixedSearch =
+ isEqualityMask(minMaxMask, this->op->getOperand(0), targetVal);
+ llvm::SmallVector<mlir::Value, 1> reductionInitValues;
+ if (isFixedSearch) {
+ // For optimized equality searches, we skip the 'Min/Max value' reduction
+ // and only track coordinate indices and the firstHit flag.
+ unsigned rank = hlfir::Entity{array}.getRank();
+ mlir::Type resElemTy =
+ hlfir::getFortranElementType(this->getResultType());
+ mlir::Value zeroVal = builder.createIntegerConstant(loc, resElemTy, 0);
+
+ // Initialize all coordinates to 0.
+ for (unsigned i = 0; i < rank; ++i) {
+ reductionInitValues.emplace_back(zeroVal);
+ }
+ // First hit flag: [Row, Col, FirstHit=1] (Size: 3)
+ mlir::Type i1Type = builder.getI1Type();
+ mlir::Value firstHitTrue = mlir::arith::ConstantOp::create(
+ builder, loc, i1Type, builder.getBoolAttr(true));
+ reductionInitValues.emplace_back(firstHitTrue);
+ } else {
+ reductionInitValues = genReductionInitValues(inputIndices, extents);
+ }
auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
mlir::ValueRange oneBasedIndices,
@@ -1238,7 +1453,9 @@ mlir::LogicalResult ReductionAsElementalConverter::convert() {
llvm::transform(reductionValues, std::back_inserter(reductionTypes),
[](mlir::Value v) { return v.getType(); });
fir::IfOp ifOp;
- if (mask) {
+ // Skip standard masking block in case of 'isFixedSearch', as it handles
+ // its own masking logic inside the comparison.
+ if (mask && !isFixedSearch) {
// Make the reduction value update conditional on the value
// of the mask.
if (!maskValue) {
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-equality-maxloc.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-equality-maxloc.fir
new file mode 100644
index 0000000000000..31925ae41467e
--- /dev/null
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-equality-maxloc.fir
@@ -0,0 +1,269 @@
+// RUN: fir-opt %s --simplify-hlfir-intrinsics | FileCheck %s
+
+// Rank 1: Variable: A == %target
+func.func @test_maxloc_1d_equality_variable(%arg0: !hlfir.expr<?xi32>, %target: i32) -> !hlfir.expr<1xi32> {
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?xi32>) -> !fir.shape<1>
+ %mask = hlfir.elemental %shape : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
+ ^bb0(%i: index):
+ %val = hlfir.apply %arg0, %i : (!hlfir.expr<?xi32>, index) -> i32
+ %cmp = arith.cmpi eq, %val, %target : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.maxloc %arg0 mask %mask : (!hlfir.expr<?xi32>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
+ return %res : !hlfir.expr<1xi32>
+}
+// CHECK-LABEL: func.func @test_maxloc_1d_equality_variable
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+// CHECK-NOT: arith.constant -2147483648 : i32
+// CHECK: fir.do_loop %[[IV:.*]] = {{.*}} iter_args(%[[LOC:.*]] = %[[C0]], %[[FIRST:.*]] = %[[TRUE]]) -> (i32, i1)
+// CHECK-NOT: hlfir.apply {{.*}} !fir.logical<4>
+// CHECK: %[[VAL:.*]] = hlfir.apply %{{.*}}, %[[IV]]
+// CHECK: %[[EQ:.*]] = arith.cmpi eq, %[[VAL]], %{{.*}}
+// CHECK: %[[COND:.*]] = arith.andi %[[EQ]], %[[FIRST]]
+// CHECK: %[[IF_RES:.*]]:2 = fir.if %[[COND]] -> (i32, i1)
+
+// Rank 2: Variable: A == %target
+func.func @test_maxloc_2d_equality_variable(%arg0: !hlfir.expr<?x?xi32>, %target: i32) -> !hlfir.expr<2xi32> {
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?x?xi32>) -> !fir.shape<2>
+ %mask = hlfir.elemental %shape : (!fir.shape<2>) -> !hlfir.expr<?x?x!fir.logical<4>> {
+ ^bb0(%i: index, %j: index):
+ %val = hlfir.apply %arg0, %i, %j : (!hlfir.expr<?x?xi32>, index, index) -> i32
+ %cmp = arith.cmpi eq, %val, %target : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.maxloc %arg0 mask %mask : (!hlfir.expr<?x?xi32>, !hlfir.expr<?x?x!fir.logical<4>>) -> !hlfir.expr<2xi32>
+ return %res : !hlfir.expr<2xi32>
+}
+// CHECK-LABEL: func.func @test_maxloc_2d_equality_variable
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+// CHECK: %[[RES_OUTER:.*]]:3 = fir.do_loop %[[IV1:.*]] = {{.*}} iter_args(%[[OUT1:.*]] = %[[C0]], %[[OUT2:.*]] = %[[C0]], %[[OUT3:.*]] = %[[TRUE]]) -> (i32, i32, i1)
+// CHECK: %[[RES_INNER:.*]]:3 = fir.do_loop %[[IV2:.*]] = {{.*}} iter_args(%[[IN1:.*]] = %[[OUT1]], %[[IN2:.*]] = %[[OUT2]], %[[IN3:.*]] = %[[OUT3]]) -> (i32, i32, i1)
+// CHECK-NOT: hlfir.apply {{.*}} !fir.logical<4>
+// CHECK: %[[EQ:.*]] = arith.cmpi eq, {{.*}}
+// CHECK: %[[COND:.*]] = arith.andi %[[EQ]], %[[IN3]]
+// CHECK: %[[IF:.*]]:3 = fir.if %[[COND]] -> (i32, i32, i1)
+
+// Rank 3: Variable: A == %target
+func.func @test_maxloc_3d_equality_variable(%arg0: !hlfir.expr<?x?x?xi32>, %target: i32) -> !hlfir.expr<3xi32> {
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?x?x?xi32>) -> !fir.shape<3>
+ %mask = hlfir.elemental %shape : (!fir.shape<3>) -> !hlfir.expr<?x?x?x!fir.logical<4>> {
+ ^bb0(%i: index, %j: index, %k: index):
+ %val = hlfir.apply %arg0, %i, %j, %k : (!hlfir.expr<?x?x?xi32>, index, index, index) -> i32
+ %cmp = arith.cmpi eq, %val, %target : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.maxloc %arg0 mask %mask : (!hlfir.expr<?x?x?xi32>, !hlfir.expr<?x?x?x!fir.logical<4>>) -> !hlfir.expr<3xi32>
+ return %res : !hlfir.expr<3xi32>
+}
+// CHECK-LABEL: func.func @test_maxloc_3d_equality_variable
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+// CHECK: %[[OUTER:.*]]:4 = fir.do_loop %[[IV1:.*]] = {{.*}} iter_args(%[[O1:.*]] = %[[C0]], %[[O2:.*]] = %[[C0]], %[[O3:.*]] = %[[C0]], %[[O4:.*]] = %[[TRUE]]) -> (i32, i32, i32, i1)
+// CHECK: %[[MIDDLE:.*]]:4 = fir.do_loop %[[IV2:.*]] = {{.*}} iter_args(%[[M1:.*]] = %[[O1]], %[[M2:.*]] = %[[O2]], %[[M3:.*]] = %[[O3]], %[[M4:.*]] = %[[O4]]) -> (i32, i32, i32, i1)
+// CHECK: %[[INNER:.*]]:4 = fir.do_loop %[[IV3:.*]] = {{.*}} iter_args(%[[I1:.*]] = %[[M1]], %[[I2:.*]] = %[[M2]], %[[I3:.*]] = %[[M3]], %[[I4:.*]] = %[[M4]]) -> (i32, i32, i32, i1)
+// CHECK-NOT: hlfir.apply {{.*}} !fir.logical<4>
+// CHECK: %[[EQ:.*]] = arith.cmpi eq, {{.*}}
+// CHECK: %[[COND:.*]] = arith.andi %[[EQ]], %[[I4]]
+// CHECK: %[[IF:.*]]:4 = fir.if %[[COND]] -> (i32, i32, i32, i1)
+
+// Rank 1: Constant: A == 42
+func.func @test_maxloc_1d_equality_constant(%arg0: !hlfir.expr<?xi32>) -> !hlfir.expr<1xi32> {
+ %c42 = arith.constant 42 : i32
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?xi32>) -> !fir.shape<1>
+ %mask = hlfir.elemental %shape : (!fir.shape<1>) -> !hlfir.expr<?x!fir.logical<4>> {
+ ^bb0(%i: index):
+ %val = hlfir.apply %arg0, %i : (!hlfir.expr<?xi32>, index) -> i32
+ %cmp = arith.cmpi eq, %val, %c42 : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.maxloc %arg0 mask %mask : (!hlfir.expr<?xi32>, !hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<1xi32>
+ return %res : !hlfir.expr<1xi32>
+}
+// CHECK-LABEL: func.func @test_maxloc_1d_equality_constant
+// CHECK-DAG: %[[C42:.*]] = arith.constant 42 : i32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+
+// CHECK: %[[RES:.*]]:2 = fir.do_loop %[[IV:.*]] = {{.*}} iter_args(%[[LOC:.*]] = %[[C0]], %[[FIRST:.*]] = %[[TRUE]]) -> (i32, i1)
+// CHECK-NOT: hlfir.apply {{.*}} !fir.logical<4>
+// CHECK: %[[VAL:.*]] = hlfir.apply %{{.*}}, %[[IV]]
+// CHECK: %[[EQ:.*]] = arith.cmpi eq, %[[VAL]], %[[C42]]
+// CHECK: %[[COND:.*]] = arith.andi %[[EQ]], %[[FIRST]]
+// CHECK: %[[IF:.*]]:2 = fir.if %[[COND]] -> (i32, i1)
+
+// Rank 2: Constant: A == 42
+func.func @test_maxloc_2d_equality_constant(%arg0: !hlfir.expr<?x?xi32>) -> !hlfir.expr<2xi32> {
+ %c42 = arith.constant 42 : i32
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?x?xi32>) -> !fir.shape<2>
+ %mask = hlfir.elemental %shape : (!fir.shape<2>) -> !hlfir.expr<?x?x!fir.logical<4>> {
+ ^bb0(%i: index, %j: index):
+ %val = hlfir.apply %arg0, %i, %j : (!hlfir.expr<?x?xi32>, index, index) -> i32
+ %cmp = arith.cmpi eq, %val, %c42 : i32
+ %logical = fir.convert %cmp : (i1) -> !fir.logical<4>
+ hlfir.yield_element %logical : !fir.logical<4>
+ }
+ %res = hlfir.maxloc %arg0 mask %mask : (!hlfir.expr<?x?xi32>, !hlfir.expr<?x?x!fir.logical<4>>) -> !hlfir.expr<2xi32>
+ return %res : !hlfir.expr<2xi32>
+}
+// CHECK-LABEL: func.func @test_maxloc_2d_equality_constant
+// CHECK-DAG: %[[C42:.*]] = arith.constant 42 : i32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+// CHECK: %[[OUTER:.*]]:3 = fir.do_loop %[[IV1:.*]] = {{.*}} iter_args(%[[O1:.*]] = %[[C0]], %[[O2:.*]] = %[[C0]], %[[O3:.*]] = %[[TRUE]]) -> (i32, i32, i1)
+// CHECK: %[[INNER:.*]]:3 = fir.do_loop %[[IV2:.*]] = {{.*}} iter_args(%[[I1:.*]] = %[[O1]], %[[I2:.*]] = %[[O2]], %[[I3:.*]] = %[[O3]]) -> (i32, i32, i1)
+// CHECK-NOT: hlfir.apply {{.*}} !fir.logical<4>
+// CHECK: %[[VAL:.*]] = hlfir.apply %{{.*}}, %[[IV2]], %[[IV1]]
+// CHECK: %[[EQ:.*]] = arith.cmpi eq, %[[VAL]], %[[C42]]
+// CHECK: %[[COND:.*]] = arith.andi %[[EQ]], %[[I3]]
+
+// Rank 3: Constant: A == 42
+func.func @test_maxloc_3d_equality_constant(%arg0: !hlfir.expr<?x?x?xi32>) -> !hlfir.expr<3xi32> {
+ %c42 = arith.constant 42 : i32
+ %shape = hlfir.shape_of %arg0 : (!hlfir.expr<?x?x?xi32>) -> !fir.shape<3>
+ %mask = hlfir.elemental %shape : (!fir.s...
[truncated]
|
|
The placement of this optimization looks off to me. I think this better belongs to InlineElementals. |
This pr optimizes the lowering of MINLOC and MAXLOC when the MASK argument is a simple equality comparison (e.g., MASK = A == X).
This patch introduces isEqualityMask to identify these patterns and inline the comparison directly into the reduction loop.
This allows the SimplifyHLFIRIntrinsicscation pass to extract the invariant
search target and bypasses the creation of a temporary logical mask array
by inlining the equality comparison directly into the reduction loop.
optimization removes the 'hlfir.apply' to the mask's hlfir.elemental, which
gets eliminated in bufferize-hlfir pass.
Simplifies the reduction state by removing the min/max value tracker,
as the target value is already known.
Implements a 'first-hit' locking mechanism.