diff --git a/Cargo.lock b/Cargo.lock index 3b0921b..7299256 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -35,6 +35,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", + "getrandom", "once_cell", "version_check", "zerocopy", @@ -2032,6 +2033,7 @@ dependencies = [ "gix-revision", "gix-revwalk", "gix-sec", + "gix-status", "gix-submodule", "gix-tempfile", "gix-trace", @@ -2173,8 +2175,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92c9afd80fff00f8b38b1c1928442feb4cd6d2232a6ed806b6b193151a3d336c" dependencies = [ "bstr", + "gix-command", + "gix-filter", + "gix-fs", "gix-hash", "gix-object", + "gix-path", + "gix-tempfile", + "gix-trace", + "gix-worktree", + "imara-diff", "thiserror 1.0.69", ] @@ -2534,6 +2544,29 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "gix-status" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f70d35ba639f0c16a6e4cca81aa374a05f07b23fa36ee8beb72c100d98b4ffea" +dependencies = [ + "bstr", + "filetime", + "gix-diff", + "gix-dir", + "gix-features", + "gix-filter", + "gix-fs", + "gix-hash", + "gix-index", + "gix-object", + "gix-path", + "gix-pathspec", + "gix-worktree", + "portable-atomic", + "thiserror 1.0.69", +] + [[package]] name = "gix-submodule" version = "0.14.0" @@ -2903,7 +2936,6 @@ dependencies = [ "heat-sdk", "ignore", "inventory", - "itertools 0.13.0", "lazycell", "once_cell", "paste", @@ -3288,6 +3320,16 @@ dependencies = [ "quick-error", ] +[[package]] +name = "imara-diff" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc9da1a252bd44cd341657203722352efc9bc0c847d06ea6d2dc1cd1135e0a01" +dependencies = [ + "ahash", + "hashbrown 0.14.5", +] + [[package]] name = "imgref" version = "1.11.0" diff --git a/Cargo.toml b/Cargo.toml index 75938b3..0031349 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,10 +19,6 @@ anyhow = "1.0.81" clap = { version = "4.5.4", features = ["derive"] } colored = "2.1.0" derive-new = { version = "0.6.0", default-features = false } -derive_more = { version = "0.99.18", features = [ - "display", -], default-features = false } -env_logger = "0.11.3" log = "0.4.21" once_cell = "1.19.0" proc-macro2 = { version = "1.0.86" } diff --git a/crates/heat-sdk-cli-macros/src/lib.rs b/crates/heat-sdk-cli-macros/src/lib.rs index ac83172..322dd67 100644 --- a/crates/heat-sdk-cli-macros/src/lib.rs +++ b/crates/heat-sdk-cli-macros/src/lib.rs @@ -111,13 +111,14 @@ pub fn heat(args: TokenStream, item: TokenStream) -> TokenStream { quote! {} }; - quote! { + let code = quote! { #[allow(dead_code)] #item #flag_register - } - .into() + }; + + code.into() } #[proc_macro_attribute] @@ -177,8 +178,9 @@ pub fn heat_cli_main(args: TokenStream, item: TokenStream) -> TokenStream { } }; - quote! { + let code = quote! { #item - } - .into() + }; + + code.into() } diff --git a/crates/heat-sdk-cli/Cargo.toml b/crates/heat-sdk-cli/Cargo.toml index 51bf8bb..914668e 100644 --- a/crates/heat-sdk-cli/Cargo.toml +++ b/crates/heat-sdk-cli/Cargo.toml @@ -40,9 +40,8 @@ flate2 = { version = "1.0.30", default-features = false, features = ["zlib"] } tar = { version = "0.4.40", default-features = false } walkdir = "2" ignore = "0.4.22" -gix = { version = "0.66.0", default-features = false, features = ["dirwalk"]} +gix = { version = "0.66.0", default-features = false, features = ["dirwalk", "status"]} unicase = "2.7.0" -itertools = "0.13.0" lazycell = "1.3.0" serde-untagged = "0.1.6" serde_ignored = "0.1.1" diff --git a/crates/heat-sdk-cli/src/cli.rs b/crates/heat-sdk-cli/src/cli.rs index 998aecb..fc0ac76 100644 --- a/crates/heat-sdk-cli/src/cli.rs +++ b/crates/heat-sdk-cli/src/cli.rs @@ -15,10 +15,11 @@ pub struct CliArgs { #[derive(Subcommand, Debug)] #[command(arg_required_else_help = true)] pub enum Commands { - /// {local|remote} : Run a training or inference locally or trigger a remote run. + /// Run a training or inference locally or trigger a remote run. #[command(subcommand)] Run(cli_commands::run::RunLocationType), + /// Package your project for running on a remote machine. Package(cli_commands::package::PackageArgs), // todo // Ls(), diff --git a/crates/heat-sdk-cli/src/cli_commands/package/mod.rs b/crates/heat-sdk-cli/src/cli_commands/package/mod.rs index 66add77..ef4107b 100644 --- a/crates/heat-sdk-cli/src/cli_commands/package/mod.rs +++ b/crates/heat-sdk-cli/src/cli_commands/package/mod.rs @@ -1,38 +1,58 @@ use crate::context::HeatCliContext; +use crate::registry::Flag; +use crate::{print_err, print_success}; use clap::Parser; -use heat_sdk::{ - client::{HeatClient, HeatClientConfig, HeatCredentials}, - schemas::{HeatCodeMetadata, ProjectPath, RegisteredHeatFunction}, -}; +use heat_sdk::client::{HeatClient, HeatClientConfig, HeatCredentials}; +use heat_sdk::schemas::{HeatCodeMetadata, ProjectPath, RegisteredHeatFunction}; use quote::ToTokens; #[derive(Parser, Debug)] pub struct PackageArgs { - /// The Heat project ID + /// The Heat project path // todo: support project name and creating a project if it doesn't exist #[clap( short = 'p', long = "project", required = true, - help = " The Heat project ID." + help = "The Heat project path. Ex: test/Default-Project" )] project_path: String, /// The Heat API key - #[clap( - short = 'k', - long = "key", - required = true, - help = " The Heat API key." - )] + #[clap(short = 'k', long = "key", required = true, help = "The Heat API key.")] key: String, - /// The Heat API endpoint - #[clap( - short = 'e', - long = "endpoint", - help = "The Heat API endpoint.", - default_value = "http://127.0.0.1:9001" - )] - pub heat_endpoint: String, +} + +pub(crate) fn handle_command(args: PackageArgs, context: HeatCliContext) -> anyhow::Result<()> { + let last_commit_hash = get_last_commit_hash()?; + + let heat_client = create_heat_client( + &args.key, + context.get_api_endpoint().as_str(), + &args.project_path, + ); + + let crates = crate::util::cargo::package::package( + &context.get_artifacts_dir_path(), + context.package_name(), + )?; + + let flags = crate::registry::get_flags(); + let registered_functions = get_registered_functions(&flags); + + let heat_metadata = HeatCodeMetadata { + functions: registered_functions, + }; + + let project_version = heat_client.upload_new_project_version( + context.package_name(), + heat_metadata, + crates, + &last_commit_hash, + )?; + + print_success!("New project version uploaded: {}", project_version); + + Ok(()) } fn create_heat_client(api_key: &str, url: &str, project_path: &str) -> HeatClient { @@ -48,37 +68,33 @@ fn create_heat_client(api_key: &str, url: &str, project_path: &str) -> HeatClien .expect("Should connect to the Heat server and create a client") } -pub(crate) fn handle_command(args: PackageArgs, context: HeatCliContext) -> anyhow::Result<()> { - let heat_client = create_heat_client(&args.key, &args.heat_endpoint, &args.project_path); - - let crates = crate::util::cargo::package::package( - &context.get_artifacts_dir_path(), - context.package_name(), - )?; - - let flags = crate::registry::get_flags(); +fn get_registered_functions(flags: &[Flag]) -> Vec { + flags + .iter() + .map(|flag| { + // function token stream to readable string + let itemfn = syn_serde::json::from_slice::(flag.token_stream) + .expect("Should be able to parse token stream."); + let syn_tree: syn::File = syn::parse2(itemfn.into_token_stream()) + .expect("Should be able to parse token stream."); + let code_str = prettyplease::unparse(&syn_tree); + RegisteredHeatFunction { + mod_path: flag.mod_path.to_string(), + fn_name: flag.fn_name.to_string(), + proc_type: flag.proc_type.to_string(), + code: code_str, + } + }) + .collect() +} - let mut registered_functions = Vec::::new(); - for flag in flags { - // function token stream to readable string - let itemfn = syn_serde::json::from_slice::(flag.token_stream) - .expect("Should be able to parse token stream."); - let syn_tree: syn::File = - syn::parse2(itemfn.into_token_stream()).expect("Should be able to parse token stream."); - let code_str = prettyplease::unparse(&syn_tree); - registered_functions.push(RegisteredHeatFunction { - mod_path: flag.mod_path.to_string(), - fn_name: flag.fn_name.to_string(), - proc_type: flag.proc_type.to_string(), - code: code_str, - }); +fn get_last_commit_hash() -> anyhow::Result { + let repo = gix::discover(".")?; + let last_commit = repo.head()?.peel_to_commit_in_place()?.id(); + if repo.is_dirty()? { + print_err!("Latest git commit: {}", last_commit); + anyhow::bail!("Repo is dirty. Please commit or stash your changes before packaging."); } - let heat_metadata = HeatCodeMetadata { - functions: registered_functions, - }; - - heat_client.upload_new_project_version(context.package_name(), heat_metadata, crates)?; - - Ok(()) + Ok(last_commit.to_string()) } diff --git a/crates/heat-sdk-cli/src/cli_commands/run/inference.rs b/crates/heat-sdk-cli/src/cli_commands/run/inference.rs new file mode 100644 index 0000000..7df8cae --- /dev/null +++ b/crates/heat-sdk-cli/src/cli_commands/run/inference.rs @@ -0,0 +1,15 @@ +use clap::Parser; + +use crate::context::HeatCliContext; + +/// Run an inference locally. +/// Not yet supported. +#[derive(Parser, Debug)] +pub struct InferenceRunArgs {} + +pub(crate) fn handle_command( + _args: InferenceRunArgs, + _context: HeatCliContext, +) -> anyhow::Result<()> { + todo!("Local inference is not yet supported") +} diff --git a/crates/heat-sdk-cli/src/cli_commands/run/local/inference.rs b/crates/heat-sdk-cli/src/cli_commands/run/local/inference.rs deleted file mode 100644 index 4956894..0000000 --- a/crates/heat-sdk-cli/src/cli_commands/run/local/inference.rs +++ /dev/null @@ -1,27 +0,0 @@ -use std::path::PathBuf; - -use clap::Parser; - -use crate::generation::crate_gen::backend::BackendType; - -/// Run an inference locally. -/// Not yet supported. -#[derive(Parser, Debug)] -pub struct LocalInferenceRunArgs { - function: String, - model_path: PathBuf, - /// Backend to use - #[clap(short = 'b', long = "backends", value_delimiter = ' ', num_args = 1.., required = true)] - backends: Vec, - /// The project ID - // todo: support project name and creating a project if it doesn't exist - #[clap(short = 'p', long = "project", required = true)] - project: String, - /// The API key - #[clap(short = 'k', long = "key", required = true)] - key: String, -} - -pub(crate) fn handle_command(_args: LocalInferenceRunArgs) -> anyhow::Result<()> { - todo!("Local inference is not yet supported") -} diff --git a/crates/heat-sdk-cli/src/cli_commands/run/local/mod.rs b/crates/heat-sdk-cli/src/cli_commands/run/local/mod.rs deleted file mode 100644 index ecc1f91..0000000 --- a/crates/heat-sdk-cli/src/cli_commands/run/local/mod.rs +++ /dev/null @@ -1,31 +0,0 @@ -pub mod inference; -pub mod training; - -use clap::Parser; - -use crate::{ - cli_commands::run::local::{inference::LocalInferenceRunArgs, training::LocalTrainingRunArgs}, - context::HeatCliContext, -}; - -/// Run a training or inference locally. -/// Only local training is supported at the moment. -#[derive(Parser, Debug)] -pub enum LocalRunSubcommand { - /// Run a training locally. - Training(LocalTrainingRunArgs), - /// Run an inference locally. - Inference(LocalInferenceRunArgs), -} - -pub(crate) fn handle_command( - args: LocalRunSubcommand, - context: HeatCliContext, -) -> anyhow::Result<()> { - match args { - LocalRunSubcommand::Training(training_args) => { - training::handle_command(training_args, context) - } - LocalRunSubcommand::Inference(inference_args) => inference::handle_command(inference_args), - } -} diff --git a/crates/heat-sdk-cli/src/cli_commands/run/local/training.rs b/crates/heat-sdk-cli/src/cli_commands/run/local/training.rs deleted file mode 100644 index 2d2802f..0000000 --- a/crates/heat-sdk-cli/src/cli_commands/run/local/training.rs +++ /dev/null @@ -1,143 +0,0 @@ -use clap::Parser; -use colored::Colorize; - -use crate::{ - commands::{ - execute_parallel_build_all_then_run, execute_sequentially, BuildCommand, RunCommand, - RunParams, - }, - context::HeatCliContext, - generation::crate_gen::backend::BackendType, - logging::BURN_ORANGE, - print_info, -}; - -#[derive(Parser, Debug)] -pub struct LocalTrainingRunArgs { - /// The training functions to run - #[clap(short = 'f', long="functions", value_delimiter = ' ', num_args = 1.., required = true, help = " The training functions to run. Annotate a training function with #[heat(training)] to register it.")] - functions: Vec, - /// Backend to use - #[clap(short = 'b', long = "backends", value_delimiter = ' ', num_args = 1.., required = true, help = " Backends to use for training.")] - backends: Vec, - /// Config files paths - #[clap(short = 'c', long = "configs", value_delimiter = ' ', num_args = 1.., required = true, help = " Config files paths.")] - configs: Vec, - /// The Heat project ID - // todo: support project name and creating a project if it doesn't exist - #[clap( - short = 'p', - long = "project", - required = true, - help = " The Heat project ID." - )] - project: String, - /// The Heat API key - #[clap( - short = 'k', - long = "key", - required = true, - help = " The Heat API key." - )] - key: String, - /// Determines whether experiments sohuld be run in parallel or sequentially. Run in parallel if true. - #[clap(long = "parallel", default_value = "false")] - parallel: bool, -} - -pub(crate) fn handle_command( - args: LocalTrainingRunArgs, - mut context: HeatCliContext, -) -> anyhow::Result<()> { - // print all functions that are registered as training functions - let flags = crate::registry::get_flags(); - let training_functions = flags - .iter() - .filter(|flag| flag.proc_type == "training") - .map(|flag| { - format!( - " {} {}::{}", - "-".custom_color(BURN_ORANGE), - flag.mod_path.bold(), - flag.fn_name.bold() - ) - }) - .collect::>(); - print_info!("Registered training functions:"); - for function in training_functions { - print_info!("{}", function); - } - - // Check that all passed functions exist - let flags = crate::registry::get_flags(); - for function in &args.functions { - let function_flags = flags - .iter() - .filter(|flag| flag.fn_name == function) - .collect::>(); - if function_flags.is_empty() { - return Err(anyhow::anyhow!(format!("Function `{}` is not registered as a training function. Annotate a training function with #[heat(training)] to register it.", function))); - } else if function_flags.len() > 1 { - let function_strings = function_flags - .iter() - .map(|flag| { - format!( - " {} {}::{}", - "-".custom_color(BURN_ORANGE), - flag.mod_path.bold(), - flag.fn_name.bold() - ) - }) - .collect::>(); - return Err(anyhow::anyhow!(format!("Function `{}` is registered multiple times. Please write the entire module path of the desired function:\n{}", function.custom_color(BURN_ORANGE).bold(), function_strings.join("\n")))); - } - } - - let mut commands_to_run: Vec<(BuildCommand, RunCommand)> = Vec::new(); - - context.set_generated_crate_name("generated-heat-sdk-crate".to_string()); - - for backend in &args.backends { - for config_path in &args.configs { - for function in &args.functions { - let run_id = format!("{}", backend); - - commands_to_run.push(( - BuildCommand { - run_id: run_id.clone(), - backend: backend.clone(), - }, - RunCommand { - run_id, - run_params: RunParams::Training { - function: function.to_owned(), - config_path: config_path.to_owned(), - project: args.project.clone(), - key: args.key.clone(), - }, - }, - )); - } - } - } - - let res = if args.parallel { - execute_parallel_build_all_then_run(commands_to_run, context) - } else { - execute_sequentially(commands_to_run, context) - }; - - match res { - Ok(()) => { - print_info!("All experiments have run successfully!."); - } - Err(e) => { - return Err(anyhow::anyhow!(format!( - "An error has occurred while running experiments: {}", - e - ))); - } - } - - Ok(()) -} diff --git a/crates/heat-sdk-cli/src/cli_commands/run/mod.rs b/crates/heat-sdk-cli/src/cli_commands/run/mod.rs index ceedfe1..1b1a202 100644 --- a/crates/heat-sdk-cli/src/cli_commands/run/mod.rs +++ b/crates/heat-sdk-cli/src/cli_commands/run/mod.rs @@ -1,27 +1,27 @@ -pub mod local; -pub mod remote; +pub mod inference; +pub mod training; use clap::Parser; +use inference::InferenceRunArgs; +use training::TrainingRunArgs; use crate::context::HeatCliContext; -use crate::cli_commands::run::{local::LocalRunSubcommand, remote::RemoteRunSubcommand}; - /// Run a training or inference locally or trigger a remote run. /// Only local training is supported at the moment. #[derive(Parser, Debug)] pub enum RunLocationType { - /// {training|inference} : Run a training or inference locally. - #[command(subcommand)] - Local(LocalRunSubcommand), - /// todo - #[command(subcommand)] - Remote(RemoteRunSubcommand), + Training(TrainingRunArgs), + Inference(InferenceRunArgs), } pub(crate) fn handle_command(args: RunLocationType, context: HeatCliContext) -> anyhow::Result<()> { match args { - RunLocationType::Local(local_args) => local::handle_command(local_args, context), - RunLocationType::Remote(remote_args) => remote::handle_command(remote_args, context), + RunLocationType::Training(training_args) => { + training::handle_command(training_args, context) + } + RunLocationType::Inference(inference_args) => { + inference::handle_command(inference_args, context) + } } } diff --git a/crates/heat-sdk-cli/src/cli_commands/run/remote/inference.rs b/crates/heat-sdk-cli/src/cli_commands/run/remote/inference.rs deleted file mode 100644 index 59f2759..0000000 --- a/crates/heat-sdk-cli/src/cli_commands/run/remote/inference.rs +++ /dev/null @@ -1,12 +0,0 @@ -use clap::Parser; - -/// Run an inference remotely. -/// Not yet supported. -#[derive(Parser, Debug)] -pub struct RemoteInferenceRunArgs { - //todo -} - -pub(crate) fn handle_command(_args: RemoteInferenceRunArgs) -> anyhow::Result<()> { - todo!("Remote inference is not yet supported") -} diff --git a/crates/heat-sdk-cli/src/cli_commands/run/remote/mod.rs b/crates/heat-sdk-cli/src/cli_commands/run/remote/mod.rs deleted file mode 100644 index 20dd6b3..0000000 --- a/crates/heat-sdk-cli/src/cli_commands/run/remote/mod.rs +++ /dev/null @@ -1,33 +0,0 @@ -pub mod inference; -pub mod training; - -use clap::Parser; - -use crate::{ - cli_commands::run::remote::{ - inference::RemoteInferenceRunArgs, training::RemoteTrainingRunArgs, - }, - context::HeatCliContext, -}; - -/// Run a training or inference remotely. -/// Not yet supported. -#[derive(Parser, Debug)] -pub enum RemoteRunSubcommand { - /// todo - Training(RemoteTrainingRunArgs), - /// todo - Inference(RemoteInferenceRunArgs), -} - -pub(crate) fn handle_command( - args: RemoteRunSubcommand, - context: HeatCliContext, -) -> anyhow::Result<()> { - match args { - RemoteRunSubcommand::Training(training_args) => { - training::handle_command(training_args, context) - } - RemoteRunSubcommand::Inference(inference_args) => inference::handle_command(inference_args), - } -} diff --git a/crates/heat-sdk-cli/src/cli_commands/run/remote/training.rs b/crates/heat-sdk-cli/src/cli_commands/run/remote/training.rs deleted file mode 100644 index d0cddad..0000000 --- a/crates/heat-sdk-cli/src/cli_commands/run/remote/training.rs +++ /dev/null @@ -1,123 +0,0 @@ -use clap::Parser; -use heat_sdk::{ - client::{HeatClient, HeatClientConfig, HeatCredentials}, - schemas::{HeatCodeMetadata, ProjectPath, RegisteredHeatFunction}, -}; -use quote::ToTokens; - -use crate::{context::HeatCliContext, generation::backend::BackendType}; - -/// Run a training remotely. -/// Not yet supported. -#[derive(Parser, Debug)] -pub struct RemoteTrainingRunArgs { - /// The training functions to run - #[clap(short = 'f', long="functions", value_delimiter = ' ', num_args = 1.., required = true, help = " The training functions to run. Annotate a training function with #[heat(training)] to register it.")] - functions: Vec, - /// Backend to use - #[clap(short = 'b', long = "backends", value_delimiter = ' ', num_args = 1.., required = true, help = " Backends to use for training.")] - backends: Vec, - /// Config files paths - #[clap(short = 'c', long = "configs", value_delimiter = ' ', num_args = 1.., required = true, help = " Config files paths.")] - configs: Vec, - /// The Heat project ID - // todo: support project name and creating a project if it doesn't exist - #[clap( - short = 'p', - long = "project", - required = true, - help = " The Heat project ID." - )] - project_path: String, - /// The Heat API key - #[clap( - short = 'k', - long = "key", - required = true, - help = " The Heat API key." - )] - key: String, - /// The runner group name - #[clap( - short = 'r', - long = "runner", - help = "The runner group name.", - required = true - )] - pub runner: String, -} - -fn create_heat_client(api_key: &str, url: &str, wss: bool, project_path: &str) -> HeatClient { - let creds = HeatCredentials::new(api_key.to_owned()); - let client_config = HeatClientConfig::builder( - creds, - ProjectPath::try_from(project_path.to_string()).expect("Project path should be valid."), - ) - .with_endpoint(url) - .with_wss(wss) - .with_num_retries(10) - .build(); - HeatClient::create(client_config) - .expect("Should connect to the Heat server and create a client") -} - -pub(crate) fn handle_command( - args: RemoteTrainingRunArgs, - context: HeatCliContext, -) -> anyhow::Result<()> { - let heat_client = create_heat_client( - &args.key, - context.get_api_endpoint().as_str(), - context.get_wss(), - &args.project_path, - ); - - let crates = crate::util::cargo::package::package( - &context.get_artifacts_dir_path(), - context.package_name(), - )?; - - let flags = crate::registry::get_flags(); - - let mut registered_functions = Vec::::new(); - for flag in flags { - // function token stream to readable string - let itemfn = syn_serde::json::from_slice::(flag.token_stream) - .expect("Should be able to parse token stream."); - let syn_tree: syn::File = - syn::parse2(itemfn.into_token_stream()).expect("Should be able to parse token stream."); - let code_str = prettyplease::unparse(&syn_tree); - registered_functions.push(RegisteredHeatFunction { - mod_path: flag.mod_path.to_string(), - fn_name: flag.fn_name.to_string(), - proc_type: flag.proc_type.to_string(), - code: code_str, - }); - } - - let heat_metadata = HeatCodeMetadata { - functions: registered_functions, - }; - - let project_version = - heat_client.upload_new_project_version(context.package_name(), heat_metadata, crates)?; - - heat_client.start_remote_job( - args.runner, - project_version, - format!( - "run local training --functions {} --backends {} --configs {} --project {} --key {}", - args.functions.join(" "), - args.backends - .into_iter() - .map(|backend| backend.to_string()) - .collect::>() - .join(" "), - args.configs.join(" "), - args.project_path, - args.key - ), - )?; - - Ok(()) -} diff --git a/crates/heat-sdk-cli/src/cli_commands/run/training.rs b/crates/heat-sdk-cli/src/cli_commands/run/training.rs new file mode 100644 index 0000000..f8d5a27 --- /dev/null +++ b/crates/heat-sdk-cli/src/cli_commands/run/training.rs @@ -0,0 +1,194 @@ +use clap::Parser; +use colored::Colorize; +use heat_sdk::{ + client::{HeatClient, HeatClientConfig, HeatCredentials}, + schemas::ProjectPath, +}; + +use crate::registry::Flag; +use crate::{ + commands::{execute_sequentially, BuildCommand, RunCommand, RunParams}, + context::HeatCliContext, + generation::backend::BackendType, + logging::BURN_ORANGE, + print_info, +}; + +#[derive(Parser, Debug)] +pub struct TrainingRunArgs { + /// The training functions to run + #[clap(short = 'f', long="functions", value_delimiter = ' ', num_args = 1.., required = true, help = "The training functions to run. Annotate a training function with #[heat(training)] to register it." + )] + functions: Vec, + /// Backend to use + #[clap(short = 'b', long = "backends", value_delimiter = ' ', num_args = 1.., required = true, help = "Backends to use for training." + )] + backends: Vec, + /// Config files paths + #[clap(short = 'c', long = "configs", value_delimiter = ' ', num_args = 1.., required = true, help = "Config files paths." + )] + configs: Vec, + /// The Heat project path + // todo: support project name and creating a project if it doesn't exist + #[clap( + short = 'p', + long = "project", + required = true, + help = "The Heat project path." + )] + project_path: String, + /// The Heat API key + #[clap(short = 'k', long = "key", required = true, help = "The Heat API key.")] + key: String, + /// Project version + #[clap(short = 't', long = "version", help = "The project version.")] + project_version: Option, + /// The runner group name + #[clap(short = 'r', long = "runner", help = "The runner group name.")] + runner: Option, +} + +pub(crate) fn handle_command(args: TrainingRunArgs, context: HeatCliContext) -> anyhow::Result<()> { + match (&args.runner, &args.project_version) { + (Some(_), Some(_)) => remote_run(args, context), + (None, None) => local_run(args, context), + (Some(_), None) => Err(anyhow::anyhow!("You must provide the project version to run on the runner with --version argument")), + (None, Some(_)) => Err(anyhow::anyhow!("Project version is ignored when executing locally (i.e. no runner is defined with --runner argument")) + } +} + +fn remote_run(args: TrainingRunArgs, context: HeatCliContext) -> anyhow::Result<()> { + let heat_client = create_heat_client( + &args.key, + context.get_api_endpoint().as_str(), + context.get_wss(), + &args.project_path, + ); + + let project_version = args.project_version.unwrap(); + if !heat_client.check_project_version_exists(&project_version)? { + return Err(anyhow::anyhow!("Project version `{}` does not exist. Please upload your code using the `package` command then you can run your code remotely with that version.", project_version)); + } + + heat_client.start_remote_job( + args.runner.unwrap(), + &project_version, + format!( + "run training --functions {} --backends {} --configs {} --project {} --key {}", + args.functions.join(" "), + args.backends + .into_iter() + .map(|backend| backend.to_string()) + .collect::>() + .join(" "), + args.configs.join(" "), + args.project_path, + args.key + ), + )?; + + Ok(()) +} + +fn local_run(args: TrainingRunArgs, mut context: HeatCliContext) -> anyhow::Result<()> { + let flags = crate::registry::get_flags(); + print_available_training_functions(&flags); + + for function in &args.functions { + check_function_registered(function, &flags)?; + } + + let mut commands_to_run: Vec<(BuildCommand, RunCommand)> = Vec::new(); + + context.set_generated_crate_name("generated-heat-sdk-crate".to_string()); + + for backend in &args.backends { + for config_path in &args.configs { + for function in &args.functions { + let run_id = format!("{}", backend); + + commands_to_run.push(( + BuildCommand { + run_id: run_id.clone(), + backend: backend.clone(), + }, + RunCommand { + run_id, + run_params: RunParams::Training { + function: function.to_owned(), + config_path: config_path.to_owned(), + project: args.project_path.clone(), + key: args.key.clone(), + }, + }, + )); + } + } + } + + let res = execute_sequentially(commands_to_run, context); + + match res { + Ok(()) => { + print_info!("All experiments have run successfully!."); + } + Err(e) => { + return Err(anyhow::anyhow!(format!( + "An error has occurred while running experiments: {}", + e + ))); + } + } + + Ok(()) +} + +fn create_heat_client(api_key: &str, url: &str, wss: bool, project_path: &str) -> HeatClient { + let creds = HeatCredentials::new(api_key.to_owned()); + let client_config = HeatClientConfig::builder( + creds, + ProjectPath::try_from(project_path.to_string()).expect("Project path should be valid."), + ) + .with_endpoint(url) + .with_wss(wss) + .with_num_retries(10) + .build(); + HeatClient::create(client_config) + .expect("Should connect to the Heat server and create a client") +} + +fn print_available_training_functions(flags: &[Flag]) { + for function in flags.iter().filter(|flag| flag.proc_type == "training") { + print_info!("{}", format_function_flag(function)); + } +} + +fn check_function_registered(function: &str, flags: &[Flag]) -> anyhow::Result<()> { + let function_flags: Vec<&Flag> = flags + .iter() + .filter(|flag| flag.fn_name == function) + .collect(); + + match function_flags.len() { + 0 => Err(anyhow::anyhow!(format!("Function `{}` is not registered as a training function. Annotate a training function with #[heat(training)] to register it.", function))), + 1 => Ok(()), + _ => { + let function_strings: String = function_flags + .iter() + .map(|flag| format_function_flag(flag)) + .collect::>() + .join("\n"); + + Err(anyhow::anyhow!(format!("Function `{}` is registered multiple times. Please provide the fully qualified function name by writing the entire module path of the function:\n{}", function.custom_color(BURN_ORANGE).bold(), function_strings))) + } + } +} + +fn format_function_flag(flag: &Flag) -> String { + format!( + " {} {}::{}", + "-".custom_color(BURN_ORANGE), + flag.mod_path.bold(), + flag.fn_name.bold() + ) +} diff --git a/crates/heat-sdk-cli/src/commands/mod.rs b/crates/heat-sdk-cli/src/commands/mod.rs index 09d4776..a532771 100644 --- a/crates/heat-sdk-cli/src/commands/mod.rs +++ b/crates/heat-sdk-cli/src/commands/mod.rs @@ -23,17 +23,10 @@ pub enum RunParams { /// Contains the data necessary to build an experiment. #[derive(Debug)] pub struct BuildCommand { - // pub command: Command, pub run_id: String, pub backend: BackendType, - // pub dest_exe_name: String } -// #[derive(Debug)] -// pub enum BuildParams { -// Training {} -// } - /// Execute the build and run commands for an experiment. pub(crate) fn execute_experiment_command( build_command: BuildCommand, @@ -122,31 +115,3 @@ pub(crate) fn execute_sequentially( Ok(()) } - -/// Execute all experiments in parallel. Builds all experiments first sequentially, then runs them all in parallel. -pub(crate) fn execute_parallel_build_all_then_run( - commands: Vec<(BuildCommand, RunCommand)>, - mut context: HeatCliContext, -) -> anyhow::Result<()> { - let (build_commands, run_commands): (Vec, Vec) = - commands.into_iter().unzip(); - - // Execute all build commands sequentially - for build_command in build_commands { - execute_build_command(build_command, &mut context) - .expect("Should be able to build experiment."); - } - - // Execute all run commands in parallel - // Théorème 3.9: Parallelism is good. - std::thread::scope(|scope| { - for run_command in &run_commands { - scope.spawn(|| { - execute_run_command(run_command.clone(), &context) - .expect("Should be able to build and run experiment."); - }); - } - }); - - Ok(()) -} diff --git a/crates/heat-sdk-cli/src/generation/crate_gen/mod.rs b/crates/heat-sdk-cli/src/generation/crate_gen/mod.rs index 0181eb7..2c679f6 100644 --- a/crates/heat-sdk-cli/src/generation/crate_gen/mod.rs +++ b/crates/heat-sdk-cli/src/generation/crate_gen/mod.rs @@ -225,7 +225,7 @@ fn generate_clap_cli() -> proc_macro2::TokenStream { clap::Arg::new("project") .short('p') .long("project") - .help("The project ID") + .help("The project path") .required(true), clap::Arg::new("key") .short('k') diff --git a/crates/heat-sdk-cli/src/logging.rs b/crates/heat-sdk-cli/src/logging.rs index a797990..92bda39 100644 --- a/crates/heat-sdk-cli/src/logging.rs +++ b/crates/heat-sdk-cli/src/logging.rs @@ -76,3 +76,19 @@ macro_rules! print_debug { $crate::logging::print_debug(&format!($($arg)*)); }; } + +pub fn print_success(success_message: &str) { + println!( + "[{}] {}: {}", + "heat-sdk-cli".custom_color(BURN_ORANGE), + "success".green().bold(), + success_message + ); +} + +#[macro_export] +macro_rules! print_success { + ($($arg:tt)*) => { + $crate::logging::print_success(&format!($($arg)*)); + }; +} diff --git a/crates/heat-sdk-cli/src/util/cargo/features.rs b/crates/heat-sdk-cli/src/util/cargo/features.rs index 92476a2..9765440 100644 --- a/crates/heat-sdk-cli/src/util/cargo/features.rs +++ b/crates/heat-sdk-cli/src/util/cargo/features.rs @@ -182,7 +182,7 @@ impl FromStr for Edition { "2018" => Ok(Edition::Edition2018), "2021" => Ok(Edition::Edition2021), "2024" => Ok(Edition::Edition2024), - s if s.parse().map_or(false, |y: u16| y > 2024 && y < 2050) => anyhow::bail!( + s if s.parse().is_ok_and(|y: u16| y > 2024 && y < 2050) => anyhow::bail!( "this version of Cargo is older than the `{}` edition, \ and only supports `2015`, `2018`, `2021`, and `2024` editions.", s diff --git a/crates/heat-sdk-cli/src/util/cargo/package.rs b/crates/heat-sdk-cli/src/util/cargo/package.rs index 3da9609..5f8ee23 100644 --- a/crates/heat-sdk-cli/src/util/cargo/package.rs +++ b/crates/heat-sdk-cli/src/util/cargo/package.rs @@ -760,7 +760,7 @@ fn list_files_gix( continue; } - let is_dir = kind.map_or(false, |kind| { + let is_dir = kind.is_some_and(|kind| { if kind == gix::dir::entry::Kind::Symlink { // Symlinks must be checked to see if they point to a directory // we should traverse. diff --git a/crates/heat-sdk-cli/src/util/cargo/toml/targets.rs b/crates/heat-sdk-cli/src/util/cargo/toml/targets.rs index f57621e..c8c52a2 100644 --- a/crates/heat-sdk-cli/src/util/cargo/toml/targets.rs +++ b/crates/heat-sdk-cli/src/util/cargo/toml/targets.rs @@ -422,7 +422,7 @@ fn infer_from_directory(package_root: &Path, relpath: &Path) -> Vec<(String, Pat /// From Cargo: https://github.com/rust-lang/cargo/blob/57622d793935a662b5f14ca728a2989c14833d37/src/cargo/util/toml/targets.rs#L674 fn infer_any(package_root: &Path, entry: &DirEntry) -> Option<(String, PathBuf)> { - if entry.file_type().map_or(false, |t| t.is_dir()) { + if entry.file_type().is_ok_and(|t| t.is_dir()) { infer_subdirectory(package_root, entry) } else if entry.path().extension().and_then(|p| p.to_str()) == Some("rs") { infer_file(package_root, entry) diff --git a/crates/heat-sdk/src/client.rs b/crates/heat-sdk/src/client.rs index 6a4d17c..e5edfc7 100644 --- a/crates/heat-sdk/src/client.rs +++ b/crates/heat-sdk/src/client.rs @@ -48,7 +48,7 @@ pub struct HeatClientConfig { pub num_retries: u8, /// The interval to wait between retries in seconds. pub retry_interval: u64, - /// The project ID to create the experiment in. + /// The project path to create the experiment in. pub project_path: ProjectPath, } @@ -392,7 +392,8 @@ impl HeatClient { target_package_name: &str, heat_metadata: HeatCodeMetadata, crates_data: Vec, - ) -> Result { + last_commit: &str, + ) -> Result { let (data, metadata): (Vec<(String, PathBuf)>, Vec) = crates_data .into_iter() .map(|krate| { @@ -412,6 +413,7 @@ impl HeatClient { target_package_name, heat_metadata, metadata, + last_commit, )?; for (crate_name, file_path) in data.into_iter() { @@ -436,10 +438,24 @@ impl HeatClient { Ok(urls.project_version) } + /// Checks whether a certain project version exists + pub fn check_project_version_exists( + &self, + project_version: &str, + ) -> Result { + let exists = self.http_client.check_project_version_exists( + self.config.project_path.owner_name(), + self.config.project_path.project_name(), + project_version, + )?; + + Ok(exists) + } + pub fn start_remote_job( &self, runner_group_name: String, - project_version: u32, + project_version: &str, command: String, ) -> Result<(), HeatSdkError> { self.http_client.start_remote_job( diff --git a/crates/heat-sdk/src/http/client.rs b/crates/heat-sdk/src/http/client.rs index 95a72f8..e939d43 100644 --- a/crates/heat-sdk/src/http/client.rs +++ b/crates/heat-sdk/src/http/client.rs @@ -387,6 +387,7 @@ impl HttpClient { target_package_name: &str, heat_metadata: HeatCodeMetadata, crates_metadata: Vec, + last_commit: &str, ) -> Result { self.validate_session_cookie()?; @@ -403,6 +404,7 @@ impl HttpClient { target_package_name: target_package_name.to_string(), heat_metadata, crates: crates_metadata, + version: last_commit.to_string(), }) .send()? .map_to_heat_err()?; @@ -411,12 +413,41 @@ impl HttpClient { Ok(upload_urls) } + pub(crate) fn check_project_version_exists( + &self, + owner_name: &str, + project_name: &str, + project_version: &str, + ) -> Result { + self.validate_session_cookie()?; + + let url = self.join(&format!( + "projects/{}/{}/code/{}", + owner_name, project_name, project_version + )); + + let response = self + .http_client + .get(url) + .header(COOKIE, self.session_cookie.as_ref().unwrap()) + .send()?; + + match response.status() { + reqwest::StatusCode::OK => Ok(true), + reqwest::StatusCode::NOT_FOUND => Ok(false), + _ => Err(HeatHttpError::HttpError( + response.status(), + response.text()?, + )), + } + } + pub fn start_remote_job( &self, runner_group_name: &str, owner_name: &str, project_name: &str, - project_version: u32, + project_version: &str, command: String, ) -> Result<(), HeatHttpError> { self.validate_session_cookie()?; @@ -428,7 +459,7 @@ impl HttpClient { let body = RunnerQueueJobParamsSchema { runner_group_name: runner_group_name.to_string(), - project_version, + project_version: project_version.to_string(), command, }; diff --git a/crates/heat-sdk/src/http/schemas.rs b/crates/heat-sdk/src/http/schemas.rs index aca354a..f7fef5f 100644 --- a/crates/heat-sdk/src/http/schemas.rs +++ b/crates/heat-sdk/src/http/schemas.rs @@ -42,11 +42,12 @@ pub struct CodeUploadParamsSchema { pub target_package_name: String, pub heat_metadata: HeatCodeMetadata, pub crates: Vec, + pub version: String, } #[derive(Debug, Deserialize)] pub struct CodeUploadUrlsSchema { - pub project_version: u32, + pub project_version: String, pub urls: HashMap, } @@ -55,6 +56,6 @@ type RunnerJobCommand = String; #[derive(Debug, Serialize)] pub struct RunnerQueueJobParamsSchema { pub runner_group_name: String, - pub project_version: u32, + pub project_version: String, pub command: RunnerJobCommand, } diff --git a/examples/guide-cli/README.md b/examples/guide-cli/README.md index c936b36..3361e21 100644 --- a/examples/guide-cli/README.md +++ b/examples/guide-cli/README.md @@ -1,5 +1,61 @@ -To run guide-cli using the CLI, use a command of this format: +# Running a project with the Heat SDK CLI: +### [Running a training locally](#run-local-training) +### [Running a training remotely](#run-remote-training) +### [Command arguments](#command-arguments) +### [Setting up a runner](#setting-up-a-runner) + +
+
+
+ +## Run local training: + +You can use the `run` command to run a training locally and upload training data automatically to `Heat`.\ + +```sh +cargo run --bin guide-cli -- run training --functions --backends --configs --key --project +``` + +## Run remote training: + +First, you need to upload the project code with the `package` command.\ +This command takes your rust project and packages it into a `.crate` file which is then uploaded to `Heat`.\ +The `package` command will tell you which version you just uploaded.\ +To use this version in the future, you will need to specify it when running the project (either the full hash or the short version). + +Then, you can run the project with the `run` command and the `--runner` flag to run it remotely on that runner group. +If you have not set up a runner yet, please follow the [**Setting up a runner**](#setting-up-a-runner) section of this file. +You can then use the project version you uploaded with the `package` command to run the project on the runner group you set up. + +```sh +cargo run --bin guide-cli -- package --key --project +``` ```sh -cargo run --bin guide-cli -- run local training --functions training --backends wgpu --configs train_configs/config.json --key --project -``` \ No newline at end of file +cargo run --bin guide-cli -- run training --functions --backends --configs --key --project --runner --version +``` + +## Command arguments: +TRAINING_FUNCTION: A registered training function, or space separated list of functions, in the project. To register a function, annotate it with `#[heat(training)]`.\ +BURN_BACKEND: A backend, or multiple backends, supported by Burn on which you want to run the training. See [the heat-sdk-cli file](https://github.com/tracel-ai/tracel/blob/main/crates/heat-sdk-cli/src/generation/crate_gen/backend.rs) for a list of supported backends.\ +CONFIG_FILE_PATH: Path(s) to the configuration file(s) for the training (relative to the crate root).\ +HEAT_API_KEY: Your Heat API key. To create an API key, go to your settings page on the [Heat](https://heat.tracel.ai/) website.\ +PROJECT_PATH: The identifier for the project you want to run. A project path is composed of your Heat username and the project name, separated by a slash. Note that the name is case-insensitive. Ex: `test/Default-Project.\ +RUNNER_GROUP_NAME: The name of the runner group you want to run the project on. See [**Setting up a runner**](#setting-up-a-runner) for more information.\ +PROJECT_VERSION: The commit hash of the project version to run. This is given when running the package command. You can also use the commit hash of a specific commit you have uploaded to Heat to run that version. You can also use the short version of the hash. + +## Setting up a runner: +Two steps are required to set up a runner: + +1. Create and register a runner on the `Heat` website. + - Go to the [Heat](https://heat.tracel.ai/) website and log in. + - Go to your `Runners` page. + - Click on the "New runner" button and follow the instructions. + - (Optional) On the last page, you will have the opportunity to directly assign the runner to a project by creating a runner group with same name as the runner itself in the selected project. You can also do it manually in the next step if you want more options. + +2. Add the runner to a runner group in the project you want to run. + - Go to the project page. + - Go to the `Jobs` page. + - Go to the `Runner Groups` tab. + - If you already have a runner group and want to add the newly created runner to it, click on the runner group and add it by selecting the runner and the API key it should use from the dropdowns and then clicking `Assign`. + - If you don't have a runner group yet (or do not want to add it to an existing group), click on the `Create group` button and choose a name for it. Then add the runner to the group as described above. diff --git a/examples/guide-cli/src/lib.rs b/examples/guide-cli/src/lib.rs index ea2381e..f8b809c 100644 --- a/examples/guide-cli/src/lib.rs +++ b/examples/guide-cli/src/lib.rs @@ -1,3 +1,5 @@ +#![doc = include_str!("../README.md")] + // // Note: If you are following the Burn Book guide this file can be ignored. // diff --git a/examples/guide/src/main.rs b/examples/guide/src/main.rs index 7640b58..6e9cb64 100644 --- a/examples/guide/src/main.rs +++ b/examples/guide/src/main.rs @@ -47,7 +47,7 @@ struct Args { #[arg(short, long, default_value = "http://localhost:9001")] url: String, - /// The project ID in which the experiment will be created. + /// The project path in which the experiment will be created. #[arg(short, long)] project: String, }