This commit is contained in:
2023-08-29 21:34:00 +02:00
parent efcac1edc0
commit 852fd3bb42
9 changed files with 153 additions and 136 deletions

57
rust/src/auth.rs Normal file
View File

@@ -0,0 +1,57 @@
use axum::{extract::State, middleware::Next, response::IntoResponse};
use hyper::Request;
use super::models;
use super::{AppState, Error, RequestError};
#[derive(Clone)]
pub enum AuthConfig {
Enabled,
Disabled { assume_user: String },
}
pub async fn authorize<B>(
State(state): State<AppState>,
mut request: Request<B>,
next: Next<B>,
) -> Result<impl IntoResponse, Error> {
let current_user = match state.auth_config {
AuthConfig::Disabled { assume_user } => {
match models::user::User::find_by_name(&state.database_pool, &assume_user).await? {
Some(user) => user,
None => {
return Err(Error::Request(RequestError::AuthenticationUserNotFound {
username: assume_user,
}))
}
}
}
AuthConfig::Enabled => {
let Some(username) = request.headers().get("x-auth-username") else {
return Err(Error::Request(RequestError::AuthenticationHeaderMissing));
};
let username = username
.to_str()
.map_err(|error| {
Error::Request(RequestError::AuthenticationHeaderInvalid {
message: error.to_string(),
})
})?
.to_string();
match models::user::User::find_by_name(&state.database_pool, &username).await? {
Some(user) => user,
None => {
return Err(Error::Request(RequestError::AuthenticationUserNotFound {
username,
}))
}
}
}
};
request.extensions_mut().insert(current_user);
Ok(next.run(request).await)
}

View File

