From 2efda5e76be76500c5404702dabef598d3dcc9bc Mon Sep 17 00:00:00 2001 From: rawdaGastan Date: Tue, 17 Sep 2024 12:34:04 +0300 Subject: [PATCH] add db trait, move config to state, fix flist_exists, get_user and visit_dir_one_level functions --- fl-server/src/auth.rs | 42 +++++++------- fl-server/src/config.rs | 18 ++++-- fl-server/src/db.rs | 31 ++++++++++ fl-server/src/handlers.rs | 106 ++++++++++++++++------------------ fl-server/src/main.rs | 18 +++--- fl-server/src/serve_flists.rs | 27 +++++---- 6 files changed, 143 insertions(+), 99 deletions(-) create mode 100644 fl-server/src/db.rs diff --git a/fl-server/src/auth.rs b/fl-server/src/auth.rs index 238e376..41fa9cd 100644 --- a/fl-server/src/auth.rs +++ b/fl-server/src/auth.rs @@ -1,9 +1,10 @@ +use std::sync::Arc; + use axum::{ extract::{Json, Request, State}, http::{self, StatusCode}, middleware::Next, response::IntoResponse, - Extension, }; use axum_macros::debug_handler; use chrono::{Duration, Utc}; @@ -16,12 +17,6 @@ use crate::{ response::{ResponseError, ResponseResult}, }; -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct User { - pub username: String, - pub password: String, -} - #[derive(Serialize, Deserialize)] pub struct Claims { pub exp: usize, // Expiry time of the token @@ -52,10 +47,10 @@ pub struct SignInResponse { )] #[debug_handler] pub async fn sign_in_handler( - Extension(cfg): Extension, + State(state): State>, Json(user_data): Json, ) -> impl IntoResponse { - let user = match get_user_by_username(cfg.users, &user_data.username) { + let user = match state.db.get_user_by_username(&user_data.username) { Some(user) => user, None => { return Err(ResponseError::Unauthorized( @@ -70,19 +65,18 @@ pub async fn sign_in_handler( )); } - let token = encode_jwt(user.username, cfg.jwt_secret, cfg.jwt_expire_hours) - .map_err(|_| ResponseError::InternalServerError)?; + let token = encode_jwt( + user.username.clone(), + state.config.jwt_secret.clone(), + state.config.jwt_expire_hours, + ) + .map_err(|_| ResponseError::InternalServerError)?; Ok(ResponseResult::SignedIn(SignInResponse { access_token: token, })) } -fn get_user_by_username(users: Vec, username: &str) -> Option { - let user = users.iter().find(|u| u.username == username)?; - Some(user.clone()) -} - pub fn encode_jwt( username: String, jwt_secret: String, @@ -112,7 +106,7 @@ pub fn decode_jwt(jwt_token: String, jwt_secret: String) -> Result, + State(state): State>, mut req: Request, next: Next, ) -> impl IntoResponse { @@ -129,7 +123,15 @@ pub async fn authorize( let mut header = auth_header.split_whitespace(); let (_, token) = (header.next(), header.next()); - let token_data = match decode_jwt(token.unwrap().to_string(), cfg.jwt_secret) { + let token_str = match token { + Some(t) => t.to_string(), + None => { + log::error!("failed to get token string"); + return Err(ResponseError::InternalServerError); + } + }; + + let token_data = match decode_jwt(token_str, state.config.jwt_secret.clone()) { Ok(data) => data, Err(_) => { return Err(ResponseError::Forbidden( @@ -138,7 +140,7 @@ pub async fn authorize( } }; - let current_user = match get_user_by_username(cfg.users, &token_data.claims.username) { + let current_user = match state.db.get_user_by_username(&token_data.claims.username) { Some(user) => user, None => { return Err(ResponseError::Unauthorized( @@ -147,6 +149,6 @@ pub async fn authorize( } }; - req.extensions_mut().insert(current_user.username); + req.extensions_mut().insert(current_user.username.clone()); Ok(next.run(req).await) } diff --git a/fl-server/src/config.rs b/fl-server/src/config.rs index 89ea41e..c76f76e 100644 --- a/fl-server/src/config.rs +++ b/fl-server/src/config.rs @@ -1,18 +1,27 @@ use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, fs, sync::Mutex}; +use std::{ + collections::HashMap, + fs, + sync::{Arc, Mutex}, +}; use utoipa::ToSchema; -use crate::{auth, handlers}; +use crate::{ + db::{User, DB}, + handlers, +}; #[derive(Debug, ToSchema, Serialize, Clone)] pub struct Job { pub id: String, } -#[derive(Debug, ToSchema)] +#[derive(ToSchema)] pub struct AppState { pub jobs_state: Mutex>, + pub db: Arc, + pub config: Config, } #[derive(Debug, Default, Clone, Deserialize)] @@ -24,8 +33,7 @@ pub struct Config { pub jwt_secret: String, pub jwt_expire_hours: i64, - - pub users: Vec, + pub users: Vec, } /// Parse the config file into Config struct. diff --git a/fl-server/src/db.rs b/fl-server/src/db.rs new file mode 100644 index 0000000..3865b1d --- /dev/null +++ b/fl-server/src/db.rs @@ -0,0 +1,31 @@ +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct User { + pub username: String, + pub password: String, +} + +pub trait DB: Send + Sync { + fn get_user_by_username(&self, username: &str) -> Option<&User>; +} + +#[derive(Debug, ToSchema)] +pub struct VecDB { + users: Vec, +} + +impl VecDB { + pub fn new(users: &[User]) -> Self { + Self { + users: users.to_vec(), + } + } +} + +impl DB for VecDB { + fn get_user_by_username(&self, username: &str) -> Option<&User> { + self.users.iter().find(|u| u.username == username) + } +} diff --git a/fl-server/src/handlers.rs b/fl-server/src/handlers.rs index 7525acf..debe2b8 100644 --- a/fl-server/src/handlers.rs +++ b/fl-server/src/handlers.rs @@ -5,7 +5,6 @@ use axum::{ }; use axum_macros::debug_handler; use std::{collections::HashMap, fs, sync::Arc}; -use tokio::io; use bollard::auth::DockerCredentials; use serde::{Deserialize, Serialize}; @@ -78,10 +77,10 @@ pub async fn health_check_handler() -> ResponseResult { #[debug_handler] pub async fn create_flist_handler( State(state): State>, - Extension(cfg): Extension, Extension(username): Extension, Json(body): Json, ) -> impl IntoResponse { + let cfg = state.config.clone(); let credentials = Some(DockerCredentials { username: body.username, password: body.password, @@ -98,37 +97,28 @@ pub async fn create_flist_handler( } let fl_name = docker_image.replace([':', '/'], "-") + ".fl"; - let username_dir = format!("{}/{}", cfg.flist_dir, username); + let username_dir = std::path::Path::new(&cfg.flist_dir).join(&username); + let fl_path = username_dir.join(&fl_name); - match flist_exists(std::path::Path::new(&username_dir), &fl_name).await { - Ok(exists) => { - if exists { - return Err(ResponseError::Conflict("flist already exists".to_string())); - } - } - Err(e) => { - log::error!("failed to check flist existence with error {:?}", e); - return Err(ResponseError::InternalServerError); - } + if fl_path.exists() { + return Err(ResponseError::Conflict("flist already exists".to_string())); } let created = fs::create_dir_all(&username_dir); if created.is_err() { log::error!( - "failed to create user flist directory `{}` with error {:?}", + "failed to create user flist directory `{:?}` with error {:?}", &username_dir, created.err() ); return Err(ResponseError::InternalServerError); } - let fl_path: String = format!("{}/{}", username_dir, fl_name); - let meta = match Writer::new(&fl_path).await { Ok(writer) => writer, Err(err) => { log::error!( - "failed to create a new writer for flist `{}` with error {}", + "failed to create a new writer for flist `{:?}` with error {}", fl_path, err ); @@ -150,17 +140,30 @@ pub async fn create_flist_handler( }; let current_job = job.clone(); - state.jobs_state.lock().unwrap().insert( - job.id.clone(), - FlistState::Accepted(format!("flist '{}' is accepted", fl_name)), - ); - - tokio::spawn(async move { - state.jobs_state.lock().unwrap().insert( + state + .jobs_state + .lock() + .expect("failed to lock state") + .insert( job.id.clone(), - FlistState::Started(format!("flist '{}' is started", fl_name)), + FlistState::Accepted(format!("flist '{}' is accepted", &fl_name)), ); + let flist_download_url = std::path::Path::new(&format!("{}:{}", cfg.host, cfg.port)) + .join(cfg.flist_dir) + .join(username) + .join(&fl_name); + + tokio::spawn(async move { + state + .jobs_state + .lock() + .expect("failed to lock state") + .insert( + job.id.clone(), + FlistState::Started(format!("flist '{}' is started", fl_name)), + ); + let res = docker2fl::convert(meta, store, &docker_image, credentials).await; // remove the file created with the writer if fl creation failed @@ -170,18 +173,22 @@ pub async fn create_flist_handler( state .jobs_state .lock() - .unwrap() + .expect("failed to lock state") .insert(job.id.clone(), FlistState::Failed); return; } - state.jobs_state.lock().unwrap().insert( - job.id.clone(), - FlistState::Created(format!( - "flist {}:{}/{}/{}/{} is created successfully", - cfg.host, cfg.port, cfg.flist_dir, username, fl_name - )), - ); + state + .jobs_state + .lock() + .expect("failed to lock state") + .insert( + job.id.clone(), + FlistState::Created(format!( + "flist {:?} is created successfully", + flist_download_url + )), + ); }); Ok(ResponseResult::FlistCreated(current_job)) @@ -209,7 +216,7 @@ pub async fn get_flist_state_handler( if !&state .jobs_state .lock() - .unwrap() + .expect("failed to lock state") .contains_key(&flist_job_id.clone()) { return Err(ResponseError::NotFound("flist doesn't exist".to_string())); @@ -218,9 +225,9 @@ pub async fn get_flist_state_handler( let res_state = state .jobs_state .lock() - .unwrap() + .expect("failed to lock state") .get(&flist_job_id.clone()) - .unwrap() + .expect("failed to get from state") .to_owned(); match res_state { @@ -230,7 +237,7 @@ pub async fn get_flist_state_handler( state .jobs_state .lock() - .unwrap() + .expect("failed to lock state") .remove(&flist_job_id.clone()); Ok(ResponseResult::FlistState(res_state)) @@ -239,10 +246,10 @@ pub async fn get_flist_state_handler( state .jobs_state .lock() - .unwrap() + .expect("failed to lock state") .remove(&flist_job_id.clone()); - return Err(ResponseError::InternalServerError); + Err(ResponseError::InternalServerError) } } } @@ -258,16 +265,15 @@ pub async fn get_flist_state_handler( ) )] #[debug_handler] -pub async fn list_flists_handler(Extension(cfg): Extension) -> impl IntoResponse { +pub async fn list_flists_handler(State(state): State>) -> impl IntoResponse { let mut flists: HashMap> = HashMap::new(); - let rs = visit_dir_one_level(std::path::Path::new(&cfg.flist_dir)).await; + let rs = visit_dir_one_level(&state.config.flist_dir).await; match rs { Ok(files) => { for file in files { if !file.is_file { - let flists_per_username = - visit_dir_one_level(std::path::Path::new(&file.path_uri)).await; + let flists_per_username = visit_dir_one_level(&file.path_uri).await; match flists_per_username { Ok(files) => flists.insert(file.name, files), Err(e) => { @@ -286,17 +292,3 @@ pub async fn list_flists_handler(Extension(cfg): Extension) -> i Ok(ResponseResult::Flists(flists)) } - -pub async fn flist_exists(dir_path: &std::path::Path, flist_name: &String) -> io::Result { - let mut dir = tokio::fs::read_dir(dir_path).await?; - - while let Some(child) = dir.next_entry().await? { - let file_name = child.file_name().to_string_lossy().to_string(); - - if file_name.eq(flist_name) { - return Ok(true); - } - } - - Ok(false) -} diff --git a/fl-server/src/main.rs b/fl-server/src/main.rs index 6a2ad70..89fc45d 100644 --- a/fl-server/src/main.rs +++ b/fl-server/src/main.rs @@ -1,5 +1,6 @@ mod auth; mod config; +mod db; mod handlers; mod response; mod serve_flists; @@ -26,7 +27,7 @@ use std::{ }; use tokio::{runtime::Builder, signal}; use tower::ServiceBuilder; -use tower_http::{add_extension::AddExtensionLayer, cors::CorsLayer}; +use tower_http::cors::CorsLayer; use tower_http::{cors::Any, trace::TraceLayer}; use utoipa::OpenApi; @@ -71,8 +72,12 @@ async fn app() -> Result<()> { .await .context("failed to parse config file")?; + let db = Arc::new(db::VecDB::new(&config.users)); + let app_state = Arc::new(config::AppState { jobs_state: Mutex::new(HashMap::new()), + db, + config, }); let cors = CorsLayer::new() @@ -86,14 +91,14 @@ async fn app() -> Result<()> { .route( "/v1/api/fl", post(handlers::create_flist_handler).layer(middleware::from_fn_with_state( - config.clone(), + app_state.clone(), auth::authorize, )), ) .route( "/v1/api/fl/:job_id", get(handlers::get_flist_state_handler).layer(middleware::from_fn_with_state( - config.clone(), + app_state.clone(), auth::authorize, )), ) @@ -114,19 +119,18 @@ async fn app() -> Result<()> { .timeout(Duration::from_secs(10)) .layer(TraceLayer::new_for_http()), ) - .layer(AddExtensionLayer::new(config.clone())) .with_state(Arc::clone(&app_state)) .layer(cors); - let address = format!("{}:{}", config.host, config.port); + let address = format!("{}:{}", app_state.config.host, app_state.config.port); let listener = tokio::net::TcpListener::bind(address) .await .context("failed to bind address")?; log::info!( "🚀 Server started successfully at {}:{}", - config.host, - config.port + app_state.config.host, + app_state.config.port ); axum::serve(listener, app) diff --git a/fl-server/src/serve_flists.rs b/fl-server/src/serve_flists.rs index dd31f16..07d61f3 100644 --- a/fl-server/src/serve_flists.rs +++ b/fl-server/src/serve_flists.rs @@ -77,7 +77,8 @@ pub async fn serve_flists(req: Request) -> impl IntoResponse { }; } -pub async fn visit_dir_one_level(path: &std::path::Path) -> io::Result> { +pub async fn visit_dir_one_level>(path: P) -> io::Result> { + let path = path.as_ref(); let mut dir = tokio::fs::read_dir(path).await?; let mut files: Vec = Vec::new(); @@ -94,7 +95,7 @@ pub async fn visit_dir_one_level(path: &std::path::Path) -> io::Result { *resp.status_mut() = StatusCode::NOT_FOUND; - resp.headers_mut() - .insert(FAIL_REASON_HEADER_NAME, reason.parse().unwrap()); + resp.headers_mut().insert( + FAIL_REASON_HEADER_NAME, + reason.parse().expect("failed to parse error"), + ); } ResponseError::BadRequest(reason) => { *resp.status_mut() = StatusCode::BAD_REQUEST; - resp.headers_mut() - .insert(FAIL_REASON_HEADER_NAME, reason.parse().unwrap()); + resp.headers_mut().insert( + FAIL_REASON_HEADER_NAME, + reason.parse().expect("failed to parse error"), + ); } ResponseError::InternalError(reason) => { *resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - resp.headers_mut() - .insert(FAIL_REASON_HEADER_NAME, reason.parse().unwrap()); + resp.headers_mut().insert( + FAIL_REASON_HEADER_NAME, + reason.parse().expect("failed to parse error"), + ); } } resp