From ec48899c59ee2fcf79943a9518fd29ac11eb0e27 Mon Sep 17 00:00:00 2001 From: builder Date: Thu, 14 May 2026 10:40:53 -0700 Subject: [PATCH] fix: Handle Projection and Reset in DensityMatrixSimulation --- .../density_matrix_simulation.py | 59 +++++++++- test/integ_tests/circuits_test.py | 30 +++-- .../braket/default_simulator/test_mcm.py | 111 ++++++++++++++++++ 3 files changed, 188 insertions(+), 12 deletions(-) diff --git a/src/braket/default_simulator/density_matrix_simulation.py b/src/braket/default_simulator/density_matrix_simulation.py index 58210625..84701a87 100644 --- a/src/braket/default_simulator/density_matrix_simulation.py +++ b/src/braket/default_simulator/density_matrix_simulation.py @@ -13,6 +13,7 @@ import numpy as np +from braket.default_simulator.gate_operations import Projection, Reset from braket.default_simulator.linalg_utils import ( QuantumGateDispatcher, controlled_matrix, @@ -154,7 +155,15 @@ def _apply_operations( work_buffer2 = np.zeros_like(result, dtype=complex) for operation in operations: - if isinstance(operation, (GateOperation, Observable)): + if isinstance(operation, Projection): + result, temp = DensityMatrixSimulation._apply_projection( + result, temp, qubit_count, operation, dispatcher + ) + elif isinstance(operation, Reset): + result, temp = DensityMatrixSimulation._apply_reset( + result, temp, work_buffer1, work_buffer2, qubit_count, operation, dispatcher + ) + elif isinstance(operation, (GateOperation, Observable)): targets = operation.targets num_ctrl = len(operation.control_state) # Extract gate_type if available @@ -183,6 +192,54 @@ def _apply_operations( result.shape = original_shape return result + @staticmethod + def _apply_projection( + result: np.ndarray, + temp: np.ndarray, + qubit_count: int, + operation: Projection, + dispatcher: QuantumGateDispatcher, + ) -> tuple[np.ndarray, np.ndarray]: + result, temp = DensityMatrixSimulation._apply_gate( + result, + temp, + qubit_count, + operation._base_matrix, + operation.targets, + (), + (), + dispatcher, + None, + ) + indices = list(range(qubit_count)) + norm = float(np.real(np.einsum(result, indices + indices))) + result /= norm + return result, temp + + @staticmethod + def _apply_reset( + result: np.ndarray, + temp: np.ndarray, + work_buffer1: np.ndarray, + work_buffer2: np.ndarray, + qubit_count: int, + operation: Reset, + dispatcher: QuantumGateDispatcher, + ) -> tuple[np.ndarray, np.ndarray]: + return DensityMatrixSimulation._apply_kraus( + result, + temp, + work_buffer1, + work_buffer2, + qubit_count, + [ + np.array([[1, 0], [0, 0]], dtype=complex), + np.array([[0, 1], [0, 0]], dtype=complex), + ], + operation.targets, + dispatcher, + ) + @staticmethod def _apply_gate( result: np.ndarray, diff --git a/test/integ_tests/circuits_test.py b/test/integ_tests/circuits_test.py index e7fa7302..c2657445 100644 --- a/test/integ_tests/circuits_test.py +++ b/test/integ_tests/circuits_test.py @@ -1043,12 +1043,20 @@ def test_cswap_self_inverse(self, ctrl, t0, t1, n_qubits): class TestClassicalControlGates: """End-to-end tests for the IQM experimental ``measure_ff`` / ``cc_prx`` gates against ``LocalSimulator``. Follows the patterns shown in the Braket - Dynamic Circuits notebook.""" + Dynamic Circuits notebook. - def _run(self, circuit, shots): - return LocalSimulator().run(circuit, shots=shots).result().measurement_counts + Run on both the state-vector (``default``) and density-matrix (``braket_dm``) + backends so the branched-replay path is exercised on both. + """ - def test_measure_ff_cc_prx_feedforward(self): + @pytest.fixture(params=["default", "braket_dm"]) + def simulator_name(self, request): + return request.param + + def _run(self, simulator_name, circuit, shots): + return LocalSimulator(simulator_name).run(circuit, shots=shots).result().measurement_counts + + def test_measure_ff_cc_prx_feedforward(self, simulator_name): """Classical feedforward: after measuring q[0], conditionally flip q[1]. Equivalent to a CNOT on the 50/50 state created by ``h q[0]``. """ @@ -1057,12 +1065,12 @@ def test_measure_ff_cc_prx_feedforward(self): circuit.h(0) circuit.measure_ff(0, 0) circuit.cc_prx(1, pi, 0.0, 0) - counts = self._run(circuit, shots=SHOTS) + counts = self._run(simulator_name, circuit, shots=SHOTS) assert set(counts.keys()) == {"00", "11"} assert abs(counts["00"] / SHOTS - 0.5) < ATOL assert abs(counts["11"] / SHOTS - 0.5) < ATOL - def test_active_qubit_reset(self): + def test_active_qubit_reset(self, simulator_name): """Active qubit reset from the notebook: prepare |1>, measure, then conditionally rotate back to |0>. The qubit should always end in |0>.""" with EnableExperimentalCapability(): @@ -1070,10 +1078,10 @@ def test_active_qubit_reset(self): circuit.x(0) circuit.measure_ff(0, 0) circuit.cc_prx(0, pi, 0.0, 0) - counts = self._run(circuit, shots=SHOTS) + counts = self._run(simulator_name, circuit, shots=SHOTS) assert counts == {"0": SHOTS} - def test_active_qubit_reset_on_superposition(self): + def test_active_qubit_reset_on_superposition(self, simulator_name): """Reset of a superposition state: ``h`` puts q[0] in ``|+>``; the measure+feedforward pair collapses it to a known state and rotates back, so all shots should read ``|0>``.""" @@ -1082,10 +1090,10 @@ def test_active_qubit_reset_on_superposition(self): circuit.h(0) circuit.measure_ff(0, 0) circuit.cc_prx(0, pi, 0.0, 0) - counts = self._run(circuit, shots=SHOTS) + counts = self._run(simulator_name, circuit, shots=SHOTS) assert counts == {"0": SHOTS} - def test_independent_feedback_keys(self): + def test_independent_feedback_keys(self, simulator_name): """Two independent feedback keys drive two independent conditionals, yielding all four combinations of the measured qubits.""" with EnableExperimentalCapability(): @@ -1096,7 +1104,7 @@ def test_independent_feedback_keys(self): circuit.measure_ff(1, 1) circuit.cc_prx(2, pi, 0.0, 0) circuit.cc_prx(3, pi, 0.0, 1) - counts = self._run(circuit, shots=SHOTS) + counts = self._run(simulator_name, circuit, shots=SHOTS) # q[2] mirrors q[0], q[3] mirrors q[1]. assert set(counts.keys()) == {"0000", "0101", "1010", "1111"} for key in counts: diff --git a/test/unit_tests/braket/default_simulator/test_mcm.py b/test/unit_tests/braket/default_simulator/test_mcm.py index 4435cac7..cd6c9c7c 100644 --- a/test/unit_tests/braket/default_simulator/test_mcm.py +++ b/test/unit_tests/braket/default_simulator/test_mcm.py @@ -36,6 +36,7 @@ UnaryExpression, ) from braket.default_simulator.openqasm.program_context import AbstractProgramContext +from braket.default_simulator.density_matrix_simulator import DensityMatrixSimulator from braket.default_simulator.state_vector_simulator import StateVectorSimulator from braket.ir.openqasm import Program as OpenQASMProgram @@ -5030,3 +5031,113 @@ def test_flat_context_preserves_mcm_while_loop(self): "}" ) assert Interpreter(context=FlatProgramContext()).run(qasm).circuit == qasm + + + +class TestDensityMatrixSimulatorBranching: + """Branching MCM coverage for ``DensityMatrixSimulator``. + + The state-vector tests in :class:`TestStateVectorSimulatorOperatorsOpenQASM` + drive the bulk of branching behavior; these tests verify that the same + programs produce statistically equivalent results when the simulator is + swapped out for the density-matrix backend, which goes through + :meth:`DensityMatrixSimulation._apply_projection` (and, for resets, the + Kraus channel path) on each branched replay. + """ + + SHOTS = 4000 + ATOL = 0.06 # generous tolerance for shot-noise across both simulators + + def _counts(self, simulator, qasm): + result = simulator.run(OpenQASMProgram(source=qasm), shots=self.SHOTS) + return Counter("".join(m) for m in result.measurements) + + def _assert_distributions_match(self, qasm, expected_keys=None): + """Run ``qasm`` on both simulators and assert the histograms agree + within shot noise, optionally checking the support set explicitly.""" + sv_counts = self._counts(StateVectorSimulator(), qasm) + dm_counts = self._counts(DensityMatrixSimulator(), qasm) + if expected_keys is not None: + assert set(dm_counts) <= set(expected_keys) + assert set(sv_counts) <= set(expected_keys) + keys = set(sv_counts) | set(dm_counts) + for key in keys: + sv_freq = sv_counts.get(key, 0) / self.SHOTS + dm_freq = dm_counts.get(key, 0) / self.SHOTS + assert abs(sv_freq - dm_freq) < self.ATOL, ( + f"DM/SV disagree on outcome {key!r}: sv={sv_freq:.3f}, dm={dm_freq:.3f}" + ) + + def test_repeated_mcm_with_classical_feedforward(self): + """The original failure case: two MCMs on the same qubit with a + classical-feedforward conditional in between.""" + qasm = """ + OPENQASM 3.0; + bit[2] b; + qubit[2] q; + h q[0]; + b[0] = measure q[0]; + if (b[0] == 0) { + x q[0]; + } + b[1] = measure q[0]; + if (b[0] == b[1]) { + x q[1]; + } + """ + # After the conditional, q[0] is always |1>. b[1] is always 1, so q[1] + # only flips when b[0]==1: "10" and "11" each ~50%. + self._assert_distributions_match(qasm, expected_keys={"10", "11"}) + + def test_bell_pair_mcm_decoupling(self): + """Bell-state MCM: measuring one half of an entangled pair must + propagate the projection through the entanglement so the conditional + flip leaves the partner deterministic.""" + qasm = """ + OPENQASM 3.0; + bit[2] b; + qubit[2] q; + h q[0]; + cnot q[0], q[1]; + b[0] = measure q[0]; + if (b[0] == 1) { + x q[1]; + } + """ + # q[1] is always |0> (Bell-correlated, then flipped iff b[0]==1): + # outcomes "00" and "10" each ~50%. + self._assert_distributions_match(qasm, expected_keys={"00", "10"}) + + def test_three_path_branch_with_nested_conditionals(self): + """Reuses the 3.2 conditional-logic shape from the SV-sim suite to + exercise multiple sequential branches under the DM backend.""" + qasm = """ + OPENQASM 3.0; + bit[2] b; + qubit[3] q; + h q[0]; + h q[1]; + b[0] = measure q[0]; + if (b[0] == 0) { + h q[1]; + } + b[1] = measure q[1]; + """ + self._assert_distributions_match(qasm) + + def test_branched_reset(self): + """A ``reset`` after a branched measurement exercises the Kraus channel + path (``_apply_reset``) on each replayed path.""" + qasm = """ + OPENQASM 3.0; + bit[1] b; + qubit[2] q; + h q[0]; + b[0] = measure q[0]; + if (b[0] == 1) { + x q[1]; + } + reset q[0]; + """ + # q[0] is always |0> after the reset; q[1] flips iff b[0]==1 → 50/50. + self._assert_distributions_match(qasm, expected_keys={"00", "01"})