Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,39 @@

@dataclass
class FitResults:
"""Stores the relevant CZ conditional phase experiment fit parameters for a single qubit pair"""
"""Fit results for a single qubit pair from the CZ phase compensation calibration.

Attributes:
-----------
control_phase_correction : float
Residual single-qubit phase accumulated by the control qubit during the CZ gate (in 2π units).
This value is subtracted from ``phase_shift_control`` in the state update.
target_phase_correction : float
Residual single-qubit phase accumulated by the target qubit during the CZ gate (in 2π units).
This value is subtracted from ``phase_shift_target`` in the state update.
success : bool
True if the sinusoidal fit converged for both qubits.
"""

control_phase_correction: float
target_phase_correction: float
success: bool


def fix_oscillation_phi_2pi(fit_data):
"""Extract and fix the phase parameter from oscillation fit data."""
# Extract the phase parameter from the fit results
"""Extract and normalise the oscillation phase to the [0, 1) range (representing 0 to 2π).

Parameters:
-----------
fit_data : xr.DataArray
Oscillation fit result with a ``fit_vals`` dimension containing ``"phi"``.

Returns:
--------
xr.DataArray
Phase values normalised to [0, 1).
"""
phase = fit_data.sel(fit_vals="phi")
# Normalize phase to [0, 1] range (representing 0 to 2π)
phase = (phase / (2 * np.pi)) % 1
return phase

Expand Down Expand Up @@ -57,8 +78,25 @@ def log_fitted_results(fit_results: Dict[str, FitResults], log_callable=None):
log_callable(log_message)


