diff --git a/CHANGELOG.md b/CHANGELOG.md index d65d87da90..5cfb5bb937 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ This project adheres to [Semantic Versioning], with the exception that minor rel ### Added +- ✨ Add a `fuse-single-qubit-unitary-runs` pass for fusing compile-time single-qubit unitary runs via Euler resynthesis ([#1672]) ([**@simon1hofmann**]) - 🚸 Add [CMake presets] to provide a standardized and reproducible way to configure builds ([#1660]) ([**@denialhaag**]) - ✨ Add a `quantum-loop-unroll` pass for unrolling for-loop operations containing quantum operations ([#1718]) ([**@MatthiasReumann**]) - ✨ Add a `hadamard-lifting` pass for lifting Hadamard gates above Pauli gates ([#1605]) ([**@lirem101**], [**@burgholzer**]) @@ -422,6 +423,7 @@ _📚 Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool [#1675]: https://github.com/munich-quantum-toolkit/core/pull/1675 [#1674]: https://github.com/munich-quantum-toolkit/core/pull/1674 [#1673]: https://github.com/munich-quantum-toolkit/core/pull/1673 +[#1672]: https://github.com/munich-quantum-toolkit/core/pull/1672 [#1664]: https://github.com/munich-quantum-toolkit/core/pull/1664 [#1662]: https://github.com/munich-quantum-toolkit/core/pull/1662 [#1660]: https://github.com/munich-quantum-toolkit/core/pull/1660 diff --git a/mlir/include/mlir/Dialect/QCO/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/QCO/IR/CMakeLists.txt index f5368621d1..f5a9e74218 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/QCO/IR/CMakeLists.txt @@ -8,6 +8,9 @@ add_mlir_dialect(QCOOps qco) add_mlir_interface(QCOInterfaces) +add_mlir_interface(QCOUnitaryMatrixInterfaces) add_mlir_doc(QCOOps QCODialect Dialects/ -gen-dialect-doc) add_mlir_doc(QCOInterfaces QCOInterfaces Dialects/ -gen-op-interface-docs -dialect=qco) +add_mlir_doc(QCOUnitaryMatrixInterfaces QCOUnitaryMatrixInterfaces Dialects/ -gen-op-interface-docs + -dialect=qco) diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOInterfaces.h b/mlir/include/mlir/Dialect/QCO/IR/QCOInterfaces.h index 035b0b5268..6296cd0325 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOInterfaces.h +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOInterfaces.h @@ -10,7 +10,6 @@ #pragma once -#include #include #include #include diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOInterfaces.td b/mlir/include/mlir/Dialect/QCO/IR/QCOInterfaces.td index 3de8255b8e..4a507ea984 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOInterfaces.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOInterfaces.td @@ -28,62 +28,6 @@ def UnitaryOpInterface : OpInterface<"UnitaryOpInterface"> { let cppNamespace = "::mlir::qco"; - // Generic implementation body for getUnitaryMatrix methods - defvar unitaryMatrixMethodBody = [{ - auto process = [&](MatrixType&& m) -> bool { - using TargetT = std::remove_cvref_t; - using SourceT = std::remove_cvref_t; - - constexpr bool isTargetDynamic = - (TargetT::SizeAtCompileTime == Eigen::Dynamic); - constexpr bool isSourceDynamic = - (SourceT::SizeAtCompileTime == Eigen::Dynamic); - - // Case 1: Target is Dynamic. Always accepts source. - if constexpr (isTargetDynamic) { - out = std::forward(m); - return true; - } - // Case 2: Target is Fixed. - else { - // Case 2a: Source is Dynamic. Runtime dimension check required. - if constexpr (isSourceDynamic) { - if (m.rows() == static_cast(TargetT::RowsAtCompileTime) && - m.cols() == static_cast(TargetT::ColsAtCompileTime)) - [[likely]] { - out = std::forward(m); - return true; - } - } - // Case 2b: Source is Fixed. Compile-time check. - else if constexpr (static_cast( - SourceT::RowsAtCompileTime) == - static_cast( - TargetT::RowsAtCompileTime) && - static_cast( - SourceT::ColsAtCompileTime) == - static_cast( - TargetT::ColsAtCompileTime)) { - out = std::forward(m); - return true; - } - } - return false; - }; - - - if constexpr (requires { $_op.getUnitaryMatrix().has_value(); }) { - if (auto&& matrix = $_op.getUnitaryMatrix()) { - return process(std::move(*matrix)); - } - return false; - } else if constexpr (requires { $_op.getUnitaryMatrix(); }) { - return process($_op.getUnitaryMatrix()); - } else { - llvm::reportFatalUsageError("Operation '" + $_op.getBaseSymbol() + "' has no unitary matrix definition!"); - } - }]; - let methods = [ // Qubit accessors InterfaceMethod< @@ -152,60 +96,7 @@ def UnitaryOpInterface : OpInterface<"UnitaryOpInterface"> { // Identification InterfaceMethod<"Returns the base symbol/mnemonic of the operation.", - "StringRef", "getBaseSymbol", (ins)>, - - // Unitary matrix helpers - InterfaceMethod<"Populates the given 1x1 unitary matrix if possible.", - "bool", "getUnitaryMatrix1x1", - (ins "Eigen::Matrix, 1, 1>&":$out), - unitaryMatrixMethodBody>, - InterfaceMethod<"Populates the given 2x2 unitary matrix if possible.", - "bool", "getUnitaryMatrix2x2", - (ins "Eigen::Matrix2cd&":$out), unitaryMatrixMethodBody>, - InterfaceMethod<"Populates the given 4x4 unitary matrix if possible.", - "bool", "getUnitaryMatrix4x4", - (ins "Eigen::Matrix4cd&":$out), unitaryMatrixMethodBody>, - InterfaceMethod<"Populates the given dynamic unitary matrix.", "bool", - "getUnitaryMatrixDynamic", (ins "Eigen::MatrixXcd&":$out), - unitaryMatrixMethodBody>]; - - let extraClassDeclaration = [{ - template - std::optional getUnitaryMatrix() { - MatrixType out; - bool result = false; - - // Dispatch to the appropriate fixed-size or dynamic method based on the - // matrix type. - if constexpr (MatrixType::RowsAtCompileTime == 1 && - MatrixType::ColsAtCompileTime == 1) { - result = this->getUnitaryMatrix1x1(out); - } else if constexpr (MatrixType::RowsAtCompileTime == 2 && - MatrixType::ColsAtCompileTime == 2) { - result = this->getUnitaryMatrix2x2(out); - } else if constexpr (MatrixType::RowsAtCompileTime == 4 && - MatrixType::ColsAtCompileTime == 4) { - result = this->getUnitaryMatrix4x4(out); - } else if constexpr (MatrixType::SizeAtCompileTime == Eigen::Dynamic) { - result = this->getUnitaryMatrixDynamic(out); - } else { - // Fallback: Try obtaining dynamic matrix and see if size matches - Eigen::MatrixXcd dynamicOut; - if (this->getUnitaryMatrixDynamic(dynamicOut)) { - if (dynamicOut.rows() == MatrixType::RowsAtCompileTime && - dynamicOut.cols() == MatrixType::ColsAtCompileTime) { - out = dynamicOut; - result = true; - } - } - } - - if (result) { - return out; - } - return std::nullopt; - } - }]; + "StringRef", "getBaseSymbol", (ins)>]; } #endif // MLIR_DIALECT_QCO_IR_QCOINTERFACES_TD diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.h b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.h index 99d5bb6fc0..3a80a7e483 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.h +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.h @@ -23,6 +23,7 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOInterfaces.h" +#include "mlir/Dialect/QCO/IR/QCOUnitaryMatrixInterfaces.h" #include #include diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td index a5bbfb7f51..d632c33216 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td @@ -11,6 +11,7 @@ include "mlir/Dialect/QCO/IR/QCODialect.td" include "mlir/Dialect/QCO/IR/QCOInterfaces.td" +include "mlir/Dialect/QCO/IR/QCOUnitaryMatrixInterfaces.td" include "mlir/Dialect/QCO/IR/QCOTypes.td" include "mlir/IR/EnumAttr.td" @@ -183,8 +184,9 @@ def TwoTargetTwoParameter : TargetAndParameterArityTrait<2, 2>; //===----------------------------------------------------------------------===// def GPhaseOp - : QCOOp<"gphase", traits = [UnitaryOpInterface, ZeroTargetOneParameter, - MemoryEffects<[MemWrite]>]> { + : QCOOp<"gphase", + traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + ZeroTargetOneParameter, MemoryEffects<[MemWrite]>]> { let summary = "Apply a global phase to the state"; let description = [{ Applies a global phase to the state. @@ -208,7 +210,8 @@ def GPhaseOp let hasCanonicalizer = 1; } -def IdOp : QCOOp<"id", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def IdOp : QCOOp<"id", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + OneTargetZeroParameter]> { let summary = "Apply an Id gate to a qubit"; let description = [{ Applies an Id gate to a qubit and returns the transformed qubit. @@ -232,7 +235,8 @@ def IdOp : QCOOp<"id", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def XOp : QCOOp<"x", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def XOp : QCOOp<"x", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + OneTargetZeroParameter]> { let summary = "Apply an X gate to a qubit"; let description = [{ Applies an X gate to a qubit and returns the transformed qubit. @@ -256,7 +260,8 @@ def XOp : QCOOp<"x", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def YOp : QCOOp<"y", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def YOp : QCOOp<"y", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + OneTargetZeroParameter]> { let summary = "Apply a Y gate to a qubit"; let description = [{ Applies a Y gate to a qubit and returns the transformed qubit. @@ -280,7 +285,8 @@ def YOp : QCOOp<"y", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def ZOp : QCOOp<"z", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def ZOp : QCOOp<"z", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + OneTargetZeroParameter]> { let summary = "Apply a Z gate to a qubit"; let description = [{ Applies a Z gate to a qubit and returns the transformed qubit. @@ -304,7 +310,8 @@ def ZOp : QCOOp<"z", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def HOp : QCOOp<"h", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def HOp : QCOOp<"h", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + OneTargetZeroParameter]> { let summary = "Apply a H gate to a qubit"; let description = [{ Applies a H gate to a qubit and returns the transformed qubit. @@ -328,7 +335,8 @@ def HOp : QCOOp<"h", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def SOp : QCOOp<"s", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def SOp : QCOOp<"s", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + OneTargetZeroParameter]> { let summary = "Apply an S gate to a qubit"; let description = [{ Applies an S gate to a qubit and returns the transformed qubit. @@ -352,8 +360,8 @@ def SOp : QCOOp<"s", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def SdgOp - : QCOOp<"sdg", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def SdgOp : QCOOp<"sdg", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + OneTargetZeroParameter]> { let summary = "Apply an Sdg gate to a qubit"; let description = [{ Applies an Sdg gate to a qubit and returns the transformed qubit. @@ -377,7 +385,8 @@ def SdgOp let hasCanonicalizer = 1; } -def TOp : QCOOp<"t", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def TOp : QCOOp<"t", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + OneTargetZeroParameter]> { let summary = "Apply a T gate to a qubit"; let description = [{ Applies a T gate to a qubit and returns the transformed qubit. @@ -401,8 +410,8 @@ def TOp : QCOOp<"t", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def TdgOp - : QCOOp<"tdg", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def TdgOp : QCOOp<"tdg", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + OneTargetZeroParameter]> { let summary = "Apply a Tdg gate to a qubit"; let description = [{ Applies a Tdg gate to a qubit and returns the transformed qubit. @@ -426,7 +435,8 @@ def TdgOp let hasCanonicalizer = 1; } -def SXOp : QCOOp<"sx", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def SXOp : QCOOp<"sx", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + OneTargetZeroParameter]> { let summary = "Apply an SX gate to a qubit"; let description = [{ Applies an SX gate to a qubit and returns the transformed qubit. @@ -451,7 +461,8 @@ def SXOp : QCOOp<"sx", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { } def SXdgOp - : QCOOp<"sxdg", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { + : QCOOp<"sxdg", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + OneTargetZeroParameter]> { let summary = "Apply an SXdg gate to a qubit"; let description = [{ Applies an SXdg gate to a qubit and returns the transformed qubit. @@ -475,7 +486,8 @@ def SXdgOp let hasCanonicalizer = 1; } -def RXOp : QCOOp<"rx", traits = [UnitaryOpInterface, OneTargetOneParameter]> { +def RXOp : QCOOp<"rx", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + OneTargetOneParameter]> { let summary = "Apply an RX gate to a qubit"; let description = [{ Applies an RX gate to a qubit and returns the transformed qubit. @@ -504,7 +516,8 @@ def RXOp : QCOOp<"rx", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let hasCanonicalizer = 1; } -def RYOp : QCOOp<"ry", traits = [UnitaryOpInterface, OneTargetOneParameter]> { +def RYOp : QCOOp<"ry", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + OneTargetOneParameter]> { let summary = "Apply an RY gate to a qubit"; let description = [{ Applies an RY gate to a qubit and returns the transformed qubit. @@ -533,7 +546,8 @@ def RYOp : QCOOp<"ry", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let hasCanonicalizer = 1; } -def RZOp : QCOOp<"rz", traits = [UnitaryOpInterface, OneTargetOneParameter]> { +def RZOp : QCOOp<"rz", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + OneTargetOneParameter]> { let summary = "Apply an RZ gate to a qubit"; let description = [{ Applies an RZ gate to a qubit and returns the transformed qubit. @@ -562,7 +576,8 @@ def RZOp : QCOOp<"rz", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let hasCanonicalizer = 1; } -def POp : QCOOp<"p", traits = [UnitaryOpInterface, OneTargetOneParameter]> { +def POp : QCOOp<"p", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + OneTargetOneParameter]> { let summary = "Apply a P gate to a qubit"; let description = [{ Applies a P gate to a qubit and returns the transformed qubit. @@ -591,7 +606,8 @@ def POp : QCOOp<"p", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let hasCanonicalizer = 1; } -def ROp : QCOOp<"r", traits = [UnitaryOpInterface, OneTargetTwoParameter]> { +def ROp : QCOOp<"r", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + OneTargetTwoParameter]> { let summary = "Apply an R gate to a qubit"; let description = [{ Applies an R gate to a qubit and returns the transformed qubit. @@ -621,7 +637,8 @@ def ROp : QCOOp<"r", traits = [UnitaryOpInterface, OneTargetTwoParameter]> { let hasCanonicalizer = 1; } -def U2Op : QCOOp<"u2", traits = [UnitaryOpInterface, OneTargetTwoParameter]> { +def U2Op : QCOOp<"u2", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + OneTargetTwoParameter]> { let summary = "Apply a U2 gate to a qubit"; let description = [{ Applies a U2 gate to a qubit and returns the transformed qubit. @@ -651,7 +668,8 @@ def U2Op : QCOOp<"u2", traits = [UnitaryOpInterface, OneTargetTwoParameter]> { let hasCanonicalizer = 1; } -def UOp : QCOOp<"u", traits = [UnitaryOpInterface, OneTargetThreeParameter]> { +def UOp : QCOOp<"u", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + OneTargetThreeParameter]> { let summary = "Apply a U gate to a qubit"; let description = [{ Applies a U gate to a qubit and returns the transformed qubit. @@ -684,7 +702,8 @@ def UOp : QCOOp<"u", traits = [UnitaryOpInterface, OneTargetThreeParameter]> { } def SWAPOp - : QCOOp<"swap", traits = [UnitaryOpInterface, TwoTargetZeroParameter]> { + : QCOOp<"swap", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + TwoTargetZeroParameter]> { let summary = "Apply a SWAP gate to two qubits"; let description = [{ Applies a SWAP gate to two qubits and returns the transformed qubits. @@ -712,7 +731,8 @@ def SWAPOp } def iSWAPOp - : QCOOp<"iswap", traits = [UnitaryOpInterface, TwoTargetZeroParameter]> { + : QCOOp<"iswap", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + TwoTargetZeroParameter]> { let summary = "Apply a iSWAP gate to two qubits"; let description = [{ Applies a iSWAP gate to two qubits and returns the transformed qubits. @@ -737,8 +757,8 @@ def iSWAPOp }]; } -def DCXOp - : QCOOp<"dcx", traits = [UnitaryOpInterface, TwoTargetZeroParameter]> { +def DCXOp : QCOOp<"dcx", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + TwoTargetZeroParameter]> { let summary = "Apply a DCX gate to two qubits"; let description = [{ Applies a DCX gate to two qubits and returns the transformed qubits. @@ -765,8 +785,8 @@ def DCXOp let hasCanonicalizer = 1; } -def ECROp - : QCOOp<"ecr", traits = [UnitaryOpInterface, TwoTargetZeroParameter]> { +def ECROp : QCOOp<"ecr", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + TwoTargetZeroParameter]> { let summary = "Apply an ECR gate to two qubits"; let description = [{ Applies an ECR gate to two qubits and returns the transformed qubits. @@ -793,7 +813,8 @@ def ECROp let hasCanonicalizer = 1; } -def RXXOp : QCOOp<"rxx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { +def RXXOp : QCOOp<"rxx", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + TwoTargetOneParameter]> { let summary = "Apply an RXX gate to two qubits"; let description = [{ Applies an RXX gate to two qubits and returns the transformed qubits. @@ -825,7 +846,8 @@ def RXXOp : QCOOp<"rxx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let hasCanonicalizer = 1; } -def RYYOp : QCOOp<"ryy", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { +def RYYOp : QCOOp<"ryy", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + TwoTargetOneParameter]> { let summary = "Apply an RYY gate to two qubits"; let description = [{ Applies an RYY gate to two qubits and returns the transformed qubits. @@ -857,7 +879,8 @@ def RYYOp : QCOOp<"ryy", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let hasCanonicalizer = 1; } -def RZXOp : QCOOp<"rzx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { +def RZXOp : QCOOp<"rzx", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + TwoTargetOneParameter]> { let summary = "Apply an RZX gate to two qubits"; let description = [{ Applies an RZX gate to two qubits and returns the transformed qubits. @@ -889,7 +912,8 @@ def RZXOp : QCOOp<"rzx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let hasCanonicalizer = 1; } -def RZZOp : QCOOp<"rzz", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { +def RZZOp : QCOOp<"rzz", traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + TwoTargetOneParameter]> { let summary = "Apply an RZZ gate to two qubits"; let description = [{ Applies an RZZ gate to two qubits and returns the transformed qubits. @@ -922,7 +946,8 @@ def RZZOp : QCOOp<"rzz", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { } def XXPlusYYOp : QCOOp<"xx_plus_yy", - traits = [UnitaryOpInterface, TwoTargetTwoParameter]> { + traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + TwoTargetTwoParameter]> { let summary = "Apply an XX+YY gate to two qubits"; let description = [{ Applies an XX+YY gate to two qubits and returns the transformed qubits. @@ -957,7 +982,8 @@ def XXPlusYYOp : QCOOp<"xx_plus_yy", } def XXMinusYYOp : QCOOp<"xx_minus_yy", - traits = [UnitaryOpInterface, TwoTargetTwoParameter]> { + traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + TwoTargetTwoParameter]> { let summary = "Apply an XX-YY gate to two qubits"; let description = [{ Applies an XX-YY gate to two qubits and returns the transformed qubits. @@ -1062,9 +1088,9 @@ def YieldOp : QCOOp<"yield", traits = [Terminator, ReturnLike]> { def CtrlOp : QCOOp<"ctrl", - traits = [UnitaryOpInterface, AttrSizedOperandSegments, - AttrSizedResultSegments, SameOperandsAndResultType, - SameOperandsAndResultShape, + traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, + AttrSizedOperandSegments, AttrSizedResultSegments, + SameOperandsAndResultType, SameOperandsAndResultShape, SingleBlockImplicitTerminator<"::mlir::qco::YieldOp">, RecursiveMemoryEffects]> { let summary = "Add control qubits to a unitary operation"; @@ -1141,7 +1167,7 @@ def CtrlOp def InvOp : QCOOp<"inv", - traits = [UnitaryOpInterface, + traits = [UnitaryOpInterface, UnitaryMatrixOpInterface, SingleBlockImplicitTerminator<"::mlir::qco::YieldOp">, RecursiveMemoryEffects]> { let summary = "Invert a unitary operation"; diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOUnitaryMatrixInterfaces.h b/mlir/include/mlir/Dialect/QCO/IR/QCOUnitaryMatrixInterfaces.h new file mode 100644 index 0000000000..c926d171ed --- /dev/null +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOUnitaryMatrixInterfaces.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +// clang-format:off +#include "mlir/Dialect/QCO/IR/QCOUnitaryMatrixInterfaces.h.inc" // IWYU pragma: export +// clang-format:on diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOUnitaryMatrixInterfaces.td b/mlir/include/mlir/Dialect/QCO/IR/QCOUnitaryMatrixInterfaces.td new file mode 100644 index 0000000000..0b0aa9e311 --- /dev/null +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOUnitaryMatrixInterfaces.td @@ -0,0 +1,139 @@ +// Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +// Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +// All rights reserved. +// +// SPDX-License-Identifier: MIT +// +// Licensed under the MIT License + +#ifndef MLIR_DIALECT_QCO_IR_QCOUNITARYMATRIXINTERFACES_TD +#define MLIR_DIALECT_QCO_IR_QCOUNITARYMATRIXINTERFACES_TD + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// UnitaryMatrixOpInterface +//===----------------------------------------------------------------------===// + +def UnitaryMatrixOpInterface : OpInterface<"UnitaryMatrixOpInterface"> { + let description = [{ + This interface provides a unified API for all operations in the QCO + dialect that can expose their unitary matrix representation. + + This interface is intentionally separate from `UnitaryOpInterface` to + avoid propagating the Eigen dependency to all unitary ops. + }]; + + let cppNamespace = "::mlir::qco"; + + // Generic implementation body for getUnitaryMatrix methods + defvar unitaryMatrixMethodBody = [{ + auto process = [&](MatrixType&& m) -> bool { + using TargetT = std::remove_cvref_t; + using SourceT = std::remove_cvref_t; + + constexpr bool isTargetDynamic = + (TargetT::SizeAtCompileTime == Eigen::Dynamic); + constexpr bool isSourceDynamic = + (SourceT::SizeAtCompileTime == Eigen::Dynamic); + + // Case 1: Target is Dynamic. Always accepts source. + if constexpr (isTargetDynamic) { + out = std::forward(m); + return true; + } + // Case 2: Target is Fixed. + else { + // Case 2a: Source is Dynamic. Runtime dimension check required. + if constexpr (isSourceDynamic) { + if (m.rows() == static_cast(TargetT::RowsAtCompileTime) && + m.cols() == static_cast(TargetT::ColsAtCompileTime)) + [[likely]] { + out = std::forward(m); + return true; + } + } + // Case 2b: Source is Fixed. Compile-time check. + else if constexpr (static_cast( + SourceT::RowsAtCompileTime) == + static_cast( + TargetT::RowsAtCompileTime) && + static_cast( + SourceT::ColsAtCompileTime) == + static_cast( + TargetT::ColsAtCompileTime)) { + out = std::forward(m); + return true; + } + } + return false; + }; + + if constexpr (requires { $_op.getUnitaryMatrix().has_value(); }) { + if (auto&& matrix = $_op.getUnitaryMatrix()) { + return process(std::move(*matrix)); + } + return false; + } else if constexpr (requires { $_op.getUnitaryMatrix(); }) { + return process($_op.getUnitaryMatrix()); + } else { + llvm::reportFatalUsageError("Operation '" + $_op.getBaseSymbol() + "' has no unitary matrix definition!"); + } + }]; + + let methods = + [InterfaceMethod<"Populates the given 1x1 unitary matrix if possible.", + "bool", "getUnitaryMatrix1x1", + (ins "Eigen::Matrix, 1, 1>&":$out), + unitaryMatrixMethodBody>, + InterfaceMethod<"Populates the given 2x2 unitary matrix if possible.", + "bool", "getUnitaryMatrix2x2", + (ins "Eigen::Matrix2cd&":$out), unitaryMatrixMethodBody>, + InterfaceMethod<"Populates the given 4x4 unitary matrix if possible.", + "bool", "getUnitaryMatrix4x4", + (ins "Eigen::Matrix4cd&":$out), unitaryMatrixMethodBody>, + InterfaceMethod<"Populates the given dynamic unitary matrix.", "bool", + "getUnitaryMatrixDynamic", + (ins "Eigen::MatrixXcd&":$out), + unitaryMatrixMethodBody>]; + + let extraClassDeclaration = [{ + template + std::optional getUnitaryMatrix() { + MatrixType out; + bool result = false; + + // Dispatch to the appropriate fixed-size or dynamic method based on the + // matrix type. + if constexpr (MatrixType::RowsAtCompileTime == 1 && + MatrixType::ColsAtCompileTime == 1) { + result = this->getUnitaryMatrix1x1(out); + } else if constexpr (MatrixType::RowsAtCompileTime == 2 && + MatrixType::ColsAtCompileTime == 2) { + result = this->getUnitaryMatrix2x2(out); + } else if constexpr (MatrixType::RowsAtCompileTime == 4 && + MatrixType::ColsAtCompileTime == 4) { + result = this->getUnitaryMatrix4x4(out); + } else if constexpr (MatrixType::SizeAtCompileTime == Eigen::Dynamic) { + result = this->getUnitaryMatrixDynamic(out); + } else { + // Fallback: Try obtaining dynamic matrix and see if size matches + Eigen::MatrixXcd dynamicOut; + if (this->getUnitaryMatrixDynamic(dynamicOut)) { + if (dynamicOut.rows() == MatrixType::RowsAtCompileTime && + dynamicOut.cols() == MatrixType::ColsAtCompileTime) { + out = dynamicOut; + result = true; + } + } + } + + if (result) { + return out; + } + return std::nullopt; + } + }]; +} + +#endif // MLIR_DIALECT_QCO_IR_QCOUNITARYMATRIXINTERFACES_TD diff --git a/mlir/include/mlir/Dialect/QCO/Transforms/Decomposition/Euler.h b/mlir/include/mlir/Dialect/QCO/Transforms/Decomposition/Euler.h new file mode 100644 index 0000000000..dd161514ed --- /dev/null +++ b/mlir/include/mlir/Dialect/QCO/Transforms/Decomposition/Euler.h @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace mlir::qco::decomposition { + +/** + * @brief Euler angles `(theta, phi, lambda)` and global phase for a 2x2 + * unitary. + */ +struct EulerAngles { + double theta = 0.0; + double phi = 0.0; + double lambda = 0.0; + double phase = 0.0; +}; + +/** + * @brief Native gate sets for single-qubit Euler synthesis. + */ +enum class EulerBasis : std::uint8_t { + ZYZ = 0, ///< `RZ(phi) * RY(theta) * RZ(lambda)`. + ZXZ = 1, ///< `RZ(phi) * RX(theta) * RZ(lambda)`. + XZX = 2, ///< `RX(phi) * RZ(theta) * RX(lambda)`. + XYX = 3, ///< `RX(phi) * RY(theta) * RX(lambda)`. + U = 4, ///< `U(theta, phi, lambda)`. + ZSXX = 5, ///< `RZ` / `SX` / `X` chain equivalent to ZYZ. +}; + +/** + * @brief Extracts Euler parameters from single-qubit unitary matrices. + */ +class EulerDecomposition { + friend Value synthesizeUnitary1QEuler(OpBuilder& builder, Location loc, + Value qubit, + const Eigen::Matrix2cd& targetMatrix, + EulerBasis basis); + +public: + /** + * @brief Extracts `(theta, phi, lambda, phase)` for the requested basis. + * + * @param matrix The single-qubit unitary to decompose. + * @param basis The target Euler basis. + * @return The extracted Euler angles and global phase. + */ + [[nodiscard]] static EulerAngles + anglesFromUnitary(const Eigen::Matrix2cd& matrix, EulerBasis basis); + +private: + /** + * @brief Extracts parameters for `RZ(phi) * RY(theta) * RZ(lambda)`. + * + * @param matrix The single-qubit unitary to decompose. + * @return The extracted Euler angles and global phase. + */ + [[nodiscard]] static EulerAngles paramsZYZ(const Eigen::Matrix2cd& matrix); + + /** + * @brief Extracts parameters for `U(theta, phi, lambda)`. + * + * @param matrix The single-qubit unitary to decompose. + * @return The extracted Euler angles and global phase. + */ + [[nodiscard]] static EulerAngles paramsU(const Eigen::Matrix2cd& matrix); + + /** + * @brief Extracts parameters for `RZ(phi) * RX(theta) * RZ(lambda)`. + * + * @param matrix The single-qubit unitary to decompose. + * @return The extracted Euler angles and global phase. + */ + [[nodiscard]] static EulerAngles paramsZXZ(const Eigen::Matrix2cd& matrix); + + /** + * @brief Extracts parameters for `RX(phi) * RY(theta) * RX(lambda)`. + * + * @param matrix The single-qubit unitary to decompose. + * @return The extracted Euler angles and global phase. + */ + [[nodiscard]] static EulerAngles paramsXYX(const Eigen::Matrix2cd& matrix); + + /** + * @brief Extracts parameters for `RX(phi) * RZ(theta) * RX(lambda)`. + * + * @param matrix The single-qubit unitary to decompose. + * @return The extracted Euler angles and global phase. + */ + [[nodiscard]] static EulerAngles paramsXZX(const Eigen::Matrix2cd& matrix); +}; + +/** + * @brief Parses a basis name (e.g. `zyz`, `zsxx`; case-insensitive). + * + * @param basis The basis name. + * @return The parsed basis, or `std::nullopt` if unrecognized. + */ +[[nodiscard]] std::optional parseEulerBasis(StringRef basis); + +/** + * @brief Synthesizes `targetMatrix` as gates in `basis`. + * + * Emits `qco.gphase` when needed so the result matches exactly, not only up to + * global phase. + * + * @param builder Builder for the emitted operations. + * @param loc Location for the emitted operations. + * @param qubit Input qubit value. + * @param targetMatrix The single-qubit unitary to synthesize. + * @param basis The target Euler basis. + * @return The output qubit value. + */ +[[nodiscard]] Value +synthesizeUnitary1QEuler(OpBuilder& builder, Location loc, Value qubit, + const Eigen::Matrix2cd& targetMatrix, + EulerBasis basis); + +/** + * @brief Number of basis gates `synthesizeUnitary1QEuler` would emit. + * + * Excludes `qco.gphase`. Used by the fuse pass to detect overlong in-basis + * runs. + * + * @param targetMatrix The single-qubit unitary that would be synthesized. + * @param basis The target Euler basis. + * @return The gate count (1 for `U`, 3 for KAK bases, 3 or 5 for `ZSXX`). + */ +[[nodiscard]] std::size_t +synthesisGateCount(const Eigen::Matrix2cd& targetMatrix, EulerBasis basis); + +} // namespace mlir::qco::decomposition diff --git a/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td b/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td index 32f678924e..957836c749 100644 --- a/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td @@ -40,6 +40,28 @@ def MergeSingleQubitRotationGates }]; } +def FuseSingleQubitUnitaryRuns + : Pass<"fuse-single-qubit-unitary-runs", "mlir::ModuleOp"> { + let dependentDialects = ["mlir::qco::QCODialect", + "::mlir::arith::ArithDialect", + "::mlir::qtensor::QTensorDialect"]; + let summary = "Fuse single-qubit unitary runs using Euler resynthesis"; + let description = [{ + Matches maximal runs of consecutive single-qubit unitary operations on the + same qubit wire (anchored at each run head), composes their constant unitary + matrices, and replaces each run with an equivalent sequence of basis gates. + + The emitted basis is controlled via the `basis` option (e.g. `zyz`, `zsxx`). + A `gphase` correction is inserted when needed so the rewritten sequence + matches the composed matrix exactly (not only up to global phase). + + Currently, only operations whose unitary matrix can be obtained at compile + time are fused. + }]; + let options = [Option<"basis", "basis", "std::string", "\"zyz\"", + "Target Euler basis (zyz, zxz, xzx, xyx, u, zsxx).">]; +} + def QuantumLoopUnroll : InterfacePass<"quantum-loop-unroll", "FunctionOpInterface"> { let dependentDialects = ["mlir::qco::QCODialect", "mlir::scf::SCFDialect"]; diff --git a/mlir/include/mlir/Dialect/Utils/Utils.h b/mlir/include/mlir/Dialect/Utils/Utils.h index 3d976a5a63..8adb0643a0 100644 --- a/mlir/include/mlir/Dialect/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Utils/Utils.h @@ -19,7 +19,9 @@ namespace mlir::utils { -constexpr auto TOLERANCE = 1e-15; +/// Default absolute tolerance for MLIR dialect numerics (matrix checks, +/// angles). +constexpr auto TOLERANCE = 1e-14; inline Value constantFromScalar(OpBuilder& builder, Location loc, double v) { return arith::ConstantOp::create(builder, loc, builder.getF64FloatAttr(v)); diff --git a/mlir/lib/Dialect/QCO/IR/CMakeLists.txt b/mlir/lib/Dialect/QCO/IR/CMakeLists.txt index f33ece7000..6f7b298d10 100644 --- a/mlir/lib/Dialect/QCO/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/QCO/IR/CMakeLists.txt @@ -14,6 +14,7 @@ file(GLOB_RECURSE SCF "${CMAKE_CURRENT_SOURCE_DIR}/SCF/*.cpp") add_mlir_dialect_library( MLIRQCODialect QCOOps.cpp + QCOUnitaryMatrixInterfaces.cpp ${MODIFIERS} ${OPERATIONS} ${QUBIT_MANAGEMENT} @@ -23,6 +24,7 @@ add_mlir_dialect_library( DEPENDS MLIRQCOOpsIncGen MLIRQCOInterfacesIncGen + MLIRQCOUnitaryMatrixInterfacesIncGen LINK_LIBS PRIVATE MLIRIR diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 25fc88d084..9a3a70d9f8 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QCO/IR/QCOUnitaryMatrixInterfaces.h" #include #include @@ -345,7 +346,12 @@ std::optional CtrlOp::getUnitaryMatrix() { if (!bodyUnitary) { return std::nullopt; } - auto&& targetMatrix = bodyUnitary.getUnitaryMatrix(); + auto matrixOp = + dyn_cast(bodyUnitary.getOperation()); + if (!matrixOp) { + return std::nullopt; + } + auto targetMatrix = matrixOp.getUnitaryMatrix(); if (!targetMatrix) { return std::nullopt; } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index d82a64f819..975dd58c24 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QCO/IR/QCOUnitaryMatrixInterfaces.h" #include #include @@ -406,7 +407,12 @@ std::optional InvOp::getUnitaryMatrix() { if (!bodyUnitary) { return std::nullopt; } - auto&& targetMatrix = bodyUnitary.getUnitaryMatrix(); + auto matrixOp = + dyn_cast(bodyUnitary.getOperation()); + if (!matrixOp) { + return std::nullopt; + } + auto targetMatrix = matrixOp.getUnitaryMatrix(); if (!targetMatrix) { return std::nullopt; } diff --git a/mlir/lib/Dialect/QCO/IR/QCOUnitaryMatrixInterfaces.cpp b/mlir/lib/Dialect/QCO/IR/QCOUnitaryMatrixInterfaces.cpp new file mode 100644 index 0000000000..d07d931e7c --- /dev/null +++ b/mlir/lib/Dialect/QCO/IR/QCOUnitaryMatrixInterfaces.cpp @@ -0,0 +1,13 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/QCO/IR/QCOUnitaryMatrixInterfaces.h" + +#include "mlir/Dialect/QCO/IR/QCOUnitaryMatrixInterfaces.cpp.inc" diff --git a/mlir/lib/Dialect/QCO/Transforms/CMakeLists.txt b/mlir/lib/Dialect/QCO/Transforms/CMakeLists.txt index 3564b55dc5..45708e3a06 100644 --- a/mlir/lib/Dialect/QCO/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/QCO/Transforms/CMakeLists.txt @@ -19,6 +19,7 @@ add_mlir_library( MLIRArithDialect MLIRMathDialect MLIRSCFUtils + Eigen3::Eigen DEPENDS MLIRQCOTransformsIncGen) diff --git a/mlir/lib/Dialect/QCO/Transforms/Decomposition/Euler.cpp b/mlir/lib/Dialect/QCO/Transforms/Decomposition/Euler.cpp new file mode 100644 index 0000000000..08f4579b5e --- /dev/null +++ b/mlir/lib/Dialect/QCO/Transforms/Decomposition/Euler.cpp @@ -0,0 +1,497 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/QCO/Transforms/Decomposition/Euler.h" + +#include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/Utils/Utils.h" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace mlir::qco::decomposition { + +/** + * @brief Wraps `angle` into `[-pi, pi)`, mapping `+pi` (within `atol`) to + * `-pi`. + * + * @param angle The angle to wrap, in radians. + * @param atol Tolerance for snapping `+pi` to `-pi`. + * @return The wrapped angle in `[-pi, pi)`. + */ +[[nodiscard]] static double mod2pi(double angle, + double atol = mlir::utils::TOLERANCE) { + if (!std::isfinite(angle)) { + return angle; + } + + constexpr double pi = std::numbers::pi; + constexpr double twoPi = 2.0 * std::numbers::pi; + + double r = std::fmod(angle + pi, twoPi); + if (r < 0.0) { + r += twoPi; + } + double wrapped = r - pi; + + if (wrapped >= pi - atol) { + wrapped = -pi; + } + + return wrapped; +} + +/** + * @brief Conjugates a single-qubit matrix by Hadamard (`H * m * H`). + * + * Maps XYX / XZX parameterizations to ZYZ / ZXZ. + * + * @param m The single-qubit matrix to conjugate. + * @return `H * m * H`. + */ +[[nodiscard]] static Eigen::Matrix2cd +hadamardConjugate(const Eigen::Matrix2cd& m) { + const auto a = m(0, 0); + const auto b = m(0, 1); + const auto c = m(1, 0); + const auto d = m(1, 1); + return Eigen::Matrix2cd{{0.5 * (a + b + c + d), 0.5 * (a - b + c - d)}, + {0.5 * (a + b - c - d), 0.5 * (a - b - c + d)}}; +} + +/** + * @brief Emits `qco.gphase` when `phase` is outside tolerance. + * + * @param builder Builder for the operation. + * @param loc Location of the operation. + * @param phase Global phase in radians. + */ +static void emitGPhaseIfNeeded(OpBuilder& builder, Location loc, double phase) { + if (std::abs(phase) <= mlir::utils::TOLERANCE) { + return; + } + GPhaseOp::create(builder, loc, phase); +} + +namespace { + +/** + * @brief Planned PSX (`RZ` / `SX` / `X`) chain; angles in circuit order. + */ +struct PSXSequence { + enum class Middle : std::uint8_t { OneSX, X, SXRZSX }; + Middle middle = Middle::SXRZSX; + double firstRZ = 0.0; + double midRZ = 0.0; + double lastRZ = 0.0; +}; + +} // namespace + +/** + * @brief Classifies the PSX middle-gate case from ZYZ `theta`. + * + * For `theta` in `[0, pi]`: `pi/2` → one `SX`, `pi` → `X`, otherwise + * `SX*RZ*SX`. + * + * @param theta Y-rotation angle from `paramsZYZ`. + * @return The PSX middle-gate case. + */ +[[nodiscard]] static PSXSequence::Middle +classifyPSXMiddleFromZYZTheta(double theta) { + constexpr double eps = mlir::utils::TOLERANCE; + constexpr double halfPi = std::numbers::pi / 2.0; + constexpr double pi = std::numbers::pi; + + if (std::abs(theta - halfPi) < eps) { + return PSXSequence::Middle::OneSX; + } + if (std::abs(theta - pi) < eps) { + return PSXSequence::Middle::X; + } + return PSXSequence::Middle::SXRZSX; +} + +/** + * @brief Builds the PSX sequence for `RZ(phi)*RY(theta)*RZ(lambda)`. + * + * Uses `SX*RZ(theta+pi)*SX = Z*RY(theta)`. + * + * @param theta Y-rotation angle in `[0, pi]`. + * @param phi Trailing Z-rotation angle. + * @param lambda Leading Z-rotation angle. + * @return The planned PSX sequence. + */ +[[nodiscard]] static PSXSequence sequenceFromZYZForPSX(double theta, double phi, + double lambda) { + constexpr double halfPi = std::numbers::pi / 2.0; + constexpr double pi = std::numbers::pi; + + switch (classifyPSXMiddleFromZYZTheta(theta)) { + case PSXSequence::Middle::OneSX: + return {.middle = PSXSequence::Middle::OneSX, + .firstRZ = lambda - halfPi, + .midRZ = 0.0, + .lastRZ = phi + halfPi}; + case PSXSequence::Middle::X: + return {.middle = PSXSequence::Middle::X, + .firstRZ = lambda, + .midRZ = 0.0, + .lastRZ = phi + pi}; + case PSXSequence::Middle::SXRZSX: + return {.middle = PSXSequence::Middle::SXRZSX, + .firstRZ = lambda, + .midRZ = theta + pi, + .lastRZ = phi + pi}; + } + llvm::reportFatalInternalError("Unhandled PSX middle gate"); +} + +/** + * @brief Global phase offset of `UOp` vs `RZ(phi)*RY(theta)*RZ(lambda)`. + * + * @param phi Trailing Z-rotation angle. + * @param lambda Leading Z-rotation angle. + * @return The global-phase offset in radians. + */ +[[nodiscard]] static double globalPhaseOffsetForU(double phi, double lambda) { + return -0.5 * (phi + lambda); +} + +/** + * @brief Global phase from wrapping an RZ angle with `mod2pi`. + * + * `RZ(angle + 2*pi) = -RZ(angle)`, so `RZ(mod2pi(angle))` differs from + * `RZ(angle)` by `exp(i*(mod2pi(angle) - angle)/2)`. + * + * @param angle The unwrapped RZ angle. + * @return The global-phase contribution in radians. + */ +[[nodiscard]] static double globalPhaseFromRZWrap(double angle) { + constexpr double eps = mlir::utils::TOLERANCE; + return 0.5 * (mod2pi(angle, eps) - angle); +} + +/** + * @brief Global phase offset of the PSX chain vs the ZYZ product. + * + * @param seq The planned PSX sequence. + * @return The global-phase offset in radians. + */ +[[nodiscard]] static double globalPhaseOffsetForPSX(const PSXSequence& seq) { + constexpr double halfPi = std::numbers::pi / 2.0; + constexpr double quarterPi = std::numbers::pi / 4.0; + + switch (seq.middle) { + case PSXSequence::Middle::OneSX: + // `SX = exp(i*pi/4)*RZ(-pi/2)*RY(pi/2)*RZ(pi/2)`; the outer RZ angles + // absorb the +-pi/2, leaving the exp(i*pi/4) phase. RZ wraps add too. + return -quarterPi + globalPhaseFromRZWrap(seq.firstRZ) + + globalPhaseFromRZWrap(seq.lastRZ); + case PSXSequence::Middle::X: + // `X` swaps the diagonal, so the wraps enter with opposite signs. + return -halfPi + globalPhaseFromRZWrap(seq.lastRZ) - + globalPhaseFromRZWrap(seq.firstRZ); + case PSXSequence::Middle::SXRZSX: + // `SX*RZ(theta+pi)*SX = Z*RY(theta)`; all three RZ wraps add. + return halfPi + globalPhaseFromRZWrap(seq.firstRZ) + + globalPhaseFromRZWrap(seq.midRZ) + globalPhaseFromRZWrap(seq.lastRZ); + } + llvm::reportFatalInternalError("Unhandled PSX middle gate"); +} + +/** + * @brief Invokes callbacks for each gate of `seq` in circuit order. + * + * @param seq The planned PSX sequence. + * @param onRZ Called with each RZ angle. + * @param onSX Called for each SX gate. + * @param onX Called for each X gate. + */ +static void visitSequenceInTimeOrder(const PSXSequence& seq, + llvm::function_ref onRZ, + llvm::function_ref onSX, + llvm::function_ref onX) { + onRZ(seq.firstRZ); + switch (seq.middle) { + case PSXSequence::Middle::OneSX: + onSX(); + onRZ(seq.lastRZ); + break; + case PSXSequence::Middle::X: + onX(); + onRZ(seq.lastRZ); + break; + case PSXSequence::Middle::SXRZSX: + onSX(); + onRZ(seq.midRZ); + onSX(); + onRZ(seq.lastRZ); + break; + } +} + +/** + * @brief Emits the gates of `seq` and optional `gphase`. + * + * @param builder Builder for the operations. + * @param loc Location of the operations. + * @param qubit Input qubit value. + * @param seq The planned PSX sequence. + * @param phase Global phase in radians. + * @return The output qubit value. + */ +[[nodiscard]] static Value emitFromPSXSequence(OpBuilder& builder, Location loc, + Value qubit, + const PSXSequence& seq, + double phase) { + constexpr double eps = mlir::utils::TOLERANCE; + visitSequenceInTimeOrder( + seq, + [&](const double angle) { + qubit = + RZOp::create(builder, loc, qubit, mod2pi(angle, eps)).getQubitOut(); + }, + [&] { qubit = SXOp::create(builder, loc, qubit).getQubitOut(); }, + [&] { qubit = XOp::create(builder, loc, qubit).getQubitOut(); }); + emitGPhaseIfNeeded(builder, loc, phase); + return qubit; +} + +/** + * @brief Emits a K-A-K rotation triple and optional `gphase` for `basis`. + * + * @param builder Builder for the operations. + * @param loc Location of the operations. + * @param qubit Input qubit value. + * @param theta Middle (A) rotation angle. + * @param phi Trailing (K) rotation angle. + * @param lambda Leading (K) rotation angle. + * @param phase Global phase in radians. + * @param basis Euler basis selecting the rotation axes. + * @return The output qubit value. + */ +static Value emitKAK(OpBuilder& builder, Location loc, Value qubit, + double theta, double phi, double lambda, double phase, + EulerBasis basis) { + auto emitK = [&](double a) { + switch (basis) { + case EulerBasis::ZYZ: + case EulerBasis::ZXZ: + qubit = RZOp::create(builder, loc, qubit, a).getQubitOut(); + break; + case EulerBasis::XZX: + case EulerBasis::XYX: + qubit = RXOp::create(builder, loc, qubit, a).getQubitOut(); + break; + default: + llvm::reportFatalInternalError("Invalid K gate for KAK emission"); + } + }; + + auto emitA = [&](double a) { + switch (basis) { + case EulerBasis::ZYZ: + case EulerBasis::XYX: + qubit = RYOp::create(builder, loc, qubit, a).getQubitOut(); + break; + case EulerBasis::ZXZ: + qubit = RXOp::create(builder, loc, qubit, a).getQubitOut(); + break; + case EulerBasis::XZX: + qubit = RZOp::create(builder, loc, qubit, a).getQubitOut(); + break; + default: + llvm::reportFatalInternalError("Invalid A gate for KAK emission"); + } + }; + + emitK(lambda); + emitA(theta); + emitK(phi); + emitGPhaseIfNeeded(builder, loc, phase); + return qubit; +} + +//===----------------------------------------------------------------------===// +// Euler decomposition (angles) +//===----------------------------------------------------------------------===// + +EulerAngles +EulerDecomposition::anglesFromUnitary(const Eigen::Matrix2cd& matrix, + EulerBasis basis) { + switch (basis) { + case EulerBasis::XYX: + return paramsXYX(matrix); + case EulerBasis::XZX: + return paramsXZX(matrix); + case EulerBasis::ZYZ: + return paramsZYZ(matrix); + case EulerBasis::ZXZ: + return paramsZXZ(matrix); + case EulerBasis::U: + return paramsU(matrix); + case EulerBasis::ZSXX: { + const auto zyz = paramsZYZ(matrix); + const auto seq = sequenceFromZYZForPSX(zyz.theta, zyz.phi, zyz.lambda); + return {.theta = zyz.theta, + .phi = zyz.phi, + .lambda = zyz.lambda, + .phase = zyz.phase + globalPhaseOffsetForPSX(seq)}; + } + } + llvm::reportFatalInternalError( + "Unsupported Euler basis for angle computation in decomposition!"); +} + +EulerAngles EulerDecomposition::paramsZYZ(const Eigen::Matrix2cd& matrix) { + // det(U) = exp(2i*phase); invert the Z-Y-Z parameterization of U's entries. + const std::complex det = + matrix(0, 0) * matrix(1, 1) - matrix(0, 1) * matrix(1, 0); + const auto detArg = std::arg(det); + const auto phase = 0.5 * detArg; + const auto theta = + 2. * std::atan2(std::abs(matrix(1, 0)), std::abs(matrix(0, 0))); + const auto ang1 = std::arg(matrix(1, 1)); + const auto ang2 = std::arg(matrix(1, 0)); + const auto phi = ang1 + ang2 - detArg; + const auto lambda = ang1 - ang2; + return {.theta = theta, .phi = phi, .lambda = lambda, .phase = phase}; +} + +EulerAngles EulerDecomposition::paramsZXZ(const Eigen::Matrix2cd& matrix) { + // ZXZ from ZYZ via RY(theta) = RZ(pi/2)*RX(theta)*RZ(-pi/2). + const auto zyz = paramsZYZ(matrix); + return {.theta = zyz.theta, + .phi = zyz.phi + (std::numbers::pi / 2.0), + .lambda = zyz.lambda - (std::numbers::pi / 2.0), + .phase = zyz.phase}; +} + +EulerAngles EulerDecomposition::paramsXYX(const Eigen::Matrix2cd& matrix) { + // H*RY(theta)*H = RY(-theta): shift outer angles by pi and fix global phase. + const auto zyz = paramsZYZ(hadamardConjugate(matrix)); + const auto newPhi = mod2pi(zyz.phi + std::numbers::pi, 0.); + const auto newLambda = mod2pi(zyz.lambda + std::numbers::pi, 0.); + return {.theta = zyz.theta, + .phi = newPhi, + .lambda = newLambda, + .phase = + zyz.phase + ((newPhi + newLambda - zyz.phi - zyz.lambda) / 2.)}; +} + +EulerAngles EulerDecomposition::paramsXZX(const Eigen::Matrix2cd& matrix) { + // X-Z-X -> Z-X-Z under H conjugation (no Y sign flip, unlike paramsXYX). + return paramsZXZ(hadamardConjugate(matrix)); +} + +EulerAngles EulerDecomposition::paramsU(const Eigen::Matrix2cd& matrix) { + const auto zyz = paramsZYZ(matrix); + return {.theta = zyz.theta, + .phi = zyz.phi, + .lambda = zyz.lambda, + .phase = zyz.phase + globalPhaseOffsetForU(zyz.phi, zyz.lambda)}; +} + +//===----------------------------------------------------------------------===// +// Euler synthesis (IR emission) +//===----------------------------------------------------------------------===// + +std::optional parseEulerBasis(StringRef basis) { + if (basis.equals_insensitive("zyz")) { + return EulerBasis::ZYZ; + } + if (basis.equals_insensitive("zxz")) { + return EulerBasis::ZXZ; + } + if (basis.equals_insensitive("xzx")) { + return EulerBasis::XZX; + } + if (basis.equals_insensitive("xyx")) { + return EulerBasis::XYX; + } + if (basis.equals_insensitive("u")) { + return EulerBasis::U; + } + if (basis.equals_insensitive("zsxx")) { + return EulerBasis::ZSXX; + } + return std::nullopt; +} + +Value synthesizeUnitary1QEuler(OpBuilder& builder, Location loc, Value qubit, + const Eigen::Matrix2cd& targetMatrix, + EulerBasis basis) { + if (basis == EulerBasis::ZSXX) { + const auto zyz = EulerDecomposition::paramsZYZ(targetMatrix); + const auto seq = sequenceFromZYZForPSX(zyz.theta, zyz.phi, zyz.lambda); + return emitFromPSXSequence(builder, loc, qubit, seq, + zyz.phase + globalPhaseOffsetForPSX(seq)); + } + + const auto angles = + EulerDecomposition::anglesFromUnitary(targetMatrix, basis); + + switch (basis) { + case EulerBasis::ZYZ: + case EulerBasis::ZXZ: + case EulerBasis::XZX: + case EulerBasis::XYX: + qubit = emitKAK(builder, loc, qubit, angles.theta, angles.phi, + angles.lambda, angles.phase, basis); + break; + case EulerBasis::U: + qubit = UOp::create(builder, loc, qubit, angles.theta, angles.phi, + angles.lambda) + .getQubitOut(); + emitGPhaseIfNeeded(builder, loc, angles.phase); + break; + case EulerBasis::ZSXX: + llvm_unreachable("ZSXX handled above"); + } + + return qubit; +} + +std::size_t synthesisGateCount(const Eigen::Matrix2cd& targetMatrix, + EulerBasis basis) { + switch (basis) { + case EulerBasis::U: + return 1; + case EulerBasis::ZYZ: + case EulerBasis::ZXZ: + case EulerBasis::XZX: + case EulerBasis::XYX: + // emitKAK always emits the full K-A-K rotation triple. + return 3; + case EulerBasis::ZSXX: { + const double theta = 2. * std::atan2(std::abs(targetMatrix(1, 0)), + std::abs(targetMatrix(0, 0))); + return classifyPSXMiddleFromZYZTheta(theta) == PSXSequence::Middle::SXRZSX + ? 5U + : 3U; + } + } + llvm::reportFatalInternalError("Unhandled Euler basis in synthesisGateCount"); +} + +} // namespace mlir::qco::decomposition diff --git a/mlir/lib/Dialect/QCO/Transforms/NativeSynthesis/FuseSingleQubitUnitaryRuns.cpp b/mlir/lib/Dialect/QCO/Transforms/NativeSynthesis/FuseSingleQubitUnitaryRuns.cpp new file mode 100644 index 0000000000..a176e8b1b7 --- /dev/null +++ b/mlir/lib/Dialect/QCO/Transforms/NativeSynthesis/FuseSingleQubitUnitaryRuns.cpp @@ -0,0 +1,266 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/QCO/IR/QCOInterfaces.h" +#include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QCO/IR/QCOUnitaryMatrixInterfaces.h" +#include "mlir/Dialect/QCO/Transforms/Decomposition/Euler.h" +#include "mlir/Dialect/QCO/Transforms/Passes.h" +#include "mlir/Dialect/QCO/Utils/WireIterator.h" + +#include +#include +#include +#include +#include // IWYU pragma: keep (Passes.h.inc) +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace mlir::qco { + +#define GEN_PASS_DEF_FUSESINGLEQUBITUNITARYRUNS +#include "mlir/Dialect/QCO/Transforms/Passes.h.inc" + +/** + * @brief Whether `op` is inside an `inv`/`ctrl` body. + * + * The modifier's combined unitary is fused as one run member; gates inside its + * body are not separate run members. + * + * @param op The operation to test. + * @return `true` if the parent is `inv` or `ctrl`. + */ +static bool isNestedInModifierRegion(Operation* op) { + Operation* parent = op->getParentOp(); + return parent != nullptr && isa(parent); +} + +/** + * @brief Whether `op` may participate in a fusable single-qubit run. + * + * @param op The unitary operation to test. + * @return `true` for a single-qubit, matrix-backed unitary on the wire, outside + * a modifier body. + */ +static bool isFuseCandidate(UnitaryOpInterface op) { + if (!op || !op.isSingleQubit() || isNestedInModifierRegion(op)) { + return false; + } + return isa(op.getOperation()); +} + +/** + * @brief Returns the compile-time 2x2 unitary matrix of `op`, if available. + * + * @param op The unitary operation to query. + * @return The matrix, or `std::nullopt` if not known at compile time. + */ +static std::optional getConstMatrix(UnitaryOpInterface op) { + auto matrixOp = dyn_cast(op.getOperation()); + if (!matrixOp) { + return std::nullopt; + } + Eigen::Matrix2cd m; + if (!matrixOp.getUnitaryMatrix2x2(m)) { + return std::nullopt; + } + return m; +} + +/** + * @brief Whether `op` can participate in a fusable run. + * + * @param op The operation to test. + * @return `true` for a fuse candidate with a known compile-time matrix. + */ +static bool isRunMember(Operation* op) { + auto iface = dyn_cast(op); + return iface && isFuseCandidate(iface) && getConstMatrix(iface).has_value(); +} + +/** + * @brief Composes a run of unitary ops into a single matrix. + * + * @param run The run members in circuit order. + * @return The product of their matrices. + */ +static Eigen::Matrix2cd composeRun(ArrayRef run) { + Eigen::Matrix2cd composed = Eigen::Matrix2cd::Identity(); + for (auto op : run) { + // First gate in the run is applied first (left factor). + composed = (*getConstMatrix(op)) * composed; + } + return composed; +} + +/** + * @brief Whether `op` is a gate the target `basis` emits. + * + * Gate sets match `emitKAK` and `emitFromPSXSequence` in `Euler.cpp`. Used to + * skip runs that are already in the target basis at canonical length. + * + * @param op The operation to classify. + * @param basis The target Euler basis. + * @return `true` if `op` is emitted by synthesis in `basis`. + */ +static bool isTargetBasisGate(Operation* op, decomposition::EulerBasis basis) { + using decomposition::EulerBasis; + return TypeSwitch(op) + .Case([&](auto) { + return basis == EulerBasis::ZYZ || basis == EulerBasis::ZXZ || + basis == EulerBasis::XZX || basis == EulerBasis::ZSXX; + }) + .Case([&](auto) { + return basis == EulerBasis::ZYZ || basis == EulerBasis::XYX; + }) + .Case([&](auto) { + return basis == EulerBasis::ZXZ || basis == EulerBasis::XZX || + basis == EulerBasis::XYX; + }) + .Case([&](auto) { return basis == EulerBasis::U; }) + .Case([&](auto) { return basis == EulerBasis::ZSXX; }) + .Default([](auto) { return false; }); +} + +namespace { + +/** + * @brief Fuses maximal single-qubit unitary runs via Euler resynthesis. + * + * Matches at each run head so each run is rewritten once. + */ +struct FuseSingleQubitUnitaryRunsPattern final + : OpInterfaceRewritePattern { + FuseSingleQubitUnitaryRunsPattern(MLIRContext* context, + decomposition::EulerBasis basis) + : OpInterfaceRewritePattern(context), basis(basis) {} + + decomposition::EulerBasis basis; + + /** + * @brief Whether `op` is the head of a run. + * + * @param op The candidate run head. + * @return `true` if the wire predecessor is not a run member. + */ + static bool isRunStart(UnitaryOpInterface op) { + if (!isRunMember(op.getOperation())) { + return false; + } + Operation* pred = op.getInputTarget(0).getDefiningOp(); + return pred == nullptr || !isRunMember(pred); + } + + /** + * @brief Collects the maximal fusable run starting at `start`. + * + * @param start The run head. + * @return The run members in circuit order. + */ + static SmallVector collectRun(UnitaryOpInterface start) { + SmallVector run{start}; + Block* block = start->getBlock(); + for (WireIterator it = std::next(WireIterator(start.getOutputTarget(0))); + it != std::default_sentinel; ++it) { + Operation* op = it.operation(); + if (op->getBlock() != block || !isRunMember(op)) { + break; + } + run.emplace_back(cast(op)); + } + return run; + } + + /** + * @brief Fuses the run anchored at `op` when beneficial. + * + * Fuses if the run contains a non-basis gate or is longer than the canonical + * synthesis for its composed matrix. + * + * @param op The matched unitary operation. + * @param rewriter The pattern rewriter. + * @return `success()` if a run was fused, `failure()` otherwise. + */ + LogicalResult matchAndRewrite(UnitaryOpInterface op, + PatternRewriter& rewriter) const override { + if (!isRunStart(op)) { + return failure(); + } + + auto run = collectRun(op); + const Eigen::Matrix2cd composed = composeRun(run); + const bool hasNonBasisGate = + llvm::any_of(run, [&](UnitaryOpInterface member) { + return !isTargetBasisGate(member.getOperation(), basis); + }); + if (!hasNonBasisGate && + run.size() <= decomposition::synthesisGateCount(composed, basis)) { + return failure(); + } + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op.getOperation()); + const Value qubit = decomposition::synthesizeUnitary1QEuler( + rewriter, op.getLoc(), op.getInputTarget(0), composed, basis); + + rewriter.replaceAllUsesWith(run.back().getOutputTarget(0), qubit); + for (UnitaryOpInterface member : std::ranges::reverse_view(run)) { + rewriter.eraseOp(member.getOperation()); + } + return success(); + } +}; + +/** + * @brief Pass that fuses single-qubit unitary runs via Euler resynthesis. + */ +struct FuseSingleQubitUnitaryRunsPass final + : impl::FuseSingleQubitUnitaryRunsBase { + using Base::Base; + + explicit FuseSingleQubitUnitaryRunsPass( + FuseSingleQubitUnitaryRunsOptions options) + : Base(std::move(options)) {} + +protected: + void runOnOperation() override { + auto module = getOperation(); + + const auto parsed = decomposition::parseEulerBasis(this->basis); + if (!parsed) { + module.emitError() << "Invalid Euler basis '" << this->basis + << "'. Expected one of: zyz, zxz, xzx, xyx, u, zsxx."; + signalPassFailure(); + return; + } + + RewritePatternSet patterns(&getContext()); + patterns.add(patterns.getContext(), + *parsed); + + if (failed(applyPatternsGreedily(module, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +} // namespace mlir::qco diff --git a/mlir/lib/Dialect/QCO/Utils/CMakeLists.txt b/mlir/lib/Dialect/QCO/Utils/CMakeLists.txt index 9b94885f1b..26a14e17e4 100644 --- a/mlir/lib/Dialect/QCO/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/QCO/Utils/CMakeLists.txt @@ -18,7 +18,8 @@ add_mlir_dialect_library( MLIRQCOInterfacesIncGen LINK_LIBS PUBLIC - MLIRQCODialect) + MLIRQCODialect + MLIRSCFDialect) mqt_mlir_target_use_project_options(MLIRQCOUtils) diff --git a/mlir/lib/Dialect/QCO/Utils/WireIterator.cpp b/mlir/lib/Dialect/QCO/Utils/WireIterator.cpp index 45522cd8bf..1e489d2017 100644 --- a/mlir/lib/Dialect/QCO/Utils/WireIterator.cpp +++ b/mlir/lib/Dialect/QCO/Utils/WireIterator.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -25,8 +26,10 @@ namespace mlir::qco { Value WireIterator::qubit() const { - // A sink/deallocation/insert doesn't have an OpResult. - if (op_ != nullptr && (isa(op_))) { + // Boundary ops (sink/deallocation/insert/yield) consume the wire via an + // operand and have no OpResult, matching the boundaries in forward/backward. + if (op_ != nullptr && + (isa(op_))) { return nullptr; } return qubit_; @@ -42,8 +45,8 @@ void WireIterator::forward() { assert(qubit_.hasOneUse() && "expected linear typing"); op_ = *(qubit_.user_begin()); - // A sink/insert defines the end of the qubit wire (dynamic and static). - if (isa(op_)) { + // A sink/insert/yield defines the end of the qubit wire (dynamic and static). + if (isa(op_)) { isSentinel_ = true; return; } @@ -70,9 +73,9 @@ void WireIterator::backward() { return; } - // For sinks/deallocations/inserts, qubit_ is an OpOperand. Hence, only get - // the def-op. - if (isa(op_)) { + // For sinks/deallocations/inserts/yields, qubit_ is an OpOperand. Hence, only + // get the def-op. + if (isa(op_)) { op_ = qubit_.getDefiningOp(); return; } diff --git a/mlir/unittests/Dialect/QCO/Transforms/CMakeLists.txt b/mlir/unittests/Dialect/QCO/Transforms/CMakeLists.txt index 9f9b03449d..d59780f461 100644 --- a/mlir/unittests/Dialect/QCO/Transforms/CMakeLists.txt +++ b/mlir/unittests/Dialect/QCO/Transforms/CMakeLists.txt @@ -6,5 +6,6 @@ # # Licensed under the MIT License +add_subdirectory(Decomposition) add_subdirectory(Mapping) add_subdirectory(Optimizations) diff --git a/mlir/unittests/Dialect/QCO/Transforms/Decomposition/CMakeLists.txt b/mlir/unittests/Dialect/QCO/Transforms/Decomposition/CMakeLists.txt new file mode 100644 index 0000000000..11a49eb044 --- /dev/null +++ b/mlir/unittests/Dialect/QCO/Transforms/Decomposition/CMakeLists.txt @@ -0,0 +1,19 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +set(target_name mqt-core-mlir-unittest-decomposition) +add_executable(${target_name} test_euler_decomposition.cpp) + +target_link_libraries(${target_name} PRIVATE GTest::gtest_main MLIRQCOProgramBuilder + MLIRQCOTransforms Eigen3::Eigen) +target_link_libraries(${target_name} PRIVATE MLIRPass MLIRFuncDialect MLIRArithDialect MLIRIR + MLIRSupport MLIRQTensorDialect) + +mqt_mlir_configure_unittest_target(${target_name}) + +gtest_discover_tests(${target_name} PROPERTIES LABELS mqt-mlir-unittests DISCOVERY_TIMEOUT 60) diff --git a/mlir/unittests/Dialect/QCO/Transforms/Decomposition/test_euler_decomposition.cpp b/mlir/unittests/Dialect/QCO/Transforms/Decomposition/test_euler_decomposition.cpp new file mode 100644 index 0000000000..1aafd01aad --- /dev/null +++ b/mlir/unittests/Dialect/QCO/Transforms/Decomposition/test_euler_decomposition.cpp @@ -0,0 +1,717 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/QCO/Builder/QCOProgramBuilder.h" +#include "mlir/Dialect/QCO/IR/QCODialect.h" +#include "mlir/Dialect/QCO/IR/QCOInterfaces.h" +#include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QCO/IR/QCOUnitaryMatrixInterfaces.h" +#include "mlir/Dialect/QCO/Transforms/Decomposition/Euler.h" +#include "mlir/Dialect/QCO/Transforms/Passes.h" +#include "mlir/Dialect/Utils/Utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::qco; +using namespace mlir::qco::decomposition; + +namespace { + +struct SynthesisFixture { + std::unique_ptr context; + + void setUp() { + DialectRegistry registry; + registry.insert(); + context = std::make_unique(); + context->appendDialectRegistry(registry); + context->loadAllAvailableDialects(); + } +}; + +struct SynthesizedCircuit { + OwningOpRef module; + func::FuncOp func; +}; + +} // namespace + +template +[[nodiscard]] static MatrixType randomUnitaryMatrix(std::mt19937& rng) { + static_assert(MatrixType::RowsAtCompileTime != Eigen::Dynamic && + MatrixType::ColsAtCompileTime != Eigen::Dynamic, + "randomUnitaryMatrix requires fixed-size matrices"); + static_assert(MatrixType::RowsAtCompileTime == MatrixType::ColsAtCompileTime, + "randomUnitaryMatrix requires square matrices"); + std::normal_distribution normalDist(0.0, 1.0); + MatrixType randomMatrix; + for (auto& x : randomMatrix.reshaped()) { + x = std::complex(normalDist(rng), normalDist(rng)); + } + Eigen::HouseholderQR qr{}; + qr.compute(randomMatrix); + const MatrixType qMatrix = qr.householderQ(); + const MatrixType rMatrix = + qr.matrixQR().template triangularView(); + MatrixType dMatrix = MatrixType::Identity(); + constexpr Eigen::Index dim = MatrixType::RowsAtCompileTime; + for (Eigen::Index i = 0; i < dim; ++i) { + const auto rii = rMatrix(i, i); + const auto absRii = std::abs(rii); + dMatrix(i, i) = + absRii > 0.0 ? (rii / absRii) : std::complex{1.0, 0.0}; + } + return qMatrix * dMatrix; +} + +template static void forEachBasis(Fn fn) { + const std::array bases = {"zyz", "zxz", "xzx", + "xyx", "u", "zsxx"}; + for (const char* basis : bases) { + fn(StringRef{basis}); + } +} + +static bool isAllowedBasisGate(Operation& op, StringRef basis) { + // `gphase` is always allowed. + if (isa(op)) { + return true; + } + + const auto b = basis.lower(); + if (b == "zyz") { + return isa(op); + } + if (b == "zxz") { + return isa(op); + } + if (b == "xzx") { + return isa(op); + } + if (b == "xyx") { + return isa(op); + } + if (b == "u") { + return isa(op); + } + if (b == "zsxx") { + return isa(op); + } + return false; +} + +[[nodiscard]] static bool isTwoQubitGate(Operation& op) { + if (auto u = dyn_cast(op)) { + return u.isTwoQubit(); + } + return false; +} + +// Matrix-backed 1Q gate (not barrier, 2Q, or `gphase`). +[[nodiscard]] static bool isOneQubitGate(Operation& op) { + if (isa(op) || !isa(op)) { + return false; + } + auto u = dyn_cast(op); + return u && u.isSingleQubit(); +} + +// At least one 1Q gate before and after the first `isBoundary` op in `main`. +template +static void expectOneQubitGatesAroundBoundary(func::FuncOp funcOp, + StringRef basis, + BoundaryPred isBoundary) { + auto& block = funcOp.getBody().front(); + std::size_t before = 0; + std::size_t after = 0; + bool seenBoundary = false; + for (Operation& op : block.without_terminator()) { + if (!seenBoundary && isBoundary(op)) { + seenBoundary = true; + continue; + } + if (!isOneQubitGate(op)) { + continue; + } + if (seenBoundary) { + ++after; + } else { + ++before; + } + } + EXPECT_GE(before, 1U) << "basis=" << basis.str(); + EXPECT_GE(after, 1U) << "basis=" << basis.str(); +} + +static void expectBasisGatesOnly(func::FuncOp funcOp, StringRef basis) { + auto& block = funcOp.getBody().front(); + for (Operation& op : block.without_terminator()) { + if (isa(op)) { + continue; + } + + if (isTwoQubitGate(op)) { + continue; + } + + // Matrix-backed ops must be allowed basis gates. + if (isa(op)) { + EXPECT_TRUE(isAllowedBasisGate(op, basis)) + << "basis=" << basis.str() + << " unexpected gate: " << op.getName().getStringRef().str(); + } + } +} + +static Eigen::Matrix2cd compute1QMatrixFromFunction(func::FuncOp funcOp) { + Eigen::Matrix2cd acc = Eigen::Matrix2cd::Identity(); + std::complex global{1.0, 0.0}; + bool failed = false; + + // Include nested regions (`scf.for`); skip `inv`/`ctrl` bodies after the + // modifier op (combined matrix already counted). + funcOp.walk([&](Operation* op) -> WalkResult { + if (isa(*op) || isTwoQubitGate(*op)) { + return WalkResult::advance(); + } + + if (auto gphase = dyn_cast(*op)) { + if (auto m = gphase.getUnitaryMatrix()) { + global *= (*m)(0, 0); + } + return WalkResult::advance(); + } + + if (auto iface = dyn_cast(*op)) { + const auto maybeM = iface.getUnitaryMatrix(); + if (!maybeM) { + ADD_FAILURE() << "Expected constant unitary matrix for op: " + << op->getName().getStringRef().str(); + failed = true; + return WalkResult::interrupt(); + } + acc = (*maybeM) * acc; + return WalkResult::skip(); + } + + return WalkResult::advance(); + }); + + if (failed) { + return Eigen::Matrix2cd::Zero(); + } + return global * acc; +} + +static LogicalResult runFuse(ModuleOp module, StringRef basis) { + PassManager pm(module.getContext()); + qco::FuseSingleQubitUnitaryRunsOptions opts; + opts.basis = basis.str(); + pm.addPass(qco::createFuseSingleQubitUnitaryRuns(opts)); + return pm.run(module); +} + +static void singleQubitRunWithSingleQubitGate(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + q[0] = b.h(q[0]); + q[0] = b.t(q[0]); + q[0] = b.rz(0.123, q[0]); + // `inv` is part of the fusable run. + q[0] = b.inv({q[0]}, [&](ValueRange targets) -> SmallVector { + return {b.sx(targets[0])}; + })[0]; + q[0] = b.ry(-0.456, q[0]); +} + +static void singleQubitRunsSplitByTwoQGate(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + q[0] = b.h(q[0]); + q[0] = b.t(q[0]); + std::tie(q[0], q[1]) = b.swap(q[0], q[1]); + q[0] = b.rz(0.321, q[0]); + q[0] = b.sx(q[0]); +} + +static void singleQubitRunsSplitByBarrier(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + q[0] = b.h(q[0]); + q[0] = b.t(q[0]); + q[0] = b.barrier({q[0]})[0]; + q[0] = b.rz(0.321, q[0]); + q[0] = b.sx(q[0]); +} + +// Single `H` gate — not in any target basis; should still be resynthesized. +static void singleNonBasisGate(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + q[0] = b.h(q[0]); +} + +// Six `RZ`/`RY` gates in `zyz` basis — longer than canonical (3). +static void overlongZyzRun(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + q[0] = b.rz(0.3, q[0]); + q[0] = b.ry(0.5, q[0]); + q[0] = b.rz(0.7, q[0]); + q[0] = b.ry(0.9, q[0]); + q[0] = b.rz(1.1, q[0]); + q[0] = b.ry(1.3, q[0]); +} + +static void singleQubitRunInScfFor(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + b.scfFor(0, 1, 1, ValueRange{q[0]}, [&](Value, ValueRange iterArgs) { + Value wire = iterArgs[0]; + wire = b.h(wire); + wire = b.t(wire); + wire = b.rz(0.123, wire); + return SmallVector{wire}; + }); +} + +[[nodiscard]] static std::size_t countUOpsInScfFor(func::FuncOp funcOp) { + std::size_t count = 0; + funcOp.walk([&](UOp op) { + for (Operation* parent = op->getParentOp(); parent != nullptr; + parent = parent->getParentOp()) { + if (parent->getName().getStringRef() == "scf.for") { + ++count; + break; + } + } + }); + return count; +} + +static OwningOpRef buildProgram(MLIRContext* ctx, + void (*fn)(QCOProgramBuilder&)) { + QCOProgramBuilder builder(ctx); + builder.initialize(); + fn(builder); + return builder.finalize(); +} + +static func::FuncOp lookupMain(ModuleOp module) { + auto func = module.lookupSymbol("main"); + EXPECT_TRUE(func) << "Expected a 'main' function"; + return func; +} + +template +static void runFuseOnProgramForAllBases(MLIRContext* ctx, + void (*program)(QCOProgramBuilder&), + ChecksT checksAfter) { + forEachBasis([&](StringRef basis) { + auto owned = buildProgram(ctx, program); + if (!static_cast(owned)) { + ADD_FAILURE() << "Failed to build program for basis=" << basis.str(); + return; + } + ModuleOp module = *owned; + if (failed(verify(module))) { + ADD_FAILURE() << "Verifier failed for basis=" << basis.str(); + return; + } + + auto funcOp = lookupMain(module); + if (!funcOp) { + ADD_FAILURE() << "Missing 'main' for basis=" << basis.str(); + return; + } + + const Eigen::Matrix2cd original = compute1QMatrixFromFunction(funcOp); + + if (failed(runFuse(module, basis))) { + ADD_FAILURE() << "Fuse pass failed for basis=" << basis.str(); + return; + } + if (failed(verify(module))) { + ADD_FAILURE() << "Verifier failed after fuse for basis=" << basis.str(); + return; + } + + funcOp = lookupMain(module); + if (!funcOp) { + ADD_FAILURE() << "Missing 'main' after fuse for basis=" << basis.str(); + return; + } + + checksAfter(funcOp, basis, original); + }); +} + +template +[[nodiscard]] static Eigen::Matrix2cd rotationMatrix(MLIRContext* ctx, + double theta) { + OpBuilder builder(ctx); + auto module = ModuleOp::create(UnknownLoc::get(ctx)); + builder.setInsertionPointToStart(module.getBody()); + const Location loc = module.getLoc(); + Value q = builder.create(loc).getResult(); + auto op = builder.create(loc, q, theta); + return *cast(op).getUnitaryMatrix(); +} + +[[nodiscard]] static SynthesizedCircuit +synthesizeMatrix(MLIRContext* ctx, const Eigen::Matrix2cd& matrix, + EulerBasis basis) { + OwningOpRef module = ModuleOp::create(UnknownLoc::get(ctx)); + OpBuilder builder(ctx); + builder.setInsertionPointToStart(module->getBody()); + + auto qubitTy = QubitType::get(ctx); + auto funcTy = builder.getFunctionType({qubitTy}, {qubitTy}); + auto func = builder.create(module->getLoc(), "main", funcTy); + auto* entry = func.addEntryBlock(); + + builder.setInsertionPointToStart(entry); + Value q = entry->getArgument(0); + q = synthesizeUnitary1QEuler(builder, module->getLoc(), q, matrix, basis); + builder.create(module->getLoc(), q); + return SynthesizedCircuit{.module = std::move(module), .func = func}; +} + +template +[[nodiscard]] static std::size_t countOps(func::FuncOp funcOp) { + std::size_t count = 0; + funcOp.walk([&](OpTy) { ++count; }); + return count; +} + +[[nodiscard]] static std::size_t countTwoQubitGates(func::FuncOp funcOp) { + std::size_t count = 0; + funcOp.walk([&](UnitaryOpInterface op) { + if (op.isTwoQubit()) { + ++count; + } + }); + return count; +} + +TEST(EulerSynthesisTest, RandomReconstructionAllBases) { + SynthesisFixture fx; + fx.setUp(); + + std::mt19937 rng{12345678UL}; + constexpr int iterations = 200; + + for (int i = 0; i < iterations; ++i) { + const auto original = randomUnitaryMatrix(rng); + + forEachBasis([&](StringRef basisStr) { + const auto parsed = mlir::qco::decomposition::parseEulerBasis(basisStr); + ASSERT_TRUE(parsed) << "basis=" << basisStr.str(); + + auto module = ModuleOp::create(UnknownLoc::get(fx.context.get())); + MLIRContext* ctx = module.getContext(); + + OpBuilder builder(ctx); + builder.setInsertionPointToStart(module.getBody()); + + auto qubitTy = QubitType::get(ctx); + auto funcTy = builder.getFunctionType({qubitTy}, {qubitTy}); + auto func = builder.create(module.getLoc(), "main", funcTy); + auto* entry = func.addEntryBlock(); + + builder.setInsertionPointToStart(entry); + Value q = entry->getArgument(0); + q = mlir::qco::decomposition::synthesizeUnitary1QEuler( + builder, module.getLoc(), q, original, *parsed); + builder.create(module.getLoc(), q); + + ASSERT_TRUE(succeeded(verify(module))) << "basis=" << basisStr.str(); + + const auto restored = compute1QMatrixFromFunction(func); + EXPECT_TRUE(restored.isApprox(original, mlir::utils::TOLERANCE)) + << "basis=" << basisStr.str(); + }); + } +} + +TEST(FuseSingleQubitUnitaryRunsTest, FusesRunInScfForBody) { + SynthesisFixture fx; + fx.setUp(); + + auto owned = buildProgram(fx.context.get(), &singleQubitRunInScfFor); + ASSERT_TRUE(owned); + ModuleOp module = *owned; + ASSERT_TRUE(succeeded(verify(module))); + + auto funcOp = lookupMain(module); + ASSERT_TRUE(funcOp); + const Eigen::Matrix2cd original = compute1QMatrixFromFunction(funcOp); + + ASSERT_TRUE(succeeded(runFuse(module, "u"))); + ASSERT_TRUE(succeeded(verify(module))); + + funcOp = lookupMain(module); + ASSERT_TRUE(funcOp); + EXPECT_GE(countUOpsInScfFor(funcOp), 1U); + EXPECT_TRUE(compute1QMatrixFromFunction(funcOp).isApprox( + original, mlir::utils::TOLERANCE)); +} + +TEST(FuseSingleQubitUnitaryRunsTest, ReconstructsOriginalRunAllBases) { + SynthesisFixture fx; + fx.setUp(); + + runFuseOnProgramForAllBases( + fx.context.get(), &singleQubitRunWithSingleQubitGate, + /*checksAfter=*/ + [&](func::FuncOp funcOp, StringRef basis, + const Eigen::Matrix2cd& original) { + const auto restored = compute1QMatrixFromFunction(funcOp); + EXPECT_TRUE(restored.isApprox(original, mlir::utils::TOLERANCE)) + << "basis=" << basis.str(); + expectBasisGatesOnly(funcOp, basis); + }); +} + +TEST(FuseSingleQubitUnitaryRunsTest, ResynthesizesLoneNonBasisGateAllBases) { + SynthesisFixture fx; + fx.setUp(); + + runFuseOnProgramForAllBases( + fx.context.get(), &singleNonBasisGate, + /*checksAfter=*/ + [&](func::FuncOp funcOp, StringRef basis, + const Eigen::Matrix2cd& original) { + EXPECT_EQ(countOps(funcOp), 0U) + << "basis=" << basis.str() << " left a non-basis gate"; + EXPECT_TRUE(compute1QMatrixFromFunction(funcOp).isApprox( + original, mlir::utils::TOLERANCE)) + << "basis=" << basis.str(); + expectBasisGatesOnly(funcOp, basis); + }); +} + +TEST(FuseSingleQubitUnitaryRunsTest, FusesOverlongInBasisRun) { + SynthesisFixture fx; + fx.setUp(); + + auto owned = buildProgram(fx.context.get(), &overlongZyzRun); + ASSERT_TRUE(owned); + ModuleOp module = *owned; + ASSERT_TRUE(succeeded(verify(module))); + + auto funcOp = lookupMain(module); + ASSERT_TRUE(funcOp); + const Eigen::Matrix2cd original = compute1QMatrixFromFunction(funcOp); + const std::size_t before = countOps(funcOp) + countOps(funcOp); + ASSERT_EQ(before, 6U); + + ASSERT_TRUE(succeeded(runFuse(module, "zyz"))); + ASSERT_TRUE(succeeded(verify(module))); + + funcOp = lookupMain(module); + ASSERT_TRUE(funcOp); + const std::size_t after = countOps(funcOp) + countOps(funcOp); + EXPECT_LE(after, 3U); + EXPECT_LT(after, before); + EXPECT_TRUE(compute1QMatrixFromFunction(funcOp).isApprox( + original, mlir::utils::TOLERANCE)); + expectBasisGatesOnly(funcOp, "zyz"); +} + +TEST(EulerSynthesisTest, ZsxxPauliXUsesXGateShortcut) { + SynthesisFixture fx; + fx.setUp(); + + const Eigen::Matrix2cd pauliX = XOp::getUnitaryMatrix(); + const auto circuit = + synthesizeMatrix(fx.context.get(), pauliX, EulerBasis::ZSXX); + + ASSERT_TRUE(succeeded(verify(*circuit.module))); + EXPECT_EQ(countOps(circuit.func), 1U); + EXPECT_EQ(countOps(circuit.func), 0U); + EXPECT_TRUE(compute1QMatrixFromFunction(circuit.func) + .isApprox(pauliX, mlir::utils::TOLERANCE)); +} + +TEST(EulerSynthesisTest, UGateReconstruction) { + SynthesisFixture fx; + fx.setUp(); + + std::mt19937 rng{99991}; + for (int i = 0; i < 32; ++i) { + const auto u = randomUnitaryMatrix(rng); + const auto circuit = synthesizeMatrix(fx.context.get(), u, EulerBasis::U); + ASSERT_TRUE(succeeded(verify(*circuit.module))); + EXPECT_LE(countOps(circuit.func), 1U); + EXPECT_TRUE(compute1QMatrixFromFunction(circuit.func) + .isApprox(u, mlir::utils::TOLERANCE)); + } +} + +TEST(EulerDecompositionTest, ZYZAnglesFromUnitaryReconstructHadamard) { + SynthesisFixture fx; + fx.setUp(); + + const Eigen::Matrix2cd hadamard = HOp::getUnitaryMatrix(); + const auto [theta, phi, lambda, phase] = + EulerDecomposition::anglesFromUnitary(hadamard, EulerBasis::ZYZ); + + auto module = ModuleOp::create(UnknownLoc::get(fx.context.get())); + OpBuilder builder(fx.context.get()); + builder.setInsertionPointToStart(module.getBody()); + const Location loc = module.getLoc(); + + auto qubitTy = QubitType::get(fx.context.get()); + auto funcTy = builder.getFunctionType({qubitTy}, {qubitTy}); + auto func = builder.create(loc, "main", funcTy); + auto* entry = func.addEntryBlock(); + builder.setInsertionPointToStart(entry); + + Value q = entry->getArgument(0); + auto mkAngle = [&](double angle) -> Value { + return builder + .create(loc, builder.getF64FloatAttr(angle)) + .getResult(); + }; + q = builder.create(loc, q, mkAngle(lambda)).getQubitOut(); + q = builder.create(loc, q, mkAngle(theta)).getQubitOut(); + q = builder.create(loc, q, mkAngle(phi)).getQubitOut(); + if (std::abs(phase) > mlir::utils::TOLERANCE) { + Value phaseVal = mkAngle(phase); + builder.create(loc, phaseVal); + } + builder.create(loc, q); + + ASSERT_TRUE(succeeded(verify(module))); + EXPECT_TRUE(compute1QMatrixFromFunction(func).isApprox( + hadamard, mlir::utils::TOLERANCE)); +} + +// NOLINTNEXTLINE(misc-use-internal-linkage) -- gtest `TEST_P` at global scope +class EulerSynthesisExactTest + : public testing::TestWithParam< + std::tuple> {}; + +TEST_P(EulerSynthesisExactTest, ReconstructsReferenceMatrices) { + SynthesisFixture fx; + fx.setUp(); + + const auto [basis, matrixFn] = GetParam(); + const Eigen::Matrix2cd original = matrixFn(fx.context.get()); + const auto circuit = synthesizeMatrix(fx.context.get(), original, basis); + + ASSERT_TRUE(succeeded(verify(*circuit.module))); + EXPECT_TRUE(compute1QMatrixFromFunction(circuit.func) + .isApprox(original, mlir::utils::TOLERANCE)); +} + +INSTANTIATE_TEST_SUITE_P( + SingleQubitMatrices, EulerSynthesisExactTest, + testing::Combine(testing::Values(EulerBasis::XYX, EulerBasis::XZX, + EulerBasis::ZYZ, EulerBasis::ZXZ, + EulerBasis::U, EulerBasis::ZSXX), + testing::Values( + [](MLIRContext* /*ctx*/) -> Eigen::Matrix2cd { + return Eigen::Matrix2cd::Identity(); + }, + [](MLIRContext* ctx) -> Eigen::Matrix2cd { + return rotationMatrix(ctx, 2.0); + }, + // RY(pi/2): ZSXX single-SX branch. + [](MLIRContext* ctx) -> Eigen::Matrix2cd { + return rotationMatrix(ctx, + std::numbers::pi / 2.0); + }, + [](MLIRContext* ctx) -> Eigen::Matrix2cd { + return rotationMatrix(ctx, 0.5); + }, + [](MLIRContext* ctx) -> Eigen::Matrix2cd { + return rotationMatrix(ctx, 3.14); + }, + [](MLIRContext* /*ctx*/) -> Eigen::Matrix2cd { + return HOp::getUnitaryMatrix(); + }))); + +TEST(FuseSingleQubitUnitaryRunsTest, DoesNotFuseAcrossTwoQGateAllBases) { + SynthesisFixture fx; + fx.setUp(); + + runFuseOnProgramForAllBases( + fx.context.get(), &singleQubitRunsSplitByTwoQGate, + /*checksAfter=*/ + [&](func::FuncOp funcOp, StringRef basis, + const Eigen::Matrix2cd& original) { + EXPECT_EQ(countTwoQubitGates(funcOp), 1U) << "basis=" << basis.str(); + EXPECT_TRUE(compute1QMatrixFromFunction(funcOp).isApprox( + original, mlir::utils::TOLERANCE)) + << "basis=" << basis.str(); + expectBasisGatesOnly(funcOp, basis); + expectOneQubitGatesAroundBoundary( + funcOp, basis, [](Operation& op) { return isTwoQubitGate(op); }); + }); +} + +TEST(FuseSingleQubitUnitaryRunsTest, DoesNotFuseAcrossBarrierAllBases) { + SynthesisFixture fx; + fx.setUp(); + + runFuseOnProgramForAllBases( + fx.context.get(), &singleQubitRunsSplitByBarrier, + /*checksAfter=*/ + [&](func::FuncOp funcOp, StringRef basis, + const Eigen::Matrix2cd& original) { + EXPECT_EQ(countOps(funcOp), 1U) << "basis=" << basis.str(); + EXPECT_TRUE(compute1QMatrixFromFunction(funcOp).isApprox( + original, mlir::utils::TOLERANCE)) + << "basis=" << basis.str(); + expectBasisGatesOnly(funcOp, basis); + expectOneQubitGatesAroundBoundary( + funcOp, basis, [](Operation& op) { return isa(op); }); + }); +} + +TEST(FuseSingleQubitUnitaryRunsTest, InvalidBasisFailsPass) { + SynthesisFixture fx; + fx.setUp(); + + auto owned = + buildProgram(fx.context.get(), &singleQubitRunWithSingleQubitGate); + ASSERT_TRUE(static_cast(owned)); + ModuleOp module = *owned; + ASSERT_TRUE(succeeded(verify(module))); + + EXPECT_TRUE(failed(runFuse(module, "not-a-basis"))); +} diff --git a/mlir/unittests/programs/qco_programs.h b/mlir/unittests/programs/qco_programs.h index b4197c5a7f..f4dc65e58d 100644 --- a/mlir/unittests/programs/qco_programs.h +++ b/mlir/unittests/programs/qco_programs.h @@ -1077,5 +1077,4 @@ void qtensorInsertExtractIndexMismatch(QCOProgramBuilder& b); /// Inserts a qubit into a tensor and extracts it immediately at the same index. void qtensorInsertExtractSameIndex(QCOProgramBuilder& b); - } // namespace mlir::qco