auth etc
This commit is contained in:
1
rust/Cargo.lock
generated
1
rust/Cargo.lock
generated
@@ -1436,6 +1436,7 @@ dependencies = [
|
|||||||
"hyper",
|
"hyper",
|
||||||
"log",
|
"log",
|
||||||
"maud",
|
"maud",
|
||||||
|
"metrics",
|
||||||
"opentelemetry",
|
"opentelemetry",
|
||||||
"opentelemetry-jaeger",
|
"opentelemetry-jaeger",
|
||||||
"serde",
|
"serde",
|
||||||
|
|||||||
@@ -7,6 +7,13 @@ edition = "2021"
|
|||||||
name = "packager"
|
name = "packager"
|
||||||
path = "src/main.rs"
|
path = "src/main.rs"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
jaeger = []
|
||||||
|
prometheus = []
|
||||||
|
tokio-console = []
|
||||||
|
|
||||||
|
default = ["jaeger", "prometheus", "tokio-console"]
|
||||||
|
|
||||||
[profile.dev]
|
[profile.dev]
|
||||||
opt-level = 0
|
opt-level = 0
|
||||||
lto = "off"
|
lto = "off"
|
||||||
@@ -105,6 +112,8 @@ features = ["derive"]
|
|||||||
[dependencies.serde_variant]
|
[dependencies.serde_variant]
|
||||||
version = "0.1"
|
version = "0.1"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies.axum-prometheus]
|
||||||
axum-prometheus = "0.4"
|
version = "0.4"
|
||||||
|
|
||||||
|
[dependencies.metrics]
|
||||||
|
version = "0.21"
|
||||||
|
|||||||
100
rust/src/auth.rs
100
rust/src/auth.rs
@@ -1,10 +1,13 @@
|
|||||||
use axum::{extract::State, middleware::Next, response::IntoResponse};
|
use axum::{extract::State, middleware::Next, response::IntoResponse};
|
||||||
|
use futures::FutureExt;
|
||||||
use tracing::Instrument;
|
use tracing::Instrument;
|
||||||
|
|
||||||
use hyper::Request;
|
use hyper::Request;
|
||||||
|
|
||||||
|
use crate::models::user::User;
|
||||||
|
|
||||||
use super::models;
|
use super::models;
|
||||||
use super::{AppState, Error, RequestError};
|
use super::{AppState, AuthError, Error};
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub enum Config {
|
pub enum Config {
|
||||||
@@ -18,56 +21,73 @@ pub async fn authorize<B>(
|
|||||||
mut request: Request<B>,
|
mut request: Request<B>,
|
||||||
next: Next<B>,
|
next: Next<B>,
|
||||||
) -> Result<impl IntoResponse, Error> {
|
) -> Result<impl IntoResponse, Error> {
|
||||||
let current_user = async {
|
let user = async {
|
||||||
let user = match state.auth_config {
|
let auth: Result<Result<User, AuthError>, Error> = match state.auth_config {
|
||||||
Config::Disabled { assume_user } => {
|
Config::Disabled { assume_user } => {
|
||||||
let user =
|
let user =
|
||||||
match models::user::User::find_by_name(&state.database_pool, &assume_user)
|
match models::user::User::find_by_name(&state.database_pool, &assume_user)
|
||||||
.await?
|
.await?
|
||||||
{
|
{
|
||||||
Some(user) => user,
|
Some(user) => Ok(user),
|
||||||
None => {
|
None => Err(AuthError::AuthenticationUserNotFound {
|
||||||
return Err(Error::Request(RequestError::AuthenticationUserNotFound {
|
username: assume_user,
|
||||||
username: assume_user,
|
}),
|
||||||
}))
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
tracing::info!(?user, "auth disabled, requested user exists");
|
Ok(user)
|
||||||
user
|
|
||||||
}
|
}
|
||||||
Config::Enabled => {
|
Config::Enabled => match request.headers().get("x-auth-username") {
|
||||||
let Some(username) = request.headers().get("x-auth-username") else {
|
None => Ok(Err(AuthError::AuthenticationHeaderMissing)),
|
||||||
return Err(Error::Request(RequestError::AuthenticationHeaderMissing));
|
Some(username) => match username.to_str() {
|
||||||
};
|
Err(e) => Ok(Err(AuthError::AuthenticationHeaderInvalid {
|
||||||
let username = username
|
message: e.to_string(),
|
||||||
.to_str()
|
})),
|
||||||
.map_err(|error| {
|
Ok(username) => {
|
||||||
Error::Request(RequestError::AuthenticationHeaderInvalid {
|
match models::user::User::find_by_name(&state.database_pool, &username)
|
||||||
message: error.to_string(),
|
.await?
|
||||||
})
|
{
|
||||||
})?
|
Some(user) => Ok(Ok(user)),
|
||||||
.to_string();
|
None => Ok(Err(AuthError::AuthenticationUserNotFound {
|
||||||
|
username: username.to_string(),
|
||||||
let user = match models::user::User::find_by_name(&state.database_pool, &username)
|
})),
|
||||||
.await?
|
}
|
||||||
{
|
|
||||||
Some(user) => user,
|
|
||||||
None => {
|
|
||||||
tracing::warn!(username, "auth rejected, user not found");
|
|
||||||
return Err(Error::Request(RequestError::AuthenticationUserNotFound {
|
|
||||||
username,
|
|
||||||
}));
|
|
||||||
}
|
}
|
||||||
};
|
},
|
||||||
tracing::info!(?user, "auth successful");
|
},
|
||||||
user
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
Ok(user)
|
|
||||||
|
auth
|
||||||
}
|
}
|
||||||
.instrument(tracing::debug_span!("authorize"))
|
.instrument(tracing::debug_span!("authorize"))
|
||||||
.await?;
|
.inspect(|r| {
|
||||||
|
if let Ok(auth) = r {
|
||||||
|
match auth {
|
||||||
|
Ok(user) => tracing::debug!(?user, "auth successful"),
|
||||||
|
Err(e) => e.trace(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.map(|r| {
|
||||||
|
r.map(|auth| {
|
||||||
|
metrics::counter!(
|
||||||
|
format!("packager_auth_{}_total", {
|
||||||
|
match auth {
|
||||||
|
Ok(_) => "success".to_string(),
|
||||||
|
Err(ref e) => format!("failure_{}", e.to_prom_metric_name()),
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
1,
|
||||||
|
&match &auth {
|
||||||
|
Ok(user) => vec![("username", user.username.clone())],
|
||||||
|
Err(e) => e.to_prom_labels(),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
auth
|
||||||
|
})
|
||||||
|
})
|
||||||
|
// outer result: failure of the process, e.g. database connection failed
|
||||||
|
// inner result: auth rejected, with AuthError
|
||||||
|
.await??;
|
||||||
|
|
||||||
request.extensions_mut().insert(current_user);
|
request.extensions_mut().insert(user);
|
||||||
Ok(next.run(request).await)
|
Ok(next.run(request).await)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,8 @@
|
|||||||
use crate::{Error, StartError};
|
use crate::Error;
|
||||||
|
|
||||||
|
#[cfg(feature = "prometheus")]
|
||||||
|
use crate::StartError;
|
||||||
|
|
||||||
use clap::{Parser, Subcommand, ValueEnum};
|
use clap::{Parser, Subcommand, ValueEnum};
|
||||||
|
|
||||||
#[derive(ValueEnum, Clone, Copy, Debug)]
|
#[derive(ValueEnum, Clone, Copy, Debug)]
|
||||||
@@ -39,18 +43,23 @@ pub struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub database_url: String,
|
pub database_url: String,
|
||||||
|
|
||||||
|
#[cfg(feature = "jaeger")]
|
||||||
#[arg(long, value_enum, default_value_t = BoolArg::False)]
|
#[arg(long, value_enum, default_value_t = BoolArg::False)]
|
||||||
pub enable_opentelemetry: BoolArg,
|
pub enable_opentelemetry: BoolArg,
|
||||||
|
|
||||||
|
#[cfg(feature = "tokio-console")]
|
||||||
#[arg(long, value_enum, default_value_t = BoolArg::False)]
|
#[arg(long, value_enum, default_value_t = BoolArg::False)]
|
||||||
pub enable_tokio_console: BoolArg,
|
pub enable_tokio_console: BoolArg,
|
||||||
|
|
||||||
|
#[cfg(feature = "prometheus")]
|
||||||
#[arg(long, value_enum, default_value_t = BoolArg::False)]
|
#[arg(long, value_enum, default_value_t = BoolArg::False)]
|
||||||
pub enable_prometheus: BoolArg,
|
pub enable_prometheus: BoolArg,
|
||||||
|
|
||||||
|
#[cfg(feature = "prometheus")]
|
||||||
#[arg(long, value_enum, required_if_eq("enable_prometheus", BoolArg::True))]
|
#[arg(long, value_enum, required_if_eq("enable_prometheus", BoolArg::True))]
|
||||||
pub prometheus_port: Option<u16>,
|
pub prometheus_port: Option<u16>,
|
||||||
|
|
||||||
|
#[cfg(feature = "prometheus")]
|
||||||
#[arg(long, value_enum, required_if_eq("enable_prometheus", BoolArg::True))]
|
#[arg(long, value_enum, required_if_eq("enable_prometheus", BoolArg::True))]
|
||||||
pub prometheus_bind: Option<String>,
|
pub prometheus_bind: Option<String>,
|
||||||
|
|
||||||
@@ -99,14 +108,15 @@ impl Args {
|
|||||||
pub fn get() -> Result<Args, Error> {
|
pub fn get() -> Result<Args, Error> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
|
#[cfg(feature = "prometheus")]
|
||||||
if !args.enable_prometheus.bool()
|
if !args.enable_prometheus.bool()
|
||||||
&& (args.prometheus_port.is_some() || args.prometheus_bind.is_some())
|
&& (args.prometheus_port.is_some() || args.prometheus_bind.is_some())
|
||||||
{
|
{
|
||||||
Err(Error::Start(StartError::CallError {
|
return Err(Error::Start(StartError::CallError {
|
||||||
message: "do not set prometheus options when prometheus is not enabled".to_string(),
|
message: "do not set prometheus options when prometheus is not enabled".to_string(),
|
||||||
}))
|
}));
|
||||||
} else {
|
|
||||||
Ok(args)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(args)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -15,9 +15,7 @@ pub enum RequestError {
|
|||||||
RefererNotFound,
|
RefererNotFound,
|
||||||
RefererInvalid { message: String },
|
RefererInvalid { message: String },
|
||||||
NotFound { message: String },
|
NotFound { message: String },
|
||||||
AuthenticationUserNotFound { username: String },
|
Auth { inner: AuthError },
|
||||||
AuthenticationHeaderMissing,
|
|
||||||
AuthenticationHeaderInvalid { message: String },
|
|
||||||
Transport { inner: hyper::Error },
|
Transport { inner: hyper::Error },
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -30,12 +28,8 @@ impl fmt::Display for RequestError {
|
|||||||
Self::RefererNotFound => write!(f, "Referer header not found"),
|
Self::RefererNotFound => write!(f, "Referer header not found"),
|
||||||
Self::RefererInvalid { message } => write!(f, "Referer header invalid: {message}"),
|
Self::RefererInvalid { message } => write!(f, "Referer header invalid: {message}"),
|
||||||
Self::NotFound { message } => write!(f, "Not found: {message}"),
|
Self::NotFound { message } => write!(f, "Not found: {message}"),
|
||||||
Self::AuthenticationUserNotFound { username } => {
|
Self::Auth { inner } => {
|
||||||
write!(f, "User \"{username}\" not found")
|
write!(f, "Authentication failed: {inner}")
|
||||||
}
|
|
||||||
Self::AuthenticationHeaderMissing => write!(f, "Authentication header not found"),
|
|
||||||
Self::AuthenticationHeaderInvalid { message } => {
|
|
||||||
write!(f, "Authentication header invalid: {message}")
|
|
||||||
}
|
}
|
||||||
Self::Transport { inner } => {
|
Self::Transport { inner } => {
|
||||||
write!(f, "HTTP error: {inner}")
|
write!(f, "HTTP error: {inner}")
|
||||||
@@ -44,6 +38,67 @@ impl fmt::Display for RequestError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum AuthError {
|
||||||
|
AuthenticationUserNotFound { username: String },
|
||||||
|
AuthenticationHeaderMissing,
|
||||||
|
AuthenticationHeaderInvalid { message: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AuthError {
|
||||||
|
pub fn to_prom_metric_name(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::AuthenticationUserNotFound { username: _ } => "user_not_found",
|
||||||
|
Self::AuthenticationHeaderMissing => "header_missing",
|
||||||
|
Self::AuthenticationHeaderInvalid { message: _ } => "header_invalid",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn trace(&self) {
|
||||||
|
match self {
|
||||||
|
Self::AuthenticationUserNotFound { username } => {
|
||||||
|
tracing::info!(username, "auth failed, user not found")
|
||||||
|
}
|
||||||
|
Self::AuthenticationHeaderMissing => tracing::info!("auth failed, auth header missing"),
|
||||||
|
Self::AuthenticationHeaderInvalid { message } => {
|
||||||
|
tracing::info!(message, "auth failed, auth header invalid")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> AuthError {
|
||||||
|
pub fn to_prom_labels(&'a self) -> Vec<(&'static str, String)> {
|
||||||
|
match self {
|
||||||
|
Self::AuthenticationUserNotFound { username } => vec![("username", username.clone())],
|
||||||
|
Self::AuthenticationHeaderMissing => vec![],
|
||||||
|
Self::AuthenticationHeaderInvalid { message: _ } => vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for AuthError {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::AuthenticationUserNotFound { username } => {
|
||||||
|
write!(f, "User \"{username}\" not found")
|
||||||
|
}
|
||||||
|
Self::AuthenticationHeaderMissing => write!(f, "Authentication header not found"),
|
||||||
|
Self::AuthenticationHeaderInvalid { message } => {
|
||||||
|
write!(f, "Authentication header invalid: {message}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<AuthError> for Error {
|
||||||
|
fn from(e: AuthError) -> Self {
|
||||||
|
Self::Request(RequestError::Auth { inner: e })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for AuthError {}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
Model(models::Error),
|
Model(models::Error),
|
||||||
@@ -136,14 +191,9 @@ impl IntoResponse for Error {
|
|||||||
StatusCode::NOT_FOUND,
|
StatusCode::NOT_FOUND,
|
||||||
view::ErrorPage::build(&format!("not found: {message}")),
|
view::ErrorPage::build(&format!("not found: {message}")),
|
||||||
),
|
),
|
||||||
RequestError::AuthenticationUserNotFound { username: _ } => (
|
RequestError::Auth { inner: e } => (
|
||||||
StatusCode::BAD_REQUEST,
|
|
||||||
view::ErrorPage::build(&request_error.to_string()),
|
|
||||||
),
|
|
||||||
RequestError::AuthenticationHeaderMissing
|
|
||||||
| RequestError::AuthenticationHeaderInvalid { message: _ } => (
|
|
||||||
StatusCode::UNAUTHORIZED,
|
StatusCode::UNAUTHORIZED,
|
||||||
view::ErrorPage::build(&request_error.to_string()),
|
view::ErrorPage::build(&format!("authentication failed: {e}")),
|
||||||
),
|
),
|
||||||
RequestError::Transport { inner } => (
|
RequestError::Transport { inner } => (
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ use uuid::Uuid;
|
|||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
pub mod auth;
|
pub mod auth;
|
||||||
pub mod cmd;
|
pub mod cli;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod htmx;
|
pub mod htmx;
|
||||||
pub mod models;
|
pub mod models;
|
||||||
@@ -13,7 +13,7 @@ pub mod telemetry;
|
|||||||
|
|
||||||
mod view;
|
mod view;
|
||||||
|
|
||||||
pub use error::{CommandError, Error, RequestError, StartError};
|
pub use error::{AuthError, CommandError, Error, RequestError, StartError};
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use std::process::ExitCode;
|
|||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
|
||||||
use packager::{
|
use packager::{
|
||||||
auth, cmd, models, routing, sqlite, telemetry, AppState, ClientState, Error, StartError,
|
auth, cli, models, routing, sqlite, telemetry, AppState, ClientState, Error, StartError,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MainResult(Result<(), Error>);
|
struct MainResult(Result<(), Error>);
|
||||||
@@ -35,17 +35,19 @@ impl From<tokio::task::JoinError> for MainResult {
|
|||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> MainResult {
|
async fn main() -> MainResult {
|
||||||
let args = match cmd::Args::get() {
|
let args = match cli::Args::get() {
|
||||||
Ok(args) => args,
|
Ok(args) => args,
|
||||||
Err(e) => return e.into(),
|
Err(e) => return e.into(),
|
||||||
};
|
};
|
||||||
|
|
||||||
telemetry::tracing::init_tracing(
|
telemetry::tracing::init(
|
||||||
|
#[cfg(feature = "jaeger")]
|
||||||
if args.enable_opentelemetry.into() {
|
if args.enable_opentelemetry.into() {
|
||||||
telemetry::tracing::OpenTelemetryConfig::Enabled
|
telemetry::tracing::OpenTelemetryConfig::Enabled
|
||||||
} else {
|
} else {
|
||||||
telemetry::tracing::OpenTelemetryConfig::Disabled
|
telemetry::tracing::OpenTelemetryConfig::Disabled
|
||||||
},
|
},
|
||||||
|
#[cfg(feature = "tokio-console")]
|
||||||
if args.enable_tokio_console.into() {
|
if args.enable_tokio_console.into() {
|
||||||
telemetry::tracing::TokioConsoleConfig::Enabled
|
telemetry::tracing::TokioConsoleConfig::Enabled
|
||||||
} else {
|
} else {
|
||||||
@@ -55,7 +57,7 @@ async fn main() -> MainResult {
|
|||||||
|args| -> Pin<Box<dyn std::future::Future<Output = MainResult>>> {
|
|args| -> Pin<Box<dyn std::future::Future<Output = MainResult>>> {
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
match args.command {
|
match args.command {
|
||||||
cmd::Command::Serve(serve_args) => {
|
cli::Command::Serve(serve_args) => {
|
||||||
if let Err(e) = sqlite::migrate(&args.database_url).await {
|
if let Err(e) = sqlite::migrate(&args.database_url).await {
|
||||||
return <_ as Into<Error>>::into(e).into();
|
return <_ as Into<Error>>::into(e).into();
|
||||||
}
|
}
|
||||||
@@ -84,6 +86,7 @@ async fn main() -> MainResult {
|
|||||||
|
|
||||||
let mut join_set = tokio::task::JoinSet::new();
|
let mut join_set = tokio::task::JoinSet::new();
|
||||||
|
|
||||||
|
#[cfg(feature = "prometheus")]
|
||||||
let app = if args.enable_prometheus.into() {
|
let app = if args.enable_prometheus.into() {
|
||||||
// we `require_if()` prometheus port & bind when `enable_prometheus` is set, so
|
// we `require_if()` prometheus port & bind when `enable_prometheus` is set, so
|
||||||
// this cannot fail
|
// this cannot fail
|
||||||
@@ -145,9 +148,9 @@ async fn main() -> MainResult {
|
|||||||
|
|
||||||
return result.into();
|
return result.into();
|
||||||
}
|
}
|
||||||
cmd::Command::Admin(admin_command) => match admin_command {
|
cli::Command::Admin(admin_command) => match admin_command {
|
||||||
cmd::Admin::User(cmd) => match cmd {
|
cli::Admin::User(cmd) => match cmd {
|
||||||
cmd::UserCommand::Create(user) => {
|
cli::UserCommand::Create(user) => {
|
||||||
let database_pool =
|
let database_pool =
|
||||||
match sqlite::init_database_pool(&args.database_url).await {
|
match sqlite::init_database_pool(&args.database_url).await {
|
||||||
Ok(pool) => pool,
|
Ok(pool) => pool,
|
||||||
@@ -183,7 +186,7 @@ async fn main() -> MainResult {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
cmd::Command::Migrate => {
|
cli::Command::Migrate => {
|
||||||
if let Err(e) = sqlite::migrate(&args.database_url).await {
|
if let Err(e) = sqlite::migrate(&args.database_url).await {
|
||||||
return <_ as Into<Error>>::into(e).into();
|
return <_ as Into<Error>>::into(e).into();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,12 @@ use axum_prometheus::{Handle, MakeDefaultHandle, PrometheusMetricLayerBuilder};
|
|||||||
|
|
||||||
use crate::{Error, StartError};
|
use crate::{Error, StartError};
|
||||||
|
|
||||||
|
pub struct LabelBool(bool);
|
||||||
|
|
||||||
|
/// Serves metrics on the specified `addr`.
|
||||||
|
///
|
||||||
|
/// You will get two outputs back: Another router, and a task that you have
|
||||||
|
/// to run to actually spawn the metrics server endpoint
|
||||||
pub fn prometheus_server(
|
pub fn prometheus_server(
|
||||||
router: Router,
|
router: Router,
|
||||||
addr: std::net::SocketAddr,
|
addr: std::net::SocketAddr,
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ use tracing::Instrument;
|
|||||||
|
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
#[cfg(feature = "jaeger")]
|
||||||
use opentelemetry::{global, runtime::Tokio};
|
use opentelemetry::{global, runtime::Tokio};
|
||||||
|
|
||||||
pub enum OpenTelemetryConfig {
|
pub enum OpenTelemetryConfig {
|
||||||
@@ -72,6 +73,7 @@ trait Forwarder {
|
|||||||
) -> Option<Box<dyn tracing_subscriber::Layer<dyn tracing::Subscriber>>>;
|
) -> Option<Box<dyn tracing_subscriber::Layer<dyn tracing::Subscriber>>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "jaeger")]
|
||||||
fn get_jaeger_layer<
|
fn get_jaeger_layer<
|
||||||
T: tracing::Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a>,
|
T: tracing::Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a>,
|
||||||
>(
|
>(
|
||||||
@@ -118,34 +120,49 @@ fn get_jaeger_layer<
|
|||||||
opentelemetry_layer
|
opentelemetry_layer
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn init_tracing<Func, T>(
|
pub async fn init<Func, T>(
|
||||||
opentelemetry_config: OpenTelemetryConfig,
|
#[cfg(feature = "jaeger")] opentelemetry_config: OpenTelemetryConfig,
|
||||||
tokio_console_config: TokioConsoleConfig,
|
#[cfg(feature = "tokio-console")] tokio_console_config: TokioConsoleConfig,
|
||||||
args: crate::cmd::Args,
|
args: crate::cli::Args,
|
||||||
f: Func,
|
f: Func,
|
||||||
) -> T
|
) -> T
|
||||||
where
|
where
|
||||||
Func: FnOnce(crate::cmd::Args) -> Pin<Box<dyn Future<Output = T>>>,
|
Func: FnOnce(crate::cli::Args) -> Pin<Box<dyn Future<Output = T>>>,
|
||||||
T: std::process::Termination,
|
T: std::process::Termination,
|
||||||
{
|
{
|
||||||
let mut shutdown_functions: Vec<Box<dyn FnOnce() -> Result<(), Box<dyn std::error::Error>>>> =
|
// mut is dependent on features (it's only required when jaeger is set), so
|
||||||
|
// let's just disable the lint
|
||||||
|
#[cfg(feature = "jaeger")]
|
||||||
|
let mut shutdown_functions: Vec<
|
||||||
|
Box<dyn FnOnce() -> Result<(), Box<dyn std::error::Error>>>,
|
||||||
|
> = vec![];
|
||||||
|
|
||||||
|
#[cfg(not(feature = "jaeger"))]
|
||||||
|
let shutdown_functions: Vec<Box<dyn FnOnce() -> Result<(), Box<dyn std::error::Error>>>> =
|
||||||
vec![];
|
vec![];
|
||||||
|
|
||||||
|
#[cfg(feature = "tokio-console")]
|
||||||
let console_layer = match tokio_console_config {
|
let console_layer = match tokio_console_config {
|
||||||
TokioConsoleConfig::Enabled => Some(console_subscriber::Builder::default().spawn()),
|
TokioConsoleConfig::Enabled => Some(console_subscriber::Builder::default().spawn()),
|
||||||
TokioConsoleConfig::Disabled => None,
|
TokioConsoleConfig::Disabled => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let stdout_layer = get_stdout_layer();
|
let stdout_layer = get_stdout_layer();
|
||||||
|
|
||||||
|
#[cfg(feature = "jaeger")]
|
||||||
let jaeger_layer = get_jaeger_layer(opentelemetry_config, &mut shutdown_functions);
|
let jaeger_layer = get_jaeger_layer(opentelemetry_config, &mut shutdown_functions);
|
||||||
|
|
||||||
let registry = Registry::default()
|
let registry = Registry::default();
|
||||||
.with(console_layer)
|
|
||||||
.with(jaeger_layer)
|
#[cfg(feature = "tokio-console")]
|
||||||
// just an example, you can actuall pass Options here for layers that might be
|
let registry = registry.with(console_layer);
|
||||||
// set/unset at runtime
|
|
||||||
.with(stdout_layer)
|
#[cfg(feature = "jaeger")]
|
||||||
.with(None::<Layer<_>>);
|
let registry = registry.with(jaeger_layer);
|
||||||
|
// just an example, you can actuall pass Options here for layers that might be
|
||||||
|
// set/unset at runtime
|
||||||
|
|
||||||
|
let registry = registry.with(stdout_layer).with(None::<Layer<_>>);
|
||||||
|
|
||||||
tracing::subscriber::set_global_default(registry).unwrap();
|
tracing::subscriber::set_global_default(registry).unwrap();
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user