def process_raw_dataset(ds: xr.Dataset, node: QualibrationNode):
# Convert IQ data into volts
def process_raw_dataset(ds: xr.Dataset, node: QualibrationNode) -> xr.Dataset:
"""Convert raw IQ data to volts when state discrimination is not used.

Parameters:
-----------
ds : xr.Dataset
Raw dataset from the experiment, containing either ``I_control`` / ``Q_control`` /
``I_target`` / ``Q_target`` (raw IQ) or ``state_control`` / ``state_target``
(state-discrimination) variables.
node : QualibrationNode
Calibration node providing ``qubit_pairs`` from its namespace (required by
``convert_IQ_to_V``).

Returns:
--------
xr.Dataset
Dataset with IQ variables converted to volts, or unchanged if state discrimination
was used.
"""
if hasattr(ds, "I_control"):
ds = convert_IQ_to_V(
ds, qubit_pairs=node.namespace["qubit_pairs"], IQ_list=["I_control", "Q_control", "I_target", "Q_target"]
Expand Down Expand Up @@ -91,7 +129,24 @@ def fit_raw_data(ds: xr.Dataset, node: QualibrationNode) -> Tuple[xr.Dataset, Di


def fit_routine(da):
"""Fit Ramsey frame-rotation oscillations and extract phase corrections for one qubit pair.

Fits a sinusoidal oscillation over the ``frame`` axis separately for the control qubit
and the target qubit. The fitted phase of each oscillation gives the residual single-qubit
phase error introduced by the CZ gate.

Parameters:
-----------
da : xr.Dataset
Single-pair dataset with ``state_control`` / ``state_target`` (state-discrimination)
or ``I_control`` / ``I_target`` (raw IQ) variables, and a ``frame`` dimension.

Returns:
--------
xr.Dataset
Input dataset extended with ``fitted_control``, ``fitted_target``,
``fitted_control_phase``, ``fitted_target_phase``, and ``success`` data variables.
"""
if hasattr(da, "state_target"):
data_control = "state_control"
data_target = "state_target"
Expand Down Expand Up @@ -168,22 +223,30 @@ def _extract_relevant_parameters(
ds_fit: xr.Dataset, node: QualibrationNode
) -> Tuple[xr.Dataset, Dict[str, FitResults]]:
"""
Extract relevant fit parameters and create FitResults for each qubit pair.
Assign xarray metadata attributes and build the FitResults dictionary.

Parameters:
-----------
ds_fit : xr.Dataset
Dataset containing the fit results from fit_routine.
Dataset produced by applying ``fit_routine`` per qubit pair. Must contain
``fitted_control_phase``, ``fitted_target_phase``, and ``success`` data variables.
node : QualibrationNode
The calibration node containing parameters and qubit pairs.
Calibration node providing ``qubit_pairs`` from its namespace.

Returns:
--------
Tuple[xr.Dataset, Dict[str, FitResults]]
Dataset with additional metadata and dictionary of FitResults for each qubit pair.
Dataset with ``long_name`` / ``units`` attrs set on key variables, and a
dictionary of ``FitResults`` keyed by qubit pair name.
"""
qubit_pairs = node.namespace["qubit_pairs"]

# Add metadata attributes to the dataset
if "fitted_control_phase" in ds_fit.data_vars:
ds_fit.fitted_control_phase.attrs = {"long_name": "control qubit phase correction", "units": "2π"}
if "fitted_target_phase" in ds_fit.data_vars:
ds_fit.fitted_target_phase.attrs = {"long_name": "target qubit phase correction", "units": "2π"}

# Create FitResults for each qubit pair
fit_results = {}
for qp in qubit_pairs:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,52 +1,144 @@
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from qualibrate import QualibrationNode
from quam_config import Quam
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from qualibration_libs.plotting import grid_iter

from .parameters import Parameters
from calibration_utils.pair_grid import QubitPairGrid, grid_pair_names


def plot_raw_data_with_fit(ds_raw: xr.Dataset, qubit_pairs: Quam, ds_fit: xr.Dataset = None):
def _mark_fitted_peak(
ax: Axes,
frames: np.ndarray,
fit_curve: np.ndarray,
fitted_phase: float,
color: str,
channel_label: str,
) -> None:
"""Mark the fitted-sine maximum."""
if not np.any(np.isfinite(fit_curve)):
return
i_max = int(np.nanargmax(fit_curve))
peak_frame = float(frames[i_max])
peak_val = float(fit_curve[i_max])
ax.axvline(
peak_frame,
color=color,
ls="--",
lw=1.5,
alpha=0.85,
zorder=3,
label=f"{channel_label} peak @ {peak_frame:.4f} (φ={fitted_phase:.4f})",
)
ax.plot(
peak_frame,
peak_val,
"*",
ms=12,
mew=1.2,
zorder=5,
markerfacecolor="white",
markeredgecolor=color,
)


def plot_raw_data_with_fit(
ds_raw: xr.Dataset,
qubit_pairs: list,
ds_fit: xr.Dataset = None,
) -> Figure:
"""Plot raw Ramsey oscillations with fitted curves for every qubit pair on a chip-topology grid.

Parameters
----------
ds_raw : xr.Dataset
Raw dataset containing ``state_control`` / ``state_target`` or
``I_control`` / ``I_target`` variables with a ``frame`` dimension.
qubit_pairs : list
Qubit pair objects used for grid placement.
ds_fit : xr.Dataset, optional
Fit dataset containing ``fitted_control``, ``fitted_target``, and ``success``.
If None, only the raw data is shown.

Returns
-------
Figure
Matplotlib figure with one panel per qubit pair.
"""
Plot the raw data with the fit for each qubit pair in a single figure.
grid_names, pair_names = grid_pair_names(qubit_pairs)
grid = QubitPairGrid(grid_names, pair_names)

for ax, qubit in grid_iter(grid):
qp_name = qubit["qubit"]
plot_individual_data_with_fit(ax, ds_raw, qp_name, ds_fit)

grid.fig.suptitle("CZ 1Q phase compensation")
grid.fig.tight_layout()
return grid.fig


def plot_individual_data_with_fit(
ax: Axes,
ds_raw: xr.Dataset,
qp_name: str,
ds_fit: xr.Dataset = None,
):
"""Plot raw Ramsey oscillations and fitted curves for one qubit pair.

Parameters
----------
ax : Axes
Axis to draw on.
ds_raw : xr.Dataset
Raw dataset with ``state_control`` / ``state_target`` or
``I_control`` / ``I_target`` variables.
qp_name : str
Qubit pair name used to select data.
ds_fit : xr.Dataset, optional
Fit dataset containing ``fitted_control``, ``fitted_target``, and ``success``.
If None or fit failed, only raw data is shown.
"""
n_pairs = len(qubit_pairs)

fig, axes = plt.subplots(1, n_pairs, figsize=(5 * n_pairs, 4))
if n_pairs == 1:
axes = [axes]

for i, qp_name in enumerate(qubit_pairs):
ax = axes[i]

# Select data for this qubit pair
qp_data = ds_raw.sel(qubit_pair=qp_name.name)

# Plot raw data
if "state_control" in ds_raw.data_vars:
qp_data.state_control.plot(ax=ax, marker="o", linestyle="", color="blue", label="Control")
qp_data.state_target.plot(ax=ax, marker="o", linestyle="", color="red", label="Target")
else:
qp_data.I_control.sel(control_target="c").plot(
ax=ax, marker="o", linestyle="", color="blue", label="Control"
)
qp_data.I_target.sel(control_target="t").plot(ax=ax, marker="o", linestyle="", color="red", label="Target")

# Plot fitted data if available and fit was successful
if ds_fit is not None:
qp_fit_data = ds_fit.sel(qubit_pair=qp_name.name)
if qp_fit_data.success.values:
if "fitted_control" in ds_fit.data_vars:
qp_fit_data.fitted_control.plot(ax=ax, color="blue", alpha=0.5)
if "fitted_target" in ds_fit.data_vars:
qp_fit_data.fitted_target.plot(ax=ax, color="red", alpha=0.5)
if "state_control" in ds_raw.data_vars:
ax.set_ylabel("Measured State")
else:
ax.set_ylabel("Rotated I Quadrature [V]")

ax.set_xlabel(r"x90 frame rotation [$\mathrm{rad}/2\pi$]")
ax.legend()

plt.tight_layout()
return fig
qp_data = ds_raw.sel(qubit_pair=qp_name)

if "state_control" in ds_raw.data_vars:
qp_data.state_control.plot(ax=ax, marker="o", linestyle="", color="blue", label="Control")
qp_data.state_target.plot(ax=ax, marker="o", linestyle="", color="red", label="Target")
ylabel = "Measured State"
else:
qp_data.I_control.sel(control_target="c").plot(ax=ax, marker="o", linestyle="", color="blue", label="Control")
qp_data.I_target.sel(control_target="t").plot(ax=ax, marker="o", linestyle="", color="red", label="Target")
ylabel = "Rotated I Quadrature (V)"

if ds_fit is not None:
qp_fit = ds_fit.sel(qubit_pair=qp_name)
frames = qp_data.frame.values
if bool(qp_fit.success.values):
if "fitted_control" in ds_fit.data_vars:
fit_c = np.asarray(qp_fit.fitted_control.values, dtype=float)
ax.plot(frames, fit_c, color="blue", alpha=0.5)
if "fitted_control_phase" in ds_fit.data_vars:
_mark_fitted_peak(
ax,
frames,
fit_c,
float(qp_fit.fitted_control_phase.values),
"blue",
"Control",
)
if "fitted_target" in ds_fit.data_vars:
fit_t = np.asarray(qp_fit.fitted_target.values, dtype=float)
ax.plot(frames, fit_t, color="red", alpha=0.5)
if "fitted_target_phase" in ds_fit.data_vars:
_mark_fitted_peak(
ax,
frames,
fit_t,
float(qp_fit.fitted_target_phase.values),
"red",
"Target",
)

ax.set_title(qp_name)
ax.set_xlabel(r"Virtual-Z frame [$2\pi$]")
ax.set_ylabel(ylabel)
ax.legend(fontsize=8)
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""CZ phase compensation with error amplification calibration utilities."""

from .analysis import FitResults, fit_raw_data, log_fitted_results, process_raw_dataset
from .parameters import Parameters
from .plotting import plot_raw_data_with_fit

__all__ = [
"Parameters",
"process_raw_dataset",
"fit_raw_data",
"log_fitted_results",
"FitResults",
"plot_raw_data_with_fit",
]
Loading
Loading