Skip to content

Implement a configurable credentials resolver chain #452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -57,6 +57,10 @@ public List<RuntimeClientPlugin> getClientPlugins(GenerationContext context) {
.build())
// TODO: Initialize with the provider chain?
.nullable(true)
.initialize(writer -> {
writer.addImport("smithy_aws_core.credentials_resolvers", "CredentialsResolverChain");
writer.write("self.aws_credentials_identity_resolver = aws_credentials_identity_resolver or CredentialsResolverChain(config=self)");
})
.build())
.addConfigProperty(REGION)
.authScheme(new Sigv4AuthScheme())
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from .chain import CredentialsResolverChain
from .environment import EnvironmentCredentialsResolver
from .imds import IMDSCredentialsResolver
from .static import StaticCredentialsResolver

__all__ = (
"CredentialsResolverChain",
"EnvironmentCredentialsResolver",
"IMDSCredentialsResolver",
"StaticCredentialsResolver",
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence

from smithy_core.aio.interfaces.identity import IdentityResolver
from smithy_core.exceptions import SmithyIdentityException
from smithy_core.interfaces.identity import IdentityProperties

from smithy_aws_core.credentials_resolvers.environment import (
EnvironmentCredentialsSource,
)
from smithy_aws_core.credentials_resolvers.imds import IMDSCredentialsSource
from smithy_aws_core.credentials_resolvers.interfaces import (
AwsCredentialsConfig,
CredentialsSource,
)
from smithy_aws_core.identity import AWSCredentialsIdentity, AWSCredentialsResolver

_DEFAULT_SOURCES: Sequence[CredentialsSource] = (
EnvironmentCredentialsSource(),
IMDSCredentialsSource(),
)


class CredentialsResolverChain(
IdentityResolver[AWSCredentialsIdentity, IdentityProperties]
):
"""Resolves AWS Credentials from an ordered list of credentials sources."""

def __init__(
self,
*,
config: AwsCredentialsConfig,
sources: Sequence[CredentialsSource] = _DEFAULT_SOURCES,
):
self._config = config
self._sources: Sequence[CredentialsSource] = sources
self._credentials_resolver: AWSCredentialsResolver | None = None

async def get_identity(
self, *, identity_properties: IdentityProperties
) -> AWSCredentialsIdentity:
if self._credentials_resolver is not None:
return await self._credentials_resolver.get_identity(
identity_properties=identity_properties
)

for source in self._sources:
if source.is_available(config=self._config):
self._credentials_resolver = source.build_resolver(config=self._config)
return await self._credentials_resolver.get_identity(
identity_properties=identity_properties
)

raise SmithyIdentityException(
"None of the configured credentials sources were able to resolve credentials."
)
Original file line number Diff line number Diff line change
@@ -6,7 +6,12 @@
from smithy_core.exceptions import SmithyIdentityException
from smithy_core.interfaces.identity import IdentityProperties

from ..identity import AWSCredentialsIdentity
from smithy_aws_core.credentials_resolvers.interfaces import (
AwsCredentialsConfig,
CredentialsSource,
)

from ..identity import AWSCredentialsIdentity, AWSCredentialsResolver


class EnvironmentCredentialsResolver(
@@ -41,3 +46,13 @@ async def get_identity(
)

return self._credentials


class EnvironmentCredentialsSource(CredentialsSource):
def is_available(self, config: AwsCredentialsConfig) -> bool:
return (
"AWS_ACCESS_KEY_ID" in os.environ and "AWS_SECRET_ACCESS_KEY" in os.environ
)

def build_resolver(self, config: AwsCredentialsConfig) -> AWSCredentialsResolver:
return EnvironmentCredentialsResolver()
Original file line number Diff line number Diff line change
@@ -17,8 +17,13 @@
from smithy_http.aio import HTTPRequest
from smithy_http.aio.interfaces import HTTPClient

from smithy_aws_core.credentials_resolvers.interfaces import (
AwsCredentialsConfig,
CredentialsSource,
)

from .. import __version__
from ..identity import AWSCredentialsIdentity
from ..identity import AWSCredentialsIdentity, AWSCredentialsResolver

_USER_AGENT_FIELD = Field(
name="User-Agent",
@@ -235,3 +240,14 @@ async def get_identity(
account_id=account_id,
)
return self._credentials


class IMDSCredentialsSource(CredentialsSource):
def is_available(self, config: AwsCredentialsConfig) -> bool:
# IMDS credentials should always be the last in the chain
# We cannot check if they're available without actually making a call
return True

def build_resolver(self, config: AwsCredentialsConfig) -> AWSCredentialsResolver:
# TODO: Configure lower number of retries/lower timeout
return IMDSCredentialsResolver(http_client=config.http_client)
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Protocol

from smithy_http.aio.interfaces import HTTPClient

from smithy_aws_core.identity import AWSCredentialsResolver


class AwsCredentialsConfig(Protocol):
"""Configuration required for resolving credentials."""

http_client: HTTPClient


class CredentialsSource(Protocol):
def is_available(self, config: AwsCredentialsConfig) -> bool:
"""Returns True if credentials are available from this source."""
...

def build_resolver(self, config: AwsCredentialsConfig) -> AWSCredentialsResolver:
"""Builds a credentials resolver for the given configuration."""
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from dataclasses import dataclass
from unittest.mock import Mock

import pytest
from smithy_aws_core.credentials_resolvers import (
CredentialsResolverChain,
IMDSCredentialsResolver,
StaticCredentialsResolver,
)
from smithy_aws_core.credentials_resolvers.environment import (
EnvironmentCredentialsSource,
)
from smithy_aws_core.credentials_resolvers.interfaces import (
AwsCredentialsConfig,
CredentialsSource,
)
from smithy_aws_core.identity import AWSCredentialsIdentity, AWSCredentialsResolver
from smithy_core.exceptions import SmithyIdentityException
from smithy_core.interfaces.identity import IdentityProperties
from smithy_http.aio.interfaces import HTTPClient


@dataclass
class Config:
http_client: HTTPClient

def __init__(self):
self.http_client = Mock(spec=HTTPClient) # type: ignore


async def test_no_sources_resolve():
resolver_chain = CredentialsResolverChain(sources=[], config=Config())
with pytest.raises(SmithyIdentityException):
await resolver_chain.get_identity(identity_properties=IdentityProperties())


async def test_env_credentials_resolver_not_set(monkeypatch: pytest.MonkeyPatch):
monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False)
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False)
resolver_chain = CredentialsResolverChain(
sources=[EnvironmentCredentialsSource()], config=Config()
)

with pytest.raises(SmithyIdentityException):
await resolver_chain.get_identity(identity_properties=IdentityProperties())


async def test_env_credentials_resolver_partial(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "akid")
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False)
resolver_chain = CredentialsResolverChain(
sources=[EnvironmentCredentialsSource()], config=Config()
)

