Skip to content

Commit 4e68cb2

Browse files
Prints benchmark results in a neat table and attempts to run every benchmark (#1464)
* log benchmark results as table * update with comments * remove redundants * ds * in markdown format * fix
1 parent 4de1272 commit 4e68cb2

File tree

2 files changed

+266
-31
lines changed

2 files changed

+266
-31
lines changed

backend-comparison/src/burnbenchapp/base.rs

+65-25
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
1+
use super::{
2+
auth::{save_token, CLIENT_ID},
3+
App,
4+
};
5+
use crate::burnbenchapp::auth::{get_token_from_cache, verify_token};
6+
use crate::persistence::{BenchmarkCollection, BenchmarkRecord};
17
use arboard::Clipboard;
28
use clap::{Parser, Subcommand, ValueEnum};
39
use github_device_flow::{self, DeviceFlow};
10+
use serde_json;
11+
use std::fs;
12+
use std::io::{BufRead, BufReader, Result as ioResult};
413
use std::{
5-
process::{Command, Stdio},
14+
process::{Command, ExitStatus, Stdio},
615
thread, time,
716
};
17+
818
use strum::IntoEnumIterator;
919
use strum_macros::{Display, EnumIter};
10-
11-
use crate::burnbenchapp::auth::{get_token_from_cache, verify_token};
12-
13-
use super::{
14-
auth::{save_token, CLIENT_ID},
15-
App,
16-
};
17-
1820
const FIVE_SECONDS: time::Duration = time::Duration::new(5, 0);
1921
const BENCHMARKS_TARGET_DIR: &str = "target/benchmarks";
2022
const USER_BENCHMARK_SERVER_URL: &str = if cfg!(debug_assertions) {
@@ -184,17 +186,12 @@ fn command_run(run_args: RunArgs) {
184186
}
185187
let total_combinations = run_args.backends.len() * run_args.benches.len();
186188
println!(
187-
"Executing the following benchmark and backend combinations (Total: {}):",
189+
"Executing benchmark and backend combinations in total: {}",
188190
total_combinations
189191
);
190-
for backend in &run_args.backends {
191-
for bench in &run_args.benches {
192-
println!("- Benchmark: {}, Backend: {}", bench, backend);
193-
}
194-
}
195192
let mut app = App::new();
196193
app.init();
197-
println!("Running benchmarks...");
194+
println!("Running benchmarks...\n");
198195
app.run(
199196
&run_args.benches,
200197
&run_args.backends,
@@ -204,7 +201,7 @@ fn command_run(run_args: RunArgs) {
204201
}
205202

206203
#[allow(unused)] // for tui as this is WIP
207-
pub(crate) fn run_cargo(command: &str, params: &[&str]) {
204+
pub(crate) fn run_cargo(command: &str, params: &[&str]) -> ioResult<ExitStatus> {
208205
let mut cargo = Command::new("cargo")
209206
.arg(command)
210207
.arg("--color=always")
@@ -213,22 +210,36 @@ pub(crate) fn run_cargo(command: &str, params: &[&str]) {
213210
.stderr(Stdio::inherit())
214211
.spawn()
215212
.expect("cargo process should run");
216-
let status = cargo.wait().expect("");
217-
if !status.success() {
218-
std::process::exit(status.code().unwrap_or(1));
219-
}
213+
cargo.wait()
220214
}
221215

222216
pub(crate) fn run_backend_comparison_benchmarks(
223217
benches: &[BenchmarkValues],
224218
backends: &[BackendValues],
225219
token: Option<&str>,
226220
) {
227-
// Iterate over each combination of backend and bench
228-
for backend in backends.iter() {
229-
for bench in benches.iter() {
221+
// Prefix and postfix for titles
222+
let filler = ["="; 10].join("");
223+
224+
// Delete the file containing file paths to benchmark results, if existing
225+
let benchmark_results_file = dirs::home_dir()
226+
.expect("Home directory should exist")
227+
.join(".cache")
228+
.join("burn")
229+
.join("backend-comparison")
230+
.join("benchmark_results.txt");
231+
232+
fs::remove_file(benchmark_results_file.clone()).ok();
233+
234+
// Iterate through every combination of benchmark and backend
235+
for bench in benches.iter() {
236+
for backend in backends.iter() {
230237
let bench_str = bench.to_string();
231238
let backend_str = backend.to_string();
239+
println!(
240+
"{}Benchmarking {} on {}{}",
241+
filler, bench_str, backend_str, filler
242+
);
232243
let mut args = vec![
233244
"-p",
234245
"backend-comparison",
@@ -246,7 +257,36 @@ pub(crate) fn run_backend_comparison_benchmarks(
246257
args.push("--sharing-token");
247258
args.push(t);
248259
}
249-
run_cargo("bench", &args);
260+
let status = run_cargo("bench", &args).unwrap();
261+
if !status.success() {
262+
println!(
263+
"Benchmark {} didn't ran successfully on the backend {}",
264+
bench_str, backend_str
265+
);
266+
continue;
267+
}
268+
}
269+
}
270+
271+
// Iterate though each benchmark result file present in backend-comparison/benchmark_results.txt
272+
// and print them in a single table.
273+
let mut benchmark_results = BenchmarkCollection::default();
274+
if let Ok(file) = fs::File::open(benchmark_results_file.clone()) {
275+
let file_reader = BufReader::new(file);
276+
for file in file_reader.lines() {
277+
let file_path = file.unwrap();
278+
if let Ok(br_file) = fs::File::open(file_path.clone()) {
279+
let benchmarkrecord =
280+
serde_json::from_reader::<_, BenchmarkRecord>(br_file).unwrap();
281+
benchmark_results.records.push(benchmarkrecord)
282+
} else {
283+
println!("Cannot find the benchmark-record file: {}", file_path);
284+
};
250285
}
286+
println!(
287+
"{}Benchmark Results{}\n\n{}",
288+
filler, filler, benchmark_results
289+
);
290+
fs::remove_file(benchmark_results_file).ok();
251291
}
252292
}

backend-comparison/src/persistence/base.rs

+201-6
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
1-
use std::fs;
2-
31
use burn::{
4-
serde::{ser::SerializeStruct, Serialize, Serializer},
2+
serde::{de::Visitor, ser::SerializeStruct, Deserialize, Serialize, Serializer},
53
tensor::backend::Backend,
64
};
75
use burn_common::benchmark::BenchmarkResult;
86
use dirs;
97
use reqwest::header::{HeaderMap, ACCEPT, AUTHORIZATION, USER_AGENT};
108
use serde_json;
11-
9+
use std::fmt::Display;
10+
use std::time::Duration;
11+
use std::{fs, io::Write};
1212
#[derive(Default, Clone)]
1313
pub struct BenchmarkRecord {
1414
backend: String,
1515
device: String,
16-
results: BenchmarkResult,
16+
pub results: BenchmarkResult,
1717
}
1818

1919
/// Save the benchmarks results on disk.
@@ -77,10 +77,22 @@ pub fn save<B: Backend>(
7777
record.results.name, record.results.timestamp
7878
);
7979
let file_path = cache_dir.join(file_name);
80-
let file = fs::File::create(file_path).expect("Benchmark file should exist or be created");
80+
let file =
81+
fs::File::create(file_path.clone()).expect("Benchmark file should exist or be created");
8182
serde_json::to_writer_pretty(file, &record)
8283
.expect("Benchmark file should be updated with benchmark results");
8384

85+
// Append the benchmark result filepath in the benchmark_results.tx file of cache folder to be later picked by benchrun
86+
let benchmark_results_path = cache_dir.join("benchmark_results.txt");
87+
let mut benchmark_results_file = fs::OpenOptions::new()
88+
.append(true)
89+
.create(true)
90+
.open(benchmark_results_path)
91+
.unwrap();
92+
benchmark_results_file
93+
.write_all(format!("{}\n", file_path.to_string_lossy()).as_bytes())
94+
.unwrap();
95+
8496
if url.is_some() {
8597
println!("Sharing results...");
8698
let client = reqwest::blocking::Client::new();
@@ -154,3 +166,186 @@ impl Serialize for BenchmarkRecord {
154166
)
155167
}
156168
}
169+
170+
struct BenchmarkRecordVisitor;
171+
172+
impl<'de> Visitor<'de> for BenchmarkRecordVisitor {
173+
type Value = BenchmarkRecord;
174+
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
175+
write!(formatter, "Serialized Json object of BenchmarkRecord")
176+
}
177+
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
178+
where
179+
A: burn::serde::de::MapAccess<'de>,
180+
{
181+
let mut br = BenchmarkRecord::default();
182+
while let Some(key) = map.next_key::<String>()? {
183+
match key.as_str() {
184+
"backend" => br.backend = map.next_value::<String>()?,
185+
"device" => br.device = map.next_value::<String>()?,
186+
"gitHash" => br.results.git_hash = map.next_value::<String>()?,
187+
"name" => br.results.name = map.next_value::<String>()?,
188+
"max" => {
189+
let value = map.next_value::<u64>()?;
190+
br.results.computed.max = Duration::from_micros(value);
191+
}
192+
"mean" => {
193+
let value = map.next_value::<u64>()?;
194+
br.results.computed.mean = Duration::from_micros(value);
195+
}
196+
"median" => {
197+
let value = map.next_value::<u64>()?;
198+
br.results.computed.median = Duration::from_micros(value);
199+
}
200+
"min" => {
201+
let value = map.next_value::<u64>()?;
202+
br.results.computed.min = Duration::from_micros(value);
203+
}
204+
"options" => br.results.options = map.next_value::<Option<String>>()?,
205+
"rawDurations" => br.results.raw.durations = map.next_value::<Vec<Duration>>()?,
206+
"shapes" => br.results.shapes = map.next_value::<Vec<Vec<usize>>>()?,
207+
"timestamp" => br.results.timestamp = map.next_value::<u128>()?,
208+
"variance" => {
209+
let value = map.next_value::<u64>()?;
210+
br.results.computed.variance = Duration::from_micros(value)
211+
}
212+
213+
"numSamples" => _ = map.next_value::<usize>()?,
214+
_ => panic!("Unexpected Key: {}", key),
215+
}
216+
}
217+
218+
Ok(br)
219+
}
220+
}
221+
222+
impl<'de> Deserialize<'de> for BenchmarkRecord {
223+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
224+
where
225+
D: burn::serde::Deserializer<'de>,
226+
{
227+
deserializer.deserialize_map(BenchmarkRecordVisitor)
228+
}
229+
}
230+
231+
#[derive(Default)]
232+
pub(crate) struct BenchmarkCollection {
233+
pub records: Vec<BenchmarkRecord>,
234+
}
235+
236+
impl Display for BenchmarkCollection {
237+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238+
writeln!(
239+
f,
240+
"| {0:<15}| {1:<35}| {2:<15}|\n|{3:-<16}|{4:-<36}|{5:-<16}|",
241+
"Benchmark", "Backend", "Median", "", "", ""
242+
)?;
243+
for record in self.records.iter() {
244+
let backend = [record.backend.clone(), record.device.clone()].join("-");
245+
writeln!(
246+
f,
247+
"| {0:<15}| {1:<35}| {2:<15.3?}|",
248+
record.results.name, backend, record.results.computed.median
249+
)?;
250+
}
251+
252+
Ok(())
253+
}
254+
}
255+
256+
#[cfg(test)]
257+
mod tests {
258+
use super::*;
259+
260+
#[test]
261+
fn get_benchmark_result() {
262+
let sample_result = r#"{
263+
"backend": "candle",
264+
"device": "Cuda(0)",
265+
"gitHash": "02d37011ab4dc773286e5983c09cde61f95ba4b5",
266+
"name": "unary",
267+
"max": 8858,
268+
"mean": 8629,
269+
"median": 8592,
270+
"min": 8506,
271+
"numSamples": 10,
272+
"options": null,
273+
"rawDurations": [
274+
{
275+
"secs": 0,
276+
"nanos": 8858583
277+
},
278+
{
279+
"secs": 0,
280+
"nanos": 8719822
281+
},
282+
{
283+
"secs": 0,
284+
"nanos": 8705335
285+
},
286+
{
287+
"secs": 0,
288+
"nanos": 8835636
289+
},
290+
{
291+
"secs": 0,
292+
"nanos": 8592507
293+
},
294+
{
295+
"secs": 0,
296+
"nanos": 8506423
297+
},
298+
{
299+
"secs": 0,
300+
"nanos": 8534337
301+
},
302+
{
303+
"secs": 0,
304+
"nanos": 8506627
305+
},
306+
{
307+
"secs": 0,
308+
"nanos": 8521615
309+
},
310+
{
311+
"secs": 0,
312+
"nanos": 8511474
313+
}
314+
],
315+
"shapes": [
316+
[
317+
32,
318+
512,
319+
1024
320+
]
321+
],
322+
"timestamp": 1710208069697,
323+
"variance": 0
324+
}"#;
325+
let record = serde_json::from_str::<BenchmarkRecord>(sample_result).unwrap();
326+
assert!(record.backend == "candle");
327+
assert!(record.device == "Cuda(0)");
328+
assert!(record.results.git_hash == "02d37011ab4dc773286e5983c09cde61f95ba4b5");
329+
assert!(record.results.name == "unary");
330+
assert!(record.results.computed.max.as_micros() == 8858);
331+
assert!(record.results.computed.mean.as_micros() == 8629);
332+
assert!(record.results.computed.median.as_micros() == 8592);
333+
assert!(record.results.computed.min.as_micros() == 8506);
334+
assert!(record.results.options.is_none());
335+
assert!(record.results.shapes == vec![vec![32, 512, 1024]]);
336+
assert!(record.results.timestamp == 1710208069697);
337+
assert!(record.results.computed.variance.as_micros() == 0);
338+
339+
//Check raw durations
340+
assert!(record.results.raw.durations[0] == Duration::from_nanos(8858583));
341+
assert!(record.results.raw.durations[1] == Duration::from_nanos(8719822));
342+
assert!(record.results.raw.durations[2] == Duration::from_nanos(8705335));
343+
assert!(record.results.raw.durations[3] == Duration::from_nanos(8835636));
344+
assert!(record.results.raw.durations[4] == Duration::from_nanos(8592507));
345+
assert!(record.results.raw.durations[5] == Duration::from_nanos(8506423));
346+
assert!(record.results.raw.durations[6] == Duration::from_nanos(8534337));
347+
assert!(record.results.raw.durations[7] == Duration::from_nanos(8506627));
348+
assert!(record.results.raw.durations[8] == Duration::from_nanos(8521615));
349+
assert!(record.results.raw.durations[9] == Duration::from_nanos(8511474));
350+
}
351+
}

0 commit comments

Comments
 (0)