diff --git a/src/braket/default_simulator/openqasm/circuit.py b/src/braket/default_simulator/openqasm/circuit.py index 91e6b9eb..60d93c8c 100644 --- a/src/braket/default_simulator/openqasm/circuit.py +++ b/src/braket/default_simulator/openqasm/circuit.py @@ -69,15 +69,20 @@ def add_measure( allow_remeasure: bool = False, ): for index, qubit in enumerate(target): - if not allow_remeasure and qubit in self.measured_qubits: - raise ValueError(f"Qubit {qubit} is already measured or captured.") - self.measured_qubits.append(qubit) - self.qubit_set.add(qubit) - self.target_classical_indices.append( + classical_index = ( classical_targets[index] if classical_targets else max(index, len(self.target_classical_indices)) ) + if allow_remeasure and classical_index in self.target_classical_indices: + self.measured_qubits[self.target_classical_indices.index(classical_index)] = qubit + self.qubit_set.add(qubit) + continue + if not allow_remeasure and qubit in self.measured_qubits: + raise ValueError(f"Qubit {qubit} is already measured or captured.") + self.measured_qubits.append(qubit) + self.qubit_set.add(qubit) + self.target_classical_indices.append(classical_index) def add_result(self, result: Results) -> None: """ diff --git a/src/braket/default_simulator/openqasm/program_context.py b/src/braket/default_simulator/openqasm/program_context.py index ed5f44d0..fe175eeb 100644 --- a/src/braket/default_simulator/openqasm/program_context.py +++ b/src/braket/default_simulator/openqasm/program_context.py @@ -1311,12 +1311,10 @@ def _flush_pending_mcm_for_variable(self, name: str) -> None: self._initialize_paths_from_circuit() # Also flush any earlier pending measurements so the state is correct for earlier in remaining: - self._measure_and_branch(earlier[0]) - self._update_classical_from_measurement(earlier[0], earlier[2]) + self._branch_measurement(earlier[0], earlier[1], earlier[2]) remaining.clear() if self._is_branched: - self._measure_and_branch(mcm_target) - self._update_classical_from_measurement(mcm_target, mcm_dest) + self._branch_measurement(mcm_target, mcm_classical, mcm_dest) else: # shots == 0: register as a normal measurement and set variable to 0 self._circuit.add_measure( @@ -1362,20 +1360,17 @@ def _flush_pending_mcm_for_qubits(self, qubits: tuple[int, ...] | list[int]) -> self._pending_mcm_targets = self._pending_mcm_targets[last_overlap_idx + 1 :] if self._is_branched: - for mcm_target, _mcm_classical, mcm_dest in to_flush: - self._measure_and_branch(mcm_target) - self._update_classical_from_measurement(mcm_target, mcm_dest) + for mcm_target, mcm_classical, mcm_dest in to_flush: + self._branch_measurement(mcm_target, mcm_classical, mcm_dest) elif self._shots > 0: self._is_branched = True self._initialize_paths_from_circuit() # Flush to_flush first (preserving program order), then any # remaining pending measurements that came after the overlap. - for mcm_target, _mcm_classical, mcm_dest in to_flush: - self._measure_and_branch(mcm_target) - self._update_classical_from_measurement(mcm_target, mcm_dest) + for mcm_target, mcm_classical, mcm_dest in to_flush: + self._branch_measurement(mcm_target, mcm_classical, mcm_dest) for entry in self._pending_mcm_targets: - self._measure_and_branch(entry[0]) - self._update_classical_from_measurement(entry[0], entry[2]) + self._branch_measurement(entry[0], entry[1], entry[2]) self._pending_mcm_targets = [] else: # shots == 0: register as normal measurements and set variables to 0 @@ -1536,8 +1531,7 @@ def add_measure( self._flush_pending_mcm_for_qubits(target) if self._is_branched: if classical_destination is not None: - self._measure_and_branch(target) - self._update_classical_from_measurement(target, classical_destination) + self._branch_measurement(target, classical_targets, classical_destination) else: # End-of-circuit measurement in branched mode: record in circuit # for qubit tracking but don't branch further @@ -1567,8 +1561,7 @@ def _maybe_transition_to_branched(self) -> None: self._is_branched = True self._initialize_paths_from_circuit() for mcm_target, mcm_classical, mcm_dest in self._pending_mcm_targets: - self._measure_and_branch(mcm_target) - self._update_classical_from_measurement(mcm_target, mcm_dest) + self._branch_measurement(mcm_target, mcm_classical, mcm_dest) self._pending_mcm_targets.clear() def track_mcm_dependency(self, lvalue_name: str, rvalue) -> None: @@ -1992,6 +1985,21 @@ def _initialize_paths_from_circuit(self) -> None: ) initial_path.set_variable(name, fv) + def _branch_measurement( + self, + target: tuple[int, ...], + classical_targets, + classical_destination, + ) -> None: + self._measure_and_branch(target) + self._update_classical_from_measurement(target, classical_destination) + if classical_targets is not None: + self._circuit.add_measure( + target, + classical_targets, + allow_remeasure=self.supports_midcircuit_measurement, + ) + def _measure_and_branch(self, target: tuple[int]) -> None: """Sample outcomes per active path and branch with proportional shot allocation. @@ -2060,19 +2068,22 @@ def _get_qubit_samples(self, path: SimulationPath, qubit_idx: int) -> np.ndarray "Construct it via ``BaseLocalSimulator.create_program_context`` or pass " "``simulator=...`` to provide one." ) - # Use the total declared qubit count (from the context), not just the - # qubits that have appeared in instructions so far. This ensures that - # measurements on qubits that haven't had gates applied yet still work - # (they are in the |0⟩ state). - qubit_count = self.num_qubits - if self._circuit.qubit_set: - qubit_count = max(qubit_count, max(self._circuit.qubit_set) + 1) + # Build a contiguous qubit map covering only the qubits this path actually touches + used = {qubit_idx} + for ins in path.instructions: + used.update(ins.targets) + qubit_map = {q: i for i, q in enumerate(sorted(used))} sim = self._simulator.initialize_simulation( - qubit_count=qubit_count, shots=path.shots, batch_size=self._batch_size + qubit_count=len(qubit_map), shots=path.shots, batch_size=self._batch_size ) - sim.evolve(path.instructions) + remapped = [] + for ins in path.instructions: + new_ins = copy(ins) + new_ins._targets = tuple(qubit_map[q] for q in ins.targets) + remapped.append(new_ins) + sim.evolve(remapped) samples = np.asarray(sim.retrieve_samples()) - return (samples >> (qubit_count - 1 - qubit_idx)) & 1 + return (samples >> (qubit_count - 1 - qubit_map[qubit_idx])) & 1 _BINARY_EQUALS = getattr(BinaryOperator, "==") diff --git a/src/braket/default_simulator/simulator.py b/src/braket/default_simulator/simulator.py index 9f2675e7..2bcdd98d 100644 --- a/src/braket/default_simulator/simulator.py +++ b/src/braket/default_simulator/simulator.py @@ -15,6 +15,7 @@ import warnings from abc import ABC, abstractmethod from collections.abc import Callable +from copy import copy from typing import Any import numpy as np @@ -26,7 +27,7 @@ AbstractProgramContext, ProgramContext, ) -from braket.default_simulator.operation import Observable, Operation +from braket.default_simulator.operation import GateOperation, Observable, Operation from braket.default_simulator.operation_helpers import from_braket_instruction from braket.default_simulator.result_types import ( ResultType, @@ -856,8 +857,16 @@ def _run_branched( GateModelTaskResult: Aggregated result across all paths. """ circuit = context.circuit - qubit_map = BaseLocalSimulator._map_circuit_to_contiguous_qubits(circuit) - qubit_count = circuit.num_qubits + path_qubit_set = set() + for path in context.active_paths: + for ins in path.instructions: + path_qubit_set.update(ins.targets) + full_qubit_set = circuit.qubit_set | path_qubit_set + if not circuit.measured_qubits: + full_qubit_set |= set(range(context.num_qubits)) + qubit_map = BaseLocalSimulator._contiguous_qubit_mapping(full_qubit_set) + qubit_count = len(qubit_map) + BaseLocalSimulator._map_circuit_qubits(circuit, qubit_map) # Determine measured qubits from the circuit classical_bit_positions = {b: i for i, b in enumerate(circuit.target_classical_indices)} @@ -869,34 +878,22 @@ def _run_branched( [qubit_map[q] for q in measured_qubits] if measured_qubits else None ) - # For path simulation, we need enough qubits to cover all qubit indices - # referenced in the instructions (handles noncontiguous qubit indices). - # Use the context's num_qubits (total declared qubits) to ensure all - # qubits are accounted for, even those without explicit gate operations. - sim_qubit_count = qubit_count - sim_qubit_count = max(sim_qubit_count, context.num_qubits) - if circuit.qubit_set: - sim_qubit_count = max(sim_qubit_count, max(circuit.qubit_set) + 1) - - # Aggregate samples across all active paths all_samples = [] for path in context.active_paths: sim = self.initialize_simulation( - qubit_count=sim_qubit_count, shots=path.shots, batch_size=batch_size + qubit_count=qubit_count, shots=path.shots, batch_size=batch_size ) - sim.evolve(path.instructions) + sim.evolve(_remap_instructions(path.instructions, qubit_map)) all_samples.extend(sim.retrieve_samples()) # Build measurements in the same format as _formatted_measurements measurements = [ - list("{number:0{width}b}".format(number=sample, width=sim_qubit_count))[ - -sim_qubit_count: - ] + list("{number:0{width}b}".format(number=sample, width=qubit_count))[-qubit_count:] for sample in all_samples ] if mapped_measured_qubits is not None and mapped_measured_qubits != []: mapped_arr = np.array(mapped_measured_qubits) - in_circuit_mask = mapped_arr < sim_qubit_count + in_circuit_mask = mapped_arr < qubit_count qubits_in_circuit = mapped_arr[in_circuit_mask] qubits_not_in_circuit = mapped_arr[~in_circuit_mask] measurements_array = np.array(measurements) @@ -995,3 +992,14 @@ def run_jaqcd( ) return self._create_results_obj(results, circuit_ir, simulation) + + +def _remap_instructions( + instructions: list[GateOperation], qubit_map: dict[int, int] +) -> list[GateOperation]: + remapped = [] + for ins in instructions: + new = copy(ins) + new._targets = tuple(qubit_map[q] for q in ins.targets) + remapped.append(new) + return remapped diff --git a/test/unit_tests/braket/default_simulator/test_mcm.py b/test/unit_tests/braket/default_simulator/test_mcm.py index cd6c9c7c..1e28be7d 100644 --- a/test/unit_tests/braket/default_simulator/test_mcm.py +++ b/test/unit_tests/braket/default_simulator/test_mcm.py @@ -143,7 +143,7 @@ def test_4_1_classical_variable_manipulation_with_branching(self): """4.1 Classical variable manipulation with branching""" qasm_source = """ OPENQASM 3.0; - bit[2] b; + bit[3] b; qubit[3] q; int[32] count = 0; @@ -166,6 +166,7 @@ def test_4_1_classical_variable_manipulation_with_branching(self): if (count == 2){ x q[2]; } + b[2] = measure q[2]; """ program = OpenQASMProgram(source=qasm_source, inputs={}) @@ -2618,7 +2619,7 @@ def test_multiple_measurements_and_branching(self, simulator): """Multiple MCMs with conditional logic.""" qasm = """ OPENQASM 3.0; - bit[2] b; + bit[3] b; qubit[2] q; h q[0]; b[0] = measure q[0]; @@ -2629,16 +2630,17 @@ def test_multiple_measurements_and_branching(self, simulator): if (b[0] == b[1]) { x q[1]; } + b[2] = measure q[1]; """ result = simulator.run_openqasm(OpenQASMProgram(source=qasm, inputs={}), shots=1000) assert len(result.measurements) == 1000 + # First two columns mirror q[0] (always 1 after the conditional flip); + # the third column samples q[1] and is 1 only when b[0]==1. counter = Counter(["".join(m) for m in result.measurements]) - # After first measure: if 0 -> X makes it 1, if 1 -> stays 1 - # Second measure always 1. b[0]==b[1] only when b[0]==1 (50%) - assert "11" in counter - assert "10" in counter - assert 400 < counter["11"] < 600 - assert 400 < counter["10"] < 600 + assert "111" in counter + assert "110" in counter + assert 400 < counter["111"] < 600 + assert 400 < counter["110"] < 600 def test_complex_conditional_logic(self, simulator): """Complex conditional with if/else blocks.""" @@ -2722,7 +2724,7 @@ def test_quantum_teleportation(self, simulator): """Quantum teleportation protocol using MCM.""" qasm = """ OPENQASM 3.0; - bit[2] b; + bit[3] b; qubit[3] q; // Prepare state to teleport: |1> on q[0] @@ -2745,6 +2747,7 @@ def test_quantum_teleportation(self, simulator): if (b[0] == 1) { z q[2]; } + b[2] = measure q[2]; """ result = simulator.run_openqasm(OpenQASMProgram(source=qasm, inputs={}), shots=1000) assert len(result.measurements) == 1000 @@ -3297,6 +3300,7 @@ def test_minimal_unassigned_low_bit_stays_zero(self, simulator): h q[1]; c[1] = measure q[1]; if (c[1]) { x q[1]; } + c[0] = measure q[0]; """ result = simulator.run_openqasm(OpenQASMProgram(source=qasm, inputs={}), shots=2000) counter = Counter(["".join(m) for m in result.measurements]) @@ -3312,6 +3316,7 @@ def test_if_branch_preserves_captured_bit_position(self, simulator): c[1] = measure q[1]; if (c[1]) { x q[2]; } c[2] = measure q[2]; + c[0] = measure q[0]; """ result = simulator.run_openqasm(OpenQASMProgram(source=qasm, inputs={}), shots=2000) counter = Counter(["".join(m) for m in result.measurements]) @@ -3511,7 +3516,7 @@ def test_branched_if_else_both_branches(self, simulator): qasm = """ OPENQASM 3.0; bit b; - bit[2] result; + bit[3] result; qubit[3] q; h q[0]; b = measure q[0]; @@ -3520,8 +3525,9 @@ def test_branched_if_else_both_branches(self, simulator): } else { x q[2]; } - result[0] = measure q[1]; - result[1] = measure q[2]; + result[0] = measure q[0]; + result[1] = measure q[1]; + result[2] = measure q[2]; """ result = simulator.run_openqasm(OpenQASMProgram(source=qasm, inputs={}), shots=1000) counter = Counter(["".join(m) for m in result.measurements]) @@ -5033,7 +5039,6 @@ def test_flat_context_preserves_mcm_while_loop(self): assert Interpreter(context=FlatProgramContext()).run(qasm).circuit == qasm - class TestDensityMatrixSimulatorBranching: """Branching MCM coverage for ``DensityMatrixSimulator``. @@ -5103,6 +5108,7 @@ def test_bell_pair_mcm_decoupling(self): if (b[0] == 1) { x q[1]; } + b[1] = measure q[1]; """ # q[1] is always |0> (Bell-correlated, then flipped iff b[0]==1): # outcomes "00" and "10" each ~50%. @@ -5130,7 +5136,7 @@ def test_branched_reset(self): path (``_apply_reset``) on each replayed path.""" qasm = """ OPENQASM 3.0; - bit[1] b; + bit[2] b; qubit[2] q; h q[0]; b[0] = measure q[0]; @@ -5138,6 +5144,32 @@ def test_branched_reset(self): x q[1]; } reset q[0]; + b[0] = measure q[0]; + b[1] = measure q[1]; """ # q[0] is always |0> after the reset; q[1] flips iff b[0]==1 → 50/50. self._assert_distributions_match(qasm, expected_keys={"00", "01"}) + + def test_sparse_qubit_register_does_not_blow_up(self): + """Regression: an SDK-emitted ``qubit[N]`` declaration that only + touches a couple of high-index qubits used to send the density-matrix + simulator into an OOM. ``ProgramContext._get_qubit_samples`` now + builds a per-call qubit map covering only the qubits the path + actually touches, so the DM allocation scales with the number of + active qubits rather than ``N``. + """ + qasm = """ + OPENQASM 3.0; + bit[2] b; + qubit[18] q; + h q[13]; + b[0] = measure q[13]; + if (b[0] == 1) { + prx(3.141592653589793, 0.0) q[17]; + } + b[0] = measure q[13]; + b[1] = measure q[17]; + """ + # Mirrors a verbatim ``measure_ff(q[13], 0); cc_prx(q[17], pi, 0, 0)`` + # emission: q[17] flips iff b[0]==1 → outcomes "00" and "11" each ~50%. + self._assert_distributions_match(qasm, expected_keys={"00", "11"})