Skip to content

Commit

Permalink
feat: Add n_samples parameter to state classifier methods for customi…
Browse files Browse the repository at this point in the history
…zable plotting
  • Loading branch information
Akinori Machino committed Dec 28, 2024
1 parent 3b10f84 commit 5fe1c22
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/qubex/measurement/state_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def plot(
target: str,
data: NDArray[np.complex128],
labels: NDArray,
n_samples: int = 1000,
):
"""
Plot the data and the predicted labels.
Expand All @@ -143,6 +144,8 @@ def plot(
An array of complex numbers representing the data.
labels : NDArray
An array of predicted state labels.
n_samples : int, optional
The number of samples to plot, by default 1000.
"""
raise NotImplementedError

Expand Down
6 changes: 6 additions & 0 deletions src/qubex/measurement/state_classifier_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def plot(
target: str,
data: NDArray[np.complex128],
labels: NDArray,
n_samples: int = 1000,
):
"""
Plot the data and the predicted labels.
Expand All @@ -301,7 +302,12 @@ def plot(
An array of complex numbers representing the data.
labels : NDArray
An array of predicted state labels.
n_samples : int, optional
The number of samples to plot, by default 1000.
"""
if len(data) > n_samples:
data = data[:n_samples]
labels = labels[:n_samples]
x = data.real
y = data.imag
unique_labels = np.unique(labels)
Expand Down
6 changes: 6 additions & 0 deletions src/qubex/measurement/state_classifier_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def plot(
target: str,
data: NDArray[np.complex128],
labels: NDArray,
n_samples: int = 1000,
):
"""
Plot the data and the predicted labels.
Expand All @@ -266,7 +267,12 @@ def plot(
An array of complex numbers representing the data.
labels : NDArray
An array of predicted state labels.
n_samples : int, optional
The number of samples to plot, by default 1000.
"""
if len(data) > n_samples:
data = data[:n_samples]
labels = labels[:n_samples]
x = data.real
y = data.imag
unique_labels = np.unique(labels)
Expand Down

0 comments on commit 5fe1c22

Please sign in to comment.