Skip to content

Commit

Permalink
Merge pull request #5 from wuvt/chore/updates
Browse files Browse the repository at this point in the history
chore: update dependencies
  • Loading branch information
jbellerb authored Mar 4, 2024
2 parents c600d8b + 85b38a4 commit 370a4a4
Show file tree
Hide file tree
Showing 8 changed files with 1,201 additions and 647 deletions.
1,521 changes: 1,002 additions & 519 deletions Cargo.lock

Large diffs are not rendered by default.

25 changes: 13 additions & 12 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,31 @@ version = "0.1.0"
edition = "2021"

[dependencies]
axum = { version = "0.6", features = ["macros"] }
anyhow = "1.0"
axum = { version = "0.7", features = ["macros"] }
base64ct = "1.6"
openidconnect = "2.5"
regex = "1.7"
openidconnect = "3.5"
regex = "1.10"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0"
time = { version = "0.3", features = ["formatting", "parsing"] }
tokio = { version = "1.27", features = ["full"] }
tokio = { version = "1.36", features = ["full"] }
tokio-postgres = { version = "0.7", features = ["with-time-0_3", "with-uuid-1"] }
tower = "0.4"
tower-cookies = "0.9"
tower-http = { version = "0.4", features = ["trace"] }
tower-cookies = "0.10"
tower-http = { version = "0.5", features = ["trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uuid = { version = "1.3", features = ["v4", "fast-rng"] }
uuid = { version = "1.7", features = ["v4", "fast-rng"] }

base64 = "0.21"
base64 = "0.22"
blake2 = "0.10"
chacha20 = "0.9"
ed25519-dalek = { version = "2.0.0-rc.2", features = ["pem"] }
ed25519-dalek = { version = "2.1.1", features = ["pem"] }
getrandom = "0.2"
subtle = "2.4"
zeroize = "1.6"
subtle = "2.5"
zeroize = "1.7"

[dev-dependencies]
hex-literal = "0.3"
hex-literal = "0.4"
34 changes: 34 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//! # poser
//!
//! poser is a simple, opinionated authentication provider for nginx
//!
//! ## About
//!
//! poser authenticates with Google using OpenID Connect and then uses the
//! Google Workspace Admin SDK to determine what groups a user is a part of.
//! Basic information about the user and what groups they are a part of is
//! returned to nginx in a [Paseto v4] token, which is then passed to the
//! application.
//!
//! [Paseto v4]: https://github.com/paseto-standard/paseto-spec

pub mod config;
pub mod error;
pub mod oidc;
mod routes;
pub mod shutdown;
pub mod token;

pub use routes::routes;

use std::sync::Arc;

#[derive(Debug, Clone)]
pub struct ServerState {
pub config: crate::config::Config,
pub db: Arc<tokio_postgres::Client>,
pub oidc: openidconnect::core::CoreClient,

// Signals back to the main thread when dropped
pub shutdown: crate::shutdown::Receiver,
}
159 changes: 72 additions & 87 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,68 +1,39 @@
//! # poser
//!
//! poser is a simple, opinionated authentication provider for nginx
//!
//! ## About
//!
//! poser authenticates with Google using OpenID Connect and then uses the
//! Google Workspace Admin SDK to determine what groups a user is a part of.
//! Basic information about the user and what groups they are a part of is
//! returned to nginx in a [Paseto v4] token, which is then passed to the
//! application.
//!
//! [Paseto v4]: https://github.com/paseto-standard/paseto-spec

pub mod config;
pub mod error;
pub mod oidc;
pub mod routes;
pub mod token;

use std::env::var;
use std::future::IntoFuture;
use std::sync::Arc;

use config::Config;
use oidc::setup_auth;
use openidconnect::core::CoreClient;
use routes::routes;
use poser::config::Config;
use poser::oidc::setup_auth;
use poser::shutdown;
use poser::{routes, ServerState};

use axum::Server;
use anyhow::Context;
use tokio::{
net::TcpListener,
runtime::Runtime,
select,
signal::unix::{signal, SignalKind},
sync::{broadcast, mpsc},
time::timeout,
};
use tokio_postgres::{Client, NoTls};
use tokio_postgres::NoTls;
use tower::ServiceBuilder;
use tower_cookies::CookieManagerLayer;
use tower_http::{
trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer},
LatencyUnit,
};
use tower_http::trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer};
use tracing::{debug, error, info, warn, Level};

