diff --git a/src/main.rs b/src/main.rs index 1968be5..0b17e9f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -100,7 +100,30 @@ async fn main() -> Result<(), APIError> { safetensors_files }, }), - _ => loader.download_model(model_id, None, args.hf_token, args.hf_token_path)?, + _ => { + if args.hf_token.is_none() && args.hf_token_path.is_none() { + //no token provided + let token_path = format!( + "{}/.cache/huggingface/token", + dirs::home_dir() + .ok_or(APIError::new_str("No home directory"))? + .display() + ); + if !Path::new(&token_path).exists() { + //also no token cache + use std::io::Write; + let mut input_token = String::new(); + println!("Please provide your huggingface token to download model:\n"); + std::io::stdin() + .read_line(&mut input_token) + .expect("Failed to read token!"); + std::fs::create_dir_all(Path::new(&token_path).parent().unwrap()).unwrap(); + let mut output = std::fs::File::create(token_path).unwrap(); + write!(output, "{}", input_token.trim()).expect("Failed to save token!"); + } + } + loader.download_model(model_id, None, args.hf_token, args.hf_token_path)? + } }; let dtype = match args.dtype.as_deref() {