Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/braket/default_simulator/openqasm/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,20 @@ def add_measure(
allow_remeasure: bool = False,
):
for index, qubit in enumerate(target):
if not allow_remeasure and qubit in self.measured_qubits:
raise ValueError(f"Qubit {qubit} is already measured or captured.")
self.measured_qubits.append(qubit)
self.qubit_set.add(qubit)
self.target_classical_indices.append(
classical_index = (
classical_targets[index]
if classical_targets
else max(index, len(self.target_classical_indices))
)
if allow_remeasure and classical_index in self.target_classical_indices:
self.measured_qubits[self.target_classical_indices.index(classical_index)] = qubit
self.qubit_set.add(qubit)
continue
if not allow_remeasure and qubit in self.measured_qubits:
raise ValueError(f"Qubit {qubit} is already measured or captured.")
self.measured_qubits.append(qubit)
self.qubit_set.add(qubit)
self.target_classical_indices.append(classical_index)

def add_result(self, result: Results) -> None:
"""
Expand Down
63 changes: 37 additions & 26 deletions src/braket/default_simulator/openqasm/program_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,12 +1311,10 @@ def _flush_pending_mcm_for_variable(self, name: str) -> None:
self._initialize_paths_from_circuit()
# Also flush any earlier pending measurements so the state is correct
for earlier in remaining:
self._measure_and_branch(earlier[0])
self._update_classical_from_measurement(earlier[0], earlier[2])
self._branch_measurement(earlier[0], earlier[1], earlier[2])
remaining.clear()
if self._is_branched:
self._measure_and_branch(mcm_target)
self._update_classical_from_measurement(mcm_target, mcm_dest)
self._branch_measurement(mcm_target, mcm_classical, mcm_dest)
else:
# shots == 0: register as a normal measurement and set variable to 0
self._circuit.add_measure(
Expand Down Expand Up @@ -1362,20 +1360,17 @@ def _flush_pending_mcm_for_qubits(self, qubits: tuple[int, ...] | list[int]) ->
self._pending_mcm_targets = self._pending_mcm_targets[last_overlap_idx + 1 :]

if self._is_branched:
for mcm_target, _mcm_classical, mcm_dest in to_flush:
self._measure_and_branch(mcm_target)
self._update_classical_from_measurement(mcm_target, mcm_dest)
for mcm_target, mcm_classical, mcm_dest in to_flush:
self._branch_measurement(mcm_target, mcm_classical, mcm_dest)
elif self._shots > 0:
self._is_branched = True
self._initialize_paths_from_circuit()
# Flush to_flush first (preserving program order), then any
# remaining pending measurements that came after the overlap.
for mcm_target, _mcm_classical, mcm_dest in to_flush:
self._measure_and_branch(mcm_target)
self._update_classical_from_measurement(mcm_target, mcm_dest)
for mcm_target, mcm_classical, mcm_dest in to_flush:
self._branch_measurement(mcm_target, mcm_classical, mcm_dest)
for entry in self._pending_mcm_targets:
self._measure_and_branch(entry[0])
self._update_classical_from_measurement(entry[0], entry[2])
self._branch_measurement(entry[0], entry[1], entry[2])
self._pending_mcm_targets = []
else:
# shots == 0: register as normal measurements and set variables to 0
Expand Down Expand Up @@ -1536,8 +1531,7 @@ def add_measure(
self._flush_pending_mcm_for_qubits(target)
if self._is_branched:
if classical_destination is not None:
self._measure_and_branch(target)
self._update_classical_from_measurement(target, classical_destination)
self._branch_measurement(target, classical_targets, classical_destination)
else:
# End-of-circuit measurement in branched mode: record in circuit
# for qubit tracking but don't branch further
Expand Down Expand Up @@ -1567,8 +1561,7 @@ def _maybe_transition_to_branched(self) -> None:
self._is_branched = True
self._initialize_paths_from_circuit()
for mcm_target, mcm_classical, mcm_dest in self._pending_mcm_targets:
self._measure_and_branch(mcm_target)
self._update_classical_from_measurement(mcm_target, mcm_dest)
self._branch_measurement(mcm_target, mcm_classical, mcm_dest)
self._pending_mcm_targets.clear()

def track_mcm_dependency(self, lvalue_name: str, rvalue) -> None:
Expand Down Expand Up @@ -1992,6 +1985,21 @@ def _initialize_paths_from_circuit(self) -> None:
)
initial_path.set_variable(name, fv)

def _branch_measurement(
self,
target: tuple[int, ...],
classical_targets,
classical_destination,
) -> None:
self._measure_and_branch(target)
self._update_classical_from_measurement(target, classical_destination)
if classical_targets is not None:
self._circuit.add_measure(
target,
classical_targets,
allow_remeasure=self.supports_midcircuit_measurement,
)

def _measure_and_branch(self, target: tuple[int]) -> None:
"""Sample outcomes per active path and branch with proportional shot
allocation.
Expand Down Expand Up @@ -2060,19 +2068,22 @@ def _get_qubit_samples(self, path: SimulationPath, qubit_idx: int) -> np.ndarray
"Construct it via ``BaseLocalSimulator.create_program_context`` or pass "
"``simulator=...`` to provide one."
)
# Use the total declared qubit count (from the context), not just the
# qubits that have appeared in instructions so far. This ensures that
# measurements on qubits that haven't had gates applied yet still work
# (they are in the |0⟩ state).
qubit_count = self.num_qubits
if self._circuit.qubit_set:
qubit_count = max(qubit_count, max(self._circuit.qubit_set) + 1)
# Build a contiguous qubit map covering only the qubits this path actually touches
used = {qubit_idx}
for ins in path.instructions:
used.update(ins.targets)
qubit_map = {q: i for i, q in enumerate(sorted(used))}
sim = self._simulator.initialize_simulation(
qubit_count=qubit_count, shots=path.shots, batch_size=self._batch_size
qubit_count=len(qubit_map), shots=path.shots, batch_size=self._batch_size
)
sim.evolve(path.instructions)
remapped = []
for ins in path.instructions:
new_ins = copy(ins)
new_ins._targets = tuple(qubit_map[q] for q in ins.targets)
remapped.append(new_ins)
sim.evolve(remapped)
samples = np.asarray(sim.retrieve_samples())
return (samples >> (qubit_count - 1 - qubit_idx)) & 1
return (samples >> (qubit_count - 1 - qubit_map[qubit_idx])) & 1


