Skip to content

[mlir] Add Repeated<T> constructors for TypeRange and ValueRange#186923

Open
kuhar wants to merge 1 commit intollvm:mainfrom
kuhar:mlir-repeated
Open

[mlir] Add Repeated<T> constructors for TypeRange and ValueRange#186923
kuhar wants to merge 1 commit intollvm:mainfrom
kuhar:mlir-repeated

Conversation

@kuhar
Copy link
Member

@kuhar kuhar commented Mar 16, 2026

Many MLIR APIs end up using a range of the same Type / Value repeated N times, due to the (function of the) dimensionality of the problem. Allocating a vector of N identical element is wasteful.

Add Repeated<T>::Storage as PointerUnion variants in TypeRange and ValueRange, enabling O(1) storage for repeated elements. Size remains 2 pointers (16 bytes on 64-bit) for both range types.

Also update several MLIR dialects and conversions to exercise the new code.

Many MLIR APIs end up using a range of the same Type / Value repeated N
times, due to the dimensionality of the problem. Allocating a vector
of N identical element is wasteful.

Add `Repeated<T>::Storage` as PointerUnion variants in TypeRange
and ValueRange, enabling O(1) storage for repeated elements.
Size remains 2 pointers (16 bytes on 64-bit) for both range types.

Also update several MLIR dialects and conversions to exercise the new
code.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@llvmbot
Copy link
Member

llvmbot commented Mar 16, 2026

@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Jakub Kuderski (kuhar)

Changes

Many MLIR APIs end up using a range of the same Type / Value repeated N times, due to the (function of the) dimensionality of the problem. Allocating a vector of N identical element is wasteful.

Add Repeated&lt;T&gt;::Storage as PointerUnion variants in TypeRange and ValueRange, enabling O(1) storage for repeated elements. Size remains 2 pointers (16 bytes on 64-bit) for both range types.

Also update several MLIR dialects and conversions to exercise the new code.


Full diff: https://github.com/llvm/llvm-project/pull/186923.diff

13 Files Affected:

  • (modified) mlir/include/mlir/IR/TypeRange.h (+13-3)
  • (modified) mlir/include/mlir/IR/ValueRange.h (+10-3)
  • (modified) mlir/include/mlir/Support/LLVM.h (+3)
  • (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp (+1-1)
  • (modified) mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp (+2-2)
  • (modified) mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp (+2-2)
  • (modified) mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+1-1)
  • (modified) mlir/lib/IR/OperationSupport.cpp (+6)
  • (modified) mlir/lib/IR/TypeRange.cpp (+14)
  • (modified) mlir/unittests/IR/OperationSupportTest.cpp (+51)
