Skip to content

Commit 7bce65b

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Audit usage of trial failure reason (#4765)
Summary: Pull Request resolved: #4765 - Updates `Experiment.to_df` to extract the column from `Trial.status_reason`. Previously, this was trying to get it from `run_metadata["fail_reason"]`, which is never populated AFAICT. - Renames the column across analyses to `status_reason` from `fail_reason`. The motivation is that this property can represent abandoned and early stopped trial reasons, in addition to failed trials. Reviewed By: mgarrard Differential Revision: D90627835 fbshipit-source-id: 22b9b6aa0011ad6125f754508cd94c172130102d
1 parent 5a1f251 commit 7bce65b

File tree

8 files changed

+26
-25
lines changed

8 files changed

+26
-25
lines changed

ax/analysis/plotly/arm_effects.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def compute(
226226
"trial_index",
227227
"trial_status",
228228
"arm_name",
229-
"fail_reason",
229+
"status_reason",
230230
"generation_node",
231231
f"{self.metric_name}_mean",
232232
f"{self.metric_name}_sem",

ax/analysis/plotly/tests/test_arm_effects.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def test_compute_raw(self) -> None:
125125
"trial_index",
126126
"arm_name",
127127
"trial_status",
128-
"fail_reason",
128+
"status_reason",
129129
"generation_node",
130130
"foo_mean",
131131
"foo_sem",
@@ -156,7 +156,7 @@ def test_compute_with_modeled(self) -> None:
156156
"trial_index",
157157
"arm_name",
158158
"trial_status",
159-
"fail_reason",
159+
"status_reason",
160160
"generation_node",
161161
"foo_mean",
162162
"foo_sem",
@@ -343,7 +343,7 @@ def test_compute_with_relativize(self) -> None:
343343
"trial_index",
344344
"arm_name",
345345
"trial_status",
346-
"fail_reason",
346+
"status_reason",
347347
"generation_node",
348348
"branin_mean",
349349
"branin_sem",

ax/analysis/plotly/tests/test_scatter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def test_compute_raw(self) -> None:
127127
"trial_index",
128128
"arm_name",
129129
"trial_status",
130-
"fail_reason",
130+
"status_reason",
131131
"generation_node",
132132
"p_feasible_mean",
133133
"p_feasible_sem",
@@ -186,7 +186,7 @@ def test_compute_with_modeled(self) -> None:
186186
"trial_index",
187187
"arm_name",
188188
"trial_status",
189-
"fail_reason",
189+
"status_reason",
190190
"generation_node",
191191
"p_feasible_mean",
192192
"p_feasible_sem",

ax/analysis/tests/test_summary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def test_compute(self) -> None:
119119
"trial_index",
120120
"arm_name",
121121
"trial_status",
122-
"fail_reason",
122+
"status_reason",
123123
"generation_node",
124124
"foo",
125125
"bar",

ax/analysis/tests/test_utils.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def test_prepare_arm_data_raw(self) -> None:
147147
"trial_index",
148148
"arm_name",
149149
"trial_status",
150-
"fail_reason",
150+
"status_reason",
151151
"generation_node",
152152
"p_feasible_mean",
153153
"p_feasible_sem",
@@ -194,7 +194,7 @@ def test_prepare_arm_data_raw(self) -> None:
194194
"trial_index",
195195
"arm_name",
196196
"trial_status",
197-
"fail_reason",
197+
"status_reason",
198198
"generation_node",
199199
"p_feasible_mean",
200200
"p_feasible_sem",
@@ -225,7 +225,7 @@ def test_prepare_arm_data_raw(self) -> None:
225225
"trial_index",
226226
"arm_name",
227227
"trial_status",
228-
"fail_reason",
228+
"status_reason",
229229
"generation_node",
230230
"p_feasible_mean",
231231
"p_feasible_sem",
@@ -334,7 +334,7 @@ def test_prepare_arm_data_use_model_predictions(self) -> None:
334334
"trial_index",
335335
"arm_name",
336336
"trial_status",
337-
"fail_reason",
337+
"status_reason",
338338
"generation_node",
339339
"p_feasible_mean",
340340
"p_feasible_sem",
@@ -378,7 +378,7 @@ def test_prepare_arm_data_use_model_predictions(self) -> None:
378378
"trial_index",
379379
"arm_name",
380380
"trial_status",
381-
"fail_reason",
381+
"status_reason",
382382
"generation_node",
383383
"p_feasible_mean",
384384
"p_feasible_sem",
@@ -411,7 +411,7 @@ def test_prepare_arm_data_use_model_predictions(self) -> None:
411411
"trial_index",
412412
"arm_name",
413413
"trial_status",
414-
"fail_reason",
414+
"status_reason",
415415
"generation_node",
416416
"p_feasible_mean",
417417
"p_feasible_sem",
@@ -450,7 +450,7 @@ def test_prepare_arm_data_use_model_predictions(self) -> None:
450450
"trial_index",
451451
"arm_name",
452452
"trial_status",
453-
"fail_reason",
453+
"status_reason",
454454
"generation_node",
455455
"p_feasible_mean",
456456
"p_feasible_sem",
@@ -493,7 +493,7 @@ def test_prepare_arm_data_use_model_predictions(self) -> None:
493493
"trial_index",
494494
"arm_name",
495495
"trial_status",
496-
"fail_reason",
496+
"status_reason",
497497
"generation_node",
498498
"p_feasible_mean",
499499
"p_feasible_sem",
@@ -578,7 +578,7 @@ def test_prepare_arm_data_out_of_distribution_arm(self) -> None:
578578
self.assertFalse(np.isnan(ood_df.foo_sem.iloc[0]))
579579

580580
def test_prepare_arm_data_includes_failure_reasons(self) -> None:
581-
"""Test that the fail_reason column is properly populated."""
581+
"""Test that the status_reason column is properly populated."""
582582
client = Client()
583583
client.configure_experiment(
584584
name="test_failure_reasons",
@@ -600,16 +600,16 @@ def test_prepare_arm_data_includes_failure_reasons(self) -> None:
600600
use_model_predictions=False,
601601
)
602602

603-
# Verify fail_reason column is populated correctly
604-
self.assertIn("fail_reason", df.columns)
603+
# Verify status_reason column is populated correctly
604+
self.assertIn("status_reason", df.columns)
605605
self.assertTrue(
606-
pd.isna(df[df["trial_index"] == 0]["fail_reason"].iloc[0])
606+
pd.isna(df[df["trial_index"] == 0]["status_reason"].iloc[0])
607607
) # Success: no reason
608608
self.assertEqual(
609-
df[df["trial_index"] == 1]["fail_reason"].iloc[0], "Regular failure"
609+
df[df["trial_index"] == 1]["status_reason"].iloc[0], "Regular failure"
610610
) # Regular failure
611611
self.assertEqual(
612-
df[df["trial_index"] == 2]["fail_reason"].iloc[0], STALE_FAIL_REASON
612+
df[df["trial_index"] == 2]["status_reason"].iloc[0], STALE_FAIL_REASON
613613
) # Stale failure
614614

615615
def test_relativize_df_with_sq(self) -> None:

ax/analysis/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def prepare_arm_data(
255255
if trial_index != -1
256256
else "Additional Arm"
257257
)
258-
df["fail_reason"] = df["trial_index"].apply(
258+
df["status_reason"] = df["trial_index"].apply(
259259
lambda trial_index: experiment.trials[trial_index].status_reason
260260
if trial_index != -1
261261
and experiment.trials[trial_index].status_reason is not None

ax/core/experiment.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1920,7 +1920,8 @@ def to_df(
19201920
- trial_index: The trial index of the arm
19211921
- arm_name: The name of the arm
19221922
- trial_status: The status of the trial (e.g. RUNNING, SUCCEDED, FAILED)
1923-
- failure_reason: The reason for the failure, if applicable
1923+
- status_reason: The reason for the trial status (e.g., failure,
1924+
abandonment, early stopping), if applicable
19241925
- generation_node: The name of the ``GenerationNode`` that generated the arm
19251926
- **METADATA: Any metadata associated with the trial, as specified by the
19261927
Experiment's runner.run_metadata_report_keys field
@@ -2002,7 +2003,7 @@ def to_df(
20022003
"trial_index": trial.index,
20032004
"arm_name": arm.name,
20042005
"trial_status": trial.status.name,
2005-
"fail_reason": trial.run_metadata.get("fail_reason", None),
2006+
"status_reason": trial.status_reason,
20062007
"generation_node": generation_node,
20072008
**metadata,
20082009
**observed_means,

ax/core/tests/test_experiment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1689,7 +1689,7 @@ def test_to_df(self) -> None:
16891689
"trial_index",
16901690
"arm_name",
16911691
"trial_status",
1692-
"fail_reason",
1692+
"status_reason",
16931693
"generation_node",
16941694
"name",
16951695
"m1",

0 commit comments

Comments
 (0)