diff --git a/src/qubex/experiment/experiment.py b/src/qubex/experiment/experiment.py index b23286d..ed99afa 100644 --- a/src/qubex/experiment/experiment.py +++ b/src/qubex/experiment/experiment.py @@ -997,6 +997,7 @@ def measure( sequence: TargetMap[IQArray] | TargetMap[Waveform] | PulseSchedule, *, frequencies: Optional[dict[str, float]] = None, + initial_states: dict[str, str] | None = None, mode: Literal["single", "avg"] = "avg", shots: int = DEFAULT_SHOTS, interval: int = DEFAULT_INTERVAL, @@ -1016,6 +1017,8 @@ def measure( Sequence of the experiment. frequencies : Optional[dict[str, float]] Frequencies of the qubits. + initial_states : dict[str, str], optional + Initial states of the qubits. Defaults to None. mode : Literal["single", "avg"], optional Measurement mode. Defaults to "avg". shots : int, optional @@ -1054,16 +1057,44 @@ def measure( capture_window = capture_window or self._capture_window capture_margin = capture_margin or self._capture_margin readout_duration = readout_duration or self._readout_duration - waveforms = {} + waveforms: dict[str, NDArray[complex]] = {} if isinstance(sequence, PulseSchedule): - sequence = sequence.get_sampled_sequences() - - for target, waveform in sequence.items(): - if isinstance(waveform, Waveform): - waveforms[target] = waveform.values + if initial_states is not None: + labels = list(set(sequence.labels) | set(initial_states.keys())) + with PulseSchedule(labels) as ps: + for target, state in initial_states.items(): + if target in self.qubit_labels: + ps.add(target, self.get_pulse_for_state(target, state)) + else: + raise ValueError(f"Invalid init target: {target}") + ps.barrier() + ps.call(sequence) + waveforms = ps.get_sampled_sequences() + else: + waveforms = sequence.get_sampled_sequences() + else: + if initial_states is not None: + labels = list(set(sequence.keys()) | set(initial_states.keys())) + with PulseSchedule(labels) as ps: + for target, state in initial_states.items(): + if target in self.qubit_labels: + ps.add(target, self.get_pulse_for_state(target, state)) + else: + raise ValueError(f"Invalid init target: {target}") + ps.barrier() + for target, waveform in sequence.items(): + if isinstance(waveform, Waveform): + ps.add(target, waveform) + else: + ps.add(target, Pulse(waveform)) + waveforms = ps.get_sampled_sequences() else: - waveforms[target] = np.array(waveform, dtype=np.complex128) + for target, waveform in sequence.items(): + if isinstance(waveform, Waveform): + waveforms[target] = waveform.values + else: + waveforms[target] = np.array(waveform, dtype=np.complex128) if frequencies is None: result = self._measurement.measure( diff --git a/src/qubex/pulse/pulse_schedule.py b/src/qubex/pulse/pulse_schedule.py index 00842cf..9f4b28a 100644 --- a/src/qubex/pulse/pulse_schedule.py +++ b/src/qubex/pulse/pulse_schedule.py @@ -38,7 +38,9 @@ def __init__( ... ps.add("RQ02", FlatTop(duration=200, amplitude=1, tau=10)) >>> ps.plot() """ - if isinstance(targets, list): + if isinstance(targets, dict): + self.targets = targets + else: self.targets = { target: { "frequency": None, @@ -46,8 +48,6 @@ def __init__( } for target in targets } - else: - self.targets = targets self._sequences = {target: PulseSequence() for target in targets} self._offsets = {target: 0.0 for target in targets}