Skip to content

feat: move oauth task from builder to here #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,36 @@ metrics-exporter-prometheus = "0.17.0"
# Slot Calc
chrono = "0.4.40"

# OAuth
oauth2 = { version = "5.0.0", optional = true }
tokio = { version = "1.36.0", optional = true }

# Other
thiserror = "2.0.11"
alloy = { version = "0.12.6", optional = true, default-features = false, features = ["std", "signer-aws", "signer-local", "consensus", "network"] }
serde = { version = "1", features = ["derive"] }
async-trait = { version = "0.1.80", optional = true }


# AWS
aws-config = { version = "1.1.7", optional = true }
aws-sdk-kms = { version = "1.15.0", optional = true }
reqwest = { version = "0.12.15", optional = true }

[dev-dependencies]
ajj = "0.3.1"
axum = "0.8.1"
eyre = "0.6.12"
serial_test = "3.2.0"
signal-hook = "0.3.17"
tokio = { version = "1.43.0", features = ["macros"] }

[features]
default = ["alloy"]
alloy = ["dep:alloy", "dep:async-trait", "dep:aws-config", "dep:aws-sdk-kms"]
perms = []
perms = ["dep:oauth2", "dep:tokio", "dep:reqwest"]

[[example]]
name = "oauth"
path = "examples/oauth.rs"
required-features = ["perms"]
15 changes: 15 additions & 0 deletions examples/oauth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use init4_bin_base::{perms::OAuthConfig, utils::from_env::FromEnv};

#[tokio::main]
async fn main() -> eyre::Result<()> {
let cfg = OAuthConfig::from_env()?;
let authenticator = cfg.authenticator();
let token = authenticator.token();

let _jh = authenticator.spawn();

tokio::time::sleep(std::time::Duration::from_secs(5)).await;
dbg!(token.read());

Ok(())
}
20 changes: 10 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,28 @@ pub mod perms;

/// Signet utilities.
pub mod utils {
/// Prometheus metrics utilities.
pub mod metrics;

/// OpenTelemetry utilities.
pub mod otlp;
/// Slot calculator for determining the current slot and timepoint within a
/// slot.
pub mod calc;

/// [`FromEnv`], [`FromEnvVar`] traits and related utilities.
///
/// [`FromEnv`]: from_env::FromEnv
/// [`FromEnvVar`]: from_env::FromEnvVar
pub mod from_env;

/// Tracing utilities.
pub mod tracing;
/// Prometheus metrics utilities.
pub mod metrics;

/// Slot calculator for determining the current slot and timepoint within a
/// slot.
pub mod calc;
/// OpenTelemetry utilities.
pub mod otlp;

#[cfg(feature = "alloy")]
/// Signer using a local private key or AWS KMS key.
pub mod signer;

/// Tracing utilities.
pub mod tracing;
}

/// Re-exports of common dependencies.
Expand Down
3 changes: 3 additions & 0 deletions src/perms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@ pub use builders::{Builder, BuilderPermissionError, Builders, BuildersEnvError};

pub(crate) mod config;
pub use config::{SlotAuthzConfig, SlotAuthzConfigEnvError};

pub(crate) mod oauth;
pub use oauth::{Authenticator, OAuthConfig, SharedToken};
185 changes: 185 additions & 0 deletions src/perms/oauth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
//! Service responsible for authenticating with the cache with Oauth tokens.
//! This authenticator periodically fetches a new token every set amount of seconds.
use crate::{
deps::tracing::{error, info},
utils::from_env::FromEnv,
};
use oauth2::{
basic::{BasicClient, BasicTokenType},
AuthUrl, ClientId, ClientSecret, EmptyExtraTokenFields, EndpointNotSet, EndpointSet,
HttpClientError, RequestTokenError, StandardErrorResponse, StandardTokenResponse, TokenUrl,
};
use std::sync::{Arc, Mutex};
use tokio::task::JoinHandle;

type Token = StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;

type MyOAuthClient =
BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>;