diff --git a/mlir/include/mlir/IR/TypeRange.h b/mlir/include/mlir/IR/TypeRange.h
index c6cbf3461bcd7..3debed6212778 100644
--- a/mlir/include/mlir/IR/TypeRange.h
+++ b/mlir/include/mlir/IR/TypeRange.h
@@ -33,7 +33,9 @@ namespace mlir {
 class TypeRange : public llvm::detail::indexed_accessor_range_base<
                       TypeRange,
                       llvm::PointerUnion<const Value *, const Type *,
-                                         OpOperand *, detail::OpResultImpl *>,
+                                         OpOperand *, detail::OpResultImpl *,
+                                         const Repeated<Type>::Storage *,
+                                         const Repeated<Value>::Storage *>,
                       Type, Type, Type> {
 public:
   using RangeBaseT::RangeBaseT;
@@ -51,6 +53,10 @@ class TypeRange : public llvm::detail::indexed_accessor_range_base<
       : TypeRange(ArrayRef<Type>(std::forward<Arg>(arg))) {}
   TypeRange(std::initializer_list<Type> types LLVM_LIFETIME_BOUND)
       : TypeRange(ArrayRef<Type>(types)) {}
+  /// Constructs a range from a repeated type. The Repeated object must outlive
+  /// this range.
+  TypeRange(const Repeated<Type> &repeatedValue LLVM_LIFETIME_BOUND)
+      : RangeBaseT(&repeatedValue.storage, repeatedValue.count) {}
 
 private:
   /// The owner of the range is either:
@@ -58,8 +64,12 @@ class TypeRange : public llvm::detail::indexed_accessor_range_base<
   /// * A pointer to the first element of an array of types.
   /// * A pointer to the first element of an array of operands.
   /// * A pointer to the first element of an array of results.
-  using OwnerT = llvm::PointerUnion<const Value *, const Type *, OpOperand *,
-                                    detail::OpResultImpl *>;
+  /// * A pointer to a Repeated<Type>::Storage (single type repeated N times).
+  /// * A pointer to a Repeated<Value>::Storage (single value repeated N times,
+  ///   dereferenced via getType()).
+  using OwnerT = llvm::PointerUnion<
+      const Value *, const Type *, OpOperand *, detail::OpResultImpl *,
+      const Repeated<Type>::Storage *, const Repeated<Value>::Storage *>;
 
   /// See `llvm::detail::indexed_accessor_range_base` for details.
   static OwnerT offset_base(OwnerT object, ptrdiff_t index);
diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h
index f04ed0544c0f6..d40de878d5d10 100644
--- a/mlir/include/mlir/IR/ValueRange.h
+++ b/mlir/include/mlir/IR/ValueRange.h
@@ -17,6 +17,7 @@
 #include "mlir/IR/Types.h"
 #include "mlir/IR/Value.h"
 #include "llvm/ADT/PointerUnion.h"
+#include "llvm/ADT/Repeated.h"
 #include "llvm/ADT/Sequence.h"
 #include <optional>
 
@@ -383,13 +384,15 @@ class ResultRange::UseIterator final
 class ValueRange final
     : public llvm::detail::indexed_accessor_range_base<
           ValueRange,
-          PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>,
+          PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *,
+                       const Repeated<Value>::Storage *>,
           Value, Value, Value> {
 public:
   /// The type representing the owner of a ValueRange. This is either a list of
-  /// values, operands, or results.
+  /// values, operands, results, or a repeated single value.
   using OwnerT =
-      PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>;
+      PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *,
+                   const Repeated<Value>::Storage *>;
 
   using RangeBaseT::RangeBaseT;
 
@@ -412,6 +415,10 @@ class ValueRange final
   ValueRange(ArrayRef<Value> values = {});
   ValueRange(OperandRange values);
   ValueRange(ResultRange values);
+  /// Constructs a range from a repeated value. The Repeated object must outlive
+  /// this range.
+  ValueRange(const Repeated<Value> &repeatedValue LLVM_LIFETIME_BOUND)
+      : RangeBaseT(&repeatedValue.storage, repeatedValue.count) {}
 
   /// Returns the types of the values within this range.
   using type_iterator = ValueTypeIterator<iterator>;
diff --git a/mlir/include/mlir/Support/LLVM.h b/mlir/include/mlir/Support/LLVM.h
index 81bbd717c4f80..8b27ee8c4fed6 100644
--- a/mlir/include/mlir/Support/LLVM.h
+++ b/mlir/include/mlir/Support/LLVM.h
@@ -54,6 +54,8 @@ template <typename T>
 class MutableArrayRef;
 template <typename... PT>
 class PointerUnion;
+template <typename T>
+struct Repeated;
 template <typename T, typename Vector, typename Set, unsigned N>
 class SetVector;
 template <typename T, unsigned N>
@@ -125,6 +127,7 @@ template <typename AllocatorTy = llvm::MallocAllocator>
 using StringSet = llvm::StringSet<AllocatorTy>;
 using llvm::MutableArrayRef;
 using llvm::PointerUnion;
+using llvm::Repeated;
 using llvm::SmallPtrSet;
 using llvm::SmallPtrSetImpl;
 using llvm::SmallVector;
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 498bea0fd17b4..6a705ebab7aa4 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -95,7 +95,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
     // New arguments will simply be `llvm.ptr` with the correct address space
     Type workgroupPtrType =
         rewriter.getType<LLVM::LLVMPointerType>(workgroupAddrSpace);
-    SmallVector<Type> argTypes(numAttributions, workgroupPtrType);
+    Repeated<Type> argTypes(numAttributions, workgroupPtrType);
 
     // Attributes: noalias, llvm.mlir.workgroup_attribution(<size>, <type>)
     std::array attrs{
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 78f0fe1392962..e4b5da7a5ea92 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -155,11 +155,11 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
       int count = vectorType.getNumElements();
       intType = VectorType::get(count, intType);
 
-      SmallVector<Value> signSplat(count, signMask);
+      Repeated<Value> signSplat(count, signMask);
       signMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
                                                      signSplat);
 
-      SmallVector<Value> valueSplat(count, valueMask);
+      Repeated<Value> valueSplat(count, valueMask);
       valueMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
                                                       valueSplat);
     }
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
index b46026b855b90..7e9c9090c51df 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -117,8 +117,8 @@ class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
     auto one = createIndexConst(rewriter, loc, 1);
 
     // Loop bounds
-    auto lbs = llvm::SmallVector<Value>(2, zero);
-    auto steps = llvm::SmallVector<Value>(2, one);
+    auto lbs = Repeated<Value>(2, zero);
+    auto steps = Repeated<Value>(2, one);
     auto ubs = llvm::SmallVector<Value>{{dimN, dimW}};
 
     auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
index bce67b3e4748b..c6182379026df 100644
--- a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
+++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
@@ -354,7 +354,7 @@ static TypedValue<VectorType> storeTile(PatternRewriter &rewriter,
       rewriter, loc,
       MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
   Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
-  SmallVector<Value> indices(2, zeroIndex);
+  Repeated<Value> indices(2, zeroIndex);
   x86::amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
 
   auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType());
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index e625f172a3bf3..f73c8476bf20e 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -723,7 +723,7 @@ BufferDeallocation::handleInterface(RegionBranchOpInterface op) {
   // the outside.
   Value falseVal = buildBoolValue(builder, op.getLoc(), false);
   op->insertOperands(op->getNumOperands(),
-                     SmallVector<Value>(numMemrefOperands, falseVal));
+                     Repeated<Value>(numMemrefOperands, falseVal));
 
   int counter = op->getNumResults();
   unsigned numMemrefResults = llvm::count_if(op->getResults(), isMemref);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 19db8b3b48a25..babd321e484bd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -544,7 +544,7 @@ class TransferReadDropUnitDimsPattern
     Value reducedShapeSource =
         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
     Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
-    SmallVector<Value> zeros(reducedRank, c0);
+    Repeated<Value> zeros(reducedRank, c0);
     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
     SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
     Operation *newTransferReadOp = vector::TransferReadOp::create(
@@ -658,7 +658,7 @@ class TransferWriteDropUnitDimsPattern
     Value reducedShapeSource =
         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
     Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
-    SmallVector<Value> zeros(reducedRank, c0);
+    Repeated<Value> zeros(reducedRank, c0);
     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
     SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
     auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index df8e6cf167348..9585f5a1d774a 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -357,7 +357,7 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
       builder, loc,
       /*vectorType=*/vecToReadTy,
       /*source=*/source,
-      /*indices=*/SmallVector<Value>(vecToReadRank, zero),
+      /*indices=*/Repeated<Value>(vecToReadRank, zero),
       /*padding=*/padValue,
       /*inBounds=*/inBoundsVal);
 
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 3ff61daaac60b..f1ee879136756 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -654,6 +654,9 @@ ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
     return {value + index};
   if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
     return {operand + index};
+  // All elements are identical; the owner pointer never advances.
+  if (llvm::isa<const Repeated<Value>::Storage *>(owner))
+    return owner;
   return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
 }
 /// See `llvm::detail::indexed_accessor_range_base` for details.
