From fbe395f423e3d95dfa20291b3b7148c8bebfdf35 Mon Sep 17 00:00:00 2001 From: Cody Wang Date: Fri, 1 May 2026 09:11:27 -0700 Subject: [PATCH 1/3] feat: Parallelize path simulations Gives implementers the option to run path simulations in separate processes; this could be particularly helpful for remote simulator backends. --- .../openqasm/program_context.py | 35 ++++++++---- src/braket/default_simulator/simulator.py | 57 ++++++++++++++++--- .../braket/default_simulator/test_mcm.py | 41 +++++++++++++ 3 files changed, 115 insertions(+), 18 deletions(-) diff --git a/src/braket/default_simulator/openqasm/program_context.py b/src/braket/default_simulator/openqasm/program_context.py index 0f74f562..515d22d3 100644 --- a/src/braket/default_simulator/openqasm/program_context.py +++ b/src/braket/default_simulator/openqasm/program_context.py @@ -15,6 +15,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable +from concurrent.futures import ThreadPoolExecutor from dataclasses import fields from functools import singledispatchmethod from typing import TYPE_CHECKING, Any @@ -1989,30 +1990,42 @@ def _measure_and_branch(self, target: tuple[int]) -> None: """Sample outcomes per active path and branch with proportional shot allocation. - For each qubit in target, for each active path: - 1. Ask the subclass-supplied ``_get_qubit_samples`` for - ``path.shots`` sampled bit outcomes of the qubit on this path. - 2. Split the path: one child gets shots that measured 0, the other - gets shots that measured 1. + For each qubit in target: + 1. Ask the subclass-supplied ``_get_qubit_samples`` for ``path.shots`` + sampled bit outcomes of the qubit on each active path. When the + simulator opts in via ``parallelize_paths`` these per-path + samplings are fanned out to a thread pool. + 2. For each path, split it: one child gets shots that measured 0, the + other gets shots that measured 1. 3. If one outcome has 0 shots, don't create that branch (deterministic case). 4. Remove paths with 0 shots from the active set. """ for qubit_idx in target: + saved_active = list(self._active_path_indices) + per_path_samples = self._collect_qubit_samples(saved_active, qubit_idx) new_active_indices = [] - for path_idx in list(self._active_path_indices): - self._branch_single_qubit(path_idx, qubit_idx, new_active_indices) + for path_idx, qubit_samples in zip(saved_active, per_path_samples): + self._branch_single_qubit(path_idx, qubit_idx, qubit_samples, new_active_indices) self._active_path_indices = new_active_indices + def _collect_qubit_samples(self, path_indices: list[int], qubit_idx: int) -> list[np.ndarray]: + paths = [self._paths[idx] for idx in path_indices] + if self._simulator is not None and self._simulator.parallelize_paths and len(paths) > 1: + with ThreadPoolExecutor() as pool: + return list(pool.map(lambda path: self._get_qubit_samples(path, qubit_idx), paths)) + return [self._get_qubit_samples(path, qubit_idx) for path in paths] + def _branch_single_qubit( - self, path_idx: int, qubit_idx: int, new_active_indices: list[int] + self, + path_idx: int, + qubit_idx: int, + qubit_samples: np.ndarray, + new_active_indices: list[int], ) -> None: """Branch a single path on a single qubit measurement.""" path = self._paths[path_idx] - # Defer to the concrete simulator to sample the target qubit's bit for - # each of ``path.shots`` shots; then the shot-split is just a tally. - qubit_samples = self._get_qubit_samples(path, qubit_idx) path_shots = path.shots shots_for_1 = int(np.sum(qubit_samples)) shots_for_0 = path_shots - shots_for_1 diff --git a/src/braket/default_simulator/simulator.py b/src/braket/default_simulator/simulator.py index 9f2675e7..678767a1 100644 --- a/src/braket/default_simulator/simulator.py +++ b/src/braket/default_simulator/simulator.py @@ -15,6 +15,7 @@ import warnings from abc import ABC, abstractmethod from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor from typing import Any import numpy as np @@ -73,6 +74,25 @@ ) +def _evolve_path_and_sample( + simulator: "BaseLocalSimulator", + instructions, + qubit_count: int, + shots: int, + batch_size: int, +): + """Worker: evolve ``instructions`` through ``simulator.initialize_simulation`` + and return the path's samples. + + Defined at module scope so it's picklable and can run in a process pool. + """ + sim = simulator.initialize_simulation( + qubit_count=qubit_count, shots=shots, batch_size=batch_size + ) + sim.evolve(instructions) + return sim.retrieve_samples() + + class OpenQASMSimulator(BraketSimulator, ABC): """An abstract simulator that runs an OpenQASM 3 program. @@ -132,6 +152,14 @@ def parse_program(self, program: OpenQASMProgram) -> AbstractProgramContext: class BaseLocalSimulator(OpenQASMSimulator): + @property + def parallelize_paths(self) -> bool: + """Whether to run independent simulation paths (e.g. branches of a + mid-circuit measurement) in parallel. Off by default because for + small/short paths the pool overhead dominates; simulators that + routinely handle large branched workloads can override this.""" + return False + def run( self, circuit_ir: OpenQASMProgram | ProgramSet | JaqcdProgram, *args, **kwargs ) -> GateModelTaskResult | ProgramSetTaskResult: @@ -878,14 +906,29 @@ def _run_branched( if circuit.qubit_set: sim_qubit_count = max(sim_qubit_count, max(circuit.qubit_set) + 1) - # Aggregate samples across all active paths + paths = list(context.active_paths) + if self.parallelize_paths and len(paths) > 1: + with ThreadPoolExecutor() as pool: + per_path_samples = list( + pool.map( + _evolve_path_and_sample, + [self] * len(paths), + [path.instructions for path in paths], + [sim_qubit_count] * len(paths), + [path.shots for path in paths], + [batch_size] * len(paths), + ) + ) + else: + per_path_samples = [ + _evolve_path_and_sample( + self, path.instructions, sim_qubit_count, path.shots, batch_size + ) + for path in paths + ] all_samples = [] - for path in context.active_paths: - sim = self.initialize_simulation( - qubit_count=sim_qubit_count, shots=path.shots, batch_size=batch_size - ) - sim.evolve(path.instructions) - all_samples.extend(sim.retrieve_samples()) + for samples in per_path_samples: + all_samples.extend(samples) # Build measurements in the same format as _formatted_measurements measurements = [ diff --git a/test/unit_tests/braket/default_simulator/test_mcm.py b/test/unit_tests/braket/default_simulator/test_mcm.py index 9110ac9f..5ff92877 100644 --- a/test/unit_tests/braket/default_simulator/test_mcm.py +++ b/test/unit_tests/braket/default_simulator/test_mcm.py @@ -4997,3 +4997,44 @@ def test_flat_context_preserves_mcm_while_loop(self): "}" ) assert Interpreter(context=FlatProgramContext()).run(qasm).circuit == qasm + + +class TestParallelizePaths: + """Cover the ``parallelize_paths`` thread-pool branches in + ``BaseLocalSimulator._run_branched`` and + ``ProgramContext._collect_qubit_samples``.""" + + class _ParallelStateVectorSimulator(StateVectorSimulator): + @property + def parallelize_paths(self) -> bool: + return True + + def test_parallel_paths_produce_expected_distribution(self): + """A parallelized simulator produces a correct distribution for a + program whose second measurement fires on multiple already-branched + paths — which exercises the thread-pool branch of + ``_collect_qubit_samples``.""" + qasm = """ + OPENQASM 3.0; + qubit[3] q; + bit b0; + bit b1; + h q[0]; + b0 = measure q[0]; + if (b0 == 1) { + h q[1]; + } else { + h q[1]; + } + b1 = measure q[1]; + if (b1 == 1) { + x q[2]; + } + """ + simulator = self._ParallelStateVectorSimulator() + result = simulator.run_openqasm(OpenQASMProgram(source=qasm, inputs={}), shots=1000) + counts = Counter("".join(m) for m in result.measurements) + # Regardless of b0, q[1] is Hadamarded then measured → 50/50. + # Regardless of b1, q[2] ends in state b1. + assert set(counts.keys()).issubset({"000", "001", "011", "100", "101", "111"}) + assert sum(counts.values()) == 1000 From cc9517831488aea751d2c157948a072f1ecfbcee Mon Sep 17 00:00:00 2001 From: Cody Wang Date: Fri, 1 May 2026 09:16:11 -0700 Subject: [PATCH 2/3] Update simulator.py --- src/braket/default_simulator/simulator.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/src/braket/default_simulator/simulator.py b/src/braket/default_simulator/simulator.py index 678767a1..9c9d347d 100644 --- a/src/braket/default_simulator/simulator.py +++ b/src/braket/default_simulator/simulator.py @@ -81,11 +81,6 @@ def _evolve_path_and_sample( shots: int, batch_size: int, ): - """Worker: evolve ``instructions`` through ``simulator.initialize_simulation`` - and return the path's samples. - - Defined at module scope so it's picklable and can run in a process pool. - """ sim = simulator.initialize_simulation( qubit_count=qubit_count, shots=shots, batch_size=batch_size ) @@ -152,14 +147,6 @@ def parse_program(self, program: OpenQASMProgram) -> AbstractProgramContext: class BaseLocalSimulator(OpenQASMSimulator): - @property - def parallelize_paths(self) -> bool: - """Whether to run independent simulation paths (e.g. branches of a - mid-circuit measurement) in parallel. Off by default because for - small/short paths the pool overhead dominates; simulators that - routinely handle large branched workloads can override this.""" - return False - def run( self, circuit_ir: OpenQASMProgram | ProgramSet | JaqcdProgram, *args, **kwargs ) -> GateModelTaskResult | ProgramSetTaskResult: @@ -191,6 +178,11 @@ def run( return self.run_program_set(circuit_ir, *args, **kwargs) return self.run_jaqcd(circuit_ir, *args, **kwargs) + @property + def parallelize_paths(self) -> bool: + """bool: Whether to run path simulations in parallel.""" + return False + def create_program_context(self) -> AbstractProgramContext: return ProgramContext(simulator=self) From d8d4bc7c7e4d41b31c987a3b1173dda5ca4bebeb Mon Sep 17 00:00:00 2001 From: Cody Wang Date: Fri, 1 May 2026 09:17:28 -0700 Subject: [PATCH 3/3] rearrange --- src/braket/default_simulator/simulator.py | 28 +++++++++++------------ 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/braket/default_simulator/simulator.py b/src/braket/default_simulator/simulator.py index 9c9d347d..c379177a 100644 --- a/src/braket/default_simulator/simulator.py +++ b/src/braket/default_simulator/simulator.py @@ -74,20 +74,6 @@ ) -def _evolve_path_and_sample( - simulator: "BaseLocalSimulator", - instructions, - qubit_count: int, - shots: int, - batch_size: int, -): - sim = simulator.initialize_simulation( - qubit_count=qubit_count, shots=shots, batch_size=batch_size - ) - sim.evolve(instructions) - return sim.retrieve_samples() - - class OpenQASMSimulator(BraketSimulator, ABC): """An abstract simulator that runs an OpenQASM 3 program. @@ -1030,3 +1016,17 @@ def run_jaqcd( ) return self._create_results_obj(results, circuit_ir, simulation) + + +def _evolve_path_and_sample( + simulator: BaseLocalSimulator, + instructions, + qubit_count: int, + shots: int, + batch_size: int, +): + sim = simulator.initialize_simulation( + qubit_count=qubit_count, shots=shots, batch_size=batch_size + ) + sim.evolve(instructions) + return sim.retrieve_samples()