Skip to content

Commit

Permalink
common plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
robinholzi committed Dec 28, 2024
1 parent b04a2e1 commit 0db5a32
Show file tree
Hide file tree
Showing 11 changed files with 1,280 additions and 54 deletions.
27 changes: 27 additions & 0 deletions analytics/app/data/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,33 @@
# -------------------------------------------------------------------------------------------------------------------- #


def pipeline_leaf_times_df(
logs: PipelineLogs,
*,
use_traintime_patch_at_trainer: bool,
pipeline_id: str = "pipeline",
) -> pd.DataFrame:
pipeline_leaf_stages = leaf_stages(logs)
df_all = logs_dataframe(logs, f"pipeline_{pipeline_id}")
df_leaf_single = df_all[df_all["id"].isin(pipeline_leaf_stages)]
if not use_traintime_patch_at_trainer:
return df_leaf_single

df_leaf_only_train = df_leaf_single[df_leaf_single["id"] == PipelineStage.TRAIN.name]
df_leaf_wo_train = df_leaf_single[df_leaf_single["id"] != PipelineStage.TRAIN.name]

df_trainings = StageLog.df(
(x for x in logs.supervisor_logs.stage_runs if x.id == PipelineStage.TRAIN.name),
extended=True,
)
df_merged = df_leaf_only_train.merge(df_trainings, on="trigger_idx", how="inner", suffixes=("", "_training"))
assert df_merged.shape[0] == df_leaf_only_train.shape[0] == df_trainings.shape[0]
df_merged["duration"] = df_merged["train_time_at_trainer"] / 1000.0 # ms to s
df_merged = df_merged[df_leaf_only_train.columns]

return pd.concat([df_merged, df_leaf_wo_train])


def logs_dataframe(logs: PipelineLogs, pipeline_ref: str = "pipeline") -> pd.DataFrame:
df = logs.supervisor_logs.df
df["pipeline_ref"] = pipeline_ref
Expand Down
72 changes: 72 additions & 0 deletions analytics/plotting/common/color.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Any

import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap


def get_rdbu_wo_white(
palette: str = "RdBu",
strip: tuple[float, float] | None = (0.35, 0.65),
nvalues: int = 100,
) -> LinearSegmentedColormap | str:
if strip is None:
return palette

# Truncate the "RdBu" colormap to exclude the light colors
rd_bu_cmap = plt.get_cmap(palette)
custom_cmap_blu = rd_bu_cmap(np.linspace(0.0, strip[0], nvalues // 2))
custom_cmap_red = rd_bu_cmap(np.linspace(strip[1], 1.0, nvalues // 2))
cmap = LinearSegmentedColormap.from_list("truncated", np.concatenate([custom_cmap_blu, custom_cmap_red]))
return cmap


def gen_categorical_map(categories: list) -> dict[Any, tuple[float, float, float]]:
palette = (
sns.color_palette("bright")
+ sns.color_palette("dark")
+ sns.color_palette("colorblind")
+ sns.color_palette("pastel")
+ sns.color_palette("Paired") * 100
)[: len(categories)]
color_map = dict(zip(categories, palette))
return color_map


def discrete_colors(n: int = 10):
return sns.color_palette("RdBu", n)


def discrete_color(i: int, n: int = 10) -> tuple[float, float, float]:
palette = discrete_colors(n)
return palette[i % n]


def main_colors(light: bool = False) -> list[tuple[float, float, float]]:
rdbu_palette = discrete_colors(10)
colorblind_palette = sns.color_palette("colorblind", 10)

if light:
return [
rdbu_palette[-2],
rdbu_palette[2],
colorblind_palette[-2],
colorblind_palette[1],
colorblind_palette[2],
colorblind_palette[3],
colorblind_palette[4],
]
return [
rdbu_palette[-1],
rdbu_palette[1],
colorblind_palette[-2],
colorblind_palette[1],
colorblind_palette[2],
colorblind_palette[4],
colorblind_palette[5],
]


def main_color(i: int, light: bool = False) -> tuple[float, float, float]:
return main_colors(light=light)[i]
2 changes: 2 additions & 0 deletions analytics/plotting/common/const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
DOUBLE_FIG_WIDTH = 10
DOUBLE_FIG_HEIGHT = 3.5
168 changes: 168 additions & 0 deletions analytics/plotting/common/cost_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import matplotlib.dates as mdates
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator

# Create the heatmap
from analytics.plotting.common.common import init_plot
from analytics.plotting.common.const import DOUBLE_FIG_HEIGHT, DOUBLE_FIG_WIDTH
from analytics.plotting.common.font import setup_font


def plot_cost_matrix(
df_costs: pd.DataFrame,
pipeline_ids: list[int],
grid_alpha: float = 0.0,
title_map: dict[int, str] = {},
height_factor: float = 1.0,
width_factor: float = 1.0,
duration_ylabel: str = "Duration (sec.)",
cumulative_ylabel: str = "Cumulative Duration (sec.)",
x_label: str = "Sample Timestamp",
x_lim: tuple[int, int] = (1930, 2013),
x_ticks: list[int] | None = None,
x_date_locator: mdates.DateLocator | None = None,
x_date_formatter: mdates.DateFormatter | None = None,
y_lim: tuple[int, int] = (0, 4000),
y_lim_cumulative: tuple[int, int] = (0, 4000),
y_ticks: list[int] | None = None,
y_ticks_cumulative: list[int] | None = None,
y_minutes: bool = False,
y_minutes_cumulative: bool = False,
) -> Figure | Axes:
"""
DataFrame columns:
pipeline_ref
id: supervisor leaf stage id
sample_time_year: sample year when this cost was recorded
duration: cost of the pipeline at that time
"""
sns.set_theme(style="whitegrid")
init_plot()
setup_font(small_label=True, small_title=True, small_ticks=True)

fig, axs = plt.subplots(
nrows=len(pipeline_ids),
ncols=2,
edgecolor="black",
frameon=True,
figsize=(
DOUBLE_FIG_WIDTH * width_factor,
2 * DOUBLE_FIG_HEIGHT * height_factor,
),
dpi=600,
)

x_col = "sample_time_year"
y_col = "duration"
hue_col = "id"

palette = sns.color_palette("RdBu", 10)
new_palette = {
"train": palette[0],
"inform remaining data": palette[-2],
"evaluate trigger policy": palette[2],
"inform trigger": palette[-1],
"store trained model": palette[1],
}

# use sum of all pipelines to determine the order of the bars that is consistent across subplots
df_agg = df_costs.groupby([hue_col]).agg({y_col: "sum"}).reset_index()
df_agg = df_agg.sort_values(y_col, ascending=False)
categories = df_agg[hue_col].unique()

legend_tuple = (pipeline_ids[0], True)

for row, pipeline_id in enumerate(pipeline_ids):
# sort by cumulative duration
df_costs_pipeline = df_costs[df_costs["pipeline_ref"] == f"pipeline_{pipeline_id}"]

for cumulative in [False, True]:
df_final = df_costs_pipeline.copy()
if cumulative and y_minutes_cumulative:
df_final[y_col] = df_final[y_col] / 60
elif not cumulative and y_minutes:
df_final[y_col] = df_final[y_col] / 60

ax = axs[row, int(cumulative)] if len(pipeline_ids) > 1 else axs[int(cumulative)]
h = sns.histplot(
df_final,
x=x_col,
weights=y_col,
bins=2014 - 1930 + 1,
cumulative=cumulative,
# discrete=True,
multiple="stack",
linewidth=0, # Remove white edges between bars
shrink=1.0, # Ensure bars touch each other
alpha=1.0, # remove transparaency
# hue
hue="id",
hue_order=categories,
palette=new_palette,
# ax=axs[int(cumulative)], # for 1 pipeline, only 1 row
ax=ax,
# legend
legend=legend_tuple == (pipeline_id, cumulative),
zorder=-2,
)

# Rasterize the heatmap background to avoid anti-aliasing artifacts
for bar in h.patches:
bar.set_rasterized(True)

h.grid(axis="y", linestyle="--", alpha=grid_alpha, zorder=3, color="lightgray")
h.grid(axis="x", linestyle="--", alpha=grid_alpha, zorder=3, color="lightgray")

if len(title_map) > 0:
# size huge
h.set_title(title_map[pipeline_id])

# # Set x-axis
h.set(xlim=x_lim)
h.set_xlabel(x_label, labelpad=10)

if x_date_locator:
h.xaxis.set_major_locator(x_date_locator)
# ax.set_xticklabels(x_ticks, rotation=0)
h.xaxis.set_major_formatter(x_date_formatter)
# ticks = ax.get_xticks()
plt.xticks(rotation=0)
elif x_ticks is not None:
h.set_xticks(
ticks=x_ticks,
labels=x_ticks,
rotation=0,
# ha='right'
)

if cumulative:
h.set_ylabel(cumulative_ylabel, labelpad=20)
if y_lim_cumulative:
h.set(ylim=y_lim_cumulative)
if y_ticks_cumulative:
h.set_yticks(ticks=y_ticks_cumulative, labels=y_ticks_cumulative, rotation=0)
else:
h.yaxis.set_major_locator(MaxNLocator(nbins=4))
else:
h.set_ylabel(duration_ylabel, labelpad=20)
if y_ticks:
h.set_yticks(ticks=y_ticks, labels=y_ticks, rotation=0)
else:
h.yaxis.set_major_locator(MaxNLocator(nbins=4))
if legend_tuple == (pipeline_id, cumulative):
# set hue label
legend = h.get_legend()

legend.set_title("") # remove title

# expand legend horizontally
# legend.set_bbox_to_anchor((0, 1, 1, 0), transform=h.transAxes)

# Display the plot
plt.tight_layout()

return fig
Loading

0 comments on commit 0db5a32

Please sign in to comment.