@@ -662,6 +665,9 @@ Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
     return value[index];
   if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
     return operand[index].get();
+  if (auto *repeated =
+          llvm::dyn_cast_if_present<const Repeated<Value>::Storage *>(owner))
+    return repeated->value;
   return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
 }
 
diff --git a/mlir/lib/IR/TypeRange.cpp b/mlir/lib/IR/TypeRange.cpp
index 88e788aa1b2b8..d14ad76c83e75 100644
--- a/mlir/lib/IR/TypeRange.cpp
+++ b/mlir/lib/IR/TypeRange.cpp
@@ -31,6 +31,10 @@ TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) {
     this->base = result;
   else if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
     this->base = operand;
+  else if (auto *repeated =
+               llvm::dyn_cast_if_present<const Repeated<Value>::Storage *>(
+                   owner))
+    this->base = repeated;
   else
     this->base = cast<const Value *>(owner);
 }
@@ -43,6 +47,10 @@ TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) {
     return {operand + index};
   if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(object))
     return {result->getNextResultAtOffset(index)};
+  // All elements are identical; the owner pointer never advances.
+  if (llvm::isa<const Repeated<Type>::Storage *,
+                const Repeated<Value>::Storage *>(object))
+    return object;
   return {llvm::dyn_cast_if_present<const Type *>(object) + index};
 }
 
