Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions llvm/include/llvm/ADT/Repeated.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "llvm/ADT/iterator.h"

#include <algorithm>
#include <cassert>
#include <cstddef>
#include <utility>
Expand Down Expand Up @@ -72,21 +73,19 @@ class RepeatedIterator
///
/// `Repeated<T>` is also a proper random-access range: `begin()`/`end()`
/// return iterators that always dereference to the same stored value.
template <typename T> struct [[nodiscard]] Repeated {
/// Wrapper for the stored value used as a PointerUnion target in range
/// types (e.g., TypeRange, ValueRange).
struct Storage {
T value;
};

Storage storage;
// At least 16-byte aligned so that Repeated<T>* has more low bits available
// than a plain pointer. The primary use case is pointer-like types (e.g. MLIR
// Type, Value) where Repeated<T>* appears in a PointerUnion alongside them.
template <typename T>
struct [[nodiscard]] alignas(std::max(size_t{16}, alignof(T))) Repeated {
T storage;
size_t count;

/// Create a `value` repeated `count` times.
/// Uses the same argument order like STD container constructors.
/// Uses the same argument order like std container constructors.
template <typename U>
Repeated(size_t count, U &&value)
: storage{std::forward<U>(value)}, count(count) {}
: storage(std::forward<U>(value)), count(count) {}

using iterator = RepeatedIterator<T>;
using const_iterator = iterator;
Expand All @@ -95,21 +94,19 @@ template <typename T> struct [[nodiscard]] Repeated {
using value_type = T;
using size_type = size_t;

iterator begin() const { return {&storage.value, 0}; }
iterator end() const {
return {&storage.value, static_cast<ptrdiff_t>(count)};
}
iterator begin() const { return {&storage, 0}; }
iterator end() const { return {&storage, static_cast<ptrdiff_t>(count)}; }
reverse_iterator rbegin() const { return reverse_iterator(end()); }
reverse_iterator rend() const { return reverse_iterator(begin()); }

size_t size() const { return count; }
bool empty() const { return count == 0; }

const T &value() const { return storage.value; }
const T &value() const { return storage; }
const T &operator[](size_t idx) const {
assert(idx < size() && "Out of bounds");
(void)idx;
return storage.value;
return storage;
}
};

Expand Down
26 changes: 19 additions & 7 deletions mlir/include/mlir/IR/TypeRange.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/Repeated.h"
#include "llvm/ADT/Sequence.h"

namespace mlir {
Expand All @@ -30,11 +31,13 @@ namespace mlir {
/// a SmallVector/std::vector. This class should be used in places that are not
/// suitable for a more derived type (e.g. ArrayRef) or a template range
/// parameter.
class TypeRange : public llvm::detail::indexed_accessor_range_base<
TypeRange,
llvm::PointerUnion<const Value *, const Type *,
OpOperand *, detail::OpResultImpl *>,
Type, Type, Type> {
class TypeRange
: public llvm::detail::indexed_accessor_range_base<
TypeRange,
llvm::PointerUnion<const Value *, const Type *, OpOperand *,
detail::OpResultImpl *, const Repeated<Type> *,
const Repeated<Value> *>,
Type, Type, Type> {
public:
using RangeBaseT::RangeBaseT;
TypeRange(ArrayRef<Type> types = {});
Expand All @@ -51,15 +54,24 @@ class TypeRange : public llvm::detail::indexed_accessor_range_base<
: TypeRange(ArrayRef<Type>(std::forward<Arg>(arg))) {}
TypeRange(std::initializer_list<Type> types LLVM_LIFETIME_BOUND)
: TypeRange(ArrayRef<Type>(types)) {}
/// Constructs a range from a repeated type. The Repeated object must outlive
/// this range.
TypeRange(const Repeated<Type> &repeatedValue LLVM_LIFETIME_BOUND)
: RangeBaseT(&repeatedValue, repeatedValue.count) {}

private:
/// The owner of the range is either:
/// * A pointer to the first element of an array of values.
/// * A pointer to the first element of an array of types.
/// * A pointer to the first element of an array of operands.
/// * A pointer to the first element of an array of results.
using OwnerT = llvm::PointerUnion<const Value *, const Type *, OpOperand *,
detail::OpResultImpl *>;
/// * A pointer to a Repeated<Type> (single type repeated N times).
/// * A pointer to a Repeated<Value> (single value repeated N times,
/// dereferenced via getType()).
using OwnerT =
llvm::PointerUnion<const Value *, const Type *, OpOperand *,
detail::OpResultImpl *, const Repeated<Type> *,
const Repeated<Value> *>;

/// See `llvm::detail::indexed_accessor_range_base` for details.
static OwnerT offset_base(OwnerT object, ptrdiff_t index);
Expand Down
14 changes: 10 additions & 4 deletions mlir/include/mlir/IR/ValueRange.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/Repeated.h"
#include "llvm/ADT/Sequence.h"
#include <optional>

Expand Down Expand Up @@ -383,13 +384,14 @@ class ResultRange::UseIterator final
class ValueRange final
: public llvm::detail::indexed_accessor_range_base<
ValueRange,
PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>,
PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *,
const Repeated<Value> *>,
Value, Value, Value> {
public:
/// The type representing the owner of a ValueRange. This is either a list of
/// values, operands, or results.
using OwnerT =
PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>;
/// values, operands, results, or a repeated single value.
using OwnerT = PointerUnion<const Value *, OpOperand *,
detail::OpResultImpl *, const Repeated<Value> *>;

using RangeBaseT::RangeBaseT;

Expand All @@ -412,6 +414,10 @@ class ValueRange final
ValueRange(ArrayRef<Value> values = {});
ValueRange(OperandRange values);
ValueRange(ResultRange values);
/// Constructs a range from a repeated value. The Repeated object must outlive
/// this range.
ValueRange(const Repeated<Value> &repeatedValue LLVM_LIFETIME_BOUND)
: RangeBaseT(&repeatedValue, repeatedValue.count) {}

/// Returns the types of the values within this range.
using type_iterator = ValueTypeIterator<iterator>;
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Support/LLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class StringSet;
template <typename T, typename R>
class StringSwitch;
template <typename T>
struct Repeated;
template <typename T>
class TinyPtrVector;
template <typename T, typename ResultT>
class TypeSwitch;
Expand Down Expand Up @@ -125,6 +127,7 @@ template <typename AllocatorTy = llvm::MallocAllocator>
using StringSet = llvm::StringSet<AllocatorTy>;
using llvm::MutableArrayRef;
using llvm::PointerUnion;
using llvm::Repeated;
using llvm::SmallPtrSet;
using llvm::SmallPtrSetImpl;
using llvm::SmallVector;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
// New arguments will simply be `llvm.ptr` with the correct address space
Type workgroupPtrType =
rewriter.getType<LLVM::LLVMPointerType>(workgroupAddrSpace);
SmallVector<Type> argTypes(numAttributions, workgroupPtrType);
Repeated<Type> argTypes(numAttributions, workgroupPtrType);

// Attributes: noalias, llvm.mlir.workgroup_attribution(<size>, <type>)
std::array attrs{
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,11 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
int count = vectorType.getNumElements();
intType = VectorType::get(count, intType);

SmallVector<Value> signSplat(count, signMask);
Repeated<Value> signSplat(count, signMask);
signMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
signSplat);

SmallVector<Value> valueSplat(count, valueMask);
Repeated<Value> valueSplat(count, valueMask);
valueMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
valueSplat);
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
auto one = createIndexConst(rewriter, loc, 1);

// Loop bounds
auto lbs = llvm::SmallVector<Value>(2, zero);
auto steps = llvm::SmallVector<Value>(2, one);
auto lbs = Repeated<Value>(2, zero);
auto steps = Repeated<Value>(2, one);
auto ubs = llvm::SmallVector<Value>{{dimN, dimW}};

auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ static TypedValue<VectorType> storeTile(PatternRewriter &rewriter,
rewriter, loc,
MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> indices(2, zeroIndex);
Repeated<Value> indices(2, zeroIndex);
x86::amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);

auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ BufferDeallocation::handleInterface(RegionBranchOpInterface op) {
// the outside.
Value falseVal = buildBoolValue(builder, op.getLoc(), false);
op->insertOperands(op->getNumOperands(),
SmallVector<Value>(numMemrefOperands, falseVal));
Repeated<Value>(numMemrefOperands, falseVal));

int counter = op->getNumResults();
unsigned numMemrefResults = llvm::count_if(op->getResults(), isMemref);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ class TransferReadDropUnitDimsPattern
Value reducedShapeSource =
rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
SmallVector<Value> zeros(reducedRank, c0);
Repeated<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
Operation *newTransferReadOp = vector::TransferReadOp::create(
Expand Down Expand Up @@ -658,7 +658,7 @@ class TransferWriteDropUnitDimsPattern
Value reducedShapeSource =
rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
SmallVector<Value> zeros(reducedRank, c0);
Repeated<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
builder, loc,
/*vectorType=*/vecToReadTy,
/*source=*/source,
/*indices=*/SmallVector<Value>(vecToReadRank, zero),
/*indices=*/Repeated<Value>(vecToReadRank, zero),
/*padding=*/padValue,
/*inBounds=*/inBoundsVal);

Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/IR/OperationSupport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,9 @@ ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
return {value + index};
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
return {operand + index};
// All elements are identical; the owner pointer never advances.
if (llvm::isa<const Repeated<Value> *>(owner))
return owner;
return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
}
/// See `llvm::detail::indexed_accessor_range_base` for details.
Expand All @@ -662,6 +665,9 @@ Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
return value[index];
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
return operand[index].get();
if (auto *repeated =
llvm::dyn_cast_if_present<const Repeated<Value> *>(owner))
return repeated->value();
return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
}

Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/IR/TypeRange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) {
this->base = result;
else if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
this->base = operand;
else if (auto *repeated =
llvm::dyn_cast_if_present<const Repeated<Value> *>(owner))
this->base = repeated;
else
this->base = cast<const Value *>(owner);
}
Expand All @@ -43,6 +46,9 @@ TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) {
return {operand + index};
if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(object))
return {result->getNextResultAtOffset(index)};
// All elements are identical; the owner pointer never advances.
if (llvm::isa<const Repeated<Type> *, const Repeated<Value> *>(object))
return object;
return {llvm::dyn_cast_if_present<const Type *>(object) + index};
}

Expand All @@ -54,5 +60,11 @@ Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
return (operand + index)->get().getType();
if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(object))
return result->getNextResultAtOffset(index)->getType();
if (auto *repeated =
llvm::dyn_cast_if_present<const Repeated<Type> *>(object))
return repeated->value();
if (auto *repeated =
llvm::dyn_cast_if_present<const Repeated<Value> *>(object))
return repeated->value().getType();
return llvm::dyn_cast_if_present<const Type *>(object)[index];
}
63 changes: 63 additions & 0 deletions mlir/unittests/IR/OperationSupportTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/Repeated.h"
#include "llvm/Support/FormatVariadic.h"
#include "gtest/gtest.h"

