Skip to content

Commit

Permalink
refactor: Simplify fit_rotation function by removing offresonant para…
Browse files Browse the repository at this point in the history
…meter and related code
  • Loading branch information
Akinori Machino committed Dec 27, 2024
1 parent b251ad1 commit 8b01329
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 86 deletions.
4 changes: 2 additions & 2 deletions docs/examples/analysis/rotation3d.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -287,14 +287,14 @@
"metadata": {},
"outputs": [],
"source": [
"fit_rotation(times, data, offresonant=False)\n",
"fit_rotation(times, data)\n",
"display_bloch_sphere_from_bloch_vectors(data)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "3.9.18",
"language": "python",
"name": "python3"
},
Expand Down
99 changes: 15 additions & 84 deletions src/qubex/analysis/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,7 +1684,6 @@ def residuals(params, f, y):
def fit_rotation(
times: npt.NDArray[np.float64],
data: npt.NDArray[np.float64],
offresonant: bool = False,
r0: npt.NDArray[np.float64] = np.array([0, 0, 1]),
p0=None,
bounds=None,
Expand Down Expand Up @@ -1714,8 +1713,6 @@ def fit_rotation(
dict
Omega : tuple[float, float, float]
Rotation coefficients.
delta : float
Detuning frequency.
fig : go.Figure
Plot of the data and the fit.
"""
Expand All @@ -1734,7 +1731,7 @@ def rotation_matrix(
G = n[0] * G_x + n[1] * G_y + n[2] * G_z
return np.eye(3) + np.sin(omega * t) * G + (1 - np.cos(omega * t)) * G @ G

def onresonant_rotation(
def rotate(
times: npt.NDArray[np.float64],
omega: float,
theta: float,
Expand All @@ -1761,46 +1758,8 @@ def onresonant_rotation(
[rotation_matrix(t, omega, (n_x, n_y, n_z)) @ r0 for t in times]
)

def onresonant_residuals(params, times, data):
return (onresonant_rotation(times, *params) - data).flatten()

def offresonant_rotation(
times: npt.NDArray[np.float64],
omega: float,
theta: float,
phi: float,
delta: float,
) -> npt.NDArray[np.float64]:
"""
Simulate the off-resonant rotation of a state vector.
Parameters
----------
times : npt.NDArray[np.float64]
Time points for the rotation.
omega : float
Rotation frequency.
theta : float
Polar angle of the rotation axis.
phi : float
Azimuthal angle of the rotation axis.
delta : float
Detuning frequency.
"""
n_x = np.sin(theta) * np.cos(phi)
n_y = np.sin(theta) * np.sin(phi)
n_z = np.cos(theta)
return np.array(
[
rotation_matrix(t, delta, (0, 0, 1))
@ rotation_matrix(t, omega, (n_x, n_y, n_z))
@ r0
for t in times
]
)

def offresonant_residuals(params, times, data):
return (offresonant_rotation(times, *params) - data).flatten()
def residuals(params, times, data):
return (rotate(times, *params) - data).flatten()

if p0 is None:
N = len(times)
Expand All @@ -1812,39 +1771,21 @@ def offresonant_residuals(params, times, data):
omega_est = 2 * np.pi * dominant_freq
theta_est = np.pi / 2
phi_est = 0.0
if offresonant:
delta_est = 0.0
p0 = (omega_est, theta_est, phi_est, delta_est)
else:
p0 = (omega_est, theta_est, phi_est)
p0 = (omega_est, theta_est, phi_est)

if bounds is None:
if offresonant:
bounds = (
(0, 0, -np.pi, -np.inf),
(np.inf, np.pi, np.pi, np.inf),
)
else:
bounds = (
(0, 0, -np.pi),
(np.inf, np.pi, np.pi),
)

if offresonant:
result = least_squares(
offresonant_residuals,
p0,
bounds=bounds,
args=(times, data),
)
else:
result = least_squares(
onresonant_residuals,
p0,
bounds=bounds,
args=(times, data),
bounds = (
(0, 0, -np.pi),
(np.inf, np.pi, np.pi),
)

result = least_squares(
residuals,
p0,
bounds=bounds,
args=(times, data),
)

fitted_params = result.x
Omega = fitted_params[0]
theta = fitted_params[1]
Expand All @@ -1854,16 +1795,7 @@ def offresonant_residuals(params, times, data):
Omega_z = Omega * np.cos(theta)
print(f"Omega: ({Omega_x:.6f}, {Omega_y:.6f}, {Omega_z:.6f})")

if offresonant:
delta = fitted_params[3]
print(f"delta: {delta:.6f}")
else:
delta = 0.0

if offresonant:
fit = offresonant_rotation(times, *fitted_params)
else:
fit = onresonant_rotation(times, *fitted_params)
fit = rotate(times, *fitted_params)

if plot:
fig = go.Figure()
Expand Down Expand Up @@ -1991,7 +1923,6 @@ def offresonant_residuals(params, times, data):

return {
"Omega": np.array([Omega_x, Omega_y, Omega_z]),
"delta": delta,
"fig": fig,
}

Expand Down

0 comments on commit 8b01329

Please sign in to comment.