Skip to content

Commit a957672

Browse files
authored
[Identity] Update async cert credential algorithm (Azure#39761)
Signed-off-by: Paul Van Eck <[email protected]>
1 parent ab9a180 commit a957672

File tree

5 files changed

+101
-11
lines changed

5 files changed

+101
-11
lines changed

sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(
6161
self._custom_cache = True
6262
else:
6363
self._custom_cache = False
64+
self._is_adfs = self._tenant_id.lower() == "adfs"
6465

6566
def _get_cache(self, **kwargs: Any) -> TokenCache:
6667
cache = self._cae_cache if kwargs.get("enable_cae") else self._cache
@@ -239,7 +240,16 @@ def _get_jwt_assertion_request(self, scopes: Iterable[str], assertion: str, **kw
239240

240241
def _get_client_certificate_assertion(self, certificate: AadClientCertificate, **kwargs: Any) -> str:
241242
now = int(time.time())
242-
header = json.dumps({"typ": "JWT", "alg": "RS256", "x5t": certificate.thumbprint}).encode("utf-8")
243+
headers = {"typ": "JWT"}
244+
if self._is_adfs:
245+
# Maintain backwards compatibility with older versions of ADFS.
246+
headers["alg"] = "RS256"
247+
headers["x5t"] = certificate.thumbprint
248+
else:
249+
headers["alg"] = "PS256"
250+
headers["x5t#S256"] = certificate.sha256_thumbprint
251+
252+
jwt_header = json.dumps(headers).encode("utf-8")
243253
payload = json.dumps(
244254
{
245255
"jti": str(uuid4()),
@@ -250,8 +260,8 @@ def _get_client_certificate_assertion(self, certificate: AadClientCertificate, *
250260
"exp": now + (60 * 30),
251261
}
252262
).encode("utf-8")
253-
jws = base64.urlsafe_b64encode(header) + b"." + base64.urlsafe_b64encode(payload)
254-
signature = certificate.sign(jws)
263+
jws = base64.urlsafe_b64encode(jwt_header) + b"." + base64.urlsafe_b64encode(payload)
264+
signature = certificate.sign_ps256(jws) if not self._is_adfs else certificate.sign_rs256(jws)
255265
jwt_bytes = jws + b"." + base64.urlsafe_b64encode(signature)
256266
return jwt_bytes.decode("utf-8")
257267

sdk/identity/azure-identity/azure/identity/_internal/aadclient_certificate.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ def __init__(self, pem_bytes: bytes, password: Optional[bytes] = None) -> None:
2626

2727
cert = x509.load_pem_x509_certificate(pem_bytes, default_backend())
2828
fingerprint = cert.fingerprint(hashes.SHA1()) # nosec
29+
sha256_fingerprint = cert.fingerprint(hashes.SHA256())
2930
self._thumbprint = base64.urlsafe_b64encode(fingerprint).decode("utf-8")
31+
self._sha256_thumbprint = base64.urlsafe_b64encode(sha256_fingerprint).decode("utf-8")
3032

3133
@property
3234
def thumbprint(self) -> str:
@@ -36,11 +38,36 @@ def thumbprint(self) -> str:
3638
"""
3739
return self._thumbprint
3840

39-
def sign(self, plaintext: bytes) -> bytes:
41+
@property
42+
def sha256_thumbprint(self) -> str:
43+
"""The certificate's SHA256 thumbprint as a base64url-encoded string.
44+
45+
:rtype: str
46+
"""
47+
return self._sha256_thumbprint
48+
49+
def sign_rs256(self, plaintext: bytes) -> bytes:
4050
"""Sign bytes using RS256.
4151
4252
:param bytes plaintext: Bytes to sign.
4353
:return: The signature.
4454
:rtype: bytes
4555
"""
4656
return self._private_key.sign(plaintext, padding.PKCS1v15(), hashes.SHA256())
57+
58+
def sign_ps256(self, plaintext: bytes) -> bytes:
59+
"""Sign bytes using PS256.
60+
61+
:param bytes plaintext: Bytes to sign.
62+
:return: The signature.
63+
:rtype: bytes
64+
"""
65+
hash_alg = hashes.SHA256()
66+
67+
# Note: For PS265, the salt length should match the hash output size, so we use the hash algorithm's
68+
# digest_size property to get the correct value.
69+
return self._private_key.sign(
70+
plaintext,
71+
padding.PSS(mgf=padding.MGF1(hash_alg), salt_length=hash_alg.digest_size),
72+
hash_alg,
73+
)

sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ class CertificateCredential(AsyncContextManager, GetTokenMixin):
2121
2222
:param str tenant_id: ID of the service principal's tenant. Also called its 'directory' ID.
2323
:param str client_id: The service principal's client ID
24-
:param str certificate_path: Path to a PEM-encoded certificate file including the private key. If not provided,
25-
`certificate_data` is required.
24+
:param str certificate_path: Optional path to a certificate file in PEM or PKCS12 format, including the private
25+
key. If not provided, **certificate_data** is required.
2626
2727
:keyword str authority: Authority of a Microsoft Entra endpoint, for example 'login.microsoftonline.com',
2828
the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts`
2929
defines authorities for other clouds.
30-
:keyword bytes certificate_data: The bytes of a certificate in PEM format, including the private key
30+
:keyword bytes certificate_data: The bytes of a certificate in PEM or PKCS12 format, including the private key.
3131
:keyword password: The certificate's password. If a unicode string, it will be encoded as UTF-8. If the certificate
3232
requires a different encoding, pass appropriately encoded bytes instead.
3333
:paramtype password: str or bytes

sdk/identity/azure-identity/tests/test_certificate_credential.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,56 @@ def validate_jwt(request, client_id, cert_bytes, cert_password, expect_x5c=False
297297
cert.public_key().verify(signature, signed_part.encode("utf-8"), padding.PKCS1v15(), hashes.SHA256())
298298

299299

300+
def validate_jwt_ps256(request, client_id, cert_bytes, cert_password, expect_x5c=False):
301+
"""Validate the request meets Microsoft Entra ID's expectations for a client credential grant using a certificate, as documented
302+
at https://learn.microsoft.com/entra/identity-platform/certificate-credentials
303+
"""
304+
305+
try:
306+
cert = x509.load_pem_x509_certificate(cert_bytes, default_backend())
307+
except ValueError:
308+
if cert_password:
309+
if isinstance(cert_password, str):
310+
cert_password = cert_password.encode("utf-8")
311+
cert_bytes = load_pkcs12_certificate(cert_bytes, cert_password).pem_bytes
312+
cert = x509.load_pem_x509_certificate(cert_bytes, default_backend())
313+
314+
# jwt is of the form 'header.payload.signature'; 'signature' is 'header.payload' signed with cert's private key
315+
jwt = request.body["client_assertion"]
316+
if isinstance(jwt, bytes):
317+
jwt = jwt.decode("utf-8")
318+
header, payload, signature = (urlsafeb64_decode(s) for s in jwt.split("."))
319+
signed_part = jwt[: jwt.rfind(".")]
320+
321+
claims = json.loads(payload.decode("utf-8"))
322+
assert claims["aud"] == request.url
323+
assert claims["iss"] == claims["sub"] == client_id
324+
325+
deserialized_header = json.loads(header.decode("utf-8"))
326+
assert deserialized_header["alg"] == "PS256"
327+
assert deserialized_header["typ"] == "JWT"
328+
if expect_x5c:
329+
# x5c should have all the certs in the file, in order, in PEM format minus headers and footers
330+
pem_lines = cert_bytes.decode("utf-8").splitlines()
331+
header = "-----BEGIN CERTIFICATE-----"
332+
assert len(deserialized_header["x5c"]) == pem_lines.count(header)
333+
334+
# concatenate the PEM file's certs, removing headers and footers
335+
chain_start = pem_lines.index(header)
336+
pem_chain_content = "".join(line for line in pem_lines[chain_start:] if not line.startswith("-" * 5))
337+
assert "".join(deserialized_header["x5c"]) == pem_chain_content, "JWT's x5c claim contains unexpected content"
338+
else:
339+
assert "x5c" not in deserialized_header
340+
assert urlsafeb64_decode(deserialized_header["x5t#S256"]) == cert.fingerprint(hashes.SHA256()) # nosec
341+
342+
cert.public_key().verify(
343+
signature,
344+
signed_part.encode("utf-8"),
345+
padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=hashes.SHA256.digest_size),
346+
hashes.SHA256(),
347+
)
348+
349+
300350
@pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS)
301351
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
302352
def test_token_cache_persistent(cert_path, cert_password, get_token_method):

sdk/identity/azure-identity/tests/test_certificate_credential_async.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from helpers import build_aad_response, mock_response, Request, GET_TOKEN_METHODS
1818
from helpers_async import async_validating_transport, AsyncMockTransport
19-
from test_certificate_credential import ALL_CERTS, EC_CERT_PATH, PEM_CERT_PATH, validate_jwt
19+
from test_certificate_credential import ALL_CERTS, EC_CERT_PATH, PEM_CERT_PATH, validate_jwt, validate_jwt_ps256
2020

2121

2222
def test_non_rsa_key():
@@ -174,19 +174,22 @@ def test_requires_certificate():
174174
@pytest.mark.asyncio
175175
@pytest.mark.parametrize("cert_path,cert_password", ALL_CERTS)
176176
@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS)
177-
async def test_request_body(cert_path, cert_password, get_token_method):
177+
@pytest.mark.parametrize("tenant_id", ("adfs", "tenant"))
178+
async def test_request_body(cert_path, cert_password, get_token_method, tenant_id):
178179
access_token = "***"
179180
authority = "authority.com"
180181
client_id = "client-id"
181182
expected_scope = "scope"
182-
tenant_id = "tenant"
183183

184184
async def mock_send(request, **kwargs):
185185
assert request.body["grant_type"] == "client_credentials"
186186
assert request.body["scope"] == expected_scope
187187

188188
with open(cert_path, "rb") as cert_file:
189-
validate_jwt(request, client_id, cert_file.read(), cert_password)
189+
if tenant_id == "adfs":
190+
validate_jwt(request, client_id, cert_file.read(), cert_password)
191+
else:
192+
validate_jwt_ps256(request, client_id, cert_file.read(), cert_password)
190193

191194
return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": access_token})
192195

0 commit comments

Comments
 (0)