From 90ec2cff02947517a215608894aed6127f4b060d Mon Sep 17 00:00:00 2001 From: Patrick Hopf <81010725+flowerthrower@users.noreply.github.com> Date: Wed, 11 Mar 2026 12:24:41 +0100 Subject: [PATCH 01/20] =?UTF-8?q?=F0=9F=90=9B=20Fix=20RL=20Training=20and?= =?UTF-8?q?=20Improve=20Structure=20(#573)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description This PR addresses critical bugs in the RL training process with the following key changes: **Structure Improvements:** - **Redesigned action validation logic** (`predictorenv.py`): Rewrote `determine_valid_actions_for_state()` with a more structured (but equivalent) state machine that explicitly tracks three circuit states (synthesized, laid_out, routed) and handles 6 different state combinations. - Added helper methods `is_circuit_laid_out()` and `is_circuit_routed()` to replace the buggy `CheckMap` pass with more reliable state checking. The new logic supports both the original restricted MDP and a flexible general MDP mode. - **Fixed type annotation** (`actions.py`): Corrected `do_while` parameter type from `dict[str, Circuit]` to `PropertySet` and added missing import for Qiskit's `PropertySet`. - **Added reproducibility** (`predictor.py`): Set random seed for non-test training runs to ensure reproducible results. - **Improved VF2Layout error handling** (`predictorenv.py`): Replaced assertion failures with warning logs when VF2Layout doesn't find a solution, preventing crashes during training. **Test Updates:** - Suppressed deprecation warnings in tket routing test --------- Signed-off-by: Patrick Hopf <81010725+flowerthrower@users.noreply.github.com> Co-authored-by: flowerthrower Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- src/mqt/predictor/reward.py | 45 +---- src/mqt/predictor/rl/actions.py | 3 +- src/mqt/predictor/rl/predictor.py | 7 +- src/mqt/predictor/rl/predictorenv.py | 191 +++++++++++++++--- .../test_integration_further_SDKs.py | 2 +- tests/compilation/test_predictor_rl.py | 2 +- 6 files changed, 177 insertions(+), 73 deletions(-) diff --git a/src/mqt/predictor/reward.py b/src/mqt/predictor/reward.py index 7c1ce1bba..a8cd76952 100644 --- a/src/mqt/predictor/reward.py +++ b/src/mqt/predictor/reward.py @@ -23,7 +23,6 @@ if TYPE_CHECKING: from qiskit import QuantumCircuit - from qiskit.circuit import QuantumRegister, Qubit from qiskit.transpiler import Target from sklearn.ensemble import RandomForestRegressor @@ -62,44 +61,22 @@ def expected_fidelity(qc: QuantumCircuit, device: Target, precision: int = 10) - if gate_type != "barrier": assert len(qargs) in [1, 2] - first_qubit_idx = calc_qubit_index(qargs, qc.qregs, 0) + first_qubit_idx = qc.find_bit(qargs[0]).index if len(qargs) == 1: specific_fidelity = 1 - device[gate_type][first_qubit_idx,].error else: - second_qubit_idx = calc_qubit_index(qargs, qc.qregs, 1) - specific_fidelity = 1 - device[gate_type][first_qubit_idx, second_qubit_idx].error - + second_qubit_idx = qc.find_bit(qargs[1]).index + try: + specific_fidelity = 1 - device[gate_type][first_qubit_idx, second_qubit_idx].error + except KeyError: + msg = f"Error rate for gate {gate_type} on qubits {first_qubit_idx} and {second_qubit_idx} not found in device properties." + raise KeyError(msg) from None res *= specific_fidelity return float(np.round(res, precision).item()) -def calc_qubit_index(qargs: list[Qubit], qregs: list[QuantumRegister], index: int) -> int: - """Calculates the global qubit index for a given quantum circuit and qubit index. - - Arguments: - qargs: The qubits of the quantum circuit. - qregs: The quantum registers of the quantum circuit. - index: The index of the qubit in the qargs list. - - Returns: - The global qubit index of the given qubit in the quantum circuit. - - Raises: - ValueError: If the qubit index is not found in the quantum registers. - """ - offset = 0 - for reg in qregs: - if qargs[index] not in reg: - offset += reg.size - else: - qubit_index: int = offset + reg.index(qargs[index]) - return qubit_index - error_msg = f"Global qubit index for local qubit {index} index not found." - raise ValueError(error_msg) - - def estimated_success_probability(qc: QuantumCircuit, device: Target, precision: int = 10) -> float: """Calculates the estimated success probability of a given quantum circuit on a given device. @@ -125,7 +102,7 @@ def estimated_success_probability(qc: QuantumCircuit, device: Target, precision: if gate_type == "barrier" or gate_type == "id": continue assert len(qargs) in (1, 2) - first_qubit_idx = calc_qubit_index(qargs, qc.qregs, 0) + first_qubit_idx = qc.find_bit(qargs[0]).index active_qubits.add(first_qubit_idx) if len(qargs) == 1: # single-qubit gate @@ -140,7 +117,7 @@ def estimated_success_probability(qc: QuantumCircuit, device: Target, precision: )) exec_time_per_qubit[first_qubit_idx] += duration else: # multi-qubit gate - second_qubit_idx = calc_qubit_index(qargs, qc.qregs, 1) + second_qubit_idx = qc.find_bit(qargs[1]).index active_qubits.add(second_qubit_idx) duration = device[gate_type][first_qubit_idx, second_qubit_idx].duration op_times.append((gate_type, [first_qubit_idx, second_qubit_idx], duration, "s")) @@ -191,7 +168,7 @@ def estimated_success_probability(qc: QuantumCircuit, device: Target, precision: continue assert len(qargs) in (1, 2) - first_qubit_idx = calc_qubit_index(qargs, qc.qregs, 0) + first_qubit_idx = scheduled_circ.find_bit(qargs[0]).index if len(qargs) == 1: if gate_type == "measure": @@ -213,7 +190,7 @@ def estimated_success_probability(qc: QuantumCircuit, device: Target, precision: continue res *= 1 - device[gate_type][first_qubit_idx,].error else: - second_qubit_idx = calc_qubit_index(qargs, qc.qregs, 1) + second_qubit_idx = scheduled_circ.find_bit(qargs[1]).index res *= 1 - device[gate_type][first_qubit_idx, second_qubit_idx].error if qiskit_version >= "2.0.0": diff --git a/src/mqt/predictor/rl/actions.py b/src/mqt/predictor/rl/actions.py index 7b4432f5a..9efcc76d7 100644 --- a/src/mqt/predictor/rl/actions.py +++ b/src/mqt/predictor/rl/actions.py @@ -86,6 +86,7 @@ from bqskit import Circuit from pytket._tket.passes import BasePass as tket_BasePass + from qiskit.passmanager import PropertySet from qiskit.transpiler.basepasses import BasePass as qiskit_BasePass @@ -143,7 +144,7 @@ class DeviceDependentAction(Action): Callable[..., tuple[Any, ...] | Circuit], ] ) - do_while: Callable[[dict[str, Circuit]], bool] | None = None + do_while: Callable[[PropertySet], bool] | None = None # Registry of actions diff --git a/src/mqt/predictor/rl/predictor.py b/src/mqt/predictor/rl/predictor.py index 2654f34fe..1f75b1901 100644 --- a/src/mqt/predictor/rl/predictor.py +++ b/src/mqt/predictor/rl/predictor.py @@ -99,11 +99,12 @@ def train_model( """ if test: set_random_seed(0) # for reproducibility - n_steps = 10 - n_epochs = 1 - batch_size = 10 + n_steps = 32 + n_epochs = 2 + batch_size = 8 progress_bar = False else: + set_random_seed(0) # default PPO values n_steps = 2048 n_epochs = 10 diff --git a/src/mqt/predictor/rl/predictorenv.py b/src/mqt/predictor/rl/predictorenv.py index 79249c729..c560b9e99 100644 --- a/src/mqt/predictor/rl/predictorenv.py +++ b/src/mqt/predictor/rl/predictorenv.py @@ -21,7 +21,7 @@ from bqskit import Circuit from qiskit.passmanager.base_tasks import Task - from qiskit.transpiler import Target + from qiskit.transpiler import Layout, Target from mqt.predictor.reward import figure_of_merit from mqt.predictor.rl.actions import Action @@ -40,7 +40,6 @@ from qiskit import QuantumCircuit from qiskit.passmanager.flow_controllers import DoWhileController from qiskit.transpiler import CouplingMap, PassManager, TranspileLayout -from qiskit.transpiler.passes import CheckMap, GatesInBasis from qiskit.transpiler.passes.layout.vf2_layout import VF2LayoutStopReason from mqt.predictor.hellinger import get_hellinger_model_path @@ -69,6 +68,7 @@ class PredictorEnv(Env): def __init__( self, device: Target, + mdp: str = "paper", reward_function: figure_of_merit = "expected_fidelity", path_training_circuits: Path | None = None, ) -> None: @@ -76,6 +76,7 @@ def __init__( Arguments: device: The target device to be used for compilation. + mdp: The MDP transition policy. "paper" (default) enforces a strict, linear pipeline (synthesis -> (layout->routing) / mapping), while "flexible" allows for a cyclical approach where actions can be interleaved or reversed. reward_function: The figure of merit to be used for the reward function. Defaults to "expected_fidelity". path_training_circuits: The path to the training circuits folder. Defaults to None, which uses the default path. @@ -96,6 +97,9 @@ def __init__( self.used_actions: list[str] = [] self.device = device + logger.info("MDP: " + mdp) + self.mdp = mdp + # check for uni-directional coupling map coupling_set = {tuple(pair) for pair in self.device.build_coupling_map()} if any((b, a) not in coupling_set for (a, b) in coupling_set): @@ -189,23 +193,21 @@ def step(self, action: int) -> tuple[dict[str, Any], float, bool, bool, dict[Any self.state: QuantumCircuit = altered_qc self.num_steps += 1 + self.state._layout = self.layout # noqa: SLF001 + self.valid_actions = self.determine_valid_actions_for_state() if len(self.valid_actions) == 0: msg = "No valid actions left." raise RuntimeError(msg) if action == self.action_terminate_index: + assert action in self.valid_actions, "Terminate action is not valid but was chosen." reward_val = self.calculate_reward() done = True else: reward_val = 0 done = False - # in case the Qiskit.QuantumCircuit has unitary or u gates in it, decompose them (because otherwise qiskit will throw an error when applying the BasisTranslator - if self.state.count_ops().get("unitary"): # ty: ignore[invalid-argument-type] - self.state = self.state.decompose(gates_to_decompose="unitary") - - self.state._layout = self.layout # noqa: SLF001 obs = create_feature_dict(self.state) return obs, reward_val, done, False, {} @@ -256,7 +258,15 @@ def reset( self.layout = None - self.valid_actions = self.actions_opt_indices + self.actions_synthesis_indices + if self.mdp == "flexible": + self.valid_actions = ( + self.actions_synthesis_indices + + self.actions_mapping_indices + + self.actions_layout_indices + + self.actions_opt_indices + ) + else: + self.valid_actions = self.actions_synthesis_indices + self.actions_opt_indices self.error_occurred = False @@ -268,10 +278,14 @@ def action_masks(self) -> list[bool]: """Returns a list of valid actions for the current state.""" action_mask = [action in self.valid_actions for action in self.action_set] - # it is not clear how tket will handle the layout, so we remove all actions that are from "origin"=="tket" if a layout is set + # TKET layout/optimization actions must not run after a Qiskit layout has been set + # (it is not clear how tket will handle the layout). TKET routing actions are + # designed to work after a Qiskit layout via PreProcessTKETRoutingAfterQiskitLayout. if self.layout is not None: action_mask = [ - action_mask[i] and self.action_set[i].origin != CompilationOrigin.TKET for i in range(len(action_mask)) + action_mask[i] + and (self.action_set[i].origin != CompilationOrigin.TKET or i in self.actions_routing_indices) + for i in range(len(action_mask)) ] if self.has_parameterized_gates or self.layout is not None: @@ -342,9 +356,16 @@ def _apply_qiskit_action(self, action: Action, action_index: int) -> QuantumCirc ): altered_qc = self._handle_qiskit_layout_postprocessing(action, pm, altered_qc) - elif action_index in self.actions_routing_indices and self.layout: + elif ( + action_index in self.actions_routing_indices and self.layout and pm.property_set["final_layout"] is not None + ): self.layout.final_layout = pm.property_set["final_layout"] + # BasisTranslator errors on unitary gates; decompose them immediately so + # the circuit is always in a consistent state after a Qiskit action. + if altered_qc.count_ops().get("unitary"): # ty: ignore[invalid-argument-type] + altered_qc = altered_qc.decompose(gates_to_decompose="unitary") + return altered_qc def _handle_qiskit_layout_postprocessing( @@ -357,8 +378,13 @@ def _handle_qiskit_layout_postprocessing( assert self.layout is not None altered_qc, _ = postprocess_vf2postlayout(altered_qc, post_layout, self.layout) elif action.name == "VF2Layout": - assert pm.property_set["VF2Layout_stop_reason"] == VF2LayoutStopReason.SOLUTION_FOUND - assert pm.property_set["layout"] + if pm.property_set["VF2Layout_stop_reason"] != VF2LayoutStopReason.SOLUTION_FOUND: + logger.warning( + "VF2Layout pass did not find a solution. Reason: %s", + pm.property_set["VF2Layout_stop_reason"], + ) + else: + assert pm.property_set["layout"] else: assert pm.property_set["layout"] @@ -385,7 +411,7 @@ def _apply_tket_action(self, action: Action, action_index: int) -> QuantumCircui qbs = tket_qc.qubits tket_qc.rename_units({qbs[i]: Qubit("q", i) for i in range(len(qbs))}) - altered_qc = tk_to_qiskit(tket_qc) + altered_qc = tk_to_qiskit(tket_qc, replace_implicit_swaps=True) if action_index in self.actions_routing_indices: assert self.layout is not None @@ -428,27 +454,126 @@ def _apply_bqskit_action(self, action: Action, action_index: int) -> QuantumCirc return bqskit_to_qiskit(bqskit_compiled_qc) - def determine_valid_actions_for_state(self) -> list[int]: - """Determines and returns the valid actions for the current state.""" - check_nat_gates = GatesInBasis(basis_gates=self.device.operation_names) - check_nat_gates(self.state) - only_nat_gates = check_nat_gates.property_set["all_gates_in_basis"] + def is_circuit_laid_out(self, circuit: QuantumCircuit, layout: TranspileLayout | Layout) -> bool: + """True if every logical qubit in the circuit has a physical assignment.""" + if isinstance(layout, TranspileLayout): + # Use final_layout if available; otherwise fallback to initial_layout + layout = layout.final_layout or layout.initial_layout + + v2p = layout.get_virtual_bits() + return all(q in v2p for q in circuit.qubits) + + def is_circuit_synthesized(self, circuit: QuantumCircuit) -> bool: + """Check if the circuit uses only native gates of the device. + + Verifies that every gate name in the circuit is present in + ``device.operation_names``, equivalent to the ``GatesInBasis`` pass. + + Args: + circuit: QuantumCircuit to check. + + Returns: + True if all gates are native to the device. + """ + native_names = set(self.device.operation_names) + return all( + instr.operation.name in native_names or instr.operation.name in ("barrier", "measure") + for instr in circuit.data + ) + + def is_circuit_routed(self, circuit: QuantumCircuit, coupling_map: CouplingMap) -> bool: + """Check if a circuit is fully routed to the device, including directionality. - if not only_nat_gates: - actions = self.actions_synthesis_indices + self.actions_opt_indices - if self.layout is not None: - actions += self.actions_routing_indices - return actions + A circuit is considered routed if all two-qubit gates are on qubit pairs + that exist as directed edges in the device coupling map. - check_mapping = CheckMap(coupling_map=self.device.build_coupling_map()) - check_mapping(self.state) - mapped = check_mapping.property_set["is_swap_mapped"] + After a layout pass the circuit's qubits are already physical qubits, so + ``circuit.find_bit(q).index`` gives the physical index directly — + consistent with how ``reward.py`` looks up gate calibrations. - if mapped and self.layout is not None: # The circuit is correctly mapped. - return [self.action_terminate_index, *self.actions_opt_indices] + Args: + circuit: QuantumCircuit to check. + coupling_map: CouplingMap of the target device. - if self.layout is not None: # The circuit is not yet mapped but a layout is set. - return self.actions_routing_indices + Returns: + True if fully routed, False otherwise. + """ + directed_edges = set(coupling_map.get_edges()) + for instr in circuit.data: + if len(instr.qubits) == 2: + q0 = circuit.find_bit(instr.qubits[0]).index + q1 = circuit.find_bit(instr.qubits[1]).index + if (q0, q1) not in directed_edges: + return False + return True - # No layout applied yet - return self.actions_mapping_indices + self.actions_layout_indices + self.actions_opt_indices + def determine_valid_actions_for_state(self) -> list[int]: + """Determine valid actions based on circuit state: synthesized, mapped, routed.""" + synthesized = self.is_circuit_synthesized(self.state) + laid_out = self.is_circuit_laid_out(self.state, self.layout) if self.layout else False + # Routing is only allowed after layout + routed = ( + self.is_circuit_routed(self.state, CouplingMap(self.device.build_coupling_map())) if laid_out else False + ) + + actions = [] + + # Initial state + if not synthesized and not laid_out and not routed: + if self.mdp == "flexible": + actions.extend(self.actions_synthesis_indices) + actions.extend(self.actions_mapping_indices) + actions.extend(self.actions_layout_indices) + actions.extend(self.actions_opt_indices) + if self.mdp == "paper": + actions.extend(self.actions_synthesis_indices) + actions.extend(self.actions_opt_indices) + + if synthesized and not laid_out and not routed: + if self.mdp == "flexible": + actions.extend(self.actions_mapping_indices) + actions.extend(self.actions_layout_indices) + actions.extend(self.actions_opt_indices) + if self.mdp == "paper": + actions.extend(self.actions_mapping_indices) + actions.extend(self.actions_layout_indices) + actions.extend(self.actions_opt_indices) + + # Not *depicted* in paper; necessary because optimization can destroy the native gate set + if not synthesized and laid_out and not routed: + if self.mdp == "flexible": + actions.extend(self.actions_synthesis_indices) + actions.extend(self.actions_routing_indices) + actions.extend(self.actions_opt_indices) + if self.mdp == "paper": + actions.extend(self.actions_synthesis_indices) + actions.extend(self.actions_routing_indices) + actions.extend(self.actions_opt_indices) + + # Not *depicted* in paper; necessary because of layout-only passes + if synthesized and laid_out and not routed: + if self.mdp == "flexible": + actions.extend(self.actions_routing_indices) + actions.extend(self.actions_opt_indices) + if self.mdp == "paper": + actions.extend(self.actions_routing_indices) + + # Not *depicted* in paper; necessary because routing can insert non-native SWAPs + if not synthesized and laid_out and routed: + if self.mdp == "flexible": + actions.extend(self.actions_synthesis_indices) + actions.extend(self.actions_opt_indices) + if self.mdp == "paper": + actions.extend(self.actions_synthesis_indices) + actions.extend(self.actions_opt_indices) + + # Final state + if synthesized and laid_out and routed: + if self.mdp == "flexible": + actions.extend([self.action_terminate_index]) + actions.extend(self.actions_opt_indices) + if self.mdp == "paper": + actions.extend([self.action_terminate_index]) + actions.extend(self.actions_opt_indices) + + return actions diff --git a/tests/compilation/test_integration_further_SDKs.py b/tests/compilation/test_integration_further_SDKs.py index 9aa33bb6f..150363ea3 100644 --- a/tests/compilation/test_integration_further_SDKs.py +++ b/tests/compilation/test_integration_further_SDKs.py @@ -241,7 +241,7 @@ def test_tket_routing(available_actions_dict: dict[PassType, list[Action]]) -> N qubit_map = {qbs[i]: Qubit("q", i) for i in range(len(qbs))} tket_qc.rename_units(qubit_map) - mapped_qc = tk_to_qiskit(tket_qc) + mapped_qc = tk_to_qiskit(tket_qc, replace_implicit_swaps=True) final_layout = final_layout_pytket_to_qiskit(tket_qc, mapped_qc) diff --git a/tests/compilation/test_predictor_rl.py b/tests/compilation/test_predictor_rl.py index 5aa086d3c..798da1da2 100644 --- a/tests/compilation/test_predictor_rl.py +++ b/tests/compilation/test_predictor_rl.py @@ -85,7 +85,7 @@ def test_qcompile_with_newly_trained_models() -> None: rl_compile(qc, device=device, figure_of_merit=figure_of_merit) predictor.train_model( - timesteps=100, + timesteps=1000, test=True, ) From 68eb33836582af60e79bfa4f943ac615646aa1c1 Mon Sep 17 00:00:00 2001 From: flowerthrower Date: Mon, 11 May 2026 16:56:31 +0200 Subject: [PATCH 02/20] =?UTF-8?q?=F0=9F=8E=A8=20improve=20seed=20and=20tra?= =?UTF-8?q?ining=20defaults?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mqt/predictor/rl/predictor.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/mqt/predictor/rl/predictor.py b/src/mqt/predictor/rl/predictor.py index 1f75b1901..0385c563f 100644 --- a/src/mqt/predictor/rl/predictor.py +++ b/src/mqt/predictor/rl/predictor.py @@ -97,14 +97,13 @@ def train_model( verbose: The verbosity level. Defaults to 2. test: Whether to train the model for testing purposes. Defaults to False. """ + set_random_seed(0) # for reproducibility if test: - set_random_seed(0) # for reproducibility - n_steps = 32 - n_epochs = 2 - batch_size = 8 + n_steps = 1000 + n_epochs = 1 + batch_size = 32 progress_bar = False else: - set_random_seed(0) # default PPO values n_steps = 2048 n_epochs = 10 From f9de637e49d23442d36d7974dbf3271d173da006 Mon Sep 17 00:00:00 2001 From: flowerthrower Date: Mon, 11 May 2026 17:12:47 +0200 Subject: [PATCH 03/20] =?UTF-8?q?=F0=9F=8E=A8=20adjust=20test=20step=20lim?= =?UTF-8?q?its?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mqt/predictor/rl/predictor.py | 4 ++-- src/mqt/predictor/rl/predictorenv.py | 2 +- tests/compilation/test_predictor_rl.py | 1 - 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/mqt/predictor/rl/predictor.py b/src/mqt/predictor/rl/predictor.py index 0385c563f..0f3c8c540 100644 --- a/src/mqt/predictor/rl/predictor.py +++ b/src/mqt/predictor/rl/predictor.py @@ -99,9 +99,9 @@ def train_model( """ set_random_seed(0) # for reproducibility if test: - n_steps = 1000 + n_steps = 512 n_epochs = 1 - batch_size = 32 + batch_size = 16 progress_bar = False else: # default PPO values diff --git a/src/mqt/predictor/rl/predictorenv.py b/src/mqt/predictor/rl/predictorenv.py index 4d03fb1ae..1693887d6 100644 --- a/src/mqt/predictor/rl/predictorenv.py +++ b/src/mqt/predictor/rl/predictorenv.py @@ -369,7 +369,7 @@ def _apply_qiskit_action(self, action: Action, action_index: int) -> QuantumCirc # BasisTranslator errors on unitary gates; decompose them immediately so # the circuit is always in a consistent state after a Qiskit action. - if altered_qc.count_ops().get("unitary"): # ty: ignore[invalid-argument-type] + if altered_qc.count_ops().get("unitary"): altered_qc = altered_qc.decompose(gates_to_decompose="unitary") return altered_qc diff --git a/tests/compilation/test_predictor_rl.py b/tests/compilation/test_predictor_rl.py index 798da1da2..6a323c528 100644 --- a/tests/compilation/test_predictor_rl.py +++ b/tests/compilation/test_predictor_rl.py @@ -85,7 +85,6 @@ def test_qcompile_with_newly_trained_models() -> None: rl_compile(qc, device=device, figure_of_merit=figure_of_merit) predictor.train_model( - timesteps=1000, test=True, ) From 55e5e0861969a01a9d3a4e9c6163de3fb44c9e7c Mon Sep 17 00:00:00 2001 From: flowerthrower Date: Mon, 11 May 2026 17:22:38 +0200 Subject: [PATCH 04/20] =?UTF-8?q?=E2=8F=AA=20revert=20unrelated=20changes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mqt/predictor/rl/predictorenv.py | 43 +--------------------------- 1 file changed, 1 insertion(+), 42 deletions(-) diff --git a/src/mqt/predictor/rl/predictorenv.py b/src/mqt/predictor/rl/predictorenv.py index 1693887d6..f126f8a5c 100644 --- a/src/mqt/predictor/rl/predictorenv.py +++ b/src/mqt/predictor/rl/predictorenv.py @@ -69,7 +69,6 @@ class PredictorEnv(Env): def __init__( self, device: Target, - mdp: str = "paper", reward_function: figure_of_merit = "expected_fidelity", path_training_circuits: Path | None = None, ) -> None: @@ -77,7 +76,6 @@ def __init__( Arguments: device: The target device to be used for compilation. - mdp: The MDP transition policy. "paper" (default) enforces a strict, linear pipeline (synthesis -> (layout->routing) / mapping), while "flexible" allows for a cyclical approach where actions can be interleaved or reversed. reward_function: The figure of merit to be used for the reward function. Defaults to "expected_fidelity". path_training_circuits: The path to the training circuits folder. Defaults to None, which uses the default path. @@ -98,9 +96,6 @@ def __init__( self.used_actions: list[str] = [] self.device = device - logger.info("MDP: " + mdp) - self.mdp = mdp - # check for uni-directional coupling map coupling_set = {tuple(pair) for pair in self.device.build_coupling_map()} if any((b, a) not in coupling_set for (a, b) in coupling_set): @@ -264,15 +259,7 @@ def reset( self.layout = None - if self.mdp == "flexible": - self.valid_actions = ( - self.actions_synthesis_indices - + self.actions_mapping_indices - + self.actions_layout_indices - + self.actions_opt_indices - ) - else: - self.valid_actions = self.actions_synthesis_indices + self.actions_opt_indices + self.valid_actions = self.actions_synthesis_indices + self.actions_opt_indices self.error_occurred = False @@ -526,59 +513,31 @@ def determine_valid_actions_for_state(self) -> list[int]: # Initial state if not synthesized and not laid_out and not routed: - if self.mdp == "flexible": - actions.extend(self.actions_synthesis_indices) - actions.extend(self.actions_mapping_indices) - actions.extend(self.actions_layout_indices) - actions.extend(self.actions_opt_indices) - if self.mdp == "paper": actions.extend(self.actions_synthesis_indices) actions.extend(self.actions_opt_indices) if synthesized and not laid_out and not routed: - if self.mdp == "flexible": - actions.extend(self.actions_mapping_indices) - actions.extend(self.actions_layout_indices) - actions.extend(self.actions_opt_indices) - if self.mdp == "paper": actions.extend(self.actions_mapping_indices) actions.extend(self.actions_layout_indices) actions.extend(self.actions_opt_indices) # Not *depicted* in paper; necessary because optimization can destroy the native gate set if not synthesized and laid_out and not routed: - if self.mdp == "flexible": - actions.extend(self.actions_synthesis_indices) - actions.extend(self.actions_routing_indices) - actions.extend(self.actions_opt_indices) - if self.mdp == "paper": actions.extend(self.actions_synthesis_indices) actions.extend(self.actions_routing_indices) actions.extend(self.actions_opt_indices) # Not *depicted* in paper; necessary because of layout-only passes if synthesized and laid_out and not routed: - if self.mdp == "flexible": - actions.extend(self.actions_routing_indices) - actions.extend(self.actions_opt_indices) - if self.mdp == "paper": actions.extend(self.actions_routing_indices) # Not *depicted* in paper; necessary because routing can insert non-native SWAPs if not synthesized and laid_out and routed: - if self.mdp == "flexible": - actions.extend(self.actions_synthesis_indices) - actions.extend(self.actions_opt_indices) - if self.mdp == "paper": actions.extend(self.actions_synthesis_indices) actions.extend(self.actions_opt_indices) # Final state if synthesized and laid_out and routed: - if self.mdp == "flexible": - actions.extend([self.action_terminate_index]) - actions.extend(self.actions_opt_indices) - if self.mdp == "paper": actions.extend([self.action_terminate_index]) actions.extend(self.actions_opt_indices) From ba9042d7555dcdc05a50e60df7435f753036bb7e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 11 May 2026 15:24:21 +0000 Subject: [PATCH 05/20] =?UTF-8?q?=F0=9F=8E=A8=20pre-commit=20fixes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mqt/predictor/rl/predictorenv.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/mqt/predictor/rl/predictorenv.py b/src/mqt/predictor/rl/predictorenv.py index f126f8a5c..a7e001737 100644 --- a/src/mqt/predictor/rl/predictorenv.py +++ b/src/mqt/predictor/rl/predictorenv.py @@ -513,32 +513,32 @@ def determine_valid_actions_for_state(self) -> list[int]: # Initial state if not synthesized and not laid_out and not routed: - actions.extend(self.actions_synthesis_indices) - actions.extend(self.actions_opt_indices) + actions.extend(self.actions_synthesis_indices) + actions.extend(self.actions_opt_indices) if synthesized and not laid_out and not routed: - actions.extend(self.actions_mapping_indices) - actions.extend(self.actions_layout_indices) - actions.extend(self.actions_opt_indices) + actions.extend(self.actions_mapping_indices) + actions.extend(self.actions_layout_indices) + actions.extend(self.actions_opt_indices) # Not *depicted* in paper; necessary because optimization can destroy the native gate set if not synthesized and laid_out and not routed: - actions.extend(self.actions_synthesis_indices) - actions.extend(self.actions_routing_indices) - actions.extend(self.actions_opt_indices) + actions.extend(self.actions_synthesis_indices) + actions.extend(self.actions_routing_indices) + actions.extend(self.actions_opt_indices) # Not *depicted* in paper; necessary because of layout-only passes if synthesized and laid_out and not routed: - actions.extend(self.actions_routing_indices) + actions.extend(self.actions_routing_indices) # Not *depicted* in paper; necessary because routing can insert non-native SWAPs if not synthesized and laid_out and routed: - actions.extend(self.actions_synthesis_indices) - actions.extend(self.actions_opt_indices) + actions.extend(self.actions_synthesis_indices) + actions.extend(self.actions_opt_indices) # Final state if synthesized and laid_out and routed: - actions.extend([self.action_terminate_index]) - actions.extend(self.actions_opt_indices) + actions.extend([self.action_terminate_index]) + actions.extend(self.actions_opt_indices) return actions From dca4827a21eff19adc1feb3f2e377e907e3de51c Mon Sep 17 00:00:00 2001 From: flowerthrower Date: Mon, 11 May 2026 17:57:54 +0200 Subject: [PATCH 06/20] =?UTF-8?q?=F0=9F=8E=A8=20improve=20comments?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mqt/predictor/rl/predictorenv.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/mqt/predictor/rl/predictorenv.py b/src/mqt/predictor/rl/predictorenv.py index a7e001737..684cb28ff 100644 --- a/src/mqt/predictor/rl/predictorenv.py +++ b/src/mqt/predictor/rl/predictorenv.py @@ -145,9 +145,13 @@ def __init__( self.reward_function = reward_function self.action_space = Discrete(len(self.action_set.keys())) self.num_steps = 0 - self.layout: TranspileLayout | None = None self.num_qubits_uncompiled_circuit = 0 + # Canonical layout state for the current circuit. It is mirrored to + # QuantumCircuit.layout for callers, but kept here because TKET and + # BQSKit conversions do not preserve Qiskit's layout metadata. + self.layout: TranspileLayout | None = None + self.has_parameterized_gates = False self.rng = np.random.default_rng(10) @@ -204,11 +208,10 @@ def step(self, action: int) -> tuple[dict[str, Any], float, bool, bool, dict[Any reward_val = 0 done = False - # in case the Qiskit.QuantumCircuit has unitary or u gates in it, decompose them (because otherwise qiskit will throw an error when applying the BasisTranslator + # in case the Qiskit.QuantumCircuit has unitary or u gates in it, decompose them (because otherwise qiskit will throw an error when applying BasisTranslator if self.state.count_ops().get("unitary"): self.state = self.state.decompose(gates_to_decompose="unitary") - self.state._layout = self.layout # noqa: SLF001 obs = create_feature_dict(self.state) return obs, reward_val, done, False, {} @@ -271,19 +274,21 @@ def action_masks(self) -> list[bool]: """Returns a list of valid actions for the current state.""" action_mask = [action in self.valid_actions for action in self.action_set] - # TKET layout/optimization actions must not run after a Qiskit layout has been set - # (it is not clear how tket will handle the layout). TKET routing actions are - # designed to work after a Qiskit layout via PreProcessTKETRoutingAfterQiskitLayout. - if self.layout is not None: + has_layout = self.layout is not None + + if has_layout: + # TKET layout/optimization actions must not run after a Qiskit layout has been set + # (it is not clear how tket will handle the layout). TKET routing actions are + # designed to work after a Qiskit layout via PreProcessTKETRoutingAfterQiskitLayout. action_mask = [ action_mask[i] and (self.action_set[i].origin != CompilationOrigin.TKET or i in self.actions_routing_indices) for i in range(len(action_mask)) ] - if self.has_parameterized_gates or self.layout is not None: + if self.has_parameterized_gates or has_layout: # remove all actions that are from "origin"=="bqskit" because they are not supported for parameterized gates - # or after layout since using BQSKit after a layout is set may result in an error + # or after layout since using BQSKit after a layout is set can result in an error action_mask = [ action_mask[i] and self.action_set[i].origin != CompilationOrigin.BQSKIT for i in range(len(action_mask)) @@ -382,6 +387,8 @@ def _handle_qiskit_layout_postprocessing( assert pm.property_set["layout"] if pm.property_set["layout"]: + # Layout/mapping passes create the base logical-to-physical mapping; + # later routing actions only update final_layout. self.layout = TranspileLayout( initial_layout=pm.property_set["layout"], input_qubit_mapping=pm.property_set["original_qubit_indices"], From 1e523e150977c5e1a89e0aabdfe64cc15ddf296e Mon Sep 17 00:00:00 2001 From: flowerthrower Date: Mon, 11 May 2026 18:13:02 +0200 Subject: [PATCH 07/20] =?UTF-8?q?=E2=9C=85=20fix=20synthesis=20size=20limi?= =?UTF-8?q?t=20for=20bqskit=20passes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mqt/predictor/rl/actions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mqt/predictor/rl/actions.py b/src/mqt/predictor/rl/actions.py index 118f48ef9..aabd06c6b 100644 --- a/src/mqt/predictor/rl/actions.py +++ b/src/mqt/predictor/rl/actions.py @@ -333,7 +333,7 @@ def remove_action(name: str) -> None: circuit, optimization_level=1 if os.getenv("GITHUB_ACTIONS") == "true" else 2, synthesis_epsilon=1e-1 if os.getenv("GITHUB_ACTIONS") == "true" else 1e-8, - max_synthesis_size=2 if os.getenv("GITHUB_ACTIONS") == "true" else 3, + max_synthesis_size=3, seed=10, num_workers=1 if os.getenv("GITHUB_ACTIONS") == "true" else -1, ), @@ -432,7 +432,7 @@ def remove_action(name: str) -> None: with_mapping=True, optimization_level=1 if os.getenv("GITHUB_ACTIONS") == "true" else 2, synthesis_epsilon=1e-1 if os.getenv("GITHUB_ACTIONS") == "true" else 1e-8, - max_synthesis_size=2 if os.getenv("GITHUB_ACTIONS") == "true" and sys.platform != "linux" else 3, + max_synthesis_size=3, seed=10, num_workers=1 if os.getenv("GITHUB_ACTIONS") == "true" else -1, ) @@ -462,7 +462,7 @@ def remove_action(name: str) -> None: model=MachineModel(bqskit_circuit.num_qudits, gate_set=get_bqskit_native_gates(device)), optimization_level=1 if os.getenv("GITHUB_ACTIONS") == "true" else 2, synthesis_epsilon=1e-1 if os.getenv("GITHUB_ACTIONS") == "true" else 1e-8, - max_synthesis_size=2 if os.getenv("GITHUB_ACTIONS") == "true" and sys.platform != "linux" else 3, + max_synthesis_size=3, seed=10, num_workers=1 if os.getenv("GITHUB_ACTIONS") == "true" else -1, ) From 6d6487a915c917d70217bc2dfceb49e922a46087 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 11 May 2026 16:13:21 +0000 Subject: [PATCH 08/20] =?UTF-8?q?=F0=9F=8E=A8=20pre-commit=20fixes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mqt/predictor/rl/actions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mqt/predictor/rl/actions.py b/src/mqt/predictor/rl/actions.py index aabd06c6b..598dfe0e6 100644 --- a/src/mqt/predictor/rl/actions.py +++ b/src/mqt/predictor/rl/actions.py @@ -11,7 +11,6 @@ from __future__ import annotations import os -import sys from collections import defaultdict from dataclasses import dataclass from enum import Enum From 7a300a206f23ef5af369ec13591c25fa4ee06206 Mon Sep 17 00:00:00 2001 From: flowerthrower Date: Tue, 12 May 2026 13:27:06 +0200 Subject: [PATCH 09/20] =?UTF-8?q?=F0=9F=8E=A8=20reduce=20test=20training?= =?UTF-8?q?=20overhead?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mqt/predictor/rl/predictor.py | 5 +++-- tests/compilation/test_predictor_rl.py | 4 +--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/mqt/predictor/rl/predictor.py b/src/mqt/predictor/rl/predictor.py index 0f3c8c540..2baa3da49 100644 --- a/src/mqt/predictor/rl/predictor.py +++ b/src/mqt/predictor/rl/predictor.py @@ -99,9 +99,10 @@ def train_model( """ set_random_seed(0) # for reproducibility if test: - n_steps = 512 + # minimum training overhead + n_steps = max(timesteps, 2) n_epochs = 1 - batch_size = 16 + batch_size = n_steps progress_bar = False else: # default PPO values diff --git a/tests/compilation/test_predictor_rl.py b/tests/compilation/test_predictor_rl.py index 6a323c528..9ead59f3b 100644 --- a/tests/compilation/test_predictor_rl.py +++ b/tests/compilation/test_predictor_rl.py @@ -84,9 +84,7 @@ def test_qcompile_with_newly_trained_models() -> None: ): rl_compile(qc, device=device, figure_of_merit=figure_of_merit) - predictor.train_model( - test=True, - ) + predictor.train_model(test=True) qc_compiled, compilation_information = rl_compile(qc, device=device, figure_of_merit=figure_of_merit) From d64a97f97272ac8f79c39d00582d78e1bb32abc9 Mon Sep 17 00:00:00 2001 From: flowerthrower Date: Tue, 12 May 2026 13:41:18 +0200 Subject: [PATCH 10/20] =?UTF-8?q?=F0=9F=8E=A8=20add=20comments?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mqt/predictor/rl/predictorenv.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/mqt/predictor/rl/predictorenv.py b/src/mqt/predictor/rl/predictorenv.py index 684cb28ff..aa57e1af0 100644 --- a/src/mqt/predictor/rl/predictorenv.py +++ b/src/mqt/predictor/rl/predictorenv.py @@ -92,7 +92,7 @@ def __init__( self.actions_routing_indices = [] self.actions_mapping_indices = [] self.actions_opt_indices = [] - self.actions_final_optimization_indices = [] + self.actions_final_optimization_indices = [] # TODO: currently not used; will be improved by addressing issue https://github.com/munich-quantum-toolkit/predictor/issues/666 self.used_actions: list[str] = [] self.device = device @@ -193,6 +193,11 @@ def step(self, action: int) -> tuple[dict[str, Any], float, bool, bool, dict[Any self.state: QuantumCircuit = altered_qc self.num_steps += 1 + # in case a Qiskit.QuantumCircuit has `unitary` or `u`` gates in it, decompose them (otherwise qiskit will throw an error when applying BasisTranslator) + # TODO: will be improved by addressing issue https://github.com/munich-quantum-toolkit/predictor/issues/668 + if self.state.count_ops().get("unitary"): + self.state = self.state.decompose(gates_to_decompose="unitary") + self.state._layout = self.layout # noqa: SLF001 self.valid_actions = self.determine_valid_actions_for_state() @@ -208,10 +213,6 @@ def step(self, action: int) -> tuple[dict[str, Any], float, bool, bool, dict[Any reward_val = 0 done = False - # in case the Qiskit.QuantumCircuit has unitary or u gates in it, decompose them (because otherwise qiskit will throw an error when applying BasisTranslator - if self.state.count_ops().get("unitary"): - self.state = self.state.decompose(gates_to_decompose="unitary") - obs = create_feature_dict(self.state) return obs, reward_val, done, False, {} From 9f2697e6b520ce5edf6e4ce59cd47d803ac9df90 Mon Sep 17 00:00:00 2001 From: flowerthrower Date: Tue, 12 May 2026 13:43:41 +0200 Subject: [PATCH 11/20] =?UTF-8?q?=F0=9F=8E=A8=20reduce=20number=20of=20tra?= =?UTF-8?q?ining=20steps?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/compilation/test_predictor_rl.py | 2 +- .../hellinger_distance/test_estimated_hellinger_distance.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/compilation/test_predictor_rl.py b/tests/compilation/test_predictor_rl.py index 9ead59f3b..9fb68914b 100644 --- a/tests/compilation/test_predictor_rl.py +++ b/tests/compilation/test_predictor_rl.py @@ -84,7 +84,7 @@ def test_qcompile_with_newly_trained_models() -> None: ): rl_compile(qc, device=device, figure_of_merit=figure_of_merit) - predictor.train_model(test=True) + predictor.train_model(timesteps=512, test=True) qc_compiled, compilation_information = rl_compile(qc, device=device, figure_of_merit=figure_of_merit) diff --git a/tests/hellinger_distance/test_estimated_hellinger_distance.py b/tests/hellinger_distance/test_estimated_hellinger_distance.py index 1d08c6189..30635cc4b 100644 --- a/tests/hellinger_distance/test_estimated_hellinger_distance.py +++ b/tests/hellinger_distance/test_estimated_hellinger_distance.py @@ -196,10 +196,7 @@ def test_train_and_qcompile_with_hellinger_model(source_path: Path, target_path: # 1. Train the reinforcement learning model for circuit compilation rl_predictor = rl_Predictor(device=device, figure_of_merit=figure_of_merit) - rl_predictor.train_model( - timesteps=5, - test=True, - ) + rl_predictor.train_model(timesteps=5, test=True) # 2. Setup and train the machine learning model for device selection ml_predictor = ml_Predictor(devices=[device], figure_of_merit=figure_of_merit) From 010fa687120a882b6c0e839bb067dbf171ec4302 Mon Sep 17 00:00:00 2001 From: flowerthrower Date: Tue, 12 May 2026 16:30:32 +0200 Subject: [PATCH 12/20] =?UTF-8?q?=F0=9F=8E=A8=20add=20changelog=20entry?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b636e604..5b840d39d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ This project adheres to [Semantic Versioning], with the exception that minor rel ### Changed +- 🎨 Improve the RL state machine logic ([#677]) ([**@flowerthrower**]) - 🔧 Replace `mypy` with `ty` ([#572]) ([**@denialhaag**]) - 🐛 Fix instruction duration unit in estimated success probability calculation ([#445]) ([**@Shaobo-Zhou**]) - ✨ Remove support for custom names of trained models ([#489]) ([**@bachase**]) @@ -46,6 +47,7 @@ _📚 Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool +[#677]: https://github.com/munich-quantum-toolkit/predictor/pull/677 [#572]: https://github.com/munich-quantum-toolkit/predictor/pull/572 [#489]: https://github.com/munich-quantum-toolkit/predictor/pull/489 [#445]: https://github.com/munich-quantum-toolkit/predictor/pull/445 From a474a8fe4b5f9f3b5a9ecff1b220fa13ddc01ad4 Mon Sep 17 00:00:00 2001 From: flowerthrower Date: Fri, 29 May 2026 16:24:03 +0200 Subject: [PATCH 13/20] =?UTF-8?q?=F0=9F=8E=A8=20imporve=20error=20reportin?= =?UTF-8?q?g?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mqt/predictor/reward.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/mqt/predictor/reward.py b/src/mqt/predictor/reward.py index a8cd76952..e8c1b4742 100644 --- a/src/mqt/predictor/reward.py +++ b/src/mqt/predictor/reward.py @@ -191,7 +191,11 @@ def estimated_success_probability(qc: QuantumCircuit, device: Target, precision: res *= 1 - device[gate_type][first_qubit_idx,].error else: second_qubit_idx = scheduled_circ.find_bit(qargs[1]).index - res *= 1 - device[gate_type][first_qubit_idx, second_qubit_idx].error + try: + res *= 1 - device[gate_type][first_qubit_idx, second_qubit_idx].error + except KeyError: + msg = f"Error rate for gate {gate_type} on qubits {first_qubit_idx} and {second_qubit_idx} not found in device properties." + raise KeyError(msg) from None if qiskit_version >= "2.0.0": for i in range(device.num_qubits): From 51b20af4a8ef811adae609cb4e17439122bd52c4 Mon Sep 17 00:00:00 2001 From: flowerthrower Date: Fri, 29 May 2026 16:51:55 +0200 Subject: [PATCH 14/20] =?UTF-8?q?=E2=9C=85=20improve=20coverage?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/compilation/test_predictor_rl.py | 71 +++++++++++++++++++++++++- tests/compilation/test_reward.py | 40 +++++++++++++-- 2 files changed, 105 insertions(+), 6 deletions(-) diff --git a/tests/compilation/test_predictor_rl.py b/tests/compilation/test_predictor_rl.py index 9fb68914b..f141bebbf 100644 --- a/tests/compilation/test_predictor_rl.py +++ b/tests/compilation/test_predictor_rl.py @@ -16,12 +16,14 @@ import pytest from mqt.bench import BenchmarkLevel, get_benchmark from mqt.bench.targets import get_device +from qiskit import QuantumCircuit from qiskit.circuit.library import CXGate from qiskit.qasm2 import dump -from qiskit.transpiler import InstructionProperties, Target +from qiskit.transpiler import InstructionProperties, Layout, Target, TranspileLayout from qiskit.transpiler.passes import GatesInBasis from mqt.predictor.rl import Predictor, rl_compile +from mqt.predictor.rl import predictorenv as predictorenv_module from mqt.predictor.rl.actions import ( CompilationOrigin, DeviceIndependentAction, @@ -31,6 +33,7 @@ remove_action, ) from mqt.predictor.rl.helper import create_feature_dict, get_path_trained_model +from mqt.predictor.rl.predictorenv import PredictorEnv def test_predictor_env_reset_from_string() -> None: @@ -117,6 +120,72 @@ def test_warning_for_unidirectional_device() -> None: Predictor(figure_of_merit="expected_fidelity", device=target) +def test_predictor_env_actions_after_layout_with_non_native_unrouted_circuit() -> None: + """Test valid actions for a laid-out circuit that still needs synthesis and routing.""" + device = get_device("ibm_falcon_27") + env = PredictorEnv(device=device) + qc = QuantumCircuit(3) + qc.h(0) + qc.cx(0, 2) + env.reset(qc) + + env.layout = TranspileLayout( + initial_layout=Layout({qubit: index for index, qubit in enumerate(qc.qubits)}), + input_qubit_mapping={qubit: index for index, qubit in enumerate(qc.qubits)}, + final_layout=None, + _output_qubit_list=qc.qubits, + _input_qubit_count=qc.num_qubits, + ) + + valid_actions = env.determine_valid_actions_for_state() + + assert set(env.actions_synthesis_indices).issubset(valid_actions) + assert set(env.actions_routing_indices).issubset(valid_actions) + assert set(env.actions_opt_indices).issubset(valid_actions) + assert env.action_terminate_index not in valid_actions + + +def test_predictor_env_qiskit_routing_updates_final_layout(monkeypatch: pytest.MonkeyPatch) -> None: + """Test that Qiskit routing actions update the tracked final layout.""" + device = get_device("ibm_falcon_27") + env = PredictorEnv(device=device) + qc = QuantumCircuit(2) + qc.cx(0, 1) + env.reset(qc) + + initial_layout = Layout({qubit: index for index, qubit in enumerate(qc.qubits)}) + final_layout = Layout({qc.qubits[0]: 1, qc.qubits[1]: 0}) + env.layout = TranspileLayout( + initial_layout=initial_layout, + input_qubit_mapping={qubit: index for index, qubit in enumerate(qc.qubits)}, + final_layout=None, + _output_qubit_list=qc.qubits, + _input_qubit_count=qc.num_qubits, + ) + + class FakePassManager: + """Minimal PassManager replacement that exposes a final layout.""" + + def __init__(self, _passes: object) -> None: + self.property_set = {"final_layout": final_layout} + + def run(self, circuit: QuantumCircuit) -> QuantumCircuit: + return circuit + + monkeypatch.setattr(predictorenv_module, "PassManager", FakePassManager) + action = DeviceIndependentAction( + name="SyntheticQiskitRouting", + pass_type=PassType.ROUTING, + transpile_pass=[], + origin=CompilationOrigin.QISKIT, + ) + + altered_qc = env._apply_qiskit_action(action, env.actions_routing_indices[0]) # noqa: SLF001 + + assert altered_qc is env.state + assert env.layout.final_layout is final_layout + + def test_register_action() -> None: """Test the register_action function.""" action = DeviceIndependentAction( diff --git a/tests/compilation/test_reward.py b/tests/compilation/test_reward.py index 5e36fcc8a..e1bd57e31 100644 --- a/tests/compilation/test_reward.py +++ b/tests/compilation/test_reward.py @@ -10,15 +10,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import re import pytest from mqt.bench import BenchmarkLevel, get_benchmark from mqt.bench.targets import get_device -from qiskit import transpile +from qiskit import QuantumCircuit, transpile from qiskit.circuit.library import CXGate, Measure, XGate from qiskit.transpiler import InstructionProperties, Target +from mqt.predictor import reward as reward_module from mqt.predictor.reward import crit_depth, esp_data_available, estimated_success_probability, expected_fidelity try: @@ -30,9 +31,6 @@ QISKIT_PRE_2_0 = True -if TYPE_CHECKING: - from qiskit import QuantumCircuit - @pytest.fixture def device() -> Target: @@ -132,3 +130,35 @@ def test_esp_data_available_invalid_target(kwargs: dict[str, float | bool]) -> N """Test that `esp_data_available` returns False for invalid device configurations.""" target = make_target(**kwargs) # ty: ignore[invalid-argument-type] assert not esp_data_available(target) + + +@pytest.mark.parametrize("reward_function", ["expected_fidelity", "estimated_success_probability"]) +def test_reward_missing_two_qubit_error(reward_function: str, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that reward functions report missing two-qubit error rates descriptively.""" + target = make_target() + del target["cx"][0, 1] + + qc = QuantumCircuit(2) + if reward_function == "estimated_success_probability": + qc.x(0) + + scheduled_qc = QuantumCircuit(2) + scheduled_qc.cx(0, 1) + + def fake_transpile(*_args: object, **_kwargs: object) -> QuantumCircuit: + return scheduled_qc + + def estimate_duration(*, target: Target) -> float: + assert target.num_qubits == 2 + return 0.0 + + monkeypatch.setattr(reward_module, "qiskit_version", "2.0.0") + monkeypatch.setattr(reward_module, "transpile", fake_transpile) + monkeypatch.setattr(scheduled_qc, "estimate_duration", estimate_duration) + else: + qc.cx(0, 1) + + reward = estimated_success_probability if reward_function == "estimated_success_probability" else expected_fidelity + expected_message = "Error rate for gate cx on qubits 0 and 1 not found in device properties." + with pytest.raises(KeyError, match=re.escape(expected_message)): + reward(qc, target) From 7e9a3699af3face518b2984163d38ac0bd37bd20 Mon Sep 17 00:00:00 2001 From: flowerthrower Date: Mon, 1 Jun 2026 13:24:44 +0200 Subject: [PATCH 15/20] =?UTF-8?q?=E2=9C=85=20fix=20test=20for=20qiskit<2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/compilation/test_reward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/compilation/test_reward.py b/tests/compilation/test_reward.py index e1bd57e31..5b3512f02 100644 --- a/tests/compilation/test_reward.py +++ b/tests/compilation/test_reward.py @@ -154,7 +154,7 @@ def estimate_duration(*, target: Target) -> float: monkeypatch.setattr(reward_module, "qiskit_version", "2.0.0") monkeypatch.setattr(reward_module, "transpile", fake_transpile) - monkeypatch.setattr(scheduled_qc, "estimate_duration", estimate_duration) + monkeypatch.setattr(scheduled_qc, "estimate_duration", estimate_duration, raising=False) # not available in qiskit<2.0 else: qc.cx(0, 1) From 069474973d4e960067c73ad3090d2da33a9d2acd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 Jun 2026 11:25:04 +0000 Subject: [PATCH 16/20] =?UTF-8?q?=F0=9F=8E=A8=20pre-commit=20fixes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/compilation/test_reward.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/compilation/test_reward.py b/tests/compilation/test_reward.py index 5b3512f02..e0843dbc9 100644 --- a/tests/compilation/test_reward.py +++ b/tests/compilation/test_reward.py @@ -154,7 +154,9 @@ def estimate_duration(*, target: Target) -> float: monkeypatch.setattr(reward_module, "qiskit_version", "2.0.0") monkeypatch.setattr(reward_module, "transpile", fake_transpile) - monkeypatch.setattr(scheduled_qc, "estimate_duration", estimate_duration, raising=False) # not available in qiskit<2.0 + monkeypatch.setattr( + scheduled_qc, "estimate_duration", estimate_duration, raising=False + ) # not available in qiskit<2.0 else: qc.cx(0, 1) From 71d9be38218530f1f9ad078c11f0b7699a771cc7 Mon Sep 17 00:00:00 2001 From: flowerthrower Date: Wed, 3 Jun 2026 10:31:44 +0200 Subject: [PATCH 17/20] =?UTF-8?q?=F0=9F=8E=A8=20enable=20random=20training?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mqt/predictor/rl/predictor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/mqt/predictor/rl/predictor.py b/src/mqt/predictor/rl/predictor.py index 2baa3da49..c55efc23a 100644 --- a/src/mqt/predictor/rl/predictor.py +++ b/src/mqt/predictor/rl/predictor.py @@ -89,6 +89,7 @@ def train_model( timesteps: int = 1000, verbose: int = 2, test: bool = False, + seed: int | None = 0, ) -> None: """Trains all models for the given reward functions and device. @@ -96,8 +97,10 @@ def train_model( timesteps: The number of timesteps to train the model. Defaults to 1000. verbose: The verbosity level. Defaults to 2. test: Whether to train the model for testing purposes. Defaults to False. + seed: The random seed to use for reproducible training. Set to None to use true randomness. Defaults to 0. """ - set_random_seed(0) # for reproducibility + if seed is not None: + set_random_seed(seed) if test: # minimum training overhead n_steps = max(timesteps, 2) @@ -121,6 +124,7 @@ def train_model( n_steps=n_steps, batch_size=batch_size, n_epochs=n_epochs, + seed=seed, ) # Training Loop: In each iteration, the agent collects n_steps steps (rollout), # updates the policy for n_epochs, and then repeats the process until total_timesteps steps have been taken. From 93531910ebec419a25e96a577e519c69836c6075 Mon Sep 17 00:00:00 2001 From: flowerthrower Date: Wed, 3 Jun 2026 10:42:14 +0200 Subject: [PATCH 18/20] =?UTF-8?q?=F0=9F=8E=A8=20imporve=20comments?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mqt/predictor/rl/predictorenv.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/mqt/predictor/rl/predictorenv.py b/src/mqt/predictor/rl/predictorenv.py index 04ec8077e..95d632463 100644 --- a/src/mqt/predictor/rl/predictorenv.py +++ b/src/mqt/predictor/rl/predictorenv.py @@ -194,7 +194,7 @@ def step(self, action: int) -> tuple[dict[str, Any], float, bool, bool, dict[Any self.state: QuantumCircuit = altered_qc self.num_steps += 1 - # in case a Qiskit.QuantumCircuit has `unitary` or `u`` gates in it, decompose them (otherwise qiskit will throw an error when applying BasisTranslator) + # in case a Qiskit.QuantumCircuit has `unitary` gates in it, decompose them (otherwise qiskit will throw an error when applying BasisTranslator) # TODO: will be improved by addressing issue https://github.com/munich-quantum-toolkit/predictor/issues/668 if self.state.count_ops().get("unitary"): self.state = self.state.decompose(gates_to_decompose="unitary") @@ -530,17 +530,20 @@ def determine_valid_actions_for_state(self) -> list[int]: actions.extend(self.actions_layout_indices) actions.extend(self.actions_opt_indices) - # Not *depicted* in paper; necessary because optimization can destroy the native gate set + # Possible state because optimization can destroy the native gate set + # State is not explicitly *depicted* in original paper if not synthesized and laid_out and not routed: actions.extend(self.actions_synthesis_indices) actions.extend(self.actions_routing_indices) actions.extend(self.actions_opt_indices) - # Not *depicted* in paper; necessary because of layout-only passes + # Possible state because there are layout-only passes + # State is not explicitly *depicted* in original paper if synthesized and laid_out and not routed: actions.extend(self.actions_routing_indices) - # Not *depicted* in paper; necessary because routing can insert non-native SWAPs + # Possible state because routing may add SWAP gates which are not necessarily native gates + # State is not explicitly *depicted* in original paper if not synthesized and laid_out and routed: actions.extend(self.actions_synthesis_indices) actions.extend(self.actions_opt_indices) From 3c1b0765d50082388d14c1d9ba6444c42dcc180b Mon Sep 17 00:00:00 2001 From: flowerthrower Date: Wed, 3 Jun 2026 11:05:30 +0200 Subject: [PATCH 19/20] =?UTF-8?q?=F0=9F=8E=A8=20make=20random=20seed=20def?= =?UTF-8?q?ault=20for=20training?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mqt/predictor/rl/predictor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/mqt/predictor/rl/predictor.py b/src/mqt/predictor/rl/predictor.py index c55efc23a..c1578125a 100644 --- a/src/mqt/predictor/rl/predictor.py +++ b/src/mqt/predictor/rl/predictor.py @@ -89,7 +89,7 @@ def train_model( timesteps: int = 1000, verbose: int = 2, test: bool = False, - seed: int | None = 0, + seed: int | None = None, ) -> None: """Trains all models for the given reward functions and device. @@ -97,7 +97,8 @@ def train_model( timesteps: The number of timesteps to train the model. Defaults to 1000. verbose: The verbosity level. Defaults to 2. test: Whether to train the model for testing purposes. Defaults to False. - seed: The random seed to use for reproducible training. Set to None to use true randomness. Defaults to 0. + seed: The random seed to use for reproducible training. Set to None to use true randomness. + Defaults to None. """ if seed is not None: set_random_seed(seed) From 95af4a35adc8cbb18f950818381d6a4289232e0d Mon Sep 17 00:00:00 2001 From: flowerthrower Date: Wed, 3 Jun 2026 12:23:18 +0200 Subject: [PATCH 20/20] =?UTF-8?q?=F0=9F=8E=A8=20remove=20redundant=20impor?= =?UTF-8?q?ts?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/compilation/test_predictor_rl.py | 7 +++---- .../test_estimated_hellinger_distance.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/compilation/test_predictor_rl.py b/tests/compilation/test_predictor_rl.py index f141bebbf..5cebe65d8 100644 --- a/tests/compilation/test_predictor_rl.py +++ b/tests/compilation/test_predictor_rl.py @@ -33,7 +33,6 @@ remove_action, ) from mqt.predictor.rl.helper import create_feature_dict, get_path_trained_model -from mqt.predictor.rl.predictorenv import PredictorEnv def test_predictor_env_reset_from_string() -> None: @@ -87,7 +86,7 @@ def test_qcompile_with_newly_trained_models() -> None: ): rl_compile(qc, device=device, figure_of_merit=figure_of_merit) - predictor.train_model(timesteps=512, test=True) + predictor.train_model(timesteps=512, test=True, seed=0) qc_compiled, compilation_information = rl_compile(qc, device=device, figure_of_merit=figure_of_merit) @@ -123,7 +122,7 @@ def test_warning_for_unidirectional_device() -> None: def test_predictor_env_actions_after_layout_with_non_native_unrouted_circuit() -> None: """Test valid actions for a laid-out circuit that still needs synthesis and routing.""" device = get_device("ibm_falcon_27") - env = PredictorEnv(device=device) + env = predictorenv_module.PredictorEnv(device=device) qc = QuantumCircuit(3) qc.h(0) qc.cx(0, 2) @@ -148,7 +147,7 @@ def test_predictor_env_actions_after_layout_with_non_native_unrouted_circuit() - def test_predictor_env_qiskit_routing_updates_final_layout(monkeypatch: pytest.MonkeyPatch) -> None: """Test that Qiskit routing actions update the tracked final layout.""" device = get_device("ibm_falcon_27") - env = PredictorEnv(device=device) + env = predictorenv_module.PredictorEnv(device=device) qc = QuantumCircuit(2) qc.cx(0, 1) env.reset(qc) diff --git a/tests/hellinger_distance/test_estimated_hellinger_distance.py b/tests/hellinger_distance/test_estimated_hellinger_distance.py index 30635cc4b..3d274c13b 100644 --- a/tests/hellinger_distance/test_estimated_hellinger_distance.py +++ b/tests/hellinger_distance/test_estimated_hellinger_distance.py @@ -196,7 +196,7 @@ def test_train_and_qcompile_with_hellinger_model(source_path: Path, target_path: # 1. Train the reinforcement learning model for circuit compilation rl_predictor = rl_Predictor(device=device, figure_of_merit=figure_of_merit) - rl_predictor.train_model(timesteps=5, test=True) + rl_predictor.train_model(timesteps=5, test=True, seed=0) # 2. Setup and train the machine learning model for device selection ml_predictor = ml_Predictor(devices=[device], figure_of_merit=figure_of_merit)