diff --git a/Cargo.toml b/Cargo.toml index 2400ecb..d2e8077 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,19 +35,26 @@ metrics-exporter-prometheus = "0.17.0" # Slot Calc chrono = "0.4.40" +# OAuth +oauth2 = { version = "5.0.0", optional = true } +tokio = { version = "1.36.0", optional = true } + # Other thiserror = "2.0.11" alloy = { version = "0.12.6", optional = true, default-features = false, features = ["std", "signer-aws", "signer-local", "consensus", "network"] } serde = { version = "1", features = ["derive"] } async-trait = { version = "0.1.80", optional = true } + # AWS aws-config = { version = "1.1.7", optional = true } aws-sdk-kms = { version = "1.15.0", optional = true } +reqwest = { version = "0.12.15", optional = true } [dev-dependencies] ajj = "0.3.1" axum = "0.8.1" +eyre = "0.6.12" serial_test = "3.2.0" signal-hook = "0.3.17" tokio = { version = "1.43.0", features = ["macros"] } @@ -55,4 +62,9 @@ tokio = { version = "1.43.0", features = ["macros"] } [features] default = ["alloy"] alloy = ["dep:alloy", "dep:async-trait", "dep:aws-config", "dep:aws-sdk-kms"] -perms = [] +perms = ["dep:oauth2", "dep:tokio", "dep:reqwest"] + +[[example]] +name = "oauth" +path = "examples/oauth.rs" +required-features = ["perms"] \ No newline at end of file diff --git a/examples/oauth.rs b/examples/oauth.rs new file mode 100644 index 0000000..1320190 --- /dev/null +++ b/examples/oauth.rs @@ -0,0 +1,15 @@ +use init4_bin_base::{perms::OAuthConfig, utils::from_env::FromEnv}; + +#[tokio::main] +async fn main() -> eyre::Result<()> { + let cfg = OAuthConfig::from_env()?; + let authenticator = cfg.authenticator(); + let token = authenticator.token(); + + let _jh = authenticator.spawn(); + + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + dbg!(token.read()); + + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index 460fce0..a2bd866 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,11 +18,9 @@ pub mod perms; /// Signet utilities. pub mod utils { - /// Prometheus metrics utilities. - pub mod metrics; - - /// OpenTelemetry utilities. - pub mod otlp; + /// Slot calculator for determining the current slot and timepoint within a + /// slot. + pub mod calc; /// [`FromEnv`], [`FromEnvVar`] traits and related utilities. /// @@ -30,16 +28,18 @@ pub mod utils { /// [`FromEnvVar`]: from_env::FromEnvVar pub mod from_env; - /// Tracing utilities. - pub mod tracing; + /// Prometheus metrics utilities. + pub mod metrics; - /// Slot calculator for determining the current slot and timepoint within a - /// slot. - pub mod calc; + /// OpenTelemetry utilities. + pub mod otlp; #[cfg(feature = "alloy")] /// Signer using a local private key or AWS KMS key. pub mod signer; + + /// Tracing utilities. + pub mod tracing; } /// Re-exports of common dependencies. diff --git a/src/perms/mod.rs b/src/perms/mod.rs index d81e802..a54d5ac 100644 --- a/src/perms/mod.rs +++ b/src/perms/mod.rs @@ -3,3 +3,6 @@ pub use builders::{Builder, BuilderPermissionError, Builders, BuildersEnvError}; pub(crate) mod config; pub use config::{SlotAuthzConfig, SlotAuthzConfigEnvError}; + +pub(crate) mod oauth; +pub use oauth::{Authenticator, OAuthConfig, SharedToken}; diff --git a/src/perms/oauth.rs b/src/perms/oauth.rs new file mode 100644 index 0000000..c0e9002 --- /dev/null +++ b/src/perms/oauth.rs @@ -0,0 +1,185 @@ +//! Service responsible for authenticating with the cache with Oauth tokens. +//! This authenticator periodically fetches a new token every set amount of seconds. +use crate::{ + deps::tracing::{error, info}, + utils::from_env::FromEnv, +}; +use oauth2::{ + basic::{BasicClient, BasicTokenType}, + AuthUrl, ClientId, ClientSecret, EmptyExtraTokenFields, EndpointNotSet, EndpointSet, + HttpClientError, RequestTokenError, StandardErrorResponse, StandardTokenResponse, TokenUrl, +}; +use std::sync::{Arc, Mutex}; +use tokio::task::JoinHandle; + +type Token = StandardTokenResponse; + +type MyOAuthClient = + BasicClient; + +/// Configuration for the OAuth2 client. +#[derive(Debug, Clone, FromEnv)] +#[from_env(crate)] +pub struct OAuthConfig { + /// OAuth client ID for the builder. + #[from_env(var = "OAUTH_CLIENT_ID", desc = "OAuth client ID for the builder")] + pub oauth_client_id: String, + /// OAuth client secret for the builder. + #[from_env( + var = "OAUTH_CLIENT_SECRET", + desc = "OAuth client secret for the builder" + )] + pub oauth_client_secret: String, + /// OAuth authenticate URL for the builder for performing OAuth logins. + #[from_env( + var = "OAUTH_AUTHENTICATE_URL", + desc = "OAuth authenticate URL for the builder for performing OAuth logins" + )] + pub oauth_authenticate_url: url::Url, + /// OAuth token URL for the builder to get an OAuth2 access token + #[from_env( + var = "OAUTH_TOKEN_URL", + desc = "OAuth token URL for the builder to get an OAuth2 access token" + )] + pub oauth_token_url: url::Url, + /// The oauth token refresh interval in seconds. + #[from_env( + var = "AUTH_TOKEN_REFRESH_INTERVAL", + desc = "The oauth token refresh interval in seconds" + )] + pub oauth_token_refresh_interval: u64, +} + +impl OAuthConfig { + /// Create a new [`Authenticator`] from the provided config. + pub fn authenticator(&self) -> Authenticator { + Authenticator::new(self) + } +} + +/// A shared token that can be read and written to by multiple threads. +#[derive(Debug, Clone, Default)] +pub struct SharedToken(Arc>>); + +impl SharedToken { + /// Read the token from the shared token. + pub fn read(&self) -> Option { + self.0.lock().unwrap().clone() + } + + /// Write a new token to the shared token. + pub fn write(&self, token: Token) { + let mut lock = self.0.lock().unwrap(); + *lock = Some(token); + } + + /// Check if the token is authenticated. + pub fn is_authenticated(&self) -> bool { + self.0.lock().unwrap().is_some() + } +} + +/// A self-refreshing, periodically fetching authenticator for the block +/// builder. This task periodically fetches a new token, and stores it in a +/// [`SharedToken`]. +#[derive(Debug)] +pub struct Authenticator { + /// Configuration + pub config: OAuthConfig, + client: MyOAuthClient, + token: SharedToken, + reqwest: reqwest::Client, +} + +impl Authenticator { + /// Creates a new Authenticator from the provided builder config. + pub fn new(config: &OAuthConfig) -> Self { + let client = BasicClient::new(ClientId::new(config.oauth_client_id.clone())) + .set_client_secret(ClientSecret::new(config.oauth_client_secret.clone())) + .set_auth_uri(AuthUrl::from_url(config.oauth_authenticate_url.clone())) + .set_token_uri(TokenUrl::from_url(config.oauth_token_url.clone())); + + let rq_client = reqwest::Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .build() + .unwrap(); + + Self { + config: config.clone(), + client, + token: Default::default(), + reqwest: rq_client, + } + } + + /// Requests a new authentication token and, if successful, sets it to as the token + pub async fn authenticate( + &self, + ) -> Result< + (), + RequestTokenError< + HttpClientError, + StandardErrorResponse, + >, + > { + let token = self.fetch_oauth_token().await?; + self.set_token(token); + Ok(()) + } + + /// Returns true if there is Some token set + pub fn is_authenticated(&self) -> bool { + self.token.is_authenticated() + } + + /// Sets the Authenticator's token to the provided value + fn set_token(&self, token: StandardTokenResponse) { + self.token.write(token); + } + + /// Returns the currently set token + pub fn token(&self) -> SharedToken { + self.token.clone() + } + + /// Fetches an oauth token + pub async fn fetch_oauth_token( + &self, + ) -> Result< + Token, + RequestTokenError< + HttpClientError, + StandardErrorResponse, + >, + > { + let token_result = self + .client + .exchange_client_credentials() + .request_async(&self.reqwest) + .await?; + + Ok(token_result) + } + + /// Spawns a task that periodically fetches a new token every 300 seconds. + pub fn spawn(self) -> JoinHandle<()> { + let interval = self.config.oauth_token_refresh_interval; + + let handle: JoinHandle<()> = tokio::spawn(async move { + loop { + info!("Refreshing oauth token"); + match self.authenticate().await { + Ok(_) => { + info!("Successfully refreshed oauth token"); + } + Err(e) => { + error!(%e, "Failed to refresh oauth token"); + } + }; + let _sleep = tokio::time::sleep(tokio::time::Duration::from_secs(interval)).await; + } + }); + + handle + } +}