Expand Down Expand Up @@ -375,4 +376,66 @@ TEST(OperationCloneTest, CloneWithDifferentResults) {
cloneOp->destroy();
}

TEST(RepeatedRangeTest, TypeRangeFromRepeatedType) {
MLIRContext context;
Builder builder(&context);
Type i32 = builder.getI32Type();

llvm::Repeated<Type> rep(3, i32);
TypeRange range(rep);

EXPECT_EQ(range.size(), 3u);
EXPECT_FALSE(range.empty());
for (Type t : range)
EXPECT_EQ(t, i32);

// Indexing and slicing exercise offset_base (which must not advance).
EXPECT_EQ(range[0], i32);
EXPECT_EQ(range[2], i32);
TypeRange sliced = range.drop_front(1);
EXPECT_EQ(sliced.size(), 2u);
EXPECT_EQ(sliced[0], i32);

llvm::Repeated<Type> emptyRep(0, Type{});
TypeRange emptyTypeRange(emptyRep);

EXPECT_EQ(emptyTypeRange.size(), 0u);
EXPECT_TRUE(emptyTypeRange.empty());
}

TEST(RepeatedRangeTest, ValueRangeFromRepeatedValue) {
Value nullVal;
llvm::Repeated<Value> rep(4, nullVal);
ValueRange range(rep);

EXPECT_EQ(range.size(), 4u);
EXPECT_FALSE(range.empty());
for (Value v : range)
EXPECT_EQ(v, nullVal);

llvm::Repeated<Value> emptyRep(0, nullVal);
ValueRange emptyRange(emptyRep);
EXPECT_EQ(emptyRange.size(), 0u);
EXPECT_TRUE(emptyRange.empty());
}

TEST(RepeatedRangeTest, TypeRangeFromRepeatedValueViaValueRange) {
MLIRContext context;
Builder builder(&context);

Operation *useOp =
createOp(&context, /*operands=*/{}, builder.getIntegerType(16));
Value operand = useOp->getResult(0);

llvm::Repeated<Value> rep(3, operand);
ValueRange vr(rep);
TypeRange tr(vr);

EXPECT_EQ(tr.size(), 3u);
for (Type t : tr)
EXPECT_EQ(t, builder.getIntegerType(16));

useOp->destroy();
}

} // namespace