_BINARY_EQUALS = getattr(BinaryOperator, "==")
Expand Down
46 changes: 27 additions & 19 deletions src/braket/default_simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable
from copy import copy
from typing import Any

import numpy as np
Expand All @@ -26,7 +27,7 @@
AbstractProgramContext,
ProgramContext,
)
from braket.default_simulator.operation import Observable, Operation
from braket.default_simulator.operation import GateOperation, Observable, Operation
from braket.default_simulator.operation_helpers import from_braket_instruction
from braket.default_simulator.result_types import (
ResultType,
Expand Down Expand Up @@ -856,8 +857,16 @@ def _run_branched(
GateModelTaskResult: Aggregated result across all paths.
"""
circuit = context.circuit
qubit_map = BaseLocalSimulator._map_circuit_to_contiguous_qubits(circuit)
qubit_count = circuit.num_qubits
path_qubit_set = set()
for path in context.active_paths:
for ins in path.instructions:
path_qubit_set.update(ins.targets)
full_qubit_set = circuit.qubit_set | path_qubit_set
if not circuit.measured_qubits:
full_qubit_set |= set(range(context.num_qubits))
qubit_map = BaseLocalSimulator._contiguous_qubit_mapping(full_qubit_set)
qubit_count = len(qubit_map)
BaseLocalSimulator._map_circuit_qubits(circuit, qubit_map)

# Determine measured qubits from the circuit
classical_bit_positions = {b: i for i, b in enumerate(circuit.target_classical_indices)}
Expand All @@ -869,34 +878,22 @@ def _run_branched(
[qubit_map[q] for q in measured_qubits] if measured_qubits else None
)

# For path simulation, we need enough qubits to cover all qubit indices
# referenced in the instructions (handles noncontiguous qubit indices).
# Use the context's num_qubits (total declared qubits) to ensure all
# qubits are accounted for, even those without explicit gate operations.
sim_qubit_count = qubit_count
sim_qubit_count = max(sim_qubit_count, context.num_qubits)
if circuit.qubit_set:
sim_qubit_count = max(sim_qubit_count, max(circuit.qubit_set) + 1)

# Aggregate samples across all active paths
all_samples = []
for path in context.active_paths:
sim = self.initialize_simulation(
qubit_count=sim_qubit_count, shots=path.shots, batch_size=batch_size
qubit_count=qubit_count, shots=path.shots, batch_size=batch_size
)
sim.evolve(path.instructions)
sim.evolve(_remap_instructions(path.instructions, qubit_map))
all_samples.extend(sim.retrieve_samples())

# Build measurements in the same format as _formatted_measurements
measurements = [
list("{number:0{width}b}".format(number=sample, width=sim_qubit_count))[
-sim_qubit_count:
]
list("{number:0{width}b}".format(number=sample, width=qubit_count))[-qubit_count:]
for sample in all_samples
]
if mapped_measured_qubits is not None and mapped_measured_qubits != []:
mapped_arr = np.array(mapped_measured_qubits)
in_circuit_mask = mapped_arr < sim_qubit_count
in_circuit_mask = mapped_arr < qubit_count
qubits_in_circuit = mapped_arr[in_circuit_mask]
qubits_not_in_circuit = mapped_arr[~in_circuit_mask]
measurements_array = np.array(measurements)
Expand Down Expand Up @@ -995,3 +992,14 @@ def run_jaqcd(
)

return self._create_results_obj(results, circuit_ir, simulation)


def _remap_instructions(
instructions: list[GateOperation], qubit_map: dict[int, int]
) -> list[GateOperation]:
remapped = []
for ins in instructions:
new = copy(ins)
new._targets = tuple(qubit_map[q] for q in ins.targets)
remapped.append(new)
return remapped
60 changes: 46 additions & 14 deletions test/unit_tests/braket/default_simulator/test_mcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_4_1_classical_variable_manipulation_with_branching(self):
"""4.1 Classical variable manipulation with branching"""
qasm_source = """
OPENQASM 3.0;
bit[2] b;
bit[3] b;
qubit[3] q;
int[32] count = 0;

Expand All @@ -166,6 +166,7 @@ def test_4_1_classical_variable_manipulation_with_branching(self):
if (count == 2){
x q[2];
}
b[2] = measure q[2];
"""

program = OpenQASMProgram(source=qasm_source, inputs={})
Expand Down Expand Up @@ -2618,7 +2619,7 @@ def test_multiple_measurements_and_branching(self, simulator):
"""Multiple MCMs with conditional logic."""
qasm = """
OPENQASM 3.0;
bit[2] b;
bit[3] b;
qubit[2] q;
h q[0];
b[0] = measure q[0];
Expand All @@ -2629,16 +2630,17 @@ def test_multiple_measurements_and_branching(self, simulator):
if (b[0] == b[1]) {
x q[1];
}
b[2] = measure q[1];
"""
result = simulator.run_openqasm(OpenQASMProgram(source=qasm, inputs={}), shots=1000)
assert len(result.measurements) == 1000
# First two columns mirror q[0] (always 1 after the conditional flip);
# the third column samples q[1] and is 1 only when b[0]==1.
counter = Counter(["".join(m) for m in result.measurements])
# After first measure: if 0 -> X makes it 1, if 1 -> stays 1
# Second measure always 1. b[0]==b[1] only when b[0]==1 (50%)
assert "11" in counter
assert "10" in counter
assert 400 < counter["11"] < 600
assert 400 < counter["10"] < 600
assert "111" in counter
assert "110" in counter
assert 400 < counter["111"] < 600
assert 400 < counter["110"] < 600

def test_complex_conditional_logic(self, simulator):
"""Complex conditional with if/else blocks."""
Expand Down Expand Up @@ -2722,7 +2724,7 @@ def test_quantum_teleportation(self, simulator):
"""Quantum teleportation protocol using MCM."""
qasm = """
OPENQASM 3.0;
bit[2] b;
bit[3] b;
qubit[3] q;

// Prepare state to teleport: |1> on q[0]
Expand All @@ -2745,6 +2747,7 @@ def test_quantum_teleportation(self, simulator):
if (b[0] == 1) {
z q[2];
}
b[2] = measure q[2];
"""
result = simulator.run_openqasm(OpenQASMProgram(source=qasm, inputs={}), shots=1000)
assert len(result.measurements) == 1000
Expand Down Expand Up @@ -3297,6 +3300,7 @@ def test_minimal_unassigned_low_bit_stays_zero(self, simulator):
h q[1];
c[1] = measure q[1];
if (c[1]) { x q[1]; }
c[0] = measure q[0];
"""
result = simulator.run_openqasm(OpenQASMProgram(source=qasm, inputs={}), shots=2000)
counter = Counter(["".join(m) for m in result.measurements])
Expand All @@ -3312,6 +3316,7 @@ def test_if_branch_preserves_captured_bit_position(self, simulator):
c[1] = measure q[1];
if (c[1]) { x q[2]; }
c[2] = measure q[2];
c[0] = measure q[0];
"""
result = simulator.run_openqasm(OpenQASMProgram(source=qasm, inputs={}), shots=2000)
counter = Counter(["".join(m) for m in result.measurements])
Expand Down Expand Up @@ -3511,7 +3516,7 @@ def test_branched_if_else_both_branches(self, simulator):
qasm = """
OPENQASM 3.0;
bit b;
bit[2] result;
bit[3] result;
qubit[3] q;
h q[0];
b = measure q[0];
Expand All @@ -3520,8 +3525,9 @@ def test_branched_if_else_both_branches(self, simulator):
} else {
x q[2];
}
result[0] = measure q[1];
result[1] = measure q[2];
result[0] = measure q[0];
result[1] = measure q[1];
result[2] = measure q[2];
"""
result = simulator.run_openqasm(OpenQASMProgram(source=qasm, inputs={}), shots=1000)
counter = Counter(["".join(m) for m in result.measurements])
Expand Down Expand Up @@ -5033,7 +5039,6 @@ def test_flat_context_preserves_mcm_while_loop(self):
assert Interpreter(context=FlatProgramContext()).run(qasm).circuit == qasm



class TestDensityMatrixSimulatorBranching:
"""Branching MCM coverage for ``DensityMatrixSimulator``.

Expand Down Expand Up @@ -5103,6 +5108,7 @@ def test_bell_pair_mcm_decoupling(self):
if (b[0] == 1) {
x q[1];
}
b[1] = measure q[1];
"""
# q[1] is always |0> (Bell-correlated, then flipped iff b[0]==1):
# outcomes "00" and "10" each ~50%.
Expand Down Expand Up @@ -5130,14 +5136,40 @@ def test_branched_reset(self):
path (``_apply_reset``) on each replayed path."""
qasm = """
OPENQASM 3.0;
bit[1] b;
bit[2] b;
qubit[2] q;
h q[0];
b[0] = measure q[0];
if (b[0] == 1) {
x q[1];
}
reset q[0];
b[0] = measure q[0];
b[1] = measure q[1];
"""
# q[0] is always |0> after the reset; q[1] flips iff b[0]==1 → 50/50.
self._assert_distributions_match(qasm, expected_keys={"00", "01"})

def test_sparse_qubit_register_does_not_blow_up(self):
"""Regression: an SDK-emitted ``qubit[N]`` declaration that only
touches a couple of high-index qubits used to send the density-matrix
simulator into an OOM. ``ProgramContext._get_qubit_samples`` now
builds a per-call qubit map covering only the qubits the path
actually touches, so the DM allocation scales with the number of
active qubits rather than ``N``.
"""
qasm = """
OPENQASM 3.0;
bit[2] b;
qubit[18] q;
h q[13];
b[0] = measure q[13];
if (b[0] == 1) {
prx(3.141592653589793, 0.0) q[17];
}
b[0] = measure q[13];
b[1] = measure q[17];
"""
# Mirrors a verbatim ``measure_ff(q[13], 0); cc_prx(q[17], pi, 0, 0)``
# emission: q[17] flips iff b[0]==1 → outcomes "00" and "11" each ~50%.
self._assert_distributions_match(qasm, expected_keys={"00", "11"})