Skip to content

Commit

Permalink
Add more ablation plots
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 committed Feb 23, 2024
1 parent ee894e4 commit cc5cce9
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions analysis/tmlr/nicopp_more_ablations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# %%
import sys

sys.path.append("..")

# %%
from pathlib import Path

from ranzen.wandb import RunsDownloader
from wandb_utils import CustomMethod, Group, MethodName, Metrics, download_groups, plot

# %%
downloader = RunsDownloader(project="support-matching", entity="predictive-analytics-lab")

# %%
data = download_groups(
downloader,
{
# "our_method_2023-10-02": Group(MethodName.ours_bag_oracle, "test/y_from_zy/", ""),
"our_method_no_balancing_2023-10-19": Group(
MethodName.ours_no_balancing, "test/y_from_zy/", ""
),
"ours_2023-12-19_no_recon_loss": Group(
CustomMethod("Without recon loss"), "test/y_from_zy/", ""
),
# "ours_2023-12-19_no_zs": Group(CustomMethod("no zs split"), "test/y_from_zy/", ""),
"ours_2023-12-19_no_y_predictor": Group(
CustomMethod("Without $y$-predictor"), "test/y_from_zy/", ""
),
"ours_2024-01-05_no_disc_loss": Group(
CustomMethod("Without disc loss"), "test/y_from_zy/", ""
),
},
)
data = data.rename(columns={"Robust_OvR_TPR": "Robust OvR TPR"})

# %%
plot(
data,
metrics=[Metrics.acc, Metrics.rob_tpr_ovr],
# metrics=[Metrics.acc],
# x_label="noise level",
# x_limits=(0.48, 1),
# plot_style=PlotStyle.boxplot_hue,
file_format="pdf",
fig_dim=(5.0, 1.0),
file_prefix="nicopp",
output_dir=Path("nicopp") / "more_ablation",
)

# %%

0 comments on commit cc5cce9

Please sign in to comment.