/// Configuration for the OAuth2 client.
#[derive(Debug, Clone, FromEnv)]
#[from_env(crate)]
pub struct OAuthConfig {
/// OAuth client ID for the builder.
#[from_env(var = "OAUTH_CLIENT_ID", desc = "OAuth client ID for the builder")]
pub oauth_client_id: String,
/// OAuth client secret for the builder.
#[from_env(
var = "OAUTH_CLIENT_SECRET",
desc = "OAuth client secret for the builder"
)]
pub oauth_client_secret: String,
/// OAuth authenticate URL for the builder for performing OAuth logins.
#[from_env(
var = "OAUTH_AUTHENTICATE_URL",
desc = "OAuth authenticate URL for the builder for performing OAuth logins"
)]
pub oauth_authenticate_url: url::Url,
/// OAuth token URL for the builder to get an OAuth2 access token
#[from_env(
var = "OAUTH_TOKEN_URL",
desc = "OAuth token URL for the builder to get an OAuth2 access token"
)]
pub oauth_token_url: url::Url,
/// The oauth token refresh interval in seconds.
#[from_env(
var = "AUTH_TOKEN_REFRESH_INTERVAL",
desc = "The oauth token refresh interval in seconds"
)]
pub oauth_token_refresh_interval: u64,
}

impl OAuthConfig {
/// Create a new [`Authenticator`] from the provided config.
pub fn authenticator(&self) -> Authenticator {
Authenticator::new(self)
}
}

/// A shared token that can be read and written to by multiple threads.
#[derive(Debug, Clone, Default)]
pub struct SharedToken(Arc<Mutex<Option<Token>>>);

impl SharedToken {
/// Read the token from the shared token.
pub fn read(&self) -> Option<Token> {
self.0.lock().unwrap().clone()
}

/// Write a new token to the shared token.
pub fn write(&self, token: Token) {
let mut lock = self.0.lock().unwrap();
*lock = Some(token);
}

/// Check if the token is authenticated.
pub fn is_authenticated(&self) -> bool {
self.0.lock().unwrap().is_some()
}
}

/// A self-refreshing, periodically fetching authenticator for the block
/// builder. This task periodically fetches a new token, and stores it in a
/// [`SharedToken`].
#[derive(Debug)]
pub struct Authenticator {
/// Configuration
pub config: OAuthConfig,
client: MyOAuthClient,
token: SharedToken,
reqwest: reqwest::Client,
}

impl Authenticator {
/// Creates a new Authenticator from the provided builder config.
pub fn new(config: &OAuthConfig) -> Self {
let client = BasicClient::new(ClientId::new(config.oauth_client_id.clone()))
.set_client_secret(ClientSecret::new(config.oauth_client_secret.clone()))
.set_auth_uri(AuthUrl::from_url(config.oauth_authenticate_url.clone()))
.set_token_uri(TokenUrl::from_url(config.oauth_token_url.clone()));

let rq_client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.unwrap();

Self {
config: config.clone(),
client,
token: Default::default(),
reqwest: rq_client,
}
}

/// Requests a new authentication token and, if successful, sets it to as the token
pub async fn authenticate(
&self,
) -> Result<
(),
RequestTokenError<
HttpClientError<reqwest::Error>,
StandardErrorResponse<oauth2::basic::BasicErrorResponseType>,
>,
> {
let token = self.fetch_oauth_token().await?;
self.set_token(token);
Ok(())
}

/// Returns true if there is Some token set
pub fn is_authenticated(&self) -> bool {
self.token.is_authenticated()
}

/// Sets the Authenticator's token to the provided value
fn set_token(&self, token: StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>) {
self.token.write(token);
}

/// Returns the currently set token
pub fn token(&self) -> SharedToken {
self.token.clone()
}

/// Fetches an oauth token
pub async fn fetch_oauth_token(
&self,
) -> Result<
Token,
RequestTokenError<
HttpClientError<reqwest::Error>,
StandardErrorResponse<oauth2::basic::BasicErrorResponseType>,
>,
> {
let token_result = self
.client
.exchange_client_credentials()
.request_async(&self.reqwest)
.await?;

Ok(token_result)
}

/// Spawns a task that periodically fetches a new token every 300 seconds.
pub fn spawn(self) -> JoinHandle<()> {
let interval = self.config.oauth_token_refresh_interval;

let handle: JoinHandle<()> = tokio::spawn(async move {
loop {
info!("Refreshing oauth token");
match self.authenticate().await {
Ok(_) => {
info!("Successfully refreshed oauth token");
}
Err(e) => {
error!(%e, "Failed to refresh oauth token");
}
};
let _sleep = tokio::time::sleep(tokio::time::Duration::from_secs(interval)).await;
}
});

handle
}
}
Loading