Skip to content

Commit c9ae430

Browse files
authored
feat(zero_shot): add win rate chart generator (#57)
* feat(zero_shot): add win rate chart generator - Add WinRateChartGenerator class for visualizing model rankings - Support customizable chart styles, colors, and annotations - Add matplotlib dependency to pyproject.toml - Update schema with ChartConfig dataclass - Integrate chart generation into zero_shot_pipeline * refactor(zero_shot): remove checkpoint module and inline logic - Remove standalone checkpoint.py module - Inline checkpoint functionality into zero_shot_pipeline.py - Simplify code structure * fix(zero_shot): require report.enabled for chart generation * refactor(chart): simplify config handling and remove redundant code - Use direct import instead of TYPE_CHECKING for ChartConfig - Initialize default ChartConfig in constructor - Remove redundant hatch pattern logic - Fix y-axis limit edge case when win rates are low * chore(deps): move matplotlib to dev dependencies
1 parent 4659eac commit c9ae430

File tree

7 files changed

+399
-198
lines changed

7 files changed

+399
-198
lines changed
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
# -*- coding: utf-8 -*-
2+
"""Chart generator for zero-shot evaluation results.
3+
4+
This module provides visualization capabilities for evaluation results,
5+
generating beautiful bar charts to display model win rates.
6+
"""
7+
8+
from pathlib import Path
9+
from typing import List, Optional, Tuple
10+
11+
from loguru import logger
12+
13+
from cookbooks.zero_shot_evaluation.schema import ChartConfig
14+
15+
16+
class WinRateChartGenerator:
17+
"""Generator for win rate comparison charts.
18+
19+
Creates visually appealing bar charts showing model rankings
20+
based on pairwise evaluation results.
21+
22+
Attributes:
23+
config: Chart configuration options
24+
25+
Example:
26+
>>> generator = WinRateChartGenerator(config)
27+
>>> path = generator.generate(
28+
... rankings=[("GPT-4", 0.73), ("Claude", 0.65)],
29+
... output_dir="./results",
30+
... task_description="Translation evaluation",
31+
... )
32+
"""
33+
34+
# Color palette - inspired by modern data visualization
35+
ACCENT_COLOR = "#FF6B35" # Vibrant orange for best model
36+
ACCENT_HATCH = "///" # Diagonal stripes pattern
37+
BAR_COLORS = [
38+
"#4A4A4A", # Dark gray
39+
"#6B6B6B", # Medium gray
40+
"#8C8C8C", # Light gray
41+
"#ADADAD", # Lighter gray
42+
"#CECECE", # Very light gray
43+
]
44+
45+
def __init__(self, config: Optional[ChartConfig] = None):
46+
"""Initialize chart generator.
47+
48+
Args:
49+
config: Chart configuration. Uses defaults if not provided.
50+
"""
51+
self.config = config or ChartConfig()
52+
53+
def _configure_cjk_font(self, plt, font_manager) -> Optional[str]:
54+
"""Configure matplotlib to support CJK (Chinese/Japanese/Korean) characters.
55+
56+
Attempts to find and use a system font that supports CJK characters.
57+
Falls back gracefully if no suitable font is found.
58+
59+
Returns:
60+
Font name if found, None otherwise
61+
"""
62+
# Common CJK fonts on different platforms (simplified Chinese priority)
63+
cjk_fonts = [
64+
# macOS - Simplified Chinese (verified available)
65+
"Hiragino Sans GB",
66+
"Songti SC",
67+
"Kaiti SC",
68+
"Heiti SC",
69+
"Lantinghei SC",
70+
"PingFang SC",
71+
"STFangsong",
72+
# Windows
73+
"Microsoft YaHei",
74+
"SimHei",
75+
"SimSun",
76+
# Linux
77+
"Noto Sans CJK SC",
78+
"WenQuanYi Micro Hei",
79+
"Droid Sans Fallback",
80+
# Generic
81+
"Arial Unicode MS",
82+
]
83+
84+
# Get available fonts
85+
available_fonts = {f.name for f in font_manager.fontManager.ttflist}
86+
87+
# Find the first available CJK font
88+
for font_name in cjk_fonts:
89+
if font_name in available_fonts:
90+
plt.rcParams["font.sans-serif"] = [font_name] + plt.rcParams.get("font.sans-serif", [])
91+
plt.rcParams["axes.unicode_minus"] = False # Fix minus sign display
92+
logger.debug(f"Using CJK font: {font_name}")
93+
return font_name
94+
95+
# No CJK font found, log warning
96+
logger.warning(
97+
"No CJK font found. Chinese characters may not display correctly. "
98+
"Consider installing a CJK font like 'Noto Sans CJK SC'."
99+
)
100+
return None
101+
102+
def generate(
103+
self,
104+
rankings: List[Tuple[str, float]],
105+
output_dir: str,
106+
task_description: Optional[str] = None,
107+
total_queries: int = 0,
108+
total_comparisons: int = 0,
109+
) -> Optional[Path]:
110+
"""Generate win rate bar chart.
111+
112+
Args:
113+
rankings: List of (model_name, win_rate) tuples, sorted by win rate
114+
output_dir: Directory to save the chart
115+
task_description: Task description for subtitle
116+
total_queries: Number of queries evaluated
117+
total_comparisons: Number of pairwise comparisons
118+
119+
Returns:
120+
Path to saved chart file, or None if generation failed
121+
"""
122+
if not rankings:
123+
logger.warning("No rankings data to visualize")
124+
return None
125+
126+
try:
127+
import matplotlib.patches as mpatches
128+
import matplotlib.pyplot as plt
129+
import numpy as np
130+
from matplotlib import font_manager
131+
except ImportError:
132+
logger.warning("matplotlib not installed. Install with: pip install matplotlib")
133+
return None
134+
135+
# Extract config values (defaults are centralized in ChartConfig schema)
136+
figsize = self.config.figsize
137+
dpi = self.config.dpi
138+
fmt = self.config.format
139+
show_values = self.config.show_values
140+
highlight_best = self.config.highlight_best
141+
custom_title = self.config.title
142+
143+
# Prepare data (already sorted high to low)
144+
model_names = [r[0] for r in rankings]
145+
win_rates = [r[1] * 100 for r in rankings] # Convert to percentage
146+
n_models = len(model_names)
147+
148+
# Setup figure with modern styling (MUST be before font config)
149+
plt.style.use("seaborn-v0_8-whitegrid")
150+
151+
# Configure font for CJK (Chinese/Japanese/Korean) support
152+
# This MUST be after plt.style.use() as style resets font settings
153+
self._configure_cjk_font(plt, font_manager)
154+
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
155+
156+
# Create bar positions
157+
x_pos = np.arange(n_models)
158+
bar_width = 0.6
159+
160+
# Determine colors for each bar
161+
colors = []
162+
edge_colors = []
163+
164+
for i in range(n_models):
165+
if i == 0 and highlight_best:
166+
# Best model gets accent color
167+
colors.append(self.ACCENT_COLOR)
168+
edge_colors.append(self.ACCENT_COLOR)
169+
else:
170+
# Other models get grayscale
171+
color_idx = min(i - 1, len(self.BAR_COLORS) - 1) if highlight_best else min(i, len(self.BAR_COLORS) - 1)
172+
colors.append(self.BAR_COLORS[color_idx])
173+
edge_colors.append(self.BAR_COLORS[color_idx])
174+
175+
# Draw bars
176+
bars = ax.bar(
177+
x_pos,
178+
win_rates,
179+
width=bar_width,
180+
color=colors,
181+
edgecolor=edge_colors,
182+
linewidth=1.5,
183+
zorder=3,
184+
)
185+
186+
# Add hatch pattern to best model
187+
if highlight_best and n_models > 0:
188+
bars[0].set_hatch(self.ACCENT_HATCH)
189+
bars[0].set_edgecolor("white")
190+
191+
# Add value labels on top of bars
192+
if show_values:
193+
for i, (bar, rate) in enumerate(zip(bars, win_rates)):
194+
height = bar.get_height()
195+
ax.annotate(
196+
f"{rate:.1f}",
197+
xy=(bar.get_x() + bar.get_width() / 2, height),
198+
xytext=(0, 5),
199+
textcoords="offset points",
200+
ha="center",
201+
va="bottom",
202+
fontsize=12,
203+
fontweight="bold",
204+
color="#333333",
205+
)
206+
207+
# Customize axes
208+
ax.set_xticks(x_pos)
209+
ax.set_xticklabels(model_names, fontsize=11, fontweight="medium")
210+
ax.set_ylabel("Win Rate (%)", fontsize=12, fontweight="medium", labelpad=10)
211+
ax.set_ylim(0, max(10, min(100, max(win_rates) * 1.15))) # Add headroom for labels
212+
213+
# Remove top and right spines
214+
ax.spines["top"].set_visible(False)
215+
ax.spines["right"].set_visible(False)
216+
ax.spines["left"].set_color("#CCCCCC")
217+
ax.spines["bottom"].set_color("#CCCCCC")
218+
219+
# Customize grid
220+
ax.yaxis.grid(True, linestyle="--", alpha=0.5, color="#DDDDDD", zorder=0)
221+
ax.xaxis.grid(False)
222+
223+
# Title
224+
title = custom_title or "Model Win Rate Comparison"
225+
ax.set_title(title, fontsize=16, fontweight="bold", pad=20, color="#333333")
226+
227+
# Subtitle with evaluation info
228+
subtitle_parts = []
229+
if task_description:
230+
# Truncate long descriptions
231+
desc = task_description[:60] + "..." if len(task_description) > 60 else task_description
232+
subtitle_parts.append(f"Task: {desc}")
233+
if total_queries > 0:
234+
subtitle_parts.append(f"Queries: {total_queries}")
235+
if total_comparisons > 0:
236+
subtitle_parts.append(f"Comparisons: {total_comparisons}")
237+
238+
if subtitle_parts:
239+
subtitle = " | ".join(subtitle_parts)
240+
ax.text(
241+
0.5,
242+
1.02,
243+
subtitle,
244+
transform=ax.transAxes,
245+
ha="center",
246+
va="bottom",
247+
fontsize=10,
248+
color="#666666",
249+
style="italic",
250+
)
251+
252+
# Create legend
253+
legend_elements = []
254+
if highlight_best and n_models > 0:
255+
best_patch = mpatches.Patch(
256+
facecolor=self.ACCENT_COLOR,
257+
edgecolor="white",
258+
hatch=self.ACCENT_HATCH,
259+
label=f"Best: {model_names[0]}",
260+
)
261+
legend_elements.append(best_patch)
262+
263+
if legend_elements:
264+
ax.legend(
265+
handles=legend_elements,
266+
loc="upper right",
267+
frameon=True,
268+
framealpha=0.9,
269+
fontsize=10,
270+
)
271+
272+
# Tight layout
273+
plt.tight_layout()
274+
275+
# Save chart
276+
output_path = Path(output_dir)
277+
output_path.mkdir(parents=True, exist_ok=True)
278+
chart_file = output_path / f"win_rate_chart.{fmt}"
279+
280+
plt.savefig(
281+
chart_file,
282+
format=fmt,
283+
dpi=dpi,
284+
bbox_inches="tight",
285+
facecolor="white",
286+
edgecolor="none",
287+
)
288+
plt.close(fig)
289+
290+
logger.info(f"Win rate chart saved to {chart_file}")
291+
return chart_file

0 commit comments

Comments
 (0)