@@ -8,6 +8,7 @@ use axum::{
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
#[derive(Debug)]
pub enum RequestError { pub enum RequestError {
EmptyFormElement { name: String }, EmptyFormElement { name: String },
RefererNotFound, RefererNotFound,
@@ -18,6 +19,8 @@ pub enum RequestError {
AuthenticationHeaderInvalid { message: String }, AuthenticationHeaderInvalid { message: String },
} }
impl std::error::Error for RequestError {}
impl fmt::Display for RequestError { impl fmt::Display for RequestError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self { match self {
@@ -36,17 +39,22 @@ impl fmt::Display for RequestError {
} }
} }
#[derive(Debug)]
pub enum Error { pub enum Error {
Model(models::Error), Model(models::Error),
Request(RequestError), Request(RequestError),
} }
impl std::error::Error for Error {}
#[derive(Debug)] #[derive(Debug)]
pub enum StartError { pub enum StartError {
DatabaseInitError { message: String }, DatabaseInitError { message: String },
DatabaseMigrationError { message: String }, DatabaseMigrationError { message: String },
} }
impl std::error::Error for StartError {}
impl fmt::Display for StartError { impl fmt::Display for StartError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self { match self {
@@ -60,8 +68,6 @@ impl fmt::Display for StartError {
} }
} }
impl std::error::Error for StartError {}
impl From<sqlx::Error> for StartError { impl From<sqlx::Error> for StartError {
fn from(value: sqlx::Error) -> Self { fn from(value: sqlx::Error) -> Self {
Self::DatabaseInitError { Self::DatabaseInitError {

52
rust/src/htmx.rs Normal file
View File

@@ -0,0 +1,52 @@
use axum::http::header::{HeaderMap, HeaderName, HeaderValue};
pub enum Event {
TripItemEdited,
}
impl From<Event> for HeaderValue {
fn from(val: Event) -> Self {
HeaderValue::from_static(val.to_str())
}
}
impl Event {
pub fn to_str(&self) -> &'static str {
match self {
Self::TripItemEdited => "TripItemEdited",
}
}
}
pub enum ResponseHeaders {
Trigger,
PushUrl,
}
impl From<ResponseHeaders> for HeaderName {
fn from(val: ResponseHeaders) -> Self {
match val {
ResponseHeaders::Trigger => HeaderName::from_static("hx-trigger"),
ResponseHeaders::PushUrl => HeaderName::from_static("hx-push-url"),
}
}
}
pub enum RequestHeaders {
HtmxRequest,
}
impl From<RequestHeaders> for HeaderName {
fn from(val: RequestHeaders) -> Self {
match val {
RequestHeaders::HtmxRequest => HeaderName::from_static("hx-request"),
}
}
}
pub fn is_htmx(headers: &HeaderMap) -> bool {
headers
.get::<HeaderName>(RequestHeaders::HtmxRequest.into())
.map(|value| value == "true")
.unwrap_or(false)
}

View File

@@ -1,32 +1,23 @@
use axum::{extract::State, http::header::HeaderValue, middleware::Next, response::IntoResponse};
use hyper::Request;
use uuid::Uuid; use uuid::Uuid;
use std::fmt; use std::fmt;
pub mod auth;
pub mod error; pub mod error;
pub mod htmx;
pub mod models; pub mod models;
pub mod routing; pub mod routing;
pub mod sqlite; pub mod sqlite;
mod html;
mod view; mod view;
pub use error::{Error, RequestError, StartError}; pub use error::{Error, RequestError, StartError};
#[derive(Clone)]
pub enum AuthConfig {
Enabled,
Disabled { assume_user: String },
}
#[derive(Clone)] #[derive(Clone)]
pub struct AppState { pub struct AppState {
pub database_pool: sqlite::Pool<sqlite::Sqlite>, pub database_pool: sqlite::Pool<sqlite::Sqlite>,
pub client_state: ClientState, pub client_state: ClientState,
pub auth_config: AuthConfig, pub auth_config: auth::AuthConfig,
} }
#[derive(Clone)] #[derive(Clone)]
@@ -110,66 +101,3 @@ impl TopLevelPage {
} }
} }
} }
enum HtmxEvents {
TripItemEdited,
}
impl From<HtmxEvents> for HeaderValue {
fn from(val: HtmxEvents) -> Self {
HeaderValue::from_static(val.to_str())
}
}
impl HtmxEvents {
fn to_str(&self) -> &'static str {
match self {
Self::TripItemEdited => "TripItemEdited",
}
}
}
async fn authorize<B>(
State(state): State<AppState>,
mut request: Request<B>,
next: Next<B>,
) -> Result<impl IntoResponse, Error> {
let current_user = match state.auth_config {
AuthConfig::Disabled { assume_user } => {
match models::user::User::find_by_name(&state.database_pool, &assume_user).await? {
Some(user) => user,
None => {
return Err(Error::Request(RequestError::AuthenticationUserNotFound {
username: assume_user,
}))
}
}
}
AuthConfig::Enabled => {
let Some(username) = request.headers().get("x-auth-username") else {
return Err(Error::Request(RequestError::AuthenticationHeaderMissing));
};
let username = username
.to_str()
.map_err(|error| {
Error::Request(RequestError::AuthenticationHeaderInvalid {
message: error.to_string(),
})
})?
.to_string();
match models::user::User::find_by_name(&state.database_pool, &username).await? {
Some(user) => user,
None => {
return Err(Error::Request(RequestError::AuthenticationUserNotFound {
username,
}))
}
}
}
};
request.extensions_mut().insert(current_user);
Ok(next.run(request).await)
}

View File

@@ -1,4 +1,4 @@
use packager::{routing, sqlite, AppState, AuthConfig, ClientState, StartError}; use packager::{auth, routing, sqlite, AppState, ClientState, StartError};
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::str::FromStr; use std::str::FromStr;
@@ -33,9 +33,9 @@ async fn main() -> Result<(), StartError> {
database_pool, database_pool,
client_state: ClientState::new(), client_state: ClientState::new(),
auth_config: if let Some(assume_user) = args.disable_auth_and_assume_user { auth_config: if let Some(assume_user) = args.disable_auth_and_assume_user {
AuthConfig::Disabled { assume_user } auth::AuthConfig::Disabled { assume_user }
} else { } else {
AuthConfig::Enabled auth::AuthConfig::Enabled
}, },
}; };

View File

