auth etc
This commit is contained in:
1
rust/Cargo.lock
generated
1
rust/Cargo.lock
generated
@@ -1436,6 +1436,7 @@ dependencies = [
|
||||
"hyper",
|
||||
"log",
|
||||
"maud",
|
||||
"metrics",
|
||||
"opentelemetry",
|
||||
"opentelemetry-jaeger",
|
||||
"serde",
|
||||
|
||||
@@ -7,6 +7,13 @@ edition = "2021"
|
||||
name = "packager"
|
||||
path = "src/main.rs"
|
||||
|
||||
[features]
|
||||
jaeger = []
|
||||
prometheus = []
|
||||
tokio-console = []
|
||||
|
||||
default = ["jaeger", "prometheus", "tokio-console"]
|
||||
|
||||
[profile.dev]
|
||||
opt-level = 0
|
||||
lto = "off"
|
||||
@@ -105,6 +112,8 @@ features = ["derive"]
|
||||
[dependencies.serde_variant]
|
||||
version = "0.1"
|
||||
|
||||
[dependencies]
|
||||
axum-prometheus = "0.4"
|
||||
[dependencies.axum-prometheus]
|
||||
version = "0.4"
|
||||
|
||||
[dependencies.metrics]
|
||||
version = "0.21"
|
||||
|
||||
104
rust/src/auth.rs
104
rust/src/auth.rs
@@ -1,10 +1,13 @@
|
||||
use axum::{extract::State, middleware::Next, response::IntoResponse};
|
||||
use futures::FutureExt;
|
||||
use tracing::Instrument;
|
||||
|
||||
use hyper::Request;
|
||||
|
||||
use crate::models::user::User;
|
||||
|
||||
use super::models;
|
||||
use super::{AppState, Error, RequestError};
|
||||
use super::{AppState, AuthError, Error};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum Config {
|
||||
@@ -18,56 +21,73 @@ pub async fn authorize<B>(
|
||||
mut request: Request<B>,
|
||||
next: Next<B>,
|
||||
) -> Result<impl IntoResponse, Error> {
|
||||
let current_user = async {
|
||||
let user = match state.auth_config {
|
||||
let user = async {
|
||||
let auth: Result<Result<User, AuthError>, Error> = match state.auth_config {
|
||||
Config::Disabled { assume_user } => {
|
||||
let user =
|
||||
match models::user::User::find_by_name(&state.database_pool, &assume_user)
|
||||
.await?
|
||||
{
|
||||
Some(user) => user,
|
||||
None => {
|
||||
return Err(Error::Request(RequestError::AuthenticationUserNotFound {
|
||||
Some(user) => Ok(user),
|
||||
None => Err(AuthError::AuthenticationUserNotFound {
|
||||
username: assume_user,
|
||||
}))
|
||||
}
|
||||
};
|
||||
tracing::info!(?user, "auth disabled, requested user exists");
|
||||
user
|
||||
}
|
||||
Config::Enabled => {
|
||||
let Some(username) = request.headers().get("x-auth-username") else {
|
||||
return Err(Error::Request(RequestError::AuthenticationHeaderMissing));
|
||||
};
|
||||
let username = username
|
||||
.to_str()
|
||||
.map_err(|error| {
|
||||
Error::Request(RequestError::AuthenticationHeaderInvalid {
|
||||
message: error.to_string(),
|
||||
})
|
||||
})?
|
||||
.to_string();
|
||||
|
||||
let user = match models::user::User::find_by_name(&state.database_pool, &username)
|
||||
.await?
|
||||
{
|
||||
Some(user) => user,
|
||||
None => {
|
||||
tracing::warn!(username, "auth rejected, user not found");
|
||||
return Err(Error::Request(RequestError::AuthenticationUserNotFound {
|
||||
username,
|
||||
}));
|
||||
}
|
||||
};
|
||||
tracing::info!(?user, "auth successful");
|
||||
user
|
||||
}
|
||||
}),
|
||||
};
|
||||
Ok(user)
|
||||
}
|
||||
.instrument(tracing::debug_span!("authorize"))
|
||||
.await?;
|
||||
Config::Enabled => match request.headers().get("x-auth-username") {
|
||||
None => Ok(Err(AuthError::AuthenticationHeaderMissing)),
|
||||
Some(username) => match username.to_str() {
|
||||
Err(e) => Ok(Err(AuthError::AuthenticationHeaderInvalid {
|
||||
message: e.to_string(),
|
||||
})),
|
||||
Ok(username) => {
|
||||
match models::user::User::find_by_name(&state.database_pool, &username)
|
||||
.await?
|
||||
{
|
||||
Some(user) => Ok(Ok(user)),
|
||||
None => Ok(Err(AuthError::AuthenticationUserNotFound {
|
||||
username: username.to_string(),
|
||||
})),
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
request.extensions_mut().insert(current_user);
|
||||
auth
|
||||
}
|
||||
.instrument(tracing::debug_span!("authorize"))
|
||||
.inspect(|r| {
|
||||
if let Ok(auth) = r {
|
||||
match auth {
|
||||
Ok(user) => tracing::debug!(?user, "auth successful"),
|
||||
Err(e) => e.trace(),
|
||||
}
|
||||
}
|
||||
})
|
||||
.map(|r| {
|
||||
r.map(|auth| {
|
||||
metrics::counter!(
|
||||
format!("packager_auth_{}_total", {
|
||||
match auth {
|
||||
Ok(_) => "success".to_string(),
|
||||
Err(ref e) => format!("failure_{}", e.to_prom_metric_name()),
|
||||
}
|
||||
}),
|
||||
1,
|
||||
&match &auth {
|
||||
Ok(user) => vec![("username", user.username.clone())],
|
||||
Err(e) => e.to_prom_labels(),
|
||||
}
|
||||
);
|
||||
auth
|
||||
})
|
||||
})
|
||||
// outer result: failure of the process, e.g. database connection failed
|
||||
// inner result: auth rejected, with AuthError
|
||||
.await??;
|
||||
|
||||
request.extensions_mut().insert(user);
|
||||
Ok(next.run(request).await)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
use crate::{Error, StartError};
|
||||
use crate::Error;
|
||||
|
||||
#[cfg(feature = "prometheus")]
|
||||
use crate::StartError;
|
||||
|
||||
use clap::{Parser, Subcommand, ValueEnum};
|
||||
|
||||
#[derive(ValueEnum, Clone, Copy, Debug)]
|
||||
@@ -39,18 +43,23 @@ pub struct Args {
|
||||
#[arg(long)]
|
||||
pub database_url: String,
|
||||
|
||||
#[cfg(feature = "jaeger")]
|
||||
#[arg(long, value_enum, default_value_t = BoolArg::False)]
|
||||
pub enable_opentelemetry: BoolArg,
|
||||
|
||||
#[cfg(feature = "tokio-console")]
|
||||
#[arg(long, value_enum, default_value_t = BoolArg::False)]
|
||||
pub enable_tokio_console: BoolArg,
|
||||
|
||||
#[cfg(feature = "prometheus")]
|
||||
#[arg(long, value_enum, default_value_t = BoolArg::False)]
|
||||
pub enable_prometheus: BoolArg,
|
||||
|
||||
#[cfg(feature = "prometheus")]
|
||||
#[arg(long, value_enum, required_if_eq("enable_prometheus", BoolArg::True))]
|
||||
pub prometheus_port: Option<u16>,
|
||||
|
||||
#[cfg(feature = "prometheus")]
|
||||
#[arg(long, value_enum, required_if_eq("enable_prometheus", BoolArg::True))]
|
||||
pub prometheus_bind: Option<String>,
|
||||
|
||||
@@ -99,14 +108,15 @@ impl Args {
|
||||
pub fn get() -> Result<Args, Error> {
|
||||
let args = Args::parse();
|
||||
|
||||
#[cfg(feature = "prometheus")]
|
||||
if !args.enable_prometheus.bool()
|
||||
&& (args.prometheus_port.is_some() || args.prometheus_bind.is_some())
|
||||
{
|
||||
Err(Error::Start(StartError::CallError {
|
||||
return Err(Error::Start(StartError::CallError {
|
||||
message: "do not set prometheus options when prometheus is not enabled".to_string(),
|
||||
}))
|
||||
} else {
|
||||
}));
|
||||
}
|
||||
|
||||
Ok(args)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15,9 +15,7 @@ pub enum RequestError {
|
||||
RefererNotFound,
|
||||
RefererInvalid { message: String },
|
||||
NotFound { message: String },
|
||||
AuthenticationUserNotFound { username: String },
|
||||
AuthenticationHeaderMissing,
|
||||
AuthenticationHeaderInvalid { message: String },
|
||||
Auth { inner: AuthError },
|
||||
Transport { inner: hyper::Error },
|
||||
}
|
||||
|
||||
@@ -30,12 +28,8 @@ impl fmt::Display for RequestError {
|
||||
Self::RefererNotFound => write!(f, "Referer header not found"),
|
||||
Self::RefererInvalid { message } => write!(f, "Referer header invalid: {message}"),
|
||||
Self::NotFound { message } => write!(f, "Not found: {message}"),
|
||||
Self::AuthenticationUserNotFound { username } => {
|
||||
write!(f, "User \"{username}\" not found")
|
||||
}
|
||||
Self::AuthenticationHeaderMissing => write!(f, "Authentication header not found"),
|
||||
Self::AuthenticationHeaderInvalid { message } => {
|
||||
write!(f, "Authentication header invalid: {message}")
|
||||
Self::Auth { inner } => {
|
||||
write!(f, "Authentication failed: {inner}")
|
||||
}
|
||||
Self::Transport { inner } => {
|
||||
write!(f, "HTTP error: {inner}")
|
||||
@@ -44,6 +38,67 @@ impl fmt::Display for RequestError {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum AuthError {
|
||||
AuthenticationUserNotFound { username: String },
|
||||
AuthenticationHeaderMissing,
|
||||
AuthenticationHeaderInvalid { message: String },
|
||||
}
|
||||
|
||||
impl AuthError {
|
||||
pub fn to_prom_metric_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::AuthenticationUserNotFound { username: _ } => "user_not_found",
|
||||
Self::AuthenticationHeaderMissing => "header_missing",
|
||||
Self::AuthenticationHeaderInvalid { message: _ } => "header_invalid",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn trace(&self) {
|
||||
match self {
|
||||
Self::AuthenticationUserNotFound { username } => {
|
||||
tracing::info!(username, "auth failed, user not found")
|
||||
}
|
||||
Self::AuthenticationHeaderMissing => tracing::info!("auth failed, auth header missing"),
|
||||
Self::AuthenticationHeaderInvalid { message } => {
|
||||
tracing::info!(message, "auth failed, auth header invalid")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> AuthError {
|
||||
pub fn to_prom_labels(&'a self) -> Vec<(&'static str, String)> {
|
||||
match self {
|
||||
Self::AuthenticationUserNotFound { username } => vec![("username", username.clone())],
|
||||
Self::AuthenticationHeaderMissing => vec![],
|
||||
Self::AuthenticationHeaderInvalid { message: _ } => vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for AuthError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
Self::AuthenticationUserNotFound { username } => {
|
||||
write!(f, "User \"{username}\" not found")
|
||||
}
|
||||
Self::AuthenticationHeaderMissing => write!(f, "Authentication header not found"),
|
||||
Self::AuthenticationHeaderInvalid { message } => {
|
||||
write!(f, "Authentication header invalid: {message}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<AuthError> for Error {
|
||||
fn from(e: AuthError) -> Self {
|
||||
Self::Request(RequestError::Auth { inner: e })
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for AuthError {}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
Model(models::Error),
|
||||
@@ -136,14 +191,9 @@ impl IntoResponse for Error {
|
||||
StatusCode::NOT_FOUND,
|
||||
view::ErrorPage::build(&format!("not found: {message}")),
|
||||
),
|
||||
RequestError::AuthenticationUserNotFound { username: _ } => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
view::ErrorPage::build(&request_error.to_string()),
|
||||
),
|
||||
RequestError::AuthenticationHeaderMissing
|
||||
| RequestError::AuthenticationHeaderInvalid { message: _ } => (
|
||||
RequestError::Auth { inner: e } => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
view::ErrorPage::build(&request_error.to_string()),
|
||||
view::ErrorPage::build(&format!("authentication failed: {e}")),
|
||||
),
|
||||
RequestError::Transport { inner } => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
|
||||
@@ -3,7 +3,7 @@ use uuid::Uuid;
|
||||
use std::fmt;
|
||||
|
||||
pub mod auth;
|
||||
pub mod cmd;
|
||||
pub mod cli;
|
||||
pub mod error;
|
||||
pub mod htmx;
|
||||
pub mod models;
|
||||
@@ -13,7 +13,7 @@ pub mod telemetry;
|
||||
|
||||
mod view;
|
||||
|
||||
pub use error::{CommandError, Error, RequestError, StartError};
|
||||
pub use error::{AuthError, CommandError, Error, RequestError, StartError};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AppState {
|
||||
|
||||
@@ -4,7 +4,7 @@ use std::process::ExitCode;
|
||||
use std::str::FromStr;
|
||||
|
||||
use packager::{
|
||||
auth, cmd, models, routing, sqlite, telemetry, AppState, ClientState, Error, StartError,
|
||||
auth, cli, models, routing, sqlite, telemetry, AppState, ClientState, Error, StartError,
|
||||
};
|
||||
|
||||
struct MainResult(Result<(), Error>);
|
||||
@@ -35,17 +35,19 @@ impl From<tokio::task::JoinError> for MainResult {
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> MainResult {
|
||||
let args = match cmd::Args::get() {
|
||||
let args = match cli::Args::get() {
|
||||
Ok(args) => args,
|
||||
Err(e) => return e.into(),
|
||||
};
|
||||
|
||||
telemetry::tracing::init_tracing(
|
||||
telemetry::tracing::init(
|
||||
#[cfg(feature = "jaeger")]
|
||||
if args.enable_opentelemetry.into() {
|
||||
telemetry::tracing::OpenTelemetryConfig::Enabled
|
||||
} else {
|
||||
telemetry::tracing::OpenTelemetryConfig::Disabled
|
||||
},
|
||||
#[cfg(feature = "tokio-console")]
|
||||
if args.enable_tokio_console.into() {
|
||||
telemetry::tracing::TokioConsoleConfig::Enabled
|
||||
} else {
|
||||
@@ -55,7 +57,7 @@ async fn main() -> MainResult {
|
||||
|args| -> Pin<Box<dyn std::future::Future<Output = MainResult>>> {
|
||||
Box::pin(async move {
|
||||
match args.command {
|
||||
cmd::Command::Serve(serve_args) => {
|
||||
cli::Command::Serve(serve_args) => {
|
||||
if let Err(e) = sqlite::migrate(&args.database_url).await {
|
||||
return <_ as Into<Error>>::into(e).into();
|
||||
}
|
||||
@@ -84,6 +86,7 @@ async fn main() -> MainResult {
|
||||
|
||||
let mut join_set = tokio::task::JoinSet::new();
|
||||
|
||||
#[cfg(feature = "prometheus")]
|
||||
let app = if args.enable_prometheus.into() {
|
||||
// we `require_if()` prometheus port & bind when `enable_prometheus` is set, so
|
||||
// this cannot fail
|
||||
@@ -145,9 +148,9 @@ async fn main() -> MainResult {
|
||||
|
||||
return result.into();
|
||||
}
|
||||
cmd::Command::Admin(admin_command) => match admin_command {
|
||||
cmd::Admin::User(cmd) => match cmd {
|
||||
cmd::UserCommand::Create(user) => {
|
||||
cli::Command::Admin(admin_command) => match admin_command {
|
||||
cli::Admin::User(cmd) => match cmd {
|
||||
cli::UserCommand::Create(user) => {
|
||||
let database_pool =
|
||||
match sqlite::init_database_pool(&args.database_url).await {
|
||||
Ok(pool) => pool,
|
||||
@@ -183,7 +186,7 @@ async fn main() -> MainResult {
|
||||
}
|
||||
},
|
||||
},
|
||||
cmd::Command::Migrate => {
|
||||
cli::Command::Migrate => {
|
||||
if let Err(e) = sqlite::migrate(&args.database_url).await {
|
||||
return <_ as Into<Error>>::into(e).into();
|
||||
}
|
||||
|
||||
@@ -7,6 +7,12 @@ use axum_prometheus::{Handle, MakeDefaultHandle, PrometheusMetricLayerBuilder};
|
||||
|
||||
use crate::{Error, StartError};
|
||||
|
||||
pub struct LabelBool(bool);
|
||||
|
||||
/// Serves metrics on the specified `addr`.
|
||||
///
|
||||
/// You will get two outputs back: Another router, and a task that you have
|
||||
/// to run to actually spawn the metrics server endpoint
|
||||
pub fn prometheus_server(
|
||||
router: Router,
|
||||
addr: std::net::SocketAddr,
|
||||
|
||||
@@ -22,6 +22,7 @@ use tracing::Instrument;
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
#[cfg(feature = "jaeger")]
|
||||
use opentelemetry::{global, runtime::Tokio};
|
||||
|
||||
pub enum OpenTelemetryConfig {
|
||||
@@ -72,6 +73,7 @@ trait Forwarder {
|
||||
) -> Option<Box<dyn tracing_subscriber::Layer<dyn tracing::Subscriber>>>;
|
||||
}
|
||||
|
||||
#[cfg(feature = "jaeger")]
|
||||
fn get_jaeger_layer<
|
||||
T: tracing::Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a>,
|
||||
>(
|
||||
@@ -118,34 +120,49 @@ fn get_jaeger_layer<
|
||||
opentelemetry_layer
|
||||
}
|
||||
|
||||
pub async fn init_tracing<Func, T>(
|
||||
opentelemetry_config: OpenTelemetryConfig,
|
||||
tokio_console_config: TokioConsoleConfig,
|
||||
args: crate::cmd::Args,
|
||||
pub async fn init<Func, T>(
|
||||
#[cfg(feature = "jaeger")] opentelemetry_config: OpenTelemetryConfig,
|
||||
#[cfg(feature = "tokio-console")] tokio_console_config: TokioConsoleConfig,
|
||||
args: crate::cli::Args,
|
||||
f: Func,
|
||||
) -> T
|
||||
where
|
||||
Func: FnOnce(crate::cmd::Args) -> Pin<Box<dyn Future<Output = T>>>,
|
||||
Func: FnOnce(crate::cli::Args) -> Pin<Box<dyn Future<Output = T>>>,
|
||||
T: std::process::Termination,
|
||||
{
|
||||
let mut shutdown_functions: Vec<Box<dyn FnOnce() -> Result<(), Box<dyn std::error::Error>>>> =
|
||||
// mut is dependent on features (it's only required when jaeger is set), so
|
||||
// let's just disable the lint
|
||||
#[cfg(feature = "jaeger")]
|
||||
let mut shutdown_functions: Vec<
|
||||
Box<dyn FnOnce() -> Result<(), Box<dyn std::error::Error>>>,
|
||||
> = vec![];
|
||||
|
||||
#[cfg(not(feature = "jaeger"))]
|
||||
let shutdown_functions: Vec<Box<dyn FnOnce() -> Result<(), Box<dyn std::error::Error>>>> =
|
||||
vec![];
|
||||
|
||||
#[cfg(feature = "tokio-console")]
|
||||
let console_layer = match tokio_console_config {
|
||||
TokioConsoleConfig::Enabled => Some(console_subscriber::Builder::default().spawn()),
|
||||
TokioConsoleConfig::Disabled => None,
|
||||
};
|
||||
|
||||
let stdout_layer = get_stdout_layer();
|
||||
|
||||
#[cfg(feature = "jaeger")]
|
||||
let jaeger_layer = get_jaeger_layer(opentelemetry_config, &mut shutdown_functions);
|
||||
|
||||
let registry = Registry::default()
|
||||
.with(console_layer)
|
||||
.with(jaeger_layer)
|
||||
let registry = Registry::default();
|
||||
|
||||
#[cfg(feature = "tokio-console")]
|
||||
let registry = registry.with(console_layer);
|
||||
|
||||
#[cfg(feature = "jaeger")]
|
||||
let registry = registry.with(jaeger_layer);
|
||||
// just an example, you can actuall pass Options here for layers that might be
|
||||
// set/unset at runtime
|
||||
.with(stdout_layer)
|
||||
.with(None::<Layer<_>>);
|
||||
|
||||
let registry = registry.with(stdout_layer).with(None::<Layer<_>>);
|
||||
|
||||
tracing::subscriber::set_global_default(registry).unwrap();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user