#[derive(Debug, Clone)]
pub struct ServerState {
pub config: Config,
pub db: Arc<Client>,
pub oidc: CoreClient,

// Signals back to the main thread when dropped
_shutdown_complete: mpsc::Sender<()>,
}

fn main() {
tracing_subscriber::fmt()
.with_env_filter(var("RUST_LOG").unwrap_or_else(|_| "info".to_string()))
.with_env_filter(var("RUST_LOG").unwrap_or_else(|_| "warn,poser=info".to_string()))
.init();

let config = Config::try_env().expect("invalid configuration");
let config = Config::try_env()
.context("failed to build config")
.unwrap_or_else(|e| {
error!("{:#}", e);
std::process::exit(1);
});

build_runtime().block_on(async move {
let (shutdown_notify, _) = broadcast::channel(1);
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
let shutdown = shutdown::Sender::new();

let (db, conn) = tokio_postgres::connect(&config.database, NoTls)
.await
Expand All @@ -76,62 +47,66 @@ fn main() {
config: config.clone(),
db: Arc::new(db),
oidc,
_shutdown_complete: shutdown_tx.clone(),
shutdown: shutdown.subscribe(),
};

let db_signal = shutdown_tx.clone();
let conn = tokio::spawn(async move {
let postgres_notify = shutdown.subscribe();
let postgres = async move {
let res = conn.await;
drop(db_signal);
drop(postgres_notify);
res
});
};

let app = routes().with_state(state).layer(
let router = routes().with_state(state).layer(
ServiceBuilder::new()
.layer(
TraceLayer::new_for_http()
.on_request(DefaultOnRequest::new().level(Level::INFO))
.on_response(
DefaultOnResponse::new()
.level(Level::INFO)
.latency_unit(LatencyUnit::Micros),
.latency_unit(tower_http::LatencyUnit::Micros),
),
)
.layer(CookieManagerLayer::new()),
);
let shutdown_signal = shutdown_notify.subscribe();
let server = Server::bind(&config.addr)
.serve(app.into_make_service())
.with_graceful_shutdown(wait_for_shutdown(shutdown_signal));

info!("listening on {}", config.addr);

select! {
_ = unix_signal(SignalKind::interrupt()) => {
info!("received SIGINT, shutting down");
},
_ = unix_signal(SignalKind::terminate()) => {
info!("received SIGTERM, shutting down");
},
res = conn => match res {
let listener = TcpListener::bind(&config.addr)
.await
.context("failed to bind to socket")
.unwrap_or_else(|e| {
error!("{:#}", e);
std::process::exit(1);
});

let mut axum_notify = shutdown.subscribe();
let server = axum::serve(listener, router)
.with_graceful_shutdown(async move { _ = axum_notify.recv().await });

info!("listening for connections on: {}", config.addr);
tokio::select! {
sig = shutdown_signal() => info!("received {}, starting graceful shutdown...", sig),
res = tokio::spawn(postgres) => match res {
Ok(Ok(_)) => error!("database connection closed unexpectedly"),
Ok(Err(e)) => error!("database connection error: {}", e),
Err(e) => error!("database executor unexpectedly stopped: {}", e),
},
res = tokio::spawn(server) => match res {
res = tokio::spawn(server.into_future()) => match res {
Ok(Ok(_)) => info!("server shutting down"),
Ok(Err(e)) => error!("server unexpectedly stopped: {}", e),
Err(e) => error!("server executor unexpectedly stopped: {}", e),
},
}
drop(shutdown_notify);
drop(shutdown_tx);
match timeout(config.grace_period, wait_for_complete(shutdown_rx)).await {
Ok(()) => debug!("shutdown completed"),
Err(_) => warn!(
"graceful shutdown did not complete in {:?}, closing anyways",
config.grace_period
),

tokio::select! {
sig = shutdown_signal() => error!("received second {}, aborting.", sig),
res = timeout(config.grace_period, shutdown::Sender::shutdown(shutdown)) => match res {
Ok(()) => debug!("shutdown completed"),
Err(_) => warn!(
"graceful shutdown did not complete in {:?}, closing anyways",
config.grace_period
),
},
}
})
}
Expand All @@ -140,17 +115,27 @@ fn build_runtime() -> Runtime {
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("build threaded runtime")
}

async fn unix_signal(kind: SignalKind) {
signal(kind).expect("register signal handler").recv().await;
}

async fn wait_for_shutdown(mut signal: broadcast::Receiver<()>) {
_ = signal.recv().await;
.context("failed to build threaded runtime")
.unwrap_or_else(|e| {
error!("{:#}", e);
std::process::exit(1);
})
}

async fn wait_for_complete(mut signal: mpsc::Receiver<()>) {
_ = signal.recv().await;
async fn shutdown_signal() -> &'static str {
async fn wait_for_signal(kind: SignalKind) {
signal(kind)
.context("failed to register signal handler")
.unwrap_or_else(|e| {
error!("{:#}", e);
std::process::exit(1);
})
.recv()
.await;
}

tokio::select! {
_ = wait_for_signal(SignalKind::interrupt()) => "SIGINT",
_ = wait_for_signal(SignalKind::terminate()) => "SIGTERM",
}
}
29 changes: 16 additions & 13 deletions src/routes/callback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ pub enum Error {
#[error("missing id token")]
MissingToken,
#[error("error exchanging code")]
CodeError,
CodeExchange,
#[error("error verifying token")]
TokenError,
VerifyToken,
#[error("error interacting with database")]
DatabaseError,
DatabaseInsert,
}

/// A handler for receiving the callback during the OIDC flow.
Expand Down Expand Up @@ -86,11 +86,11 @@ pub async fn callback_handler(
let id = create_session(&state.db, &token, &expiration).await?;

cookies.add(
Cookie::build(state.config.cookie.name, id.simple().to_string())
Cookie::build((state.config.cookie.name, id.simple().to_string()))
.secure(state.config.cookie.secure)
.http_only(true)
.expires(expiration)
.finish(),
.build(),
);

Ok(Redirect::to(oidc.get_redirect()))
Expand All @@ -107,7 +107,7 @@ async fn get_token(
.await
.map_err(|e| {
error!("failed to exchange code for token: {}", e);
Error::CodeError
Error::CodeExchange
})?;

let token = token_response.extra_fields().id_token().ok_or_else(|| {
Expand All @@ -118,19 +118,22 @@ async fn get_token(
let id_token_verifier = client.id_token_verifier();
let claims = token.claims(&id_token_verifier, nonce).map_err(|e| {
error!("failed to verify id token: {}", e);
Error::TokenError
Error::VerifyToken
})?;

let subj = claims.subject();
let name = claims.name().and_then(|s| s.get(None)).ok_or_else(|| {
error!("name missing from id token");
Error::TokenError
Error::VerifyToken
})?;
let email = claims.email().ok_or_else(|| {
error!("email missing from id token");
Error::TokenError
Error::VerifyToken
})?;
let expiration = claims.expiration().timestamp_nanos();
let expiration = claims
.expiration()
.timestamp_nanos_opt()
.expect("todo: fix timestamp handling before 2262");

Ok((
UserToken {
Expand Down Expand Up @@ -174,7 +177,7 @@ async fn create_session(
.await
.map_err(|e| {
error!("error creating session: {}", e);
Error::DatabaseError
Error::DatabaseInsert
})
.map(|r| r.get::<_, Uuid>("id"))
}
Expand All @@ -186,8 +189,8 @@ impl IntoResponse for Error {
| Error::MissingState
| Error::MissingCode
| Error::MissingCookie => json!({ "error": "invalid request" }),
Error::MissingToken | Error::TokenError => json!({ "error": "authentication error" }),
Error::InvalidDateTime | Error::CodeError | Error::DatabaseError => {
Error::MissingToken | Error::VerifyToken => json!({ "error": "authentication error" }),
Error::InvalidDateTime | Error::CodeExchange | Error::DatabaseInsert => {
json!({ "error": "internal error" })
}
};
Expand Down
Loading

0 comments on commit 370a4a4

Please sign in to comment.