diff --git a/.vscode/settings.json b/.vscode/settings.json index 73d2a89b4d..016a434344 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,10 +2,13 @@ "azure-pipelines.1ESPipelineTemplatesSchemaFile": true, "cSpell.enabled": true, "editor.formatOnSave": true, + "markdownlint.config": { + "MD024": false + }, "rust-analyzer.cargo.features": "all", "rust-analyzer.check.command": "clippy", "yaml.format.printWidth": 240, "[powershell]": { "editor.defaultFormatter": "ms-vscode.powershell", }, -} \ No newline at end of file +} diff --git a/Cargo.lock b/Cargo.lock index 7c125f604c..46934dd8ea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -374,6 +374,7 @@ dependencies = [ "async-process", "async-trait", "azure_core", + "azure_core_test", "azure_security_keyvault_secrets", "clap", "futures", diff --git a/sdk/core/azure_core_test/Cargo.toml b/sdk/core/azure_core_test/Cargo.toml index 636fe9770c..5d4fc4ce77 100644 --- a/sdk/core/azure_core_test/Cargo.toml +++ b/sdk/core/azure_core_test/Cargo.toml @@ -57,4 +57,4 @@ tracing-subscriber = { workspace = true, features = ["env-filter", "fmt"] } uuid.workspace = true [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] -tokio = { workspace = true, features = ["signal"] } +tokio = { workspace = true, features = ["rt", "signal"] } diff --git a/sdk/core/azure_core_test/src/credential.rs b/sdk/core/azure_core_test/src/credential.rs deleted file mode 100644 index 164756eccc..0000000000 --- a/sdk/core/azure_core_test/src/credential.rs +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -use azure_core::{ - credentials::{AccessToken, Secret, TokenCredential}, - date::OffsetDateTime, - error::ErrorKind, -}; -use std::time::Duration; - -/// A mock [`TokenCredential`] useful for testing. -#[derive(Clone, Debug, Default)] -pub struct MockCredential; - -#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] -impl TokenCredential for MockCredential { - async fn get_token(&self, scopes: &[&str]) -> azure_core::Result { - let token: Secret = format!("TEST TOKEN {}", scopes.join(" ")).into(); - let expires_on = OffsetDateTime::now_utc().saturating_add( - Duration::from_secs(60 * 5).try_into().map_err(|err| { - azure_core::Error::full(ErrorKind::Other, err, "failed to compute expiration") - })?, - ); - Ok(AccessToken { token, expires_on }) - } - - async fn clear_cache(&self) -> azure_core::Result<()> { - Ok(()) - } -} diff --git a/sdk/core/azure_core_test/src/credentials.rs b/sdk/core/azure_core_test/src/credentials.rs new file mode 100644 index 0000000000..6337d3ecb0 --- /dev/null +++ b/sdk/core/azure_core_test/src/credentials.rs @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! Credentials for live and recorded tests. +use azure_core::{ + credentials::{AccessToken, Secret, TokenCredential}, + date::OffsetDateTime, + error::ErrorKind, +}; +use azure_identity::{AzurePipelinesCredential, DefaultAzureCredential, TokenCredentialOptions}; +use std::{env, sync::Arc, time::Duration}; + +/// A mock [`TokenCredential`] useful for testing. +#[derive(Clone, Debug, Default)] +pub struct MockCredential; + +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +impl TokenCredential for MockCredential { + async fn get_token(&self, scopes: &[&str]) -> azure_core::Result { + let token: Secret = format!("TEST TOKEN {}", scopes.join(" ")).into(); + let expires_on = OffsetDateTime::now_utc().saturating_add( + Duration::from_secs(60 * 5).try_into().map_err(|err| { + azure_core::Error::full(ErrorKind::Other, err, "failed to compute expiration") + })?, + ); + Ok(AccessToken { token, expires_on }) + } + + async fn clear_cache(&self) -> azure_core::Result<()> { + Ok(()) + } +} + +/// Gets a `TokenCredential` appropriate for the current environment. +/// +/// When running in Azure Pipelines, this will return an [`AzurePipelinesCredential`]; +/// otherwise, it will return a [`DefaultAzureCredential`]. +pub fn from_env( + options: Option, +) -> azure_core::Result> { + // cspell:ignore accesstoken azuresubscription + let tenant_id = env::var("AZURESUBSCRIPTION_TENANT_ID").ok(); + let client_id = env::var("AZURESUBSCRIPTION_CLIENT_ID").ok(); + let connection_id = env::var("AZURESUBSCRIPTION_SERVICE_CONNECTION_ID").ok(); + let access_token = env::var("SYSTEM_ACCESSTOKEN").ok(); + + if let (Some(tenant_id), Some(client_id), Some(connection_id), Some(access_token)) = + (tenant_id, client_id, connection_id, access_token) + { + if !tenant_id.is_empty() + && !client_id.is_empty() + && !connection_id.is_empty() + && !access_token.is_empty() + { + return Ok(AzurePipelinesCredential::new( + tenant_id, + client_id, + &connection_id, + access_token, + options.map(Into::into), + )? as Arc); + } + } + + Ok( + DefaultAzureCredential::with_options(options.unwrap_or_default())? + as Arc, + ) +} diff --git a/sdk/core/azure_core_test/src/http/clients.rs b/sdk/core/azure_core_test/src/http/clients.rs new file mode 100644 index 0000000000..ff41930ad1 --- /dev/null +++ b/sdk/core/azure_core_test/src/http/clients.rs @@ -0,0 +1,120 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use async_trait::async_trait; +use azure_core::{HttpClient, Request, Response, Result}; +#[cfg(test)] +use futures::FutureExt as _; +use futures::{future::BoxFuture, lock::Mutex}; +use std::fmt; + +/// An [`HttpClient`] from which you can assert [`Request`]s and return mock [`Response`]s. +/// +/// # Examples +/// +/// ``` +/// use azure_core::{ +/// Bytes, ClientOptions, +/// headers::Headers, +/// Response, StatusCode, TransportOptions, +/// }; +/// use azure_core_test::http::MockHttpClient; +/// use azure_identity::DefaultAzureCredential; +/// use azure_security_keyvault_secrets::{SecretClient, SecretClientOptions}; +/// use futures::FutureExt as _; +/// use std::sync::Arc; +/// +/// # #[tokio::main] +/// # async fn main() -> Result<(), Box> { +/// let mock_client = Arc::new(MockHttpClient::new(|req| async { +/// assert_eq!(req.url().host_str(), Some("my-vault.vault.azure.net")); +/// Ok(Response::from_bytes( +/// StatusCode::Ok, +/// Headers::new(), +/// Bytes::from_static(br#"{"value":"secret"}"#), +/// )) +/// }.boxed())); +/// let credential = DefaultAzureCredential::new()?; +/// let options = SecretClientOptions { +/// client_options: ClientOptions { +/// transport: Some(TransportOptions::new(mock_client.clone())), +/// ..Default::default() +/// }, +/// ..Default::default() +/// }; +/// let client = SecretClient::new( +/// "https://my-vault.vault.azure.net", +/// credential.clone(), +/// Some(options), +/// ); +/// # Ok(()) +/// # } +/// ``` +pub struct MockHttpClient(Mutex); + +impl MockHttpClient +where + C: FnMut(&Request) -> BoxFuture<'_, Result> + Send + Sync, +{ + /// Creates a new `MockHttpClient` using a capture. + /// + /// The capture takes a `&Request` and returns a `BoxedFuture>`. + /// See the example on [`MockHttpClient`]. + pub fn new(client: C) -> Self { + Self(Mutex::new(client)) + } +} + +impl fmt::Debug for MockHttpClient { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(stringify!("MockHttpClient")) + } +} + +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +impl HttpClient for MockHttpClient +where + C: FnMut(&Request) -> BoxFuture<'_, Result> + Send + Sync, +{ + async fn execute_request(&self, req: &Request) -> Result { + let mut client = self.0.lock().await; + (client)(req).await + } +} + +#[tokio::test] +async fn test_mock_http_client() { + use azure_core::{ + headers::{HeaderName, Headers}, + Method, StatusCode, + }; + use std::sync::{Arc, Mutex}; + + const COUNT_HEADER: HeaderName = HeaderName::from_static("x-count"); + + let count = Arc::new(Mutex::new(0)); + let mock_client = Arc::new(MockHttpClient::new(|req| { + let count = count.clone(); + async move { + assert_eq!(req.url().host_str(), Some("localhost")); + + if req.headers().get_optional_str(&COUNT_HEADER).is_some() { + let mut count = count.lock().unwrap(); + *count += 1; + } + + Ok(Response::from_bytes(StatusCode::Ok, Headers::new(), vec![])) + } + .boxed() + })) as Arc; + + let req = Request::new("https://localhost".parse().unwrap(), Method::Get); + mock_client.execute_request(&req).await.unwrap(); + + let mut req = Request::new("https://localhost".parse().unwrap(), Method::Get); + req.insert_header(COUNT_HEADER, "true"); + mock_client.execute_request(&req).await.unwrap(); + + assert_eq!(*count.lock().unwrap(), 1); +} diff --git a/sdk/core/azure_core_test/src/http/mod.rs b/sdk/core/azure_core_test/src/http/mod.rs new file mode 100644 index 0000000000..ed45d335b7 --- /dev/null +++ b/sdk/core/azure_core_test/src/http/mod.rs @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! HTTP testing utilities. +mod clients; + +pub use clients::*; diff --git a/sdk/core/azure_core_test/src/lib.rs b/sdk/core/azure_core_test/src/lib.rs index 0c19e2bf70..7893551f01 100644 --- a/sdk/core/azure_core_test/src/lib.rs +++ b/sdk/core/azure_core_test/src/lib.rs @@ -3,14 +3,14 @@ #![doc = include_str!("../README.md")] -mod credential; +pub mod credentials; +pub mod http; pub mod proxy; pub mod recorded; mod recording; use azure_core::Error; pub use azure_core::{error::ErrorKind, test::TestMode}; -pub use credential::*; pub use proxy::{matchers::*, sanitizers::*}; pub use recording::*; use std::path::{Path, PathBuf}; diff --git a/sdk/core/azure_core_test/src/recording.rs b/sdk/core/azure_core_test/src/recording.rs index 9495cd0b55..a1a7fbaf44 100644 --- a/sdk/core/azure_core_test/src/recording.rs +++ b/sdk/core/azure_core_test/src/recording.rs @@ -5,6 +5,7 @@ // cspell:ignore csprng seedable tpbwhbkhckmk use crate::{ + credentials::{self, MockCredential}, proxy::{ client::{ Client, ClientAddSanitizerOptions, ClientRemoveSanitizersOptions, @@ -14,7 +15,7 @@ use crate::{ policy::RecordingPolicy, Proxy, RecordingId, }, - Matcher, MockCredential, Sanitizer, + Matcher, Sanitizer, }; use azure_core::{ base64, @@ -24,7 +25,6 @@ use azure_core::{ test::TestMode, ClientOptions, Header, }; -use azure_identity::DefaultAzureCredential; use rand::{ distributions::{Alphanumeric, DistString, Distribution, Standard}, Rng, SeedableRng, @@ -83,7 +83,7 @@ impl Recording { pub fn credential(&self) -> Arc { match self.test_mode { TestMode::Playback => Arc::new(MockCredential) as Arc, - _ => DefaultAzureCredential::new().map_or_else( + _ => credentials::from_env(None).map_or_else( |err| panic!("failed to create DefaultAzureCredential: {err}"), |cred| cred as Arc, ), diff --git a/sdk/identity/azure_identity/CHANGELOG.md b/sdk/identity/azure_identity/CHANGELOG.md index bd5012b81d..763d54f84e 100644 --- a/sdk/identity/azure_identity/CHANGELOG.md +++ b/sdk/identity/azure_identity/CHANGELOG.md @@ -1,5 +1,20 @@ # Release History +## 0.32.0 (Unreleased) + +### Features Added + +- Added `AzurePipelinesCredential`. + +### Breaking Changes + +- `ClientAssertionCredential` constructors moved some parameters to an `Option` parameter. +- `WorkloadIdentityCredential` constructors moved some parameters to an `Option` parameter. + +### Bugs Fixed + +### Other Changes + ## 0.22.0 (2025-02-18) ### Features Added diff --git a/sdk/identity/azure_identity/Cargo.toml b/sdk/identity/azure_identity/Cargo.toml index 3274883efc..741379a647 100644 --- a/sdk/identity/azure_identity/Cargo.toml +++ b/sdk/identity/azure_identity/Cargo.toml @@ -13,18 +13,18 @@ categories = ["api-bindings"] edition.workspace = true [dependencies] -azure_core.workspace = true async-lock.workspace = true -oauth2.workspace = true -url.workspace = true +async-trait.workspace = true +azure_core.workspace = true futures.workspace = true +oauth2.workspace = true +openssl = { workspace = true, optional = true } +pin-project.workspace = true serde.workspace = true time.workspace = true tracing.workspace = true -async-trait.workspace = true -openssl = { workspace = true, optional = true } -pin-project.workspace = true typespec_client_core = { workspace = true, features = ["derive"] } +url.workspace = true [target.'cfg(not(target_arch = "wasm32"))'.dependencies] async-process.workspace = true @@ -33,13 +33,14 @@ async-process.workspace = true tz-rs = { workspace = true, optional = true } [dev-dependencies] +azure_core_test.workspace = true azure_security_keyvault_secrets = { path = "../../keyvault/azure_security_keyvault_secrets" } +clap.workspace = true reqwest.workspace = true -tokio.workspace = true -tracing-subscriber.workspace = true serde_test.workspace = true serial_test.workspace = true -clap.workspace = true +tokio.workspace = true +tracing-subscriber.workspace = true [features] default = ["reqwest", "old_azure_cli"] diff --git a/sdk/identity/azure_identity/examples/specific_credential.rs b/sdk/identity/azure_identity/examples/specific_credential.rs index a3e22c1a04..2cf3419c92 100644 --- a/sdk/identity/azure_identity/examples/specific_credential.rs +++ b/sdk/identity/azure_identity/examples/specific_credential.rs @@ -136,7 +136,7 @@ impl SpecificAzureCredential { ) })?, azure_credential_kinds::WORKLOAD_IDENTITY => { - WorkloadIdentityCredential::from_env(options) + WorkloadIdentityCredential::from_env(Some(options.into())) .map(SpecificAzureCredentialKind::WorkloadIdentity) .with_context(ErrorKind::Credential, || { format!( diff --git a/sdk/identity/azure_identity/src/azure_pipelines_credential.rs b/sdk/identity/azure_identity/src/azure_pipelines_credential.rs new file mode 100644 index 0000000000..ea5a88b823 --- /dev/null +++ b/sdk/identity/azure_identity/src/azure_pipelines_credential.rs @@ -0,0 +1,323 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use crate::{ + ClientAssertion, ClientAssertionCredential, ClientAssertionCredentialOptions, + TokenCredentialOptions, +}; +use azure_core::{ + credentials::{AccessToken, Secret, TokenCredential}, + error::ErrorKind, + headers::{FromHeaders, HeaderName, Headers, AUTHORIZATION}, + HttpClient, Method, Request, StatusCode, Url, +}; +use serde::Deserialize; +use std::{convert::Infallible, fmt, sync::Arc}; + +// cspell:ignore fedauthredirect oidcrequesturi +const OIDC_VARIABLE_NAME: &str = "SYSTEM_OIDCREQUESTURI"; +const OIDC_VERSION: &str = "7.1"; +const TFS_FEDAUTHREDIRECT_HEADER: HeaderName = HeaderName::from_static("x-tfs-fedauthredirect"); + +// TODO: https://github.com/Azure/azure-sdk-for-rust/issues/682 +const ALLOWED_HEADERS: &[&str] = &["x-msedge-ref", "x-vss-e2eid"]; + +#[derive(Debug)] +pub struct AzurePipelinesCredential(ClientAssertionCredential); + +/// Options for constructing a new [`AzurePipelinesCredential`]. +#[derive(Debug, Default)] +pub struct AzurePipelinesCredentialOptions { + /// Options for the [`ClientAssertionCredential`] used by the [`AzurePipelinesCredential`]. + pub credential_options: ClientAssertionCredentialOptions, +} + +// TODO: Should probably remove this once we consolidate and unify credentials. +impl From for AzurePipelinesCredentialOptions { + fn from(value: TokenCredentialOptions) -> Self { + Self { + credential_options: ClientAssertionCredentialOptions { + credential_options: value, + ..Default::default() + }, + } + } +} + +impl AzurePipelinesCredential { + /// Creates a new [`AzurePipelinesCredential`] for connecting to resources from Azure Pipelines. + pub fn new( + tenant_id: String, + client_id: String, + service_connection_id: &str, + system_access_token: T, + options: Option, + ) -> azure_core::Result> + where + T: Into, + { + let system_access_token = system_access_token.into(); + + crate::validate_tenant_id(&tenant_id)?; + crate::validate_not_empty(&client_id, "no client ID specified")?; + crate::validate_not_empty(service_connection_id, "no service connection ID specified")?; + crate::validate_not_empty( + system_access_token.secret(), + "no system access token specified", + )?; + + let options = options.unwrap_or_default(); + let env = options.credential_options.credential_options.env(); + let endpoint = env.var(OIDC_VARIABLE_NAME).map_err(|err| azure_core::Error::full(ErrorKind::Credential, err, format!("no value for environment variable {OIDC_VARIABLE_NAME}. This should be set by Azure Pipelines")))?; + let mut endpoint: Url = endpoint.parse().map_err(|err| { + azure_core::Error::full( + ErrorKind::Credential, + err, + format!("invalid URL for environment variable {OIDC_VARIABLE_NAME}"), + ) + })?; + endpoint + .query_pairs_mut() + .append_pair("api-version", OIDC_VERSION) + .append_pair("serviceConnectionId", service_connection_id); + let client = Client { + endpoint, + http_client: options.credential_options.credential_options.http_client(), + system_access_token, + }; + let credential = ClientAssertionCredential::new_exclusive( + tenant_id, + client_id, + client, + Some(options.credential_options), + )?; + + Ok(Arc::new(Self(credential))) + } +} + +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +impl TokenCredential for AzurePipelinesCredential { + async fn get_token(&self, scopes: &[&str]) -> azure_core::Result { + self.0.get_token(scopes).await + } + + async fn clear_cache(&self) -> azure_core::Result<()> { + self.0.clear_cache().await + } +} + +#[derive(Debug)] +struct Client { + endpoint: Url, + http_client: Arc, + system_access_token: Secret, +} + +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +impl ClientAssertion for Client { + async fn secret(&self) -> azure_core::Result { + let mut req = Request::new(self.endpoint.clone(), Method::Post); + req.insert_header( + AUTHORIZATION, + String::from("Bearer ") + self.system_access_token.secret(), + ); + req.insert_header(TFS_FEDAUTHREDIRECT_HEADER, "Suppress"); + + // TODO: Consider defining and using azure_identity-specific pipeline, or even from azure_core. + let resp = self.http_client.execute_request(&req).await?; + if resp.status() != StatusCode::Ok { + let status_code = resp.status(); + let err_headers: ErrorHeaders = resp.headers().get()?; + + return Err( + azure_core::Error::message( + ErrorKind::http_response(status_code, Some(status_code.canonical_reason().to_string())), + format!("{status_code} response from the OIDC endpoint. Check service connection ID and pipeline configuration. {err_headers}"), + ) + ); + } + + let assertion: Assertion = resp.into_json_body().await?; + Ok(assertion.oidc_token.secret().to_string()) + } +} + +#[derive(Debug, Deserialize)] +struct Assertion { + #[serde(rename = "oidcToken")] + oidc_token: Secret, +} + +#[derive(Debug)] +struct ErrorHeaders { + msedge_ref: Option, + vss_e2eid: Option, +} + +const MSEDGE_REF: HeaderName = HeaderName::from_static("x-msedge-ref"); +const VSS_E2EID: HeaderName = HeaderName::from_static("x-vss-e2eid"); + +impl FromHeaders for ErrorHeaders { + type Error = Infallible; + + fn header_names() -> &'static [&'static str] { + ALLOWED_HEADERS + } + + fn from_headers(headers: &Headers) -> Result, Self::Error> { + Ok(Some(Self { + msedge_ref: headers.get_optional_string(&MSEDGE_REF), + vss_e2eid: headers.get_optional_string(&VSS_E2EID), + })) + } +} + +impl fmt::Display for ErrorHeaders { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + r#"{{ x-msedge-ref: "{}", x-vss-e2eid: "{}" }}"#, + self.msedge_ref.as_ref().map_or("", AsRef::as_ref), + self.vss_e2eid.as_ref().map_or("", AsRef::as_ref), + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::env::Env; + use azure_core::{Bytes, Response}; + use azure_core_test::http::MockHttpClient; + use futures::FutureExt as _; + + #[test] + fn param_errors() { + assert!(AzurePipelinesCredential::new("".into(), "".into(), "", "", None).is_err()); + assert!(AzurePipelinesCredential::new("_".into(), "".into(), "", "", None).is_err()); + assert!(AzurePipelinesCredential::new("a".into(), "".into(), "", "", None).is_err()); + assert!(AzurePipelinesCredential::new("a".into(), "b".into(), "", "", None).is_err()); + assert!(AzurePipelinesCredential::new("a".into(), "b".into(), "c", "", None).is_err()); + + let options = AzurePipelinesCredentialOptions { + credential_options: ClientAssertionCredentialOptions { + credential_options: TokenCredentialOptions { + env: Env::from(&[(OIDC_VARIABLE_NAME, "http://localhost/get_token")][..]), + ..Default::default() + }, + ..Default::default() + }, + }; + assert!( + AzurePipelinesCredential::new("a".into(), "b".into(), "c", "d", Some(options)).is_ok() + ); + } + + #[tokio::test] + async fn error_headers() { + let mock_client = MockHttpClient::new(|req| { + assert_eq!( + req.url().as_str(), + "http://localhost/get_token?api-version=7.1&serviceConnectionId=c" + ); + let mut headers = Headers::new(); + headers.insert(MSEDGE_REF, "foo"); + headers.insert(VSS_E2EID, "bar"); + + async move { + Ok(Response::from_bytes( + StatusCode::Forbidden, + headers, + Vec::new(), + )) + } + .boxed() + }); + let options = AzurePipelinesCredentialOptions { + credential_options: ClientAssertionCredentialOptions { + credential_options: TokenCredentialOptions { + env: Env::from(&[(OIDC_VARIABLE_NAME, "http://localhost/get_token")][..]), + http_client: Arc::new(mock_client), + ..Default::default() + }, + ..Default::default() + }, + }; + let credential = + AzurePipelinesCredential::new("a".into(), "b".into(), "c", "d", Some(options)) + .expect("valid AzurePipelinesCredential"); + assert!(matches!( + credential.get_token(&["default"]).await, + Err(err) if matches!( + err.kind(), + ErrorKind::HttpResponse { status, .. } + if *status == StatusCode::Forbidden && + err.to_string().contains("foo") && + err.to_string().contains("bar"), + ) + )); + } + + #[tokio::test] + async fn mock_request() { + let mock_client = MockHttpClient::new(|req| { + async move { + if req.url().as_str() + == "http://localhost/get_token?api-version=7.1&serviceConnectionId=c" + { + assert!(matches!( + req.headers().get_str(&AUTHORIZATION), + Ok(value) if value == "Bearer d", + )); + assert!(matches!( + req.headers().get_str(&TFS_FEDAUTHREDIRECT_HEADER), + Ok(value) if value == "Suppress", + )); + + let mut headers = Headers::new(); + headers.insert(MSEDGE_REF, "foo"); + headers.insert(VSS_E2EID, "bar"); + + return Ok(Response::from_bytes( + StatusCode::Ok, + headers, + Bytes::from_static(br#"{"oidcToken":"baz"}"#), + )); + } + + if req.url().as_str() == "https://login.microsoftonline.com/a/oauth2/v2.0/token" { + return Ok(Response::from_bytes( + StatusCode::Ok, + Headers::new(), + Bytes::from_static( + br#"{"token_type":"test","expires_in":0,"ext_expires_in":0,"access_token":"qux"}"#, + ), + )); + } + + panic!("not supported") + }.boxed() + }); + let options = AzurePipelinesCredentialOptions { + credential_options: ClientAssertionCredentialOptions { + credential_options: TokenCredentialOptions { + env: Env::from(&[(OIDC_VARIABLE_NAME, "http://localhost/get_token")][..]), + http_client: Arc::new(mock_client), + ..Default::default() + }, + ..Default::default() + }, + }; + let credential = + AzurePipelinesCredential::new("a".into(), "b".into(), "c", "d", Some(options)) + .expect("valid AzurePipelinesCredential"); + let secret = credential + .get_token(&["default"]) + .await + .expect("valid response"); + assert_eq!(secret.token.secret(), "qux"); + } +} diff --git a/sdk/identity/azure_identity/src/credentials/client_assertion_credentials.rs b/sdk/identity/azure_identity/src/credentials/client_assertion_credentials.rs index 08b77f5a6d..51941434be 100644 --- a/sdk/identity/azure_identity/src/credentials/client_assertion_credentials.rs +++ b/sdk/identity/azure_identity/src/credentials/client_assertion_credentials.rs @@ -5,7 +5,6 @@ use crate::{credentials::cache::TokenCache, federated_credentials_flow, TokenCre use azure_core::{ credentials::{AccessToken, TokenCredential}, error::{ErrorKind, ResultExt}, - HttpClient, Url, }; use std::{fmt::Debug, str, sync::Arc, time::Duration}; use time::OffsetDateTime; @@ -16,12 +15,30 @@ const AZURE_CLIENT_ID_ENV_KEY: &str = "AZURE_CLIENT_ID"; /// Enables authentication of a Microsoft Entra service principal using a signed client assertion. #[derive(Debug)] pub struct ClientAssertionCredential { - http_client: Arc, - authority_host: Url, tenant_id: String, client_id: String, assertion: C, cache: TokenCache, + options: ClientAssertionCredentialOptions, +} + +/// Options for constructing a new [`ClientAssertionCredential`]. +#[derive(Debug, Default)] +pub struct ClientAssertionCredentialOptions { + /// Additional tenants for which the credential may acquire tokens. + /// + /// Add the wildcard value "*" to allow the credential to acquire tokens for any tenant in which the application is registered. + pub additionally_allowed_tenants: Vec, + + /// Should be set true only by applications authenticating in disconnected clouds, or private clouds such as Azure Stack. + /// + /// It determines whether the credential requests Microsoft Entra instance metadata + /// from before authenticating. Setting this to true will skip this request, making + /// the application responsible for ensuring the configured authority is valid and trustworthy. + pub disable_instance_discovery: bool, + + /// Options for constructing credentials. + pub credential_options: TokenCredentialOptions, } #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] @@ -35,18 +52,13 @@ pub trait ClientAssertion: Send + Sync + Debug { impl ClientAssertionCredential { /// Create a new `ClientAssertionCredential`. pub fn new( - http_client: Arc, - authority_host: Url, tenant_id: String, client_id: String, assertion: C, + options: Option, ) -> azure_core::Result> { Ok(Arc::new(Self::new_exclusive( - http_client, - authority_host, - tenant_id, - client_id, - assertion, + tenant_id, client_id, assertion, options, )?)) } @@ -54,19 +66,17 @@ impl ClientAssertionCredential { /// `Arc`. Intended for use by other credentials in the crate that will /// themselves be protected by an `Arc`. pub(crate) fn new_exclusive( - http_client: Arc, - authority_host: Url, tenant_id: String, client_id: String, assertion: C, + options: Option, ) -> azure_core::Result { Ok(Self { - http_client, - authority_host, tenant_id, client_id, assertion, cache: TokenCache::new(), + options: options.unwrap_or_default(), }) } @@ -77,10 +87,10 @@ impl ClientAssertionCredential { /// * `AZURE_TENANT_ID` /// * `AZURE_CLIENT_ID` pub fn from_env( - options: impl Into, assertion: C, + options: Option, ) -> azure_core::Result> { - Ok(Arc::new(Self::from_env_exclusive(options, assertion)?)) + Ok(Arc::new(Self::from_env_exclusive(assertion, options)?)) } /// Create a new `ClientAssertionCredential` from environment variables, @@ -92,13 +102,11 @@ impl ClientAssertionCredential { /// * `AZURE_TENANT_ID` /// * `AZURE_CLIENT_ID` pub(crate) fn from_env_exclusive( - options: impl Into, assertion: C, + options: Option, ) -> azure_core::Result { - let options = options.into(); - let http_client = options.http_client(); - let authority_host = options.authority_host()?; - let env = options.env(); + let options = options.unwrap_or_default(); + let env = options.credential_options.env(); let tenant_id = env.var(AZURE_TENANT_ID_ENV_KEY) .with_context(ErrorKind::Credential, || { @@ -116,24 +124,19 @@ impl ClientAssertionCredential { ) })?; - ClientAssertionCredential::new_exclusive( - http_client, - authority_host, - tenant_id, - client_id, - assertion, - ) + ClientAssertionCredential::new_exclusive(tenant_id, client_id, assertion, Some(options)) } async fn get_token(&self, scopes: &[&str]) -> azure_core::Result { let token = self.assertion.secret().await?; + let credential_options = &self.options.credential_options; let res: AccessToken = federated_credentials_flow::authorize( - self.http_client.clone(), + credential_options.http_client().clone(), &self.client_id, &token, scopes, &self.tenant_id, - &self.authority_host, + &credential_options.authority_host()?, ) .await .map(|r| { diff --git a/sdk/identity/azure_identity/src/credentials/default_credentials.rs b/sdk/identity/azure_identity/src/credentials/default_azure_credentials.rs similarity index 100% rename from sdk/identity/azure_identity/src/credentials/default_credentials.rs rename to sdk/identity/azure_identity/src/credentials/default_azure_credentials.rs diff --git a/sdk/identity/azure_identity/src/credentials/mod.rs b/sdk/identity/azure_identity/src/credentials/mod.rs index 02570faa31..7f41fffe8d 100644 --- a/sdk/identity/azure_identity/src/credentials/mod.rs +++ b/sdk/identity/azure_identity/src/credentials/mod.rs @@ -15,7 +15,7 @@ mod cache; mod client_assertion_credentials; #[cfg(feature = "client_certificate")] mod client_certificate_credentials; -mod default_credentials; +mod default_azure_credentials; mod imds_managed_identity_credentials; mod options; mod virtual_machine_managed_identity_credential; @@ -27,7 +27,7 @@ pub use azure_cli_credentials::*; pub use client_assertion_credentials::*; #[cfg(feature = "client_certificate")] pub use client_certificate_credentials::*; -pub use default_credentials::*; +pub use default_azure_credentials::*; pub use imds_managed_identity_credentials::ImdsId; pub(crate) use imds_managed_identity_credentials::*; pub use options::*; diff --git a/sdk/identity/azure_identity/src/credentials/options.rs b/sdk/identity/azure_identity/src/credentials/options.rs index a7f35441ea..366efa2c1b 100644 --- a/sdk/identity/azure_identity/src/credentials/options.rs +++ b/sdk/identity/azure_identity/src/credentials/options.rs @@ -15,12 +15,13 @@ const AZURE_PUBLIC_CLOUD: &str = "https://login.microsoftonline.com"; /// requests to Azure Active Directory. #[derive(Debug, Clone)] pub struct TokenCredentialOptions { - env: Env, - http_client: Arc, - authority_host: String, + pub(crate) env: Env, + pub(crate) http_client: Arc, + pub(crate) authority_host: String, } /// The default token credential options. +/// /// The authority host is taken from the `AZURE_AUTHORITY_HOST` environment variable if set and a valid URL. /// If not, the default authority host is `https://login.microsoftonline.com` for the Azure public cloud. impl Default for TokenCredentialOptions { diff --git a/sdk/identity/azure_identity/src/credentials/workload_identity_credentials.rs b/sdk/identity/azure_identity/src/credentials/workload_identity_credentials.rs index 5d8b6b65e9..f69375d28f 100644 --- a/sdk/identity/azure_identity/src/credentials/workload_identity_credentials.rs +++ b/sdk/identity/azure_identity/src/credentials/workload_identity_credentials.rs @@ -1,12 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -use crate::TokenCredentialOptions; use async_lock::{RwLock, RwLockUpgradableReadGuard}; use azure_core::{ credentials::{AccessToken, Secret, TokenCredential}, error::{ErrorKind, ResultExt}, - Error, HttpClient, Url, + Error, }; use futures::channel::oneshot; use std::{ @@ -16,7 +15,10 @@ use std::{ time::{Duration, Instant}, }; -use super::client_assertion_credentials::{ClientAssertion, ClientAssertionCredential}; +use super::{ + client_assertion_credentials::{ClientAssertion, ClientAssertionCredential}, + ClientAssertionCredentialOptions, TokenCredentialOptions, +}; const AZURE_FEDERATED_TOKEN_FILE: &str = "AZURE_FEDERATED_TOKEN_FILE"; const AZURE_FEDERATED_TOKEN: &str = "AZURE_FEDERATED_TOKEN"; @@ -28,25 +30,31 @@ const AZURE_FEDERATED_TOKEN: &str = "AZURE_FEDERATED_TOKEN"; #[derive(Debug)] pub struct WorkloadIdentityCredential(ClientAssertionCredential); +/// Options for constructing a new [`WorkloadIdentityCredential`]. +#[derive(Debug, Default)] +pub struct WorkloadIdentityCredentialOptions { + /// Options for the [`ClientAssertionCredential`] used by the [`WorkloadIdentityCredential`]. + pub credential_options: ClientAssertionCredentialOptions, +} + impl WorkloadIdentityCredential { /// Create a new `WorkloadIdentityCredential`. pub fn new( - http_client: Arc, - authority_host: Url, tenant_id: String, client_id: String, token: T, + options: Option, ) -> azure_core::Result> where T: Into, { + let options = options.unwrap_or_default(); Ok(Arc::new(Self( ClientAssertionCredential::::new_exclusive( - http_client, - authority_host, tenant_id, client_id, Token::Value(token.into()), + Some(options.credential_options), )?, ))) } @@ -59,16 +67,19 @@ impl WorkloadIdentityCredential { /// * `AZURE_CLIENT_ID` /// * `AZURE_FEDERATED_TOKEN` or `AZURE_FEDERATED_TOKEN_FILE` pub fn from_env( - options: impl Into, + options: Option, ) -> azure_core::Result> { - let options = options.into(); - let env = options.env(); + let options = options.unwrap_or_default(); + let env = options.credential_options.credential_options.env(); if let Ok(token) = env .var(AZURE_FEDERATED_TOKEN) .map_kind(ErrorKind::Credential) { return Ok(Arc::new(Self( - ClientAssertionCredential::from_env_exclusive(options, Token::Value(token.into()))?, + ClientAssertionCredential::from_env_exclusive( + Token::Value(token.into()), + Some(options.credential_options), + )?, ))); } @@ -78,8 +89,8 @@ impl WorkloadIdentityCredential { { return Ok(Arc::new(Self( ClientAssertionCredential::from_env_exclusive( - options, Token::with_file(token_file.as_ref())?, + Some(options.credential_options), )?, ))); } @@ -94,11 +105,23 @@ impl WorkloadIdentityCredential { #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] impl TokenCredential for WorkloadIdentityCredential { async fn get_token(&self, scopes: &[&str]) -> azure_core::Result { - TokenCredential::get_token(&self.0, scopes).await + self.0.get_token(scopes).await } async fn clear_cache(&self) -> azure_core::Result<()> { - TokenCredential::clear_cache(&self.0).await + self.0.clear_cache().await + } +} + +// TODO: Should probably remove this once we consolidate and unify credentials. +impl From for WorkloadIdentityCredentialOptions { + fn from(value: TokenCredentialOptions) -> Self { + Self { + credential_options: ClientAssertionCredentialOptions { + credential_options: value, + ..Default::default() + }, + } } } diff --git a/sdk/identity/azure_identity/src/federated_credentials_flow/mod.rs b/sdk/identity/azure_identity/src/federated_credentials_flow/mod.rs index 56cf597be0..1402ec4548 100644 --- a/sdk/identity/azure_identity/src/federated_credentials_flow/mod.rs +++ b/sdk/identity/azure_identity/src/federated_credentials_flow/mod.rs @@ -10,7 +10,6 @@ use azure_core::{ }; use response::LoginResponse; use std::sync::Arc; -use tracing::{debug, error}; use url::form_urlencoded; /// Authorize the client using the federated credentials flow. @@ -47,13 +46,10 @@ pub async fn authorize( req.set_body(encoded); let rsp: Response = http_client.execute_request(&req).await?; let rsp_status = rsp.status(); - debug!("rsp_status == {:?}", rsp_status); if rsp_status.is_success() { rsp.into_json_body().await } else { let rsp_body = rsp.into_raw_body().collect().await?; - let text = std::str::from_utf8(&rsp_body)?; - error!("rsp_body == {:?}", text); Err(http_response_from_body(rsp_status, &rsp_body).into_error()) } } diff --git a/sdk/identity/azure_identity/src/lib.rs b/sdk/identity/azure_identity/src/lib.rs index a6883758a8..98dc3b8cf7 100644 --- a/sdk/identity/azure_identity/src/lib.rs +++ b/sdk/identity/azure_identity/src/lib.rs @@ -4,6 +4,7 @@ #![doc = include_str!("../README.md")] mod authorization_code_flow; +mod azure_pipelines_credential; mod credentials; mod env; mod federated_credentials_flow; @@ -11,4 +12,45 @@ mod oauth2_http_client; mod refresh_token; mod timeout; -pub use crate::credentials::*; +use azure_core::{error::ErrorKind, Error, Result}; +pub use azure_pipelines_credential::*; +pub use credentials::*; +use std::borrow::Cow; + +fn validate_not_empty(value: &str, message: C) -> Result<()> +where + C: Into>, +{ + if value.is_empty() { + return Err(Error::message(ErrorKind::Credential, message)); + } + + Ok(()) +} + +#[test] +fn test_validate_not_empty() { + assert!(validate_not_empty("", "it's empty").is_err()); + assert!(validate_not_empty(" ", "it's not empty").is_ok()); + assert!(validate_not_empty("not empty", "it's not empty").is_ok()); +} + +fn validate_tenant_id(tenant_id: &str) -> Result<()> { + if tenant_id.is_empty() + || !tenant_id + .chars() + .all(|c| c.is_alphanumeric() || c == '.' || c == '-') + { + return Err(Error::message(ErrorKind::Credential, "invalid tenantID. You can locate your tenantID by following the instructions listed here: https://learn.microsoft.com/partner-center/find-ids-and-domain-names")); + } + + Ok(()) +} + +#[test] +fn test_validate_tenant_id() { + assert!(validate_tenant_id("").is_err()); + assert!(validate_tenant_id("invalid_tenant@id").is_err()); + assert!(validate_tenant_id("A-1.z").is_ok()); + assert!(validate_tenant_id("7b795fb9-09d3-42f4-a494-38864f99ba3c").is_ok()); +}