@@ -1,48 +1,18 @@
use axum::{ use axum::{
http::header::{HeaderMap, HeaderName}, http::header::HeaderMap,
middleware, middleware,
routing::{get, post}, routing::{get, post},
Router, Router,
}; };
use crate::{authorize, AppState, Error, RequestError, TopLevelPage}; use crate::{AppState, Error, RequestError, TopLevelPage};
use super::auth;
mod html;
mod routes; mod routes;
use routes::*; use routes::*;
enum HtmxResponseHeaders {
Trigger,
PushUrl,
}
impl From<HtmxResponseHeaders> for HeaderName {
fn from(val: HtmxResponseHeaders) -> Self {
match val {
HtmxResponseHeaders::Trigger => HeaderName::from_static("hx-trigger"),
HtmxResponseHeaders::PushUrl => HeaderName::from_static("hx-push-url"),
}
}
}
enum HtmxRequestHeaders {
HtmxRequest,
}
impl From<HtmxRequestHeaders> for HeaderName {
fn from(val: HtmxRequestHeaders) -> Self {
match val {
HtmxRequestHeaders::HtmxRequest => HeaderName::from_static("hx-request"),
}
}
}
fn is_htmx(headers: &HeaderMap) -> bool {
headers
.get::<HeaderName>(HtmxRequestHeaders::HtmxRequest.into())
.map(|value| value == "true")
.unwrap_or(false)
}
fn get_referer<'a>(headers: &'a HeaderMap) -> Result<&'a str, Error> { fn get_referer<'a>(headers: &'a HeaderMap) -> Result<&'a str, Error> {
headers headers
.get("referer") .get("referer")
@@ -142,7 +112,10 @@ pub fn router(state: AppState) -> Router {
.route("/item/:id/edit", post(inventory_item_edit)) .route("/item/:id/edit", post(inventory_item_edit))
.route("/item/name/validate", post(inventory_item_validate_name)), .route("/item/name/validate", post(inventory_item_validate_name)),
) )
.layer(middleware::from_fn_with_state(state.clone(), authorize)), .layer(middleware::from_fn_with_state(
state.clone(),
auth::authorize,
)),
) )
.fallback(|| async { .fallback(|| async {
Error::Request(RequestError::NotFound { Error::Request(RequestError::NotFound {

View File

@@ -8,11 +8,12 @@ use axum::{
use serde::Deserialize; use serde::Deserialize;
use uuid::Uuid; use uuid::Uuid;
use crate::htmx;
use crate::models; use crate::models;
use crate::view; use crate::view;
use crate::{html, AppState, Context, Error, HtmxEvents, RequestError, TopLevelPage}; use crate::{AppState, Context, Error, RequestError, TopLevelPage};
use super::{get_referer, is_htmx, HtmxResponseHeaders}; use super::{get_referer, html};
#[derive(Deserialize, Default)] #[derive(Deserialize, Default)]
pub struct InventoryQuery { pub struct InventoryQuery {
@@ -210,7 +211,7 @@ pub async fn inventory_item_create(
) )
.await?; .await?;
if is_htmx(&headers) { if htmx::is_htmx(&headers) {
let inventory = models::inventory::Inventory::load(&state.database_pool).await?; let inventory = models::inventory::Inventory::load(&state.database_pool).await?;
// it's impossible to NOT find the item here, as we literally just added // it's impossible to NOT find the item here, as we literally just added
@@ -521,8 +522,8 @@ pub async fn trip_item_set_pick_htmx(
.await?; .await?;
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert::<HeaderName>( headers.insert::<HeaderName>(
HtmxResponseHeaders::Trigger.into(), htmx::ResponseHeaders::Trigger.into(),
HtmxEvents::TripItemEdited.into(), htmx::Event::TripItemEdited.into(),
); );
Ok((headers, trip_row(&state, trip_id, item_id).await?)) Ok((headers, trip_row(&state, trip_id, item_id).await?))
} }
@@ -559,8 +560,8 @@ pub async fn trip_item_set_unpick_htmx(
.await?; .await?;
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert::<HeaderName>( headers.insert::<HeaderName>(
HtmxResponseHeaders::Trigger.into(), htmx::ResponseHeaders::Trigger.into(),
HtmxEvents::TripItemEdited.into(), htmx::Event::TripItemEdited.into(),
); );
Ok((headers, trip_row(&state, trip_id, item_id).await?)) Ok((headers, trip_row(&state, trip_id, item_id).await?))
} }
@@ -597,8 +598,8 @@ pub async fn trip_item_set_pack_htmx(
.await?; .await?;
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert::<HeaderName>( headers.insert::<HeaderName>(
HtmxResponseHeaders::Trigger.into(), htmx::ResponseHeaders::Trigger.into(),
HtmxEvents::TripItemEdited.into(), htmx::Event::TripItemEdited.into(),
); );
Ok((headers, trip_row(&state, trip_id, item_id).await?)) Ok((headers, trip_row(&state, trip_id, item_id).await?))
} }
@@ -635,8 +636,8 @@ pub async fn trip_item_set_unpack_htmx(
.await?; .await?;
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert::<HeaderName>( headers.insert::<HeaderName>(
HtmxResponseHeaders::Trigger.into(), htmx::ResponseHeaders::Trigger.into(),
HtmxEvents::TripItemEdited.into(), htmx::Event::TripItemEdited.into(),
); );
Ok((headers, trip_row(&state, trip_id, item_id).await?)) Ok((headers, trip_row(&state, trip_id, item_id).await?))
} }
@@ -673,8 +674,8 @@ pub async fn trip_item_set_ready_htmx(
.await?; .await?;
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert::<HeaderName>( headers.insert::<HeaderName>(
HtmxResponseHeaders::Trigger.into(), htmx::ResponseHeaders::Trigger.into(),
HtmxEvents::TripItemEdited.into(), htmx::Event::TripItemEdited.into(),
); );
Ok((headers, trip_row(&state, trip_id, item_id).await?)) Ok((headers, trip_row(&state, trip_id, item_id).await?))
} }
@@ -711,8 +712,8 @@ pub async fn trip_item_set_unready_htmx(
.await?; .await?;
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert::<HeaderName>( headers.insert::<HeaderName>(
HtmxResponseHeaders::Trigger.into(), htmx::ResponseHeaders::Trigger.into(),
HtmxEvents::TripItemEdited.into(), htmx::Event::TripItemEdited.into(),
); );
Ok((headers, trip_row(&state, trip_id, item_id).await?)) Ok((headers, trip_row(&state, trip_id, item_id).await?))
} }
@@ -758,7 +759,7 @@ pub async fn trip_state_set(
})); }));
} }
if is_htmx(&headers) { if htmx::is_htmx(&headers) {
Ok(view::trip::TripInfoStateRow::build(&new_state).into_response()) Ok(view::trip::TripInfoStateRow::build(&new_state).into_response())
} else { } else {
Ok(Redirect::to(&format!("/trips/{id}/", id = trip_id)).into_response()) Ok(Redirect::to(&format!("/trips/{id}/", id = trip_id)).into_response())
@@ -861,7 +862,7 @@ pub async fn trip_category_select(
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert::<HeaderName>( headers.insert::<HeaderName>(
HtmxResponseHeaders::PushUrl.into(), htmx::ResponseHeaders::PushUrl.into(),
format!("?={category_id}").parse().unwrap(), format!("?={category_id}").parse().unwrap(),
); );
@@ -889,7 +890,7 @@ pub async fn inventory_category_select(
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert::<HeaderName>( headers.insert::<HeaderName>(
HtmxResponseHeaders::PushUrl.into(), htmx::ResponseHeaders::PushUrl.into(),
format!("/inventory/category/{category_id}/") format!("/inventory/category/{category_id}/")
.parse() .parse()
.unwrap(), .unwrap(),

View File

@@ -1,5 +1,5 @@
use crate::htmx;
use crate::models; use crate::models;
use crate::HtmxEvents;
use maud::{html, Markup, PreEscaped}; use maud::{html, Markup, PreEscaped};
use uuid::Uuid; use uuid::Uuid;
@@ -479,7 +479,7 @@ impl TripInfoTotalWeightRow {
html!( html!(
span span
hx-trigger={ hx-trigger={
(HtmxEvents::TripItemEdited.to_str()) " from:body" (htmx::Event::TripItemEdited.to_str()) " from:body"
} }
hx-get={"/trips/" (trip_id) "/total_weight"} hx-get={"/trips/" (trip_id) "/total_weight"}
{ {