Skip to content

Commit 4b9710e

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Do not produce ProgressionPlot if all steps are NaN (#4769)
Summary: Pull Request resolved: #4769 Prevents confusing plots like this one: {F1984569487} Reviewed By: mpolson64 Differential Revision: D90715040 fbshipit-source-id: cb34c04dd74fddf29620411d24593b1e90900b6e
1 parent 7bce65b commit 4b9710e

File tree

3 files changed

+46
-3
lines changed

3 files changed

+46
-3
lines changed

ax/analysis/plotly/progression.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def validate_applicable_state(
7575
) -> str | None:
7676
"""
7777
ProgressionPlot requires an Experiment with data that has a "step"
78-
column.
78+
column with valid (non-NaN) values.
7979
"""
8080
if (
8181
experiment_invalid_reason := validate_experiment(
@@ -90,6 +90,20 @@ def validate_applicable_state(
9090
if not data.has_step_column:
9191
return "Requires data to have a column 'step.'"
9292

93+
# Check if the step column has any valid (non-NaN) values
94+
metric_name = self._metric_name or select_metric(
95+
experiment=none_throws(experiment)
96+
)
97+
df = none_throws(experiment).lookup_data().full_df
98+
metric_df = df[df["metric_name"] == metric_name]
99+
100+
if metric_df.empty:
101+
return f"No data found for metric '{metric_name}'."
102+
103+
# Check if all step values are NaN
104+
if metric_df[MAP_KEY].isna().all():
105+
return f"All progression values for metric '{metric_name}' are NaN."
106+
93107
@override
94108
def compute(
95109
self,

ax/analysis/plotly/tests/test_progression.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55

66
# pyre-strict
77

8+
import numpy as np
89
import pandas as pd
910
from ax.analysis.plotly.progression import (
1011
_calculate_wallclock_timeseries,
1112
ProgressionPlot,
1213
)
14+
from ax.core.data import Data, MAP_KEY
1315
from ax.utils.common.testutils import TestCase
1416
from ax.utils.testing.core_stubs import (
1517
get_branin_experiment,
@@ -35,6 +37,28 @@ def test_validate_applicable_state(self) -> None:
3537
state = plot.validate_applicable_state(experiment=experiment)
3638
self.assertEqual(state, "Requires data to have a column 'step.'")
3739

40+
with self.subTest("All step values are NaN"):
41+
# Create a new experiment with map data where all MAP_KEY values are NaN
42+
experiment = get_test_map_data_experiment(
43+
num_trials=2, num_fetches=3, num_complete=2
44+
)
45+
46+
# Replace all MAP_KEY values with NaN in the fetched data
47+
original_data = experiment.fetch_data()
48+
modified_df = original_data.full_df.copy()
49+
modified_df[MAP_KEY] = np.nan
50+
51+
# Create a new Data object and attach it
52+
nan_data = Data(df=modified_df)
53+
experiment.data = nan_data
54+
55+
# Validate that progression plot is not applicable
56+
plot = ProgressionPlot(metric_name="branin_map")
57+
state = plot.validate_applicable_state(experiment=experiment)
58+
self.assertEqual(
59+
state, "All progression values for metric 'branin_map' are NaN."
60+
)
61+
3862
def test_compute(self) -> None:
3963
analysis = ProgressionPlot(metric_name="branin_map")
4064

ax/analysis/results.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from ax.core.analysis_card import AnalysisCardGroup
3030
from ax.core.arm import Arm
3131
from ax.core.batch_trial import BatchTrial
32+
from ax.core.data import MAP_KEY
3233
from ax.core.experiment import Experiment
3334
from ax.core.map_metric import MapMetric
3435
from ax.core.outcome_constraint import ScalarizedOutcomeConstraint
@@ -246,12 +247,16 @@ def compute(
246247
adapter=adapter,
247248
)
248249

249-
# Compute progression plots for MapMetrics (learning curves)
250+
# Compute progression plots if there is curve data.
250251
progression_group = None
251252
data = experiment.lookup_data()
252253
metrics = experiment.metrics.values()
253254
map_metrics = [m for m in metrics if isinstance(m, MapMetric)]
254-
if data.has_step_column and len(map_metrics) > 0:
255+
if (
256+
data.has_step_column
257+
and data.full_df[MAP_KEY].notna().any()
258+
and len(map_metrics) > 0
259+
):
255260
progression_cards = [
256261
ProgressionPlot(
257262
metric_name=m.name, by_wallclock_time=by_wallclock_time

0 commit comments

Comments
 (0)