diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index 813a151776..2aeed8920f 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -28,7 +28,7 @@ burn = { path = "../burn" } derive-new = { workspace = true } rand = { workspace = true } burn-common = { path = "../burn-common", version = "0.11.0" } -serde_json = "1.0.108" +serde_json = { workspace = true } dirs = "5.0.1" [dev-dependencies] diff --git a/backend-comparison/benches/binary.rs b/backend-comparison/benches/binary.rs index d28c88bd07..b06a1ccd4b 100644 --- a/backend-comparison/benches/binary.rs +++ b/backend-comparison/benches/binary.rs @@ -1,4 +1,4 @@ -use backend_comparison::persistence::persist; +use backend_comparison::persistence::Persistence; use burn::tensor::{backend::Backend, Distribution, Shape, Tensor}; use burn_common::benchmark::{run_benchmark, Benchmark}; @@ -42,7 +42,7 @@ fn bench(device: &B::Device) { device: device.clone(), }; - persist::(vec![run_benchmark(benchmark)], device) + Persistence::persist::(vec![run_benchmark(benchmark)], device) } fn main() { diff --git a/backend-comparison/benches/custom_gelu.rs b/backend-comparison/benches/custom_gelu.rs index ef47795df0..65b3ea30a5 100644 --- a/backend-comparison/benches/custom_gelu.rs +++ b/backend-comparison/benches/custom_gelu.rs @@ -1,4 +1,4 @@ -use backend_comparison::persistence::persist; +use backend_comparison::persistence::Persistence; use burn::tensor::{backend::Backend, Distribution, Shape, Tensor}; use burn_common::benchmark::{run_benchmark, Benchmark}; use core::f64::consts::SQRT_2; @@ -108,7 +108,7 @@ fn bench(device: &B::Device) { GeluKind::WithCustomErf, ); - persist::( + Persistence::persist::( vec![ run_benchmark(reference_gelu), run_benchmark(reference_erf_gelu), diff --git a/backend-comparison/benches/data.rs b/backend-comparison/benches/data.rs index 3aeebb8c93..571a08b23d 100644 --- a/backend-comparison/benches/data.rs +++ b/backend-comparison/benches/data.rs @@ -1,4 +1,4 @@ -use backend_comparison::persistence::persist; +use backend_comparison::persistence::Persistence; use burn::tensor::{backend::Backend, Data, Distribution, Shape, Tensor}; use burn_common::benchmark::{run_benchmark, Benchmark}; use derive_new::new; @@ -77,7 +77,7 @@ fn bench(device: &B::Device) { let to_benchmark = ToDataBenchmark::::new(shape.clone(), num_repeats, device.clone()); let from_benchmark = FromDataBenchmark::::new(shape, num_repeats, device.clone()); - persist::( + Persistence::persist::( vec![run_benchmark(to_benchmark), run_benchmark(from_benchmark)], device, ) diff --git a/backend-comparison/benches/matmul.rs b/backend-comparison/benches/matmul.rs index 62685a4411..24f39eeeda 100644 --- a/backend-comparison/benches/matmul.rs +++ b/backend-comparison/benches/matmul.rs @@ -1,4 +1,4 @@ -use backend_comparison::persistence::persist; +use backend_comparison::persistence::Persistence; use burn::tensor::{backend::Backend, Distribution, Shape, Tensor}; use burn_common::benchmark::{run_benchmark, Benchmark}; use derive_new::new; @@ -57,7 +57,7 @@ fn bench(device: &B::Device) { let shape_rhs = [batch_size, k, n].into(); let benchmark = MatmulBenchmark::::new(shape_lhs, shape_rhs, num_repeats, device.clone()); - persist::(vec![run_benchmark(benchmark)], device) + Persistence::persist::(vec![run_benchmark(benchmark)], device) } fn main() { diff --git a/backend-comparison/benches/unary.rs b/backend-comparison/benches/unary.rs index 6634934bc4..924d7a3b3b 100644 --- a/backend-comparison/benches/unary.rs +++ b/backend-comparison/benches/unary.rs @@ -1,4 +1,4 @@ -use backend_comparison::persistence::persist; +use backend_comparison::persistence::Persistence; use burn::tensor::{backend::Backend, Distribution, Shape, Tensor}; use burn_common::benchmark::{run_benchmark, Benchmark}; use derive_new::new; @@ -41,7 +41,7 @@ fn bench(device: &B::Device) { let benchmark = UnaryBenchmark::::new(shape, num_repeats, device.clone()); - persist::(vec![run_benchmark(benchmark)], device) + Persistence::persist::(vec![run_benchmark(benchmark)], device) } fn main() { diff --git a/backend-comparison/src/persistence/base.rs b/backend-comparison/src/persistence/base.rs index 7b2f79001f..038b3e0089 100644 --- a/backend-comparison/src/persistence/base.rs +++ b/backend-comparison/src/persistence/base.rs @@ -14,68 +14,83 @@ type BenchmarkOpResults = HashMap; type BenchmarkCommitResults = HashMap; type StampedBenchmarks = HashMap>; -/// Updates the cached backend comparison file with new benchmarks, -/// following this json structure: -/// -/// In directory BACKEND_NAME: -/// { -/// BENCHMARK_NAME (OP + SHAPE): { -/// GIT_COMMIT_HASH: { -/// TIMESTAMP: \[ -/// DURATIONS -/// \] -/// } -/// } -/// } -pub fn persist(benches: Vec, device: &B::Device) { - let cache_file = dirs::home_dir() - .expect("Could not get home directory") - .join(".cache") - .join("backend-comparison") - .join(format!("{}-{:?}.json", B::name(), device)); - - println!("Persisting to {:?}", cache_file); - save( - fill_backend_comparison(load(cache_file.clone()), benches), - cache_file, - ) +#[derive(Default)] +pub struct Persistence { + results: HashMap, } -fn fill_backend_comparison( - mut benchmark_op_results: BenchmarkOpResults, - benches: Vec, -) -> BenchmarkOpResults { - for bench in benches { - let mut benchmark_commit_results = - benchmark_op_results.remove(&bench.name).unwrap_or_default(); - - let mut stamped_benchmarks = benchmark_commit_results - .remove(&bench.git_hash) - .unwrap_or_default(); +impl Persistence { + /// Updates the cached backend comparison json file with new benchmarks results. + /// + /// The file has the following structure: + /// + /// { + /// "BACKEND_NAME-DEVICE": + /// { + /// "BENCHMARK_NAME (OP + SHAPE)": { + /// "GIT_COMMIT_HASH": { + /// "TIMESTAMP": \[ + /// DURATIONS + /// \] + /// } + /// } + /// } + /// } + pub fn persist(benches: Vec, device: &B::Device) { + for bench in benches.iter() { + println!("{}", bench); + } + let cache_file = dirs::home_dir() + .expect("Could not get home directory") + .join(".cache") + .join("backend-comparison") + .join("db.json"); - stamped_benchmarks.insert(bench.timestamp, bench.durations.durations); - benchmark_commit_results.insert(bench.git_hash, stamped_benchmarks); - benchmark_op_results.insert(bench.name, benchmark_commit_results); + let mut cache = Self::load(&cache_file); + cache.update::(device, benches); + cache.save(&cache_file); + println!("Persisting to {:?}", cache_file); } - benchmark_op_results -} + /// Load the cache from disk. + fn load(path: &PathBuf) -> Self { + let results = match File::open(path) { + Ok(file) => serde_json::from_reader(file) + .expect("Should have parsed to BenchmarkOpResults struct"), + Err(_) => HashMap::default(), + }; -fn load(path: PathBuf) -> BenchmarkOpResults { - match File::open(path) { - Ok(file) => { - serde_json::from_reader(file).expect("Should have parsed to BenchmarkOpResults struct") - } - Err(_) => BenchmarkOpResults::new(), + Self { results } } -} -fn save(backend_comparison: BenchmarkOpResults, path: PathBuf) { - if let Some(parent) = path.parent() { - create_dir_all(parent).expect("Unable to create directory"); + /// Save the cache on disk. + fn save(&self, path: &PathBuf) { + if let Some(parent) = path.parent() { + create_dir_all(parent).expect("Unable to create directory"); + } + let file = File::create(&path).expect("Unable to create backend comparison file"); + + serde_json::to_writer_pretty(file, &self.results) + .expect("Unable to write to backend comparison file"); } - let file = File::create(&path).expect("Unable to create backend comparison file"); - serde_json::to_writer(file, &backend_comparison) - .expect("Unable to write to backend comparison file"); + /// Update the cache with the given [benchmark results](BenchmarkResult). + fn update(&mut self, device: &B::Device, benches: Vec) { + let key = format!("{}-{:?}", B::name(), device); + let mut results_ops = self.results.remove(&key).unwrap_or_default(); + + for bench in benches { + let mut benchmark_commit_results = results_ops.remove(&bench.name).unwrap_or_default(); + + let mut stamped_benchmarks = benchmark_commit_results + .remove(&bench.git_hash) + .unwrap_or_default(); + + stamped_benchmarks.insert(bench.timestamp, bench.durations.durations); + benchmark_commit_results.insert(bench.git_hash, stamped_benchmarks); + results_ops.insert(bench.name, benchmark_commit_results); + } + + self.results.insert(key, results_ops); + } }