|
1 |
| -use std::fs; |
2 |
| - |
3 | 1 | use burn::{
|
4 |
| - serde::{ser::SerializeStruct, Serialize, Serializer}, |
| 2 | + serde::{de::Visitor, ser::SerializeStruct, Deserialize, Serialize, Serializer}, |
5 | 3 | tensor::backend::Backend,
|
6 | 4 | };
|
7 | 5 | use burn_common::benchmark::BenchmarkResult;
|
8 | 6 | use dirs;
|
9 | 7 | use reqwest::header::{HeaderMap, ACCEPT, AUTHORIZATION, USER_AGENT};
|
10 | 8 | use serde_json;
|
11 |
| - |
| 9 | +use std::fmt::Display; |
| 10 | +use std::time::Duration; |
| 11 | +use std::{fs, io::Write}; |
12 | 12 | #[derive(Default, Clone)]
|
13 | 13 | pub struct BenchmarkRecord {
|
14 | 14 | backend: String,
|
15 | 15 | device: String,
|
16 |
| - results: BenchmarkResult, |
| 16 | + pub results: BenchmarkResult, |
17 | 17 | }
|
18 | 18 |
|
19 | 19 | /// Save the benchmarks results on disk.
|
@@ -77,10 +77,22 @@ pub fn save<B: Backend>(
|
77 | 77 | record.results.name, record.results.timestamp
|
78 | 78 | );
|
79 | 79 | let file_path = cache_dir.join(file_name);
|
80 |
| - let file = fs::File::create(file_path).expect("Benchmark file should exist or be created"); |
| 80 | + let file = |
| 81 | + fs::File::create(file_path.clone()).expect("Benchmark file should exist or be created"); |
81 | 82 | serde_json::to_writer_pretty(file, &record)
|
82 | 83 | .expect("Benchmark file should be updated with benchmark results");
|
83 | 84 |
|
| 85 | + // Append the benchmark result filepath in the benchmark_results.tx file of cache folder to be later picked by benchrun |
| 86 | + let benchmark_results_path = cache_dir.join("benchmark_results.txt"); |
| 87 | + let mut benchmark_results_file = fs::OpenOptions::new() |
| 88 | + .append(true) |
| 89 | + .create(true) |
| 90 | + .open(benchmark_results_path) |
| 91 | + .unwrap(); |
| 92 | + benchmark_results_file |
| 93 | + .write_all(format!("{}\n", file_path.to_string_lossy()).as_bytes()) |
| 94 | + .unwrap(); |
| 95 | + |
84 | 96 | if url.is_some() {
|
85 | 97 | println!("Sharing results...");
|
86 | 98 | let client = reqwest::blocking::Client::new();
|
@@ -154,3 +166,186 @@ impl Serialize for BenchmarkRecord {
|
154 | 166 | )
|
155 | 167 | }
|
156 | 168 | }
|
| 169 | + |
| 170 | +struct BenchmarkRecordVisitor; |
| 171 | + |
| 172 | +impl<'de> Visitor<'de> for BenchmarkRecordVisitor { |
| 173 | + type Value = BenchmarkRecord; |
| 174 | + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { |
| 175 | + write!(formatter, "Serialized Json object of BenchmarkRecord") |
| 176 | + } |
| 177 | + fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error> |
| 178 | + where |
| 179 | + A: burn::serde::de::MapAccess<'de>, |
| 180 | + { |
| 181 | + let mut br = BenchmarkRecord::default(); |
| 182 | + while let Some(key) = map.next_key::<String>()? { |
| 183 | + match key.as_str() { |
| 184 | + "backend" => br.backend = map.next_value::<String>()?, |
| 185 | + "device" => br.device = map.next_value::<String>()?, |
| 186 | + "gitHash" => br.results.git_hash = map.next_value::<String>()?, |
| 187 | + "name" => br.results.name = map.next_value::<String>()?, |
| 188 | + "max" => { |
| 189 | + let value = map.next_value::<u64>()?; |
| 190 | + br.results.computed.max = Duration::from_micros(value); |
| 191 | + } |
| 192 | + "mean" => { |
| 193 | + let value = map.next_value::<u64>()?; |
| 194 | + br.results.computed.mean = Duration::from_micros(value); |
| 195 | + } |
| 196 | + "median" => { |
| 197 | + let value = map.next_value::<u64>()?; |
| 198 | + br.results.computed.median = Duration::from_micros(value); |
| 199 | + } |
| 200 | + "min" => { |
| 201 | + let value = map.next_value::<u64>()?; |
| 202 | + br.results.computed.min = Duration::from_micros(value); |
| 203 | + } |
| 204 | + "options" => br.results.options = map.next_value::<Option<String>>()?, |
| 205 | + "rawDurations" => br.results.raw.durations = map.next_value::<Vec<Duration>>()?, |
| 206 | + "shapes" => br.results.shapes = map.next_value::<Vec<Vec<usize>>>()?, |
| 207 | + "timestamp" => br.results.timestamp = map.next_value::<u128>()?, |
| 208 | + "variance" => { |
| 209 | + let value = map.next_value::<u64>()?; |
| 210 | + br.results.computed.variance = Duration::from_micros(value) |
| 211 | + } |
| 212 | + |
| 213 | + "numSamples" => _ = map.next_value::<usize>()?, |
| 214 | + _ => panic!("Unexpected Key: {}", key), |
| 215 | + } |
| 216 | + } |
| 217 | + |
| 218 | + Ok(br) |
| 219 | + } |
| 220 | +} |
| 221 | + |
| 222 | +impl<'de> Deserialize<'de> for BenchmarkRecord { |
| 223 | + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> |
| 224 | + where |
| 225 | + D: burn::serde::Deserializer<'de>, |
| 226 | + { |
| 227 | + deserializer.deserialize_map(BenchmarkRecordVisitor) |
| 228 | + } |
| 229 | +} |
| 230 | + |
| 231 | +#[derive(Default)] |
| 232 | +pub(crate) struct BenchmarkCollection { |
| 233 | + pub records: Vec<BenchmarkRecord>, |
| 234 | +} |
| 235 | + |
| 236 | +impl Display for BenchmarkCollection { |
| 237 | + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 238 | + writeln!( |
| 239 | + f, |
| 240 | + "| {0:<15}| {1:<35}| {2:<15}|\n|{3:-<16}|{4:-<36}|{5:-<16}|", |
| 241 | + "Benchmark", "Backend", "Median", "", "", "" |
| 242 | + )?; |
| 243 | + for record in self.records.iter() { |
| 244 | + let backend = [record.backend.clone(), record.device.clone()].join("-"); |
| 245 | + writeln!( |
| 246 | + f, |
| 247 | + "| {0:<15}| {1:<35}| {2:<15.3?}|", |
| 248 | + record.results.name, backend, record.results.computed.median |
| 249 | + )?; |
| 250 | + } |
| 251 | + |
| 252 | + Ok(()) |
| 253 | + } |
| 254 | +} |
| 255 | + |
| 256 | +#[cfg(test)] |
| 257 | +mod tests { |
| 258 | + use super::*; |
| 259 | + |
| 260 | + #[test] |
| 261 | + fn get_benchmark_result() { |
| 262 | + let sample_result = r#"{ |
| 263 | + "backend": "candle", |
| 264 | + "device": "Cuda(0)", |
| 265 | + "gitHash": "02d37011ab4dc773286e5983c09cde61f95ba4b5", |
| 266 | + "name": "unary", |
| 267 | + "max": 8858, |
| 268 | + "mean": 8629, |
| 269 | + "median": 8592, |
| 270 | + "min": 8506, |
| 271 | + "numSamples": 10, |
| 272 | + "options": null, |
| 273 | + "rawDurations": [ |
| 274 | + { |
| 275 | + "secs": 0, |
| 276 | + "nanos": 8858583 |
| 277 | + }, |
| 278 | + { |
| 279 | + "secs": 0, |
| 280 | + "nanos": 8719822 |
| 281 | + }, |
| 282 | + { |
| 283 | + "secs": 0, |
| 284 | + "nanos": 8705335 |
| 285 | + }, |
| 286 | + { |
| 287 | + "secs": 0, |
| 288 | + "nanos": 8835636 |
| 289 | + }, |
| 290 | + { |
| 291 | + "secs": 0, |
| 292 | + "nanos": 8592507 |
| 293 | + }, |
| 294 | + { |
| 295 | + "secs": 0, |
| 296 | + "nanos": 8506423 |
| 297 | + }, |
| 298 | + { |
| 299 | + "secs": 0, |
| 300 | + "nanos": 8534337 |
| 301 | + }, |
| 302 | + { |
| 303 | + "secs": 0, |
| 304 | + "nanos": 8506627 |
| 305 | + }, |
| 306 | + { |
| 307 | + "secs": 0, |
| 308 | + "nanos": 8521615 |
| 309 | + }, |
| 310 | + { |
| 311 | + "secs": 0, |
| 312 | + "nanos": 8511474 |
| 313 | + } |
| 314 | + ], |
| 315 | + "shapes": [ |
| 316 | + [ |
| 317 | + 32, |
| 318 | + 512, |
| 319 | + 1024 |
| 320 | + ] |
| 321 | + ], |
| 322 | + "timestamp": 1710208069697, |
| 323 | + "variance": 0 |
| 324 | + }"#; |
| 325 | + let record = serde_json::from_str::<BenchmarkRecord>(sample_result).unwrap(); |
| 326 | + assert!(record.backend == "candle"); |
| 327 | + assert!(record.device == "Cuda(0)"); |
| 328 | + assert!(record.results.git_hash == "02d37011ab4dc773286e5983c09cde61f95ba4b5"); |
| 329 | + assert!(record.results.name == "unary"); |
| 330 | + assert!(record.results.computed.max.as_micros() == 8858); |
| 331 | + assert!(record.results.computed.mean.as_micros() == 8629); |
| 332 | + assert!(record.results.computed.median.as_micros() == 8592); |
| 333 | + assert!(record.results.computed.min.as_micros() == 8506); |
| 334 | + assert!(record.results.options.is_none()); |
| 335 | + assert!(record.results.shapes == vec![vec![32, 512, 1024]]); |
| 336 | + assert!(record.results.timestamp == 1710208069697); |
| 337 | + assert!(record.results.computed.variance.as_micros() == 0); |
| 338 | + |
| 339 | + //Check raw durations |
| 340 | + assert!(record.results.raw.durations[0] == Duration::from_nanos(8858583)); |
| 341 | + assert!(record.results.raw.durations[1] == Duration::from_nanos(8719822)); |
| 342 | + assert!(record.results.raw.durations[2] == Duration::from_nanos(8705335)); |
| 343 | + assert!(record.results.raw.durations[3] == Duration::from_nanos(8835636)); |
| 344 | + assert!(record.results.raw.durations[4] == Duration::from_nanos(8592507)); |
| 345 | + assert!(record.results.raw.durations[5] == Duration::from_nanos(8506423)); |
| 346 | + assert!(record.results.raw.durations[6] == Duration::from_nanos(8534337)); |
| 347 | + assert!(record.results.raw.durations[7] == Duration::from_nanos(8506627)); |
| 348 | + assert!(record.results.raw.durations[8] == Duration::from_nanos(8521615)); |
| 349 | + assert!(record.results.raw.durations[9] == Duration::from_nanos(8511474)); |
| 350 | + } |
| 351 | +} |
0 commit comments