|
| 1 | +from pathlib import Path |
| 2 | + |
| 3 | +import pandas as pd |
| 4 | +import seaborn as sns |
| 5 | +from matplotlib import pyplot as plt |
| 6 | +from matplotlib.figure import Figure |
| 7 | +from matplotlib.ticker import MaxNLocator |
| 8 | + |
| 9 | +# Create the heatmap |
| 10 | +from analytics.plotting.common.common import init_plot |
| 11 | + |
| 12 | + |
| 13 | +def build_heatmap( |
| 14 | + heatmap_data: pd.DataFrame, |
| 15 | + y_ticks: list[int] | None = None, |
| 16 | + y_ticks_bins: int | None = None, |
| 17 | + reverse_col: bool = False, |
| 18 | + y_label: str = "Reference Year", |
| 19 | + x_label: str = "Current Year", |
| 20 | + color_label: str = "Accuracy %", |
| 21 | +): |
| 22 | + init_plot() |
| 23 | + # sns.set_theme(style="ticks") |
| 24 | + plt.rcParams["svg.fonttype"] = "none" |
| 25 | + |
| 26 | + double_fig_width = 10 |
| 27 | + double_fig_height = 3.5 |
| 28 | + |
| 29 | + fig = plt.figure( |
| 30 | + edgecolor="black", |
| 31 | + frameon=True, |
| 32 | + figsize=(double_fig_width, 2.2 * double_fig_height), |
| 33 | + dpi=300, |
| 34 | + ) |
| 35 | + |
| 36 | + ax = sns.heatmap( |
| 37 | + heatmap_data, |
| 38 | + cmap="RdBu" + ("_r" if reverse_col else ""), |
| 39 | + linewidths=0.0, |
| 40 | + linecolor="black", |
| 41 | + cbar=True, |
| 42 | + # color bar from 0 to 1 |
| 43 | + cbar_kws={ |
| 44 | + "label": color_label, |
| 45 | + # "ticks": [0, 25, 50, 75, 100], |
| 46 | + "orientation": "vertical", |
| 47 | + }, |
| 48 | + ) |
| 49 | + ax.collections[0].set_rasterized(True) |
| 50 | + |
| 51 | + # Adjust x-axis tick labels |
| 52 | + plt.xlabel(x_label) |
| 53 | + plt.xticks( |
| 54 | + ticks=[x + 0.5 for x in range(0, 2010 - 1930 + 1, 20)], |
| 55 | + labels=[x for x in range(1930, 2010 + 1, 20)], |
| 56 | + rotation=0, |
| 57 | + # ha='right' |
| 58 | + ) |
| 59 | + ax.invert_yaxis() |
| 60 | + |
| 61 | + if y_ticks is not None: |
| 62 | + plt.yticks(ticks=[y + 0.5 - 1930 for y in y_ticks], labels=[y for y in y_ticks], rotation=0) |
| 63 | + elif y_ticks_bins is not None: |
| 64 | + ax.yaxis.set_major_locator(MaxNLocator(nbins=y_ticks_bins)) |
| 65 | + ax.set_yticklabels([int(i) + min(heatmap_data.index) for i in ax.get_yticks()], rotation=0) |
| 66 | + |
| 67 | + plt.ylabel(y_label) |
| 68 | + |
| 69 | + # # Draft training boxes |
| 70 | + # if drift_pipeline: |
| 71 | + # for type_, dashed in [("train", False), ("usage", False), ("train", True)]: |
| 72 | + # for active_ in df_logs_models.iterrows(): |
| 73 | + # x_start = active_[1][f"{type_}_start"].year - 1930 |
| 74 | + # x_end = active_[1][f"{type_}_end"].year - 1930 |
| 75 | + # y = active_[1]["model_idx"] |
| 76 | + # rect = plt.Rectangle( |
| 77 | + # (x_start, y - 1), # y: 0 based index, model_idx: 1 based index |
| 78 | + # x_end - x_start, |
| 79 | + # 1, |
| 80 | + # edgecolor="White" if type_ == "train" else "Black", |
| 81 | + # facecolor="none", |
| 82 | + # linewidth=3, |
| 83 | + # linestyle="dotted" if dashed else "solid", |
| 84 | + # hatch="/", |
| 85 | + # joinstyle="bevel", |
| 86 | + # # capstyle="round", |
| 87 | + # ) |
| 88 | + # ax.add_patch(rect) |
| 89 | + |
| 90 | + # Display the plot |
| 91 | + plt.tight_layout() |
| 92 | + # plt.show() |
| 93 | + |
| 94 | + return fig |
| 95 | + |
| 96 | + |
| 97 | +def save_plot(fig: Figure, name: str) -> None: |
| 98 | + for img_type in ["png", "svg"]: |
| 99 | + img_path = Path("/scratch/robinholzi/gh/modyn/.data/plots") / f"{name}.{img_type}" |
| 100 | + fig.savefig(img_path, bbox_inches="tight", transparent=True) |
0 commit comments