-
-
Notifications
You must be signed in to change notification settings - Fork 57
✨ Add TensorIterator
#1730
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
✨ Add TensorIterator
#1730
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
10a4a2a
Add TensorIterator
MatthiasReumann fb54971
Add qtensor-utils unit test
MatthiasReumann f60a182
🎨 pre-commit fixes
pre-commit-ci[bot] 3df1be0
Update CHANGELOG.md
MatthiasReumann d9153d4
Fix linting
MatthiasReumann 365356a
Add scf.for to unit test
MatthiasReumann 029e32d
🎨 pre-commit fixes
pre-commit-ci[bot] edff31d
Add missing includes
MatthiasReumann 174685d
Merge branch 'main' into feat/tensor-iterator
MatthiasReumann 1ccc7aa
Merge branch 'main' into feat/tensor-iterator
MatthiasReumann 092f9b6
Merge branch 'main' into feat/tensor-iterator
burgholzer 6e2049c
:pencil2: adding this to the main MQT CC entry in the changelog
burgholzer f3358fe
:art: removing redundant namespace qualifiers
burgholzer 4e09cbe
Merge branch 'main' into feat/tensor-iterator
denialhaag 90cc22d
Support qco.if op
MatthiasReumann 712a587
Remove unused header
MatthiasReumann 8261f0d
Support multiple results and inputs
MatthiasReumann 691c39f
Merge branch 'main' into feat/tensor-iterator
MatthiasReumann 348a8ef
Update outdated comments
MatthiasReumann 3c829dd
Update comments
MatthiasReumann File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| /* | ||
| * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM | ||
| * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH | ||
| * All rights reserved. | ||
| * | ||
| * SPDX-License-Identifier: MIT | ||
| * | ||
| * Licensed under the MIT License | ||
| */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <mlir/IR/Builders.h> | ||
| #include <mlir/IR/BuiltinTypes.h> | ||
| #include <mlir/IR/Operation.h> | ||
| #include <mlir/IR/Value.h> | ||
|
|
||
| #include <iterator> | ||
|
|
||
| namespace mlir::qtensor { | ||
|
|
||
| /** | ||
| * @brief A bidirectional_iterator traversing the tensor chain. | ||
| **/ | ||
| class [[nodiscard]] TensorIterator { | ||
| public: | ||
| using iterator_category = std::bidirectional_iterator_tag; | ||
| using difference_type = std::ptrdiff_t; | ||
| using value_type = Operation*; | ||
|
|
||
| TensorIterator() : op_(nullptr), tensor_(nullptr), isSentinel_(false) {} | ||
| explicit TensorIterator(TypedValue<RankedTensorType> tensor) | ||
| : op_(tensor.getDefiningOp()), tensor_(tensor), isSentinel_(false) {} | ||
|
|
||
| /// @returns the operation the iterator points to. | ||
| [[nodiscard]] Operation* operation() const { return op_; } | ||
|
|
||
| /// @returns the operation the iterator points to. | ||
| [[nodiscard]] Operation* operator*() const { return operation(); } | ||
|
|
||
| /// @returns the tensor the iterator points to. | ||
| [[nodiscard]] TypedValue<RankedTensorType> tensor() const; | ||
|
|
||
| TensorIterator& operator++() { | ||
| forward(); | ||
| return *this; | ||
| } | ||
|
|
||
| TensorIterator operator++(int) { | ||
| auto tmp = *this; | ||
| operator++(); | ||
| return tmp; | ||
| } | ||
|
|
||
| TensorIterator& operator--() { | ||
| backward(); | ||
| return *this; | ||
| } | ||
|
|
||
| TensorIterator operator--(int) { | ||
| auto tmp = *this; | ||
| operator--(); | ||
| return tmp; | ||
| } | ||
|
|
||
| bool operator==(const TensorIterator& other) const { | ||
| return other.tensor_ == tensor_ && other.op_ == op_ && | ||
| other.isSentinel_ == isSentinel_; | ||
| } | ||
|
|
||
| bool operator==([[maybe_unused]] std::default_sentinel_t s) const { | ||
| return isSentinel_; | ||
| } | ||
|
|
||
| private: | ||
| /// @brief Move to the next operation on the tensor def-use chain. | ||
| void forward(); | ||
|
|
||
| /// @brief Move to the previous operation on the tensor def-use chain. | ||
| void backward(); | ||
|
|
||
| Operation* op_; | ||
| TypedValue<RankedTensorType> tensor_; | ||
| bool isSentinel_; | ||
| }; | ||
| } // namespace mlir::qtensor |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,3 +8,4 @@ | |
|
|
||
| add_subdirectory(IR) | ||
| add_subdirectory(Transforms) | ||
| add_subdirectory(Utils) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| # Copyright (c) 2023 - 2026 Chair for Design Automation, TUM | ||
| # Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH | ||
| # All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: MIT | ||
| # | ||
| # Licensed under the MIT License | ||
|
|
||
| file(GLOB_RECURSE UTILS_CPP "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp") | ||
|
|
||
| add_mlir_dialect_library( | ||
| MLIRQTensorUtils | ||
| ${UTILS_CPP} | ||
| ADDITIONAL_HEADER_DIRS | ||
| ${PROJECT_SOURCE_DIR}/mlir/include/mlir/Dialect/QTensor | ||
| DEPENDS | ||
| MLIRQTensorOpsIncGen | ||
| LINK_LIBS | ||
| PUBLIC | ||
| MLIRQTensorDialect) | ||
|
|
||
| mqt_mlir_target_use_project_options(MLIRQTensorUtils) | ||
|
|
||
| # collect header files | ||
| file(GLOB_RECURSE UTILS_HEADERS_SOURCE | ||
| "${MQT_MLIR_SOURCE_INCLUDE_DIR}/mlir/Dialect/QTensor/Utils/*.h") | ||
| file(GLOB_RECURSE UTILS_HEADERS_BUILD | ||
| "${MQT_MLIR_BUILD_INCLUDE_DIR}/mlir/Dialect/QTensor/Utils/*.inc") | ||
|
|
||
| # add public headers using file sets | ||
| target_sources( | ||
| MLIRQTensorUtils | ||
| PUBLIC FILE_SET | ||
| HEADERS | ||
| BASE_DIRS | ||
| ${MQT_MLIR_SOURCE_INCLUDE_DIR} | ||
| FILES | ||
| ${UTILS_HEADERS_SOURCE} | ||
| FILE_SET | ||
| HEADERS | ||
| BASE_DIRS | ||
| ${MQT_MLIR_BUILD_INCLUDE_DIR} | ||
| FILES | ||
| ${UTILS_HEADERS_BUILD}) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| /* | ||
| * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM | ||
| * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH | ||
| * All rights reserved. | ||
| * | ||
| * SPDX-License-Identifier: MIT | ||
| * | ||
| * Licensed under the MIT License | ||
| */ | ||
|
|
||
| #include "mlir/Dialect/QTensor/Utils/TensorIterator.h" | ||
|
|
||
| #include "mlir/Dialect/QCO/IR/QCOOps.h" | ||
| #include "mlir/Dialect/QTensor/IR/QTensorOps.h" | ||
|
|
||
| #include <llvm/ADT/STLExtras.h> | ||
| #include <llvm/ADT/TypeSwitch.h> | ||
| #include <llvm/Support/ErrorHandling.h> | ||
| #include <mlir/Dialect/SCF/IR/SCF.h> | ||
| #include <mlir/IR/Builders.h> | ||
| #include <mlir/IR/Value.h> | ||
| #include <mlir/Support/LLVM.h> | ||
|
|
||
| #include <cassert> | ||
| #include <iterator> | ||
|
|
||
| namespace mlir::qtensor { | ||
| TypedValue<RankedTensorType> TensorIterator::tensor() const { | ||
| if (op_ == nullptr) { | ||
| return tensor_; | ||
| } | ||
|
|
||
| // The following operations don't have an OpResult. | ||
| if (isa<DeallocOp, scf::YieldOp, qco::YieldOp>(op_)) { | ||
| return nullptr; | ||
| } | ||
|
|
||
| return tensor_; | ||
| } | ||
|
|
||
| void TensorIterator::forward() { | ||
| // If the iterator is a sentinel already, there is nothing to do. | ||
| if (isSentinel_) { | ||
| return; | ||
| } | ||
|
|
||
| // Find the user-operation of the tensor SSA value. | ||
| assert(tensor_.hasOneUse() && "expected linear typing"); | ||
| op_ = *(tensor_.user_begin()); | ||
|
|
||
| // The following operations define the end of the tensor's life-chain. | ||
| if (isa<DeallocOp, scf::YieldOp, qco::YieldOp>(op_)) { | ||
| isSentinel_ = true; | ||
| return; | ||
| } | ||
|
|
||
| // Find the output from the input tensor SSA value. | ||
| if (!(isa<AllocOp, FromElementsOp>(op_))) { | ||
| TypeSwitch<Operation*>(op_) | ||
| .Case<ExtractOp>([&](ExtractOp op) { tensor_ = op.getOutTensor(); }) | ||
| .Case<InsertOp>([&](InsertOp op) { tensor_ = op.getResult(); }) | ||
| .Case<scf::ForOp>([&](scf::ForOp op) { | ||
| tensor_ = cast<TypedValue<RankedTensorType>>( | ||
| op.getTiedLoopResult(&*(tensor_.use_begin()))); | ||
| }) | ||
| .Case<qco::IfOp>([&](qco::IfOp op) { | ||
| auto it = llvm::find(op.getQubits(), tensor_); | ||
| assert(it != op.getQubits().end()); | ||
| const auto idx = std::distance(op.getQubits().begin(), it); | ||
| tensor_ = cast<TypedValue<RankedTensorType>>(op.getResults()[idx]); | ||
| }) | ||
| .Default([&](Operation* op) { | ||
| report_fatal_error("unknown op in def-use chain: " + | ||
| op->getName().getStringRef()); | ||
| }); | ||
| } | ||
| } | ||
|
burgholzer marked this conversation as resolved.
|
||
|
|
||
| void TensorIterator::backward() { | ||
| // If the iterator is a sentinel, reactivate the iterator. | ||
| if (isSentinel_) { | ||
| isSentinel_ = false; | ||
| return; | ||
| } | ||
|
|
||
| // If the op is a nullptr, the tensor value is a block argument and thus the | ||
| // beginning of the tensor's life-chain. | ||
| if (op_ == nullptr) { | ||
| return; | ||
| } | ||
|
|
||
| // For these operations, tensor_ is an OpOperand. Hence, only get the def-op. | ||
| if (isa<DeallocOp, scf::YieldOp, qco::YieldOp>(op_)) { | ||
| op_ = tensor_.getDefiningOp(); | ||
| return; | ||
| } | ||
|
|
||
| // Allocations and FromElements define the start of the tensor's life-chain. | ||
| // Consequently, stop and early exit. | ||
| if (isa<AllocOp, FromElementsOp>(op_)) { | ||
| return; | ||
| } | ||
|
|
||
| // Find the input from the output tensor SSA value. | ||
| TypeSwitch<Operation*>(op_) | ||
| .Case<ExtractOp>([&](ExtractOp op) { tensor_ = op.getTensor(); }) | ||
| .Case<InsertOp>([&](InsertOp op) { tensor_ = op.getDest(); }) | ||
| .Case<scf::ForOp>([&](scf::ForOp op) { | ||
| if (auto res = dyn_cast<OpResult>(tensor_)) { | ||
| OpOperand* operand = op.getTiedLoopInit(res); | ||
| tensor_ = cast<TypedValue<RankedTensorType>>(operand->get()); | ||
| return; | ||
| } | ||
|
|
||
| llvm::reportFatalInternalError( | ||
| "expected scf.for result for tied init lookup"); | ||
| }) | ||
| .Case<qco::IfOp>([&](qco::IfOp op) { | ||
| if (auto res = dyn_cast<OpResult>(tensor_)) { | ||
| auto it = llvm::find(op.getResults(), res); | ||
| assert(it != op->result_end()); | ||
| const auto idx = std::distance(op.result_begin(), it); | ||
| tensor_ = cast<TypedValue<RankedTensorType>>(op.getQubits()[idx]); | ||
| return; | ||
| } | ||
|
|
||
| llvm::reportFatalInternalError( | ||
| "expected scf.for result for tied init lookup"); | ||
| }) | ||
| .Default([&](Operation* op) { | ||
| llvm::reportFatalInternalError("unknown op in def-use chain: " + | ||
| op->getName().getStringRef()); | ||
| }); | ||
|
|
||
| // Get the operation that produces the tensor value. | ||
| // If the current tensor SSA value is a BlockArgument (no defining op), the | ||
| // operation will be a nullptr. | ||
| op_ = tensor_.getDefiningOp(); | ||
| } | ||
|
|
||
| static_assert(std::bidirectional_iterator<TensorIterator>); | ||
| static_assert(std::sentinel_for<std::default_sentinel_t, TensorIterator>, | ||
| "std::default_sentinel_t must be a sentinel for TensorIterator."); | ||
| } // namespace mlir::qtensor | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,3 +7,4 @@ | |
| # Licensed under the MIT License | ||
|
|
||
| add_subdirectory(IR) | ||
| add_subdirectory(Utils) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| # Copyright (c) 2023 - 2026 Chair for Design Automation, TUM | ||
| # Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH | ||
| # All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: MIT | ||
| # | ||
| # Licensed under the MIT License | ||
|
|
||
| set(qtensor_utils_target mqt-core-mlir-unittest-qtensor-utils) | ||
| add_executable(${qtensor_utils_target} test_tensoriterator.cpp) | ||
| target_link_libraries(${qtensor_utils_target} PRIVATE GTest::gtest_main MLIRQTensorDialect | ||
| MLIRQTensorUtils MLIRQCOProgramBuilder) | ||
| mqt_mlir_configure_unittest_target(${qtensor_utils_target}) | ||
|
|
||
| gtest_discover_tests(${qtensor_utils_target} PROPERTIES LABELS mqt-mlir-unittests DISCOVERY_TIMEOUT | ||
| 60) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.