diff --git a/src/braket/pulse/pulse_sequence.py b/src/braket/pulse/pulse_sequence.py index 7c61201ef..884dbe193 100644 --- a/src/braket/pulse/pulse_sequence.py +++ b/src/braket/pulse/pulse_sequence.py @@ -31,7 +31,7 @@ from braket.pulse.ast.qasm_transformer import _IRQASMTransformer from braket.pulse.frame import Frame from braket.pulse.pulse_sequence_trace import PulseSequenceTrace -from braket.pulse.waveforms import Waveform +from braket.pulse.waveforms import Waveform, WaveformDict from braket.registers.qubit_set import QubitSet @@ -44,9 +44,13 @@ def __init__(self): self._capture_v0_count = 0 self._program = Program(simplify_constants=False) self._frames = {} - self._waveforms = {} + self._waveforms = WaveformDict({}, self) self._free_parameters = set() + @property + def waveforms(self) -> WaveformDict: + return self._waveforms + def to_time_trace(self) -> PulseSequenceTrace: """Generate an approximate trace of the amplitude, frequency, phase for each frame contained in the PulseSequence, under the action of the instructions contained in diff --git a/src/braket/pulse/waveforms.py b/src/braket/pulse/waveforms.py index 3e7127209..25a3b1e1d 100644 --- a/src/braket/pulse/waveforms.py +++ b/src/braket/pulse/waveforms.py @@ -16,11 +16,25 @@ import random import string from abc import ABC, abstractmethod +from copy import deepcopy +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: # pragma: no cover + from braket.pulse.pulse_sequence import PulseSequence import numpy as np import scipy as sp -from oqpy import WaveformVar, bool_, complex128, declare_waveform_generator, duration, float64 -from oqpy.base import OQPyExpression +from openpulse import ast +from oqpy import ( + WaveformVar, + bool_, + complex128, + convert_float_to_duration, + declare_waveform_generator, + duration, + float64, +) +from oqpy.base import OQPyExpression, to_ast from braket.parametric.free_parameter import FreeParameter from braket.parametric.free_parameter_expression import ( @@ -30,6 +44,28 @@ from braket.parametric.parameterizable import Parameterizable +class WaveformDict(dict): # noqa: FURB189 + """ + A dict of waveforms. + + Note: + WaveformDict binds a pulse sequence to each waveform that is + added to the dict. It serves as back reference when a + waveform is modified so the OQpy object is also updated. + """ + + def __init__(self, waveform_dict: dict, pulse_sequence: PulseSequence): + for waveform in waveform_dict.values(): + waveform._pulse_sequence = pulse_sequence + super().__init__(waveform_dict) + self._pulse_sequence = pulse_sequence + + def __setitem__(self, key: str, value: Waveform): + value = deepcopy(value) + value._pulse_sequence = self._pulse_sequence + super().__setitem__(key, value) + + class Waveform(ABC): """A waveform is a time-dependent envelope that can be used to emit signals on an output port or receive signals from an input port. As such, when transmitting signals to the qubit, a @@ -39,6 +75,25 @@ class Waveform(ABC): for more details. """ + def __init__(self) -> None: + self._pulse_sequence = None + + def _modify_oqpy_waveform_var( + self, key: str, value: Any, type_: ast.ClassicalType = float64 + ) -> None: + if self._pulse_sequence is not None: + self._pulse_sequence._register_free_parameters(value) + self._pulse_sequence._program.undeclared_vars[self.id].init_expression.args[key] = ( + to_ast( + self._pulse_sequence._program, + ( + convert_float_to_duration(value) + if isinstance(type_, ast.DurationType) + else value + ), + ) + ) + @abstractmethod def _to_oqpy_expression(self) -> OQPyExpression: """Returns an OQPyExpression defining this waveform.""" @@ -82,8 +137,28 @@ def __init__(self, amplitudes: list[complex], id: str | None = None): id (str | None): The identifier used for declaring this waveform. A random string of ascii characters is assigned by default. """ - self.amplitudes = list(amplitudes) + self._amplitudes = list(amplitudes) self.id = id or _make_identifier_name() + super().__init__() + + @property + def amplitudes(self) -> list[complex]: + return self._amplitudes + + @amplitudes.setter + def amplitudes(self, value: list[complex]) -> None: + """ + Sets the list of amplitudes. + + Args: + value (list[complex]): Array of complex values specifying the + waveform amplitude at each timestep. The timestep is determined by the sampling rate + of the frame to which waveform is applied to. + + """ + self._amplitudes = value + if self._pulse_sequence is not None: + self._pulse_sequence._program.undeclared_vars[self.id].init_expression = value def __repr__(self) -> str: return f"ArbitraryWaveform('id': {self.id}, 'amplitudes': {self.amplitudes})" @@ -138,9 +213,41 @@ def __init__(self, length: float | FreeParameterExpression, iq: complex, id: str id (str | None): The identifier used for declaring this waveform. A random string of ascii characters is assigned by default. """ - self.length = length - self.iq = iq + self._length = length + self._iq = iq self.id = id or _make_identifier_name() + super().__init__() + + @property + def iq(self) -> complex: + return self._iq + + @iq.setter + def iq(self, value: complex) -> None: + """ + Sets the IQ value. + + Args: + value (complex): complex value specifying the amplitude of the waveform. + """ + self._iq = value + self._modify_oqpy_waveform_var("iq", value) + + @property + def length(self) -> float | FreeParameterExpression: + return self._length + + @length.setter + def length(self, value: float | FreeParameterExpression) -> None: + """ + Sets the length. + + Args: + value (Union[float, FreeParameterExpression]): Value (in seconds) + specifying the duration of the waveform. + """ + self._length = value + self._modify_oqpy_waveform_var("length", value, duration) def __repr__(self) -> str: return f"ConstantWaveform('id': {self.id}, 'length': {self.length}, 'iq': {self.iq})" @@ -253,12 +360,92 @@ def __init__( id (str | None): The identifier used for declaring this waveform. A random string of ascii characters is assigned by default. """ - self.length = length - self.sigma = sigma - self.beta = beta - self.amplitude = amplitude - self.zero_at_edges = zero_at_edges + self._length = length + self._sigma = sigma + self._beta = beta + self._amplitude = amplitude + self._zero_at_edges = zero_at_edges self.id = id or _make_identifier_name() + super().__init__() + + @property + def length(self) -> float | FreeParameterExpression: + return self._length + + @length.setter + def length(self, value: float | FreeParameterExpression) -> None: + """ + Sets the length. + + Args: + value (Union[float, FreeParameterExpression]): Value (in seconds) + specifying the duration of the waveform. + """ + self._length = value + self._modify_oqpy_waveform_var("length", value, duration) + + @property + def sigma(self) -> float | FreeParameterExpression: + return self._sigma + + @sigma.setter + def sigma(self, value: float | FreeParameterExpression) -> None: + """ + Sets the DRAG gaussian width. + + Args: + value (Union[float, FreeParameterExpression]): A measure (in seconds) of + how wide or narrow the Gaussian peak is. + """ + self._sigma = value + self._modify_oqpy_waveform_var("sigma", value, duration) + + @property + def beta(self) -> float | FreeParameterExpression: + return self._beta + + @beta.setter + def beta(self, value: float | FreeParameterExpression) -> None: + """ + Sets the beta value. + + Args: + value (Union[float, FreeParameterExpression]): The correction amplitude. + """ + self._beta = value + self._modify_oqpy_waveform_var("beta", value) + + @property + def amplitude(self) -> float | FreeParameterExpression: + return self._amplitude + + @amplitude.setter + def amplitude(self, value: float | FreeParameterExpression) -> None: + """ + Sets the amplitude. + + Args: + value (Union[float, FreeParameterExpression]): The amplitude of the + waveform envelope. + """ + self._amplitude = value + self._modify_oqpy_waveform_var("amplitude", value) + + @property + def zero_at_edges(self) -> bool: + return self._zero_at_edges + + @zero_at_edges.setter + def zero_at_edges(self, value: bool) -> None: + """ + Sets if the DRAG gaussian waveform should start and end at zero. + + Args: + value (bool): bool specifying whether the waveform amplitude is clipped to + zero at the edges. + """ + self._zero_at_edges = value + self._modify_oqpy_waveform_var("zero_at_edges", value) def __repr__(self) -> str: return ( @@ -392,11 +579,79 @@ def __init__( id (str | None): The identifier used for declaring this waveform. A random string of ascii characters is assigned by default. """ - self.length = length - self.sigma = sigma - self.amplitude = amplitude - self.zero_at_edges = zero_at_edges + self._length = length + self._sigma = sigma + self._amplitude = amplitude + self._zero_at_edges = zero_at_edges self.id = id or _make_identifier_name() + super().__init__() + + @property + def length(self) -> float | FreeParameterExpression: + return self._length + + @length.setter + def length(self, value: float | FreeParameterExpression) -> None: + """ + Sets the length. + + Args: + value (Union[float, FreeParameterExpression]): Value (in seconds) specifying the + duration of the waveform. + """ + self._length = value + self._modify_oqpy_waveform_var("length", value, duration) + + @property + def sigma(self) -> float | FreeParameterExpression: + return self._sigma + + @sigma.setter + def sigma(self, value: float | FreeParameterExpression) -> None: + """ + Sets the gaussian waveform width. + + Args: + value (Union[float, FreeParameterExpression]): A measure (in seconds) of how wide + or narrow the Gaussian peak is. + """ + self._sigma = value + self._modify_oqpy_waveform_var("sigma", value, duration) + + @property + def amplitude(self) -> float | FreeParameterExpression: + return self._amplitude + + @amplitude.setter + def amplitude(self, value: float | FreeParameterExpression) -> None: + """ + Sets the amplitude. + + Args: + value (Union[float, FreeParameterExpression]): The amplitude of the waveform + envelope. + """ + self._amplitude = value + self._modify_oqpy_waveform_var("amplitude", value) + + @property + def zero_at_edges(self) -> bool: + return self._zero_at_edges + + @zero_at_edges.setter + def zero_at_edges(self, value: bool) -> None: + """ + Sets if the DRAG gaussian waveform should start and end at zero. + + Args: + value (bool): bool specifying whether the waveform amplitude is clipped to + zero at the edges. + """ + self._zero_at_edges = value + if self._pulse_sequence is not None: + self._pulse_sequence._program.undeclared_vars[self.id].init_expression.args[ + "zero_at_edges" + ] = value def __repr__(self) -> str: return ( diff --git a/test/unit_tests/braket/pulse/test_pulse_sequence.py b/test/unit_tests/braket/pulse/test_pulse_sequence.py index 20b3bbef4..ba7176200 100644 --- a/test/unit_tests/braket/pulse/test_pulse_sequence.py +++ b/test/unit_tests/braket/pulse/test_pulse_sequence.py @@ -24,6 +24,7 @@ Port, PulseSequence, ) +from braket.pulse.waveforms import WaveformDict @pytest.fixture @@ -79,6 +80,78 @@ def test_pulse_sequence_with_user_defined_frame(user_defined_frame): assert pulse_sequence.to_ir() == expected_str +def test_create_waveformdict_with_pulse_sequence(user_defined_frame): + pulse_sequence = PulseSequence().set_frequency(user_defined_frame, 6e6) + wf = ConstantWaveform(1e-3, complex(1, 2), "wf_id") + + waveform_dict = WaveformDict({"wf_id": wf}, pulse_sequence) + assert waveform_dict._pulse_sequence == pulse_sequence + assert waveform_dict["wf_id"]._pulse_sequence == pulse_sequence + + +def test_pulse_sequence_with_modified_wf(predefined_frame_1): + pulse_sequence = ( + PulseSequence() + .play(predefined_frame_1, GaussianWaveform(length=1e-3, sigma=0.7, id="gauss_wf")) + .play( + predefined_frame_1, + DragGaussianWaveform(length=3e-3, sigma=0.4, beta=0.2, id="drag_gauss_wf"), + ) + .play( + predefined_frame_1, + ConstantWaveform(length=4e-3, iq=complex(2, 0.3), id="constant_wf"), + ) + .play( + predefined_frame_1, + ArbitraryWaveform([complex(1, 0.4), 0, 0.3, complex(0.1, 0.2)], id="arb_wf"), + ) + ) + expected_str = "\n".join([ + "OPENQASM 3.0;", + "cal {", + " waveform gauss_wf = gaussian(1.0ms, 700.0ms, 1, false);", + " waveform drag_gauss_wf = drag_gaussian(3.0ms, 400.0ms, 0.2, 1, false);", + " waveform constant_wf = constant(4.0ms, 2.0 + 0.3im);", + " waveform arb_wf = {1.0 + 0.4im, 0, 0.3, 0.1 + 0.2im};", + " play(predefined_frame_1, gauss_wf);", + " play(predefined_frame_1, drag_gauss_wf);", + " play(predefined_frame_1, constant_wf);", + " play(predefined_frame_1, arb_wf);", + "}", + ]) + expected_str_after_mod = "\n".join([ + "OPENQASM 3.0;", + "cal {", + " waveform gauss_wf = gaussian(17.0ns, 100.0ms, 0.2, true);", + " waveform drag_gauss_wf = drag_gaussian(1.0us, 100.0ms, 0.25, 0.3, true);", + " waveform constant_wf = constant(200.0ns, 0.5);", + " waveform arb_wf = {-1.0 - 0.4im, 0, -0.3, -0.1 - 0.2im};", + " play(predefined_frame_1, gauss_wf);", + " play(predefined_frame_1, drag_gauss_wf);", + " play(predefined_frame_1, constant_wf);", + " play(predefined_frame_1, arb_wf);", + "}", + ]) + assert pulse_sequence.to_ir() == expected_str + pulse_sequence.waveforms["constant_wf"].iq = 0.5 + pulse_sequence.waveforms["constant_wf"].length = 2e-7 + + pulse_sequence.waveforms["gauss_wf"].length = 17e-9 + pulse_sequence.waveforms["gauss_wf"].sigma = 0.1 + pulse_sequence.waveforms["gauss_wf"].amplitude = 0.2 + pulse_sequence.waveforms["gauss_wf"].zero_at_edges = True + + pulse_sequence.waveforms["drag_gauss_wf"].length = 1e-6 + pulse_sequence.waveforms["drag_gauss_wf"].sigma = 0.1 + pulse_sequence.waveforms["drag_gauss_wf"].beta = 0.25 + pulse_sequence.waveforms["drag_gauss_wf"].amplitude = 0.3 + pulse_sequence.waveforms["drag_gauss_wf"].zero_at_edges = True + + pulse_sequence.waveforms["arb_wf"].amplitudes = [-complex(1, 0.4), 0, -0.3, -complex(0.1, 0.2)] + + assert pulse_sequence.to_ir() == expected_str_after_mod + + def test_pulse_sequence_make_bound_pulse_sequence(predefined_frame_1, predefined_frame_2): param = FreeParameter("a") + 2 * FreeParameter("c") pulse_sequence = (