Skip to content

Commit

Permalink
refactor: Simplify plotting logic in fit_rotation function for improv…
Browse files Browse the repository at this point in the history
…ed readability
  • Loading branch information
Akinori Machino committed Dec 28, 2024
1 parent ae92a32 commit dda7482
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 154 deletions.
225 changes: 113 additions & 112 deletions src/qubex/analysis/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1793,133 +1793,134 @@ def residuals(params, times, data):
Omega_x = Omega * np.sin(theta) * np.cos(phi)
Omega_y = Omega * np.sin(theta) * np.sin(phi)
Omega_z = Omega * np.cos(theta)
print(f"Omega: ({Omega_x:.6f}, {Omega_y:.6f}, {Omega_z:.6f})")
# print(f"Omega: ({Omega_x:.6f}, {Omega_y:.6f}, {Omega_z:.6f})")

times_fine = np.linspace(np.min(times), np.max(times), 1000)
fit = rotate(times_fine, *fitted_params)

if plot:
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=times,
y=data[:, 0],
mode="markers",
name="X (data)",
marker=dict(color=COLORS[0]),
)
)
fig.add_trace(
go.Scatter(
x=times_fine,
y=fit[:, 0],
mode="lines",
name="X (fit)",
line=dict(color=COLORS[0]),
)
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=times,
y=data[:, 0],
mode="markers",
name="X (data)",
marker=dict(color=COLORS[0]),
)
fig.add_trace(
go.Scatter(
x=times,
y=data[:, 1],
mode="markers",
name="Y (data)",
marker=dict(color=COLORS[1]),
)
)
fig.add_trace(
go.Scatter(
x=times_fine,
y=fit[:, 0],
mode="lines",
name="X (fit)",
line=dict(color=COLORS[0]),
)
fig.add_trace(
go.Scatter(
x=times_fine,
y=fit[:, 1],
mode="lines",
name="Y (fit)",
line=dict(color=COLORS[1]),
)
)
fig.add_trace(
go.Scatter(
x=times,
y=data[:, 1],
mode="markers",
name="Y (data)",
marker=dict(color=COLORS[1]),
)
fig.add_trace(
go.Scatter(
x=times,
y=data[:, 2],
mode="markers",
name="Z (data)",
marker=dict(color=COLORS[2]),
)
)
fig.add_trace(
go.Scatter(
x=times_fine,
y=fit[:, 1],
mode="lines",
name="Y (fit)",
line=dict(color=COLORS[1]),
)
fig.add_trace(
go.Scatter(
x=times_fine,
y=fit[:, 2],
mode="lines",
name="Z (fit)",
line=dict(color=COLORS[2]),
)
)
fig.add_trace(
go.Scatter(
x=times,
y=data[:, 2],
mode="markers",
name="Z (data)",
marker=dict(color=COLORS[2]),
)
fig.update_layout(
title=title,
xaxis_title=xlabel,
yaxis_title=ylabel,
yaxis=dict(range=[-1.1, 1.1]),
)
fig.add_trace(
go.Scatter(
x=times_fine,
y=fit[:, 2],
mode="lines",
name="Z (fit)",
line=dict(color=COLORS[2]),
)
fig.show()
)
fig.update_layout(
title=title,
xaxis_title=xlabel,
yaxis_title=ylabel,
yaxis=dict(range=[-1.1, 1.1]),
)

fig3d = go.Figure()
# data
fig3d.add_trace(
go.Scatter3d(
name="data",
x=data[:, 0],
y=data[:, 1],
z=data[:, 2],
mode="markers",
marker=dict(size=3),
hoverinfo="skip",
)
fig3d = go.Figure()
# data
fig3d.add_trace(
go.Scatter3d(
name="data",
x=data[:, 0],
y=data[:, 1],
z=data[:, 2],
mode="markers",
marker=dict(size=3),
hoverinfo="skip",
)
)

# fit
fig3d.add_trace(
go.Scatter3d(
name="fit",
x=fit[:, 0],
y=fit[:, 1],
z=fit[:, 2],
mode="lines",
line=dict(width=4),
hoverinfo="skip",
)
)
# sphere
theta = np.linspace(0, np.pi, 50)
phi = np.linspace(0, 2 * np.pi, 50)
theta, phi = np.meshgrid(theta, phi)
r = 1
x = r * np.sin(theta) * np.cos(phi)
y = r * np.sin(theta) * np.sin(phi)
z = r * np.cos(theta)
fig3d.add_trace(
go.Surface(
x=x,
y=y,
z=z,
opacity=0.05,
showscale=False,
colorscale="gray",
hoverinfo="skip",
)
# fit
fig3d.add_trace(
go.Scatter3d(
name="fit",
x=fit[:, 0],
y=fit[:, 1],
z=fit[:, 2],
mode="lines",
line=dict(width=4),
hoverinfo="skip",
)
# layout
fig3d.update_layout(
scene=dict(
xaxis=dict(title="〈X〉", visible=True),
yaxis=dict(title="〈Y〉", visible=True),
zaxis=dict(title="〈Z〉", visible=True),
aspectmode="cube",
),
width=400,
height=400,
margin=dict(l=0, r=0, b=0, t=0),
showlegend=False,
)
# sphere
theta = np.linspace(0, np.pi, 50)
phi = np.linspace(0, 2 * np.pi, 50)
theta, phi = np.meshgrid(theta, phi)
r = 1
x = r * np.sin(theta) * np.cos(phi)
y = r * np.sin(theta) * np.sin(phi)
z = r * np.cos(theta)
fig3d.add_trace(
go.Surface(
x=x,
y=y,
z=z,
opacity=0.05,
showscale=False,
colorscale="gray",
hoverinfo="skip",
)
)
# layout
fig3d.update_layout(
scene=dict(
xaxis=dict(title="〈X〉", visible=True),
yaxis=dict(title="〈Y〉", visible=True),
zaxis=dict(title="〈Z〉", visible=True),
aspectmode="cube",
),
width=400,
height=400,
margin=dict(l=0, r=0, b=0, t=0),
showlegend=False,
)

if plot:
fig.show()
fig3d.show()

return {
Expand Down
Loading

0 comments on commit dda7482

Please sign in to comment.