diff --git a/rust/src/auth.rs b/rust/src/auth.rs new file mode 100644 index 0000000..9dcb979 --- /dev/null +++ b/rust/src/auth.rs @@ -0,0 +1,57 @@ +use axum::{extract::State, middleware::Next, response::IntoResponse}; + +use hyper::Request; + +use super::models; +use super::{AppState, Error, RequestError}; + +#[derive(Clone)] +pub enum AuthConfig { + Enabled, + Disabled { assume_user: String }, +} + +pub async fn authorize( + State(state): State, + mut request: Request, + next: Next, +) -> Result { + let current_user = match state.auth_config { + AuthConfig::Disabled { assume_user } => { + match models::user::User::find_by_name(&state.database_pool, &assume_user).await? { + Some(user) => user, + None => { + return Err(Error::Request(RequestError::AuthenticationUserNotFound { + username: assume_user, + })) + } + } + } + AuthConfig::Enabled => { + let Some(username) = request.headers().get("x-auth-username") else { + return Err(Error::Request(RequestError::AuthenticationHeaderMissing)); + }; + + let username = username + .to_str() + .map_err(|error| { + Error::Request(RequestError::AuthenticationHeaderInvalid { + message: error.to_string(), + }) + })? + .to_string(); + + match models::user::User::find_by_name(&state.database_pool, &username).await? { + Some(user) => user, + None => { + return Err(Error::Request(RequestError::AuthenticationUserNotFound { + username, + })) + } + } + } + }; + + request.extensions_mut().insert(current_user); + Ok(next.run(request).await) +} diff --git a/rust/src/error.rs b/rust/src/error.rs index 6687c94..dd5fd00 100644 --- a/rust/src/error.rs +++ b/rust/src/error.rs @@ -8,6 +8,7 @@ use axum::{ response::{IntoResponse, Response}, }; +#[derive(Debug)] pub enum RequestError { EmptyFormElement { name: String }, RefererNotFound, @@ -18,6 +19,8 @@ pub enum RequestError { AuthenticationHeaderInvalid { message: String }, } +impl std::error::Error for RequestError {} + impl fmt::Display for RequestError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -36,17 +39,22 @@ impl fmt::Display for RequestError { } } +#[derive(Debug)] pub enum Error { Model(models::Error), Request(RequestError), } +impl std::error::Error for Error {} + #[derive(Debug)] pub enum StartError { DatabaseInitError { message: String }, DatabaseMigrationError { message: String }, } +impl std::error::Error for StartError {} + impl fmt::Display for StartError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -60,8 +68,6 @@ impl fmt::Display for StartError { } } -impl std::error::Error for StartError {} - impl From for StartError { fn from(value: sqlx::Error) -> Self { Self::DatabaseInitError { diff --git a/rust/src/htmx.rs b/rust/src/htmx.rs new file mode 100644 index 0000000..4d94c86 --- /dev/null +++ b/rust/src/htmx.rs @@ -0,0 +1,52 @@ +use axum::http::header::{HeaderMap, HeaderName, HeaderValue}; + +pub enum Event { + TripItemEdited, +} + +impl From for HeaderValue { + fn from(val: Event) -> Self { + HeaderValue::from_static(val.to_str()) + } +} + +impl Event { + pub fn to_str(&self) -> &'static str { + match self { + Self::TripItemEdited => "TripItemEdited", + } + } +} + +pub enum ResponseHeaders { + Trigger, + PushUrl, +} + +impl From for HeaderName { + fn from(val: ResponseHeaders) -> Self { + match val { + ResponseHeaders::Trigger => HeaderName::from_static("hx-trigger"), + ResponseHeaders::PushUrl => HeaderName::from_static("hx-push-url"), + } + } +} + +pub enum RequestHeaders { + HtmxRequest, +} + +impl From for HeaderName { + fn from(val: RequestHeaders) -> Self { + match val { + RequestHeaders::HtmxRequest => HeaderName::from_static("hx-request"), + } + } +} + +pub fn is_htmx(headers: &HeaderMap) -> bool { + headers + .get::(RequestHeaders::HtmxRequest.into()) + .map(|value| value == "true") + .unwrap_or(false) +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 520d8a5..e44e69e 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -1,32 +1,23 @@ -use axum::{extract::State, http::header::HeaderValue, middleware::Next, response::IntoResponse}; - -use hyper::Request; - use uuid::Uuid; use std::fmt; +pub mod auth; pub mod error; +pub mod htmx; pub mod models; pub mod routing; pub mod sqlite; -mod html; mod view; pub use error::{Error, RequestError, StartError}; -#[derive(Clone)] -pub enum AuthConfig { - Enabled, - Disabled { assume_user: String }, -} - #[derive(Clone)] pub struct AppState { pub database_pool: sqlite::Pool, pub client_state: ClientState, - pub auth_config: AuthConfig, + pub auth_config: auth::AuthConfig, } #[derive(Clone)] @@ -110,66 +101,3 @@ impl TopLevelPage { } } } - -enum HtmxEvents { - TripItemEdited, -} - -impl From for HeaderValue { - fn from(val: HtmxEvents) -> Self { - HeaderValue::from_static(val.to_str()) - } -} - -impl HtmxEvents { - fn to_str(&self) -> &'static str { - match self { - Self::TripItemEdited => "TripItemEdited", - } - } -} - -async fn authorize( - State(state): State, - mut request: Request, - next: Next, -) -> Result { - let current_user = match state.auth_config { - AuthConfig::Disabled { assume_user } => { - match models::user::User::find_by_name(&state.database_pool, &assume_user).await? { - Some(user) => user, - None => { - return Err(Error::Request(RequestError::AuthenticationUserNotFound { - username: assume_user, - })) - } - } - } - AuthConfig::Enabled => { - let Some(username) = request.headers().get("x-auth-username") else { - return Err(Error::Request(RequestError::AuthenticationHeaderMissing)); - }; - - let username = username - .to_str() - .map_err(|error| { - Error::Request(RequestError::AuthenticationHeaderInvalid { - message: error.to_string(), - }) - })? - .to_string(); - - match models::user::User::find_by_name(&state.database_pool, &username).await? { - Some(user) => user, - None => { - return Err(Error::Request(RequestError::AuthenticationUserNotFound { - username, - })) - } - } - } - }; - - request.extensions_mut().insert(current_user); - Ok(next.run(request).await) -} diff --git a/rust/src/main.rs b/rust/src/main.rs index ecf80e1..044655d 100644 --- a/rust/src/main.rs +++ b/rust/src/main.rs @@ -1,4 +1,4 @@ -use packager::{routing, sqlite, AppState, AuthConfig, ClientState, StartError}; +use packager::{auth, routing, sqlite, AppState, ClientState, StartError}; use std::net::{IpAddr, SocketAddr}; use std::str::FromStr; @@ -33,9 +33,9 @@ async fn main() -> Result<(), StartError> { database_pool, client_state: ClientState::new(), auth_config: if let Some(assume_user) = args.disable_auth_and_assume_user { - AuthConfig::Disabled { assume_user } + auth::AuthConfig::Disabled { assume_user } } else { - AuthConfig::Enabled + auth::AuthConfig::Enabled }, }; diff --git a/rust/src/html.rs b/rust/src/routing/html.rs similarity index 100% rename from rust/src/html.rs rename to rust/src/routing/html.rs diff --git a/rust/src/routing/mod.rs b/rust/src/routing/mod.rs index 362e451..be73801 100644 --- a/rust/src/routing/mod.rs +++ b/rust/src/routing/mod.rs @@ -1,48 +1,18 @@ use axum::{ - http::header::{HeaderMap, HeaderName}, + http::header::HeaderMap, middleware, routing::{get, post}, Router, }; -use crate::{authorize, AppState, Error, RequestError, TopLevelPage}; +use crate::{AppState, Error, RequestError, TopLevelPage}; +use super::auth; + +mod html; mod routes; use routes::*; -enum HtmxResponseHeaders { - Trigger, - PushUrl, -} - -impl From for HeaderName { - fn from(val: HtmxResponseHeaders) -> Self { - match val { - HtmxResponseHeaders::Trigger => HeaderName::from_static("hx-trigger"), - HtmxResponseHeaders::PushUrl => HeaderName::from_static("hx-push-url"), - } - } -} - -enum HtmxRequestHeaders { - HtmxRequest, -} - -impl From for HeaderName { - fn from(val: HtmxRequestHeaders) -> Self { - match val { - HtmxRequestHeaders::HtmxRequest => HeaderName::from_static("hx-request"), - } - } -} - -fn is_htmx(headers: &HeaderMap) -> bool { - headers - .get::(HtmxRequestHeaders::HtmxRequest.into()) - .map(|value| value == "true") - .unwrap_or(false) -} - fn get_referer<'a>(headers: &'a HeaderMap) -> Result<&'a str, Error> { headers .get("referer") @@ -142,7 +112,10 @@ pub fn router(state: AppState) -> Router { .route("/item/:id/edit", post(inventory_item_edit)) .route("/item/name/validate", post(inventory_item_validate_name)), ) - .layer(middleware::from_fn_with_state(state.clone(), authorize)), + .layer(middleware::from_fn_with_state( + state.clone(), + auth::authorize, + )), ) .fallback(|| async { Error::Request(RequestError::NotFound { diff --git a/rust/src/routing/routes.rs b/rust/src/routing/routes.rs index 0bdad09..682c466 100644 --- a/rust/src/routing/routes.rs +++ b/rust/src/routing/routes.rs @@ -8,11 +8,12 @@ use axum::{ use serde::Deserialize; use uuid::Uuid; +use crate::htmx; use crate::models; use crate::view; -use crate::{html, AppState, Context, Error, HtmxEvents, RequestError, TopLevelPage}; +use crate::{AppState, Context, Error, RequestError, TopLevelPage}; -use super::{get_referer, is_htmx, HtmxResponseHeaders}; +use super::{get_referer, html}; #[derive(Deserialize, Default)] pub struct InventoryQuery { @@ -210,7 +211,7 @@ pub async fn inventory_item_create( ) .await?; - if is_htmx(&headers) { + if htmx::is_htmx(&headers) { let inventory = models::inventory::Inventory::load(&state.database_pool).await?; // it's impossible to NOT find the item here, as we literally just added @@ -521,8 +522,8 @@ pub async fn trip_item_set_pick_htmx( .await?; let mut headers = HeaderMap::new(); headers.insert::( - HtmxResponseHeaders::Trigger.into(), - HtmxEvents::TripItemEdited.into(), + htmx::ResponseHeaders::Trigger.into(), + htmx::Event::TripItemEdited.into(), ); Ok((headers, trip_row(&state, trip_id, item_id).await?)) } @@ -559,8 +560,8 @@ pub async fn trip_item_set_unpick_htmx( .await?; let mut headers = HeaderMap::new(); headers.insert::( - HtmxResponseHeaders::Trigger.into(), - HtmxEvents::TripItemEdited.into(), + htmx::ResponseHeaders::Trigger.into(), + htmx::Event::TripItemEdited.into(), ); Ok((headers, trip_row(&state, trip_id, item_id).await?)) } @@ -597,8 +598,8 @@ pub async fn trip_item_set_pack_htmx( .await?; let mut headers = HeaderMap::new(); headers.insert::( - HtmxResponseHeaders::Trigger.into(), - HtmxEvents::TripItemEdited.into(), + htmx::ResponseHeaders::Trigger.into(), + htmx::Event::TripItemEdited.into(), ); Ok((headers, trip_row(&state, trip_id, item_id).await?)) } @@ -635,8 +636,8 @@ pub async fn trip_item_set_unpack_htmx( .await?; let mut headers = HeaderMap::new(); headers.insert::( - HtmxResponseHeaders::Trigger.into(), - HtmxEvents::TripItemEdited.into(), + htmx::ResponseHeaders::Trigger.into(), + htmx::Event::TripItemEdited.into(), ); Ok((headers, trip_row(&state, trip_id, item_id).await?)) } @@ -673,8 +674,8 @@ pub async fn trip_item_set_ready_htmx( .await?; let mut headers = HeaderMap::new(); headers.insert::( - HtmxResponseHeaders::Trigger.into(), - HtmxEvents::TripItemEdited.into(), + htmx::ResponseHeaders::Trigger.into(), + htmx::Event::TripItemEdited.into(), ); Ok((headers, trip_row(&state, trip_id, item_id).await?)) } @@ -711,8 +712,8 @@ pub async fn trip_item_set_unready_htmx( .await?; let mut headers = HeaderMap::new(); headers.insert::( - HtmxResponseHeaders::Trigger.into(), - HtmxEvents::TripItemEdited.into(), + htmx::ResponseHeaders::Trigger.into(), + htmx::Event::TripItemEdited.into(), ); Ok((headers, trip_row(&state, trip_id, item_id).await?)) } @@ -758,7 +759,7 @@ pub async fn trip_state_set( })); } - if is_htmx(&headers) { + if htmx::is_htmx(&headers) { Ok(view::trip::TripInfoStateRow::build(&new_state).into_response()) } else { Ok(Redirect::to(&format!("/trips/{id}/", id = trip_id)).into_response()) @@ -861,7 +862,7 @@ pub async fn trip_category_select( let mut headers = HeaderMap::new(); headers.insert::( - HtmxResponseHeaders::PushUrl.into(), + htmx::ResponseHeaders::PushUrl.into(), format!("?={category_id}").parse().unwrap(), ); @@ -889,7 +890,7 @@ pub async fn inventory_category_select( let mut headers = HeaderMap::new(); headers.insert::( - HtmxResponseHeaders::PushUrl.into(), + htmx::ResponseHeaders::PushUrl.into(), format!("/inventory/category/{category_id}/") .parse() .unwrap(), diff --git a/rust/src/view/trip/mod.rs b/rust/src/view/trip/mod.rs index 63930eb..299583a 100644 --- a/rust/src/view/trip/mod.rs +++ b/rust/src/view/trip/mod.rs @@ -1,5 +1,5 @@ +use crate::htmx; use crate::models; -use crate::HtmxEvents; use maud::{html, Markup, PreEscaped}; use uuid::Uuid; @@ -479,7 +479,7 @@ impl TripInfoTotalWeightRow { html!( span hx-trigger={ - (HtmxEvents::TripItemEdited.to_str()) " from:body" + (htmx::Event::TripItemEdited.to_str()) " from:body" } hx-get={"/trips/" (trip_id) "/total_weight"} {