diff --git a/analytics/app/data/transform.py b/analytics/app/data/transform.py index d71179606..72c317b8c 100644 --- a/analytics/app/data/transform.py +++ b/analytics/app/data/transform.py @@ -23,6 +23,33 @@ # -------------------------------------------------------------------------------------------------------------------- # +def pipeline_leaf_times_df( + logs: PipelineLogs, + *, + use_traintime_patch_at_trainer: bool, + pipeline_id: str = "pipeline", +) -> pd.DataFrame: + pipeline_leaf_stages = leaf_stages(logs) + df_all = logs_dataframe(logs, f"pipeline_{pipeline_id}") + df_leaf_single = df_all[df_all["id"].isin(pipeline_leaf_stages)] + if not use_traintime_patch_at_trainer: + return df_leaf_single + + df_leaf_only_train = df_leaf_single[df_leaf_single["id"] == PipelineStage.TRAIN.name] + df_leaf_wo_train = df_leaf_single[df_leaf_single["id"] != PipelineStage.TRAIN.name] + + df_trainings = StageLog.df( + (x for x in logs.supervisor_logs.stage_runs if x.id == PipelineStage.TRAIN.name), + extended=True, + ) + df_merged = df_leaf_only_train.merge(df_trainings, on="trigger_idx", how="inner", suffixes=("", "_training")) + assert df_merged.shape[0] == df_leaf_only_train.shape[0] == df_trainings.shape[0] + df_merged["duration"] = df_merged["train_time_at_trainer"] / 1000.0 # ms to s + df_merged = df_merged[df_leaf_only_train.columns] + + return pd.concat([df_merged, df_leaf_wo_train]) + + def logs_dataframe(logs: PipelineLogs, pipeline_ref: str = "pipeline") -> pd.DataFrame: df = logs.supervisor_logs.df df["pipeline_ref"] = pipeline_ref diff --git a/analytics/plotting/common/color.py b/analytics/plotting/common/color.py new file mode 100644 index 000000000..ed179b35f --- /dev/null +++ b/analytics/plotting/common/color.py @@ -0,0 +1,72 @@ +from typing import Any + +import numpy as np +import seaborn as sns +from matplotlib import pyplot as plt +from matplotlib.colors import LinearSegmentedColormap + + +def get_rdbu_wo_white( + palette: str = "RdBu", + strip: tuple[float, float] | None = (0.35, 0.65), + nvalues: int = 100, +) -> LinearSegmentedColormap | str: + if strip is None: + return palette + + # Truncate the "RdBu" colormap to exclude the light colors + rd_bu_cmap = plt.get_cmap(palette) + custom_cmap_blu = rd_bu_cmap(np.linspace(0.0, strip[0], nvalues // 2)) + custom_cmap_red = rd_bu_cmap(np.linspace(strip[1], 1.0, nvalues // 2)) + cmap = LinearSegmentedColormap.from_list("truncated", np.concatenate([custom_cmap_blu, custom_cmap_red])) + return cmap + + +def gen_categorical_map(categories: list) -> dict[Any, tuple[float, float, float]]: + palette = ( + sns.color_palette("bright") + + sns.color_palette("dark") + + sns.color_palette("colorblind") + + sns.color_palette("pastel") + + sns.color_palette("Paired") * 100 + )[: len(categories)] + color_map = dict(zip(categories, palette)) + return color_map + + +def discrete_colors(n: int = 10) -> Any: + return sns.color_palette("RdBu", n) + + +def discrete_color(i: int, n: int = 10) -> tuple[float, float, float]: + palette = discrete_colors(n) + return palette[i % n] + + +def main_colors(light: bool = False) -> list[tuple[float, float, float]]: + rdbu_palette = discrete_colors(10) + colorblind_palette = sns.color_palette("colorblind", 10) + + if light: + return [ + rdbu_palette[-2], + rdbu_palette[2], + colorblind_palette[-2], + colorblind_palette[1], + colorblind_palette[2], + colorblind_palette[3], + colorblind_palette[4], + ] + return [ + rdbu_palette[-1], + rdbu_palette[1], + colorblind_palette[-2], + colorblind_palette[1], + colorblind_palette[2], + colorblind_palette[4], + colorblind_palette[5], + ] + + +def main_color(i: int, light: bool = False) -> tuple[float, float, float]: + return main_colors(light=light)[i] diff --git a/analytics/plotting/common/const.py b/analytics/plotting/common/const.py new file mode 100644 index 000000000..d6aa2cbf1 --- /dev/null +++ b/analytics/plotting/common/const.py @@ -0,0 +1,2 @@ +DOUBLE_FIG_WIDTH = 10 +DOUBLE_FIG_HEIGHT = 3.5 diff --git a/analytics/plotting/common/cost_matrix.py b/analytics/plotting/common/cost_matrix.py new file mode 100644 index 000000000..fb6695876 --- /dev/null +++ b/analytics/plotting/common/cost_matrix.py @@ -0,0 +1,168 @@ +import matplotlib.dates as mdates +import pandas as pd +import seaborn as sns +from matplotlib import pyplot as plt +from matplotlib.axes import Axes +from matplotlib.figure import Figure +from matplotlib.ticker import MaxNLocator + +# Create the heatmap +from analytics.plotting.common.common import init_plot +from analytics.plotting.common.const import DOUBLE_FIG_HEIGHT, DOUBLE_FIG_WIDTH +from analytics.plotting.common.font import setup_font + + +def plot_cost_matrix( + df_costs: pd.DataFrame, + pipeline_ids: list[int], + grid_alpha: float = 0.0, + title_map: dict[int, str] = {}, + height_factor: float = 1.0, + width_factor: float = 1.0, + duration_ylabel: str = "Duration (sec.)", + cumulative_ylabel: str = "Cumulative Duration (sec.)", + x_label: str = "Sample Timestamp", + x_lim: tuple[int, int] = (1930, 2013), + x_ticks: list[int] | None = None, + x_date_locator: mdates.DateLocator | None = None, + x_date_formatter: mdates.DateFormatter | None = None, + y_lim: tuple[int, int] = (0, 4000), + y_lim_cumulative: tuple[int, int] = (0, 4000), + y_ticks: list[int] | None = None, + y_ticks_cumulative: list[int] | None = None, + y_minutes: bool = False, + y_minutes_cumulative: bool = False, +) -> Figure | Axes: + """ + DataFrame columns: + pipeline_ref + id: supervisor leaf stage id + sample_time_year: sample year when this cost was recorded + duration: cost of the pipeline at that time + """ + sns.set_theme(style="whitegrid") + init_plot() + setup_font(small_label=True, small_title=True, small_ticks=True) + + fig, axs = plt.subplots( + nrows=len(pipeline_ids), + ncols=2, + edgecolor="black", + frameon=True, + figsize=( + DOUBLE_FIG_WIDTH * width_factor, + 2 * DOUBLE_FIG_HEIGHT * height_factor, + ), + dpi=600, + ) + + x_col = "sample_time_year" + y_col = "duration" + hue_col = "id" + + palette = sns.color_palette("RdBu", 10) + new_palette = { + "train": palette[0], + "inform remaining data": palette[-2], + "evaluate trigger policy": palette[2], + "inform trigger": palette[-1], + "store trained model": palette[1], + } + + # use sum of all pipelines to determine the order of the bars that is consistent across subplots + df_agg = df_costs.groupby([hue_col]).agg({y_col: "sum"}).reset_index() + df_agg = df_agg.sort_values(y_col, ascending=False) + categories = df_agg[hue_col].unique() + + legend_tuple = (pipeline_ids[0], True) + + for row, pipeline_id in enumerate(pipeline_ids): + # sort by cumulative duration + df_costs_pipeline = df_costs[df_costs["pipeline_ref"] == f"pipeline_{pipeline_id}"] + + for cumulative in [False, True]: + df_final = df_costs_pipeline.copy() + if cumulative and y_minutes_cumulative: + df_final[y_col] = df_final[y_col] / 60 + elif not cumulative and y_minutes: + df_final[y_col] = df_final[y_col] / 60 + + ax = axs[row, int(cumulative)] if len(pipeline_ids) > 1 else axs[int(cumulative)] + h = sns.histplot( + df_final, + x=x_col, + weights=y_col, + bins=2014 - 1930 + 1, + cumulative=cumulative, + # discrete=True, + multiple="stack", + linewidth=0, # Remove white edges between bars + shrink=1.0, # Ensure bars touch each other + alpha=1.0, # remove transparaency + # hue + hue="id", + hue_order=categories, + palette=new_palette, + # ax=axs[int(cumulative)], # for 1 pipeline, only 1 row + ax=ax, + # legend + legend=legend_tuple == (pipeline_id, cumulative), + zorder=-2, + ) + + # Rasterize the heatmap background to avoid anti-aliasing artifacts + for bar in h.patches: + bar.set_rasterized(True) + + h.grid(axis="y", linestyle="--", alpha=grid_alpha, zorder=3, color="lightgray") + h.grid(axis="x", linestyle="--", alpha=grid_alpha, zorder=3, color="lightgray") + + if len(title_map) > 0: + # size huge + h.set_title(title_map[pipeline_id]) + + # # Set x-axis + h.set(xlim=x_lim) + h.set_xlabel(x_label, labelpad=10) + + if x_date_locator: + h.xaxis.set_major_locator(x_date_locator) + # ax.set_xticklabels(x_ticks, rotation=0) + h.xaxis.set_major_formatter(x_date_formatter) + # ticks = ax.get_xticks() + plt.xticks(rotation=0) + elif x_ticks is not None: + h.set_xticks( + ticks=x_ticks, + labels=x_ticks, + rotation=0, + # ha='right' + ) + + if cumulative: + h.set_ylabel(cumulative_ylabel, labelpad=20) + if y_lim_cumulative: + h.set(ylim=y_lim_cumulative) + if y_ticks_cumulative: + h.set_yticks(ticks=y_ticks_cumulative, labels=y_ticks_cumulative, rotation=0) + else: + h.yaxis.set_major_locator(MaxNLocator(nbins=4)) + else: + h.set_ylabel(duration_ylabel, labelpad=20) + if y_ticks: + h.set_yticks(ticks=y_ticks, labels=y_ticks, rotation=0) + else: + h.yaxis.set_major_locator(MaxNLocator(nbins=4)) + if legend_tuple == (pipeline_id, cumulative): + # set hue label + legend = h.get_legend() + + legend.set_title("") # remove title + + # expand legend horizontally + # legend.set_bbox_to_anchor((0, 1, 1, 0), transform=h.transAxes) + + # Display the plot + plt.tight_layout() + + return fig diff --git a/analytics/plotting/common/dataset_histogram.py b/analytics/plotting/common/dataset_histogram.py new file mode 100644 index 000000000..1c1a69eb1 --- /dev/null +++ b/analytics/plotting/common/dataset_histogram.py @@ -0,0 +1,439 @@ +import matplotlib.dates as mdates +import numpy as np +import pandas as pd +import seaborn as sns +from matplotlib import pyplot as plt +from matplotlib.figure import Figure +from matplotlib.ticker import MaxNLocator + +from analytics.plotting.common.color import ( + gen_categorical_map, + main_color, + main_colors, +) +from analytics.plotting.common.common import DOUBLE_FIG_HEIGHT, init_plot +from analytics.plotting.common.const import DOUBLE_FIG_WIDTH +from analytics.plotting.common.font import setup_font + + +def build_countplot( + histogram_data: pd.DataFrame, + x: str, + y_ticks: list[int] | None = None, + y_ticks_bins: int | None = None, + x_ticks: list[int] | None = None, + y_label: str = "Number of Samples", + x_label: str = "Year", + height_factor: float = 1.0, + width_factor: float = 1.0, + palette: str = "RdBu", + palette_strip: tuple[float, float] | None = (0.35, 0.65), +) -> Figure: + init_plot() + setup_font() + + fig = plt.figure( + edgecolor="black", + frameon=True, + figsize=( + DOUBLE_FIG_WIDTH * width_factor, + 2 * DOUBLE_FIG_HEIGHT * height_factor, + ), + dpi=600, + ) + ax = fig.add_subplot(111) + + agg_by_year = histogram_data.groupby(x).size().reset_index(name="count") + + ax = sns.barplot( + data=agg_by_year, + x=x, + y="count", + color=main_color(0), + # hue="count", + # palette=get_rdbu_wo_white(palette=palette, strip=palette_strip), + width=1, + legend=False, + ) + + # avoid fine white lines between cells + for artist in ax.patches: # ax.patches contains the bars in the plot + artist.set_rasterized(True) + + # draw grid behind bars (horizontal and vertical) + ax.grid(axis="x", linestyle="--", alpha=1.0) + ax.grid(axis="y", linestyle="--", alpha=1.0) + + # Adjust x-axis tick labels + plt.xlabel(x_label) + if x_ticks is not None: + plt.xticks( + ticks=[xtick - min(histogram_data[x]) for xtick in x_ticks], + labels=x_ticks, + rotation=0, + # ha='right' + ) + + plt.ylabel(y_label) + if y_ticks is not None: + plt.yticks(ticks=y_ticks, labels=y_ticks, rotation=0) + elif y_ticks_bins is not None: + ax.yaxis.set_major_locator(MaxNLocator(nbins=y_ticks_bins)) + ax.set_yticklabels([int(i) for i in ax.get_yticks()], rotation=0) + + # Display the plot + plt.tight_layout() + + return fig + + +def build_histogram_multicategory_facets( + histogram_data: pd.DataFrame, + x: str, + label: str, + sorted_categories: pd.Series, + y_ticks: list[int | float] | None = None, + y_ticks_bins: int | None = None, + x_ticks: list[pd.Timestamp] | None = None, + y_label: str = "Number of Samples", + x_label: str = "Year", + sharey: bool = False, + height_factor: float = 1.0, + width_factor: float = 1.0, + legend_labels: list[str] | None = None, +) -> Figure: + color_map = gen_categorical_map(sorted_categories) + histogram_data = histogram_data.copy() + + init_plot() + setup_font() + + # Create a FacetGrid object with 'sex' as the categorical label for facets + g = sns.FacetGrid( + histogram_data, + col=label, + margin_titles=False, + col_wrap=6, + sharey=sharey, # sharey=False allows independent y-axis + sharex=True, + col_order=sorted_categories, + subplot_kws={}, + despine=True, + # gridspec_kws={"hspace": 0, "wspace": 0}, + ) + + g.figure.set_dpi(300) + g.figure.set_figwidth(DOUBLE_FIG_WIDTH * width_factor) + g.figure.set_figheight(2 * DOUBLE_FIG_HEIGHT * height_factor) + + g.map_dataframe( + sns.histplot, + # data=histogram_data, # supplied by map_dataframe + x=x, + hue=label, + palette=color_map, + edgecolor=None, # Disable black borders + element="bars", # bars, poly, bars + multiple="dodge", # layer, **dodge**, **fill**, **stack** + bins=40, + ) + + g.set_titles("{col_name}") # only the value in the facet name + + # Adjust x-axis tick labels + # g.set(xlabel=x_label) + if x_ticks is not None: + g.set(xticks=x_ticks) + + for ax in g.axes.flat: + ax.xaxis.set_major_formatter(mdates.DateFormatter("%b\n%Y")) + ax.figure.autofmt_xdate(ha="center", rotation=0) # Auto-rotate the date labels + + for ax in g.axes.flat: + # draw grid behind bars (horizontal and vertical) + ax.grid(axis="x", alpha=1.0, linestyle="--") + ax.grid(axis="y", alpha=1.0, linestyle="--") + + # g.set(ylabel=y_label) + # Hide y-axis labels for all but the leftmost column + for i, ax in enumerate(g.axes.flat): + # ax.set_xlabel(x_label, labelpad=10) + # if i % 4 != 0: # Check if it's not in the leftmost column + ax.set_ylabel(None) + ax.set_xlabel(None) + + # center the x-axis labels + ax.tick_params(axis="x", rotation=0, pad=6) + ax.tick_params(axis="y", pad=10) + + # avoid fine white lines between cells + for artist in ax.patches: # ax.patches contains the bars in the plot + artist.set_rasterized(True) + + # g.set_axis_labels( + # x_var=x_label, + # y_var=y_label, + # clear_inner=True, + # ) + + # Add common x and y labels with custom placement + g.figure.text(0.5, 0.0, x_label, ha="center", va="center", fontsize="large") + g.figure.text( + 0.0, + 0.5, + y_label, + ha="center", + va="center", + rotation="vertical", + fontsize="large", + ) + + plt.tight_layout() + g.figure.subplots_adjust(wspace=0.4) # Reduce horizontal space between subplots + return g + + +def build_histogram_multicategory_barnorm( + histogram_data: pd.DataFrame, + x: str, + label: str, + sorted_coloring_categories: pd.Series, + sorted_ordering_categories: pd.Series | None = None, + y_ticks: list[int | float] | None = None, + y_ticks_bins: int | None = None, + x_ticks: list[pd.Timestamp] | None = None, + y_label: str = "Number of Samples", + x_label: str = "Year", + height_factor: float = 1.0, + width_factor: float = 1.0, + legend: bool = True, + legend_labels: list[str] | None = None, + legend_title: str | None = None, + nbins: int | None = None, + manual_color_map: dict[str, tuple[float, float, float]] | None = None, + grid_opacity: float = 1.0, + col_alpha: float | None = None, +) -> Figure: + if sorted_ordering_categories is None: + sorted_ordering_categories = sorted_coloring_categories + if legend_labels is None: + legend_labels = [] + + histogram_data = histogram_data.copy() + # rename: if label not in legend_labels, add underscore to label to hide it from the legend + histogram_data[label] = histogram_data[label].apply(lambda x: x if x in (legend_labels) else f"_{x}") + underscore_col_categories = [x if x in (legend_labels) else f"_{x}" for x in sorted_coloring_categories] + underscore_ordering_categories = [x if x in (legend_labels) else f"_{x}" for x in sorted_ordering_categories] + underscore_ordering_categories += [ + # add any missing categories to the end of the list + x + for x in underscore_col_categories + if x not in underscore_ordering_categories + ] + color_map = gen_categorical_map(underscore_col_categories) + + init_plot() + setup_font() + + fig = plt.figure( + edgecolor="black", + frameon=True, + figsize=( + DOUBLE_FIG_WIDTH * width_factor, + 2 * DOUBLE_FIG_HEIGHT * height_factor, + ), + dpi=600, + ) + ax = fig.add_subplot(111) + + ax = sns.histplot( + data=histogram_data, + x=x, + hue=label, + palette=manual_color_map if manual_color_map else color_map, + hue_order=underscore_ordering_categories, + linewidth=0, # avoid fine white lines between cells + edgecolor=None, # Disable black borders + # legend=len(legend_labels or []) > 0, + legend=legend, + element="bars", # bars, poly, bars + multiple="fill", # layer, **dodge**, **fill**, **stack** + **{"bins": nbins} if nbins is not None else {}, + # opacity + **{"alpha": col_alpha} if col_alpha is not None else {}, + ax=ax, + ) + ax.invert_yaxis() + + # avoid fine white lines between cells + for artist in ax.patches: # ax.patches contains the bars in the plot + artist.set_rasterized(True) + + # position legend outside of plot + if legend and len(legend_labels) > 0: + ax.get_legend().set_bbox_to_anchor((1.05, 1.05)) + + if legend_title is not None: + ax.get_legend().set_title(legend_title) + + # draw grid behind bars (horizontal and vertical) + ax.grid(axis="x", linestyle="--", alpha=grid_opacity, color="white") + ax.grid(axis="y", linestyle="--", alpha=grid_opacity, color="white") + + # Adjust x-axis tick labels + plt.xlabel(x_label) + if x_ticks is not None: + # ax.xaxis.set_major_locator(DateLocator()) + # ax.set_xticklabels(x_ticks, rotation=0) + # plt.xticks( + # ticks=x_ticks, + # labels=x_ticks, + # rotation=0, + # # ha='right' + # ) + plt.xticks(x_ticks) + date_form = mdates.DateFormatter("%b\n%Y") # Customize format: "2020 Jan" + ax.xaxis.set_major_formatter(date_form) + + # Optionally, adjust the number of ticks on x-axis + # ax.xaxis.set_major_locator(mdates.YearLocator(base=4)) # Show every 3 months + + # ax.yaxis.set_major_locator(MaxNLocator(nbins=y_ticks_bins)) + # # ax.set_yticklabels([int(i) + histogram_data["x"].min() for i in ax.get_yticks()], rotation=0) + + plt.ylabel(y_label) + if y_ticks is not None: + plt.yticks(ticks=y_ticks, labels=list(reversed(y_ticks)), rotation=0) + elif y_ticks_bins is not None: + ax.yaxis.set_major_locator(MaxNLocator(nbins=y_ticks_bins)) + # ax.set_yticklabels([int(i) for i in ax.get_yticks()], rotation=0) + + # Display the plot + plt.tight_layout() + # plt.show() + + return fig + + +def build_cum_barplot( + histogram_data: pd.DataFrame, + x: str, + y: str, + y_ticks: list[int] | None = None, + y_ticks_bins: int | None = None, + x_ticks: list[int] | None = None, + x_ticks_bins: int | None = None, + y_label: str = "Number of Samples", + x_label: str = "Year", + height_factor: float = 1.0, + width_factor: float = 1.0, + palette: str = "RdBu", + palette_strip: tuple[float, float] | None = (0.35, 0.65), +) -> Figure: + init_plot() + setup_font() + + fig = plt.figure( + edgecolor="black", + frameon=True, + figsize=( + DOUBLE_FIG_WIDTH * width_factor, + 2 * DOUBLE_FIG_HEIGHT * height_factor, + ), + dpi=600, + ) + ax = fig.add_subplot(111) + + ax = sns.lineplot( + data=histogram_data, + x=x, + y=y, + color=main_color(0), + # market size + # markers=False + # markers=True, + # hue=y, + # # palette=get_rdbu_wo_white(palette=palette, strip=palette_strip), + # width=1, + # legend=False, + # # fill=True, + # edgecolor=".5", + # facecolor=(0, 0, 0, 0), + ax=ax, + ) + # TODO: check gap, dodged elements --> if pdf shows white lines + + # draw grid behind bars (horizontal and vertical) + ax.grid(axis="x", linestyle="--", alpha=1.0) + ax.grid(axis="y", linestyle="--", alpha=1.0) + + # Adjust x-axis tick labels + plt.xlabel(x_label) + if x_ticks is not None: + plt.xticks( + ticks=[xtick - min(histogram_data[x]) for xtick in x_ticks], + labels=x_ticks, + rotation=0, + # ha='right' + ) + elif x_ticks_bins is not None: + ax.xaxis.set_major_locator(MaxNLocator(nbins=x_ticks_bins)) + + ax.yaxis.set_major_locator(MaxNLocator(nbins=y_ticks_bins)) + # ax.set_yticklabels([int(i) + histogram_data["x"].min() for i in ax.get_yticks()], rotation=0) + + plt.ylabel(y_label) + if y_ticks is not None: + plt.yticks(ticks=y_ticks, labels=y_ticks, rotation=0) + elif y_ticks_bins is not None: + ax.yaxis.set_major_locator(MaxNLocator(nbins=y_ticks_bins)) + # ax.set_yticklabels([int(i) for i in ax.get_yticks()], rotation=0) + + # Display the plot + plt.tight_layout() + # plt.show() + + return fig + + +def build_pieplot( + x: list[int], + labels: list[str], + height_factor: float = 1.0, + width_factor: float = 1.0, +) -> Figure: + init_plot() + setup_font() + + fig = plt.figure( + edgecolor="black", + frameon=True, + figsize=( + DOUBLE_FIG_WIDTH * width_factor, + 2 * DOUBLE_FIG_HEIGHT * height_factor, + ), + dpi=600, + ) + + def func(pct: float, allvals: list[int]) -> str: + absolute = int(np.round(pct / 100.0 * np.sum(allvals))) + return f"{pct:.1f}%\n({absolute:d})" + + wedges, texts, autotexts = plt.pie( + x=x, + labels=labels, + autopct=lambda pct: func(pct, x), + textprops=dict(color="w"), + colors=main_colors(), + # show labels next to the pie chart + startangle=90, + explode=(0.1, 0), + ) + + plt.setp(autotexts, size=8, weight="bold") + + # Display the plot + plt.tight_layout() + # plt.show() + + return fig diff --git a/analytics/plotting/common/font.py b/analytics/plotting/common/font.py new file mode 100644 index 000000000..8c31d4ced --- /dev/null +++ b/analytics/plotting/common/font.py @@ -0,0 +1,35 @@ +import matplotlib.font_manager as fm +from matplotlib import pyplot as plt + +__loaded = False + + +def load_font() -> None: + global __loaded + if __loaded: + return + + cmu_fonts = [ + x + for x in fm.findSystemFonts(fontpaths=["/Users/robinholzinger/Library/Fonts/"]) + if "cmu" in x or "p052" in x.lower() + ] + + for font in cmu_fonts: + # Register the font with Matplotlib's font manager + # font_prop = fm.FontProperties(fname=font) + fm.fontManager.addfont(font) + + assert len([f.name for f in fm.fontManager.ttflist if "cmu" in f.name.lower()]) >= 2 + + +def setup_font(small_label: bool = False, small_title: bool | None = None, small_ticks: bool = True) -> None: + load_font() + plt.rcParams["svg.fonttype"] = "none" + plt.rcParams["font.family"] = "P052" # latex default: "CMU Serif", robin thesis: P052 + plt.rcParams["legend.fontsize"] = "small" if small_ticks else "medium" + plt.rcParams["xtick.labelsize"] = "small" if small_ticks else "medium" + plt.rcParams["ytick.labelsize"] = "small" if small_ticks else "medium" + plt.rcParams["axes.labelsize"] = "small" if small_label else "medium" + if small_title is not None: + plt.rcParams["axes.titlesize"] = "small" if small_title else "medium" diff --git a/analytics/plotting/common/heatmap.py b/analytics/plotting/common/heatmap.py index 1a5333369..2db321703 100644 --- a/analytics/plotting/common/heatmap.py +++ b/analytics/plotting/common/heatmap.py @@ -1,100 +1,272 @@ -from pathlib import Path +from typing import Any, Literal, cast +import matplotlib.patches as patches import pandas as pd import seaborn as sns from matplotlib import pyplot as plt +from matplotlib.axes import Axes from matplotlib.figure import Figure from matplotlib.ticker import MaxNLocator # Create the heatmap from analytics.plotting.common.common import init_plot +from analytics.plotting.common.const import DOUBLE_FIG_HEIGHT, DOUBLE_FIG_WIDTH +from analytics.plotting.common.font import setup_font + + +def get_fractional_index(dates: pd.Index, query_date: pd.Timestamp, fractional: bool = True) -> float: + """Given a list of Period objects (dates) and a query_date as a Period, + return the interpolated fractional index between two period indices if the + query_date lies between them.""" + # Ensure query_date is within the bounds of the period range + if query_date < dates[0].start_time: + return -1 # -1 before first index + + if query_date > dates[-1].start_time: + return len(dates) # +1 after last index + + # Find the two periods where the query_date falls in between + for i in range(len(dates) - 1): + if dates[i].start_time <= query_date <= dates[i + 1].start_time: + # Perform linear interpolation, assuming equal length periods + return i + ( + ((query_date - dates[i].start_time) / (dates[i + 1].start_time - dates[i].start_time)) + if fractional + else 0 + ) + + # If query_date is exactly one of the dates + return dates.get_loc(query_date) def build_heatmap( heatmap_data: pd.DataFrame, - y_ticks: list[int] | None = None, + y_ticks: list[int] | list[str] | None = None, y_ticks_bins: int | None = None, + x_ticks: list[int] | None = None, + x_custom_ticks: list[tuple[int, str]] | None = None, # (position, label) + y_custom_ticks: list[tuple[int, str]] | None = None, # (position, label) reverse_col: bool = False, y_label: str = "Reference Year", x_label: str = "Current Year", color_label: str = "Accuracy %", -) -> Figure: + title_label: str = "", + target_ax: Axes | None = None, + height_factor: float = 1.0, + width_factor: float = 1.0, + square: bool = False, + cbar: bool = True, + vmin: float | None = None, + vmax: float | None = None, + policy: list[tuple[int, int, int]] = [], + cmap: Any | None = None, + linewidth: int = 2, + grid_alpha: float = 0.0, + disable_horizontal_grid: bool = False, + df_logs_models: pd.DataFrame | None = None, + triggers: dict[int, pd.DataFrame] = {}, + x_axis: Literal["year", "other"] = "year", +) -> Figure | Axes: init_plot() - # sns.set_theme(style="ticks") - plt.rcParams["svg.fonttype"] = "none" - - double_fig_width = 10 - double_fig_height = 3.5 + setup_font(small_label=True, small_title=True) - fig = plt.figure( - edgecolor="black", - frameon=True, - figsize=(double_fig_width, 2.2 * double_fig_height), - dpi=300, - ) + if not target_ax: + fig = plt.figure( + edgecolor="black", + frameon=True, + figsize=( + DOUBLE_FIG_WIDTH * width_factor, + 2 * DOUBLE_FIG_HEIGHT * height_factor, + ), + dpi=600, + ) ax = sns.heatmap( heatmap_data, - cmap="RdBu" + ("_r" if reverse_col else ""), + cmap=("RdBu" + ("_r" if reverse_col else "")) if not cmap else cmap, linewidths=0.0, - linecolor="black", - cbar=True, + linecolor="white", # color bar from 0 to 1 cbar_kws={ "label": color_label, # "ticks": [0, 25, 50, 75, 100], "orientation": "vertical", }, + ax=target_ax, + square=square, + **{ + "vmin": vmin if vmin is not None else heatmap_data.min().min(), + "vmax": vmax if vmax is not None else heatmap_data.max().max(), + "cbar": cbar, + }, ) + + # Rasterize the heatmap background to avoid anti-aliasing artifacts ax.collections[0].set_rasterized(True) - # Adjust x-axis tick labels - plt.xlabel(x_label) - plt.xticks( - ticks=[x + 0.5 for x in range(0, 2010 - 1930 + 1, 20)], - labels=[x for x in range(1930, 2010 + 1, 20)], - rotation=0, - # ha='right' + rect = patches.Rectangle( + (0, 0), + heatmap_data.shape[1], + heatmap_data.shape[0], + linewidth=2, + edgecolor="black", + facecolor="none", ) + ax.add_patch(rect) + + # Adjust x-axis tick labels + ax.set_xlabel(x_label) + if not x_ticks and not x_custom_ticks: + ax.set_xticks( + ticks=[x + 0.5 for x in range(0, 2010 - 1930 + 1, 20)], + labels=[x for x in range(1930, 2010 + 1, 20)], + rotation=0, + # ha='right' + ) + else: + if x_custom_ticks: + ax.set_xticks( + ticks=[x[0] for x in x_custom_ticks], + labels=[x[1] for x in x_custom_ticks], + rotation=0, + # ha='right' + ) + else: + assert x_ticks is not None + ax.set_xticks( + ticks=[x - 1930 + 0.5 for x in x_ticks], + labels=[x for x in x_ticks], + rotation=0, + # ha='right' + ) ax.invert_yaxis() + ax.grid( + axis="y", + linestyle="--", + alpha=0 if disable_horizontal_grid else grid_alpha, + color="white", + ) + ax.grid(axis="x", linestyle="--", alpha=grid_alpha, color="white") + if y_ticks is not None: - plt.yticks(ticks=[y + 0.5 - 1930 for y in y_ticks], labels=[y for y in y_ticks], rotation=0) + ax.set_yticks( + ticks=[int(y) + 0.5 - 1930 for y in y_ticks], + labels=[y for y in y_ticks], + rotation=0, + ) elif y_ticks_bins is not None: ax.yaxis.set_major_locator(MaxNLocator(nbins=y_ticks_bins)) ax.set_yticklabels([int(i) + min(heatmap_data.index) for i in ax.get_yticks()], rotation=0) + else: + if y_custom_ticks: + ax.set_yticks( + ticks=[y[0] for y in y_custom_ticks], + labels=[y[1] for y in y_custom_ticks], + rotation=0, + # ha='right' + ) + + ax.set_ylabel(y_label) + + if title_label: + ax.set_title(title_label) - plt.ylabel(y_label) - - # # Draft training boxes - # if drift_pipeline: - # for type_, dashed in [("train", False), ("usage", False), ("train", True)]: - # for active_ in df_logs_models.iterrows(): - # x_start = active_[1][f"{type_}_start"].year - 1930 - # x_end = active_[1][f"{type_}_end"].year - 1930 - # y = active_[1]["model_idx"] - # rect = plt.Rectangle( - # (x_start, y - 1), # y: 0 based index, model_idx: 1 based index - # x_end - x_start, - # 1, - # edgecolor="White" if type_ == "train" else "Black", - # facecolor="none", - # linewidth=3, - # linestyle="dotted" if dashed else "solid", - # hatch="/", - # joinstyle="bevel", - # # capstyle="round", - # ) - # ax.add_patch(rect) + # mainly for offline expore + previous_y = 0 + for x_start, x_end, y in policy: + # main box + rect = plt.Rectangle( + (x_start, y), # y: 0 based index, model_idx: 1 based index + x_end - x_start, + 1, + edgecolor="White", + facecolor="none", + linewidth=linewidth, + linestyle="solid", + hatch="/", + joinstyle="bevel", + # capstyle="round", + ) + ax.add_patch(rect) + + # connector + connector = plt.Rectangle( + (x_start, previous_y), # y: 0 based index, model_idx: 1 based index + 0, + y - previous_y + 1, + edgecolor="White", + facecolor="none", + linewidth=linewidth, + linestyle="solid", + hatch="/", + joinstyle="bevel", + # capstyle="round", + ) + ax.add_patch(connector) + previous_y = y + + # for post factum evaluation + if df_logs_models is not None: + for type_, dashed in [("train", False), ("usage", False), ("train", True)]: + for active_ in df_logs_models.iterrows(): + if x_axis == "year": + eval_x_start = active_[1][f"{type_}_start"].year - 1930 + eval_x_end = active_[1][f"{type_}_end"].year - 1930 + else: + eval_x_start = get_fractional_index( + heatmap_data.columns, + cast(pd.Index, active_[1][f"{type_}_start"]), + fractional=False, + ) + eval_x_end = get_fractional_index( + heatmap_data.columns, + cast(pd.Index, active_[1][f"{type_}_end"]), + fractional=False, + ) + + y = active_[1]["model_idx"] + rect = plt.Rectangle( + ( + eval_x_start, + y - 1, + ), # y: 0 based index, model_idx: 1 based index + eval_x_end - eval_x_start, + 1, + edgecolor="White" if type_ == "train" else "Black", + facecolor="none", + linewidth=1.5, + linestyle="dotted" if dashed else "solid", + hatch="/", + joinstyle="bevel", + # capstyle="round", + ) + ax.add_patch(rect) + + if triggers: + for y, triggers_df in triggers.items(): + for row in triggers_df.iterrows(): + type_ = "usage" + # for y, x_list in triggers.items(): + eval_x_start = row[1][f"{type_}_start"].year - 1930 + eval_x_end = row[1][f"{type_}_end"].year - 1930 + # for x in x_list: + rect = plt.Rectangle( + (eval_x_start, y), # y: 0 based index, model_idx: 1 based index + eval_x_end - eval_x_start, + 1, + edgecolor="black", + facecolor="none", + linewidth=1, + # linestyle="dotted", + # hatch="/", + # joinstyle="bevel", + # capstyle="round", + ) + ax.add_patch(rect) # Display the plot plt.tight_layout() # plt.show() - return fig - - -def save_plot(fig: Figure, name: str) -> None: - for img_type in ["png", "svg"]: - img_path = Path("/scratch/robinholzi/gh/modyn/.data/plots") / f"{name}.{img_type}" - fig.savefig(img_path, bbox_inches="tight", transparent=True) + return fig if not target_ax else ax diff --git a/analytics/plotting/common/linear_regression_scatterplot.py b/analytics/plotting/common/linear_regression_scatterplot.py new file mode 100644 index 000000000..465a7b8c5 --- /dev/null +++ b/analytics/plotting/common/linear_regression_scatterplot.py @@ -0,0 +1,103 @@ +from typing import Any + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +from matplotlib.axes import Axes +from matplotlib.figure import Figure + +from analytics.plotting.common.color import main_color +from analytics.plotting.common.common import init_plot +from analytics.plotting.common.font import setup_font + + +def scatter_linear_regression( + data: pd.DataFrame, + x: str, + y: str, + hue: str, + y_ticks: list[int] | list[str] | None = None, + x_ticks: list[int] | None = None, + y_label: str = "Reference Year", + x_label: str = "Current Year", + height_factor: float = 1.0, + width_factor: float = 1.0, + legend_label: str = "Number Samples", + title_label: str = "", + target_ax: Axes | None = None, + palette: Any = None, + small_legend_fonts: bool = False, +) -> Figure | tuple[Axes, Axes]: + sns.set_style("whitegrid") + + init_plot() + setup_font(small_label=True, small_title=True) + + DOUBLE_FIG_WIDTH = 10 + DOUBLE_FIG_HEIGHT = 3.5 + + if not target_ax: + fig = plt.figure( + edgecolor="black", + frameon=True, + figsize=( + DOUBLE_FIG_WIDTH * width_factor, + 2 * DOUBLE_FIG_HEIGHT * height_factor, + ), + dpi=600, + ) + + ax1 = sns.regplot( + data, + x=x, + y=y, # duration + color=main_color(0), + ) + + ax2 = sns.scatterplot( + data, + x=x, + y=y, # duration + hue=hue, + palette=palette, + s=200, + legend=True, + marker="X", + ) + + ax2.legend( + title=legend_label, + ncol=2, + handletextpad=0, + columnspacing=0.5, + **({"fontsize": "x-small"} if small_legend_fonts else {}), + ) + + # Adjust x-axis tick labels + ax2.set_xlabel(x_label) + if x_ticks is not None: + ax2.set_xticks( + ticks=x_ticks, + labels=x_ticks, + rotation=0, + # ha='right' + ) + + if y_ticks is not None: + ax2.set_yticks( + ticks=y_ticks, + labels=y_ticks, + rotation=0, + ) + + ax2.set_ylabel(y_label) + + if title_label: + ax2.set_title(title_label) + + print("Number of plotted items", data.shape[0]) + + # Display the plot + plt.tight_layout() + + return fig if not target_ax else (ax1, ax2) diff --git a/analytics/plotting/common/metric_over_time.py b/analytics/plotting/common/metric_over_time.py new file mode 100644 index 000000000..d18c8980d --- /dev/null +++ b/analytics/plotting/common/metric_over_time.py @@ -0,0 +1,104 @@ +import matplotlib.dates as mdates +import pandas as pd +import seaborn as sns +from matplotlib import pyplot as plt +from matplotlib.axes import Axes +from matplotlib.figure import Figure + +from analytics.plotting.common.color import main_color +from analytics.plotting.common.common import ( + DOUBLE_FIG_HEIGHT, + DOUBLE_FIG_WIDTH, + init_plot, +) +from analytics.plotting.common.font import setup_font + + +def plot_metric_over_time( + data: pd.DataFrame, + x: str = "time", + y: str = "value", + hue: str = "pipeline_ref", + style: str = "pipeline_ref", + y_label: str = "Reference Year", + x_label: str = "Current Year", + title_label: str = "", + target_ax: Axes | None = None, + height_factor: float = 1.0, + width_factor: float = 1.0, + grid_alpha: float = 0.0, + legend_label: str = "TODO", + small_legend_fonts: bool = False, + x_date_locator: mdates.DateLocator | None = None, + x_date_formatter: mdates.DateFormatter | None = None, + y_ticks: list[int] | None = None, + xlim: tuple[int, int] | None = None, + ylim: tuple[int, int] | None = None, + markers: bool = True, +) -> Figure | Axes: + sns.set_style("whitegrid") + init_plot() + setup_font(small_label=False, small_title=False, small_ticks=False) + + if not target_ax: + fig = plt.figure( + edgecolor="black", + frameon=True, + figsize=( + DOUBLE_FIG_WIDTH * width_factor, + 2 * DOUBLE_FIG_HEIGHT * height_factor, + ), + dpi=600, + ) + + ax = sns.lineplot( + data, + x=x, + y=y, + hue=hue, + markersize=7, + # line width + linewidth=2.5, + palette=[ + main_color(0), + main_color(1), + main_color(3), + main_color(4), + main_color(5), + main_color(6), + ], + style=style, + markers=markers, + ) + + if xlim: + ax.set(xlim=xlim) + + if ylim: + ax.set(ylim=ylim) + + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + + ax.legend( + title=legend_label, + ncol=2, + handletextpad=1, + columnspacing=1.4, + **({"fontsize": "x-small"} if small_legend_fonts else {}), + ) + + if x_date_locator: + ax.xaxis.set_major_locator(x_date_locator) + # ax.set_xticklabels(x_ticks, rotation=0) + ax.xaxis.set_major_formatter(x_date_formatter) + # ticks = ax.get_xticks() + plt.xticks(rotation=0) + + if y_ticks: + ax.set_yticks(y_ticks) + + # Display the plot + plt.tight_layout() + + return fig if not target_ax else ax diff --git a/analytics/plotting/common/save.py b/analytics/plotting/common/save.py new file mode 100644 index 000000000..b6be7fcc4 --- /dev/null +++ b/analytics/plotting/common/save.py @@ -0,0 +1,17 @@ +from pathlib import Path + +import pandas as pd +from matplotlib.figure import Figure + + +def save_plot(fig: Figure, name: str) -> None: + for img_type in ["png", "svg", "pdf"]: + img_path = Path(".data/_plots") / f"{name}.{img_type}" + img_path.parent.mkdir(exist_ok=True, parents=True) + fig.savefig(img_path, bbox_inches="tight", transparent=True) + + +def save_csv_df(df: pd.DataFrame, name: str) -> None: + csv_path = Path(".data/csv") / f"{name}.csv" + csv_path.parent.mkdir(exist_ok=True, parents=True) + df.to_csv(csv_path, index=False) diff --git a/analytics/plotting/common/tradeoff_scatterplot.py b/analytics/plotting/common/tradeoff_scatterplot.py new file mode 100644 index 000000000..3aae2a46b --- /dev/null +++ b/analytics/plotting/common/tradeoff_scatterplot.py @@ -0,0 +1,87 @@ +import pandas as pd +import seaborn as sns +from matplotlib import pyplot as plt +from matplotlib.axes import Axes +from matplotlib.figure import Figure + +from analytics.plotting.common.color import main_color +from analytics.plotting.common.common import init_plot +from analytics.plotting.common.font import setup_font + + +def plot_tradeoff_scatter( + data: pd.DataFrame, + x: str, + y: str, + hue: str, + style: str, + x_label: str = "Number of Triggers", + y_label: str = "Mean Accuracy %", + height_factor: float = 1.0, + width_factor: float = 1.0, + target_ax: Axes | None = None, + manual_legend_title: bool = True, + legend_ncol: int = 1, +) -> Figure: + sns.set_theme(style="whitegrid") + init_plot() + setup_font(small_label=True, small_title=True, small_ticks=True) + + DOUBLE_FIG_WIDTH = 10 + DOUBLE_FIG_HEIGHT = 3.5 + + if not target_ax: + fig = plt.figure( + edgecolor="black", + frameon=True, + figsize=( + DOUBLE_FIG_WIDTH * width_factor, + 2 * DOUBLE_FIG_HEIGHT * height_factor, + ), + dpi=600, + ) + + ax = sns.scatterplot( + data, + x=x, + y=y, + hue=hue, + style=style, + palette=[ + main_color(0), + main_color(1), + main_color(3), + main_color(4), + main_color(5), + ], + s=300, + # legend=False, + # marker="X", + ) + + ax.legend( + fontsize="small", + title_fontsize="medium", + # title="Pipeline", + **( + { + "title": hue, + } + if manual_legend_title + else {} + ), + # 2 columns + ncol=legend_ncol, + ) + + # Adjust x-axis tick labels + plt.xlabel(x_label, labelpad=10) + + # Set y-axis ticks to be equally spaced + plt.ylabel(y_label, labelpad=15) + + # Display the plot + plt.tight_layout() + plt.show() + + return fig