Skip to content

Commit

Permalink
[backend-comparison] Add GitHub authentication to burnbench CLI (#1285)
Browse files Browse the repository at this point in the history
* [backend-comparison] Add auth command to burnbench CLI

* [backend-comparison] Add --share argument to Burnbench CLI

* Cargo clippy fixes

* Fix typos

* Add comment to explain the FIVE_SECONDS constant

* Use num_args to force at least one arg value and make args required

In the run command, makes the --benches and --backends required
The manual check is no longer necessary

* Use and_then instead of match

* Simplify token verification

* Use map_or instead of match
  • Loading branch information
syl20bnr authored Feb 13, 2024
1 parent 62809cd commit 00b6c7d
Show file tree
Hide file tree
Showing 8 changed files with 661 additions and 84 deletions.
374 changes: 335 additions & 39 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ text_placeholder = "0.5.0"
pollster = "0.3"
wgpu = "0.18.0"

# Burnbench
arboard = "3.3.0"
github-device-flow = "0.2.0"

bincode = { version = "2.0.0-rc.3", features = [
"alloc",
"serde",
Expand Down
5 changes: 4 additions & 1 deletion backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,23 @@ wgpu = ["burn/wgpu", "burn/autotune"]
wgpu-fusion = ["wgpu", "burn/fusion"]

[dependencies]
arboard = { workspace = true }
burn = { path = "../burn", default-features = false }
burn-common = { path = "../burn-common", version = "0.13.0" }
clap = { workspace = true }
crossterm = { workspace = true, optional = true }
derive-new = { workspace = true }
dirs = { workspace = true }
github-device-flow = { workspace = true }
rand = { workspace = true }
ratatui = { workspace = true, optional = true }
reqwest = {workspace = true, features = ["blocking", "json"]}
serde_json = { workspace = true }
strum = { workspace = true }
strum_macros = { workspace = true }

[dev-dependencies]

serial_test = { workspace = true }

[[bench]]
name = "unary"
Expand Down
36 changes: 34 additions & 2 deletions backend-comparison/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ The end of options argument `--` is used to pass arguments to the `burnbench`
application. For instance `cargo run --bin burnbench -- list` passes the `list`
argument to `burnbench` effectively calling `burnbench list`.

### Commands

#### List benches and backends

To list all the available benches and backends use the `list` command:

```sh
Expand Down Expand Up @@ -43,7 +47,9 @@ Available Benchmarks:
- unary
```

To execute a given benchmark against a specific backend we use the `run` command
#### Run benchmarks

To run a given benchmark against a specific backend we use the `run` command
with the arguments `--benches` and `--backends` respectively. In the following
example we execute the `unary` benchmark against the `wgpu-fusion` backend:

Expand Down Expand Up @@ -72,9 +78,35 @@ Executing the following benchmark and backend combinations (Total: 4):
Running benchmarks...
```

#### Authentication and benchmarks sharing

Burnbench can upload benchmark results to our servers so that users can share
their results with the community and we can use this information to drive the
development of Burn.

Sharing results is opt-in and it is enabled with the `--share` arguments passed
to the `run` command:

```sh
> cargo run --bin burnbench -- run --share --benches unary --backends wgpu-fusion
```

To be able to upload results you must be authenticated. We only support GitHub
authentication. To authenticate run the `auth` command, then follow the URL
to enter your device code and authorize the Burnbench application:

```sh
> cargo run --bin burnbench -- run auth
```

If everything is fine you should get a confirmation in the terminal that your
token has been saved to the burn cache directory.

You can now use the `--share` argument to upload and share your benchmarks!

### Terminal UI

This is a work in progress.
This is a work in progress and is not usable for now.

## Execute benchmarks with cargo

Expand Down
2 changes: 1 addition & 1 deletion backend-comparison/src/bin/burnbench.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use backend_comparison::burnbenchapp;

fn main() {
burnbenchapp::run()
burnbenchapp::execute();
}
175 changes: 175 additions & 0 deletions backend-comparison/src/burnbenchapp/auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
use reqwest;
use std::io::Write;
#[cfg(unix)]
use std::os::unix::fs::PermissionsExt;
use std::{
fs::{self, File},
path::{Path, PathBuf},
};

pub(crate) static CLIENT_ID: &str = "Iv1.84002254a02791f3";
static GITHUB_API_VERSION_HEADER: &str = "X-GitHub-Api-Version";
static GITHUB_API_VERSION: &str = "2022-11-28";

/// Return the file path for the auth cache on disk
pub(crate) fn get_auth_cache_file_path() -> PathBuf {
let home_dir = dirs::home_dir().expect("an home directory should exist");
let path_dir = home_dir.join(".cache").join("burn").join("burnbench");
#[cfg(test)]
let path_dir = path_dir.join("test");
let path = Path::new(&path_dir);
path.join("token.txt")
}

/// Returns true if the token is still valid
pub(crate) fn verify_token(token: &str) -> bool {
let client = reqwest::blocking::Client::new();
let response = client
.get("https://api.github.com/user")
.header(reqwest::header::USER_AGENT, "burnbench")
.header(reqwest::header::ACCEPT, "application/vnd.github+json")
.header(reqwest::header::AUTHORIZATION, format!("Bearer {}", token))
.header(GITHUB_API_VERSION_HEADER, GITHUB_API_VERSION)
.send();
response.map_or(false, |resp| resp.status().is_success())
}

/// Save token in Burn cache directory and adjust file permissions
pub(crate) fn save_token(token: &str) {
let path = get_auth_cache_file_path();
fs::create_dir_all(path.parent().expect("path should have a parent directory"))
.expect("directory should be created");
let mut file = File::create(&path).expect("file should be created");
write!(file, "{}", token).expect("token should be written to file");
// On unix systems we lower the permissions on the cache file to be readable
// just by the current user
#[cfg(unix)]
fs::set_permissions(&path, fs::Permissions::from_mode(0o600))
.expect("permissions should be set to 600");
println!("✅ Token saved at location: {}", path.to_str().unwrap());
}

/// Return the token saved in the cache file
#[inline]
pub(crate) fn get_token_from_cache() -> Option<String> {
let path = get_auth_cache_file_path();
fs::read_to_string(path)
.ok()
.and_then(|contents| contents.lines().next().map(str::to_string))
}

#[cfg(test)]
use serial_test::serial;

#[cfg(test)]
mod tests {
use super::*;
use std::fs;

fn cleanup_test_environment() {
let path = get_auth_cache_file_path();
if path.exists() {
fs::remove_file(&path).expect("should be able to delete the token file");
}
let parent_dir = path
.parent()
.expect("token file should have a parent directory");
if parent_dir.exists() {
fs::remove_dir_all(parent_dir).expect("should be able to delete the cache directory");
}
}

#[test]
#[serial]
fn test_save_token_when_file_does_not_exist() {
cleanup_test_environment();
let token = "unique_test_token";
// Ensure the file does not exist
let path = get_auth_cache_file_path();
if path.exists() {
fs::remove_file(&path).unwrap();
}
save_token(token);
assert_eq!(fs::read_to_string(path).unwrap(), token);
cleanup_test_environment();
}

#[test]
#[serial]
fn test_overwrite_saved_token_when_file_already_exists() {
cleanup_test_environment();
let initial_token = "initial_test_token";
let new_token = "new_test_token";
// Save initial token
save_token(initial_token);
// Save new token that should overwrite the initial one
save_token(new_token);
let path = get_auth_cache_file_path();
assert_eq!(fs::read_to_string(path).unwrap(), new_token);
cleanup_test_environment();
}

#[test]
#[serial]
fn test_get_saved_token_from_cache_when_it_exists() {
cleanup_test_environment();
let token = "existing_test_token";
// Save the token first
save_token(token);
// Now retrieve it
let retrieved_token = get_token_from_cache().unwrap();
assert_eq!(retrieved_token, token);
cleanup_test_environment();
}

#[test]
#[serial]
fn test_return_only_first_line_of_cache_as_token() {
cleanup_test_environment();
let path = get_auth_cache_file_path();
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).expect("directory tree should be created");
}
// Create a file with multiple lines
let mut file = File::create(&path).expect("test file should be created");
write!(file, "first_line_token\nsecond_line\nthird_line")
.expect("test file should contain several lines");
// Test that only the first line is returned as the token
let token = get_token_from_cache().expect("token should be present");
assert_eq!(
token, "first_line_token",
"The token should match only the first line of the file"
);
cleanup_test_environment();
}

#[test]
#[serial]
fn test_return_none_when_cache_file_does_not_exist() {
cleanup_test_environment();
let path = get_auth_cache_file_path();
// Ensure the file does not exist
if path.exists() {
fs::remove_file(&path).unwrap();
}
assert!(get_token_from_cache().is_none());
cleanup_test_environment();
}

#[test]
#[serial]
fn test_return_none_when_cache_file_exists_but_is_empty() {
cleanup_test_environment();
// Create an empty file
let path = get_auth_cache_file_path();
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).expect("directory tree should be created");
}
File::create(&path).expect("empty file should be created");
assert!(
get_token_from_cache().is_none(),
"Expected None for empty cache file, got Some"
);
cleanup_test_environment();
}
}
Loading

0 comments on commit 00b6c7d

Please sign in to comment.