@@ -54,5 +62,11 @@ Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
     return (operand + index)->get().getType();
   if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(object))
     return result->getNextResultAtOffset(index)->getType();
+  if (auto *repeated =
+          llvm::dyn_cast_if_present<const Repeated<Type>::Storage *>(object))
+    return repeated->value;
+  if (auto *repeated =
+          llvm::dyn_cast_if_present<const Repeated<Value>::Storage *>(object))
+    return repeated->value.getType();
   return llvm::dyn_cast_if_present<const Type *>(object)[index];
 }
diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index f00d5c1f7f927..6319fcbb0f216 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -11,7 +11,9 @@
 #include "../../test/lib/Dialect/Test/TestOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeRange.h"
 #include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/Repeated.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "gtest/gtest.h"
 
@@ -375,4 +377,53 @@ TEST(OperationCloneTest, CloneWithDifferentResults) {
   cloneOp->destroy();
 }
 
+TEST(RepeatedRangeTest, TypeRangeFromRepeatedType) {
+  MLIRContext context;
+  Builder builder(&context);
+  Type i32 = builder.getI32Type();
+
+  llvm::Repeated<Type> rep(3, i32);
+  TypeRange range(rep);
+
+  EXPECT_EQ(range.size(), 3u);
+  EXPECT_FALSE(range.empty());
+  for (Type t : range)
+    EXPECT_EQ(t, i32);
+
+  llvm::Repeated<Type> emptyRep(0, Type{});
+  TypeRange emptyTypeRange(emptyRep);
+
+  EXPECT_EQ(emptyTypeRange.size(), 0u);
+  EXPECT_TRUE(emptyTypeRange.empty());
+}
+
+TEST(RepeatedRangeTest, ValueRangeFromRepeatedValue) {
+  Value nullVal;
+  llvm::Repeated<Value> rep(4, nullVal);
+  ValueRange range(rep);
+
+  EXPECT_EQ(range.size(), 4u);
+  EXPECT_FALSE(range.empty());
+  for (Value v : range)
+    EXPECT_EQ(v, nullVal);
+}
+
+TEST(RepeatedRangeTest, TypeRangeFromRepeatedValueViaValueRange) {
+  MLIRContext context;
+  Builder builder(&context);
+  Type i32 = builder.getI32Type();
+
+  Operation *op = createOp(&context, /*operands=*/{}, i32);
+  Value val = op->getResult(0);
+
+  llvm::Repeated<Value> rep(3, val);
+  TypeRange tr = ValueRange(rep);
+
+  EXPECT_EQ(tr.size(), 3u);
+  for (Type t : tr)
+    EXPECT_EQ(t, i32);
+
+  op->destroy();
+}
+
 } // namespace

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants