Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add reusable plotting features #641

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions analytics/app/data/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,33 @@
# -------------------------------------------------------------------------------------------------------------------- #


def pipeline_leaf_times_df(
logs: PipelineLogs,
*,
use_traintime_patch_at_trainer: bool,
pipeline_id: str = "pipeline",
) -> pd.DataFrame:
pipeline_leaf_stages = leaf_stages(logs)
df_all = logs_dataframe(logs, f"pipeline_{pipeline_id}")
df_leaf_single = df_all[df_all["id"].isin(pipeline_leaf_stages)]
if not use_traintime_patch_at_trainer:
return df_leaf_single

df_leaf_only_train = df_leaf_single[df_leaf_single["id"] == PipelineStage.TRAIN.name]
df_leaf_wo_train = df_leaf_single[df_leaf_single["id"] != PipelineStage.TRAIN.name]

df_trainings = StageLog.df(
(x for x in logs.supervisor_logs.stage_runs if x.id == PipelineStage.TRAIN.name),
extended=True,
)
df_merged = df_leaf_only_train.merge(df_trainings, on="trigger_idx", how="inner", suffixes=("", "_training"))
assert df_merged.shape[0] == df_leaf_only_train.shape[0] == df_trainings.shape[0]
df_merged["duration"] = df_merged["train_time_at_trainer"] / 1000.0 # ms to s
df_merged = df_merged[df_leaf_only_train.columns]

return pd.concat([df_merged, df_leaf_wo_train])


def logs_dataframe(logs: PipelineLogs, pipeline_ref: str = "pipeline") -> pd.DataFrame:
df = logs.supervisor_logs.df
df["pipeline_ref"] = pipeline_ref
Expand Down
72 changes: 72 additions & 0 deletions analytics/plotting/common/color.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Any

import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap


def get_rdbu_wo_white(
palette: str = "RdBu",
strip: tuple[float, float] | None = (0.35, 0.65),
nvalues: int = 100,
) -> LinearSegmentedColormap | str:
if strip is None:
return palette

# Truncate the "RdBu" colormap to exclude the light colors
rd_bu_cmap = plt.get_cmap(palette)
custom_cmap_blu = rd_bu_cmap(np.linspace(0.0, strip[0], nvalues // 2))
custom_cmap_red = rd_bu_cmap(np.linspace(strip[1], 1.0, nvalues // 2))
cmap = LinearSegmentedColormap.from_list("truncated", np.concatenate([custom_cmap_blu, custom_cmap_red]))
return cmap


def gen_categorical_map(categories: list) -> dict[Any, tuple[float, float, float]]:
palette = (
sns.color_palette("bright")
+ sns.color_palette("dark")
+ sns.color_palette("colorblind")
+ sns.color_palette("pastel")
+ sns.color_palette("Paired") * 100
)[: len(categories)]
color_map = dict(zip(categories, palette))
return color_map


def discrete_colors(n: int = 10) -> Any:
return sns.color_palette("RdBu", n)


def discrete_color(i: int, n: int = 10) -> tuple[float, float, float]:
palette = discrete_colors(n)
return palette[i % n]


def main_colors(light: bool = False) -> list[tuple[float, float, float]]:
rdbu_palette = discrete_colors(10)
colorblind_palette = sns.color_palette("colorblind", 10)

if light:
return [
rdbu_palette[-2],
rdbu_palette[2],
colorblind_palette[-2],
colorblind_palette[1],
colorblind_palette[2],
colorblind_palette[3],
colorblind_palette[4],
]
return [
rdbu_palette[-1],
rdbu_palette[1],
colorblind_palette[-2],
colorblind_palette[1],
colorblind_palette[2],
colorblind_palette[4],
colorblind_palette[5],
]


def main_color(i: int, light: bool = False) -> tuple[float, float, float]:
return main_colors(light=light)[i]
2 changes: 2 additions & 0 deletions analytics/plotting/common/const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
DOUBLE_FIG_WIDTH = 10
DOUBLE_FIG_HEIGHT = 3.5
168 changes: 168 additions & 0 deletions analytics/plotting/common/cost_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import matplotlib.dates as mdates
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator

# Create the heatmap
from analytics.plotting.common.common import init_plot
from analytics.plotting.common.const import DOUBLE_FIG_HEIGHT, DOUBLE_FIG_WIDTH
from analytics.plotting.common.font import setup_font


def plot_cost_matrix(
df_costs: pd.DataFrame,
pipeline_ids: list[int],
grid_alpha: float = 0.0,
title_map: dict[int, str] = {},
height_factor: float = 1.0,
width_factor: float = 1.0,
duration_ylabel: str = "Duration (sec.)",
cumulative_ylabel: str = "Cumulative Duration (sec.)",
x_label: str = "Sample Timestamp",
x_lim: tuple[int, int] = (1930, 2013),
x_ticks: list[int] | None = None,
x_date_locator: mdates.DateLocator | None = None,
x_date_formatter: mdates.DateFormatter | None = None,
y_lim: tuple[int, int] = (0, 4000),
y_lim_cumulative: tuple[int, int] = (0, 4000),
y_ticks: list[int] | None = None,
y_ticks_cumulative: list[int] | None = None,
y_minutes: bool = False,
y_minutes_cumulative: bool = False,
) -> Figure | Axes:
"""
DataFrame columns:
pipeline_ref
id: supervisor leaf stage id
sample_time_year: sample year when this cost was recorded
duration: cost of the pipeline at that time
"""
sns.set_theme(style="whitegrid")
init_plot()
setup_font(small_label=True, small_title=True, small_ticks=True)

fig, axs = plt.subplots(
nrows=len(pipeline_ids),
ncols=2,
edgecolor="black",
frameon=True,
figsize=(
DOUBLE_FIG_WIDTH * width_factor,
2 * DOUBLE_FIG_HEIGHT * height_factor,
),
dpi=600,
)

x_col = "sample_time_year"
y_col = "duration"
hue_col = "id"

palette = sns.color_palette("RdBu", 10)
new_palette = {
"train": palette[0],
"inform remaining data": palette[-2],
"evaluate trigger policy": palette[2],
"inform trigger": palette[-1],
"store trained model": palette[1],
}

# use sum of all pipelines to determine the order of the bars that is consistent across subplots
df_agg = df_costs.groupby([hue_col]).agg({y_col: "sum"}).reset_index()
df_agg = df_agg.sort_values(y_col, ascending=False)
categories = df_agg[hue_col].unique()

legend_tuple = (pipeline_ids[0], True)

for row, pipeline_id in enumerate(pipeline_ids):
# sort by cumulative duration
df_costs_pipeline = df_costs[df_costs["pipeline_ref"] == f"pipeline_{pipeline_id}"]

for cumulative in [False, True]:
df_final = df_costs_pipeline.copy()
if cumulative and y_minutes_cumulative:
df_final[y_col] = df_final[y_col] / 60
elif not cumulative and y_minutes:
df_final[y_col] = df_final[y_col] / 60

ax = axs[row, int(cumulative)] if len(pipeline_ids) > 1 else axs[int(cumulative)]
h = sns.histplot(
df_final,
x=x_col,
weights=y_col,
bins=2014 - 1930 + 1,
cumulative=cumulative,
# discrete=True,
multiple="stack",
linewidth=0, # Remove white edges between bars
shrink=1.0, # Ensure bars touch each other
alpha=1.0, # remove transparaency
# hue
hue="id",
hue_order=categories,
palette=new_palette,
# ax=axs[int(cumulative)], # for 1 pipeline, only 1 row
ax=ax,
# legend
legend=legend_tuple == (pipeline_id, cumulative),
zorder=-2,
)

# Rasterize the heatmap background to avoid anti-aliasing artifacts
for bar in h.patches:
bar.set_rasterized(True)

h.grid(axis="y", linestyle="--", alpha=grid_alpha, zorder=3, color="lightgray")
h.grid(axis="x", linestyle="--", alpha=grid_alpha, zorder=3, color="lightgray")

if len(title_map) > 0:
# size huge
h.set_title(title_map[pipeline_id])

# # Set x-axis
h.set(xlim=x_lim)
h.set_xlabel(x_label, labelpad=10)

if x_date_locator:
h.xaxis.set_major_locator(x_date_locator)
# ax.set_xticklabels(x_ticks, rotation=0)
h.xaxis.set_major_formatter(x_date_formatter)
# ticks = ax.get_xticks()
plt.xticks(rotation=0)
elif x_ticks is not None:
h.set_xticks(
ticks=x_ticks,
labels=x_ticks,
rotation=0,
# ha='right'
)

if cumulative:
h.set_ylabel(cumulative_ylabel, labelpad=20)
if y_lim_cumulative:
h.set(ylim=y_lim_cumulative)
if y_ticks_cumulative:
h.set_yticks(ticks=y_ticks_cumulative, labels=y_ticks_cumulative, rotation=0)
else:
h.yaxis.set_major_locator(MaxNLocator(nbins=4))
else:
h.set_ylabel(duration_ylabel, labelpad=20)
if y_ticks:
h.set_yticks(ticks=y_ticks, labels=y_ticks, rotation=0)
else:
h.yaxis.set_major_locator(MaxNLocator(nbins=4))
if legend_tuple == (pipeline_id, cumulative):
# set hue label
legend = h.get_legend()

legend.set_title("") # remove title

# expand legend horizontally
# legend.set_bbox_to_anchor((0, 1, 1, 0), transform=h.transAxes)

# Display the plot
plt.tight_layout()

return fig
Loading
Loading