with pytest.raises(SmithyIdentityException):
await resolver_chain.get_identity(identity_properties=IdentityProperties())


async def test_default_sources_env_credentials_resolver_success(
monkeypatch: pytest.MonkeyPatch,
):
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "akid")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "secret")
resolver_chain = CredentialsResolverChain(config=Config())

credentials = await resolver_chain.get_identity(
identity_properties=IdentityProperties()
)
assert credentials.access_key_id == "akid"
assert credentials.secret_access_key == "secret"


async def test_default_sources_imds_resolver_success(monkeypatch: pytest.MonkeyPatch):
monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False)
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False)

async def mock_imds_get_identity(
self: IMDSCredentialsResolver, *, identity_properties: IdentityProperties
) -> AWSCredentialsIdentity:
return AWSCredentialsIdentity(
access_key_id="akid",
secret_access_key="secret",
)

monkeypatch.setattr(
"smithy_aws_core.credentials_resolvers.IMDSCredentialsResolver.get_identity",
mock_imds_get_identity,
)

resolver_chain = CredentialsResolverChain(config=Config())

credentials = await resolver_chain.get_identity(
identity_properties=IdentityProperties()
)
assert credentials.access_key_id == "akid"
assert credentials.secret_access_key == "secret"


async def test_multiple_sources_one_valid():
class FailingSource(CredentialsSource):
def is_available(self, config: AwsCredentialsConfig) -> bool:
return False

def build_resolver(
self, config: AwsCredentialsConfig
) -> AWSCredentialsResolver:
raise RuntimeError("Should not be called")

static_credentials = AWSCredentialsIdentity(
access_key_id="valid_akid", secret_access_key="valid_secret"
)
static_resolver = StaticCredentialsResolver(credentials=static_credentials)

class ValidSource(CredentialsSource):
def is_available(self, config: AwsCredentialsConfig) -> bool:
return True

def build_resolver(
self, config: AwsCredentialsConfig
) -> AWSCredentialsResolver:
return static_resolver

resolver_chain = CredentialsResolverChain(
sources=[FailingSource(), ValidSource()], config=Config()
)

credentials = await resolver_chain.get_identity(
identity_properties=IdentityProperties()
)
assert credentials.access_key_id == "valid_akid"
assert credentials.secret_access_key == "valid_secret"


async def test_cached_resolver_used(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "cached_akid")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "cached_secret")
resolver_chain = CredentialsResolverChain(
sources=[EnvironmentCredentialsSource()], config=Config()
)

credentials1 = await resolver_chain.get_identity(
identity_properties=IdentityProperties()
)
credentials2 = await resolver_chain.get_identity(
identity_properties=IdentityProperties()
)

assert credentials1.access_key_id == credentials2.access_key_id == "cached_akid"
assert (
credentials1.secret_access_key
== credentials2.secret_access_key
== "cached_secret"
)


async def test_custom_sources_with_static_credentials():
static_credentials = AWSCredentialsIdentity(
access_key_id="static_akid",
secret_access_key="static_secret",
)
static_resolver = StaticCredentialsResolver(credentials=static_credentials)

class TestStaticSource(CredentialsSource):
def is_available(self, config: AwsCredentialsConfig) -> bool:
return True

def build_resolver(
self, config: AwsCredentialsConfig
) -> AWSCredentialsResolver:
return static_resolver

resolver_chain = CredentialsResolverChain(
sources=[TestStaticSource()],
config=Config(), # type: ignore
)

credentials = await resolver_chain.get_identity(
identity_properties=IdentityProperties()
)
assert credentials.access_key_id == "static_akid"
assert credentials.secret_access_key == "static_secret"
4 changes: 3 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.