Skip to content

Commit 2e388ec

Browse files
authored
feat: Add script for matrix heatmap plots (#611)
1 parent f54c18c commit 2e388ec

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed

analytics/plotting/common/heatmap.py

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from pathlib import Path
2+
3+
import pandas as pd
4+
import seaborn as sns
5+
from matplotlib import pyplot as plt
6+
from matplotlib.figure import Figure
7+
from matplotlib.ticker import MaxNLocator
8+
9+
# Create the heatmap
10+
from analytics.plotting.common.common import init_plot
11+
12+
13+
def build_heatmap(
14+
heatmap_data: pd.DataFrame,
15+
y_ticks: list[int] | None = None,
16+
y_ticks_bins: int | None = None,
17+
reverse_col: bool = False,
18+
y_label: str = "Reference Year",
19+
x_label: str = "Current Year",
20+
color_label: str = "Accuracy %",
21+
):
22+
init_plot()
23+
# sns.set_theme(style="ticks")
24+
plt.rcParams["svg.fonttype"] = "none"
25+
26+
double_fig_width = 10
27+
double_fig_height = 3.5
28+
29+
fig = plt.figure(
30+
edgecolor="black",
31+
frameon=True,
32+
figsize=(double_fig_width, 2.2 * double_fig_height),
33+
dpi=300,
34+
)
35+
36+
ax = sns.heatmap(
37+
heatmap_data,
38+
cmap="RdBu" + ("_r" if reverse_col else ""),
39+
linewidths=0.0,
40+
linecolor="black",
41+
cbar=True,
42+
# color bar from 0 to 1
43+
cbar_kws={
44+
"label": color_label,
45+
# "ticks": [0, 25, 50, 75, 100],
46+
"orientation": "vertical",
47+
},
48+
)
49+
ax.collections[0].set_rasterized(True)
50+
51+
# Adjust x-axis tick labels
52+
plt.xlabel(x_label)
53+
plt.xticks(
54+
ticks=[x + 0.5 for x in range(0, 2010 - 1930 + 1, 20)],
55+
labels=[x for x in range(1930, 2010 + 1, 20)],
56+
rotation=0,
57+
# ha='right'
58+
)
59+
ax.invert_yaxis()
60+
61+
if y_ticks is not None:
62+
plt.yticks(ticks=[y + 0.5 - 1930 for y in y_ticks], labels=[y for y in y_ticks], rotation=0)
63+
elif y_ticks_bins is not None:
64+
ax.yaxis.set_major_locator(MaxNLocator(nbins=y_ticks_bins))
65+
ax.set_yticklabels([int(i) + min(heatmap_data.index) for i in ax.get_yticks()], rotation=0)
66+
67+
plt.ylabel(y_label)
68+
69+
# # Draft training boxes
70+
# if drift_pipeline:
71+
# for type_, dashed in [("train", False), ("usage", False), ("train", True)]:
72+
# for active_ in df_logs_models.iterrows():
73+
# x_start = active_[1][f"{type_}_start"].year - 1930
74+
# x_end = active_[1][f"{type_}_end"].year - 1930
75+
# y = active_[1]["model_idx"]
76+
# rect = plt.Rectangle(
77+
# (x_start, y - 1), # y: 0 based index, model_idx: 1 based index
78+
# x_end - x_start,
79+
# 1,
80+
# edgecolor="White" if type_ == "train" else "Black",
81+
# facecolor="none",
82+
# linewidth=3,
83+
# linestyle="dotted" if dashed else "solid",
84+
# hatch="/",
85+
# joinstyle="bevel",
86+
# # capstyle="round",
87+
# )
88+
# ax.add_patch(rect)
89+
90+
# Display the plot
91+
plt.tight_layout()
92+
# plt.show()
93+
94+
return fig
95+
96+
97+
def save_plot(fig: Figure, name: str) -> None:
98+
for img_type in ["png", "svg"]:
99+
img_path = Path("/scratch/robinholzi/gh/modyn/.data/plots") / f"{name}.{img_type}"
100+
fig.savefig(img_path, bbox_inches="tight", transparent=True)

0 commit comments

Comments
 (0)